From 2fe9f8c295ccd61c8b0d1cda7f1078291d565c49 Mon Sep 17 00:00:00 2001 From: Sergey Cherepanov Date: Mon, 13 Feb 2023 13:45:07 +0300 Subject: [PATCH] handshake: return identity --- net/secureservice/handshake/handshake.go | 48 ++-- net/secureservice/handshake/handshake_test.go | 217 +++++++++++------- 2 files changed, 161 insertions(+), 104 deletions(-) diff --git a/net/secureservice/handshake/handshake.go b/net/secureservice/handshake/handshake.go index 8c13cb97..72fbed02 100644 --- a/net/secureservice/handshake/handshake.go +++ b/net/secureservice/handshake/handshake.go @@ -47,59 +47,59 @@ var handshakePool = &sync.Pool{New: func() any { type CredentialChecker interface { MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials - CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) + CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) } -func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) { +func OutgoingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { h := newHandshake() defer h.release() h.conn = sc localCred := cc.MakeCredentials(sc) if err = h.writeCredentials(localCred); err != nil { h.tryWriteErrAndClose(err) - return err + return } msg, err := h.readMsg() if err != nil { h.tryWriteErrAndClose(err) - return err + return } if msg.ack != nil { if msg.ack.Error == handshakeproto.Error_InvalidCredentials { - return ErrPeerDeclinedCredentials + return nil, ErrPeerDeclinedCredentials } - return handshakeError{e: msg.ack.Error} + return nil, handshakeError{e: msg.ack.Error} } - if err = cc.CheckCredential(sc, msg.cred); err != nil { + if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { h.tryWriteErrAndClose(err) - return err + return } if err = h.writeAck(handshakeproto.Error_Null); err != nil { h.tryWriteErrAndClose(err) - return err + return nil, err } msg, err = h.readMsg() if err != nil { h.tryWriteErrAndClose(err) - return err + return nil, err } if msg.ack == nil { err = ErrUnexpectedPayload h.tryWriteErrAndClose(err) - return err + return nil, err } if msg.ack.Error == handshakeproto.Error_Null { - return nil + return identity, nil } else { _ = h.conn.Close() - return handshakeError{e: msg.ack.Error} + return nil, handshakeError{e: msg.ack.Error} } } -func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) { +func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { h := newHandshake() defer h.release() h.conn = sc @@ -107,40 +107,40 @@ func IncomingHandshake(sc sec.SecureConn, cc CredentialChecker) (err error) { msg, err := h.readMsg() if err != nil { h.tryWriteErrAndClose(err) - return err + return } if msg.ack != nil { - return ErrUnexpectedPayload + return nil, ErrUnexpectedPayload } - if err = cc.CheckCredential(sc, msg.cred); err != nil { + if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { h.tryWriteErrAndClose(err) - return err + return } if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil { h.tryWriteErrAndClose(err) - return err + return nil, err } msg, err = h.readMsg() if err != nil { h.tryWriteErrAndClose(err) - return err + return nil, err } if msg.ack == nil { err = ErrUnexpectedPayload h.tryWriteErrAndClose(err) - return err + return nil, err } if msg.ack.Error != handshakeproto.Error_Null { if msg.ack.Error == handshakeproto.Error_InvalidCredentials { - return ErrPeerDeclinedCredentials + return nil, ErrPeerDeclinedCredentials } - return handshakeError{e: msg.ack.Error} + return nil, handshakeError{e: msg.ack.Error} } if err = h.writeAck(handshakeproto.Error_Null); err != nil { h.tryWriteErrAndClose(err) - return err + return nil, err } return } diff --git a/net/secureservice/handshake/handshake_test.go b/net/secureservice/handshake/handshake_test.go index 63292428..d8bd8f21 100644 --- a/net/secureservice/handshake/handshake_test.go +++ b/net/secureservice/handshake/handshake_test.go @@ -18,17 +18,23 @@ import ( var noVerifyChecker = &testCredChecker{ makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify}, - checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) { - return + checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { + return []byte("identity"), nil }, } +type handshakeRes struct { + identity []byte + err error +} + func TestOutgoingHandshake(t *testing.T) { t.Run("success", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -36,7 +42,8 @@ func TestOutgoingHandshake(t *testing.T) { msg, err := h.readMsg() require.NoError(t, err) require.Nil(t, msg.ack) - require.NoError(t, noVerifyChecker.CheckCredential(c2, msg.cred)) + _, err = noVerifyChecker.CheckCredential(c2, msg.cred) + require.NoError(t, err) // send credential message require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // receive ack @@ -45,23 +52,27 @@ func TestOutgoingHandshake(t *testing.T) { require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) // send ack require.NoError(t, h.writeAck(handshakeproto.Error_Null)) - resErr := <-hanshareResCh - assert.NoError(t, resErr) + res := <-handshakeResCh + assert.NotEmpty(t, res.identity) + assert.NoError(t, res.err) }) t.Run("write cred err", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() _ = c2.Close() - require.Error(t, <-hanshareResCh) + res := <-handshakeResCh + require.Error(t, res.err) }) t.Run("read cred err", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -69,13 +80,15 @@ func TestOutgoingHandshake(t *testing.T) { _, err := h.readMsg() require.NoError(t, err) _ = c2.Close() - require.Error(t, <-hanshareResCh) + res := <-handshakeResCh + require.Error(t, res.err) }) t.Run("ack err", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -83,13 +96,15 @@ func TestOutgoingHandshake(t *testing.T) { _, err := h.readMsg() require.NoError(t, err) require.NoError(t, h.writeAck(ErrInvalidCredentials.e)) - require.EqualError(t, <-hanshareResCh, ErrPeerDeclinedCredentials.Error()) + res := <-handshakeResCh + require.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error()) }) t.Run("cred err", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) + identity, err := OutgoingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -100,13 +115,15 @@ func TestOutgoingHandshake(t *testing.T) { msg, err := h.readMsg() require.NoError(t, err) assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error) - require.EqualError(t, <-hanshareResCh, ErrInvalidCredentials.Error()) + res := <-handshakeResCh + require.EqualError(t, res.err, ErrInvalidCredentials.Error()) }) t.Run("write ack err", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -116,13 +133,15 @@ func TestOutgoingHandshake(t *testing.T) { // write credentials and close conn require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) _ = c2.Close() - require.Error(t, <-hanshareResCh) + res := <-handshakeResCh + require.Error(t, res.err) }) t.Run("read ack err", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -135,13 +154,15 @@ func TestOutgoingHandshake(t *testing.T) { _, err = h.readMsg() require.NoError(t, err) _ = c2.Close() - require.Error(t, <-hanshareResCh) + res := <-handshakeResCh + require.Error(t, res.err) }) t.Run("write cred instead ack", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -158,13 +179,15 @@ func TestOutgoingHandshake(t *testing.T) { msg, err := h.readMsg() require.NoError(t, err) assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) - require.Error(t, <-hanshareResCh) + res := <-handshakeResCh + require.Error(t, res.err) }) t.Run("final ack error", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -172,7 +195,8 @@ func TestOutgoingHandshake(t *testing.T) { msg, err := h.readMsg() require.NoError(t, err) require.Nil(t, msg.ack) - require.NoError(t, noVerifyChecker.CheckCredential(c2, msg.cred)) + _, err = noVerifyChecker.CheckCredential(c2, msg.cred) + require.NoError(t, err) // send credential message require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // receive ack @@ -181,17 +205,18 @@ func TestOutgoingHandshake(t *testing.T) { require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) // send ack require.NoError(t, h.writeAck(handshakeproto.Error_UnexpectedPayload)) - resErr := <-hanshareResCh - assert.Error(t, resErr) + res := <-handshakeResCh + require.Error(t, res.err) }) } func TestIncomingHandshake(t *testing.T) { t.Run("success", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -208,47 +233,56 @@ func TestIncomingHandshake(t *testing.T) { msg, err = h.readMsg() require.NoError(t, err) assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error) - require.NoError(t, <-hanshareResCh) + res := <-handshakeResCh + assert.NotEmpty(t, res.identity) + require.NoError(t, res.err) }) t.Run("write cred err", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() _ = c2.Close() - require.Error(t, <-hanshareResCh) + res := <-handshakeResCh + require.Error(t, res.err) }) t.Run("read cred err", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials and close conn require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) _ = c2.Close() - require.Error(t, <-hanshareResCh) + res := <-handshakeResCh + require.Error(t, res.err) }) t.Run("write ack instead cred", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write ack instead cred require.NoError(t, h.writeAck(handshakeproto.Error_Null)) - require.Error(t, <-hanshareResCh) + res := <-handshakeResCh + require.Error(t, res.err) }) t.Run("invalid cred", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) + identity, err := IncomingHandshake(c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -260,13 +294,15 @@ func TestIncomingHandshake(t *testing.T) { require.Nil(t, msg.cred) require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error) - require.EqualError(t, <-hanshareResCh, ErrInvalidCredentials.Error()) + res := <-handshakeResCh + require.EqualError(t, res.err, ErrInvalidCredentials.Error()) }) t.Run("write cred instead ack", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -280,13 +316,15 @@ func TestIncomingHandshake(t *testing.T) { // expect ack with error msg, err := h.readMsg() require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) - require.Error(t, <-hanshareResCh) + res := <-handshakeResCh + require.Error(t, res.err) }) t.Run("read ack err", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -297,13 +335,15 @@ func TestIncomingHandshake(t *testing.T) { require.NoError(t, err) _ = c2.Close() - require.Error(t, <-hanshareResCh) + res := <-handshakeResCh + require.Error(t, res.err) }) t.Run("write ack with invalid", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -317,13 +357,15 @@ func TestIncomingHandshake(t *testing.T) { // write ack require.NoError(t, h.writeAck(handshakeproto.Error_InvalidCredentials)) - assert.EqualError(t, <-hanshareResCh, ErrPeerDeclinedCredentials.Error()) + res := <-handshakeResCh + assert.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error()) }) t.Run("write ack with err", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -337,13 +379,15 @@ func TestIncomingHandshake(t *testing.T) { // write ack require.NoError(t, h.writeAck(handshakeproto.Error_Unexpected)) - assert.EqualError(t, <-hanshareResCh, ErrUnexpected.Error()) + res := <-handshakeResCh + assert.EqualError(t, res.err, ErrUnexpected.Error()) }) t.Run("final ack error", func(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 @@ -357,53 +401,65 @@ func TestIncomingHandshake(t *testing.T) { // write ack and close conn require.NoError(t, h.writeAck(handshakeproto.Error_Null)) _ = c2.Close() - require.Error(t, <-hanshareResCh) + res := <-handshakeResCh + require.Error(t, res.err) }) } func TestNotAHandshakeMessage(t *testing.T) { c1, c2 := newConnPair(t) - var hanshareResCh = make(chan error, 1) + var handshakeResCh = make(chan handshakeRes, 1) go func() { - hanshareResCh <- IncomingHandshake(c1, noVerifyChecker) + identity, err := IncomingHandshake(c1, noVerifyChecker) + handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 _, err := c2.Write([]byte("some unexpected bytes")) require.Error(t, err) - assert.EqualError(t, <-hanshareResCh, ErrGotNotAHandshakeMessage.Error()) + res := <-handshakeResCh + assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error()) } func TestEndToEnd(t *testing.T) { c1, c2 := newConnPair(t) var ( - inRes = make(chan error, 1) - outRes = make(chan error, 1) + inResCh = make(chan handshakeRes, 1) + outResCh = make(chan handshakeRes, 1) ) st := time.Now() go func() { - outRes <- OutgoingHandshake(c1, noVerifyChecker) + identity, err := OutgoingHandshake(c1, noVerifyChecker) + outResCh <- handshakeRes{identity: identity, err: err} }() go func() { - inRes <- IncomingHandshake(c2, noVerifyChecker) + identity, err := IncomingHandshake(c2, noVerifyChecker) + inResCh <- handshakeRes{identity: identity, err: err} }() - assert.NoError(t, <-outRes) - assert.NoError(t, <-inRes) + + outRes := <-outResCh + assert.NoError(t, outRes.err) + assert.NotEmpty(t, outRes.identity) + + inRes := <-inResCh + assert.NoError(t, inRes.err) + assert.NotEmpty(t, inRes.identity) t.Log("dur", time.Since(st)) } func BenchmarkHandshake(b *testing.B) { c1, c2 := newConnPair(b) var ( - inRes = make(chan error) - outRes = make(chan error) + inRes = make(chan struct{}) + outRes = make(chan struct{}) done = make(chan struct{}) ) defer close(done) go func() { for { + _, _ = OutgoingHandshake(c1, noVerifyChecker) select { - case outRes <- OutgoingHandshake(c1, noVerifyChecker): + case outRes <- struct{}{}: case <-done: return } @@ -411,8 +467,9 @@ func BenchmarkHandshake(b *testing.B) { }() go func() { for { + _, _ = IncomingHandshake(c2, noVerifyChecker) select { - case inRes <- IncomingHandshake(c2, noVerifyChecker): + case inRes <- struct{}{}: case <-done: return } @@ -430,7 +487,7 @@ func BenchmarkHandshake(b *testing.B) { type testCredChecker struct { makeCred *handshakeproto.Credentials - checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) + checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) checkErr error } @@ -438,14 +495,14 @@ func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cre return t.makeCred } -func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (err error) { +func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { if t.checkErr != nil { - return t.checkErr + return nil, t.checkErr } if t.checkCred != nil { return t.checkCred(sc, cred) } - return nil + return nil, nil } func newConnPair(t require.TestingT) (sc1, sc2 *secConn) {