Merge branch 'yamux' into new-sync-protocol

This commit is contained in:
Sergey Cherepanov 2023-06-07 13:34:31 +02:00
commit 5a8c69e557
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C
14 changed files with 218 additions and 176 deletions

12
go.mod
View File

@ -25,12 +25,12 @@ require (
github.com/ipfs/go-ipld-format v0.4.0 github.com/ipfs/go-ipld-format v0.4.0
github.com/ipfs/go-merkledag v0.10.0 github.com/ipfs/go-merkledag v0.10.0
github.com/ipfs/go-unixfs v0.4.6 github.com/ipfs/go-unixfs v0.4.6
github.com/libp2p/go-libp2p v0.27.3 github.com/libp2p/go-libp2p v0.27.5
github.com/mr-tron/base58 v1.2.0 github.com/mr-tron/base58 v1.2.0
github.com/multiformats/go-multibase v0.2.0 github.com/multiformats/go-multibase v0.2.0
github.com/multiformats/go-multihash v0.2.2 github.com/multiformats/go-multihash v0.2.2
github.com/prometheus/client_golang v1.15.1 github.com/prometheus/client_golang v1.15.1
github.com/stretchr/testify v1.8.3 github.com/stretchr/testify v1.8.4
github.com/tyler-smith/go-bip39 v1.1.0 github.com/tyler-smith/go-bip39 v1.1.0
github.com/zeebo/blake3 v0.2.3 github.com/zeebo/blake3 v0.2.3
go.uber.org/atomic v1.11.0 go.uber.org/atomic v1.11.0
@ -56,7 +56,7 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 // indirect github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect
github.com/huin/goupnp v1.2.0 // indirect github.com/huin/goupnp v1.2.0 // indirect
github.com/ipfs/bbloom v0.0.4 // indirect github.com/ipfs/bbloom v0.0.4 // indirect
@ -89,7 +89,7 @@ require (
github.com/multiformats/go-multicodec v0.9.0 // indirect github.com/multiformats/go-multicodec v0.9.0 // indirect
github.com/multiformats/go-multistream v0.4.1 // indirect github.com/multiformats/go-multistream v0.4.1 // indirect
github.com/multiformats/go-varint v0.0.7 // indirect github.com/multiformats/go-varint v0.0.7 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/onsi/ginkgo/v2 v2.9.7 // indirect
github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
@ -97,7 +97,7 @@ require (
github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.10.0 // indirect github.com/prometheus/procfs v0.10.0 // indirect
github.com/quic-go/quic-go v0.34.0 // indirect github.com/quic-go/quic-go v0.35.1 // indirect
github.com/quic-go/webtransport-go v0.5.3 // indirect github.com/quic-go/webtransport-go v0.5.3 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158 // indirect github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158 // indirect
@ -109,7 +109,7 @@ require (
golang.org/x/image v0.6.0 // indirect golang.org/x/image v0.6.0 // indirect
golang.org/x/sync v0.2.0 // indirect golang.org/x/sync v0.2.0 // indirect
golang.org/x/sys v0.8.0 // indirect golang.org/x/sys v0.8.0 // indirect
golang.org/x/tools v0.9.1 // indirect golang.org/x/tools v0.9.3 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect
lukechampine.com/blake3 v1.2.1 // indirect lukechampine.com/blake3 v1.2.1 // indirect

24
go.sum
View File

@ -67,8 +67,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9tRnGuw+ZVCB3FazXODD6JE1R8= github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
@ -188,8 +188,8 @@ github.com/libp2p/go-buffer-pool v0.0.2/go.mod h1:MvaB6xw5vOrDl8rYZGLFdKAuk/hRoR
github.com/libp2p/go-buffer-pool v0.1.0 h1:oK4mSFcQz7cTQIfqbe4MIj9gLW+mnanjyFtc6cdF0Y8= github.com/libp2p/go-buffer-pool v0.1.0 h1:oK4mSFcQz7cTQIfqbe4MIj9gLW+mnanjyFtc6cdF0Y8=
github.com/libp2p/go-buffer-pool v0.1.0/go.mod h1:N+vh8gMqimBzdKkSMVuydVDq+UV5QTWy5HSiZacSbPg= github.com/libp2p/go-buffer-pool v0.1.0/go.mod h1:N+vh8gMqimBzdKkSMVuydVDq+UV5QTWy5HSiZacSbPg=
github.com/libp2p/go-cidranger v1.1.0 h1:ewPN8EZ0dd1LSnrtuwd4709PXVcITVeuwbag38yPW7c= github.com/libp2p/go-cidranger v1.1.0 h1:ewPN8EZ0dd1LSnrtuwd4709PXVcITVeuwbag38yPW7c=
github.com/libp2p/go-libp2p v0.27.3 h1:tkV/zm3KCZ4R5er9Xcs2pt0YNB4JH0iBfGAtHJdLHRs= github.com/libp2p/go-libp2p v0.27.5 h1:KwA7pXKXpz8hG6Cr1fMA7UkgleogcwQj0sxl5qquWRg=
github.com/libp2p/go-libp2p v0.27.3/go.mod h1:FAvvfQa/YOShUYdiSS03IR9OXzkcJXwcNA2FUCh9ImE= github.com/libp2p/go-libp2p v0.27.5/go.mod h1:oMfQGTb9CHnrOuSM6yMmyK2lXz3qIhnkn2+oK3B1Y2g=
github.com/libp2p/go-libp2p-asn-util v0.3.0 h1:gMDcMyYiZKkocGXDQ5nsUQyquC9+H+iLEQHwOCZ7s8s= github.com/libp2p/go-libp2p-asn-util v0.3.0 h1:gMDcMyYiZKkocGXDQ5nsUQyquC9+H+iLEQHwOCZ7s8s=
github.com/libp2p/go-libp2p-record v0.2.0 h1:oiNUOCWno2BFuxt3my4i1frNrt7PerzB3queqa1NkQ0= github.com/libp2p/go-libp2p-record v0.2.0 h1:oiNUOCWno2BFuxt3my4i1frNrt7PerzB3queqa1NkQ0=
github.com/libp2p/go-libp2p-testing v0.12.0 h1:EPvBb4kKMWO29qP4mZGyhVzUyR25dvfUIK5WDu6iPUA= github.com/libp2p/go-libp2p-testing v0.12.0 h1:EPvBb4kKMWO29qP4mZGyhVzUyR25dvfUIK5WDu6iPUA=
@ -244,8 +244,8 @@ github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXS
github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8= github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8=
github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU= github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU=
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY= github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0=
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -268,8 +268,8 @@ github.com/prometheus/procfs v0.10.0/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPH
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU= github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/quic-go/webtransport-go v0.5.3 h1:5XMlzemqB4qmOlgIus5zB45AcZ2kCgCy2EptUrfOPWU= github.com/quic-go/webtransport-go v0.5.3 h1:5XMlzemqB4qmOlgIus5zB45AcZ2kCgCy2EptUrfOPWU=
github.com/quic-go/webtransport-go v0.5.3/go.mod h1:OhmmgJIzTTqXK5xvtuX0oBpLV2GkLWNDA+UeTGJXErU= github.com/quic-go/webtransport-go v0.5.3/go.mod h1:OhmmgJIzTTqXK5xvtuX0oBpLV2GkLWNDA+UeTGJXErU=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
@ -287,8 +287,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/tyler-smith/go-bip39 v1.1.0 h1:5eUemwrMargf3BSLRRCalXT93Ns6pQJIjYQN2nyfOP8= github.com/tyler-smith/go-bip39 v1.1.0 h1:5eUemwrMargf3BSLRRCalXT93Ns6pQJIjYQN2nyfOP8=
github.com/tyler-smith/go-bip39 v1.1.0/go.mod h1:gUYDtqQw1JS3ZJ8UWVcGTGqqr6YIN3CWg+kkNaLt55U= github.com/tyler-smith/go-bip39 v1.1.0/go.mod h1:gUYDtqQw1JS3ZJ8UWVcGTGqqr6YIN3CWg+kkNaLt55U=
github.com/warpfork/go-testmark v0.11.0 h1:J6LnV8KpceDvo7spaNU4+DauH2n1x+6RaO2rJrmpQ9U= github.com/warpfork/go-testmark v0.11.0 h1:J6LnV8KpceDvo7spaNU4+DauH2n1x+6RaO2rJrmpQ9U=
@ -415,8 +415,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -17,7 +17,7 @@ type TimeoutConn struct {
timeout time.Duration timeout time.Duration
} }
func NewConn(conn net.Conn, timeout time.Duration) *TimeoutConn { func NewTimeout(conn net.Conn, timeout time.Duration) *TimeoutConn {
return &TimeoutConn{conn, timeout} return &TimeoutConn{conn, timeout}
} }

View File

@ -5,7 +5,6 @@ import (
"github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anyproto/any-sync/util/crypto" "github.com/anyproto/any-sync/util/crypto"
"github.com/libp2p/go-libp2p/core/sec"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -19,11 +18,11 @@ type noVerifyChecker struct {
cred *handshakeproto.Credentials cred *handshakeproto.Credentials
} }
func (n noVerifyChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { func (n noVerifyChecker) MakeCredentials(remotePeerId string) *handshakeproto.Credentials {
return n.cred return n.cred
} }
func (n noVerifyChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { func (n noVerifyChecker) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if cred.Version != n.cred.Version { if cred.Version != n.cred.Version {
return nil, handshake.ErrIncompatibleVersion return nil, handshake.ErrIncompatibleVersion
} }
@ -42,8 +41,8 @@ type peerSignVerifier struct {
account *accountdata.AccountKeys account *accountdata.AccountKeys
} }
func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { func (p *peerSignVerifier) MakeCredentials(remotePeerId string) *handshakeproto.Credentials {
sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + sc.RemotePeer().String())) sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + remotePeerId))
if err != nil { if err != nil {
log.Warn("can't sign identity credentials", zap.Error(err)) log.Warn("can't sign identity credentials", zap.Error(err))
} }
@ -61,7 +60,7 @@ func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cr
} }
} }
func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { func (p *peerSignVerifier) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if cred.Version != p.protoVersion { if cred.Version != p.protoVersion {
return nil, handshake.ErrIncompatibleVersion return nil, handshake.ErrIncompatibleVersion
} }
@ -76,7 +75,7 @@ func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakepro
if err != nil { if err != nil {
return nil, handshake.ErrInvalidCredentials return nil, handshake.ErrInvalidCredentials
} }
ok, err := pubKey.Verify([]byte((sc.RemotePeer().String() + p.account.PeerId)), msg.Sign) ok, err := pubKey.Verify([]byte((remotePeerId + p.account.PeerId)), msg.Sign)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -4,13 +4,8 @@ import (
"github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/testutil/accounttest" "github.com/anyproto/any-sync/testutil/accounttest"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net"
"testing" "testing"
) )
@ -23,8 +18,8 @@ func TestPeerSignVerifier_CheckCredential(t *testing.T) {
cc1 := newPeerSignVerifier(0, a1) cc1 := newPeerSignVerifier(0, a1)
cc2 := newPeerSignVerifier(0, a2) cc2 := newPeerSignVerifier(0, a2)
c1 := newTestSC(a2.PeerId) c1 := a2.PeerId
c2 := newTestSC(a1.PeerId) c2 := a1.PeerId
cr1 := cc1.MakeCredentials(c1) cr1 := cc1.MakeCredentials(c1)
cr2 := cc2.MakeCredentials(c2) cr2 := cc2.MakeCredentials(c2)
@ -48,8 +43,8 @@ func TestIncompatibleVersion(t *testing.T) {
cc1 := newPeerSignVerifier(0, a1) cc1 := newPeerSignVerifier(0, a1)
cc2 := newPeerSignVerifier(1, a2) cc2 := newPeerSignVerifier(1, a2)
c1 := newTestSC(a2.PeerId) c1 := a2.PeerId
c2 := newTestSC(a1.PeerId) c2 := a1.PeerId
cr1 := cc1.MakeCredentials(c1) cr1 := cc1.MakeCredentials(c1)
cr2 := cc2.MakeCredentials(c2) cr2 := cc2.MakeCredentials(c2)
@ -68,35 +63,3 @@ func newTestAccData(t *testing.T) *accountdata.AccountKeys {
require.NoError(t, as.Init(nil)) require.NoError(t, as.Init(nil))
return as.Account() return as.Account()
} }
func newTestSC(peerId string) sec.SecureConn {
pid, _ := peer.Decode(peerId)
return &testSc{
ID: pid,
}
}
type testSc struct {
net.Conn
peer.ID
}
func (t *testSc) LocalPeer() peer.ID {
return ""
}
func (t *testSc) LocalPrivateKey() crypto.PrivKey {
return nil
}
func (t *testSc) RemotePeer() peer.ID {
return t.ID
}
func (t *testSc) RemotePublicKey() crypto.PubKey {
return nil
}
func (t *testSc) ConnState() network.ConnectionState {
return network.ConnectionState{}
}

View File

@ -3,32 +3,36 @@ package handshake
import ( import (
"context" "context"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec" "io"
) )
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { func OutgoingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
h := newHandshake() h := newHandshake()
done := make(chan struct{}) done := make(chan struct{})
var (
resIdentity []byte
resErr error
)
go func() { go func() {
defer close(done) defer close(done)
identity, err = outgoingHandshake(h, sc, cc) resIdentity, resErr = outgoingHandshake(h, conn, peerId, cc)
}() }()
select { select {
case <-done: case <-done:
return return resIdentity, resErr
case <-ctx.Done(): case <-ctx.Done():
_ = sc.Close() _ = conn.Close()
return nil, ctx.Err() return nil, ctx.Err()
} }
} }
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { func outgoingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
defer h.release() defer h.release()
h.conn = sc h.conn = conn
localCred := cc.MakeCredentials(sc) localCred := cc.MakeCredentials(peerId)
if err = h.writeCredentials(localCred); err != nil { if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return return
@ -45,7 +49,7 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i
return nil, HandshakeError{e: msg.ack.Error} return nil, HandshakeError{e: msg.ack.Error}
} }
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return return
} }
@ -68,40 +72,44 @@ func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (i
} }
} }
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { func IncomingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
h := newHandshake() h := newHandshake()
done := make(chan struct{}) done := make(chan struct{})
var (
resIdentity []byte
resError error
)
go func() { go func() {
defer close(done) defer close(done)
identity, err = incomingHandshake(h, sc, cc) resIdentity, resError = incomingHandshake(h, conn, peerId, cc)
}() }()
select { select {
case <-done: case <-done:
return return resIdentity, resError
case <-ctx.Done(): case <-ctx.Done():
_ = sc.Close() _ = conn.Close()
return nil, ctx.Err() return nil, ctx.Err()
} }
} }
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { func incomingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
defer h.release() defer h.release()
h.conn = sc h.conn = conn
msg, err := h.readMsg(msgTypeCred) msg, err := h.readMsg(msgTypeCred)
if err != nil { if err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return return
} }
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return return
} }
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil { if err = h.writeCredentials(cc.MakeCredentials(peerId)); err != nil {
h.tryWriteErrAndClose(err) h.tryWriteErrAndClose(err)
return nil, err return nil, err
} }

View File

@ -7,7 +7,6 @@ import (
"github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net" "net"
@ -17,7 +16,7 @@ import (
var noVerifyChecker = &testCredChecker{ var noVerifyChecker = &testCredChecker{
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify}, makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { checkCred: func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
return []byte("identity"), nil return []byte("identity"), nil
}, },
} }
@ -32,7 +31,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -40,10 +39,10 @@ func TestOutgoingHandshake(t *testing.T) {
// receive credential message // receive credential message
msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto) msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto)
require.NoError(t, err) require.NoError(t, err)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred) _, err = noVerifyChecker.CheckCredential("p1", msg.cred)
require.NoError(t, err) require.NoError(t, err)
// send credential message // send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// receive ack // receive ack
msg, err = h.readMsg(msgTypeAck) msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
@ -58,7 +57,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
_ = c2.Close() _ = c2.Close()
@ -69,7 +68,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -85,7 +84,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -101,7 +100,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) identity, err := OutgoingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -109,7 +108,7 @@ func TestOutgoingHandshake(t *testing.T) {
// receive credential message // receive credential message
_, err := h.readMsg(msgTypeCred) _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
msg, err := h.readMsg(msgTypeAck) msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error) assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
@ -120,7 +119,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -129,7 +128,7 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg(msgTypeCred) _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write credentials and close conn // write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
_ = c2.Close() _ = c2.Close()
res := <-handshakeResCh res := <-handshakeResCh
require.Error(t, res.err) require.Error(t, res.err)
@ -138,7 +137,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -147,7 +146,7 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg(msgTypeCred) _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read ack and close conn // read ack and close conn
_, err = h.readMsg(msgTypeAck) _, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
@ -159,7 +158,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -168,12 +167,12 @@ func TestOutgoingHandshake(t *testing.T) {
_, err := h.readMsg(msgTypeCred) _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read ack // read ack
_, err = h.readMsg(msgTypeAck) _, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
// write cred instead ack // write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
_, err = h.readMsg(msgTypeAck) _, err = h.readMsg(msgTypeAck)
require.Error(t, err) require.Error(t, err)
res := <-handshakeResCh res := <-handshakeResCh
@ -183,7 +182,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -192,10 +191,10 @@ func TestOutgoingHandshake(t *testing.T) {
msg, err := h.readMsg(msgTypeCred) msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred) _, err = noVerifyChecker.CheckCredential("", msg.cred)
require.NoError(t, err) require.NoError(t, err)
// send credential message // send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// receive ack // receive ack
msg, err = h.readMsg(msgTypeAck) msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
@ -211,7 +210,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker) identity, err := OutgoingHandshake(ctx, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -234,13 +233,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials // wait credentials
msg, err := h.readMsg(msgTypeCred) msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
@ -260,7 +259,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
_ = c2.Close() _ = c2.Close()
@ -271,13 +270,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials and close conn // write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
_ = c2.Close() _ = c2.Close()
res := <-handshakeResCh res := <-handshakeResCh
require.Error(t, res.err) require.Error(t, res.err)
@ -286,7 +285,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -300,13 +299,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// except ack with error // except ack with error
msg, err := h.readMsg(msgTypeAck) msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
@ -320,13 +319,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion}) identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion})
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// except ack with error // except ack with error
msg, err := h.readMsg(msgTypeAck) msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
@ -340,18 +339,18 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read cred // read cred
_, err := h.readMsg(msgTypeCred) _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write cred instead ack // write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// expect EOF // expect EOF
_, err = h.readMsg(msgTypeAck) _, err = h.readMsg(msgTypeAck)
require.Error(t, err) require.Error(t, err)
@ -362,13 +361,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read cred and close conn // read cred and close conn
_, err := h.readMsg(msgTypeCred) _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
@ -381,13 +380,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials // wait credentials
msg, err := h.readMsg(msgTypeCred) msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
@ -403,13 +402,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials // wait credentials
msg, err := h.readMsg(msgTypeCred) msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
@ -425,13 +424,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials // wait credentials
msg, err := h.readMsg(msgTypeCred) msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
@ -448,13 +447,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(ctx, c1, noVerifyChecker) identity, err := IncomingHandshake(ctx, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials // wait credentials
_, err := h.readMsg(msgTypeCred) _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
@ -472,7 +471,7 @@ func TestNotAHandshakeMessage(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -491,11 +490,11 @@ func TestEndToEnd(t *testing.T) {
) )
st := time.Now() st := time.Now()
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
outResCh <- handshakeRes{identity: identity, err: err} outResCh <- handshakeRes{identity: identity, err: err}
}() }()
go func() { go func() {
identity, err := IncomingHandshake(nil, c2, noVerifyChecker) identity, err := IncomingHandshake(nil, c2, "", noVerifyChecker)
inResCh <- handshakeRes{identity: identity, err: err} inResCh <- handshakeRes{identity: identity, err: err}
}() }()
@ -519,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) {
defer close(done) defer close(done)
go func() { go func() {
for { for {
_, _ = OutgoingHandshake(nil, c1, noVerifyChecker) _, _ = OutgoingHandshake(nil, c1, "", noVerifyChecker)
select { select {
case outRes <- struct{}{}: case outRes <- struct{}{}:
case <-done: case <-done:
@ -529,7 +528,7 @@ func BenchmarkHandshake(b *testing.B) {
}() }()
go func() { go func() {
for { for {
_, _ = IncomingHandshake(nil, c2, noVerifyChecker) _, _ = IncomingHandshake(nil, c2, "", noVerifyChecker)
select { select {
case inRes <- struct{}{}: case inRes <- struct{}{}:
case <-done: case <-done:
@ -549,20 +548,20 @@ func BenchmarkHandshake(b *testing.B) {
type testCredChecker struct { type testCredChecker struct {
makeCred *handshakeproto.Credentials makeCred *handshakeproto.Credentials
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) checkCred func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error)
checkErr error checkErr error
} }
func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { func (t *testCredChecker) MakeCredentials(peerId string) *handshakeproto.Credentials {
return t.makeCred return t.makeCred
} }
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { func (t *testCredChecker) CheckCredential(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if t.checkErr != nil { if t.checkErr != nil {
return nil, t.checkErr return nil, t.checkErr
} }
if t.checkCred != nil { if t.checkCred != nil {
return t.checkCred(sc, cred) return t.checkCred(peerId, cred)
} }
return nil, nil return nil, nil
} }

View File

@ -4,10 +4,8 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"io" "io"
"net"
"sync" "sync"
) )
@ -65,8 +63,8 @@ var handshakePool = &sync.Pool{New: func() any {
}} }}
type CredentialChecker interface { type CredentialChecker interface {
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials MakeCredentials(remotePeerId string) *handshakeproto.Credentials
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error)
} }
func newHandshake() *handshake { func newHandshake() *handshake {
@ -74,7 +72,7 @@ func newHandshake() *handshake {
} }
type handshake struct { type handshake struct {
conn net.Conn conn io.ReadWriteCloser
remoteCred *handshakeproto.Credentials remoteCred *handshakeproto.Credentials
remoteProto *handshakeproto.Proto remoteProto *handshakeproto.Proto
remoteAck *handshakeproto.Ack remoteAck *handshakeproto.Ack

View File

@ -2,6 +2,7 @@ package secureservice
import ( import (
"context" "context"
"crypto/tls"
commonaccount "github.com/anyproto/any-sync/accountservice" commonaccount "github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
@ -10,9 +11,9 @@ import (
"github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
"github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/sec"
libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls"
"go.uber.org/zap" "go.uber.org/zap"
"io"
"net" "net"
) )
@ -25,8 +26,10 @@ func New() SecureService {
} }
type SecureService interface { type SecureService interface {
SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error)
SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error)
HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, remotePeerId string) (cctx context.Context, err error)
ServerTlsConfig() (*tls.Config, error)
app.Component app.Component
} }
@ -75,28 +78,31 @@ func (s *secureService) Name() (name string) {
return CName return CName
} }
func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) { func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) {
sc, err = s.p2pTr.SecureInbound(ctx, conn, "") sc, err := s.p2pTr.SecureInbound(ctx, conn, "")
if err != nil { if err != nil {
return nil, nil, handshake.HandshakeError{ return nil, handshake.HandshakeError{
Err: err, Err: err,
} }
} }
return s.HandshakeInbound(ctx, sc, sc.RemotePeer().String())
}
identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker) func (s *secureService) HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, peerId string) (cctx context.Context, err error) {
identity, err := handshake.IncomingHandshake(ctx, conn, peerId, s.inboundChecker)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
cctx = context.Background() cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String()) cctx = peer.CtxWithPeerId(cctx, peerId)
cctx = peer.CtxWithIdentity(cctx, identity) cctx = peer.CtxWithIdentity(cctx, identity)
return return
} }
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) { func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) {
sc, err = s.p2pTr.SecureOutbound(ctx, conn, "") sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
if err != nil { if err != nil {
return nil, nil, handshake.HandshakeError{Err: err} return nil, handshake.HandshakeError{Err: err}
} }
peerId := sc.RemotePeer().String() peerId := sc.RemotePeer().String()
confTypes := s.nodeconf.NodeTypes(peerId) confTypes := s.nodeconf.NodeTypes(peerId)
@ -106,12 +112,22 @@ func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx
} else { } else {
checker = s.noVerifyChecker checker = s.noVerifyChecker
} }
identity, err := handshake.OutgoingHandshake(ctx, sc, checker) identity, err := handshake.OutgoingHandshake(ctx, sc, sc.RemotePeer().String(), checker)
if err != nil { if err != nil {
return nil, nil, err return nil, err
} }
cctx = context.Background() cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String()) cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
cctx = peer.CtxWithIdentity(cctx, identity) cctx = peer.CtxWithIdentity(cctx, identity)
return cctx, sc, nil return cctx, nil
}
func (s *secureService) ServerTlsConfig() (*tls.Config, error) {
p2pIdn, err := libp2ptls.NewIdentity(s.key)
if err != nil {
return nil, err
}
conf, _ := p2pIdn.ConfigForPeer("")
conf.NextProtos = []string{"anysync"}
return conf, nil
} }

View File

@ -32,18 +32,17 @@ func TestHandshake(t *testing.T) {
resCh := make(chan acceptRes) resCh := make(chan acceptRes)
go func() { go func() {
var ar acceptRes var ar acceptRes
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc) ar.ctx, ar.err = fxS.SecureInbound(ctx, sc)
resCh <- ar resCh <- ar
}() }()
fxC := newFixture(t, nc, nc.GetAccountService(1), 0) fxC := newFixture(t, nc, nc.GetAccountService(1), 0)
defer fxC.Finish(t) defer fxC.Finish(t)
cctx, secConn, err := fxC.SecureOutbound(ctx, cc) cctx, err := fxC.SecureOutbound(ctx, cc)
require.NoError(t, err) require.NoError(t, err)
ctxPeerId, err := peer.CtxPeerId(cctx) ctxPeerId, err := peer.CtxPeerId(cctx)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String())
assert.Equal(t, nc.GetAccountService(0).Account().PeerId, ctxPeerId) assert.Equal(t, nc.GetAccountService(0).Account().PeerId, ctxPeerId)
res := <-resCh res := <-resCh
require.NoError(t, res.err) require.NoError(t, res.err)
@ -70,12 +69,12 @@ func TestHandshakeIncompatibleVersion(t *testing.T) {
resCh := make(chan acceptRes) resCh := make(chan acceptRes)
go func() { go func() {
var ar acceptRes var ar acceptRes
ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc) ar.ctx, ar.err = fxS.SecureInbound(ctx, sc)
resCh <- ar resCh <- ar
}() }()
fxC := newFixture(t, nc, nc.GetAccountService(1), 1) fxC := newFixture(t, nc, nc.GetAccountService(1), 1)
defer fxC.Finish(t) defer fxC.Finish(t)
_, _, err := fxC.SecureOutbound(ctx, cc) _, err := fxC.SecureOutbound(ctx, cc)
require.Equal(t, handshake.ErrIncompatibleVersion, err) require.Equal(t, handshake.ErrIncompatibleVersion, err)
res := <-resCh res := <-resCh
require.Equal(t, handshake.ErrIncompatibleVersion, res.err) require.Equal(t, handshake.ErrIncompatibleVersion, res.err)

View File

@ -8,5 +8,4 @@ type Config struct {
ListenAddrs []string `yaml:"listenAddrs"` ListenAddrs []string `yaml:"listenAddrs"`
WriteTimeoutSec int `yaml:"writeTimeoutSec"` WriteTimeoutSec int `yaml:"writeTimeoutSec"`
DialTimeoutSec int `yaml:"dialTimeoutSec"` DialTimeoutSec int `yaml:"dialTimeoutSec"`
MaxStreams int `yaml:"maxStreams"`
} }

View File

@ -26,7 +26,10 @@ type yamuxConn struct {
} }
func (y *yamuxConn) Open(ctx context.Context) (conn net.Conn, err error) { func (y *yamuxConn) Open(ctx context.Context) (conn net.Conn, err error) {
return y.Session.Open() if conn, err = y.Session.Open(); err != nil {
return
}
return
} }
func (y *yamuxConn) LastUsage() time.Time { func (y *yamuxConn) LastUsage() time.Time {
@ -46,6 +49,7 @@ func (y *yamuxConn) Accept() (conn net.Conn, err error) {
if err == yamux.ErrSessionShutdown { if err == yamux.ErrSessionShutdown {
err = transport.ErrConnClosed err = transport.ErrConnClosed
} }
return
} }
return return
} }

View File

@ -43,9 +43,6 @@ func (y *yamuxTransport) Init(a *app.App) (err error) {
y.secure = a.MustComponent(secureservice.CName).(secureservice.SecureService) y.secure = a.MustComponent(secureservice.CName).(secureservice.SecureService)
y.conf = a.MustComponent("config").(configGetter).GetYamux() y.conf = a.MustComponent("config").(configGetter).GetYamux()
y.yamuxConf = yamux.DefaultConfig() y.yamuxConf = yamux.DefaultConfig()
if y.conf.MaxStreams > 0 {
y.yamuxConf.AcceptBacklog = y.conf.MaxStreams
}
y.yamuxConf.EnableKeepAlive = false y.yamuxConf.EnableKeepAlive = false
y.yamuxConf.StreamOpenTimeout = time.Duration(y.conf.DialTimeoutSec) * time.Second y.yamuxConf.StreamOpenTimeout = time.Duration(y.conf.DialTimeoutSec) * time.Second
y.yamuxConf.ConnectionWriteTimeout = time.Duration(y.conf.WriteTimeoutSec) * time.Second y.yamuxConf.ConnectionWriteTimeout = time.Duration(y.conf.WriteTimeoutSec) * time.Second
@ -86,12 +83,12 @@ func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.Mu
} }
ctx, cancel := context.WithTimeout(ctx, dialTimeout) ctx, cancel := context.WithTimeout(ctx, dialTimeout)
defer cancel() defer cancel()
cctx, sc, err := y.secure.SecureOutbound(ctx, conn) cctx, err := y.secure.SecureOutbound(ctx, conn)
if err != nil { if err != nil {
_ = conn.Close() _ = conn.Close()
return nil, err return nil, err
} }
luc := connutil.NewLastUsageConn(sc) luc := connutil.NewLastUsageConn(conn)
sess, err := yamux.Client(luc, y.yamuxConf) sess, err := yamux.Client(luc, y.yamuxConf)
if err != nil { if err != nil {
return return
@ -132,12 +129,12 @@ func (y *yamuxTransport) acceptLoop(ctx context.Context, list net.Listener) {
func (y *yamuxTransport) accept(conn net.Conn) { func (y *yamuxTransport) accept(conn net.Conn) {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(y.conf.DialTimeoutSec)*time.Second) ctx, cancel := context.WithTimeout(context.Background(), time.Duration(y.conf.DialTimeoutSec)*time.Second)
defer cancel() defer cancel()
cctx, sc, err := y.secure.SecureInbound(ctx, conn) cctx, err := y.secure.SecureInbound(ctx, conn)
if err != nil { if err != nil {
log.Warn("incoming connection handshake error", zap.Error(err)) log.Warn("incoming connection handshake error", zap.Error(err))
return return
} }
luc := connutil.NewLastUsageConn(sc) luc := connutil.NewLastUsageConn(conn)
sess, err := yamux.Server(luc, y.yamuxConf) sess, err := yamux.Server(luc, y.yamuxConf)
if err != nil { if err != nil {
log.Warn("incoming connection yamux session error", zap.Error(err)) log.Warn("incoming connection yamux session error", zap.Error(err))

View File

@ -14,7 +14,10 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"io" "io"
"net"
"sync"
"testing" "testing"
"time"
) )
var ctx = context.Background() var ctx = context.Background()
@ -28,7 +31,7 @@ func TestYamuxTransport_Dial(t *testing.T) {
mcC, err := fxC.Dial(ctx, fxS.addr) mcC, err := fxC.Dial(ctx, fxS.addr)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, fxS.accepter.mcs, 1) require.Len(t, fxS.accepter.mcs, 1)
mcS := fxS.accepter.mcs[0] mcS := <-fxS.accepter.mcs
var ( var (
sData string sData string
@ -63,6 +66,64 @@ func TestYamuxTransport_Dial(t *testing.T) {
assert.NoError(t, copyErr) assert.NoError(t, copyErr)
} }
// no deadline - 69100 rps
// common write deadline - 66700 rps
// subconn write deadline - 67100 rps
func TestWriteBench(t *testing.T) {
t.Skip()
var (
numSubConn = 10
numWrites = 100000
)
fxS := newFixture(t)
defer fxS.finish(t)
fxC := newFixture(t)
defer fxC.finish(t)
mcC, err := fxC.Dial(ctx, fxS.addr)
require.NoError(t, err)
mcS := <-fxS.accepter.mcs
go func() {
for i := 0; i < numSubConn; i++ {
conn, err := mcS.Accept()
require.NoError(t, err)
go func(sc net.Conn) {
var b = make([]byte, 1024)
for {
n, _ := sc.Read(b)
if n > 0 {
sc.Write(b[:n])
} else {
break
}
}
}(conn)
}
}()
var wg sync.WaitGroup
wg.Add(numSubConn)
st := time.Now()
for i := 0; i < numSubConn; i++ {
conn, err := mcC.Open(ctx)
require.NoError(t, err)
go func(sc net.Conn) {
defer sc.Close()
defer wg.Done()
for j := 0; j < numWrites; j++ {
var b = []byte("some data some data some data some data some data some data some data some data some data")
sc.Write(b)
sc.Read(b)
}
}(conn)
}
wg.Wait()
dur := time.Since(st)
t.Logf("%.2f req per sec", float64(numWrites*numSubConn)/dur.Seconds())
}
type fixture struct { type fixture struct {
*yamuxTransport *yamuxTransport
a *app.App a *app.App
@ -78,7 +139,7 @@ func newFixture(t *testing.T) *fixture {
yamuxTransport: New().(*yamuxTransport), yamuxTransport: New().(*yamuxTransport),
ctrl: gomock.NewController(t), ctrl: gomock.NewController(t),
acc: &accounttest.AccountTestService{}, acc: &accounttest.AccountTestService{},
accepter: &testAccepter{}, accepter: &testAccepter{mcs: make(chan transport.MultiConn, 100)},
a: new(app.App), a: new(app.App),
} }
@ -112,17 +173,16 @@ func (c *testConf) GetYamux() Config {
ListenAddrs: []string{"127.0.0.1:0"}, ListenAddrs: []string{"127.0.0.1:0"},
WriteTimeoutSec: 10, WriteTimeoutSec: 10,
DialTimeoutSec: 10, DialTimeoutSec: 10,
MaxStreams: 1024,
} }
} }
type testAccepter struct { type testAccepter struct {
err error err error
mcs []transport.MultiConn mcs chan transport.MultiConn
} }
func (t *testAccepter) Accept(mc transport.MultiConn) (err error) { func (t *testAccepter) Accept(mc transport.MultiConn) (err error) {
t.mcs = append(t.mcs, mc) t.mcs <- mc
return t.err return t.err
} }