Fix tests and change peer id context logic

This commit is contained in:
mcrakhman 2022-11-07 13:36:10 +01:00 committed by Mikhail Iudin
parent b0b4e5b721
commit 4881c052aa
No known key found for this signature in database
GPG Key ID: FAAAA8BAABDFF1C0
15 changed files with 92 additions and 170 deletions

View File

@ -11,7 +11,6 @@ import (
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ldiff" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ldiff"
"go.uber.org/zap" "go.uber.org/zap"
"storj.io/drpc/drpcctx"
"time" "time"
) )
@ -76,7 +75,7 @@ func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error)
return d.sendPushSpaceRequest(ctx, cl) return d.sendPushSpaceRequest(ctx, cl)
} }
ctx = context.WithValue(ctx, drpcctx.TransportKey{}, p.Secure()) ctx = peer.CtxWithPeerId(ctx, p.Id())
d.pingTreesInCache(ctx, newIds) d.pingTreesInCache(ctx, newIds)
d.pingTreesInCache(ctx, changedIds) d.pingTreesInCache(ctx, changedIds)
d.pingTreesInCache(ctx, removedIds) d.pingTreesInCache(ctx, removedIds)

View File

@ -5,9 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ocache" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ocache"
"github.com/libp2p/go-libp2p/core/sec"
"storj.io/drpc/drpcctx"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@ -190,7 +189,7 @@ func (s *streamPool) AddAndReadStreamAsync(stream spacesyncproto.SpaceStream) {
func (s *streamPool) AddAndReadStreamSync(stream spacesyncproto.SpaceStream) (err error) { func (s *streamPool) AddAndReadStreamSync(stream spacesyncproto.SpaceStream) (err error) {
s.Lock() s.Lock()
peerId, err := GetPeerIdFromStreamContext(stream.Context()) peerId, err := peer.CtxPeerId(stream.Context())
if err != nil { if err != nil {
s.Unlock() s.Unlock()
return return
@ -277,15 +276,6 @@ func (s *streamPool) removePeer(peerId string) (err error) {
return return
} }
func GetPeerIdFromStreamContext(ctx context.Context) (string, error) {
conn, ok := ctx.Value(drpcctx.TransportKey{}).(sec.SecureConn)
if !ok {
return "", fmt.Errorf("incorrect connection type in stream")
}
return conn.RemotePeer().String(), nil
}
func genStreamPoolKey(peerId, treeId string, counter uint64) string { func genStreamPoolKey(peerId, treeId string, counter uint64) string {
return fmt.Sprintf("%s.%s.%d", peerId, treeId, counter) return fmt.Sprintf("%s.%s.%d", peerId, treeId, counter)
} }

View File

@ -3,9 +3,9 @@ package syncservice
import ( import (
"context" "context"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/rpc/rpctest" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/rpc/rpctest"
"github.com/anytypeio/go-anytype-infrastructure-experiments/consensus/consensusproto" "github.com/anytypeio/go-anytype-infrastructure-experiments/consensus/consensusproto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing" "testing"
"time" "time"
@ -53,24 +53,23 @@ type fixture struct {
clientStream spacesyncproto.DRPCSpace_StreamStream clientStream spacesyncproto.DRPCSpace_StreamStream
serverStream spacesyncproto.DRPCSpace_StreamStream serverStream spacesyncproto.DRPCSpace_StreamStream
pool *streamPool pool *streamPool
localId peer.ID clientId string
remoteId peer.ID serverId string
} }
func newFixture(t *testing.T, localId, remoteId peer.ID, handler MessageHandler) *fixture { func newFixture(t *testing.T, clientId, serverId string, handler MessageHandler) *fixture {
fx := &fixture{ fx := &fixture{
testServer: &testServer{}, testServer: &testServer{},
drpcTS: rpctest.NewTestServer(), drpcTS: rpctest.NewTestServer(),
localId: localId, clientId: clientId,
remoteId: remoteId, serverId: serverId,
} }
fx.testServer.stream = make(chan spacesyncproto.DRPCSpace_StreamStream, 1) fx.testServer.stream = make(chan spacesyncproto.DRPCSpace_StreamStream, 1)
require.NoError(t, spacesyncproto.DRPCRegisterSpace(fx.drpcTS.Mux, fx.testServer)) require.NoError(t, spacesyncproto.DRPCRegisterSpace(fx.drpcTS.Mux, fx.testServer))
clientWrapper := rpctest.NewSecConnWrapper(nil, nil, localId, remoteId) fx.client = spacesyncproto.NewDRPCSpaceClient(fx.drpcTS.Dial(peer.CtxWithPeerId(context.Background(), clientId)))
fx.client = spacesyncproto.NewDRPCSpaceClient(fx.drpcTS.DialWrapConn(nil, clientWrapper))
var err error var err error
fx.clientStream, err = fx.client.Stream(context.Background()) fx.clientStream, err = fx.client.Stream(peer.CtxWithPeerId(context.Background(), serverId))
require.NoError(t, err) require.NoError(t, err)
fx.serverStream = fx.testServer.waitStream(t) fx.serverStream = fx.testServer.waitStream(t)
fx.pool = newStreamPool(handler).(*streamPool) fx.pool = newStreamPool(handler).(*streamPool)
@ -87,14 +86,14 @@ func (fx *fixture) run(t *testing.T) chan error {
time.Sleep(time.Millisecond * 10) time.Sleep(time.Millisecond * 10)
fx.pool.Lock() fx.pool.Lock()
require.Equal(t, fx.pool.peerStreams[fx.remoteId.String()], fx.clientStream) require.Equal(t, fx.pool.peerStreams[fx.serverId], fx.clientStream)
fx.pool.Unlock() fx.pool.Unlock()
return waitCh return waitCh
} }
func TestStreamPool_AddAndReadStreamAsync(t *testing.T) { func TestStreamPool_AddAndReadStreamAsync(t *testing.T) {
remId := peer.ID("remoteId") remId := "remoteId"
t.Run("client close", func(t *testing.T) { t.Run("client close", func(t *testing.T) {
fx := newFixture(t, "", remId, nil) fx := newFixture(t, "", remId, nil)
@ -105,7 +104,7 @@ func TestStreamPool_AddAndReadStreamAsync(t *testing.T) {
err = <-waitCh err = <-waitCh
require.Error(t, err) require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()]) require.Nil(t, fx.pool.peerStreams[remId])
}) })
t.Run("server close", func(t *testing.T) { t.Run("server close", func(t *testing.T) {
fx := newFixture(t, "", remId, nil) fx := newFixture(t, "", remId, nil)
@ -116,12 +115,12 @@ func TestStreamPool_AddAndReadStreamAsync(t *testing.T) {
err = <-waitCh err = <-waitCh
require.Error(t, err) require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()]) require.Nil(t, fx.pool.peerStreams[remId])
}) })
} }
func TestStreamPool_Close(t *testing.T) { func TestStreamPool_Close(t *testing.T) {
remId := peer.ID("remoteId") remId := "remoteId"
t.Run("client close", func(t *testing.T) { t.Run("client close", func(t *testing.T) {
fx := newFixture(t, "", remId, nil) fx := newFixture(t, "", remId, nil)
@ -160,7 +159,7 @@ func TestStreamPool_Close(t *testing.T) {
} }
func TestStreamPool_ReceiveMessage(t *testing.T) { func TestStreamPool_ReceiveMessage(t *testing.T) {
remId := peer.ID("remoteId") remId := "remoteId"
t.Run("pool receive message from server", func(t *testing.T) { t.Run("pool receive message from server", func(t *testing.T) {
objectId := "objectId" objectId := "objectId"
msg := &spacesyncproto.ObjectSyncMessage{ msg := &spacesyncproto.ObjectSyncMessage{
@ -182,23 +181,23 @@ func TestStreamPool_ReceiveMessage(t *testing.T) {
err = <-waitCh err = <-waitCh
require.Error(t, err) require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()]) require.Nil(t, fx.pool.peerStreams[remId])
}) })
} }
func TestStreamPool_HasActiveStream(t *testing.T) { func TestStreamPool_HasActiveStream(t *testing.T) {
remId := peer.ID("remoteId") remId := "remoteId"
t.Run("pool has active stream", func(t *testing.T) { t.Run("pool has active stream", func(t *testing.T) {
fx := newFixture(t, "", remId, nil) fx := newFixture(t, "", remId, nil)
waitCh := fx.run(t) waitCh := fx.run(t)
require.True(t, fx.pool.HasActiveStream(remId.String())) require.True(t, fx.pool.HasActiveStream(remId))
err := fx.clientStream.Close() err := fx.clientStream.Close()
require.NoError(t, err) require.NoError(t, err)
err = <-waitCh err = <-waitCh
require.Error(t, err) require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()]) require.Nil(t, fx.pool.peerStreams[remId])
}) })
t.Run("pool has no active stream", func(t *testing.T) { t.Run("pool has no active stream", func(t *testing.T) {
fx := newFixture(t, "", remId, nil) fx := newFixture(t, "", remId, nil)
@ -207,13 +206,13 @@ func TestStreamPool_HasActiveStream(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
err = <-waitCh err = <-waitCh
require.Error(t, err) require.Error(t, err)
require.False(t, fx.pool.HasActiveStream(remId.String())) require.False(t, fx.pool.HasActiveStream(remId))
require.Nil(t, fx.pool.peerStreams[remId.String()]) require.Nil(t, fx.pool.peerStreams[remId])
}) })
} }
func TestStreamPool_SendAsync(t *testing.T) { func TestStreamPool_SendAsync(t *testing.T) {
remId := peer.ID("remoteId") remId := "remoteId"
t.Run("pool send async to server", func(t *testing.T) { t.Run("pool send async to server", func(t *testing.T) {
objectId := "objectId" objectId := "objectId"
msg := &spacesyncproto.ObjectSyncMessage{ msg := &spacesyncproto.ObjectSyncMessage{
@ -229,7 +228,7 @@ func TestStreamPool_SendAsync(t *testing.T) {
}() }()
waitCh := fx.run(t) waitCh := fx.run(t)
err := fx.pool.SendAsync([]string{remId.String()}, msg) err := fx.pool.SendAsync([]string{remId}, msg)
require.NoError(t, err) require.NoError(t, err)
<-recvChan <-recvChan
err = fx.clientStream.Close() err = fx.clientStream.Close()
@ -237,12 +236,12 @@ func TestStreamPool_SendAsync(t *testing.T) {
err = <-waitCh err = <-waitCh
require.Error(t, err) require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()]) require.Nil(t, fx.pool.peerStreams[remId])
}) })
} }
func TestStreamPool_SendSync(t *testing.T) { func TestStreamPool_SendSync(t *testing.T) {
remId := peer.ID("remoteId") remId := "remoteId"
t.Run("pool send sync to server", func(t *testing.T) { t.Run("pool send sync to server", func(t *testing.T) {
objectId := "objectId" objectId := "objectId"
payload := []byte("payload") payload := []byte("payload")
@ -260,7 +259,7 @@ func TestStreamPool_SendSync(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
waitCh := fx.run(t) waitCh := fx.run(t)
res, err := fx.pool.SendSync(remId.String(), msg) res, err := fx.pool.SendSync(remId, msg)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, payload, res.Payload) require.Equal(t, payload, res.Payload)
err = fx.clientStream.Close() err = fx.clientStream.Close()
@ -268,7 +267,7 @@ func TestStreamPool_SendSync(t *testing.T) {
err = <-waitCh err = <-waitCh
require.Error(t, err) require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()]) require.Nil(t, fx.pool.peerStreams[remId])
}) })
t.Run("pool send sync timeout", func(t *testing.T) { t.Run("pool send sync timeout", func(t *testing.T) {
@ -285,19 +284,19 @@ func TestStreamPool_SendSync(t *testing.T) {
require.NotEmpty(t, message.ReplyId) require.NotEmpty(t, message.ReplyId)
}() }()
waitCh := fx.run(t) waitCh := fx.run(t)
_, err := fx.pool.SendSync(remId.String(), msg) _, err := fx.pool.SendSync(remId, msg)
require.Equal(t, ErrSyncTimeout, err) require.Equal(t, ErrSyncTimeout, err)
err = fx.clientStream.Close() err = fx.clientStream.Close()
require.NoError(t, err) require.NoError(t, err)
err = <-waitCh err = <-waitCh
require.Error(t, err) require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()]) require.Nil(t, fx.pool.peerStreams[remId])
}) })
} }
func TestStreamPool_BroadcastAsync(t *testing.T) { func TestStreamPool_BroadcastAsync(t *testing.T) {
remId := peer.ID("remoteId") remId := "remoteId"
t.Run("pool broadcast async to server", func(t *testing.T) { t.Run("pool broadcast async to server", func(t *testing.T) {
objectId := "objectId" objectId := "objectId"
msg := &spacesyncproto.ObjectSyncMessage{ msg := &spacesyncproto.ObjectSyncMessage{
@ -321,6 +320,6 @@ func TestStreamPool_BroadcastAsync(t *testing.T) {
err = <-waitCh err = <-waitCh
require.Error(t, err) require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()]) require.Nil(t, fx.pool.peerStreams[remId])
}) })
} }

View File

@ -97,11 +97,11 @@ func (s *syncService) responsibleStreamCheckLoop(ctx context.Context) {
if err != nil { if err != nil {
return return
} }
for _, peer := range respPeers { for _, p := range respPeers {
if s.streamPool.HasActiveStream(peer.Id()) { if s.streamPool.HasActiveStream(p.Id()) {
continue continue
} }
stream, err := s.clientFactory.Client(peer).Stream(ctx) stream, err := s.clientFactory.Client(p).Stream(ctx)
if err != nil { if err != nil {
err = rpcerr.Unwrap(err) err = rpcerr.Unwrap(err)
log.With("spaceId", s.spaceId).Errorf("failed to open stream: %v", err) log.With("spaceId", s.spaceId).Errorf("failed to open stream: %v", err)

View File

@ -9,6 +9,7 @@ import (
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice/synchandler" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice/synchandler"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/synctree/updatelistener" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/synctree/updatelistener"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/list" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/list"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/storage" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/storage"
@ -109,8 +110,7 @@ func CreateSyncTree(ctx context.Context, deps CreateDeps) (t tree.ObjectTree, er
func BuildSyncTreeOrGetRemote(ctx context.Context, id string, deps BuildDeps) (t tree.ObjectTree, err error) { func BuildSyncTreeOrGetRemote(ctx context.Context, id string, deps BuildDeps) (t tree.ObjectTree, err error) {
getTreeRemote := func() (msg *treechangeproto.TreeSyncMessage, err error) { getTreeRemote := func() (msg *treechangeproto.TreeSyncMessage, err error) {
// TODO: add empty context handling (when this is not happening due to head update) peerId, err := peer.CtxPeerId(ctx)
peerId, err := syncservice.GetPeerIdFromStreamContext(ctx)
if err != nil { if err != nil {
return return
} }

View File

@ -0,0 +1,32 @@
package peer
import (
"context"
"errors"
"github.com/libp2p/go-libp2p/core/sec"
"storj.io/drpc/drpcctx"
)
type contextKey uint
const (
contextKeyPeerId contextKey = iota
)
var ErrPeerIdNotFoundInContext = errors.New("peer id not found in context")
// CtxPeerId first tries to get peer id under our own key, but if it is not found tries to get through DRPC key
func CtxPeerId(ctx context.Context) (string, error) {
if peerId, ok := ctx.Value(contextKeyPeerId).(string); ok {
return peerId, nil
}
if conn, ok := ctx.Value(drpcctx.TransportKey{}).(sec.SecureConn); ok {
return conn.RemotePeer().String(), nil
}
return "", ErrPeerIdNotFoundInContext
}
// CtxWithPeerId sets peer id in the context
func CtxWithPeerId(ctx context.Context, peerId string) context.Context {
return context.WithValue(ctx, contextKeyPeerId, peerId)
}

View File

@ -21,7 +21,6 @@ type Peer interface {
Id() string Id() string
LastUsage() time.Time LastUsage() time.Time
UpdateLastUsage() UpdateLastUsage()
Secure() sec.SecureConn
drpc.Conn drpc.Conn
} }
@ -36,10 +35,6 @@ func (p *peer) Id() string {
return p.id return p.id
} }
func (p *peer) Secure() sec.SecureConn {
return p.sc
}
func (p *peer) LastUsage() time.Time { func (p *peer) LastUsage() time.Time {
select { select {
case <-p.Closed(): case <-p.Closed():

View File

@ -36,7 +36,7 @@ func (t *TestPool) Get(ctx context.Context, id string) (peer.Peer, error) {
if t.ts == nil { if t.ts == nil {
return nil, ErrCantConnect return nil, ErrCantConnect
} }
return &testPeer{id: id, Conn: t.ts.Dial()}, nil return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil
} }
func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) { func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) {
@ -45,7 +45,7 @@ func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) {
if t.ts == nil { if t.ts == nil {
return nil, ErrCantConnect return nil, ErrCantConnect
} }
return &testPeer{id: id, Conn: t.ts.Dial()}, nil return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil
} }
func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
@ -54,7 +54,7 @@ func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, e
if t.ts == nil { if t.ts == nil {
return nil, ErrCantConnect return nil, ErrCantConnect
} }
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial()}, nil return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
} }
func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
@ -63,7 +63,7 @@ func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer,
if t.ts == nil { if t.ts == nil {
return nil, ErrCantConnect return nil, ErrCantConnect
} }
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial()}, nil return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
} }
func (t *TestPool) Init(a *app.App) (err error) { func (t *TestPool) Init(a *app.App) (err error) {

View File

@ -2,8 +2,6 @@ package rpctest
import ( import (
"context" "context"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"net" "net"
"storj.io/drpc" "storj.io/drpc"
"storj.io/drpc/drpcconn" "storj.io/drpc/drpcconn"
@ -11,48 +9,6 @@ import (
"storj.io/drpc/drpcserver" "storj.io/drpc/drpcserver"
) )
type SecConnMock struct {
net.Conn
localPrivKey crypto.PrivKey
remotePubKey crypto.PubKey
localId peer.ID
remoteId peer.ID
}
func (s *SecConnMock) LocalPeer() peer.ID {
return s.localId
}
func (s *SecConnMock) LocalPrivateKey() crypto.PrivKey {
return s.localPrivKey
}
func (s *SecConnMock) RemotePeer() peer.ID {
return s.remoteId
}
func (s *SecConnMock) RemotePublicKey() crypto.PubKey {
return s.remotePubKey
}
type ConnWrapper func(conn net.Conn) net.Conn
func NewSecConnWrapper(
localPrivKey crypto.PrivKey,
remotePubKey crypto.PubKey,
localId peer.ID,
remoteId peer.ID) ConnWrapper {
return func(conn net.Conn) net.Conn {
return &SecConnMock{
Conn: conn,
localPrivKey: localPrivKey,
remotePubKey: remotePubKey,
localId: localId,
remoteId: remoteId,
}
}
}
func NewTestServer() *TesServer { func NewTestServer() *TesServer {
ts := &TesServer{ ts := &TesServer{
Mux: drpcmux.New(), Mux: drpcmux.New(),
@ -66,18 +22,8 @@ type TesServer struct {
*drpcserver.Server *drpcserver.Server
} }
func (ts *TesServer) Dial() drpc.Conn { func (ts *TesServer) Dial(ctx context.Context) drpc.Conn {
return ts.DialWrapConn(nil, nil)
}
func (ts *TesServer) DialWrapConn(serverWrapper ConnWrapper, clientWrapper ConnWrapper) drpc.Conn {
sc, cc := net.Pipe() sc, cc := net.Pipe()
if serverWrapper != nil { go ts.Server.ServeOne(ctx, sc)
sc = serverWrapper(sc)
}
if clientWrapper != nil {
cc = clientWrapper(cc)
}
go ts.Server.ServeOne(context.Background(), sc)
return drpcconn.New(cc) return drpcconn.New(cc)
} }

View File

@ -1,28 +0,0 @@
package secure
import (
"context"
"errors"
"github.com/libp2p/go-libp2p/core/sec"
)
var (
ErrSecureConnNotFoundInContext = errors.New("secure connection not found in context")
)
type contextKey uint
const (
contextKeySecureConn contextKey = iota
)
func CtxSecureConn(ctx context.Context) (sec.SecureConn, error) {
if conn, ok := ctx.Value(contextKeySecureConn).(sec.SecureConn); ok {
return conn, nil
}
return nil, ErrSecureConnNotFoundInContext
}
func ctxWithSecureConn(ctx context.Context, conn sec.SecureConn) context.Context {
return context.WithValue(ctx, contextKeySecureConn, conn)
}

View File

@ -2,6 +2,7 @@ package secure
import ( import (
"context" "context"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
"github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/crypto"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"net" "net"
@ -45,6 +46,6 @@ func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.C
if err != nil { if err != nil {
return nil, nil, HandshakeError(err) return nil, nil, HandshakeError(err)
} }
ctx = ctxWithSecureConn(ctx, secure) ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String())
return ctx, secure, nil return ctx, secure, nil
} }

View File

@ -65,31 +65,24 @@ message ACLUserAdd {
ACLUserPermissions permissions = 4; ACLUserPermissions permissions = 4;
} }
// accept key, encrypt key, invite id
// GetSpace(id) -> ... (space header + acl root) -> diff
// Join(ACLJoinRecord) -> Ok
message ACLUserInvite { message ACLUserInvite {
bytes acceptPublicKey = 1; bytes acceptPublicKey = 1;
// TODO: change to read key uint64 encryptSymKeyHash = 2;
bytes encryptPublicKey = 2;
repeated bytes encryptedReadKeys = 3; repeated bytes encryptedReadKeys = 3;
ACLUserPermissions permissions = 4; ACLUserPermissions permissions = 4;
// TODO: either derive inviteId from pub keys or think if it is possible to just use ACL record id
string inviteId = 5;
} }
message ACLUserJoin { message ACLUserJoin {
bytes identity = 1; bytes identity = 1;
bytes encryptionKey = 2; bytes encryptionKey = 2;
bytes acceptSignature = 3; bytes acceptSignature = 3;
string inviteId = 4; bytes acceptPubKey = 4;
repeated bytes encryptedReadKeys = 5; repeated bytes encryptedReadKeys = 5;
} }
message ACLUserRemove { message ACLUserRemove {
bytes identity = 1; bytes identity = 1;
repeated ACLReadKeyReplace readKeyReplaces = 3; repeated ACLReadKeyReplace readKeyReplaces = 2;
} }
message ACLReadKeyReplace { message ACLReadKeyReplace {
@ -109,7 +102,6 @@ enum ACLUserPermissions {
Reader = 2; Reader = 2;
} }
message ACLSyncMessage { message ACLSyncMessage {
ACLSyncContentValue content = 2; ACLSyncContentValue content = 2;
} }

View File

@ -1,7 +1,6 @@
package list package list
import ( import (
"context"
account "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/account" account "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/account"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/common" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/common"
@ -44,7 +43,8 @@ func TestAclRecordBuilder_BuildUserJoin(t *testing.T) {
Payload: marshalledJoin, Payload: marshalledJoin,
Id: id, Id: id,
} }
err = aclList.AddRawRecords(context.Background(), []*aclrecordproto.RawACLRecordWithId{rawRec}) res, err := aclList.AddRawRecord(rawRec)
require.True(t, res)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, aclrecordproto.ACLUserPermissions_Writer, aclList.ACLState().UserStates()[identity].Permissions) require.Equal(t, aclrecordproto.ACLUserPermissions_Writer, aclList.ACLState().UserStates()[identity].Permissions)
} }

View File

@ -89,7 +89,3 @@ func TestAclList_ACLState_UserJoinAndRemove(t *testing.T) {
_, err = aclList.ACLState().PermissionsAtRecord(records[3].Id, idB) _, err = aclList.ACLState().PermissionsAtRecord(records[3].Id, idB)
assert.Error(t, err, "B should have no permissions at record 3, because user should be removed") assert.Error(t, err, "B should have no permissions at record 3, because user should be removed")
} }
func TestAclList_AddRawRecord(t *testing.T) {
}

View File

@ -5,7 +5,6 @@
package mock_list package mock_list
import ( import (
context "context"
reflect "reflect" reflect "reflect"
aclrecordproto "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto" aclrecordproto "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto"
@ -50,18 +49,19 @@ func (mr *MockACLListMockRecorder) ACLState() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ACLState", reflect.TypeOf((*MockACLList)(nil).ACLState)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ACLState", reflect.TypeOf((*MockACLList)(nil).ACLState))
} }
// AddRawRecords mocks base method. // AddRawRecord mocks base method.
func (m *MockACLList) AddRawRecords(arg0 context.Context, arg1 []*aclrecordproto.RawACLRecordWithId) error { func (m *MockACLList) AddRawRecord(arg0 *aclrecordproto.RawACLRecordWithId) (bool, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddRawRecords", arg0, arg1) ret := m.ctrl.Call(m, "AddRawRecord", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(bool)
return ret0 ret1, _ := ret[1].(error)
return ret0, ret1
} }
// AddRawRecords indicates an expected call of AddRawRecords. // AddRawRecord indicates an expected call of AddRawRecord.
func (mr *MockACLListMockRecorder) AddRawRecords(arg0, arg1 interface{}) *gomock.Call { func (mr *MockACLListMockRecorder) AddRawRecord(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecords", reflect.TypeOf((*MockACLList)(nil).AddRawRecords), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockACLList)(nil).AddRawRecord), arg0)
} }
// Close mocks base method. // Close mocks base method.