handshake: return identity
This commit is contained in:
parent
e8a7ad7476
commit
2fe9f8c295
@ -47,59 +47,59 @@ var handshakePool = &sync.Pool{New: func() any {
|
||||
|
||||
type CredentialChecker interface {
|
||||
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials
|
||||
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error)
|
||||
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
|
||||
}
|
||||
|
||||
func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) {
|
||||
func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
h := newHandshake()
|
||||
defer h.release()
|
||||
h.conn = sc
|
||||
localCred := cc.MakeCredentials(sc)
|
||||
if err = h.writeCredentials(localCred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return
|
||||
}
|
||||
msg, err := h.readMsg()
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return
|
||||
}
|
||||
if msg.ack != nil {
|
||||
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
|
||||
return ErrPeerDeclinedCredentials
|
||||
return nil, ErrPeerDeclinedCredentials
|
||||
}
|
||||
return handshakeError{e: msg.ack.Error}
|
||||
return nil, handshakeError{e: msg.ack.Error}
|
||||
}
|
||||
|
||||
if err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err = h.readMsg()
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if msg.ack == nil {
|
||||
err = ErrUnexpectedPayload
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if msg.ack.Error == handshakeproto.Error_Null {
|
||||
return nil
|
||||
return identity, nil
|
||||
} else {
|
||||
_ = h.conn.Close()
|
||||
return handshakeError{e: msg.ack.Error}
|
||||
return nil, handshakeError{e: msg.ack.Error}
|
||||
}
|
||||
}
|
||||
|
||||
func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) {
|
||||
func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
h := newHandshake()
|
||||
defer h.release()
|
||||
h.conn = sc
|
||||
@ -107,40 +107,40 @@ func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) {
|
||||
msg, err := h.readMsg()
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return
|
||||
}
|
||||
if msg.ack != nil {
|
||||
return ErrUnexpectedPayload
|
||||
return nil, ErrUnexpectedPayload
|
||||
}
|
||||
if err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err = h.readMsg()
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if msg.ack == nil {
|
||||
err = ErrUnexpectedPayload
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if msg.ack.Error != handshakeproto.Error_Null {
|
||||
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
|
||||
return ErrPeerDeclinedCredentials
|
||||
return nil, ErrPeerDeclinedCredentials
|
||||
}
|
||||
return handshakeError{e: msg.ack.Error}
|
||||
return nil, handshakeError{e: msg.ack.Error}
|
||||
}
|
||||
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -18,17 +18,23 @@ import (
|
||||
|
||||
var noVerifyChecker = &testCredChecker{
|
||||
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
|
||||
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) {
|
||||
return
|
||||
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||
return []byte("identity"), nil
|
||||
},
|
||||
}
|
||||
|
||||
type handshakeRes struct {
|
||||
identity []byte
|
||||
err error
|
||||
}
|
||||
|
||||
func TestOutgoingHandshake(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -36,7 +42,8 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
msg, err := h.readMsg()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, msg.ack)
|
||||
require.NoError(t, noVerifyChecker.CheckCredential(c2, msg.cred))
|
||||
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
|
||||
require.NoError(t, err)
|
||||
// send credential message
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// receive ack
|
||||
@ -45,23 +52,27 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
||||
// send ack
|
||||
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
|
||||
resErr := <-hanshareResCh
|
||||
assert.NoError(t, resErr)
|
||||
res := <-handshakeResCh
|
||||
assert.NotEmpty(t, res.identity)
|
||||
assert.NoError(t, res.err)
|
||||
})
|
||||
t.Run("write cred err", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
_ = c2.Close()
|
||||
require.Error(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("read cred err", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -69,13 +80,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
_, err := h.readMsg()
|
||||
require.NoError(t, err)
|
||||
_ = c2.Close()
|
||||
require.Error(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("ack err", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -83,13 +96,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
_, err := h.readMsg()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, h.writeAck(ErrInvalidCredentials.e))
|
||||
require.EqualError(t, <-hanshareResCh, ErrPeerDeclinedCredentials.Error())
|
||||
res := <-handshakeResCh
|
||||
require.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error())
|
||||
})
|
||||
t.Run("cred err", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
identity, err := OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -100,13 +115,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
msg, err := h.readMsg()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
|
||||
require.EqualError(t, <-hanshareResCh, ErrInvalidCredentials.Error())
|
||||
res := <-handshakeResCh
|
||||
require.EqualError(t, res.err, ErrInvalidCredentials.Error())
|
||||
})
|
||||
t.Run("write ack err", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -116,13 +133,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
// write credentials and close conn
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
_ = c2.Close()
|
||||
require.Error(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("read ack err", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -135,13 +154,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
_, err = h.readMsg()
|
||||
require.NoError(t, err)
|
||||
_ = c2.Close()
|
||||
require.Error(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("write cred instead ack", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -158,13 +179,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
msg, err := h.readMsg()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
|
||||
require.Error(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("final ack error", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -172,7 +195,8 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
msg, err := h.readMsg()
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, msg.ack)
|
||||
require.NoError(t, noVerifyChecker.CheckCredential(c2, msg.cred))
|
||||
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
|
||||
require.NoError(t, err)
|
||||
// send credential message
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
// receive ack
|
||||
@ -181,17 +205,18 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
||||
// send ack
|
||||
require.NoError(t, h.writeAck(handshakeproto.Error_UnexpectedPayload))
|
||||
resErr := <-hanshareResCh
|
||||
assert.Error(t, resErr)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestIncomingHandshake(t *testing.T) {
|
||||
t.Run("success", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -208,47 +233,56 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
msg, err = h.readMsg()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
||||
require.NoError(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
assert.NotEmpty(t, res.identity)
|
||||
require.NoError(t, res.err)
|
||||
})
|
||||
t.Run("write cred err", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
_ = c2.Close()
|
||||
require.Error(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("read cred err", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials and close conn
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
_ = c2.Close()
|
||||
require.Error(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("write ack instead cred", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write ack instead cred
|
||||
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
|
||||
require.Error(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("invalid cred", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
identity, err := IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -260,13 +294,15 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
require.Nil(t, msg.cred)
|
||||
require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error)
|
||||
|
||||
require.EqualError(t, <-hanshareResCh, ErrInvalidCredentials.Error())
|
||||
res := <-handshakeResCh
|
||||
require.EqualError(t, res.err, ErrInvalidCredentials.Error())
|
||||
})
|
||||
t.Run("write cred instead ack", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -280,13 +316,15 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// expect ack with error
|
||||
msg, err := h.readMsg()
|
||||
require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
|
||||
require.Error(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("read ack err", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -297,13 +335,15 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
_ = c2.Close()
|
||||
|
||||
require.Error(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
t.Run("write ack with invalid", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -317,13 +357,15 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write ack
|
||||
require.NoError(t, h.writeAck(handshakeproto.Error_InvalidCredentials))
|
||||
|
||||
assert.EqualError(t, <-hanshareResCh, ErrPeerDeclinedCredentials.Error())
|
||||
res := <-handshakeResCh
|
||||
assert.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error())
|
||||
})
|
||||
t.Run("write ack with err", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -337,13 +379,15 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write ack
|
||||
require.NoError(t, h.writeAck(handshakeproto.Error_Unexpected))
|
||||
|
||||
assert.EqualError(t, <-hanshareResCh, ErrUnexpected.Error())
|
||||
res := <-handshakeResCh
|
||||
assert.EqualError(t, res.err, ErrUnexpected.Error())
|
||||
})
|
||||
t.Run("final ack error", func(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
@ -357,53 +401,65 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
// write ack and close conn
|
||||
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
|
||||
_ = c2.Close()
|
||||
require.Error(t, <-hanshareResCh)
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestNotAHandshakeMessage(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var hanshareResCh = make(chan error, 1)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
_, err := c2.Write([]byte("some unexpected bytes"))
|
||||
require.Error(t, err)
|
||||
assert.EqualError(t, <-hanshareResCh, ErrGotNotAHandshakeMessage.Error())
|
||||
res := <-handshakeResCh
|
||||
assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error())
|
||||
}
|
||||
|
||||
func TestEndToEnd(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var (
|
||||
inRes = make(chan error, 1)
|
||||
outRes = make(chan error, 1)
|
||||
inResCh = make(chan handshakeRes, 1)
|
||||
outResCh = make(chan handshakeRes, 1)
|
||||
)
|
||||
st := time.Now()
|
||||
go func() {
|
||||
outRes <- OutgoingHandshake(c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||
outResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
go func() {
|
||||
inRes <- IncomingHandshake(c2, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(c2, noVerifyChecker)
|
||||
inResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
assert.NoError(t, <-outRes)
|
||||
assert.NoError(t, <-inRes)
|
||||
|
||||
outRes := <-outResCh
|
||||
assert.NoError(t, outRes.err)
|
||||
assert.NotEmpty(t, outRes.identity)
|
||||
|
||||
inRes := <-inResCh
|
||||
assert.NoError(t, inRes.err)
|
||||
assert.NotEmpty(t, inRes.identity)
|
||||
t.Log("dur", time.Since(st))
|
||||
}
|
||||
|
||||
func BenchmarkHandshake(b *testing.B) {
|
||||
c1, c2 := newConnPair(b)
|
||||
var (
|
||||
inRes = make(chan error)
|
||||
outRes = make(chan error)
|
||||
inRes = make(chan struct{})
|
||||
outRes = make(chan struct{})
|
||||
done = make(chan struct{})
|
||||
)
|
||||
defer close(done)
|
||||
go func() {
|
||||
for {
|
||||
_, _ = OutgoingHandshake(c1, noVerifyChecker)
|
||||
select {
|
||||
case outRes <- OutgoingHandshake(c1, noVerifyChecker):
|
||||
case outRes <- struct{}{}:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
@ -411,8 +467,9 @@ func BenchmarkHandshake(b *testing.B) {
|
||||
}()
|
||||
go func() {
|
||||
for {
|
||||
_, _ = IncomingHandshake(c2, noVerifyChecker)
|
||||
select {
|
||||
case inRes <- IncomingHandshake(c2, noVerifyChecker):
|
||||
case inRes <- struct{}{}:
|
||||
case <-done:
|
||||
return
|
||||
}
|
||||
@ -430,7 +487,7 @@ func BenchmarkHandshake(b *testing.B) {
|
||||
|
||||
type testCredChecker struct {
|
||||
makeCred *handshakeproto.Credentials
|
||||
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error)
|
||||
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
|
||||
checkErr error
|
||||
}
|
||||
|
||||
@ -438,14 +495,14 @@ func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cre
|
||||
return t.makeCred
|
||||
}
|
||||
|
||||
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) {
|
||||
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||
if t.checkErr != nil {
|
||||
return t.checkErr
|
||||
return nil, t.checkErr
|
||||
}
|
||||
if t.checkCred != nil {
|
||||
return t.checkCred(sc, cred)
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func newConnPair(t require.TestingT) (sc1, sc2 *secConn) {
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user