Test stream pool
This commit is contained in:
parent
70f3929f44
commit
ddae2be2a1
@ -5,30 +5,34 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/ocache"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"storj.io/drpc/drpcctx"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrEmptyPeer = errors.New("don't have such a peer")
|
||||
var ErrStreamClosed = errors.New("stream is already closed")
|
||||
|
||||
const maxSimultaneousOperationsPerStream = 10
|
||||
var maxSimultaneousOperationsPerStream = 10
|
||||
var syncWaitPeriod = 2 * time.Second
|
||||
|
||||
var ErrSyncTimeout = errors.New("too long wait on sync receive")
|
||||
|
||||
// StreamPool can be made generic to work with different streams
|
||||
type StreamPool interface {
|
||||
Sender
|
||||
ocache.ObjectLastUsage
|
||||
AddAndReadStreamSync(stream spacesyncproto.SpaceStream) (err error)
|
||||
AddAndReadStreamAsync(stream spacesyncproto.SpaceStream)
|
||||
HasActiveStream(peerId string) bool
|
||||
Close() (err error)
|
||||
}
|
||||
|
||||
type Sender interface {
|
||||
SendSync(peerId string, message *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
|
||||
SendAsync(peers []string, message *spacesyncproto.ObjectSyncMessage) (err error)
|
||||
BroadcastAsync(message *spacesyncproto.ObjectSyncMessage) (err error)
|
||||
|
||||
HasActiveStream(peerId string) bool
|
||||
Close() (err error)
|
||||
}
|
||||
|
||||
type MessageHandler func(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error)
|
||||
@ -44,15 +48,23 @@ type streamPool struct {
|
||||
wg *sync.WaitGroup
|
||||
waiters map[string]responseWaiter
|
||||
waitersMx sync.Mutex
|
||||
counter uint64
|
||||
counter atomic.Uint64
|
||||
lastUsage atomic.Int64
|
||||
}
|
||||
|
||||
func newStreamPool(messageHandler MessageHandler) StreamPool {
|
||||
return &streamPool{
|
||||
s := &streamPool{
|
||||
peerStreams: make(map[string]spacesyncproto.SpaceStream),
|
||||
messageHandler: messageHandler,
|
||||
waiters: make(map[string]responseWaiter),
|
||||
wg: &sync.WaitGroup{},
|
||||
}
|
||||
s.lastUsage.Store(time.Now().Unix())
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *streamPool) LastUsage() time.Time {
|
||||
return time.Unix(s.lastUsage.Load(), 0)
|
||||
}
|
||||
|
||||
func (s *streamPool) HasActiveStream(peerId string) (res bool) {
|
||||
@ -65,26 +77,39 @@ func (s *streamPool) HasActiveStream(peerId string) (res bool) {
|
||||
func (s *streamPool) SendSync(
|
||||
peerId string,
|
||||
msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
|
||||
newCounter := atomic.AddUint64(&s.counter, 1)
|
||||
msg.TrackingId = genStreamPoolKey(peerId, msg.TreeId, newCounter)
|
||||
newCounter := s.counter.Add(1)
|
||||
msg.ReplyId = genStreamPoolKey(peerId, msg.ObjectId, newCounter)
|
||||
|
||||
s.waitersMx.Lock()
|
||||
waiter := responseWaiter{
|
||||
ch: make(chan *spacesyncproto.ObjectSyncMessage),
|
||||
ch: make(chan *spacesyncproto.ObjectSyncMessage, 1),
|
||||
}
|
||||
s.waiters[msg.TrackingId] = waiter
|
||||
s.waiters[msg.ReplyId] = waiter
|
||||
s.waitersMx.Unlock()
|
||||
|
||||
err = s.SendAsync([]string{peerId}, msg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
delay := time.NewTimer(syncWaitPeriod)
|
||||
select {
|
||||
case <-delay.C:
|
||||
s.waitersMx.Lock()
|
||||
delete(s.waiters, msg.ReplyId)
|
||||
s.waitersMx.Unlock()
|
||||
|
||||
reply = <-waiter.ch
|
||||
log.With("replyId", msg.ReplyId).Error("time elapsed when waiting")
|
||||
err = ErrSyncTimeout
|
||||
case reply = <-waiter.ch:
|
||||
if !delay.Stop() {
|
||||
<-delay.C
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamPool) SendAsync(peers []string, message *spacesyncproto.ObjectSyncMessage) (err error) {
|
||||
s.lastUsage.Store(time.Now().Unix())
|
||||
getStreams := func() (streams []spacesyncproto.SpaceStream) {
|
||||
for _, pId := range peers {
|
||||
stream, err := s.getOrDeleteStream(pId)
|
||||
@ -100,10 +125,13 @@ func (s *streamPool) SendAsync(peers []string, message *spacesyncproto.ObjectSyn
|
||||
streams := getStreams()
|
||||
s.Unlock()
|
||||
|
||||
log.With("objectId", message.ObjectId).
|
||||
Debugf("sending message to %d peers", len(streams))
|
||||
for _, s := range streams {
|
||||
if len(peers) == 1 {
|
||||
err = s.Send(message)
|
||||
}
|
||||
if len(peers) != 1 {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -145,6 +173,8 @@ Loop:
|
||||
|
||||
func (s *streamPool) BroadcastAsync(message *spacesyncproto.ObjectSyncMessage) (err error) {
|
||||
streams := s.getAllStreams()
|
||||
log.With("objectId", message.ObjectId).
|
||||
Debugf("broadcasting message to %d peers", len(streams))
|
||||
for _, stream := range streams {
|
||||
if err = stream.Send(message); err != nil {
|
||||
// TODO: add logging
|
||||
@ -191,28 +221,33 @@ func (s *streamPool) readPeerLoop(peerId string, stream spacesyncproto.SpaceStre
|
||||
}
|
||||
|
||||
process := func(msg *spacesyncproto.ObjectSyncMessage) {
|
||||
if msg.TrackingId == "" {
|
||||
s.lastUsage.Store(time.Now().Unix())
|
||||
if msg.ReplyId == "" {
|
||||
s.messageHandler(stream.Context(), peerId, msg)
|
||||
return
|
||||
}
|
||||
|
||||
log.With("replyId", msg.ReplyId).Debug("getting message with reply id")
|
||||
s.waitersMx.Lock()
|
||||
waiter, exists := s.waiters[msg.TrackingId]
|
||||
waiter, exists := s.waiters[msg.ReplyId]
|
||||
|
||||
if !exists {
|
||||
log.With("replyId", msg.ReplyId).Debug("reply id not exists")
|
||||
s.waitersMx.Unlock()
|
||||
s.messageHandler(stream.Context(), peerId, msg)
|
||||
return
|
||||
}
|
||||
log.With("replyId", msg.ReplyId).Debug("reply id exists")
|
||||
|
||||
delete(s.waiters, msg.TrackingId)
|
||||
delete(s.waiters, msg.ReplyId)
|
||||
s.waitersMx.Unlock()
|
||||
waiter.ch <- msg
|
||||
}
|
||||
|
||||
Loop:
|
||||
for {
|
||||
msg, err := stream.Recv()
|
||||
var msg *spacesyncproto.ObjectSyncMessage
|
||||
msg, err = stream.Recv()
|
||||
s.lastUsage.Store(time.Now().Unix())
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
@ -226,7 +261,8 @@ Loop:
|
||||
limiter <- struct{}{}
|
||||
}()
|
||||
}
|
||||
return s.removePeer(peerId)
|
||||
s.removePeer(peerId)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamPool) removePeer(peerId string) (err error) {
|
||||
|
||||
339
common/commonspace/syncservice/streampool_test.go
Normal file
339
common/commonspace/syncservice/streampool_test.go
Normal file
@ -0,0 +1,339 @@
|
||||
package syncservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/common/net/rpc/rpctest"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/consensus/consensusproto"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/stretchr/testify/require"
|
||||
"storj.io/drpc"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testPeer struct {
|
||||
id string
|
||||
drpc.Conn
|
||||
}
|
||||
|
||||
func (t testPeer) Id() string {
|
||||
return t.id
|
||||
}
|
||||
|
||||
func (t testPeer) LastUsage() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
func (t testPeer) UpdateLastUsage() {}
|
||||
|
||||
type testServer struct {
|
||||
stream chan spacesyncproto.DRPCSpace_StreamStream
|
||||
addLog func(ctx context.Context, req *consensusproto.AddLogRequest) error
|
||||
addRecord func(ctx context.Context, req *consensusproto.AddRecordRequest) error
|
||||
releaseStream chan error
|
||||
watchErrOnce bool
|
||||
}
|
||||
|
||||
func (t *testServer) HeadSync(ctx context.Context, request *spacesyncproto.HeadSyncRequest) (*spacesyncproto.HeadSyncResponse, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (t *testServer) PushSpace(ctx context.Context, request *spacesyncproto.PushSpaceRequest) (*spacesyncproto.PushSpaceResponse, error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (t *testServer) Stream(stream spacesyncproto.DRPCSpace_StreamStream) error {
|
||||
t.stream <- stream
|
||||
return <-t.releaseStream
|
||||
}
|
||||
|
||||
func (t *testServer) waitStream(test *testing.T) spacesyncproto.DRPCSpace_StreamStream {
|
||||
select {
|
||||
case <-time.After(time.Second * 5):
|
||||
test.Fatalf("waiteStream timeout")
|
||||
case st := <-t.stream:
|
||||
return st
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type fixture struct {
|
||||
testServer *testServer
|
||||
drpcTS *rpctest.TesServer
|
||||
client spacesyncproto.DRPCSpaceClient
|
||||
clientStream spacesyncproto.DRPCSpace_StreamStream
|
||||
serverStream spacesyncproto.DRPCSpace_StreamStream
|
||||
pool *streamPool
|
||||
localId peer.ID
|
||||
remoteId peer.ID
|
||||
}
|
||||
|
||||
func newFixture(t *testing.T, localId, remoteId peer.ID, handler MessageHandler) *fixture {
|
||||
fx := &fixture{
|
||||
testServer: &testServer{},
|
||||
drpcTS: rpctest.NewTestServer(),
|
||||
localId: localId,
|
||||
remoteId: remoteId,
|
||||
}
|
||||
fx.testServer.stream = make(chan spacesyncproto.DRPCSpace_StreamStream, 1)
|
||||
require.NoError(t, spacesyncproto.DRPCRegisterSpace(fx.drpcTS.Mux, fx.testServer))
|
||||
clientWrapper := rpctest.NewSecConnWrapper(nil, nil, localId, remoteId)
|
||||
p := &testPeer{id: localId.String(), Conn: fx.drpcTS.DialWrapConn(nil, clientWrapper)}
|
||||
fx.client = spacesyncproto.NewDRPCSpaceClient(p)
|
||||
|
||||
var err error
|
||||
fx.clientStream, err = fx.client.Stream(context.Background())
|
||||
require.NoError(t, err)
|
||||
fx.serverStream = fx.testServer.waitStream(t)
|
||||
fx.pool = newStreamPool(handler).(*streamPool)
|
||||
|
||||
return fx
|
||||
}
|
||||
|
||||
func (fx *fixture) run(t *testing.T) chan error {
|
||||
waitCh := make(chan error)
|
||||
go func() {
|
||||
err := fx.pool.AddAndReadStreamSync(fx.clientStream)
|
||||
waitCh <- err
|
||||
}()
|
||||
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
fx.pool.Lock()
|
||||
require.Equal(t, fx.pool.peerStreams[fx.remoteId.String()], fx.clientStream)
|
||||
fx.pool.Unlock()
|
||||
|
||||
return waitCh
|
||||
}
|
||||
|
||||
func TestStreamPool_AddAndReadStreamAsync(t *testing.T) {
|
||||
remId := peer.ID("remoteId")
|
||||
|
||||
t.Run("client close", func(t *testing.T) {
|
||||
fx := newFixture(t, "", remId, nil)
|
||||
waitCh := fx.run(t)
|
||||
|
||||
err := fx.clientStream.Close()
|
||||
require.NoError(t, err)
|
||||
err = <-waitCh
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
||||
})
|
||||
t.Run("server close", func(t *testing.T) {
|
||||
fx := newFixture(t, "", remId, nil)
|
||||
waitCh := fx.run(t)
|
||||
|
||||
err := fx.serverStream.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = <-waitCh
|
||||
require.Error(t, err)
|
||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamPool_Close(t *testing.T) {
|
||||
remId := peer.ID("remoteId")
|
||||
|
||||
t.Run("client close", func(t *testing.T) {
|
||||
fx := newFixture(t, "", remId, nil)
|
||||
fx.run(t)
|
||||
var events []string
|
||||
recvChan := make(chan struct{})
|
||||
go func() {
|
||||
fx.pool.Close()
|
||||
events = append(events, "pool_close")
|
||||
recvChan <- struct{}{}
|
||||
}()
|
||||
time.Sleep(50 * time.Millisecond) //err = <-waitCh
|
||||
events = append(events, "stream_close")
|
||||
err := fx.clientStream.Close()
|
||||
require.NoError(t, err)
|
||||
<-recvChan
|
||||
require.Equal(t, []string{"stream_close", "pool_close"}, events)
|
||||
})
|
||||
t.Run("server close", func(t *testing.T) {
|
||||
fx := newFixture(t, "", remId, nil)
|
||||
fx.run(t)
|
||||
var events []string
|
||||
recvChan := make(chan struct{})
|
||||
go func() {
|
||||
fx.pool.Close()
|
||||
events = append(events, "pool_close")
|
||||
recvChan <- struct{}{}
|
||||
}()
|
||||
time.Sleep(50 * time.Millisecond) //err = <-waitCh
|
||||
events = append(events, "stream_close")
|
||||
err := fx.clientStream.Close()
|
||||
require.NoError(t, err)
|
||||
<-recvChan
|
||||
require.Equal(t, []string{"stream_close", "pool_close"}, events)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamPool_ReceiveMessage(t *testing.T) {
|
||||
remId := peer.ID("remoteId")
|
||||
t.Run("pool receive message from server", func(t *testing.T) {
|
||||
objectId := "objectId"
|
||||
msg := &spacesyncproto.ObjectSyncMessage{
|
||||
ObjectId: objectId,
|
||||
}
|
||||
recvChan := make(chan struct{})
|
||||
fx := newFixture(t, "", remId, func(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) {
|
||||
require.Equal(t, msg, message)
|
||||
recvChan <- struct{}{}
|
||||
return nil
|
||||
})
|
||||
waitCh := fx.run(t)
|
||||
|
||||
err := fx.serverStream.Send(msg)
|
||||
require.NoError(t, err)
|
||||
<-recvChan
|
||||
err = fx.clientStream.Close()
|
||||
require.NoError(t, err)
|
||||
err = <-waitCh
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamPool_HasActiveStream(t *testing.T) {
|
||||
remId := peer.ID("remoteId")
|
||||
t.Run("pool has active stream", func(t *testing.T) {
|
||||
fx := newFixture(t, "", remId, nil)
|
||||
waitCh := fx.run(t)
|
||||
require.True(t, fx.pool.HasActiveStream(remId.String()))
|
||||
|
||||
err := fx.clientStream.Close()
|
||||
require.NoError(t, err)
|
||||
err = <-waitCh
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
||||
})
|
||||
t.Run("pool has no active stream", func(t *testing.T) {
|
||||
fx := newFixture(t, "", remId, nil)
|
||||
waitCh := fx.run(t)
|
||||
err := fx.clientStream.Close()
|
||||
require.NoError(t, err)
|
||||
err = <-waitCh
|
||||
require.Error(t, err)
|
||||
require.False(t, fx.pool.HasActiveStream(remId.String()))
|
||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamPool_SendAsync(t *testing.T) {
|
||||
remId := peer.ID("remoteId")
|
||||
t.Run("pool send async to server", func(t *testing.T) {
|
||||
objectId := "objectId"
|
||||
msg := &spacesyncproto.ObjectSyncMessage{
|
||||
ObjectId: objectId,
|
||||
}
|
||||
fx := newFixture(t, "", remId, nil)
|
||||
recvChan := make(chan struct{})
|
||||
go func() {
|
||||
message, err := fx.serverStream.Recv()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, msg, message)
|
||||
recvChan <- struct{}{}
|
||||
}()
|
||||
waitCh := fx.run(t)
|
||||
|
||||
err := fx.pool.SendAsync([]string{remId.String()}, msg)
|
||||
require.NoError(t, err)
|
||||
<-recvChan
|
||||
err = fx.clientStream.Close()
|
||||
require.NoError(t, err)
|
||||
err = <-waitCh
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamPool_SendSync(t *testing.T) {
|
||||
remId := peer.ID("remoteId")
|
||||
t.Run("pool send sync to server", func(t *testing.T) {
|
||||
objectId := "objectId"
|
||||
payload := []byte("payload")
|
||||
msg := &spacesyncproto.ObjectSyncMessage{
|
||||
ObjectId: objectId,
|
||||
}
|
||||
fx := newFixture(t, "", remId, nil)
|
||||
go func() {
|
||||
message, err := fx.serverStream.Recv()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, msg.ObjectId, message.ObjectId)
|
||||
require.NotEmpty(t, message.ReplyId)
|
||||
message.Payload = payload
|
||||
err = fx.serverStream.Send(message)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
waitCh := fx.run(t)
|
||||
res, err := fx.pool.SendSync(remId.String(), msg)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, payload, res.Payload)
|
||||
err = fx.clientStream.Close()
|
||||
require.NoError(t, err)
|
||||
err = <-waitCh
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
||||
})
|
||||
|
||||
t.Run("pool send sync timeout", func(t *testing.T) {
|
||||
objectId := "objectId"
|
||||
msg := &spacesyncproto.ObjectSyncMessage{
|
||||
ObjectId: objectId,
|
||||
}
|
||||
fx := newFixture(t, "", remId, nil)
|
||||
syncWaitPeriod = time.Millisecond * 30
|
||||
go func() {
|
||||
message, err := fx.serverStream.Recv()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, msg.ObjectId, message.ObjectId)
|
||||
require.NotEmpty(t, message.ReplyId)
|
||||
}()
|
||||
waitCh := fx.run(t)
|
||||
_, err := fx.pool.SendSync(remId.String(), msg)
|
||||
require.Equal(t, ErrSyncTimeout, err)
|
||||
err = fx.clientStream.Close()
|
||||
require.NoError(t, err)
|
||||
err = <-waitCh
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
||||
})
|
||||
}
|
||||
|
||||
func TestStreamPool_BroadcastAsync(t *testing.T) {
|
||||
remId := peer.ID("remoteId")
|
||||
t.Run("pool broadcast async to server", func(t *testing.T) {
|
||||
objectId := "objectId"
|
||||
msg := &spacesyncproto.ObjectSyncMessage{
|
||||
ObjectId: objectId,
|
||||
}
|
||||
fx := newFixture(t, "", remId, nil)
|
||||
recvChan := make(chan struct{})
|
||||
go func() {
|
||||
message, err := fx.serverStream.Recv()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, msg, message)
|
||||
recvChan <- struct{}{}
|
||||
}()
|
||||
waitCh := fx.run(t)
|
||||
|
||||
err := fx.pool.BroadcastAsync(msg)
|
||||
require.NoError(t, err)
|
||||
<-recvChan
|
||||
err = fx.clientStream.Close()
|
||||
require.NoError(t, err)
|
||||
err = <-waitCh
|
||||
|
||||
require.Error(t, err)
|
||||
require.Nil(t, fx.pool.peerStreams[remId.String()])
|
||||
})
|
||||
}
|
||||
@ -104,7 +104,7 @@ func (s *syncService) responsibleStreamCheckLoop(ctx context.Context) {
|
||||
stream, err := s.clientFactory.Client(peer).Stream(ctx)
|
||||
if err != nil {
|
||||
err = rpcerr.Unwrap(err)
|
||||
log.With("spaceId", s.spaceId).Errorf("failed to open stream: %v", err)
|
||||
log.With("spaceId", s.spaceId).Errorf("failed to open clientStream: %v", err)
|
||||
// so here probably the request is failed because there is no such space,
|
||||
// but diffService should handle such cases by sending pushSpace
|
||||
continue
|
||||
@ -113,7 +113,7 @@ func (s *syncService) responsibleStreamCheckLoop(ctx context.Context) {
|
||||
err = stream.Send(&spacesyncproto.ObjectSyncMessage{SpaceId: s.spaceId})
|
||||
if err != nil {
|
||||
err = rpcerr.Unwrap(err)
|
||||
log.With("spaceId", s.spaceId).Errorf("failed to send first message to stream: %v", err)
|
||||
log.With("spaceId", s.spaceId).Errorf("failed to send first message to clientStream: %v", err)
|
||||
continue
|
||||
}
|
||||
s.streamPool.AddAndReadStreamAsync(stream)
|
||||
|
||||
@ -2,6 +2,8 @@ package rpctest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/libp2p/go-libp2p/core/crypto"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"net"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcconn"
|
||||
@ -9,6 +11,48 @@ import (
|
||||
"storj.io/drpc/drpcserver"
|
||||
)
|
||||
|
||||
type SecConnMock struct {
|
||||
net.Conn
|
||||
localPrivKey crypto.PrivKey
|
||||
remotePubKey crypto.PubKey
|
||||
localId peer.ID
|
||||
remoteId peer.ID
|
||||
}
|
||||
|
||||
func (s *SecConnMock) LocalPeer() peer.ID {
|
||||
return s.localId
|
||||
}
|
||||
|
||||
func (s *SecConnMock) LocalPrivateKey() crypto.PrivKey {
|
||||
return s.localPrivKey
|
||||
}
|
||||
|
||||
func (s *SecConnMock) RemotePeer() peer.ID {
|
||||
return s.remoteId
|
||||
}
|
||||
|
||||
func (s *SecConnMock) RemotePublicKey() crypto.PubKey {
|
||||
return s.remotePubKey
|
||||
}
|
||||
|
||||
type ConnWrapper func(conn net.Conn) net.Conn
|
||||
|
||||
func NewSecConnWrapper(
|
||||
localPrivKey crypto.PrivKey,
|
||||
remotePubKey crypto.PubKey,
|
||||
localId peer.ID,
|
||||
remoteId peer.ID) ConnWrapper {
|
||||
return func(conn net.Conn) net.Conn {
|
||||
return &SecConnMock{
|
||||
Conn: conn,
|
||||
localPrivKey: localPrivKey,
|
||||
remotePubKey: remotePubKey,
|
||||
localId: localId,
|
||||
remoteId: remoteId,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func NewTestServer() *TesServer {
|
||||
ts := &TesServer{
|
||||
Mux: drpcmux.New(),
|
||||
@ -23,7 +67,17 @@ type TesServer struct {
|
||||
}
|
||||
|
||||
func (ts *TesServer) Dial() drpc.Conn {
|
||||
return ts.DialWrapConn(nil, nil)
|
||||
}
|
||||
|
||||
func (ts *TesServer) DialWrapConn(serverWrapper ConnWrapper, clientWrapper ConnWrapper) drpc.Conn {
|
||||
sc, cc := net.Pipe()
|
||||
if serverWrapper != nil {
|
||||
sc = serverWrapper(sc)
|
||||
}
|
||||
if clientWrapper != nil {
|
||||
cc = clientWrapper(cc)
|
||||
}
|
||||
go ts.Server.ServeOne(context.Background(), sc)
|
||||
return drpcconn.New(cc)
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user