Merge remote-tracking branch 'origin/yamux' into new-sync-protocol
This commit is contained in:
commit
4c45ad3e67
@ -6,6 +6,9 @@ import (
|
||||
)
|
||||
|
||||
func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Option {
|
||||
if reg == nil {
|
||||
return nil
|
||||
}
|
||||
if subsystem == "" {
|
||||
subsystem = "cache"
|
||||
}
|
||||
@ -13,9 +16,7 @@ func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Optio
|
||||
subSplit := strings.Split(subsystem, ".")
|
||||
namespace = strings.Join(nameSplit, "_")
|
||||
subsystem = strings.Join(subSplit, "_")
|
||||
if reg == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return func(cache *oCache) {
|
||||
cache.metrics = &metrics{
|
||||
hit: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
|
||||
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/app/ocache"
|
||||
"github.com/anyproto/any-sync/net/connutil"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/anyproto/any-sync/net/transport"
|
||||
@ -25,7 +26,7 @@ type connCtrl interface {
|
||||
func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) {
|
||||
ctx := mc.Context()
|
||||
pr := &peer{
|
||||
active: map[drpc.Conn]struct{}{},
|
||||
active: map[*subConn]struct{}{},
|
||||
MultiConn: mc,
|
||||
ctrl: ctrl,
|
||||
}
|
||||
@ -38,6 +39,7 @@ func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) {
|
||||
|
||||
type Peer interface {
|
||||
Id() string
|
||||
Context() context.Context
|
||||
|
||||
AcquireDrpcConn(ctx context.Context) (drpc.Conn, error)
|
||||
ReleaseDrpcConn(conn drpc.Conn)
|
||||
@ -50,14 +52,19 @@ type Peer interface {
|
||||
ocache.Object
|
||||
}
|
||||
|
||||
type subConn struct {
|
||||
drpc.Conn
|
||||
*connutil.LastUsageConn
|
||||
}
|
||||
|
||||
type peer struct {
|
||||
id string
|
||||
|
||||
ctrl connCtrl
|
||||
|
||||
// drpc conn pool
|
||||
inactive []drpc.Conn
|
||||
active map[drpc.Conn]struct{}
|
||||
inactive []*subConn
|
||||
active map[*subConn]struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
@ -70,29 +77,34 @@ func (p *peer) Id() string {
|
||||
|
||||
func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if len(p.inactive) == 0 {
|
||||
conn, err := p.Open(ctx)
|
||||
p.mu.Unlock()
|
||||
dconn, err := p.openDrpcConn(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dconn := drpcconn.New(conn)
|
||||
p.mu.Lock()
|
||||
p.inactive = append(p.inactive, dconn)
|
||||
}
|
||||
idx := len(p.inactive) - 1
|
||||
res := p.inactive[idx]
|
||||
p.inactive = p.inactive[:idx]
|
||||
p.active[res] = struct{}{}
|
||||
p.mu.Unlock()
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (p *peer) ReleaseDrpcConn(conn drpc.Conn) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if _, ok := p.active[conn]; ok {
|
||||
delete(p.active, conn)
|
||||
sc, ok := conn.(*subConn)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
p.inactive = append(p.inactive, conn)
|
||||
if _, ok = p.active[sc]; ok {
|
||||
delete(p.active, sc)
|
||||
}
|
||||
p.inactive = append(p.inactive, sc)
|
||||
return
|
||||
}
|
||||
|
||||
@ -105,6 +117,21 @@ func (p *peer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error
|
||||
return do(conn)
|
||||
}
|
||||
|
||||
func (p *peer) openDrpcConn(ctx context.Context) (dconn *subConn, err error) {
|
||||
conn, err := p.Open(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = handshake.OutgoingProtoHandshake(ctx, conn, handshakeproto.ProtoType_DRPC); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tconn := connutil.NewLastUsageConn(conn)
|
||||
return &subConn{
|
||||
Conn: drpcconn.New(tconn),
|
||||
LastUsageConn: tconn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *peer) acceptLoop() {
|
||||
var exitErr error
|
||||
defer func() {
|
||||
@ -145,12 +172,49 @@ func (p *peer) serve(conn net.Conn) (err error) {
|
||||
}
|
||||
|
||||
func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) {
|
||||
p.gc(objectTTL)
|
||||
if time.Now().Sub(p.LastUsage()) < objectTTL {
|
||||
return false, nil
|
||||
}
|
||||
return true, p.Close()
|
||||
}
|
||||
|
||||
func (p *peer) gc(ttl time.Duration) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
minLastUsage := time.Now().Add(-ttl)
|
||||
var hasClosed bool
|
||||
for i, in := range p.inactive {
|
||||
select {
|
||||
case <-in.Closed():
|
||||
p.inactive[i] = nil
|
||||
hasClosed = true
|
||||
default:
|
||||
}
|
||||
if in.LastUsage().Before(minLastUsage) {
|
||||
_ = in.Close()
|
||||
p.inactive[i] = nil
|
||||
hasClosed = true
|
||||
}
|
||||
}
|
||||
if hasClosed {
|
||||
inactive := p.inactive
|
||||
p.inactive = p.inactive[:0]
|
||||
for _, in := range inactive {
|
||||
if in != nil {
|
||||
p.inactive = append(p.inactive, in)
|
||||
}
|
||||
}
|
||||
}
|
||||
for act := range p.active {
|
||||
select {
|
||||
case <-act.Closed():
|
||||
delete(p.active, act)
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *peer) Close() (err error) {
|
||||
log.Debug("peer close", zap.String("peerId", p.id))
|
||||
return p.MultiConn.Close()
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net"
|
||||
_ "net/http/pprof"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
@ -20,6 +21,9 @@ func TestPeer_AcquireDrpcConn(t *testing.T) {
|
||||
fx := newFixture(t, "p1")
|
||||
defer fx.finish()
|
||||
in, out := net.Pipe()
|
||||
go func() {
|
||||
handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker)
|
||||
}()
|
||||
defer out.Close()
|
||||
fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil)
|
||||
dc, err := fx.AcquireDrpcConn(ctx)
|
||||
@ -77,6 +81,52 @@ func TestPeer_TryClose(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res)
|
||||
})
|
||||
t.Run("gc", func(t *testing.T) {
|
||||
fx := newFixture(t, "p1")
|
||||
defer fx.finish()
|
||||
now := time.Now()
|
||||
fx.mc.EXPECT().LastUsage().Return(now.Add(time.Millisecond * 100))
|
||||
|
||||
// make one inactive
|
||||
in, out := net.Pipe()
|
||||
go func() {
|
||||
handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker)
|
||||
}()
|
||||
defer out.Close()
|
||||
fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil)
|
||||
dc, err := fx.AcquireDrpcConn(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
// make one active but closed
|
||||
in2, out2 := net.Pipe()
|
||||
go func() {
|
||||
handshake.IncomingProtoHandshake(ctx, out2, defaultProtoChecker)
|
||||
}()
|
||||
defer out2.Close()
|
||||
fx.mc.EXPECT().Open(gomock.Any()).Return(in2, nil)
|
||||
dc2, err := fx.AcquireDrpcConn(ctx)
|
||||
require.NoError(t, err)
|
||||
_ = dc2.Close()
|
||||
|
||||
// make one inactive and closed
|
||||
in3, out3 := net.Pipe()
|
||||
go func() {
|
||||
handshake.IncomingProtoHandshake(ctx, out3, defaultProtoChecker)
|
||||
}()
|
||||
defer out3.Close()
|
||||
fx.mc.EXPECT().Open(gomock.Any()).Return(in3, nil)
|
||||
dc3, err := fx.AcquireDrpcConn(ctx)
|
||||
require.NoError(t, err)
|
||||
fx.ReleaseDrpcConn(dc3)
|
||||
_ = dc3.Close()
|
||||
fx.ReleaseDrpcConn(dc)
|
||||
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
res, err := fx.TryClose(time.Millisecond * 50)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res)
|
||||
})
|
||||
}
|
||||
|
||||
type acceptedConn struct {
|
||||
|
||||
30
net/rpc/rpctest/peer.go
Normal file
30
net/rpc/rpctest/peer.go
Normal file
@ -0,0 +1,30 @@
|
||||
package rpctest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/net/connutil"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/net/transport"
|
||||
yamux2 "github.com/anyproto/any-sync/net/transport/yamux"
|
||||
"github.com/hashicorp/yamux"
|
||||
"net"
|
||||
)
|
||||
|
||||
func MultiConnPair(peerIdServ, peerIdClient string) (serv, client transport.MultiConn) {
|
||||
sc, cc := net.Pipe()
|
||||
var servConn = make(chan transport.MultiConn, 1)
|
||||
go func() {
|
||||
sess, err := yamux.Server(sc, yamux.DefaultConfig())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
servConn <- yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdServ), connutil.NewLastUsageConn(sc), "", sess)
|
||||
}()
|
||||
sess, err := yamux.Client(cc, yamux.DefaultConfig())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
client = yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdClient), connutil.NewLastUsageConn(cc), "", sess)
|
||||
serv = <-servConn
|
||||
return
|
||||
}
|
||||
@ -1,122 +0,0 @@
|
||||
package rpctest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/net/pool"
|
||||
"math/rand"
|
||||
"storj.io/drpc"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrCantConnect = errors.New("can't connect to test server")
|
||||
|
||||
func NewTestPool() *TestPool {
|
||||
return &TestPool{
|
||||
peers: map[string]peer.Peer{},
|
||||
}
|
||||
}
|
||||
|
||||
type TestPool struct {
|
||||
ts *TesServer
|
||||
peers map[string]peer.Peer
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (t *TestPool) WithServer(ts *TesServer) *TestPool {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.ts = ts
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TestPool) Get(ctx context.Context, id string) (peer.Peer, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if p, ok := t.peers[id]; ok {
|
||||
return p, nil
|
||||
}
|
||||
if t.ts == nil {
|
||||
return nil, ErrCantConnect
|
||||
}
|
||||
return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil
|
||||
}
|
||||
|
||||
func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) {
|
||||
return t.Get(ctx, id)
|
||||
}
|
||||
|
||||
func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
for _, peerId := range peerIds {
|
||||
if p, ok := t.peers[peerId]; ok {
|
||||
return p, nil
|
||||
}
|
||||
}
|
||||
if t.ts == nil {
|
||||
return nil, ErrCantConnect
|
||||
}
|
||||
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
|
||||
}
|
||||
|
||||
func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
if t.ts == nil {
|
||||
return nil, ErrCantConnect
|
||||
}
|
||||
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
|
||||
}
|
||||
|
||||
func (t *TestPool) NewPool(name string) pool.Pool {
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *TestPool) AddPeer(p peer.Peer) {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
t.peers[p.Id()] = p
|
||||
}
|
||||
|
||||
func (t *TestPool) Init(a *app.App) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TestPool) Name() (name string) {
|
||||
return pool.CName
|
||||
}
|
||||
|
||||
func (t *TestPool) Run(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *TestPool) Close(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testPeer struct {
|
||||
id string
|
||||
drpc.Conn
|
||||
}
|
||||
|
||||
func (t testPeer) Addr() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (t testPeer) TryClose(objectTTL time.Duration) (res bool, err error) {
|
||||
return true, t.Close()
|
||||
}
|
||||
|
||||
func (t testPeer) Id() string {
|
||||
return t.id
|
||||
}
|
||||
|
||||
func (t testPeer) LastUsage() time.Time {
|
||||
return time.Now()
|
||||
}
|
||||
|
||||
func (t testPeer) UpdateLastUsage() {}
|
||||
@ -5,43 +5,39 @@ import (
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/net/rpc/server"
|
||||
"net"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcconn"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
)
|
||||
|
||||
func NewTestServer() *TesServer {
|
||||
ts := &TesServer{
|
||||
func NewTestServer() *TestServer {
|
||||
ts := &TestServer{
|
||||
Mux: drpcmux.New(),
|
||||
}
|
||||
ts.Server = drpcserver.New(ts.Mux)
|
||||
return ts
|
||||
}
|
||||
|
||||
type TesServer struct {
|
||||
type TestServer struct {
|
||||
*drpcmux.Mux
|
||||
*drpcserver.Server
|
||||
}
|
||||
|
||||
func (ts *TesServer) Init(a *app.App) (err error) {
|
||||
func (ts *TestServer) Init(a *app.App) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *TesServer) Name() (name string) {
|
||||
func (ts *TestServer) Name() (name string) {
|
||||
return server.CName
|
||||
}
|
||||
|
||||
func (ts *TesServer) Run(ctx context.Context) (err error) {
|
||||
func (ts *TestServer) Run(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *TesServer) Close(ctx context.Context) (err error) {
|
||||
func (ts *TestServer) Close(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ts *TesServer) Dial(ctx context.Context) drpc.Conn {
|
||||
sc, cc := net.Pipe()
|
||||
go ts.Server.ServeOne(ctx, sc)
|
||||
return drpcconn.New(cc)
|
||||
func (s *TestServer) ServeConn(ctx context.Context, conn net.Conn) (err error) {
|
||||
return s.Server.ServeOne(ctx, conn)
|
||||
}
|
||||
|
||||
@ -3,7 +3,7 @@ package streampool
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/util/multiqueue"
|
||||
"github.com/cheggaaa/mb/v3"
|
||||
"go.uber.org/zap"
|
||||
"storj.io/drpc"
|
||||
"sync/atomic"
|
||||
@ -17,17 +17,12 @@ type stream struct {
|
||||
streamId uint32
|
||||
closed atomic.Bool
|
||||
l logger.CtxLogger
|
||||
queue multiqueue.MultiQueue[drpc.Message]
|
||||
queue *mb.MB[drpc.Message]
|
||||
tags []string
|
||||
}
|
||||
|
||||
func (sr *stream) write(msg drpc.Message) (err error) {
|
||||
var queueId string
|
||||
if qId, ok := msg.(MessageQueueId); ok {
|
||||
queueId = qId.MessageQueueId()
|
||||
msg = qId.DrpcMessage()
|
||||
}
|
||||
return sr.queue.Add(sr.stream.Context(), queueId, msg)
|
||||
return sr.queue.Add(sr.stream.Context(), msg)
|
||||
}
|
||||
|
||||
func (sr *stream) readLoop() error {
|
||||
@ -50,13 +45,21 @@ func (sr *stream) readLoop() error {
|
||||
}
|
||||
}
|
||||
|
||||
func (sr *stream) writeToStream(msg drpc.Message) {
|
||||
if err := sr.stream.MsgSend(msg, EncodingProto); err != nil {
|
||||
sr.l.Warn("msg send error", zap.Error(err))
|
||||
sr.streamClose()
|
||||
return
|
||||
func (sr *stream) writeLoop() {
|
||||
for {
|
||||
msg, err := sr.queue.WaitOne(sr.peerCtx)
|
||||
if err != nil {
|
||||
if err != mb.ErrClosed {
|
||||
sr.streamClose()
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := sr.stream.MsgSend(msg, EncodingProto); err != nil {
|
||||
sr.l.Warn("msg send error", zap.Error(err))
|
||||
sr.streamClose()
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (sr *stream) streamClose() {
|
||||
|
||||
@ -4,7 +4,7 @@ import (
|
||||
"fmt"
|
||||
"github.com/anyproto/any-sync/net"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/util/multiqueue"
|
||||
"github.com/cheggaaa/mb/v3"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/net/context"
|
||||
@ -74,6 +74,9 @@ func (s *streamPool) ReadStream(drpcStream drpc.Stream, tags ...string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
go func() {
|
||||
st.writeLoop()
|
||||
}()
|
||||
return st.readLoop()
|
||||
}
|
||||
|
||||
@ -85,6 +88,9 @@ func (s *streamPool) AddStream(drpcStream drpc.Stream, tags ...string) error {
|
||||
go func() {
|
||||
_ = st.readLoop()
|
||||
}()
|
||||
go func() {
|
||||
st.writeLoop()
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -122,7 +128,7 @@ func (s *streamPool) addStream(drpcStream drpc.Stream, tags ...string) (*stream,
|
||||
l: log.With(zap.String("peerId", peerId), zap.Uint32("streamId", streamId)),
|
||||
tags: tags,
|
||||
}
|
||||
st.queue = multiqueue.New[drpc.Message](st.writeToStream, s.writeQueueSize)
|
||||
st.queue = mb.New[drpc.Message](s.writeQueueSize)
|
||||
s.streams[streamId] = st
|
||||
s.streamIdsByPeer[peerId] = append(s.streamIdsByPeer[peerId], streamId)
|
||||
for _, tag := range tags {
|
||||
@ -364,21 +370,3 @@ func removeStream(m map[string][]uint32, key string, streamId uint32) {
|
||||
m[key] = streamIds
|
||||
}
|
||||
}
|
||||
|
||||
// WithQueueId wraps the message and adds queueId
|
||||
func WithQueueId(msg drpc.Message, queueId string) drpc.Message {
|
||||
return &messageWithQueueId{queueId: queueId, Message: msg}
|
||||
}
|
||||
|
||||
type messageWithQueueId struct {
|
||||
drpc.Message
|
||||
queueId string
|
||||
}
|
||||
|
||||
func (m messageWithQueueId) MessageQueueId() string {
|
||||
return m.queueId
|
||||
}
|
||||
|
||||
func (m messageWithQueueId) DrpcMessage() drpc.Message {
|
||||
return m.Message
|
||||
}
|
||||
|
||||
@ -18,17 +18,25 @@ import (
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
func makePeerPair(t *testing.T, fx *fixture, peerId string) (pS, pC peer.Peer) {
|
||||
mcS, mcC := rpctest.MultiConnPair(peerId+"server", peerId)
|
||||
pS, err := peer.NewPeer(mcS, fx.ts)
|
||||
require.NoError(t, err)
|
||||
pC, err = peer.NewPeer(mcC, fx.ts)
|
||||
require.NoError(t, err)
|
||||
return
|
||||
}
|
||||
|
||||
func newClientStream(t *testing.T, fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) {
|
||||
p, err := fx.tp.Dial(ctx, peerId)
|
||||
_, pC := makePeerPair(t, fx, peerId)
|
||||
drpcConn, err := pC.AcquireDrpcConn(ctx)
|
||||
require.NoError(t, err)
|
||||
ctx = peer.CtxWithPeerId(ctx, peerId)
|
||||
s, err := testservice.NewDRPCTestClient(p).TestStream(ctx)
|
||||
st, err = testservice.NewDRPCTestClient(drpcConn).TestStream(pC.Context())
|
||||
require.NoError(t, err)
|
||||
return s, p
|
||||
return st, pC
|
||||
}
|
||||
|
||||
func TestStreamPool_AddStream(t *testing.T) {
|
||||
|
||||
t.Run("broadcast incoming", func(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
defer fx.Finish(t)
|
||||
@ -39,7 +47,7 @@ func TestStreamPool_AddStream(t *testing.T) {
|
||||
require.NoError(t, fx.AddStream(s2, "space2", "common"))
|
||||
|
||||
require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space1"}, "space1"))
|
||||
require.NoError(t, fx.Broadcast(ctx, WithQueueId(&testservice.StreamMessage{ReqData: "space2"}, "q2"), "space2"))
|
||||
require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space2"}, "space2"))
|
||||
require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "common"}, "common"))
|
||||
|
||||
var serverResults []string
|
||||
@ -85,11 +93,10 @@ func TestStreamPool_Send(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
defer fx.Finish(t)
|
||||
|
||||
p, err := fx.tp.Dial(ctx, "p1")
|
||||
require.NoError(t, err)
|
||||
pS, _ := makePeerPair(t, fx, "p1")
|
||||
|
||||
require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) {
|
||||
return []peer.Peer{p}, nil
|
||||
return []peer.Peer{pS}, nil
|
||||
}))
|
||||
|
||||
var msg *testservice.StreamMessage
|
||||
@ -100,12 +107,12 @@ func TestStreamPool_Send(t *testing.T) {
|
||||
}
|
||||
assert.Equal(t, "should open stream", msg.ReqData)
|
||||
})
|
||||
|
||||
t.Run("parallel open stream", func(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
defer fx.Finish(t)
|
||||
|
||||
p, err := fx.tp.Dial(ctx, "p1")
|
||||
require.NoError(t, err)
|
||||
pS, _ := makePeerPair(t, fx, "p1")
|
||||
|
||||
fx.th.streamOpenDelay = time.Second / 3
|
||||
|
||||
@ -113,7 +120,7 @@ func TestStreamPool_Send(t *testing.T) {
|
||||
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
go require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) {
|
||||
return []peer.Peer{p}, nil
|
||||
return []peer.Peer{pS}, nil
|
||||
}))
|
||||
}
|
||||
|
||||
@ -134,9 +141,8 @@ func TestStreamPool_Send(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
defer fx.Finish(t)
|
||||
|
||||
p, err := fx.tp.Dial(ctx, "p1")
|
||||
require.NoError(t, err)
|
||||
_ = p.Close()
|
||||
pS, _ := makePeerPair(t, fx, "p1")
|
||||
_ = pS.Close()
|
||||
|
||||
fx.th.streamOpenDelay = time.Second / 3
|
||||
|
||||
@ -147,11 +153,12 @@ func TestStreamPool_Send(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, p, &testservice.StreamMessage{ReqData: "should open stream"}))
|
||||
assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, pS, &testservice.StreamMessage{ReqData: "should open stream"}))
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestStreamPool_SendById(t *testing.T) {
|
||||
@ -196,10 +203,9 @@ func TestStreamPool_Tags(t *testing.T) {
|
||||
|
||||
func newFixture(t *testing.T) *fixture {
|
||||
fx := &fixture{}
|
||||
ts := rpctest.NewTestServer()
|
||||
fx.ts = rpctest.NewTestServer()
|
||||
fx.tsh = &testServerHandler{receiveCh: make(chan *testservice.StreamMessage, 100)}
|
||||
require.NoError(t, testservice.DRPCRegisterTest(ts, fx.tsh))
|
||||
fx.tp = rpctest.NewTestPool().WithServer(ts)
|
||||
require.NoError(t, testservice.DRPCRegisterTest(fx.ts, fx.tsh))
|
||||
fx.th = &testHandler{}
|
||||
fx.StreamPool = New().NewStreamPool(fx.th, StreamConfig{
|
||||
SendQueueSize: 10,
|
||||
@ -211,14 +217,13 @@ func newFixture(t *testing.T) *fixture {
|
||||
|
||||
type fixture struct {
|
||||
StreamPool
|
||||
tp *rpctest.TestPool
|
||||
th *testHandler
|
||||
tsh *testServerHandler
|
||||
ts *rpctest.TestServer
|
||||
}
|
||||
|
||||
func (fx *fixture) Finish(t *testing.T) {
|
||||
require.NoError(t, fx.Close())
|
||||
require.NoError(t, fx.tp.Close(ctx))
|
||||
}
|
||||
|
||||
type testHandler struct {
|
||||
@ -231,7 +236,11 @@ func (t *testHandler) OpenStream(ctx context.Context, p peer.Peer) (stream drpc.
|
||||
if t.streamOpenDelay > 0 {
|
||||
time.Sleep(t.streamOpenDelay)
|
||||
}
|
||||
stream, err = testservice.NewDRPCTestClient(p).TestStream(ctx)
|
||||
conn, err := p.AcquireDrpcConn(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
stream, err = testservice.NewDRPCTestClient(conn).TestStream(p.Context())
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -9,6 +9,15 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewMultiConn(cctx context.Context, luConn *connutil.LastUsageConn, addr string, sess *yamux.Session) transport.MultiConn {
|
||||
return &yamuxConn{
|
||||
ctx: cctx,
|
||||
luConn: luConn,
|
||||
addr: addr,
|
||||
Session: sess,
|
||||
}
|
||||
}
|
||||
|
||||
type yamuxConn struct {
|
||||
ctx context.Context
|
||||
luConn *connutil.LastUsageConn
|
||||
|
||||
@ -96,12 +96,7 @@ func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.Mu
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
mc = &yamuxConn{
|
||||
ctx: cctx,
|
||||
luConn: luc,
|
||||
Session: sess,
|
||||
addr: addr,
|
||||
}
|
||||
mc = NewMultiConn(cctx, luc, addr, sess)
|
||||
return
|
||||
}
|
||||
|
||||
@ -148,12 +143,7 @@ func (y *yamuxTransport) accept(conn net.Conn) {
|
||||
log.Warn("incoming connection yamux session error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
mc := &yamuxConn{
|
||||
ctx: cctx,
|
||||
luConn: luc,
|
||||
Session: sess,
|
||||
addr: conn.RemoteAddr().String(),
|
||||
}
|
||||
mc := NewMultiConn(cctx, luc, conn.RemoteAddr().String(), sess)
|
||||
if err = y.accepter.Accept(mc); err != nil {
|
||||
log.Warn("connection accept error", zap.Error(err))
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user