peer outgoing proto handshake + test multiconn + streampool tests

This commit is contained in:
Sergey Cherepanov 2023-06-05 20:39:09 +02:00
parent 96768adaae
commit 0c0a501aad
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C
7 changed files with 97 additions and 171 deletions

View File

@ -38,6 +38,7 @@ func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) {
type Peer interface { type Peer interface {
Id() string Id() string
Context() context.Context
AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error)
ReleaseDrpcConn(conn drpc.Conn) ReleaseDrpcConn(conn drpc.Conn)
@ -70,19 +71,20 @@ func (p *peer) Id() string {
func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
p.mu.Lock() p.mu.Lock()
defer p.mu.Unlock()
if len(p.inactive) == 0 { if len(p.inactive) == 0 {
conn, err := p.Open(ctx) p.mu.Unlock()
dconn, err := p.openDrpcConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dconn := drpcconn.New(conn) p.mu.Lock()
p.inactive = append(p.inactive, dconn) p.inactive = append(p.inactive, dconn)
} }
idx := len(p.inactive) - 1 idx := len(p.inactive) - 1
res := p.inactive[idx] res := p.inactive[idx]
p.inactive = p.inactive[:idx] p.inactive = p.inactive[:idx]
p.active[res] = struct{}{} p.active[res] = struct{}{}
p.mu.Unlock()
return res, nil return res, nil
} }
@ -105,6 +107,18 @@ func (p *peer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error
return do(conn) return do(conn)
} }
func (p *peer) openDrpcConn(ctx context.Context) (dconn drpc.Conn, err error) {
conn, err := p.Open(ctx)
if err != nil {
return nil, err
}
if err = handshake.OutgoingProtoHandshake(ctx, conn, handshakeproto.ProtoType_DRPC); err != nil {
return nil, err
}
dconn = drpcconn.New(conn)
return
}
func (p *peer) acceptLoop() { func (p *peer) acceptLoop() {
var exitErr error var exitErr error
defer func() { defer func() {

30
net/rpc/rpctest/peer.go Normal file
View File

@ -0,0 +1,30 @@
package rpctest
import (
"context"
"github.com/anyproto/any-sync/net/connutil"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/transport"
yamux2 "github.com/anyproto/any-sync/net/transport/yamux"
"github.com/hashicorp/yamux"
"net"
)
func MultiConnPair(peerIdServ, peerIdClient string) (serv, client transport.MultiConn) {
sc, cc := net.Pipe()
var servConn = make(chan transport.MultiConn, 1)
go func() {
sess, err := yamux.Server(sc, yamux.DefaultConfig())
if err != nil {
panic(err)
}
servConn <- yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdServ), connutil.NewLastUsageConn(sc), "", sess)
}()
sess, err := yamux.Client(cc, yamux.DefaultConfig())
if err != nil {
panic(err)
}
client = yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdClient), connutil.NewLastUsageConn(cc), "", sess)
serv = <-servConn
return
}

View File

@ -1,122 +0,0 @@
package rpctest
import (
"context"
"errors"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/pool"
"math/rand"
"storj.io/drpc"
"sync"
"time"
)
var ErrCantConnect = errors.New("can't connect to test server")
func NewTestPool() *TestPool {
return &TestPool{
peers: map[string]peer.Peer{},
}
}
type TestPool struct {
ts *TesServer
peers map[string]peer.Peer
mu sync.Mutex
}
func (t *TestPool) WithServer(ts *TesServer) *TestPool {
t.mu.Lock()
defer t.mu.Unlock()
t.ts = ts
return t
}
func (t *TestPool) Get(ctx context.Context, id string) (peer.Peer, error) {
t.mu.Lock()
defer t.mu.Unlock()
if p, ok := t.peers[id]; ok {
return p, nil
}
if t.ts == nil {
return nil, ErrCantConnect
}
return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil
}
func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) {
return t.Get(ctx, id)
}
func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
t.mu.Lock()
defer t.mu.Unlock()
for _, peerId := range peerIds {
if p, ok := t.peers[peerId]; ok {
return p, nil
}
}
if t.ts == nil {
return nil, ErrCantConnect
}
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
}
func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.ts == nil {
return nil, ErrCantConnect
}
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
}
func (t *TestPool) NewPool(name string) pool.Pool {
return t
}
func (t *TestPool) AddPeer(p peer.Peer) {
t.mu.Lock()
defer t.mu.Unlock()
t.peers[p.Id()] = p
}
func (t *TestPool) Init(a *app.App) (err error) {
return nil
}
func (t *TestPool) Name() (name string) {
return pool.CName
}
func (t *TestPool) Run(ctx context.Context) (err error) {
return nil
}
func (t *TestPool) Close(ctx context.Context) (err error) {
return nil
}
type testPeer struct {
id string
drpc.Conn
}
func (t testPeer) Addr() string {
return ""
}
func (t testPeer) TryClose(objectTTL time.Duration) (res bool, err error) {
return true, t.Close()
}
func (t testPeer) Id() string {
return t.id
}
func (t testPeer) LastUsage() time.Time {
return time.Now()
}
func (t testPeer) UpdateLastUsage() {}

View File

@ -5,43 +5,39 @@ import (
"github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/net/rpc/server" "github.com/anyproto/any-sync/net/rpc/server"
"net" "net"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
"storj.io/drpc/drpcmux" "storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver" "storj.io/drpc/drpcserver"
) )
func NewTestServer() *TesServer { func NewTestServer() *TestServer {
ts := &TesServer{ ts := &TestServer{
Mux: drpcmux.New(), Mux: drpcmux.New(),
} }
ts.Server = drpcserver.New(ts.Mux) ts.Server = drpcserver.New(ts.Mux)
return ts return ts
} }
type TesServer struct { type TestServer struct {
*drpcmux.Mux *drpcmux.Mux
*drpcserver.Server *drpcserver.Server
} }
func (ts *TesServer) Init(a *app.App) (err error) { func (ts *TestServer) Init(a *app.App) (err error) {
return nil return nil
} }
func (ts *TesServer) Name() (name string) { func (ts *TestServer) Name() (name string) {
return server.CName return server.CName
} }
func (ts *TesServer) Run(ctx context.Context) (err error) { func (ts *TestServer) Run(ctx context.Context) (err error) {
return nil return nil
} }
func (ts *TesServer) Close(ctx context.Context) (err error) { func (ts *TestServer) Close(ctx context.Context) (err error) {
return nil return nil
} }
func (ts *TesServer) Dial(ctx context.Context) drpc.Conn { func (s *TestServer) ServeConn(ctx context.Context, conn net.Conn) (err error) {
sc, cc := net.Pipe() return s.Server.ServeOne(ctx, conn)
go ts.Server.ServeOne(ctx, sc)
return drpcconn.New(cc)
} }

View File

@ -18,17 +18,25 @@ import (
var ctx = context.Background() var ctx = context.Background()
func makePeerPair(t *testing.T, fx *fixture, peerId string) (pS, pC peer.Peer) {
mcS, mcC := rpctest.MultiConnPair(peerId+"server", peerId)
pS, err := peer.NewPeer(mcS, fx.ts)
require.NoError(t, err)
pC, err = peer.NewPeer(mcC, fx.ts)
require.NoError(t, err)
return
}
func newClientStream(t *testing.T, fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) { func newClientStream(t *testing.T, fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) {
p, err := fx.tp.Dial(ctx, peerId) _, pC := makePeerPair(t, fx, peerId)
drpcConn, err := pC.AcquireDrpcConn(ctx)
require.NoError(t, err) require.NoError(t, err)
ctx = peer.CtxWithPeerId(ctx, peerId) st, err = testservice.NewDRPCTestClient(drpcConn).TestStream(pC.Context())
s, err := testservice.NewDRPCTestClient(p).TestStream(ctx)
require.NoError(t, err) require.NoError(t, err)
return s, p return st, pC
} }
func TestStreamPool_AddStream(t *testing.T) { func TestStreamPool_AddStream(t *testing.T) {
t.Run("broadcast incoming", func(t *testing.T) { t.Run("broadcast incoming", func(t *testing.T) {
fx := newFixture(t) fx := newFixture(t)
defer fx.Finish(t) defer fx.Finish(t)
@ -85,11 +93,10 @@ func TestStreamPool_Send(t *testing.T) {
fx := newFixture(t) fx := newFixture(t)
defer fx.Finish(t) defer fx.Finish(t)
p, err := fx.tp.Dial(ctx, "p1") pS, _ := makePeerPair(t, fx, "p1")
require.NoError(t, err)
require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) { require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) {
return []peer.Peer{p}, nil return []peer.Peer{pS}, nil
})) }))
var msg *testservice.StreamMessage var msg *testservice.StreamMessage
@ -100,12 +107,12 @@ func TestStreamPool_Send(t *testing.T) {
} }
assert.Equal(t, "should open stream", msg.ReqData) assert.Equal(t, "should open stream", msg.ReqData)
}) })
t.Run("parallel open stream", func(t *testing.T) { t.Run("parallel open stream", func(t *testing.T) {
fx := newFixture(t) fx := newFixture(t)
defer fx.Finish(t) defer fx.Finish(t)
p, err := fx.tp.Dial(ctx, "p1") pS, _ := makePeerPair(t, fx, "p1")
require.NoError(t, err)
fx.th.streamOpenDelay = time.Second / 3 fx.th.streamOpenDelay = time.Second / 3
@ -113,7 +120,7 @@ func TestStreamPool_Send(t *testing.T) {
for i := 0; i < numMsgs; i++ { for i := 0; i < numMsgs; i++ {
go require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) { go require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) {
return []peer.Peer{p}, nil return []peer.Peer{pS}, nil
})) }))
} }
@ -134,9 +141,8 @@ func TestStreamPool_Send(t *testing.T) {
fx := newFixture(t) fx := newFixture(t)
defer fx.Finish(t) defer fx.Finish(t)
p, err := fx.tp.Dial(ctx, "p1") pS, _ := makePeerPair(t, fx, "p1")
require.NoError(t, err) _ = pS.Close()
_ = p.Close()
fx.th.streamOpenDelay = time.Second / 3 fx.th.streamOpenDelay = time.Second / 3
@ -147,11 +153,12 @@ func TestStreamPool_Send(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, p, &testservice.StreamMessage{ReqData: "should open stream"})) assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, pS, &testservice.StreamMessage{ReqData: "should open stream"}))
}() }()
} }
wg.Wait() wg.Wait()
}) })
} }
func TestStreamPool_SendById(t *testing.T) { func TestStreamPool_SendById(t *testing.T) {
@ -196,10 +203,9 @@ func TestStreamPool_Tags(t *testing.T) {
func newFixture(t *testing.T) *fixture { func newFixture(t *testing.T) *fixture {
fx := &fixture{} fx := &fixture{}
ts := rpctest.NewTestServer() fx.ts = rpctest.NewTestServer()
fx.tsh = &testServerHandler{receiveCh: make(chan *testservice.StreamMessage, 100)} fx.tsh = &testServerHandler{receiveCh: make(chan *testservice.StreamMessage, 100)}
require.NoError(t, testservice.DRPCRegisterTest(ts, fx.tsh)) require.NoError(t, testservice.DRPCRegisterTest(fx.ts, fx.tsh))
fx.tp = rpctest.NewTestPool().WithServer(ts)
fx.th = &testHandler{} fx.th = &testHandler{}
fx.StreamPool = New().NewStreamPool(fx.th, StreamConfig{ fx.StreamPool = New().NewStreamPool(fx.th, StreamConfig{
SendQueueSize: 10, SendQueueSize: 10,
@ -211,14 +217,13 @@ func newFixture(t *testing.T) *fixture {
type fixture struct { type fixture struct {
StreamPool StreamPool
tp *rpctest.TestPool
th *testHandler th *testHandler
tsh *testServerHandler tsh *testServerHandler
ts *rpctest.TestServer
} }
func (fx *fixture) Finish(t *testing.T) { func (fx *fixture) Finish(t *testing.T) {
require.NoError(t, fx.Close()) require.NoError(t, fx.Close())
require.NoError(t, fx.tp.Close(ctx))
} }
type testHandler struct { type testHandler struct {
@ -231,7 +236,11 @@ func (t *testHandler) OpenStream(ctx context.Context, p peer.Peer) (stream drpc.
if t.streamOpenDelay > 0 { if t.streamOpenDelay > 0 {
time.Sleep(t.streamOpenDelay) time.Sleep(t.streamOpenDelay)
} }
stream, err = testservice.NewDRPCTestClient(p).TestStream(ctx) conn, err := p.AcquireDrpcConn(ctx)
if err != nil {
return
}
stream, err = testservice.NewDRPCTestClient(conn).TestStream(p.Context())
return return
} }

View File

@ -9,6 +9,15 @@ import (
"time" "time"
) )
func NewMultiConn(cctx context.Context, luConn *connutil.LastUsageConn, addr string, sess *yamux.Session) transport.MultiConn {
return &yamuxConn{
ctx: cctx,
luConn: luConn,
addr: addr,
Session: sess,
}
}
type yamuxConn struct { type yamuxConn struct {
ctx context.Context ctx context.Context
luConn *connutil.LastUsageConn luConn *connutil.LastUsageConn

View File

@ -96,12 +96,7 @@ func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.Mu
if err != nil { if err != nil {
return return
} }
mc = &yamuxConn{ mc = NewMultiConn(cctx, luc, addr, sess)
ctx: cctx,
luConn: luc,
Session: sess,
addr: addr,
}
return return
} }
@ -148,12 +143,7 @@ func (y *yamuxTransport) accept(conn net.Conn) {
log.Warn("incoming connection yamux session error", zap.Error(err)) log.Warn("incoming connection yamux session error", zap.Error(err))
return return
} }
mc := &yamuxConn{ mc := NewMultiConn(cctx, luc, conn.RemoteAddr().String(), sess)
ctx: cctx,
luConn: luc,
Session: sess,
addr: conn.RemoteAddr().String(),
}
if err = y.accepter.Accept(mc); err != nil { if err = y.accepter.Accept(mc); err != nil {
log.Warn("connection accept error", zap.Error(err)) log.Warn("connection accept error", zap.Error(err))
} }