From fb007211f0e64ff63bab051ae4510dc501db4818 Mon Sep 17 00:00:00 2001 From: Sergey Cherepanov Date: Wed, 31 May 2023 20:24:07 +0200 Subject: [PATCH] simplify drpc server + peer accept loop --- net/peer/peer.go | 67 ++++++++++++++--- net/peer/peer_test.go | 83 +++++++++++++++++++- net/peerservice/peerservice.go | 7 +- net/rpc/server/baseserver.go | 134 --------------------------------- net/rpc/server/drpcserver.go | 61 ++++++++------- net/rpc/server/util.go | 18 ----- net/rpc/server/util_windows.go | 41 ---------- net/transport/transport.go | 5 ++ net/transport/yamux/conn.go | 10 +++ 9 files changed, 185 insertions(+), 241 deletions(-) delete mode 100644 net/rpc/server/baseserver.go delete mode 100644 net/rpc/server/util.go delete mode 100644 net/rpc/server/util_windows.go diff --git a/net/peer/peer.go b/net/peer/peer.go index df8a2799..9c8e0ddc 100644 --- a/net/peer/peer.go +++ b/net/peer/peer.go @@ -2,28 +2,37 @@ package peer import ( "context" + "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/ocache" + "github.com/anyproto/any-sync/net/secureservice/handshake" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/net/transport" + "go.uber.org/zap" + "io" + "net" "storj.io/drpc" "storj.io/drpc/drpcconn" "sync" "time" - - "github.com/anyproto/any-sync/app/logger" - "go.uber.org/zap" ) var log = logger.NewNamed("common.net.peer") -func NewPeer(mc transport.MultiConn) (p Peer, err error) { +type connCtrl interface { + ServeConn(ctx context.Context, conn net.Conn) (err error) +} + +func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) { ctx := mc.Context() pr := &peer{ active: map[drpc.Conn]struct{}{}, MultiConn: mc, + ctrl: ctrl, } if pr.id, err = CtxPeerId(ctx); err != nil { return } + go pr.acceptLoop() return pr, nil } @@ -43,10 +52,13 @@ type Peer interface { type peer struct { id string + ctrl connCtrl + // drpc conn pool inactive []drpc.Conn active map[drpc.Conn]struct{} - mu sync.Mutex + + mu sync.Mutex transport.MultiConn } @@ -83,16 +95,49 @@ func (p *peer) ReleaseDrpcConn(conn drpc.Conn) { return } +func (p *peer) acceptLoop() { + var exitErr error + defer func() { + if exitErr != transport.ErrConnClosed { + log.Warn("accept error: close connection", zap.Error(exitErr)) + _ = p.MultiConn.Close() + } + }() + for { + conn, err := p.Accept() + if err != nil { + exitErr = err + return + } + go func() { + serveErr := p.serve(conn) + if serveErr != io.EOF && serveErr != transport.ErrConnClosed { + log.InfoCtx(p.Context(), "serve connection error", zap.Error(serveErr)) + } + }() + } +} + +var defaultProtoChecker = handshake.ProtoChecker{ + AllowedProtoTypes: []handshakeproto.ProtoType{ + handshakeproto.ProtoType_DRPC, + }, +} + +func (p *peer) serve(conn net.Conn) (err error) { + hsCtx, cancel := context.WithTimeout(p.Context(), time.Second*20) + if _, err = handshake.IncomingProtoHandshake(hsCtx, conn, defaultProtoChecker); err != nil { + cancel() + return + } + cancel() + return p.ctrl.ServeConn(p.Context(), conn) +} + func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) { if time.Now().Sub(p.LastUsage()) < objectTTL { return false, nil } - p.mu.Lock() - if len(p.active) > 0 { - p.mu.Unlock() - return false, nil - } - p.mu.Unlock() return true, p.Close() } diff --git a/net/peer/peer_test.go b/net/peer/peer_test.go index 41efc45d..00a14252 100644 --- a/net/peer/peer_test.go +++ b/net/peer/peer_test.go @@ -2,12 +2,16 @@ package peer import ( "context" + "github.com/anyproto/any-sync/net/secureservice/handshake" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/net/transport/mock_transport" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "io" "net" "testing" + "time" ) var ctx = context.Background() @@ -38,14 +42,63 @@ func TestPeer_AcquireDrpcConn(t *testing.T) { assert.Len(t, fx.inactive, 0) } +func TestPeerAccept(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + in, out := net.Pipe() + defer out.Close() + + var outHandshakeCh = make(chan error) + go func() { + outHandshakeCh <- handshake.OutgoingProtoHandshake(ctx, out, handshakeproto.ProtoType_DRPC) + }() + fx.acceptCh <- acceptedConn{conn: in} + cn := <-fx.testCtrl.serveConn + assert.Equal(t, in, cn) + assert.NoError(t, <-outHandshakeCh) +} + +func TestPeer_TryClose(t *testing.T) { + t.Run("ttl", func(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + lu := time.Now() + fx.mc.EXPECT().LastUsage().Return(lu) + res, err := fx.TryClose(time.Second) + require.NoError(t, err) + assert.False(t, res) + }) + t.Run("close", func(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + lu := time.Now().Add(-time.Second * 2) + fx.mc.EXPECT().LastUsage().Return(lu) + res, err := fx.TryClose(time.Second) + require.NoError(t, err) + assert.True(t, res) + }) +} + +type acceptedConn struct { + conn net.Conn + err error +} + func newFixture(t *testing.T, peerId string) *fixture { fx := &fixture{ - ctrl: gomock.NewController(t), + ctrl: gomock.NewController(t), + acceptCh: make(chan acceptedConn), + testCtrl: newTesCtrl(), } fx.mc = mock_transport.NewMockMultiConn(fx.ctrl) ctx := CtxWithPeerId(context.Background(), peerId) fx.mc.EXPECT().Context().Return(ctx).AnyTimes() - p, err := NewPeer(fx.mc) + fx.mc.EXPECT().Accept().DoAndReturn(func() (net.Conn, error) { + ac := <-fx.acceptCh + return ac.conn, ac.err + }).AnyTimes() + fx.mc.EXPECT().Close().AnyTimes() + p, err := NewPeer(fx.mc, fx.testCtrl) require.NoError(t, err) fx.peer = p.(*peer) return fx @@ -53,10 +106,32 @@ func newFixture(t *testing.T, peerId string) *fixture { type fixture struct { *peer - ctrl *gomock.Controller - mc *mock_transport.MockMultiConn + ctrl *gomock.Controller + mc *mock_transport.MockMultiConn + acceptCh chan acceptedConn + testCtrl *testCtrl } func (fx *fixture) finish() { + fx.testCtrl.close() fx.ctrl.Finish() } + +func newTesCtrl() *testCtrl { + return &testCtrl{closeCh: make(chan struct{}), serveConn: make(chan net.Conn, 10)} +} + +type testCtrl struct { + serveConn chan net.Conn + closeCh chan struct{} +} + +func (t *testCtrl) ServeConn(ctx context.Context, conn net.Conn) (err error) { + t.serveConn <- conn + <-t.closeCh + return io.EOF +} + +func (t *testCtrl) close() { + close(t.closeCh) +} diff --git a/net/peerservice/peerservice.go b/net/peerservice/peerservice.go index a83e6785..e0691733 100644 --- a/net/peerservice/peerservice.go +++ b/net/peerservice/peerservice.go @@ -7,6 +7,7 @@ import ( "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/pool" + "github.com/anyproto/any-sync/net/rpc/server" "github.com/anyproto/any-sync/net/transport" "github.com/anyproto/any-sync/net/transport/yamux" "github.com/anyproto/any-sync/nodeconf" @@ -38,6 +39,7 @@ type peerService struct { nodeConf nodeconf.NodeConf peerAddrs map[string][]string pool pool.Pool + server server.DRPCServer mu sync.RWMutex } @@ -45,6 +47,7 @@ func (p *peerService) Init(a *app.App) (err error) { p.yamux = a.MustComponent(yamux.CName).(transport.Transport) p.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) p.pool = a.MustComponent(pool.CName).(pool.Pool) + p.server = a.MustComponent(server.CName).(server.DRPCServer) p.peerAddrs = map[string][]string{} return nil } @@ -75,11 +78,11 @@ func (p *peerService) Dial(ctx context.Context, peerId string) (pr peer.Peer, er if err != nil { return } - return peer.NewPeer(mc) + return peer.NewPeer(mc, p.server) } func (p *peerService) Accept(mc transport.MultiConn) (err error) { - pr, err := peer.NewPeer(mc) + pr, err := peer.NewPeer(mc, p.server) if err != nil { return err } diff --git a/net/rpc/server/baseserver.go b/net/rpc/server/baseserver.go deleted file mode 100644 index cb3047ed..00000000 --- a/net/rpc/server/baseserver.go +++ /dev/null @@ -1,134 +0,0 @@ -package server - -import ( - "context" - "github.com/anyproto/any-sync/net/peer" - "github.com/anyproto/any-sync/net/secureservice" - "github.com/libp2p/go-libp2p/core/sec" - "github.com/zeebo/errs" - "go.uber.org/zap" - "io" - "net" - "storj.io/drpc" - "storj.io/drpc/drpcmanager" - "storj.io/drpc/drpcmux" - "storj.io/drpc/drpcserver" - "storj.io/drpc/drpcwire" - "time" -) - -type BaseDrpcServer struct { - drpcServer *drpcserver.Server - transport secureservice.SecureService - listeners []net.Listener - handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) - cancel func() - *drpcmux.Mux -} - -type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler - -type Params struct { - BufferSizeMb int - ListenAddrs []string - Wrapper DRPCHandlerWrapper - TimeoutMillis int - Handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) -} - -func NewBaseDrpcServer() *BaseDrpcServer { - return &BaseDrpcServer{Mux: drpcmux.New()} -} - -func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) { - s.drpcServer = drpcserver.NewWithOptions(params.Wrapper(s.Mux), drpcserver.Options{Manager: drpcmanager.Options{ - Reader: drpcwire.ReaderOptions{MaximumBufferSize: params.BufferSizeMb * (1 << 20)}, - }}) - s.handshake = params.Handshake - ctx, s.cancel = context.WithCancel(ctx) - for _, addr := range params.ListenAddrs { - list, err := net.Listen("tcp", addr) - if err != nil { - return err - } - s.listeners = append(s.listeners, list) - go s.serve(ctx, list) - } - return -} - -func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) { - l := log.With(zap.String("localAddr", lis.Addr().String())) - l.Info("drpc listener started") - defer func() { - l.Debug("drpc listener stopped") - }() - for { - select { - case <-ctx.Done(): - return - default: - } - conn, err := lis.Accept() - if err != nil { - if isTemporary(err) { - l.Debug("listener temporary accept error", zap.Error(err)) - select { - case <-time.After(time.Second): - case <-ctx.Done(): - return - } - continue - } - l.Error("listener accept error", zap.Error(err)) - return - } - go s.serveConn(conn) - } -} - -func (s *BaseDrpcServer) serveConn(conn net.Conn) { - l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String())) - var ( - ctx = context.Background() - err error - ) - if s.handshake != nil { - ctx, conn, err = s.handshake(conn) - if err != nil { - l.Info("handshake error", zap.Error(err)) - return - } - if sc, ok := conn.(sec.SecureConn); ok { - ctx = peer.CtxWithPeerId(ctx, sc.RemotePeer().String()) - } - } - ctx = peer.CtxWithPeerAddr(ctx, conn.RemoteAddr().String()) - l.Debug("connection opened") - if err := s.drpcServer.ServeOne(ctx, conn); err != nil { - if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) { - l.Debug("connection closed") - } else { - l.Warn("serve connection error", zap.Error(err)) - } - } -} - -func (s *BaseDrpcServer) ListenAddrs() (addrs []net.Addr) { - for _, list := range s.listeners { - addrs = append(addrs, list.Addr()) - } - return -} - -func (s *BaseDrpcServer) Close(ctx context.Context) (err error) { - if s.cancel != nil { - s.cancel() - } - for _, l := range s.listeners { - if e := l.Close(); e != nil { - log.Warn("close listener error", zap.Error(e)) - } - } - return -} diff --git a/net/rpc/server/drpcserver.go b/net/rpc/server/drpcserver.go index 1874d16a..2b061515 100644 --- a/net/rpc/server/drpcserver.go +++ b/net/rpc/server/drpcserver.go @@ -6,11 +6,13 @@ import ( "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/metric" anyNet "github.com/anyproto/any-sync/net" - "github.com/anyproto/any-sync/net/secureservice" - "github.com/libp2p/go-libp2p/core/sec" + "go.uber.org/zap" "net" "storj.io/drpc" - "time" + "storj.io/drpc/drpcmanager" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + "storj.io/drpc/drpcwire" ) const CName = "common.net.drpcserver" @@ -18,49 +20,46 @@ const CName = "common.net.drpcserver" var log = logger.NewNamed(CName) func New() DRPCServer { - return &drpcServer{BaseDrpcServer: NewBaseDrpcServer()} + return &drpcServer{} } type DRPCServer interface { - app.ComponentRunnable + ServeConn(ctx context.Context, conn net.Conn) (err error) + app.Component drpc.Mux } type drpcServer struct { - config anyNet.Config - metric metric.Metric - transport secureservice.SecureService - *BaseDrpcServer + drpcServer *drpcserver.Server + *drpcmux.Mux + config anyNet.Config + metric metric.Metric } -func (s *drpcServer) Init(a *app.App) (err error) { - s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet() - s.metric = a.MustComponent(metric.CName).(metric.Metric) - s.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService) - return nil -} +type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler func (s *drpcServer) Name() (name string) { return CName } -func (s *drpcServer) Run(ctx context.Context) (err error) { - params := Params{ - BufferSizeMb: s.config.Stream.MaxMsgSizeMb, - TimeoutMillis: s.config.Stream.TimeoutMilliseconds, - ListenAddrs: s.config.Server.ListenAddrs, - Wrapper: func(handler drpc.Handler) drpc.Handler { - return s.metric.WrapDRPCHandler(handler) - }, - Handshake: func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - return s.transport.SecureInbound(ctx, conn) - }, +func (s *drpcServer) Init(a *app.App) (err error) { + s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet() + s.metric, _ = a.Component(metric.CName).(metric.Metric) + s.Mux = drpcmux.New() + + var handler drpc.Handler + handler = s + if s.metric != nil { + handler = s.metric.WrapDRPCHandler(s) } - return s.BaseDrpcServer.Run(ctx, params) + s.drpcServer = drpcserver.NewWithOptions(handler, drpcserver.Options{Manager: drpcmanager.Options{ + Reader: drpcwire.ReaderOptions{MaximumBufferSize: s.config.Stream.MaxMsgSizeMb * (1 << 20)}, + }}) + return } -func (s *drpcServer) Close(ctx context.Context) (err error) { - return s.BaseDrpcServer.Close(ctx) +func (s *drpcServer) ServeConn(ctx context.Context, conn net.Conn) (err error) { + l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String())) + l.Debug("drpc serve peer") + return s.drpcServer.ServeOne(ctx, conn) } diff --git a/net/rpc/server/util.go b/net/rpc/server/util.go deleted file mode 100644 index 5852288a..00000000 --- a/net/rpc/server/util.go +++ /dev/null @@ -1,18 +0,0 @@ -//go:build !windows - -package server - -import ( - "errors" - "net" -) - -// isTemporary checks if an error is temporary. -func isTemporary(err error) bool { - var nErr net.Error - if errors.As(err, &nErr) { - return nErr.Temporary() - } - - return false -} diff --git a/net/rpc/server/util_windows.go b/net/rpc/server/util_windows.go deleted file mode 100644 index efef2915..00000000 --- a/net/rpc/server/util_windows.go +++ /dev/null @@ -1,41 +0,0 @@ -//go:build windows - -package server - -import ( - "errors" - "net" - "os" - "syscall" -) - -const ( - _WSAEMFILE syscall.Errno = 10024 - _WSAENETRESET syscall.Errno = 10052 - _WSAENOBUFS syscall.Errno = 10055 -) - -// isTemporary checks if an error is temporary. -// see related go issue for more detail: https://go-review.googlesource.com/c/go/+/208537/ -func isTemporary(err error) bool { - var nErr net.Error - if !errors.As(err, &nErr) { - return false - } - - if nErr.Temporary() { - return true - } - - var sErr *os.SyscallError - if errors.As(err, &sErr) { - switch sErr.Err { - case _WSAENETRESET, - _WSAEMFILE, - _WSAENOBUFS: - return true - } - } - - return false -} diff --git a/net/transport/transport.go b/net/transport/transport.go index 7b247793..2dab5348 100644 --- a/net/transport/transport.go +++ b/net/transport/transport.go @@ -3,10 +3,15 @@ package transport import ( "context" + "errors" "net" "time" ) +var ( + ErrConnClosed = errors.New("connection closed") +) + // Transport is a common interface for a network transport type Transport interface { // SetAccepter sets accepter that will be called for new connections diff --git a/net/transport/yamux/conn.go b/net/transport/yamux/conn.go index 541b8257..d563df34 100644 --- a/net/transport/yamux/conn.go +++ b/net/transport/yamux/conn.go @@ -3,6 +3,7 @@ package yamux import ( "context" "github.com/anyproto/any-sync/net/connutil" + "github.com/anyproto/any-sync/net/transport" "github.com/hashicorp/yamux" "net" "time" @@ -30,3 +31,12 @@ func (y *yamuxConn) Context() context.Context { func (y *yamuxConn) Addr() string { return y.addr } + +func (y *yamuxConn) Accept() (conn net.Conn, err error) { + if conn, err = y.Session.Accept(); err != nil { + if err == yamux.ErrSessionShutdown { + err = transport.ErrConnClosed + } + } + return +}