yamux: AddListener method

This commit is contained in:
Sergey Cherepanov 2023-06-08 14:36:14 +02:00
parent 85cf6b8332
commit 24ce490524
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C

View File

@ -11,6 +11,7 @@ import (
"github.com/hashicorp/yamux" "github.com/hashicorp/yamux"
"go.uber.org/zap" "go.uber.org/zap"
"net" "net"
"sync"
"time" "time"
) )
@ -25,6 +26,7 @@ func New() Yamux {
// Yamux implements transport.Transport with tcp+yamux // Yamux implements transport.Transport with tcp+yamux
type Yamux interface { type Yamux interface {
transport.Transport transport.Transport
AddListener(lis net.Listener)
app.ComponentRunnable app.ComponentRunnable
} }
@ -37,6 +39,7 @@ type yamuxTransport struct {
listCtx context.Context listCtx context.Context
listCtxCancel context.CancelFunc listCtxCancel context.CancelFunc
yamuxConf *yamux.Config yamuxConf *yamux.Config
mu sync.Mutex
} }
func (y *yamuxTransport) Init(a *app.App) (err error) { func (y *yamuxTransport) Init(a *app.App) (err error) {
@ -63,6 +66,8 @@ func (y *yamuxTransport) Run(ctx context.Context) (err error) {
if y.accepter == nil { if y.accepter == nil {
return fmt.Errorf("can't run service without accepter") return fmt.Errorf("can't run service without accepter")
} }
y.mu.Lock()
defer y.mu.Unlock()
for _, listAddr := range y.conf.ListenAddrs { for _, listAddr := range y.conf.ListenAddrs {
list, err := net.Listen("tcp", listAddr) list, err := net.Listen("tcp", listAddr)
if err != nil { if err != nil {
@ -81,6 +86,13 @@ func (y *yamuxTransport) SetAccepter(accepter transport.Accepter) {
y.accepter = accepter y.accepter = accepter
} }
func (y *yamuxTransport) AddListener(lis net.Listener) {
y.mu.Lock()
defer y.mu.Unlock()
y.listeners = append(y.listeners, lis)
go y.acceptLoop(y.listCtx, lis)
}
func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.MultiConn, err error) { func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.MultiConn, err error) {
dialTimeout := time.Duration(y.conf.DialTimeoutSec) * time.Second dialTimeout := time.Duration(y.conf.DialTimeoutSec) * time.Second
conn, err := net.DialTimeout("tcp", addr, dialTimeout) conn, err := net.DialTimeout("tcp", addr, dialTimeout)