simplify drpc server + peer accept loop
This commit is contained in:
parent
00c582e157
commit
fb007211f0
@ -2,28 +2,37 @@ package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/app/ocache"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/anyproto/any-sync/net/transport"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcconn"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var log = logger.NewNamed("common.net.peer")
|
||||
|
||||
func NewPeer(mc transport.MultiConn) (p Peer, err error) {
|
||||
type connCtrl interface {
|
||||
ServeConn(ctx context.Context, conn net.Conn) (err error)
|
||||
}
|
||||
|
||||
func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) {
|
||||
ctx := mc.Context()
|
||||
pr := &peer{
|
||||
active: map[drpc.Conn]struct{}{},
|
||||
MultiConn: mc,
|
||||
ctrl: ctrl,
|
||||
}
|
||||
if pr.id, err = CtxPeerId(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
go pr.acceptLoop()
|
||||
return pr, nil
|
||||
}
|
||||
|
||||
@ -43,9 +52,12 @@ type Peer interface {
|
||||
type peer struct {
|
||||
id string
|
||||
|
||||
ctrl connCtrl
|
||||
|
||||
// drpc conn pool
|
||||
inactive []drpc.Conn
|
||||
active map[drpc.Conn]struct{}
|
||||
|
||||
mu sync.Mutex
|
||||
|
||||
transport.MultiConn
|
||||
@ -83,16 +95,49 @@ func (p *peer) ReleaseDrpcConn(conn drpc.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
func (p *peer) acceptLoop() {
|
||||
var exitErr error
|
||||
defer func() {
|
||||
if exitErr != transport.ErrConnClosed {
|
||||
log.Warn("accept error: close connection", zap.Error(exitErr))
|
||||
_ = p.MultiConn.Close()
|
||||
}
|
||||
}()
|
||||
for {
|
||||
conn, err := p.Accept()
|
||||
if err != nil {
|
||||
exitErr = err
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
serveErr := p.serve(conn)
|
||||
if serveErr != io.EOF && serveErr != transport.ErrConnClosed {
|
||||
log.InfoCtx(p.Context(), "serve connection error", zap.Error(serveErr))
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
var defaultProtoChecker = handshake.ProtoChecker{
|
||||
AllowedProtoTypes: []handshakeproto.ProtoType{
|
||||
handshakeproto.ProtoType_DRPC,
|
||||
},
|
||||
}
|
||||
|
||||
func (p *peer) serve(conn net.Conn) (err error) {
|
||||
hsCtx, cancel := context.WithTimeout(p.Context(), time.Second*20)
|
||||
if _, err = handshake.IncomingProtoHandshake(hsCtx, conn, defaultProtoChecker); err != nil {
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
cancel()
|
||||
return p.ctrl.ServeConn(p.Context(), conn)
|
||||
}
|
||||
|
||||
func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) {
|
||||
if time.Now().Sub(p.LastUsage()) < objectTTL {
|
||||
return false, nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
if len(p.active) > 0 {
|
||||
p.mu.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
p.mu.Unlock()
|
||||
return true, p.Close()
|
||||
}
|
||||
|
||||
|
||||
@ -2,12 +2,16 @@ package peer
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake"
|
||||
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
|
||||
"github.com/anyproto/any-sync/net/transport/mock_transport"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
@ -38,14 +42,63 @@ func TestPeer_AcquireDrpcConn(t *testing.T) {
|
||||
assert.Len(t, fx.inactive, 0)
|
||||
}
|
||||
|
||||
func TestPeerAccept(t *testing.T) {
|
||||
fx := newFixture(t, "p1")
|
||||
defer fx.finish()
|
||||
in, out := net.Pipe()
|
||||
defer out.Close()
|
||||
|
||||
var outHandshakeCh = make(chan error)
|
||||
go func() {
|
||||
outHandshakeCh <- handshake.OutgoingProtoHandshake(ctx, out, handshakeproto.ProtoType_DRPC)
|
||||
}()
|
||||
fx.acceptCh <- acceptedConn{conn: in}
|
||||
cn := <-fx.testCtrl.serveConn
|
||||
assert.Equal(t, in, cn)
|
||||
assert.NoError(t, <-outHandshakeCh)
|
||||
}
|
||||
|
||||
func TestPeer_TryClose(t *testing.T) {
|
||||
t.Run("ttl", func(t *testing.T) {
|
||||
fx := newFixture(t, "p1")
|
||||
defer fx.finish()
|
||||
lu := time.Now()
|
||||
fx.mc.EXPECT().LastUsage().Return(lu)
|
||||
res, err := fx.TryClose(time.Second)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, res)
|
||||
})
|
||||
t.Run("close", func(t *testing.T) {
|
||||
fx := newFixture(t, "p1")
|
||||
defer fx.finish()
|
||||
lu := time.Now().Add(-time.Second * 2)
|
||||
fx.mc.EXPECT().LastUsage().Return(lu)
|
||||
res, err := fx.TryClose(time.Second)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, res)
|
||||
})
|
||||
}
|
||||
|
||||
type acceptedConn struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
func newFixture(t *testing.T, peerId string) *fixture {
|
||||
fx := &fixture{
|
||||
ctrl: gomock.NewController(t),
|
||||
acceptCh: make(chan acceptedConn),
|
||||
testCtrl: newTesCtrl(),
|
||||
}
|
||||
fx.mc = mock_transport.NewMockMultiConn(fx.ctrl)
|
||||
ctx := CtxWithPeerId(context.Background(), peerId)
|
||||
fx.mc.EXPECT().Context().Return(ctx).AnyTimes()
|
||||
p, err := NewPeer(fx.mc)
|
||||
fx.mc.EXPECT().Accept().DoAndReturn(func() (net.Conn, error) {
|
||||
ac := <-fx.acceptCh
|
||||
return ac.conn, ac.err
|
||||
}).AnyTimes()
|
||||
fx.mc.EXPECT().Close().AnyTimes()
|
||||
p, err := NewPeer(fx.mc, fx.testCtrl)
|
||||
require.NoError(t, err)
|
||||
fx.peer = p.(*peer)
|
||||
return fx
|
||||
@ -55,8 +108,30 @@ type fixture struct {
|
||||
*peer
|
||||
ctrl *gomock.Controller
|
||||
mc *mock_transport.MockMultiConn
|
||||
acceptCh chan acceptedConn
|
||||
testCtrl *testCtrl
|
||||
}
|
||||
|
||||
func (fx *fixture) finish() {
|
||||
fx.testCtrl.close()
|
||||
fx.ctrl.Finish()
|
||||
}
|
||||
|
||||
func newTesCtrl() *testCtrl {
|
||||
return &testCtrl{closeCh: make(chan struct{}), serveConn: make(chan net.Conn, 10)}
|
||||
}
|
||||
|
||||
type testCtrl struct {
|
||||
serveConn chan net.Conn
|
||||
closeCh chan struct{}
|
||||
}
|
||||
|
||||
func (t *testCtrl) ServeConn(ctx context.Context, conn net.Conn) (err error) {
|
||||
t.serveConn <- conn
|
||||
<-t.closeCh
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (t *testCtrl) close() {
|
||||
close(t.closeCh)
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/net/pool"
|
||||
"github.com/anyproto/any-sync/net/rpc/server"
|
||||
"github.com/anyproto/any-sync/net/transport"
|
||||
"github.com/anyproto/any-sync/net/transport/yamux"
|
||||
"github.com/anyproto/any-sync/nodeconf"
|
||||
@ -38,6 +39,7 @@ type peerService struct {
|
||||
nodeConf nodeconf.NodeConf
|
||||
peerAddrs map[string][]string
|
||||
pool pool.Pool
|
||||
server server.DRPCServer
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
@ -45,6 +47,7 @@ func (p *peerService) Init(a *app.App) (err error) {
|
||||
p.yamux = a.MustComponent(yamux.CName).(transport.Transport)
|
||||
p.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
|
||||
p.pool = a.MustComponent(pool.CName).(pool.Pool)
|
||||
p.server = a.MustComponent(server.CName).(server.DRPCServer)
|
||||
p.peerAddrs = map[string][]string{}
|
||||
return nil
|
||||
}
|
||||
@ -75,11 +78,11 @@ func (p *peerService) Dial(ctx context.Context, peerId string) (pr peer.Peer, er
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return peer.NewPeer(mc)
|
||||
return peer.NewPeer(mc, p.server)
|
||||
}
|
||||
|
||||
func (p *peerService) Accept(mc transport.MultiConn) (err error) {
|
||||
pr, err := peer.NewPeer(mc)
|
||||
pr, err := peer.NewPeer(mc, p.server)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1,134 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/net/secureservice"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"github.com/zeebo/errs"
|
||||
"go.uber.org/zap"
|
||||
"io"
|
||||
"net"
|
||||
"storj.io/drpc"
|
||||
"storj.io/drpc/drpcmanager"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
"storj.io/drpc/drpcwire"
|
||||
"time"
|
||||
)
|
||||
|
||||
type BaseDrpcServer struct {
|
||||
drpcServer *drpcserver.Server
|
||||
transport secureservice.SecureService
|
||||
listeners []net.Listener
|
||||
handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
|
||||
cancel func()
|
||||
*drpcmux.Mux
|
||||
}
|
||||
|
||||
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
|
||||
|
||||
type Params struct {
|
||||
BufferSizeMb int
|
||||
ListenAddrs []string
|
||||
Wrapper DRPCHandlerWrapper
|
||||
TimeoutMillis int
|
||||
Handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
|
||||
}
|
||||
|
||||
func NewBaseDrpcServer() *BaseDrpcServer {
|
||||
return &BaseDrpcServer{Mux: drpcmux.New()}
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) {
|
||||
s.drpcServer = drpcserver.NewWithOptions(params.Wrapper(s.Mux), drpcserver.Options{Manager: drpcmanager.Options{
|
||||
Reader: drpcwire.ReaderOptions{MaximumBufferSize: params.BufferSizeMb * (1 << 20)},
|
||||
}})
|
||||
s.handshake = params.Handshake
|
||||
ctx, s.cancel = context.WithCancel(ctx)
|
||||
for _, addr := range params.ListenAddrs {
|
||||
list, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.listeners = append(s.listeners, list)
|
||||
go s.serve(ctx, list)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) {
|
||||
l := log.With(zap.String("localAddr", lis.Addr().String()))
|
||||
l.Info("drpc listener started")
|
||||
defer func() {
|
||||
l.Debug("drpc listener stopped")
|
||||
}()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
conn, err := lis.Accept()
|
||||
if err != nil {
|
||||
if isTemporary(err) {
|
||||
l.Debug("listener temporary accept error", zap.Error(err))
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
l.Error("listener accept error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
go s.serveConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) serveConn(conn net.Conn) {
|
||||
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
|
||||
var (
|
||||
ctx = context.Background()
|
||||
err error
|
||||
)
|
||||
if s.handshake != nil {
|
||||
ctx, conn, err = s.handshake(conn)
|
||||
if err != nil {
|
||||
l.Info("handshake error", zap.Error(err))
|
||||
return
|
||||
}
|
||||
if sc, ok := conn.(sec.SecureConn); ok {
|
||||
ctx = peer.CtxWithPeerId(ctx, sc.RemotePeer().String())
|
||||
}
|
||||
}
|
||||
ctx = peer.CtxWithPeerAddr(ctx, conn.RemoteAddr().String())
|
||||
l.Debug("connection opened")
|
||||
if err := s.drpcServer.ServeOne(ctx, conn); err != nil {
|
||||
if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) {
|
||||
l.Debug("connection closed")
|
||||
} else {
|
||||
l.Warn("serve connection error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) ListenAddrs() (addrs []net.Addr) {
|
||||
for _, list := range s.listeners {
|
||||
addrs = append(addrs, list.Addr())
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *BaseDrpcServer) Close(ctx context.Context) (err error) {
|
||||
if s.cancel != nil {
|
||||
s.cancel()
|
||||
}
|
||||
for _, l := range s.listeners {
|
||||
if e := l.Close(); e != nil {
|
||||
log.Warn("close listener error", zap.Error(e))
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -6,11 +6,13 @@ import (
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/metric"
|
||||
anyNet "github.com/anyproto/any-sync/net"
|
||||
"github.com/anyproto/any-sync/net/secureservice"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"go.uber.org/zap"
|
||||
"net"
|
||||
"storj.io/drpc"
|
||||
"time"
|
||||
"storj.io/drpc/drpcmanager"
|
||||
"storj.io/drpc/drpcmux"
|
||||
"storj.io/drpc/drpcserver"
|
||||
"storj.io/drpc/drpcwire"
|
||||
)
|
||||
|
||||
const CName = "common.net.drpcserver"
|
||||
@ -18,49 +20,46 @@ const CName = "common.net.drpcserver"
|
||||
var log = logger.NewNamed(CName)
|
||||
|
||||
func New() DRPCServer {
|
||||
return &drpcServer{BaseDrpcServer: NewBaseDrpcServer()}
|
||||
return &drpcServer{}
|
||||
}
|
||||
|
||||
type DRPCServer interface {
|
||||
app.ComponentRunnable
|
||||
ServeConn(ctx context.Context, conn net.Conn) (err error)
|
||||
app.Component
|
||||
drpc.Mux
|
||||
}
|
||||
|
||||
type drpcServer struct {
|
||||
drpcServer *drpcserver.Server
|
||||
*drpcmux.Mux
|
||||
config anyNet.Config
|
||||
metric metric.Metric
|
||||
transport secureservice.SecureService
|
||||
*BaseDrpcServer
|
||||
}
|
||||
|
||||
func (s *drpcServer) Init(a *app.App) (err error) {
|
||||
s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet()
|
||||
s.metric = a.MustComponent(metric.CName).(metric.Metric)
|
||||
s.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService)
|
||||
return nil
|
||||
}
|
||||
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
|
||||
|
||||
func (s *drpcServer) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
func (s *drpcServer) Run(ctx context.Context) (err error) {
|
||||
params := Params{
|
||||
BufferSizeMb: s.config.Stream.MaxMsgSizeMb,
|
||||
TimeoutMillis: s.config.Stream.TimeoutMilliseconds,
|
||||
ListenAddrs: s.config.Server.ListenAddrs,
|
||||
Wrapper: func(handler drpc.Handler) drpc.Handler {
|
||||
return s.metric.WrapDRPCHandler(handler)
|
||||
},
|
||||
Handshake: func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer cancel()
|
||||
return s.transport.SecureInbound(ctx, conn)
|
||||
},
|
||||
func (s *drpcServer) Init(a *app.App) (err error) {
|
||||
s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet()
|
||||
s.metric, _ = a.Component(metric.CName).(metric.Metric)
|
||||
s.Mux = drpcmux.New()
|
||||
|
||||
var handler drpc.Handler
|
||||
handler = s
|
||||
if s.metric != nil {
|
||||
handler = s.metric.WrapDRPCHandler(s)
|
||||
}
|
||||
return s.BaseDrpcServer.Run(ctx, params)
|
||||
s.drpcServer = drpcserver.NewWithOptions(handler, drpcserver.Options{Manager: drpcmanager.Options{
|
||||
Reader: drpcwire.ReaderOptions{MaximumBufferSize: s.config.Stream.MaxMsgSizeMb * (1 << 20)},
|
||||
}})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *drpcServer) Close(ctx context.Context) (err error) {
|
||||
return s.BaseDrpcServer.Close(ctx)
|
||||
func (s *drpcServer) ServeConn(ctx context.Context, conn net.Conn) (err error) {
|
||||
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
|
||||
l.Debug("drpc serve peer")
|
||||
return s.drpcServer.ServeOne(ctx, conn)
|
||||
}
|
||||
|
||||
@ -1,18 +0,0 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
// isTemporary checks if an error is temporary.
|
||||
func isTemporary(err error) bool {
|
||||
var nErr net.Error
|
||||
if errors.As(err, &nErr) {
|
||||
return nErr.Temporary()
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@ -1,41 +0,0 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const (
|
||||
_WSAEMFILE syscall.Errno = 10024
|
||||
_WSAENETRESET syscall.Errno = 10052
|
||||
_WSAENOBUFS syscall.Errno = 10055
|
||||
)
|
||||
|
||||
// isTemporary checks if an error is temporary.
|
||||
// see related go issue for more detail: https://go-review.googlesource.com/c/go/+/208537/
|
||||
func isTemporary(err error) bool {
|
||||
var nErr net.Error
|
||||
if !errors.As(err, &nErr) {
|
||||
return false
|
||||
}
|
||||
|
||||
if nErr.Temporary() {
|
||||
return true
|
||||
}
|
||||
|
||||
var sErr *os.SyscallError
|
||||
if errors.As(err, &sErr) {
|
||||
switch sErr.Err {
|
||||
case _WSAENETRESET,
|
||||
_WSAEMFILE,
|
||||
_WSAENOBUFS:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
@ -3,10 +3,15 @@ package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrConnClosed = errors.New("connection closed")
|
||||
)
|
||||
|
||||
// Transport is a common interface for a network transport
|
||||
type Transport interface {
|
||||
// SetAccepter sets accepter that will be called for new connections
|
||||
|
||||
@ -3,6 +3,7 @@ package yamux
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/net/connutil"
|
||||
"github.com/anyproto/any-sync/net/transport"
|
||||
"github.com/hashicorp/yamux"
|
||||
"net"
|
||||
"time"
|
||||
@ -30,3 +31,12 @@ func (y *yamuxConn) Context() context.Context {
|
||||
func (y *yamuxConn) Addr() string {
|
||||
return y.addr
|
||||
}
|
||||
|
||||
func (y *yamuxConn) Accept() (conn net.Conn, err error) {
|
||||
if conn, err = y.Session.Accept(); err != nil {
|
||||
if err == yamux.ErrSessionShutdown {
|
||||
err = transport.ErrConnClosed
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user