From d04e55bc9c092921dca32b04a7203c06c1ad0cda Mon Sep 17 00:00:00 2001 From: Sergey Cherepanov Date: Mon, 22 May 2023 17:57:00 +0200 Subject: [PATCH] propagate handshake error + NetworkCompatibilityStatus method --- net/config.go | 6 + net/dialer/dialer.go | 4 + net/pool/pool.go | 21 ++- net/pool/pool_test.go | 14 +- net/secureservice/handshake/handshake.go | 24 ++-- net/streampool/streampool.go | 4 +- nodeconf/mock_nodeconf/mock_nodeconf.go | 14 ++ nodeconf/nodeconf.go | 1 - nodeconf/service.go | 43 +++++- nodeconf/service_test.go | 175 +++++++++++++++++++++++ testutil/accounttest/accountservice.go | 10 -- testutil/testnodeconf/testnodeconf.go | 11 +- 12 files changed, 289 insertions(+), 38 deletions(-) create mode 100644 nodeconf/service_test.go diff --git a/net/config.go b/net/config.go index b0cdf564..333261dd 100644 --- a/net/config.go +++ b/net/config.go @@ -1,5 +1,11 @@ package net +import "errors" + +var ( + ErrUnableToConnect = errors.New("unable to connect") +) + type ConfigGetter interface { GetNet() Config } diff --git a/net/dialer/dialer.go b/net/dialer/dialer.go index d862e78f..1992e6e5 100644 --- a/net/dialer/dialer.go +++ b/net/dialer/dialer.go @@ -9,6 +9,7 @@ import ( net2 "github.com/anytypeio/any-sync/net" "github.com/anytypeio/any-sync/net/peer" "github.com/anytypeio/any-sync/net/secureservice" + "github.com/anytypeio/any-sync/net/secureservice/handshake" "github.com/anytypeio/any-sync/net/timeoutconn" "github.com/anytypeio/any-sync/nodeconf" "github.com/libp2p/go-libp2p/core/sec" @@ -120,6 +121,9 @@ func (d *dialer) handshake(ctx context.Context, addr, peerId string) (conn drpc. timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds)) sc, err = d.transport.SecureOutbound(ctx, timeoutConn) if err != nil { + if he, ok := err.(handshake.HandshakeError); ok { + return nil, nil, he + } return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st)) } if peerId != sc.RemotePeer().String() { diff --git a/net/pool/pool.go b/net/pool/pool.go index b7e391c6..81698e5a 100644 --- a/net/pool/pool.go +++ b/net/pool/pool.go @@ -2,18 +2,15 @@ package pool import ( "context" - "errors" "github.com/anytypeio/any-sync/app/ocache" + "github.com/anytypeio/any-sync/net" "github.com/anytypeio/any-sync/net/dialer" "github.com/anytypeio/any-sync/net/peer" + "github.com/anytypeio/any-sync/net/secureservice/handshake" "go.uber.org/zap" "math/rand" ) -var ( - ErrUnableToConnect = errors.New("unable to connect") -) - // Pool creates and caches outgoing connection type Pool interface { // Get lookups to peer in existing connections or creates and cache new one @@ -76,14 +73,19 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error peerIds[i], peerIds[j] = peerIds[j], peerIds[i] }) // connecting + var lastErr error for _, peerId := range peerIds { if v, err := p.cache.Get(ctx, peerId); err == nil { return v.(peer.Peer), nil } else { log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err)) + lastErr = err } } - return nil, ErrUnableToConnect + if _, ok := lastErr.(handshake.HandshakeError); !ok { + lastErr = net.ErrUnableToConnect + } + return nil, lastErr } func (p *pool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { @@ -92,14 +94,19 @@ func (p *pool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, erro peerIds[i], peerIds[j] = peerIds[j], peerIds[i] }) // connecting + var lastErr error for _, peerId := range peerIds { if v, err := p.dialer.Dial(ctx, peerId); err == nil { return v.(peer.Peer), nil } else { log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err)) + lastErr = err } } - return nil, ErrUnableToConnect + if _, ok := lastErr.(handshake.HandshakeError); !ok { + lastErr = net.ErrUnableToConnect + } + return nil, lastErr } func (p *pool) Close(ctx context.Context) (err error) { diff --git a/net/pool/pool_test.go b/net/pool/pool_test.go index 262a59f6..40d33cfa 100644 --- a/net/pool/pool_test.go +++ b/net/pool/pool_test.go @@ -5,8 +5,10 @@ import ( "errors" "fmt" "github.com/anytypeio/any-sync/app" + "github.com/anytypeio/any-sync/net" "github.com/anytypeio/any-sync/net/dialer" "github.com/anytypeio/any-sync/net/peer" + "github.com/anytypeio/any-sync/net/secureservice/handshake" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "storj.io/drpc" @@ -116,7 +118,17 @@ func TestPool_GetOneOf(t *testing.T) { return nil, fmt.Errorf("persistent error") } p, err := fx.GetOneOf(ctx, []string{"3", "2", "1"}) - assert.Equal(t, ErrUnableToConnect, err) + assert.Equal(t, net.ErrUnableToConnect, err) + assert.Nil(t, p) + }) + t.Run("handshake error", func(t *testing.T) { + fx := newFixture(t) + defer fx.Finish() + fx.Dialer.dial = func(ctx context.Context, peerId string) (peer peer.Peer, err error) { + return nil, handshake.ErrIncompatibleVersion + } + p, err := fx.GetOneOf(ctx, []string{"3", "2", "1"}) + assert.Equal(t, handshake.ErrIncompatibleVersion, err) assert.Nil(t, p) }) } diff --git a/net/secureservice/handshake/handshake.go b/net/secureservice/handshake/handshake.go index 044311a9..1faef9e9 100644 --- a/net/secureservice/handshake/handshake.go +++ b/net/secureservice/handshake/handshake.go @@ -18,23 +18,23 @@ const ( msgTypeAck = byte(2) ) -type handshakeError struct { +type HandshakeError struct { e handshakeproto.Error } -func (he handshakeError) Error() string { +func (he HandshakeError) Error() string { return he.e.String() } var ( - ErrUnexpectedPayload = handshakeError{handshakeproto.Error_UnexpectedPayload} - ErrDeadlineExceeded = handshakeError{handshakeproto.Error_DeadlineExceeded} - ErrInvalidCredentials = handshakeError{handshakeproto.Error_InvalidCredentials} + ErrUnexpectedPayload = HandshakeError{handshakeproto.Error_UnexpectedPayload} + ErrDeadlineExceeded = HandshakeError{handshakeproto.Error_DeadlineExceeded} + ErrInvalidCredentials = HandshakeError{handshakeproto.Error_InvalidCredentials} ErrPeerDeclinedCredentials = errors.New("remote peer declined the credentials") - ErrSkipVerifyNotAllowed = handshakeError{handshakeproto.Error_SkipVerifyNotAllowed} - ErrUnexpected = handshakeError{handshakeproto.Error_Unexpected} + ErrSkipVerifyNotAllowed = HandshakeError{handshakeproto.Error_SkipVerifyNotAllowed} + ErrUnexpected = HandshakeError{handshakeproto.Error_Unexpected} - ErrIncompatibleVersion = handshakeError{handshakeproto.Error_IncompatibleVersion} + ErrIncompatibleVersion = HandshakeError{handshakeproto.Error_IncompatibleVersion} ErrGotNotAHandshakeMessage = errors.New("go not a handshake message") ) @@ -89,7 +89,7 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i if msg.ack.Error == handshakeproto.Error_InvalidCredentials { return nil, ErrPeerDeclinedCredentials } - return nil, handshakeError{e: msg.ack.Error} + return nil, HandshakeError{e: msg.ack.Error} } if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { @@ -116,7 +116,7 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i return identity, nil } else { _ = h.conn.Close() - return nil, handshakeError{e: msg.ack.Error} + return nil, HandshakeError{e: msg.ack.Error} } } @@ -175,7 +175,7 @@ func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i if msg.ack.Error == handshakeproto.Error_InvalidCredentials { return nil, ErrPeerDeclinedCredentials } - return nil, handshakeError{e: msg.ack.Error} + return nil, HandshakeError{e: msg.ack.Error} } if err = h.writeAck(handshakeproto.Error_Null); err != nil { h.tryWriteErrAndClose(err) @@ -212,7 +212,7 @@ func (h *handshake) tryWriteErrAndClose(err error) { return } var ackErr handshakeproto.Error - if he, ok := err.(handshakeError); ok { + if he, ok := err.(HandshakeError); ok { ackErr = he.e } else { ackErr = handshakeproto.Error_Unexpected diff --git a/net/streampool/streampool.go b/net/streampool/streampool.go index 442a5097..42e3a0cb 100644 --- a/net/streampool/streampool.go +++ b/net/streampool/streampool.go @@ -2,8 +2,8 @@ package streampool import ( "fmt" + "github.com/anytypeio/any-sync/net" "github.com/anytypeio/any-sync/net/peer" - "github.com/anytypeio/any-sync/net/pool" "github.com/anytypeio/any-sync/util/multiqueue" "go.uber.org/zap" "golang.org/x/exp/slices" @@ -172,7 +172,7 @@ func (s *streamPool) SendById(ctx context.Context, msg drpc.Message, peerIds ... } } if len(streamsByPeer) == 0 { - return pool.ErrUnableToConnect + return net.ErrUnableToConnect } return } diff --git a/nodeconf/mock_nodeconf/mock_nodeconf.go b/nodeconf/mock_nodeconf/mock_nodeconf.go index b422c8a9..dc261709 100644 --- a/nodeconf/mock_nodeconf/mock_nodeconf.go +++ b/nodeconf/mock_nodeconf/mock_nodeconf.go @@ -177,6 +177,20 @@ func (mr *MockServiceMockRecorder) Name() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockService)(nil).Name)) } +// NetworkCompatibilityStatus mocks base method. +func (m *MockService) NetworkCompatibilityStatus() nodeconf.NetworkCompatibilityStatus { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NetworkCompatibilityStatus") + ret0, _ := ret[0].(nodeconf.NetworkCompatibilityStatus) + return ret0 +} + +// NetworkCompatibilityStatus indicates an expected call of NetworkCompatibilityStatus. +func (mr *MockServiceMockRecorder) NetworkCompatibilityStatus() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NetworkCompatibilityStatus", reflect.TypeOf((*MockService)(nil).NetworkCompatibilityStatus)) +} + // NodeIds mocks base method. func (m *MockService) NodeIds(arg0 string) []string { m.ctrl.T.Helper() diff --git a/nodeconf/nodeconf.go b/nodeconf/nodeconf.go index ac665b9f..c49893d5 100644 --- a/nodeconf/nodeconf.go +++ b/nodeconf/nodeconf.go @@ -1,4 +1,3 @@ -//go:generate mockgen -destination mock_nodeconf/mock_nodeconf.go github.com/anytypeio/any-sync/nodeconf Service package nodeconf import ( diff --git a/nodeconf/service.go b/nodeconf/service.go index 47f9e39d..024392a1 100644 --- a/nodeconf/service.go +++ b/nodeconf/service.go @@ -1,3 +1,4 @@ +//go:generate mockgen -destination mock_nodeconf/mock_nodeconf.go github.com/anytypeio/any-sync/nodeconf Service package nodeconf import ( @@ -5,6 +6,8 @@ import ( commonaccount "github.com/anytypeio/any-sync/accountservice" "github.com/anytypeio/any-sync/app" "github.com/anytypeio/any-sync/app/logger" + "github.com/anytypeio/any-sync/net" + "github.com/anytypeio/any-sync/net/secureservice/handshake" "github.com/anytypeio/any-sync/util/periodicsync" "github.com/anytypeio/go-chash" "go.uber.org/zap" @@ -20,12 +23,22 @@ const ( var log = logger.NewNamed(CName) +type NetworkCompatibilityStatus int + +const ( + NetworkCompatibilityStatusUnknown NetworkCompatibilityStatus = iota + NetworkCompatibilityStatusOk + NetworkCompatibilityStatusError + NetworkCompatibilityStatusIncompatible +) + func New() Service { return new(service) } type Service interface { NodeConf + NetworkCompatibilityStatus() NetworkCompatibilityStatus app.ComponentRunnable } @@ -37,6 +50,8 @@ type service struct { last NodeConf mu sync.RWMutex sync periodicsync.PeriodicSync + + compatibilityStatus NetworkCompatibilityStatus } func (s *service) Init(a *app.App) (err error) { @@ -75,10 +90,19 @@ func (s *service) Run(_ context.Context) (err error) { return } +func (s *service) NetworkCompatibilityStatus() NetworkCompatibilityStatus { + s.mu.RLock() + defer s.mu.RUnlock() + return s.compatibilityStatus +} + func (s *service) updateConfiguration(ctx context.Context) (err error) { last, err := s.source.GetLast(ctx, s.Configuration().Id) if err != nil { + s.setCompatibilityStatusByErr(err) return + } else { + s.setCompatibilityStatusByErr(nil) } if err = s.store.SaveLast(ctx, last); err != nil { return @@ -137,6 +161,21 @@ func (s *service) setLastConfiguration(c Configuration) (err error) { return } +func (s *service) setCompatibilityStatusByErr(err error) { + s.mu.Lock() + defer s.mu.Unlock() + switch err { + case nil: + s.compatibilityStatus = NetworkCompatibilityStatusOk + case handshake.ErrIncompatibleVersion: + s.compatibilityStatus = NetworkCompatibilityStatusIncompatible + case net.ErrUnableToConnect: + s.compatibilityStatus = NetworkCompatibilityStatusUnknown + default: + s.compatibilityStatus = NetworkCompatibilityStatusError + } +} + func (s *service) Id() string { s.mu.RLock() defer s.mu.RUnlock() @@ -204,6 +243,8 @@ func (s *service) NodeTypes(nodeId string) []NodeType { } func (s *service) Close(ctx context.Context) (err error) { - s.sync.Close() + if s.sync != nil { + s.sync.Close() + } return } diff --git a/nodeconf/service_test.go b/nodeconf/service_test.go new file mode 100644 index 00000000..5ed8154c --- /dev/null +++ b/nodeconf/service_test.go @@ -0,0 +1,175 @@ +package nodeconf + +import ( + "context" + "errors" + "github.com/anytypeio/any-sync/app" + "github.com/anytypeio/any-sync/net" + "github.com/anytypeio/any-sync/net/secureservice/handshake" + "github.com/anytypeio/any-sync/testutil/accounttest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "sync" + "testing" + "time" +) + +var ctx = context.Background() + +func TestService_NetworkCompatibilityStatus(t *testing.T) { + t.Run("unknown", func(t *testing.T) { + fx := newFixture(t) + defer fx.finish(t) + fx.testSource.call = func() (c Configuration, e error) { + e = net.ErrUnableToConnect + return + } + fx.run(t) + time.Sleep(time.Millisecond * 10) + assert.Equal(t, NetworkCompatibilityStatusUnknown, fx.NetworkCompatibilityStatus()) + }) + t.Run("incompatible", func(t *testing.T) { + fx := newFixture(t) + defer fx.finish(t) + fx.testSource.err = handshake.ErrIncompatibleVersion + fx.run(t) + time.Sleep(time.Millisecond * 10) + assert.Equal(t, NetworkCompatibilityStatusIncompatible, fx.NetworkCompatibilityStatus()) + }) + t.Run("error", func(t *testing.T) { + fx := newFixture(t) + defer fx.finish(t) + fx.testSource.call = func() (c Configuration, e error) { + e = errors.New("some error") + return + } + fx.run(t) + time.Sleep(time.Millisecond * 10) + assert.Equal(t, NetworkCompatibilityStatusError, fx.NetworkCompatibilityStatus()) + }) + t.Run("ok", func(t *testing.T) { + fx := newFixture(t) + defer fx.finish(t) + fx.run(t) + time.Sleep(time.Millisecond * 10) + assert.Equal(t, NetworkCompatibilityStatusOk, fx.NetworkCompatibilityStatus()) + }) +} + +func newFixture(t *testing.T) *fixture { + fx := &fixture{ + Service: New(), + a: new(app.App), + testStore: &testStore{}, + testSource: &testSource{}, + testConf: newTestConf(), + } + fx.a.Register(fx.testConf).Register(&accounttest.AccountTestService{}).Register(fx.Service).Register(fx.testSource).Register(fx.testStore) + return fx +} + +type fixture struct { + Service + a *app.App + testStore *testStore + testSource *testSource + testConf *testConf +} + +func (fx *fixture) run(t *testing.T) { + require.NoError(t, fx.a.Start(ctx)) +} + +func (fx *fixture) finish(t *testing.T) { + require.NoError(t, fx.a.Close(ctx)) +} + +type testSource struct { + conf Configuration + err error + call func() (Configuration, error) +} + +func (t *testSource) Init(a *app.App) error { return nil } +func (t *testSource) Name() string { return CNameSource } + +func (t *testSource) GetLast(ctx context.Context, currentId string) (c Configuration, err error) { + if t.call != nil { + return t.call() + } + return t.conf, t.err +} + +type testStore struct { + conf *Configuration + mu sync.Mutex +} + +func (t *testStore) Init(a *app.App) error { return nil } +func (t *testStore) Name() string { return CNameStore } + +func (t *testStore) GetLast(ctx context.Context, netId string) (c Configuration, err error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.conf != nil { + return *t.conf, nil + } else { + err = ErrConfigurationNotFound + } + return +} + +func (t *testStore) SaveLast(ctx context.Context, c Configuration) (err error) { + t.mu.Lock() + defer t.mu.Unlock() + t.conf = &c + return +} + +type testConf struct { + Configuration +} + +func (t *testConf) Init(a *app.App) error { return nil } +func (t *testConf) Name() string { return "config" } + +func (t *testConf) GetNodeConf() Configuration { + return t.Configuration +} + +func newTestConf() *testConf { + return &testConf{ + Configuration{ + Id: "test", + NetworkId: "testNetwork", + Nodes: []Node{ + { + PeerId: "12D3KooWKLCajM89S8unbt3tgGbRLgmiWnFZT3adn9A5pQciBSLa", + Addresses: []string{"127.0.0.1:4830"}, + Types: []NodeType{NodeTypeCoordinator}, + }, + { + PeerId: "12D3KooWKnXTtbveMDUFfeSqR5dt9a4JW66tZQXG7C7PdDh3vqGu", + Addresses: []string{"127.0.0.1:4730"}, + Types: []NodeType{NodeTypeTree}, + }, + { + PeerId: "12D3KooWKgVN2kW8xw5Uvm2sLUnkeUNQYAvcWvF58maTzev7FjPi", + Addresses: []string{"127.0.0.1:4731"}, + Types: []NodeType{NodeTypeTree}, + }, + { + PeerId: "12D3KooWCUPYuMnQhu9yREJgQyjcz8zWY83rZGmDLwb9YR6QkbZX", + Addresses: []string{"127.0.0.1:4732"}, + Types: []NodeType{NodeTypeTree}, + }, + { + PeerId: "12D3KooWQxiZ5a7vcy4DTJa8Gy1eVUmwb5ojN4SrJC9Rjxzigw6C", + Addresses: []string{"127.0.0.1:4733"}, + Types: []NodeType{NodeTypeFile}, + }, + }, + CreationTime: time.Now(), + }, + } +} diff --git a/testutil/accounttest/accountservice.go b/testutil/accounttest/accountservice.go index 3c9f6e11..610f640b 100644 --- a/testutil/accounttest/accountservice.go +++ b/testutil/accounttest/accountservice.go @@ -4,8 +4,6 @@ import ( accountService "github.com/anytypeio/any-sync/accountservice" "github.com/anytypeio/any-sync/app" "github.com/anytypeio/any-sync/commonspace/object/accountdata" - "github.com/anytypeio/any-sync/nodeconf" - "github.com/anytypeio/any-sync/nodeconf/nodeconfstore" "github.com/anytypeio/any-sync/util/crypto" ) @@ -47,11 +45,3 @@ func (s *AccountTestService) Name() (name string) { func (s *AccountTestService) Account() *accountdata.AccountKeys { return s.acc } - -func (s *AccountTestService) NodeConf(addrs []string) nodeconfstore.NodeConfig { - return nodeconfstore.NodeConfig{ - PeerId: s.acc.PeerId, - Addresses: addrs, - Types: []nodeconf.NodeType{nodeconf.NodeTypeTree}, - } -} diff --git a/testutil/testnodeconf/testnodeconf.go b/testutil/testnodeconf/testnodeconf.go index 118fcfdd..a0188369 100644 --- a/testutil/testnodeconf/testnodeconf.go +++ b/testutil/testnodeconf/testnodeconf.go @@ -3,7 +3,7 @@ package testnodeconf import ( "github.com/anytypeio/any-sync/accountservice" "github.com/anytypeio/any-sync/app" - "github.com/anytypeio/any-sync/nodeconf/nodeconfstore" + "github.com/anytypeio/any-sync/nodeconf" "github.com/anytypeio/any-sync/testutil/accounttest" ) @@ -17,14 +17,17 @@ func GenNodeConfig(num int) (conf *Config) { if err := ac.Init(nil); err != nil { panic(err) } - conf.nodes = append(conf.nodes, ac.NodeConf(nil)) + conf.nodes.Nodes = append(conf.nodes.Nodes, nodeconf.Node{ + PeerId: ac.Account().PeerId, + Types: []nodeconf.NodeType{nodeconf.NodeTypeTree}, + }) conf.configs = append(conf.configs, ac) } return conf } type Config struct { - nodes []nodeconfstore.NodeConfig + nodes nodeconf.Configuration configs []*accounttest.AccountTestService } @@ -35,7 +38,7 @@ func (c *Config) GetNodesConfId() string { return "test" } -func (c *Config) GetNodes() []nodeconfstore.NodeConfig { +func (c *Config) GetNodeConf() nodeconf.Configuration { return c.nodes }