handshake with ctx

This commit is contained in:
Sergey Cherepanov 2023-02-15 21:29:37 +03:00 committed by Mikhail Iudin
parent 862d5fe693
commit e93812cdcc
No known key found for this signature in database
GPG Key ID: FAAAA8BAABDFF1C0
4 changed files with 123 additions and 31 deletions

View File

@ -1,6 +1,7 @@
package handshake
import (
"context"
"encoding/binary"
"errors"
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
@ -50,8 +51,26 @@ type CredentialChecker interface {
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
}
func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = outgoingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
localCred := cc.MakeCredentials(sc)
@ -99,8 +118,26 @@ func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte
}
}
func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = incomingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc

View File

@ -1,6 +1,7 @@
package handshake
import (
"context"
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
peer2 "github.com/anytypeio/any-sync/util/peer"
@ -11,11 +12,16 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net"
"net/http"
_ "net/http/pprof"
"testing"
"time"
)
func init() {
go http.ListenAndServe(":6060", nil)
}
var noVerifyChecker = &testCredChecker{
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
@ -33,7 +39,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -60,7 +66,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
_ = c2.Close()
@ -71,7 +77,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -87,7 +93,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -103,7 +109,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -122,7 +128,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -140,7 +146,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -161,7 +167,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -186,7 +192,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -208,6 +214,28 @@ func TestOutgoingHandshake(t *testing.T) {
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("context cancel", func(t *testing.T) {
var ctx, ctxCancel = context.WithCancel(context.Background())
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
require.NoError(t, err)
ctxCancel()
res := <-handshakeResCh
assert.EqualError(t, res.err, context.Canceled.Error())
_, err = c2.Read(make([]byte, 10))
assert.Error(t, err)
_, err = c2.Write(make([]byte, 10))
assert.Error(t, err)
})
}
func TestIncomingHandshake(t *testing.T) {
@ -215,7 +243,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -241,7 +269,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
_ = c2.Close()
@ -252,7 +280,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -267,7 +295,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -281,7 +309,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -301,7 +329,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -323,7 +351,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -342,7 +370,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -364,7 +392,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -386,7 +414,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -404,13 +432,36 @@ func TestIncomingHandshake(t *testing.T) {
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("context cancel", func(t *testing.T) {
var ctx, ctxCancel = context.WithCancel(context.Background())
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(ctx, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
_, err := h.readMsg()
require.NoError(t, err)
ctxCancel()
res := <-handshakeResCh
require.EqualError(t, res.err, context.Canceled.Error())
_, err = c2.Read(make([]byte, 10))
assert.Error(t, err)
_, err = c2.Write(make([]byte, 10))
assert.Error(t, err)
})
}
func TestNotAHandshakeMessage(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -429,11 +480,11 @@ func TestEndToEnd(t *testing.T) {
)
st := time.Now()
go func() {
identity, err := OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
outResCh <- handshakeRes{identity: identity, err: err}
}()
go func() {
identity, err := IncomingHandshake(c2, noVerifyChecker)
identity, err := IncomingHandshake(nil, c2, noVerifyChecker)
inResCh <- handshakeRes{identity: identity, err: err}
}()
@ -457,7 +508,7 @@ func BenchmarkHandshake(b *testing.B) {
defer close(done)
go func() {
for {
_, _ = OutgoingHandshake(c1, noVerifyChecker)
_, _ = OutgoingHandshake(nil, c1, noVerifyChecker)
select {
case outRes <- struct{}{}:
case <-done:
@ -467,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) {
}()
go func() {
for {
_, _ = IncomingHandshake(c2, noVerifyChecker)
_, _ = IncomingHandshake(nil, c2, noVerifyChecker)
select {
case inRes <- struct{}{}:
case <-done:

View File

@ -37,7 +37,7 @@ func New() SecureService {
}
type SecureService interface {
TLSListener(lis net.Listener, timeoutMillis int) ContextListener
TLSListener(lis net.Listener, timeoutMillis int, withIdentityCheck bool) ContextListener
BasicListener(lis net.Listener, timeoutMillis int) ContextListener
TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
app.Component
@ -76,8 +76,12 @@ func (s *secureService) Name() (name string) {
return CName
}
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int) ContextListener {
return newTLSListener(s.key, lis, timeoutMillis)
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int, identityHandshake bool) ContextListener {
cc := s.noVerifyChecker
if identityHandshake {
cc = s.peerSignVerifier
}
return newTLSListener(cc, s.key, lis, timeoutMillis)
}
func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener {
@ -98,7 +102,7 @@ func (s *secureService) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureC
checker = s.noVerifyChecker
}
// ignore identity for outgoing connection because we don't need it at this moment
_, err = handshake.OutgoingHandshake(sc, checker)
_, err = handshake.OutgoingHandshake(ctx, sc, checker)
if err != nil {
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
}

View File

@ -57,7 +57,7 @@ func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.C
err: err,
}
}
identity, err := handshake.IncomingHandshake(secure, p.cc)
identity, err := handshake.IncomingHandshake(nil, secure, p.cc)
if err != nil {
return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),