2023-05-22 18:00:30 +02:00

289 lines
6.9 KiB
Go

package handshake
import (
"context"
"encoding/binary"
"errors"
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices"
"io"
"sync"
)
const headerSize = 5 // 1 byte for type + 4 byte for uint32 size
const (
msgTypeCred = byte(1)
msgTypeAck = byte(2)
)
type HandshakeError struct {
err error
e handshakeproto.Error
}
func (he HandshakeError) Error() string {
if he.err != nil {
return he.err.Error()
}
return he.e.String()
}
var (
ErrUnexpectedPayload = HandshakeError{e: handshakeproto.Error_UnexpectedPayload}
ErrDeadlineExceeded = HandshakeError{e: handshakeproto.Error_DeadlineExceeded}
ErrInvalidCredentials = HandshakeError{e: handshakeproto.Error_InvalidCredentials}
ErrPeerDeclinedCredentials = HandshakeError{err: errors.New("remote peer declined the credentials")}
ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed}
ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected}
ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion}
ErrGotNotAHandshakeMessage = errors.New("go not a handshake message")
)
var handshakePool = &sync.Pool{New: func() any {
return &handshake{
remoteCred: &handshakeproto.Credentials{},
remoteAck: &handshakeproto.Ack{},
localAck: &handshakeproto.Ack{},
buf: make([]byte, 0, 1024),
}
}}
type CredentialChecker interface {
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
}
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = outgoingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
localCred := cc.MakeCredentials(sc)
if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err)
return
}
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error == handshakeproto.Error_Null {
return identity, nil
} else {
_ = h.conn.Close()
return nil, HandshakeError{e: msg.ack.Error}
}
}
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = incomingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
return nil, ErrUnexpectedPayload
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error != handshakeproto.Error_Null {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
return
}
func newHandshake() *handshake {
return handshakePool.Get().(*handshake)
}
type handshake struct {
conn sec.SecureConn
remoteCred *handshakeproto.Credentials
remoteAck *handshakeproto.Ack
localAck *handshakeproto.Ack
buf []byte
}
func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err error) {
h.buf = slices.Grow(h.buf, cred.Size()+headerSize)[:cred.Size()+headerSize]
n, err := cred.MarshalToSizedBuffer(h.buf[headerSize:])
if err != nil {
return err
}
return h.writeData(msgTypeCred, n)
}
func (h *handshake) tryWriteErrAndClose(err error) {
if err == ErrGotNotAHandshakeMessage {
// if we got unexpected message - just close the connection
_ = h.conn.Close()
return
}
var ackErr handshakeproto.Error
if he, ok := err.(HandshakeError); ok {
ackErr = he.e
} else {
ackErr = handshakeproto.Error_Unexpected
}
_ = h.writeAck(ackErr)
_ = h.conn.Close()
}
func (h *handshake) writeAck(ackErr handshakeproto.Error) (err error) {
h.localAck.Error = ackErr
h.buf = slices.Grow(h.buf, h.localAck.Size()+headerSize)[:h.localAck.Size()+headerSize]
n, err := h.localAck.MarshalTo(h.buf[headerSize:])
if err != nil {
return err
}
return h.writeData(msgTypeAck, n)
}
func (h *handshake) writeData(tp byte, size int) (err error) {
h.buf[0] = tp
binary.LittleEndian.PutUint32(h.buf[1:headerSize], uint32(size))
_, err = h.conn.Write(h.buf[:size+headerSize])
return err
}
type message struct {
cred *handshakeproto.Credentials
ack *handshakeproto.Ack
}
func (h *handshake) readMsg() (msg message, err error) {
h.buf = slices.Grow(h.buf, headerSize)[:headerSize]
if _, err = io.ReadFull(h.conn, h.buf[:headerSize]); err != nil {
return
}
tp := h.buf[0]
if tp != msgTypeCred && tp != msgTypeAck {
err = ErrGotNotAHandshakeMessage
return
}
size := binary.LittleEndian.Uint32(h.buf[1:headerSize])
h.buf = slices.Grow(h.buf, int(size))[:size]
if _, err = io.ReadFull(h.conn, h.buf[:size]); err != nil {
return
}
switch tp {
case msgTypeCred:
if err = h.remoteCred.Unmarshal(h.buf[:size]); err != nil {
return
}
msg.cred = h.remoteCred
case msgTypeAck:
if err = h.remoteAck.Unmarshal(h.buf[:size]); err != nil {
return
}
msg.ack = h.remoteAck
}
return
}
func (h *handshake) release() {
h.buf = h.buf[:0]
h.conn = nil
h.localAck.Error = 0
h.remoteAck.Error = 0
h.remoteCred.Type = 0
h.remoteCred.Payload = h.remoteCred.Payload[:0]
handshakePool.Put(h)
}