Merge branch 'yamux' into new-sync-protocol
This commit is contained in:
commit
5a8c69e557
12
go.mod
12
go.mod
@ -25,12 +25,12 @@ require (
|
||||
github.com/ipfs/go-ipld-format v0.4.0
|
||||
github.com/ipfs/go-merkledag v0.10.0
|
||||
github.com/ipfs/go-unixfs v0.4.6
|
||||
github.com/libp2p/go-libp2p v0.27.3
|
||||
github.com/libp2p/go-libp2p v0.27.5
|
||||
github.com/mr-tron/base58 v1.2.0
|
||||
github.com/multiformats/go-multibase v0.2.0
|
||||
github.com/multiformats/go-multihash v0.2.2
|
||||
github.com/prometheus/client_golang v1.15.1
|
||||
github.com/stretchr/testify v1.8.3
|
||||
github.com/stretchr/testify v1.8.4
|
||||
github.com/tyler-smith/go-bip39 v1.1.0
|
||||
github.com/zeebo/blake3 v0.2.3
|
||||
go.uber.org/atomic v1.11.0
|
||||
@ -56,7 +56,7 @@ require (
|
||||
github.com/go-logr/stdr v1.2.2 // indirect
|
||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 // indirect
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
|
||||
github.com/hashicorp/golang-lru v0.5.4 // indirect
|
||||
github.com/huin/goupnp v1.2.0 // indirect
|
||||
github.com/ipfs/bbloom v0.0.4 // indirect
|
||||
@ -89,7 +89,7 @@ require (
|
||||
github.com/multiformats/go-multicodec v0.9.0 // indirect
|
||||
github.com/multiformats/go-multistream v0.4.1 // indirect
|
||||
github.com/multiformats/go-varint v0.0.7 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
|
||||
github.com/onsi/ginkgo/v2 v2.9.7 // indirect
|
||||
github.com/opentracing/opentracing-go v1.2.0 // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
@ -97,7 +97,7 @@ require (
|
||||
github.com/prometheus/client_model v0.4.0 // indirect
|
||||
github.com/prometheus/common v0.44.0 // indirect
|
||||
github.com/prometheus/procfs v0.10.0 // indirect
|
||||
github.com/quic-go/quic-go v0.34.0 // indirect
|
||||
github.com/quic-go/quic-go v0.35.1 // indirect
|
||||
github.com/quic-go/webtransport-go v0.5.3 // indirect
|
||||
github.com/spaolacci/murmur3 v1.1.0 // indirect
|
||||
github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158 // indirect
|
||||
@ -109,7 +109,7 @@ require (
|
||||
golang.org/x/image v0.6.0 // indirect
|
||||
golang.org/x/sync v0.2.0 // indirect
|
||||
golang.org/x/sys v0.8.0 // indirect
|
||||
golang.org/x/tools v0.9.1 // indirect
|
||||
golang.org/x/tools v0.9.3 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
lukechampine.com/blake3 v1.2.1 // indirect
|
||||
|
||||
24
go.sum
24
go.sum
@ -67,8 +67,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
|
||||
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9tRnGuw+ZVCB3FazXODD6JE1R8=
|
||||
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk=
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
|
||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
@ -188,8 +188,8 @@ github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoR
|
||||
github.com/libp2p/go-buffer-pool v0.1.0 h1:oK4mSFcQz7cTQIfqbe4MIj9gLW+mnanjyFtc6cdF0Y8=
|
||||
github.com/libp2p/go-buffer-pool v0.1.0/go.mod h1:N+vh8gMqimBzdKkSMVuydVDq+UV5QTWy5HSiZacSbPg=
|
||||
github.com/libp2p/go-cidranger v1.1.0 h1:ewPN8EZ0dd1LSnrtuwd4709PXVcITVeuwbag38yPW7c=
|
||||
github.com/libp2p/go-libp2p v0.27.3 h1:tkV/zm3KCZ4R5er9Xcs2pt0YNB4JH0iBfGAtHJdLHRs=
|
||||
github.com/libp2p/go-libp2p v0.27.3/go.mod h1:FAvvfQa/YOShUYdiSS03IR9OXzkcJXwcNA2FUCh9ImE=
|
||||
github.com/libp2p/go-libp2p v0.27.5 h1:KwA7pXKXpz8hG6Cr1fMA7UkgleogcwQj0sxl5qquWRg=
|
||||
github.com/libp2p/go-libp2p v0.27.5/go.mod h1:oMfQGTb9CHnrOuSM6yMmyK2lXz3qIhnkn2+oK3B1Y2g=
|
||||
github.com/libp2p/go-libp2p-asn-util v0.3.0 h1:gMDcMyYiZKkocGXDQ5nsUQyquC9+H+iLEQHwOCZ7s8s=
|
||||
github.com/libp2p/go-libp2p-record v0.2.0 h1:oiNUOCWno2BFuxt3my4i1frNrt7PerzB3queqa1NkQ0=
|
||||
github.com/libp2p/go-libp2p-testing v0.12.0 h1:EPvBb4kKMWO29qP4mZGyhVzUyR25dvfUIK5WDu6iPUA=
|
||||
@ -244,8 +244,8 @@ github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXS
|
||||
github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8=
|
||||
github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU=
|
||||
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
|
||||
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
|
||||
github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss=
|
||||
github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0=
|
||||
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
|
||||
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@ -268,8 +268,8 @@ github.com/prometheus/procfs v0.10.0/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPH
|
||||
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
|
||||
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
|
||||
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
|
||||
github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU=
|
||||
github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
|
||||
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
|
||||
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
|
||||
github.com/quic-go/webtransport-go v0.5.3 h1:5XMlzemqB4qmOlgIus5zB45AcZ2kCgCy2EptUrfOPWU=
|
||||
github.com/quic-go/webtransport-go v0.5.3/go.mod h1:OhmmgJIzTTqXK5xvtuX0oBpLV2GkLWNDA+UeTGJXErU=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
@ -287,8 +287,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
|
||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/tyler-smith/go-bip39 v1.1.0 h1:5eUemwrMargf3BSLRRCalXT93Ns6pQJIjYQN2nyfOP8=
|
||||
github.com/tyler-smith/go-bip39 v1.1.0/go.mod h1:gUYDtqQw1JS3ZJ8UWVcGTGqqr6YIN3CWg+kkNaLt55U=
|
||||
github.com/warpfork/go-testmark v0.11.0 h1:J6LnV8KpceDvo7spaNU4+DauH2n1x+6RaO2rJrmpQ9U=
|
||||
@ -415,8 +415,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
|
||||
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
|
||||
golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
|
||||
golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
||||
@ -17,7 +17,7 @@ type TimeoutConn struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewConn(conn net.Conn, timeout time.Duration) *TimeoutConn {
|
||||
func NewTimeout(conn net.Conn, timeout time.Duration) *TimeoutConn {
|
||||
return &TimeoutConn{conn, timeout}
|
||||
}
|
||||
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
@ -19,11 +18,11 @@ type noVerifyChecker struct {
|
||||
cred *handshakeproto.Credentials
|
||||
}
|
||||
|
||||
func (n noVerifyChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
|
||||
func (n noVerifyChecker) MakeCredentials(remotePeerId string) *handshakeproto.Credentials {
|
||||
return n.cred
|
||||
}
|
||||
|
||||
func (n noVerifyChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||
func (n noVerifyChecker) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||
if cred.Version != n.cred.Version {
|
||||
return nil, handshake.ErrIncompatibleVersion
|
||||
}
|
||||
@ -42,8 +41,8 @@ type peerSignVerifier struct {
|
||||
account *accountdata.AccountKeys
|
||||
}
|
||||
|
||||
func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
|
||||
sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + sc.RemotePeer().String()))
|
||||
func (p *peerSignVerifier) MakeCredentials(remotePeerId string) *handshakeproto.Credentials {
|
||||
sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + remotePeerId))
|
||||
if err != nil {
|
||||
log.Warn("can't sign identity credentials", zap.Error(err))
|
||||
}
|
||||
@ -61,7 +60,7 @@ func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cr
|
||||
}
|
||||
}
|
||||
|
||||
func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||
func (p *peerSignVerifier) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||
if cred.Version != p.protoVersion {
|
||||
return nil, handshake.ErrIncompatibleVersion
|
||||
}
|
||||
@ -76,7 +75,7 @@ func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakepro
|
||||
if err != nil {
|
||||
return nil, handshake.ErrInvalidCredentials
|
||||
}
|
||||
ok, err := pubKey.Verify([]byte((sc.RemotePeer().String() + p.account.PeerId)), msg.Sign)
|
||||
ok, err := pubKey.Verify([]byte((remotePeerId + p.account.PeerId)), msg.Sign)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -4,13 +4,8 @@ import (
|
||||
"github.com/anyproto/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake"
|
||||
"github.com/anyproto/any-sync/testutil/accounttest"
|
||||
"github.com/libp2p/go-libp2p/core/crypto"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -23,8 +18,8 @@ func TestPeerSignVerifier_CheckCredential(t *testing.T) {
|
||||
cc1 := newPeerSignVerifier(0, a1)
|
||||
cc2 := newPeerSignVerifier(0, a2)
|
||||
|
||||
c1 := newTestSC(a2.PeerId)
|
||||
c2 := newTestSC(a1.PeerId)
|
||||
c1 := a2.PeerId
|
||||
c2 := a1.PeerId
|
||||
|
||||
cr1 := cc1.MakeCredentials(c1)
|
||||
cr2 := cc2.MakeCredentials(c2)
|
||||
@ -48,8 +43,8 @@ func TestIncompatibleVersion(t *testing.T) {
|
||||
cc1 := newPeerSignVerifier(0, a1)
|
||||
cc2 := newPeerSignVerifier(1, a2)
|
||||
|
||||
c1 := newTestSC(a2.PeerId)
|
||||
c2 := newTestSC(a1.PeerId)
|
||||
c1 := a2.PeerId
|
||||
c2 := a1.PeerId
|
||||
|
||||
cr1 := cc1.MakeCredentials(c1)
|
||||
cr2 := cc2.MakeCredentials(c2)
|
||||
@ -68,35 +63,3 @@ func newTestAccData(t *testing.T) *accountdata.AccountKeys {
|
||||
require.NoError(t, as.Init(nil))
|
||||
return as.Account()
|
||||
}
|
||||
|
||||
func newTestSC(peerId string) sec.SecureConn {
|
||||
pid, _ := peer.Decode(peerId)
|
||||
return &testSc{
|
||||
ID: pid,
|
||||
}
|
||||
}
|
||||
|
||||
type testSc struct {
|
||||
net.Conn
|
||||
peer.ID
|
||||
}
|
||||
|
||||
func (t *testSc) LocalPeer() peer.ID {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (t *testSc) LocalPrivateKey() crypto.PrivKey {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testSc) RemotePeer() peer.ID {
|
||||
return t.ID
|
||||
}
|
||||
|
||||
func (t *testSc) RemotePublicKey() crypto.PubKey {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testSc) ConnState() network.ConnectionState {
|
||||
return network.ConnectionState{}
|
||||
}
|
||||
|
||||
@ -3,32 +3,36 @@ package handshake
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"io"
|
||||
)
|
||||
|
||||
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
func OutgoingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
h := newHandshake()
|
||||
done := make(chan struct{})
|
||||
var (
|
||||
resIdentity []byte
|
||||
resErr error
|
||||
)
|
||||
go func() {
|
||||
defer close(done)
|
||||
identity, err = outgoingHandshake(h, sc, cc)
|
||||
resIdentity, resErr = outgoingHandshake(h, conn, peerId, cc)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
return resIdentity, resErr
|
||||
case <-ctx.Done():
|
||||
_ = sc.Close()
|
||||
_ = conn.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
func outgoingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
|
||||
defer h.release()
|
||||
h.conn = sc
|
||||
localCred := cc.MakeCredentials(sc)
|
||||
h.conn = conn
|
||||
localCred := cc.MakeCredentials(peerId)
|
||||
if err = h.writeCredentials(localCred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
@ -45,7 +49,7 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i
|
||||
return nil, HandshakeError{e: msg.ack.Error}
|
||||
}
|
||||
|
||||
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||
if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
@ -68,40 +72,44 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i
|
||||
}
|
||||
}
|
||||
|
||||
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
func IncomingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
h := newHandshake()
|
||||
done := make(chan struct{})
|
||||
var (
|
||||
resIdentity []byte
|
||||
resError error
|
||||
)
|
||||
go func() {
|
||||
defer close(done)
|
||||
identity, err = incomingHandshake(h, sc, cc)
|
||||
resIdentity, resError = incomingHandshake(h, conn, peerId, cc)
|
||||
}()
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
return resIdentity, resError
|
||||
case <-ctx.Done():
|
||||
_ = sc.Close()
|
||||
_ = conn.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
|
||||
func incomingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
|
||||
defer h.release()
|
||||
h.conn = sc
|
||||
h.conn = conn
|
||||
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
if err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
|
||||
if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
|
||||
if err = h.writeCredentials(cc.MakeCredentials(peerId)); err != nil {
|
||||
h.tryWriteErrAndClose(err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -7,7 +7,6 @@ import (
|
||||
"github.com/libp2p/go-libp2p/core/crypto"
|
||||
"github.com/libp2p/go-libp2p/core/network"
|
||||
"github.com/libp2p/go-libp2p/core/peer"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"net"
|
||||
@ -17,7 +16,7 @@ import (
|
||||
|
||||
var noVerifyChecker = &testCredChecker{
|
||||
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
|
||||
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||
checkCred: func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||
return []byte("identity"), nil
|
||||
},
|
||||
}
|
||||
@ -32,7 +31,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -40,10 +39,10 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
// receive credential message
|
||||
msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto)
|
||||
require.NoError(t, err)
|
||||
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
|
||||
_, err = noVerifyChecker.CheckCredential("p1", msg.cred)
|
||||
require.NoError(t, err)
|
||||
// send credential message
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// receive ack
|
||||
msg, err = h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
@ -58,7 +57,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
_ = c2.Close()
|
||||
@ -69,7 +68,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -85,7 +84,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -101,7 +100,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
identity, err := OutgoingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -109,7 +108,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
// receive credential message
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
msg, err := h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
|
||||
@ -120,7 +119,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -129,7 +128,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
// write credentials and close conn
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
_ = c2.Close()
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
@ -138,7 +137,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -147,7 +146,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// read ack and close conn
|
||||
_, err = h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
@ -159,7 +158,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -168,12 +167,12 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// read ack
|
||||
_, err = h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
// write cred instead ack
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
_, err = h.readMsg(msgTypeAck)
|
||||
require.Error(t, err)
|
||||
res := <-handshakeResCh
|
||||
@ -183,7 +182,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -192,10 +191,10 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, msg.ack)
|
||||
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
|
||||
_, err = noVerifyChecker.CheckCredential("", msg.cred)
|
||||
require.NoError(t, err)
|
||||
// send credential message
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// receive ack
|
||||
msg, err = h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
@ -211,7 +210,7 @@ func TestOutgoingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(ctx, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -234,13 +233,13 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// wait credentials
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
@ -260,7 +259,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
_ = c2.Close()
|
||||
@ -271,13 +270,13 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials and close conn
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
_ = c2.Close()
|
||||
res := <-handshakeResCh
|
||||
require.Error(t, res.err)
|
||||
@ -286,7 +285,7 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -300,13 +299,13 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// except ack with error
|
||||
msg, err := h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
@ -320,13 +319,13 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion})
|
||||
identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion})
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// except ack with error
|
||||
msg, err := h.readMsg(msgTypeAck)
|
||||
require.NoError(t, err)
|
||||
@ -340,18 +339,18 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// read cred
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
// write cred instead ack
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// expect EOF
|
||||
_, err = h.readMsg(msgTypeAck)
|
||||
require.Error(t, err)
|
||||
@ -362,13 +361,13 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// read cred and close conn
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
@ -381,13 +380,13 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// wait credentials
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
@ -403,13 +402,13 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// wait credentials
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
@ -425,13 +424,13 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// wait credentials
|
||||
msg, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
@ -448,13 +447,13 @@ func TestIncomingHandshake(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(ctx, c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(ctx, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
h.conn = c2
|
||||
// write credentials
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
|
||||
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
|
||||
// wait credentials
|
||||
_, err := h.readMsg(msgTypeCred)
|
||||
require.NoError(t, err)
|
||||
@ -472,7 +471,7 @@ func TestNotAHandshakeMessage(t *testing.T) {
|
||||
c1, c2 := newConnPair(t)
|
||||
var handshakeResCh = make(chan handshakeRes, 1)
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
|
||||
handshakeResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
h := newHandshake()
|
||||
@ -491,11 +490,11 @@ func TestEndToEnd(t *testing.T) {
|
||||
)
|
||||
st := time.Now()
|
||||
go func() {
|
||||
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
|
||||
outResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
go func() {
|
||||
identity, err := IncomingHandshake(nil, c2, noVerifyChecker)
|
||||
identity, err := IncomingHandshake(nil, c2, "", noVerifyChecker)
|
||||
inResCh <- handshakeRes{identity: identity, err: err}
|
||||
}()
|
||||
|
||||
@ -519,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) {
|
||||
defer close(done)
|
||||
go func() {
|
||||
for {
|
||||
_, _ = OutgoingHandshake(nil, c1, noVerifyChecker)
|
||||
_, _ = OutgoingHandshake(nil, c1, "", noVerifyChecker)
|
||||
select {
|
||||
case outRes <- struct{}{}:
|
||||
case <-done:
|
||||
@ -529,7 +528,7 @@ func BenchmarkHandshake(b *testing.B) {
|
||||
}()
|
||||
go func() {
|
||||
for {
|
||||
_, _ = IncomingHandshake(nil, c2, noVerifyChecker)
|
||||
_, _ = IncomingHandshake(nil, c2, "", noVerifyChecker)
|
||||
select {
|
||||
case inRes <- struct{}{}:
|
||||
case <-done:
|
||||
@ -549,20 +548,20 @@ func BenchmarkHandshake(b *testing.B) {
|
||||
|
||||
type testCredChecker struct {
|
||||
makeCred *handshakeproto.Credentials
|
||||
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
|
||||
checkCred func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error)
|
||||
checkErr error
|
||||
}
|
||||
|
||||
func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
|
||||
func (t *testCredChecker) MakeCredentials(peerId string) *handshakeproto.Credentials {
|
||||
return t.makeCred
|
||||
}
|
||||
|
||||
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||
func (t *testCredChecker) CheckCredential(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
|
||||
if t.checkErr != nil {
|
||||
return nil, t.checkErr
|
||||
}
|
||||
if t.checkCred != nil {
|
||||
return t.checkCred(sc, cred)
|
||||
return t.checkCred(peerId, cred)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@ -4,10 +4,8 @@ import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"golang.org/x/exp/slices"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
@ -65,8 +63,8 @@ var handshakePool = &sync.Pool{New: func() any {
|
||||
}}
|
||||
|
||||
type CredentialChecker interface {
|
||||
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials
|
||||
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
|
||||
MakeCredentials(remotePeerId string) *handshakeproto.Credentials
|
||||
CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error)
|
||||
}
|
||||
|
||||
func newHandshake() *handshake {
|
||||
@ -74,7 +72,7 @@ func newHandshake() *handshake {
|
||||
}
|
||||
|
||||
type handshake struct {
|
||||
conn net.Conn
|
||||
conn io.ReadWriteCloser
|
||||
remoteCred *handshakeproto.Credentials
|
||||
remoteProto *handshakeproto.Proto
|
||||
remoteAck *handshakeproto.Ack
|
||||
|
||||
@ -2,6 +2,7 @@ package secureservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
commonaccount "github.com/anyproto/any-sync/accountservice"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
@ -10,9 +11,9 @@ import (
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake"
|
||||
"github.com/anyproto/any-sync/nodeconf"
|
||||
"github.com/libp2p/go-libp2p/core/crypto"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
@ -25,8 +26,10 @@ func New() SecureService {
|
||||
}
|
||||
|
||||
type SecureService interface {
|
||||
SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
|
||||
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
|
||||
SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error)
|
||||
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error)
|
||||
HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, remotePeerId string) (cctx context.Context, err error)
|
||||
ServerTlsConfig() (*tls.Config, error)
|
||||
app.Component
|
||||
}
|
||||
|
||||
@ -75,28 +78,31 @@ func (s *secureService) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
|
||||
sc, err = s.p2pTr.SecureInbound(ctx, conn, "")
|
||||
func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) {
|
||||
sc, err := s.p2pTr.SecureInbound(ctx, conn, "")
|
||||
if err != nil {
|
||||
return nil, nil, handshake.HandshakeError{
|
||||
return nil, handshake.HandshakeError{
|
||||
Err: err,
|
||||
}
|
||||
}
|
||||
return s.HandshakeInbound(ctx, sc, sc.RemotePeer().String())
|
||||
}
|
||||
|
||||
identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker)
|
||||
func (s *secureService) HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, peerId string) (cctx context.Context, err error) {
|
||||
identity, err := handshake.IncomingHandshake(ctx, conn, peerId, s.inboundChecker)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
cctx = context.Background()
|
||||
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
|
||||
cctx = peer.CtxWithPeerId(cctx, peerId)
|
||||
cctx = peer.CtxWithIdentity(cctx, identity)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
|
||||
sc, err = s.p2pTr.SecureOutbound(ctx, conn, "")
|
||||
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) {
|
||||
sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
|
||||
if err != nil {
|
||||
return nil, nil, handshake.HandshakeError{Err: err}
|
||||
return nil, handshake.HandshakeError{Err: err}
|
||||
}
|
||||
peerId := sc.RemotePeer().String()
|
||||
confTypes := s.nodeconf.NodeTypes(peerId)
|
||||
@ -106,12 +112,22 @@ func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx
|
||||
} else {
|
||||
checker = s.noVerifyChecker
|
||||
}
|
||||
identity, err := handshake.OutgoingHandshake(ctx, sc, checker)
|
||||
identity, err := handshake.OutgoingHandshake(ctx, sc, sc.RemotePeer().String(), checker)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, err
|
||||
}
|
||||
cctx = context.Background()
|
||||
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
|
||||
cctx = peer.CtxWithIdentity(cctx, identity)
|
||||
return cctx, sc, nil
|
||||
return cctx, nil
|
||||
}
|
||||
|
||||
func (s *secureService) ServerTlsConfig() (*tls.Config, error) {
|
||||
p2pIdn, err := libp2ptls.NewIdentity(s.key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
conf, _ := p2pIdn.ConfigForPeer("")
|
||||
conf.NextProtos = []string{"anysync"}
|
||||
return conf, nil
|
||||
}
|
||||
|
||||
@ -32,18 +32,17 @@ func TestHandshake(t *testing.T) {
|
||||
resCh := make(chan acceptRes)
|
||||
go func() {
|
||||
var ar acceptRes
|
||||
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc)
|
||||
ar.ctx, ar.err = fxS.SecureInbound(ctx, sc)
|
||||
resCh <- ar
|
||||
}()
|
||||
|
||||
fxC := newFixture(t, nc, nc.GetAccountService(1), 0)
|
||||
defer fxC.Finish(t)
|
||||
|
||||
cctx, secConn, err := fxC.SecureOutbound(ctx, cc)
|
||||
cctx, err := fxC.SecureOutbound(ctx, cc)
|
||||
require.NoError(t, err)
|
||||
ctxPeerId, err := peer.CtxPeerId(cctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
|
||||
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, ctxPeerId)
|
||||
res := <-resCh
|
||||
require.NoError(t, res.err)
|
||||
@ -70,12 +69,12 @@ func TestHandshakeIncompatibleVersion(t *testing.T) {
|
||||
resCh := make(chan acceptRes)
|
||||
go func() {
|
||||
var ar acceptRes
|
||||
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc)
|
||||
ar.ctx, ar.err = fxS.SecureInbound(ctx, sc)
|
||||
resCh <- ar
|
||||
}()
|
||||
fxC := newFixture(t, nc, nc.GetAccountService(1), 1)
|
||||
defer fxC.Finish(t)
|
||||
_, _, err := fxC.SecureOutbound(ctx, cc)
|
||||
_, err := fxC.SecureOutbound(ctx, cc)
|
||||
require.Equal(t, handshake.ErrIncompatibleVersion, err)
|
||||
res := <-resCh
|
||||
require.Equal(t, handshake.ErrIncompatibleVersion, res.err)
|
||||
|
||||
@ -8,5 +8,4 @@ type Config struct {
|
||||
ListenAddrs []string `yaml:"listenAddrs"`
|
||||
WriteTimeoutSec int `yaml:"writeTimeoutSec"`
|
||||
DialTimeoutSec int `yaml:"dialTimeoutSec"`
|
||||
MaxStreams int `yaml:"maxStreams"`
|
||||
}
|
||||
|
||||
@ -26,7 +26,10 @@ type yamuxConn struct {
|
||||
}
|
||||
|
||||
func (y *yamuxConn) Open(ctx context.Context) (conn net.Conn, err error) {
|
||||
return y.Session.Open()
|
||||
if conn, err = y.Session.Open(); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (y *yamuxConn) LastUsage() time.Time {
|
||||
@ -46,6 +49,7 @@ func (y *yamuxConn) Accept() (conn net.Conn, err error) {
|
||||
if err == yamux.ErrSessionShutdown {
|
||||
err = transport.ErrConnClosed
|
||||
}
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -43,9 +43,6 @@ func (y *yamuxTransport) Init(a *app.App) (err error) {
|
||||
y.secure = a.MustComponent(secureservice.CName).(secureservice.SecureService)
|
||||
y.conf = a.MustComponent("config").(configGetter).GetYamux()
|
||||
y.yamuxConf = yamux.DefaultConfig()
|
||||
if y.conf.MaxStreams > 0 {
|
||||
y.yamuxConf.AcceptBacklog = y.conf.MaxStreams
|
||||
}
|
||||
y.yamuxConf.EnableKeepAlive = false
|
||||
y.yamuxConf.StreamOpenTimeout = time.Duration(y.conf.DialTimeoutSec) * time.Second
|
||||
y.yamuxConf.ConnectionWriteTimeout = time.Duration(y.conf.WriteTimeoutSec) * time.Second
|
||||
@ -86,12 +83,12 @@ func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.Mu
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(ctx, dialTimeout)
|
||||
defer cancel()
|
||||
cctx, sc, err := y.secure.SecureOutbound(ctx, conn)
|
||||
cctx, err := y.secure.SecureOutbound(ctx, conn)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
luc := connutil.NewLastUsageConn(sc)
|
||||
luc := connutil.NewLastUsageConn(conn)
|
||||
sess, err := yamux.Client(luc, y.yamuxConf)
|
||||
if err != nil {
|
||||
return
|
||||
@ -132,12 +129,12 @@ func (y *yamuxTransport) acceptLoop(ctx context.Context, list net.Listener) {
|
||||
func (y *yamuxTransport) accept(conn net.Conn) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(y.conf.DialTimeoutSec)*time.Second)
|
||||
defer cancel()
|
||||
cctx, sc, err := y.secure.SecureInbound(ctx, conn)
|
||||
cctx, err := y.secure.SecureInbound(ctx, conn)
|
||||
if err != nil {
|
||||
log.Warn("incoming connection handshake error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
luc := connutil.NewLastUsageConn(sc)
|
||||
luc := connutil.NewLastUsageConn(conn)
|
||||
sess, err := yamux.Server(luc, y.yamuxConf)
|
||||
if err != nil {
|
||||
log.Warn("incoming connection yamux session error", zap.Error(err))
|
||||
|
||||
@ -14,7 +14,10 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
@ -28,7 +31,7 @@ func TestYamuxTransport_Dial(t *testing.T) {
|
||||
mcC, err := fxC.Dial(ctx, fxS.addr)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, fxS.accepter.mcs, 1)
|
||||
mcS := fxS.accepter.mcs[0]
|
||||
mcS := <-fxS.accepter.mcs
|
||||
|
||||
var (
|
||||
sData string
|
||||
@ -63,6 +66,64 @@ func TestYamuxTransport_Dial(t *testing.T) {
|
||||
assert.NoError(t, copyErr)
|
||||
}
|
||||
|
||||
// no deadline - 69100 rps
|
||||
// common write deadline - 66700 rps
|
||||
// subconn write deadline - 67100 rps
|
||||
func TestWriteBench(t *testing.T) {
|
||||
t.Skip()
|
||||
var (
|
||||
numSubConn = 10
|
||||
numWrites = 100000
|
||||
)
|
||||
|
||||
fxS := newFixture(t)
|
||||
defer fxS.finish(t)
|
||||
fxC := newFixture(t)
|
||||
defer fxC.finish(t)
|
||||
|
||||
mcC, err := fxC.Dial(ctx, fxS.addr)
|
||||
require.NoError(t, err)
|
||||
mcS := <-fxS.accepter.mcs
|
||||
|
||||
go func() {
|
||||
for i := 0; i < numSubConn; i++ {
|
||||
conn, err := mcS.Accept()
|
||||
require.NoError(t, err)
|
||||
go func(sc net.Conn) {
|
||||
var b = make([]byte, 1024)
|
||||
for {
|
||||
n, _ := sc.Read(b)
|
||||
if n > 0 {
|
||||
sc.Write(b[:n])
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(numSubConn)
|
||||
st := time.Now()
|
||||
for i := 0; i < numSubConn; i++ {
|
||||
conn, err := mcC.Open(ctx)
|
||||
require.NoError(t, err)
|
||||
go func(sc net.Conn) {
|
||||
defer sc.Close()
|
||||
defer wg.Done()
|
||||
for j := 0; j < numWrites; j++ {
|
||||
var b = []byte("some data some data some data some data some data some data some data some data some data")
|
||||
sc.Write(b)
|
||||
sc.Read(b)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
wg.Wait()
|
||||
dur := time.Since(st)
|
||||
t.Logf("%.2f req per sec", float64(numWrites*numSubConn)/dur.Seconds())
|
||||
}
|
||||
|
||||
type fixture struct {
|
||||
*yamuxTransport
|
||||
a *app.App
|
||||
@ -78,7 +139,7 @@ func newFixture(t *testing.T) *fixture {
|
||||
yamuxTransport: New().(*yamuxTransport),
|
||||
ctrl: gomock.NewController(t),
|
||||
acc: &accounttest.AccountTestService{},
|
||||
accepter: &testAccepter{},
|
||||
accepter: &testAccepter{mcs: make(chan transport.MultiConn, 100)},
|
||||
a: new(app.App),
|
||||
}
|
||||
|
||||
@ -112,17 +173,16 @@ func (c *testConf) GetYamux() Config {
|
||||
ListenAddrs: []string{"127.0.0.1:0"},
|
||||
WriteTimeoutSec: 10,
|
||||
DialTimeoutSec: 10,
|
||||
MaxStreams: 1024,
|
||||
}
|
||||
}
|
||||
|
||||
type testAccepter struct {
|
||||
err error
|
||||
mcs []transport.MultiConn
|
||||
mcs chan transport.MultiConn
|
||||
}
|
||||
|
||||
func (t *testAccepter) Accept(mc transport.MultiConn) (err error) {
|
||||
t.mcs = append(t.mcs, mc)
|
||||
t.mcs <- mc
|
||||
return t.err
|
||||
}
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user