package streampool import ( "github.com/anytypeio/any-sync/net/peer" "go.uber.org/zap" "golang.org/x/exp/slices" "golang.org/x/net/context" "storj.io/drpc" "sync" ) // StreamHandler handles incoming messages from streams type StreamHandler interface { // OpenStream opens stream with given peer OpenStream(ctx context.Context, p peer.Peer) (stream drpc.Stream, tags []string, err error) // HandleMessage handles incoming message HandleMessage(ctx context.Context, peerId string, msg drpc.Message) (err error) // NewReadMessage creates new empty message for unmarshalling into it NewReadMessage() drpc.Message } // StreamPool keeps and read streams type StreamPool interface { // AddStream adds new incoming stream into the pool AddStream(peerId string, stream drpc.Stream, tags ...string) // Send sends a message to given peers. A stream will be opened if it is not cached before. Works async. Send(ctx context.Context, msg drpc.Message, peers ...peer.Peer) (err error) // Broadcast sends a message to all peers with given tags. Works async. Broadcast(ctx context.Context, msg drpc.Message, tags ...string) (err error) // Close closes all streams Close() error } type streamPool struct { handler StreamHandler streamIdsByPeer map[string][]uint32 streamIdsByTag map[string][]uint32 streams map[uint32]*stream opening map[string]chan struct{} exec *sendPool mu sync.RWMutex lastStreamId uint32 } func (s *streamPool) AddStream(peerId string, drpcStream drpc.Stream, tags ...string) { s.mu.Lock() defer s.mu.Unlock() s.lastStreamId++ streamId := s.lastStreamId st := &stream{ peerId: peerId, stream: drpcStream, pool: s, streamId: streamId, l: log.With(zap.String("peerId", peerId), zap.Uint32("streamId", streamId)), tags: tags, } s.streams[streamId] = st s.streamIdsByPeer[peerId] = append(s.streamIdsByPeer[peerId], streamId) for _, tag := range tags { s.streamIdsByTag[tag] = append(s.streamIdsByTag[tag], streamId) } go st.readLoop() } func (s *streamPool) Send(ctx context.Context, msg drpc.Message, peers ...peer.Peer) (err error) { var funcs []func() for _, p := range peers { funcs = append(funcs, func() { if e := s.sendOne(ctx, p, msg); e != nil { log.Info("send peer error", zap.Error(e)) } }) } return s.exec.Add(ctx, funcs...) } func (s *streamPool) sendOne(ctx context.Context, p peer.Peer, msg drpc.Message) (err error) { // get all streams relates to peer streams, err := s.getStreams(ctx, p) if err != nil { return } for _, st := range streams { if err = st.write(msg); err != nil { log.Info("stream write error", zap.Error(err)) // continue with next stream continue } else { // stop sending on success break } } return } func (s *streamPool) getStreams(ctx context.Context, p peer.Peer) (streams []*stream, err error) { s.mu.Lock() // check cached streams streamIds := s.streamIdsByPeer[p.Id()] for _, streamId := range streamIds { streams = append(streams, s.streams[streamId]) } var openingCh chan struct{} // no cached streams found if len(streams) == 0 { // start opening process openingCh = s.openStream(ctx, p) } s.mu.Unlock() // not empty openingCh means we should wait for the stream opening and try again if openingCh != nil { select { case <-openingCh: return s.getStreams(ctx, p) case <-ctx.Done(): return nil, ctx.Err() } } return streams, nil } func (s *streamPool) openStream(ctx context.Context, p peer.Peer) chan struct{} { if ch, ok := s.opening[p.Id()]; ok { // already have an opening process for this stream - return channel return ch } ch := make(chan struct{}) s.opening[p.Id()] = ch go func() { // start stream opening in separate goroutine to avoid lock whole pool defer func() { s.mu.Lock() defer s.mu.Unlock() close(ch) delete(s.opening, p.Id()) }() // open new stream and add to pool st, tags, err := s.handler.OpenStream(ctx, p) if err != nil { log.Warn("stream open error", zap.Error(err)) return } s.AddStream(p.Id(), st, tags...) }() return ch } func (s *streamPool) Broadcast(ctx context.Context, msg drpc.Message, tags ...string) (err error) { s.mu.Lock() var streams []*stream for _, tag := range tags { for _, streamId := range s.streamIdsByTag[tag] { streams = append(streams, s.streams[streamId]) } } s.mu.Unlock() var funcs []func() for _, st := range streams { funcs = append(funcs, func() { if e := st.write(msg); e != nil { log.Debug("broadcast write error", zap.Error(e)) } }) } return s.exec.Add(ctx, funcs...) } func (s *streamPool) removeStream(streamId uint32) { s.mu.Lock() defer s.mu.Unlock() st := s.streams[streamId] if st == nil { log.Fatal("removeStream: stream does not exist", zap.Uint32("streamId", streamId)) } var removeStream = func(m map[string][]uint32, key string) { streamIds := m[key] idx := slices.Index(streamIds, streamId) if idx == -1 { log.Fatal("removeStream: streamId does not exist", zap.Uint32("streamId", streamId)) } streamIds = slices.Delete(streamIds, idx, idx+1) if len(streamIds) == 0 { delete(m, key) } else { m[key] = streamIds } } removeStream(s.streamIdsByPeer, st.peerId) for _, tag := range st.tags { removeStream(s.streamIdsByTag, tag) } delete(s.streams, streamId) } func (s *streamPool) Close() (err error) { return s.exec.Close() }