From 0c0a501aad69256f017c1d51e3a982571ee0a05d Mon Sep 17 00:00:00 2001 From: Sergey Cherepanov Date: Mon, 5 Jun 2023 20:39:09 +0200 Subject: [PATCH 1/4] peer outgoing proto handshake + test multiconn + streampool tests --- net/peer/peer.go | 20 ++++- net/rpc/rpctest/peer.go | 30 ++++++++ net/rpc/rpctest/pool.go | 122 ------------------------------ net/rpc/rpctest/server.go | 22 +++--- net/streampool/streampool_test.go | 51 ++++++++----- net/transport/yamux/conn.go | 9 +++ net/transport/yamux/yamux.go | 14 +--- 7 files changed, 97 insertions(+), 171 deletions(-) create mode 100644 net/rpc/rpctest/peer.go delete mode 100644 net/rpc/rpctest/pool.go diff --git a/net/peer/peer.go b/net/peer/peer.go index 879f6f9b..ab0ae236 100644 --- a/net/peer/peer.go +++ b/net/peer/peer.go @@ -38,6 +38,7 @@ func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) { type Peer interface { Id() string + Context() context.Context AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) ReleaseDrpcConn(conn drpc.Conn) @@ -70,19 +71,20 @@ func (p *peer) Id() string { func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { p.mu.Lock() - defer p.mu.Unlock() if len(p.inactive) == 0 { - conn, err := p.Open(ctx) + p.mu.Unlock() + dconn, err := p.openDrpcConn(ctx) if err != nil { return nil, err } - dconn := drpcconn.New(conn) + p.mu.Lock() p.inactive = append(p.inactive, dconn) } idx := len(p.inactive) - 1 res := p.inactive[idx] p.inactive = p.inactive[:idx] p.active[res] = struct{}{} + p.mu.Unlock() return res, nil } @@ -105,6 +107,18 @@ func (p *peer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error return do(conn) } +func (p *peer) openDrpcConn(ctx context.Context) (dconn drpc.Conn, err error) { + conn, err := p.Open(ctx) + if err != nil { + return nil, err + } + if err = handshake.OutgoingProtoHandshake(ctx, conn, handshakeproto.ProtoType_DRPC); err != nil { + return nil, err + } + dconn = drpcconn.New(conn) + return +} + func (p *peer) acceptLoop() { var exitErr error defer func() { diff --git a/net/rpc/rpctest/peer.go b/net/rpc/rpctest/peer.go new file mode 100644 index 00000000..a5fef8be --- /dev/null +++ b/net/rpc/rpctest/peer.go @@ -0,0 +1,30 @@ +package rpctest + +import ( + "context" + "github.com/anyproto/any-sync/net/connutil" + "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/net/transport" + yamux2 "github.com/anyproto/any-sync/net/transport/yamux" + "github.com/hashicorp/yamux" + "net" +) + +func MultiConnPair(peerIdServ, peerIdClient string) (serv, client transport.MultiConn) { + sc, cc := net.Pipe() + var servConn = make(chan transport.MultiConn, 1) + go func() { + sess, err := yamux.Server(sc, yamux.DefaultConfig()) + if err != nil { + panic(err) + } + servConn <- yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdServ), connutil.NewLastUsageConn(sc), "", sess) + }() + sess, err := yamux.Client(cc, yamux.DefaultConfig()) + if err != nil { + panic(err) + } + client = yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdClient), connutil.NewLastUsageConn(cc), "", sess) + serv = <-servConn + return +} diff --git a/net/rpc/rpctest/pool.go b/net/rpc/rpctest/pool.go deleted file mode 100644 index a22a88cc..00000000 --- a/net/rpc/rpctest/pool.go +++ /dev/null @@ -1,122 +0,0 @@ -package rpctest - -import ( - "context" - "errors" - "github.com/anyproto/any-sync/app" - "github.com/anyproto/any-sync/net/peer" - "github.com/anyproto/any-sync/net/pool" - "math/rand" - "storj.io/drpc" - "sync" - "time" -) - -var ErrCantConnect = errors.New("can't connect to test server") - -func NewTestPool() *TestPool { - return &TestPool{ - peers: map[string]peer.Peer{}, - } -} - -type TestPool struct { - ts *TesServer - peers map[string]peer.Peer - mu sync.Mutex -} - -func (t *TestPool) WithServer(ts *TesServer) *TestPool { - t.mu.Lock() - defer t.mu.Unlock() - t.ts = ts - return t -} - -func (t *TestPool) Get(ctx context.Context, id string) (peer.Peer, error) { - t.mu.Lock() - defer t.mu.Unlock() - if p, ok := t.peers[id]; ok { - return p, nil - } - if t.ts == nil { - return nil, ErrCantConnect - } - return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil -} - -func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) { - return t.Get(ctx, id) -} - -func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { - t.mu.Lock() - defer t.mu.Unlock() - for _, peerId := range peerIds { - if p, ok := t.peers[peerId]; ok { - return p, nil - } - } - if t.ts == nil { - return nil, ErrCantConnect - } - return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil -} - -func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { - t.mu.Lock() - defer t.mu.Unlock() - if t.ts == nil { - return nil, ErrCantConnect - } - return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil -} - -func (t *TestPool) NewPool(name string) pool.Pool { - return t -} - -func (t *TestPool) AddPeer(p peer.Peer) { - t.mu.Lock() - defer t.mu.Unlock() - t.peers[p.Id()] = p -} - -func (t *TestPool) Init(a *app.App) (err error) { - return nil -} - -func (t *TestPool) Name() (name string) { - return pool.CName -} - -func (t *TestPool) Run(ctx context.Context) (err error) { - return nil -} - -func (t *TestPool) Close(ctx context.Context) (err error) { - return nil -} - -type testPeer struct { - id string - drpc.Conn -} - -func (t testPeer) Addr() string { - return "" -} - -func (t testPeer) TryClose(objectTTL time.Duration) (res bool, err error) { - return true, t.Close() -} - -func (t testPeer) Id() string { - return t.id -} - -func (t testPeer) LastUsage() time.Time { - return time.Now() -} - -func (t testPeer) UpdateLastUsage() {} diff --git a/net/rpc/rpctest/server.go b/net/rpc/rpctest/server.go index 4731a8e2..053187f4 100644 --- a/net/rpc/rpctest/server.go +++ b/net/rpc/rpctest/server.go @@ -5,43 +5,39 @@ import ( "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/net/rpc/server" "net" - "storj.io/drpc" - "storj.io/drpc/drpcconn" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" ) -func NewTestServer() *TesServer { - ts := &TesServer{ +func NewTestServer() *TestServer { + ts := &TestServer{ Mux: drpcmux.New(), } ts.Server = drpcserver.New(ts.Mux) return ts } -type TesServer struct { +type TestServer struct { *drpcmux.Mux *drpcserver.Server } -func (ts *TesServer) Init(a *app.App) (err error) { +func (ts *TestServer) Init(a *app.App) (err error) { return nil } -func (ts *TesServer) Name() (name string) { +func (ts *TestServer) Name() (name string) { return server.CName } -func (ts *TesServer) Run(ctx context.Context) (err error) { +func (ts *TestServer) Run(ctx context.Context) (err error) { return nil } -func (ts *TesServer) Close(ctx context.Context) (err error) { +func (ts *TestServer) Close(ctx context.Context) (err error) { return nil } -func (ts *TesServer) Dial(ctx context.Context) drpc.Conn { - sc, cc := net.Pipe() - go ts.Server.ServeOne(ctx, sc) - return drpcconn.New(cc) +func (s *TestServer) ServeConn(ctx context.Context, conn net.Conn) (err error) { + return s.Server.ServeOne(ctx, conn) } diff --git a/net/streampool/streampool_test.go b/net/streampool/streampool_test.go index 75d51059..6a053bd4 100644 --- a/net/streampool/streampool_test.go +++ b/net/streampool/streampool_test.go @@ -18,17 +18,25 @@ import ( var ctx = context.Background() +func makePeerPair(t *testing.T, fx *fixture, peerId string) (pS, pC peer.Peer) { + mcS, mcC := rpctest.MultiConnPair(peerId+"server", peerId) + pS, err := peer.NewPeer(mcS, fx.ts) + require.NoError(t, err) + pC, err = peer.NewPeer(mcC, fx.ts) + require.NoError(t, err) + return +} + func newClientStream(t *testing.T, fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) { - p, err := fx.tp.Dial(ctx, peerId) + _, pC := makePeerPair(t, fx, peerId) + drpcConn, err := pC.AcquireDrpcConn(ctx) require.NoError(t, err) - ctx = peer.CtxWithPeerId(ctx, peerId) - s, err := testservice.NewDRPCTestClient(p).TestStream(ctx) + st, err = testservice.NewDRPCTestClient(drpcConn).TestStream(pC.Context()) require.NoError(t, err) - return s, p + return st, pC } func TestStreamPool_AddStream(t *testing.T) { - t.Run("broadcast incoming", func(t *testing.T) { fx := newFixture(t) defer fx.Finish(t) @@ -85,11 +93,10 @@ func TestStreamPool_Send(t *testing.T) { fx := newFixture(t) defer fx.Finish(t) - p, err := fx.tp.Dial(ctx, "p1") - require.NoError(t, err) + pS, _ := makePeerPair(t, fx, "p1") require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) { - return []peer.Peer{p}, nil + return []peer.Peer{pS}, nil })) var msg *testservice.StreamMessage @@ -100,12 +107,12 @@ func TestStreamPool_Send(t *testing.T) { } assert.Equal(t, "should open stream", msg.ReqData) }) + t.Run("parallel open stream", func(t *testing.T) { fx := newFixture(t) defer fx.Finish(t) - p, err := fx.tp.Dial(ctx, "p1") - require.NoError(t, err) + pS, _ := makePeerPair(t, fx, "p1") fx.th.streamOpenDelay = time.Second / 3 @@ -113,7 +120,7 @@ func TestStreamPool_Send(t *testing.T) { for i := 0; i < numMsgs; i++ { go require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) { - return []peer.Peer{p}, nil + return []peer.Peer{pS}, nil })) } @@ -134,9 +141,8 @@ func TestStreamPool_Send(t *testing.T) { fx := newFixture(t) defer fx.Finish(t) - p, err := fx.tp.Dial(ctx, "p1") - require.NoError(t, err) - _ = p.Close() + pS, _ := makePeerPair(t, fx, "p1") + _ = pS.Close() fx.th.streamOpenDelay = time.Second / 3 @@ -147,11 +153,12 @@ func TestStreamPool_Send(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, p, &testservice.StreamMessage{ReqData: "should open stream"})) + assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, pS, &testservice.StreamMessage{ReqData: "should open stream"})) }() } wg.Wait() }) + } func TestStreamPool_SendById(t *testing.T) { @@ -196,10 +203,9 @@ func TestStreamPool_Tags(t *testing.T) { func newFixture(t *testing.T) *fixture { fx := &fixture{} - ts := rpctest.NewTestServer() + fx.ts = rpctest.NewTestServer() fx.tsh = &testServerHandler{receiveCh: make(chan *testservice.StreamMessage, 100)} - require.NoError(t, testservice.DRPCRegisterTest(ts, fx.tsh)) - fx.tp = rpctest.NewTestPool().WithServer(ts) + require.NoError(t, testservice.DRPCRegisterTest(fx.ts, fx.tsh)) fx.th = &testHandler{} fx.StreamPool = New().NewStreamPool(fx.th, StreamConfig{ SendQueueSize: 10, @@ -211,14 +217,13 @@ func newFixture(t *testing.T) *fixture { type fixture struct { StreamPool - tp *rpctest.TestPool th *testHandler tsh *testServerHandler + ts *rpctest.TestServer } func (fx *fixture) Finish(t *testing.T) { require.NoError(t, fx.Close()) - require.NoError(t, fx.tp.Close(ctx)) } type testHandler struct { @@ -231,7 +236,11 @@ func (t *testHandler) OpenStream(ctx context.Context, p peer.Peer) (stream drpc. if t.streamOpenDelay > 0 { time.Sleep(t.streamOpenDelay) } - stream, err = testservice.NewDRPCTestClient(p).TestStream(ctx) + conn, err := p.AcquireDrpcConn(ctx) + if err != nil { + return + } + stream, err = testservice.NewDRPCTestClient(conn).TestStream(p.Context()) return } diff --git a/net/transport/yamux/conn.go b/net/transport/yamux/conn.go index d563df34..5aa66162 100644 --- a/net/transport/yamux/conn.go +++ b/net/transport/yamux/conn.go @@ -9,6 +9,15 @@ import ( "time" ) +func NewMultiConn(cctx context.Context, luConn *connutil.LastUsageConn, addr string, sess *yamux.Session) transport.MultiConn { + return &yamuxConn{ + ctx: cctx, + luConn: luConn, + addr: addr, + Session: sess, + } +} + type yamuxConn struct { ctx context.Context luConn *connutil.LastUsageConn diff --git a/net/transport/yamux/yamux.go b/net/transport/yamux/yamux.go index 44729392..305fcb92 100644 --- a/net/transport/yamux/yamux.go +++ b/net/transport/yamux/yamux.go @@ -96,12 +96,7 @@ func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.Mu if err != nil { return } - mc = &yamuxConn{ - ctx: cctx, - luConn: luc, - Session: sess, - addr: addr, - } + mc = NewMultiConn(cctx, luc, addr, sess) return } @@ -148,12 +143,7 @@ func (y *yamuxTransport) accept(conn net.Conn) { log.Warn("incoming connection yamux session error", zap.Error(err)) return } - mc := &yamuxConn{ - ctx: cctx, - luConn: luc, - Session: sess, - addr: conn.RemoteAddr().String(), - } + mc := NewMultiConn(cctx, luc, conn.RemoteAddr().String(), sess) if err = y.accepter.Accept(mc); err != nil { log.Warn("connection accept error", zap.Error(err)) } From 9d4945c733d7995058add7f518626e27aba39098 Mon Sep 17 00:00:00 2001 From: Sergey Cherepanov Date: Mon, 5 Jun 2023 20:39:18 +0200 Subject: [PATCH 2/4] fix --- app/ocache/metrics.go | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/app/ocache/metrics.go b/app/ocache/metrics.go index b520dff2..c2dc04c5 100644 --- a/app/ocache/metrics.go +++ b/app/ocache/metrics.go @@ -6,6 +6,9 @@ import ( ) func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Option { + if reg == nil { + return nil + } if subsystem == "" { subsystem = "cache" } @@ -13,9 +16,7 @@ func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Optio subSplit := strings.Split(subsystem, ".") namespace = strings.Join(nameSplit, "_") subsystem = strings.Join(subSplit, "_") - if reg == nil { - return nil - } + return func(cache *oCache) { cache.metrics = &metrics{ hit: prometheus.NewCounter(prometheus.CounterOpts{ From af9d71d16ef446f0698c8017e463a0f6be993c23 Mon Sep 17 00:00:00 2001 From: Sergey Cherepanov Date: Mon, 5 Jun 2023 21:23:41 +0200 Subject: [PATCH 3/4] peer subConn gc --- net/peer/peer.go | 68 +++++++++++++++++++++++++++++++++++++------ net/peer/peer_test.go | 50 +++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+), 9 deletions(-) diff --git a/net/peer/peer.go b/net/peer/peer.go index ab0ae236..7f42260c 100644 --- a/net/peer/peer.go +++ b/net/peer/peer.go @@ -4,6 +4,7 @@ import ( "context" "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/ocache" + "github.com/anyproto/any-sync/net/connutil" "github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/net/transport" @@ -25,7 +26,7 @@ type connCtrl interface { func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) { ctx := mc.Context() pr := &peer{ - active: map[drpc.Conn]struct{}{}, + active: map[*subConn]struct{}{}, MultiConn: mc, ctrl: ctrl, } @@ -51,14 +52,19 @@ type Peer interface { ocache.Object } +type subConn struct { + drpc.Conn + *connutil.LastUsageConn +} + type peer struct { id string ctrl connCtrl // drpc conn pool - inactive []drpc.Conn - active map[drpc.Conn]struct{} + inactive []*subConn + active map[*subConn]struct{} mu sync.Mutex @@ -91,10 +97,14 @@ func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { func (p *peer) ReleaseDrpcConn(conn drpc.Conn) { p.mu.Lock() defer p.mu.Unlock() - if _, ok := p.active[conn]; ok { - delete(p.active, conn) + sc, ok := conn.(*subConn) + if !ok { + return } - p.inactive = append(p.inactive, conn) + if _, ok = p.active[sc]; ok { + delete(p.active, sc) + } + p.inactive = append(p.inactive, sc) return } @@ -107,7 +117,7 @@ func (p *peer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error return do(conn) } -func (p *peer) openDrpcConn(ctx context.Context) (dconn drpc.Conn, err error) { +func (p *peer) openDrpcConn(ctx context.Context) (dconn *subConn, err error) { conn, err := p.Open(ctx) if err != nil { return nil, err @@ -115,8 +125,11 @@ func (p *peer) openDrpcConn(ctx context.Context) (dconn drpc.Conn, err error) { if err = handshake.OutgoingProtoHandshake(ctx, conn, handshakeproto.ProtoType_DRPC); err != nil { return nil, err } - dconn = drpcconn.New(conn) - return + tconn := connutil.NewLastUsageConn(conn) + return &subConn{ + Conn: drpcconn.New(tconn), + LastUsageConn: tconn, + }, nil } func (p *peer) acceptLoop() { @@ -159,12 +172,49 @@ func (p *peer) serve(conn net.Conn) (err error) { } func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) { + p.gc(objectTTL) if time.Now().Sub(p.LastUsage()) < objectTTL { return false, nil } return true, p.Close() } +func (p *peer) gc(ttl time.Duration) { + p.mu.Lock() + defer p.mu.Unlock() + minLastUsage := time.Now().Add(-ttl) + var hasClosed bool + for i, in := range p.inactive { + select { + case <-in.Closed(): + p.inactive[i] = nil + hasClosed = true + default: + } + if in.LastUsage().Before(minLastUsage) { + _ = in.Close() + p.inactive[i] = nil + hasClosed = true + } + } + if hasClosed { + inactive := p.inactive + p.inactive = p.inactive[:0] + for _, in := range inactive { + if in != nil { + p.inactive = append(p.inactive, in) + } + } + } + for act := range p.active { + select { + case <-act.Closed(): + delete(p.active, act) + default: + } + } +} + func (p *peer) Close() (err error) { log.Debug("peer close", zap.String("peerId", p.id)) return p.MultiConn.Close() diff --git a/net/peer/peer_test.go b/net/peer/peer_test.go index 00a14252..ac1ff28b 100644 --- a/net/peer/peer_test.go +++ b/net/peer/peer_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "io" "net" + _ "net/http/pprof" "testing" "time" ) @@ -20,6 +21,9 @@ func TestPeer_AcquireDrpcConn(t *testing.T) { fx := newFixture(t, "p1") defer fx.finish() in, out := net.Pipe() + go func() { + handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker) + }() defer out.Close() fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil) dc, err := fx.AcquireDrpcConn(ctx) @@ -77,6 +81,52 @@ func TestPeer_TryClose(t *testing.T) { require.NoError(t, err) assert.True(t, res) }) + t.Run("gc", func(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + now := time.Now() + fx.mc.EXPECT().LastUsage().Return(now.Add(time.Millisecond * 100)) + + // make one inactive + in, out := net.Pipe() + go func() { + handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker) + }() + defer out.Close() + fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil) + dc, err := fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + + // make one active but closed + in2, out2 := net.Pipe() + go func() { + handshake.IncomingProtoHandshake(ctx, out2, defaultProtoChecker) + }() + defer out2.Close() + fx.mc.EXPECT().Open(gomock.Any()).Return(in2, nil) + dc2, err := fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + _ = dc2.Close() + + // make one inactive and closed + in3, out3 := net.Pipe() + go func() { + handshake.IncomingProtoHandshake(ctx, out3, defaultProtoChecker) + }() + defer out3.Close() + fx.mc.EXPECT().Open(gomock.Any()).Return(in3, nil) + dc3, err := fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + fx.ReleaseDrpcConn(dc3) + _ = dc3.Close() + fx.ReleaseDrpcConn(dc) + + time.Sleep(time.Millisecond * 100) + + res, err := fx.TryClose(time.Millisecond * 50) + require.NoError(t, err) + assert.False(t, res) + }) } type acceptedConn struct { From 0b4f08fbefb5833213c4ec22fd4bb38a7264f9eb Mon Sep 17 00:00:00 2001 From: Sergey Cherepanov Date: Mon, 5 Jun 2023 21:34:23 +0200 Subject: [PATCH 4/4] remove stream multiqueue --- net/streampool/stream.go | 31 +++++++++++++++++-------------- net/streampool/streampool.go | 28 ++++++++-------------------- net/streampool/streampool_test.go | 2 +- 3 files changed, 26 insertions(+), 35 deletions(-) diff --git a/net/streampool/stream.go b/net/streampool/stream.go index 5dff0cb9..1d59af4b 100644 --- a/net/streampool/stream.go +++ b/net/streampool/stream.go @@ -3,7 +3,7 @@ package streampool import ( "context" "github.com/anyproto/any-sync/app/logger" - "github.com/anyproto/any-sync/util/multiqueue" + "github.com/cheggaaa/mb/v3" "go.uber.org/zap" "storj.io/drpc" "sync/atomic" @@ -17,17 +17,12 @@ type stream struct { streamId uint32 closed atomic.Bool l logger.CtxLogger - queue multiqueue.MultiQueue[drpc.Message] + queue *mb.MB[drpc.Message] tags []string } func (sr *stream) write(msg drpc.Message) (err error) { - var queueId string - if qId, ok := msg.(MessageQueueId); ok { - queueId = qId.MessageQueueId() - msg = qId.DrpcMessage() - } - return sr.queue.Add(sr.stream.Context(), queueId, msg) + return sr.queue.Add(sr.stream.Context(), msg) } func (sr *stream) readLoop() error { @@ -50,13 +45,21 @@ func (sr *stream) readLoop() error { } } -func (sr *stream) writeToStream(msg drpc.Message) { - if err := sr.stream.MsgSend(msg, EncodingProto); err != nil { - sr.l.Warn("msg send error", zap.Error(err)) - sr.streamClose() - return +func (sr *stream) writeLoop() { + for { + msg, err := sr.queue.WaitOne(sr.peerCtx) + if err != nil { + if err != mb.ErrClosed { + sr.streamClose() + } + return + } + if err := sr.stream.MsgSend(msg, EncodingProto); err != nil { + sr.l.Warn("msg send error", zap.Error(err)) + sr.streamClose() + return + } } - return } func (sr *stream) streamClose() { diff --git a/net/streampool/streampool.go b/net/streampool/streampool.go index 59ee6e4c..50a020b3 100644 --- a/net/streampool/streampool.go +++ b/net/streampool/streampool.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/anyproto/any-sync/net" "github.com/anyproto/any-sync/net/peer" - "github.com/anyproto/any-sync/util/multiqueue" + "github.com/cheggaaa/mb/v3" "go.uber.org/zap" "golang.org/x/exp/slices" "golang.org/x/net/context" @@ -74,6 +74,9 @@ func (s *streamPool) ReadStream(drpcStream drpc.Stream, tags ...string) error { if err != nil { return err } + go func() { + st.writeLoop() + }() return st.readLoop() } @@ -85,6 +88,9 @@ func (s *streamPool) AddStream(drpcStream drpc.Stream, tags ...string) error { go func() { _ = st.readLoop() }() + go func() { + st.writeLoop() + }() return nil } @@ -122,7 +128,7 @@ func (s *streamPool) addStream(drpcStream drpc.Stream, tags ...string) (*stream, l: log.With(zap.String("peerId", peerId), zap.Uint32("streamId", streamId)), tags: tags, } - st.queue = multiqueue.New[drpc.Message](st.writeToStream, s.writeQueueSize) + st.queue = mb.New[drpc.Message](s.writeQueueSize) s.streams[streamId] = st s.streamIdsByPeer[peerId] = append(s.streamIdsByPeer[peerId], streamId) for _, tag := range tags { @@ -364,21 +370,3 @@ func removeStream(m map[string][]uint32, key string, streamId uint32) { m[key] = streamIds } } - -// WithQueueId wraps the message and adds queueId -func WithQueueId(msg drpc.Message, queueId string) drpc.Message { - return &messageWithQueueId{queueId: queueId, Message: msg} -} - -type messageWithQueueId struct { - drpc.Message - queueId string -} - -func (m messageWithQueueId) MessageQueueId() string { - return m.queueId -} - -func (m messageWithQueueId) DrpcMessage() drpc.Message { - return m.Message -} diff --git a/net/streampool/streampool_test.go b/net/streampool/streampool_test.go index 6a053bd4..d4a05de3 100644 --- a/net/streampool/streampool_test.go +++ b/net/streampool/streampool_test.go @@ -47,7 +47,7 @@ func TestStreamPool_AddStream(t *testing.T) { require.NoError(t, fx.AddStream(s2, "space2", "common")) require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space1"}, "space1")) - require.NoError(t, fx.Broadcast(ctx, WithQueueId(&testservice.StreamMessage{ReqData: "space2"}, "q2"), "space2")) + require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space2"}, "space2")) require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "common"}, "common")) var serverResults []string