diff --git a/net/dialer/dialer.go b/net/dialer/dialer.go index 24718286..2076f959 100644 --- a/net/dialer/dialer.go +++ b/net/dialer/dialer.go @@ -34,6 +34,7 @@ func New() Dialer { type Dialer interface { Dial(ctx context.Context, peerId string) (peer peer.Peer, err error) UpdateAddrs(addrs map[string][]string) + SetPeerAddrs(peerId string, addrs []string) app.Component } @@ -62,6 +63,15 @@ func (d *dialer) UpdateAddrs(addrs map[string][]string) { d.mu.Unlock() } +func (d *dialer) SetPeerAddrs(peerId string, addrs []string) { + d.mu.Lock() + defer d.mu.Unlock() + if d.peerAddrs == nil { + return + } + d.peerAddrs[peerId] = addrs +} + func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err error) { d.mu.RLock() defer d.mu.RUnlock() diff --git a/net/pool/pool_test.go b/net/pool/pool_test.go index 94fde050..f913333c 100644 --- a/net/pool/pool_test.go +++ b/net/pool/pool_test.go @@ -160,6 +160,10 @@ func (d *dialerMock) UpdateAddrs(addrs map[string][]string) { return } +func (d *dialerMock) SetPeerAddrs(peerId string, addrs []string) { + return +} + func (d *dialerMock) Init(a *app.App) (err error) { return } diff --git a/net/streampool/streampool.go b/net/streampool/streampool.go index 8319e972..01527d14 100644 --- a/net/streampool/streampool.go +++ b/net/streampool/streampool.go @@ -40,6 +40,8 @@ type StreamPool interface { AddTagsCtx(ctx context.Context, tags ...string) error // RemoveTagsCtx removes tags from stream, stream will be extracted from ctx RemoveTagsCtx(ctx context.Context, tags ...string) error + // Streams gets all streams for specific tags + Streams(tags ...string) (streams []drpc.Stream) // Close closes all streams Close() error } @@ -73,6 +75,17 @@ func (s *streamPool) AddStream(peerId string, drpcStream drpc.Stream, tags ...st }() } +func (s *streamPool) Streams(tags ...string) (streams []drpc.Stream) { + s.mu.Lock() + defer s.mu.Unlock() + for _, tag := range tags { + for _, id := range s.streamIdsByTag[tag] { + streams = append(streams, s.streams[id].stream) + } + } + return +} + func (s *streamPool) addStream(peerId string, drpcStream drpc.Stream, tags ...string) *stream { s.mu.Lock() defer s.mu.Unlock()