Fix tests and change peer id context logic

This commit is contained in:
mcrakhman 2022-11-07 13:36:10 +01:00
parent 354ee3b6c7
commit 925005a9de
No known key found for this signature in database
GPG Key ID: DED12CFEF5B8396B
15 changed files with 100 additions and 170 deletions

View File

@ -11,7 +11,6 @@ import (
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ldiff"
"go.uber.org/zap"
"storj.io/drpc/drpcctx"
"time"
)
@ -76,7 +75,7 @@ func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error)
return d.sendPushSpaceRequest(ctx, cl)
}
ctx = context.WithValue(ctx, drpcctx.TransportKey{}, p.Secure())
ctx = peer.CtxWithPeerId(ctx, p.Id())
d.pingTreesInCache(ctx, newIds)
d.pingTreesInCache(ctx, changedIds)
d.pingTreesInCache(ctx, removedIds)

View File

@ -5,9 +5,8 @@ import (
"errors"
"fmt"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ocache"
"github.com/libp2p/go-libp2p/core/sec"
"storj.io/drpc/drpcctx"
"sync"
"sync/atomic"
"time"
@ -190,7 +189,7 @@ func (s *streamPool) AddAndReadStreamAsync(stream spacesyncproto.SpaceStream) {
func (s *streamPool) AddAndReadStreamSync(stream spacesyncproto.SpaceStream) (err error) {
s.Lock()
peerId, err := GetPeerIdFromStreamContext(stream.Context())
peerId, err := peer.CtxPeerId(stream.Context())
if err != nil {
s.Unlock()
return
@ -277,15 +276,6 @@ func (s *streamPool) removePeer(peerId string) (err error) {
return
}
func GetPeerIdFromStreamContext(ctx context.Context) (string, error) {
conn, ok := ctx.Value(drpcctx.TransportKey{}).(sec.SecureConn)
if !ok {
return "", fmt.Errorf("incorrect connection type in stream")
}
return conn.RemotePeer().String(), nil
}
func genStreamPoolKey(peerId, treeId string, counter uint64) string {
return fmt.Sprintf("%s.%s.%d", peerId, treeId, counter)
}

View File

@ -3,9 +3,9 @@ package syncservice
import (
"context"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/rpc/rpctest"
"github.com/anytypeio/go-anytype-infrastructure-experiments/consensus/consensusproto"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/stretchr/testify/require"
"testing"
"time"
@ -53,24 +53,23 @@ type fixture struct {
clientStream spacesyncproto.DRPCSpace_StreamStream
serverStream spacesyncproto.DRPCSpace_StreamStream
pool *streamPool
localId peer.ID
remoteId peer.ID
clientId string
serverId string
}
func newFixture(t *testing.T, localId, remoteId peer.ID, handler MessageHandler) *fixture {
func newFixture(t *testing.T, clientId, serverId string, handler MessageHandler) *fixture {
fx := &fixture{
testServer: &testServer{},
drpcTS: rpctest.NewTestServer(),
localId: localId,
remoteId: remoteId,
clientId: clientId,
serverId: serverId,
}
fx.testServer.stream = make(chan spacesyncproto.DRPCSpace_StreamStream, 1)
require.NoError(t, spacesyncproto.DRPCRegisterSpace(fx.drpcTS.Mux, fx.testServer))
clientWrapper := rpctest.NewSecConnWrapper(nil, nil, localId, remoteId)
fx.client = spacesyncproto.NewDRPCSpaceClient(fx.drpcTS.DialWrapConn(nil, clientWrapper))
fx.client = spacesyncproto.NewDRPCSpaceClient(fx.drpcTS.Dial(peer.CtxWithPeerId(context.Background(), clientId)))
var err error
fx.clientStream, err = fx.client.Stream(context.Background())
fx.clientStream, err = fx.client.Stream(peer.CtxWithPeerId(context.Background(), serverId))
require.NoError(t, err)
fx.serverStream = fx.testServer.waitStream(t)
fx.pool = newStreamPool(handler).(*streamPool)
@ -87,14 +86,14 @@ func (fx *fixture) run(t *testing.T) chan error {
time.Sleep(time.Millisecond * 10)
fx.pool.Lock()
require.Equal(t, fx.pool.peerStreams[fx.remoteId.String()], fx.clientStream)
require.Equal(t, fx.pool.peerStreams[fx.serverId], fx.clientStream)
fx.pool.Unlock()
return waitCh
}
func TestStreamPool_AddAndReadStreamAsync(t *testing.T) {
remId := peer.ID("remoteId")
remId := "remoteId"
t.Run("client close", func(t *testing.T) {
fx := newFixture(t, "", remId, nil)
@ -105,7 +104,7 @@ func TestStreamPool_AddAndReadStreamAsync(t *testing.T) {
err = <-waitCh
require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()])
require.Nil(t, fx.pool.peerStreams[remId])
})
t.Run("server close", func(t *testing.T) {
fx := newFixture(t, "", remId, nil)
@ -116,12 +115,12 @@ func TestStreamPool_AddAndReadStreamAsync(t *testing.T) {
err = <-waitCh
require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()])
require.Nil(t, fx.pool.peerStreams[remId])
})
}
func TestStreamPool_Close(t *testing.T) {
remId := peer.ID("remoteId")
remId := "remoteId"
t.Run("client close", func(t *testing.T) {
fx := newFixture(t, "", remId, nil)
@ -160,7 +159,7 @@ func TestStreamPool_Close(t *testing.T) {
}
func TestStreamPool_ReceiveMessage(t *testing.T) {
remId := peer.ID("remoteId")
remId := "remoteId"
t.Run("pool receive message from server", func(t *testing.T) {
objectId := "objectId"
msg := &spacesyncproto.ObjectSyncMessage{
@ -182,23 +181,23 @@ func TestStreamPool_ReceiveMessage(t *testing.T) {
err = <-waitCh
require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()])
require.Nil(t, fx.pool.peerStreams[remId])
})
}
func TestStreamPool_HasActiveStream(t *testing.T) {
remId := peer.ID("remoteId")
remId := "remoteId"
t.Run("pool has active stream", func(t *testing.T) {
fx := newFixture(t, "", remId, nil)
waitCh := fx.run(t)
require.True(t, fx.pool.HasActiveStream(remId.String()))
require.True(t, fx.pool.HasActiveStream(remId))
err := fx.clientStream.Close()
require.NoError(t, err)
err = <-waitCh
require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()])
require.Nil(t, fx.pool.peerStreams[remId])
})
t.Run("pool has no active stream", func(t *testing.T) {
fx := newFixture(t, "", remId, nil)
@ -207,13 +206,13 @@ func TestStreamPool_HasActiveStream(t *testing.T) {
require.NoError(t, err)
err = <-waitCh
require.Error(t, err)
require.False(t, fx.pool.HasActiveStream(remId.String()))
require.Nil(t, fx.pool.peerStreams[remId.String()])
require.False(t, fx.pool.HasActiveStream(remId))
require.Nil(t, fx.pool.peerStreams[remId])
})
}
func TestStreamPool_SendAsync(t *testing.T) {
remId := peer.ID("remoteId")
remId := "remoteId"
t.Run("pool send async to server", func(t *testing.T) {
objectId := "objectId"
msg := &spacesyncproto.ObjectSyncMessage{
@ -229,7 +228,7 @@ func TestStreamPool_SendAsync(t *testing.T) {
}()
waitCh := fx.run(t)
err := fx.pool.SendAsync([]string{remId.String()}, msg)
err := fx.pool.SendAsync([]string{remId}, msg)
require.NoError(t, err)
<-recvChan
err = fx.clientStream.Close()
@ -237,12 +236,12 @@ func TestStreamPool_SendAsync(t *testing.T) {
err = <-waitCh
require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()])
require.Nil(t, fx.pool.peerStreams[remId])
})
}
func TestStreamPool_SendSync(t *testing.T) {
remId := peer.ID("remoteId")
remId := "remoteId"
t.Run("pool send sync to server", func(t *testing.T) {
objectId := "objectId"
payload := []byte("payload")
@ -260,7 +259,7 @@ func TestStreamPool_SendSync(t *testing.T) {
require.NoError(t, err)
}()
waitCh := fx.run(t)
res, err := fx.pool.SendSync(remId.String(), msg)
res, err := fx.pool.SendSync(remId, msg)
require.NoError(t, err)
require.Equal(t, payload, res.Payload)
err = fx.clientStream.Close()
@ -268,7 +267,7 @@ func TestStreamPool_SendSync(t *testing.T) {
err = <-waitCh
require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()])
require.Nil(t, fx.pool.peerStreams[remId])
})
t.Run("pool send sync timeout", func(t *testing.T) {
@ -285,19 +284,19 @@ func TestStreamPool_SendSync(t *testing.T) {
require.NotEmpty(t, message.ReplyId)
}()
waitCh := fx.run(t)
_, err := fx.pool.SendSync(remId.String(), msg)
_, err := fx.pool.SendSync(remId, msg)
require.Equal(t, ErrSyncTimeout, err)
err = fx.clientStream.Close()
require.NoError(t, err)
err = <-waitCh
require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()])
require.Nil(t, fx.pool.peerStreams[remId])
})
}
func TestStreamPool_BroadcastAsync(t *testing.T) {
remId := peer.ID("remoteId")
remId := "remoteId"
t.Run("pool broadcast async to server", func(t *testing.T) {
objectId := "objectId"
msg := &spacesyncproto.ObjectSyncMessage{
@ -321,6 +320,6 @@ func TestStreamPool_BroadcastAsync(t *testing.T) {
err = <-waitCh
require.Error(t, err)
require.Nil(t, fx.pool.peerStreams[remId.String()])
require.Nil(t, fx.pool.peerStreams[remId])
})
}

View File

@ -97,11 +97,11 @@ func (s *syncService) responsibleStreamCheckLoop(ctx context.Context) {
if err != nil {
return
}
for _, peer := range respPeers {
if s.streamPool.HasActiveStream(peer.Id()) {
for _, p := range respPeers {
if s.streamPool.HasActiveStream(p.Id()) {
continue
}
stream, err := s.clientFactory.Client(peer).Stream(ctx)
stream, err := s.clientFactory.Client(p).Stream(ctx)
if err != nil {
err = rpcerr.Unwrap(err)
log.With("spaceId", s.spaceId).Errorf("failed to open stream: %v", err)

View File

@ -10,6 +10,7 @@ import (
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice/synchandler"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/synctree/updatelistener"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/list"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/storage"
@ -111,8 +112,7 @@ func CreateSyncTree(ctx context.Context, deps CreateDeps) (t tree.ObjectTree, er
func BuildSyncTreeOrGetRemote(ctx context.Context, id string, deps BuildDeps) (t tree.ObjectTree, err error) {
getTreeRemote := func() (msg *treechangeproto.TreeSyncMessage, err error) {
// TODO: add empty context handling (when this is not happening due to head update)
peerId, err := syncservice.GetPeerIdFromStreamContext(ctx)
peerId, err := peer.CtxPeerId(ctx)
if err != nil {
return
}

View File

@ -0,0 +1,32 @@
package peer
import (
"context"
"errors"
"github.com/libp2p/go-libp2p/core/sec"
"storj.io/drpc/drpcctx"
)
type contextKey uint
const (
contextKeyPeerId contextKey = iota
)
var ErrPeerIdNotFoundInContext = errors.New("peer id not found in context")
// CtxPeerId first tries to get peer id under our own key, but if it is not found tries to get through DRPC key
func CtxPeerId(ctx context.Context) (string, error) {
if peerId, ok := ctx.Value(contextKeyPeerId).(string); ok {
return peerId, nil
}
if conn, ok := ctx.Value(drpcctx.TransportKey{}).(sec.SecureConn); ok {
return conn.RemotePeer().String(), nil
}
return "", ErrPeerIdNotFoundInContext
}
// CtxWithPeerId sets peer id in the context
func CtxWithPeerId(ctx context.Context, peerId string) context.Context {
return context.WithValue(ctx, contextKeyPeerId, peerId)
}

View File

@ -21,7 +21,6 @@ type Peer interface {
Id() string
LastUsage() time.Time
UpdateLastUsage()
Secure() sec.SecureConn
drpc.Conn
}
@ -36,10 +35,6 @@ func (p *peer) Id() string {
return p.id
}
func (p *peer) Secure() sec.SecureConn {
return p.sc
}
func (p *peer) LastUsage() time.Time {
select {
case <-p.Closed():

View File

@ -36,7 +36,7 @@ func (t *TestPool) Get(ctx context.Context, id string) (peer.Peer, error) {
if t.ts == nil {
return nil, ErrCantConnect
}
return &testPeer{id: id, Conn: t.ts.Dial()}, nil
return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil
}
func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) {
@ -45,7 +45,7 @@ func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) {
if t.ts == nil {
return nil, ErrCantConnect
}
return &testPeer{id: id, Conn: t.ts.Dial()}, nil
return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil
}
func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
@ -54,7 +54,7 @@ func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, e
if t.ts == nil {
return nil, ErrCantConnect
}
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial()}, nil
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
}
func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
@ -63,7 +63,7 @@ func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer,
if t.ts == nil {
return nil, ErrCantConnect
}
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial()}, nil
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
}
func (t *TestPool) Init(a *app.App) (err error) {

View File

@ -2,8 +2,6 @@ package rpctest
import (
"context"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/peer"
"net"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
@ -11,48 +9,6 @@ import (
"storj.io/drpc/drpcserver"
)
type SecConnMock struct {
net.Conn
localPrivKey crypto.PrivKey
remotePubKey crypto.PubKey
localId peer.ID
remoteId peer.ID
}
func (s *SecConnMock) LocalPeer() peer.ID {
return s.localId
}
func (s *SecConnMock) LocalPrivateKey() crypto.PrivKey {
return s.localPrivKey
}
func (s *SecConnMock) RemotePeer() peer.ID {
return s.remoteId
}
func (s *SecConnMock) RemotePublicKey() crypto.PubKey {
return s.remotePubKey
}
type ConnWrapper func(conn net.Conn) net.Conn
func NewSecConnWrapper(
localPrivKey crypto.PrivKey,
remotePubKey crypto.PubKey,
localId peer.ID,
remoteId peer.ID) ConnWrapper {
return func(conn net.Conn) net.Conn {
return &SecConnMock{
Conn: conn,
localPrivKey: localPrivKey,
remotePubKey: remotePubKey,
localId: localId,
remoteId: remoteId,
}
}
}
func NewTestServer() *TesServer {
ts := &TesServer{
Mux: drpcmux.New(),
@ -66,18 +22,8 @@ type TesServer struct {
*drpcserver.Server
}
func (ts *TesServer) Dial() drpc.Conn {
return ts.DialWrapConn(nil, nil)
}
func (ts *TesServer) DialWrapConn(serverWrapper ConnWrapper, clientWrapper ConnWrapper) drpc.Conn {
func (ts *TesServer) Dial(ctx context.Context) drpc.Conn {
sc, cc := net.Pipe()
if serverWrapper != nil {
sc = serverWrapper(sc)
}
if clientWrapper != nil {
cc = clientWrapper(cc)
}
go ts.Server.ServeOne(context.Background(), sc)
go ts.Server.ServeOne(ctx, sc)
return drpcconn.New(cc)
}

View File

@ -1,28 +0,0 @@
package secure
import (
"context"
"errors"
"github.com/libp2p/go-libp2p/core/sec"
)
var (
ErrSecureConnNotFoundInContext = errors.New("secure connection not found in context")
)
type contextKey uint
const (
contextKeySecureConn contextKey = iota
)
func CtxSecureConn(ctx context.Context) (sec.SecureConn, error) {
if conn, ok := ctx.Value(contextKeySecureConn).(sec.SecureConn); ok {
return conn, nil
}
return nil, ErrSecureConnNotFoundInContext
}
func ctxWithSecureConn(ctx context.Context, conn sec.SecureConn) context.Context {
return context.WithValue(ctx, contextKeySecureConn, conn)
}

View File

@ -2,6 +2,7 @@ package secure
import (
"context"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
"github.com/libp2p/go-libp2p/core/crypto"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"net"
@ -45,6 +46,6 @@ func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.C
if err != nil {
return nil, nil, HandshakeError(err)
}
ctx = ctxWithSecureConn(ctx, secure)
ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String())
return ctx, secure, nil
}

View File

@ -66,23 +66,23 @@ message ACLUserAdd {
}
message ACLUserInvite {
bytes acceptPublicKey = 1;
uint64 encryptSymKeyHash = 2;
repeated bytes encryptedReadKeys = 3;
ACLUserPermissions permissions = 4;
bytes acceptPublicKey = 1;
uint64 encryptSymKeyHash = 2;
repeated bytes encryptedReadKeys = 3;
ACLUserPermissions permissions = 4;
}
message ACLUserJoin {
bytes identity = 1;
bytes encryptionKey = 2;
bytes acceptSignature = 3;
bytes acceptPubKey = 4;
repeated bytes encryptedReadKeys = 5;
bytes identity = 1;
bytes encryptionKey = 2;
bytes acceptSignature = 3;
bytes acceptPubKey = 4;
repeated bytes encryptedReadKeys = 5;
}
message ACLUserRemove {
bytes identity = 1;
repeated ACLReadKeyReplace readKeyReplaces = 2;
bytes identity = 1;
repeated ACLReadKeyReplace readKeyReplaces = 2;
}
message ACLReadKeyReplace {

View File

@ -1,7 +1,6 @@
package list
import (
"context"
account "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/account"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto"
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/common"
@ -44,7 +43,8 @@ func TestAclRecordBuilder_BuildUserJoin(t *testing.T) {
Payload: marshalledJoin,
Id: id,
}
err = aclList.AddRawRecords(context.Background(), []*aclrecordproto.RawACLRecordWithId{rawRec})
res, err := aclList.AddRawRecord(rawRec)
require.True(t, res)
require.NoError(t, err)
require.Equal(t, aclrecordproto.ACLUserPermissions_Writer, aclList.ACLState().UserStates()[identity].Permissions)
}

View File

@ -89,7 +89,3 @@ func TestAclList_ACLState_UserJoinAndRemove(t *testing.T) {
_, err = aclList.ACLState().PermissionsAtRecord(records[3].Id, idB)
assert.Error(t, err, "B should have no permissions at record 3, because user should be removed")
}
func TestAclList_AddRawRecord(t *testing.T) {
}

View File

@ -5,7 +5,6 @@
package mock_list
import (
context "context"
reflect "reflect"
aclrecordproto "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto"
@ -50,18 +49,19 @@ func (mr *MockACLListMockRecorder) ACLState() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ACLState", reflect.TypeOf((*MockACLList)(nil).ACLState))
}
// AddRawRecords mocks base method.
func (m *MockACLList) AddRawRecords(arg0 context.Context, arg1 []*aclrecordproto.RawACLRecordWithId) error {
// AddRawRecord mocks base method.
func (m *MockACLList) AddRawRecord(arg0 *aclrecordproto.RawACLRecordWithId) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddRawRecords", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
ret := m.ctrl.Call(m, "AddRawRecord", arg0)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AddRawRecords indicates an expected call of AddRawRecords.
func (mr *MockACLListMockRecorder) AddRawRecords(arg0, arg1 interface{}) *gomock.Call {
// AddRawRecord indicates an expected call of AddRawRecord.
func (mr *MockACLListMockRecorder) AddRawRecord(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecords", reflect.TypeOf((*MockACLList)(nil).AddRawRecords), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockACLList)(nil).AddRawRecord), arg0)
}
// Close mocks base method.