handshake: return identity

This commit is contained in:
Sergey Cherepanov 2023-02-13 13:45:07 +03:00
parent e8a7ad7476
commit 2fe9f8c295
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C
2 changed files with 161 additions and 104 deletions

View File

@ -47,59 +47,59 @@ var handshakePool = &sync.Pool{New: func() any {
type CredentialChecker interface { type CredentialChecker interface {
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
} }
func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) { func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
h := newHandshake() h := newHandshake()
defer h.release() defer h.release()
h.conn = sc h.conn = sc
localCred := cc.MakeCredentials(sc) localCred := cc.MakeCredentials(sc)
if err = h.writeCredentials(localCred); err != nil { if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return
} }
msg, err := h.readMsg() msg, err := h.readMsg()
if err != nil { if err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return
} }
if msg.ack != nil { if msg.ack != nil {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials { if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return ErrPeerDeclinedCredentials return nil, ErrPeerDeclinedCredentials
} }
return handshakeError{e: msg.ack.Error} return nil, handshakeError{e: msg.ack.Error}
} }
if err = cc.CheckCredential(sc, msg.cred); err != nil { if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return
} }
if err = h.writeAck(handshakeproto.Error_Null); err != nil { if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return nil, err
} }
msg, err = h.readMsg() msg, err = h.readMsg()
if err != nil { if err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return nil, err
} }
if msg.ack == nil { if msg.ack == nil {
err = ErrUnexpectedPayload err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return nil, err
} }
if msg.ack.Error == handshakeproto.Error_Null { if msg.ack.Error == handshakeproto.Error_Null {
return nil return identity, nil
} else { } else {
_ = h.conn.Close() _ = h.conn.Close()
return handshakeError{e: msg.ack.Error} return nil, handshakeError{e: msg.ack.Error}
} }
} }
func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) { func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
h := newHandshake() h := newHandshake()
defer h.release() defer h.release()
h.conn = sc h.conn = sc
@ -107,40 +107,40 @@ func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) {
msg, err := h.readMsg() msg, err := h.readMsg()
if err != nil { if err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return
} }
if msg.ack != nil { if msg.ack != nil {
return ErrUnexpectedPayload return nil, ErrUnexpectedPayload
} }
if err = cc.CheckCredential(sc, msg.cred); err != nil { if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return
} }
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil { if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return nil, err
} }
msg, err = h.readMsg() msg, err = h.readMsg()
if err != nil { if err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return nil, err
} }
if msg.ack == nil { if msg.ack == nil {
err = ErrUnexpectedPayload err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return nil, err
} }
if msg.ack.Error != handshakeproto.Error_Null { if msg.ack.Error != handshakeproto.Error_Null {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials { if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return ErrPeerDeclinedCredentials return nil, ErrPeerDeclinedCredentials
} }
return handshakeError{e: msg.ack.Error} return nil, handshakeError{e: msg.ack.Error}
} }
if err = h.writeAck(handshakeproto.Error_Null); err != nil { if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return err return nil, err
} }
return return
} }

View File

@ -18,17 +18,23 @@ import (
var noVerifyChecker = &testCredChecker{ var noVerifyChecker = &testCredChecker{
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify}, makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) { checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
return return []byte("identity"), nil
}, },
} }
type handshakeRes struct {
identity []byte
err error
}
func TestOutgoingHandshake(t *testing.T) { func TestOutgoingHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -36,7 +42,8 @@ func TestOutgoingHandshake(t *testing.T) {
msg, err := h.readMsg() msg, err := h.readMsg()
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
require.NoError(t, noVerifyChecker.CheckCredential(c2, msg.cred)) _, err = noVerifyChecker.CheckCredential(c2, msg.cred)
require.NoError(t, err)
// send credential message // send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// receive ack // receive ack
@ -45,23 +52,27 @@ func TestOutgoingHandshake(t *testing.T) {
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack // send ack
require.NoError(t, h.writeAck(handshakeproto.Error_Null)) require.NoError(t, h.writeAck(handshakeproto.Error_Null))
resErr := <-hanshareResCh res := <-handshakeResCh
assert.NoError(t, resErr) assert.NotEmpty(t, res.identity)
assert.NoError(t, res.err)
}) })
t.Run("write cred err", func(t *testing.T) { t.Run("write cred err", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
_ = c2.Close() _ = c2.Close()
require.Error(t, <-hanshareResCh) res := <-handshakeResCh
require.Error(t, res.err)
}) })
t.Run("read cred err", func(t *testing.T) { t.Run("read cred err", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -69,13 +80,15 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg() _, err := h.readMsg()
require.NoError(t, err) require.NoError(t, err)
_ = c2.Close() _ = c2.Close()
require.Error(t, <-hanshareResCh) res := <-handshakeResCh
require.Error(t, res.err)
}) })
t.Run("ack err", func(t *testing.T) { t.Run("ack err", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -83,13 +96,15 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg() _, err := h.readMsg()
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, h.writeAck(ErrInvalidCredentials.e)) require.NoError(t, h.writeAck(ErrInvalidCredentials.e))
require.EqualError(t, <-hanshareResCh, ErrPeerDeclinedCredentials.Error()) res := <-handshakeResCh
require.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error())
}) })
t.Run("cred err", func(t *testing.T) { t.Run("cred err", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) identity, err := OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -100,13 +115,15 @@ func TestOutgoingHandshake(t *testing.T) {
msg, err := h.readMsg() msg, err := h.readMsg()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error) assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
require.EqualError(t, <-hanshareResCh, ErrInvalidCredentials.Error()) res := <-handshakeResCh
require.EqualError(t, res.err, ErrInvalidCredentials.Error())
}) })
t.Run("write ack err", func(t *testing.T) { t.Run("write ack err", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -116,13 +133,15 @@ func TestOutgoingHandshake(t *testing.T) {
// write credentials and close conn // write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
_ = c2.Close() _ = c2.Close()
require.Error(t, <-hanshareResCh) res := <-handshakeResCh
require.Error(t, res.err)
}) })
t.Run("read ack err", func(t *testing.T) { t.Run("read ack err", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -135,13 +154,15 @@ func TestOutgoingHandshake(t *testing.T) {
_, err = h.readMsg() _, err = h.readMsg()
require.NoError(t, err) require.NoError(t, err)
_ = c2.Close() _ = c2.Close()
require.Error(t, <-hanshareResCh) res := <-handshakeResCh
require.Error(t, res.err)
}) })
t.Run("write cred instead ack", func(t *testing.T) { t.Run("write cred instead ack", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -158,13 +179,15 @@ func TestOutgoingHandshake(t *testing.T) {
msg, err := h.readMsg() msg, err := h.readMsg()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
require.Error(t, <-hanshareResCh) res := <-handshakeResCh
require.Error(t, res.err)
}) })
t.Run("final ack error", func(t *testing.T) { t.Run("final ack error", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -172,7 +195,8 @@ func TestOutgoingHandshake(t *testing.T) {
msg, err := h.readMsg() msg, err := h.readMsg()
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
require.NoError(t, noVerifyChecker.CheckCredential(c2, msg.cred)) _, err = noVerifyChecker.CheckCredential(c2, msg.cred)
require.NoError(t, err)
// send credential message // send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// receive ack // receive ack
@ -181,17 +205,18 @@ func TestOutgoingHandshake(t *testing.T) {
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack // send ack
require.NoError(t, h.writeAck(handshakeproto.Error_UnexpectedPayload)) require.NoError(t, h.writeAck(handshakeproto.Error_UnexpectedPayload))
resErr := <-hanshareResCh res := <-handshakeResCh
assert.Error(t, resErr) require.Error(t, res.err)
}) })
} }
func TestIncomingHandshake(t *testing.T) { func TestIncomingHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) { t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -208,47 +233,56 @@ func TestIncomingHandshake(t *testing.T) {
msg, err = h.readMsg() msg, err = h.readMsg()
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error) assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
require.NoError(t, <-hanshareResCh) res := <-handshakeResCh
assert.NotEmpty(t, res.identity)
require.NoError(t, res.err)
}) })
t.Run("write cred err", func(t *testing.T) { t.Run("write cred err", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
_ = c2.Close() _ = c2.Close()
require.Error(t, <-hanshareResCh) res := <-handshakeResCh
require.Error(t, res.err)
}) })
t.Run("read cred err", func(t *testing.T) { t.Run("read cred err", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials and close conn // write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
_ = c2.Close() _ = c2.Close()
require.Error(t, <-hanshareResCh) res := <-handshakeResCh
require.Error(t, res.err)
}) })
t.Run("write ack instead cred", func(t *testing.T) { t.Run("write ack instead cred", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write ack instead cred // write ack instead cred
require.NoError(t, h.writeAck(handshakeproto.Error_Null)) require.NoError(t, h.writeAck(handshakeproto.Error_Null))
require.Error(t, <-hanshareResCh) res := <-handshakeResCh
require.Error(t, res.err)
}) })
t.Run("invalid cred", func(t *testing.T) { t.Run("invalid cred", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) identity, err := IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -260,13 +294,15 @@ func TestIncomingHandshake(t *testing.T) {
require.Nil(t, msg.cred) require.Nil(t, msg.cred)
require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error) require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error)
require.EqualError(t, <-hanshareResCh, ErrInvalidCredentials.Error()) res := <-handshakeResCh
require.EqualError(t, res.err, ErrInvalidCredentials.Error())
}) })
t.Run("write cred instead ack", func(t *testing.T) { t.Run("write cred instead ack", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -280,13 +316,15 @@ func TestIncomingHandshake(t *testing.T) {
// expect ack with error // expect ack with error
msg, err := h.readMsg() msg, err := h.readMsg()
require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
require.Error(t, <-hanshareResCh) res := <-handshakeResCh
require.Error(t, res.err)
}) })
t.Run("read ack err", func(t *testing.T) { t.Run("read ack err", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -297,13 +335,15 @@ func TestIncomingHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_ = c2.Close() _ = c2.Close()
require.Error(t, <-hanshareResCh) res := <-handshakeResCh
require.Error(t, res.err)
}) })
t.Run("write ack with invalid", func(t *testing.T) { t.Run("write ack with invalid", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -317,13 +357,15 @@ func TestIncomingHandshake(t *testing.T) {
// write ack // write ack
require.NoError(t, h.writeAck(handshakeproto.Error_InvalidCredentials)) require.NoError(t, h.writeAck(handshakeproto.Error_InvalidCredentials))
assert.EqualError(t, <-hanshareResCh, ErrPeerDeclinedCredentials.Error()) res := <-handshakeResCh
assert.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error())
}) })
t.Run("write ack with err", func(t *testing.T) { t.Run("write ack with err", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -337,13 +379,15 @@ func TestIncomingHandshake(t *testing.T) {
// write ack // write ack
require.NoError(t, h.writeAck(handshakeproto.Error_Unexpected)) require.NoError(t, h.writeAck(handshakeproto.Error_Unexpected))
assert.EqualError(t, <-hanshareResCh, ErrUnexpected.Error()) res := <-handshakeResCh
assert.EqualError(t, res.err, ErrUnexpected.Error())
}) })
t.Run("final ack error", func(t *testing.T) { t.Run("final ack error", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -357,53 +401,65 @@ func TestIncomingHandshake(t *testing.T) {
// write ack and close conn // write ack and close conn
require.NoError(t, h.writeAck(handshakeproto.Error_Null)) require.NoError(t, h.writeAck(handshakeproto.Error_Null))
_ = c2.Close() _ = c2.Close()
require.Error(t, <-hanshareResCh) res := <-handshakeResCh
require.Error(t, res.err)
}) })
} }
func TestNotAHandshakeMessage(t *testing.T) { func TestNotAHandshakeMessage(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
_, err := c2.Write([]byte("some unexpected bytes")) _, err := c2.Write([]byte("some unexpected bytes"))
require.Error(t, err) require.Error(t, err)
assert.EqualError(t, <-hanshareResCh, ErrGotNotAHandshakeMessage.Error()) res := <-handshakeResCh
assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error())
} }
func TestEndToEnd(t *testing.T) { func TestEndToEnd(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var ( var (
inRes = make(chan error, 1) inResCh = make(chan handshakeRes, 1)
outRes = make(chan error, 1) outResCh = make(chan handshakeRes, 1)
) )
st := time.Now() st := time.Now()
go func() { go func() {
outRes <- OutgoingHandshake(c1, noVerifyChecker) identity, err := OutgoingHandshake(c1, noVerifyChecker)
outResCh <- handshakeRes{identity: identity, err: err}
}() }()
go func() { go func() {
inRes <- IncomingHandshake(c2, noVerifyChecker) identity, err := IncomingHandshake(c2, noVerifyChecker)
inResCh <- handshakeRes{identity: identity, err: err}
}() }()
assert.NoError(t, <-outRes)
assert.NoError(t, <-inRes) outRes := <-outResCh
assert.NoError(t, outRes.err)
assert.NotEmpty(t, outRes.identity)
inRes := <-inResCh
assert.NoError(t, inRes.err)
assert.NotEmpty(t, inRes.identity)
t.Log("dur", time.Since(st)) t.Log("dur", time.Since(st))
} }
func BenchmarkHandshake(b *testing.B) { func BenchmarkHandshake(b *testing.B) {
c1, c2 := newConnPair(b) c1, c2 := newConnPair(b)
var ( var (
inRes = make(chan error) inRes = make(chan struct{})
outRes = make(chan error) outRes = make(chan struct{})
done = make(chan struct{}) done = make(chan struct{})
) )
defer close(done) defer close(done)
go func() { go func() {
for { for {
_, _ = OutgoingHandshake(c1, noVerifyChecker)
select { select {
case outRes <- OutgoingHandshake(c1, noVerifyChecker): case outRes <- struct{}{}:
case <-done: case <-done:
return return
} }
@ -411,8 +467,9 @@ func BenchmarkHandshake(b *testing.B) {
}() }()
go func() { go func() {
for { for {
_, _ = IncomingHandshake(c2, noVerifyChecker)
select { select {
case inRes <- IncomingHandshake(c2, noVerifyChecker): case inRes <- struct{}{}:
case <-done: case <-done:
return return
} }
@ -430,7 +487,7 @@ func BenchmarkHandshake(b *testing.B) {
type testCredChecker struct { type testCredChecker struct {
makeCred *handshakeproto.Credentials makeCred *handshakeproto.Credentials
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
checkErr error checkErr error
} }
@ -438,14 +495,14 @@ func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cre
return t.makeCred return t.makeCred
} }
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) { func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
if t.checkErr != nil { if t.checkErr != nil {
return t.checkErr return nil, t.checkErr
} }
if t.checkCred != nil { if t.checkCred != nil {
return t.checkCred(sc, cred) return t.checkCred(sc, cred)
} }
return nil return nil, nil
} }
func newConnPair(t require.TestingT) (sc1, sc2 *secConn) { func newConnPair(t require.TestingT) (sc1, sc2 *secConn) {