Merge remote-tracking branch 'origin/yamux' into new-sync-protocol

This commit is contained in:
mcrakhman 2023-06-04 10:43:11 +02:00
commit 248205cddd
No known key found for this signature in database
GPG Key ID: DED12CFEF5B8396B
35 changed files with 1900 additions and 752 deletions

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32
// protoc-gen-go-drpc version: v0.0.33
// source: commonfile/fileproto/protos/file.proto
package fileproto

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32
// protoc-gen-go-drpc version: v0.0.33
// source: commonspace/spacesyncproto/protos/spacesync.proto
package spacesyncproto
@ -103,6 +103,10 @@ type drpcSpaceSync_ObjectSyncStreamClient struct {
drpc.Stream
}
func (x *drpcSpaceSync_ObjectSyncStreamClient) GetStream() drpc.Stream {
return x.Stream
}
func (x *drpcSpaceSync_ObjectSyncStreamClient) Send(m *ObjectSyncMessage) error {
return x.MsgSend(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{})
}

View File

@ -10,6 +10,7 @@ import (
"github.com/anyproto/any-sync/net/rpc/rpcerr"
"github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/util/crypto"
"storj.io/drpc"
)
const CName = "common.coordinator.coordinatorclient"
@ -39,42 +40,8 @@ type coordinatorClient struct {
nodeConf nodeconf.Service
}
func (c *coordinatorClient) ChangeStatus(ctx context.Context, spaceId string, deleteRaw *treechangeproto.RawTreeChangeWithId) (status *coordinatorproto.SpaceStatusPayload, err error) {
cl, err := c.client(ctx)
if err != nil {
return
}
resp, err := cl.SpaceStatusChange(ctx, &coordinatorproto.SpaceStatusChangeRequest{
SpaceId: spaceId,
DeletionChangeId: deleteRaw.GetId(),
DeletionChangePayload: deleteRaw.GetRawChange(),
})
if err != nil {
err = rpcerr.Unwrap(err)
return
}
status = resp.Payload
return
}
func (c *coordinatorClient) StatusCheck(ctx context.Context, spaceId string) (status *coordinatorproto.SpaceStatusPayload, err error) {
cl, err := c.client(ctx)
if err != nil {
return
}
resp, err := cl.SpaceStatusCheck(ctx, &coordinatorproto.SpaceStatusCheckRequest{
SpaceId: spaceId,
})
if err != nil {
err = rpcerr.Unwrap(err)
return
}
status = resp.Payload
return
}
func (c *coordinatorClient) Init(a *app.App) (err error) {
c.pool = a.MustComponent(pool.CName).(pool.Service).NewPool(CName)
c.pool = a.MustComponent(pool.CName).(pool.Service)
c.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.Service)
return
}
@ -83,8 +50,37 @@ func (c *coordinatorClient) Name() (name string) {
return CName
}
func (c *coordinatorClient) ChangeStatus(ctx context.Context, spaceId string, deleteRaw *treechangeproto.RawTreeChangeWithId) (status *coordinatorproto.SpaceStatusPayload, err error) {
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
resp, err := cl.SpaceStatusChange(ctx, &coordinatorproto.SpaceStatusChangeRequest{
SpaceId: spaceId,
DeletionChangeId: deleteRaw.GetId(),
DeletionChangePayload: deleteRaw.GetRawChange(),
})
if err != nil {
return rpcerr.Unwrap(err)
}
status = resp.Payload
return nil
})
return
}
func (c *coordinatorClient) StatusCheck(ctx context.Context, spaceId string) (status *coordinatorproto.SpaceStatusPayload, err error) {
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
resp, err := cl.SpaceStatusCheck(ctx, &coordinatorproto.SpaceStatusCheckRequest{
SpaceId: spaceId,
})
if err != nil {
return rpcerr.Unwrap(err)
}
status = resp.Payload
return nil
})
return
}
func (c *coordinatorClient) SpaceSign(ctx context.Context, payload SpaceSignPayload) (receipt *coordinatorproto.SpaceReceiptWithSignature, err error) {
cl, err := c.client(ctx)
if err != nil {
return
}
@ -100,54 +96,56 @@ func (c *coordinatorClient) SpaceSign(ctx context.Context, payload SpaceSignPayl
if err != nil {
return
}
resp, err := cl.SpaceSign(ctx, &coordinatorproto.SpaceSignRequest{
SpaceId: payload.SpaceId,
Header: payload.SpaceHeader,
OldIdentity: oldIdentity,
NewIdentitySignature: newSignature,
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
resp, err := cl.SpaceSign(ctx, &coordinatorproto.SpaceSignRequest{
SpaceId: payload.SpaceId,
Header: payload.SpaceHeader,
OldIdentity: oldIdentity,
NewIdentitySignature: newSignature,
})
if err != nil {
return rpcerr.Unwrap(err)
}
receipt = resp.Receipt
return nil
})
if err != nil {
err = rpcerr.Unwrap(err)
return
}
return resp.Receipt, nil
}
func (c *coordinatorClient) FileLimitCheck(ctx context.Context, spaceId string, identity []byte) (limit uint64, err error) {
cl, err := c.client(ctx)
if err != nil {
return
}
resp, err := cl.FileLimitCheck(ctx, &coordinatorproto.FileLimitCheckRequest{
AccountIdentity: identity,
SpaceId: spaceId,
})
if err != nil {
err = rpcerr.Unwrap(err)
return
}
return resp.Limit, nil
}
func (c *coordinatorClient) NetworkConfiguration(ctx context.Context, currentId string) (resp *coordinatorproto.NetworkConfigurationResponse, err error) {
cl, err := c.client(ctx)
if err != nil {
return
}
resp, err = cl.NetworkConfiguration(ctx, &coordinatorproto.NetworkConfigurationRequest{
CurrentId: currentId,
})
if err != nil {
err = rpcerr.Unwrap(err)
return
}
return
}
func (c *coordinatorClient) client(ctx context.Context) (coordinatorproto.DRPCCoordinatorClient, error) {
func (c *coordinatorClient) FileLimitCheck(ctx context.Context, spaceId string, identity []byte) (limit uint64, err error) {
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
resp, err := cl.FileLimitCheck(ctx, &coordinatorproto.FileLimitCheckRequest{
AccountIdentity: identity,
SpaceId: spaceId,
})
if err != nil {
return rpcerr.Unwrap(err)
}
limit = resp.Limit
return nil
})
return
}
func (c *coordinatorClient) NetworkConfiguration(ctx context.Context, currentId string) (resp *coordinatorproto.NetworkConfigurationResponse, err error) {
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
resp, err = cl.NetworkConfiguration(ctx, &coordinatorproto.NetworkConfigurationRequest{
CurrentId: currentId,
})
if err != nil {
return rpcerr.Unwrap(err)
}
return nil
})
return
}
func (c *coordinatorClient) doClient(ctx context.Context, f func(cl coordinatorproto.DRPCCoordinatorClient) error) error {
p, err := c.pool.GetOneOf(ctx, c.nodeConf.CoordinatorPeers())
if err != nil {
return nil, err
return err
}
return coordinatorproto.NewDRPCCoordinatorClient(p), nil
return p.DoDrpc(ctx, func(conn drpc.Conn) error {
return f(coordinatorproto.NewDRPCCoordinatorClient(conn))
})
}

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32
// protoc-gen-go-drpc version: v0.0.33
// source: coordinator/coordinatorproto/protos/coordinator.proto
package coordinatorproto

3
go.mod
View File

@ -14,6 +14,7 @@ require (
github.com/gogo/protobuf v1.3.2
github.com/golang/mock v1.6.0
github.com/google/uuid v1.3.0
github.com/hashicorp/yamux v0.1.1
github.com/huandu/skiplist v1.2.0
github.com/ipfs/go-block-format v0.1.2
github.com/ipfs/go-blockservice v0.5.2
@ -33,6 +34,7 @@ require (
github.com/tyler-smith/go-bip39 v1.1.0
github.com/zeebo/blake3 v0.2.3
github.com/zeebo/errs v1.3.0
go.uber.org/atomic v1.11.0
go.uber.org/zap v1.24.0
golang.org/x/crypto v0.9.0
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
@ -103,7 +105,6 @@ require (
github.com/whyrusleeping/chunker v0.0.0-20181014151217-fe64bd25879f // indirect
go.opentelemetry.io/otel v1.7.0 // indirect
go.opentelemetry.io/otel/trace v1.7.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.6.0 // indirect
golang.org/x/sync v0.2.0 // indirect

2
go.sum
View File

@ -79,6 +79,8 @@ github.com/gxed/hashland/keccakpg v0.0.1/go.mod h1:kRzw3HkwxFU1mpmPP8v1WyQzwdGfm
github.com/gxed/hashland/murmur3 v0.0.1/go.mod h1:KjXop02n4/ckmZSnY2+HKcLud/tcmvhST0bie/0lS48=
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE=
github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
github.com/huandu/go-assert v1.1.5 h1:fjemmA7sSfYHJD7CUqs9qTwwfdNAx7/j2/ZlHXzNB3c=
github.com/huandu/go-assert v1.1.5/go.mod h1:yOLvuqZwmcHIC5rIzrBhT7D3Q9c3GFnd0JrPVhn/06U=
github.com/huandu/skiplist v1.2.0 h1:gox56QD77HzSC0w+Ws3MH3iie755GBJU1OER3h5VsYw=

View File

@ -1,4 +1,4 @@
package timeoutconn
package connutil
import (
"errors"
@ -10,18 +10,18 @@ import (
"go.uber.org/zap"
)
var log = logger.NewNamed("common.net.timeoutconn")
var log = logger.NewNamed("common.net.connutil")
type Conn struct {
type TimeoutConn struct {
net.Conn
timeout time.Duration
}
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
return &Conn{conn, timeout}
func NewConn(conn net.Conn, timeout time.Duration) *TimeoutConn {
return &TimeoutConn{conn, timeout}
}
func (c *Conn) Write(p []byte) (n int, err error) {
func (c *TimeoutConn) Write(p []byte) (n int, err error) {
for {
if c.timeout != 0 {
if e := c.Conn.SetWriteDeadline(time.Now().Add(c.timeout)); e != nil {

30
net/connutil/usage.go Normal file
View File

@ -0,0 +1,30 @@
package connutil
import (
"go.uber.org/atomic"
"net"
"time"
)
func NewLastUsageConn(conn net.Conn) *LastUsageConn {
return &LastUsageConn{Conn: conn}
}
type LastUsageConn struct {
net.Conn
lastUsage atomic.Time
}
func (c *LastUsageConn) Write(p []byte) (n int, err error) {
c.lastUsage.Store(time.Now())
return c.Conn.Write(p)
}
func (c *LastUsageConn) Read(p []byte) (n int, err error) {
c.lastUsage.Store(time.Now())
return c.Conn.Read(p)
}
func (c *LastUsageConn) LastUsage() time.Time {
return c.lastUsage.Load()
}

View File

@ -1,137 +0,0 @@
package dialer
import (
"context"
"errors"
"fmt"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
net2 "github.com/anyproto/any-sync/net"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/secureservice"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/timeoutconn"
"github.com/anyproto/any-sync/nodeconf"
"github.com/libp2p/go-libp2p/core/sec"
"go.uber.org/zap"
"net"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
"storj.io/drpc/drpcmanager"
"storj.io/drpc/drpcwire"
"sync"
"time"
)
const CName = "common.net.dialer"
var (
ErrAddrsNotFound = errors.New("addrs for peer not found")
ErrPeerIdIsUnexpected = errors.New("expected to connect with other peer id")
)
var log = logger.NewNamed(CName)
func New() Dialer {
return &dialer{peerAddrs: map[string][]string{}}
}
type Dialer interface {
Dial(ctx context.Context, peerId string) (peer peer.Peer, err error)
SetPeerAddrs(peerId string, addrs []string)
app.Component
}
type dialer struct {
transport secureservice.SecureService
config net2.Config
nodeConf nodeconf.NodeConf
peerAddrs map[string][]string
mu sync.RWMutex
}
func (d *dialer) Init(a *app.App) (err error) {
d.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService)
d.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
d.config = a.MustComponent("config").(net2.ConfigGetter).GetNet()
return
}
func (d *dialer) Name() (name string) {
return CName
}
func (d *dialer) SetPeerAddrs(peerId string, addrs []string) {
d.mu.Lock()
defer d.mu.Unlock()
d.peerAddrs[peerId] = addrs
}
func (d *dialer) getPeerAddrs(peerId string) ([]string, error) {
if addrs, ok := d.nodeConf.PeerAddresses(peerId); ok {
return addrs, nil
}
addrs, ok := d.peerAddrs[peerId]
if !ok || len(addrs) == 0 {
return nil, ErrAddrsNotFound
}
return addrs, nil
}
func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err error) {
var ctxCancel context.CancelFunc
ctx, ctxCancel = context.WithTimeout(ctx, time.Second*10)
defer ctxCancel()
d.mu.RLock()
defer d.mu.RUnlock()
addrs, err := d.getPeerAddrs(peerId)
if err != nil {
return
}
var (
conn drpc.Conn
sc sec.SecureConn
)
log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs))
for _, addr := range addrs {
conn, sc, err = d.handshake(ctx, addr, peerId)
if err != nil {
log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err))
} else {
break
}
}
if err != nil {
return
}
return peer.NewPeer(sc, conn), nil
}
func (d *dialer) handshake(ctx context.Context, addr, peerId string) (conn drpc.Conn, sc sec.SecureConn, err error) {
st := time.Now()
// TODO: move dial timeout to config
tcpConn, err := net.DialTimeout("tcp", addr, time.Second*15)
if err != nil {
return nil, nil, fmt.Errorf("dialTimeout error: %v; since start: %v", err, time.Since(st))
}
timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds))
sc, err = d.transport.SecureOutbound(ctx, timeoutConn)
if err != nil {
if he, ok := err.(handshake.HandshakeError); ok {
return nil, nil, he
}
return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st))
}
if peerId != sc.RemotePeer().String() {
return nil, nil, ErrPeerIdIsUnexpected
}
log.Info("connected with remote host", zap.String("serverPeer", sc.RemotePeer().String()), zap.String("addr", addr))
conn = drpcconn.NewWithOptions(sc, drpcconn.Options{Manager: drpcmanager.Options{
Reader: drpcwire.ReaderOptions{MaximumBufferSize: d.config.Stream.MaxMsgSizeMb * (1 << 20)},
}})
return conn, sc, err
}

View File

@ -2,82 +2,146 @@ package peer
import (
"context"
"sync/atomic"
"time"
"github.com/anyproto/any-sync/app/logger"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/anyproto/any-sync/app/ocache"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anyproto/any-sync/net/transport"
"go.uber.org/zap"
"io"
"net"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
"sync"
"time"
)
var log = logger.NewNamed("common.net.peer")
func NewPeer(sc sec.SecureConn, conn drpc.Conn) Peer {
return &peer{
id: sc.RemotePeer().String(),
lastUsage: time.Now().Unix(),
sc: sc,
Conn: conn,
type connCtrl interface {
ServeConn(ctx context.Context, conn net.Conn) (err error)
}
func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) {
ctx := mc.Context()
pr := &peer{
active: map[drpc.Conn]struct{}{},
MultiConn: mc,
ctrl: ctrl,
}
if pr.id, err = CtxPeerId(ctx); err != nil {
return
}
go pr.acceptLoop()
return pr, nil
}
type Peer interface {
Id() string
LastUsage() time.Time
UpdateLastUsage()
Addr() string
AcquireDrpcConn(ctx context.Context) (drpc.Conn, error)
ReleaseDrpcConn(conn drpc.Conn)
DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error
IsClosed() bool
TryClose(objectTTL time.Duration) (res bool, err error)
drpc.Conn
ocache.Object
}
type peer struct {
id string
ttl time.Duration
lastUsage int64
sc sec.SecureConn
drpc.Conn
id string
ctrl connCtrl
// drpc conn pool
inactive []drpc.Conn
active map[drpc.Conn]struct{}
mu sync.Mutex
transport.MultiConn
}
func (p *peer) Id() string {
return p.id
}
func (p *peer) LastUsage() time.Time {
select {
case <-p.Closed():
return time.Unix(0, 0)
default:
func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
p.mu.Lock()
defer p.mu.Unlock()
if len(p.inactive) == 0 {
conn, err := p.Open(ctx)
if err != nil {
return nil, err
}
dconn := drpcconn.New(conn)
p.inactive = append(p.inactive, dconn)
}
return time.Unix(atomic.LoadInt64(&p.lastUsage), 0)
idx := len(p.inactive) - 1
res := p.inactive[idx]
p.inactive = p.inactive[:idx]
p.active[res] = struct{}{}
return res, nil
}
func (p *peer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error {
defer p.UpdateLastUsage()
return p.Conn.Invoke(ctx, rpc, enc, in, out)
}
func (p *peer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
defer p.UpdateLastUsage()
return p.Conn.NewStream(ctx, rpc, enc)
}
func (p *peer) Read(b []byte) (n int, err error) {
if n, err = p.sc.Read(b); err == nil {
p.UpdateLastUsage()
func (p *peer) ReleaseDrpcConn(conn drpc.Conn) {
p.mu.Lock()
defer p.mu.Unlock()
if _, ok := p.active[conn]; ok {
delete(p.active, conn)
}
p.inactive = append(p.inactive, conn)
return
}
func (p *peer) Write(b []byte) (n int, err error) {
if n, err = p.sc.Write(b); err == nil {
p.UpdateLastUsage()
func (p *peer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error {
conn, err := p.AcquireDrpcConn(ctx)
if err != nil {
return err
}
return
defer p.ReleaseDrpcConn(conn)
return do(conn)
}
func (p *peer) UpdateLastUsage() {
atomic.StoreInt64(&p.lastUsage, time.Now().Unix())
func (p *peer) acceptLoop() {
var exitErr error
defer func() {
if exitErr != transport.ErrConnClosed {
log.Warn("accept error: close connection", zap.Error(exitErr))
_ = p.MultiConn.Close()
}
}()
for {
conn, err := p.Accept()
if err != nil {
exitErr = err
return
}
go func() {
serveErr := p.serve(conn)
if serveErr != io.EOF && serveErr != transport.ErrConnClosed {
log.InfoCtx(p.Context(), "serve connection error", zap.Error(serveErr))
}
}()
}
}
var defaultProtoChecker = handshake.ProtoChecker{
AllowedProtoTypes: []handshakeproto.ProtoType{
handshakeproto.ProtoType_DRPC,
},
}
func (p *peer) serve(conn net.Conn) (err error) {
hsCtx, cancel := context.WithTimeout(p.Context(), time.Second*20)
if _, err = handshake.IncomingProtoHandshake(hsCtx, conn, defaultProtoChecker); err != nil {
cancel()
return
}
cancel()
return p.ctrl.ServeConn(p.Context(), conn)
}
func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) {
@ -87,14 +151,7 @@ func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) {
return true, p.Close()
}
func (p *peer) Addr() string {
if p.sc != nil {
return p.sc.RemoteAddr().String()
}
return ""
}
func (p *peer) Close() (err error) {
log.Debug("peer close", zap.String("peerId", p.id))
return p.Conn.Close()
return p.MultiConn.Close()
}

137
net/peer/peer_test.go Normal file
View File

@ -0,0 +1,137 @@
package peer
import (
"context"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anyproto/any-sync/net/transport/mock_transport"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"net"
"testing"
"time"
)
var ctx = context.Background()
func TestPeer_AcquireDrpcConn(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
in, out := net.Pipe()
defer out.Close()
fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil)
dc, err := fx.AcquireDrpcConn(ctx)
require.NoError(t, err)
assert.NotEmpty(t, dc)
defer dc.Close()
assert.Len(t, fx.active, 1)
assert.Len(t, fx.inactive, 0)
fx.ReleaseDrpcConn(dc)
assert.Len(t, fx.active, 0)
assert.Len(t, fx.inactive, 1)
dc, err = fx.AcquireDrpcConn(ctx)
require.NoError(t, err)
assert.NotEmpty(t, dc)
assert.Len(t, fx.active, 1)
assert.Len(t, fx.inactive, 0)
}
func TestPeerAccept(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
in, out := net.Pipe()
defer out.Close()
var outHandshakeCh = make(chan error)
go func() {
outHandshakeCh <- handshake.OutgoingProtoHandshake(ctx, out, handshakeproto.ProtoType_DRPC)
}()
fx.acceptCh <- acceptedConn{conn: in}
cn := <-fx.testCtrl.serveConn
assert.Equal(t, in, cn)
assert.NoError(t, <-outHandshakeCh)
}
func TestPeer_TryClose(t *testing.T) {
t.Run("ttl", func(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
lu := time.Now()
fx.mc.EXPECT().LastUsage().Return(lu)
res, err := fx.TryClose(time.Second)
require.NoError(t, err)
assert.False(t, res)
})
t.Run("close", func(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
lu := time.Now().Add(-time.Second * 2)
fx.mc.EXPECT().LastUsage().Return(lu)
res, err := fx.TryClose(time.Second)
require.NoError(t, err)
assert.True(t, res)
})
}
type acceptedConn struct {
conn net.Conn
err error
}
func newFixture(t *testing.T, peerId string) *fixture {
fx := &fixture{
ctrl: gomock.NewController(t),
acceptCh: make(chan acceptedConn),
testCtrl: newTesCtrl(),
}
fx.mc = mock_transport.NewMockMultiConn(fx.ctrl)
ctx := CtxWithPeerId(context.Background(), peerId)
fx.mc.EXPECT().Context().Return(ctx).AnyTimes()
fx.mc.EXPECT().Accept().DoAndReturn(func() (net.Conn, error) {
ac := <-fx.acceptCh
return ac.conn, ac.err
}).AnyTimes()
fx.mc.EXPECT().Close().AnyTimes()
p, err := NewPeer(fx.mc, fx.testCtrl)
require.NoError(t, err)
fx.peer = p.(*peer)
return fx
}
type fixture struct {
*peer
ctrl *gomock.Controller
mc *mock_transport.MockMultiConn
acceptCh chan acceptedConn
testCtrl *testCtrl
}
func (fx *fixture) finish() {
fx.testCtrl.close()
fx.ctrl.Finish()
}
func newTesCtrl() *testCtrl {
return &testCtrl{closeCh: make(chan struct{}), serveConn: make(chan net.Conn, 10)}
}
type testCtrl struct {
serveConn chan net.Conn
closeCh chan struct{}
}
func (t *testCtrl) ServeConn(ctx context.Context, conn net.Conn) (err error) {
t.serveConn <- conn
<-t.closeCh
return io.EOF
}
func (t *testCtrl) close() {
close(t.closeCh)
}

View File

@ -0,0 +1,110 @@
package peerservice
import (
"context"
"errors"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/pool"
"github.com/anyproto/any-sync/net/rpc/server"
"github.com/anyproto/any-sync/net/transport"
"github.com/anyproto/any-sync/net/transport/yamux"
"github.com/anyproto/any-sync/nodeconf"
"go.uber.org/zap"
"sync"
)
const CName = "net.peerservice"
var log = logger.NewNamed(CName)
var (
ErrAddrsNotFound = errors.New("addrs for peer not found")
)
func New() PeerService {
return new(peerService)
}
type PeerService interface {
Dial(ctx context.Context, peerId string) (pr peer.Peer, err error)
SetPeerAddrs(peerId string, addrs []string)
transport.Accepter
app.Component
}
type peerService struct {
yamux transport.Transport
nodeConf nodeconf.NodeConf
peerAddrs map[string][]string
pool pool.Pool
server server.DRPCServer
mu sync.RWMutex
}
func (p *peerService) Init(a *app.App) (err error) {
p.yamux = a.MustComponent(yamux.CName).(transport.Transport)
p.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
p.pool = a.MustComponent(pool.CName).(pool.Pool)
p.server = a.MustComponent(server.CName).(server.DRPCServer)
p.peerAddrs = map[string][]string{}
return nil
}
func (p *peerService) Name() (name string) {
return CName
}
func (p *peerService) Dial(ctx context.Context, peerId string) (pr peer.Peer, err error) {
p.mu.RLock()
defer p.mu.RUnlock()
addrs, err := p.getPeerAddrs(peerId)
if err != nil {
return
}
var mc transport.MultiConn
log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs))
for _, addr := range addrs {
mc, err = p.yamux.Dial(ctx, addr)
if err != nil {
log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err))
} else {
break
}
}
if err != nil {
return
}
return peer.NewPeer(mc, p.server)
}
func (p *peerService) Accept(mc transport.MultiConn) (err error) {
pr, err := peer.NewPeer(mc, p.server)
if err != nil {
return err
}
if err = p.pool.AddPeer(context.Background(), pr); err != nil {
_ = pr.Close()
}
return
}
func (p *peerService) SetPeerAddrs(peerId string, addrs []string) {
p.mu.Lock()
defer p.mu.Unlock()
p.peerAddrs[peerId] = addrs
}
func (p *peerService) getPeerAddrs(peerId string) ([]string, error) {
if addrs, ok := p.nodeConf.PeerAddresses(peerId); ok {
return addrs, nil
}
addrs, ok := p.peerAddrs[peerId]
if !ok || len(addrs) == 0 {
return nil, ErrAddrsNotFound
}
return addrs, nil
}

View File

@ -4,7 +4,6 @@ import (
"context"
"github.com/anyproto/any-sync/app/ocache"
"github.com/anyproto/any-sync/net"
"github.com/anyproto/any-sync/net/dialer"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"go.uber.org/zap"
@ -13,59 +12,61 @@ import (
// Pool creates and caches outgoing connection
type Pool interface {
// Get lookups to peer in existing connections or creates and cache new one
// Get lookups to peer in existing connections or creates and outgoing new one
Get(ctx context.Context, id string) (peer.Peer, error)
// Dial creates new connection to peer and not use cache
Dial(ctx context.Context, id string) (peer.Peer, error)
// GetOneOf searches at least one existing connection in cache or creates a new one from a randomly selected id from given list
// GetOneOf searches at least one existing connection in outgoing or creates a new one from a randomly selected id from given list
GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error)
DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error)
// AddPeer adds incoming peer to the pool
AddPeer(ctx context.Context, p peer.Peer) (err error)
}
type pool struct {
cache ocache.OCache
dialer dialer.Dialer
outgoing ocache.OCache
incoming ocache.OCache
}
func (p *pool) Name() (name string) {
return CName
}
func (p *pool) Run(ctx context.Context) (err error) {
return nil
func (p *pool) Get(ctx context.Context, id string) (pr peer.Peer, err error) {
// if we have incoming connection - try to reuse it
if pr, err = p.get(ctx, p.incoming, id); err != nil {
// or try to get or create outgoing
return p.get(ctx, p.outgoing, id)
}
return
}
func (p *pool) Get(ctx context.Context, id string) (peer.Peer, error) {
v, err := p.cache.Get(ctx, id)
func (p *pool) get(ctx context.Context, source ocache.OCache, id string) (peer.Peer, error) {
v, err := source.Get(ctx, id)
if err != nil {
return nil, err
}
pr := v.(peer.Peer)
select {
case <-pr.Closed():
default:
if !pr.IsClosed() {
return pr, nil
}
_, _ = p.cache.Remove(ctx, id)
_, _ = source.Remove(ctx, id)
return p.Get(ctx, id)
}
func (p *pool) Dial(ctx context.Context, id string) (peer.Peer, error) {
return p.dialer.Dial(ctx, id)
}
func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
// finding existing connection
for _, peerId := range peerIds {
if v, err := p.cache.Pick(ctx, peerId); err == nil {
if v, err := p.incoming.Pick(ctx, peerId); err == nil {
pr := v.(peer.Peer)
select {
case <-pr.Closed():
default:
if !pr.IsClosed() {
return pr, nil
}
_, _ = p.cache.Remove(ctx, peerId)
_, _ = p.incoming.Remove(ctx, peerId)
}
if v, err := p.outgoing.Pick(ctx, peerId); err == nil {
pr := v.(peer.Peer)
if !pr.IsClosed() {
return pr, nil
}
_, _ = p.outgoing.Remove(ctx, peerId)
}
}
// shuffle ids for better consistency
@ -75,8 +76,8 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error
// connecting
var lastErr error
for _, peerId := range peerIds {
if v, err := p.cache.Get(ctx, peerId); err == nil {
return v.(peer.Peer), nil
if v, err := p.Get(ctx, peerId); err == nil {
return v, nil
} else {
log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err))
lastErr = err
@ -88,27 +89,18 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error
return nil, lastErr
}
func (p *pool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
// shuffle ids for better consistency
rand.Shuffle(len(peerIds), func(i, j int) {
peerIds[i], peerIds[j] = peerIds[j], peerIds[i]
})
// connecting
var lastErr error
for _, peerId := range peerIds {
if v, err := p.dialer.Dial(ctx, peerId); err == nil {
return v.(peer.Peer), nil
func (p *pool) AddPeer(ctx context.Context, pr peer.Peer) (err error) {
if err = p.incoming.Add(pr.Id(), pr); err != nil {
if err == ocache.ErrExists {
// in case when an incoming connection with a peer already exists, we close and remove an existing connection
if v, e := p.incoming.Pick(ctx, pr.Id()); e == nil {
_ = v.Close()
_, _ = p.incoming.Remove(ctx, pr.Id())
return p.incoming.Add(pr.Id(), pr)
}
} else {
log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err))
lastErr = err
return err
}
}
if _, ok := lastErr.(handshake.HandshakeError); !ok {
lastErr = net.ErrUnableToConnect
}
return nil, lastErr
}
func (p *pool) Close(ctx context.Context) (err error) {
return p.cache.Close()
return
}

View File

@ -6,11 +6,11 @@ import (
"fmt"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/net"
"github.com/anyproto/any-sync/net/dialer"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
net2 "net"
"storj.io/drpc"
"testing"
"time"
@ -133,6 +133,27 @@ func TestPool_GetOneOf(t *testing.T) {
})
}
func TestPool_AddPeer(t *testing.T) {
t.Run("success", func(t *testing.T) {
fx := newFixture(t)
defer fx.Finish()
require.NoError(t, fx.AddPeer(ctx, newTestPeer("p1")))
})
t.Run("two peers", func(t *testing.T) {
fx := newFixture(t)
defer fx.Finish()
p1, p2 := newTestPeer("p1"), newTestPeer("p1")
require.NoError(t, fx.AddPeer(ctx, p1))
require.NoError(t, fx.AddPeer(ctx, p2))
select {
case <-p1.closed:
default:
assert.Truef(t, false, "peer not closed")
}
})
}
func newFixture(t *testing.T) *fixture {
fx := &fixture{
Service: New(),
@ -158,7 +179,7 @@ type fixture struct {
t *testing.T
}
var _ dialer.Dialer = (*dialerMock)(nil)
var _ dialer = (*dialerMock)(nil)
type dialerMock struct {
dial func(ctx context.Context, peerId string) (peer peer.Peer, err error)
@ -181,7 +202,7 @@ func (d *dialerMock) Init(a *app.App) (err error) {
}
func (d *dialerMock) Name() (name string) {
return dialer.CName
return "net.peerservice"
}
func newTestPeer(id string) *testPeer {
@ -196,6 +217,31 @@ type testPeer struct {
closed chan struct{}
}
func (t *testPeer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error {
return fmt.Errorf("not implemented")
}
func (t *testPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
return nil, fmt.Errorf("not implemented")
}
func (t *testPeer) ReleaseDrpcConn(conn drpc.Conn) {}
func (t *testPeer) Context() context.Context {
//TODO implement me
panic("implement me")
}
func (t *testPeer) Accept() (conn net2.Conn, err error) {
//TODO implement me
panic("implement me")
}
func (t *testPeer) Open(ctx context.Context) (conn net2.Conn, err error) {
//TODO implement me
panic("implement me")
}
func (t *testPeer) Addr() string {
return ""
}
@ -204,12 +250,6 @@ func (t *testPeer) Id() string {
return t.id
}
func (t *testPeer) LastUsage() time.Time {
return time.Now()
}
func (t *testPeer) UpdateLastUsage() {}
func (t *testPeer) TryClose(objectTTL time.Duration) (res bool, err error) {
return true, t.Close()
}
@ -224,14 +264,11 @@ func (t *testPeer) Close() error {
return nil
}
func (t *testPeer) Closed() <-chan struct{} {
return t.closed
}
func (t *testPeer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error {
return fmt.Errorf("call Invoke on test peer")
}
func (t *testPeer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
return nil, fmt.Errorf("call NewStream on test peer")
func (t *testPeer) IsClosed() bool {
select {
case <-t.closed:
return true
default:
return false
}
}

View File

@ -6,8 +6,9 @@ import (
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/app/ocache"
"github.com/anyproto/any-sync/metric"
"github.com/anyproto/any-sync/net/dialer"
"github.com/anyproto/any-sync/net/peer"
"github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap"
"time"
)
@ -23,46 +24,54 @@ func New() Service {
type Service interface {
Pool
NewPool(name string) Pool
app.ComponentRunnable
}
type dialer interface {
Dial(ctx context.Context, peerId string) (pr peer.Peer, err error)
}
type poolService struct {
// default pool
*pool
dialer dialer.Dialer
dialer dialer
metricReg *prometheus.Registry
}
func (p *poolService) Init(a *app.App) (err error) {
p.dialer = a.MustComponent(dialer.CName).(dialer.Dialer)
p.pool = &pool{dialer: p.dialer}
p.dialer = a.MustComponent("net.peerservice").(dialer)
p.pool = &pool{}
if m := a.Component(metric.CName); m != nil {
p.metricReg = m.(metric.Metric).Registry()
}
p.pool.cache = ocache.New(
p.pool.outgoing = ocache.New(
func(ctx context.Context, id string) (value ocache.Object, err error) {
return p.dialer.Dial(ctx, id)
},
ocache.WithLogger(log.Sugar()),
ocache.WithGCPeriod(time.Minute),
ocache.WithTTL(time.Minute*5),
ocache.WithPrometheus(p.metricReg, "netpool", "default"),
ocache.WithPrometheus(p.metricReg, "netpool", "outgoing"),
)
p.pool.incoming = ocache.New(
func(ctx context.Context, id string) (value ocache.Object, err error) {
return nil, ocache.ErrNotExists
},
ocache.WithLogger(log.Sugar()),
ocache.WithGCPeriod(time.Minute),
ocache.WithTTL(time.Minute*5),
ocache.WithPrometheus(p.metricReg, "netpool", "incoming"),
)
return nil
}
func (p *poolService) NewPool(name string) Pool {
return &pool{
dialer: p.dialer,
cache: ocache.New(
func(ctx context.Context, id string) (value ocache.Object, err error) {
return p.dialer.Dial(ctx, id)
},
ocache.WithLogger(log.Sugar()),
ocache.WithGCPeriod(time.Minute),
ocache.WithTTL(time.Minute*5),
ocache.WithPrometheus(p.metricReg, "netpool", name),
),
}
func (p *pool) Run(ctx context.Context) (err error) {
return nil
}
func (p *pool) Close(ctx context.Context) (err error) {
if e := p.incoming.Close(); e != nil {
log.Warn("close incoming cache error", zap.Error(e))
}
return p.outgoing.Close()
}

View File

@ -1,134 +0,0 @@
package server
import (
"context"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/secureservice"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/zeebo/errs"
"go.uber.org/zap"
"io"
"net"
"storj.io/drpc"
"storj.io/drpc/drpcmanager"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"storj.io/drpc/drpcwire"
"time"
)
type BaseDrpcServer struct {
drpcServer *drpcserver.Server
transport secureservice.SecureService
listeners []net.Listener
handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
cancel func()
*drpcmux.Mux
}
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
type Params struct {
BufferSizeMb int
ListenAddrs []string
Wrapper DRPCHandlerWrapper
TimeoutMillis int
Handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
}
func NewBaseDrpcServer() *BaseDrpcServer {
return &BaseDrpcServer{Mux: drpcmux.New()}
}
func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) {
s.drpcServer = drpcserver.NewWithOptions(params.Wrapper(s.Mux), drpcserver.Options{Manager: drpcmanager.Options{
Reader: drpcwire.ReaderOptions{MaximumBufferSize: params.BufferSizeMb * (1 << 20)},
}})
s.handshake = params.Handshake
ctx, s.cancel = context.WithCancel(ctx)
for _, addr := range params.ListenAddrs {
list, err := net.Listen("tcp", addr)
if err != nil {
return err
}
s.listeners = append(s.listeners, list)
go s.serve(ctx, list)
}
return
}
func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) {
l := log.With(zap.String("localAddr", lis.Addr().String()))
l.Info("drpc listener started")
defer func() {
l.Debug("drpc listener stopped")
}()
for {
select {
case <-ctx.Done():
return
default:
}
conn, err := lis.Accept()
if err != nil {
if isTemporary(err) {
l.Debug("listener temporary accept error", zap.Error(err))
select {
case <-time.After(time.Second):
case <-ctx.Done():
return
}
continue
}
l.Error("listener accept error", zap.Error(err))
return
}
go s.serveConn(conn)
}
}
func (s *BaseDrpcServer) serveConn(conn net.Conn) {
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
var (
ctx = context.Background()
err error
)
if s.handshake != nil {
ctx, conn, err = s.handshake(conn)
if err != nil {
l.Info("handshake error", zap.Error(err))
return
}
if sc, ok := conn.(sec.SecureConn); ok {
ctx = peer.CtxWithPeerId(ctx, sc.RemotePeer().String())
}
}
ctx = peer.CtxWithPeerAddr(ctx, conn.RemoteAddr().String())
l.Debug("connection opened")
if err := s.drpcServer.ServeOne(ctx, conn); err != nil {
if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) {
l.Debug("connection closed")
} else {
l.Warn("serve connection error", zap.Error(err))
}
}
}
func (s *BaseDrpcServer) ListenAddrs() (addrs []net.Addr) {
for _, list := range s.listeners {
addrs = append(addrs, list.Addr())
}
return
}
func (s *BaseDrpcServer) Close(ctx context.Context) (err error) {
if s.cancel != nil {
s.cancel()
}
for _, l := range s.listeners {
if e := l.Close(); e != nil {
log.Warn("close listener error", zap.Error(e))
}
}
return
}

View File

@ -6,11 +6,13 @@ import (
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/metric"
anyNet "github.com/anyproto/any-sync/net"
"github.com/anyproto/any-sync/net/secureservice"
"github.com/libp2p/go-libp2p/core/sec"
"go.uber.org/zap"
"net"
"storj.io/drpc"
"time"
"storj.io/drpc/drpcmanager"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"storj.io/drpc/drpcwire"
)
const CName = "common.net.drpcserver"
@ -18,49 +20,46 @@ const CName = "common.net.drpcserver"
var log = logger.NewNamed(CName)
func New() DRPCServer {
return &drpcServer{BaseDrpcServer: NewBaseDrpcServer()}
return &drpcServer{}
}
type DRPCServer interface {
app.ComponentRunnable
ServeConn(ctx context.Context, conn net.Conn) (err error)
app.Component
drpc.Mux
}
type drpcServer struct {
config anyNet.Config
metric metric.Metric
transport secureservice.SecureService
*BaseDrpcServer
drpcServer *drpcserver.Server
*drpcmux.Mux
config anyNet.Config
metric metric.Metric
}
func (s *drpcServer) Init(a *app.App) (err error) {
s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet()
s.metric = a.MustComponent(metric.CName).(metric.Metric)
s.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService)
return nil
}
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
func (s *drpcServer) Name() (name string) {
return CName
}
func (s *drpcServer) Run(ctx context.Context) (err error) {
params := Params{
BufferSizeMb: s.config.Stream.MaxMsgSizeMb,
TimeoutMillis: s.config.Stream.TimeoutMilliseconds,
ListenAddrs: s.config.Server.ListenAddrs,
Wrapper: func(handler drpc.Handler) drpc.Handler {
return s.metric.WrapDRPCHandler(handler)
},
Handshake: func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return s.transport.SecureInbound(ctx, conn)
},
func (s *drpcServer) Init(a *app.App) (err error) {
s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet()
s.metric, _ = a.Component(metric.CName).(metric.Metric)
s.Mux = drpcmux.New()
var handler drpc.Handler
handler = s
if s.metric != nil {
handler = s.metric.WrapDRPCHandler(s)
}
return s.BaseDrpcServer.Run(ctx, params)
s.drpcServer = drpcserver.NewWithOptions(handler, drpcserver.Options{Manager: drpcmanager.Options{
Reader: drpcwire.ReaderOptions{MaximumBufferSize: s.config.Stream.MaxMsgSizeMb * (1 << 20)},
}})
return
}
func (s *drpcServer) Close(ctx context.Context) (err error) {
return s.BaseDrpcServer.Close(ctx)
func (s *drpcServer) ServeConn(ctx context.Context, conn net.Conn) (err error) {
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
l.Debug("drpc serve peer")
return s.drpcServer.ServeOne(ctx, conn)
}

View File

@ -0,0 +1,125 @@
package handshake
import (
"context"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
)
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = outgoingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
localCred := cc.MakeCredentials(sc)
if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err)
return
}
msg, err := h.readMsg(msgTypeAck, msgTypeCred)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg(msgTypeAck)
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error == handshakeproto.Error_Null {
return identity, nil
} else {
_ = h.conn.Close()
return nil, HandshakeError{e: msg.ack.Error}
}
}
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = incomingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
msg, err := h.readMsg(msgTypeCred)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg(msgTypeAck)
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error != handshakeproto.Error_Null {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
return
}

View File

@ -38,15 +38,14 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake()
h.conn = c2
// receive credential message
msg, err := h.readMsg()
msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto)
require.NoError(t, err)
require.Nil(t, msg.ack)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
require.NoError(t, err)
// send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// receive ack
msg, err = h.readMsg()
msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack
@ -76,7 +75,7 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
_ = c2.Close()
res := <-handshakeResCh
@ -92,7 +91,7 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
require.NoError(t, h.writeAck(ErrInvalidCredentials.e))
res := <-handshakeResCh
@ -108,10 +107,10 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
msg, err := h.readMsg()
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
res := <-handshakeResCh
@ -127,7 +126,7 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
@ -145,12 +144,12 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read ack and close conn
_, err = h.readMsg()
_, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
_ = c2.Close()
res := <-handshakeResCh
@ -166,18 +165,17 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read ack
_, err = h.readMsg()
_, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
// write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
msg, err := h.readMsg()
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
_, err = h.readMsg(msgTypeAck)
require.Error(t, err)
res := <-handshakeResCh
require.Error(t, res.err)
})
@ -191,7 +189,7 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake()
h.conn = c2
// receive credential message
msg, err := h.readMsg()
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
require.Nil(t, msg.ack)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
@ -199,7 +197,7 @@ func TestOutgoingHandshake(t *testing.T) {
// send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// receive ack
msg, err = h.readMsg()
msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack
@ -219,7 +217,7 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
ctxCancel()
res := <-handshakeResCh
@ -244,14 +242,14 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
msg, err := h.readMsg()
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
// write ack
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
// wait ack
msg, err = h.readMsg()
msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
res := <-handshakeResCh
@ -310,7 +308,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// except ack with error
msg, err := h.readMsg()
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
require.Nil(t, msg.cred)
require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error)
@ -330,7 +328,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// except ack with error
msg, err := h.readMsg()
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
require.Nil(t, msg.cred)
require.Equal(t, handshakeproto.Error_IncompatibleVersion, msg.ack.Error)
@ -350,13 +348,13 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read cred
_, err := h.readMsg()
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// expect ack with error
msg, err := h.readMsg()
require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
// expect EOF
_, err = h.readMsg(msgTypeAck)
require.Error(t, err)
res := <-handshakeResCh
require.Error(t, res.err)
})
@ -372,7 +370,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read cred and close conn
_, err := h.readMsg()
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
_ = c2.Close()
@ -391,7 +389,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
msg, err := h.readMsg()
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
@ -413,7 +411,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
msg, err := h.readMsg()
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
@ -435,7 +433,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
msg, err := h.readMsg()
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
@ -458,7 +456,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
_, err := h.readMsg()
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
ctxCancel()
res := <-handshakeResCh
@ -482,7 +480,7 @@ func TestNotAHandshakeMessage(t *testing.T) {
_, err := c2.Write([]byte("some unexpected bytes"))
require.Error(t, err)
res := <-handshakeResCh
assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error())
assert.Error(t, res.err)
}
func TestEndToEnd(t *testing.T) {

View File

@ -1,21 +1,30 @@
package handshake
import (
"context"
"encoding/binary"
"errors"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices"
"io"
"net"
"sync"
)
const headerSize = 5 // 1 byte for type + 4 byte for uint32 size
const (
msgTypeCred = byte(1)
msgTypeAck = byte(2)
msgTypeCred = byte(1)
msgTypeAck = byte(2)
msgTypeProto = byte(3)
sizeLimit = 200 * 1024 // 200 Kb
)
var (
credMsgTypes = []byte{msgTypeCred, msgTypeAck}
protoMsgTypes = []byte{msgTypeProto, msgTypeAck}
protoMsgTypesAck = []byte{msgTypeAck}
)
type HandshakeError struct {
@ -38,17 +47,20 @@ var (
ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed}
ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected}
ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion}
ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion}
ErrIncompatibleProto = HandshakeError{e: handshakeproto.Error_IncompatibleProto}
ErrRemoteIncompatibleProto = HandshakeError{Err: errors.New("remote peer declined the proto")}
ErrGotNotAHandshakeMessage = errors.New("go not a handshake message")
ErrGotUnexpectedMessage = errors.New("go not a handshake message")
)
var handshakePool = &sync.Pool{New: func() any {
return &handshake{
remoteCred: &handshakeproto.Credentials{},
remoteAck: &handshakeproto.Ack{},
localAck: &handshakeproto.Ack{},
buf: make([]byte, 0, 1024),
remoteCred: &handshakeproto.Credentials{},
remoteAck: &handshakeproto.Ack{},
localAck: &handshakeproto.Ack{},
remoteProto: &handshakeproto.Proto{},
buf: make([]byte, 0, 1024),
}
}}
@ -57,147 +69,17 @@ type CredentialChecker interface {
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
}
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = outgoingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
localCred := cc.MakeCredentials(sc)
if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err)
return
}
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error == handshakeproto.Error_Null {
return identity, nil
} else {
_ = h.conn.Close()
return nil, HandshakeError{e: msg.ack.Error}
}
}
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = incomingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
return nil, ErrUnexpectedPayload
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error != handshakeproto.Error_Null {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
return
}
func newHandshake() *handshake {
return handshakePool.Get().(*handshake)
}
type handshake struct {
conn sec.SecureConn
remoteCred *handshakeproto.Credentials
remoteAck *handshakeproto.Ack
localAck *handshakeproto.Ack
buf []byte
conn net.Conn
remoteCred *handshakeproto.Credentials
remoteProto *handshakeproto.Proto
remoteAck *handshakeproto.Ack
localAck *handshakeproto.Ack
buf []byte
}
func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err error) {
@ -209,8 +91,17 @@ func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err erro
return h.writeData(msgTypeCred, n)
}
func (h *handshake) writeProto(proto *handshakeproto.Proto) (err error) {
h.buf = slices.Grow(h.buf, proto.Size()+headerSize)[:proto.Size()+headerSize]
n, err := proto.MarshalToSizedBuffer(h.buf[headerSize:])
if err != nil {
return err
}
return h.writeData(msgTypeProto, n)
}
func (h *handshake) tryWriteErrAndClose(err error) {
if err == ErrGotNotAHandshakeMessage {
if err == ErrUnexpectedPayload {
// if we got unexpected message - just close the connection
_ = h.conn.Close()
return
@ -243,21 +134,26 @@ func (h *handshake) writeData(tp byte, size int) (err error) {
}
type message struct {
cred *handshakeproto.Credentials
ack *handshakeproto.Ack
cred *handshakeproto.Credentials
proto *handshakeproto.Proto
ack *handshakeproto.Ack
}
func (h *handshake) readMsg() (msg message, err error) {
func (h *handshake) readMsg(allowedTypes ...byte) (msg message, err error) {
h.buf = slices.Grow(h.buf, headerSize)[:headerSize]
if _, err = io.ReadFull(h.conn, h.buf[:headerSize]); err != nil {
return
}
tp := h.buf[0]
if tp != msgTypeCred && tp != msgTypeAck {
err = ErrGotNotAHandshakeMessage
if !slices.Contains(allowedTypes, tp) {
err = ErrUnexpectedPayload
return
}
size := binary.LittleEndian.Uint32(h.buf[1:headerSize])
if size > sizeLimit {
err = ErrGotUnexpectedMessage
return
}
h.buf = slices.Grow(h.buf, int(size))[:size]
if _, err = io.ReadFull(h.conn, h.buf[:size]); err != nil {
return
@ -273,6 +169,11 @@ func (h *handshake) readMsg() (msg message, err error) {
return
}
msg.ack = h.remoteAck
case msgTypeProto:
if err = h.remoteProto.Unmarshal(h.buf[:size]); err != nil {
return
}
msg.proto = h.remoteProto
}
return
}
@ -284,5 +185,6 @@ func (h *handshake) release() {
h.remoteAck.Error = 0
h.remoteCred.Type = 0
h.remoteCred.Payload = h.remoteCred.Payload[:0]
h.remoteProto.Proto = 0
handshakePool.Put(h)
}

View File

@ -59,6 +59,7 @@ const (
Error_SkipVerifyNotAllowed Error = 4
Error_DeadlineExceeded Error = 5
Error_IncompatibleVersion Error = 6
Error_IncompatibleProto Error = 7
)
var Error_name = map[int32]string{
@ -69,6 +70,7 @@ var Error_name = map[int32]string{
4: "SkipVerifyNotAllowed",
5: "DeadlineExceeded",
6: "IncompatibleVersion",
7: "IncompatibleProto",
}
var Error_value = map[string]int32{
@ -79,6 +81,7 @@ var Error_value = map[string]int32{
"SkipVerifyNotAllowed": 4,
"DeadlineExceeded": 5,
"IncompatibleVersion": 6,
"IncompatibleProto": 7,
}
func (x Error) String() string {
@ -89,6 +92,28 @@ func (Error) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{1}
}
type ProtoType int32
const (
ProtoType_DRPC ProtoType = 0
)
var ProtoType_name = map[int32]string{
0: "DRPC",
}
var ProtoType_value = map[string]int32{
"DRPC": 0,
}
func (x ProtoType) String() string {
return proto.EnumName(ProtoType_name, int32(x))
}
func (ProtoType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{2}
}
type Credentials struct {
Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"`
Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
@ -247,12 +272,58 @@ func (m *Ack) GetError() Error {
return Error_Null
}
type Proto struct {
Proto ProtoType `protobuf:"varint,1,opt,name=proto,proto3,enum=anyHandshake.ProtoType" json:"proto,omitempty"`
}
func (m *Proto) Reset() { *m = Proto{} }
func (m *Proto) String() string { return proto.CompactTextString(m) }
func (*Proto) ProtoMessage() {}
func (*Proto) Descriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{3}
}
func (m *Proto) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Proto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Proto.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Proto) XXX_Merge(src proto.Message) {
xxx_messageInfo_Proto.Merge(m, src)
}
func (m *Proto) XXX_Size() int {
return m.Size()
}
func (m *Proto) XXX_DiscardUnknown() {
xxx_messageInfo_Proto.DiscardUnknown(m)
}
var xxx_messageInfo_Proto proto.InternalMessageInfo
func (m *Proto) GetProto() ProtoType {
if m != nil {
return m.Proto
}
return ProtoType_DRPC
}
func init() {
proto.RegisterEnum("anyHandshake.CredentialsType", CredentialsType_name, CredentialsType_value)
proto.RegisterEnum("anyHandshake.Error", Error_name, Error_value)
proto.RegisterEnum("anyHandshake.ProtoType", ProtoType_name, ProtoType_value)
proto.RegisterType((*Credentials)(nil), "anyHandshake.Credentials")
proto.RegisterType((*PayloadSignedPeerIds)(nil), "anyHandshake.PayloadSignedPeerIds")
proto.RegisterType((*Ack)(nil), "anyHandshake.Ack")
proto.RegisterType((*Proto)(nil), "anyHandshake.Proto")
}
func init() {
@ -260,32 +331,35 @@ func init() {
}
var fileDescriptor_60283fc75f020893 = []byte{
// 395 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcd, 0x6e, 0x13, 0x31,
0x10, 0xc7, 0xd7, 0x4d, 0x52, 0xaa, 0x21, 0x2d, 0xee, 0x34, 0xc0, 0x0a, 0x89, 0x55, 0x94, 0x53,
0xc8, 0x21, 0xe1, 0xeb, 0x05, 0x02, 0x2d, 0x22, 0x97, 0xaa, 0xda, 0x42, 0x0f, 0xdc, 0xdc, 0xf5,
0xd0, 0x5a, 0x31, 0xf6, 0xca, 0x76, 0x43, 0xf7, 0x2d, 0xb8, 0xf2, 0x46, 0x1c, 0x7b, 0xe4, 0x88,
0x92, 0x17, 0x41, 0x71, 0x12, 0x92, 0x70, 0xea, 0xc5, 0x9e, 0x8f, 0x9f, 0xfd, 0xff, 0x8f, 0x65,
0x18, 0x1a, 0x0a, 0x03, 0x4f, 0xc5, 0x8d, 0x23, 0x4f, 0x6e, 0xa2, 0x0a, 0x1a, 0x5c, 0x0b, 0x23,
0xfd, 0xb5, 0x18, 0x6f, 0x44, 0xa5, 0xb3, 0xc1, 0x0e, 0xe2, 0xea, 0xd7, 0xd5, 0x7e, 0x2c, 0x60,
0x53, 0x98, 0xea, 0xe3, 0xaa, 0xd6, 0x09, 0xf0, 0xf0, 0xbd, 0x23, 0x49, 0x26, 0x28, 0xa1, 0x3d,
0xbe, 0x82, 0x7a, 0xa8, 0x4a, 0x4a, 0x59, 0x9b, 0x75, 0x0f, 0x5e, 0x3f, 0xef, 0x6f, 0xb2, 0xfd,
0x0d, 0xf0, 0x53, 0x55, 0x52, 0x1e, 0x51, 0x4c, 0xe1, 0x41, 0x29, 0x2a, 0x6d, 0x85, 0x4c, 0x77,
0xda, 0xac, 0xdb, 0xcc, 0x57, 0xe9, 0xbc, 0x33, 0x21, 0xe7, 0x95, 0x35, 0x69, 0xad, 0xcd, 0xba,
0xfb, 0xf9, 0x2a, 0xed, 0x7c, 0x80, 0xd6, 0xd9, 0x02, 0x3a, 0x57, 0x57, 0x86, 0xe4, 0x19, 0x91,
0x1b, 0x49, 0x8f, 0xcf, 0x60, 0x4f, 0x45, 0x89, 0x50, 0x45, 0x0b, 0xcd, 0xfc, 0x5f, 0x8e, 0x08,
0x75, 0xaf, 0xae, 0xcc, 0x52, 0x24, 0xc6, 0x9d, 0x97, 0x50, 0x1b, 0x16, 0x63, 0x7c, 0x01, 0x0d,
0x72, 0xce, 0xba, 0xa5, 0xed, 0xa3, 0x6d, 0xdb, 0x27, 0xf3, 0x56, 0xbe, 0x20, 0x7a, 0x6f, 0xe1,
0xd1, 0x7f, 0x63, 0xe0, 0x01, 0xc0, 0xf9, 0x58, 0x95, 0x17, 0xe4, 0xd4, 0xd7, 0x8a, 0x27, 0x78,
0x08, 0xfb, 0x5b, 0xae, 0x38, 0xeb, 0xfd, 0x64, 0xd0, 0x88, 0xd7, 0xe0, 0x1e, 0xd4, 0x4f, 0x6f,
0xb4, 0xe6, 0xc9, 0xfc, 0xd8, 0x67, 0x43, 0xb7, 0x25, 0x15, 0x81, 0x24, 0x67, 0xf8, 0x04, 0x70,
0x64, 0x26, 0x42, 0x2b, 0xb9, 0x21, 0xc0, 0x77, 0xf0, 0x31, 0x1c, 0xae, 0xb9, 0xe5, 0xd4, 0xbc,
0x86, 0x29, 0xb4, 0xd6, 0xaa, 0xa7, 0x36, 0x0c, 0xb5, 0xb6, 0xdf, 0x49, 0xf2, 0x3a, 0xb6, 0x80,
0x1f, 0x93, 0x90, 0x5a, 0x19, 0x3a, 0xb9, 0x2d, 0x88, 0x24, 0x49, 0xde, 0xc0, 0xa7, 0x70, 0x34,
0x32, 0x85, 0xfd, 0x56, 0x8a, 0xa0, 0x2e, 0x35, 0x5d, 0x2c, 0x5e, 0x92, 0xef, 0xbe, 0x3b, 0xfe,
0x35, 0xcd, 0xd8, 0xdd, 0x34, 0x63, 0x7f, 0xa6, 0x19, 0xfb, 0x31, 0xcb, 0x92, 0xbb, 0x59, 0x96,
0xfc, 0x9e, 0x65, 0xc9, 0x97, 0xde, 0xfd, 0x3f, 0xcb, 0xe5, 0x6e, 0xdc, 0xde, 0xfc, 0x0d, 0x00,
0x00, 0xff, 0xff, 0xbf, 0x78, 0x2f, 0x36, 0x61, 0x02, 0x00, 0x00,
// 439 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcf, 0x6e, 0xd3, 0x40,
0x10, 0xc6, 0xbd, 0x4d, 0xdc, 0x96, 0x21, 0x2d, 0xdb, 0x6d, 0x4a, 0x2d, 0x24, 0xac, 0x28, 0xa7,
0x10, 0x89, 0x84, 0x7f, 0xe2, 0x1e, 0x9a, 0x22, 0x72, 0xa9, 0x22, 0x17, 0x7a, 0xe0, 0xb6, 0xf5,
0x0e, 0xed, 0x2a, 0xcb, 0xda, 0xda, 0xdd, 0x86, 0xfa, 0x2d, 0x78, 0x14, 0x1e, 0x83, 0x63, 0x8f,
0x1c, 0x51, 0xf2, 0x22, 0xc8, 0x9b, 0xa4, 0x71, 0x38, 0xf5, 0x62, 0xef, 0xcc, 0xfc, 0xc6, 0xdf,
0x7c, 0xe3, 0x85, 0x81, 0x46, 0xd7, 0xb7, 0x98, 0xde, 0x18, 0xb4, 0x68, 0xa6, 0x32, 0xc5, 0xfe,
0x35, 0xd7, 0xc2, 0x5e, 0xf3, 0x49, 0xe5, 0x94, 0x9b, 0xcc, 0x65, 0x7d, 0xff, 0xb4, 0xeb, 0x6c,
0xcf, 0x27, 0x58, 0x83, 0xeb, 0xe2, 0xd3, 0x2a, 0xd7, 0x76, 0xf0, 0xf8, 0xc4, 0xa0, 0x40, 0xed,
0x24, 0x57, 0x96, 0xbd, 0x86, 0xba, 0x2b, 0x72, 0x8c, 0x48, 0x8b, 0x74, 0xf6, 0xdf, 0x3c, 0xef,
0x55, 0xd9, 0x5e, 0x05, 0xfc, 0x5c, 0xe4, 0x98, 0x78, 0x94, 0x45, 0xb0, 0x93, 0xf3, 0x42, 0x65,
0x5c, 0x44, 0x5b, 0x2d, 0xd2, 0x69, 0x24, 0xab, 0xb0, 0xac, 0x4c, 0xd1, 0x58, 0x99, 0xe9, 0xa8,
0xd6, 0x22, 0x9d, 0xbd, 0x64, 0x15, 0xb6, 0x3f, 0x42, 0x73, 0xbc, 0x80, 0xce, 0xe5, 0x95, 0x46,
0x31, 0x46, 0x34, 0x23, 0x61, 0xd9, 0x33, 0xd8, 0x95, 0x5e, 0xc2, 0x15, 0x7e, 0x84, 0x46, 0x72,
0x1f, 0x33, 0x06, 0x75, 0x2b, 0xaf, 0xf4, 0x52, 0xc4, 0x9f, 0xdb, 0xaf, 0xa0, 0x36, 0x48, 0x27,
0xec, 0x05, 0x84, 0x68, 0x4c, 0x66, 0x96, 0x63, 0x1f, 0x6e, 0x8e, 0x7d, 0x5a, 0x96, 0x92, 0x05,
0xd1, 0x7e, 0x0f, 0xe1, 0xd8, 0xaf, 0xe1, 0x25, 0x84, 0x7e, 0x1f, 0xcb, 0x9e, 0xe3, 0xcd, 0x1e,
0xcf, 0x78, 0x93, 0x0b, 0xaa, 0xfb, 0x0e, 0x9e, 0xfc, 0x67, 0x9f, 0xed, 0x03, 0x9c, 0x4f, 0x64,
0x7e, 0x81, 0x46, 0x7e, 0x2b, 0x68, 0xc0, 0x0e, 0x60, 0x6f, 0xc3, 0x0d, 0x25, 0xdd, 0x5f, 0x04,
0x42, 0x2f, 0xcf, 0x76, 0xa1, 0x7e, 0x76, 0xa3, 0x14, 0x0d, 0xca, 0xb6, 0x2f, 0x1a, 0x6f, 0x73,
0x4c, 0x1d, 0x0a, 0x4a, 0xd8, 0x53, 0x60, 0x23, 0x3d, 0xe5, 0x4a, 0x8a, 0x8a, 0x00, 0xdd, 0x62,
0x47, 0x70, 0xb0, 0xe6, 0x96, 0xdb, 0xa2, 0x35, 0x16, 0x41, 0x73, 0xad, 0x7a, 0x96, 0xb9, 0x81,
0x52, 0xd9, 0x0f, 0x14, 0xb4, 0xce, 0x9a, 0x40, 0x87, 0xc8, 0x85, 0x92, 0x1a, 0x4f, 0x6f, 0x53,
0x44, 0x81, 0x82, 0x86, 0xec, 0x18, 0x0e, 0x47, 0x3a, 0xcd, 0xbe, 0xe7, 0xdc, 0xc9, 0x4b, 0x85,
0x17, 0x8b, 0x3f, 0x40, 0xb7, 0xcb, 0xef, 0x57, 0x0b, 0xde, 0x31, 0xdd, 0xe9, 0x1e, 0xc1, 0xa3,
0x7b, 0xf3, 0xe5, 0xd4, 0xc3, 0x64, 0x7c, 0x42, 0x83, 0x0f, 0xc3, 0xdf, 0xb3, 0x98, 0xdc, 0xcd,
0x62, 0xf2, 0x77, 0x16, 0x93, 0x9f, 0xf3, 0x38, 0xb8, 0x9b, 0xc7, 0xc1, 0x9f, 0x79, 0x1c, 0x7c,
0xed, 0x3e, 0xfc, 0x4a, 0x5e, 0x6e, 0xfb, 0xd7, 0xdb, 0x7f, 0x01, 0x00, 0x00, 0xff, 0xff, 0x53,
0x32, 0xf7, 0x79, 0xc7, 0x02, 0x00, 0x00,
}
func (m *Credentials) Marshal() (dAtA []byte, err error) {
@ -393,6 +467,34 @@ func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) {
return len(dAtA) - i, nil
}
func (m *Proto) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Proto) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Proto) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if m.Proto != 0 {
i = encodeVarintHandshake(dAtA, i, uint64(m.Proto))
i--
dAtA[i] = 0x8
}
return len(dAtA) - i, nil
}
func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int {
offset -= sovHandshake(v)
base := offset
@ -452,6 +554,18 @@ func (m *Ack) Size() (n int) {
return n
}
func (m *Proto) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
if m.Proto != 0 {
n += 1 + sovHandshake(uint64(m.Proto))
}
return n
}
func sovHandshake(x uint64) (n int) {
return (math_bits.Len64(x|1) + 6) / 7
}
@ -767,6 +881,75 @@ func (m *Ack) Unmarshal(dAtA []byte) error {
}
return nil
}
func (m *Proto) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Proto: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Proto: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Proto", wireType)
}
m.Proto = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Proto |= ProtoType(b&0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := skipHandshake(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return ErrInvalidLengthHandshake
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipHandshake(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0

View File

@ -5,6 +5,8 @@ option go_package = "net/secureservice/handshake/handshakeproto";
/*
CREDENTIALS HANDSHAKE
Alice opens a new connection with Bob
1. TLS handshake done successfully; both sides know local and remote peer identifiers.
@ -68,4 +70,20 @@ enum Error {
SkipVerifyNotAllowed = 4;
DeadlineExceeded = 5;
IncompatibleVersion = 6;
IncompatibleProto = 7;
}
/*
PROTO HANDSHAKE
*/
message Proto {
ProtoType proto = 1;
}
enum ProtoType {
DRPC = 0;
}

View File

@ -0,0 +1,97 @@
package handshake
import (
"context"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"golang.org/x/exp/slices"
"net"
)
type ProtoChecker struct {
AllowedProtoTypes []handshakeproto.ProtoType
}
func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) (err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
err = outgoingProtoHandshake(h, conn, pt)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = conn.Close()
return ctx.Err()
}
}
func outgoingProtoHandshake(h *handshake, conn net.Conn, pt handshakeproto.ProtoType) (err error) {
defer h.release()
h.conn = conn
localProto := &handshakeproto.Proto{
Proto: pt,
}
if err = h.writeProto(localProto); err != nil {
h.tryWriteErrAndClose(err)
return
}
msg, err := h.readMsg(msgTypeAck)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack.Error == handshakeproto.Error_IncompatibleProto {
return ErrRemoteIncompatibleProto
}
if msg.ack.Error == handshakeproto.Error_Null {
return nil
}
return HandshakeError{e: msg.ack.Error}
}
func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
protoType, err = incomingProtoHandshake(h, conn, pt)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = conn.Close()
return 0, ctx.Err()
}
}
func incomingProtoHandshake(h *handshake, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
defer h.release()
h.conn = conn
msg, err := h.readMsg(msgTypeProto)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if !slices.Contains(pt.AllowedProtoTypes, msg.proto.Proto) {
err = ErrIncompatibleProto
h.tryWriteErrAndClose(err)
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return 0, err
} else {
return msg.proto.Proto, nil
}
}

View File

@ -0,0 +1,121 @@
package handshake
import (
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
)
type protoRes struct {
protoType handshakeproto.ProtoType
err error
}
func newProtoChecker(types ...handshakeproto.ProtoType) ProtoChecker {
return ProtoChecker{AllowedProtoTypes: types}
}
func TestIncomingProtoHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1))
protoResCh <- protoRes{protoType: protoType, err: err}
}()
h := newHandshake()
h.conn = c2
// write desired proto
require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: handshakeproto.ProtoType(1)}))
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
res := <-protoResCh
require.NoError(t, res.err)
assert.Equal(t, handshakeproto.ProtoType(1), res.protoType)
})
t.Run("incompatible", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1))
protoResCh <- protoRes{protoType: protoType, err: err}
}()
h := newHandshake()
h.conn = c2
// write desired proto
require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: 0}))
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_IncompatibleProto, msg.ack.Error)
res := <-protoResCh
require.Error(t, res.err, ErrIncompatibleProto.Error())
})
}
func TestOutgoingProtoHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
err := OutgoingProtoHandshake(nil, c1, 1)
protoResCh <- protoRes{err: err}
}()
h := newHandshake()
h.conn = c2
msg, err := h.readMsg(msgTypeProto)
require.NoError(t, err)
assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto)
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
res := <-protoResCh
assert.NoError(t, res.err)
})
t.Run("incompatible", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
err := OutgoingProtoHandshake(nil, c1, 1)
protoResCh <- protoRes{err: err}
}()
h := newHandshake()
h.conn = c2
msg, err := h.readMsg(msgTypeProto)
require.NoError(t, err)
assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto)
require.NoError(t, h.writeAck(handshakeproto.Error_IncompatibleProto))
res := <-protoResCh
assert.EqualError(t, res.err, ErrRemoteIncompatibleProto.Error())
})
}
func TestEndToEndProto(t *testing.T) {
c1, c2 := newConnPair(t)
var (
inResCh = make(chan protoRes, 1)
outResCh = make(chan protoRes, 1)
)
st := time.Now()
go func() {
err := OutgoingProtoHandshake(nil, c1, 0)
outResCh <- protoRes{err: err}
}()
go func() {
protoType, err := IncomingProtoHandshake(nil, c2, newProtoChecker(0, 1))
inResCh <- protoRes{protoType: protoType, err: err}
}()
outRes := <-outResCh
assert.NoError(t, outRes.err)
inRes := <-inResCh
assert.NoError(t, inRes.err)
assert.Equal(t, handshakeproto.ProtoType(0), inRes.protoType)
t.Log("dur", time.Since(st))
}

View File

@ -25,7 +25,7 @@ func New() SecureService {
}
type SecureService interface {
SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
app.Component
}
@ -93,10 +93,10 @@ func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx
return
}
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
sc, err = s.p2pTr.SecureOutbound(ctx, conn, "")
if err != nil {
return nil, handshake.HandshakeError{Err: err}
return nil, nil, handshake.HandshakeError{Err: err}
}
peerId := sc.RemotePeer().String()
confTypes := s.nodeconf.NodeTypes(peerId)
@ -106,10 +106,12 @@ func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.
} else {
checker = s.noVerifyChecker
}
// ignore identity for outgoing connection because we don't need it at this moment
_, err = handshake.OutgoingHandshake(ctx, sc, checker)
identity, err := handshake.OutgoingHandshake(ctx, sc, checker)
if err != nil {
return nil, err
return nil, nil, err
}
return sc, nil
cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
cctx = peer.CtxWithIdentity(cctx, identity)
return cctx, sc, nil
}

View File

@ -39,9 +39,12 @@ func TestHandshake(t *testing.T) {
fxC := newFixture(t, nc, nc.GetAccountService(1), 0)
defer fxC.Finish(t)
secConn, err := fxC.SecureOutbound(ctx, cc)
cctx, secConn, err := fxC.SecureOutbound(ctx, cc)
require.NoError(t, err)
ctxPeerId, err := peer.CtxPeerId(cctx)
require.NoError(t, err)
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, ctxPeerId)
res := <-resCh
require.NoError(t, res.err)
peerId, err := peer.CtxPeerId(res.ctx)
@ -72,7 +75,7 @@ func TestHandshakeIncompatibleVersion(t *testing.T) {
}()
fxC := newFixture(t, nc, nc.GetAccountService(1), 1)
defer fxC.Finish(t)
_, err := fxC.SecureOutbound(ctx, cc)
_, _, err := fxC.SecureOutbound(ctx, cc)
require.Equal(t, handshake.ErrIncompatibleVersion, err)
res := <-resCh
require.Equal(t, handshake.ErrIncompatibleVersion, res.err)

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32
// protoc-gen-go-drpc version: v0.0.33
// source: net/streampool/testservice/protos/testservice.proto
package testservice
@ -72,6 +72,10 @@ type drpcTest_TestStreamClient struct {
drpc.Stream
}
func (x *drpcTest_TestStreamClient) GetStream() drpc.Stream {
return x.Stream
}
func (x *drpcTest_TestStreamClient) Send(m *StreamMessage) error {
return x.MsgSend(m, drpcEncoding_File_net_streampool_testservice_protos_testservice_proto{})
}

View File

@ -0,0 +1,188 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/net/transport (interfaces: Transport,MultiConn)
// Package mock_transport is a generated GoMock package.
package mock_transport
import (
context "context"
net "net"
reflect "reflect"
time "time"
transport "github.com/anyproto/any-sync/net/transport"
gomock "github.com/golang/mock/gomock"
)
// MockTransport is a mock of Transport interface.
type MockTransport struct {
ctrl *gomock.Controller
recorder *MockTransportMockRecorder
}
// MockTransportMockRecorder is the mock recorder for MockTransport.
type MockTransportMockRecorder struct {
mock *MockTransport
}
// NewMockTransport creates a new mock instance.
func NewMockTransport(ctrl *gomock.Controller) *MockTransport {
mock := &MockTransport{ctrl: ctrl}
mock.recorder = &MockTransportMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTransport) EXPECT() *MockTransportMockRecorder {
return m.recorder
}
// Dial mocks base method.
func (m *MockTransport) Dial(arg0 context.Context, arg1 string) (transport.MultiConn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Dial", arg0, arg1)
ret0, _ := ret[0].(transport.MultiConn)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Dial indicates an expected call of Dial.
func (mr *MockTransportMockRecorder) Dial(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockTransport)(nil).Dial), arg0, arg1)
}
// SetAccepter mocks base method.
func (m *MockTransport) SetAccepter(arg0 transport.Accepter) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetAccepter", arg0)
}
// SetAccepter indicates an expected call of SetAccepter.
func (mr *MockTransportMockRecorder) SetAccepter(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccepter", reflect.TypeOf((*MockTransport)(nil).SetAccepter), arg0)
}
// MockMultiConn is a mock of MultiConn interface.
type MockMultiConn struct {
ctrl *gomock.Controller
recorder *MockMultiConnMockRecorder
}
// MockMultiConnMockRecorder is the mock recorder for MockMultiConn.
type MockMultiConnMockRecorder struct {
mock *MockMultiConn
}
// NewMockMultiConn creates a new mock instance.
func NewMockMultiConn(ctrl *gomock.Controller) *MockMultiConn {
mock := &MockMultiConn{ctrl: ctrl}
mock.recorder = &MockMultiConnMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockMultiConn) EXPECT() *MockMultiConnMockRecorder {
return m.recorder
}
// Accept mocks base method.
func (m *MockMultiConn) Accept() (net.Conn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Accept")
ret0, _ := ret[0].(net.Conn)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Accept indicates an expected call of Accept.
func (mr *MockMultiConnMockRecorder) Accept() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockMultiConn)(nil).Accept))
}
// Addr mocks base method.
func (m *MockMultiConn) Addr() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Addr")
ret0, _ := ret[0].(string)
return ret0
}
// Addr indicates an expected call of Addr.
func (mr *MockMultiConnMockRecorder) Addr() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockMultiConn)(nil).Addr))
}
// Close mocks base method.
func (m *MockMultiConn) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockMultiConnMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiConn)(nil).Close))
}
// Context mocks base method.
func (m *MockMultiConn) Context() context.Context {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Context")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Context indicates an expected call of Context.
func (mr *MockMultiConnMockRecorder) Context() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockMultiConn)(nil).Context))
}
// IsClosed mocks base method.
func (m *MockMultiConn) IsClosed() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClosed")
ret0, _ := ret[0].(bool)
return ret0
}
// IsClosed indicates an expected call of IsClosed.
func (mr *MockMultiConnMockRecorder) IsClosed() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiConn)(nil).IsClosed))
}
// LastUsage mocks base method.
func (m *MockMultiConn) LastUsage() time.Time {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LastUsage")
ret0, _ := ret[0].(time.Time)
return ret0
}
// LastUsage indicates an expected call of LastUsage.
func (mr *MockMultiConnMockRecorder) LastUsage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastUsage", reflect.TypeOf((*MockMultiConn)(nil).LastUsage))
}
// Open mocks base method.
func (m *MockMultiConn) Open(arg0 context.Context) (net.Conn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Open", arg0)
ret0, _ := ret[0].(net.Conn)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Open indicates an expected call of Open.
func (mr *MockMultiConnMockRecorder) Open(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockMultiConn)(nil).Open), arg0)
}

View File

@ -0,0 +1,44 @@
//go:generate mockgen -destination mock_transport/mock_transport.go github.com/anyproto/any-sync/net/transport Transport,MultiConn
package transport
import (
"context"
"errors"
"net"
"time"
)
var (
ErrConnClosed = errors.New("connection closed")
)
// Transport is a common interface for a network transport
type Transport interface {
// SetAccepter sets accepter that will be called for new connections
// this method should be called before app start
SetAccepter(accepter Accepter)
// Dial creates a new connection by given address
Dial(ctx context.Context, addr string) (mc MultiConn, err error)
}
// MultiConn is an object of multiplexing connection containing handshake info
type MultiConn interface {
// Context returns the connection context that contains handshake details
Context() context.Context
// Accept accepts new sub connections
Accept() (conn net.Conn, err error)
// Open opens new sub connection
Open(ctx context.Context) (conn net.Conn, err error)
// LastUsage returns the time of the last connection activity
LastUsage() time.Time
// Addr returns remote peer address
Addr() string
// IsClosed returns true when connection is closed
IsClosed() bool
// Close closes the connection and all sub connections
Close() error
}
type Accepter interface {
Accept(mc MultiConn) (err error)
}

View File

@ -0,0 +1,12 @@
package yamux
type configGetter interface {
GetYamux() Config
}
type Config struct {
ListenAddrs []string `yaml:"listenAddrs"`
WriteTimeoutSec int `yaml:"writeTimeoutSec"`
DialTimeoutSec int `yaml:"dialTimeoutSec"`
MaxStreams int `yaml:"maxStreams"`
}

View File

@ -0,0 +1,42 @@
package yamux
import (
"context"
"github.com/anyproto/any-sync/net/connutil"
"github.com/anyproto/any-sync/net/transport"
"github.com/hashicorp/yamux"
"net"
"time"
)
type yamuxConn struct {
ctx context.Context
luConn *connutil.LastUsageConn
addr string
*yamux.Session
}
func (y *yamuxConn) Open(ctx context.Context) (conn net.Conn, err error) {
return y.Session.Open()
}
func (y *yamuxConn) LastUsage() time.Time {
return y.luConn.LastUsage()
}
func (y *yamuxConn) Context() context.Context {
return y.ctx
}
func (y *yamuxConn) Addr() string {
return y.addr
}
func (y *yamuxConn) Accept() (conn net.Conn, err error) {
if conn, err = y.Session.Accept(); err != nil {
if err == yamux.ErrSessionShutdown {
err = transport.ErrConnClosed
}
}
return
}

View File

@ -1,6 +1,6 @@
//go:build !windows
package server
package yamux
import (
"errors"

View File

@ -1,6 +1,6 @@
//go:build windows
package server
package yamux
import (
"errors"

View File

@ -0,0 +1,170 @@
package yamux
import (
"context"
"fmt"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/net/connutil"
"github.com/anyproto/any-sync/net/secureservice"
"github.com/anyproto/any-sync/net/transport"
"github.com/hashicorp/yamux"
"go.uber.org/zap"
"net"
"time"
)
const CName = "net.transport.yamux"
var log = logger.NewNamed(CName)
func New() Yamux {
return new(yamuxTransport)
}
// Yamux implements transport.Transport with tcp+yamux
type Yamux interface {
transport.Transport
app.ComponentRunnable
}
type yamuxTransport struct {
secure secureservice.SecureService
accepter transport.Accepter
conf Config
listeners []net.Listener
listCtx context.Context
listCtxCancel context.CancelFunc
yamuxConf *yamux.Config
}
func (y *yamuxTransport) Init(a *app.App) (err error) {
y.secure = a.MustComponent(secureservice.CName).(secureservice.SecureService)
y.conf = a.MustComponent("config").(configGetter).GetYamux()
y.yamuxConf = yamux.DefaultConfig()
if y.conf.MaxStreams > 0 {
y.yamuxConf.AcceptBacklog = y.conf.MaxStreams
}
y.yamuxConf.EnableKeepAlive = false
y.yamuxConf.StreamOpenTimeout = time.Duration(y.conf.DialTimeoutSec) * time.Second
y.yamuxConf.ConnectionWriteTimeout = time.Duration(y.conf.WriteTimeoutSec) * time.Second
return
}
func (y *yamuxTransport) Name() string {
return CName
}
func (y *yamuxTransport) Run(ctx context.Context) (err error) {
if y.accepter == nil {
return fmt.Errorf("can't run service without accepter")
}
for _, listAddr := range y.conf.ListenAddrs {
list, err := net.Listen("tcp", listAddr)
if err != nil {
return err
}
y.listeners = append(y.listeners, list)
}
y.listCtx, y.listCtxCancel = context.WithCancel(context.Background())
for _, list := range y.listeners {
go y.acceptLoop(y.listCtx, list)
}
return
}
func (y *yamuxTransport) SetAccepter(accepter transport.Accepter) {
y.accepter = accepter
}
func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.MultiConn, err error) {
dialTimeout := time.Duration(y.conf.DialTimeoutSec) * time.Second
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
if err != nil {
return nil, err
}
ctx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
cctx, sc, err := y.secure.SecureOutbound(ctx, conn)
if err != nil {
_ = conn.Close()
return nil, err
}
luc := connutil.NewLastUsageConn(sc)
sess, err := yamux.Client(luc, y.yamuxConf)
if err != nil {
return
}
mc = &yamuxConn{
ctx: cctx,
luConn: luc,
Session: sess,
addr: addr,
}
return
}
func (y *yamuxTransport) acceptLoop(ctx context.Context, list net.Listener) {
l := log.With(zap.String("localAddr", list.Addr().String()))
l.Info("yamux listener started")
defer func() {
l.Debug("yamux listener stopped")
}()
for {
conn, err := list.Accept()
if err != nil {
if isTemporary(err) {
l.Debug("listener temporary accept error", zap.Error(err))
select {
case <-time.After(time.Second):
case <-ctx.Done():
return
}
continue
}
if err != net.ErrClosed {
l.Error("listener closed with error", zap.Error(err))
} else {
l.Info("listener closed")
}
return
}
go y.accept(conn)
}
}
func (y *yamuxTransport) accept(conn net.Conn) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(y.conf.DialTimeoutSec)*time.Second)
defer cancel()
cctx, sc, err := y.secure.SecureInbound(ctx, conn)
if err != nil {
log.Warn("incoming connection handshake error", zap.Error(err))
return
}
luc := connutil.NewLastUsageConn(sc)
sess, err := yamux.Server(luc, y.yamuxConf)
if err != nil {
log.Warn("incoming connection yamux session error", zap.Error(err))
return
}
mc := &yamuxConn{
ctx: cctx,
luConn: luc,
Session: sess,
addr: conn.RemoteAddr().String(),
}
if err = y.accepter.Accept(mc); err != nil {
log.Warn("connection accept error", zap.Error(err))
}
}
func (y *yamuxTransport) Close(ctx context.Context) (err error) {
if y.listCtxCancel != nil {
y.listCtxCancel()
}
for _, l := range y.listeners {
_ = l.Close()
}
return
}

View File

@ -0,0 +1,134 @@
package yamux
import (
"bytes"
"context"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/net/secureservice"
"github.com/anyproto/any-sync/net/transport"
"github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/nodeconf/mock_nodeconf"
"github.com/anyproto/any-sync/testutil/accounttest"
"github.com/anyproto/any-sync/testutil/testnodeconf"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"testing"
)
var ctx = context.Background()
func TestYamuxTransport_Dial(t *testing.T) {
fxS := newFixture(t)
defer fxS.finish(t)
fxC := newFixture(t)
defer fxC.finish(t)
mcC, err := fxC.Dial(ctx, fxS.addr)
require.NoError(t, err)
require.Len(t, fxS.accepter.mcs, 1)
mcS := fxS.accepter.mcs[0]
var (
sData string
acceptErr error
copyErr error
done = make(chan struct{})
)
go func() {
defer close(done)
conn, serr := mcS.Accept()
if serr != nil {
acceptErr = serr
return
}
buf := bytes.NewBuffer(nil)
_, copyErr = io.Copy(buf, conn)
sData = buf.String()
return
}()
conn, err := mcC.Open(ctx)
require.NoError(t, err)
data := "some data"
_, err = conn.Write([]byte(data))
require.NoError(t, err)
require.NoError(t, conn.Close())
<-done
assert.NoError(t, acceptErr)
assert.Equal(t, data, sData)
assert.NoError(t, copyErr)
}
type fixture struct {
*yamuxTransport
a *app.App
ctrl *gomock.Controller
mockNodeConf *mock_nodeconf.MockService
acc *accounttest.AccountTestService
accepter *testAccepter
addr string
}
func newFixture(t *testing.T) *fixture {
fx := &fixture{
yamuxTransport: New().(*yamuxTransport),
ctrl: gomock.NewController(t),
acc: &accounttest.AccountTestService{},
accepter: &testAccepter{},
a: new(app.App),
}
fx.mockNodeConf = mock_nodeconf.NewMockService(fx.ctrl)
fx.mockNodeConf.EXPECT().Init(gomock.Any())
fx.mockNodeConf.EXPECT().Name().Return(nodeconf.CName).AnyTimes()
fx.mockNodeConf.EXPECT().Run(ctx)
fx.mockNodeConf.EXPECT().Close(ctx)
fx.mockNodeConf.EXPECT().NodeTypes(gomock.Any()).Return([]nodeconf.NodeType{nodeconf.NodeTypeTree}).AnyTimes()
fx.a.Register(fx.acc).Register(newTestConf()).Register(fx.mockNodeConf).Register(secureservice.New()).Register(fx.yamuxTransport).Register(fx.accepter)
require.NoError(t, fx.a.Start(ctx))
fx.addr = fx.listeners[0].Addr().String()
return fx
}
func (fx *fixture) finish(t *testing.T) {
require.NoError(t, fx.a.Close(ctx))
fx.ctrl.Finish()
}
func newTestConf() *testConf {
return &testConf{testnodeconf.GenNodeConfig(1)}
}
type testConf struct {
*testnodeconf.Config
}
func (c *testConf) GetYamux() Config {
return Config{
ListenAddrs: []string{"127.0.0.1:0"},
WriteTimeoutSec: 10,
DialTimeoutSec: 10,
MaxStreams: 1024,
}
}
type testAccepter struct {
err error
mcs []transport.MultiConn
}
func (t *testAccepter) Accept(mc transport.MultiConn) (err error) {
t.mcs = append(t.mcs, mc)
return t.err
}
func (t *testAccepter) Init(a *app.App) (err error) {
a.MustComponent(CName).(transport.Transport).SetAccepter(t)
return nil
}
func (t *testAccepter) Name() (name string) { return "testAccepter" }