simplify drpc server + peer accept loop

This commit is contained in:
Sergey Cherepanov 2023-05-31 20:24:07 +02:00
parent 00c582e157
commit fb007211f0
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C
9 changed files with 185 additions and 241 deletions

View File

@ -2,28 +2,37 @@ package peer
import (
"context"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/app/ocache"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anyproto/any-sync/net/transport"
"go.uber.org/zap"
"io"
"net"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
"sync"
"time"
"github.com/anyproto/any-sync/app/logger"
"go.uber.org/zap"
)
var log = logger.NewNamed("common.net.peer")
func NewPeer(mc transport.MultiConn) (p Peer, err error) {
type connCtrl interface {
ServeConn(ctx context.Context, conn net.Conn) (err error)
}
func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) {
ctx := mc.Context()
pr := &peer{
active: map[drpc.Conn]struct{}{},
MultiConn: mc,
ctrl: ctrl,
}
if pr.id, err = CtxPeerId(ctx); err != nil {
return
}
go pr.acceptLoop()
return pr, nil
}
@ -43,9 +52,12 @@ type Peer interface {
type peer struct {
id string
ctrl connCtrl
// drpc conn pool
inactive []drpc.Conn
active map[drpc.Conn]struct{}
mu sync.Mutex
transport.MultiConn
@ -83,16 +95,49 @@ func (p *peer) ReleaseDrpcConn(conn drpc.Conn) {
return
}
func (p *peer) acceptLoop() {
var exitErr error
defer func() {
if exitErr != transport.ErrConnClosed {
log.Warn("accept error: close connection", zap.Error(exitErr))
_ = p.MultiConn.Close()
}
}()
for {
conn, err := p.Accept()
if err != nil {
exitErr = err
return
}
go func() {
serveErr := p.serve(conn)
if serveErr != io.EOF && serveErr != transport.ErrConnClosed {
log.InfoCtx(p.Context(), "serve connection error", zap.Error(serveErr))
}
}()
}
}
var defaultProtoChecker = handshake.ProtoChecker{
AllowedProtoTypes: []handshakeproto.ProtoType{
handshakeproto.ProtoType_DRPC,
},
}
func (p *peer) serve(conn net.Conn) (err error) {
hsCtx, cancel := context.WithTimeout(p.Context(), time.Second*20)
if _, err = handshake.IncomingProtoHandshake(hsCtx, conn, defaultProtoChecker); err != nil {
cancel()
return
}
cancel()
return p.ctrl.ServeConn(p.Context(), conn)
}
func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) {
if time.Now().Sub(p.LastUsage()) < objectTTL {
return false, nil
}
p.mu.Lock()
if len(p.active) > 0 {
p.mu.Unlock()
return false, nil
}
p.mu.Unlock()
return true, p.Close()
}

View File

@ -2,12 +2,16 @@ package peer
import (
"context"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anyproto/any-sync/net/transport/mock_transport"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"net"
"testing"
"time"
)
var ctx = context.Background()
@ -38,14 +42,63 @@ func TestPeer_AcquireDrpcConn(t *testing.T) {
assert.Len(t, fx.inactive, 0)
}
func TestPeerAccept(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
in, out := net.Pipe()
defer out.Close()
var outHandshakeCh = make(chan error)
go func() {
outHandshakeCh <- handshake.OutgoingProtoHandshake(ctx, out, handshakeproto.ProtoType_DRPC)
}()
fx.acceptCh <- acceptedConn{conn: in}
cn := <-fx.testCtrl.serveConn
assert.Equal(t, in, cn)
assert.NoError(t, <-outHandshakeCh)
}
func TestPeer_TryClose(t *testing.T) {
t.Run("ttl", func(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
lu := time.Now()
fx.mc.EXPECT().LastUsage().Return(lu)
res, err := fx.TryClose(time.Second)
require.NoError(t, err)
assert.False(t, res)
})
t.Run("close", func(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
lu := time.Now().Add(-time.Second * 2)
fx.mc.EXPECT().LastUsage().Return(lu)
res, err := fx.TryClose(time.Second)
require.NoError(t, err)
assert.True(t, res)
})
}
type acceptedConn struct {
conn net.Conn
err error
}
func newFixture(t *testing.T, peerId string) *fixture {
fx := &fixture{
ctrl: gomock.NewController(t),
acceptCh: make(chan acceptedConn),
testCtrl: newTesCtrl(),
}
fx.mc = mock_transport.NewMockMultiConn(fx.ctrl)
ctx := CtxWithPeerId(context.Background(), peerId)
fx.mc.EXPECT().Context().Return(ctx).AnyTimes()
p, err := NewPeer(fx.mc)
fx.mc.EXPECT().Accept().DoAndReturn(func() (net.Conn, error) {
ac := <-fx.acceptCh
return ac.conn, ac.err
}).AnyTimes()
fx.mc.EXPECT().Close().AnyTimes()
p, err := NewPeer(fx.mc, fx.testCtrl)
require.NoError(t, err)
fx.peer = p.(*peer)
return fx
@ -55,8 +108,30 @@ type fixture struct {
*peer
ctrl *gomock.Controller
mc *mock_transport.MockMultiConn
acceptCh chan acceptedConn
testCtrl *testCtrl
}
func (fx *fixture) finish() {
fx.testCtrl.close()
fx.ctrl.Finish()
}
func newTesCtrl() *testCtrl {
return &testCtrl{closeCh: make(chan struct{}), serveConn: make(chan net.Conn, 10)}
}
type testCtrl struct {
serveConn chan net.Conn
closeCh chan struct{}
}
func (t *testCtrl) ServeConn(ctx context.Context, conn net.Conn) (err error) {
t.serveConn <- conn
<-t.closeCh
return io.EOF
}
func (t *testCtrl) close() {
close(t.closeCh)
}

View File

@ -7,6 +7,7 @@ import (
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/pool"
"github.com/anyproto/any-sync/net/rpc/server"
"github.com/anyproto/any-sync/net/transport"
"github.com/anyproto/any-sync/net/transport/yamux"
"github.com/anyproto/any-sync/nodeconf"
@ -38,6 +39,7 @@ type peerService struct {
nodeConf nodeconf.NodeConf
peerAddrs map[string][]string
pool pool.Pool
server server.DRPCServer
mu sync.RWMutex
}
@ -45,6 +47,7 @@ func (p *peerService) Init(a *app.App) (err error) {
p.yamux = a.MustComponent(yamux.CName).(transport.Transport)
p.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
p.pool = a.MustComponent(pool.CName).(pool.Pool)
p.server = a.MustComponent(server.CName).(server.DRPCServer)
p.peerAddrs = map[string][]string{}
return nil
}
@ -75,11 +78,11 @@ func (p *peerService) Dial(ctx context.Context, peerId string) (pr peer.Peer, er
if err != nil {
return
}
return peer.NewPeer(mc)
return peer.NewPeer(mc, p.server)
}
func (p *peerService) Accept(mc transport.MultiConn) (err error) {
pr, err := peer.NewPeer(mc)
pr, err := peer.NewPeer(mc, p.server)
if err != nil {
return err
}

View File

@ -1,134 +0,0 @@
package server
import (
"context"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/secureservice"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/zeebo/errs"
"go.uber.org/zap"
"io"
"net"
"storj.io/drpc"
"storj.io/drpc/drpcmanager"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"storj.io/drpc/drpcwire"
"time"
)
type BaseDrpcServer struct {
drpcServer *drpcserver.Server
transport secureservice.SecureService
listeners []net.Listener
handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
cancel func()
*drpcmux.Mux
}
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
type Params struct {
BufferSizeMb int
ListenAddrs []string
Wrapper DRPCHandlerWrapper
TimeoutMillis int
Handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
}
func NewBaseDrpcServer() *BaseDrpcServer {
return &BaseDrpcServer{Mux: drpcmux.New()}
}
func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) {
s.drpcServer = drpcserver.NewWithOptions(params.Wrapper(s.Mux), drpcserver.Options{Manager: drpcmanager.Options{
Reader: drpcwire.ReaderOptions{MaximumBufferSize: params.BufferSizeMb * (1 << 20)},
}})
s.handshake = params.Handshake
ctx, s.cancel = context.WithCancel(ctx)
for _, addr := range params.ListenAddrs {
list, err := net.Listen("tcp", addr)
if err != nil {
return err
}
s.listeners = append(s.listeners, list)
go s.serve(ctx, list)
}
return
}
func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) {
l := log.With(zap.String("localAddr", lis.Addr().String()))
l.Info("drpc listener started")
defer func() {
l.Debug("drpc listener stopped")
}()
for {
select {
case <-ctx.Done():
return
default:
}
conn, err := lis.Accept()
if err != nil {
if isTemporary(err) {
l.Debug("listener temporary accept error", zap.Error(err))
select {
case <-time.After(time.Second):
case <-ctx.Done():
return
}
continue
}
l.Error("listener accept error", zap.Error(err))
return
}
go s.serveConn(conn)
}
}
func (s *BaseDrpcServer) serveConn(conn net.Conn) {
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
var (
ctx = context.Background()
err error
)
if s.handshake != nil {
ctx, conn, err = s.handshake(conn)
if err != nil {
l.Info("handshake error", zap.Error(err))
return
}
if sc, ok := conn.(sec.SecureConn); ok {
ctx = peer.CtxWithPeerId(ctx, sc.RemotePeer().String())
}
}
ctx = peer.CtxWithPeerAddr(ctx, conn.RemoteAddr().String())
l.Debug("connection opened")
if err := s.drpcServer.ServeOne(ctx, conn); err != nil {
if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) {
l.Debug("connection closed")
} else {
l.Warn("serve connection error", zap.Error(err))
}
}
}
func (s *BaseDrpcServer) ListenAddrs() (addrs []net.Addr) {
for _, list := range s.listeners {
addrs = append(addrs, list.Addr())
}
return
}
func (s *BaseDrpcServer) Close(ctx context.Context) (err error) {
if s.cancel != nil {
s.cancel()
}
for _, l := range s.listeners {
if e := l.Close(); e != nil {
log.Warn("close listener error", zap.Error(e))
}
}
return
}

View File

@ -6,11 +6,13 @@ import (
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/metric"
anyNet "github.com/anyproto/any-sync/net"
"github.com/anyproto/any-sync/net/secureservice"
"github.com/libp2p/go-libp2p/core/sec"
"go.uber.org/zap"
"net"
"storj.io/drpc"
"time"
"storj.io/drpc/drpcmanager"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"storj.io/drpc/drpcwire"
)
const CName = "common.net.drpcserver"
@ -18,49 +20,46 @@ const CName = "common.net.drpcserver"
var log = logger.NewNamed(CName)
func New() DRPCServer {
return &drpcServer{BaseDrpcServer: NewBaseDrpcServer()}
return &drpcServer{}
}
type DRPCServer interface {
app.ComponentRunnable
ServeConn(ctx context.Context, conn net.Conn) (err error)
app.Component
drpc.Mux
}
type drpcServer struct {
drpcServer *drpcserver.Server
*drpcmux.Mux
config anyNet.Config
metric metric.Metric
transport secureservice.SecureService
*BaseDrpcServer
}
func (s *drpcServer) Init(a *app.App) (err error) {
s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet()
s.metric = a.MustComponent(metric.CName).(metric.Metric)
s.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService)
return nil
}
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
func (s *drpcServer) Name() (name string) {
return CName
}
func (s *drpcServer) Run(ctx context.Context) (err error) {
params := Params{
BufferSizeMb: s.config.Stream.MaxMsgSizeMb,
TimeoutMillis: s.config.Stream.TimeoutMilliseconds,
ListenAddrs: s.config.Server.ListenAddrs,
Wrapper: func(handler drpc.Handler) drpc.Handler {
return s.metric.WrapDRPCHandler(handler)
},
Handshake: func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return s.transport.SecureInbound(ctx, conn)
},
func (s *drpcServer) Init(a *app.App) (err error) {
s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet()
s.metric, _ = a.Component(metric.CName).(metric.Metric)
s.Mux = drpcmux.New()
var handler drpc.Handler
handler = s
if s.metric != nil {
handler = s.metric.WrapDRPCHandler(s)
}
return s.BaseDrpcServer.Run(ctx, params)
s.drpcServer = drpcserver.NewWithOptions(handler, drpcserver.Options{Manager: drpcmanager.Options{
Reader: drpcwire.ReaderOptions{MaximumBufferSize: s.config.Stream.MaxMsgSizeMb * (1 << 20)},
}})
return
}
func (s *drpcServer) Close(ctx context.Context) (err error) {
return s.BaseDrpcServer.Close(ctx)
func (s *drpcServer) ServeConn(ctx context.Context, conn net.Conn) (err error) {
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
l.Debug("drpc serve peer")
return s.drpcServer.ServeOne(ctx, conn)
}

View File

@ -1,18 +0,0 @@
//go:build !windows
package server
import (
"errors"
"net"
)
// isTemporary checks if an error is temporary.
func isTemporary(err error) bool {
var nErr net.Error
if errors.As(err, &nErr) {
return nErr.Temporary()
}
return false
}

View File

@ -1,41 +0,0 @@
//go:build windows
package server
import (
"errors"
"net"
"os"
"syscall"
)
const (
_WSAEMFILE syscall.Errno = 10024
_WSAENETRESET syscall.Errno = 10052
_WSAENOBUFS syscall.Errno = 10055
)
// isTemporary checks if an error is temporary.
// see related go issue for more detail: https://go-review.googlesource.com/c/go/+/208537/
func isTemporary(err error) bool {
var nErr net.Error
if !errors.As(err, &nErr) {
return false
}
if nErr.Temporary() {
return true
}
var sErr *os.SyscallError
if errors.As(err, &sErr) {
switch sErr.Err {
case _WSAENETRESET,
_WSAEMFILE,
_WSAENOBUFS:
return true
}
}
return false
}

View File

@ -3,10 +3,15 @@ package transport
import (
"context"
"errors"
"net"
"time"
)
var (
ErrConnClosed = errors.New("connection closed")
)
// Transport is a common interface for a network transport
type Transport interface {
// SetAccepter sets accepter that will be called for new connections

View File

@ -3,6 +3,7 @@ package yamux
import (
"context"
"github.com/anyproto/any-sync/net/connutil"
"github.com/anyproto/any-sync/net/transport"
"github.com/hashicorp/yamux"
"net"
"time"
@ -30,3 +31,12 @@ func (y *yamuxConn) Context() context.Context {
func (y *yamuxConn) Addr() string {
return y.addr
}
func (y *yamuxConn) Accept() (conn net.Conn, err error) {
if conn, err = y.Session.Accept(); err != nil {
if err == yamux.ErrSessionShutdown {
err = transport.ErrConnClosed
}
}
return
}