diff --git a/net/streampool/streampool.go b/net/streampool/streampool.go index 0a2b4419..a846ef51 100644 --- a/net/streampool/streampool.go +++ b/net/streampool/streampool.go @@ -21,6 +21,9 @@ type StreamHandler interface { NewReadMessage() drpc.Message } +// PeerGetter should dial or return cached peers +type PeerGetter func(ctx context.Context) (peers []peer.Peer, err error) + // StreamPool keeps and read streams type StreamPool interface { // AddStream adds new outgoing stream into the pool @@ -28,7 +31,7 @@ type StreamPool interface { // ReadStream adds new incoming stream and synchronously read it ReadStream(peerId string, stream drpc.Stream, tags ...string) (err error) // Send sends a message to given peers. A stream will be opened if it is not cached before. Works async. - Send(ctx context.Context, msg drpc.Message, peers ...peer.Peer) (err error) + Send(ctx context.Context, msg drpc.Message, target PeerGetter) (err error) // SendById sends a message to given peerIds. Works only if stream exists SendById(ctx context.Context, msg drpc.Message, peerIds ...string) (err error) // Broadcast sends a message to all peers with given tags. Works async. @@ -95,7 +98,7 @@ func (s *streamPool) addStream(peerId string, drpcStream drpc.Stream, tags ...st return st } -func (s *streamPool) Send(ctx context.Context, msg drpc.Message, peers ...peer.Peer) (err error) { +func (s *streamPool) Send(ctx context.Context, msg drpc.Message, peerGetter PeerGetter) (err error) { var sendOneFunc = func(sp peer.Peer) func() { return func() { if e := s.sendOne(ctx, sp, msg); e != nil { @@ -105,13 +108,17 @@ func (s *streamPool) Send(ctx context.Context, msg drpc.Message, peers ...peer.P } } } - - for _, p := range peers { - if err = s.exec.Add(ctx, sendOneFunc(p)); err != nil { - return + return s.exec.Add(ctx, func() { + peers, dialErr := peerGetter(ctx) + if dialErr != nil { + log.InfoCtx(ctx, "can't get peers", zap.Error(dialErr)) } - } - return + for _, p := range peers { + if err = s.exec.Add(ctx, sendOneFunc(p)); err != nil { + return + } + } + }) } func (s *streamPool) SendById(ctx context.Context, msg drpc.Message, peerIds ...string) (err error) { diff --git a/net/streampool/streampool_test.go b/net/streampool/streampool_test.go index 671b2aa2..c67377cc 100644 --- a/net/streampool/streampool_test.go +++ b/net/streampool/streampool_test.go @@ -66,7 +66,9 @@ func TestStreamPool_AddStream(t *testing.T) { defer s1.Close() fx.AddStream("p1", s1, "space1", "common") - require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "test"}, p1)) + require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "test"}, func(ctx context.Context) (peers []peer.Peer, err error) { + return []peer.Peer{p1}, nil + })) var msg *testservice.StreamMessage select { case msg = <-fx.tsh.receiveCh: @@ -85,7 +87,9 @@ func TestStreamPool_Send(t *testing.T) { p, err := fx.tp.Dial(ctx, "p1") require.NoError(t, err) - require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, p)) + require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) { + return []peer.Peer{p}, nil + })) var msg *testservice.StreamMessage select { @@ -107,7 +111,9 @@ func TestStreamPool_Send(t *testing.T) { var numMsgs = 5 for i := 0; i < numMsgs; i++ { - go require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, p)) + go require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) { + return []peer.Peer{p}, nil + })) } var msgs []*testservice.StreamMessage