handshake: return identity
This commit is contained in:
parent
e8a7ad7476
commit
2fe9f8c295
@ -47,59 +47,59 @@ var handshakePool = &sync.Pool{New: func() any {
|
|||||||
|
|
||||||
type CredentialChecker interface {
|
type CredentialChecker interface {
|
||||||
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials
|
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials
|
||||||
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error)
|
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) {
|
func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
defer h.release()
|
defer h.release()
|
||||||
h.conn = sc
|
h.conn = sc
|
||||||
localCred := cc.MakeCredentials(sc)
|
localCred := cc.MakeCredentials(sc)
|
||||||
if err = h.writeCredentials(localCred); err != nil {
|
if err = h.writeCredentials(localCred); err != nil {
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
msg, err := h.readMsg()
|
msg, err := h.readMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if msg.ack != nil {
|
if msg.ack != nil {
|
||||||
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
|
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
|
||||||
return ErrPeerDeclinedCredentials
|
return nil, ErrPeerDeclinedCredentials
|
||||||
}
|
}
|
||||||
return handshakeError{e: msg.ack.Error}
|
return nil, handshakeError{e: msg.ack.Error}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = cc.CheckCredential(sc, msg.cred); err != nil {
|
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
|
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err = h.readMsg()
|
msg, err = h.readMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
if msg.ack == nil {
|
if msg.ack == nil {
|
||||||
err = ErrUnexpectedPayload
|
err = ErrUnexpectedPayload
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
if msg.ack.Error == handshakeproto.Error_Null {
|
if msg.ack.Error == handshakeproto.Error_Null {
|
||||||
return nil
|
return identity, nil
|
||||||
} else {
|
} else {
|
||||||
_ = h.conn.Close()
|
_ = h.conn.Close()
|
||||||
return handshakeError{e: msg.ack.Error}
|
return nil, handshakeError{e: msg.ack.Error}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) {
|
func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
defer h.release()
|
defer h.release()
|
||||||
h.conn = sc
|
h.conn = sc
|
||||||
@ -107,40 +107,40 @@ func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) {
|
|||||||
msg, err := h.readMsg()
|
msg, err := h.readMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
if msg.ack != nil {
|
if msg.ack != nil {
|
||||||
return ErrUnexpectedPayload
|
return nil, ErrUnexpectedPayload
|
||||||
}
|
}
|
||||||
if err = cc.CheckCredential(sc, msg.cred); err != nil {
|
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
|
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
msg, err = h.readMsg()
|
msg, err = h.readMsg()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
if msg.ack == nil {
|
if msg.ack == nil {
|
||||||
err = ErrUnexpectedPayload
|
err = ErrUnexpectedPayload
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
if msg.ack.Error != handshakeproto.Error_Null {
|
if msg.ack.Error != handshakeproto.Error_Null {
|
||||||
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
|
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
|
||||||
return ErrPeerDeclinedCredentials
|
return nil, ErrPeerDeclinedCredentials
|
||||||
}
|
}
|
||||||
return handshakeError{e: msg.ack.Error}
|
return nil, handshakeError{e: msg.ack.Error}
|
||||||
}
|
}
|
||||||
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
|
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
|
||||||
h.tryWriteErrAndClose(err)
|
h.tryWriteErrAndClose(err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@ -18,17 +18,23 @@ import (
|
|||||||
|
|
||||||
var noVerifyChecker = &testCredChecker{
|
var noVerifyChecker = &testCredChecker{
|
||||||
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
|
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
|
||||||
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) {
|
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||||
return
|
return []byte("identity"), nil
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type handshakeRes struct {
|
||||||
|
identity []byte
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
func TestOutgoingHandshake(t *testing.T) {
|
func TestOutgoingHandshake(t *testing.T) {
|
||||||
t.Run("success", func(t *testing.T) {
|
t.Run("success", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -36,7 +42,8 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
msg, err := h.readMsg()
|
msg, err := h.readMsg()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Nil(t, msg.ack)
|
require.Nil(t, msg.ack)
|
||||||
require.NoError(t, noVerifyChecker.CheckCredential(c2, msg.cred))
|
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
|
||||||
|
require.NoError(t, err)
|
||||||
// send credential message
|
// send credential message
|
||||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||||
// receive ack
|
// receive ack
|
||||||
@ -45,23 +52,27 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
||||||
// send ack
|
// send ack
|
||||||
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
|
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
|
||||||
resErr := <-hanshareResCh
|
res := <-handshakeResCh
|
||||||
assert.NoError(t, resErr)
|
assert.NotEmpty(t, res.identity)
|
||||||
|
assert.NoError(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("write cred err", func(t *testing.T) {
|
t.Run("write cred err", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
_ = c2.Close()
|
_ = c2.Close()
|
||||||
require.Error(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("read cred err", func(t *testing.T) {
|
t.Run("read cred err", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -69,13 +80,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
_, err := h.readMsg()
|
_, err := h.readMsg()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_ = c2.Close()
|
_ = c2.Close()
|
||||||
require.Error(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("ack err", func(t *testing.T) {
|
t.Run("ack err", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -83,13 +96,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
_, err := h.readMsg()
|
_, err := h.readMsg()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, h.writeAck(ErrInvalidCredentials.e))
|
require.NoError(t, h.writeAck(ErrInvalidCredentials.e))
|
||||||
require.EqualError(t, <-hanshareResCh, ErrPeerDeclinedCredentials.Error())
|
res := <-handshakeResCh
|
||||||
|
require.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error())
|
||||||
})
|
})
|
||||||
t.Run("cred err", func(t *testing.T) {
|
t.Run("cred err", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
identity, err := OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -100,13 +115,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
msg, err := h.readMsg()
|
msg, err := h.readMsg()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
|
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
|
||||||
require.EqualError(t, <-hanshareResCh, ErrInvalidCredentials.Error())
|
res := <-handshakeResCh
|
||||||
|
require.EqualError(t, res.err, ErrInvalidCredentials.Error())
|
||||||
})
|
})
|
||||||
t.Run("write ack err", func(t *testing.T) {
|
t.Run("write ack err", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -116,13 +133,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
// write credentials and close conn
|
// write credentials and close conn
|
||||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||||
_ = c2.Close()
|
_ = c2.Close()
|
||||||
require.Error(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("read ack err", func(t *testing.T) {
|
t.Run("read ack err", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -135,13 +154,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
_, err = h.readMsg()
|
_, err = h.readMsg()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_ = c2.Close()
|
_ = c2.Close()
|
||||||
require.Error(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("write cred instead ack", func(t *testing.T) {
|
t.Run("write cred instead ack", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -158,13 +179,15 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
msg, err := h.readMsg()
|
msg, err := h.readMsg()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
|
assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
|
||||||
require.Error(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("final ack error", func(t *testing.T) {
|
t.Run("final ack error", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -172,7 +195,8 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
msg, err := h.readMsg()
|
msg, err := h.readMsg()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Nil(t, msg.ack)
|
require.Nil(t, msg.ack)
|
||||||
require.NoError(t, noVerifyChecker.CheckCredential(c2, msg.cred))
|
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
|
||||||
|
require.NoError(t, err)
|
||||||
// send credential message
|
// send credential message
|
||||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||||
// receive ack
|
// receive ack
|
||||||
@ -181,17 +205,18 @@ func TestOutgoingHandshake(t *testing.T) {
|
|||||||
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
||||||
// send ack
|
// send ack
|
||||||
require.NoError(t, h.writeAck(handshakeproto.Error_UnexpectedPayload))
|
require.NoError(t, h.writeAck(handshakeproto.Error_UnexpectedPayload))
|
||||||
resErr := <-hanshareResCh
|
res := <-handshakeResCh
|
||||||
assert.Error(t, resErr)
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIncomingHandshake(t *testing.T) {
|
func TestIncomingHandshake(t *testing.T) {
|
||||||
t.Run("success", func(t *testing.T) {
|
t.Run("success", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -208,47 +233,56 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
msg, err = h.readMsg()
|
msg, err = h.readMsg()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
|
||||||
require.NoError(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
assert.NotEmpty(t, res.identity)
|
||||||
|
require.NoError(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("write cred err", func(t *testing.T) {
|
t.Run("write cred err", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
_ = c2.Close()
|
_ = c2.Close()
|
||||||
require.Error(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("read cred err", func(t *testing.T) {
|
t.Run("read cred err", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
// write credentials and close conn
|
// write credentials and close conn
|
||||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||||
_ = c2.Close()
|
_ = c2.Close()
|
||||||
require.Error(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("write ack instead cred", func(t *testing.T) {
|
t.Run("write ack instead cred", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
// write ack instead cred
|
// write ack instead cred
|
||||||
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
|
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
|
||||||
require.Error(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("invalid cred", func(t *testing.T) {
|
t.Run("invalid cred", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
identity, err := IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -260,13 +294,15 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
require.Nil(t, msg.cred)
|
require.Nil(t, msg.cred)
|
||||||
require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error)
|
require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error)
|
||||||
|
|
||||||
require.EqualError(t, <-hanshareResCh, ErrInvalidCredentials.Error())
|
res := <-handshakeResCh
|
||||||
|
require.EqualError(t, res.err, ErrInvalidCredentials.Error())
|
||||||
})
|
})
|
||||||
t.Run("write cred instead ack", func(t *testing.T) {
|
t.Run("write cred instead ack", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -280,13 +316,15 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
// expect ack with error
|
// expect ack with error
|
||||||
msg, err := h.readMsg()
|
msg, err := h.readMsg()
|
||||||
require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
|
require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
|
||||||
require.Error(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("read ack err", func(t *testing.T) {
|
t.Run("read ack err", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -297,13 +335,15 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
_ = c2.Close()
|
_ = c2.Close()
|
||||||
|
|
||||||
require.Error(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
t.Run("write ack with invalid", func(t *testing.T) {
|
t.Run("write ack with invalid", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -317,13 +357,15 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
// write ack
|
// write ack
|
||||||
require.NoError(t, h.writeAck(handshakeproto.Error_InvalidCredentials))
|
require.NoError(t, h.writeAck(handshakeproto.Error_InvalidCredentials))
|
||||||
|
|
||||||
assert.EqualError(t, <-hanshareResCh, ErrPeerDeclinedCredentials.Error())
|
res := <-handshakeResCh
|
||||||
|
assert.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error())
|
||||||
})
|
})
|
||||||
t.Run("write ack with err", func(t *testing.T) {
|
t.Run("write ack with err", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -337,13 +379,15 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
// write ack
|
// write ack
|
||||||
require.NoError(t, h.writeAck(handshakeproto.Error_Unexpected))
|
require.NoError(t, h.writeAck(handshakeproto.Error_Unexpected))
|
||||||
|
|
||||||
assert.EqualError(t, <-hanshareResCh, ErrUnexpected.Error())
|
res := <-handshakeResCh
|
||||||
|
assert.EqualError(t, res.err, ErrUnexpected.Error())
|
||||||
})
|
})
|
||||||
t.Run("final ack error", func(t *testing.T) {
|
t.Run("final ack error", func(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
@ -357,53 +401,65 @@ func TestIncomingHandshake(t *testing.T) {
|
|||||||
// write ack and close conn
|
// write ack and close conn
|
||||||
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
|
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
|
||||||
_ = c2.Close()
|
_ = c2.Close()
|
||||||
require.Error(t, <-hanshareResCh)
|
res := <-handshakeResCh
|
||||||
|
require.Error(t, res.err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNotAHandshakeMessage(t *testing.T) {
|
func TestNotAHandshakeMessage(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var hanshareResCh = make(chan error, 1)
|
var handshakeResCh = make(chan handshakeRes, 1)
|
||||||
go func() {
|
go func() {
|
||||||
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
|
identity, err := IncomingHandshake(c1, noVerifyChecker)
|
||||||
|
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
h := newHandshake()
|
h := newHandshake()
|
||||||
h.conn = c2
|
h.conn = c2
|
||||||
_, err := c2.Write([]byte("some unexpected bytes"))
|
_, err := c2.Write([]byte("some unexpected bytes"))
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.EqualError(t, <-hanshareResCh, ErrGotNotAHandshakeMessage.Error())
|
res := <-handshakeResCh
|
||||||
|
assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEndToEnd(t *testing.T) {
|
func TestEndToEnd(t *testing.T) {
|
||||||
c1, c2 := newConnPair(t)
|
c1, c2 := newConnPair(t)
|
||||||
var (
|
var (
|
||||||
inRes = make(chan error, 1)
|
inResCh = make(chan handshakeRes, 1)
|
||||||
outRes = make(chan error, 1)
|
outResCh = make(chan handshakeRes, 1)
|
||||||
)
|
)
|
||||||
st := time.Now()
|
st := time.Now()
|
||||||
go func() {
|
go func() {
|
||||||
outRes <- OutgoingHandshake(c1, noVerifyChecker)
|
identity, err := OutgoingHandshake(c1, noVerifyChecker)
|
||||||
|
outResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
inRes <- IncomingHandshake(c2, noVerifyChecker)
|
identity, err := IncomingHandshake(c2, noVerifyChecker)
|
||||||
|
inResCh <- handshakeRes{identity: identity, err: err}
|
||||||
}()
|
}()
|
||||||
assert.NoError(t, <-outRes)
|
|
||||||
assert.NoError(t, <-inRes)
|
outRes := <-outResCh
|
||||||
|
assert.NoError(t, outRes.err)
|
||||||
|
assert.NotEmpty(t, outRes.identity)
|
||||||
|
|
||||||
|
inRes := <-inResCh
|
||||||
|
assert.NoError(t, inRes.err)
|
||||||
|
assert.NotEmpty(t, inRes.identity)
|
||||||
t.Log("dur", time.Since(st))
|
t.Log("dur", time.Since(st))
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkHandshake(b *testing.B) {
|
func BenchmarkHandshake(b *testing.B) {
|
||||||
c1, c2 := newConnPair(b)
|
c1, c2 := newConnPair(b)
|
||||||
var (
|
var (
|
||||||
inRes = make(chan error)
|
inRes = make(chan struct{})
|
||||||
outRes = make(chan error)
|
outRes = make(chan struct{})
|
||||||
done = make(chan struct{})
|
done = make(chan struct{})
|
||||||
)
|
)
|
||||||
defer close(done)
|
defer close(done)
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
|
_, _ = OutgoingHandshake(c1, noVerifyChecker)
|
||||||
select {
|
select {
|
||||||
case outRes <- OutgoingHandshake(c1, noVerifyChecker):
|
case outRes <- struct{}{}:
|
||||||
case <-done:
|
case <-done:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -411,8 +467,9 @@ func BenchmarkHandshake(b *testing.B) {
|
|||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
|
_, _ = IncomingHandshake(c2, noVerifyChecker)
|
||||||
select {
|
select {
|
||||||
case inRes <- IncomingHandshake(c2, noVerifyChecker):
|
case inRes <- struct{}{}:
|
||||||
case <-done:
|
case <-done:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -430,7 +487,7 @@ func BenchmarkHandshake(b *testing.B) {
|
|||||||
|
|
||||||
type testCredChecker struct {
|
type testCredChecker struct {
|
||||||
makeCred *handshakeproto.Credentials
|
makeCred *handshakeproto.Credentials
|
||||||
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error)
|
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
|
||||||
checkErr error
|
checkErr error
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -438,14 +495,14 @@ func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cre
|
|||||||
return t.makeCred
|
return t.makeCred
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) {
|
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||||
if t.checkErr != nil {
|
if t.checkErr != nil {
|
||||||
return t.checkErr
|
return nil, t.checkErr
|
||||||
}
|
}
|
||||||
if t.checkCred != nil {
|
if t.checkCred != nil {
|
||||||
return t.checkCred(sc, cred)
|
return t.checkCred(sc, cred)
|
||||||
}
|
}
|
||||||
return nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newConnPair(t require.TestingT) (sc1, sc2 *secConn) {
|
func newConnPair(t require.TestingT) (sc1, sc2 *secConn) {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user