any-sync/net/streampool/streampool_test.go
2023-05-15 20:01:52 +02:00

274 lines
7.1 KiB
Go

package streampool
import (
"fmt"
"github.com/anytypeio/any-sync/net/peer"
"github.com/anytypeio/any-sync/net/rpc/rpctest"
"github.com/anytypeio/any-sync/net/streampool/testservice"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
"sort"
"storj.io/drpc"
"sync"
"sync/atomic"
"testing"
"time"
)
var ctx = context.Background()
func newClientStream(t *testing.T, fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) {
p, err := fx.tp.Dial(ctx, peerId)
require.NoError(t, err)
ctx = peer.CtxWithPeerId(ctx, peerId)
s, err := testservice.NewDRPCTestClient(p).TestStream(ctx)
require.NoError(t, err)
return s, p
}
func TestStreamPool_AddStream(t *testing.T) {
t.Run("broadcast incoming", func(t *testing.T) {
fx := newFixture(t)
defer fx.Finish(t)
s1, _ := newClientStream(t, fx, "p1")
require.NoError(t, fx.AddStream(s1, "space1", "common"))
s2, _ := newClientStream(t, fx, "p2")
require.NoError(t, fx.AddStream(s2, "space2", "common"))
require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space1"}, "space1"))
require.NoError(t, fx.Broadcast(ctx, WithQueueId(&testservice.StreamMessage{ReqData: "space2"}, "q2"), "space2"))
require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "common"}, "common"))
var serverResults []string
for i := 0; i < 4; i++ {
select {
case msg := <-fx.tsh.receiveCh:
serverResults = append(serverResults, msg.ReqData)
case <-time.After(time.Second):
require.NoError(t, fmt.Errorf("timeout"))
}
}
sort.Strings(serverResults)
assert.Equal(t, []string{"common", "common", "space1", "space2"}, serverResults)
assert.NoError(t, s1.Close())
assert.NoError(t, s2.Close())
})
t.Run("send incoming", func(t *testing.T) {
fx := newFixture(t)
defer fx.Finish(t)
s1, p1 := newClientStream(t, fx, "p1")
defer s1.Close()
require.NoError(t, fx.AddStream(s1, "space1", "common"))
require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "test"}, func(ctx context.Context) (peers []peer.Peer, err error) {
return []peer.Peer{p1}, nil
}))
var msg *testservice.StreamMessage
select {
case msg = <-fx.tsh.receiveCh:
case <-time.After(time.Second):
require.NoError(t, fmt.Errorf("timeout"))
}
assert.Equal(t, "test", msg.ReqData)
})
}
func TestStreamPool_Send(t *testing.T) {
t.Run("open stream", func(t *testing.T) {
fx := newFixture(t)
defer fx.Finish(t)
p, err := fx.tp.Dial(ctx, "p1")
require.NoError(t, err)
require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) {
return []peer.Peer{p}, nil
}))
var msg *testservice.StreamMessage
select {
case msg = <-fx.tsh.receiveCh:
case <-time.After(time.Second):
require.NoError(t, fmt.Errorf("timeout"))
}
assert.Equal(t, "should open stream", msg.ReqData)
})
t.Run("parallel open stream", func(t *testing.T) {
fx := newFixture(t)
defer fx.Finish(t)
p, err := fx.tp.Dial(ctx, "p1")
require.NoError(t, err)
fx.th.streamOpenDelay = time.Second / 3
var numMsgs = 5
for i := 0; i < numMsgs; i++ {
go require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) {
return []peer.Peer{p}, nil
}))
}
var msgs []*testservice.StreamMessage
for i := 0; i < numMsgs; i++ {
select {
case msg := <-fx.tsh.receiveCh:
msgs = append(msgs, msg)
case <-time.After(time.Second):
require.NoError(t, fmt.Errorf("timeout"))
}
}
assert.Len(t, msgs, numMsgs)
// make sure that we have only one stream
assert.Equal(t, int32(1), fx.tsh.streamsCount.Load())
})
t.Run("parallel open stream error", func(t *testing.T) {
fx := newFixture(t)
defer fx.Finish(t)
p, err := fx.tp.Dial(ctx, "p1")
require.NoError(t, err)
_ = p.Close()
fx.th.streamOpenDelay = time.Second / 3
var numMsgs = 5
var wg sync.WaitGroup
for i := 0; i < numMsgs; i++ {
wg.Add(1)
go func() {
defer wg.Done()
assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, p, &testservice.StreamMessage{ReqData: "should open stream"}))
}()
}
wg.Wait()
})
}
func TestStreamPool_SendById(t *testing.T) {
fx := newFixture(t)
defer fx.Finish(t)
s1, _ := newClientStream(t, fx, "p1")
defer s1.Close()
require.NoError(t, fx.AddStream(s1, "space1", "common"))
require.NoError(t, fx.SendById(ctx, &testservice.StreamMessage{ReqData: "test"}, "p1"))
var msg *testservice.StreamMessage
select {
case msg = <-fx.tsh.receiveCh:
case <-time.After(time.Second):
require.NoError(t, fmt.Errorf("timeout"))
}
assert.Equal(t, "test", msg.ReqData)
}
func TestStreamPool_Tags(t *testing.T) {
fx := newFixture(t)
defer fx.Finish(t)
s1, _ := newClientStream(t, fx, "p1")
defer s1.Close()
require.NoError(t, fx.AddStream(s1, "t1"))
s2, _ := newClientStream(t, fx, "p2")
defer s1.Close()
require.NoError(t, fx.AddStream(s2, "t2"))
err := fx.AddTagsCtx(streamCtx(ctx, 1, "p1"), "t3", "t3")
require.NoError(t, err)
assert.Equal(t, []uint32{1}, fx.StreamPool.(*streamPool).streamIdsByTag["t3"])
err = fx.RemoveTagsCtx(streamCtx(ctx, 2, "p2"), "t2")
require.NoError(t, err)
assert.Len(t, fx.StreamPool.(*streamPool).streamIdsByTag["t2"], 0)
}
func newFixture(t *testing.T) *fixture {
fx := &fixture{}
ts := rpctest.NewTestServer()
fx.tsh = &testServerHandler{receiveCh: make(chan *testservice.StreamMessage, 100)}
require.NoError(t, testservice.DRPCRegisterTest(ts, fx.tsh))
fx.tp = rpctest.NewTestPool().WithServer(ts)
fx.th = &testHandler{}
fx.StreamPool = New().NewStreamPool(fx.th, StreamConfig{
SendQueueSize: 10,
DialQueueWorkers: 1,
DialQueueSize: 10,
})
return fx
}
type fixture struct {
StreamPool
tp *rpctest.TestPool
th *testHandler
tsh *testServerHandler
}
func (fx *fixture) Finish(t *testing.T) {
require.NoError(t, fx.Close())
require.NoError(t, fx.tp.Close(ctx))
}
type testHandler struct {
streamOpenDelay time.Duration
incomingMessages []drpc.Message
mu sync.Mutex
}
func (t *testHandler) OpenStream(ctx context.Context, p peer.Peer) (stream drpc.Stream, tags []string, err error) {
if t.streamOpenDelay > 0 {
time.Sleep(t.streamOpenDelay)
}
stream, err = testservice.NewDRPCTestClient(p).TestStream(ctx)
return
}
func (t *testHandler) HandleMessage(ctx context.Context, peerId string, msg drpc.Message) (err error) {
t.mu.Lock()
defer t.mu.Unlock()
t.incomingMessages = append(t.incomingMessages, msg)
return nil
}
func (t *testHandler) DRPCEncoding() drpc.Encoding {
return EncodingProto
}
func (t *testHandler) NewReadMessage() drpc.Message {
return new(testservice.StreamMessage)
}
type testServerHandler struct {
receiveCh chan *testservice.StreamMessage
streamsCount atomic.Int32
mu sync.Mutex
}
func (t *testServerHandler) TestStream(st testservice.DRPCTest_TestStreamStream) error {
t.streamsCount.Add(1)
defer t.streamsCount.Add(-1)
for {
msg, err := st.Recv()
if err != nil {
return err
}
t.receiveCh <- msg
if err = st.Send(msg); err != nil {
return err
}
}
return nil
}