Fix tests and change peer id context logic
This commit is contained in:
parent
b0b4e5b721
commit
4881c052aa
@ -11,7 +11,6 @@ import (
|
|||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ldiff"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ldiff"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"storj.io/drpc/drpcctx"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -76,7 +75,7 @@ func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error)
|
|||||||
return d.sendPushSpaceRequest(ctx, cl)
|
return d.sendPushSpaceRequest(ctx, cl)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, drpcctx.TransportKey{}, p.Secure())
|
ctx = peer.CtxWithPeerId(ctx, p.Id())
|
||||||
d.pingTreesInCache(ctx, newIds)
|
d.pingTreesInCache(ctx, newIds)
|
||||||
d.pingTreesInCache(ctx, changedIds)
|
d.pingTreesInCache(ctx, changedIds)
|
||||||
d.pingTreesInCache(ctx, removedIds)
|
d.pingTreesInCache(ctx, removedIds)
|
||||||
|
|||||||
@ -5,9 +5,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto"
|
||||||
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ocache"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ocache"
|
||||||
"github.com/libp2p/go-libp2p/core/sec"
|
|
||||||
"storj.io/drpc/drpcctx"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@ -190,7 +189,7 @@ func (s *streamPool) AddAndReadStreamAsync(stream spacesyncproto.SpaceStream) {
|
|||||||
|
|
||||||
func (s *streamPool) AddAndReadStreamSync(stream spacesyncproto.SpaceStream) (err error) {
|
func (s *streamPool) AddAndReadStreamSync(stream spacesyncproto.SpaceStream) (err error) {
|
||||||
s.Lock()
|
s.Lock()
|
||||||
peerId, err := GetPeerIdFromStreamContext(stream.Context())
|
peerId, err := peer.CtxPeerId(stream.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.Unlock()
|
s.Unlock()
|
||||||
return
|
return
|
||||||
@ -277,15 +276,6 @@ func (s *streamPool) removePeer(peerId string) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetPeerIdFromStreamContext(ctx context.Context) (string, error) {
|
|
||||||
conn, ok := ctx.Value(drpcctx.TransportKey{}).(sec.SecureConn)
|
|
||||||
if !ok {
|
|
||||||
return "", fmt.Errorf("incorrect connection type in stream")
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn.RemotePeer().String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func genStreamPoolKey(peerId, treeId string, counter uint64) string {
|
func genStreamPoolKey(peerId, treeId string, counter uint64) string {
|
||||||
return fmt.Sprintf("%s.%s.%d", peerId, treeId, counter)
|
return fmt.Sprintf("%s.%s.%d", peerId, treeId, counter)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -3,9 +3,9 @@ package syncservice
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto"
|
||||||
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/rpc/rpctest"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/rpc/rpctest"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/consensus/consensusproto"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/consensus/consensusproto"
|
||||||
"github.com/libp2p/go-libp2p/core/peer"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@ -53,24 +53,23 @@ type fixture struct {
|
|||||||
clientStream spacesyncproto.DRPCSpace_StreamStream
|
clientStream spacesyncproto.DRPCSpace_StreamStream
|
||||||
serverStream spacesyncproto.DRPCSpace_StreamStream
|
serverStream spacesyncproto.DRPCSpace_StreamStream
|
||||||
pool *streamPool
|
pool *streamPool
|
||||||
localId peer.ID
|
clientId string
|
||||||
remoteId peer.ID
|
serverId string
|
||||||
}
|
}
|
||||||
|
|
||||||
func newFixture(t *testing.T, localId, remoteId peer.ID, handler MessageHandler) *fixture {
|
func newFixture(t *testing.T, clientId, serverId string, handler MessageHandler) *fixture {
|
||||||
fx := &fixture{
|
fx := &fixture{
|
||||||
testServer: &testServer{},
|
testServer: &testServer{},
|
||||||
drpcTS: rpctest.NewTestServer(),
|
drpcTS: rpctest.NewTestServer(),
|
||||||
localId: localId,
|
clientId: clientId,
|
||||||
remoteId: remoteId,
|
serverId: serverId,
|
||||||
}
|
}
|
||||||
fx.testServer.stream = make(chan spacesyncproto.DRPCSpace_StreamStream, 1)
|
fx.testServer.stream = make(chan spacesyncproto.DRPCSpace_StreamStream, 1)
|
||||||
require.NoError(t, spacesyncproto.DRPCRegisterSpace(fx.drpcTS.Mux, fx.testServer))
|
require.NoError(t, spacesyncproto.DRPCRegisterSpace(fx.drpcTS.Mux, fx.testServer))
|
||||||
clientWrapper := rpctest.NewSecConnWrapper(nil, nil, localId, remoteId)
|
fx.client = spacesyncproto.NewDRPCSpaceClient(fx.drpcTS.Dial(peer.CtxWithPeerId(context.Background(), clientId)))
|
||||||
fx.client = spacesyncproto.NewDRPCSpaceClient(fx.drpcTS.DialWrapConn(nil, clientWrapper))
|
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
fx.clientStream, err = fx.client.Stream(context.Background())
|
fx.clientStream, err = fx.client.Stream(peer.CtxWithPeerId(context.Background(), serverId))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
fx.serverStream = fx.testServer.waitStream(t)
|
fx.serverStream = fx.testServer.waitStream(t)
|
||||||
fx.pool = newStreamPool(handler).(*streamPool)
|
fx.pool = newStreamPool(handler).(*streamPool)
|
||||||
@ -87,14 +86,14 @@ func (fx *fixture) run(t *testing.T) chan error {
|
|||||||
|
|
||||||
time.Sleep(time.Millisecond * 10)
|
time.Sleep(time.Millisecond * 10)
|
||||||
fx.pool.Lock()
|
fx.pool.Lock()
|
||||||
require.Equal(t, fx.pool.peerStreams[fx.remoteId.String()], fx.clientStream)
|
require.Equal(t, fx.pool.peerStreams[fx.serverId], fx.clientStream)
|
||||||
fx.pool.Unlock()
|
fx.pool.Unlock()
|
||||||
|
|
||||||
return waitCh
|
return waitCh
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamPool_AddAndReadStreamAsync(t *testing.T) {
|
func TestStreamPool_AddAndReadStreamAsync(t *testing.T) {
|
||||||
remId := peer.ID("remoteId")
|
remId := "remoteId"
|
||||||
|
|
||||||
t.Run("client close", func(t *testing.T) {
|
t.Run("client close", func(t *testing.T) {
|
||||||
fx := newFixture(t, "", remId, nil)
|
fx := newFixture(t, "", remId, nil)
|
||||||
@ -105,7 +104,7 @@ func TestStreamPool_AddAndReadStreamAsync(t *testing.T) {
|
|||||||
err = <-waitCh
|
err = <-waitCh
|
||||||
|
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
require.Nil(t, fx.pool.peerStreams[remId])
|
||||||
})
|
})
|
||||||
t.Run("server close", func(t *testing.T) {
|
t.Run("server close", func(t *testing.T) {
|
||||||
fx := newFixture(t, "", remId, nil)
|
fx := newFixture(t, "", remId, nil)
|
||||||
@ -116,12 +115,12 @@ func TestStreamPool_AddAndReadStreamAsync(t *testing.T) {
|
|||||||
|
|
||||||
err = <-waitCh
|
err = <-waitCh
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
require.Nil(t, fx.pool.peerStreams[remId])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamPool_Close(t *testing.T) {
|
func TestStreamPool_Close(t *testing.T) {
|
||||||
remId := peer.ID("remoteId")
|
remId := "remoteId"
|
||||||
|
|
||||||
t.Run("client close", func(t *testing.T) {
|
t.Run("client close", func(t *testing.T) {
|
||||||
fx := newFixture(t, "", remId, nil)
|
fx := newFixture(t, "", remId, nil)
|
||||||
@ -160,7 +159,7 @@ func TestStreamPool_Close(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamPool_ReceiveMessage(t *testing.T) {
|
func TestStreamPool_ReceiveMessage(t *testing.T) {
|
||||||
remId := peer.ID("remoteId")
|
remId := "remoteId"
|
||||||
t.Run("pool receive message from server", func(t *testing.T) {
|
t.Run("pool receive message from server", func(t *testing.T) {
|
||||||
objectId := "objectId"
|
objectId := "objectId"
|
||||||
msg := &spacesyncproto.ObjectSyncMessage{
|
msg := &spacesyncproto.ObjectSyncMessage{
|
||||||
@ -182,23 +181,23 @@ func TestStreamPool_ReceiveMessage(t *testing.T) {
|
|||||||
err = <-waitCh
|
err = <-waitCh
|
||||||
|
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
require.Nil(t, fx.pool.peerStreams[remId])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamPool_HasActiveStream(t *testing.T) {
|
func TestStreamPool_HasActiveStream(t *testing.T) {
|
||||||
remId := peer.ID("remoteId")
|
remId := "remoteId"
|
||||||
t.Run("pool has active stream", func(t *testing.T) {
|
t.Run("pool has active stream", func(t *testing.T) {
|
||||||
fx := newFixture(t, "", remId, nil)
|
fx := newFixture(t, "", remId, nil)
|
||||||
waitCh := fx.run(t)
|
waitCh := fx.run(t)
|
||||||
require.True(t, fx.pool.HasActiveStream(remId.String()))
|
require.True(t, fx.pool.HasActiveStream(remId))
|
||||||
|
|
||||||
err := fx.clientStream.Close()
|
err := fx.clientStream.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = <-waitCh
|
err = <-waitCh
|
||||||
|
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
require.Nil(t, fx.pool.peerStreams[remId])
|
||||||
})
|
})
|
||||||
t.Run("pool has no active stream", func(t *testing.T) {
|
t.Run("pool has no active stream", func(t *testing.T) {
|
||||||
fx := newFixture(t, "", remId, nil)
|
fx := newFixture(t, "", remId, nil)
|
||||||
@ -207,13 +206,13 @@ func TestStreamPool_HasActiveStream(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = <-waitCh
|
err = <-waitCh
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.False(t, fx.pool.HasActiveStream(remId.String()))
|
require.False(t, fx.pool.HasActiveStream(remId))
|
||||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
require.Nil(t, fx.pool.peerStreams[remId])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamPool_SendAsync(t *testing.T) {
|
func TestStreamPool_SendAsync(t *testing.T) {
|
||||||
remId := peer.ID("remoteId")
|
remId := "remoteId"
|
||||||
t.Run("pool send async to server", func(t *testing.T) {
|
t.Run("pool send async to server", func(t *testing.T) {
|
||||||
objectId := "objectId"
|
objectId := "objectId"
|
||||||
msg := &spacesyncproto.ObjectSyncMessage{
|
msg := &spacesyncproto.ObjectSyncMessage{
|
||||||
@ -229,7 +228,7 @@ func TestStreamPool_SendAsync(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
waitCh := fx.run(t)
|
waitCh := fx.run(t)
|
||||||
|
|
||||||
err := fx.pool.SendAsync([]string{remId.String()}, msg)
|
err := fx.pool.SendAsync([]string{remId}, msg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
<-recvChan
|
<-recvChan
|
||||||
err = fx.clientStream.Close()
|
err = fx.clientStream.Close()
|
||||||
@ -237,12 +236,12 @@ func TestStreamPool_SendAsync(t *testing.T) {
|
|||||||
err = <-waitCh
|
err = <-waitCh
|
||||||
|
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
require.Nil(t, fx.pool.peerStreams[remId])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamPool_SendSync(t *testing.T) {
|
func TestStreamPool_SendSync(t *testing.T) {
|
||||||
remId := peer.ID("remoteId")
|
remId := "remoteId"
|
||||||
t.Run("pool send sync to server", func(t *testing.T) {
|
t.Run("pool send sync to server", func(t *testing.T) {
|
||||||
objectId := "objectId"
|
objectId := "objectId"
|
||||||
payload := []byte("payload")
|
payload := []byte("payload")
|
||||||
@ -260,7 +259,7 @@ func TestStreamPool_SendSync(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
waitCh := fx.run(t)
|
waitCh := fx.run(t)
|
||||||
res, err := fx.pool.SendSync(remId.String(), msg)
|
res, err := fx.pool.SendSync(remId, msg)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, payload, res.Payload)
|
require.Equal(t, payload, res.Payload)
|
||||||
err = fx.clientStream.Close()
|
err = fx.clientStream.Close()
|
||||||
@ -268,7 +267,7 @@ func TestStreamPool_SendSync(t *testing.T) {
|
|||||||
err = <-waitCh
|
err = <-waitCh
|
||||||
|
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
require.Nil(t, fx.pool.peerStreams[remId])
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("pool send sync timeout", func(t *testing.T) {
|
t.Run("pool send sync timeout", func(t *testing.T) {
|
||||||
@ -285,19 +284,19 @@ func TestStreamPool_SendSync(t *testing.T) {
|
|||||||
require.NotEmpty(t, message.ReplyId)
|
require.NotEmpty(t, message.ReplyId)
|
||||||
}()
|
}()
|
||||||
waitCh := fx.run(t)
|
waitCh := fx.run(t)
|
||||||
_, err := fx.pool.SendSync(remId.String(), msg)
|
_, err := fx.pool.SendSync(remId, msg)
|
||||||
require.Equal(t, ErrSyncTimeout, err)
|
require.Equal(t, ErrSyncTimeout, err)
|
||||||
err = fx.clientStream.Close()
|
err = fx.clientStream.Close()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
err = <-waitCh
|
err = <-waitCh
|
||||||
|
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
require.Nil(t, fx.pool.peerStreams[remId])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestStreamPool_BroadcastAsync(t *testing.T) {
|
func TestStreamPool_BroadcastAsync(t *testing.T) {
|
||||||
remId := peer.ID("remoteId")
|
remId := "remoteId"
|
||||||
t.Run("pool broadcast async to server", func(t *testing.T) {
|
t.Run("pool broadcast async to server", func(t *testing.T) {
|
||||||
objectId := "objectId"
|
objectId := "objectId"
|
||||||
msg := &spacesyncproto.ObjectSyncMessage{
|
msg := &spacesyncproto.ObjectSyncMessage{
|
||||||
@ -321,6 +320,6 @@ func TestStreamPool_BroadcastAsync(t *testing.T) {
|
|||||||
err = <-waitCh
|
err = <-waitCh
|
||||||
|
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
require.Nil(t, fx.pool.peerStreams[remId])
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -97,11 +97,11 @@ func (s *syncService) responsibleStreamCheckLoop(ctx context.Context) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, peer := range respPeers {
|
for _, p := range respPeers {
|
||||||
if s.streamPool.HasActiveStream(peer.Id()) {
|
if s.streamPool.HasActiveStream(p.Id()) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
stream, err := s.clientFactory.Client(peer).Stream(ctx)
|
stream, err := s.clientFactory.Client(p).Stream(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = rpcerr.Unwrap(err)
|
err = rpcerr.Unwrap(err)
|
||||||
log.With("spaceId", s.spaceId).Errorf("failed to open stream: %v", err)
|
log.With("spaceId", s.spaceId).Errorf("failed to open stream: %v", err)
|
||||||
|
|||||||
@ -9,6 +9,7 @@ import (
|
|||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice/synchandler"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice/synchandler"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/synctree/updatelistener"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/synctree/updatelistener"
|
||||||
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/nodeconf"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/list"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/list"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/storage"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/storage"
|
||||||
@ -109,8 +110,7 @@ func CreateSyncTree(ctx context.Context, deps CreateDeps) (t tree.ObjectTree, er
|
|||||||
|
|
||||||
func BuildSyncTreeOrGetRemote(ctx context.Context, id string, deps BuildDeps) (t tree.ObjectTree, err error) {
|
func BuildSyncTreeOrGetRemote(ctx context.Context, id string, deps BuildDeps) (t tree.ObjectTree, err error) {
|
||||||
getTreeRemote := func() (msg *treechangeproto.TreeSyncMessage, err error) {
|
getTreeRemote := func() (msg *treechangeproto.TreeSyncMessage, err error) {
|
||||||
// TODO: add empty context handling (when this is not happening due to head update)
|
peerId, err := peer.CtxPeerId(ctx)
|
||||||
peerId, err := syncservice.GetPeerIdFromStreamContext(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
32
common/net/peer/context.go
Normal file
32
common/net/peer/context.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package peer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"github.com/libp2p/go-libp2p/core/sec"
|
||||||
|
"storj.io/drpc/drpcctx"
|
||||||
|
)
|
||||||
|
|
||||||
|
type contextKey uint
|
||||||
|
|
||||||
|
const (
|
||||||
|
contextKeyPeerId contextKey = iota
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrPeerIdNotFoundInContext = errors.New("peer id not found in context")
|
||||||
|
|
||||||
|
// CtxPeerId first tries to get peer id under our own key, but if it is not found tries to get through DRPC key
|
||||||
|
func CtxPeerId(ctx context.Context) (string, error) {
|
||||||
|
if peerId, ok := ctx.Value(contextKeyPeerId).(string); ok {
|
||||||
|
return peerId, nil
|
||||||
|
}
|
||||||
|
if conn, ok := ctx.Value(drpcctx.TransportKey{}).(sec.SecureConn); ok {
|
||||||
|
return conn.RemotePeer().String(), nil
|
||||||
|
}
|
||||||
|
return "", ErrPeerIdNotFoundInContext
|
||||||
|
}
|
||||||
|
|
||||||
|
// CtxWithPeerId sets peer id in the context
|
||||||
|
func CtxWithPeerId(ctx context.Context, peerId string) context.Context {
|
||||||
|
return context.WithValue(ctx, contextKeyPeerId, peerId)
|
||||||
|
}
|
||||||
@ -21,7 +21,6 @@ type Peer interface {
|
|||||||
Id() string
|
Id() string
|
||||||
LastUsage() time.Time
|
LastUsage() time.Time
|
||||||
UpdateLastUsage()
|
UpdateLastUsage()
|
||||||
Secure() sec.SecureConn
|
|
||||||
drpc.Conn
|
drpc.Conn
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,10 +35,6 @@ func (p *peer) Id() string {
|
|||||||
return p.id
|
return p.id
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *peer) Secure() sec.SecureConn {
|
|
||||||
return p.sc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *peer) LastUsage() time.Time {
|
func (p *peer) LastUsage() time.Time {
|
||||||
select {
|
select {
|
||||||
case <-p.Closed():
|
case <-p.Closed():
|
||||||
|
|||||||
@ -36,7 +36,7 @@ func (t *TestPool) Get(ctx context.Context, id string) (peer.Peer, error) {
|
|||||||
if t.ts == nil {
|
if t.ts == nil {
|
||||||
return nil, ErrCantConnect
|
return nil, ErrCantConnect
|
||||||
}
|
}
|
||||||
return &testPeer{id: id, Conn: t.ts.Dial()}, nil
|
return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) {
|
func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) {
|
||||||
@ -45,7 +45,7 @@ func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) {
|
|||||||
if t.ts == nil {
|
if t.ts == nil {
|
||||||
return nil, ErrCantConnect
|
return nil, ErrCantConnect
|
||||||
}
|
}
|
||||||
return &testPeer{id: id, Conn: t.ts.Dial()}, nil
|
return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
|
func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
|
||||||
@ -54,7 +54,7 @@ func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, e
|
|||||||
if t.ts == nil {
|
if t.ts == nil {
|
||||||
return nil, ErrCantConnect
|
return nil, ErrCantConnect
|
||||||
}
|
}
|
||||||
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial()}, nil
|
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
|
func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
|
||||||
@ -63,7 +63,7 @@ func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer,
|
|||||||
if t.ts == nil {
|
if t.ts == nil {
|
||||||
return nil, ErrCantConnect
|
return nil, ErrCantConnect
|
||||||
}
|
}
|
||||||
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial()}, nil
|
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TestPool) Init(a *app.App) (err error) {
|
func (t *TestPool) Init(a *app.App) (err error) {
|
||||||
|
|||||||
@ -2,8 +2,6 @@ package rpctest
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/libp2p/go-libp2p/core/crypto"
|
|
||||||
"github.com/libp2p/go-libp2p/core/peer"
|
|
||||||
"net"
|
"net"
|
||||||
"storj.io/drpc"
|
"storj.io/drpc"
|
||||||
"storj.io/drpc/drpcconn"
|
"storj.io/drpc/drpcconn"
|
||||||
@ -11,48 +9,6 @@ import (
|
|||||||
"storj.io/drpc/drpcserver"
|
"storj.io/drpc/drpcserver"
|
||||||
)
|
)
|
||||||
|
|
||||||
type SecConnMock struct {
|
|
||||||
net.Conn
|
|
||||||
localPrivKey crypto.PrivKey
|
|
||||||
remotePubKey crypto.PubKey
|
|
||||||
localId peer.ID
|
|
||||||
remoteId peer.ID
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SecConnMock) LocalPeer() peer.ID {
|
|
||||||
return s.localId
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SecConnMock) LocalPrivateKey() crypto.PrivKey {
|
|
||||||
return s.localPrivKey
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SecConnMock) RemotePeer() peer.ID {
|
|
||||||
return s.remoteId
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *SecConnMock) RemotePublicKey() crypto.PubKey {
|
|
||||||
return s.remotePubKey
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConnWrapper func(conn net.Conn) net.Conn
|
|
||||||
|
|
||||||
func NewSecConnWrapper(
|
|
||||||
localPrivKey crypto.PrivKey,
|
|
||||||
remotePubKey crypto.PubKey,
|
|
||||||
localId peer.ID,
|
|
||||||
remoteId peer.ID) ConnWrapper {
|
|
||||||
return func(conn net.Conn) net.Conn {
|
|
||||||
return &SecConnMock{
|
|
||||||
Conn: conn,
|
|
||||||
localPrivKey: localPrivKey,
|
|
||||||
remotePubKey: remotePubKey,
|
|
||||||
localId: localId,
|
|
||||||
remoteId: remoteId,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewTestServer() *TesServer {
|
func NewTestServer() *TesServer {
|
||||||
ts := &TesServer{
|
ts := &TesServer{
|
||||||
Mux: drpcmux.New(),
|
Mux: drpcmux.New(),
|
||||||
@ -66,18 +22,8 @@ type TesServer struct {
|
|||||||
*drpcserver.Server
|
*drpcserver.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ts *TesServer) Dial() drpc.Conn {
|
func (ts *TesServer) Dial(ctx context.Context) drpc.Conn {
|
||||||
return ts.DialWrapConn(nil, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ts *TesServer) DialWrapConn(serverWrapper ConnWrapper, clientWrapper ConnWrapper) drpc.Conn {
|
|
||||||
sc, cc := net.Pipe()
|
sc, cc := net.Pipe()
|
||||||
if serverWrapper != nil {
|
go ts.Server.ServeOne(ctx, sc)
|
||||||
sc = serverWrapper(sc)
|
|
||||||
}
|
|
||||||
if clientWrapper != nil {
|
|
||||||
cc = clientWrapper(cc)
|
|
||||||
}
|
|
||||||
go ts.Server.ServeOne(context.Background(), sc)
|
|
||||||
return drpcconn.New(cc)
|
return drpcconn.New(cc)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,28 +0,0 @@
|
|||||||
package secure
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"github.com/libp2p/go-libp2p/core/sec"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrSecureConnNotFoundInContext = errors.New("secure connection not found in context")
|
|
||||||
)
|
|
||||||
|
|
||||||
type contextKey uint
|
|
||||||
|
|
||||||
const (
|
|
||||||
contextKeySecureConn contextKey = iota
|
|
||||||
)
|
|
||||||
|
|
||||||
func CtxSecureConn(ctx context.Context) (sec.SecureConn, error) {
|
|
||||||
if conn, ok := ctx.Value(contextKeySecureConn).(sec.SecureConn); ok {
|
|
||||||
return conn, nil
|
|
||||||
}
|
|
||||||
return nil, ErrSecureConnNotFoundInContext
|
|
||||||
}
|
|
||||||
|
|
||||||
func ctxWithSecureConn(ctx context.Context, conn sec.SecureConn) context.Context {
|
|
||||||
return context.WithValue(ctx, contextKeySecureConn, conn)
|
|
||||||
}
|
|
||||||
@ -2,6 +2,7 @@ package secure
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/peer"
|
||||||
"github.com/libp2p/go-libp2p/core/crypto"
|
"github.com/libp2p/go-libp2p/core/crypto"
|
||||||
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
|
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
|
||||||
"net"
|
"net"
|
||||||
@ -45,6 +46,6 @@ func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.C
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, HandshakeError(err)
|
return nil, nil, HandshakeError(err)
|
||||||
}
|
}
|
||||||
ctx = ctxWithSecureConn(ctx, secure)
|
ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String())
|
||||||
return ctx, secure, nil
|
return ctx, secure, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@ -65,31 +65,24 @@ message ACLUserAdd {
|
|||||||
ACLUserPermissions permissions = 4;
|
ACLUserPermissions permissions = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// accept key, encrypt key, invite id
|
|
||||||
// GetSpace(id) -> ... (space header + acl root) -> diff
|
|
||||||
// Join(ACLJoinRecord) -> Ok
|
|
||||||
|
|
||||||
message ACLUserInvite {
|
message ACLUserInvite {
|
||||||
bytes acceptPublicKey = 1;
|
bytes acceptPublicKey = 1;
|
||||||
// TODO: change to read key
|
uint64 encryptSymKeyHash = 2;
|
||||||
bytes encryptPublicKey = 2;
|
|
||||||
repeated bytes encryptedReadKeys = 3;
|
repeated bytes encryptedReadKeys = 3;
|
||||||
ACLUserPermissions permissions = 4;
|
ACLUserPermissions permissions = 4;
|
||||||
// TODO: either derive inviteId from pub keys or think if it is possible to just use ACL record id
|
|
||||||
string inviteId = 5;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
message ACLUserJoin {
|
message ACLUserJoin {
|
||||||
bytes identity = 1;
|
bytes identity = 1;
|
||||||
bytes encryptionKey = 2;
|
bytes encryptionKey = 2;
|
||||||
bytes acceptSignature = 3;
|
bytes acceptSignature = 3;
|
||||||
string inviteId = 4;
|
bytes acceptPubKey = 4;
|
||||||
repeated bytes encryptedReadKeys = 5;
|
repeated bytes encryptedReadKeys = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ACLUserRemove {
|
message ACLUserRemove {
|
||||||
bytes identity = 1;
|
bytes identity = 1;
|
||||||
repeated ACLReadKeyReplace readKeyReplaces = 3;
|
repeated ACLReadKeyReplace readKeyReplaces = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ACLReadKeyReplace {
|
message ACLReadKeyReplace {
|
||||||
@ -109,7 +102,6 @@ enum ACLUserPermissions {
|
|||||||
Reader = 2;
|
Reader = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
message ACLSyncMessage {
|
message ACLSyncMessage {
|
||||||
ACLSyncContentValue content = 2;
|
ACLSyncContentValue content = 2;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
package list
|
package list
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
account "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/account"
|
account "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/account"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto"
|
||||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/common"
|
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/common"
|
||||||
@ -44,7 +43,8 @@ func TestAclRecordBuilder_BuildUserJoin(t *testing.T) {
|
|||||||
Payload: marshalledJoin,
|
Payload: marshalledJoin,
|
||||||
Id: id,
|
Id: id,
|
||||||
}
|
}
|
||||||
err = aclList.AddRawRecords(context.Background(), []*aclrecordproto.RawACLRecordWithId{rawRec})
|
res, err := aclList.AddRawRecord(rawRec)
|
||||||
|
require.True(t, res)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, aclrecordproto.ACLUserPermissions_Writer, aclList.ACLState().UserStates()[identity].Permissions)
|
require.Equal(t, aclrecordproto.ACLUserPermissions_Writer, aclList.ACLState().UserStates()[identity].Permissions)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -89,7 +89,3 @@ func TestAclList_ACLState_UserJoinAndRemove(t *testing.T) {
|
|||||||
_, err = aclList.ACLState().PermissionsAtRecord(records[3].Id, idB)
|
_, err = aclList.ACLState().PermissionsAtRecord(records[3].Id, idB)
|
||||||
assert.Error(t, err, "B should have no permissions at record 3, because user should be removed")
|
assert.Error(t, err, "B should have no permissions at record 3, because user should be removed")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAclList_AddRawRecord(t *testing.T) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|||||||
@ -5,7 +5,6 @@
|
|||||||
package mock_list
|
package mock_list
|
||||||
|
|
||||||
import (
|
import (
|
||||||
context "context"
|
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
aclrecordproto "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto"
|
aclrecordproto "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto"
|
||||||
@ -50,18 +49,19 @@ func (mr *MockACLListMockRecorder) ACLState() *gomock.Call {
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ACLState", reflect.TypeOf((*MockACLList)(nil).ACLState))
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ACLState", reflect.TypeOf((*MockACLList)(nil).ACLState))
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRawRecords mocks base method.
|
// AddRawRecord mocks base method.
|
||||||
func (m *MockACLList) AddRawRecords(arg0 context.Context, arg1 []*aclrecordproto.RawACLRecordWithId) error {
|
func (m *MockACLList) AddRawRecord(arg0 *aclrecordproto.RawACLRecordWithId) (bool, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddRawRecords", arg0, arg1)
|
ret := m.ctrl.Call(m, "AddRawRecord", arg0)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddRawRecords indicates an expected call of AddRawRecords.
|
// AddRawRecord indicates an expected call of AddRawRecord.
|
||||||
func (mr *MockACLListMockRecorder) AddRawRecords(arg0, arg1 interface{}) *gomock.Call {
|
func (mr *MockACLListMockRecorder) AddRawRecord(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecords", reflect.TypeOf((*MockACLList)(nil).AddRawRecords), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockACLList)(nil).AddRawRecord), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close mocks base method.
|
// Close mocks base method.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user