diff --git a/net/secureservice/handshake/handshake.go b/net/secureservice/handshake/handshake.go index 72fbed02..b1d2eabf 100644 --- a/net/secureservice/handshake/handshake.go +++ b/net/secureservice/handshake/handshake.go @@ -1,6 +1,7 @@ package handshake import ( + "context" "encoding/binary" "errors" "github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto" @@ -50,8 +51,26 @@ type CredentialChecker interface { CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) } -func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { +func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { + if ctx == nil { + ctx = context.Background() + } h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + identity, err = outgoingHandshake(h, sc, cc) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = sc.Close() + return nil, ctx.Err() + } +} + +func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { defer h.release() h.conn = sc localCred := cc.MakeCredentials(sc) @@ -99,8 +118,26 @@ func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte } } -func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { +func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { + if ctx == nil { + ctx = context.Background() + } h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + identity, err = incomingHandshake(h, sc, cc) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = sc.Close() + return nil, ctx.Err() + } +} + +func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { defer h.release() h.conn = sc diff --git a/net/secureservice/handshake/handshake_test.go b/net/secureservice/handshake/handshake_test.go index d8bd8f21..e32a9362 100644 --- a/net/secureservice/handshake/handshake_test.go +++ b/net/secureservice/handshake/handshake_test.go @@ -1,6 +1,7 @@ package handshake import ( + "context" "github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey" peer2 "github.com/anytypeio/any-sync/util/peer" @@ -11,11 +12,16 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "net" + "net/http" _ "net/http/pprof" "testing" "time" ) +func init() { + go http.ListenAndServe(":6060", nil) +} + var noVerifyChecker = &testCredChecker{ makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify}, checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { @@ -33,7 +39,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -60,7 +66,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() _ = c2.Close() @@ -71,7 +77,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -87,7 +93,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -103,7 +109,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) + identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -122,7 +128,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -140,7 +146,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -161,7 +167,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -186,7 +192,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -208,6 +214,28 @@ func TestOutgoingHandshake(t *testing.T) { res := <-handshakeResCh require.Error(t, res.err) }) + t.Run("context cancel", func(t *testing.T) { + var ctx, ctxCancel = context.WithCancel(context.Background()) + + c1, c2 := newConnPair(t) + var handshakeResCh = make(chan handshakeRes, 1) + go func() { + identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} + }() + h := newHandshake() + h.conn = c2 + // receive credential message + _, err := h.readMsg() + require.NoError(t, err) + ctxCancel() + res := <-handshakeResCh + assert.EqualError(t, res.err, context.Canceled.Error()) + _, err = c2.Read(make([]byte, 10)) + assert.Error(t, err) + _, err = c2.Write(make([]byte, 10)) + assert.Error(t, err) + }) } func TestIncomingHandshake(t *testing.T) { @@ -215,7 +243,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -241,7 +269,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() _ = c2.Close() @@ -252,7 +280,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -267,7 +295,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -281,7 +309,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) + identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -301,7 +329,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -323,7 +351,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -342,7 +370,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -364,7 +392,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -386,7 +414,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -404,13 +432,36 @@ func TestIncomingHandshake(t *testing.T) { res := <-handshakeResCh require.Error(t, res.err) }) + t.Run("context cancel", func(t *testing.T) { + var ctx, ctxCancel = context.WithCancel(context.Background()) + c1, c2 := newConnPair(t) + var handshakeResCh = make(chan handshakeRes, 1) + go func() { + identity, err := IncomingHandshake(ctx, c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} + }() + h := newHandshake() + h.conn = c2 + // write credentials + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + // wait credentials + _, err := h.readMsg() + require.NoError(t, err) + ctxCancel() + res := <-handshakeResCh + require.EqualError(t, res.err, context.Canceled.Error()) + _, err = c2.Read(make([]byte, 10)) + assert.Error(t, err) + _, err = c2.Write(make([]byte, 10)) + assert.Error(t, err) + }) } func TestNotAHandshakeMessage(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -429,11 +480,11 @@ func TestEndToEnd(t *testing.T) { ) st := time.Now() go func() { - identity, err := OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) outResCh <- handshakeRes{identity: identity, err: err} }() go func() { - identity, err := IncomingHandshake(c2, noVerifyChecker) + identity, err := IncomingHandshake(nil, c2, noVerifyChecker) inResCh <- handshakeRes{identity: identity, err: err} }() @@ -457,7 +508,7 @@ func BenchmarkHandshake(b *testing.B) { defer close(done) go func() { for { - _, _ = OutgoingHandshake(c1, noVerifyChecker) + _, _ = OutgoingHandshake(nil, c1, noVerifyChecker) select { case outRes <- struct{}{}: case <-done: @@ -467,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) { }() go func() { for { - _, _ = IncomingHandshake(c2, noVerifyChecker) + _, _ = IncomingHandshake(nil, c2, noVerifyChecker) select { case inRes <- struct{}{}: case <-done: diff --git a/net/secureservice/secureservice.go b/net/secureservice/secureservice.go index a6cb7544..8f1f4a4d 100644 --- a/net/secureservice/secureservice.go +++ b/net/secureservice/secureservice.go @@ -37,7 +37,7 @@ func New() SecureService { } type SecureService interface { - TLSListener(lis net.Listener, timeoutMillis int) ContextListener + TLSListener(lis net.Listener, timeoutMillis int, withIdentityCheck bool) ContextListener BasicListener(lis net.Listener, timeoutMillis int) ContextListener TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error) app.Component @@ -76,8 +76,12 @@ func (s *secureService) Name() (name string) { return CName } -func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int) ContextListener { - return newTLSListener(s.key, lis, timeoutMillis) +func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int, identityHandshake bool) ContextListener { + cc := s.noVerifyChecker + if identityHandshake { + cc = s.peerSignVerifier + } + return newTLSListener(cc, s.key, lis, timeoutMillis) } func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener { @@ -98,7 +102,7 @@ func (s *secureService) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureC checker = s.noVerifyChecker } // ignore identity for outgoing connection because we don't need it at this moment - _, err = handshake.OutgoingHandshake(sc, checker) + _, err = handshake.OutgoingHandshake(ctx, sc, checker) if err != nil { return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()} } diff --git a/net/secureservice/tlslistener.go b/net/secureservice/tlslistener.go index 6abf5742..65c13fe4 100644 --- a/net/secureservice/tlslistener.go +++ b/net/secureservice/tlslistener.go @@ -57,7 +57,7 @@ func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.C err: err, } } - identity, err := handshake.IncomingHandshake(secure, p.cc) + identity, err := handshake.IncomingHandshake(nil, secure, p.cc) if err != nil { return nil, nil, HandshakeError{ remoteAddr: conn.RemoteAddr().String(),