client version to handshake

This commit is contained in:
Sergey Cherepanov 2023-06-13 15:30:10 +02:00
parent ba7cffb51a
commit c753da8def
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C
9 changed files with 219 additions and 121 deletions

View File

@ -14,6 +14,7 @@ const (
contextKeyPeerId contextKey = iota contextKeyPeerId contextKey = iota
contextKeyIdentity contextKeyIdentity
contextKeyPeerAddr contextKeyPeerAddr
contextKeyPeerClientVersion
) )
var ( var (
@ -50,6 +51,19 @@ func CtxWithPeerAddr(ctx context.Context, addr string) context.Context {
return context.WithValue(ctx, contextKeyPeerAddr, addr) return context.WithValue(ctx, contextKeyPeerAddr, addr)
} }
// CtxPeerClientVersion returns peer client version
func CtxPeerClientVersion(ctx context.Context) string {
if p, ok := ctx.Value(contextKeyPeerClientVersion).(string); ok {
return p
}
return ""
}
// CtxWithClientVersion sets peer clientVersion to the context
func CtxWithClientVersion(ctx context.Context, addr string) context.Context {
return context.WithValue(ctx, contextKeyPeerClientVersion, addr)
}
// CtxIdentity returns identity from context // CtxIdentity returns identity from context
func CtxIdentity(ctx context.Context) ([]byte, error) { func CtxIdentity(ctx context.Context) ([]byte, error) {
if identity, ok := ctx.Value(contextKeyIdentity).([]byte); ok { if identity, ok := ctx.Value(contextKeyIdentity).([]byte); ok {

View File

@ -22,11 +22,15 @@ func (n noVerifyChecker) MakeCredentials(remotePeerId string) *handshakeproto.Cr
return n.cred return n.cred
} }
func (n noVerifyChecker) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { func (n noVerifyChecker) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (result handshake.Result, err error) {
if cred.Version != n.cred.Version { if cred.Version != n.cred.Version {
return nil, handshake.ErrIncompatibleVersion err = handshake.ErrIncompatibleVersion
return
} }
return nil, nil return handshake.Result{
ProtoVersion: cred.Version,
ClientVersion: cred.ClientVersion,
}, nil
} }
func newPeerSignVerifier(protoVersion uint32, account *accountdata.AccountKeys) handshake.CredentialChecker { func newPeerSignVerifier(protoVersion uint32, account *accountdata.AccountKeys) handshake.CredentialChecker {
@ -60,27 +64,36 @@ func (p *peerSignVerifier) MakeCredentials(remotePeerId string) *handshakeproto.
} }
} }
func (p *peerSignVerifier) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { func (p *peerSignVerifier) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (result handshake.Result, err error) {
if cred.Version != p.protoVersion { if cred.Version != p.protoVersion {
return nil, handshake.ErrIncompatibleVersion err = handshake.ErrIncompatibleVersion
return
} }
if cred.Type != handshakeproto.CredentialsType_SignedPeerIds { if cred.Type != handshakeproto.CredentialsType_SignedPeerIds {
return nil, handshake.ErrSkipVerifyNotAllowed err = handshake.ErrSkipVerifyNotAllowed
return
} }
var msg = &handshakeproto.PayloadSignedPeerIds{} var msg = &handshakeproto.PayloadSignedPeerIds{}
if err = msg.Unmarshal(cred.Payload); err != nil { if err = msg.Unmarshal(cred.Payload); err != nil {
return nil, handshake.ErrUnexpectedPayload err = handshake.ErrUnexpectedPayload
return
} }
pubKey, err := crypto.UnmarshalEd25519PublicKeyProto(msg.Identity) pubKey, err := crypto.UnmarshalEd25519PublicKeyProto(msg.Identity)
if err != nil { if err != nil {
return nil, handshake.ErrInvalidCredentials err = handshake.ErrInvalidCredentials
return
} }
ok, err := pubKey.Verify([]byte((remotePeerId + p.account.PeerId)), msg.Sign) ok, err := pubKey.Verify([]byte((remotePeerId + p.account.PeerId)), msg.Sign)
if err != nil { if err != nil {
return nil, err return
} }
if !ok { if !ok {
return nil, handshake.ErrInvalidCredentials err = handshake.ErrInvalidCredentials
return
} }
return msg.Identity, nil return handshake.Result{
Identity: msg.Identity,
ProtoVersion: cred.Version,
ClientVersion: cred.ClientVersion,
}, nil
} }

View File

@ -23,13 +23,13 @@ func TestPeerSignVerifier_CheckCredential(t *testing.T) {
cr1 := cc1.MakeCredentials(c1) cr1 := cc1.MakeCredentials(c1)
cr2 := cc2.MakeCredentials(c2) cr2 := cc2.MakeCredentials(c2)
id1, err := cc1.CheckCredential(c1, cr2) res, err := cc1.CheckCredential(c1, cr2)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, identity2, id1) assert.Equal(t, identity2, res.Identity)
id2, err := cc2.CheckCredential(c2, cr1) res2, err := cc2.CheckCredential(c2, cr1)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, identity1, id2) assert.Equal(t, identity1, res2.Identity)
_, err = cc1.CheckCredential(c1, cr1) _, err = cc1.CheckCredential(c1, cr1)
assert.EqualError(t, err, handshake.ErrInvalidCredentials.Error()) assert.EqualError(t, err, handshake.ErrInvalidCredentials.Error())

View File

@ -6,30 +6,31 @@ import (
"io" "io"
) )
func OutgoingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { func OutgoingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (result Result, err error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
h := newHandshake() h := newHandshake()
done := make(chan struct{}) done := make(chan struct{})
var ( var (
resIdentity []byte res Result
resErr error resErr error
) )
go func() { go func() {
defer close(done) defer close(done)
resIdentity, resErr = outgoingHandshake(h, conn, peerId, cc) res, resErr = outgoingHandshake(h, conn, peerId, cc)
}() }()
select { select {
case <-done: case <-done:
return resIdentity, resErr return res, resErr
case <-ctx.Done(): case <-ctx.Done():
_ = conn.Close() _ = conn.Close()
return nil, ctx.Err() err = ctx.Err()
return
} }
} }
func outgoingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { func outgoingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (result Result, err error) {
defer h.release() defer h.release()
h.conn = conn h.conn = conn
localCred := cc.MakeCredentials(peerId) localCred := cc.MakeCredentials(peerId)
@ -44,58 +45,62 @@ func outgoingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc
} }
if msg.ack != nil { if msg.ack != nil {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials { if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials err = ErrPeerDeclinedCredentials
return
} }
return nil, HandshakeError{e: msg.ack.Error} err = HandshakeError{e: msg.ack.Error}
return
} }
if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil { if result, err = cc.CheckCredential(peerId, msg.cred); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return return
} }
if err = h.writeAck(handshakeproto.Error_Null); err != nil { if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return nil, err return
} }
msg, err = h.readMsg(msgTypeAck) msg, err = h.readMsg(msgTypeAck)
if err != nil { if err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return nil, err return
} }
if msg.ack.Error == handshakeproto.Error_Null { if msg.ack.Error == handshakeproto.Error_Null {
return identity, nil return result, nil
} else { } else {
_ = h.conn.Close() _ = h.conn.Close()
return nil, HandshakeError{e: msg.ack.Error} err = HandshakeError{e: msg.ack.Error}
return
} }
} }
func IncomingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { func IncomingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (result Result, err error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
h := newHandshake() h := newHandshake()
done := make(chan struct{}) done := make(chan struct{})
var ( var (
resIdentity []byte res Result
resError error resError error
) )
go func() { go func() {
defer close(done) defer close(done)
resIdentity, resError = incomingHandshake(h, conn, peerId, cc) res, resError = incomingHandshake(h, conn, peerId, cc)
}() }()
select { select {
case <-done: case <-done:
return resIdentity, resError return res, resError
case <-ctx.Done(): case <-ctx.Done():
_ = conn.Close() _ = conn.Close()
return nil, ctx.Err() err = ctx.Err()
return
} }
} }
func incomingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { func incomingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (result Result, err error) {
defer h.release() defer h.release()
h.conn = conn h.conn = conn
@ -104,30 +109,32 @@ func incomingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return return
} }
if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil { if result, err = cc.CheckCredential(peerId, msg.cred); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return return
} }
if err = h.writeCredentials(cc.MakeCredentials(peerId)); err != nil { if err = h.writeCredentials(cc.MakeCredentials(peerId)); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return nil, err return
} }
msg, err = h.readMsg(msgTypeAck) msg, err = h.readMsg(msgTypeAck)
if err != nil { if err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return nil, err return
} }
if msg.ack.Error != handshakeproto.Error_Null { if msg.ack.Error != handshakeproto.Error_Null {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials { if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials err = ErrPeerDeclinedCredentials
return
} }
return nil, HandshakeError{e: msg.ack.Error} err = HandshakeError{e: msg.ack.Error}
return
} }
if err = h.writeAck(handshakeproto.Error_Null); err != nil { if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return nil, err return
} }
return return
} }

View File

@ -15,14 +15,18 @@ import (
) )
var noVerifyChecker = &testCredChecker{ var noVerifyChecker = &testCredChecker{
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify}, makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify, ClientVersion: "test:v1.0"},
checkCred: func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { checkCred: func(peerId string, cred *handshakeproto.Credentials) (res Result, err error) {
return []byte("identity"), nil return Result{
Identity: []byte("identity"),
ProtoVersion: cred.Version,
ClientVersion: cred.ClientVersion,
}, nil
}, },
} }
type handshakeRes struct { type handshakeRes struct {
identity []byte res Result
err error err error
} }
@ -32,7 +36,7 @@ func TestOutgoingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -50,7 +54,7 @@ func TestOutgoingHandshake(t *testing.T) {
// send ack // send ack
require.NoError(t, h.writeAck(handshakeproto.Error_Null)) require.NoError(t, h.writeAck(handshakeproto.Error_Null))
res := <-handshakeResCh res := <-handshakeResCh
assert.NotEmpty(t, res.identity) assert.NotEmpty(t, res.res)
assert.NoError(t, res.err) assert.NoError(t, res.err)
}) })
t.Run("write cred err", func(t *testing.T) { t.Run("write cred err", func(t *testing.T) {
@ -58,7 +62,7 @@ func TestOutgoingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
_ = c2.Close() _ = c2.Close()
res := <-handshakeResCh res := <-handshakeResCh
@ -69,7 +73,7 @@ func TestOutgoingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -85,7 +89,7 @@ func TestOutgoingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -101,7 +105,7 @@ func TestOutgoingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) identity, err := OutgoingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -120,7 +124,7 @@ func TestOutgoingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -138,7 +142,7 @@ func TestOutgoingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -159,7 +163,7 @@ func TestOutgoingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -172,9 +176,7 @@ func TestOutgoingHandshake(t *testing.T) {
_, err = h.readMsg(msgTypeAck) _, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
// write cred instead ack // write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) _ = h.writeCredentials(noVerifyChecker.MakeCredentials(""))
_, err = h.readMsg(msgTypeAck)
require.Error(t, err)
res := <-handshakeResCh res := <-handshakeResCh
require.Error(t, res.err) require.Error(t, res.err)
}) })
@ -183,7 +185,7 @@ func TestOutgoingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -211,7 +213,7 @@ func TestOutgoingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(ctx, c1, "", noVerifyChecker) identity, err := OutgoingHandshake(ctx, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -234,7 +236,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -252,7 +254,7 @@ func TestIncomingHandshake(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error) assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
res := <-handshakeResCh res := <-handshakeResCh
assert.NotEmpty(t, res.identity) assert.NotEmpty(t, res.res)
require.NoError(t, res.err) require.NoError(t, res.err)
}) })
t.Run("write cred err", func(t *testing.T) { t.Run("write cred err", func(t *testing.T) {
@ -260,7 +262,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
_ = c2.Close() _ = c2.Close()
res := <-handshakeResCh res := <-handshakeResCh
@ -271,7 +273,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -286,7 +288,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -300,7 +302,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -320,7 +322,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion}) identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion})
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -340,7 +342,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -350,7 +352,7 @@ func TestIncomingHandshake(t *testing.T) {
_, err := h.readMsg(msgTypeCred) _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write cred instead ack // write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) _ = h.writeCredentials(noVerifyChecker.MakeCredentials(""))
// expect EOF // expect EOF
_, err = h.readMsg(msgTypeAck) _, err = h.readMsg(msgTypeAck)
require.Error(t, err) require.Error(t, err)
@ -362,7 +364,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -381,7 +383,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -403,7 +405,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -425,7 +427,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -448,7 +450,7 @@ func TestIncomingHandshake(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(ctx, c1, "", noVerifyChecker) identity, err := IncomingHandshake(ctx, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -472,7 +474,7 @@ func TestNotAHandshakeMessage(t *testing.T) {
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{res: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
@ -491,20 +493,20 @@ func TestEndToEnd(t *testing.T) {
st := time.Now() st := time.Now()
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
outResCh <- handshakeRes{identity: identity, err: err} outResCh <- handshakeRes{res: identity, err: err}
}() }()
go func() { go func() {
identity, err := IncomingHandshake(nil, c2, "", noVerifyChecker) identity, err := IncomingHandshake(nil, c2, "", noVerifyChecker)
inResCh <- handshakeRes{identity: identity, err: err} inResCh <- handshakeRes{res: identity, err: err}
}() }()
outRes := <-outResCh outRes := <-outResCh
assert.NoError(t, outRes.err) assert.NoError(t, outRes.err)
assert.NotEmpty(t, outRes.identity) assert.NotEmpty(t, outRes.res)
inRes := <-inResCh inRes := <-inResCh
assert.NoError(t, inRes.err) assert.NoError(t, inRes.err)
assert.NotEmpty(t, inRes.identity) assert.NotEmpty(t, inRes.res)
t.Log("dur", time.Since(st)) t.Log("dur", time.Since(st))
} }
@ -548,7 +550,7 @@ func BenchmarkHandshake(b *testing.B) {
type testCredChecker struct { type testCredChecker struct {
makeCred *handshakeproto.Credentials makeCred *handshakeproto.Credentials
checkCred func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) checkCred func(peerId string, cred *handshakeproto.Credentials) (res Result, err error)
checkErr error checkErr error
} }
@ -556,14 +558,15 @@ func (t *testCredChecker) MakeCredentials(peerId string) *handshakeproto.Credent
return t.makeCred return t.makeCred
} }
func (t *testCredChecker) CheckCredential(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { func (t *testCredChecker) CheckCredential(peerId string, cred *handshakeproto.Credentials) (res Result, err error) {
if t.checkErr != nil { if t.checkErr != nil {
return nil, t.checkErr err = t.checkErr
return
} }
if t.checkCred != nil { if t.checkCred != nil {
return t.checkCred(peerId, cred) return t.checkCred(peerId, cred)
} }
return nil, nil return
} }
func newConnPair(t require.TestingT) (sc1, sc2 *secConn) { func newConnPair(t require.TestingT) (sc1, sc2 *secConn) {

View File

@ -64,7 +64,13 @@ var handshakePool = &sync.Pool{New: func() any {
type CredentialChecker interface { type CredentialChecker interface {
MakeCredentials(remotePeerId string) *handshakeproto.Credentials MakeCredentials(remotePeerId string) *handshakeproto.Credentials
CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (result Result, err error)
}
type Result struct {
Identity []byte
ProtoVersion uint32
ClientVersion string
} }
func newHandshake() *handshake { func newHandshake() *handshake {

View File

@ -118,6 +118,7 @@ type Credentials struct {
Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"` Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"`
Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
Version uint32 `protobuf:"varint,3,opt,name=version,proto3" json:"version,omitempty"` Version uint32 `protobuf:"varint,3,opt,name=version,proto3" json:"version,omitempty"`
ClientVersion string `protobuf:"bytes,4,opt,name=clientVersion,proto3" json:"clientVersion,omitempty"`
} }
func (m *Credentials) Reset() { *m = Credentials{} } func (m *Credentials) Reset() { *m = Credentials{} }
@ -174,6 +175,13 @@ func (m *Credentials) GetVersion() uint32 {
return 0 return 0
} }
func (m *Credentials) GetClientVersion() string {
if m != nil {
return m.ClientVersion
}
return ""
}
type PayloadSignedPeerIds struct { type PayloadSignedPeerIds struct {
// account identity // account identity
Identity []byte `protobuf:"bytes,1,opt,name=identity,proto3" json:"identity,omitempty"` Identity []byte `protobuf:"bytes,1,opt,name=identity,proto3" json:"identity,omitempty"`
@ -331,35 +339,36 @@ func init() {
} }
var fileDescriptor_60283fc75f020893 = []byte{ var fileDescriptor_60283fc75f020893 = []byte{
// 439 bytes of a gzipped FileDescriptorProto // 457 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcf, 0x6e, 0xd3, 0x40, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x52, 0x4d, 0x6f, 0xd3, 0x40,
0x10, 0xc6, 0xbd, 0x4d, 0xdc, 0x96, 0x21, 0x2d, 0xdb, 0x6d, 0x4a, 0x2d, 0x24, 0xac, 0x28, 0xa7, 0x10, 0xf5, 0x36, 0x71, 0x3f, 0x86, 0xa4, 0x6c, 0xb7, 0x29, 0xb5, 0x90, 0xb0, 0xa2, 0x88, 0x43,
0x10, 0x89, 0x84, 0x7f, 0xe2, 0x1e, 0x9a, 0x22, 0x72, 0xa9, 0x22, 0x17, 0x7a, 0xe0, 0xb6, 0xf5, 0x88, 0x44, 0xc2, 0x97, 0xb8, 0x87, 0xa6, 0x88, 0x5c, 0xaa, 0xc8, 0x85, 0x1e, 0xb8, 0x6d, 0xbd,
0x0e, 0xed, 0x2a, 0xcb, 0xda, 0xda, 0xdd, 0x86, 0xfa, 0x2d, 0x78, 0x14, 0x1e, 0x83, 0x63, 0x8f, 0x43, 0xbb, 0xca, 0xb2, 0xb6, 0xd6, 0xdb, 0x50, 0xff, 0x0b, 0xce, 0xfc, 0x0a, 0x7e, 0x06, 0xc7,
0x1c, 0x51, 0xf2, 0x22, 0xc8, 0x9b, 0xa4, 0x71, 0x38, 0xf5, 0x62, 0xef, 0xcc, 0xfc, 0xc6, 0xdf, 0x1e, 0x39, 0xa2, 0xe4, 0x8f, 0x20, 0x6f, 0x9c, 0xc6, 0xe1, 0xc4, 0xc5, 0xde, 0x99, 0xf7, 0x66,
0x7c, 0xe3, 0x85, 0x81, 0x46, 0xd7, 0xb7, 0x98, 0xde, 0x18, 0xb4, 0x68, 0xa6, 0x32, 0xc5, 0xfe, 0xe7, 0xbd, 0x67, 0xc3, 0x50, 0xa3, 0x1d, 0x64, 0x18, 0xdf, 0x18, 0xcc, 0xd0, 0xcc, 0x64, 0x8c,
0x35, 0xd7, 0xc2, 0x5e, 0xf3, 0x49, 0xe5, 0x94, 0x9b, 0xcc, 0x65, 0x7d, 0xff, 0xb4, 0xeb, 0x6c, 0x83, 0x6b, 0xae, 0x45, 0x76, 0xcd, 0xa7, 0x95, 0x53, 0x6a, 0x12, 0x9b, 0x0c, 0xdc, 0x33, 0x5b,
0xcf, 0x27, 0x58, 0x83, 0xeb, 0xe2, 0xd3, 0x2a, 0xd7, 0x76, 0xf0, 0xf8, 0xc4, 0xa0, 0x40, 0xed, 0x77, 0xfb, 0xae, 0xc1, 0x1a, 0x5c, 0xe7, 0x1f, 0x56, 0xbd, 0xce, 0x0f, 0x02, 0x0f, 0x4e, 0x0c,
0x24, 0x57, 0x96, 0xbd, 0x86, 0xba, 0x2b, 0x72, 0x8c, 0x48, 0x8b, 0x74, 0xf6, 0xdf, 0x3c, 0xef, 0x0a, 0xd4, 0x56, 0x72, 0x95, 0xb1, 0x97, 0x50, 0xb7, 0x79, 0x8a, 0x01, 0x69, 0x93, 0xee, 0xfe,
0x55, 0xd9, 0x5e, 0x05, 0xfc, 0x5c, 0xe4, 0x98, 0x78, 0x94, 0x45, 0xb0, 0x93, 0xf3, 0x42, 0x65, 0xab, 0x27, 0xfd, 0x2a, 0xb9, 0x5f, 0x21, 0x7e, 0xcc, 0x53, 0x8c, 0x1c, 0x95, 0x05, 0xb0, 0x93,
0x5c, 0x44, 0x5b, 0x2d, 0xd2, 0x69, 0x24, 0xab, 0xb0, 0xac, 0x4c, 0xd1, 0x58, 0x99, 0xe9, 0xa8, 0xf2, 0x5c, 0x25, 0x5c, 0x04, 0x5b, 0x6d, 0xd2, 0x6d, 0x44, 0xab, 0xb2, 0x40, 0x66, 0x68, 0x32,
0xd6, 0x22, 0x9d, 0xbd, 0x64, 0x15, 0xb6, 0x3f, 0x42, 0x73, 0xbc, 0x80, 0xce, 0xe5, 0x95, 0x46, 0x99, 0xe8, 0xa0, 0xd6, 0x26, 0xdd, 0x66, 0xb4, 0x2a, 0xd9, 0x53, 0x68, 0xc6, 0x4a, 0xa2, 0xb6,
0x31, 0x46, 0x34, 0x23, 0x61, 0xd9, 0x33, 0xd8, 0x95, 0x5e, 0xc2, 0x15, 0x7e, 0x84, 0x46, 0x72, 0x17, 0x25, 0x5e, 0x6f, 0x93, 0xee, 0x5e, 0xb4, 0xd9, 0xec, 0xbc, 0x87, 0xd6, 0x64, 0x79, 0xd5,
0x1f, 0x33, 0x06, 0x75, 0x2b, 0xaf, 0xf4, 0x52, 0xc4, 0x9f, 0xdb, 0xaf, 0xa0, 0x36, 0x48, 0x27, 0xb9, 0xbc, 0xd2, 0x28, 0x26, 0x88, 0x66, 0x2c, 0x32, 0xf6, 0x18, 0x76, 0xa5, 0x13, 0x62, 0x73,
0xec, 0x05, 0x84, 0x68, 0x4c, 0x66, 0x96, 0x63, 0x1f, 0x6e, 0x8e, 0x7d, 0x5a, 0x96, 0x92, 0x05, 0x27, 0xb4, 0x11, 0xdd, 0xd7, 0x8c, 0x41, 0x3d, 0x93, 0x57, 0xba, 0x94, 0xe2, 0xce, 0x9d, 0x17,
0xd1, 0x7e, 0x0f, 0xe1, 0xd8, 0xaf, 0xe1, 0x25, 0x84, 0x7e, 0x1f, 0xcb, 0x9e, 0xe3, 0xcd, 0x1e, 0x50, 0x1b, 0xc6, 0x53, 0xf6, 0x0c, 0x7c, 0x34, 0x26, 0x31, 0xa5, 0xb9, 0xc3, 0x4d, 0x73, 0xa7,
0xcf, 0x78, 0x93, 0x0b, 0xaa, 0xfb, 0x0e, 0x9e, 0xfc, 0x67, 0x9f, 0xed, 0x03, 0x9c, 0x4f, 0x64, 0x05, 0x14, 0x2d, 0x19, 0x9d, 0xb7, 0xe0, 0x4f, 0x5c, 0x5a, 0xcf, 0xc1, 0x77, 0xb1, 0x95, 0x33,
0x7e, 0x81, 0x46, 0x7e, 0x2b, 0x68, 0xc0, 0x0e, 0x60, 0x6f, 0xc3, 0x0d, 0x25, 0xdd, 0x5f, 0x04, 0xc7, 0x9b, 0x33, 0x8e, 0xe3, 0xa2, 0x58, 0xb2, 0x7a, 0x6f, 0xe0, 0xe1, 0x3f, 0x21, 0xb1, 0x7d,
0x42, 0x2f, 0xcf, 0x76, 0xa1, 0x7e, 0x76, 0xa3, 0x14, 0x0d, 0xca, 0xb6, 0x2f, 0x1a, 0x6f, 0x73, 0x80, 0xf3, 0xa9, 0x4c, 0x2f, 0xd0, 0xc8, 0x2f, 0x39, 0xf5, 0xd8, 0x01, 0x34, 0x37, 0xdc, 0x50,
0x4c, 0x1d, 0x0a, 0x4a, 0xd8, 0x53, 0x60, 0x23, 0x3d, 0xe5, 0x4a, 0x8a, 0x8a, 0x00, 0xdd, 0x62, 0xd2, 0xfb, 0x49, 0xc0, 0x77, 0xeb, 0xd9, 0x2e, 0xd4, 0xcf, 0x6e, 0x94, 0xa2, 0x5e, 0x31, 0xf6,
0x47, 0x70, 0xb0, 0xe6, 0x96, 0xdb, 0xa2, 0x35, 0x16, 0x41, 0x73, 0xad, 0x7a, 0x96, 0xb9, 0x81, 0x49, 0xe3, 0x6d, 0x8a, 0xb1, 0x45, 0x41, 0x09, 0x7b, 0x04, 0x6c, 0xac, 0x67, 0x5c, 0x49, 0x51,
0x52, 0xd9, 0x0f, 0x14, 0xb4, 0xce, 0x9a, 0x40, 0x87, 0xc8, 0x85, 0x92, 0x1a, 0x4f, 0x6f, 0x53, 0x59, 0x40, 0xb7, 0xd8, 0x11, 0x1c, 0xac, 0x79, 0x65, 0x5a, 0xb4, 0xc6, 0x02, 0x68, 0xad, 0xb7,
0x44, 0x81, 0x82, 0x86, 0xec, 0x18, 0x0e, 0x47, 0x3a, 0xcd, 0xbe, 0xe7, 0xdc, 0xc9, 0x4b, 0x85, 0x9e, 0x25, 0x76, 0xa8, 0x54, 0xf2, 0x0d, 0x05, 0xad, 0xb3, 0x16, 0xd0, 0x11, 0x72, 0xa1, 0xa4,
0x17, 0x8b, 0x3f, 0x40, 0xb7, 0xcb, 0xef, 0x57, 0x0b, 0xde, 0x31, 0xdd, 0xe9, 0x1e, 0xc1, 0xa3, 0xc6, 0xd3, 0xdb, 0x18, 0x51, 0xa0, 0xa0, 0x3e, 0x3b, 0x86, 0xc3, 0xb1, 0x8e, 0x93, 0xaf, 0x29,
0x7b, 0xf3, 0xe5, 0xd4, 0xc3, 0x64, 0x7c, 0x42, 0x83, 0x0f, 0xc3, 0xdf, 0xb3, 0x98, 0xdc, 0xcd, 0xb7, 0xf2, 0x52, 0x61, 0xf9, 0x05, 0xe8, 0x76, 0x71, 0x7f, 0x15, 0x70, 0x8e, 0xe9, 0x4e, 0xef,
0x62, 0xf2, 0x77, 0x16, 0x93, 0x9f, 0xf3, 0x38, 0xb8, 0x9b, 0xc7, 0xc1, 0x9f, 0x79, 0x1c, 0x7c, 0x08, 0xf6, 0xee, 0xcd, 0x17, 0xaa, 0x47, 0xd1, 0xe4, 0x84, 0x7a, 0xef, 0x46, 0xbf, 0xe6, 0x21,
0xed, 0x3e, 0xfc, 0x4a, 0x5e, 0x6e, 0xfb, 0xd7, 0xdb, 0x7f, 0x01, 0x00, 0x00, 0xff, 0xff, 0x53, 0xb9, 0x9b, 0x87, 0xe4, 0xcf, 0x3c, 0x24, 0xdf, 0x17, 0xa1, 0x77, 0xb7, 0x08, 0xbd, 0xdf, 0x8b,
0x32, 0xf7, 0x79, 0xc7, 0x02, 0x00, 0x00, 0xd0, 0xfb, 0xdc, 0xfb, 0xff, 0x3f, 0xf7, 0x72, 0xdb, 0xbd, 0x5e, 0xff, 0x0d, 0x00, 0x00, 0xff,
0xff, 0xe2, 0xfa, 0x40, 0x67, 0xee, 0x02, 0x00, 0x00,
} }
func (m *Credentials) Marshal() (dAtA []byte, err error) { func (m *Credentials) Marshal() (dAtA []byte, err error) {
@ -382,6 +391,13 @@ func (m *Credentials) MarshalToSizedBuffer(dAtA []byte) (int, error) {
_ = i _ = i
var l int var l int
_ = l _ = l
if len(m.ClientVersion) > 0 {
i -= len(m.ClientVersion)
copy(dAtA[i:], m.ClientVersion)
i = encodeVarintHandshake(dAtA, i, uint64(len(m.ClientVersion)))
i--
dAtA[i] = 0x22
}
if m.Version != 0 { if m.Version != 0 {
i = encodeVarintHandshake(dAtA, i, uint64(m.Version)) i = encodeVarintHandshake(dAtA, i, uint64(m.Version))
i-- i--
@ -522,6 +538,10 @@ func (m *Credentials) Size() (n int) {
if m.Version != 0 { if m.Version != 0 {
n += 1 + sovHandshake(uint64(m.Version)) n += 1 + sovHandshake(uint64(m.Version))
} }
l = len(m.ClientVersion)
if l > 0 {
n += 1 + l + sovHandshake(uint64(l))
}
return n return n
} }
@ -673,6 +693,38 @@ func (m *Credentials) Unmarshal(dAtA []byte) error {
break break
} }
} }
case 4:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field ClientVersion", wireType)
}
var stringLen uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
stringLen |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
intStringLen := int(stringLen)
if intStringLen < 0 {
return ErrInvalidLengthHandshake
}
postIndex := iNdEx + intStringLen
if postIndex < 0 {
return ErrInvalidLengthHandshake
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.ClientVersion = string(dAtA[iNdEx:postIndex])
iNdEx = postIndex
default: default:
iNdEx = preIndex iNdEx = preIndex
skippy, err := skipHandshake(dAtA[iNdEx:]) skippy, err := skipHandshake(dAtA[iNdEx:])

View File

@ -39,6 +39,7 @@ message Credentials {
CredentialsType type = 1; CredentialsType type = 1;
bytes payload = 2; bytes payload = 2;
uint32 version = 3; uint32 version = 3;
string clientVersion = 4;
} }
enum CredentialsType { enum CredentialsType {

View File

@ -99,13 +99,14 @@ func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx
} }
func (s *secureService) HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, peerId string) (cctx context.Context, err error) { func (s *secureService) HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, peerId string) (cctx context.Context, err error) {
identity, err := handshake.IncomingHandshake(ctx, conn, peerId, s.inboundChecker) res, err := handshake.IncomingHandshake(ctx, conn, peerId, s.inboundChecker)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cctx = context.Background() cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, peerId) cctx = peer.CtxWithPeerId(cctx, peerId)
cctx = peer.CtxWithIdentity(cctx, identity) cctx = peer.CtxWithIdentity(cctx, res.Identity)
cctx = peer.CtxWithClientVersion(cctx, res.ClientVersion)
return return
} }
@ -122,13 +123,14 @@ func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx
} else { } else {
checker = s.noVerifyChecker checker = s.noVerifyChecker
} }
identity, err := handshake.OutgoingHandshake(ctx, sc, sc.RemotePeer().String(), checker) res, err := handshake.OutgoingHandshake(ctx, sc, sc.RemotePeer().String(), checker)
if err != nil { if err != nil {
return nil, err return nil, err
} }
cctx = context.Background() cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String()) cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
cctx = peer.CtxWithIdentity(cctx, identity) cctx = peer.CtxWithIdentity(cctx, res.Identity)
cctx = peer.CtxWithClientVersion(cctx, res.ClientVersion)
return cctx, nil return cctx, nil
} }