Merge pull request #24 from anytypeio/any-handshake

Any handshake
This commit is contained in:
Sergey Cherepanov 2023-02-17 11:15:40 +03:00 committed by GitHub
commit 96f399293d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 2184 additions and 124 deletions

View File

@ -16,6 +16,7 @@ proto:
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. commonspace/spacesyncproto/protos/*.proto protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. commonspace/spacesyncproto/protos/*.proto
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. commonfile/fileproto/protos/*.proto protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. commonfile/fileproto/protos/*.proto
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. net/streampool/testservice/protos/*.proto protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. net/streampool/testservice/protos/*.proto
protoc --gogofaster_out=:. net/secureservice/handshake/handshakeproto/protos/*.proto
deps: deps:
go mod download go mod download

View File

@ -74,6 +74,9 @@ func (d *dialer) SetPeerAddrs(peerId string, addrs []string) {
} }
func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err error) { func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err error) {
var ctxCancel context.CancelFunc
ctx, ctxCancel = context.WithTimeout(ctx, time.Second*10)
defer ctxCancel()
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()
@ -109,7 +112,7 @@ func (d *dialer) handshake(ctx context.Context, addr string) (conn drpc.Conn, sc
} }
timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds)) timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds))
sc, err = d.transport.TLSConn(ctx, timeoutConn) sc, err = d.transport.SecureOutbound(ctx, timeoutConn)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st)) return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st))
} }

View File

@ -11,9 +11,13 @@ type contextKey uint
const ( const (
contextKeyPeerId contextKey = iota contextKeyPeerId contextKey = iota
contextKeyIdentity
) )
var ErrPeerIdNotFoundInContext = errors.New("peer id not found in context") var (
ErrPeerIdNotFoundInContext = errors.New("peer id not found in context")
ErrIdentityNotFoundInContext = errors.New("identity not found in context")
)
// CtxPeerId first tries to get peer id under our own key, but if it is not found tries to get through DRPC key // CtxPeerId first tries to get peer id under our own key, but if it is not found tries to get through DRPC key
func CtxPeerId(ctx context.Context) (string, error) { func CtxPeerId(ctx context.Context) (string, error) {
@ -30,3 +34,16 @@ func CtxPeerId(ctx context.Context) (string, error) {
func CtxWithPeerId(ctx context.Context, peerId string) context.Context { func CtxWithPeerId(ctx context.Context, peerId string) context.Context {
return context.WithValue(ctx, contextKeyPeerId, peerId) return context.WithValue(ctx, contextKeyPeerId, peerId)
} }
// CtxIdentity returns identity from context
func CtxIdentity(ctx context.Context) ([]byte, error) {
if identity, ok := ctx.Value(contextKeyIdentity).([]byte); ok {
return identity, nil
}
return nil, ErrIdentityNotFoundInContext
}
// CtxWithIdentity sets identity in the context
func CtxWithIdentity(ctx context.Context, identity []byte) context.Context {
return context.WithValue(ctx, contextKeyIdentity, identity)
}

View File

@ -3,6 +3,7 @@ package server
import ( import (
"context" "context"
"github.com/anytypeio/any-sync/net/secureservice" "github.com/anytypeio/any-sync/net/secureservice"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/zeebo/errs" "github.com/zeebo/errs"
"go.uber.org/zap" "go.uber.org/zap"
"io" "io"
@ -18,19 +19,18 @@ import (
type BaseDrpcServer struct { type BaseDrpcServer struct {
drpcServer *drpcserver.Server drpcServer *drpcserver.Server
transport secureservice.SecureService transport secureservice.SecureService
listeners []secureservice.ContextListener listeners []net.Listener
handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
cancel func() cancel func()
*drpcmux.Mux *drpcmux.Mux
} }
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
type ListenerConverter func(listener net.Listener, timeoutMillis int) secureservice.ContextListener
type Params struct { type Params struct {
BufferSizeMb int BufferSizeMb int
ListenAddrs []string ListenAddrs []string
Wrapper DRPCHandlerWrapper Wrapper DRPCHandlerWrapper
Converter ListenerConverter
TimeoutMillis int TimeoutMillis int
} }
@ -44,18 +44,17 @@ func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) {
}}) }})
ctx, s.cancel = context.WithCancel(ctx) ctx, s.cancel = context.WithCancel(ctx)
for _, addr := range params.ListenAddrs { for _, addr := range params.ListenAddrs {
tcpList, err := net.Listen("tcp", addr) list, err := net.Listen("tcp", addr)
if err != nil { if err != nil {
return err return err
} }
tlsList := params.Converter(tcpList, params.TimeoutMillis) s.listeners = append(s.listeners, list)
s.listeners = append(s.listeners, tlsList) go s.serve(ctx, list)
go s.serve(ctx, tlsList)
} }
return return
} }
func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextListener) { func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) {
l := log.With(zap.String("localAddr", lis.Addr().String())) l := log.With(zap.String("localAddr", lis.Addr().String()))
l.Info("drpc listener started") l.Info("drpc listener started")
defer func() { defer func() {
@ -67,7 +66,7 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
return return
default: default:
} }
cctx, conn, err := lis.Accept(ctx) conn, err := lis.Accept()
if err != nil { if err != nil {
if isTemporary(err) { if isTemporary(err) {
l.Debug("listener temporary accept error", zap.Error(err)) l.Debug("listener temporary accept error", zap.Error(err))
@ -85,12 +84,23 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis
l.Error("listener accept error", zap.Error(err)) l.Error("listener accept error", zap.Error(err))
return return
} }
go s.serveConn(cctx, conn) go s.serveConn(conn)
} }
} }
func (s *BaseDrpcServer) serveConn(ctx context.Context, conn net.Conn) { func (s *BaseDrpcServer) serveConn(conn net.Conn) {
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String())) l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
var (
ctx = context.Background()
err error
)
if s.handshake != nil {
ctx, conn, err = s.handshake(conn)
if err != nil {
l.Info("handshake error", zap.Error(err))
}
}
l.Debug("connection opened") l.Debug("connection opened")
if err := s.drpcServer.ServeOne(ctx, conn); err != nil { if err := s.drpcServer.ServeOne(ctx, conn); err != nil {
if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) { if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) {

View File

@ -5,10 +5,13 @@ import (
"github.com/anytypeio/any-sync/app" "github.com/anytypeio/any-sync/app"
"github.com/anytypeio/any-sync/app/logger" "github.com/anytypeio/any-sync/app/logger"
"github.com/anytypeio/any-sync/metric" "github.com/anytypeio/any-sync/metric"
"github.com/anytypeio/any-sync/net" anyNet "github.com/anytypeio/any-sync/net"
"github.com/anytypeio/any-sync/net/secureservice" "github.com/anytypeio/any-sync/net/secureservice"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"net"
"storj.io/drpc" "storj.io/drpc"
"time"
) )
const CName = "common.net.drpcserver" const CName = "common.net.drpcserver"
@ -25,14 +28,14 @@ type DRPCServer interface {
} }
type drpcServer struct { type drpcServer struct {
config net.Config config anyNet.Config
metric metric.Metric metric metric.Metric
transport secureservice.SecureService transport secureservice.SecureService
*BaseDrpcServer *BaseDrpcServer
} }
func (s *drpcServer) Init(a *app.App) (err error) { func (s *drpcServer) Init(a *app.App) (err error) {
s.config = a.MustComponent("config").(net.ConfigGetter).GetNet() s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet()
s.metric = a.MustComponent(metric.CName).(metric.Metric) s.metric = a.MustComponent(metric.CName).(metric.Metric)
s.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService) s.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService)
return nil return nil
@ -67,7 +70,11 @@ func (s *drpcServer) Run(ctx context.Context) (err error) {
SummaryVec: histVec, SummaryVec: histVec,
} }
}, },
Converter: s.transport.TLSListener, }
s.handshake = func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return s.transport.SecureInbound(ctx, conn)
} }
return s.BaseDrpcServer.Run(ctx, params) return s.BaseDrpcServer.Run(ctx, params)
} }

View File

@ -1,26 +0,0 @@
package secureservice
import (
"context"
"github.com/anytypeio/any-sync/net/timeoutconn"
"net"
"time"
)
type basicListener struct {
net.Listener
timeoutMillis int
}
func newBasicListener(listener net.Listener, timeoutMillis int) ContextListener {
return &basicListener{listener, timeoutMillis}
}
func (b *basicListener) Accept(ctx context.Context) (context.Context, net.Conn, error) {
conn, err := b.Listener.Accept()
if err != nil {
return nil, nil, err
}
timeoutConn := timeoutconn.NewConn(conn, time.Duration(b.timeoutMillis)*time.Millisecond)
return ctx, timeoutConn, err
}

View File

@ -0,0 +1,72 @@
package secureservice
import (
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
"github.com/anytypeio/any-sync/net/secureservice/handshake"
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
"github.com/libp2p/go-libp2p/core/sec"
"go.uber.org/zap"
)
func newNoVerifyChecker() handshake.CredentialChecker {
return &noVerifyChecker{cred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify}}
}
type noVerifyChecker struct {
cred *handshakeproto.Credentials
}
func (n noVerifyChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
return n.cred
}
func (n noVerifyChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
return nil, nil
}
func newPeerSignVerifier(account *accountdata.AccountData) handshake.CredentialChecker {
return &peerSignVerifier{account: account}
}
type peerSignVerifier struct {
account *accountdata.AccountData
}
func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + sc.RemotePeer().String()))
if err != nil {
log.Warn("can't sign identity credentials", zap.Error(err))
}
msg := &handshakeproto.PayloadSignedPeerIds{
Identity: p.account.Identity,
Sign: sign,
}
payload, _ := msg.Marshal()
return &handshakeproto.Credentials{
Type: handshakeproto.CredentialsType_SignedPeerIds,
Payload: payload,
}
}
func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
if cred.Type != handshakeproto.CredentialsType_SignedPeerIds {
return nil, handshake.ErrSkipVerifyNotAllowed
}
var msg = &handshakeproto.PayloadSignedPeerIds{}
if err = msg.Unmarshal(cred.Payload); err != nil {
return nil, handshake.ErrUnexpectedPayload
}
pubKey, err := signingkey.NewSigningEd25519PubKeyFromBytes(msg.Identity)
if err != nil {
return nil, handshake.ErrInvalidCredentials
}
ok, err := pubKey.Verify([]byte((sc.RemotePeer().String() + p.account.PeerId)), msg.Sign)
if err != nil {
return nil, err
}
if !ok {
return nil, handshake.ErrInvalidCredentials
}
return msg.Identity, nil
}

View File

@ -0,0 +1,77 @@
package secureservice
import (
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
"github.com/anytypeio/any-sync/net/secureservice/handshake"
"github.com/anytypeio/any-sync/testutil/accounttest"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net"
"testing"
)
func TestPeerSignVerifier_CheckCredential(t *testing.T) {
a1 := newTestAccData(t)
a2 := newTestAccData(t)
cc1 := newPeerSignVerifier(a1)
cc2 := newPeerSignVerifier(a2)
c1 := newTestSC(a2.PeerId)
c2 := newTestSC(a1.PeerId)
cr1 := cc1.MakeCredentials(c1)
cr2 := cc2.MakeCredentials(c2)
id1, err := cc1.CheckCredential(c1, cr2)
assert.NoError(t, err)
assert.Equal(t, a2.Identity, id1)
id2, err := cc2.CheckCredential(c2, cr1)
assert.NoError(t, err)
assert.Equal(t, a1.Identity, id2)
_, err = cc1.CheckCredential(c1, cr1)
assert.EqualError(t, err, handshake.ErrInvalidCredentials.Error())
}
func newTestAccData(t *testing.T) *accountdata.AccountData {
as := accounttest.AccountTestService{}
require.NoError(t, as.Init(nil))
return as.Account()
}
func newTestSC(peerId string) sec.SecureConn {
pid, _ := peer.Decode(peerId)
return &testSc{
ID: pid,
}
}
type testSc struct {
net.Conn
peer.ID
}
func (t *testSc) LocalPeer() peer.ID {
return ""
}
func (t *testSc) LocalPrivateKey() crypto.PrivKey {
return nil
}
func (t *testSc) RemotePeer() peer.ID {
return t.ID
}
func (t *testSc) RemotePublicKey() crypto.PubKey {
return nil
}
func (t *testSc) ConnState() network.ConnectionState {
return network.ConnectionState{}
}

View File

@ -0,0 +1,282 @@
package handshake
import (
"context"
"encoding/binary"
"errors"
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices"
"io"
"sync"
)
const headerSize = 5 // 1 byte for type + 4 byte for uint32 size
const (
msgTypeCred = byte(1)
msgTypeAck = byte(2)
)
type handshakeError struct {
e handshakeproto.Error
}
func (he handshakeError) Error() string {
return he.e.String()
}
var (
ErrUnexpectedPayload = handshakeError{handshakeproto.Error_UnexpectedPayload}
ErrDeadlineExceeded = handshakeError{handshakeproto.Error_DeadlineExceeded}
ErrInvalidCredentials = handshakeError{handshakeproto.Error_InvalidCredentials}
ErrPeerDeclinedCredentials = errors.New("remote peer declined the credentials")
ErrSkipVerifyNotAllowed = handshakeError{handshakeproto.Error_SkipVerifyNotAllowed}
ErrUnexpected = handshakeError{handshakeproto.Error_Unexpected}
ErrGotNotAHandshakeMessage = errors.New("go not a handshake message")
)
var handshakePool = &sync.Pool{New: func() any {
return &handshake{
remoteCred: &handshakeproto.Credentials{},
remoteAck: &handshakeproto.Ack{},
localAck: &handshakeproto.Ack{},
buf: make([]byte, 0, 1024),
}
}}
type CredentialChecker interface {
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
}
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = outgoingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
localCred := cc.MakeCredentials(sc)
if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err)
return
}
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, handshakeError{e: msg.ack.Error}
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error == handshakeproto.Error_Null {
return identity, nil
} else {
_ = h.conn.Close()
return nil, handshakeError{e: msg.ack.Error}
}
}
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = incomingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
return nil, ErrUnexpectedPayload
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error != handshakeproto.Error_Null {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, handshakeError{e: msg.ack.Error}
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
return
}
func newHandshake() *handshake {
return handshakePool.Get().(*handshake)
}
type handshake struct {
conn sec.SecureConn
remoteCred *handshakeproto.Credentials
remoteAck *handshakeproto.Ack
localAck *handshakeproto.Ack
buf []byte
}
func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err error) {
h.buf = slices.Grow(h.buf, cred.Size()+headerSize)[:cred.Size()+headerSize]
n, err := cred.MarshalToSizedBuffer(h.buf[headerSize:])
if err != nil {
return err
}
return h.writeData(msgTypeCred, n)
}
func (h *handshake) tryWriteErrAndClose(err error) {
if err == ErrGotNotAHandshakeMessage {
// if we got unexpected message - just close the connection
_ = h.conn.Close()
return
}
var ackErr handshakeproto.Error
if he, ok := err.(handshakeError); ok {
ackErr = he.e
} else {
ackErr = handshakeproto.Error_Unexpected
}
_ = h.writeAck(ackErr)
_ = h.conn.Close()
}
func (h *handshake) writeAck(ackErr handshakeproto.Error) (err error) {
h.localAck.Error = ackErr
h.buf = slices.Grow(h.buf, h.localAck.Size()+headerSize)[:h.localAck.Size()+headerSize]
n, err := h.localAck.MarshalTo(h.buf[headerSize:])
if err != nil {
return err
}
return h.writeData(msgTypeAck, n)
}
func (h *handshake) writeData(tp byte, size int) (err error) {
h.buf[0] = tp
binary.LittleEndian.PutUint32(h.buf[1:headerSize], uint32(size))
_, err = h.conn.Write(h.buf[:size+headerSize])
return err
}
type message struct {
cred *handshakeproto.Credentials
ack *handshakeproto.Ack
}
func (h *handshake) readMsg() (msg message, err error) {
h.buf = slices.Grow(h.buf, headerSize)[:headerSize]
if _, err = io.ReadFull(h.conn, h.buf[:headerSize]); err != nil {
return
}
tp := h.buf[0]
if tp != msgTypeCred && tp != msgTypeAck {
err = ErrGotNotAHandshakeMessage
return
}
size := binary.LittleEndian.Uint32(h.buf[1:headerSize])
h.buf = slices.Grow(h.buf, int(size))[:size]
if _, err = io.ReadFull(h.conn, h.buf[:size]); err != nil {
return
}
switch tp {
case msgTypeCred:
if err = h.remoteCred.Unmarshal(h.buf[:size]); err != nil {
return
}
msg.cred = h.remoteCred
case msgTypeAck:
if err = h.remoteAck.Unmarshal(h.buf[:size]); err != nil {
return
}
msg.ack = h.remoteAck
}
return
}
func (h *handshake) release() {
h.buf = h.buf[:0]
h.conn = nil
h.localAck.Error = 0
h.remoteAck.Error = 0
h.remoteCred.Type = 0
h.remoteCred.Payload = h.remoteCred.Payload[:0]
handshakePool.Put(h)
}

View File

@ -0,0 +1,615 @@
package handshake
import (
"context"
"github.com/anytypeio/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
peer2 "github.com/anytypeio/any-sync/util/peer"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net"
"net/http"
_ "net/http/pprof"
"testing"
"time"
)
func init() {
go http.ListenAndServe(":6060", nil)
}
var noVerifyChecker = &testCredChecker{
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
return []byte("identity"), nil
},
}
type handshakeRes struct {
identity []byte
err error
}
func TestOutgoingHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// receive credential message
msg, err := h.readMsg()
require.NoError(t, err)
require.Nil(t, msg.ack)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
require.NoError(t, err)
// send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// receive ack
msg, err = h.readMsg()
require.NoError(t, err)
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
res := <-handshakeResCh
assert.NotEmpty(t, res.identity)
assert.NoError(t, res.err)
})
t.Run("write cred err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("read cred err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
require.NoError(t, err)
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("ack err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
require.NoError(t, err)
require.NoError(t, h.writeAck(ErrInvalidCredentials.e))
res := <-handshakeResCh
require.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error())
})
t.Run("cred err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
require.NoError(t, err)
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
msg, err := h.readMsg()
require.NoError(t, err)
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
res := <-handshakeResCh
require.EqualError(t, res.err, ErrInvalidCredentials.Error())
})
t.Run("write ack err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
require.NoError(t, err)
// write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("read ack err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
require.NoError(t, err)
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read ack and close conn
_, err = h.readMsg()
require.NoError(t, err)
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("write cred instead ack", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
require.NoError(t, err)
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read ack
_, err = h.readMsg()
require.NoError(t, err)
// write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
msg, err := h.readMsg()
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("final ack error", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// receive credential message
msg, err := h.readMsg()
require.NoError(t, err)
require.Nil(t, msg.ack)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
require.NoError(t, err)
// send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// receive ack
msg, err = h.readMsg()
require.NoError(t, err)
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack
require.NoError(t, h.writeAck(handshakeproto.Error_UnexpectedPayload))
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("context cancel", func(t *testing.T) {
var ctx, ctxCancel = context.WithCancel(context.Background())
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// receive credential message
_, err := h.readMsg()
require.NoError(t, err)
ctxCancel()
res := <-handshakeResCh
assert.EqualError(t, res.err, context.Canceled.Error())
_, err = c2.Read(make([]byte, 10))
assert.Error(t, err)
_, err = c2.Write(make([]byte, 10))
assert.Error(t, err)
})
}
func TestIncomingHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
msg, err := h.readMsg()
require.NoError(t, err)
require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
// write ack
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
// wait ack
msg, err = h.readMsg()
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
res := <-handshakeResCh
assert.NotEmpty(t, res.identity)
require.NoError(t, res.err)
})
t.Run("write cred err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("read cred err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("write ack instead cred", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write ack instead cred
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("invalid cred", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// except ack with error
msg, err := h.readMsg()
require.NoError(t, err)
require.Nil(t, msg.cred)
require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error)
res := <-handshakeResCh
require.EqualError(t, res.err, ErrInvalidCredentials.Error())
})
t.Run("write cred instead ack", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read cred
_, err := h.readMsg()
require.NoError(t, err)
// write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// expect ack with error
msg, err := h.readMsg()
require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("read ack err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// read cred and close conn
_, err := h.readMsg()
require.NoError(t, err)
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("write ack with invalid", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
msg, err := h.readMsg()
require.NoError(t, err)
require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
// write ack
require.NoError(t, h.writeAck(handshakeproto.Error_InvalidCredentials))
res := <-handshakeResCh
assert.EqualError(t, res.err, ErrPeerDeclinedCredentials.Error())
})
t.Run("write ack with err", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
msg, err := h.readMsg()
require.NoError(t, err)
require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
// write ack
require.NoError(t, h.writeAck(handshakeproto.Error_Unexpected))
res := <-handshakeResCh
assert.EqualError(t, res.err, ErrUnexpected.Error())
})
t.Run("final ack error", func(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
msg, err := h.readMsg()
require.NoError(t, err)
require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
// write ack and close conn
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
})
t.Run("context cancel", func(t *testing.T) {
var ctx, ctxCancel = context.WithCancel(context.Background())
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(ctx, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
// wait credentials
_, err := h.readMsg()
require.NoError(t, err)
ctxCancel()
res := <-handshakeResCh
require.EqualError(t, res.err, context.Canceled.Error())
_, err = c2.Read(make([]byte, 10))
assert.Error(t, err)
_, err = c2.Write(make([]byte, 10))
assert.Error(t, err)
})
}
func TestNotAHandshakeMessage(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
_, err := c2.Write([]byte("some unexpected bytes"))
require.Error(t, err)
res := <-handshakeResCh
assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error())
}
func TestEndToEnd(t *testing.T) {
c1, c2 := newConnPair(t)
var (
inResCh = make(chan handshakeRes, 1)
outResCh = make(chan handshakeRes, 1)
)
st := time.Now()
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
outResCh <- handshakeRes{identity: identity, err: err}
}()
go func() {
identity, err := IncomingHandshake(nil, c2, noVerifyChecker)
inResCh <- handshakeRes{identity: identity, err: err}
}()
outRes := <-outResCh
assert.NoError(t, outRes.err)
assert.NotEmpty(t, outRes.identity)
inRes := <-inResCh
assert.NoError(t, inRes.err)
assert.NotEmpty(t, inRes.identity)
t.Log("dur", time.Since(st))
}
func BenchmarkHandshake(b *testing.B) {
c1, c2 := newConnPair(b)
var (
inRes = make(chan struct{})
outRes = make(chan struct{})
done = make(chan struct{})
)
defer close(done)
go func() {
for {
_, _ = OutgoingHandshake(nil, c1, noVerifyChecker)
select {
case outRes <- struct{}{}:
case <-done:
return
}
}
}()
go func() {
for {
_, _ = IncomingHandshake(nil, c2, noVerifyChecker)
select {
case inRes <- struct{}{}:
case <-done:
return
}
}
}()
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
<-outRes
<-inRes
}
}
type testCredChecker struct {
makeCred *handshakeproto.Credentials
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
checkErr error
}
func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
return t.makeCred
}
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
if t.checkErr != nil {
return nil, t.checkErr
}
if t.checkCred != nil {
return t.checkCred(sc, cred)
}
return nil, nil
}
func newConnPair(t require.TestingT) (sc1, sc2 *secConn) {
c1, c2 := net.Pipe()
sk1, _, err := signingkey.GenerateRandomEd25519KeyPair()
require.NoError(t, err)
sk1b, err := sk1.Raw()
signKey1, err := crypto.UnmarshalEd25519PrivateKey(sk1b)
require.NoError(t, err)
sk2, _, err := signingkey.GenerateRandomEd25519KeyPair()
require.NoError(t, err)
sk2b, err := sk2.Raw()
signKey2, err := crypto.UnmarshalEd25519PrivateKey(sk2b)
require.NoError(t, err)
peerId1, err := peer2.IdFromSigningPubKey(sk1.GetPublic())
require.NoError(t, err)
peerId2, err := peer2.IdFromSigningPubKey(sk2.GetPublic())
require.NoError(t, err)
sc1 = &secConn{
Conn: c1,
localKey: signKey1,
remotePeer: peerId2,
}
sc2 = &secConn{
Conn: c2,
localKey: signKey2,
remotePeer: peerId1,
}
return
}
type secConn struct {
net.Conn
localKey crypto.PrivKey
remotePeer peer.ID
}
func (s *secConn) LocalPeer() peer.ID {
skB, _ := s.localKey.Raw()
sk, _ := signingkey.NewSigningEd25519PubKeyFromBytes(skB)
lp, _ := peer2.IdFromSigningPubKey(sk)
return lp
}
func (s *secConn) LocalPrivateKey() crypto.PrivKey {
return s.localKey
}
func (s *secConn) RemotePeer() peer.ID {
return s.remotePeer
}
func (s *secConn) RemotePublicKey() crypto.PubKey {
return nil
}
func (s *secConn) ConnState() network.ConnectionState {
return network.ConnectionState{}
}

View File

@ -0,0 +1,813 @@
// Code generated by protoc-gen-gogo. DO NOT EDIT.
// source: net/secureservice/handshake/handshakeproto/protos/handshake.proto
package handshakeproto
import (
fmt "fmt"
proto "github.com/gogo/protobuf/proto"
io "io"
math "math"
math_bits "math/bits"
)
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.GoGoProtoPackageIsVersion3 // please upgrade the proto package
type CredentialsType int32
const (
// SkipVerify using when identity is not required, for example in p2p cases
CredentialsType_SkipVerify CredentialsType = 0
// SignedPeerIds using a payload containing PayloadSignedPeerIds message
CredentialsType_SignedPeerIds CredentialsType = 1
)
var CredentialsType_name = map[int32]string{
0: "SkipVerify",
1: "SignedPeerIds",
}
var CredentialsType_value = map[string]int32{
"SkipVerify": 0,
"SignedPeerIds": 1,
}
func (x CredentialsType) String() string {
return proto.EnumName(CredentialsType_name, int32(x))
}
func (CredentialsType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{0}
}
type Error int32
const (
Error_Null Error = 0
Error_Unexpected Error = 1
Error_InvalidCredentials Error = 2
Error_UnexpectedPayload Error = 3
Error_SkipVerifyNotAllowed Error = 4
Error_DeadlineExceeded Error = 5
)
var Error_name = map[int32]string{
0: "Null",
1: "Unexpected",
2: "InvalidCredentials",
3: "UnexpectedPayload",
4: "SkipVerifyNotAllowed",
5: "DeadlineExceeded",
}
var Error_value = map[string]int32{
"Null": 0,
"Unexpected": 1,
"InvalidCredentials": 2,
"UnexpectedPayload": 3,
"SkipVerifyNotAllowed": 4,
"DeadlineExceeded": 5,
}
func (x Error) String() string {
return proto.EnumName(Error_name, int32(x))
}
func (Error) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{1}
}
type Credentials struct {
Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"`
Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
}
func (m *Credentials) Reset() { *m = Credentials{} }
func (m *Credentials) String() string { return proto.CompactTextString(m) }
func (*Credentials) ProtoMessage() {}
func (*Credentials) Descriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{0}
}
func (m *Credentials) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Credentials) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Credentials.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Credentials) XXX_Merge(src proto.Message) {
xxx_messageInfo_Credentials.Merge(m, src)
}
func (m *Credentials) XXX_Size() int {
return m.Size()
}
func (m *Credentials) XXX_DiscardUnknown() {
xxx_messageInfo_Credentials.DiscardUnknown(m)
}
var xxx_messageInfo_Credentials proto.InternalMessageInfo
func (m *Credentials) GetType() CredentialsType {
if m != nil {
return m.Type
}
return CredentialsType_SkipVerify
}
func (m *Credentials) GetPayload() []byte {
if m != nil {
return m.Payload
}
return nil
}
type PayloadSignedPeerIds struct {
// account identity
Identity []byte `protobuf:"bytes,1,opt,name=identity,proto3" json:"identity,omitempty"`
// sign of (localPeerId + remotePeerId)
Sign []byte `protobuf:"bytes,2,opt,name=sign,proto3" json:"sign,omitempty"`
}
func (m *PayloadSignedPeerIds) Reset() { *m = PayloadSignedPeerIds{} }
func (m *PayloadSignedPeerIds) String() string { return proto.CompactTextString(m) }
func (*PayloadSignedPeerIds) ProtoMessage() {}
func (*PayloadSignedPeerIds) Descriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{1}
}
func (m *PayloadSignedPeerIds) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *PayloadSignedPeerIds) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_PayloadSignedPeerIds.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *PayloadSignedPeerIds) XXX_Merge(src proto.Message) {
xxx_messageInfo_PayloadSignedPeerIds.Merge(m, src)
}
func (m *PayloadSignedPeerIds) XXX_Size() int {
return m.Size()
}
func (m *PayloadSignedPeerIds) XXX_DiscardUnknown() {
xxx_messageInfo_PayloadSignedPeerIds.DiscardUnknown(m)
}
var xxx_messageInfo_PayloadSignedPeerIds proto.InternalMessageInfo
func (m *PayloadSignedPeerIds) GetIdentity() []byte {
if m != nil {
return m.Identity
}
return nil
}
func (m *PayloadSignedPeerIds) GetSign() []byte {
if m != nil {
return m.Sign
}
return nil
}
type Ack struct {
Error Error `protobuf:"varint,1,opt,name=error,proto3,enum=anyHandshake.Error" json:"error,omitempty"`
}
func (m *Ack) Reset() { *m = Ack{} }
func (m *Ack) String() string { return proto.CompactTextString(m) }
func (*Ack) ProtoMessage() {}
func (*Ack) Descriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{2}
}
func (m *Ack) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Ack) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Ack.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Ack) XXX_Merge(src proto.Message) {
xxx_messageInfo_Ack.Merge(m, src)
}
func (m *Ack) XXX_Size() int {
return m.Size()
}
func (m *Ack) XXX_DiscardUnknown() {
xxx_messageInfo_Ack.DiscardUnknown(m)
}
var xxx_messageInfo_Ack proto.InternalMessageInfo
func (m *Ack) GetError() Error {
if m != nil {
return m.Error
}
return Error_Null
}
func init() {
proto.RegisterEnum("anyHandshake.CredentialsType", CredentialsType_name, CredentialsType_value)
proto.RegisterEnum("anyHandshake.Error", Error_name, Error_value)
proto.RegisterType((*Credentials)(nil), "anyHandshake.Credentials")
proto.RegisterType((*PayloadSignedPeerIds)(nil), "anyHandshake.PayloadSignedPeerIds")
proto.RegisterType((*Ack)(nil), "anyHandshake.Ack")
}
func init() {
proto.RegisterFile("net/secureservice/handshake/handshakeproto/protos/handshake.proto", fileDescriptor_60283fc75f020893)
}
var fileDescriptor_60283fc75f020893 = []byte{
// 362 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x51, 0xcd, 0x8e, 0xda, 0x30,
0x18, 0x8c, 0x21, 0xb4, 0xe8, 0x2b, 0xa5, 0xc6, 0xa5, 0x55, 0x54, 0xa9, 0x11, 0xe2, 0x44, 0x39,
0x40, 0xff, 0x5e, 0x80, 0x16, 0xaa, 0x72, 0x41, 0x28, 0xb4, 0x3d, 0x70, 0x73, 0xe3, 0xaf, 0x60,
0x61, 0x39, 0x91, 0x13, 0x28, 0xb9, 0xed, 0x23, 0xec, 0x63, 0xed, 0x91, 0xe3, 0x1e, 0x57, 0xf0,
0x22, 0x2b, 0x0c, 0x2c, 0x61, 0x4f, 0x7b, 0xb1, 0xe7, 0x1b, 0x8f, 0x67, 0xc6, 0x32, 0xf4, 0x34,
0xa6, 0xdd, 0x04, 0xc3, 0xa5, 0xc1, 0x04, 0xcd, 0x4a, 0x86, 0xd8, 0x9d, 0x73, 0x2d, 0x92, 0x39,
0x5f, 0xe4, 0x50, 0x6c, 0xa2, 0x34, 0xea, 0xda, 0x35, 0x39, 0xb3, 0x1d, 0x4b, 0xb0, 0x0a, 0xd7,
0xd9, 0xcf, 0x13, 0xd7, 0x9c, 0xc2, 0x8b, 0xef, 0x06, 0x05, 0xea, 0x54, 0x72, 0x95, 0xb0, 0x4f,
0xe0, 0xa6, 0x59, 0x8c, 0x1e, 0x69, 0x90, 0x56, 0xf5, 0xf3, 0xfb, 0x4e, 0x5e, 0xdb, 0xc9, 0x09,
0x7f, 0x65, 0x31, 0x06, 0x56, 0xca, 0x3c, 0x78, 0x1e, 0xf3, 0x4c, 0x45, 0x5c, 0x78, 0x85, 0x06,
0x69, 0x55, 0x82, 0xd3, 0xd8, 0xfc, 0x01, 0xf5, 0xf1, 0x01, 0x4e, 0xe4, 0x4c, 0xa3, 0x18, 0x23,
0x9a, 0xa1, 0x48, 0xd8, 0x3b, 0x28, 0x4b, 0x6b, 0x94, 0x66, 0x36, 0xa8, 0x12, 0x3c, 0xcc, 0x8c,
0x81, 0x9b, 0xc8, 0x99, 0x3e, 0x5a, 0x59, 0xdc, 0xfc, 0x08, 0xc5, 0x5e, 0xb8, 0x60, 0x1f, 0xa0,
0x84, 0xc6, 0x44, 0xe6, 0x58, 0xee, 0xf5, 0x65, 0xb9, 0xc1, 0xfe, 0x28, 0x38, 0x28, 0xda, 0x5f,
0xe1, 0xd5, 0xa3, 0xb2, 0xac, 0x0a, 0x30, 0x59, 0xc8, 0xf8, 0x0f, 0x1a, 0xf9, 0x2f, 0xa3, 0x0e,
0xab, 0xc1, 0xcb, 0x8b, 0x56, 0x94, 0xb4, 0xaf, 0x08, 0x94, 0xac, 0x0d, 0x2b, 0x83, 0x3b, 0x5a,
0x2a, 0x45, 0x9d, 0xfd, 0xb5, 0xdf, 0x1a, 0xd7, 0x31, 0x86, 0x29, 0x0a, 0x4a, 0xd8, 0x5b, 0x60,
0x43, 0xbd, 0xe2, 0x4a, 0x8a, 0x5c, 0x00, 0x2d, 0xb0, 0x37, 0x50, 0x3b, 0xeb, 0x8e, 0xaf, 0xa6,
0x45, 0xe6, 0x41, 0xfd, 0x9c, 0x3a, 0x8a, 0xd2, 0x9e, 0x52, 0xd1, 0x7f, 0x14, 0xd4, 0x65, 0x75,
0xa0, 0x7d, 0xe4, 0x42, 0x49, 0x8d, 0x83, 0x75, 0x88, 0x28, 0x50, 0xd0, 0xd2, 0xb7, 0xfe, 0xcd,
0xd6, 0x27, 0x9b, 0xad, 0x4f, 0xee, 0xb6, 0x3e, 0xb9, 0xde, 0xf9, 0xce, 0x66, 0xe7, 0x3b, 0xb7,
0x3b, 0xdf, 0x99, 0xb6, 0x9f, 0xfe, 0xf3, 0x7f, 0x9f, 0xd9, 0xed, 0xcb, 0x7d, 0x00, 0x00, 0x00,
0xff, 0xff, 0xa0, 0x48, 0xdf, 0x7a, 0x2e, 0x02, 0x00, 0x00,
}
func (m *Credentials) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Credentials) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Credentials) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if len(m.Payload) > 0 {
i -= len(m.Payload)
copy(dAtA[i:], m.Payload)
i = encodeVarintHandshake(dAtA, i, uint64(len(m.Payload)))
i--
dAtA[i] = 0x12
}
if m.Type != 0 {
i = encodeVarintHandshake(dAtA, i, uint64(m.Type))
i--
dAtA[i] = 0x8
}
return len(dAtA) - i, nil
}
func (m *PayloadSignedPeerIds) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *PayloadSignedPeerIds) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *PayloadSignedPeerIds) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if len(m.Sign) > 0 {
i -= len(m.Sign)
copy(dAtA[i:], m.Sign)
i = encodeVarintHandshake(dAtA, i, uint64(len(m.Sign)))
i--
dAtA[i] = 0x12
}
if len(m.Identity) > 0 {
i -= len(m.Identity)
copy(dAtA[i:], m.Identity)
i = encodeVarintHandshake(dAtA, i, uint64(len(m.Identity)))
i--
dAtA[i] = 0xa
}
return len(dAtA) - i, nil
}
func (m *Ack) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Ack) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if m.Error != 0 {
i = encodeVarintHandshake(dAtA, i, uint64(m.Error))
i--
dAtA[i] = 0x8
}
return len(dAtA) - i, nil
}
func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int {
offset -= sovHandshake(v)
base := offset
for v >= 1<<7 {
dAtA[offset] = uint8(v&0x7f | 0x80)
v >>= 7
offset++
}
dAtA[offset] = uint8(v)
return base
}
func (m *Credentials) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
if m.Type != 0 {
n += 1 + sovHandshake(uint64(m.Type))
}
l = len(m.Payload)
if l > 0 {
n += 1 + l + sovHandshake(uint64(l))
}
return n
}
func (m *PayloadSignedPeerIds) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
l = len(m.Identity)
if l > 0 {
n += 1 + l + sovHandshake(uint64(l))
}
l = len(m.Sign)
if l > 0 {
n += 1 + l + sovHandshake(uint64(l))
}
return n
}
func (m *Ack) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
if m.Error != 0 {
n += 1 + sovHandshake(uint64(m.Error))
}
return n
}
func sovHandshake(x uint64) (n int) {
return (math_bits.Len64(x|1) + 6) / 7
}
func sozHandshake(x uint64) (n int) {
return sovHandshake(uint64((x << 1) ^ uint64((int64(x) >> 63))))
}
func (m *Credentials) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Credentials: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Credentials: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Type", wireType)
}
m.Type = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Type |= CredentialsType(b&0x7F) << shift
if b < 0x80 {
break
}
}
case 2:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Payload", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthHandshake
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthHandshake
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Payload = append(m.Payload[:0], dAtA[iNdEx:postIndex]...)
if m.Payload == nil {
m.Payload = []byte{}
}
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipHandshake(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return ErrInvalidLengthHandshake
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *PayloadSignedPeerIds) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: PayloadSignedPeerIds: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: PayloadSignedPeerIds: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Identity", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthHandshake
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthHandshake
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Identity = append(m.Identity[:0], dAtA[iNdEx:postIndex]...)
if m.Identity == nil {
m.Identity = []byte{}
}
iNdEx = postIndex
case 2:
if wireType != 2 {
return fmt.Errorf("proto: wrong wireType = %d for field Sign", wireType)
}
var byteLen int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
byteLen |= int(b&0x7F) << shift
if b < 0x80 {
break
}
}
if byteLen < 0 {
return ErrInvalidLengthHandshake
}
postIndex := iNdEx + byteLen
if postIndex < 0 {
return ErrInvalidLengthHandshake
}
if postIndex > l {
return io.ErrUnexpectedEOF
}
m.Sign = append(m.Sign[:0], dAtA[iNdEx:postIndex]...)
if m.Sign == nil {
m.Sign = []byte{}
}
iNdEx = postIndex
default:
iNdEx = preIndex
skippy, err := skipHandshake(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return ErrInvalidLengthHandshake
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func (m *Ack) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Ack: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Ack: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Error", wireType)
}
m.Error = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Error |= Error(b&0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := skipHandshake(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return ErrInvalidLengthHandshake
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipHandshake(dAtA []byte) (n int, err error) {
l := len(dAtA)
iNdEx := 0
depth := 0
for iNdEx < l {
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowHandshake
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= (uint64(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
wireType := int(wire & 0x7)
switch wireType {
case 0:
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowHandshake
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
iNdEx++
if dAtA[iNdEx-1] < 0x80 {
break
}
}
case 1:
iNdEx += 8
case 2:
var length int
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return 0, ErrIntOverflowHandshake
}
if iNdEx >= l {
return 0, io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
length |= (int(b) & 0x7F) << shift
if b < 0x80 {
break
}
}
if length < 0 {
return 0, ErrInvalidLengthHandshake
}
iNdEx += length
case 3:
depth++
case 4:
if depth == 0 {
return 0, ErrUnexpectedEndOfGroupHandshake
}
depth--
case 5:
iNdEx += 4
default:
return 0, fmt.Errorf("proto: illegal wireType %d", wireType)
}
if iNdEx < 0 {
return 0, ErrInvalidLengthHandshake
}
if depth == 0 {
return iNdEx, nil
}
}
return 0, io.ErrUnexpectedEOF
}
var (
ErrInvalidLengthHandshake = fmt.Errorf("proto: negative length found during unmarshaling")
ErrIntOverflowHandshake = fmt.Errorf("proto: integer overflow")
ErrUnexpectedEndOfGroupHandshake = fmt.Errorf("proto: unexpected end of group")
)

View File

@ -0,0 +1,69 @@
syntax = "proto3";
package anyHandshake;
option go_package = "net/secureservice/handshake/handshakeproto";
/*
Alice opens a new connection with Bob
1. TLS handshake done successfully; both sides know local and remote peer identifiers.
2. Alice sends a Credentials message to Bob
3. Bob receives Alice's message and validates her credentials
3.1 If credentials are valid, Bob sends his credentials to Alice
3.2 If credentials are invalid, Bob sends an Ack message with an error and closes the connection
4. Alice receives Bob's message
4.1 If it is a credentials message, Alice validates it
4.1.1 If credentials are valid, Alice sends Ack message with error=Null
4.1.2 If credentials are invalid, Alice sends an Ack message with an error and closes the connection
4.2 If it is an Ack message, Alice has an error about why the handshake was unsuccessful
5. Bob receives an Ack message from Alice
5.1 If error == Null, Bob sends Ack with error=Null to Alice - handshake successful
5.2 If error != Null, Bob has an error about why the handshake was unsuccessful
Successful handshake scheme:
Alice -> [CREDENTIALS] -> Bob
Bob -> [CREDENTIALS] -> Alice
Alice -> [Ack:Error=Null] -> Bob
Bob -> [Ack:Error=Null] -> Alice
*/
message Credentials {
CredentialsType type = 1;
bytes payload = 2;
}
enum CredentialsType {
// SkipVerify using when identity is not required, for example in p2p cases
SkipVerify = 0;
// SignedPeerIds using a payload containing PayloadSignedPeerIds message
SignedPeerIds = 1;
}
message PayloadSignedPeerIds {
// account identity
bytes identity = 1;
// sign of (localPeerId + remotePeerId)
bytes sign = 2;
}
message Ack {
Error error = 1;
}
enum Error {
Null = 0;
Unexpected = 1;
InvalidCredentials = 2;
UnexpectedPayload = 3;
SkipVerifyNotAllowed = 4;
DeadlineExceeded = 5;
}

View File

@ -5,6 +5,10 @@ import (
commonaccount "github.com/anytypeio/any-sync/accountservice" commonaccount "github.com/anytypeio/any-sync/accountservice"
"github.com/anytypeio/any-sync/app" "github.com/anytypeio/any-sync/app"
"github.com/anytypeio/any-sync/app/logger" "github.com/anytypeio/any-sync/app/logger"
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
"github.com/anytypeio/any-sync/net/peer"
"github.com/anytypeio/any-sync/net/secureservice/handshake"
"github.com/anytypeio/any-sync/nodeconf"
"github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/sec"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
@ -34,14 +38,20 @@ func New() SecureService {
} }
type SecureService interface { type SecureService interface {
TLSListener(lis net.Listener, timeoutMillis int) ContextListener SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
BasicListener(lis net.Listener, timeoutMillis int) ContextListener SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error)
app.Component app.Component
} }
type secureService struct { type secureService struct {
p2pTr *libp2ptls.Transport
account *accountdata.AccountData
key crypto.PrivKey key crypto.PrivKey
nodeconf nodeconf.Service
noVerifyChecker handshake.CredentialChecker
peerSignVerifier handshake.CredentialChecker
inboundChecker handshake.CredentialChecker
} }
func (s *secureService) Init(a *app.App) (err error) { func (s *secureService) Init(a *app.App) (err error) {
@ -54,8 +64,23 @@ func (s *secureService) Init(a *app.App) (err error) {
return return
} }
log.Info("secure service init", zap.String("peerId", account.Account().PeerId)) s.noVerifyChecker = newNoVerifyChecker()
s.peerSignVerifier = newPeerSignVerifier(account.Account())
s.nodeconf = a.MustComponent(nodeconf.CName).(nodeconf.Service)
s.inboundChecker = s.noVerifyChecker
confTypes := s.nodeconf.GetLast().NodeTypes(account.Account().PeerId)
if len(confTypes) > 0 {
// require identity verification if we are node
s.inboundChecker = s.peerSignVerifier
}
if s.p2pTr, err = libp2ptls.New(libp2ptls.ID, s.key, nil); err != nil {
return
}
log.Info("secure service init", zap.String("peerId", account.Account().PeerId))
return nil return nil
} }
@ -63,18 +88,45 @@ func (s *secureService) Name() (name string) {
return CName return CName
} }
func (s *secureService) TLSListener(lis net.Listener, timeoutMillis int) ContextListener { func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
return newTLSListener(s.key, lis, timeoutMillis) sc, err = s.p2pTr.SecureInbound(ctx, conn, "")
}
func (s *secureService) BasicListener(lis net.Listener, timeoutMillis int) ContextListener {
return newBasicListener(lis, timeoutMillis)
}
func (s *secureService) TLSConn(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
tr, err := libp2ptls.New(libp2ptls.ID, s.key, nil)
if err != nil { if err != nil {
return nil, err return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
} }
return tr.SecureOutbound(ctx, conn, "") }
identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker)
if err != nil {
return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
}
cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
cctx = peer.CtxWithIdentity(cctx, identity)
return
}
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
if err != nil {
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
}
peerId := sc.RemotePeer().String()
confTypes := s.nodeconf.GetLast().NodeTypes(peerId)
var checker handshake.CredentialChecker
if len(confTypes) > 0 {
checker = s.peerSignVerifier
} else {
checker = s.noVerifyChecker
}
// ignore identity for outgoing connection because we don't need it at this moment
_, err = handshake.OutgoingHandshake(ctx, sc, checker)
if err != nil {
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()}
}
return sc, nil
} }

View File

@ -0,0 +1,73 @@
package secureservice
import (
"context"
"github.com/anytypeio/any-sync/accountservice"
"github.com/anytypeio/any-sync/app"
"github.com/anytypeio/any-sync/net/peer"
"github.com/anytypeio/any-sync/nodeconf"
"github.com/anytypeio/any-sync/testutil/testnodeconf"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net"
"testing"
)
var ctx = context.Background()
func TestHandshake(t *testing.T) {
nc := testnodeconf.GenNodeConfig(2)
fxS := newFixture(t, nc, nc.GetAccountService(0))
defer fxS.Finish(t)
sc, cc := net.Pipe()
type acceptRes struct {
ctx context.Context
conn net.Conn
err error
}
resCh := make(chan acceptRes)
go func() {
var ar acceptRes
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc)
resCh <- ar
}()
fxC := newFixture(t, nc, nc.GetAccountService(1))
defer fxC.Finish(t)
secConn, err := fxC.SecureOutbound(ctx, cc)
require.NoError(t, err)
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
res := <-resCh
require.NoError(t, res.err)
peerId, err := peer.CtxPeerId(res.ctx)
require.NoError(t, err)
accId, err := peer.CtxIdentity(res.ctx)
require.NoError(t, err)
assert.Equal(t, nc.GetAccountService(1).Account().PeerId, peerId)
assert.Equal(t, nc.GetAccountService(1).Account().Identity, accId)
}
func newFixture(t *testing.T, nc *testnodeconf.Config, acc accountservice.Service) *fixture {
fx := &fixture{
secureService: New().(*secureService),
acc: acc,
a: new(app.App),
}
fx.a.Register(fx.acc).Register(nc).Register(nodeconf.New()).Register(fx.secureService)
require.NoError(t, fx.a.Start(ctx))
return fx
}
type fixture struct {
*secureService
a *app.App
acc accountservice.Service
}
func (fx *fixture) Finish(t *testing.T) {
require.NoError(t, fx.a.Close(ctx))
}

View File

@ -1,59 +0,0 @@
package secureservice
import (
"context"
"github.com/anytypeio/any-sync/net/peer"
"github.com/anytypeio/any-sync/net/timeoutconn"
"github.com/libp2p/go-libp2p/core/crypto"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"net"
"time"
)
type ContextListener interface {
// Accept works like net.Listener accept but add context
Accept(ctx context.Context) (context.Context, net.Conn, error)
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
Close() error
// Addr returns the listener's network address.
Addr() net.Addr
}
func newTLSListener(key crypto.PrivKey, lis net.Listener, timeoutMillis int) ContextListener {
tr, _ := libp2ptls.New(libp2ptls.ID, key, nil)
return &tlsListener{
tr: tr,
Listener: lis,
timeoutMillis: timeoutMillis,
}
}
type tlsListener struct {
net.Listener
tr *libp2ptls.Transport
timeoutMillis int
}
func (p *tlsListener) Accept(ctx context.Context) (context.Context, net.Conn, error) {
conn, err := p.Listener.Accept()
if err != nil {
return nil, nil, err
}
timeoutConn := timeoutconn.NewConn(conn, time.Duration(p.timeoutMillis)*time.Millisecond)
return p.upgradeConn(ctx, timeoutConn)
}
func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.Context, net.Conn, error) {
secure, err := p.tr.SecureInbound(ctx, conn, "")
if err != nil {
return nil, nil, HandshakeError{
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
}
ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String())
return ctx, secure, nil
}

View File

@ -23,6 +23,8 @@ type Configuration interface {
CHash() chash.CHash CHash() chash.CHash
// Partition returns partition number by spaceId // Partition returns partition number by spaceId
Partition(spaceId string) (part int) Partition(spaceId string) (part int)
// NodeTypes returns list of known nodeTypes by nodeId, if node not registered in configuration will return empty list
NodeTypes(nodeId string) []NodeType
} }
type configuration struct { type configuration struct {
@ -82,6 +84,15 @@ func (c *configuration) Partition(spaceId string) (part int) {
return c.chash.GetPartition(ReplKey(spaceId)) return c.chash.GetPartition(ReplKey(spaceId))
} }
func (c *configuration) NodeTypes(nodeId string) []NodeType {
for _, m := range c.allMembers {
if m.PeerId == nodeId {
return m.Types
}
}
return nil
}
func ReplKey(spaceId string) (replKey string) { func ReplKey(spaceId string) (replKey string) {
if i := strings.LastIndex(spaceId, "."); i != -1 { if i := strings.LastIndex(spaceId, "."); i != -1 {
return spaceId[i+1:] return spaceId[i+1:]

View File

@ -74,9 +74,6 @@ func (s *service) Init(a *app.App) (err error) {
} }
members = append(members, member) members = append(members, member)
} }
if n.PeerId == s.accountId {
continue
}
if n.HasType(NodeTypeConsensus) { if n.HasType(NodeTypeConsensus) {
fileConfig.consensusPeers = append(fileConfig.consensusPeers, n.PeerId) fileConfig.consensusPeers = append(fileConfig.consensusPeers, n.PeerId)
} }

View File

@ -34,15 +34,21 @@ func (s *AccountTestService) Init(a *app.App) (err error) {
return return
} }
peerId, err := peer.IdFromSigningPubKey(signKey.GetPublic()) peerKey, _, err := signingkey.GenerateRandomEd25519KeyPair()
if err != nil {
return err
}
peerId, err := peer.IdFromSigningPubKey(peerKey.GetPublic())
if err != nil { if err != nil {
return err return err
} }
s.acc = &accountdata.AccountData{ s.acc = &accountdata.AccountData{
PeerId: peerId.String(),
Identity: ident, Identity: ident,
PeerKey: peerKey,
SignKey: signKey, SignKey: signKey,
EncKey: encKey, EncKey: encKey,
PeerId: peerId.String(),
} }
return nil return nil
} }

View File

@ -0,0 +1,40 @@
package testnodeconf
import (
"github.com/anytypeio/any-sync/accountservice"
"github.com/anytypeio/any-sync/app"
"github.com/anytypeio/any-sync/nodeconf"
"github.com/anytypeio/any-sync/testutil/accounttest"
)
func GenNodeConfig(num int) (conf *Config) {
conf = &Config{}
if num <= 0 {
num = 1
}
for i := 0; i < num; i++ {
ac := &accounttest.AccountTestService{}
if err := ac.Init(nil); err != nil {
panic(err)
}
conf.nodes = append(conf.nodes, ac.NodeConf(nil))
conf.configs = append(conf.configs, ac)
}
return conf
}
type Config struct {
nodes []nodeconf.NodeConfig
configs []*accounttest.AccountTestService
}
func (c *Config) Init(a *app.App) (err error) { return }
func (c *Config) Name() string { return "config" }
func (c *Config) GetNodes() []nodeconf.NodeConfig {
return c.nodes
}
func (c *Config) GetAccountService(idx int) accountservice.Service {
return c.configs[idx]
}