From a898c6fc9ca38ab87d960eafbed32e444caecde2 Mon Sep 17 00:00:00 2001 From: Sergey Cherepanov Date: Mon, 29 May 2023 16:33:33 +0200 Subject: [PATCH] pool.AddPeer close previous peer --- net/pool/pool.go | 18 +++++++++++++++--- net/pool/pool_test.go | 21 +++++++++++++++++++++ 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/net/pool/pool.go b/net/pool/pool.go index f36318df..37f8328e 100644 --- a/net/pool/pool.go +++ b/net/pool/pool.go @@ -17,7 +17,7 @@ type Pool interface { // GetOneOf searches at least one existing connection in outgoing or creates a new one from a randomly selected id from given list GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) // AddPeer adds incoming peer to the pool - AddPeer(p peer.Peer) (err error) + AddPeer(ctx context.Context, p peer.Peer) (err error) } type pool struct { @@ -89,6 +89,18 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error return nil, lastErr } -func (p *pool) AddPeer(pr peer.Peer) (err error) { - return p.incoming.Add(pr.Id(), pr) +func (p *pool) AddPeer(ctx context.Context, pr peer.Peer) (err error) { + if err = p.incoming.Add(pr.Id(), pr); err != nil { + if err == ocache.ErrExists { + // in case when an incoming connection with a peer already exists, we close and remove an existing connection + if v, e := p.incoming.Pick(ctx, pr.Id()); e == nil { + _ = v.Close() + _, _ = p.incoming.Remove(ctx, pr.Id()) + return p.incoming.Add(pr.Id(), pr) + } + } else { + return err + } + } + return } diff --git a/net/pool/pool_test.go b/net/pool/pool_test.go index 71bfa1c6..b1cc0087 100644 --- a/net/pool/pool_test.go +++ b/net/pool/pool_test.go @@ -132,6 +132,27 @@ func TestPool_GetOneOf(t *testing.T) { }) } +func TestPool_AddPeer(t *testing.T) { + t.Run("success", func(t *testing.T) { + fx := newFixture(t) + defer fx.Finish() + require.NoError(t, fx.AddPeer(ctx, newTestPeer("p1"))) + }) + t.Run("two peers", func(t *testing.T) { + fx := newFixture(t) + defer fx.Finish() + p1, p2 := newTestPeer("p1"), newTestPeer("p1") + require.NoError(t, fx.AddPeer(ctx, p1)) + require.NoError(t, fx.AddPeer(ctx, p2)) + select { + case <-p1.closed: + default: + assert.Truef(t, false, "peer not closed") + } + }) + +} + func newFixture(t *testing.T) *fixture { fx := &fixture{ Service: New(),