handshake proto + common handshake fixes

This commit is contained in:
Sergey Cherepanov 2023-05-31 16:22:49 +02:00
parent 553ed3a64b
commit c43ac9eb84
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C
11 changed files with 664 additions and 213 deletions

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT. // Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32 // protoc-gen-go-drpc version: v0.0.33
// source: commonfile/fileproto/protos/file.proto // source: commonfile/fileproto/protos/file.proto
package fileproto package fileproto

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT. // Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32 // protoc-gen-go-drpc version: v0.0.33
// source: commonspace/spacesyncproto/protos/spacesync.proto // source: commonspace/spacesyncproto/protos/spacesync.proto
package spacesyncproto package spacesyncproto
@ -102,6 +102,10 @@ type drpcSpaceSync_ObjectSyncStreamClient struct {
drpc.Stream drpc.Stream
} }
func (x *drpcSpaceSync_ObjectSyncStreamClient) GetStream() drpc.Stream {
return x.Stream
}
func (x *drpcSpaceSync_ObjectSyncStreamClient) Send(m *ObjectSyncMessage) error { func (x *drpcSpaceSync_ObjectSyncStreamClient) Send(m *ObjectSyncMessage) error {
return x.MsgSend(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}) return x.MsgSend(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{})
} }

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT. // Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32 // protoc-gen-go-drpc version: v0.0.33
// source: coordinator/coordinatorproto/protos/coordinator.proto // source: coordinator/coordinatorproto/protos/coordinator.proto
package coordinatorproto package coordinatorproto

View File

@ -0,0 +1,125 @@
package handshake
import (
"context"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
)
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = outgoingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
localCred := cc.MakeCredentials(sc)
if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err)
return
}
msg, err := h.readMsg(msgTypeAck, msgTypeCred)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg(msgTypeAck)
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error == handshakeproto.Error_Null {
return identity, nil
} else {
_ = h.conn.Close()
return nil, HandshakeError{e: msg.ack.Error}
}
}
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = incomingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
msg, err := h.readMsg(msgTypeCred)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg(msgTypeAck)
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error != handshakeproto.Error_Null {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
return
}

View File

@ -38,15 +38,14 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred) _, err = noVerifyChecker.CheckCredential(c2, msg.cred)
require.NoError(t, err) require.NoError(t, err)
// send credential message // send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// receive ack // receive ack
msg, err = h.readMsg() msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack // send ack
@ -76,7 +75,7 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
_ = c2.Close() _ = c2.Close()
res := <-handshakeResCh res := <-handshakeResCh
@ -92,7 +91,7 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, h.writeAck(ErrInvalidCredentials.e)) require.NoError(t, h.writeAck(ErrInvalidCredentials.e))
res := <-handshakeResCh res := <-handshakeResCh
@ -108,10 +107,10 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error) assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
res := <-handshakeResCh res := <-handshakeResCh
@ -127,7 +126,7 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write credentials and close conn // write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
@ -145,12 +144,12 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read ack and close conn // read ack and close conn
_, err = h.readMsg() _, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
_ = c2.Close() _ = c2.Close()
res := <-handshakeResCh res := <-handshakeResCh
@ -166,18 +165,17 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read ack // read ack
_, err = h.readMsg() _, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
// write cred instead ack // write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
msg, err := h.readMsg() _, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.Error(t, err)
assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
res := <-handshakeResCh res := <-handshakeResCh
require.Error(t, res.err) require.Error(t, res.err)
}) })
@ -191,7 +189,7 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred) _, err = noVerifyChecker.CheckCredential(c2, msg.cred)
@ -199,7 +197,7 @@ func TestOutgoingHandshake(t *testing.T) {
// send credential message // send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// receive ack // receive ack
msg, err = h.readMsg() msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack // send ack
@ -219,7 +217,7 @@ func TestOutgoingHandshake(t *testing.T) {
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
ctxCancel() ctxCancel()
res := <-handshakeResCh res := <-handshakeResCh
@ -244,14 +242,14 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials // wait credentials
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
// write ack // write ack
require.NoError(t, h.writeAck(handshakeproto.Error_Null)) require.NoError(t, h.writeAck(handshakeproto.Error_Null))
// wait ack // wait ack
msg, err = h.readMsg() msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error) assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
res := <-handshakeResCh res := <-handshakeResCh
@ -310,7 +308,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// except ack with error // except ack with error
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.cred) require.Nil(t, msg.cred)
require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error) require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error)
@ -330,7 +328,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// except ack with error // except ack with error
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.cred) require.Nil(t, msg.cred)
require.Equal(t, handshakeproto.Error_IncompatibleVersion, msg.ack.Error) require.Equal(t, handshakeproto.Error_IncompatibleVersion, msg.ack.Error)
@ -350,13 +348,13 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read cred // read cred
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write cred instead ack // write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// expect ack with error // expect EOF
msg, err := h.readMsg() _, err = h.readMsg(msgTypeAck)
require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) require.Error(t, err)
res := <-handshakeResCh res := <-handshakeResCh
require.Error(t, res.err) require.Error(t, res.err)
}) })
@ -372,7 +370,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read cred and close conn // read cred and close conn
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
_ = c2.Close() _ = c2.Close()
@ -391,7 +389,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials // wait credentials
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
@ -413,7 +411,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials // wait credentials
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
@ -435,7 +433,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials // wait credentials
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
@ -458,7 +456,7 @@ func TestIncomingHandshake(t *testing.T) {
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials // wait credentials
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
ctxCancel() ctxCancel()
res := <-handshakeResCh res := <-handshakeResCh
@ -482,7 +480,7 @@ func TestNotAHandshakeMessage(t *testing.T) {
_, err := c2.Write([]byte("some unexpected bytes")) _, err := c2.Write([]byte("some unexpected bytes"))
require.Error(t, err) require.Error(t, err)
res := <-handshakeResCh res := <-handshakeResCh
assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error()) assert.Error(t, res.err)
} }
func TestEndToEnd(t *testing.T) { func TestEndToEnd(t *testing.T) {

View File

@ -1,7 +1,6 @@
package handshake package handshake
import ( import (
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
@ -14,8 +13,17 @@ import (
const headerSize = 5 // 1 byte for type + 4 byte for uint32 size const headerSize = 5 // 1 byte for type + 4 byte for uint32 size
const ( const (
msgTypeCred = byte(1) msgTypeCred = byte(1)
msgTypeAck = byte(2) msgTypeAck = byte(2)
msgTypeProto = byte(3)
sizeLimit = 200 * 1024 // 200 Kb
)
var (
credMsgTypes = []byte{msgTypeCred, msgTypeAck}
protoMsgTypes = []byte{msgTypeProto, msgTypeAck}
protoMsgTypesAck = []byte{msgTypeAck}
) )
type HandshakeError struct { type HandshakeError struct {
@ -38,17 +46,20 @@ var (
ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed} ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed}
ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected} ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected}
ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion} ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion}
ErrIncompatibleProto = HandshakeError{e: handshakeproto.Error_IncompatibleProto}
ErrRemoteIncompatibleProto = HandshakeError{Err: errors.New("remote peer declined the proto")}
ErrGotNotAHandshakeMessage = errors.New("go not a handshake message") ErrGotUnexpectedMessage = errors.New("go not a handshake message")
) )
var handshakePool = &sync.Pool{New: func() any { var handshakePool = &sync.Pool{New: func() any {
return &handshake{ return &handshake{
remoteCred: &handshakeproto.Credentials{}, remoteCred: &handshakeproto.Credentials{},
remoteAck: &handshakeproto.Ack{}, remoteAck: &handshakeproto.Ack{},
localAck: &handshakeproto.Ack{}, localAck: &handshakeproto.Ack{},
buf: make([]byte, 0, 1024), remoteProto: &handshakeproto.Proto{},
buf: make([]byte, 0, 1024),
} }
}} }}
@ -57,147 +68,17 @@ type CredentialChecker interface {
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
} }
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = outgoingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
localCred := cc.MakeCredentials(sc)
if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err)
return
}
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error == handshakeproto.Error_Null {
return identity, nil
} else {
_ = h.conn.Close()
return nil, HandshakeError{e: msg.ack.Error}
}
}
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = incomingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
return nil, ErrUnexpectedPayload
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error != handshakeproto.Error_Null {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
return
}
func newHandshake() *handshake { func newHandshake() *handshake {
return handshakePool.Get().(*handshake) return handshakePool.Get().(*handshake)
} }
type handshake struct { type handshake struct {
conn sec.SecureConn conn sec.SecureConn
remoteCred *handshakeproto.Credentials remoteCred *handshakeproto.Credentials
remoteAck *handshakeproto.Ack remoteProto *handshakeproto.Proto
localAck *handshakeproto.Ack remoteAck *handshakeproto.Ack
buf []byte localAck *handshakeproto.Ack
buf []byte
} }
func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err error) { func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err error) {
@ -209,8 +90,17 @@ func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err erro
return h.writeData(msgTypeCred, n) return h.writeData(msgTypeCred, n)
} }
func (h *handshake) writeProto(proto *handshakeproto.Proto) (err error) {
h.buf = slices.Grow(h.buf, proto.Size()+headerSize)[:proto.Size()+headerSize]
n, err := proto.MarshalToSizedBuffer(h.buf[headerSize:])
if err != nil {
return err
}
return h.writeData(msgTypeProto, n)
}
func (h *handshake) tryWriteErrAndClose(err error) { func (h *handshake) tryWriteErrAndClose(err error) {
if err == ErrGotNotAHandshakeMessage { if err == ErrUnexpectedPayload {
// if we got unexpected message - just close the connection // if we got unexpected message - just close the connection
_ = h.conn.Close() _ = h.conn.Close()
return return
@ -243,21 +133,26 @@ func (h *handshake) writeData(tp byte, size int) (err error) {
} }
type message struct { type message struct {
cred *handshakeproto.Credentials cred *handshakeproto.Credentials
ack *handshakeproto.Ack proto *handshakeproto.Proto
ack *handshakeproto.Ack
} }
func (h *handshake) readMsg() (msg message, err error) { func (h *handshake) readMsg(allowedTypes ...byte) (msg message, err error) {
h.buf = slices.Grow(h.buf, headerSize)[:headerSize] h.buf = slices.Grow(h.buf, headerSize)[:headerSize]
if _, err = io.ReadFull(h.conn, h.buf[:headerSize]); err != nil { if _, err = io.ReadFull(h.conn, h.buf[:headerSize]); err != nil {
return return
} }
tp := h.buf[0] tp := h.buf[0]
if tp != msgTypeCred && tp != msgTypeAck { if !slices.Contains(allowedTypes, tp) {
err = ErrGotNotAHandshakeMessage err = ErrUnexpectedPayload
return return
} }
size := binary.LittleEndian.Uint32(h.buf[1:headerSize]) size := binary.LittleEndian.Uint32(h.buf[1:headerSize])
if size > sizeLimit {
err = ErrGotUnexpectedMessage
return
}
h.buf = slices.Grow(h.buf, int(size))[:size] h.buf = slices.Grow(h.buf, int(size))[:size]
if _, err = io.ReadFull(h.conn, h.buf[:size]); err != nil { if _, err = io.ReadFull(h.conn, h.buf[:size]); err != nil {
return return
@ -273,6 +168,11 @@ func (h *handshake) readMsg() (msg message, err error) {
return return
} }
msg.ack = h.remoteAck msg.ack = h.remoteAck
case msgTypeProto:
if err = h.remoteProto.Unmarshal(h.buf[:size]); err != nil {
return
}
msg.proto = h.remoteProto
} }
return return
} }
@ -284,5 +184,6 @@ func (h *handshake) release() {
h.remoteAck.Error = 0 h.remoteAck.Error = 0
h.remoteCred.Type = 0 h.remoteCred.Type = 0
h.remoteCred.Payload = h.remoteCred.Payload[:0] h.remoteCred.Payload = h.remoteCred.Payload[:0]
h.remoteProto.Proto = 0
handshakePool.Put(h) handshakePool.Put(h)
} }

View File

@ -59,6 +59,7 @@ const (
Error_SkipVerifyNotAllowed Error = 4 Error_SkipVerifyNotAllowed Error = 4
Error_DeadlineExceeded Error = 5 Error_DeadlineExceeded Error = 5
Error_IncompatibleVersion Error = 6 Error_IncompatibleVersion Error = 6
Error_IncompatibleProto Error = 7
) )
var Error_name = map[int32]string{ var Error_name = map[int32]string{
@ -69,6 +70,7 @@ var Error_name = map[int32]string{
4: "SkipVerifyNotAllowed", 4: "SkipVerifyNotAllowed",
5: "DeadlineExceeded", 5: "DeadlineExceeded",
6: "IncompatibleVersion", 6: "IncompatibleVersion",
7: "IncompatibleProto",
} }
var Error_value = map[string]int32{ var Error_value = map[string]int32{
@ -79,6 +81,7 @@ var Error_value = map[string]int32{
"SkipVerifyNotAllowed": 4, "SkipVerifyNotAllowed": 4,
"DeadlineExceeded": 5, "DeadlineExceeded": 5,
"IncompatibleVersion": 6, "IncompatibleVersion": 6,
"IncompatibleProto": 7,
} }
func (x Error) String() string { func (x Error) String() string {
@ -89,6 +92,28 @@ func (Error) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{1} return fileDescriptor_60283fc75f020893, []int{1}
} }
type ProtoType int32
const (
ProtoType_DRPC ProtoType = 0
)
var ProtoType_name = map[int32]string{
0: "DRPC",
}
var ProtoType_value = map[string]int32{
"DRPC": 0,
}
func (x ProtoType) String() string {
return proto.EnumName(ProtoType_name, int32(x))
}
func (ProtoType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{2}
}
type Credentials struct { type Credentials struct {
Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"` Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"`
Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
@ -247,12 +272,58 @@ func (m *Ack) GetError() Error {
return Error_Null return Error_Null
} }
type Proto struct {
Proto ProtoType `protobuf:"varint,1,opt,name=proto,proto3,enum=anyHandshake.ProtoType" json:"proto,omitempty"`
}
func (m *Proto) Reset() { *m = Proto{} }
func (m *Proto) String() string { return proto.CompactTextString(m) }
func (*Proto) ProtoMessage() {}
func (*Proto) Descriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{3}
}
func (m *Proto) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Proto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Proto.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Proto) XXX_Merge(src proto.Message) {
xxx_messageInfo_Proto.Merge(m, src)
}
func (m *Proto) XXX_Size() int {
return m.Size()
}
func (m *Proto) XXX_DiscardUnknown() {
xxx_messageInfo_Proto.DiscardUnknown(m)
}
var xxx_messageInfo_Proto proto.InternalMessageInfo
func (m *Proto) GetProto() ProtoType {
if m != nil {
return m.Proto
}
return ProtoType_DRPC
}
func init() { func init() {
proto.RegisterEnum("anyHandshake.CredentialsType", CredentialsType_name, CredentialsType_value) proto.RegisterEnum("anyHandshake.CredentialsType", CredentialsType_name, CredentialsType_value)
proto.RegisterEnum("anyHandshake.Error", Error_name, Error_value) proto.RegisterEnum("anyHandshake.Error", Error_name, Error_value)
proto.RegisterEnum("anyHandshake.ProtoType", ProtoType_name, ProtoType_value)
proto.RegisterType((*Credentials)(nil), "anyHandshake.Credentials") proto.RegisterType((*Credentials)(nil), "anyHandshake.Credentials")
proto.RegisterType((*PayloadSignedPeerIds)(nil), "anyHandshake.PayloadSignedPeerIds") proto.RegisterType((*PayloadSignedPeerIds)(nil), "anyHandshake.PayloadSignedPeerIds")
proto.RegisterType((*Ack)(nil), "anyHandshake.Ack") proto.RegisterType((*Ack)(nil), "anyHandshake.Ack")
proto.RegisterType((*Proto)(nil), "anyHandshake.Proto")
} }
func init() { func init() {
@ -260,32 +331,35 @@ func init() {
} }
var fileDescriptor_60283fc75f020893 = []byte{ var fileDescriptor_60283fc75f020893 = []byte{
// 395 bytes of a gzipped FileDescriptorProto // 439 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcd, 0x6e, 0x13, 0x31, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcf, 0x6e, 0xd3, 0x40,
0x10, 0xc7, 0xd7, 0x4d, 0x52, 0xaa, 0x21, 0x2d, 0xee, 0x34, 0xc0, 0x0a, 0x89, 0x55, 0x94, 0x53, 0x10, 0xc6, 0xbd, 0x4d, 0xdc, 0x96, 0x21, 0x2d, 0xdb, 0x6d, 0x4a, 0x2d, 0x24, 0xac, 0x28, 0xa7,
0xc8, 0x21, 0xe1, 0xeb, 0x05, 0x02, 0x2d, 0x22, 0x97, 0xaa, 0xda, 0x42, 0x0f, 0xdc, 0xdc, 0xf5, 0x10, 0x89, 0x84, 0x7f, 0xe2, 0x1e, 0x9a, 0x22, 0x72, 0xa9, 0x22, 0x17, 0x7a, 0xe0, 0xb6, 0xf5,
0xd0, 0x5a, 0x31, 0xf6, 0xca, 0x76, 0x43, 0xf7, 0x2d, 0xb8, 0xf2, 0x46, 0x1c, 0x7b, 0xe4, 0x88, 0x0e, 0xed, 0x2a, 0xcb, 0xda, 0xda, 0xdd, 0x86, 0xfa, 0x2d, 0x78, 0x14, 0x1e, 0x83, 0x63, 0x8f,
0x92, 0x17, 0x41, 0x71, 0x12, 0x92, 0x70, 0xea, 0xc5, 0x9e, 0x8f, 0x9f, 0xfd, 0xff, 0x8f, 0x65, 0x1c, 0x51, 0xf2, 0x22, 0xc8, 0x9b, 0xa4, 0x71, 0x38, 0xf5, 0x62, 0xef, 0xcc, 0xfc, 0xc6, 0xdf,
0x18, 0x1a, 0x0a, 0x03, 0x4f, 0xc5, 0x8d, 0x23, 0x4f, 0x6e, 0xa2, 0x0a, 0x1a, 0x5c, 0x0b, 0x23, 0x7c, 0xe3, 0x85, 0x81, 0x46, 0xd7, 0xb7, 0x98, 0xde, 0x18, 0xb4, 0x68, 0xa6, 0x32, 0xc5, 0xfe,
0xfd, 0xb5, 0x18, 0x6f, 0x44, 0xa5, 0xb3, 0xc1, 0x0e, 0xe2, 0xea, 0xd7, 0xd5, 0x7e, 0x2c, 0x60, 0x35, 0xd7, 0xc2, 0x5e, 0xf3, 0x49, 0xe5, 0x94, 0x9b, 0xcc, 0x65, 0x7d, 0xff, 0xb4, 0xeb, 0x6c,
0x53, 0x98, 0xea, 0xe3, 0xaa, 0xd6, 0x09, 0xf0, 0xf0, 0xbd, 0x23, 0x49, 0x26, 0x28, 0xa1, 0x3d, 0xcf, 0x27, 0x58, 0x83, 0xeb, 0xe2, 0xd3, 0x2a, 0xd7, 0x76, 0xf0, 0xf8, 0xc4, 0xa0, 0x40, 0xed,
0xbe, 0x82, 0x7a, 0xa8, 0x4a, 0x4a, 0x59, 0x9b, 0x75, 0x0f, 0x5e, 0x3f, 0xef, 0x6f, 0xb2, 0xfd, 0x24, 0x57, 0x96, 0xbd, 0x86, 0xba, 0x2b, 0x72, 0x8c, 0x48, 0x8b, 0x74, 0xf6, 0xdf, 0x3c, 0xef,
0x0d, 0xf0, 0x53, 0x55, 0x52, 0x1e, 0x51, 0x4c, 0xe1, 0x41, 0x29, 0x2a, 0x6d, 0x85, 0x4c, 0x77, 0x55, 0xd9, 0x5e, 0x05, 0xfc, 0x5c, 0xe4, 0x98, 0x78, 0x94, 0x45, 0xb0, 0x93, 0xf3, 0x42, 0x65,
0xda, 0xac, 0xdb, 0xcc, 0x57, 0xe9, 0xbc, 0x33, 0x21, 0xe7, 0x95, 0x35, 0x69, 0xad, 0xcd, 0xba, 0x5c, 0x44, 0x5b, 0x2d, 0xd2, 0x69, 0x24, 0xab, 0xb0, 0xac, 0x4c, 0xd1, 0x58, 0x99, 0xe9, 0xa8,
0xfb, 0xf9, 0x2a, 0xed, 0x7c, 0x80, 0xd6, 0xd9, 0x02, 0x3a, 0x57, 0x57, 0x86, 0xe4, 0x19, 0x91, 0xd6, 0x22, 0x9d, 0xbd, 0x64, 0x15, 0xb6, 0x3f, 0x42, 0x73, 0xbc, 0x80, 0xce, 0xe5, 0x95, 0x46,
0x1b, 0x49, 0x8f, 0xcf, 0x60, 0x4f, 0x45, 0x89, 0x50, 0x45, 0x0b, 0xcd, 0xfc, 0x5f, 0x8e, 0x08, 0x31, 0x46, 0x34, 0x23, 0x61, 0xd9, 0x33, 0xd8, 0x95, 0x5e, 0xc2, 0x15, 0x7e, 0x84, 0x46, 0x72,
0x75, 0xaf, 0xae, 0xcc, 0x52, 0x24, 0xc6, 0x9d, 0x97, 0x50, 0x1b, 0x16, 0x63, 0x7c, 0x01, 0x0d, 0x1f, 0x33, 0x06, 0x75, 0x2b, 0xaf, 0xf4, 0x52, 0xc4, 0x9f, 0xdb, 0xaf, 0xa0, 0x36, 0x48, 0x27,
0x72, 0xce, 0xba, 0xa5, 0xed, 0xa3, 0x6d, 0xdb, 0x27, 0xf3, 0x56, 0xbe, 0x20, 0x7a, 0x6f, 0xe1, 0xec, 0x05, 0x84, 0x68, 0x4c, 0x66, 0x96, 0x63, 0x1f, 0x6e, 0x8e, 0x7d, 0x5a, 0x96, 0x92, 0x05,
0xd1, 0x7f, 0x63, 0xe0, 0x01, 0xc0, 0xf9, 0x58, 0x95, 0x17, 0xe4, 0xd4, 0xd7, 0x8a, 0x27, 0x78, 0xd1, 0x7e, 0x0f, 0xe1, 0xd8, 0xaf, 0xe1, 0x25, 0x84, 0x7e, 0x1f, 0xcb, 0x9e, 0xe3, 0xcd, 0x1e,
0x08, 0xfb, 0x5b, 0xae, 0x38, 0xeb, 0xfd, 0x64, 0xd0, 0x88, 0xd7, 0xe0, 0x1e, 0xd4, 0x4f, 0x6f, 0xcf, 0x78, 0x93, 0x0b, 0xaa, 0xfb, 0x0e, 0x9e, 0xfc, 0x67, 0x9f, 0xed, 0x03, 0x9c, 0x4f, 0x64,
0xb4, 0xe6, 0xc9, 0xfc, 0xd8, 0x67, 0x43, 0xb7, 0x25, 0x15, 0x81, 0x24, 0x67, 0xf8, 0x04, 0x70, 0x7e, 0x81, 0x46, 0x7e, 0x2b, 0x68, 0xc0, 0x0e, 0x60, 0x6f, 0xc3, 0x0d, 0x25, 0xdd, 0x5f, 0x04,
0x64, 0x26, 0x42, 0x2b, 0xb9, 0x21, 0xc0, 0x77, 0xf0, 0x31, 0x1c, 0xae, 0xb9, 0xe5, 0xd4, 0xbc, 0x42, 0x2f, 0xcf, 0x76, 0xa1, 0x7e, 0x76, 0xa3, 0x14, 0x0d, 0xca, 0xb6, 0x2f, 0x1a, 0x6f, 0x73,
0x86, 0x29, 0xb4, 0xd6, 0xaa, 0xa7, 0x36, 0x0c, 0xb5, 0xb6, 0xdf, 0x49, 0xf2, 0x3a, 0xb6, 0x80, 0x4c, 0x1d, 0x0a, 0x4a, 0xd8, 0x53, 0x60, 0x23, 0x3d, 0xe5, 0x4a, 0x8a, 0x8a, 0x00, 0xdd, 0x62,
0x1f, 0x93, 0x90, 0x5a, 0x19, 0x3a, 0xb9, 0x2d, 0x88, 0x24, 0x49, 0xde, 0xc0, 0xa7, 0x70, 0x34, 0x47, 0x70, 0xb0, 0xe6, 0x96, 0xdb, 0xa2, 0x35, 0x16, 0x41, 0x73, 0xad, 0x7a, 0x96, 0xb9, 0x81,
0x32, 0x85, 0xfd, 0x56, 0x8a, 0xa0, 0x2e, 0x35, 0x5d, 0x2c, 0x5e, 0x92, 0xef, 0xbe, 0x3b, 0xfe, 0x52, 0xd9, 0x0f, 0x14, 0xb4, 0xce, 0x9a, 0x40, 0x87, 0xc8, 0x85, 0x92, 0x1a, 0x4f, 0x6f, 0x53,
0x35, 0xcd, 0xd8, 0xdd, 0x34, 0x63, 0x7f, 0xa6, 0x19, 0xfb, 0x31, 0xcb, 0x92, 0xbb, 0x59, 0x96, 0x44, 0x81, 0x82, 0x86, 0xec, 0x18, 0x0e, 0x47, 0x3a, 0xcd, 0xbe, 0xe7, 0xdc, 0xc9, 0x4b, 0x85,
0xfc, 0x9e, 0x65, 0xc9, 0x97, 0xde, 0xfd, 0x3f, 0xcb, 0xe5, 0x6e, 0xdc, 0xde, 0xfc, 0x0d, 0x00, 0x17, 0x8b, 0x3f, 0x40, 0xb7, 0xcb, 0xef, 0x57, 0x0b, 0xde, 0x31, 0xdd, 0xe9, 0x1e, 0xc1, 0xa3,
0x00, 0xff, 0xff, 0xbf, 0x78, 0x2f, 0x36, 0x61, 0x02, 0x00, 0x00, 0x7b, 0xf3, 0xe5, 0xd4, 0xc3, 0x64, 0x7c, 0x42, 0x83, 0x0f, 0xc3, 0xdf, 0xb3, 0x98, 0xdc, 0xcd,
0x62, 0xf2, 0x77, 0x16, 0x93, 0x9f, 0xf3, 0x38, 0xb8, 0x9b, 0xc7, 0xc1, 0x9f, 0x79, 0x1c, 0x7c,
0xed, 0x3e, 0xfc, 0x4a, 0x5e, 0x6e, 0xfb, 0xd7, 0xdb, 0x7f, 0x01, 0x00, 0x00, 0xff, 0xff, 0x53,
0x32, 0xf7, 0x79, 0xc7, 0x02, 0x00, 0x00,
} }
func (m *Credentials) Marshal() (dAtA []byte, err error) { func (m *Credentials) Marshal() (dAtA []byte, err error) {
@ -393,6 +467,34 @@ func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) {
return len(dAtA) - i, nil return len(dAtA) - i, nil
} }
func (m *Proto) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Proto) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Proto) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if m.Proto != 0 {
i = encodeVarintHandshake(dAtA, i, uint64(m.Proto))
i--
dAtA[i] = 0x8
}
return len(dAtA) - i, nil
}
func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int { func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int {
offset -= sovHandshake(v) offset -= sovHandshake(v)
base := offset base := offset
@ -452,6 +554,18 @@ func (m *Ack) Size() (n int) {
return n return n
} }
func (m *Proto) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
if m.Proto != 0 {
n += 1 + sovHandshake(uint64(m.Proto))
}
return n
}
func sovHandshake(x uint64) (n int) { func sovHandshake(x uint64) (n int) {
return (math_bits.Len64(x|1) + 6) / 7 return (math_bits.Len64(x|1) + 6) / 7
} }
@ -767,6 +881,75 @@ func (m *Ack) Unmarshal(dAtA []byte) error {
} }
return nil return nil
} }
func (m *Proto) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Proto: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Proto: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Proto", wireType)
}
m.Proto = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Proto |= ProtoType(b&0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := skipHandshake(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return ErrInvalidLengthHandshake
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipHandshake(dAtA []byte) (n int, err error) { func skipHandshake(dAtA []byte) (n int, err error) {
l := len(dAtA) l := len(dAtA)
iNdEx := 0 iNdEx := 0

View File

@ -5,6 +5,8 @@ option go_package = "net/secureservice/handshake/handshakeproto";
/* /*
CREDENTIALS HANDSHAKE
Alice opens a new connection with Bob Alice opens a new connection with Bob
1. TLS handshake done successfully; both sides know local and remote peer identifiers. 1. TLS handshake done successfully; both sides know local and remote peer identifiers.
@ -68,4 +70,20 @@ enum Error {
SkipVerifyNotAllowed = 4; SkipVerifyNotAllowed = 4;
DeadlineExceeded = 5; DeadlineExceeded = 5;
IncompatibleVersion = 6; IncompatibleVersion = 6;
IncompatibleProto = 7;
}
/*
PROTO HANDSHAKE
*/
message Proto {
ProtoType proto = 1;
}
enum ProtoType {
DRPC = 0;
} }

View File

@ -0,0 +1,97 @@
package handshake
import (
"context"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices"
)
type ProtoChecker struct {
AllowedProtoTypes []handshakeproto.ProtoType
}
func OutgoingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt handshakeproto.ProtoType) (err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
err = outgoingProtoHandshake(h, sc, pt)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return ctx.Err()
}
}
func outgoingProtoHandshake(h *handshake, sc sec.SecureConn, pt handshakeproto.ProtoType) (err error) {
defer h.release()
h.conn = sc
localProto := &handshakeproto.Proto{
Proto: pt,
}
if err = h.writeProto(localProto); err != nil {
h.tryWriteErrAndClose(err)
return
}
msg, err := h.readMsg(msgTypeAck)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack.Error == handshakeproto.Error_IncompatibleProto {
return ErrRemoteIncompatibleProto
}
if msg.ack.Error == handshakeproto.Error_Null {
return nil
}
return HandshakeError{e: msg.ack.Error}
}
func IncomingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
protoType, err = incomingProtoHandshake(h, sc, pt)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return 0, ctx.Err()
}
}
func incomingProtoHandshake(h *handshake, sc sec.SecureConn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
defer h.release()
h.conn = sc
msg, err := h.readMsg(msgTypeProto)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if !slices.Contains(pt.AllowedProtoTypes, msg.proto.Proto) {
err = ErrIncompatibleProto
h.tryWriteErrAndClose(err)
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return 0, err
} else {
return msg.proto.Proto, nil
}
}

View File

@ -0,0 +1,121 @@
package handshake
import (
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
)
type protoRes struct {
protoType handshakeproto.ProtoType
err error
}
func newProtoChecker(types ...handshakeproto.ProtoType) ProtoChecker {
return ProtoChecker{AllowedProtoTypes: types}
}
func TestIncomingProtoHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1))
protoResCh <- protoRes{protoType: protoType, err: err}
}()
h := newHandshake()
h.conn = c2
// write desired proto
require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: handshakeproto.ProtoType(1)}))
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
res := <-protoResCh
require.NoError(t, res.err)
assert.Equal(t, handshakeproto.ProtoType(1), res.protoType)
})
t.Run("incompatible", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1))
protoResCh <- protoRes{protoType: protoType, err: err}
}()
h := newHandshake()
h.conn = c2
// write desired proto
require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: 0}))
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_IncompatibleProto, msg.ack.Error)
res := <-protoResCh
require.Error(t, res.err, ErrIncompatibleProto.Error())
})
}
func TestOutgoingProtoHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
err := OutgoingProtoHandshake(nil, c1, 1)
protoResCh <- protoRes{err: err}
}()
h := newHandshake()
h.conn = c2
msg, err := h.readMsg(msgTypeProto)
require.NoError(t, err)
assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto)
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
res := <-protoResCh
assert.NoError(t, res.err)
})
t.Run("incompatible", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
err := OutgoingProtoHandshake(nil, c1, 1)
protoResCh <- protoRes{err: err}
}()
h := newHandshake()
h.conn = c2
msg, err := h.readMsg(msgTypeProto)
require.NoError(t, err)
assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto)
require.NoError(t, h.writeAck(handshakeproto.Error_IncompatibleProto))
res := <-protoResCh
assert.EqualError(t, res.err, ErrRemoteIncompatibleProto.Error())
})
}
func TestEndToEndProto(t *testing.T) {
c1, c2 := newConnPair(t)
var (
inResCh = make(chan protoRes, 1)
outResCh = make(chan protoRes, 1)
)
st := time.Now()
go func() {
err := OutgoingProtoHandshake(nil, c1, 0)
outResCh <- protoRes{err: err}
}()
go func() {
protoType, err := IncomingProtoHandshake(nil, c2, newProtoChecker(0, 1))
inResCh <- protoRes{protoType: protoType, err: err}
}()
outRes := <-outResCh
assert.NoError(t, outRes.err)
inRes := <-inResCh
assert.NoError(t, inRes.err)
assert.Equal(t, handshakeproto.ProtoType(0), inRes.protoType)
t.Log("dur", time.Since(st))
}

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT. // Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32 // protoc-gen-go-drpc version: v0.0.33
// source: net/streampool/testservice/protos/testservice.proto // source: net/streampool/testservice/protos/testservice.proto
package testservice package testservice
@ -72,6 +72,10 @@ type drpcTest_TestStreamClient struct {
drpc.Stream drpc.Stream
} }
func (x *drpcTest_TestStreamClient) GetStream() drpc.Stream {
return x.Stream
}
func (x *drpcTest_TestStreamClient) Send(m *StreamMessage) error { func (x *drpcTest_TestStreamClient) Send(m *StreamMessage) error {
return x.MsgSend(m, drpcEncoding_File_net_streampool_testservice_protos_testservice_proto{}) return x.MsgSend(m, drpcEncoding_File_net_streampool_testservice_protos_testservice_proto{})
} }