From c43ac9eb84187419dd3fb2844ff71c605d2a2f27 Mon Sep 17 00:00:00 2001 From: Sergey Cherepanov Date: Wed, 31 May 2023 16:22:49 +0200 Subject: [PATCH] handshake proto + common handshake fixes --- commonfile/fileproto/file_drpc.pb.go | 2 +- .../spacesyncproto/spacesync_drpc.pb.go | 6 +- .../coordinatorproto/coordinator_drpc.pb.go | 2 +- net/secureservice/handshake/credential.go | 125 ++++++++++ .../{handshake_test.go => credential_test.go} | 62 +++-- net/secureservice/handshake/handshake.go | 203 ++++----------- .../handshake/handshakeproto/handshake.pb.go | 235 ++++++++++++++++-- .../handshakeproto/protos/handshake.proto | 18 ++ net/secureservice/handshake/proto.go | 97 ++++++++ net/secureservice/handshake/proto_test.go | 121 +++++++++ .../testservice/testservice_drpc.pb.go | 6 +- 11 files changed, 664 insertions(+), 213 deletions(-) create mode 100644 net/secureservice/handshake/credential.go rename net/secureservice/handshake/{handshake_test.go => credential_test.go} (94%) create mode 100644 net/secureservice/handshake/proto.go create mode 100644 net/secureservice/handshake/proto_test.go diff --git a/commonfile/fileproto/file_drpc.pb.go b/commonfile/fileproto/file_drpc.pb.go index 2f9ee69d..a03c22cd 100644 --- a/commonfile/fileproto/file_drpc.pb.go +++ b/commonfile/fileproto/file_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: commonfile/fileproto/protos/file.proto package fileproto diff --git a/commonspace/spacesyncproto/spacesync_drpc.pb.go b/commonspace/spacesyncproto/spacesync_drpc.pb.go index 2c82a645..55b3fed4 100644 --- a/commonspace/spacesyncproto/spacesync_drpc.pb.go +++ b/commonspace/spacesyncproto/spacesync_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: commonspace/spacesyncproto/protos/spacesync.proto package spacesyncproto @@ -102,6 +102,10 @@ type drpcSpaceSync_ObjectSyncStreamClient struct { drpc.Stream } +func (x *drpcSpaceSync_ObjectSyncStreamClient) GetStream() drpc.Stream { + return x.Stream +} + func (x *drpcSpaceSync_ObjectSyncStreamClient) Send(m *ObjectSyncMessage) error { return x.MsgSend(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}) } diff --git a/coordinator/coordinatorproto/coordinator_drpc.pb.go b/coordinator/coordinatorproto/coordinator_drpc.pb.go index 75e73a7b..0ed69ea2 100644 --- a/coordinator/coordinatorproto/coordinator_drpc.pb.go +++ b/coordinator/coordinatorproto/coordinator_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: coordinator/coordinatorproto/protos/coordinator.proto package coordinatorproto diff --git a/net/secureservice/handshake/credential.go b/net/secureservice/handshake/credential.go new file mode 100644 index 00000000..06108928 --- /dev/null +++ b/net/secureservice/handshake/credential.go @@ -0,0 +1,125 @@ +package handshake + +import ( + "context" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "github.com/libp2p/go-libp2p/core/sec" +) + +func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + identity, err = outgoingHandshake(h, sc, cc) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = sc.Close() + return nil, ctx.Err() + } +} + +func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { + defer h.release() + h.conn = sc + localCred := cc.MakeCredentials(sc) + if err = h.writeCredentials(localCred); err != nil { + h.tryWriteErrAndClose(err) + return + } + msg, err := h.readMsg(msgTypeAck, msgTypeCred) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if msg.ack != nil { + if msg.ack.Error == handshakeproto.Error_InvalidCredentials { + return nil, ErrPeerDeclinedCredentials + } + return nil, HandshakeError{e: msg.ack.Error} + } + + if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { + h.tryWriteErrAndClose(err) + return + } + + if err = h.writeAck(handshakeproto.Error_Null); err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + + msg, err = h.readMsg(msgTypeAck) + if err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + if msg.ack.Error == handshakeproto.Error_Null { + return identity, nil + } else { + _ = h.conn.Close() + return nil, HandshakeError{e: msg.ack.Error} + } +} + +func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + identity, err = incomingHandshake(h, sc, cc) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = sc.Close() + return nil, ctx.Err() + } +} + +func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { + defer h.release() + h.conn = sc + + msg, err := h.readMsg(msgTypeCred) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { + h.tryWriteErrAndClose(err) + return + } + + if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + + msg, err = h.readMsg(msgTypeAck) + if err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + if msg.ack.Error != handshakeproto.Error_Null { + if msg.ack.Error == handshakeproto.Error_InvalidCredentials { + return nil, ErrPeerDeclinedCredentials + } + return nil, HandshakeError{e: msg.ack.Error} + } + if err = h.writeAck(handshakeproto.Error_Null); err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + return +} diff --git a/net/secureservice/handshake/handshake_test.go b/net/secureservice/handshake/credential_test.go similarity index 94% rename from net/secureservice/handshake/handshake_test.go rename to net/secureservice/handshake/credential_test.go index 0d8b16d7..6a34f9cb 100644 --- a/net/secureservice/handshake/handshake_test.go +++ b/net/secureservice/handshake/credential_test.go @@ -38,15 +38,14 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto) require.NoError(t, err) - require.Nil(t, msg.ack) _, err = noVerifyChecker.CheckCredential(c2, msg.cred) require.NoError(t, err) // send credential message require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // receive ack - msg, err = h.readMsg() + msg, err = h.readMsg(msgTypeAck) require.NoError(t, err) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) // send ack @@ -76,7 +75,7 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) _ = c2.Close() res := <-handshakeResCh @@ -92,7 +91,7 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.NoError(t, h.writeAck(ErrInvalidCredentials.e)) res := <-handshakeResCh @@ -108,10 +107,10 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error) res := <-handshakeResCh @@ -127,7 +126,7 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials and close conn require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) @@ -145,12 +144,12 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // read ack and close conn - _, err = h.readMsg() + _, err = h.readMsg(msgTypeAck) require.NoError(t, err) _ = c2.Close() res := <-handshakeResCh @@ -166,18 +165,17 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // read ack - _, err = h.readMsg() + _, err = h.readMsg(msgTypeAck) require.NoError(t, err) // write cred instead ack require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) - msg, err := h.readMsg() - require.NoError(t, err) - assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) + _, err = h.readMsg(msgTypeAck) + require.Error(t, err) res := <-handshakeResCh require.Error(t, res.err) }) @@ -191,7 +189,7 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) _, err = noVerifyChecker.CheckCredential(c2, msg.cred) @@ -199,7 +197,7 @@ func TestOutgoingHandshake(t *testing.T) { // send credential message require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // receive ack - msg, err = h.readMsg() + msg, err = h.readMsg(msgTypeAck) require.NoError(t, err) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) // send ack @@ -219,7 +217,7 @@ func TestOutgoingHandshake(t *testing.T) { h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) ctxCancel() res := <-handshakeResCh @@ -244,14 +242,14 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) // write ack require.NoError(t, h.writeAck(handshakeproto.Error_Null)) // wait ack - msg, err = h.readMsg() + msg, err = h.readMsg(msgTypeAck) require.NoError(t, err) assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error) res := <-handshakeResCh @@ -310,7 +308,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // except ack with error - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) require.Nil(t, msg.cred) require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error) @@ -330,7 +328,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // except ack with error - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) require.Nil(t, msg.cred) require.Equal(t, handshakeproto.Error_IncompatibleVersion, msg.ack.Error) @@ -350,13 +348,13 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // read cred - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write cred instead ack require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) - // expect ack with error - msg, err := h.readMsg() - require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) + // expect EOF + _, err = h.readMsg(msgTypeAck) + require.Error(t, err) res := <-handshakeResCh require.Error(t, res.err) }) @@ -372,7 +370,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // read cred and close conn - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) _ = c2.Close() @@ -391,7 +389,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) @@ -413,7 +411,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) @@ -435,7 +433,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) @@ -458,7 +456,7 @@ func TestIncomingHandshake(t *testing.T) { // write credentials require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) // wait credentials - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) ctxCancel() res := <-handshakeResCh @@ -482,7 +480,7 @@ func TestNotAHandshakeMessage(t *testing.T) { _, err := c2.Write([]byte("some unexpected bytes")) require.Error(t, err) res := <-handshakeResCh - assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error()) + assert.Error(t, res.err) } func TestEndToEnd(t *testing.T) { diff --git a/net/secureservice/handshake/handshake.go b/net/secureservice/handshake/handshake.go index 04a9de72..39e6656f 100644 --- a/net/secureservice/handshake/handshake.go +++ b/net/secureservice/handshake/handshake.go @@ -1,7 +1,6 @@ package handshake import ( - "context" "encoding/binary" "errors" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" @@ -14,8 +13,17 @@ import ( const headerSize = 5 // 1 byte for type + 4 byte for uint32 size const ( - msgTypeCred = byte(1) - msgTypeAck = byte(2) + msgTypeCred = byte(1) + msgTypeAck = byte(2) + msgTypeProto = byte(3) + + sizeLimit = 200 * 1024 // 200 Kb +) + +var ( + credMsgTypes = []byte{msgTypeCred, msgTypeAck} + protoMsgTypes = []byte{msgTypeProto, msgTypeAck} + protoMsgTypesAck = []byte{msgTypeAck} ) type HandshakeError struct { @@ -38,17 +46,20 @@ var ( ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed} ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected} - ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion} + ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion} + ErrIncompatibleProto = HandshakeError{e: handshakeproto.Error_IncompatibleProto} + ErrRemoteIncompatibleProto = HandshakeError{Err: errors.New("remote peer declined the proto")} - ErrGotNotAHandshakeMessage = errors.New("go not a handshake message") + ErrGotUnexpectedMessage = errors.New("go not a handshake message") ) var handshakePool = &sync.Pool{New: func() any { return &handshake{ - remoteCred: &handshakeproto.Credentials{}, - remoteAck: &handshakeproto.Ack{}, - localAck: &handshakeproto.Ack{}, - buf: make([]byte, 0, 1024), + remoteCred: &handshakeproto.Credentials{}, + remoteAck: &handshakeproto.Ack{}, + localAck: &handshakeproto.Ack{}, + remoteProto: &handshakeproto.Proto{}, + buf: make([]byte, 0, 1024), } }} @@ -57,147 +68,17 @@ type CredentialChecker interface { CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) } -func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - if ctx == nil { - ctx = context.Background() - } - h := newHandshake() - done := make(chan struct{}) - go func() { - defer close(done) - identity, err = outgoingHandshake(h, sc, cc) - }() - select { - case <-done: - return - case <-ctx.Done(): - _ = sc.Close() - return nil, ctx.Err() - } -} - -func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - defer h.release() - h.conn = sc - localCred := cc.MakeCredentials(sc) - if err = h.writeCredentials(localCred); err != nil { - h.tryWriteErrAndClose(err) - return - } - msg, err := h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return - } - if msg.ack != nil { - if msg.ack.Error == handshakeproto.Error_InvalidCredentials { - return nil, ErrPeerDeclinedCredentials - } - return nil, HandshakeError{e: msg.ack.Error} - } - - if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { - h.tryWriteErrAndClose(err) - return - } - - if err = h.writeAck(handshakeproto.Error_Null); err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - - msg, err = h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack == nil { - err = ErrUnexpectedPayload - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack.Error == handshakeproto.Error_Null { - return identity, nil - } else { - _ = h.conn.Close() - return nil, HandshakeError{e: msg.ack.Error} - } -} - -func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - if ctx == nil { - ctx = context.Background() - } - h := newHandshake() - done := make(chan struct{}) - go func() { - defer close(done) - identity, err = incomingHandshake(h, sc, cc) - }() - select { - case <-done: - return - case <-ctx.Done(): - _ = sc.Close() - return nil, ctx.Err() - } -} - -func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - defer h.release() - h.conn = sc - - msg, err := h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return - } - if msg.ack != nil { - return nil, ErrUnexpectedPayload - } - if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { - h.tryWriteErrAndClose(err) - return - } - - if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - - msg, err = h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack == nil { - err = ErrUnexpectedPayload - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack.Error != handshakeproto.Error_Null { - if msg.ack.Error == handshakeproto.Error_InvalidCredentials { - return nil, ErrPeerDeclinedCredentials - } - return nil, HandshakeError{e: msg.ack.Error} - } - if err = h.writeAck(handshakeproto.Error_Null); err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - return -} - func newHandshake() *handshake { return handshakePool.Get().(*handshake) } type handshake struct { - conn sec.SecureConn - remoteCred *handshakeproto.Credentials - remoteAck *handshakeproto.Ack - localAck *handshakeproto.Ack - buf []byte + conn sec.SecureConn + remoteCred *handshakeproto.Credentials + remoteProto *handshakeproto.Proto + remoteAck *handshakeproto.Ack + localAck *handshakeproto.Ack + buf []byte } func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err error) { @@ -209,8 +90,17 @@ func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err erro return h.writeData(msgTypeCred, n) } +func (h *handshake) writeProto(proto *handshakeproto.Proto) (err error) { + h.buf = slices.Grow(h.buf, proto.Size()+headerSize)[:proto.Size()+headerSize] + n, err := proto.MarshalToSizedBuffer(h.buf[headerSize:]) + if err != nil { + return err + } + return h.writeData(msgTypeProto, n) +} + func (h *handshake) tryWriteErrAndClose(err error) { - if err == ErrGotNotAHandshakeMessage { + if err == ErrUnexpectedPayload { // if we got unexpected message - just close the connection _ = h.conn.Close() return @@ -243,21 +133,26 @@ func (h *handshake) writeData(tp byte, size int) (err error) { } type message struct { - cred *handshakeproto.Credentials - ack *handshakeproto.Ack + cred *handshakeproto.Credentials + proto *handshakeproto.Proto + ack *handshakeproto.Ack } -func (h *handshake) readMsg() (msg message, err error) { +func (h *handshake) readMsg(allowedTypes ...byte) (msg message, err error) { h.buf = slices.Grow(h.buf, headerSize)[:headerSize] if _, err = io.ReadFull(h.conn, h.buf[:headerSize]); err != nil { return } tp := h.buf[0] - if tp != msgTypeCred && tp != msgTypeAck { - err = ErrGotNotAHandshakeMessage + if !slices.Contains(allowedTypes, tp) { + err = ErrUnexpectedPayload return } size := binary.LittleEndian.Uint32(h.buf[1:headerSize]) + if size > sizeLimit { + err = ErrGotUnexpectedMessage + return + } h.buf = slices.Grow(h.buf, int(size))[:size] if _, err = io.ReadFull(h.conn, h.buf[:size]); err != nil { return @@ -273,6 +168,11 @@ func (h *handshake) readMsg() (msg message, err error) { return } msg.ack = h.remoteAck + case msgTypeProto: + if err = h.remoteProto.Unmarshal(h.buf[:size]); err != nil { + return + } + msg.proto = h.remoteProto } return } @@ -284,5 +184,6 @@ func (h *handshake) release() { h.remoteAck.Error = 0 h.remoteCred.Type = 0 h.remoteCred.Payload = h.remoteCred.Payload[:0] + h.remoteProto.Proto = 0 handshakePool.Put(h) } diff --git a/net/secureservice/handshake/handshakeproto/handshake.pb.go b/net/secureservice/handshake/handshakeproto/handshake.pb.go index 3d868ef0..e9d6dfcb 100644 --- a/net/secureservice/handshake/handshakeproto/handshake.pb.go +++ b/net/secureservice/handshake/handshakeproto/handshake.pb.go @@ -59,6 +59,7 @@ const ( Error_SkipVerifyNotAllowed Error = 4 Error_DeadlineExceeded Error = 5 Error_IncompatibleVersion Error = 6 + Error_IncompatibleProto Error = 7 ) var Error_name = map[int32]string{ @@ -69,6 +70,7 @@ var Error_name = map[int32]string{ 4: "SkipVerifyNotAllowed", 5: "DeadlineExceeded", 6: "IncompatibleVersion", + 7: "IncompatibleProto", } var Error_value = map[string]int32{ @@ -79,6 +81,7 @@ var Error_value = map[string]int32{ "SkipVerifyNotAllowed": 4, "DeadlineExceeded": 5, "IncompatibleVersion": 6, + "IncompatibleProto": 7, } func (x Error) String() string { @@ -89,6 +92,28 @@ func (Error) EnumDescriptor() ([]byte, []int) { return fileDescriptor_60283fc75f020893, []int{1} } +type ProtoType int32 + +const ( + ProtoType_DRPC ProtoType = 0 +) + +var ProtoType_name = map[int32]string{ + 0: "DRPC", +} + +var ProtoType_value = map[string]int32{ + "DRPC": 0, +} + +func (x ProtoType) String() string { + return proto.EnumName(ProtoType_name, int32(x)) +} + +func (ProtoType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_60283fc75f020893, []int{2} +} + type Credentials struct { Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"` Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` @@ -247,12 +272,58 @@ func (m *Ack) GetError() Error { return Error_Null } +type Proto struct { + Proto ProtoType `protobuf:"varint,1,opt,name=proto,proto3,enum=anyHandshake.ProtoType" json:"proto,omitempty"` +} + +func (m *Proto) Reset() { *m = Proto{} } +func (m *Proto) String() string { return proto.CompactTextString(m) } +func (*Proto) ProtoMessage() {} +func (*Proto) Descriptor() ([]byte, []int) { + return fileDescriptor_60283fc75f020893, []int{3} +} +func (m *Proto) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Proto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Proto.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Proto) XXX_Merge(src proto.Message) { + xxx_messageInfo_Proto.Merge(m, src) +} +func (m *Proto) XXX_Size() int { + return m.Size() +} +func (m *Proto) XXX_DiscardUnknown() { + xxx_messageInfo_Proto.DiscardUnknown(m) +} + +var xxx_messageInfo_Proto proto.InternalMessageInfo + +func (m *Proto) GetProto() ProtoType { + if m != nil { + return m.Proto + } + return ProtoType_DRPC +} + func init() { proto.RegisterEnum("anyHandshake.CredentialsType", CredentialsType_name, CredentialsType_value) proto.RegisterEnum("anyHandshake.Error", Error_name, Error_value) + proto.RegisterEnum("anyHandshake.ProtoType", ProtoType_name, ProtoType_value) proto.RegisterType((*Credentials)(nil), "anyHandshake.Credentials") proto.RegisterType((*PayloadSignedPeerIds)(nil), "anyHandshake.PayloadSignedPeerIds") proto.RegisterType((*Ack)(nil), "anyHandshake.Ack") + proto.RegisterType((*Proto)(nil), "anyHandshake.Proto") } func init() { @@ -260,32 +331,35 @@ func init() { } var fileDescriptor_60283fc75f020893 = []byte{ - // 395 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcd, 0x6e, 0x13, 0x31, - 0x10, 0xc7, 0xd7, 0x4d, 0x52, 0xaa, 0x21, 0x2d, 0xee, 0x34, 0xc0, 0x0a, 0x89, 0x55, 0x94, 0x53, - 0xc8, 0x21, 0xe1, 0xeb, 0x05, 0x02, 0x2d, 0x22, 0x97, 0xaa, 0xda, 0x42, 0x0f, 0xdc, 0xdc, 0xf5, - 0xd0, 0x5a, 0x31, 0xf6, 0xca, 0x76, 0x43, 0xf7, 0x2d, 0xb8, 0xf2, 0x46, 0x1c, 0x7b, 0xe4, 0x88, - 0x92, 0x17, 0x41, 0x71, 0x12, 0x92, 0x70, 0xea, 0xc5, 0x9e, 0x8f, 0x9f, 0xfd, 0xff, 0x8f, 0x65, - 0x18, 0x1a, 0x0a, 0x03, 0x4f, 0xc5, 0x8d, 0x23, 0x4f, 0x6e, 0xa2, 0x0a, 0x1a, 0x5c, 0x0b, 0x23, - 0xfd, 0xb5, 0x18, 0x6f, 0x44, 0xa5, 0xb3, 0xc1, 0x0e, 0xe2, 0xea, 0xd7, 0xd5, 0x7e, 0x2c, 0x60, - 0x53, 0x98, 0xea, 0xe3, 0xaa, 0xd6, 0x09, 0xf0, 0xf0, 0xbd, 0x23, 0x49, 0x26, 0x28, 0xa1, 0x3d, - 0xbe, 0x82, 0x7a, 0xa8, 0x4a, 0x4a, 0x59, 0x9b, 0x75, 0x0f, 0x5e, 0x3f, 0xef, 0x6f, 0xb2, 0xfd, - 0x0d, 0xf0, 0x53, 0x55, 0x52, 0x1e, 0x51, 0x4c, 0xe1, 0x41, 0x29, 0x2a, 0x6d, 0x85, 0x4c, 0x77, - 0xda, 0xac, 0xdb, 0xcc, 0x57, 0xe9, 0xbc, 0x33, 0x21, 0xe7, 0x95, 0x35, 0x69, 0xad, 0xcd, 0xba, - 0xfb, 0xf9, 0x2a, 0xed, 0x7c, 0x80, 0xd6, 0xd9, 0x02, 0x3a, 0x57, 0x57, 0x86, 0xe4, 0x19, 0x91, - 0x1b, 0x49, 0x8f, 0xcf, 0x60, 0x4f, 0x45, 0x89, 0x50, 0x45, 0x0b, 0xcd, 0xfc, 0x5f, 0x8e, 0x08, - 0x75, 0xaf, 0xae, 0xcc, 0x52, 0x24, 0xc6, 0x9d, 0x97, 0x50, 0x1b, 0x16, 0x63, 0x7c, 0x01, 0x0d, - 0x72, 0xce, 0xba, 0xa5, 0xed, 0xa3, 0x6d, 0xdb, 0x27, 0xf3, 0x56, 0xbe, 0x20, 0x7a, 0x6f, 0xe1, - 0xd1, 0x7f, 0x63, 0xe0, 0x01, 0xc0, 0xf9, 0x58, 0x95, 0x17, 0xe4, 0xd4, 0xd7, 0x8a, 0x27, 0x78, - 0x08, 0xfb, 0x5b, 0xae, 0x38, 0xeb, 0xfd, 0x64, 0xd0, 0x88, 0xd7, 0xe0, 0x1e, 0xd4, 0x4f, 0x6f, - 0xb4, 0xe6, 0xc9, 0xfc, 0xd8, 0x67, 0x43, 0xb7, 0x25, 0x15, 0x81, 0x24, 0x67, 0xf8, 0x04, 0x70, - 0x64, 0x26, 0x42, 0x2b, 0xb9, 0x21, 0xc0, 0x77, 0xf0, 0x31, 0x1c, 0xae, 0xb9, 0xe5, 0xd4, 0xbc, - 0x86, 0x29, 0xb4, 0xd6, 0xaa, 0xa7, 0x36, 0x0c, 0xb5, 0xb6, 0xdf, 0x49, 0xf2, 0x3a, 0xb6, 0x80, - 0x1f, 0x93, 0x90, 0x5a, 0x19, 0x3a, 0xb9, 0x2d, 0x88, 0x24, 0x49, 0xde, 0xc0, 0xa7, 0x70, 0x34, - 0x32, 0x85, 0xfd, 0x56, 0x8a, 0xa0, 0x2e, 0x35, 0x5d, 0x2c, 0x5e, 0x92, 0xef, 0xbe, 0x3b, 0xfe, - 0x35, 0xcd, 0xd8, 0xdd, 0x34, 0x63, 0x7f, 0xa6, 0x19, 0xfb, 0x31, 0xcb, 0x92, 0xbb, 0x59, 0x96, - 0xfc, 0x9e, 0x65, 0xc9, 0x97, 0xde, 0xfd, 0x3f, 0xcb, 0xe5, 0x6e, 0xdc, 0xde, 0xfc, 0x0d, 0x00, - 0x00, 0xff, 0xff, 0xbf, 0x78, 0x2f, 0x36, 0x61, 0x02, 0x00, 0x00, + // 439 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcf, 0x6e, 0xd3, 0x40, + 0x10, 0xc6, 0xbd, 0x4d, 0xdc, 0x96, 0x21, 0x2d, 0xdb, 0x6d, 0x4a, 0x2d, 0x24, 0xac, 0x28, 0xa7, + 0x10, 0x89, 0x84, 0x7f, 0xe2, 0x1e, 0x9a, 0x22, 0x72, 0xa9, 0x22, 0x17, 0x7a, 0xe0, 0xb6, 0xf5, + 0x0e, 0xed, 0x2a, 0xcb, 0xda, 0xda, 0xdd, 0x86, 0xfa, 0x2d, 0x78, 0x14, 0x1e, 0x83, 0x63, 0x8f, + 0x1c, 0x51, 0xf2, 0x22, 0xc8, 0x9b, 0xa4, 0x71, 0x38, 0xf5, 0x62, 0xef, 0xcc, 0xfc, 0xc6, 0xdf, + 0x7c, 0xe3, 0x85, 0x81, 0x46, 0xd7, 0xb7, 0x98, 0xde, 0x18, 0xb4, 0x68, 0xa6, 0x32, 0xc5, 0xfe, + 0x35, 0xd7, 0xc2, 0x5e, 0xf3, 0x49, 0xe5, 0x94, 0x9b, 0xcc, 0x65, 0x7d, 0xff, 0xb4, 0xeb, 0x6c, + 0xcf, 0x27, 0x58, 0x83, 0xeb, 0xe2, 0xd3, 0x2a, 0xd7, 0x76, 0xf0, 0xf8, 0xc4, 0xa0, 0x40, 0xed, + 0x24, 0x57, 0x96, 0xbd, 0x86, 0xba, 0x2b, 0x72, 0x8c, 0x48, 0x8b, 0x74, 0xf6, 0xdf, 0x3c, 0xef, + 0x55, 0xd9, 0x5e, 0x05, 0xfc, 0x5c, 0xe4, 0x98, 0x78, 0x94, 0x45, 0xb0, 0x93, 0xf3, 0x42, 0x65, + 0x5c, 0x44, 0x5b, 0x2d, 0xd2, 0x69, 0x24, 0xab, 0xb0, 0xac, 0x4c, 0xd1, 0x58, 0x99, 0xe9, 0xa8, + 0xd6, 0x22, 0x9d, 0xbd, 0x64, 0x15, 0xb6, 0x3f, 0x42, 0x73, 0xbc, 0x80, 0xce, 0xe5, 0x95, 0x46, + 0x31, 0x46, 0x34, 0x23, 0x61, 0xd9, 0x33, 0xd8, 0x95, 0x5e, 0xc2, 0x15, 0x7e, 0x84, 0x46, 0x72, + 0x1f, 0x33, 0x06, 0x75, 0x2b, 0xaf, 0xf4, 0x52, 0xc4, 0x9f, 0xdb, 0xaf, 0xa0, 0x36, 0x48, 0x27, + 0xec, 0x05, 0x84, 0x68, 0x4c, 0x66, 0x96, 0x63, 0x1f, 0x6e, 0x8e, 0x7d, 0x5a, 0x96, 0x92, 0x05, + 0xd1, 0x7e, 0x0f, 0xe1, 0xd8, 0xaf, 0xe1, 0x25, 0x84, 0x7e, 0x1f, 0xcb, 0x9e, 0xe3, 0xcd, 0x1e, + 0xcf, 0x78, 0x93, 0x0b, 0xaa, 0xfb, 0x0e, 0x9e, 0xfc, 0x67, 0x9f, 0xed, 0x03, 0x9c, 0x4f, 0x64, + 0x7e, 0x81, 0x46, 0x7e, 0x2b, 0x68, 0xc0, 0x0e, 0x60, 0x6f, 0xc3, 0x0d, 0x25, 0xdd, 0x5f, 0x04, + 0x42, 0x2f, 0xcf, 0x76, 0xa1, 0x7e, 0x76, 0xa3, 0x14, 0x0d, 0xca, 0xb6, 0x2f, 0x1a, 0x6f, 0x73, + 0x4c, 0x1d, 0x0a, 0x4a, 0xd8, 0x53, 0x60, 0x23, 0x3d, 0xe5, 0x4a, 0x8a, 0x8a, 0x00, 0xdd, 0x62, + 0x47, 0x70, 0xb0, 0xe6, 0x96, 0xdb, 0xa2, 0x35, 0x16, 0x41, 0x73, 0xad, 0x7a, 0x96, 0xb9, 0x81, + 0x52, 0xd9, 0x0f, 0x14, 0xb4, 0xce, 0x9a, 0x40, 0x87, 0xc8, 0x85, 0x92, 0x1a, 0x4f, 0x6f, 0x53, + 0x44, 0x81, 0x82, 0x86, 0xec, 0x18, 0x0e, 0x47, 0x3a, 0xcd, 0xbe, 0xe7, 0xdc, 0xc9, 0x4b, 0x85, + 0x17, 0x8b, 0x3f, 0x40, 0xb7, 0xcb, 0xef, 0x57, 0x0b, 0xde, 0x31, 0xdd, 0xe9, 0x1e, 0xc1, 0xa3, + 0x7b, 0xf3, 0xe5, 0xd4, 0xc3, 0x64, 0x7c, 0x42, 0x83, 0x0f, 0xc3, 0xdf, 0xb3, 0x98, 0xdc, 0xcd, + 0x62, 0xf2, 0x77, 0x16, 0x93, 0x9f, 0xf3, 0x38, 0xb8, 0x9b, 0xc7, 0xc1, 0x9f, 0x79, 0x1c, 0x7c, + 0xed, 0x3e, 0xfc, 0x4a, 0x5e, 0x6e, 0xfb, 0xd7, 0xdb, 0x7f, 0x01, 0x00, 0x00, 0xff, 0xff, 0x53, + 0x32, 0xf7, 0x79, 0xc7, 0x02, 0x00, 0x00, } func (m *Credentials) Marshal() (dAtA []byte, err error) { @@ -393,6 +467,34 @@ func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *Proto) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Proto) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Proto) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Proto != 0 { + i = encodeVarintHandshake(dAtA, i, uint64(m.Proto)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int { offset -= sovHandshake(v) base := offset @@ -452,6 +554,18 @@ func (m *Ack) Size() (n int) { return n } +func (m *Proto) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Proto != 0 { + n += 1 + sovHandshake(uint64(m.Proto)) + } + return n +} + func sovHandshake(x uint64) (n int) { return (math_bits.Len64(x|1) + 6) / 7 } @@ -767,6 +881,75 @@ func (m *Ack) Unmarshal(dAtA []byte) error { } return nil } +func (m *Proto) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Proto: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Proto: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Proto", wireType) + } + m.Proto = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Proto |= ProtoType(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipHandshake(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthHandshake + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func skipHandshake(dAtA []byte) (n int, err error) { l := len(dAtA) iNdEx := 0 diff --git a/net/secureservice/handshake/handshakeproto/protos/handshake.proto b/net/secureservice/handshake/handshakeproto/protos/handshake.proto index cca5822e..1ea66b28 100644 --- a/net/secureservice/handshake/handshakeproto/protos/handshake.proto +++ b/net/secureservice/handshake/handshakeproto/protos/handshake.proto @@ -5,6 +5,8 @@ option go_package = "net/secureservice/handshake/handshakeproto"; /* +CREDENTIALS HANDSHAKE + Alice opens a new connection with Bob 1. TLS handshake done successfully; both sides know local and remote peer identifiers. @@ -68,4 +70,20 @@ enum Error { SkipVerifyNotAllowed = 4; DeadlineExceeded = 5; IncompatibleVersion = 6; + IncompatibleProto = 7; +} + + +/* + +PROTO HANDSHAKE + + */ + +message Proto { + ProtoType proto = 1; +} + +enum ProtoType { + DRPC = 0; } \ No newline at end of file diff --git a/net/secureservice/handshake/proto.go b/net/secureservice/handshake/proto.go new file mode 100644 index 00000000..1e133069 --- /dev/null +++ b/net/secureservice/handshake/proto.go @@ -0,0 +1,97 @@ +package handshake + +import ( + "context" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "github.com/libp2p/go-libp2p/core/sec" + "golang.org/x/exp/slices" +) + +type ProtoChecker struct { + AllowedProtoTypes []handshakeproto.ProtoType +} + +func OutgoingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt handshakeproto.ProtoType) (err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + err = outgoingProtoHandshake(h, sc, pt) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = sc.Close() + return ctx.Err() + } +} + +func outgoingProtoHandshake(h *handshake, sc sec.SecureConn, pt handshakeproto.ProtoType) (err error) { + defer h.release() + h.conn = sc + localProto := &handshakeproto.Proto{ + Proto: pt, + } + if err = h.writeProto(localProto); err != nil { + h.tryWriteErrAndClose(err) + return + } + msg, err := h.readMsg(msgTypeAck) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if msg.ack.Error == handshakeproto.Error_IncompatibleProto { + return ErrRemoteIncompatibleProto + } + if msg.ack.Error == handshakeproto.Error_Null { + return nil + } + return HandshakeError{e: msg.ack.Error} +} + +func IncomingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + protoType, err = incomingProtoHandshake(h, sc, pt) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = sc.Close() + return 0, ctx.Err() + } +} + +func incomingProtoHandshake(h *handshake, sc sec.SecureConn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { + defer h.release() + h.conn = sc + + msg, err := h.readMsg(msgTypeProto) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if !slices.Contains(pt.AllowedProtoTypes, msg.proto.Proto) { + err = ErrIncompatibleProto + h.tryWriteErrAndClose(err) + return + } + + if err = h.writeAck(handshakeproto.Error_Null); err != nil { + h.tryWriteErrAndClose(err) + return 0, err + } else { + return msg.proto.Proto, nil + } +} diff --git a/net/secureservice/handshake/proto_test.go b/net/secureservice/handshake/proto_test.go new file mode 100644 index 00000000..f689e372 --- /dev/null +++ b/net/secureservice/handshake/proto_test.go @@ -0,0 +1,121 @@ +package handshake + +import ( + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +type protoRes struct { + protoType handshakeproto.ProtoType + err error +} + +func newProtoChecker(types ...handshakeproto.ProtoType) ProtoChecker { + return ProtoChecker{AllowedProtoTypes: types} +} +func TestIncomingProtoHandshake(t *testing.T) { + t.Run("success", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1)) + protoResCh <- protoRes{protoType: protoType, err: err} + }() + h := newHandshake() + h.conn = c2 + + // write desired proto + require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: handshakeproto.ProtoType(1)})) + msg, err := h.readMsg(msgTypeAck) + require.NoError(t, err) + assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error) + res := <-protoResCh + require.NoError(t, res.err) + assert.Equal(t, handshakeproto.ProtoType(1), res.protoType) + }) + t.Run("incompatible", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1)) + protoResCh <- protoRes{protoType: protoType, err: err} + }() + h := newHandshake() + h.conn = c2 + + // write desired proto + require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: 0})) + msg, err := h.readMsg(msgTypeAck) + require.NoError(t, err) + assert.Equal(t, handshakeproto.Error_IncompatibleProto, msg.ack.Error) + res := <-protoResCh + require.Error(t, res.err, ErrIncompatibleProto.Error()) + }) +} + +func TestOutgoingProtoHandshake(t *testing.T) { + t.Run("success", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + err := OutgoingProtoHandshake(nil, c1, 1) + protoResCh <- protoRes{err: err} + }() + h := newHandshake() + h.conn = c2 + + msg, err := h.readMsg(msgTypeProto) + require.NoError(t, err) + assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto) + require.NoError(t, h.writeAck(handshakeproto.Error_Null)) + + res := <-protoResCh + assert.NoError(t, res.err) + }) + t.Run("incompatible", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + err := OutgoingProtoHandshake(nil, c1, 1) + protoResCh <- protoRes{err: err} + }() + h := newHandshake() + h.conn = c2 + + msg, err := h.readMsg(msgTypeProto) + require.NoError(t, err) + assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto) + require.NoError(t, h.writeAck(handshakeproto.Error_IncompatibleProto)) + + res := <-protoResCh + assert.EqualError(t, res.err, ErrRemoteIncompatibleProto.Error()) + }) +} + +func TestEndToEndProto(t *testing.T) { + c1, c2 := newConnPair(t) + var ( + inResCh = make(chan protoRes, 1) + outResCh = make(chan protoRes, 1) + ) + st := time.Now() + go func() { + err := OutgoingProtoHandshake(nil, c1, 0) + outResCh <- protoRes{err: err} + }() + go func() { + protoType, err := IncomingProtoHandshake(nil, c2, newProtoChecker(0, 1)) + inResCh <- protoRes{protoType: protoType, err: err} + }() + + outRes := <-outResCh + assert.NoError(t, outRes.err) + + inRes := <-inResCh + assert.NoError(t, inRes.err) + assert.Equal(t, handshakeproto.ProtoType(0), inRes.protoType) + t.Log("dur", time.Since(st)) +} diff --git a/net/streampool/testservice/testservice_drpc.pb.go b/net/streampool/testservice/testservice_drpc.pb.go index f50fdbe7..cfe5bce9 100644 --- a/net/streampool/testservice/testservice_drpc.pb.go +++ b/net/streampool/testservice/testservice_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: net/streampool/testservice/protos/testservice.proto package testservice @@ -72,6 +72,10 @@ type drpcTest_TestStreamClient struct { drpc.Stream } +func (x *drpcTest_TestStreamClient) GetStream() drpc.Stream { + return x.Stream +} + func (x *drpcTest_TestStreamClient) Send(m *StreamMessage) error { return x.MsgSend(m, drpcEncoding_File_net_streampool_testservice_protos_testservice_proto{}) }