handshake: return identity

This commit is contained in:
Sergey Cherepanov 2023-02-13 13:45:07 +03:00
parent e8a7ad7476
commit 2fe9f8c295
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C
2 changed files with 161 additions and 104 deletions

View File

@ -47,59 +47,59 @@ var handshakePool = &sync.Pool{New: func() any {
type CredentialChecker interface {
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error)
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
}
func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) {
func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
h := newHandshake()
defer h.release()
h.conn = sc
localCred := cc.MakeCredentials(sc)
if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err)
return err
return
}
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return err
return
}
if msg.ack != nil {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return ErrPeerDeclinedCredentials
return nil, ErrPeerDeclinedCredentials
}
return handshakeError{e: msg.ack.Error}
return nil, handshakeError{e: msg.ack.Error}
}
if err = cc.CheckCredential(sc, msg.cred); err != nil {
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return err
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return err
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return err
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return err
return nil, err
}
if msg.ack.Error == handshakeproto.Error_Null {
return nil
return identity, nil
} else {
_ = h.conn.Close()
return handshakeError{e: msg.ack.Error}
return nil, handshakeError{e: msg.ack.Error}
}
}
func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) {
func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
h := newHandshake()
defer h.release()
h.conn = sc
@ -107,40 +107,40 @@ func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) {
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return err
return
}
if msg.ack != nil {
return ErrUnexpectedPayload
return nil, ErrUnexpectedPayload
}
if err = cc.CheckCredential(sc, msg.cred); err != nil {
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return err
return
}
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
h.tryWriteErrAndClose(err)
return err
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return err
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return err
return nil, err
}
if msg.ack.Error != handshakeproto.Error_Null {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return ErrPeerDeclinedCredentials
return nil, ErrPeerDeclinedCredentials
}
return handshakeError{e: msg.ack.Error}
return nil, handshakeError{e: msg.ack.Error}
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return err
return nil, err
}
return
}

View File

@ -18,17 +18,23 @@ import (
var noVerifyChecker = &testCredChecker{
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) {
return
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
return []byte("identity"), nil
},
}
type handshakeRes struct {
identity []byte
err error
}
func TestOutgoingHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -36,7 +42,8 @@ func TestOutgoingHandshake(t *testing.T) {
msg, err := h.readMsg()
require.NoError(t, err)
require.Nil(t, msg.ack)
require.NoError(t, noVerifyChecker.CheckCredential(c2, msg.cred))
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
require.NoError(t, err)
// send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// receive ack
@ -45,23 +52,27 @@ func TestOutgoingHandshake(t *testing.T) {
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
resErr := <-hanshareResCh
assert.NoError(t, resErr)
res := <-handshakeResCh
assert.NotEmpty(t, res.identity)
assert.NoError(t, res.err)
})
t.Run("write cred err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
_ = c2.Close()
require.Error(t, <-hanshareResCh)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("read cred err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -69,13 +80,15 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg()
require.NoError(t, err)
_ = c2.Close()
require.Error(t, <-hanshareResCh)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("ack err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -83,13 +96,15 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg()
require.NoError(t, err)
require.NoError(t, h.writeAck(ErrInvalidCredentials.e))
require.EqualError(t, <-hanshareResCh, ErrPeerDeclinedCredentials.Error())
res := <-handshakeResCh
require.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error())
})
t.Run("cred err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
identity, err := OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -100,13 +115,15 @@ func TestOutgoingHandshake(t *testing.T) {
msg, err := h.readMsg()
require.NoError(t, err)
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
require.EqualError(t, <-hanshareResCh, ErrInvalidCredentials.Error())
res := <-handshakeResCh
require.EqualError(t, res.err, ErrInvalidCredentials.Error())
})
t.Run("write ack err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -116,13 +133,15 @@ func TestOutgoingHandshake(t *testing.T) {
// write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
_ = c2.Close()
require.Error(t, <-hanshareResCh)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("read ack err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -135,13 +154,15 @@ func TestOutgoingHandshake(t *testing.T) {
_, err = h.readMsg()
require.NoError(t, err)
_ = c2.Close()
require.Error(t, <-hanshareResCh)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("write cred instead ack", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -158,13 +179,15 @@ func TestOutgoingHandshake(t *testing.T) {
msg, err := h.readMsg()
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
require.Error(t, <-hanshareResCh)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("final ack error", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -172,7 +195,8 @@ func TestOutgoingHandshake(t *testing.T) {
msg, err := h.readMsg()
require.NoError(t, err)
require.Nil(t, msg.ack)
require.NoError(t, noVerifyChecker.CheckCredential(c2, msg.cred))
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
require.NoError(t, err)
// send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// receive ack
@ -181,17 +205,18 @@ func TestOutgoingHandshake(t *testing.T) {
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack
require.NoError(t, h.writeAck(handshakeproto.Error_UnexpectedPayload))
resErr := <-hanshareResCh
assert.Error(t, resErr)
res := <-handshakeResCh
require.Error(t, res.err)
})
}
func TestIncomingHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -208,47 +233,56 @@ func TestIncomingHandshake(t *testing.T) {
msg, err = h.readMsg()
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
require.NoError(t, <-hanshareResCh)
res := <-handshakeResCh
assert.NotEmpty(t, res.identity)
require.NoError(t, res.err)
})
t.Run("write cred err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
_ = c2.Close()
require.Error(t, <-hanshareResCh)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("read cred err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
_ = c2.Close()
require.Error(t, <-hanshareResCh)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("write ack instead cred", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write ack instead cred
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
require.Error(t, <-hanshareResCh)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("invalid cred", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
identity, err := IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -260,13 +294,15 @@ func TestIncomingHandshake(t *testing.T) {
require.Nil(t, msg.cred)
require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error)
require.EqualError(t, <-hanshareResCh, ErrInvalidCredentials.Error())
res := <-handshakeResCh
require.EqualError(t, res.err, ErrInvalidCredentials.Error())
})
t.Run("write cred instead ack", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -280,13 +316,15 @@ func TestIncomingHandshake(t *testing.T) {
// expect ack with error
msg, err := h.readMsg()
require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
require.Error(t, <-hanshareResCh)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("read ack err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -297,13 +335,15 @@ func TestIncomingHandshake(t *testing.T) {
require.NoError(t, err)
_ = c2.Close()
require.Error(t, <-hanshareResCh)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("write ack with invalid", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -317,13 +357,15 @@ func TestIncomingHandshake(t *testing.T) {
// write ack
require.NoError(t, h.writeAck(handshakeproto.Error_InvalidCredentials))
assert.EqualError(t, <-hanshareResCh, ErrPeerDeclinedCredentials.Error())
res := <-handshakeResCh
assert.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error())
})
t.Run("write ack with err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -337,13 +379,15 @@ func TestIncomingHandshake(t *testing.T) {
// write ack
require.NoError(t, h.writeAck(handshakeproto.Error_Unexpected))
assert.EqualError(t, <-hanshareResCh, ErrUnexpected.Error())
res := <-handshakeResCh
assert.EqualError(t, res.err, ErrUnexpected.Error())
})
t.Run("final ack error", func(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
@ -357,53 +401,65 @@ func TestIncomingHandshake(t *testing.T) {
// write ack and close conn
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
_ = c2.Close()
require.Error(t, <-hanshareResCh)
res := <-handshakeResCh
require.Error(t, res.err)
})
}
func TestNotAHandshakeMessage(t *testing.T) {
c1, c2 := newConnPair(t)
var hanshareResCh = make(chan error, 1)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
hanshareResCh <- IncomingHandshake(c1, noVerifyChecker)
identity, err := IncomingHandshake(c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
_, err := c2.Write([]byte("some unexpected bytes"))
require.Error(t, err)
assert.EqualError(t, <-hanshareResCh, ErrGotNotAHandshakeMessage.Error())
res := <-handshakeResCh
assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error())
}
func TestEndToEnd(t *testing.T) {
c1, c2 := newConnPair(t)
var (
inRes = make(chan error, 1)
outRes = make(chan error, 1)
inResCh = make(chan handshakeRes, 1)
outResCh = make(chan handshakeRes, 1)
)
st := time.Now()
go func() {
outRes <- OutgoingHandshake(c1, noVerifyChecker)
identity, err := OutgoingHandshake(c1, noVerifyChecker)
outResCh <- handshakeRes{identity: identity, err: err}
}()
go func() {
inRes <- IncomingHandshake(c2, noVerifyChecker)
identity, err := IncomingHandshake(c2, noVerifyChecker)
inResCh <- handshakeRes{identity: identity, err: err}
}()
assert.NoError(t, <-outRes)
assert.NoError(t, <-inRes)
outRes := <-outResCh
assert.NoError(t, outRes.err)
assert.NotEmpty(t, outRes.identity)
inRes := <-inResCh
assert.NoError(t, inRes.err)
assert.NotEmpty(t, inRes.identity)
t.Log("dur", time.Since(st))
}
func BenchmarkHandshake(b *testing.B) {
c1, c2 := newConnPair(b)
var (
inRes = make(chan error)
outRes = make(chan error)
inRes = make(chan struct{})
outRes = make(chan struct{})
done = make(chan struct{})
)
defer close(done)
go func() {
for {
_, _ = OutgoingHandshake(c1, noVerifyChecker)
select {
case outRes <- OutgoingHandshake(c1, noVerifyChecker):
case outRes <- struct{}{}:
case <-done:
return
}
@ -411,8 +467,9 @@ func BenchmarkHandshake(b *testing.B) {
}()
go func() {
for {
_, _ = IncomingHandshake(c2, noVerifyChecker)
select {
case inRes <- IncomingHandshake(c2, noVerifyChecker):
case inRes <- struct{}{}:
case <-done:
return
}
@ -430,7 +487,7 @@ func BenchmarkHandshake(b *testing.B) {
type testCredChecker struct {
makeCred *handshakeproto.Credentials
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error)
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
checkErr error
}
@ -438,14 +495,14 @@ func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cre
return t.makeCred
}
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) {
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
if t.checkErr != nil {
return t.checkErr
return nil, t.checkErr
}
if t.checkCred != nil {
return t.checkCred(sc, cred)
}
return nil
return nil, nil
}
func newConnPair(t require.TestingT) (sc1, sc2 *secConn) {