handshake with ctx

This commit is contained in:
Sergey Cherepanov 2023-02-15 21:29:37 +03:00 committed by Mikhail Iudin
parent 862d5fe693
commit e93812cdcc
No known key found for this signature in database
GPG Key ID: FAAAA8BAABDFF1C0
4 changed files with 123 additions and 31 deletions

View File

@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
@ -50,8 +51,26 @@ type CredentialChecker interface {
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
} }
func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake() h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = outgoingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release() defer h.release()
h.conn = sc h.conn = sc
localCred := cc.MakeCredentials(sc) localCred := cc.MakeCredentials(sc)
@ -99,8 +118,26 @@ func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte
} }
} }
func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake() h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = incomingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release() defer h.release()
h.conn = sc h.conn = sc

View File

@ -1,6 +1,7 @@
package handshake package handshake
import ( import (
"context"
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey" "github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
peer2 "github.com/anytypeio/any-sync/util/peer" peer2 "github.com/anytypeio/any-sync/util/peer"
@ -11,11 +12,16 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net" "net"
"net/http"
_ "net/http/pprof" _ "net/http/pprof"
"testing" "testing"
"time" "time"
) )
func init() {
go http.ListenAndServe(":6060", nil)
}
var noVerifyChecker = &testCredChecker{ var noVerifyChecker = &testCredChecker{
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify}, makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
@ -33,7 +39,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -60,7 +66,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
_ = c2.Close() _ = c2.Close()
@ -71,7 +77,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -87,7 +93,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -103,7 +109,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -122,7 +128,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -140,7 +146,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -161,7 +167,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -186,7 +192,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -208,6 +214,28 @@ func TestOutgoingHandshake(t *testing.T) {
res := <-handshakeResCh res := <-handshakeResCh
require.Error(t, res.err) require.Error(t, res.err)
}) })
t.Run("context cancel", func(t *testing.T) {
var ctx, ctxCancel = context.WithCancel(context.Background())
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
require.NoError(t, err)
ctxCancel()
res := <-handshakeResCh
assert.EqualError(t, res.err, context.Canceled.Error())
_, err = c2.Read(make([]byte, 10))
assert.Error(t, err)
_, err = c2.Write(make([]byte, 10))
assert.Error(t, err)
})
} }
func TestIncomingHandshake(t *testing.T) { func TestIncomingHandshake(t *testing.T) {
@ -215,7 +243,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -241,7 +269,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
_ = c2.Close() _ = c2.Close()
@ -252,7 +280,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -267,7 +295,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -281,7 +309,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -301,7 +329,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -323,7 +351,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -342,7 +370,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -364,7 +392,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -386,7 +414,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -404,13 +432,36 @@ func TestIncomingHandshake(t *testing.T) {
res := <-handshakeResCh res := <-handshakeResCh
require.Error(t, res.err) require.Error(t, res.err)
}) })
t.Run("context cancel", func(t *testing.T) {
var ctx, ctxCancel = context.WithCancel(context.Background())
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(ctx, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
_, err := h.readMsg()
require.NoError(t, err)
ctxCancel()
res := <-handshakeResCh
require.EqualError(t, res.err, context.Canceled.Error())
_, err = c2.Read(make([]byte, 10))
assert.Error(t, err)
_, err = c2.Write(make([]byte, 10))
assert.Error(t, err)
})
} }
func TestNotAHandshakeMessage(t *testing.T) { func TestNotAHandshakeMessage(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -429,11 +480,11 @@ func TestEndToEnd(t *testing.T) {
) )
st := time.Now() st := time.Now()
go func() { go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
outResCh <- handshakeRes{identity: identity, err: err} outResCh <- handshakeRes{identity: identity, err: err}
}() }()
go func() { go func() {
identity, err := IncomingHandshake(c2, noVerifyChecker) identity, err := IncomingHandshake(nil, c2, noVerifyChecker)
inResCh <- handshakeRes{identity: identity, err: err} inResCh <- handshakeRes{identity: identity, err: err}
}() }()
@ -457,7 +508,7 @@ func BenchmarkHandshake(b *testing.B) {
defer close(done) defer close(done)
go func() { go func() {
for { for {
_, _ = OutgoingHandshake(c1, noVerifyChecker) _, _ = OutgoingHandshake(nil, c1, noVerifyChecker)
select { select {
case outRes <- struct{}{}: case outRes <- struct{}{}:
case <-done: case <-done:
@ -467,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) {
}() }()
go func() { go func() {
for { for {
_, _ = IncomingHandshake(c2, noVerifyChecker) _, _ = IncomingHandshake(nil, c2, noVerifyChecker)
select { select {
case inRes <- struct{}{}: case inRes <- struct{}{}:
case <-done: case <-done:

View File

@ -37,7 +37,7 @@ func New() SecureService {
} }
type SecureService interface { type SecureService interface {
TLSListener(lis net.Listener, timeoutMillis int) ContextListener TLSListener(lis net.Listener, timeoutMillis int, withIdentityCheck bool) ContextListener
BasicListener(lis net.Listener, timeoutMillis int) ContextListener BasicListener(lis net.Listener, timeoutMillis int) ContextListener
TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
app.Component app.Component
@ -76,8 +76,12 @@ func (s *secureService) Name() (name string) {
return CName return CName
} }
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int) ContextListener { func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int, identityHandshake bool) ContextListener {
return newTLSListener(s.key, lis, timeoutMillis) cc := s.noVerifyChecker
if identityHandshake {
cc = s.peerSignVerifier
}
return newTLSListener(cc, s.key, lis, timeoutMillis)
} }
func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener { func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener {
@ -98,7 +102,7 @@ func (s *secureService) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureC
checker = s.noVerifyChecker checker = s.noVerifyChecker
} }
// ignore identity for outgoing connection because we don't need it at this moment // ignore identity for outgoing connection because we don't need it at this moment
_, err = handshake.OutgoingHandshake(sc, checker) _, err = handshake.OutgoingHandshake(ctx, sc, checker)
if err != nil { if err != nil {
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()} return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
} }

View File

@ -57,7 +57,7 @@ func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.C
err: err, err: err,
} }
} }
identity, err := handshake.IncomingHandshake(secure, p.cc) identity, err := handshake.IncomingHandshake(nil, secure, p.cc)
if err != nil { if err != nil {
return nil, nil, HandshakeError{ return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(), remoteAddr: conn.RemoteAddr().String(),