From 46de85a7f824379972c01f87c443784c154bedfc Mon Sep 17 00:00:00 2001 From: mcrakhman Date: Tue, 25 Apr 2023 20:55:00 +0200 Subject: [PATCH] Check peerId when dialling --- net/dialer/dialer.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/net/dialer/dialer.go b/net/dialer/dialer.go index 00cdb996..4d85005b 100644 --- a/net/dialer/dialer.go +++ b/net/dialer/dialer.go @@ -24,7 +24,10 @@ import ( const CName = "common.net.dialer" -var ErrArrdsNotFound = errors.New("addrs for peer not found") +var ( + ErrAddrsNotFound = errors.New("addrs for peer not found") + ErrPeerIdIsUnexpected = errors.New("expected to connect with other peer id") +) var log = logger.NewNamed(CName) @@ -80,7 +83,7 @@ func (d *dialer) getPeerAddrs(peerId string) ([]string, error) { } addrs, ok := d.peerAddrs[peerId] if !ok || len(addrs) == 0 { - return nil, ErrArrdsNotFound + return nil, ErrAddrsNotFound } return addrs, nil } @@ -103,7 +106,7 @@ func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err erro ) log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs)) for _, addr := range addrs { - conn, sc, err = d.handshake(ctx, addr) + conn, sc, err = d.handshake(ctx, addr, peerId) if err != nil { log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err)) } else { @@ -116,7 +119,7 @@ func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err erro return peer.NewPeer(sc, conn), nil } -func (d *dialer) handshake(ctx context.Context, addr string) (conn drpc.Conn, sc sec.SecureConn, err error) { +func (d *dialer) handshake(ctx context.Context, addr, peerId string) (conn drpc.Conn, sc sec.SecureConn, err error) { st := time.Now() // TODO: move dial timeout to config tcpConn, err := net.DialTimeout("tcp", addr, time.Second*3) @@ -129,6 +132,9 @@ func (d *dialer) handshake(ctx context.Context, addr string) (conn drpc.Conn, sc if err != nil { return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st)) } + if peerId != sc.RemotePeer().String() { + return nil, nil, ErrPeerIdIsUnexpected + } log.Info("connected with remote host", zap.String("serverPeer", sc.RemotePeer().String()), zap.String("addr", addr)) conn = drpcconn.NewWithOptions(sc, drpcconn.Options{Manager: drpcmanager.Options{ Reader: drpcwire.ReaderOptions{MaximumBufferSize: d.config.Stream.MaxMsgSizeMb * (1 << 20)},