diff --git a/net/peer/limiter.go b/net/peer/limiter.go new file mode 100644 index 00000000..4c8e1589 --- /dev/null +++ b/net/peer/limiter.go @@ -0,0 +1,18 @@ +package peer + +import ( + "time" +) + +type limiter struct { + startThreshold int + slowDownStep time.Duration +} + +func (l limiter) wait(count int) <-chan time.Time { + if count > l.startThreshold { + wait := l.slowDownStep * time.Duration(count-l.startThreshold) + return time.After(wait) + } + return nil +} diff --git a/net/peer/peer.go b/net/peer/peer.go index c8a54a43..9ff7f023 100644 --- a/net/peer/peer.go +++ b/net/peer/peer.go @@ -19,6 +19,7 @@ import ( "storj.io/drpc/drpcstream" "storj.io/drpc/drpcwire" "sync" + "sync/atomic" "time" ) @@ -35,8 +36,15 @@ func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) { active: map[*subConn]struct{}{}, MultiConn: mc, ctrl: ctrl, - created: time.Now(), + limiter: limiter{ + // start throttling after 10 sub conns + startThreshold: 10, + slowDownStep: time.Millisecond * 100, + }, + subConnRelease: make(chan drpc.Conn), + created: time.Now(), } + pr.acceptCtx, pr.acceptCtxCancel = context.WithCancel(context.Background()) if pr.id, err = CtxPeerId(ctx); err != nil { return } @@ -70,13 +78,22 @@ type peer struct { ctrl connCtrl // drpc conn pool - inactive []*subConn - active map[*subConn]struct{} + // outgoing + inactive []*subConn + active map[*subConn]struct{} + subConnRelease chan drpc.Conn + openingWaitCount atomic.Int32 + + incomingCount atomic.Int32 + acceptCtx context.Context + + acceptCtxCancel context.CancelFunc + + limiter limiter mu sync.Mutex created time.Time - transport.MultiConn } @@ -87,7 +104,20 @@ func (p *peer) Id() string { func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { p.mu.Lock() if len(p.inactive) == 0 { + wait := p.limiter.wait(len(p.active) + int(p.openingWaitCount.Load())) p.mu.Unlock() + if wait != nil { + p.openingWaitCount.Add(1) + defer p.openingWaitCount.Add(-1) + // throttle new connection opening + select { + case <-ctx.Done(): + return nil, ctx.Err() + case dconn := <-p.subConnRelease: + return dconn, nil + case <-wait: + } + } dconn, err := p.openDrpcConn(ctx) if err != nil { return nil, err @@ -110,6 +140,21 @@ func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { } func (p *peer) ReleaseDrpcConn(conn drpc.Conn) { + // do nothing if it's closed connection + select { + case <-conn.Closed(): + return + default: + } + + // try to send this connection to acquire if anyone is waiting for it + select { + case p.subConnRelease <- conn: + return + default: + } + + // return to pool p.mu.Lock() defer p.mu.Unlock() sc, ok := conn.(*subConn) @@ -162,12 +207,21 @@ func (p *peer) acceptLoop() { } }() for { + if wait := p.limiter.wait(int(p.incomingCount.Load())); wait != nil { + select { + case <-wait: + case <-p.acceptCtx.Done(): + return + } + } conn, err := p.Accept() if err != nil { exitErr = err return } go func() { + p.incomingCount.Add(1) + defer p.incomingCount.Add(-1) serveErr := p.serve(conn) if serveErr != io.EOF && serveErr != transport.ErrConnClosed { log.InfoCtx(p.Context(), "serve connection error", zap.Error(serveErr)) diff --git a/net/peer/peer_test.go b/net/peer/peer_test.go index ac06f8d6..7aa8428c 100644 --- a/net/peer/peer_test.go +++ b/net/peer/peer_test.go @@ -12,6 +12,8 @@ import ( "io" "net" _ "net/http/pprof" + "storj.io/drpc" + "storj.io/drpc/drpcconn" "testing" "time" ) @@ -19,32 +21,86 @@ import ( var ctx = context.Background() func TestPeer_AcquireDrpcConn(t *testing.T) { + t.Run("generic", func(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + in, out := net.Pipe() + go func() { + handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker) + }() + defer out.Close() + fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil) + dc, err := fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + assert.NotEmpty(t, dc) + defer dc.Close() + + assert.Len(t, fx.active, 1) + assert.Len(t, fx.inactive, 0) + + fx.ReleaseDrpcConn(dc) + + assert.Len(t, fx.active, 0) + assert.Len(t, fx.inactive, 1) + + dc, err = fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + assert.NotEmpty(t, dc) + assert.Len(t, fx.active, 1) + assert.Len(t, fx.inactive, 0) + }) + t.Run("closed sub conn", func(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + + closedIn, _ := net.Pipe() + dc := drpcconn.New(closedIn) + fx.ReleaseDrpcConn(&subConn{Conn: dc}) + dc.Close() + + in, out := net.Pipe() + go func() { + handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker) + }() + defer out.Close() + fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil) + _, err := fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + }) +} + +func TestPeer_DrpcConn_OpenThrottling(t *testing.T) { fx := newFixture(t, "p1") defer fx.finish() - in, out := net.Pipe() + + acquire := func() (func(), drpc.Conn, error) { + in, out := net.Pipe() + go func() { + _, err := handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker) + require.NoError(t, err) + }() + + fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil) + dconn, err := fx.AcquireDrpcConn(ctx) + return func() { out.Close() }, dconn, err + } + + var conCount = fx.limiter.startThreshold + 3 + var conns []drpc.Conn + for i := 0; i < conCount; i++ { + cc, dc, err := acquire() + require.NoError(t, err) + defer cc() + conns = append(conns, dc) + } + go func() { - handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker) + time.Sleep(fx.limiter.slowDownStep) + fx.ReleaseDrpcConn(conns[0]) + conns = conns[1:] }() - defer out.Close() - fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil) - dc, err := fx.AcquireDrpcConn(ctx) + _, err := fx.AcquireDrpcConn(ctx) require.NoError(t, err) - assert.NotEmpty(t, dc) - defer dc.Close() - - assert.Len(t, fx.active, 1) - assert.Len(t, fx.inactive, 0) - - fx.ReleaseDrpcConn(dc) - - assert.Len(t, fx.active, 0) - assert.Len(t, fx.inactive, 1) - - dc, err = fx.AcquireDrpcConn(ctx) - require.NoError(t, err) - assert.NotEmpty(t, dc) - assert.Len(t, fx.active, 1) - assert.Len(t, fx.inactive, 0) } func TestPeerAccept(t *testing.T) { @@ -63,6 +119,26 @@ func TestPeerAccept(t *testing.T) { assert.NoError(t, <-outHandshakeCh) } +func TestPeer_DrpcConn_AcceptThrottling(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + + var conCount = fx.limiter.startThreshold + 3 + for i := 0; i < conCount; i++ { + in, out := net.Pipe() + defer out.Close() + + var outHandshakeCh = make(chan error) + go func() { + outHandshakeCh <- handshake.OutgoingProtoHandshake(ctx, out, handshakeproto.ProtoType_DRPC) + }() + fx.acceptCh <- acceptedConn{conn: in} + cn := <-fx.testCtrl.serveConn + assert.Equal(t, in, cn) + assert.NoError(t, <-outHandshakeCh) + } +} + func TestPeer_TryClose(t *testing.T) { t.Run("not close in first minute", func(t *testing.T) { fx := newFixture(t, "p1") diff --git a/net/pool/poolservice.go b/net/pool/poolservice.go index 2f84e5d0..0c574eb4 100644 --- a/net/pool/poolservice.go +++ b/net/pool/poolservice.go @@ -49,8 +49,8 @@ func (p *poolService) Init(a *app.App) (err error) { return p.dialer.Dial(ctx, id) }, ocache.WithLogger(log.Sugar()), - ocache.WithGCPeriod(time.Minute), - ocache.WithTTL(time.Minute*5), + ocache.WithGCPeriod(time.Minute/2), + ocache.WithTTL(time.Minute), ocache.WithPrometheus(p.metricReg, "netpool", "outgoing"), ) p.pool.incoming = ocache.New( @@ -58,8 +58,8 @@ func (p *poolService) Init(a *app.App) (err error) { return nil, ocache.ErrNotExists }, ocache.WithLogger(log.Sugar()), - ocache.WithGCPeriod(time.Minute), - ocache.WithTTL(time.Minute*5), + ocache.WithGCPeriod(time.Minute/2), + ocache.WithTTL(time.Minute), ocache.WithPrometheus(p.metricReg, "netpool", "incoming"), ) return nil diff --git a/net/rpc/rpctest/multiconntest/multiconntest.go b/net/rpc/rpctest/multiconntest/multiconntest.go new file mode 100644 index 00000000..a99d1083 --- /dev/null +++ b/net/rpc/rpctest/multiconntest/multiconntest.go @@ -0,0 +1,29 @@ +package multiconntest + +import ( + "context" + "github.com/anyproto/any-sync/net/connutil" + "github.com/anyproto/any-sync/net/transport" + yamux2 "github.com/anyproto/any-sync/net/transport/yamux" + "github.com/hashicorp/yamux" + "net" +) + +func MultiConnPair(peerServCtx, peerClientCtx context.Context) (serv, client transport.MultiConn) { + sc, cc := net.Pipe() + var servConn = make(chan transport.MultiConn, 1) + go func() { + sess, err := yamux.Server(sc, yamux.DefaultConfig()) + if err != nil { + panic(err) + } + servConn <- yamux2.NewMultiConn(peerServCtx, connutil.NewLastUsageConn(sc), "", sess) + }() + sess, err := yamux.Client(cc, yamux.DefaultConfig()) + if err != nil { + panic(err) + } + client = yamux2.NewMultiConn(peerClientCtx, connutil.NewLastUsageConn(cc), "", sess) + serv = <-servConn + return +} diff --git a/net/rpc/rpctest/peer.go b/net/rpc/rpctest/peer.go index a5fef8be..902547bb 100644 --- a/net/rpc/rpctest/peer.go +++ b/net/rpc/rpctest/peer.go @@ -2,29 +2,11 @@ package rpctest import ( "context" - "github.com/anyproto/any-sync/net/connutil" "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/net/rpc/rpctest/multiconntest" "github.com/anyproto/any-sync/net/transport" - yamux2 "github.com/anyproto/any-sync/net/transport/yamux" - "github.com/hashicorp/yamux" - "net" ) func MultiConnPair(peerIdServ, peerIdClient string) (serv, client transport.MultiConn) { - sc, cc := net.Pipe() - var servConn = make(chan transport.MultiConn, 1) - go func() { - sess, err := yamux.Server(sc, yamux.DefaultConfig()) - if err != nil { - panic(err) - } - servConn <- yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdServ), connutil.NewLastUsageConn(sc), "", sess) - }() - sess, err := yamux.Client(cc, yamux.DefaultConfig()) - if err != nil { - panic(err) - } - client = yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdClient), connutil.NewLastUsageConn(cc), "", sess) - serv = <-servConn - return + return multiconntest.MultiConnPair(peer.CtxWithPeerId(context.Background(), peerIdServ), peer.CtxWithPeerId(context.Background(), peerIdClient)) } diff --git a/net/transport/yamux/conn.go b/net/transport/yamux/conn.go index 0013f2bb..6a473796 100644 --- a/net/transport/yamux/conn.go +++ b/net/transport/yamux/conn.go @@ -6,6 +6,7 @@ import ( "github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/transport" "github.com/hashicorp/yamux" + "io" "net" "time" ) @@ -48,7 +49,7 @@ func (y *yamuxConn) Addr() string { func (y *yamuxConn) Accept() (conn net.Conn, err error) { if conn, err = y.Session.Accept(); err != nil { - if err == yamux.ErrSessionShutdown { + if err == yamux.ErrSessionShutdown || err == io.EOF { err = transport.ErrConnClosed } return diff --git a/net/transport/yamux/yamux_test.go b/net/transport/yamux/yamux_test.go index 02e1c322..9b209054 100644 --- a/net/transport/yamux/yamux_test.go +++ b/net/transport/yamux/yamux_test.go @@ -30,8 +30,12 @@ func TestYamuxTransport_Dial(t *testing.T) { mcC, err := fxC.Dial(ctx, fxS.addr) require.NoError(t, err) - require.Len(t, fxS.accepter.mcs, 1) - mcS := <-fxS.accepter.mcs + var mcS transport.MultiConn + select { + case mcS = <-fxS.accepter.mcs: + case <-time.After(time.Second * 5): + require.True(t, false, "timeout") + } var ( sData string @@ -69,11 +73,11 @@ func TestYamuxTransport_Dial(t *testing.T) { // no deadline - 69100 rps // common write deadline - 66700 rps // subconn write deadline - 67100 rps -func TestWriteBench(t *testing.T) { +func TestWriteBenchReuse(t *testing.T) { t.Skip() var ( numSubConn = 10 - numWrites = 100000 + numWrites = 10000 ) fxS := newFixture(t) @@ -124,6 +128,63 @@ func TestWriteBench(t *testing.T) { t.Logf("%.2f req per sec", float64(numWrites*numSubConn)/dur.Seconds()) } +func TestWriteBenchNew(t *testing.T) { + t.Skip() + var ( + numSubConn = 10 + numWrites = 10000 + ) + + fxS := newFixture(t) + defer fxS.finish(t) + fxC := newFixture(t) + defer fxC.finish(t) + + mcC, err := fxC.Dial(ctx, fxS.addr) + require.NoError(t, err) + mcS := <-fxS.accepter.mcs + + go func() { + for i := 0; i < numSubConn; i++ { + require.NoError(t, err) + go func() { + var b = make([]byte, 1024) + for { + conn, _ := mcS.Accept() + n, _ := conn.Read(b) + if n > 0 { + conn.Write(b[:n]) + } else { + _ = conn.Close() + break + } + conn.Close() + } + }() + } + }() + + var wg sync.WaitGroup + wg.Add(numSubConn) + st := time.Now() + for i := 0; i < numSubConn; i++ { + go func() { + defer wg.Done() + for j := 0; j < numWrites; j++ { + sc, err := mcC.Open(ctx) + require.NoError(t, err) + var b = []byte("some data some data some data some data some data some data some data some data some data") + sc.Write(b) + sc.Read(b) + sc.Close() + } + }() + } + wg.Wait() + dur := time.Since(st) + t.Logf("%.2f req per sec", float64(numWrites*numSubConn)/dur.Seconds()) +} + type fixture struct { *yamuxTransport a *app.App