Merge remote-tracking branch 'origin/yamux' into new-sync-protocol
This commit is contained in:
commit
248205cddd
@ -1,5 +1,5 @@
|
||||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.32
|
||||
// protoc-gen-go-drpc version: v0.0.33
|
||||
// source: commonfile/fileproto/protos/file.proto
|
||||
|
||||
package fileproto
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.32
|
||||
// protoc-gen-go-drpc version: v0.0.33
|
||||
// source: commonspace/spacesyncproto/protos/spacesync.proto
|
||||
|
||||
package spacesyncproto
|
||||
@ -103,6 +103,10 @@ type drpcSpaceSync_ObjectSyncStreamClient struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcSpaceSync_ObjectSyncStreamClient) GetStream() drpc.Stream {
|
||||
return x.Stream
|
||||
}
|
||||
|
||||
func (x *drpcSpaceSync_ObjectSyncStreamClient) Send(m *ObjectSyncMessage) error {
|
||||
return x.MsgSend(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{})
|
||||
}
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"github.com/anyproto/any-sync/net/rpc/rpcerr"
|
||||
"github.com/anyproto/any-sync/nodeconf"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
"storj.io/drpc"
|
||||
)
|
||||
|
||||
const CName = "common.coordinator.coordinatorclient"
|
||||
@ -39,42 +40,8 @@ type coordinatorClient struct {
|
||||
nodeConf nodeconf.Service
|
||||
}
|
||||
|
||||
func (c *coordinatorClient) ChangeStatus(ctx context.Context, spaceId string, deleteRaw *treechangeproto.RawTreeChangeWithId) (status *coordinatorproto.SpaceStatusPayload, err error) {
|
||||
cl, err := c.client(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, err := cl.SpaceStatusChange(ctx, &coordinatorproto.SpaceStatusChangeRequest{
|
||||
SpaceId: spaceId,
|
||||
DeletionChangeId: deleteRaw.GetId(),
|
||||
DeletionChangePayload: deleteRaw.GetRawChange(),
|
||||
})
|
||||
if err != nil {
|
||||
err = rpcerr.Unwrap(err)
|
||||
return
|
||||
}
|
||||
status = resp.Payload
|
||||
return
|
||||
}
|
||||
|
||||
func (c *coordinatorClient) StatusCheck(ctx context.Context, spaceId string) (status *coordinatorproto.SpaceStatusPayload, err error) {
|
||||
cl, err := c.client(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, err := cl.SpaceStatusCheck(ctx, &coordinatorproto.SpaceStatusCheckRequest{
|
||||
SpaceId: spaceId,
|
||||
})
|
||||
if err != nil {
|
||||
err = rpcerr.Unwrap(err)
|
||||
return
|
||||
}
|
||||
status = resp.Payload
|
||||
return
|
||||
}
|
||||
|
||||
func (c *coordinatorClient) Init(a *app.App) (err error) {
|
||||
c.pool = a.MustComponent(pool.CName).(pool.Service).NewPool(CName)
|
||||
c.pool = a.MustComponent(pool.CName).(pool.Service)
|
||||
c.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.Service)
|
||||
return
|
||||
}
|
||||
@ -83,8 +50,37 @@ func (c *coordinatorClient) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
func (c *coordinatorClient) ChangeStatus(ctx context.Context, spaceId string, deleteRaw *treechangeproto.RawTreeChangeWithId) (status *coordinatorproto.SpaceStatusPayload, err error) {
|
||||
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
|
||||
resp, err := cl.SpaceStatusChange(ctx, &coordinatorproto.SpaceStatusChangeRequest{
|
||||
SpaceId: spaceId,
|
||||
DeletionChangeId: deleteRaw.GetId(),
|
||||
DeletionChangePayload: deleteRaw.GetRawChange(),
|
||||
})
|
||||
if err != nil {
|
||||
return rpcerr.Unwrap(err)
|
||||
}
|
||||
status = resp.Payload
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (c *coordinatorClient) StatusCheck(ctx context.Context, spaceId string) (status *coordinatorproto.SpaceStatusPayload, err error) {
|
||||
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
|
||||
resp, err := cl.SpaceStatusCheck(ctx, &coordinatorproto.SpaceStatusCheckRequest{
|
||||
SpaceId: spaceId,
|
||||
})
|
||||
if err != nil {
|
||||
return rpcerr.Unwrap(err)
|
||||
}
|
||||
status = resp.Payload
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (c *coordinatorClient) SpaceSign(ctx context.Context, payload SpaceSignPayload) (receipt *coordinatorproto.SpaceReceiptWithSignature, err error) {
|
||||
cl, err := c.client(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -100,54 +96,56 @@ func (c *coordinatorClient) SpaceSign(ctx context.Context, payload SpaceSignPayl
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, err := cl.SpaceSign(ctx, &coordinatorproto.SpaceSignRequest{
|
||||
SpaceId: payload.SpaceId,
|
||||
Header: payload.SpaceHeader,
|
||||
OldIdentity: oldIdentity,
|
||||
NewIdentitySignature: newSignature,
|
||||
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
|
||||
resp, err := cl.SpaceSign(ctx, &coordinatorproto.SpaceSignRequest{
|
||||
SpaceId: payload.SpaceId,
|
||||
Header: payload.SpaceHeader,
|
||||
OldIdentity: oldIdentity,
|
||||
NewIdentitySignature: newSignature,
|
||||
})
|
||||
if err != nil {
|
||||
return rpcerr.Unwrap(err)
|
||||
}
|
||||
receipt = resp.Receipt
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
err = rpcerr.Unwrap(err)
|
||||
return
|
||||
}
|
||||
return resp.Receipt, nil
|
||||
}
|
||||
|
||||
func (c *coordinatorClient) FileLimitCheck(ctx context.Context, spaceId string, identity []byte) (limit uint64, err error) {
|
||||
cl, err := c.client(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, err := cl.FileLimitCheck(ctx, &coordinatorproto.FileLimitCheckRequest{
|
||||
AccountIdentity: identity,
|
||||
SpaceId: spaceId,
|
||||
})
|
||||
if err != nil {
|
||||
err = rpcerr.Unwrap(err)
|
||||
return
|
||||
}
|
||||
return resp.Limit, nil
|
||||
}
|
||||
|
||||
func (c *coordinatorClient) NetworkConfiguration(ctx context.Context, currentId string) (resp *coordinatorproto.NetworkConfigurationResponse, err error) {
|
||||
cl, err := c.client(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
resp, err = cl.NetworkConfiguration(ctx, &coordinatorproto.NetworkConfigurationRequest{
|
||||
CurrentId: currentId,
|
||||
})
|
||||
if err != nil {
|
||||
err = rpcerr.Unwrap(err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *coordinatorClient) client(ctx context.Context) (coordinatorproto.DRPCCoordinatorClient, error) {
|
||||
func (c *coordinatorClient) FileLimitCheck(ctx context.Context, spaceId string, identity []byte) (limit uint64, err error) {
|
||||
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
|
||||
resp, err := cl.FileLimitCheck(ctx, &coordinatorproto.FileLimitCheckRequest{
|
||||
AccountIdentity: identity,
|
||||
SpaceId: spaceId,
|
||||
})
|
||||
if err != nil {
|
||||
return rpcerr.Unwrap(err)
|
||||
}
|
||||
limit = resp.Limit
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (c *coordinatorClient) NetworkConfiguration(ctx context.Context, currentId string) (resp *coordinatorproto.NetworkConfigurationResponse, err error) {
|
||||
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
|
||||
resp, err = cl.NetworkConfiguration(ctx, &coordinatorproto.NetworkConfigurationRequest{
|
||||
CurrentId: currentId,
|
||||
})
|
||||
if err != nil {
|
||||
return rpcerr.Unwrap(err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (c *coordinatorClient) doClient(ctx context.Context, f func(cl coordinatorproto.DRPCCoordinatorClient) error) error {
|
||||
p, err := c.pool.GetOneOf(ctx, c.nodeConf.CoordinatorPeers())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
return coordinatorproto.NewDRPCCoordinatorClient(p), nil
|
||||
return p.DoDrpc(ctx, func(conn drpc.Conn) error {
|
||||
return f(coordinatorproto.NewDRPCCoordinatorClient(conn))
|
||||
})
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.32
|
||||
// protoc-gen-go-drpc version: v0.0.33
|
||||
// source: coordinator/coordinatorproto/protos/coordinator.proto
|
||||
|
||||
package coordinatorproto
|
||||
|
||||
3
go.mod
3
go.mod
@ -14,6 +14,7 @@ require (
|
||||
github.com/gogo/protobuf v1.3.2
|
||||
github.com/golang/mock v1.6.0
|
||||
github.com/google/uuid v1.3.0
|
||||
github.com/hashicorp/yamux v0.1.1
|
||||
github.com/huandu/skiplist v1.2.0
|
||||
github.com/ipfs/go-block-format v0.1.2
|
||||
github.com/ipfs/go-blockservice v0.5.2
|
||||
@ -33,6 +34,7 @@ require (
|
||||
github.com/tyler-smith/go-bip39 v1.1.0
|
||||
github.com/zeebo/blake3 v0.2.3
|
||||
github.com/zeebo/errs v1.3.0
|
||||
go.uber.org/atomic v1.11.0
|
||||
go.uber.org/zap v1.24.0
|
||||
golang.org/x/crypto v0.9.0
|
||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
|
||||
@ -103,7 +105,6 @@ require (
|
||||
github.com/whyrusleeping/chunker v0.0.0-20181014151217-fe64bd25879f // indirect
|
||||
go.opentelemetry.io/otel v1.7.0 // indirect
|
||||
go.opentelemetry.io/otel/trace v1.7.0 // indirect
|
||||
go.uber.org/atomic v1.11.0 // indirect
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/image v0.6.0 // indirect
|
||||
golang.org/x/sync v0.2.0 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@ -79,6 +79,8 @@ github.com/gxed/hashland/keccakpg v0.0.1/go.mod h1:kRzw3HkwxFU1mpmPP8v1WyQzwdGfm
|
||||
github.com/gxed/hashland/murmur3 v0.0.1/go.mod h1:KjXop02n4/ckmZSnY2+HKcLud/tcmvhST0bie/0lS48=
|
||||
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
|
||||
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
|
||||
github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE=
|
||||
github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
|
||||
github.com/huandu/go-assert v1.1.5 h1:fjemmA7sSfYHJD7CUqs9qTwwfdNAx7/j2/ZlHXzNB3c=
|
||||
github.com/huandu/go-assert v1.1.5/go.mod h1:yOLvuqZwmcHIC5rIzrBhT7D3Q9c3GFnd0JrPVhn/06U=
|
||||
github.com/huandu/skiplist v1.2.0 h1:gox56QD77HzSC0w+Ws3MH3iie755GBJU1OER3h5VsYw=
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
package timeoutconn
|
||||
package connutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@ -10,18 +10,18 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var log = logger.NewNamed("common.net.timeoutconn")
|
||||
var log = logger.NewNamed("common.net.connutil")
|
||||
|
||||
type Conn struct {
|
||||
type TimeoutConn struct {
|
||||
net.Conn
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, timeout time.Duration) *Conn {
|
||||
return &Conn{conn, timeout}
|
||||
func NewConn(conn net.Conn, timeout time.Duration) *TimeoutConn {
|
||||
return &TimeoutConn{conn, timeout}
|
||||
}
|
||||
|
||||
func (c *Conn) Write(p []byte) (n int, err error) {
|
||||
func (c *TimeoutConn) Write(p []byte) (n int, err error) {
|
||||
for {
|
||||
if c.timeout != 0 {
|
||||
if e := c.Conn.SetWriteDeadline(time.Now().Add(c.timeout)); e != nil {
|
||||
30
net/connutil/usage.go
Normal file
30
net/connutil/usage.go
Normal file
@ -0,0 +1,30 @@
|
||||
package connutil
|
||||
|
||||
import (
|
||||
"go.uber.org/atomic"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewLastUsageConn(conn net.Conn) *LastUsageConn {
|
||||
return &LastUsageConn{Conn: conn}
|
||||
}
|
||||
|
||||
type LastUsageConn struct {
|
||||
net.Conn
|
||||
lastUsage atomic.Time
|
||||
}
|
||||
|
||||
func (c *LastUsageConn) Write(p []byte) (n int, err error) {
|
||||
c.lastUsage.Store(time.Now())
|
||||
return c.Conn.Write(p)
|
||||
}
|
||||
|
||||
func (c *LastUsageConn) Read(p []byte) (n int, err error) {
|
||||
c.lastUsage.Store(time.Now())
|
||||
return c.Conn.Read(p)
|
||||
}
|
||||
|
||||
func (c *LastUsageConn) LastUsage() time.Time {
|
||||
return c.lastUsage.Load()
|
||||
}
|
||||
@ -1,137 +0,0 @@
|
||||
package dialer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
net2 "github.com/anyproto/any-sync/net"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/net/secureservice"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake"
|
||||
"github.com/anyproto/any-sync/net/timeoutconn"
|
||||
"github.com/anyproto/any-sync/nodeconf"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcconn"
|
||||
"storj.io/drpc/drpcmanager"
|
||||
"storj.io/drpc/drpcwire"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const CName = "common.net.dialer"
|
||||
|
||||
var (
|
||||
ErrAddrsNotFound = errors.New("addrs for peer not found")
|
||||
ErrPeerIdIsUnexpected = errors.New("expected to connect with other peer id")
|
||||
)
|
||||
|
||||
var log = logger.NewNamed(CName)
|
||||
|
||||
func New() Dialer {
|
||||
return &dialer{peerAddrs: map[string][]string{}}
|
||||
}
|
||||
|
||||
type Dialer interface {
|
||||
Dial(ctx context.Context, peerId string) (peer peer.Peer, err error)
|
||||
SetPeerAddrs(peerId string, addrs []string)
|
||||
app.Component
|
||||
}
|
||||
|
||||
type dialer struct {
|
||||
transport secureservice.SecureService
|
||||
config net2.Config
|
||||
nodeConf nodeconf.NodeConf
|
||||
peerAddrs map[string][]string
|
||||
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (d *dialer) Init(a *app.App) (err error) {
|
||||
d.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService)
|
||||
d.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
|
||||
d.config = a.MustComponent("config").(net2.ConfigGetter).GetNet()
|
||||
return
|
||||
}
|
||||
|
||||
func (d *dialer) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
func (d *dialer) SetPeerAddrs(peerId string, addrs []string) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.peerAddrs[peerId] = addrs
|
||||
}
|
||||
|
||||
func (d *dialer) getPeerAddrs(peerId string) ([]string, error) {
|
||||
if addrs, ok := d.nodeConf.PeerAddresses(peerId); ok {
|
||||
return addrs, nil
|
||||
}
|
||||
addrs, ok := d.peerAddrs[peerId]
|
||||
if !ok || len(addrs) == 0 {
|
||||
return nil, ErrAddrsNotFound
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err error) {
|
||||
var ctxCancel context.CancelFunc
|
||||
ctx, ctxCancel = context.WithTimeout(ctx, time.Second*10)
|
||||
defer ctxCancel()
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
addrs, err := d.getPeerAddrs(peerId)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
conn drpc.Conn
|
||||
sc sec.SecureConn
|
||||
)
|
||||
log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs))
|
||||
for _, addr := range addrs {
|
||||
conn, sc, err = d.handshake(ctx, addr, peerId)
|
||||
if err != nil {
|
||||
log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err))
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return peer.NewPeer(sc, conn), nil
|
||||
}
|
||||
|
||||
func (d *dialer) handshake(ctx context.Context, addr, peerId string) (conn drpc.Conn, sc sec.SecureConn, err error) {
|
||||
st := time.Now()
|
||||
// TODO: move dial timeout to config
|
||||
tcpConn, err := net.DialTimeout("tcp", addr, time.Second*15)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("dialTimeout error: %v; since start: %v", err, time.Since(st))
|
||||
}
|
||||
|
||||
timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds))
|
||||
sc, err = d.transport.SecureOutbound(ctx, timeoutConn)
|
||||
if err != nil {
|
||||
if he, ok := err.(handshake.HandshakeError); ok {
|
||||
return nil, nil, he
|
||||
}
|
||||
return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st))
|
||||
}
|
||||
if peerId != sc.RemotePeer().String() {
|
||||
return nil, nil, ErrPeerIdIsUnexpected
|
||||
}
|
||||
log.Info("connected with remote host", zap.String("serverPeer", sc.RemotePeer().String()), zap.String("addr", addr))
|
||||
conn = drpcconn.NewWithOptions(sc, drpcconn.Options{Manager: drpcmanager.Options{
|
||||
Reader: drpcwire.ReaderOptions{MaximumBufferSize: d.config.Stream.MaxMsgSizeMb * (1 << 20)},
|
||||
}})
|
||||
return conn, sc, err
|
||||
}
|
||||
161
net/peer/peer.go
161
net/peer/peer.go
@ -2,82 +2,146 @@ package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"github.com/anyproto/any-sync/app/ocache"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/anyproto/any-sync/net/transport"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcconn"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var log = logger.NewNamed("common.net.peer")
|
||||
|
||||
func NewPeer(sc sec.SecureConn, conn drpc.Conn) Peer {
|
||||
return &peer{
|
||||
id: sc.RemotePeer().String(),
|
||||
lastUsage: time.Now().Unix(),
|
||||
sc: sc,
|
||||
Conn: conn,
|
||||
type connCtrl interface {
|
||||
ServeConn(ctx context.Context, conn net.Conn) (err error)
|
||||
}
|
||||
|
||||
func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) {
|
||||
ctx := mc.Context()
|
||||
pr := &peer{
|
||||
active: map[drpc.Conn]struct{}{},
|
||||
MultiConn: mc,
|
||||
ctrl: ctrl,
|
||||
}
|
||||
if pr.id, err = CtxPeerId(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
go pr.acceptLoop()
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
type Peer interface {
|
||||
Id() string
|
||||
LastUsage() time.Time
|
||||
UpdateLastUsage()
|
||||
Addr() string
|
||||
|
||||
AcquireDrpcConn(ctx context.Context) (drpc.Conn, error)
|
||||
ReleaseDrpcConn(conn drpc.Conn)
|
||||
DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error
|
||||
|
||||
IsClosed() bool
|
||||
|
||||
TryClose(objectTTL time.Duration) (res bool, err error)
|
||||
drpc.Conn
|
||||
|
||||
ocache.Object
|
||||
}
|
||||
|
||||
type peer struct {
|
||||
id string
|
||||
ttl time.Duration
|
||||
lastUsage int64
|
||||
sc sec.SecureConn
|
||||
drpc.Conn
|
||||
id string
|
||||
|
||||
ctrl connCtrl
|
||||
|
||||
// drpc conn pool
|
||||
inactive []drpc.Conn
|
||||
active map[drpc.Conn]struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
transport.MultiConn
|
||||
}
|
||||
|
||||
func (p *peer) Id() string {
|
||||
return p.id
|
||||
}
|
||||
|
||||
func (p *peer) LastUsage() time.Time {
|
||||
select {
|
||||
case <-p.Closed():
|
||||
return time.Unix(0, 0)
|
||||
default:
|
||||
func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if len(p.inactive) == 0 {
|
||||
conn, err := p.Open(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dconn := drpcconn.New(conn)
|
||||
p.inactive = append(p.inactive, dconn)
|
||||
}
|
||||
return time.Unix(atomic.LoadInt64(&p.lastUsage), 0)
|
||||
idx := len(p.inactive) - 1
|
||||
res := p.inactive[idx]
|
||||
p.inactive = p.inactive[:idx]
|
||||
p.active[res] = struct{}{}
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (p *peer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error {
|
||||
defer p.UpdateLastUsage()
|
||||
return p.Conn.Invoke(ctx, rpc, enc, in, out)
|
||||
}
|
||||
|
||||
func (p *peer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
|
||||
defer p.UpdateLastUsage()
|
||||
return p.Conn.NewStream(ctx, rpc, enc)
|
||||
}
|
||||
|
||||
func (p *peer) Read(b []byte) (n int, err error) {
|
||||
if n, err = p.sc.Read(b); err == nil {
|
||||
p.UpdateLastUsage()
|
||||
func (p *peer) ReleaseDrpcConn(conn drpc.Conn) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if _, ok := p.active[conn]; ok {
|
||||
delete(p.active, conn)
|
||||
}
|
||||
p.inactive = append(p.inactive, conn)
|
||||
return
|
||||
}
|
||||
|
||||
func (p *peer) Write(b []byte) (n int, err error) {
|
||||
if n, err = p.sc.Write(b); err == nil {
|
||||
p.UpdateLastUsage()
|
||||
func (p *peer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error {
|
||||
conn, err := p.AcquireDrpcConn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return
|
||||
defer p.ReleaseDrpcConn(conn)
|
||||
return do(conn)
|
||||
}
|
||||
|
||||
func (p *peer) UpdateLastUsage() {
|
||||
atomic.StoreInt64(&p.lastUsage, time.Now().Unix())
|
||||
func (p *peer) acceptLoop() {
|
||||
var exitErr error
|
||||
defer func() {
|
||||
if exitErr != transport.ErrConnClosed {
|
||||
log.Warn("accept error: close connection", zap.Error(exitErr))
|
||||
_ = p.MultiConn.Close()
|
||||
}
|
||||
}()
|
||||
for {
|
||||
conn, err := p.Accept()
|
||||
if err != nil {
|
||||
exitErr = err
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
serveErr := p.serve(conn)
|
||||
if serveErr != io.EOF && serveErr != transport.ErrConnClosed {
|
||||
log.InfoCtx(p.Context(), "serve connection error", zap.Error(serveErr))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
var defaultProtoChecker = handshake.ProtoChecker{
|
||||
AllowedProtoTypes: []handshakeproto.ProtoType{
|
||||
handshakeproto.ProtoType_DRPC,
|
||||
},
|
||||
}
|
||||
|
||||
func (p *peer) serve(conn net.Conn) (err error) {
|
||||
hsCtx, cancel := context.WithTimeout(p.Context(), time.Second*20)
|
||||
if _, err = handshake.IncomingProtoHandshake(hsCtx, conn, defaultProtoChecker); err != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
cancel()
|
||||
return p.ctrl.ServeConn(p.Context(), conn)
|
||||
}
|
||||
|
||||
func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) {
|
||||
@ -87,14 +151,7 @@ func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) {
|
||||
return true, p.Close()
|
||||
}
|
||||
|
||||
func (p *peer) Addr() string {
|
||||
if p.sc != nil {
|
||||
return p.sc.RemoteAddr().String()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (p *peer) Close() (err error) {
|
||||
log.Debug("peer close", zap.String("peerId", p.id))
|
||||
return p.Conn.Close()
|
||||
return p.MultiConn.Close()
|
||||
}
|
||||
|
||||
137
net/peer/peer_test.go
Normal file
137
net/peer/peer_test.go
Normal file
@ -0,0 +1,137 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/anyproto/any-sync/net/transport/mock_transport"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
func TestPeer_AcquireDrpcConn(t *testing.T) {
|
||||
fx := newFixture(t, "p1")
|
||||
defer fx.finish()
|
||||
in, out := net.Pipe()
|
||||
defer out.Close()
|
||||
fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil)
|
||||
dc, err := fx.AcquireDrpcConn(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dc)
|
||||
defer dc.Close()
|
||||
|
||||
assert.Len(t, fx.active, 1)
|
||||
assert.Len(t, fx.inactive, 0)
|
||||
|
||||
fx.ReleaseDrpcConn(dc)
|
||||
|
||||
assert.Len(t, fx.active, 0)
|
||||
assert.Len(t, fx.inactive, 1)
|
||||
|
||||
dc, err = fx.AcquireDrpcConn(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, dc)
|
||||
assert.Len(t, fx.active, 1)
|
||||
assert.Len(t, fx.inactive, 0)
|
||||
}
|
||||
|
||||
func TestPeerAccept(t *testing.T) {
|
||||
fx := newFixture(t, "p1")
|
||||
defer fx.finish()
|
||||
in, out := net.Pipe()
|
||||
defer out.Close()
|
||||
|
||||
var outHandshakeCh = make(chan error)
|
||||
go func() {
|
||||
outHandshakeCh <- handshake.OutgoingProtoHandshake(ctx, out, handshakeproto.ProtoType_DRPC)
|
||||
}()
|
||||
fx.acceptCh <- acceptedConn{conn: in}
|
||||
cn := <-fx.testCtrl.serveConn
|
||||
assert.Equal(t, in, cn)
|
||||
assert.NoError(t, <-outHandshakeCh)
|
||||
}
|
||||
|
||||
func TestPeer_TryClose(t *testing.T) {
|
||||
t.Run("ttl", func(t *testing.T) {
|
||||
fx := newFixture(t, "p1")
|
||||
defer fx.finish()
|
||||
lu := time.Now()
|
||||
fx.mc.EXPECT().LastUsage().Return(lu)
|
||||
res, err := fx.TryClose(time.Second)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res)
|
||||
})
|
||||
t.Run("close", func(t *testing.T) {
|
||||
fx := newFixture(t, "p1")
|
||||
defer fx.finish()
|
||||
lu := time.Now().Add(-time.Second * 2)
|
||||
fx.mc.EXPECT().LastUsage().Return(lu)
|
||||
res, err := fx.TryClose(time.Second)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res)
|
||||
})
|
||||
}
|
||||
|
||||
type acceptedConn struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
func newFixture(t *testing.T, peerId string) *fixture {
|
||||
fx := &fixture{
|
||||
ctrl: gomock.NewController(t),
|
||||
acceptCh: make(chan acceptedConn),
|
||||
testCtrl: newTesCtrl(),
|
||||
}
|
||||
fx.mc = mock_transport.NewMockMultiConn(fx.ctrl)
|
||||
ctx := CtxWithPeerId(context.Background(), peerId)
|
||||
fx.mc.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
fx.mc.EXPECT().Accept().DoAndReturn(func() (net.Conn, error) {
|
||||
ac := <-fx.acceptCh
|
||||
return ac.conn, ac.err
|
||||
}).AnyTimes()
|
||||
fx.mc.EXPECT().Close().AnyTimes()
|
||||
p, err := NewPeer(fx.mc, fx.testCtrl)
|
||||
require.NoError(t, err)
|
||||
fx.peer = p.(*peer)
|
||||
return fx
|
||||
}
|
||||
|
||||
type fixture struct {
|
||||
*peer
|
||||
ctrl *gomock.Controller
|
||||
mc *mock_transport.MockMultiConn
|
||||
acceptCh chan acceptedConn
|
||||
testCtrl *testCtrl
|
||||
}
|
||||
|
||||
func (fx *fixture) finish() {
|
||||
fx.testCtrl.close()
|
||||
fx.ctrl.Finish()
|
||||
}
|
||||
|
||||
func newTesCtrl() *testCtrl {
|
||||
return &testCtrl{closeCh: make(chan struct{}), serveConn: make(chan net.Conn, 10)}
|
||||
}
|
||||
|
||||
type testCtrl struct {
|
||||
serveConn chan net.Conn
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func (t *testCtrl) ServeConn(ctx context.Context, conn net.Conn) (err error) {
|
||||
t.serveConn <- conn
|
||||
<-t.closeCh
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (t *testCtrl) close() {
|
||||
close(t.closeCh)
|
||||
}
|
||||
110
net/peerservice/peerservice.go
Normal file
110
net/peerservice/peerservice.go
Normal file
@ -0,0 +1,110 @@
|
||||
package peerservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/net/pool"
|
||||
"github.com/anyproto/any-sync/net/rpc/server"
|
||||
"github.com/anyproto/any-sync/net/transport"
|
||||
"github.com/anyproto/any-sync/net/transport/yamux"
|
||||
"github.com/anyproto/any-sync/nodeconf"
|
||||
"go.uber.org/zap"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const CName = "net.peerservice"
|
||||
|
||||
var log = logger.NewNamed(CName)
|
||||
|
||||
var (
|
||||
ErrAddrsNotFound = errors.New("addrs for peer not found")
|
||||
)
|
||||
|
||||
func New() PeerService {
|
||||
return new(peerService)
|
||||
}
|
||||
|
||||
type PeerService interface {
|
||||
Dial(ctx context.Context, peerId string) (pr peer.Peer, err error)
|
||||
SetPeerAddrs(peerId string, addrs []string)
|
||||
transport.Accepter
|
||||
app.Component
|
||||
}
|
||||
|
||||
type peerService struct {
|
||||
yamux transport.Transport
|
||||
nodeConf nodeconf.NodeConf
|
||||
peerAddrs map[string][]string
|
||||
pool pool.Pool
|
||||
server server.DRPCServer
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func (p *peerService) Init(a *app.App) (err error) {
|
||||
p.yamux = a.MustComponent(yamux.CName).(transport.Transport)
|
||||
p.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
|
||||
p.pool = a.MustComponent(pool.CName).(pool.Pool)
|
||||
p.server = a.MustComponent(server.CName).(server.DRPCServer)
|
||||
p.peerAddrs = map[string][]string{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *peerService) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
func (p *peerService) Dial(ctx context.Context, peerId string) (pr peer.Peer, err error) {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
addrs, err := p.getPeerAddrs(peerId)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var mc transport.MultiConn
|
||||
log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs))
|
||||
for _, addr := range addrs {
|
||||
mc, err = p.yamux.Dial(ctx, addr)
|
||||
if err != nil {
|
||||
log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err))
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return peer.NewPeer(mc, p.server)
|
||||
}
|
||||
|
||||
func (p *peerService) Accept(mc transport.MultiConn) (err error) {
|
||||
pr, err := peer.NewPeer(mc, p.server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = p.pool.AddPeer(context.Background(), pr); err != nil {
|
||||
_ = pr.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *peerService) SetPeerAddrs(peerId string, addrs []string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.peerAddrs[peerId] = addrs
|
||||
}
|
||||
|
||||
func (p *peerService) getPeerAddrs(peerId string) ([]string, error) {
|
||||
if addrs, ok := p.nodeConf.PeerAddresses(peerId); ok {
|
||||
return addrs, nil
|
||||
}
|
||||
addrs, ok := p.peerAddrs[peerId]
|
||||
if !ok || len(addrs) == 0 {
|
||||
return nil, ErrAddrsNotFound
|
||||
}
|
||||
return addrs, nil
|
||||
}
|
||||
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/app/ocache"
|
||||
"github.com/anyproto/any-sync/net"
|
||||
"github.com/anyproto/any-sync/net/dialer"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake"
|
||||
"go.uber.org/zap"
|
||||
@ -13,59 +12,61 @@ import (
|
||||
|
||||
// Pool creates and caches outgoing connection
|
||||
type Pool interface {
|
||||
// Get lookups to peer in existing connections or creates and cache new one
|
||||
// Get lookups to peer in existing connections or creates and outgoing new one
|
||||
Get(ctx context.Context, id string) (peer.Peer, error)
|
||||
// Dial creates new connection to peer and not use cache
|
||||
Dial(ctx context.Context, id string) (peer.Peer, error)
|
||||
// GetOneOf searches at least one existing connection in cache or creates a new one from a randomly selected id from given list
|
||||
// GetOneOf searches at least one existing connection in outgoing or creates a new one from a randomly selected id from given list
|
||||
GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error)
|
||||
|
||||
DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error)
|
||||
// AddPeer adds incoming peer to the pool
|
||||
AddPeer(ctx context.Context, p peer.Peer) (err error)
|
||||
}
|
||||
|
||||
type pool struct {
|
||||
cache ocache.OCache
|
||||
dialer dialer.Dialer
|
||||
outgoing ocache.OCache
|
||||
incoming ocache.OCache
|
||||
}
|
||||
|
||||
func (p *pool) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
func (p *pool) Run(ctx context.Context) (err error) {
|
||||
return nil
|
||||
func (p *pool) Get(ctx context.Context, id string) (pr peer.Peer, err error) {
|
||||
// if we have incoming connection - try to reuse it
|
||||
if pr, err = p.get(ctx, p.incoming, id); err != nil {
|
||||
// or try to get or create outgoing
|
||||
return p.get(ctx, p.outgoing, id)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (p *pool) Get(ctx context.Context, id string) (peer.Peer, error) {
|
||||
v, err := p.cache.Get(ctx, id)
|
||||
func (p *pool) get(ctx context.Context, source ocache.OCache, id string) (peer.Peer, error) {
|
||||
v, err := source.Get(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pr := v.(peer.Peer)
|
||||
select {
|
||||
case <-pr.Closed():
|
||||
default:
|
||||
if !pr.IsClosed() {
|
||||
return pr, nil
|
||||
}
|
||||
_, _ = p.cache.Remove(ctx, id)
|
||||
_, _ = source.Remove(ctx, id)
|
||||
return p.Get(ctx, id)
|
||||
}
|
||||
|
||||
func (p *pool) Dial(ctx context.Context, id string) (peer.Peer, error) {
|
||||
return p.dialer.Dial(ctx, id)
|
||||
}
|
||||
|
||||
func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
|
||||
// finding existing connection
|
||||
for _, peerId := range peerIds {
|
||||
if v, err := p.cache.Pick(ctx, peerId); err == nil {
|
||||
if v, err := p.incoming.Pick(ctx, peerId); err == nil {
|
||||
pr := v.(peer.Peer)
|
||||
select {
|
||||
case <-pr.Closed():
|
||||
default:
|
||||
if !pr.IsClosed() {
|
||||
return pr, nil
|
||||
}
|
||||
_, _ = p.cache.Remove(ctx, peerId)
|
||||
_, _ = p.incoming.Remove(ctx, peerId)
|
||||
}
|
||||
if v, err := p.outgoing.Pick(ctx, peerId); err == nil {
|
||||
pr := v.(peer.Peer)
|
||||
if !pr.IsClosed() {
|
||||
return pr, nil
|
||||
}
|
||||
_, _ = p.outgoing.Remove(ctx, peerId)
|
||||
}
|
||||
}
|
||||
// shuffle ids for better consistency
|
||||
@ -75,8 +76,8 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error
|
||||
// connecting
|
||||
var lastErr error
|
||||
for _, peerId := range peerIds {
|
||||
if v, err := p.cache.Get(ctx, peerId); err == nil {
|
||||
return v.(peer.Peer), nil
|
||||
if v, err := p.Get(ctx, peerId); err == nil {
|
||||
return v, nil
|
||||
} else {
|
||||
log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err))
|
||||
lastErr = err
|
||||
@ -88,27 +89,18 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (p *pool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
|
||||
// shuffle ids for better consistency
|
||||
rand.Shuffle(len(peerIds), func(i, j int) {
|
||||
peerIds[i], peerIds[j] = peerIds[j], peerIds[i]
|
||||
})
|
||||
// connecting
|
||||
var lastErr error
|
||||
for _, peerId := range peerIds {
|
||||
if v, err := p.dialer.Dial(ctx, peerId); err == nil {
|
||||
return v.(peer.Peer), nil
|
||||
func (p *pool) AddPeer(ctx context.Context, pr peer.Peer) (err error) {
|
||||
if err = p.incoming.Add(pr.Id(), pr); err != nil {
|
||||
if err == ocache.ErrExists {
|
||||
// in case when an incoming connection with a peer already exists, we close and remove an existing connection
|
||||
if v, e := p.incoming.Pick(ctx, pr.Id()); e == nil {
|
||||
_ = v.Close()
|
||||
_, _ = p.incoming.Remove(ctx, pr.Id())
|
||||
return p.incoming.Add(pr.Id(), pr)
|
||||
}
|
||||
} else {
|
||||
log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err))
|
||||
lastErr = err
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, ok := lastErr.(handshake.HandshakeError); !ok {
|
||||
lastErr = net.ErrUnableToConnect
|
||||
}
|
||||
return nil, lastErr
|
||||
}
|
||||
|
||||
func (p *pool) Close(ctx context.Context) (err error) {
|
||||
return p.cache.Close()
|
||||
return
|
||||
}
|
||||
|
||||
@ -6,11 +6,11 @@ import (
|
||||
"fmt"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/net"
|
||||
"github.com/anyproto/any-sync/net/dialer"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
net2 "net"
|
||||
"storj.io/drpc"
|
||||
"testing"
|
||||
"time"
|
||||
@ -133,6 +133,27 @@ func TestPool_GetOneOf(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestPool_AddPeer(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
defer fx.Finish()
|
||||
require.NoError(t, fx.AddPeer(ctx, newTestPeer("p1")))
|
||||
})
|
||||
t.Run("two peers", func(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
defer fx.Finish()
|
||||
p1, p2 := newTestPeer("p1"), newTestPeer("p1")
|
||||
require.NoError(t, fx.AddPeer(ctx, p1))
|
||||
require.NoError(t, fx.AddPeer(ctx, p2))
|
||||
select {
|
||||
case <-p1.closed:
|
||||
default:
|
||||
assert.Truef(t, false, "peer not closed")
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func newFixture(t *testing.T) *fixture {
|
||||
fx := &fixture{
|
||||
Service: New(),
|
||||
@ -158,7 +179,7 @@ type fixture struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
var _ dialer.Dialer = (*dialerMock)(nil)
|
||||
var _ dialer = (*dialerMock)(nil)
|
||||
|
||||
type dialerMock struct {
|
||||
dial func(ctx context.Context, peerId string) (peer peer.Peer, err error)
|
||||
@ -181,7 +202,7 @@ func (d *dialerMock) Init(a *app.App) (err error) {
|
||||
}
|
||||
|
||||
func (d *dialerMock) Name() (name string) {
|
||||
return dialer.CName
|
||||
return "net.peerservice"
|
||||
}
|
||||
|
||||
func newTestPeer(id string) *testPeer {
|
||||
@ -196,6 +217,31 @@ type testPeer struct {
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func (t *testPeer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error {
|
||||
return fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (t *testPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (t *testPeer) ReleaseDrpcConn(conn drpc.Conn) {}
|
||||
|
||||
func (t *testPeer) Context() context.Context {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (t *testPeer) Accept() (conn net2.Conn, err error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (t *testPeer) Open(ctx context.Context) (conn net2.Conn, err error) {
|
||||
//TODO implement me
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (t *testPeer) Addr() string {
|
||||
return ""
|
||||
}
|
||||
@ -204,12 +250,6 @@ func (t *testPeer) Id() string {
|
||||
return t.id
|
||||
}
|
||||
|
||||
func (t *testPeer) LastUsage() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
func (t *testPeer) UpdateLastUsage() {}
|
||||
|
||||
func (t *testPeer) TryClose(objectTTL time.Duration) (res bool, err error) {
|
||||
return true, t.Close()
|
||||
}
|
||||
@ -224,14 +264,11 @@ func (t *testPeer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testPeer) Closed() <-chan struct{} {
|
||||
return t.closed
|
||||
}
|
||||
|
||||
func (t *testPeer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error {
|
||||
return fmt.Errorf("call Invoke on test peer")
|
||||
}
|
||||
|
||||
func (t *testPeer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
|
||||
return nil, fmt.Errorf("call NewStream on test peer")
|
||||
func (t *testPeer) IsClosed() bool {
|
||||
select {
|
||||
case <-t.closed:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@ -6,8 +6,9 @@ import (
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/app/ocache"
|
||||
"github.com/anyproto/any-sync/metric"
|
||||
"github.com/anyproto/any-sync/net/dialer"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"go.uber.org/zap"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -23,46 +24,54 @@ func New() Service {
|
||||
|
||||
type Service interface {
|
||||
Pool
|
||||
NewPool(name string) Pool
|
||||
app.ComponentRunnable
|
||||
}
|
||||
|
||||
type dialer interface {
|
||||
Dial(ctx context.Context, peerId string) (pr peer.Peer, err error)
|
||||
}
|
||||
|
||||
type poolService struct {
|
||||
// default pool
|
||||
*pool
|
||||
dialer dialer.Dialer
|
||||
dialer dialer
|
||||
metricReg *prometheus.Registry
|
||||
}
|
||||
|
||||
func (p *poolService) Init(a *app.App) (err error) {
|
||||
p.dialer = a.MustComponent(dialer.CName).(dialer.Dialer)
|
||||
p.pool = &pool{dialer: p.dialer}
|
||||
p.dialer = a.MustComponent("net.peerservice").(dialer)
|
||||
p.pool = &pool{}
|
||||
if m := a.Component(metric.CName); m != nil {
|
||||
p.metricReg = m.(metric.Metric).Registry()
|
||||
}
|
||||
p.pool.cache = ocache.New(
|
||||
p.pool.outgoing = ocache.New(
|
||||
func(ctx context.Context, id string) (value ocache.Object, err error) {
|
||||
return p.dialer.Dial(ctx, id)
|
||||
},
|
||||
ocache.WithLogger(log.Sugar()),
|
||||
ocache.WithGCPeriod(time.Minute),
|
||||
ocache.WithTTL(time.Minute*5),
|
||||
ocache.WithPrometheus(p.metricReg, "netpool", "default"),
|
||||
ocache.WithPrometheus(p.metricReg, "netpool", "outgoing"),
|
||||
)
|
||||
p.pool.incoming = ocache.New(
|
||||
func(ctx context.Context, id string) (value ocache.Object, err error) {
|
||||
return nil, ocache.ErrNotExists
|
||||
},
|
||||
ocache.WithLogger(log.Sugar()),
|
||||
ocache.WithGCPeriod(time.Minute),
|
||||
ocache.WithTTL(time.Minute*5),
|
||||
ocache.WithPrometheus(p.metricReg, "netpool", "incoming"),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *poolService) NewPool(name string) Pool {
|
||||
return &pool{
|
||||
dialer: p.dialer,
|
||||
cache: ocache.New(
|
||||
func(ctx context.Context, id string) (value ocache.Object, err error) {
|
||||
return p.dialer.Dial(ctx, id)
|
||||
},
|
||||
ocache.WithLogger(log.Sugar()),
|
||||
ocache.WithGCPeriod(time.Minute),
|
||||
ocache.WithTTL(time.Minute*5),
|
||||
ocache.WithPrometheus(p.metricReg, "netpool", name),
|
||||
),
|
||||
}
|
||||
func (p *pool) Run(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pool) Close(ctx context.Context) (err error) {
|
||||
if e := p.incoming.Close(); e != nil {
|
||||
log.Warn("close incoming cache error", zap.Error(e))
|
||||
}
|
||||
return p.outgoing.Close()
|
||||
}
|
||||
|
||||
@ -1,134 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/net/secureservice"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"github.com/zeebo/errs"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcmanager"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
"storj.io/drpc/drpcwire"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BaseDrpcServer struct {
|
||||
drpcServer *drpcserver.Server
|
||||
transport secureservice.SecureService
|
||||
listeners []net.Listener
|
||||
handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
|
||||
cancel func()
|
||||
*drpcmux.Mux
|
||||
}
|
||||
|
||||
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
|
||||
|
||||
type Params struct {
|
||||
BufferSizeMb int
|
||||
ListenAddrs []string
|
||||
Wrapper DRPCHandlerWrapper
|
||||
TimeoutMillis int
|
||||
Handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
|
||||
}
|
||||
|
||||
func NewBaseDrpcServer() *BaseDrpcServer {
|
||||
return &BaseDrpcServer{Mux: drpcmux.New()}
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) {
|
||||
s.drpcServer = drpcserver.NewWithOptions(params.Wrapper(s.Mux), drpcserver.Options{Manager: drpcmanager.Options{
|
||||
Reader: drpcwire.ReaderOptions{MaximumBufferSize: params.BufferSizeMb * (1 << 20)},
|
||||
}})
|
||||
s.handshake = params.Handshake
|
||||
ctx, s.cancel = context.WithCancel(ctx)
|
||||
for _, addr := range params.ListenAddrs {
|
||||
list, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listeners = append(s.listeners, list)
|
||||
go s.serve(ctx, list)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) {
|
||||
l := log.With(zap.String("localAddr", lis.Addr().String()))
|
||||
l.Info("drpc listener started")
|
||||
defer func() {
|
||||
l.Debug("drpc listener stopped")
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
conn, err := lis.Accept()
|
||||
if err != nil {
|
||||
if isTemporary(err) {
|
||||
l.Debug("listener temporary accept error", zap.Error(err))
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
l.Error("listener accept error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
go s.serveConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) serveConn(conn net.Conn) {
|
||||
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
|
||||
var (
|
||||
ctx = context.Background()
|
||||
err error
|
||||
)
|
||||
if s.handshake != nil {
|
||||
ctx, conn, err = s.handshake(conn)
|
||||
if err != nil {
|
||||
l.Info("handshake error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
if sc, ok := conn.(sec.SecureConn); ok {
|
||||
ctx = peer.CtxWithPeerId(ctx, sc.RemotePeer().String())
|
||||
}
|
||||
}
|
||||
ctx = peer.CtxWithPeerAddr(ctx, conn.RemoteAddr().String())
|
||||
l.Debug("connection opened")
|
||||
if err := s.drpcServer.ServeOne(ctx, conn); err != nil {
|
||||
if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) {
|
||||
l.Debug("connection closed")
|
||||
} else {
|
||||
l.Warn("serve connection error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) ListenAddrs() (addrs []net.Addr) {
|
||||
for _, list := range s.listeners {
|
||||
addrs = append(addrs, list.Addr())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) Close(ctx context.Context) (err error) {
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
for _, l := range s.listeners {
|
||||
if e := l.Close(); e != nil {
|
||||
log.Warn("close listener error", zap.Error(e))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -6,11 +6,13 @@ import (
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/metric"
|
||||
anyNet "github.com/anyproto/any-sync/net"
|
||||
"github.com/anyproto/any-sync/net/secureservice"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"storj.io/drpc"
|
||||
"time"
|
||||
"storj.io/drpc/drpcmanager"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
"storj.io/drpc/drpcwire"
|
||||
)
|
||||
|
||||
const CName = "common.net.drpcserver"
|
||||
@ -18,49 +20,46 @@ const CName = "common.net.drpcserver"
|
||||
var log = logger.NewNamed(CName)
|
||||
|
||||
func New() DRPCServer {
|
||||
return &drpcServer{BaseDrpcServer: NewBaseDrpcServer()}
|
||||
return &drpcServer{}
|
||||
}
|
||||
|
||||
type DRPCServer interface {
|
||||
app.ComponentRunnable
|
||||
ServeConn(ctx context.Context, conn net.Conn) (err error)
|
||||
app.Component
|
||||
drpc.Mux
|
||||
}
|
||||
|
||||
type drpcServer struct {
|
||||
config anyNet.Config
|
||||
metric metric.Metric
|
||||
transport secureservice.SecureService
|
||||
*BaseDrpcServer
|
||||
drpcServer *drpcserver.Server
|
||||
*drpcmux.Mux
|
||||
config anyNet.Config
|
||||
metric metric.Metric
|
||||
}
|
||||
|
||||
func (s *drpcServer) Init(a *app.App) (err error) {
|
||||
s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet()
|
||||
s.metric = a.MustComponent(metric.CName).(metric.Metric)
|
||||
s.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService)
|
||||
return nil
|
||||
}
|
||||
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
|
||||
|
||||
func (s *drpcServer) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
func (s *drpcServer) Run(ctx context.Context) (err error) {
|
||||
params := Params{
|
||||
BufferSizeMb: s.config.Stream.MaxMsgSizeMb,
|
||||
TimeoutMillis: s.config.Stream.TimeoutMilliseconds,
|
||||
ListenAddrs: s.config.Server.ListenAddrs,
|
||||
Wrapper: func(handler drpc.Handler) drpc.Handler {
|
||||
return s.metric.WrapDRPCHandler(handler)
|
||||
},
|
||||
Handshake: func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
return s.transport.SecureInbound(ctx, conn)
|
||||
},
|
||||
func (s *drpcServer) Init(a *app.App) (err error) {
|
||||
s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet()
|
||||
s.metric, _ = a.Component(metric.CName).(metric.Metric)
|
||||
s.Mux = drpcmux.New()
|
||||
|
||||
var handler drpc.Handler
|
||||
handler = s
|
||||
if s.metric != nil {
|
||||
handler = s.metric.WrapDRPCHandler(s)
|
||||
}
|
||||
return s.BaseDrpcServer.Run(ctx, params)
|
||||
s.drpcServer = drpcserver.NewWithOptions(handler, drpcserver.Options{Manager: drpcmanager.Options{
|
||||
Reader: drpcwire.ReaderOptions{MaximumBufferSize: s.config.Stream.MaxMsgSizeMb * (1 << 20)},
|
||||
}})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *drpcServer) Close(ctx context.Context) (err error) {
|
||||
return s.BaseDrpcServer.Close(ctx)
|
||||
func (s *drpcServer) ServeConn(ctx context.Context, conn net.Conn) (err error) {
|
||||
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
|
||||
l.Debug("drpc serve peer")
|
||||
return s.drpcServer.ServeOne(ctx, conn)
|
||||
}
|
||||
|
||||
125
net/secureservice/handshake/credential.go
Normal file
125
net/secureservice/handshake/credential.go
Normal file
@ -0,0 +1,125 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
)
|
||||
|
||||
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
h := newHandshake()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
identity, err = outgoingHandshake(h, sc, cc)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
_ = sc.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
defer h.release()
|
||||
h.conn = sc
|
||||
localCred := cc.MakeCredentials(sc)
|
||||
if err = h.writeCredentials(localCred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
msg, err := h.readMsg(msgTypeAck, msgTypeCred)
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
if msg.ack != nil {
|
||||
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
|
||||
return nil, ErrPeerDeclinedCredentials
|
||||
}
|
||||
return nil, HandshakeError{e: msg.ack.Error}
|
||||
}
|
||||
|
||||
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err = h.readMsg(msgTypeAck)
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
if msg.ack.Error == handshakeproto.Error_Null {
|
||||
return identity, nil
|
||||
} else {
|
||||
_ = h.conn.Close()
|
||||
return nil, HandshakeError{e: msg.ack.Error}
|
||||
}
|
||||
}
|
||||
|
||||
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
h := newHandshake()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
identity, err = incomingHandshake(h, sc, cc)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
_ = sc.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
defer h.release()
|
||||
h.conn = sc
|
||||
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err = h.readMsg(msgTypeAck)
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
if msg.ack.Error != handshakeproto.Error_Null {
|
||||
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
|
||||
return nil, ErrPeerDeclinedCredentials
|
||||
}
|
||||
return nil, HandshakeError{e: msg.ack.Error}
|
||||
}
|
||||
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -38,15 +38,14 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// receive credential message
|
||||
msg, err := h.readMsg()
|
||||
msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, msg.ack)
|
||||
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
|
||||
require.NoError(t, err)
|
||||
// send credential message
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// receive ack
|
||||
msg, err = h.readMsg()
|
||||
msg, err = h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
||||
// send ack
|
||||
@ -76,7 +75,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// receive credential message
|
||||
_, err := h.readMsg()
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
_ = c2.Close()
|
||||
res := <-handshakeResCh
|
||||
@ -92,7 +91,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// receive credential message
|
||||
_, err := h.readMsg()
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, h.writeAck(ErrInvalidCredentials.e))
|
||||
res := <-handshakeResCh
|
||||
@ -108,10 +107,10 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// receive credential message
|
||||
_, err := h.readMsg()
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
msg, err := h.readMsg()
|
||||
msg, err := h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
|
||||
res := <-handshakeResCh
|
||||
@ -127,7 +126,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// receive credential message
|
||||
_, err := h.readMsg()
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
// write credentials and close conn
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
@ -145,12 +144,12 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// receive credential message
|
||||
_, err := h.readMsg()
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// read ack and close conn
|
||||
_, err = h.readMsg()
|
||||
_, err = h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
_ = c2.Close()
|
||||
res := <-handshakeResCh
|
||||
@ -166,18 +165,17 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// receive credential message
|
||||
_, err := h.readMsg()
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// read ack
|
||||
_, err = h.readMsg()
|
||||
_, err = h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
// write cred instead ack
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
msg, err := h.readMsg()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
|
||||
_, err = h.readMsg(msgTypeAck)
|
||||
require.Error(t, err)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
@ -191,7 +189,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// receive credential message
|
||||
msg, err := h.readMsg()
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, msg.ack)
|
||||
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
|
||||
@ -199,7 +197,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
// send credential message
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// receive ack
|
||||
msg, err = h.readMsg()
|
||||
msg, err = h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
||||
// send ack
|
||||
@ -219,7 +217,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// receive credential message
|
||||
_, err := h.readMsg()
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
ctxCancel()
|
||||
res := <-handshakeResCh
|
||||
@ -244,14 +242,14 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// wait credentials
|
||||
msg, err := h.readMsg()
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, msg.ack)
|
||||
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
|
||||
// write ack
|
||||
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
|
||||
// wait ack
|
||||
msg, err = h.readMsg()
|
||||
msg, err = h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
||||
res := <-handshakeResCh
|
||||
@ -310,7 +308,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// except ack with error
|
||||
msg, err := h.readMsg()
|
||||
msg, err := h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, msg.cred)
|
||||
require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error)
|
||||
@ -330,7 +328,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// except ack with error
|
||||
msg, err := h.readMsg()
|
||||
msg, err := h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, msg.cred)
|
||||
require.Equal(t, handshakeproto.Error_IncompatibleVersion, msg.ack.Error)
|
||||
@ -350,13 +348,13 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// read cred
|
||||
_, err := h.readMsg()
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
// write cred instead ack
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// expect ack with error
|
||||
msg, err := h.readMsg()
|
||||
require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
|
||||
// expect EOF
|
||||
_, err = h.readMsg(msgTypeAck)
|
||||
require.Error(t, err)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
@ -372,7 +370,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// read cred and close conn
|
||||
_, err := h.readMsg()
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
_ = c2.Close()
|
||||
|
||||
@ -391,7 +389,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// wait credentials
|
||||
msg, err := h.readMsg()
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, msg.ack)
|
||||
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
|
||||
@ -413,7 +411,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// wait credentials
|
||||
msg, err := h.readMsg()
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, msg.ack)
|
||||
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
|
||||
@ -435,7 +433,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// wait credentials
|
||||
msg, err := h.readMsg()
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, msg.ack)
|
||||
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
|
||||
@ -458,7 +456,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// wait credentials
|
||||
_, err := h.readMsg()
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
ctxCancel()
|
||||
res := <-handshakeResCh
|
||||
@ -482,7 +480,7 @@ func TestNotAHandshakeMessage(t *testing.T) {
|
||||
_, err := c2.Write([]byte("some unexpected bytes"))
|
||||
require.Error(t, err)
|
||||
res := <-handshakeResCh
|
||||
assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error())
|
||||
assert.Error(t, res.err)
|
||||
}
|
||||
|
||||
func TestEndToEnd(t *testing.T) {
|
||||
@ -1,21 +1,30 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"golang.org/x/exp/slices"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const headerSize = 5 // 1 byte for type + 4 byte for uint32 size
|
||||
|
||||
const (
|
||||
msgTypeCred = byte(1)
|
||||
msgTypeAck = byte(2)
|
||||
msgTypeCred = byte(1)
|
||||
msgTypeAck = byte(2)
|
||||
msgTypeProto = byte(3)
|
||||
|
||||
sizeLimit = 200 * 1024 // 200 Kb
|
||||
)
|
||||
|
||||
var (
|
||||
credMsgTypes = []byte{msgTypeCred, msgTypeAck}
|
||||
protoMsgTypes = []byte{msgTypeProto, msgTypeAck}
|
||||
protoMsgTypesAck = []byte{msgTypeAck}
|
||||
)
|
||||
|
||||
type HandshakeError struct {
|
||||
@ -38,17 +47,20 @@ var (
|
||||
ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed}
|
||||
ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected}
|
||||
|
||||
ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion}
|
||||
ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion}
|
||||
ErrIncompatibleProto = HandshakeError{e: handshakeproto.Error_IncompatibleProto}
|
||||
ErrRemoteIncompatibleProto = HandshakeError{Err: errors.New("remote peer declined the proto")}
|
||||
|
||||
ErrGotNotAHandshakeMessage = errors.New("go not a handshake message")
|
||||
ErrGotUnexpectedMessage = errors.New("go not a handshake message")
|
||||
)
|
||||
|
||||
var handshakePool = &sync.Pool{New: func() any {
|
||||
return &handshake{
|
||||
remoteCred: &handshakeproto.Credentials{},
|
||||
remoteAck: &handshakeproto.Ack{},
|
||||
localAck: &handshakeproto.Ack{},
|
||||
buf: make([]byte, 0, 1024),
|
||||
remoteCred: &handshakeproto.Credentials{},
|
||||
remoteAck: &handshakeproto.Ack{},
|
||||
localAck: &handshakeproto.Ack{},
|
||||
remoteProto: &handshakeproto.Proto{},
|
||||
buf: make([]byte, 0, 1024),
|
||||
}
|
||||
}}
|
||||
|
||||
@ -57,147 +69,17 @@ type CredentialChecker interface {
|
||||
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
|
||||
}
|
||||
|
||||
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
h := newHandshake()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
identity, err = outgoingHandshake(h, sc, cc)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
_ = sc.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
defer h.release()
|
||||
h.conn = sc
|
||||
localCred := cc.MakeCredentials(sc)
|
||||
if err = h.writeCredentials(localCred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
msg, err := h.readMsg()
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
if msg.ack != nil {
|
||||
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
|
||||
return nil, ErrPeerDeclinedCredentials
|
||||
}
|
||||
return nil, HandshakeError{e: msg.ack.Error}
|
||||
}
|
||||
|
||||
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err = h.readMsg()
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
if msg.ack == nil {
|
||||
err = ErrUnexpectedPayload
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
if msg.ack.Error == handshakeproto.Error_Null {
|
||||
return identity, nil
|
||||
} else {
|
||||
_ = h.conn.Close()
|
||||
return nil, HandshakeError{e: msg.ack.Error}
|
||||
}
|
||||
}
|
||||
|
||||
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
h := newHandshake()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
identity, err = incomingHandshake(h, sc, cc)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
_ = sc.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
defer h.release()
|
||||
h.conn = sc
|
||||
|
||||
msg, err := h.readMsg()
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
if msg.ack != nil {
|
||||
return nil, ErrUnexpectedPayload
|
||||
}
|
||||
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err = h.readMsg()
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
if msg.ack == nil {
|
||||
err = ErrUnexpectedPayload
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
if msg.ack.Error != handshakeproto.Error_Null {
|
||||
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
|
||||
return nil, ErrPeerDeclinedCredentials
|
||||
}
|
||||
return nil, HandshakeError{e: msg.ack.Error}
|
||||
}
|
||||
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func newHandshake() *handshake {
|
||||
return handshakePool.Get().(*handshake)
|
||||
}
|
||||
|
||||
type handshake struct {
|
||||
conn sec.SecureConn
|
||||
remoteCred *handshakeproto.Credentials
|
||||
remoteAck *handshakeproto.Ack
|
||||
localAck *handshakeproto.Ack
|
||||
buf []byte
|
||||
conn net.Conn
|
||||
remoteCred *handshakeproto.Credentials
|
||||
remoteProto *handshakeproto.Proto
|
||||
remoteAck *handshakeproto.Ack
|
||||
localAck *handshakeproto.Ack
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err error) {
|
||||
@ -209,8 +91,17 @@ func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err erro
|
||||
return h.writeData(msgTypeCred, n)
|
||||
}
|
||||
|
||||
func (h *handshake) writeProto(proto *handshakeproto.Proto) (err error) {
|
||||
h.buf = slices.Grow(h.buf, proto.Size()+headerSize)[:proto.Size()+headerSize]
|
||||
n, err := proto.MarshalToSizedBuffer(h.buf[headerSize:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return h.writeData(msgTypeProto, n)
|
||||
}
|
||||
|
||||
func (h *handshake) tryWriteErrAndClose(err error) {
|
||||
if err == ErrGotNotAHandshakeMessage {
|
||||
if err == ErrUnexpectedPayload {
|
||||
// if we got unexpected message - just close the connection
|
||||
_ = h.conn.Close()
|
||||
return
|
||||
@ -243,21 +134,26 @@ func (h *handshake) writeData(tp byte, size int) (err error) {
|
||||
}
|
||||
|
||||
type message struct {
|
||||
cred *handshakeproto.Credentials
|
||||
ack *handshakeproto.Ack
|
||||
cred *handshakeproto.Credentials
|
||||
proto *handshakeproto.Proto
|
||||
ack *handshakeproto.Ack
|
||||
}
|
||||
|
||||
func (h *handshake) readMsg() (msg message, err error) {
|
||||
func (h *handshake) readMsg(allowedTypes ...byte) (msg message, err error) {
|
||||
h.buf = slices.Grow(h.buf, headerSize)[:headerSize]
|
||||
if _, err = io.ReadFull(h.conn, h.buf[:headerSize]); err != nil {
|
||||
return
|
||||
}
|
||||
tp := h.buf[0]
|
||||
if tp != msgTypeCred && tp != msgTypeAck {
|
||||
err = ErrGotNotAHandshakeMessage
|
||||
if !slices.Contains(allowedTypes, tp) {
|
||||
err = ErrUnexpectedPayload
|
||||
return
|
||||
}
|
||||
size := binary.LittleEndian.Uint32(h.buf[1:headerSize])
|
||||
if size > sizeLimit {
|
||||
err = ErrGotUnexpectedMessage
|
||||
return
|
||||
}
|
||||
h.buf = slices.Grow(h.buf, int(size))[:size]
|
||||
if _, err = io.ReadFull(h.conn, h.buf[:size]); err != nil {
|
||||
return
|
||||
@ -273,6 +169,11 @@ func (h *handshake) readMsg() (msg message, err error) {
|
||||
return
|
||||
}
|
||||
msg.ack = h.remoteAck
|
||||
case msgTypeProto:
|
||||
if err = h.remoteProto.Unmarshal(h.buf[:size]); err != nil {
|
||||
return
|
||||
}
|
||||
msg.proto = h.remoteProto
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -284,5 +185,6 @@ func (h *handshake) release() {
|
||||
h.remoteAck.Error = 0
|
||||
h.remoteCred.Type = 0
|
||||
h.remoteCred.Payload = h.remoteCred.Payload[:0]
|
||||
h.remoteProto.Proto = 0
|
||||
handshakePool.Put(h)
|
||||
}
|
||||
|
||||
@ -59,6 +59,7 @@ const (
|
||||
Error_SkipVerifyNotAllowed Error = 4
|
||||
Error_DeadlineExceeded Error = 5
|
||||
Error_IncompatibleVersion Error = 6
|
||||
Error_IncompatibleProto Error = 7
|
||||
)
|
||||
|
||||
var Error_name = map[int32]string{
|
||||
@ -69,6 +70,7 @@ var Error_name = map[int32]string{
|
||||
4: "SkipVerifyNotAllowed",
|
||||
5: "DeadlineExceeded",
|
||||
6: "IncompatibleVersion",
|
||||
7: "IncompatibleProto",
|
||||
}
|
||||
|
||||
var Error_value = map[string]int32{
|
||||
@ -79,6 +81,7 @@ var Error_value = map[string]int32{
|
||||
"SkipVerifyNotAllowed": 4,
|
||||
"DeadlineExceeded": 5,
|
||||
"IncompatibleVersion": 6,
|
||||
"IncompatibleProto": 7,
|
||||
}
|
||||
|
||||
func (x Error) String() string {
|
||||
@ -89,6 +92,28 @@ func (Error) EnumDescriptor() ([]byte, []int) {
|
||||
return fileDescriptor_60283fc75f020893, []int{1}
|
||||
}
|
||||
|
||||
type ProtoType int32
|
||||
|
||||
const (
|
||||
ProtoType_DRPC ProtoType = 0
|
||||
)
|
||||
|
||||
var ProtoType_name = map[int32]string{
|
||||
0: "DRPC",
|
||||
}
|
||||
|
||||
var ProtoType_value = map[string]int32{
|
||||
"DRPC": 0,
|
||||
}
|
||||
|
||||
func (x ProtoType) String() string {
|
||||
return proto.EnumName(ProtoType_name, int32(x))
|
||||
}
|
||||
|
||||
func (ProtoType) EnumDescriptor() ([]byte, []int) {
|
||||
return fileDescriptor_60283fc75f020893, []int{2}
|
||||
}
|
||||
|
||||
type Credentials struct {
|
||||
Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"`
|
||||
Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
|
||||
@ -247,12 +272,58 @@ func (m *Ack) GetError() Error {
|
||||
return Error_Null
|
||||
}
|
||||
|
||||
type Proto struct {
|
||||
Proto ProtoType `protobuf:"varint,1,opt,name=proto,proto3,enum=anyHandshake.ProtoType" json:"proto,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Proto) Reset() { *m = Proto{} }
|
||||
func (m *Proto) String() string { return proto.CompactTextString(m) }
|
||||
func (*Proto) ProtoMessage() {}
|
||||
func (*Proto) Descriptor() ([]byte, []int) {
|
||||
return fileDescriptor_60283fc75f020893, []int{3}
|
||||
}
|
||||
func (m *Proto) XXX_Unmarshal(b []byte) error {
|
||||
return m.Unmarshal(b)
|
||||
}
|
||||
func (m *Proto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
|
||||
if deterministic {
|
||||
return xxx_messageInfo_Proto.Marshal(b, m, deterministic)
|
||||
} else {
|
||||
b = b[:cap(b)]
|
||||
n, err := m.MarshalToSizedBuffer(b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return b[:n], nil
|
||||
}
|
||||
}
|
||||
func (m *Proto) XXX_Merge(src proto.Message) {
|
||||
xxx_messageInfo_Proto.Merge(m, src)
|
||||
}
|
||||
func (m *Proto) XXX_Size() int {
|
||||
return m.Size()
|
||||
}
|
||||
func (m *Proto) XXX_DiscardUnknown() {
|
||||
xxx_messageInfo_Proto.DiscardUnknown(m)
|
||||
}
|
||||
|
||||
var xxx_messageInfo_Proto proto.InternalMessageInfo
|
||||
|
||||
func (m *Proto) GetProto() ProtoType {
|
||||
if m != nil {
|
||||
return m.Proto
|
||||
}
|
||||
return ProtoType_DRPC
|
||||
}
|
||||
|
||||
func init() {
|
||||
proto.RegisterEnum("anyHandshake.CredentialsType", CredentialsType_name, CredentialsType_value)
|
||||
proto.RegisterEnum("anyHandshake.Error", Error_name, Error_value)
|
||||
proto.RegisterEnum("anyHandshake.ProtoType", ProtoType_name, ProtoType_value)
|
||||
proto.RegisterType((*Credentials)(nil), "anyHandshake.Credentials")
|
||||
proto.RegisterType((*PayloadSignedPeerIds)(nil), "anyHandshake.PayloadSignedPeerIds")
|
||||
proto.RegisterType((*Ack)(nil), "anyHandshake.Ack")
|
||||
proto.RegisterType((*Proto)(nil), "anyHandshake.Proto")
|
||||
}
|
||||
|
||||
func init() {
|
||||
@ -260,32 +331,35 @@ func init() {
|
||||
}
|
||||
|
||||
var fileDescriptor_60283fc75f020893 = []byte{
|
||||
// 395 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcd, 0x6e, 0x13, 0x31,
|
||||
0x10, 0xc7, 0xd7, 0x4d, 0x52, 0xaa, 0x21, 0x2d, 0xee, 0x34, 0xc0, 0x0a, 0x89, 0x55, 0x94, 0x53,
|
||||
0xc8, 0x21, 0xe1, 0xeb, 0x05, 0x02, 0x2d, 0x22, 0x97, 0xaa, 0xda, 0x42, 0x0f, 0xdc, 0xdc, 0xf5,
|
||||
0xd0, 0x5a, 0x31, 0xf6, 0xca, 0x76, 0x43, 0xf7, 0x2d, 0xb8, 0xf2, 0x46, 0x1c, 0x7b, 0xe4, 0x88,
|
||||
0x92, 0x17, 0x41, 0x71, 0x12, 0x92, 0x70, 0xea, 0xc5, 0x9e, 0x8f, 0x9f, 0xfd, 0xff, 0x8f, 0x65,
|
||||
0x18, 0x1a, 0x0a, 0x03, 0x4f, 0xc5, 0x8d, 0x23, 0x4f, 0x6e, 0xa2, 0x0a, 0x1a, 0x5c, 0x0b, 0x23,
|
||||
0xfd, 0xb5, 0x18, 0x6f, 0x44, 0xa5, 0xb3, 0xc1, 0x0e, 0xe2, 0xea, 0xd7, 0xd5, 0x7e, 0x2c, 0x60,
|
||||
0x53, 0x98, 0xea, 0xe3, 0xaa, 0xd6, 0x09, 0xf0, 0xf0, 0xbd, 0x23, 0x49, 0x26, 0x28, 0xa1, 0x3d,
|
||||
0xbe, 0x82, 0x7a, 0xa8, 0x4a, 0x4a, 0x59, 0x9b, 0x75, 0x0f, 0x5e, 0x3f, 0xef, 0x6f, 0xb2, 0xfd,
|
||||
0x0d, 0xf0, 0x53, 0x55, 0x52, 0x1e, 0x51, 0x4c, 0xe1, 0x41, 0x29, 0x2a, 0x6d, 0x85, 0x4c, 0x77,
|
||||
0xda, 0xac, 0xdb, 0xcc, 0x57, 0xe9, 0xbc, 0x33, 0x21, 0xe7, 0x95, 0x35, 0x69, 0xad, 0xcd, 0xba,
|
||||
0xfb, 0xf9, 0x2a, 0xed, 0x7c, 0x80, 0xd6, 0xd9, 0x02, 0x3a, 0x57, 0x57, 0x86, 0xe4, 0x19, 0x91,
|
||||
0x1b, 0x49, 0x8f, 0xcf, 0x60, 0x4f, 0x45, 0x89, 0x50, 0x45, 0x0b, 0xcd, 0xfc, 0x5f, 0x8e, 0x08,
|
||||
0x75, 0xaf, 0xae, 0xcc, 0x52, 0x24, 0xc6, 0x9d, 0x97, 0x50, 0x1b, 0x16, 0x63, 0x7c, 0x01, 0x0d,
|
||||
0x72, 0xce, 0xba, 0xa5, 0xed, 0xa3, 0x6d, 0xdb, 0x27, 0xf3, 0x56, 0xbe, 0x20, 0x7a, 0x6f, 0xe1,
|
||||
0xd1, 0x7f, 0x63, 0xe0, 0x01, 0xc0, 0xf9, 0x58, 0x95, 0x17, 0xe4, 0xd4, 0xd7, 0x8a, 0x27, 0x78,
|
||||
0x08, 0xfb, 0x5b, 0xae, 0x38, 0xeb, 0xfd, 0x64, 0xd0, 0x88, 0xd7, 0xe0, 0x1e, 0xd4, 0x4f, 0x6f,
|
||||
0xb4, 0xe6, 0xc9, 0xfc, 0xd8, 0x67, 0x43, 0xb7, 0x25, 0x15, 0x81, 0x24, 0x67, 0xf8, 0x04, 0x70,
|
||||
0x64, 0x26, 0x42, 0x2b, 0xb9, 0x21, 0xc0, 0x77, 0xf0, 0x31, 0x1c, 0xae, 0xb9, 0xe5, 0xd4, 0xbc,
|
||||
0x86, 0x29, 0xb4, 0xd6, 0xaa, 0xa7, 0x36, 0x0c, 0xb5, 0xb6, 0xdf, 0x49, 0xf2, 0x3a, 0xb6, 0x80,
|
||||
0x1f, 0x93, 0x90, 0x5a, 0x19, 0x3a, 0xb9, 0x2d, 0x88, 0x24, 0x49, 0xde, 0xc0, 0xa7, 0x70, 0x34,
|
||||
0x32, 0x85, 0xfd, 0x56, 0x8a, 0xa0, 0x2e, 0x35, 0x5d, 0x2c, 0x5e, 0x92, 0xef, 0xbe, 0x3b, 0xfe,
|
||||
0x35, 0xcd, 0xd8, 0xdd, 0x34, 0x63, 0x7f, 0xa6, 0x19, 0xfb, 0x31, 0xcb, 0x92, 0xbb, 0x59, 0x96,
|
||||
0xfc, 0x9e, 0x65, 0xc9, 0x97, 0xde, 0xfd, 0x3f, 0xcb, 0xe5, 0x6e, 0xdc, 0xde, 0xfc, 0x0d, 0x00,
|
||||
0x00, 0xff, 0xff, 0xbf, 0x78, 0x2f, 0x36, 0x61, 0x02, 0x00, 0x00,
|
||||
// 439 bytes of a gzipped FileDescriptorProto
|
||||
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcf, 0x6e, 0xd3, 0x40,
|
||||
0x10, 0xc6, 0xbd, 0x4d, 0xdc, 0x96, 0x21, 0x2d, 0xdb, 0x6d, 0x4a, 0x2d, 0x24, 0xac, 0x28, 0xa7,
|
||||
0x10, 0x89, 0x84, 0x7f, 0xe2, 0x1e, 0x9a, 0x22, 0x72, 0xa9, 0x22, 0x17, 0x7a, 0xe0, 0xb6, 0xf5,
|
||||
0x0e, 0xed, 0x2a, 0xcb, 0xda, 0xda, 0xdd, 0x86, 0xfa, 0x2d, 0x78, 0x14, 0x1e, 0x83, 0x63, 0x8f,
|
||||
0x1c, 0x51, 0xf2, 0x22, 0xc8, 0x9b, 0xa4, 0x71, 0x38, 0xf5, 0x62, 0xef, 0xcc, 0xfc, 0xc6, 0xdf,
|
||||
0x7c, 0xe3, 0x85, 0x81, 0x46, 0xd7, 0xb7, 0x98, 0xde, 0x18, 0xb4, 0x68, 0xa6, 0x32, 0xc5, 0xfe,
|
||||
0x35, 0xd7, 0xc2, 0x5e, 0xf3, 0x49, 0xe5, 0x94, 0x9b, 0xcc, 0x65, 0x7d, 0xff, 0xb4, 0xeb, 0x6c,
|
||||
0xcf, 0x27, 0x58, 0x83, 0xeb, 0xe2, 0xd3, 0x2a, 0xd7, 0x76, 0xf0, 0xf8, 0xc4, 0xa0, 0x40, 0xed,
|
||||
0x24, 0x57, 0x96, 0xbd, 0x86, 0xba, 0x2b, 0x72, 0x8c, 0x48, 0x8b, 0x74, 0xf6, 0xdf, 0x3c, 0xef,
|
||||
0x55, 0xd9, 0x5e, 0x05, 0xfc, 0x5c, 0xe4, 0x98, 0x78, 0x94, 0x45, 0xb0, 0x93, 0xf3, 0x42, 0x65,
|
||||
0x5c, 0x44, 0x5b, 0x2d, 0xd2, 0x69, 0x24, 0xab, 0xb0, 0xac, 0x4c, 0xd1, 0x58, 0x99, 0xe9, 0xa8,
|
||||
0xd6, 0x22, 0x9d, 0xbd, 0x64, 0x15, 0xb6, 0x3f, 0x42, 0x73, 0xbc, 0x80, 0xce, 0xe5, 0x95, 0x46,
|
||||
0x31, 0x46, 0x34, 0x23, 0x61, 0xd9, 0x33, 0xd8, 0x95, 0x5e, 0xc2, 0x15, 0x7e, 0x84, 0x46, 0x72,
|
||||
0x1f, 0x33, 0x06, 0x75, 0x2b, 0xaf, 0xf4, 0x52, 0xc4, 0x9f, 0xdb, 0xaf, 0xa0, 0x36, 0x48, 0x27,
|
||||
0xec, 0x05, 0x84, 0x68, 0x4c, 0x66, 0x96, 0x63, 0x1f, 0x6e, 0x8e, 0x7d, 0x5a, 0x96, 0x92, 0x05,
|
||||
0xd1, 0x7e, 0x0f, 0xe1, 0xd8, 0xaf, 0xe1, 0x25, 0x84, 0x7e, 0x1f, 0xcb, 0x9e, 0xe3, 0xcd, 0x1e,
|
||||
0xcf, 0x78, 0x93, 0x0b, 0xaa, 0xfb, 0x0e, 0x9e, 0xfc, 0x67, 0x9f, 0xed, 0x03, 0x9c, 0x4f, 0x64,
|
||||
0x7e, 0x81, 0x46, 0x7e, 0x2b, 0x68, 0xc0, 0x0e, 0x60, 0x6f, 0xc3, 0x0d, 0x25, 0xdd, 0x5f, 0x04,
|
||||
0x42, 0x2f, 0xcf, 0x76, 0xa1, 0x7e, 0x76, 0xa3, 0x14, 0x0d, 0xca, 0xb6, 0x2f, 0x1a, 0x6f, 0x73,
|
||||
0x4c, 0x1d, 0x0a, 0x4a, 0xd8, 0x53, 0x60, 0x23, 0x3d, 0xe5, 0x4a, 0x8a, 0x8a, 0x00, 0xdd, 0x62,
|
||||
0x47, 0x70, 0xb0, 0xe6, 0x96, 0xdb, 0xa2, 0x35, 0x16, 0x41, 0x73, 0xad, 0x7a, 0x96, 0xb9, 0x81,
|
||||
0x52, 0xd9, 0x0f, 0x14, 0xb4, 0xce, 0x9a, 0x40, 0x87, 0xc8, 0x85, 0x92, 0x1a, 0x4f, 0x6f, 0x53,
|
||||
0x44, 0x81, 0x82, 0x86, 0xec, 0x18, 0x0e, 0x47, 0x3a, 0xcd, 0xbe, 0xe7, 0xdc, 0xc9, 0x4b, 0x85,
|
||||
0x17, 0x8b, 0x3f, 0x40, 0xb7, 0xcb, 0xef, 0x57, 0x0b, 0xde, 0x31, 0xdd, 0xe9, 0x1e, 0xc1, 0xa3,
|
||||
0x7b, 0xf3, 0xe5, 0xd4, 0xc3, 0x64, 0x7c, 0x42, 0x83, 0x0f, 0xc3, 0xdf, 0xb3, 0x98, 0xdc, 0xcd,
|
||||
0x62, 0xf2, 0x77, 0x16, 0x93, 0x9f, 0xf3, 0x38, 0xb8, 0x9b, 0xc7, 0xc1, 0x9f, 0x79, 0x1c, 0x7c,
|
||||
0xed, 0x3e, 0xfc, 0x4a, 0x5e, 0x6e, 0xfb, 0xd7, 0xdb, 0x7f, 0x01, 0x00, 0x00, 0xff, 0xff, 0x53,
|
||||
0x32, 0xf7, 0x79, 0xc7, 0x02, 0x00, 0x00,
|
||||
}
|
||||
|
||||
func (m *Credentials) Marshal() (dAtA []byte, err error) {
|
||||
@ -393,6 +467,34 @@ func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) {
|
||||
return len(dAtA) - i, nil
|
||||
}
|
||||
|
||||
func (m *Proto) Marshal() (dAtA []byte, err error) {
|
||||
size := m.Size()
|
||||
dAtA = make([]byte, size)
|
||||
n, err := m.MarshalToSizedBuffer(dAtA[:size])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return dAtA[:n], nil
|
||||
}
|
||||
|
||||
func (m *Proto) MarshalTo(dAtA []byte) (int, error) {
|
||||
size := m.Size()
|
||||
return m.MarshalToSizedBuffer(dAtA[:size])
|
||||
}
|
||||
|
||||
func (m *Proto) MarshalToSizedBuffer(dAtA []byte) (int, error) {
|
||||
i := len(dAtA)
|
||||
_ = i
|
||||
var l int
|
||||
_ = l
|
||||
if m.Proto != 0 {
|
||||
i = encodeVarintHandshake(dAtA, i, uint64(m.Proto))
|
||||
i--
|
||||
dAtA[i] = 0x8
|
||||
}
|
||||
return len(dAtA) - i, nil
|
||||
}
|
||||
|
||||
func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int {
|
||||
offset -= sovHandshake(v)
|
||||
base := offset
|
||||
@ -452,6 +554,18 @@ func (m *Ack) Size() (n int) {
|
||||
return n
|
||||
}
|
||||
|
||||
func (m *Proto) Size() (n int) {
|
||||
if m == nil {
|
||||
return 0
|
||||
}
|
||||
var l int
|
||||
_ = l
|
||||
if m.Proto != 0 {
|
||||
n += 1 + sovHandshake(uint64(m.Proto))
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
func sovHandshake(x uint64) (n int) {
|
||||
return (math_bits.Len64(x|1) + 6) / 7
|
||||
}
|
||||
@ -767,6 +881,75 @@ func (m *Ack) Unmarshal(dAtA []byte) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (m *Proto) Unmarshal(dAtA []byte) error {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
for iNdEx < l {
|
||||
preIndex := iNdEx
|
||||
var wire uint64
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowHandshake
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
wire |= uint64(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
fieldNum := int32(wire >> 3)
|
||||
wireType := int(wire & 0x7)
|
||||
if wireType == 4 {
|
||||
return fmt.Errorf("proto: Proto: wiretype end group for non-group")
|
||||
}
|
||||
if fieldNum <= 0 {
|
||||
return fmt.Errorf("proto: Proto: illegal tag %d (wire type %d)", fieldNum, wire)
|
||||
}
|
||||
switch fieldNum {
|
||||
case 1:
|
||||
if wireType != 0 {
|
||||
return fmt.Errorf("proto: wrong wireType = %d for field Proto", wireType)
|
||||
}
|
||||
m.Proto = 0
|
||||
for shift := uint(0); ; shift += 7 {
|
||||
if shift >= 64 {
|
||||
return ErrIntOverflowHandshake
|
||||
}
|
||||
if iNdEx >= l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
b := dAtA[iNdEx]
|
||||
iNdEx++
|
||||
m.Proto |= ProtoType(b&0x7F) << shift
|
||||
if b < 0x80 {
|
||||
break
|
||||
}
|
||||
}
|
||||
default:
|
||||
iNdEx = preIndex
|
||||
skippy, err := skipHandshake(dAtA[iNdEx:])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if (skippy < 0) || (iNdEx+skippy) < 0 {
|
||||
return ErrInvalidLengthHandshake
|
||||
}
|
||||
if (iNdEx + skippy) > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
iNdEx += skippy
|
||||
}
|
||||
}
|
||||
|
||||
if iNdEx > l {
|
||||
return io.ErrUnexpectedEOF
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func skipHandshake(dAtA []byte) (n int, err error) {
|
||||
l := len(dAtA)
|
||||
iNdEx := 0
|
||||
|
||||
@ -5,6 +5,8 @@ option go_package = "net/secureservice/handshake/handshakeproto";
|
||||
|
||||
/*
|
||||
|
||||
CREDENTIALS HANDSHAKE
|
||||
|
||||
Alice opens a new connection with Bob
|
||||
1. TLS handshake done successfully; both sides know local and remote peer identifiers.
|
||||
|
||||
@ -68,4 +70,20 @@ enum Error {
|
||||
SkipVerifyNotAllowed = 4;
|
||||
DeadlineExceeded = 5;
|
||||
IncompatibleVersion = 6;
|
||||
IncompatibleProto = 7;
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
|
||||
PROTO HANDSHAKE
|
||||
|
||||
*/
|
||||
|
||||
message Proto {
|
||||
ProtoType proto = 1;
|
||||
}
|
||||
|
||||
enum ProtoType {
|
||||
DRPC = 0;
|
||||
}
|
||||
97
net/secureservice/handshake/proto.go
Normal file
97
net/secureservice/handshake/proto.go
Normal file
@ -0,0 +1,97 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"golang.org/x/exp/slices"
|
||||
"net"
|
||||
)
|
||||
|
||||
type ProtoChecker struct {
|
||||
AllowedProtoTypes []handshakeproto.ProtoType
|
||||
}
|
||||
|
||||
func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) (err error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
h := newHandshake()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
err = outgoingProtoHandshake(h, conn, pt)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
_ = conn.Close()
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func outgoingProtoHandshake(h *handshake, conn net.Conn, pt handshakeproto.ProtoType) (err error) {
|
||||
defer h.release()
|
||||
h.conn = conn
|
||||
localProto := &handshakeproto.Proto{
|
||||
Proto: pt,
|
||||
}
|
||||
if err = h.writeProto(localProto); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
msg, err := h.readMsg(msgTypeAck)
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
if msg.ack.Error == handshakeproto.Error_IncompatibleProto {
|
||||
return ErrRemoteIncompatibleProto
|
||||
}
|
||||
if msg.ack.Error == handshakeproto.Error_Null {
|
||||
return nil
|
||||
}
|
||||
return HandshakeError{e: msg.ack.Error}
|
||||
}
|
||||
|
||||
func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
h := newHandshake()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
protoType, err = incomingProtoHandshake(h, conn, pt)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
_ = conn.Close()
|
||||
return 0, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func incomingProtoHandshake(h *handshake, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
|
||||
defer h.release()
|
||||
h.conn = conn
|
||||
|
||||
msg, err := h.readMsg(msgTypeProto)
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
if !slices.Contains(pt.AllowedProtoTypes, msg.proto.Proto) {
|
||||
err = ErrIncompatibleProto
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return 0, err
|
||||
} else {
|
||||
return msg.proto.Proto, nil
|
||||
}
|
||||
}
|
||||
121
net/secureservice/handshake/proto_test.go
Normal file
121
net/secureservice/handshake/proto_test.go
Normal file
@ -0,0 +1,121 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type protoRes struct {
|
||||
protoType handshakeproto.ProtoType
|
||||
err error
|
||||
}
|
||||
|
||||
func newProtoChecker(types ...handshakeproto.ProtoType) ProtoChecker {
|
||||
return ProtoChecker{AllowedProtoTypes: types}
|
||||
}
|
||||
func TestIncomingProtoHandshake(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var protoResCh = make(chan protoRes, 1)
|
||||
go func() {
|
||||
protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1))
|
||||
protoResCh <- protoRes{protoType: protoType, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
|
||||
// write desired proto
|
||||
require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: handshakeproto.ProtoType(1)}))
|
||||
msg, err := h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
||||
res := <-protoResCh
|
||||
require.NoError(t, res.err)
|
||||
assert.Equal(t, handshakeproto.ProtoType(1), res.protoType)
|
||||
})
|
||||
t.Run("incompatible", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var protoResCh = make(chan protoRes, 1)
|
||||
go func() {
|
||||
protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1))
|
||||
protoResCh <- protoRes{protoType: protoType, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
|
||||
// write desired proto
|
||||
require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: 0}))
|
||||
msg, err := h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, handshakeproto.Error_IncompatibleProto, msg.ack.Error)
|
||||
res := <-protoResCh
|
||||
require.Error(t, res.err, ErrIncompatibleProto.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestOutgoingProtoHandshake(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var protoResCh = make(chan protoRes, 1)
|
||||
go func() {
|
||||
err := OutgoingProtoHandshake(nil, c1, 1)
|
||||
protoResCh <- protoRes{err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
|
||||
msg, err := h.readMsg(msgTypeProto)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto)
|
||||
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
|
||||
|
||||
res := <-protoResCh
|
||||
assert.NoError(t, res.err)
|
||||
})
|
||||
t.Run("incompatible", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var protoResCh = make(chan protoRes, 1)
|
||||
go func() {
|
||||
err := OutgoingProtoHandshake(nil, c1, 1)
|
||||
protoResCh <- protoRes{err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
|
||||
msg, err := h.readMsg(msgTypeProto)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto)
|
||||
require.NoError(t, h.writeAck(handshakeproto.Error_IncompatibleProto))
|
||||
|
||||
res := <-protoResCh
|
||||
assert.EqualError(t, res.err, ErrRemoteIncompatibleProto.Error())
|
||||
})
|
||||
}
|
||||
|
||||
func TestEndToEndProto(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var (
|
||||
inResCh = make(chan protoRes, 1)
|
||||
outResCh = make(chan protoRes, 1)
|
||||
)
|
||||
st := time.Now()
|
||||
go func() {
|
||||
err := OutgoingProtoHandshake(nil, c1, 0)
|
||||
outResCh <- protoRes{err: err}
|
||||
}()
|
||||
go func() {
|
||||
protoType, err := IncomingProtoHandshake(nil, c2, newProtoChecker(0, 1))
|
||||
inResCh <- protoRes{protoType: protoType, err: err}
|
||||
}()
|
||||
|
||||
outRes := <-outResCh
|
||||
assert.NoError(t, outRes.err)
|
||||
|
||||
inRes := <-inResCh
|
||||
assert.NoError(t, inRes.err)
|
||||
assert.Equal(t, handshakeproto.ProtoType(0), inRes.protoType)
|
||||
t.Log("dur", time.Since(st))
|
||||
}
|
||||
@ -25,7 +25,7 @@ func New() SecureService {
|
||||
}
|
||||
|
||||
type SecureService interface {
|
||||
SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
|
||||
SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
|
||||
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
|
||||
app.Component
|
||||
}
|
||||
@ -93,10 +93,10 @@ func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx
|
||||
return
|
||||
}
|
||||
|
||||
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
|
||||
sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
|
||||
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
|
||||
sc, err = s.p2pTr.SecureOutbound(ctx, conn, "")
|
||||
if err != nil {
|
||||
return nil, handshake.HandshakeError{Err: err}
|
||||
return nil, nil, handshake.HandshakeError{Err: err}
|
||||
}
|
||||
peerId := sc.RemotePeer().String()
|
||||
confTypes := s.nodeconf.NodeTypes(peerId)
|
||||
@ -106,10 +106,12 @@ func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.
|
||||
} else {
|
||||
checker = s.noVerifyChecker
|
||||
}
|
||||
// ignore identity for outgoing connection because we don't need it at this moment
|
||||
_, err = handshake.OutgoingHandshake(ctx, sc, checker)
|
||||
identity, err := handshake.OutgoingHandshake(ctx, sc, checker)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
return sc, nil
|
||||
cctx = context.Background()
|
||||
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
|
||||
cctx = peer.CtxWithIdentity(cctx, identity)
|
||||
return cctx, sc, nil
|
||||
}
|
||||
|
||||
@ -39,9 +39,12 @@ func TestHandshake(t *testing.T) {
|
||||
fxC := newFixture(t, nc, nc.GetAccountService(1), 0)
|
||||
defer fxC.Finish(t)
|
||||
|
||||
secConn, err := fxC.SecureOutbound(ctx, cc)
|
||||
cctx, secConn, err := fxC.SecureOutbound(ctx, cc)
|
||||
require.NoError(t, err)
|
||||
ctxPeerId, err := peer.CtxPeerId(cctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
|
||||
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, ctxPeerId)
|
||||
res := <-resCh
|
||||
require.NoError(t, res.err)
|
||||
peerId, err := peer.CtxPeerId(res.ctx)
|
||||
@ -72,7 +75,7 @@ func TestHandshakeIncompatibleVersion(t *testing.T) {
|
||||
}()
|
||||
fxC := newFixture(t, nc, nc.GetAccountService(1), 1)
|
||||
defer fxC.Finish(t)
|
||||
_, err := fxC.SecureOutbound(ctx, cc)
|
||||
_, _, err := fxC.SecureOutbound(ctx, cc)
|
||||
require.Equal(t, handshake.ErrIncompatibleVersion, err)
|
||||
res := <-resCh
|
||||
require.Equal(t, handshake.ErrIncompatibleVersion, res.err)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.32
|
||||
// protoc-gen-go-drpc version: v0.0.33
|
||||
// source: net/streampool/testservice/protos/testservice.proto
|
||||
|
||||
package testservice
|
||||
@ -72,6 +72,10 @@ type drpcTest_TestStreamClient struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcTest_TestStreamClient) GetStream() drpc.Stream {
|
||||
return x.Stream
|
||||
}
|
||||
|
||||
func (x *drpcTest_TestStreamClient) Send(m *StreamMessage) error {
|
||||
return x.MsgSend(m, drpcEncoding_File_net_streampool_testservice_protos_testservice_proto{})
|
||||
}
|
||||
|
||||
188
net/transport/mock_transport/mock_transport.go
Normal file
188
net/transport/mock_transport/mock_transport.go
Normal file
@ -0,0 +1,188 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anyproto/any-sync/net/transport (interfaces: Transport,MultiConn)
|
||||
|
||||
// Package mock_transport is a generated GoMock package.
|
||||
package mock_transport
|
||||
|
||||
import (
|
||||
context "context"
|
||||
net "net"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
transport "github.com/anyproto/any-sync/net/transport"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockTransport is a mock of Transport interface.
|
||||
type MockTransport struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockTransportMockRecorder
|
||||
}
|
||||
|
||||
// MockTransportMockRecorder is the mock recorder for MockTransport.
|
||||
type MockTransportMockRecorder struct {
|
||||
mock *MockTransport
|
||||
}
|
||||
|
||||
// NewMockTransport creates a new mock instance.
|
||||
func NewMockTransport(ctrl *gomock.Controller) *MockTransport {
|
||||
mock := &MockTransport{ctrl: ctrl}
|
||||
mock.recorder = &MockTransportMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockTransport) EXPECT() *MockTransportMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Dial mocks base method.
|
||||
func (m *MockTransport) Dial(arg0 context.Context, arg1 string) (transport.MultiConn, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Dial", arg0, arg1)
|
||||
ret0, _ := ret[0].(transport.MultiConn)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Dial indicates an expected call of Dial.
|
||||
func (mr *MockTransportMockRecorder) Dial(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockTransport)(nil).Dial), arg0, arg1)
|
||||
}
|
||||
|
||||
// SetAccepter mocks base method.
|
||||
func (m *MockTransport) SetAccepter(arg0 transport.Accepter) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetAccepter", arg0)
|
||||
}
|
||||
|
||||
// SetAccepter indicates an expected call of SetAccepter.
|
||||
func (mr *MockTransportMockRecorder) SetAccepter(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccepter", reflect.TypeOf((*MockTransport)(nil).SetAccepter), arg0)
|
||||
}
|
||||
|
||||
// MockMultiConn is a mock of MultiConn interface.
|
||||
type MockMultiConn struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockMultiConnMockRecorder
|
||||
}
|
||||
|
||||
// MockMultiConnMockRecorder is the mock recorder for MockMultiConn.
|
||||
type MockMultiConnMockRecorder struct {
|
||||
mock *MockMultiConn
|
||||
}
|
||||
|
||||
// NewMockMultiConn creates a new mock instance.
|
||||
func NewMockMultiConn(ctrl *gomock.Controller) *MockMultiConn {
|
||||
mock := &MockMultiConn{ctrl: ctrl}
|
||||
mock.recorder = &MockMultiConnMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockMultiConn) EXPECT() *MockMultiConnMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Accept mocks base method.
|
||||
func (m *MockMultiConn) Accept() (net.Conn, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Accept")
|
||||
ret0, _ := ret[0].(net.Conn)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Accept indicates an expected call of Accept.
|
||||
func (mr *MockMultiConnMockRecorder) Accept() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockMultiConn)(nil).Accept))
|
||||
}
|
||||
|
||||
// Addr mocks base method.
|
||||
func (m *MockMultiConn) Addr() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Addr")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Addr indicates an expected call of Addr.
|
||||
func (mr *MockMultiConnMockRecorder) Addr() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockMultiConn)(nil).Addr))
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockMultiConn) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockMultiConnMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiConn)(nil).Close))
|
||||
}
|
||||
|
||||
// Context mocks base method.
|
||||
func (m *MockMultiConn) Context() context.Context {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Context")
|
||||
ret0, _ := ret[0].(context.Context)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Context indicates an expected call of Context.
|
||||
func (mr *MockMultiConnMockRecorder) Context() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockMultiConn)(nil).Context))
|
||||
}
|
||||
|
||||
// IsClosed mocks base method.
|
||||
func (m *MockMultiConn) IsClosed() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsClosed")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// IsClosed indicates an expected call of IsClosed.
|
||||
func (mr *MockMultiConnMockRecorder) IsClosed() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiConn)(nil).IsClosed))
|
||||
}
|
||||
|
||||
// LastUsage mocks base method.
|
||||
func (m *MockMultiConn) LastUsage() time.Time {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LastUsage")
|
||||
ret0, _ := ret[0].(time.Time)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// LastUsage indicates an expected call of LastUsage.
|
||||
func (mr *MockMultiConnMockRecorder) LastUsage() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastUsage", reflect.TypeOf((*MockMultiConn)(nil).LastUsage))
|
||||
}
|
||||
|
||||
// Open mocks base method.
|
||||
func (m *MockMultiConn) Open(arg0 context.Context) (net.Conn, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Open", arg0)
|
||||
ret0, _ := ret[0].(net.Conn)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Open indicates an expected call of Open.
|
||||
func (mr *MockMultiConnMockRecorder) Open(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockMultiConn)(nil).Open), arg0)
|
||||
}
|
||||
44
net/transport/transport.go
Normal file
44
net/transport/transport.go
Normal file
@ -0,0 +1,44 @@
|
||||
//go:generate mockgen -destination mock_transport/mock_transport.go github.com/anyproto/any-sync/net/transport Transport,MultiConn
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrConnClosed = errors.New("connection closed")
|
||||
)
|
||||
|
||||
// Transport is a common interface for a network transport
|
||||
type Transport interface {
|
||||
// SetAccepter sets accepter that will be called for new connections
|
||||
// this method should be called before app start
|
||||
SetAccepter(accepter Accepter)
|
||||
// Dial creates a new connection by given address
|
||||
Dial(ctx context.Context, addr string) (mc MultiConn, err error)
|
||||
}
|
||||
|
||||
// MultiConn is an object of multiplexing connection containing handshake info
|
||||
type MultiConn interface {
|
||||
// Context returns the connection context that contains handshake details
|
||||
Context() context.Context
|
||||
// Accept accepts new sub connections
|
||||
Accept() (conn net.Conn, err error)
|
||||
// Open opens new sub connection
|
||||
Open(ctx context.Context) (conn net.Conn, err error)
|
||||
// LastUsage returns the time of the last connection activity
|
||||
LastUsage() time.Time
|
||||
// Addr returns remote peer address
|
||||
Addr() string
|
||||
// IsClosed returns true when connection is closed
|
||||
IsClosed() bool
|
||||
// Close closes the connection and all sub connections
|
||||
Close() error
|
||||
}
|
||||
|
||||
type Accepter interface {
|
||||
Accept(mc MultiConn) (err error)
|
||||
}
|
||||
12
net/transport/yamux/config.go
Normal file
12
net/transport/yamux/config.go
Normal file
@ -0,0 +1,12 @@
|
||||
package yamux
|
||||
|
||||
type configGetter interface {
|
||||
GetYamux() Config
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
ListenAddrs []string `yaml:"listenAddrs"`
|
||||
WriteTimeoutSec int `yaml:"writeTimeoutSec"`
|
||||
DialTimeoutSec int `yaml:"dialTimeoutSec"`
|
||||
MaxStreams int `yaml:"maxStreams"`
|
||||
}
|
||||
42
net/transport/yamux/conn.go
Normal file
42
net/transport/yamux/conn.go
Normal file
@ -0,0 +1,42 @@
|
||||
package yamux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/net/connutil"
|
||||
"github.com/anyproto/any-sync/net/transport"
|
||||
"github.com/hashicorp/yamux"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type yamuxConn struct {
|
||||
ctx context.Context
|
||||
luConn *connutil.LastUsageConn
|
||||
addr string
|
||||
*yamux.Session
|
||||
}
|
||||
|
||||
func (y *yamuxConn) Open(ctx context.Context) (conn net.Conn, err error) {
|
||||
return y.Session.Open()
|
||||
}
|
||||
|
||||
func (y *yamuxConn) LastUsage() time.Time {
|
||||
return y.luConn.LastUsage()
|
||||
}
|
||||
|
||||
func (y *yamuxConn) Context() context.Context {
|
||||
return y.ctx
|
||||
}
|
||||
|
||||
func (y *yamuxConn) Addr() string {
|
||||
return y.addr
|
||||
}
|
||||
|
||||
func (y *yamuxConn) Accept() (conn net.Conn, err error) {
|
||||
if conn, err = y.Session.Accept(); err != nil {
|
||||
if err == yamux.ErrSessionShutdown {
|
||||
err = transport.ErrConnClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -1,6 +1,6 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
package yamux
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@ -1,6 +1,6 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
package yamux
|
||||
|
||||
import (
|
||||
"errors"
|
||||
170
net/transport/yamux/yamux.go
Normal file
170
net/transport/yamux/yamux.go
Normal file
@ -0,0 +1,170 @@
|
||||
package yamux
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/net/connutil"
|
||||
"github.com/anyproto/any-sync/net/secureservice"
|
||||
"github.com/anyproto/any-sync/net/transport"
|
||||
"github.com/hashicorp/yamux"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const CName = "net.transport.yamux"
|
||||
|
||||
var log = logger.NewNamed(CName)
|
||||
|
||||
func New() Yamux {
|
||||
return new(yamuxTransport)
|
||||
}
|
||||
|
||||
// Yamux implements transport.Transport with tcp+yamux
|
||||
type Yamux interface {
|
||||
transport.Transport
|
||||
app.ComponentRunnable
|
||||
}
|
||||
|
||||
type yamuxTransport struct {
|
||||
secure secureservice.SecureService
|
||||
accepter transport.Accepter
|
||||
conf Config
|
||||
|
||||
listeners []net.Listener
|
||||
listCtx context.Context
|
||||
listCtxCancel context.CancelFunc
|
||||
yamuxConf *yamux.Config
|
||||
}
|
||||
|
||||
func (y *yamuxTransport) Init(a *app.App) (err error) {
|
||||
y.secure = a.MustComponent(secureservice.CName).(secureservice.SecureService)
|
||||
y.conf = a.MustComponent("config").(configGetter).GetYamux()
|
||||
y.yamuxConf = yamux.DefaultConfig()
|
||||
if y.conf.MaxStreams > 0 {
|
||||
y.yamuxConf.AcceptBacklog = y.conf.MaxStreams
|
||||
}
|
||||
y.yamuxConf.EnableKeepAlive = false
|
||||
y.yamuxConf.StreamOpenTimeout = time.Duration(y.conf.DialTimeoutSec) * time.Second
|
||||
y.yamuxConf.ConnectionWriteTimeout = time.Duration(y.conf.WriteTimeoutSec) * time.Second
|
||||
return
|
||||
}
|
||||
|
||||
func (y *yamuxTransport) Name() string {
|
||||
return CName
|
||||
}
|
||||
|
||||
func (y *yamuxTransport) Run(ctx context.Context) (err error) {
|
||||
if y.accepter == nil {
|
||||
return fmt.Errorf("can't run service without accepter")
|
||||
}
|
||||
for _, listAddr := range y.conf.ListenAddrs {
|
||||
list, err := net.Listen("tcp", listAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
y.listeners = append(y.listeners, list)
|
||||
}
|
||||
y.listCtx, y.listCtxCancel = context.WithCancel(context.Background())
|
||||
for _, list := range y.listeners {
|
||||
go y.acceptLoop(y.listCtx, list)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (y *yamuxTransport) SetAccepter(accepter transport.Accepter) {
|
||||
y.accepter = accepter
|
||||
}
|
||||
|
||||
func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.MultiConn, err error) {
|
||||
dialTimeout := time.Duration(y.conf.DialTimeoutSec) * time.Second
|
||||
conn, err := net.DialTimeout("tcp", addr, dialTimeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, dialTimeout)
|
||||
defer cancel()
|
||||
cctx, sc, err := y.secure.SecureOutbound(ctx, conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
luc := connutil.NewLastUsageConn(sc)
|
||||
sess, err := yamux.Client(luc, y.yamuxConf)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mc = &yamuxConn{
|
||||
ctx: cctx,
|
||||
luConn: luc,
|
||||
Session: sess,
|
||||
addr: addr,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (y *yamuxTransport) acceptLoop(ctx context.Context, list net.Listener) {
|
||||
l := log.With(zap.String("localAddr", list.Addr().String()))
|
||||
l.Info("yamux listener started")
|
||||
defer func() {
|
||||
l.Debug("yamux listener stopped")
|
||||
}()
|
||||
for {
|
||||
conn, err := list.Accept()
|
||||
if err != nil {
|
||||
if isTemporary(err) {
|
||||
l.Debug("listener temporary accept error", zap.Error(err))
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err != net.ErrClosed {
|
||||
l.Error("listener closed with error", zap.Error(err))
|
||||
} else {
|
||||
l.Info("listener closed")
|
||||
}
|
||||
return
|
||||
}
|
||||
go y.accept(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (y *yamuxTransport) accept(conn net.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(y.conf.DialTimeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
cctx, sc, err := y.secure.SecureInbound(ctx, conn)
|
||||
if err != nil {
|
||||
log.Warn("incoming connection handshake error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
luc := connutil.NewLastUsageConn(sc)
|
||||
sess, err := yamux.Server(luc, y.yamuxConf)
|
||||
if err != nil {
|
||||
log.Warn("incoming connection yamux session error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
mc := &yamuxConn{
|
||||
ctx: cctx,
|
||||
luConn: luc,
|
||||
Session: sess,
|
||||
addr: conn.RemoteAddr().String(),
|
||||
}
|
||||
if err = y.accepter.Accept(mc); err != nil {
|
||||
log.Warn("connection accept error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (y *yamuxTransport) Close(ctx context.Context) (err error) {
|
||||
if y.listCtxCancel != nil {
|
||||
y.listCtxCancel()
|
||||
}
|
||||
for _, l := range y.listeners {
|
||||
_ = l.Close()
|
||||
}
|
||||
return
|
||||
}
|
||||
134
net/transport/yamux/yamux_test.go
Normal file
134
net/transport/yamux/yamux_test.go
Normal file
@ -0,0 +1,134 @@
|
||||
package yamux
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/net/secureservice"
|
||||
"github.com/anyproto/any-sync/net/transport"
|
||||
"github.com/anyproto/any-sync/nodeconf"
|
||||
"github.com/anyproto/any-sync/nodeconf/mock_nodeconf"
|
||||
"github.com/anyproto/any-sync/testutil/accounttest"
|
||||
"github.com/anyproto/any-sync/testutil/testnodeconf"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
func TestYamuxTransport_Dial(t *testing.T) {
|
||||
fxS := newFixture(t)
|
||||
defer fxS.finish(t)
|
||||
fxC := newFixture(t)
|
||||
defer fxC.finish(t)
|
||||
|
||||
mcC, err := fxC.Dial(ctx, fxS.addr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, fxS.accepter.mcs, 1)
|
||||
mcS := fxS.accepter.mcs[0]
|
||||
|
||||
var (
|
||||
sData string
|
||||
acceptErr error
|
||||
copyErr error
|
||||
done = make(chan struct{})
|
||||
)
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
conn, serr := mcS.Accept()
|
||||
if serr != nil {
|
||||
acceptErr = serr
|
||||
return
|
||||
}
|
||||
buf := bytes.NewBuffer(nil)
|
||||
_, copyErr = io.Copy(buf, conn)
|
||||
sData = buf.String()
|
||||
return
|
||||
}()
|
||||
|
||||
conn, err := mcC.Open(ctx)
|
||||
require.NoError(t, err)
|
||||
data := "some data"
|
||||
_, err = conn.Write([]byte(data))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, conn.Close())
|
||||
<-done
|
||||
|
||||
assert.NoError(t, acceptErr)
|
||||
assert.Equal(t, data, sData)
|
||||
assert.NoError(t, copyErr)
|
||||
}
|
||||
|
||||
type fixture struct {
|
||||
*yamuxTransport
|
||||
a *app.App
|
||||
ctrl *gomock.Controller
|
||||
mockNodeConf *mock_nodeconf.MockService
|
||||
acc *accounttest.AccountTestService
|
||||
accepter *testAccepter
|
||||
addr string
|
||||
}
|
||||
|
||||
func newFixture(t *testing.T) *fixture {
|
||||
fx := &fixture{
|
||||
yamuxTransport: New().(*yamuxTransport),
|
||||
ctrl: gomock.NewController(t),
|
||||
acc: &accounttest.AccountTestService{},
|
||||
accepter: &testAccepter{},
|
||||
a: new(app.App),
|
||||
}
|
||||
|
||||
fx.mockNodeConf = mock_nodeconf.NewMockService(fx.ctrl)
|
||||
fx.mockNodeConf.EXPECT().Init(gomock.Any())
|
||||
fx.mockNodeConf.EXPECT().Name().Return(nodeconf.CName).AnyTimes()
|
||||
fx.mockNodeConf.EXPECT().Run(ctx)
|
||||
fx.mockNodeConf.EXPECT().Close(ctx)
|
||||
fx.mockNodeConf.EXPECT().NodeTypes(gomock.Any()).Return([]nodeconf.NodeType{nodeconf.NodeTypeTree}).AnyTimes()
|
||||
fx.a.Register(fx.acc).Register(newTestConf()).Register(fx.mockNodeConf).Register(secureservice.New()).Register(fx.yamuxTransport).Register(fx.accepter)
|
||||
require.NoError(t, fx.a.Start(ctx))
|
||||
fx.addr = fx.listeners[0].Addr().String()
|
||||
return fx
|
||||
}
|
||||
|
||||
func (fx *fixture) finish(t *testing.T) {
|
||||
require.NoError(t, fx.a.Close(ctx))
|
||||
fx.ctrl.Finish()
|
||||
}
|
||||
|
||||
func newTestConf() *testConf {
|
||||
return &testConf{testnodeconf.GenNodeConfig(1)}
|
||||
}
|
||||
|
||||
type testConf struct {
|
||||
*testnodeconf.Config
|
||||
}
|
||||
|
||||
func (c *testConf) GetYamux() Config {
|
||||
return Config{
|
||||
ListenAddrs: []string{"127.0.0.1:0"},
|
||||
WriteTimeoutSec: 10,
|
||||
DialTimeoutSec: 10,
|
||||
MaxStreams: 1024,
|
||||
}
|
||||
}
|
||||
|
||||
type testAccepter struct {
|
||||
err error
|
||||
mcs []transport.MultiConn
|
||||
}
|
||||
|
||||
func (t *testAccepter) Accept(mc transport.MultiConn) (err error) {
|
||||
t.mcs = append(t.mcs, mc)
|
||||
return t.err
|
||||
}
|
||||
|
||||
func (t *testAccepter) Init(a *app.App) (err error) {
|
||||
a.MustComponent(CName).(transport.Transport).SetAccepter(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testAccepter) Name() (name string) { return "testAccepter" }
|
||||
Loading…
x
Reference in New Issue
Block a user