handshake with ctx
This commit is contained in:
parent
862d5fe693
commit
e93812cdcc
@ -1,6 +1,7 @@
|
|||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
|
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
|
||||||
@ -50,8 +51,26 @@ type CredentialChecker interface {
|
|||||||
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
|
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
identity, err = outgoingHandshake(h, sc, cc)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = sc.Close()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||||
defer h.release()
|
defer h.release()
|
||||||
h.conn = sc
|
h.conn = sc
|
||||||
localCred := cc.MakeCredentials(sc)
|
localCred := cc.MakeCredentials(sc)
|
||||||
@ -99,8 +118,26 @@ func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
defer close(done)
|
||||||
|
identity, err = incomingHandshake(h, sc, cc)
|
||||||
|
}()
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = sc.Close()
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||||
defer h.release()
|
defer h.release()
|
||||||
h.conn = sc
|
h.conn = sc
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
package handshake
|
package handshake
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
|
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
|
||||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||||
peer2 "github.com/anytypeio/any-sync/util/peer"
|
peer2 "github.com/anytypeio/any-sync/util/peer"
|
||||||
@ -11,11 +12,16 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
_ "net/http/pprof"
|
_ "net/http/pprof"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
go http.ListenAndServe(":6060", nil)
|
||||||
|
}
|
||||||
|
|
||||||
var noVerifyChecker = &testCredChecker{
|
var noVerifyChecker = &testCredChecker{
|
||||||
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
|
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
|
||||||
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||||
@ -33,7 +39,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -60,7 +66,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
_ = c2.Close()
|
_ = c2.Close()
|
||||||
@ -71,7 +77,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -87,7 +93,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -103,7 +109,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -122,7 +128,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -140,7 +146,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -161,7 +167,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -186,7 +192,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -208,6 +214,28 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
res := <-handshakeResCh
|
res := <-handshakeResCh
|
||||||
require.Error(t, res.err)
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
|
t.Run("context cancel", func(t *testing.T) {
|
||||||
|
var ctx, ctxCancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
c1, c2 := newConnPair(t)
|
||||||
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
|
go func() {
|
||||||
|
identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
|
}()
|
||||||
|
h := newHandshake()
|
||||||
|
h.conn = c2
|
||||||
|
// receive credential message
|
||||||
|
_, err := h.readMsg()
|
||||||
|
require.NoError(t, err)
|
||||||
|
ctxCancel()
|
||||||
|
res := <-handshakeResCh
|
||||||
|
assert.EqualError(t, res.err, context.Canceled.Error())
|
||||||
|
_, err = c2.Read(make([]byte, 10))
|
||||||
|
assert.Error(t, err)
|
||||||
|
_, err = c2.Write(make([]byte, 10))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIncomingHandshake(t *testing.T) {
|
func TestIncomingHandshake(t *testing.T) {
|
||||||
@ -215,7 +243,7 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -241,7 +269,7 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
_ = c2.Close()
|
_ = c2.Close()
|
||||||
@ -252,7 +280,7 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -267,7 +295,7 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -281,7 +309,7 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -301,7 +329,7 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -323,7 +351,7 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -342,7 +370,7 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -364,7 +392,7 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -386,7 +414,7 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -404,13 +432,36 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
res := <-handshakeResCh
|
res := <-handshakeResCh
|
||||||
require.Error(t, res.err)
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
|
t.Run("context cancel", func(t *testing.T) {
|
||||||
|
var ctx, ctxCancel = context.WithCancel(context.Background())
|
||||||
|
c1, c2 := newConnPair(t)
|
||||||
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
|
go func() {
|
||||||
|
identity, err := IncomingHandshake(ctx, c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
|
}()
|
||||||
|
h := newHandshake()
|
||||||
|
h.conn = c2
|
||||||
|
// write credentials
|
||||||
|
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||||
|
// wait credentials
|
||||||
|
_, err := h.readMsg()
|
||||||
|
require.NoError(t, err)
|
||||||
|
ctxCancel()
|
||||||
|
res := <-handshakeResCh
|
||||||
|
require.EqualError(t, res.err, context.Canceled.Error())
|
||||||
|
_, err = c2.Read(make([]byte, 10))
|
||||||
|
assert.Error(t, err)
|
||||||
|
_, err = c2.Write(make([]byte, 10))
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNotAHandshakeMessage(t *testing.T) {
|
func TestNotAHandshakeMessage(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var handshakeResCh = make(chan handshakeRes, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
@ -429,11 +480,11 @@ func TestEndToEnd(t *testing.T) {
|
|||||||
)
|
)
|
||||||
st := time.Now()
|
st := time.Now()
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||||
outResCh <- handshakeRes{identity: identity, err: err}
|
outResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
identity, err := IncomingHandshake(c2, noVerifyChecker)
|
identity, err := IncomingHandshake(nil, c2, noVerifyChecker)
|
||||||
inResCh <- handshakeRes{identity: identity, err: err}
|
inResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -457,7 +508,7 @@ func BenchmarkHandshake(b *testing.B) {
|
|||||||
defer close(done)
|
defer close(done)
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
_, _ = OutgoingHandshake(c1, noVerifyChecker)
|
_, _ = OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||||
select {
|
select {
|
||||||
case outRes <- struct{}{}:
|
case outRes <- struct{}{}:
|
||||||
case <-done:
|
case <-done:
|
||||||
@ -467,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) {
|
|||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
_, _ = IncomingHandshake(c2, noVerifyChecker)
|
_, _ = IncomingHandshake(nil, c2, noVerifyChecker)
|
||||||
select {
|
select {
|
||||||
case inRes <- struct{}{}:
|
case inRes <- struct{}{}:
|
||||||
case <-done:
|
case <-done:
|
||||||
|
|||||||
@ -37,7 +37,7 @@ func New() SecureService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type SecureService interface {
|
type SecureService interface {
|
||||||
TLSListener(lis net.Listener, timeoutMillis int) ContextListener
|
TLSListener(lis net.Listener, timeoutMillis int, withIdentityCheck bool) ContextListener
|
||||||
BasicListener(lis net.Listener, timeoutMillis int) ContextListener
|
BasicListener(lis net.Listener, timeoutMillis int) ContextListener
|
||||||
TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
|
TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
|
||||||
app.Component
|
app.Component
|
||||||
@ -76,8 +76,12 @@ func (s *secureService) Name() (name string) {
|
|||||||
return CName
|
return CName
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int) ContextListener {
|
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int, identityHandshake bool) ContextListener {
|
||||||
return newTLSListener(s.key, lis, timeoutMillis)
|
cc := s.noVerifyChecker
|
||||||
|
if identityHandshake {
|
||||||
|
cc = s.peerSignVerifier
|
||||||
|
}
|
||||||
|
return newTLSListener(cc, s.key, lis, timeoutMillis)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener {
|
func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener {
|
||||||
@ -98,7 +102,7 @@ func (s *secureService) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureC
|
|||||||
checker = s.noVerifyChecker
|
checker = s.noVerifyChecker
|
||||||
}
|
}
|
||||||
// ignore identity for outgoing connection because we don't need it at this moment
|
// ignore identity for outgoing connection because we don't need it at this moment
|
||||||
_, err = handshake.OutgoingHandshake(sc, checker)
|
_, err = handshake.OutgoingHandshake(ctx, sc, checker)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
|
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -57,7 +57,7 @@ func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.C
|
|||||||
err: err,
|
err: err,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
identity, err := handshake.IncomingHandshake(secure, p.cc)
|
identity, err := handshake.IncomingHandshake(nil, secure, p.cc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, HandshakeError{
|
return nil, nil, HandshakeError{
|
||||||
remoteAddr: conn.RemoteAddr().String(),
|
remoteAddr: conn.RemoteAddr().String(),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user