diff --git a/net/dialer/dialer.go b/net/dialer/dialer.go index 7252a4f3..4d85005b 100644 --- a/net/dialer/dialer.go +++ b/net/dialer/dialer.go @@ -24,7 +24,10 @@ import ( const CName = "common.net.dialer" -var ErrArrdsNotFound = errors.New("addrs for peer not found") +var ( + ErrAddrsNotFound = errors.New("addrs for peer not found") + ErrPeerIdIsUnexpected = errors.New("expected to connect with other peer id") +) var log = logger.NewNamed(CName) @@ -42,6 +45,7 @@ type Dialer interface { type dialer struct { transport secureservice.SecureService config net2.Config + nodeConf nodeconf.NodeConf peerAddrs map[string][]string mu sync.RWMutex @@ -49,7 +53,7 @@ type dialer struct { func (d *dialer) Init(a *app.App) (err error) { d.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService) - d.peerAddrs = a.MustComponent(nodeconf.CName).(nodeconf.Service).GetLast().Addresses() + d.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) d.config = a.MustComponent("config").(net2.ConfigGetter).GetNet() return } @@ -73,6 +77,17 @@ func (d *dialer) SetPeerAddrs(peerId string, addrs []string) { d.peerAddrs[peerId] = addrs } +func (d *dialer) getPeerAddrs(peerId string) ([]string, error) { + if addrs, ok := d.nodeConf.PeerAddresses(peerId); ok { + return addrs, nil + } + addrs, ok := d.peerAddrs[peerId] + if !ok || len(addrs) == 0 { + return nil, ErrAddrsNotFound + } + return addrs, nil +} + func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err error) { var ctxCancel context.CancelFunc ctx, ctxCancel = context.WithTimeout(ctx, time.Second*10) @@ -80,17 +95,18 @@ func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err erro d.mu.RLock() defer d.mu.RUnlock() - addrs, ok := d.peerAddrs[peerId] - if !ok || len(addrs) == 0 { - return nil, ErrArrdsNotFound + addrs, err := d.getPeerAddrs(peerId) + if err != nil { + return } + var ( conn drpc.Conn sc sec.SecureConn ) log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs)) for _, addr := range addrs { - conn, sc, err = d.handshake(ctx, addr) + conn, sc, err = d.handshake(ctx, addr, peerId) if err != nil { log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err)) } else { @@ -103,7 +119,7 @@ func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err erro return peer.NewPeer(sc, conn), nil } -func (d *dialer) handshake(ctx context.Context, addr string) (conn drpc.Conn, sc sec.SecureConn, err error) { +func (d *dialer) handshake(ctx context.Context, addr, peerId string) (conn drpc.Conn, sc sec.SecureConn, err error) { st := time.Now() // TODO: move dial timeout to config tcpConn, err := net.DialTimeout("tcp", addr, time.Second*3) @@ -116,6 +132,9 @@ func (d *dialer) handshake(ctx context.Context, addr string) (conn drpc.Conn, sc if err != nil { return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st)) } + if peerId != sc.RemotePeer().String() { + return nil, nil, ErrPeerIdIsUnexpected + } log.Info("connected with remote host", zap.String("serverPeer", sc.RemotePeer().String()), zap.String("addr", addr)) conn = drpcconn.NewWithOptions(sc, drpcconn.Options{Manager: drpcmanager.Options{ Reader: drpcwire.ReaderOptions{MaximumBufferSize: d.config.Stream.MaxMsgSizeMb * (1 << 20)},