Merge branch 'yamux' into new-sync-protocol

This commit is contained in:
Sergey Cherepanov 2023-06-07 13:34:31 +02:00
commit 5a8c69e557
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C
14 changed files with 218 additions and 176 deletions

12
go.mod
View File

@ -25,12 +25,12 @@ require (
github.com/ipfs/go-ipld-format v0.4.0
github.com/ipfs/go-merkledag v0.10.0
github.com/ipfs/go-unixfs v0.4.6
github.com/libp2p/go-libp2p v0.27.3
github.com/libp2p/go-libp2p v0.27.5
github.com/mr-tron/base58 v1.2.0
github.com/multiformats/go-multibase v0.2.0
github.com/multiformats/go-multihash v0.2.2
github.com/prometheus/client_golang v1.15.1
github.com/stretchr/testify v1.8.3
github.com/stretchr/testify v1.8.4
github.com/tyler-smith/go-bip39 v1.1.0
github.com/zeebo/blake3 v0.2.3
go.uber.org/atomic v1.11.0
@ -56,7 +56,7 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 // indirect
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
github.com/hashicorp/golang-lru v0.5.4 // indirect
github.com/huin/goupnp v1.2.0 // indirect
github.com/ipfs/bbloom v0.0.4 // indirect
@ -89,7 +89,7 @@ require (
github.com/multiformats/go-multicodec v0.9.0 // indirect
github.com/multiformats/go-multistream v0.4.1 // indirect
github.com/multiformats/go-varint v0.0.7 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect
github.com/onsi/ginkgo/v2 v2.9.7 // indirect
github.com/opentracing/opentracing-go v1.2.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
@ -97,7 +97,7 @@ require (
github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.10.0 // indirect
github.com/quic-go/quic-go v0.34.0 // indirect
github.com/quic-go/quic-go v0.35.1 // indirect
github.com/quic-go/webtransport-go v0.5.3 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158 // indirect
@ -109,7 +109,7 @@ require (
golang.org/x/image v0.6.0 // indirect
golang.org/x/sync v0.2.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/tools v0.9.1 // indirect
golang.org/x/tools v0.9.3 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.30.0 // indirect
lukechampine.com/blake3 v1.2.1 // indirect

24
go.sum
View File

@ -67,8 +67,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9tRnGuw+ZVCB3FazXODD6JE1R8=
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
@ -188,8 +188,8 @@ github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoR
github.com/libp2p/go-buffer-pool v0.1.0 h1:oK4mSFcQz7cTQIfqbe4MIj9gLW+mnanjyFtc6cdF0Y8=
github.com/libp2p/go-buffer-pool v0.1.0/go.mod h1:N+vh8gMqimBzdKkSMVuydVDq+UV5QTWy5HSiZacSbPg=
github.com/libp2p/go-cidranger v1.1.0 h1:ewPN8EZ0dd1LSnrtuwd4709PXVcITVeuwbag38yPW7c=
github.com/libp2p/go-libp2p v0.27.3 h1:tkV/zm3KCZ4R5er9Xcs2pt0YNB4JH0iBfGAtHJdLHRs=
github.com/libp2p/go-libp2p v0.27.3/go.mod h1:FAvvfQa/YOShUYdiSS03IR9OXzkcJXwcNA2FUCh9ImE=
github.com/libp2p/go-libp2p v0.27.5 h1:KwA7pXKXpz8hG6Cr1fMA7UkgleogcwQj0sxl5qquWRg=
github.com/libp2p/go-libp2p v0.27.5/go.mod h1:oMfQGTb9CHnrOuSM6yMmyK2lXz3qIhnkn2+oK3B1Y2g=
github.com/libp2p/go-libp2p-asn-util v0.3.0 h1:gMDcMyYiZKkocGXDQ5nsUQyquC9+H+iLEQHwOCZ7s8s=
github.com/libp2p/go-libp2p-record v0.2.0 h1:oiNUOCWno2BFuxt3my4i1frNrt7PerzB3queqa1NkQ0=
github.com/libp2p/go-libp2p-testing v0.12.0 h1:EPvBb4kKMWO29qP4mZGyhVzUyR25dvfUIK5WDu6iPUA=
@ -244,8 +244,8 @@ github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXS
github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8=
github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU=
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k=
github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss=
github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0=
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -268,8 +268,8 @@ github.com/prometheus/procfs v0.10.0/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPH
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU=
github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/quic-go/webtransport-go v0.5.3 h1:5XMlzemqB4qmOlgIus5zB45AcZ2kCgCy2EptUrfOPWU=
github.com/quic-go/webtransport-go v0.5.3/go.mod h1:OhmmgJIzTTqXK5xvtuX0oBpLV2GkLWNDA+UeTGJXErU=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
@ -287,8 +287,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tyler-smith/go-bip39 v1.1.0 h1:5eUemwrMargf3BSLRRCalXT93Ns6pQJIjYQN2nyfOP8=
github.com/tyler-smith/go-bip39 v1.1.0/go.mod h1:gUYDtqQw1JS3ZJ8UWVcGTGqqr6YIN3CWg+kkNaLt55U=
github.com/warpfork/go-testmark v0.11.0 h1:J6LnV8KpceDvo7spaNU4+DauH2n1x+6RaO2rJrmpQ9U=
@ -415,8 +415,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo=
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -17,7 +17,7 @@ type TimeoutConn struct {
timeout time.Duration
}
func NewConn(conn net.Conn, timeout time.Duration) *TimeoutConn {
func NewTimeout(conn net.Conn, timeout time.Duration) *TimeoutConn {
return &TimeoutConn{conn, timeout}
}

View File

@ -5,7 +5,6 @@ import (
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anyproto/any-sync/util/crypto"
"github.com/libp2p/go-libp2p/core/sec"
"go.uber.org/zap"
)
@ -19,11 +18,11 @@ type noVerifyChecker struct {
cred *handshakeproto.Credentials
}
func (n noVerifyChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
func (n noVerifyChecker) MakeCredentials(remotePeerId string) *handshakeproto.Credentials {
return n.cred
}
func (n noVerifyChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
func (n noVerifyChecker) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if cred.Version != n.cred.Version {
return nil, handshake.ErrIncompatibleVersion
}
@ -42,8 +41,8 @@ type peerSignVerifier struct {
account *accountdata.AccountKeys
}
func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + sc.RemotePeer().String()))
func (p *peerSignVerifier) MakeCredentials(remotePeerId string) *handshakeproto.Credentials {
sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + remotePeerId))
if err != nil {
log.Warn("can't sign identity credentials", zap.Error(err))
}
@ -61,7 +60,7 @@ func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cr
}
}
func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
func (p *peerSignVerifier) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if cred.Version != p.protoVersion {
return nil, handshake.ErrIncompatibleVersion
}
@ -76,7 +75,7 @@ func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakepro
if err != nil {
return nil, handshake.ErrInvalidCredentials
}
ok, err := pubKey.Verify([]byte((sc.RemotePeer().String() + p.account.PeerId)), msg.Sign)
ok, err := pubKey.Verify([]byte((remotePeerId + p.account.PeerId)), msg.Sign)
if err != nil {
return nil, err
}

View File

@ -4,13 +4,8 @@ import (
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/testutil/accounttest"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net"
"testing"
)
@ -23,8 +18,8 @@ func TestPeerSignVerifier_CheckCredential(t *testing.T) {
cc1 := newPeerSignVerifier(0, a1)
cc2 := newPeerSignVerifier(0, a2)
c1 := newTestSC(a2.PeerId)
c2 := newTestSC(a1.PeerId)
c1 := a2.PeerId
c2 := a1.PeerId
cr1 := cc1.MakeCredentials(c1)
cr2 := cc2.MakeCredentials(c2)
@ -48,8 +43,8 @@ func TestIncompatibleVersion(t *testing.T) {
cc1 := newPeerSignVerifier(0, a1)
cc2 := newPeerSignVerifier(1, a2)
c1 := newTestSC(a2.PeerId)
c2 := newTestSC(a1.PeerId)
c1 := a2.PeerId
c2 := a1.PeerId
cr1 := cc1.MakeCredentials(c1)
cr2 := cc2.MakeCredentials(c2)
@ -68,35 +63,3 @@ func newTestAccData(t *testing.T) *accountdata.AccountKeys {
require.NoError(t, as.Init(nil))
return as.Account()
}
func newTestSC(peerId string) sec.SecureConn {
pid, _ := peer.Decode(peerId)
return &testSc{
ID: pid,
}
}
type testSc struct {
net.Conn
peer.ID
}
func (t *testSc) LocalPeer() peer.ID {
return ""
}
func (t *testSc) LocalPrivateKey() crypto.PrivKey {
return nil
}
func (t *testSc) RemotePeer() peer.ID {
return t.ID
}
func (t *testSc) RemotePublicKey() crypto.PubKey {
return nil
}
func (t *testSc) ConnState() network.ConnectionState {
return network.ConnectionState{}
}

View File

@ -3,32 +3,36 @@ package handshake
import (
"context"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"io"
)
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
func OutgoingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
var (
resIdentity []byte
resErr error
)
go func() {
defer close(done)
identity, err = outgoingHandshake(h, sc, cc)
resIdentity, resErr = outgoingHandshake(h, conn, peerId, cc)
}()
select {
case <-done:
return
return resIdentity, resErr
case <-ctx.Done():
_ = sc.Close()
_ = conn.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
func outgoingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
localCred := cc.MakeCredentials(sc)
h.conn = conn
localCred := cc.MakeCredentials(peerId)
if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err)
return
@ -45,7 +49,7 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i
return nil, HandshakeError{e: msg.ack.Error}
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
@ -68,40 +72,44 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i
}
}
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
func IncomingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
var (
resIdentity []byte
resError error
)
go func() {
defer close(done)
identity, err = incomingHandshake(h, sc, cc)
resIdentity, resError = incomingHandshake(h, conn, peerId, cc)
}()
select {
case <-done:
return
return resIdentity, resError
case <-ctx.Done():
_ = sc.Close()
_ = conn.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
func incomingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
h.conn = conn
msg, err := h.readMsg(msgTypeCred)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
if err = h.writeCredentials(cc.MakeCredentials(peerId)); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}

View File

@ -7,7 +7,6 @@ import (
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net"
@ -17,7 +16,7 @@ import (
var noVerifyChecker = &testCredChecker{
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
checkCred: func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
return []byte("identity"), nil
},
}
@ -32,7 +31,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -40,10 +39,10 @@ func TestOutgoingHandshake(t *testing.T) {
// receive credential message
msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto)
require.NoError(t, err)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
_, err = noVerifyChecker.CheckCredential("p1", msg.cred)
require.NoError(t, err)
// send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// receive ack
msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
@ -58,7 +57,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
_ = c2.Close()
@ -69,7 +68,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -85,7 +84,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -101,7 +100,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
identity, err := OutgoingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -109,7 +108,7 @@ func TestOutgoingHandshake(t *testing.T) {
// receive credential message
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
@ -120,7 +119,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -129,7 +128,7 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
@ -138,7 +137,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -147,7 +146,7 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read ack and close conn
_, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
@ -159,7 +158,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -168,12 +167,12 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read ack
_, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
// write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
_, err = h.readMsg(msgTypeAck)
require.Error(t, err)
res := <-handshakeResCh
@ -183,7 +182,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -192,10 +191,10 @@ func TestOutgoingHandshake(t *testing.T) {
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
require.Nil(t, msg.ack)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
_, err = noVerifyChecker.CheckCredential("", msg.cred)
require.NoError(t, err)
// send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// receive ack
msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err)
@ -211,7 +210,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker)
identity, err := OutgoingHandshake(ctx, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -234,13 +233,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -260,7 +259,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
_ = c2.Close()
@ -271,13 +270,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
_ = c2.Close()
res := <-handshakeResCh
require.Error(t, res.err)
@ -286,7 +285,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -300,13 +299,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// except ack with error
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
@ -320,13 +319,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion})
identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion})
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// except ack with error
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
@ -340,18 +339,18 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read cred
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
// write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// expect EOF
_, err = h.readMsg(msgTypeAck)
require.Error(t, err)
@ -362,13 +361,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read cred and close conn
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -381,13 +380,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -403,13 +402,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -425,13 +424,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials
msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -448,13 +447,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(ctx, c1, noVerifyChecker)
identity, err := IncomingHandshake(ctx, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
h.conn = c2
// write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2)))
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials
_, err := h.readMsg(msgTypeCred)
require.NoError(t, err)
@ -472,7 +471,7 @@ func TestNotAHandshakeMessage(t *testing.T) {
c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1)
go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker)
identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err}
}()
h := newHandshake()
@ -491,11 +490,11 @@ func TestEndToEnd(t *testing.T) {
)
st := time.Now()
go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker)
identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
outResCh <- handshakeRes{identity: identity, err: err}
}()
go func() {
identity, err := IncomingHandshake(nil, c2, noVerifyChecker)
identity, err := IncomingHandshake(nil, c2, "", noVerifyChecker)
inResCh <- handshakeRes{identity: identity, err: err}
}()
@ -519,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) {
defer close(done)
go func() {
for {
_, _ = OutgoingHandshake(nil, c1, noVerifyChecker)
_, _ = OutgoingHandshake(nil, c1, "", noVerifyChecker)
select {
case outRes <- struct{}{}:
case <-done:
@ -529,7 +528,7 @@ func BenchmarkHandshake(b *testing.B) {
}()
go func() {
for {
_, _ = IncomingHandshake(nil, c2, noVerifyChecker)
_, _ = IncomingHandshake(nil, c2, "", noVerifyChecker)
select {
case inRes <- struct{}{}:
case <-done:
@ -549,20 +548,20 @@ func BenchmarkHandshake(b *testing.B) {
type testCredChecker struct {
makeCred *handshakeproto.Credentials
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
checkCred func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error)
checkErr error
}
func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials {
func (t *testCredChecker) MakeCredentials(peerId string) *handshakeproto.Credentials {
return t.makeCred
}
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) {
func (t *testCredChecker) CheckCredential(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if t.checkErr != nil {
return nil, t.checkErr
}
if t.checkCred != nil {
return t.checkCred(sc, cred)
return t.checkCred(peerId, cred)
}
return nil, nil
}

View File

@ -4,10 +4,8 @@ import (
"encoding/binary"
"errors"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices"
"io"
"net"
"sync"
)
@ -65,8 +63,8 @@ var handshakePool = &sync.Pool{New: func() any {
}}
type CredentialChecker interface {
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error)
MakeCredentials(remotePeerId string) *handshakeproto.Credentials
CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error)
}
func newHandshake() *handshake {
@ -74,7 +72,7 @@ func newHandshake() *handshake {
}
type handshake struct {
conn net.Conn
conn io.ReadWriteCloser
remoteCred *handshakeproto.Credentials
remoteProto *handshakeproto.Proto
remoteAck *handshakeproto.Ack

View File

@ -2,6 +2,7 @@ package secureservice
import (
"context"
"crypto/tls"
commonaccount "github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
@ -10,9 +11,9 @@ import (
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/nodeconf"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/sec"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"go.uber.org/zap"
"io"
"net"
)
@ -25,8 +26,10 @@ func New() SecureService {
}
type SecureService interface {
SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error)
SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error)
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error)
HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, remotePeerId string) (cctx context.Context, err error)
ServerTlsConfig() (*tls.Config, error)
app.Component
}
@ -75,28 +78,31 @@ func (s *secureService) Name() (name string) {
return CName
}
func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
sc, err = s.p2pTr.SecureInbound(ctx, conn, "")
func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) {
sc, err := s.p2pTr.SecureInbound(ctx, conn, "")
if err != nil {
return nil, nil, handshake.HandshakeError{
return nil, handshake.HandshakeError{
Err: err,
}
}
return s.HandshakeInbound(ctx, sc, sc.RemotePeer().String())
}
identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker)
func (s *secureService) HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, peerId string) (cctx context.Context, err error) {
identity, err := handshake.IncomingHandshake(ctx, conn, peerId, s.inboundChecker)
if err != nil {
return nil, nil, err
return nil, err
}
cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
cctx = peer.CtxWithPeerId(cctx, peerId)
cctx = peer.CtxWithIdentity(cctx, identity)
return
}
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
sc, err = s.p2pTr.SecureOutbound(ctx, conn, "")
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) {
sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
if err != nil {
return nil, nil, handshake.HandshakeError{Err: err}
return nil, handshake.HandshakeError{Err: err}
}
peerId := sc.RemotePeer().String()
confTypes := s.nodeconf.NodeTypes(peerId)
@ -106,12 +112,22 @@ func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx
} else {
checker = s.noVerifyChecker
}
identity, err := handshake.OutgoingHandshake(ctx, sc, checker)
identity, err := handshake.OutgoingHandshake(ctx, sc, sc.RemotePeer().String(), checker)
if err != nil {
return nil, nil, err
return nil, err
}
cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
cctx = peer.CtxWithIdentity(cctx, identity)
return cctx, sc, nil
return cctx, nil
}
func (s *secureService) ServerTlsConfig() (*tls.Config, error) {
p2pIdn, err := libp2ptls.NewIdentity(s.key)
if err != nil {
return nil, err
}
conf, _ := p2pIdn.ConfigForPeer("")
conf.NextProtos = []string{"anysync"}
return conf, nil
}

View File

@ -32,18 +32,17 @@ func TestHandshake(t *testing.T) {
resCh := make(chan acceptRes)
go func() {
var ar acceptRes
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc)
ar.ctx, ar.err = fxS.SecureInbound(ctx, sc)
resCh <- ar
}()
fxC := newFixture(t, nc, nc.GetAccountService(1), 0)
defer fxC.Finish(t)
cctx, secConn, err := fxC.SecureOutbound(ctx, cc)
cctx, err := fxC.SecureOutbound(ctx, cc)
require.NoError(t, err)
ctxPeerId, err := peer.CtxPeerId(cctx)
require.NoError(t, err)
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, ctxPeerId)
res := <-resCh
require.NoError(t, res.err)
@ -70,12 +69,12 @@ func TestHandshakeIncompatibleVersion(t *testing.T) {
resCh := make(chan acceptRes)
go func() {
var ar acceptRes
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc)
ar.ctx, ar.err = fxS.SecureInbound(ctx, sc)
resCh <- ar
}()
fxC := newFixture(t, nc, nc.GetAccountService(1), 1)
defer fxC.Finish(t)
_, _, err := fxC.SecureOutbound(ctx, cc)
_, err := fxC.SecureOutbound(ctx, cc)
require.Equal(t, handshake.ErrIncompatibleVersion, err)
res := <-resCh
require.Equal(t, handshake.ErrIncompatibleVersion, res.err)

View File

@ -8,5 +8,4 @@ type Config struct {
ListenAddrs []string `yaml:"listenAddrs"`
WriteTimeoutSec int `yaml:"writeTimeoutSec"`
DialTimeoutSec int `yaml:"dialTimeoutSec"`
MaxStreams int `yaml:"maxStreams"`
}

View File

@ -26,7 +26,10 @@ type yamuxConn struct {
}
func (y *yamuxConn) Open(ctx context.Context) (conn net.Conn, err error) {
return y.Session.Open()
if conn, err = y.Session.Open(); err != nil {
return
}
return
}
func (y *yamuxConn) LastUsage() time.Time {
@ -46,6 +49,7 @@ func (y *yamuxConn) Accept() (conn net.Conn, err error) {
if err == yamux.ErrSessionShutdown {
err = transport.ErrConnClosed
}
return
}
return
}

View File

@ -43,9 +43,6 @@ func (y *yamuxTransport) Init(a *app.App) (err error) {
y.secure = a.MustComponent(secureservice.CName).(secureservice.SecureService)
y.conf = a.MustComponent("config").(configGetter).GetYamux()
y.yamuxConf = yamux.DefaultConfig()
if y.conf.MaxStreams > 0 {
y.yamuxConf.AcceptBacklog = y.conf.MaxStreams
}
y.yamuxConf.EnableKeepAlive = false
y.yamuxConf.StreamOpenTimeout = time.Duration(y.conf.DialTimeoutSec) * time.Second
y.yamuxConf.ConnectionWriteTimeout = time.Duration(y.conf.WriteTimeoutSec) * time.Second
@ -86,12 +83,12 @@ func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.Mu
}
ctx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel()
cctx, sc, err := y.secure.SecureOutbound(ctx, conn)
cctx, err := y.secure.SecureOutbound(ctx, conn)
if err != nil {
_ = conn.Close()
return nil, err
}
luc := connutil.NewLastUsageConn(sc)
luc := connutil.NewLastUsageConn(conn)
sess, err := yamux.Client(luc, y.yamuxConf)
if err != nil {
return
@ -132,12 +129,12 @@ func (y *yamuxTransport) acceptLoop(ctx context.Context, list net.Listener) {
func (y *yamuxTransport) accept(conn net.Conn) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(y.conf.DialTimeoutSec)*time.Second)
defer cancel()
cctx, sc, err := y.secure.SecureInbound(ctx, conn)
cctx, err := y.secure.SecureInbound(ctx, conn)
if err != nil {
log.Warn("incoming connection handshake error", zap.Error(err))
return
}
luc := connutil.NewLastUsageConn(sc)
luc := connutil.NewLastUsageConn(conn)
sess, err := yamux.Server(luc, y.yamuxConf)
if err != nil {
log.Warn("incoming connection yamux session error", zap.Error(err))

View File

@ -14,7 +14,10 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"net"
"sync"
"testing"
"time"
)
var ctx = context.Background()
@ -28,7 +31,7 @@ func TestYamuxTransport_Dial(t *testing.T) {
mcC, err := fxC.Dial(ctx, fxS.addr)
require.NoError(t, err)
require.Len(t, fxS.accepter.mcs, 1)
mcS := fxS.accepter.mcs[0]
mcS := <-fxS.accepter.mcs
var (
sData string
@ -63,6 +66,64 @@ func TestYamuxTransport_Dial(t *testing.T) {
assert.NoError(t, copyErr)
}
// no deadline - 69100 rps
// common write deadline - 66700 rps
// subconn write deadline - 67100 rps
func TestWriteBench(t *testing.T) {
t.Skip()
var (
numSubConn = 10
numWrites = 100000
)
fxS := newFixture(t)
defer fxS.finish(t)
fxC := newFixture(t)
defer fxC.finish(t)
mcC, err := fxC.Dial(ctx, fxS.addr)
require.NoError(t, err)
mcS := <-fxS.accepter.mcs
go func() {
for i := 0; i < numSubConn; i++ {
conn, err := mcS.Accept()
require.NoError(t, err)
go func(sc net.Conn) {
var b = make([]byte, 1024)
for {
n, _ := sc.Read(b)
if n > 0 {
sc.Write(b[:n])
} else {
break
}
}
}(conn)
}
}()
var wg sync.WaitGroup
wg.Add(numSubConn)
st := time.Now()
for i := 0; i < numSubConn; i++ {
conn, err := mcC.Open(ctx)
require.NoError(t, err)
go func(sc net.Conn) {
defer sc.Close()
defer wg.Done()
for j := 0; j < numWrites; j++ {
var b = []byte("some data some data some data some data some data some data some data some data some data")
sc.Write(b)
sc.Read(b)
}
}(conn)
}
wg.Wait()
dur := time.Since(st)
t.Logf("%.2f req per sec", float64(numWrites*numSubConn)/dur.Seconds())
}
type fixture struct {
*yamuxTransport
a *app.App
@ -78,7 +139,7 @@ func newFixture(t *testing.T) *fixture {
yamuxTransport: New().(*yamuxTransport),
ctrl: gomock.NewController(t),
acc: &accounttest.AccountTestService{},
accepter: &testAccepter{},
accepter: &testAccepter{mcs: make(chan transport.MultiConn, 100)},
a: new(app.App),
}
@ -112,17 +173,16 @@ func (c *testConf) GetYamux() Config {
ListenAddrs: []string{"127.0.0.1:0"},
WriteTimeoutSec: 10,
DialTimeoutSec: 10,
MaxStreams: 1024,
}
}
type testAccepter struct {
err error
mcs []transport.MultiConn
mcs chan transport.MultiConn
}
func (t *testAccepter) Accept(mc transport.MultiConn) (err error) {
t.mcs = append(t.mcs, mc)
t.mcs <- mc
return t.err
}