From 4881c052aad45daa4751b840cabaf8165ad54ff7 Mon Sep 17 00:00:00 2001 From: mcrakhman Date: Mon, 7 Nov 2022 13:36:10 +0100 Subject: [PATCH] Fix tests and change peer id context logic --- common/commonspace/diffservice/diffsyncer.go | 3 +- common/commonspace/syncservice/streampool.go | 14 +---- .../syncservice/streampool_test.go | 61 +++++++++---------- common/commonspace/syncservice/syncservice.go | 6 +- common/commonspace/synctree/synctree.go | 4 +- common/net/peer/context.go | 32 ++++++++++ common/net/peer/peer.go | 5 -- common/net/rpc/rpctest/pool.go | 8 +-- common/net/rpc/rpctest/server.go | 58 +----------------- common/net/secure/context.go | 28 --------- common/net/secure/listener.go | 3 +- .../acl/aclrecordproto/protos/aclrecord.proto | 14 +---- common/pkg/acl/list/aclrecordbuilder_test.go | 4 +- common/pkg/acl/list/list_test.go | 4 -- common/pkg/acl/list/mock_list/mock_list.go | 18 +++--- 15 files changed, 92 insertions(+), 170 deletions(-) create mode 100644 common/net/peer/context.go delete mode 100644 common/net/secure/context.go diff --git a/common/commonspace/diffservice/diffsyncer.go b/common/commonspace/diffservice/diffsyncer.go index a30100b2..f9ca9f26 100644 --- a/common/commonspace/diffservice/diffsyncer.go +++ b/common/commonspace/diffservice/diffsyncer.go @@ -11,7 +11,6 @@ import ( "github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ldiff" "go.uber.org/zap" - "storj.io/drpc/drpcctx" "time" ) @@ -76,7 +75,7 @@ func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error) return d.sendPushSpaceRequest(ctx, cl) } - ctx = context.WithValue(ctx, drpcctx.TransportKey{}, p.Secure()) + ctx = peer.CtxWithPeerId(ctx, p.Id()) d.pingTreesInCache(ctx, newIds) d.pingTreesInCache(ctx, changedIds) d.pingTreesInCache(ctx, removedIds) diff --git a/common/commonspace/syncservice/streampool.go b/common/commonspace/syncservice/streampool.go index f5e21a20..a2bbe4ce 100644 --- a/common/commonspace/syncservice/streampool.go +++ b/common/commonspace/syncservice/streampool.go @@ -5,9 +5,8 @@ import ( "errors" "fmt" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto" + "github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ocache" - "github.com/libp2p/go-libp2p/core/sec" - "storj.io/drpc/drpcctx" "sync" "sync/atomic" "time" @@ -190,7 +189,7 @@ func (s *streamPool) AddAndReadStreamAsync(stream spacesyncproto.SpaceStream) { func (s *streamPool) AddAndReadStreamSync(stream spacesyncproto.SpaceStream) (err error) { s.Lock() - peerId, err := GetPeerIdFromStreamContext(stream.Context()) + peerId, err := peer.CtxPeerId(stream.Context()) if err != nil { s.Unlock() return @@ -277,15 +276,6 @@ func (s *streamPool) removePeer(peerId string) (err error) { return } -func GetPeerIdFromStreamContext(ctx context.Context) (string, error) { - conn, ok := ctx.Value(drpcctx.TransportKey{}).(sec.SecureConn) - if !ok { - return "", fmt.Errorf("incorrect connection type in stream") - } - - return conn.RemotePeer().String(), nil -} - func genStreamPoolKey(peerId, treeId string, counter uint64) string { return fmt.Sprintf("%s.%s.%d", peerId, treeId, counter) } diff --git a/common/commonspace/syncservice/streampool_test.go b/common/commonspace/syncservice/streampool_test.go index a9cc2bae..2d2ff512 100644 --- a/common/commonspace/syncservice/streampool_test.go +++ b/common/commonspace/syncservice/streampool_test.go @@ -3,9 +3,9 @@ package syncservice import ( "context" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto" + "github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/rpc/rpctest" "github.com/anytypeio/go-anytype-infrastructure-experiments/consensus/consensusproto" - "github.com/libp2p/go-libp2p/core/peer" "github.com/stretchr/testify/require" "testing" "time" @@ -53,24 +53,23 @@ type fixture struct { clientStream spacesyncproto.DRPCSpace_StreamStream serverStream spacesyncproto.DRPCSpace_StreamStream pool *streamPool - localId peer.ID - remoteId peer.ID + clientId string + serverId string } -func newFixture(t *testing.T, localId, remoteId peer.ID, handler MessageHandler) *fixture { +func newFixture(t *testing.T, clientId, serverId string, handler MessageHandler) *fixture { fx := &fixture{ testServer: &testServer{}, drpcTS: rpctest.NewTestServer(), - localId: localId, - remoteId: remoteId, + clientId: clientId, + serverId: serverId, } fx.testServer.stream = make(chan spacesyncproto.DRPCSpace_StreamStream, 1) require.NoError(t, spacesyncproto.DRPCRegisterSpace(fx.drpcTS.Mux, fx.testServer)) - clientWrapper := rpctest.NewSecConnWrapper(nil, nil, localId, remoteId) - fx.client = spacesyncproto.NewDRPCSpaceClient(fx.drpcTS.DialWrapConn(nil, clientWrapper)) + fx.client = spacesyncproto.NewDRPCSpaceClient(fx.drpcTS.Dial(peer.CtxWithPeerId(context.Background(), clientId))) var err error - fx.clientStream, err = fx.client.Stream(context.Background()) + fx.clientStream, err = fx.client.Stream(peer.CtxWithPeerId(context.Background(), serverId)) require.NoError(t, err) fx.serverStream = fx.testServer.waitStream(t) fx.pool = newStreamPool(handler).(*streamPool) @@ -87,14 +86,14 @@ func (fx *fixture) run(t *testing.T) chan error { time.Sleep(time.Millisecond * 10) fx.pool.Lock() - require.Equal(t, fx.pool.peerStreams[fx.remoteId.String()], fx.clientStream) + require.Equal(t, fx.pool.peerStreams[fx.serverId], fx.clientStream) fx.pool.Unlock() return waitCh } func TestStreamPool_AddAndReadStreamAsync(t *testing.T) { - remId := peer.ID("remoteId") + remId := "remoteId" t.Run("client close", func(t *testing.T) { fx := newFixture(t, "", remId, nil) @@ -105,7 +104,7 @@ func TestStreamPool_AddAndReadStreamAsync(t *testing.T) { err = <-waitCh require.Error(t, err) - require.Nil(t, fx.pool.peerStreams[remId.String()]) + require.Nil(t, fx.pool.peerStreams[remId]) }) t.Run("server close", func(t *testing.T) { fx := newFixture(t, "", remId, nil) @@ -116,12 +115,12 @@ func TestStreamPool_AddAndReadStreamAsync(t *testing.T) { err = <-waitCh require.Error(t, err) - require.Nil(t, fx.pool.peerStreams[remId.String()]) + require.Nil(t, fx.pool.peerStreams[remId]) }) } func TestStreamPool_Close(t *testing.T) { - remId := peer.ID("remoteId") + remId := "remoteId" t.Run("client close", func(t *testing.T) { fx := newFixture(t, "", remId, nil) @@ -160,7 +159,7 @@ func TestStreamPool_Close(t *testing.T) { } func TestStreamPool_ReceiveMessage(t *testing.T) { - remId := peer.ID("remoteId") + remId := "remoteId" t.Run("pool receive message from server", func(t *testing.T) { objectId := "objectId" msg := &spacesyncproto.ObjectSyncMessage{ @@ -182,23 +181,23 @@ func TestStreamPool_ReceiveMessage(t *testing.T) { err = <-waitCh require.Error(t, err) - require.Nil(t, fx.pool.peerStreams[remId.String()]) + require.Nil(t, fx.pool.peerStreams[remId]) }) } func TestStreamPool_HasActiveStream(t *testing.T) { - remId := peer.ID("remoteId") + remId := "remoteId" t.Run("pool has active stream", func(t *testing.T) { fx := newFixture(t, "", remId, nil) waitCh := fx.run(t) - require.True(t, fx.pool.HasActiveStream(remId.String())) + require.True(t, fx.pool.HasActiveStream(remId)) err := fx.clientStream.Close() require.NoError(t, err) err = <-waitCh require.Error(t, err) - require.Nil(t, fx.pool.peerStreams[remId.String()]) + require.Nil(t, fx.pool.peerStreams[remId]) }) t.Run("pool has no active stream", func(t *testing.T) { fx := newFixture(t, "", remId, nil) @@ -207,13 +206,13 @@ func TestStreamPool_HasActiveStream(t *testing.T) { require.NoError(t, err) err = <-waitCh require.Error(t, err) - require.False(t, fx.pool.HasActiveStream(remId.String())) - require.Nil(t, fx.pool.peerStreams[remId.String()]) + require.False(t, fx.pool.HasActiveStream(remId)) + require.Nil(t, fx.pool.peerStreams[remId]) }) } func TestStreamPool_SendAsync(t *testing.T) { - remId := peer.ID("remoteId") + remId := "remoteId" t.Run("pool send async to server", func(t *testing.T) { objectId := "objectId" msg := &spacesyncproto.ObjectSyncMessage{ @@ -229,7 +228,7 @@ func TestStreamPool_SendAsync(t *testing.T) { }() waitCh := fx.run(t) - err := fx.pool.SendAsync([]string{remId.String()}, msg) + err := fx.pool.SendAsync([]string{remId}, msg) require.NoError(t, err) <-recvChan err = fx.clientStream.Close() @@ -237,12 +236,12 @@ func TestStreamPool_SendAsync(t *testing.T) { err = <-waitCh require.Error(t, err) - require.Nil(t, fx.pool.peerStreams[remId.String()]) + require.Nil(t, fx.pool.peerStreams[remId]) }) } func TestStreamPool_SendSync(t *testing.T) { - remId := peer.ID("remoteId") + remId := "remoteId" t.Run("pool send sync to server", func(t *testing.T) { objectId := "objectId" payload := []byte("payload") @@ -260,7 +259,7 @@ func TestStreamPool_SendSync(t *testing.T) { require.NoError(t, err) }() waitCh := fx.run(t) - res, err := fx.pool.SendSync(remId.String(), msg) + res, err := fx.pool.SendSync(remId, msg) require.NoError(t, err) require.Equal(t, payload, res.Payload) err = fx.clientStream.Close() @@ -268,7 +267,7 @@ func TestStreamPool_SendSync(t *testing.T) { err = <-waitCh require.Error(t, err) - require.Nil(t, fx.pool.peerStreams[remId.String()]) + require.Nil(t, fx.pool.peerStreams[remId]) }) t.Run("pool send sync timeout", func(t *testing.T) { @@ -285,19 +284,19 @@ func TestStreamPool_SendSync(t *testing.T) { require.NotEmpty(t, message.ReplyId) }() waitCh := fx.run(t) - _, err := fx.pool.SendSync(remId.String(), msg) + _, err := fx.pool.SendSync(remId, msg) require.Equal(t, ErrSyncTimeout, err) err = fx.clientStream.Close() require.NoError(t, err) err = <-waitCh require.Error(t, err) - require.Nil(t, fx.pool.peerStreams[remId.String()]) + require.Nil(t, fx.pool.peerStreams[remId]) }) } func TestStreamPool_BroadcastAsync(t *testing.T) { - remId := peer.ID("remoteId") + remId := "remoteId" t.Run("pool broadcast async to server", func(t *testing.T) { objectId := "objectId" msg := &spacesyncproto.ObjectSyncMessage{ @@ -321,6 +320,6 @@ func TestStreamPool_BroadcastAsync(t *testing.T) { err = <-waitCh require.Error(t, err) - require.Nil(t, fx.pool.peerStreams[remId.String()]) + require.Nil(t, fx.pool.peerStreams[remId]) }) } diff --git a/common/commonspace/syncservice/syncservice.go b/common/commonspace/syncservice/syncservice.go index 38de40f8..983f5698 100644 --- a/common/commonspace/syncservice/syncservice.go +++ b/common/commonspace/syncservice/syncservice.go @@ -97,11 +97,11 @@ func (s *syncService) responsibleStreamCheckLoop(ctx context.Context) { if err != nil { return } - for _, peer := range respPeers { - if s.streamPool.HasActiveStream(peer.Id()) { + for _, p := range respPeers { + if s.streamPool.HasActiveStream(p.Id()) { continue } - stream, err := s.clientFactory.Client(peer).Stream(ctx) + stream, err := s.clientFactory.Client(p).Stream(ctx) if err != nil { err = rpcerr.Unwrap(err) log.With("spaceId", s.spaceId).Errorf("failed to open stream: %v", err) diff --git a/common/commonspace/synctree/synctree.go b/common/commonspace/synctree/synctree.go index 909ea82c..21a017b8 100644 --- a/common/commonspace/synctree/synctree.go +++ b/common/commonspace/synctree/synctree.go @@ -9,6 +9,7 @@ import ( "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice/synchandler" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/synctree/updatelistener" + "github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/list" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/storage" @@ -109,8 +110,7 @@ func CreateSyncTree(ctx context.Context, deps CreateDeps) (t tree.ObjectTree, er func BuildSyncTreeOrGetRemote(ctx context.Context, id string, deps BuildDeps) (t tree.ObjectTree, err error) { getTreeRemote := func() (msg *treechangeproto.TreeSyncMessage, err error) { - // TODO: add empty context handling (when this is not happening due to head update) - peerId, err := syncservice.GetPeerIdFromStreamContext(ctx) + peerId, err := peer.CtxPeerId(ctx) if err != nil { return } diff --git a/common/net/peer/context.go b/common/net/peer/context.go new file mode 100644 index 00000000..bf3b79f3 --- /dev/null +++ b/common/net/peer/context.go @@ -0,0 +1,32 @@ +package peer + +import ( + "context" + "errors" + "github.com/libp2p/go-libp2p/core/sec" + "storj.io/drpc/drpcctx" +) + +type contextKey uint + +const ( + contextKeyPeerId contextKey = iota +) + +var ErrPeerIdNotFoundInContext = errors.New("peer id not found in context") + +// CtxPeerId first tries to get peer id under our own key, but if it is not found tries to get through DRPC key +func CtxPeerId(ctx context.Context) (string, error) { + if peerId, ok := ctx.Value(contextKeyPeerId).(string); ok { + return peerId, nil + } + if conn, ok := ctx.Value(drpcctx.TransportKey{}).(sec.SecureConn); ok { + return conn.RemotePeer().String(), nil + } + return "", ErrPeerIdNotFoundInContext +} + +// CtxWithPeerId sets peer id in the context +func CtxWithPeerId(ctx context.Context, peerId string) context.Context { + return context.WithValue(ctx, contextKeyPeerId, peerId) +} diff --git a/common/net/peer/peer.go b/common/net/peer/peer.go index 3dfe517b..6056b0b9 100644 --- a/common/net/peer/peer.go +++ b/common/net/peer/peer.go @@ -21,7 +21,6 @@ type Peer interface { Id() string LastUsage() time.Time UpdateLastUsage() - Secure() sec.SecureConn drpc.Conn } @@ -36,10 +35,6 @@ func (p *peer) Id() string { return p.id } -func (p *peer) Secure() sec.SecureConn { - return p.sc -} - func (p *peer) LastUsage() time.Time { select { case <-p.Closed(): diff --git a/common/net/rpc/rpctest/pool.go b/common/net/rpc/rpctest/pool.go index 7e73b226..549e8712 100644 --- a/common/net/rpc/rpctest/pool.go +++ b/common/net/rpc/rpctest/pool.go @@ -36,7 +36,7 @@ func (t *TestPool) Get(ctx context.Context, id string) (peer.Peer, error) { if t.ts == nil { return nil, ErrCantConnect } - return &testPeer{id: id, Conn: t.ts.Dial()}, nil + return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil } func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) { @@ -45,7 +45,7 @@ func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) { if t.ts == nil { return nil, ErrCantConnect } - return &testPeer{id: id, Conn: t.ts.Dial()}, nil + return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil } func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { @@ -54,7 +54,7 @@ func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, e if t.ts == nil { return nil, ErrCantConnect } - return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial()}, nil + return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil } func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { @@ -63,7 +63,7 @@ func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, if t.ts == nil { return nil, ErrCantConnect } - return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial()}, nil + return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil } func (t *TestPool) Init(a *app.App) (err error) { diff --git a/common/net/rpc/rpctest/server.go b/common/net/rpc/rpctest/server.go index 134ce6cb..9a36d288 100644 --- a/common/net/rpc/rpctest/server.go +++ b/common/net/rpc/rpctest/server.go @@ -2,8 +2,6 @@ package rpctest import ( "context" - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/peer" "net" "storj.io/drpc" "storj.io/drpc/drpcconn" @@ -11,48 +9,6 @@ import ( "storj.io/drpc/drpcserver" ) -type SecConnMock struct { - net.Conn - localPrivKey crypto.PrivKey - remotePubKey crypto.PubKey - localId peer.ID - remoteId peer.ID -} - -func (s *SecConnMock) LocalPeer() peer.ID { - return s.localId -} - -func (s *SecConnMock) LocalPrivateKey() crypto.PrivKey { - return s.localPrivKey -} - -func (s *SecConnMock) RemotePeer() peer.ID { - return s.remoteId -} - -func (s *SecConnMock) RemotePublicKey() crypto.PubKey { - return s.remotePubKey -} - -type ConnWrapper func(conn net.Conn) net.Conn - -func NewSecConnWrapper( - localPrivKey crypto.PrivKey, - remotePubKey crypto.PubKey, - localId peer.ID, - remoteId peer.ID) ConnWrapper { - return func(conn net.Conn) net.Conn { - return &SecConnMock{ - Conn: conn, - localPrivKey: localPrivKey, - remotePubKey: remotePubKey, - localId: localId, - remoteId: remoteId, - } - } -} - func NewTestServer() *TesServer { ts := &TesServer{ Mux: drpcmux.New(), @@ -66,18 +22,8 @@ type TesServer struct { *drpcserver.Server } -func (ts *TesServer) Dial() drpc.Conn { - return ts.DialWrapConn(nil, nil) -} - -func (ts *TesServer) DialWrapConn(serverWrapper ConnWrapper, clientWrapper ConnWrapper) drpc.Conn { +func (ts *TesServer) Dial(ctx context.Context) drpc.Conn { sc, cc := net.Pipe() - if serverWrapper != nil { - sc = serverWrapper(sc) - } - if clientWrapper != nil { - cc = clientWrapper(cc) - } - go ts.Server.ServeOne(context.Background(), sc) + go ts.Server.ServeOne(ctx, sc) return drpcconn.New(cc) } diff --git a/common/net/secure/context.go b/common/net/secure/context.go deleted file mode 100644 index e22b3b00..00000000 --- a/common/net/secure/context.go +++ /dev/null @@ -1,28 +0,0 @@ -package secure - -import ( - "context" - "errors" - "github.com/libp2p/go-libp2p/core/sec" -) - -var ( - ErrSecureConnNotFoundInContext = errors.New("secure connection not found in context") -) - -type contextKey uint - -const ( - contextKeySecureConn contextKey = iota -) - -func CtxSecureConn(ctx context.Context) (sec.SecureConn, error) { - if conn, ok := ctx.Value(contextKeySecureConn).(sec.SecureConn); ok { - return conn, nil - } - return nil, ErrSecureConnNotFoundInContext -} - -func ctxWithSecureConn(ctx context.Context, conn sec.SecureConn) context.Context { - return context.WithValue(ctx, contextKeySecureConn, conn) -} diff --git a/common/net/secure/listener.go b/common/net/secure/listener.go index b7c31b26..db16470f 100644 --- a/common/net/secure/listener.go +++ b/common/net/secure/listener.go @@ -2,6 +2,7 @@ package secure import ( "context" + "github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer" "github.com/libp2p/go-libp2p/core/crypto" libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" "net" @@ -45,6 +46,6 @@ func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.C if err != nil { return nil, nil, HandshakeError(err) } - ctx = ctxWithSecureConn(ctx, secure) + ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String()) return ctx, secure, nil } diff --git a/common/pkg/acl/aclrecordproto/protos/aclrecord.proto b/common/pkg/acl/aclrecordproto/protos/aclrecord.proto index ca527e04..51f74853 100644 --- a/common/pkg/acl/aclrecordproto/protos/aclrecord.proto +++ b/common/pkg/acl/aclrecordproto/protos/aclrecord.proto @@ -65,31 +65,24 @@ message ACLUserAdd { ACLUserPermissions permissions = 4; } -// accept key, encrypt key, invite id -// GetSpace(id) -> ... (space header + acl root) -> diff -// Join(ACLJoinRecord) -> Ok - message ACLUserInvite { bytes acceptPublicKey = 1; - // TODO: change to read key - bytes encryptPublicKey = 2; + uint64 encryptSymKeyHash = 2; repeated bytes encryptedReadKeys = 3; ACLUserPermissions permissions = 4; - // TODO: either derive inviteId from pub keys or think if it is possible to just use ACL record id - string inviteId = 5; } message ACLUserJoin { bytes identity = 1; bytes encryptionKey = 2; bytes acceptSignature = 3; - string inviteId = 4; + bytes acceptPubKey = 4; repeated bytes encryptedReadKeys = 5; } message ACLUserRemove { bytes identity = 1; - repeated ACLReadKeyReplace readKeyReplaces = 3; + repeated ACLReadKeyReplace readKeyReplaces = 2; } message ACLReadKeyReplace { @@ -109,7 +102,6 @@ enum ACLUserPermissions { Reader = 2; } - message ACLSyncMessage { ACLSyncContentValue content = 2; } diff --git a/common/pkg/acl/list/aclrecordbuilder_test.go b/common/pkg/acl/list/aclrecordbuilder_test.go index 42a9cbb7..a804bcb4 100644 --- a/common/pkg/acl/list/aclrecordbuilder_test.go +++ b/common/pkg/acl/list/aclrecordbuilder_test.go @@ -1,7 +1,6 @@ package list import ( - "context" account "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/account" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/common" @@ -44,7 +43,8 @@ func TestAclRecordBuilder_BuildUserJoin(t *testing.T) { Payload: marshalledJoin, Id: id, } - err = aclList.AddRawRecords(context.Background(), []*aclrecordproto.RawACLRecordWithId{rawRec}) + res, err := aclList.AddRawRecord(rawRec) + require.True(t, res) require.NoError(t, err) require.Equal(t, aclrecordproto.ACLUserPermissions_Writer, aclList.ACLState().UserStates()[identity].Permissions) } diff --git a/common/pkg/acl/list/list_test.go b/common/pkg/acl/list/list_test.go index 965d6221..c4effdae 100644 --- a/common/pkg/acl/list/list_test.go +++ b/common/pkg/acl/list/list_test.go @@ -89,7 +89,3 @@ func TestAclList_ACLState_UserJoinAndRemove(t *testing.T) { _, err = aclList.ACLState().PermissionsAtRecord(records[3].Id, idB) assert.Error(t, err, "B should have no permissions at record 3, because user should be removed") } - -func TestAclList_AddRawRecord(t *testing.T) { - -} diff --git a/common/pkg/acl/list/mock_list/mock_list.go b/common/pkg/acl/list/mock_list/mock_list.go index 4e1c0762..c70c183b 100644 --- a/common/pkg/acl/list/mock_list/mock_list.go +++ b/common/pkg/acl/list/mock_list/mock_list.go @@ -5,7 +5,6 @@ package mock_list import ( - context "context" reflect "reflect" aclrecordproto "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto" @@ -50,18 +49,19 @@ func (mr *MockACLListMockRecorder) ACLState() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ACLState", reflect.TypeOf((*MockACLList)(nil).ACLState)) } -// AddRawRecords mocks base method. -func (m *MockACLList) AddRawRecords(arg0 context.Context, arg1 []*aclrecordproto.RawACLRecordWithId) error { +// AddRawRecord mocks base method. +func (m *MockACLList) AddRawRecord(arg0 *aclrecordproto.RawACLRecordWithId) (bool, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "AddRawRecords", arg0, arg1) - ret0, _ := ret[0].(error) - return ret0 + ret := m.ctrl.Call(m, "AddRawRecord", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 } -// AddRawRecords indicates an expected call of AddRawRecords. -func (mr *MockACLListMockRecorder) AddRawRecords(arg0, arg1 interface{}) *gomock.Call { +// AddRawRecord indicates an expected call of AddRawRecord. +func (mr *MockACLListMockRecorder) AddRawRecord(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecords", reflect.TypeOf((*MockACLList)(nil).AddRawRecords), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockACLList)(nil).AddRawRecord), arg0) } // Close mocks base method.