any-sync/net/streampool/streampool.go
2023-01-20 14:55:45 +03:00

293 lines
7.5 KiB
Go

package streampool
import (
"github.com/anytypeio/any-sync/net/peer"
"github.com/anytypeio/any-sync/net/pool"
"github.com/cheggaaa/mb/v3"
"go.uber.org/zap"
"golang.org/x/exp/slices"
"golang.org/x/net/context"
"storj.io/drpc"
"sync"
)
// StreamHandler handles incoming messages from streams
type StreamHandler interface {
// OpenStream opens stream with given peer
OpenStream(ctx context.Context, p peer.Peer) (stream drpc.Stream, tags []string, err error)
// HandleMessage handles incoming message
HandleMessage(ctx context.Context, peerId string, msg drpc.Message) (err error)
// NewReadMessage creates new empty message for unmarshalling into it
NewReadMessage() drpc.Message
}
// StreamPool keeps and read streams
type StreamPool interface {
// AddStream adds new outgoing stream into the pool
AddStream(peerId string, stream drpc.Stream, tags ...string)
// ReadStream adds new incoming stream and synchronously read it
ReadStream(peerId string, stream drpc.Stream, tags ...string) (err error)
// Send sends a message to given peers. A stream will be opened if it is not cached before. Works async.
Send(ctx context.Context, msg drpc.Message, peers ...peer.Peer) (err error)
// SendById sends a message to given peerIds. Works only if stream exists
SendById(ctx context.Context, msg drpc.Message, peerIds ...string) (err error)
// Broadcast sends a message to all peers with given tags. Works async.
Broadcast(ctx context.Context, msg drpc.Message, tags ...string) (err error)
// Close closes all streams
Close() error
}
type streamPool struct {
handler StreamHandler
streamIdsByPeer map[string][]uint32
streamIdsByTag map[string][]uint32
streams map[uint32]*stream
opening map[string]*openingProcess
exec *sendPool
handleQueue *mb.MB[handleMessage]
mu sync.RWMutex
lastStreamId uint32
}
type openingProcess struct {
ch chan struct{}
err error
}
type handleMessage struct {
ctx context.Context
msg drpc.Message
peerId string
}
func (s *streamPool) init() {
// TODO: to config
for i := 0; i < 10; i++ {
go s.handleMessageLoop()
}
}
func (s *streamPool) ReadStream(peerId string, drpcStream drpc.Stream, tags ...string) error {
st := s.addStream(peerId, drpcStream, tags...)
return st.readLoop()
}
func (s *streamPool) AddStream(peerId string, drpcStream drpc.Stream, tags ...string) {
st := s.addStream(peerId, drpcStream, tags...)
go func() {
_ = st.readLoop()
}()
}
func (s *streamPool) addStream(peerId string, drpcStream drpc.Stream, tags ...string) *stream {
s.mu.Lock()
defer s.mu.Unlock()
s.lastStreamId++
streamId := s.lastStreamId
st := &stream{
peerId: peerId,
stream: drpcStream,
pool: s,
streamId: streamId,
l: log.With(zap.String("peerId", peerId), zap.Uint32("streamId", streamId)),
tags: tags,
}
s.streams[streamId] = st
s.streamIdsByPeer[peerId] = append(s.streamIdsByPeer[peerId], streamId)
for _, tag := range tags {
s.streamIdsByTag[tag] = append(s.streamIdsByTag[tag], streamId)
}
return st
}
func (s *streamPool) Send(ctx context.Context, msg drpc.Message, peers ...peer.Peer) (err error) {
var funcs []func()
for _, p := range peers {
funcs = append(funcs, func() {
if e := s.sendOne(ctx, p, msg); e != nil {
log.Info("send peer error", zap.Error(e), zap.String("peerId", p.Id()))
}
})
}
return s.exec.Add(ctx, funcs...)
}
func (s *streamPool) SendById(ctx context.Context, msg drpc.Message, peerIds ...string) (err error) {
s.mu.Lock()
var streams []*stream
for _, peerId := range peerIds {
for _, streamId := range s.streamIdsByPeer[peerId] {
streams = append(streams, s.streams[streamId])
}
}
s.mu.Unlock()
var funcs []func()
for _, st := range streams {
funcs = append(funcs, func() {
if e := st.write(msg); e != nil {
st.l.Debug("sendById write error", zap.Error(e))
}
})
}
if len(funcs) == 0 {
return pool.ErrUnableToConnect
}
return s.exec.Add(ctx, funcs...)
}
func (s *streamPool) sendOne(ctx context.Context, p peer.Peer, msg drpc.Message) (err error) {
// get all streams relates to peer
streams, err := s.getStreams(ctx, p)
if err != nil {
return
}
for _, st := range streams {
if err = st.write(msg); err != nil {
st.l.Info("sendOne write error", zap.Error(err))
// continue with next stream
continue
} else {
// stop sending on success
break
}
}
return
}
func (s *streamPool) getStreams(ctx context.Context, p peer.Peer) (streams []*stream, err error) {
s.mu.Lock()
// check cached streams
streamIds := s.streamIdsByPeer[p.Id()]
for _, streamId := range streamIds {
streams = append(streams, s.streams[streamId])
}
var op *openingProcess
// no cached streams found
if len(streams) == 0 {
// start opening process
op = s.openStream(ctx, p)
}
s.mu.Unlock()
// not empty openingCh means we should wait for the stream opening and try again
if op != nil {
select {
case <-op.ch:
if op.err != nil {
return nil, op.err
}
return s.getStreams(ctx, p)
case <-ctx.Done():
return nil, ctx.Err()
}
}
return streams, nil
}
func (s *streamPool) openStream(ctx context.Context, p peer.Peer) *openingProcess {
if op, ok := s.opening[p.Id()]; ok {
// already have an opening process for this stream - return channel
return op
}
op := &openingProcess{
ch: make(chan struct{}),
}
s.opening[p.Id()] = op
go func() {
// start stream opening in separate goroutine to avoid lock whole pool
defer func() {
s.mu.Lock()
defer s.mu.Unlock()
close(op.ch)
delete(s.opening, p.Id())
}()
// open new stream and add to pool
st, tags, err := s.handler.OpenStream(ctx, p)
if err != nil {
op.err = err
return
}
s.AddStream(p.Id(), st, tags...)
}()
return op
}
func (s *streamPool) Broadcast(ctx context.Context, msg drpc.Message, tags ...string) (err error) {
s.mu.Lock()
var streams []*stream
for _, tag := range tags {
for _, streamId := range s.streamIdsByTag[tag] {
streams = append(streams, s.streams[streamId])
}
}
s.mu.Unlock()
var funcs []func()
for _, st := range streams {
funcs = append(funcs, func() {
if e := st.write(msg); e != nil {
log.Debug("broadcast write error", zap.Error(e))
}
})
}
if len(funcs) == 0 {
return
}
return s.exec.Add(ctx, funcs...)
}
func (s *streamPool) removeStream(streamId uint32) {
s.mu.Lock()
defer s.mu.Unlock()
st := s.streams[streamId]
if st == nil {
log.Fatal("removeStream: stream does not exist", zap.Uint32("streamId", streamId))
}
var removeStream = func(m map[string][]uint32, key string) {
streamIds := m[key]
idx := slices.Index(streamIds, streamId)
if idx == -1 {
log.Fatal("removeStream: streamId does not exist", zap.Uint32("streamId", streamId))
}
streamIds = slices.Delete(streamIds, idx, idx+1)
if len(streamIds) == 0 {
delete(m, key)
} else {
m[key] = streamIds
}
}
removeStream(s.streamIdsByPeer, st.peerId)
for _, tag := range st.tags {
removeStream(s.streamIdsByTag, tag)
}
delete(s.streams, streamId)
st.l.Debug("stream removed", zap.Strings("tags", st.tags))
}
func (s *streamPool) HandleMessage(ctx context.Context, peerId string, msg drpc.Message) (err error) {
return s.handleQueue.Add(ctx, handleMessage{
ctx: ctx,
msg: msg,
peerId: peerId,
})
}
func (s *streamPool) handleMessageLoop() {
for {
hm, err := s.handleQueue.WaitOne(context.Background())
if err != nil {
return
}
go func() {
if err = s.handler.HandleMessage(hm.ctx, hm.peerId, hm.msg); err != nil {
log.Warn("handle message error", zap.Error(err))
}
}()
}
}
func (s *streamPool) Close() (err error) {
return s.exec.Close()
}