handshake with ctx
This commit is contained in:
parent
862d5fe693
commit
e93812cdcc
@ -1,6 +1,7 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
@ -50,8 +51,26 @@ type CredentialChecker interface {
|
||||
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
|
||||
}
|
||||
|
||||
func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
h := newHandshake()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
identity, err = outgoingHandshake(h, sc, cc)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
_ = sc.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
defer h.release()
|
||||
h.conn = sc
|
||||
localCred := cc.MakeCredentials(sc)
|
||||
@ -99,8 +118,26 @@ func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte
|
||||
}
|
||||
}
|
||||
|
||||
func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
h := newHandshake()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
identity, err = incomingHandshake(h, sc, cc)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
_ = sc.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
defer h.release()
|
||||
h.conn = sc
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
package handshake
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
peer2 "github.com/anytypeio/any-sync/util/peer"
|
||||
@ -11,11 +12,16 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net"
|
||||
"net/http"
|
||||
_ "net/http/pprof"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
go http.ListenAndServe(":6060", nil)
|
||||
}
|
||||
|
||||
var noVerifyChecker = &testCredChecker{
|
||||
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
|
||||
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||
@ -33,7 +39,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -60,7 +66,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
_ = c2.Close()
|
||||
@ -71,7 +77,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -87,7 +93,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -103,7 +109,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -122,7 +128,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -140,7 +146,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -161,7 +167,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -186,7 +192,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -208,6 +214,28 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("context cancel", func(t *testing.T) {
|
||||
var ctx, ctxCancel = context.WithCancel(context.Background())
|
||||
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// receive credential message
|
||||
_, err := h.readMsg()
|
||||
require.NoError(t, err)
|
||||
ctxCancel()
|
||||
res := <-handshakeResCh
|
||||
assert.EqualError(t, res.err, context.Canceled.Error())
|
||||
_, err = c2.Read(make([]byte, 10))
|
||||
assert.Error(t, err)
|
||||
_, err = c2.Write(make([]byte, 10))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIncomingHandshake(t *testing.T) {
|
||||
@ -215,7 +243,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -241,7 +269,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
_ = c2.Close()
|
||||
@ -252,7 +280,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -267,7 +295,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -281,7 +309,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -301,7 +329,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -323,7 +351,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -342,7 +370,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -364,7 +392,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -386,7 +414,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -404,13 +432,36 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("context cancel", func(t *testing.T) {
|
||||
var ctx, ctxCancel = context.WithCancel(context.Background())
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(ctx, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// wait credentials
|
||||
_, err := h.readMsg()
|
||||
require.NoError(t, err)
|
||||
ctxCancel()
|
||||
res := <-handshakeResCh
|
||||
require.EqualError(t, res.err, context.Canceled.Error())
|
||||
_, err = c2.Read(make([]byte, 10))
|
||||
assert.Error(t, err)
|
||||
_, err = c2.Write(make([]byte, 10))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotAHandshakeMessage(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -429,11 +480,11 @@ func TestEndToEnd(t *testing.T) {
|
||||
)
|
||||
st := time.Now()
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
outResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(c2, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c2, noVerifyChecker)
|
||||
inResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
|
||||
@ -457,7 +508,7 @@ func BenchmarkHandshake(b *testing.B) {
|
||||
defer close(done)
|
||||
go func() {
|
||||
for {
|
||||
_, _ = OutgoingHandshake(c1, noVerifyChecker)
|
||||
_, _ = OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
select {
|
||||
case outRes <- struct{}{}:
|
||||
case <-done:
|
||||
@ -467,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) {
|
||||
}()
|
||||
go func() {
|
||||
for {
|
||||
_, _ = IncomingHandshake(c2, noVerifyChecker)
|
||||
_, _ = IncomingHandshake(nil, c2, noVerifyChecker)
|
||||
select {
|
||||
case inRes <- struct{}{}:
|
||||
case <-done:
|
||||
|
||||
@ -37,7 +37,7 @@ func New() SecureService {
|
||||
}
|
||||
|
||||
type SecureService interface {
|
||||
TLSListener(lis net.Listener, timeoutMillis int) ContextListener
|
||||
TLSListener(lis net.Listener, timeoutMillis int, withIdentityCheck bool) ContextListener
|
||||
BasicListener(lis net.Listener, timeoutMillis int) ContextListener
|
||||
TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
|
||||
app.Component
|
||||
@ -76,8 +76,12 @@ func (s *secureService) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int) ContextListener {
|
||||
return newTLSListener(s.key, lis, timeoutMillis)
|
||||
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int, identityHandshake bool) ContextListener {
|
||||
cc := s.noVerifyChecker
|
||||
if identityHandshake {
|
||||
cc = s.peerSignVerifier
|
||||
}
|
||||
return newTLSListener(cc, s.key, lis, timeoutMillis)
|
||||
}
|
||||
|
||||
func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener {
|
||||
@ -98,7 +102,7 @@ func (s *secureService) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureC
|
||||
checker = s.noVerifyChecker
|
||||
}
|
||||
// ignore identity for outgoing connection because we don't need it at this moment
|
||||
_, err = handshake.OutgoingHandshake(sc, checker)
|
||||
_, err = handshake.OutgoingHandshake(ctx, sc, checker)
|
||||
if err != nil {
|
||||
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
|
||||
}
|
||||
|
||||
@ -57,7 +57,7 @@ func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.C
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
identity, err := handshake.IncomingHandshake(secure, p.cc)
|
||||
identity, err := handshake.IncomingHandshake(nil, secure, p.cc)
|
||||
if err != nil {
|
||||
return nil, nil, HandshakeError{
|
||||
remoteAddr: conn.RemoteAddr().String(),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user