diff --git a/commonfile/fileproto/file_drpc.pb.go b/commonfile/fileproto/file_drpc.pb.go index 2f9ee69d..a03c22cd 100644 --- a/commonfile/fileproto/file_drpc.pb.go +++ b/commonfile/fileproto/file_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: commonfile/fileproto/protos/file.proto package fileproto diff --git a/commonspace/spacesyncproto/spacesync_drpc.pb.go b/commonspace/spacesyncproto/spacesync_drpc.pb.go index f9c7abae..11e5d715 100644 --- a/commonspace/spacesyncproto/spacesync_drpc.pb.go +++ b/commonspace/spacesyncproto/spacesync_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: commonspace/spacesyncproto/protos/spacesync.proto package spacesyncproto @@ -103,6 +103,10 @@ type drpcSpaceSync_ObjectSyncStreamClient struct { drpc.Stream } +func (x *drpcSpaceSync_ObjectSyncStreamClient) GetStream() drpc.Stream { + return x.Stream +} + func (x *drpcSpaceSync_ObjectSyncStreamClient) Send(m *ObjectSyncMessage) error { return x.MsgSend(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}) } diff --git a/coordinator/coordinatorclient/coordinatorclient.go b/coordinator/coordinatorclient/coordinatorclient.go index 4847ade7..b804d9a5 100644 --- a/coordinator/coordinatorclient/coordinatorclient.go +++ b/coordinator/coordinatorclient/coordinatorclient.go @@ -10,6 +10,7 @@ import ( "github.com/anyproto/any-sync/net/rpc/rpcerr" "github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/util/crypto" + "storj.io/drpc" ) const CName = "common.coordinator.coordinatorclient" @@ -39,42 +40,8 @@ type coordinatorClient struct { nodeConf nodeconf.Service } -func (c *coordinatorClient) ChangeStatus(ctx context.Context, spaceId string, deleteRaw *treechangeproto.RawTreeChangeWithId) (status *coordinatorproto.SpaceStatusPayload, err error) { - cl, err := c.client(ctx) - if err != nil { - return - } - resp, err := cl.SpaceStatusChange(ctx, &coordinatorproto.SpaceStatusChangeRequest{ - SpaceId: spaceId, - DeletionChangeId: deleteRaw.GetId(), - DeletionChangePayload: deleteRaw.GetRawChange(), - }) - if err != nil { - err = rpcerr.Unwrap(err) - return - } - status = resp.Payload - return -} - -func (c *coordinatorClient) StatusCheck(ctx context.Context, spaceId string) (status *coordinatorproto.SpaceStatusPayload, err error) { - cl, err := c.client(ctx) - if err != nil { - return - } - resp, err := cl.SpaceStatusCheck(ctx, &coordinatorproto.SpaceStatusCheckRequest{ - SpaceId: spaceId, - }) - if err != nil { - err = rpcerr.Unwrap(err) - return - } - status = resp.Payload - return -} - func (c *coordinatorClient) Init(a *app.App) (err error) { - c.pool = a.MustComponent(pool.CName).(pool.Service).NewPool(CName) + c.pool = a.MustComponent(pool.CName).(pool.Service) c.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.Service) return } @@ -83,8 +50,37 @@ func (c *coordinatorClient) Name() (name string) { return CName } +func (c *coordinatorClient) ChangeStatus(ctx context.Context, spaceId string, deleteRaw *treechangeproto.RawTreeChangeWithId) (status *coordinatorproto.SpaceStatusPayload, err error) { + err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error { + resp, err := cl.SpaceStatusChange(ctx, &coordinatorproto.SpaceStatusChangeRequest{ + SpaceId: spaceId, + DeletionChangeId: deleteRaw.GetId(), + DeletionChangePayload: deleteRaw.GetRawChange(), + }) + if err != nil { + return rpcerr.Unwrap(err) + } + status = resp.Payload + return nil + }) + return +} + +func (c *coordinatorClient) StatusCheck(ctx context.Context, spaceId string) (status *coordinatorproto.SpaceStatusPayload, err error) { + err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error { + resp, err := cl.SpaceStatusCheck(ctx, &coordinatorproto.SpaceStatusCheckRequest{ + SpaceId: spaceId, + }) + if err != nil { + return rpcerr.Unwrap(err) + } + status = resp.Payload + return nil + }) + return +} + func (c *coordinatorClient) SpaceSign(ctx context.Context, payload SpaceSignPayload) (receipt *coordinatorproto.SpaceReceiptWithSignature, err error) { - cl, err := c.client(ctx) if err != nil { return } @@ -100,54 +96,56 @@ func (c *coordinatorClient) SpaceSign(ctx context.Context, payload SpaceSignPayl if err != nil { return } - resp, err := cl.SpaceSign(ctx, &coordinatorproto.SpaceSignRequest{ - SpaceId: payload.SpaceId, - Header: payload.SpaceHeader, - OldIdentity: oldIdentity, - NewIdentitySignature: newSignature, + err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error { + resp, err := cl.SpaceSign(ctx, &coordinatorproto.SpaceSignRequest{ + SpaceId: payload.SpaceId, + Header: payload.SpaceHeader, + OldIdentity: oldIdentity, + NewIdentitySignature: newSignature, + }) + if err != nil { + return rpcerr.Unwrap(err) + } + receipt = resp.Receipt + return nil }) - if err != nil { - err = rpcerr.Unwrap(err) - return - } - return resp.Receipt, nil -} - -func (c *coordinatorClient) FileLimitCheck(ctx context.Context, spaceId string, identity []byte) (limit uint64, err error) { - cl, err := c.client(ctx) - if err != nil { - return - } - resp, err := cl.FileLimitCheck(ctx, &coordinatorproto.FileLimitCheckRequest{ - AccountIdentity: identity, - SpaceId: spaceId, - }) - if err != nil { - err = rpcerr.Unwrap(err) - return - } - return resp.Limit, nil -} - -func (c *coordinatorClient) NetworkConfiguration(ctx context.Context, currentId string) (resp *coordinatorproto.NetworkConfigurationResponse, err error) { - cl, err := c.client(ctx) - if err != nil { - return - } - resp, err = cl.NetworkConfiguration(ctx, &coordinatorproto.NetworkConfigurationRequest{ - CurrentId: currentId, - }) - if err != nil { - err = rpcerr.Unwrap(err) - return - } return } -func (c *coordinatorClient) client(ctx context.Context) (coordinatorproto.DRPCCoordinatorClient, error) { +func (c *coordinatorClient) FileLimitCheck(ctx context.Context, spaceId string, identity []byte) (limit uint64, err error) { + err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error { + resp, err := cl.FileLimitCheck(ctx, &coordinatorproto.FileLimitCheckRequest{ + AccountIdentity: identity, + SpaceId: spaceId, + }) + if err != nil { + return rpcerr.Unwrap(err) + } + limit = resp.Limit + return nil + }) + return +} + +func (c *coordinatorClient) NetworkConfiguration(ctx context.Context, currentId string) (resp *coordinatorproto.NetworkConfigurationResponse, err error) { + err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error { + resp, err = cl.NetworkConfiguration(ctx, &coordinatorproto.NetworkConfigurationRequest{ + CurrentId: currentId, + }) + if err != nil { + return rpcerr.Unwrap(err) + } + return nil + }) + return +} + +func (c *coordinatorClient) doClient(ctx context.Context, f func(cl coordinatorproto.DRPCCoordinatorClient) error) error { p, err := c.pool.GetOneOf(ctx, c.nodeConf.CoordinatorPeers()) if err != nil { - return nil, err + return err } - return coordinatorproto.NewDRPCCoordinatorClient(p), nil + return p.DoDrpc(ctx, func(conn drpc.Conn) error { + return f(coordinatorproto.NewDRPCCoordinatorClient(conn)) + }) } diff --git a/coordinator/coordinatorproto/coordinator_drpc.pb.go b/coordinator/coordinatorproto/coordinator_drpc.pb.go index 75e73a7b..0ed69ea2 100644 --- a/coordinator/coordinatorproto/coordinator_drpc.pb.go +++ b/coordinator/coordinatorproto/coordinator_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: coordinator/coordinatorproto/protos/coordinator.proto package coordinatorproto diff --git a/go.mod b/go.mod index 61bc0f44..5dd9ad82 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 + github.com/hashicorp/yamux v0.1.1 github.com/huandu/skiplist v1.2.0 github.com/ipfs/go-block-format v0.1.2 github.com/ipfs/go-blockservice v0.5.2 @@ -33,6 +34,7 @@ require ( github.com/tyler-smith/go-bip39 v1.1.0 github.com/zeebo/blake3 v0.2.3 github.com/zeebo/errs v1.3.0 + go.uber.org/atomic v1.11.0 go.uber.org/zap v1.24.0 golang.org/x/crypto v0.9.0 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 @@ -103,7 +105,6 @@ require ( github.com/whyrusleeping/chunker v0.0.0-20181014151217-fe64bd25879f // indirect go.opentelemetry.io/otel v1.7.0 // indirect go.opentelemetry.io/otel/trace v1.7.0 // indirect - go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.6.0 // indirect golang.org/x/sync v0.2.0 // indirect diff --git a/go.sum b/go.sum index 9abbf281..ed9b778b 100644 --- a/go.sum +++ b/go.sum @@ -79,6 +79,8 @@ github.com/gxed/hashland/keccakpg v0.0.1/go.mod h1:kRzw3HkwxFU1mpmPP8v1WyQzwdGfm github.com/gxed/hashland/murmur3 v0.0.1/go.mod h1:KjXop02n4/ckmZSnY2+HKcLud/tcmvhST0bie/0lS48= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= +github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/huandu/go-assert v1.1.5 h1:fjemmA7sSfYHJD7CUqs9qTwwfdNAx7/j2/ZlHXzNB3c= github.com/huandu/go-assert v1.1.5/go.mod h1:yOLvuqZwmcHIC5rIzrBhT7D3Q9c3GFnd0JrPVhn/06U= github.com/huandu/skiplist v1.2.0 h1:gox56QD77HzSC0w+Ws3MH3iie755GBJU1OER3h5VsYw= diff --git a/net/timeoutconn/conn.go b/net/connutil/timeout.go similarity index 82% rename from net/timeoutconn/conn.go rename to net/connutil/timeout.go index 11e80709..381998f9 100644 --- a/net/timeoutconn/conn.go +++ b/net/connutil/timeout.go @@ -1,4 +1,4 @@ -package timeoutconn +package connutil import ( "errors" @@ -10,18 +10,18 @@ import ( "go.uber.org/zap" ) -var log = logger.NewNamed("common.net.timeoutconn") +var log = logger.NewNamed("common.net.connutil") -type Conn struct { +type TimeoutConn struct { net.Conn timeout time.Duration } -func NewConn(conn net.Conn, timeout time.Duration) *Conn { - return &Conn{conn, timeout} +func NewConn(conn net.Conn, timeout time.Duration) *TimeoutConn { + return &TimeoutConn{conn, timeout} } -func (c *Conn) Write(p []byte) (n int, err error) { +func (c *TimeoutConn) Write(p []byte) (n int, err error) { for { if c.timeout != 0 { if e := c.Conn.SetWriteDeadline(time.Now().Add(c.timeout)); e != nil { diff --git a/net/connutil/usage.go b/net/connutil/usage.go new file mode 100644 index 00000000..826d9c74 --- /dev/null +++ b/net/connutil/usage.go @@ -0,0 +1,30 @@ +package connutil + +import ( + "go.uber.org/atomic" + "net" + "time" +) + +func NewLastUsageConn(conn net.Conn) *LastUsageConn { + return &LastUsageConn{Conn: conn} +} + +type LastUsageConn struct { + net.Conn + lastUsage atomic.Time +} + +func (c *LastUsageConn) Write(p []byte) (n int, err error) { + c.lastUsage.Store(time.Now()) + return c.Conn.Write(p) +} + +func (c *LastUsageConn) Read(p []byte) (n int, err error) { + c.lastUsage.Store(time.Now()) + return c.Conn.Read(p) +} + +func (c *LastUsageConn) LastUsage() time.Time { + return c.lastUsage.Load() +} diff --git a/net/dialer/dialer.go b/net/dialer/dialer.go deleted file mode 100644 index aa65da75..00000000 --- a/net/dialer/dialer.go +++ /dev/null @@ -1,137 +0,0 @@ -package dialer - -import ( - "context" - "errors" - "fmt" - "github.com/anyproto/any-sync/app" - "github.com/anyproto/any-sync/app/logger" - net2 "github.com/anyproto/any-sync/net" - "github.com/anyproto/any-sync/net/peer" - "github.com/anyproto/any-sync/net/secureservice" - "github.com/anyproto/any-sync/net/secureservice/handshake" - "github.com/anyproto/any-sync/net/timeoutconn" - "github.com/anyproto/any-sync/nodeconf" - "github.com/libp2p/go-libp2p/core/sec" - "go.uber.org/zap" - "net" - "storj.io/drpc" - "storj.io/drpc/drpcconn" - "storj.io/drpc/drpcmanager" - "storj.io/drpc/drpcwire" - "sync" - "time" -) - -const CName = "common.net.dialer" - -var ( - ErrAddrsNotFound = errors.New("addrs for peer not found") - ErrPeerIdIsUnexpected = errors.New("expected to connect with other peer id") -) - -var log = logger.NewNamed(CName) - -func New() Dialer { - return &dialer{peerAddrs: map[string][]string{}} -} - -type Dialer interface { - Dial(ctx context.Context, peerId string) (peer peer.Peer, err error) - SetPeerAddrs(peerId string, addrs []string) - app.Component -} - -type dialer struct { - transport secureservice.SecureService - config net2.Config - nodeConf nodeconf.NodeConf - peerAddrs map[string][]string - - mu sync.RWMutex -} - -func (d *dialer) Init(a *app.App) (err error) { - d.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService) - d.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) - d.config = a.MustComponent("config").(net2.ConfigGetter).GetNet() - return -} - -func (d *dialer) Name() (name string) { - return CName -} - -func (d *dialer) SetPeerAddrs(peerId string, addrs []string) { - d.mu.Lock() - defer d.mu.Unlock() - d.peerAddrs[peerId] = addrs -} - -func (d *dialer) getPeerAddrs(peerId string) ([]string, error) { - if addrs, ok := d.nodeConf.PeerAddresses(peerId); ok { - return addrs, nil - } - addrs, ok := d.peerAddrs[peerId] - if !ok || len(addrs) == 0 { - return nil, ErrAddrsNotFound - } - return addrs, nil -} - -func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err error) { - var ctxCancel context.CancelFunc - ctx, ctxCancel = context.WithTimeout(ctx, time.Second*10) - defer ctxCancel() - d.mu.RLock() - defer d.mu.RUnlock() - - addrs, err := d.getPeerAddrs(peerId) - if err != nil { - return - } - - var ( - conn drpc.Conn - sc sec.SecureConn - ) - log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs)) - for _, addr := range addrs { - conn, sc, err = d.handshake(ctx, addr, peerId) - if err != nil { - log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err)) - } else { - break - } - } - if err != nil { - return - } - return peer.NewPeer(sc, conn), nil -} - -func (d *dialer) handshake(ctx context.Context, addr, peerId string) (conn drpc.Conn, sc sec.SecureConn, err error) { - st := time.Now() - // TODO: move dial timeout to config - tcpConn, err := net.DialTimeout("tcp", addr, time.Second*15) - if err != nil { - return nil, nil, fmt.Errorf("dialTimeout error: %v; since start: %v", err, time.Since(st)) - } - - timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds)) - sc, err = d.transport.SecureOutbound(ctx, timeoutConn) - if err != nil { - if he, ok := err.(handshake.HandshakeError); ok { - return nil, nil, he - } - return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st)) - } - if peerId != sc.RemotePeer().String() { - return nil, nil, ErrPeerIdIsUnexpected - } - log.Info("connected with remote host", zap.String("serverPeer", sc.RemotePeer().String()), zap.String("addr", addr)) - conn = drpcconn.NewWithOptions(sc, drpcconn.Options{Manager: drpcmanager.Options{ - Reader: drpcwire.ReaderOptions{MaximumBufferSize: d.config.Stream.MaxMsgSizeMb * (1 << 20)}, - }}) - return conn, sc, err -} diff --git a/net/peer/peer.go b/net/peer/peer.go index 5bb8022a..879f6f9b 100644 --- a/net/peer/peer.go +++ b/net/peer/peer.go @@ -2,82 +2,146 @@ package peer import ( "context" - "sync/atomic" - "time" - "github.com/anyproto/any-sync/app/logger" - "github.com/libp2p/go-libp2p/core/sec" + "github.com/anyproto/any-sync/app/ocache" + "github.com/anyproto/any-sync/net/secureservice/handshake" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "github.com/anyproto/any-sync/net/transport" "go.uber.org/zap" + "io" + "net" "storj.io/drpc" + "storj.io/drpc/drpcconn" + "sync" + "time" ) var log = logger.NewNamed("common.net.peer") -func NewPeer(sc sec.SecureConn, conn drpc.Conn) Peer { - return &peer{ - id: sc.RemotePeer().String(), - lastUsage: time.Now().Unix(), - sc: sc, - Conn: conn, +type connCtrl interface { + ServeConn(ctx context.Context, conn net.Conn) (err error) +} + +func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) { + ctx := mc.Context() + pr := &peer{ + active: map[drpc.Conn]struct{}{}, + MultiConn: mc, + ctrl: ctrl, } + if pr.id, err = CtxPeerId(ctx); err != nil { + return + } + go pr.acceptLoop() + return pr, nil } type Peer interface { Id() string - LastUsage() time.Time - UpdateLastUsage() - Addr() string + + AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) + ReleaseDrpcConn(conn drpc.Conn) + DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error + + IsClosed() bool + TryClose(objectTTL time.Duration) (res bool, err error) - drpc.Conn + + ocache.Object } type peer struct { - id string - ttl time.Duration - lastUsage int64 - sc sec.SecureConn - drpc.Conn + id string + + ctrl connCtrl + + // drpc conn pool + inactive []drpc.Conn + active map[drpc.Conn]struct{} + + mu sync.Mutex + + transport.MultiConn } func (p *peer) Id() string { return p.id } -func (p *peer) LastUsage() time.Time { - select { - case <-p.Closed(): - return time.Unix(0, 0) - default: +func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { + p.mu.Lock() + defer p.mu.Unlock() + if len(p.inactive) == 0 { + conn, err := p.Open(ctx) + if err != nil { + return nil, err + } + dconn := drpcconn.New(conn) + p.inactive = append(p.inactive, dconn) } - return time.Unix(atomic.LoadInt64(&p.lastUsage), 0) + idx := len(p.inactive) - 1 + res := p.inactive[idx] + p.inactive = p.inactive[:idx] + p.active[res] = struct{}{} + return res, nil } -func (p *peer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { - defer p.UpdateLastUsage() - return p.Conn.Invoke(ctx, rpc, enc, in, out) -} - -func (p *peer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { - defer p.UpdateLastUsage() - return p.Conn.NewStream(ctx, rpc, enc) -} - -func (p *peer) Read(b []byte) (n int, err error) { - if n, err = p.sc.Read(b); err == nil { - p.UpdateLastUsage() +func (p *peer) ReleaseDrpcConn(conn drpc.Conn) { + p.mu.Lock() + defer p.mu.Unlock() + if _, ok := p.active[conn]; ok { + delete(p.active, conn) } + p.inactive = append(p.inactive, conn) return } -func (p *peer) Write(b []byte) (n int, err error) { - if n, err = p.sc.Write(b); err == nil { - p.UpdateLastUsage() +func (p *peer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error { + conn, err := p.AcquireDrpcConn(ctx) + if err != nil { + return err } - return + defer p.ReleaseDrpcConn(conn) + return do(conn) } -func (p *peer) UpdateLastUsage() { - atomic.StoreInt64(&p.lastUsage, time.Now().Unix()) +func (p *peer) acceptLoop() { + var exitErr error + defer func() { + if exitErr != transport.ErrConnClosed { + log.Warn("accept error: close connection", zap.Error(exitErr)) + _ = p.MultiConn.Close() + } + }() + for { + conn, err := p.Accept() + if err != nil { + exitErr = err + return + } + go func() { + serveErr := p.serve(conn) + if serveErr != io.EOF && serveErr != transport.ErrConnClosed { + log.InfoCtx(p.Context(), "serve connection error", zap.Error(serveErr)) + } + }() + } +} + +var defaultProtoChecker = handshake.ProtoChecker{ + AllowedProtoTypes: []handshakeproto.ProtoType{ + handshakeproto.ProtoType_DRPC, + }, +} + +func (p *peer) serve(conn net.Conn) (err error) { + hsCtx, cancel := context.WithTimeout(p.Context(), time.Second*20) + if _, err = handshake.IncomingProtoHandshake(hsCtx, conn, defaultProtoChecker); err != nil { + cancel() + return + } + cancel() + return p.ctrl.ServeConn(p.Context(), conn) } func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) { @@ -87,14 +151,7 @@ func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) { return true, p.Close() } -func (p *peer) Addr() string { - if p.sc != nil { - return p.sc.RemoteAddr().String() - } - return "" -} - func (p *peer) Close() (err error) { log.Debug("peer close", zap.String("peerId", p.id)) - return p.Conn.Close() + return p.MultiConn.Close() } diff --git a/net/peer/peer_test.go b/net/peer/peer_test.go new file mode 100644 index 00000000..00a14252 --- /dev/null +++ b/net/peer/peer_test.go @@ -0,0 +1,137 @@ +package peer + +import ( + "context" + "github.com/anyproto/any-sync/net/secureservice/handshake" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "github.com/anyproto/any-sync/net/transport/mock_transport" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "io" + "net" + "testing" + "time" +) + +var ctx = context.Background() + +func TestPeer_AcquireDrpcConn(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + in, out := net.Pipe() + defer out.Close() + fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil) + dc, err := fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + assert.NotEmpty(t, dc) + defer dc.Close() + + assert.Len(t, fx.active, 1) + assert.Len(t, fx.inactive, 0) + + fx.ReleaseDrpcConn(dc) + + assert.Len(t, fx.active, 0) + assert.Len(t, fx.inactive, 1) + + dc, err = fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + assert.NotEmpty(t, dc) + assert.Len(t, fx.active, 1) + assert.Len(t, fx.inactive, 0) +} + +func TestPeerAccept(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + in, out := net.Pipe() + defer out.Close() + + var outHandshakeCh = make(chan error) + go func() { + outHandshakeCh <- handshake.OutgoingProtoHandshake(ctx, out, handshakeproto.ProtoType_DRPC) + }() + fx.acceptCh <- acceptedConn{conn: in} + cn := <-fx.testCtrl.serveConn + assert.Equal(t, in, cn) + assert.NoError(t, <-outHandshakeCh) +} + +func TestPeer_TryClose(t *testing.T) { + t.Run("ttl", func(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + lu := time.Now() + fx.mc.EXPECT().LastUsage().Return(lu) + res, err := fx.TryClose(time.Second) + require.NoError(t, err) + assert.False(t, res) + }) + t.Run("close", func(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + lu := time.Now().Add(-time.Second * 2) + fx.mc.EXPECT().LastUsage().Return(lu) + res, err := fx.TryClose(time.Second) + require.NoError(t, err) + assert.True(t, res) + }) +} + +type acceptedConn struct { + conn net.Conn + err error +} + +func newFixture(t *testing.T, peerId string) *fixture { + fx := &fixture{ + ctrl: gomock.NewController(t), + acceptCh: make(chan acceptedConn), + testCtrl: newTesCtrl(), + } + fx.mc = mock_transport.NewMockMultiConn(fx.ctrl) + ctx := CtxWithPeerId(context.Background(), peerId) + fx.mc.EXPECT().Context().Return(ctx).AnyTimes() + fx.mc.EXPECT().Accept().DoAndReturn(func() (net.Conn, error) { + ac := <-fx.acceptCh + return ac.conn, ac.err + }).AnyTimes() + fx.mc.EXPECT().Close().AnyTimes() + p, err := NewPeer(fx.mc, fx.testCtrl) + require.NoError(t, err) + fx.peer = p.(*peer) + return fx +} + +type fixture struct { + *peer + ctrl *gomock.Controller + mc *mock_transport.MockMultiConn + acceptCh chan acceptedConn + testCtrl *testCtrl +} + +func (fx *fixture) finish() { + fx.testCtrl.close() + fx.ctrl.Finish() +} + +func newTesCtrl() *testCtrl { + return &testCtrl{closeCh: make(chan struct{}), serveConn: make(chan net.Conn, 10)} +} + +type testCtrl struct { + serveConn chan net.Conn + closeCh chan struct{} +} + +func (t *testCtrl) ServeConn(ctx context.Context, conn net.Conn) (err error) { + t.serveConn <- conn + <-t.closeCh + return io.EOF +} + +func (t *testCtrl) close() { + close(t.closeCh) +} diff --git a/net/peerservice/peerservice.go b/net/peerservice/peerservice.go new file mode 100644 index 00000000..e0691733 --- /dev/null +++ b/net/peerservice/peerservice.go @@ -0,0 +1,110 @@ +package peerservice + +import ( + "context" + "errors" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/net/pool" + "github.com/anyproto/any-sync/net/rpc/server" + "github.com/anyproto/any-sync/net/transport" + "github.com/anyproto/any-sync/net/transport/yamux" + "github.com/anyproto/any-sync/nodeconf" + "go.uber.org/zap" + "sync" +) + +const CName = "net.peerservice" + +var log = logger.NewNamed(CName) + +var ( + ErrAddrsNotFound = errors.New("addrs for peer not found") +) + +func New() PeerService { + return new(peerService) +} + +type PeerService interface { + Dial(ctx context.Context, peerId string) (pr peer.Peer, err error) + SetPeerAddrs(peerId string, addrs []string) + transport.Accepter + app.Component +} + +type peerService struct { + yamux transport.Transport + nodeConf nodeconf.NodeConf + peerAddrs map[string][]string + pool pool.Pool + server server.DRPCServer + mu sync.RWMutex +} + +func (p *peerService) Init(a *app.App) (err error) { + p.yamux = a.MustComponent(yamux.CName).(transport.Transport) + p.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) + p.pool = a.MustComponent(pool.CName).(pool.Pool) + p.server = a.MustComponent(server.CName).(server.DRPCServer) + p.peerAddrs = map[string][]string{} + return nil +} + +func (p *peerService) Name() (name string) { + return CName +} + +func (p *peerService) Dial(ctx context.Context, peerId string) (pr peer.Peer, err error) { + p.mu.RLock() + defer p.mu.RUnlock() + + addrs, err := p.getPeerAddrs(peerId) + if err != nil { + return + } + + var mc transport.MultiConn + log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs)) + for _, addr := range addrs { + mc, err = p.yamux.Dial(ctx, addr) + if err != nil { + log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err)) + } else { + break + } + } + if err != nil { + return + } + return peer.NewPeer(mc, p.server) +} + +func (p *peerService) Accept(mc transport.MultiConn) (err error) { + pr, err := peer.NewPeer(mc, p.server) + if err != nil { + return err + } + if err = p.pool.AddPeer(context.Background(), pr); err != nil { + _ = pr.Close() + } + return +} + +func (p *peerService) SetPeerAddrs(peerId string, addrs []string) { + p.mu.Lock() + defer p.mu.Unlock() + p.peerAddrs[peerId] = addrs +} + +func (p *peerService) getPeerAddrs(peerId string) ([]string, error) { + if addrs, ok := p.nodeConf.PeerAddresses(peerId); ok { + return addrs, nil + } + addrs, ok := p.peerAddrs[peerId] + if !ok || len(addrs) == 0 { + return nil, ErrAddrsNotFound + } + return addrs, nil +} diff --git a/net/pool/pool.go b/net/pool/pool.go index b6c0d7df..37f8328e 100644 --- a/net/pool/pool.go +++ b/net/pool/pool.go @@ -4,7 +4,6 @@ import ( "context" "github.com/anyproto/any-sync/app/ocache" "github.com/anyproto/any-sync/net" - "github.com/anyproto/any-sync/net/dialer" "github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/secureservice/handshake" "go.uber.org/zap" @@ -13,59 +12,61 @@ import ( // Pool creates and caches outgoing connection type Pool interface { - // Get lookups to peer in existing connections or creates and cache new one + // Get lookups to peer in existing connections or creates and outgoing new one Get(ctx context.Context, id string) (peer.Peer, error) - // Dial creates new connection to peer and not use cache - Dial(ctx context.Context, id string) (peer.Peer, error) - // GetOneOf searches at least one existing connection in cache or creates a new one from a randomly selected id from given list + // GetOneOf searches at least one existing connection in outgoing or creates a new one from a randomly selected id from given list GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) - - DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) + // AddPeer adds incoming peer to the pool + AddPeer(ctx context.Context, p peer.Peer) (err error) } type pool struct { - cache ocache.OCache - dialer dialer.Dialer + outgoing ocache.OCache + incoming ocache.OCache } func (p *pool) Name() (name string) { return CName } -func (p *pool) Run(ctx context.Context) (err error) { - return nil +func (p *pool) Get(ctx context.Context, id string) (pr peer.Peer, err error) { + // if we have incoming connection - try to reuse it + if pr, err = p.get(ctx, p.incoming, id); err != nil { + // or try to get or create outgoing + return p.get(ctx, p.outgoing, id) + } + return } -func (p *pool) Get(ctx context.Context, id string) (peer.Peer, error) { - v, err := p.cache.Get(ctx, id) +func (p *pool) get(ctx context.Context, source ocache.OCache, id string) (peer.Peer, error) { + v, err := source.Get(ctx, id) if err != nil { return nil, err } pr := v.(peer.Peer) - select { - case <-pr.Closed(): - default: + if !pr.IsClosed() { return pr, nil } - _, _ = p.cache.Remove(ctx, id) + _, _ = source.Remove(ctx, id) return p.Get(ctx, id) } -func (p *pool) Dial(ctx context.Context, id string) (peer.Peer, error) { - return p.dialer.Dial(ctx, id) -} - func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { // finding existing connection for _, peerId := range peerIds { - if v, err := p.cache.Pick(ctx, peerId); err == nil { + if v, err := p.incoming.Pick(ctx, peerId); err == nil { pr := v.(peer.Peer) - select { - case <-pr.Closed(): - default: + if !pr.IsClosed() { return pr, nil } - _, _ = p.cache.Remove(ctx, peerId) + _, _ = p.incoming.Remove(ctx, peerId) + } + if v, err := p.outgoing.Pick(ctx, peerId); err == nil { + pr := v.(peer.Peer) + if !pr.IsClosed() { + return pr, nil + } + _, _ = p.outgoing.Remove(ctx, peerId) } } // shuffle ids for better consistency @@ -75,8 +76,8 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error // connecting var lastErr error for _, peerId := range peerIds { - if v, err := p.cache.Get(ctx, peerId); err == nil { - return v.(peer.Peer), nil + if v, err := p.Get(ctx, peerId); err == nil { + return v, nil } else { log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err)) lastErr = err @@ -88,27 +89,18 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error return nil, lastErr } -func (p *pool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { - // shuffle ids for better consistency - rand.Shuffle(len(peerIds), func(i, j int) { - peerIds[i], peerIds[j] = peerIds[j], peerIds[i] - }) - // connecting - var lastErr error - for _, peerId := range peerIds { - if v, err := p.dialer.Dial(ctx, peerId); err == nil { - return v.(peer.Peer), nil +func (p *pool) AddPeer(ctx context.Context, pr peer.Peer) (err error) { + if err = p.incoming.Add(pr.Id(), pr); err != nil { + if err == ocache.ErrExists { + // in case when an incoming connection with a peer already exists, we close and remove an existing connection + if v, e := p.incoming.Pick(ctx, pr.Id()); e == nil { + _ = v.Close() + _, _ = p.incoming.Remove(ctx, pr.Id()) + return p.incoming.Add(pr.Id(), pr) + } } else { - log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err)) - lastErr = err + return err } } - if _, ok := lastErr.(handshake.HandshakeError); !ok { - lastErr = net.ErrUnableToConnect - } - return nil, lastErr -} - -func (p *pool) Close(ctx context.Context) (err error) { - return p.cache.Close() + return } diff --git a/net/pool/pool_test.go b/net/pool/pool_test.go index c82533e8..c93c9aec 100644 --- a/net/pool/pool_test.go +++ b/net/pool/pool_test.go @@ -6,11 +6,11 @@ import ( "fmt" "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/net" - "github.com/anyproto/any-sync/net/dialer" "github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + net2 "net" "storj.io/drpc" "testing" "time" @@ -133,6 +133,27 @@ func TestPool_GetOneOf(t *testing.T) { }) } +func TestPool_AddPeer(t *testing.T) { + t.Run("success", func(t *testing.T) { + fx := newFixture(t) + defer fx.Finish() + require.NoError(t, fx.AddPeer(ctx, newTestPeer("p1"))) + }) + t.Run("two peers", func(t *testing.T) { + fx := newFixture(t) + defer fx.Finish() + p1, p2 := newTestPeer("p1"), newTestPeer("p1") + require.NoError(t, fx.AddPeer(ctx, p1)) + require.NoError(t, fx.AddPeer(ctx, p2)) + select { + case <-p1.closed: + default: + assert.Truef(t, false, "peer not closed") + } + }) + +} + func newFixture(t *testing.T) *fixture { fx := &fixture{ Service: New(), @@ -158,7 +179,7 @@ type fixture struct { t *testing.T } -var _ dialer.Dialer = (*dialerMock)(nil) +var _ dialer = (*dialerMock)(nil) type dialerMock struct { dial func(ctx context.Context, peerId string) (peer peer.Peer, err error) @@ -181,7 +202,7 @@ func (d *dialerMock) Init(a *app.App) (err error) { } func (d *dialerMock) Name() (name string) { - return dialer.CName + return "net.peerservice" } func newTestPeer(id string) *testPeer { @@ -196,6 +217,31 @@ type testPeer struct { closed chan struct{} } +func (t *testPeer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error { + return fmt.Errorf("not implemented") +} + +func (t *testPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { + return nil, fmt.Errorf("not implemented") +} + +func (t *testPeer) ReleaseDrpcConn(conn drpc.Conn) {} + +func (t *testPeer) Context() context.Context { + //TODO implement me + panic("implement me") +} + +func (t *testPeer) Accept() (conn net2.Conn, err error) { + //TODO implement me + panic("implement me") +} + +func (t *testPeer) Open(ctx context.Context) (conn net2.Conn, err error) { + //TODO implement me + panic("implement me") +} + func (t *testPeer) Addr() string { return "" } @@ -204,12 +250,6 @@ func (t *testPeer) Id() string { return t.id } -func (t *testPeer) LastUsage() time.Time { - return time.Now() -} - -func (t *testPeer) UpdateLastUsage() {} - func (t *testPeer) TryClose(objectTTL time.Duration) (res bool, err error) { return true, t.Close() } @@ -224,14 +264,11 @@ func (t *testPeer) Close() error { return nil } -func (t *testPeer) Closed() <-chan struct{} { - return t.closed -} - -func (t *testPeer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { - return fmt.Errorf("call Invoke on test peer") -} - -func (t *testPeer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { - return nil, fmt.Errorf("call NewStream on test peer") +func (t *testPeer) IsClosed() bool { + select { + case <-t.closed: + return true + default: + return false + } } diff --git a/net/pool/poolservice.go b/net/pool/poolservice.go index 9b69ae24..2f84e5d0 100644 --- a/net/pool/poolservice.go +++ b/net/pool/poolservice.go @@ -6,8 +6,9 @@ import ( "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/ocache" "github.com/anyproto/any-sync/metric" - "github.com/anyproto/any-sync/net/dialer" + "github.com/anyproto/any-sync/net/peer" "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" "time" ) @@ -23,46 +24,54 @@ func New() Service { type Service interface { Pool - NewPool(name string) Pool app.ComponentRunnable } +type dialer interface { + Dial(ctx context.Context, peerId string) (pr peer.Peer, err error) +} + type poolService struct { // default pool *pool - dialer dialer.Dialer + dialer dialer metricReg *prometheus.Registry } func (p *poolService) Init(a *app.App) (err error) { - p.dialer = a.MustComponent(dialer.CName).(dialer.Dialer) - p.pool = &pool{dialer: p.dialer} + p.dialer = a.MustComponent("net.peerservice").(dialer) + p.pool = &pool{} if m := a.Component(metric.CName); m != nil { p.metricReg = m.(metric.Metric).Registry() } - p.pool.cache = ocache.New( + p.pool.outgoing = ocache.New( func(ctx context.Context, id string) (value ocache.Object, err error) { return p.dialer.Dial(ctx, id) }, ocache.WithLogger(log.Sugar()), ocache.WithGCPeriod(time.Minute), ocache.WithTTL(time.Minute*5), - ocache.WithPrometheus(p.metricReg, "netpool", "default"), + ocache.WithPrometheus(p.metricReg, "netpool", "outgoing"), + ) + p.pool.incoming = ocache.New( + func(ctx context.Context, id string) (value ocache.Object, err error) { + return nil, ocache.ErrNotExists + }, + ocache.WithLogger(log.Sugar()), + ocache.WithGCPeriod(time.Minute), + ocache.WithTTL(time.Minute*5), + ocache.WithPrometheus(p.metricReg, "netpool", "incoming"), ) return nil } -func (p *poolService) NewPool(name string) Pool { - return &pool{ - dialer: p.dialer, - cache: ocache.New( - func(ctx context.Context, id string) (value ocache.Object, err error) { - return p.dialer.Dial(ctx, id) - }, - ocache.WithLogger(log.Sugar()), - ocache.WithGCPeriod(time.Minute), - ocache.WithTTL(time.Minute*5), - ocache.WithPrometheus(p.metricReg, "netpool", name), - ), - } +func (p *pool) Run(ctx context.Context) (err error) { + return nil +} + +func (p *pool) Close(ctx context.Context) (err error) { + if e := p.incoming.Close(); e != nil { + log.Warn("close incoming cache error", zap.Error(e)) + } + return p.outgoing.Close() } diff --git a/net/rpc/server/baseserver.go b/net/rpc/server/baseserver.go deleted file mode 100644 index cb3047ed..00000000 --- a/net/rpc/server/baseserver.go +++ /dev/null @@ -1,134 +0,0 @@ -package server - -import ( - "context" - "github.com/anyproto/any-sync/net/peer" - "github.com/anyproto/any-sync/net/secureservice" - "github.com/libp2p/go-libp2p/core/sec" - "github.com/zeebo/errs" - "go.uber.org/zap" - "io" - "net" - "storj.io/drpc" - "storj.io/drpc/drpcmanager" - "storj.io/drpc/drpcmux" - "storj.io/drpc/drpcserver" - "storj.io/drpc/drpcwire" - "time" -) - -type BaseDrpcServer struct { - drpcServer *drpcserver.Server - transport secureservice.SecureService - listeners []net.Listener - handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) - cancel func() - *drpcmux.Mux -} - -type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler - -type Params struct { - BufferSizeMb int - ListenAddrs []string - Wrapper DRPCHandlerWrapper - TimeoutMillis int - Handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) -} - -func NewBaseDrpcServer() *BaseDrpcServer { - return &BaseDrpcServer{Mux: drpcmux.New()} -} - -func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) { - s.drpcServer = drpcserver.NewWithOptions(params.Wrapper(s.Mux), drpcserver.Options{Manager: drpcmanager.Options{ - Reader: drpcwire.ReaderOptions{MaximumBufferSize: params.BufferSizeMb * (1 << 20)}, - }}) - s.handshake = params.Handshake - ctx, s.cancel = context.WithCancel(ctx) - for _, addr := range params.ListenAddrs { - list, err := net.Listen("tcp", addr) - if err != nil { - return err - } - s.listeners = append(s.listeners, list) - go s.serve(ctx, list) - } - return -} - -func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) { - l := log.With(zap.String("localAddr", lis.Addr().String())) - l.Info("drpc listener started") - defer func() { - l.Debug("drpc listener stopped") - }() - for { - select { - case <-ctx.Done(): - return - default: - } - conn, err := lis.Accept() - if err != nil { - if isTemporary(err) { - l.Debug("listener temporary accept error", zap.Error(err)) - select { - case <-time.After(time.Second): - case <-ctx.Done(): - return - } - continue - } - l.Error("listener accept error", zap.Error(err)) - return - } - go s.serveConn(conn) - } -} - -func (s *BaseDrpcServer) serveConn(conn net.Conn) { - l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String())) - var ( - ctx = context.Background() - err error - ) - if s.handshake != nil { - ctx, conn, err = s.handshake(conn) - if err != nil { - l.Info("handshake error", zap.Error(err)) - return - } - if sc, ok := conn.(sec.SecureConn); ok { - ctx = peer.CtxWithPeerId(ctx, sc.RemotePeer().String()) - } - } - ctx = peer.CtxWithPeerAddr(ctx, conn.RemoteAddr().String()) - l.Debug("connection opened") - if err := s.drpcServer.ServeOne(ctx, conn); err != nil { - if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) { - l.Debug("connection closed") - } else { - l.Warn("serve connection error", zap.Error(err)) - } - } -} - -func (s *BaseDrpcServer) ListenAddrs() (addrs []net.Addr) { - for _, list := range s.listeners { - addrs = append(addrs, list.Addr()) - } - return -} - -func (s *BaseDrpcServer) Close(ctx context.Context) (err error) { - if s.cancel != nil { - s.cancel() - } - for _, l := range s.listeners { - if e := l.Close(); e != nil { - log.Warn("close listener error", zap.Error(e)) - } - } - return -} diff --git a/net/rpc/server/drpcserver.go b/net/rpc/server/drpcserver.go index 1874d16a..2b061515 100644 --- a/net/rpc/server/drpcserver.go +++ b/net/rpc/server/drpcserver.go @@ -6,11 +6,13 @@ import ( "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/metric" anyNet "github.com/anyproto/any-sync/net" - "github.com/anyproto/any-sync/net/secureservice" - "github.com/libp2p/go-libp2p/core/sec" + "go.uber.org/zap" "net" "storj.io/drpc" - "time" + "storj.io/drpc/drpcmanager" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + "storj.io/drpc/drpcwire" ) const CName = "common.net.drpcserver" @@ -18,49 +20,46 @@ const CName = "common.net.drpcserver" var log = logger.NewNamed(CName) func New() DRPCServer { - return &drpcServer{BaseDrpcServer: NewBaseDrpcServer()} + return &drpcServer{} } type DRPCServer interface { - app.ComponentRunnable + ServeConn(ctx context.Context, conn net.Conn) (err error) + app.Component drpc.Mux } type drpcServer struct { - config anyNet.Config - metric metric.Metric - transport secureservice.SecureService - *BaseDrpcServer + drpcServer *drpcserver.Server + *drpcmux.Mux + config anyNet.Config + metric metric.Metric } -func (s *drpcServer) Init(a *app.App) (err error) { - s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet() - s.metric = a.MustComponent(metric.CName).(metric.Metric) - s.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService) - return nil -} +type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler func (s *drpcServer) Name() (name string) { return CName } -func (s *drpcServer) Run(ctx context.Context) (err error) { - params := Params{ - BufferSizeMb: s.config.Stream.MaxMsgSizeMb, - TimeoutMillis: s.config.Stream.TimeoutMilliseconds, - ListenAddrs: s.config.Server.ListenAddrs, - Wrapper: func(handler drpc.Handler) drpc.Handler { - return s.metric.WrapDRPCHandler(handler) - }, - Handshake: func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - return s.transport.SecureInbound(ctx, conn) - }, +func (s *drpcServer) Init(a *app.App) (err error) { + s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet() + s.metric, _ = a.Component(metric.CName).(metric.Metric) + s.Mux = drpcmux.New() + + var handler drpc.Handler + handler = s + if s.metric != nil { + handler = s.metric.WrapDRPCHandler(s) } - return s.BaseDrpcServer.Run(ctx, params) + s.drpcServer = drpcserver.NewWithOptions(handler, drpcserver.Options{Manager: drpcmanager.Options{ + Reader: drpcwire.ReaderOptions{MaximumBufferSize: s.config.Stream.MaxMsgSizeMb * (1 << 20)}, + }}) + return } -func (s *drpcServer) Close(ctx context.Context) (err error) { - return s.BaseDrpcServer.Close(ctx) +func (s *drpcServer) ServeConn(ctx context.Context, conn net.Conn) (err error) { + l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String())) + l.Debug("drpc serve peer") + return s.drpcServer.ServeOne(ctx, conn) } diff --git a/net/secureservice/handshake/credential.go b/net/secureservice/handshake/credential.go new file mode 100644 index 00000000..06108928 --- /dev/null +++ b/net/secureservice/handshake/credential.go @@ -0,0 +1,125 @@ +package handshake + +import ( + "context" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "github.com/libp2p/go-libp2p/core/sec" +) + +func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + identity, err = outgoingHandshake(h, sc, cc) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = sc.Close() + return nil, ctx.Err() + } +} + +func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { + defer h.release() + h.conn = sc + localCred := cc.MakeCredentials(sc) + if err = h.writeCredentials(localCred); err != nil { + h.tryWriteErrAndClose(err) + return + } + msg, err := h.readMsg(msgTypeAck, msgTypeCred) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if msg.ack != nil { + if msg.ack.Error == handshakeproto.Error_InvalidCredentials { + return nil, ErrPeerDeclinedCredentials + } + return nil, HandshakeError{e: msg.ack.Error} + } + + if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { + h.tryWriteErrAndClose(err) + return + } + + if err = h.writeAck(handshakeproto.Error_Null); err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + + msg, err = h.readMsg(msgTypeAck) + if err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + if msg.ack.Error == handshakeproto.Error_Null { + return identity, nil + } else { + _ = h.conn.Close() + return nil, HandshakeError{e: msg.ack.Error} + } +} + +func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + identity, err = incomingHandshake(h, sc, cc) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = sc.Close() + return nil, ctx.Err() + } +} + +func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { + defer h.release() + h.conn = sc + + msg, err := h.readMsg(msgTypeCred) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { + h.tryWriteErrAndClose(err) + return + } + + if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + + msg, err = h.readMsg(msgTypeAck) + if err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + if msg.ack.Error != handshakeproto.Error_Null { + if msg.ack.Error == handshakeproto.Error_InvalidCredentials { + return nil, ErrPeerDeclinedCredentials + } + return nil, HandshakeError{e: msg.ack.Error} + } + if err = h.writeAck(handshakeproto.Error_Null); err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + return +} diff --git a/net/secureservice/handshake/handshake_test.go b/net/secureservice/handshake/credential_test.go similarity index 94% rename from net/secureservice/handshake/handshake_test.go rename to net/secureservice/handshake/credential_test.go index 0d8b16d7..6a34f9cb 100644 --- a/net/secureservice/handshake/handshake_test.go +++ b/net/secureservice/handshake/credential_test.go @@ -38,15 +38,14 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto) require.NoError(t, err) - require.Nil(t, msg.ack) _, err = noVerifyChecker.CheckCredential(c2, msg.cred) require.NoError(t, err) // send credential message require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // receive ack - msg, err = h.readMsg() + msg, err = h.readMsg(msgTypeAck) require.NoError(t, err) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) // send ack @@ -76,7 +75,7 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) _ = c2.Close() res := <-handshakeResCh @@ -92,7 +91,7 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.NoError(t, h.writeAck(ErrInvalidCredentials.e)) res := <-handshakeResCh @@ -108,10 +107,10 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error) res := <-handshakeResCh @@ -127,7 +126,7 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials and close conn require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) @@ -145,12 +144,12 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // read ack and close conn - _, err = h.readMsg() + _, err = h.readMsg(msgTypeAck) require.NoError(t, err) _ = c2.Close() res := <-handshakeResCh @@ -166,18 +165,17 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // read ack - _, err = h.readMsg() + _, err = h.readMsg(msgTypeAck) require.NoError(t, err) // write cred instead ack require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) - msg, err := h.readMsg() - require.NoError(t, err) - assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) + _, err = h.readMsg(msgTypeAck) + require.Error(t, err) res := <-handshakeResCh require.Error(t, res.err) }) @@ -191,7 +189,7 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) _, err = noVerifyChecker.CheckCredential(c2, msg.cred) @@ -199,7 +197,7 @@ func TestOutgoingHandshake(t *testing.T) { // send credential message require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // receive ack - msg, err = h.readMsg() + msg, err = h.readMsg(msgTypeAck) require.NoError(t, err) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) // send ack @@ -219,7 +217,7 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) ctxCancel() res := <-handshakeResCh @@ -244,14 +242,14 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) // write ack require.NoError(t, h.writeAck(handshakeproto.Error_Null)) // wait ack - msg, err = h.readMsg() + msg, err = h.readMsg(msgTypeAck) require.NoError(t, err) assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error) res := <-handshakeResCh @@ -310,7 +308,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // except ack with error - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) require.Nil(t, msg.cred) require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error) @@ -330,7 +328,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // except ack with error - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) require.Nil(t, msg.cred) require.Equal(t, handshakeproto.Error_IncompatibleVersion, msg.ack.Error) @@ -350,13 +348,13 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // read cred - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write cred instead ack require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) - // expect ack with error - msg, err := h.readMsg() - require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) + // expect EOF + _, err = h.readMsg(msgTypeAck) + require.Error(t, err) res := <-handshakeResCh require.Error(t, res.err) }) @@ -372,7 +370,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // read cred and close conn - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) _ = c2.Close() @@ -391,7 +389,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) @@ -413,7 +411,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) @@ -435,7 +433,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) @@ -458,7 +456,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // wait credentials - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) ctxCancel() res := <-handshakeResCh @@ -482,7 +480,7 @@ func TestNotAHandshakeMessage(t *testing.T) { _, err := c2.Write([]byte("some unexpected bytes")) require.Error(t, err) res := <-handshakeResCh - assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error()) + assert.Error(t, res.err) } func TestEndToEnd(t *testing.T) { diff --git a/net/secureservice/handshake/handshake.go b/net/secureservice/handshake/handshake.go index 04a9de72..abbafeb5 100644 --- a/net/secureservice/handshake/handshake.go +++ b/net/secureservice/handshake/handshake.go @@ -1,21 +1,30 @@ package handshake import ( - "context" "encoding/binary" "errors" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/libp2p/go-libp2p/core/sec" "golang.org/x/exp/slices" "io" + "net" "sync" ) const headerSize = 5 // 1 byte for type + 4 byte for uint32 size const ( - msgTypeCred = byte(1) - msgTypeAck = byte(2) + msgTypeCred = byte(1) + msgTypeAck = byte(2) + msgTypeProto = byte(3) + + sizeLimit = 200 * 1024 // 200 Kb +) + +var ( + credMsgTypes = []byte{msgTypeCred, msgTypeAck} + protoMsgTypes = []byte{msgTypeProto, msgTypeAck} + protoMsgTypesAck = []byte{msgTypeAck} ) type HandshakeError struct { @@ -38,17 +47,20 @@ var ( ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed} ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected} - ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion} + ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion} + ErrIncompatibleProto = HandshakeError{e: handshakeproto.Error_IncompatibleProto} + ErrRemoteIncompatibleProto = HandshakeError{Err: errors.New("remote peer declined the proto")} - ErrGotNotAHandshakeMessage = errors.New("go not a handshake message") + ErrGotUnexpectedMessage = errors.New("go not a handshake message") ) var handshakePool = &sync.Pool{New: func() any { return &handshake{ - remoteCred: &handshakeproto.Credentials{}, - remoteAck: &handshakeproto.Ack{}, - localAck: &handshakeproto.Ack{}, - buf: make([]byte, 0, 1024), + remoteCred: &handshakeproto.Credentials{}, + remoteAck: &handshakeproto.Ack{}, + localAck: &handshakeproto.Ack{}, + remoteProto: &handshakeproto.Proto{}, + buf: make([]byte, 0, 1024), } }} @@ -57,147 +69,17 @@ type CredentialChecker interface { CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) } -func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - if ctx == nil { - ctx = context.Background() - } - h := newHandshake() - done := make(chan struct{}) - go func() { - defer close(done) - identity, err = outgoingHandshake(h, sc, cc) - }() - select { - case <-done: - return - case <-ctx.Done(): - _ = sc.Close() - return nil, ctx.Err() - } -} - -func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - defer h.release() - h.conn = sc - localCred := cc.MakeCredentials(sc) - if err = h.writeCredentials(localCred); err != nil { - h.tryWriteErrAndClose(err) - return - } - msg, err := h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return - } - if msg.ack != nil { - if msg.ack.Error == handshakeproto.Error_InvalidCredentials { - return nil, ErrPeerDeclinedCredentials - } - return nil, HandshakeError{e: msg.ack.Error} - } - - if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { - h.tryWriteErrAndClose(err) - return - } - - if err = h.writeAck(handshakeproto.Error_Null); err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - - msg, err = h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack == nil { - err = ErrUnexpectedPayload - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack.Error == handshakeproto.Error_Null { - return identity, nil - } else { - _ = h.conn.Close() - return nil, HandshakeError{e: msg.ack.Error} - } -} - -func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - if ctx == nil { - ctx = context.Background() - } - h := newHandshake() - done := make(chan struct{}) - go func() { - defer close(done) - identity, err = incomingHandshake(h, sc, cc) - }() - select { - case <-done: - return - case <-ctx.Done(): - _ = sc.Close() - return nil, ctx.Err() - } -} - -func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - defer h.release() - h.conn = sc - - msg, err := h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return - } - if msg.ack != nil { - return nil, ErrUnexpectedPayload - } - if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { - h.tryWriteErrAndClose(err) - return - } - - if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - - msg, err = h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack == nil { - err = ErrUnexpectedPayload - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack.Error != handshakeproto.Error_Null { - if msg.ack.Error == handshakeproto.Error_InvalidCredentials { - return nil, ErrPeerDeclinedCredentials - } - return nil, HandshakeError{e: msg.ack.Error} - } - if err = h.writeAck(handshakeproto.Error_Null); err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - return -} - func newHandshake() *handshake { return handshakePool.Get().(*handshake) } type handshake struct { - conn sec.SecureConn - remoteCred *handshakeproto.Credentials - remoteAck *handshakeproto.Ack - localAck *handshakeproto.Ack - buf []byte + conn net.Conn + remoteCred *handshakeproto.Credentials + remoteProto *handshakeproto.Proto + remoteAck *handshakeproto.Ack + localAck *handshakeproto.Ack + buf []byte } func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err error) { @@ -209,8 +91,17 @@ func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err erro return h.writeData(msgTypeCred, n) } +func (h *handshake) writeProto(proto *handshakeproto.Proto) (err error) { + h.buf = slices.Grow(h.buf, proto.Size()+headerSize)[:proto.Size()+headerSize] + n, err := proto.MarshalToSizedBuffer(h.buf[headerSize:]) + if err != nil { + return err + } + return h.writeData(msgTypeProto, n) +} + func (h *handshake) tryWriteErrAndClose(err error) { - if err == ErrGotNotAHandshakeMessage { + if err == ErrUnexpectedPayload { // if we got unexpected message - just close the connection _ = h.conn.Close() return @@ -243,21 +134,26 @@ func (h *handshake) writeData(tp byte, size int) (err error) { } type message struct { - cred *handshakeproto.Credentials - ack *handshakeproto.Ack + cred *handshakeproto.Credentials + proto *handshakeproto.Proto + ack *handshakeproto.Ack } -func (h *handshake) readMsg() (msg message, err error) { +func (h *handshake) readMsg(allowedTypes ...byte) (msg message, err error) { h.buf = slices.Grow(h.buf, headerSize)[:headerSize] if _, err = io.ReadFull(h.conn, h.buf[:headerSize]); err != nil { return } tp := h.buf[0] - if tp != msgTypeCred && tp != msgTypeAck { - err = ErrGotNotAHandshakeMessage + if !slices.Contains(allowedTypes, tp) { + err = ErrUnexpectedPayload return } size := binary.LittleEndian.Uint32(h.buf[1:headerSize]) + if size > sizeLimit { + err = ErrGotUnexpectedMessage + return + } h.buf = slices.Grow(h.buf, int(size))[:size] if _, err = io.ReadFull(h.conn, h.buf[:size]); err != nil { return @@ -273,6 +169,11 @@ func (h *handshake) readMsg() (msg message, err error) { return } msg.ack = h.remoteAck + case msgTypeProto: + if err = h.remoteProto.Unmarshal(h.buf[:size]); err != nil { + return + } + msg.proto = h.remoteProto } return } @@ -284,5 +185,6 @@ func (h *handshake) release() { h.remoteAck.Error = 0 h.remoteCred.Type = 0 h.remoteCred.Payload = h.remoteCred.Payload[:0] + h.remoteProto.Proto = 0 handshakePool.Put(h) } diff --git a/net/secureservice/handshake/handshakeproto/handshake.pb.go b/net/secureservice/handshake/handshakeproto/handshake.pb.go index 3d868ef0..e9d6dfcb 100644 --- a/net/secureservice/handshake/handshakeproto/handshake.pb.go +++ b/net/secureservice/handshake/handshakeproto/handshake.pb.go @@ -59,6 +59,7 @@ const ( Error_SkipVerifyNotAllowed Error = 4 Error_DeadlineExceeded Error = 5 Error_IncompatibleVersion Error = 6 + Error_IncompatibleProto Error = 7 ) var Error_name = map[int32]string{ @@ -69,6 +70,7 @@ var Error_name = map[int32]string{ 4: "SkipVerifyNotAllowed", 5: "DeadlineExceeded", 6: "IncompatibleVersion", + 7: "IncompatibleProto", } var Error_value = map[string]int32{ @@ -79,6 +81,7 @@ var Error_value = map[string]int32{ "SkipVerifyNotAllowed": 4, "DeadlineExceeded": 5, "IncompatibleVersion": 6, + "IncompatibleProto": 7, } func (x Error) String() string { @@ -89,6 +92,28 @@ func (Error) EnumDescriptor() ([]byte, []int) { return fileDescriptor_60283fc75f020893, []int{1} } +type ProtoType int32 + +const ( + ProtoType_DRPC ProtoType = 0 +) + +var ProtoType_name = map[int32]string{ + 0: "DRPC", +} + +var ProtoType_value = map[string]int32{ + "DRPC": 0, +} + +func (x ProtoType) String() string { + return proto.EnumName(ProtoType_name, int32(x)) +} + +func (ProtoType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_60283fc75f020893, []int{2} +} + type Credentials struct { Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"` Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` @@ -247,12 +272,58 @@ func (m *Ack) GetError() Error { return Error_Null } +type Proto struct { + Proto ProtoType `protobuf:"varint,1,opt,name=proto,proto3,enum=anyHandshake.ProtoType" json:"proto,omitempty"` +} + +func (m *Proto) Reset() { *m = Proto{} } +func (m *Proto) String() string { return proto.CompactTextString(m) } +func (*Proto) ProtoMessage() {} +func (*Proto) Descriptor() ([]byte, []int) { + return fileDescriptor_60283fc75f020893, []int{3} +} +func (m *Proto) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Proto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Proto.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Proto) XXX_Merge(src proto.Message) { + xxx_messageInfo_Proto.Merge(m, src) +} +func (m *Proto) XXX_Size() int { + return m.Size() +} +func (m *Proto) XXX_DiscardUnknown() { + xxx_messageInfo_Proto.DiscardUnknown(m) +} + +var xxx_messageInfo_Proto proto.InternalMessageInfo + +func (m *Proto) GetProto() ProtoType { + if m != nil { + return m.Proto + } + return ProtoType_DRPC +} + func init() { proto.RegisterEnum("anyHandshake.CredentialsType", CredentialsType_name, CredentialsType_value) proto.RegisterEnum("anyHandshake.Error", Error_name, Error_value) + proto.RegisterEnum("anyHandshake.ProtoType", ProtoType_name, ProtoType_value) proto.RegisterType((*Credentials)(nil), "anyHandshake.Credentials") proto.RegisterType((*PayloadSignedPeerIds)(nil), "anyHandshake.PayloadSignedPeerIds") proto.RegisterType((*Ack)(nil), "anyHandshake.Ack") + proto.RegisterType((*Proto)(nil), "anyHandshake.Proto") } func init() { @@ -260,32 +331,35 @@ func init() { } var fileDescriptor_60283fc75f020893 = []byte{ - // 395 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcd, 0x6e, 0x13, 0x31, - 0x10, 0xc7, 0xd7, 0x4d, 0x52, 0xaa, 0x21, 0x2d, 0xee, 0x34, 0xc0, 0x0a, 0x89, 0x55, 0x94, 0x53, - 0xc8, 0x21, 0xe1, 0xeb, 0x05, 0x02, 0x2d, 0x22, 0x97, 0xaa, 0xda, 0x42, 0x0f, 0xdc, 0xdc, 0xf5, - 0xd0, 0x5a, 0x31, 0xf6, 0xca, 0x76, 0x43, 0xf7, 0x2d, 0xb8, 0xf2, 0x46, 0x1c, 0x7b, 0xe4, 0x88, - 0x92, 0x17, 0x41, 0x71, 0x12, 0x92, 0x70, 0xea, 0xc5, 0x9e, 0x8f, 0x9f, 0xfd, 0xff, 0x8f, 0x65, - 0x18, 0x1a, 0x0a, 0x03, 0x4f, 0xc5, 0x8d, 0x23, 0x4f, 0x6e, 0xa2, 0x0a, 0x1a, 0x5c, 0x0b, 0x23, - 0xfd, 0xb5, 0x18, 0x6f, 0x44, 0xa5, 0xb3, 0xc1, 0x0e, 0xe2, 0xea, 0xd7, 0xd5, 0x7e, 0x2c, 0x60, - 0x53, 0x98, 0xea, 0xe3, 0xaa, 0xd6, 0x09, 0xf0, 0xf0, 0xbd, 0x23, 0x49, 0x26, 0x28, 0xa1, 0x3d, - 0xbe, 0x82, 0x7a, 0xa8, 0x4a, 0x4a, 0x59, 0x9b, 0x75, 0x0f, 0x5e, 0x3f, 0xef, 0x6f, 0xb2, 0xfd, - 0x0d, 0xf0, 0x53, 0x55, 0x52, 0x1e, 0x51, 0x4c, 0xe1, 0x41, 0x29, 0x2a, 0x6d, 0x85, 0x4c, 0x77, - 0xda, 0xac, 0xdb, 0xcc, 0x57, 0xe9, 0xbc, 0x33, 0x21, 0xe7, 0x95, 0x35, 0x69, 0xad, 0xcd, 0xba, - 0xfb, 0xf9, 0x2a, 0xed, 0x7c, 0x80, 0xd6, 0xd9, 0x02, 0x3a, 0x57, 0x57, 0x86, 0xe4, 0x19, 0x91, - 0x1b, 0x49, 0x8f, 0xcf, 0x60, 0x4f, 0x45, 0x89, 0x50, 0x45, 0x0b, 0xcd, 0xfc, 0x5f, 0x8e, 0x08, - 0x75, 0xaf, 0xae, 0xcc, 0x52, 0x24, 0xc6, 0x9d, 0x97, 0x50, 0x1b, 0x16, 0x63, 0x7c, 0x01, 0x0d, - 0x72, 0xce, 0xba, 0xa5, 0xed, 0xa3, 0x6d, 0xdb, 0x27, 0xf3, 0x56, 0xbe, 0x20, 0x7a, 0x6f, 0xe1, - 0xd1, 0x7f, 0x63, 0xe0, 0x01, 0xc0, 0xf9, 0x58, 0x95, 0x17, 0xe4, 0xd4, 0xd7, 0x8a, 0x27, 0x78, - 0x08, 0xfb, 0x5b, 0xae, 0x38, 0xeb, 0xfd, 0x64, 0xd0, 0x88, 0xd7, 0xe0, 0x1e, 0xd4, 0x4f, 0x6f, - 0xb4, 0xe6, 0xc9, 0xfc, 0xd8, 0x67, 0x43, 0xb7, 0x25, 0x15, 0x81, 0x24, 0x67, 0xf8, 0x04, 0x70, - 0x64, 0x26, 0x42, 0x2b, 0xb9, 0x21, 0xc0, 0x77, 0xf0, 0x31, 0x1c, 0xae, 0xb9, 0xe5, 0xd4, 0xbc, - 0x86, 0x29, 0xb4, 0xd6, 0xaa, 0xa7, 0x36, 0x0c, 0xb5, 0xb6, 0xdf, 0x49, 0xf2, 0x3a, 0xb6, 0x80, - 0x1f, 0x93, 0x90, 0x5a, 0x19, 0x3a, 0xb9, 0x2d, 0x88, 0x24, 0x49, 0xde, 0xc0, 0xa7, 0x70, 0x34, - 0x32, 0x85, 0xfd, 0x56, 0x8a, 0xa0, 0x2e, 0x35, 0x5d, 0x2c, 0x5e, 0x92, 0xef, 0xbe, 0x3b, 0xfe, - 0x35, 0xcd, 0xd8, 0xdd, 0x34, 0x63, 0x7f, 0xa6, 0x19, 0xfb, 0x31, 0xcb, 0x92, 0xbb, 0x59, 0x96, - 0xfc, 0x9e, 0x65, 0xc9, 0x97, 0xde, 0xfd, 0x3f, 0xcb, 0xe5, 0x6e, 0xdc, 0xde, 0xfc, 0x0d, 0x00, - 0x00, 0xff, 0xff, 0xbf, 0x78, 0x2f, 0x36, 0x61, 0x02, 0x00, 0x00, + // 439 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcf, 0x6e, 0xd3, 0x40, + 0x10, 0xc6, 0xbd, 0x4d, 0xdc, 0x96, 0x21, 0x2d, 0xdb, 0x6d, 0x4a, 0x2d, 0x24, 0xac, 0x28, 0xa7, + 0x10, 0x89, 0x84, 0x7f, 0xe2, 0x1e, 0x9a, 0x22, 0x72, 0xa9, 0x22, 0x17, 0x7a, 0xe0, 0xb6, 0xf5, + 0x0e, 0xed, 0x2a, 0xcb, 0xda, 0xda, 0xdd, 0x86, 0xfa, 0x2d, 0x78, 0x14, 0x1e, 0x83, 0x63, 0x8f, + 0x1c, 0x51, 0xf2, 0x22, 0xc8, 0x9b, 0xa4, 0x71, 0x38, 0xf5, 0x62, 0xef, 0xcc, 0xfc, 0xc6, 0xdf, + 0x7c, 0xe3, 0x85, 0x81, 0x46, 0xd7, 0xb7, 0x98, 0xde, 0x18, 0xb4, 0x68, 0xa6, 0x32, 0xc5, 0xfe, + 0x35, 0xd7, 0xc2, 0x5e, 0xf3, 0x49, 0xe5, 0x94, 0x9b, 0xcc, 0x65, 0x7d, 0xff, 0xb4, 0xeb, 0x6c, + 0xcf, 0x27, 0x58, 0x83, 0xeb, 0xe2, 0xd3, 0x2a, 0xd7, 0x76, 0xf0, 0xf8, 0xc4, 0xa0, 0x40, 0xed, + 0x24, 0x57, 0x96, 0xbd, 0x86, 0xba, 0x2b, 0x72, 0x8c, 0x48, 0x8b, 0x74, 0xf6, 0xdf, 0x3c, 0xef, + 0x55, 0xd9, 0x5e, 0x05, 0xfc, 0x5c, 0xe4, 0x98, 0x78, 0x94, 0x45, 0xb0, 0x93, 0xf3, 0x42, 0x65, + 0x5c, 0x44, 0x5b, 0x2d, 0xd2, 0x69, 0x24, 0xab, 0xb0, 0xac, 0x4c, 0xd1, 0x58, 0x99, 0xe9, 0xa8, + 0xd6, 0x22, 0x9d, 0xbd, 0x64, 0x15, 0xb6, 0x3f, 0x42, 0x73, 0xbc, 0x80, 0xce, 0xe5, 0x95, 0x46, + 0x31, 0x46, 0x34, 0x23, 0x61, 0xd9, 0x33, 0xd8, 0x95, 0x5e, 0xc2, 0x15, 0x7e, 0x84, 0x46, 0x72, + 0x1f, 0x33, 0x06, 0x75, 0x2b, 0xaf, 0xf4, 0x52, 0xc4, 0x9f, 0xdb, 0xaf, 0xa0, 0x36, 0x48, 0x27, + 0xec, 0x05, 0x84, 0x68, 0x4c, 0x66, 0x96, 0x63, 0x1f, 0x6e, 0x8e, 0x7d, 0x5a, 0x96, 0x92, 0x05, + 0xd1, 0x7e, 0x0f, 0xe1, 0xd8, 0xaf, 0xe1, 0x25, 0x84, 0x7e, 0x1f, 0xcb, 0x9e, 0xe3, 0xcd, 0x1e, + 0xcf, 0x78, 0x93, 0x0b, 0xaa, 0xfb, 0x0e, 0x9e, 0xfc, 0x67, 0x9f, 0xed, 0x03, 0x9c, 0x4f, 0x64, + 0x7e, 0x81, 0x46, 0x7e, 0x2b, 0x68, 0xc0, 0x0e, 0x60, 0x6f, 0xc3, 0x0d, 0x25, 0xdd, 0x5f, 0x04, + 0x42, 0x2f, 0xcf, 0x76, 0xa1, 0x7e, 0x76, 0xa3, 0x14, 0x0d, 0xca, 0xb6, 0x2f, 0x1a, 0x6f, 0x73, + 0x4c, 0x1d, 0x0a, 0x4a, 0xd8, 0x53, 0x60, 0x23, 0x3d, 0xe5, 0x4a, 0x8a, 0x8a, 0x00, 0xdd, 0x62, + 0x47, 0x70, 0xb0, 0xe6, 0x96, 0xdb, 0xa2, 0x35, 0x16, 0x41, 0x73, 0xad, 0x7a, 0x96, 0xb9, 0x81, + 0x52, 0xd9, 0x0f, 0x14, 0xb4, 0xce, 0x9a, 0x40, 0x87, 0xc8, 0x85, 0x92, 0x1a, 0x4f, 0x6f, 0x53, + 0x44, 0x81, 0x82, 0x86, 0xec, 0x18, 0x0e, 0x47, 0x3a, 0xcd, 0xbe, 0xe7, 0xdc, 0xc9, 0x4b, 0x85, + 0x17, 0x8b, 0x3f, 0x40, 0xb7, 0xcb, 0xef, 0x57, 0x0b, 0xde, 0x31, 0xdd, 0xe9, 0x1e, 0xc1, 0xa3, + 0x7b, 0xf3, 0xe5, 0xd4, 0xc3, 0x64, 0x7c, 0x42, 0x83, 0x0f, 0xc3, 0xdf, 0xb3, 0x98, 0xdc, 0xcd, + 0x62, 0xf2, 0x77, 0x16, 0x93, 0x9f, 0xf3, 0x38, 0xb8, 0x9b, 0xc7, 0xc1, 0x9f, 0x79, 0x1c, 0x7c, + 0xed, 0x3e, 0xfc, 0x4a, 0x5e, 0x6e, 0xfb, 0xd7, 0xdb, 0x7f, 0x01, 0x00, 0x00, 0xff, 0xff, 0x53, + 0x32, 0xf7, 0x79, 0xc7, 0x02, 0x00, 0x00, } func (m *Credentials) Marshal() (dAtA []byte, err error) { @@ -393,6 +467,34 @@ func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *Proto) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Proto) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Proto) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Proto != 0 { + i = encodeVarintHandshake(dAtA, i, uint64(m.Proto)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int { offset -= sovHandshake(v) base := offset @@ -452,6 +554,18 @@ func (m *Ack) Size() (n int) { return n } +func (m *Proto) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Proto != 0 { + n += 1 + sovHandshake(uint64(m.Proto)) + } + return n +} + func sovHandshake(x uint64) (n int) { return (math_bits.Len64(x|1) + 6) / 7 } @@ -767,6 +881,75 @@ func (m *Ack) Unmarshal(dAtA []byte) error { } return nil } +func (m *Proto) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Proto: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Proto: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Proto", wireType) + } + m.Proto = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Proto |= ProtoType(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipHandshake(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthHandshake + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func skipHandshake(dAtA []byte) (n int, err error) { l := len(dAtA) iNdEx := 0 diff --git a/net/secureservice/handshake/handshakeproto/protos/handshake.proto b/net/secureservice/handshake/handshakeproto/protos/handshake.proto index cca5822e..1ea66b28 100644 --- a/net/secureservice/handshake/handshakeproto/protos/handshake.proto +++ b/net/secureservice/handshake/handshakeproto/protos/handshake.proto @@ -5,6 +5,8 @@ option go_package = "net/secureservice/handshake/handshakeproto"; /* +CREDENTIALS HANDSHAKE + Alice opens a new connection with Bob 1. TLS handshake done successfully; both sides know local and remote peer identifiers. @@ -68,4 +70,20 @@ enum Error { SkipVerifyNotAllowed = 4; DeadlineExceeded = 5; IncompatibleVersion = 6; + IncompatibleProto = 7; +} + + +/* + +PROTO HANDSHAKE + + */ + +message Proto { + ProtoType proto = 1; +} + +enum ProtoType { + DRPC = 0; } \ No newline at end of file diff --git a/net/secureservice/handshake/proto.go b/net/secureservice/handshake/proto.go new file mode 100644 index 00000000..45e95ab5 --- /dev/null +++ b/net/secureservice/handshake/proto.go @@ -0,0 +1,97 @@ +package handshake + +import ( + "context" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "golang.org/x/exp/slices" + "net" +) + +type ProtoChecker struct { + AllowedProtoTypes []handshakeproto.ProtoType +} + +func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) (err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + err = outgoingProtoHandshake(h, conn, pt) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = conn.Close() + return ctx.Err() + } +} + +func outgoingProtoHandshake(h *handshake, conn net.Conn, pt handshakeproto.ProtoType) (err error) { + defer h.release() + h.conn = conn + localProto := &handshakeproto.Proto{ + Proto: pt, + } + if err = h.writeProto(localProto); err != nil { + h.tryWriteErrAndClose(err) + return + } + msg, err := h.readMsg(msgTypeAck) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if msg.ack.Error == handshakeproto.Error_IncompatibleProto { + return ErrRemoteIncompatibleProto + } + if msg.ack.Error == handshakeproto.Error_Null { + return nil + } + return HandshakeError{e: msg.ack.Error} +} + +func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + protoType, err = incomingProtoHandshake(h, conn, pt) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = conn.Close() + return 0, ctx.Err() + } +} + +func incomingProtoHandshake(h *handshake, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { + defer h.release() + h.conn = conn + + msg, err := h.readMsg(msgTypeProto) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if !slices.Contains(pt.AllowedProtoTypes, msg.proto.Proto) { + err = ErrIncompatibleProto + h.tryWriteErrAndClose(err) + return + } + + if err = h.writeAck(handshakeproto.Error_Null); err != nil { + h.tryWriteErrAndClose(err) + return 0, err + } else { + return msg.proto.Proto, nil + } +} diff --git a/net/secureservice/handshake/proto_test.go b/net/secureservice/handshake/proto_test.go new file mode 100644 index 00000000..f689e372 --- /dev/null +++ b/net/secureservice/handshake/proto_test.go @@ -0,0 +1,121 @@ +package handshake + +import ( + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +type protoRes struct { + protoType handshakeproto.ProtoType + err error +} + +func newProtoChecker(types ...handshakeproto.ProtoType) ProtoChecker { + return ProtoChecker{AllowedProtoTypes: types} +} +func TestIncomingProtoHandshake(t *testing.T) { + t.Run("success", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1)) + protoResCh <- protoRes{protoType: protoType, err: err} + }() + h := newHandshake() + h.conn = c2 + + // write desired proto + require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: handshakeproto.ProtoType(1)})) + msg, err := h.readMsg(msgTypeAck) + require.NoError(t, err) + assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error) + res := <-protoResCh + require.NoError(t, res.err) + assert.Equal(t, handshakeproto.ProtoType(1), res.protoType) + }) + t.Run("incompatible", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1)) + protoResCh <- protoRes{protoType: protoType, err: err} + }() + h := newHandshake() + h.conn = c2 + + // write desired proto + require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: 0})) + msg, err := h.readMsg(msgTypeAck) + require.NoError(t, err) + assert.Equal(t, handshakeproto.Error_IncompatibleProto, msg.ack.Error) + res := <-protoResCh + require.Error(t, res.err, ErrIncompatibleProto.Error()) + }) +} + +func TestOutgoingProtoHandshake(t *testing.T) { + t.Run("success", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + err := OutgoingProtoHandshake(nil, c1, 1) + protoResCh <- protoRes{err: err} + }() + h := newHandshake() + h.conn = c2 + + msg, err := h.readMsg(msgTypeProto) + require.NoError(t, err) + assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto) + require.NoError(t, h.writeAck(handshakeproto.Error_Null)) + + res := <-protoResCh + assert.NoError(t, res.err) + }) + t.Run("incompatible", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + err := OutgoingProtoHandshake(nil, c1, 1) + protoResCh <- protoRes{err: err} + }() + h := newHandshake() + h.conn = c2 + + msg, err := h.readMsg(msgTypeProto) + require.NoError(t, err) + assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto) + require.NoError(t, h.writeAck(handshakeproto.Error_IncompatibleProto)) + + res := <-protoResCh + assert.EqualError(t, res.err, ErrRemoteIncompatibleProto.Error()) + }) +} + +func TestEndToEndProto(t *testing.T) { + c1, c2 := newConnPair(t) + var ( + inResCh = make(chan protoRes, 1) + outResCh = make(chan protoRes, 1) + ) + st := time.Now() + go func() { + err := OutgoingProtoHandshake(nil, c1, 0) + outResCh <- protoRes{err: err} + }() + go func() { + protoType, err := IncomingProtoHandshake(nil, c2, newProtoChecker(0, 1)) + inResCh <- protoRes{protoType: protoType, err: err} + }() + + outRes := <-outResCh + assert.NoError(t, outRes.err) + + inRes := <-inResCh + assert.NoError(t, inRes.err) + assert.Equal(t, handshakeproto.ProtoType(0), inRes.protoType) + t.Log("dur", time.Since(st)) +} diff --git a/net/secureservice/secureservice.go b/net/secureservice/secureservice.go index d7d86040..4faca6c0 100644 --- a/net/secureservice/secureservice.go +++ b/net/secureservice/secureservice.go @@ -25,7 +25,7 @@ func New() SecureService { } type SecureService interface { - SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) + SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) app.Component } @@ -93,10 +93,10 @@ func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx return } -func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) { - sc, err := s.p2pTr.SecureOutbound(ctx, conn, "") +func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) { + sc, err = s.p2pTr.SecureOutbound(ctx, conn, "") if err != nil { - return nil, handshake.HandshakeError{Err: err} + return nil, nil, handshake.HandshakeError{Err: err} } peerId := sc.RemotePeer().String() confTypes := s.nodeconf.NodeTypes(peerId) @@ -106,10 +106,12 @@ func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec. } else { checker = s.noVerifyChecker } - // ignore identity for outgoing connection because we don't need it at this moment - _, err = handshake.OutgoingHandshake(ctx, sc, checker) + identity, err := handshake.OutgoingHandshake(ctx, sc, checker) if err != nil { - return nil, err + return nil, nil, err } - return sc, nil + cctx = context.Background() + cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String()) + cctx = peer.CtxWithIdentity(cctx, identity) + return cctx, sc, nil } diff --git a/net/secureservice/secureservice_test.go b/net/secureservice/secureservice_test.go index 2f435a65..e03b92a4 100644 --- a/net/secureservice/secureservice_test.go +++ b/net/secureservice/secureservice_test.go @@ -39,9 +39,12 @@ func TestHandshake(t *testing.T) { fxC := newFixture(t, nc, nc.GetAccountService(1), 0) defer fxC.Finish(t) - secConn, err := fxC.SecureOutbound(ctx, cc) + cctx, secConn, err := fxC.SecureOutbound(ctx, cc) + require.NoError(t, err) + ctxPeerId, err := peer.CtxPeerId(cctx) require.NoError(t, err) assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String()) + assert.Equal(t, nc.GetAccountService(0).Account().PeerId, ctxPeerId) res := <-resCh require.NoError(t, res.err) peerId, err := peer.CtxPeerId(res.ctx) @@ -72,7 +75,7 @@ func TestHandshakeIncompatibleVersion(t *testing.T) { }() fxC := newFixture(t, nc, nc.GetAccountService(1), 1) defer fxC.Finish(t) - _, err := fxC.SecureOutbound(ctx, cc) + _, _, err := fxC.SecureOutbound(ctx, cc) require.Equal(t, handshake.ErrIncompatibleVersion, err) res := <-resCh require.Equal(t, handshake.ErrIncompatibleVersion, res.err) diff --git a/net/streampool/testservice/testservice_drpc.pb.go b/net/streampool/testservice/testservice_drpc.pb.go index f50fdbe7..cfe5bce9 100644 --- a/net/streampool/testservice/testservice_drpc.pb.go +++ b/net/streampool/testservice/testservice_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: net/streampool/testservice/protos/testservice.proto package testservice @@ -72,6 +72,10 @@ type drpcTest_TestStreamClient struct { drpc.Stream } +func (x *drpcTest_TestStreamClient) GetStream() drpc.Stream { + return x.Stream +} + func (x *drpcTest_TestStreamClient) Send(m *StreamMessage) error { return x.MsgSend(m, drpcEncoding_File_net_streampool_testservice_protos_testservice_proto{}) } diff --git a/net/transport/mock_transport/mock_transport.go b/net/transport/mock_transport/mock_transport.go new file mode 100644 index 00000000..43f5572c --- /dev/null +++ b/net/transport/mock_transport/mock_transport.go @@ -0,0 +1,188 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/anyproto/any-sync/net/transport (interfaces: Transport,MultiConn) + +// Package mock_transport is a generated GoMock package. +package mock_transport + +import ( + context "context" + net "net" + reflect "reflect" + time "time" + + transport "github.com/anyproto/any-sync/net/transport" + gomock "github.com/golang/mock/gomock" +) + +// MockTransport is a mock of Transport interface. +type MockTransport struct { + ctrl *gomock.Controller + recorder *MockTransportMockRecorder +} + +// MockTransportMockRecorder is the mock recorder for MockTransport. +type MockTransportMockRecorder struct { + mock *MockTransport +} + +// NewMockTransport creates a new mock instance. +func NewMockTransport(ctrl *gomock.Controller) *MockTransport { + mock := &MockTransport{ctrl: ctrl} + mock.recorder = &MockTransportMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTransport) EXPECT() *MockTransportMockRecorder { + return m.recorder +} + +// Dial mocks base method. +func (m *MockTransport) Dial(arg0 context.Context, arg1 string) (transport.MultiConn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Dial", arg0, arg1) + ret0, _ := ret[0].(transport.MultiConn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Dial indicates an expected call of Dial. +func (mr *MockTransportMockRecorder) Dial(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockTransport)(nil).Dial), arg0, arg1) +} + +// SetAccepter mocks base method. +func (m *MockTransport) SetAccepter(arg0 transport.Accepter) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccepter", arg0) +} + +// SetAccepter indicates an expected call of SetAccepter. +func (mr *MockTransportMockRecorder) SetAccepter(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccepter", reflect.TypeOf((*MockTransport)(nil).SetAccepter), arg0) +} + +// MockMultiConn is a mock of MultiConn interface. +type MockMultiConn struct { + ctrl *gomock.Controller + recorder *MockMultiConnMockRecorder +} + +// MockMultiConnMockRecorder is the mock recorder for MockMultiConn. +type MockMultiConnMockRecorder struct { + mock *MockMultiConn +} + +// NewMockMultiConn creates a new mock instance. +func NewMockMultiConn(ctrl *gomock.Controller) *MockMultiConn { + mock := &MockMultiConn{ctrl: ctrl} + mock.recorder = &MockMultiConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMultiConn) EXPECT() *MockMultiConnMockRecorder { + return m.recorder +} + +// Accept mocks base method. +func (m *MockMultiConn) Accept() (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Accept") + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Accept indicates an expected call of Accept. +func (mr *MockMultiConnMockRecorder) Accept() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockMultiConn)(nil).Accept)) +} + +// Addr mocks base method. +func (m *MockMultiConn) Addr() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Addr") + ret0, _ := ret[0].(string) + return ret0 +} + +// Addr indicates an expected call of Addr. +func (mr *MockMultiConnMockRecorder) Addr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockMultiConn)(nil).Addr)) +} + +// Close mocks base method. +func (m *MockMultiConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockMultiConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiConn)(nil).Close)) +} + +// Context mocks base method. +func (m *MockMultiConn) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockMultiConnMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockMultiConn)(nil).Context)) +} + +// IsClosed mocks base method. +func (m *MockMultiConn) IsClosed() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClosed") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsClosed indicates an expected call of IsClosed. +func (mr *MockMultiConnMockRecorder) IsClosed() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiConn)(nil).IsClosed)) +} + +// LastUsage mocks base method. +func (m *MockMultiConn) LastUsage() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastUsage") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// LastUsage indicates an expected call of LastUsage. +func (mr *MockMultiConnMockRecorder) LastUsage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastUsage", reflect.TypeOf((*MockMultiConn)(nil).LastUsage)) +} + +// Open mocks base method. +func (m *MockMultiConn) Open(arg0 context.Context) (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Open", arg0) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Open indicates an expected call of Open. +func (mr *MockMultiConnMockRecorder) Open(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockMultiConn)(nil).Open), arg0) +} diff --git a/net/transport/transport.go b/net/transport/transport.go new file mode 100644 index 00000000..2dab5348 --- /dev/null +++ b/net/transport/transport.go @@ -0,0 +1,44 @@ +//go:generate mockgen -destination mock_transport/mock_transport.go github.com/anyproto/any-sync/net/transport Transport,MultiConn +package transport + +import ( + "context" + "errors" + "net" + "time" +) + +var ( + ErrConnClosed = errors.New("connection closed") +) + +// Transport is a common interface for a network transport +type Transport interface { + // SetAccepter sets accepter that will be called for new connections + // this method should be called before app start + SetAccepter(accepter Accepter) + // Dial creates a new connection by given address + Dial(ctx context.Context, addr string) (mc MultiConn, err error) +} + +// MultiConn is an object of multiplexing connection containing handshake info +type MultiConn interface { + // Context returns the connection context that contains handshake details + Context() context.Context + // Accept accepts new sub connections + Accept() (conn net.Conn, err error) + // Open opens new sub connection + Open(ctx context.Context) (conn net.Conn, err error) + // LastUsage returns the time of the last connection activity + LastUsage() time.Time + // Addr returns remote peer address + Addr() string + // IsClosed returns true when connection is closed + IsClosed() bool + // Close closes the connection and all sub connections + Close() error +} + +type Accepter interface { + Accept(mc MultiConn) (err error) +} diff --git a/net/transport/yamux/config.go b/net/transport/yamux/config.go new file mode 100644 index 00000000..38f74d4b --- /dev/null +++ b/net/transport/yamux/config.go @@ -0,0 +1,12 @@ +package yamux + +type configGetter interface { + GetYamux() Config +} + +type Config struct { + ListenAddrs []string `yaml:"listenAddrs"` + WriteTimeoutSec int `yaml:"writeTimeoutSec"` + DialTimeoutSec int `yaml:"dialTimeoutSec"` + MaxStreams int `yaml:"maxStreams"` +} diff --git a/net/transport/yamux/conn.go b/net/transport/yamux/conn.go new file mode 100644 index 00000000..d563df34 --- /dev/null +++ b/net/transport/yamux/conn.go @@ -0,0 +1,42 @@ +package yamux + +import ( + "context" + "github.com/anyproto/any-sync/net/connutil" + "github.com/anyproto/any-sync/net/transport" + "github.com/hashicorp/yamux" + "net" + "time" +) + +type yamuxConn struct { + ctx context.Context + luConn *connutil.LastUsageConn + addr string + *yamux.Session +} + +func (y *yamuxConn) Open(ctx context.Context) (conn net.Conn, err error) { + return y.Session.Open() +} + +func (y *yamuxConn) LastUsage() time.Time { + return y.luConn.LastUsage() +} + +func (y *yamuxConn) Context() context.Context { + return y.ctx +} + +func (y *yamuxConn) Addr() string { + return y.addr +} + +func (y *yamuxConn) Accept() (conn net.Conn, err error) { + if conn, err = y.Session.Accept(); err != nil { + if err == yamux.ErrSessionShutdown { + err = transport.ErrConnClosed + } + } + return +} diff --git a/net/rpc/server/util.go b/net/transport/yamux/util.go similarity index 93% rename from net/rpc/server/util.go rename to net/transport/yamux/util.go index 5852288a..c1299d15 100644 --- a/net/rpc/server/util.go +++ b/net/transport/yamux/util.go @@ -1,6 +1,6 @@ //go:build !windows -package server +package yamux import ( "errors" diff --git a/net/rpc/server/util_windows.go b/net/transport/yamux/util_windows.go similarity index 97% rename from net/rpc/server/util_windows.go rename to net/transport/yamux/util_windows.go index efef2915..390524d5 100644 --- a/net/rpc/server/util_windows.go +++ b/net/transport/yamux/util_windows.go @@ -1,6 +1,6 @@ //go:build windows -package server +package yamux import ( "errors" diff --git a/net/transport/yamux/yamux.go b/net/transport/yamux/yamux.go new file mode 100644 index 00000000..44729392 --- /dev/null +++ b/net/transport/yamux/yamux.go @@ -0,0 +1,170 @@ +package yamux + +import ( + "context" + "fmt" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/net/connutil" + "github.com/anyproto/any-sync/net/secureservice" + "github.com/anyproto/any-sync/net/transport" + "github.com/hashicorp/yamux" + "go.uber.org/zap" + "net" + "time" +) + +const CName = "net.transport.yamux" + +var log = logger.NewNamed(CName) + +func New() Yamux { + return new(yamuxTransport) +} + +// Yamux implements transport.Transport with tcp+yamux +type Yamux interface { + transport.Transport + app.ComponentRunnable +} + +type yamuxTransport struct { + secure secureservice.SecureService + accepter transport.Accepter + conf Config + + listeners []net.Listener + listCtx context.Context + listCtxCancel context.CancelFunc + yamuxConf *yamux.Config +} + +func (y *yamuxTransport) Init(a *app.App) (err error) { + y.secure = a.MustComponent(secureservice.CName).(secureservice.SecureService) + y.conf = a.MustComponent("config").(configGetter).GetYamux() + y.yamuxConf = yamux.DefaultConfig() + if y.conf.MaxStreams > 0 { + y.yamuxConf.AcceptBacklog = y.conf.MaxStreams + } + y.yamuxConf.EnableKeepAlive = false + y.yamuxConf.StreamOpenTimeout = time.Duration(y.conf.DialTimeoutSec) * time.Second + y.yamuxConf.ConnectionWriteTimeout = time.Duration(y.conf.WriteTimeoutSec) * time.Second + return +} + +func (y *yamuxTransport) Name() string { + return CName +} + +func (y *yamuxTransport) Run(ctx context.Context) (err error) { + if y.accepter == nil { + return fmt.Errorf("can't run service without accepter") + } + for _, listAddr := range y.conf.ListenAddrs { + list, err := net.Listen("tcp", listAddr) + if err != nil { + return err + } + y.listeners = append(y.listeners, list) + } + y.listCtx, y.listCtxCancel = context.WithCancel(context.Background()) + for _, list := range y.listeners { + go y.acceptLoop(y.listCtx, list) + } + return +} + +func (y *yamuxTransport) SetAccepter(accepter transport.Accepter) { + y.accepter = accepter +} + +func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.MultiConn, err error) { + dialTimeout := time.Duration(y.conf.DialTimeoutSec) * time.Second + conn, err := net.DialTimeout("tcp", addr, dialTimeout) + if err != nil { + return nil, err + } + ctx, cancel := context.WithTimeout(ctx, dialTimeout) + defer cancel() + cctx, sc, err := y.secure.SecureOutbound(ctx, conn) + if err != nil { + _ = conn.Close() + return nil, err + } + luc := connutil.NewLastUsageConn(sc) + sess, err := yamux.Client(luc, y.yamuxConf) + if err != nil { + return + } + mc = &yamuxConn{ + ctx: cctx, + luConn: luc, + Session: sess, + addr: addr, + } + return +} + +func (y *yamuxTransport) acceptLoop(ctx context.Context, list net.Listener) { + l := log.With(zap.String("localAddr", list.Addr().String())) + l.Info("yamux listener started") + defer func() { + l.Debug("yamux listener stopped") + }() + for { + conn, err := list.Accept() + if err != nil { + if isTemporary(err) { + l.Debug("listener temporary accept error", zap.Error(err)) + select { + case <-time.After(time.Second): + case <-ctx.Done(): + return + } + continue + } + if err != net.ErrClosed { + l.Error("listener closed with error", zap.Error(err)) + } else { + l.Info("listener closed") + } + return + } + go y.accept(conn) + } +} + +func (y *yamuxTransport) accept(conn net.Conn) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(y.conf.DialTimeoutSec)*time.Second) + defer cancel() + cctx, sc, err := y.secure.SecureInbound(ctx, conn) + if err != nil { + log.Warn("incoming connection handshake error", zap.Error(err)) + return + } + luc := connutil.NewLastUsageConn(sc) + sess, err := yamux.Server(luc, y.yamuxConf) + if err != nil { + log.Warn("incoming connection yamux session error", zap.Error(err)) + return + } + mc := &yamuxConn{ + ctx: cctx, + luConn: luc, + Session: sess, + addr: conn.RemoteAddr().String(), + } + if err = y.accepter.Accept(mc); err != nil { + log.Warn("connection accept error", zap.Error(err)) + } +} + +func (y *yamuxTransport) Close(ctx context.Context) (err error) { + if y.listCtxCancel != nil { + y.listCtxCancel() + } + for _, l := range y.listeners { + _ = l.Close() + } + return +} diff --git a/net/transport/yamux/yamux_test.go b/net/transport/yamux/yamux_test.go new file mode 100644 index 00000000..20efdfce --- /dev/null +++ b/net/transport/yamux/yamux_test.go @@ -0,0 +1,134 @@ +package yamux + +import ( + "bytes" + "context" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/net/secureservice" + "github.com/anyproto/any-sync/net/transport" + "github.com/anyproto/any-sync/nodeconf" + "github.com/anyproto/any-sync/nodeconf/mock_nodeconf" + "github.com/anyproto/any-sync/testutil/accounttest" + "github.com/anyproto/any-sync/testutil/testnodeconf" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "io" + "testing" +) + +var ctx = context.Background() + +func TestYamuxTransport_Dial(t *testing.T) { + fxS := newFixture(t) + defer fxS.finish(t) + fxC := newFixture(t) + defer fxC.finish(t) + + mcC, err := fxC.Dial(ctx, fxS.addr) + require.NoError(t, err) + require.Len(t, fxS.accepter.mcs, 1) + mcS := fxS.accepter.mcs[0] + + var ( + sData string + acceptErr error + copyErr error + done = make(chan struct{}) + ) + + go func() { + defer close(done) + conn, serr := mcS.Accept() + if serr != nil { + acceptErr = serr + return + } + buf := bytes.NewBuffer(nil) + _, copyErr = io.Copy(buf, conn) + sData = buf.String() + return + }() + + conn, err := mcC.Open(ctx) + require.NoError(t, err) + data := "some data" + _, err = conn.Write([]byte(data)) + require.NoError(t, err) + require.NoError(t, conn.Close()) + <-done + + assert.NoError(t, acceptErr) + assert.Equal(t, data, sData) + assert.NoError(t, copyErr) +} + +type fixture struct { + *yamuxTransport + a *app.App + ctrl *gomock.Controller + mockNodeConf *mock_nodeconf.MockService + acc *accounttest.AccountTestService + accepter *testAccepter + addr string +} + +func newFixture(t *testing.T) *fixture { + fx := &fixture{ + yamuxTransport: New().(*yamuxTransport), + ctrl: gomock.NewController(t), + acc: &accounttest.AccountTestService{}, + accepter: &testAccepter{}, + a: new(app.App), + } + + fx.mockNodeConf = mock_nodeconf.NewMockService(fx.ctrl) + fx.mockNodeConf.EXPECT().Init(gomock.Any()) + fx.mockNodeConf.EXPECT().Name().Return(nodeconf.CName).AnyTimes() + fx.mockNodeConf.EXPECT().Run(ctx) + fx.mockNodeConf.EXPECT().Close(ctx) + fx.mockNodeConf.EXPECT().NodeTypes(gomock.Any()).Return([]nodeconf.NodeType{nodeconf.NodeTypeTree}).AnyTimes() + fx.a.Register(fx.acc).Register(newTestConf()).Register(fx.mockNodeConf).Register(secureservice.New()).Register(fx.yamuxTransport).Register(fx.accepter) + require.NoError(t, fx.a.Start(ctx)) + fx.addr = fx.listeners[0].Addr().String() + return fx +} + +func (fx *fixture) finish(t *testing.T) { + require.NoError(t, fx.a.Close(ctx)) + fx.ctrl.Finish() +} + +func newTestConf() *testConf { + return &testConf{testnodeconf.GenNodeConfig(1)} +} + +type testConf struct { + *testnodeconf.Config +} + +func (c *testConf) GetYamux() Config { + return Config{ + ListenAddrs: []string{"127.0.0.1:0"}, + WriteTimeoutSec: 10, + DialTimeoutSec: 10, + MaxStreams: 1024, + } +} + +type testAccepter struct { + err error + mcs []transport.MultiConn +} + +func (t *testAccepter) Accept(mc transport.MultiConn) (err error) { + t.mcs = append(t.mcs, mc) + return t.err +} + +func (t *testAccepter) Init(a *app.App) (err error) { + a.MustComponent(CName).(transport.Transport).SetAccepter(t) + return nil +} + +func (t *testAccepter) Name() (name string) { return "testAccepter" }