From 8770da4abf1adae8a9f92c031422a95e7a2c8bed Mon Sep 17 00:00:00 2001 From: Sergey Cherepanov Date: Wed, 28 Jun 2023 10:11:21 +0200 Subject: [PATCH] fix race in proto handshake --- net/secureservice/handshake/proto.go | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/net/secureservice/handshake/proto.go b/net/secureservice/handshake/proto.go index 45e95ab5..6cd82110 100644 --- a/net/secureservice/handshake/proto.go +++ b/net/secureservice/handshake/proto.go @@ -11,19 +11,20 @@ type ProtoChecker struct { AllowedProtoTypes []handshakeproto.ProtoType } -func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) (err error) { +func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) error { if ctx == nil { ctx = context.Background() } h := newHandshake() done := make(chan struct{}) + var err error go func() { defer close(done) err = outgoingProtoHandshake(h, conn, pt) }() select { case <-done: - return + return err case <-ctx.Done(): _ = conn.Close() return ctx.Err() @@ -54,19 +55,23 @@ func outgoingProtoHandshake(h *handshake, conn net.Conn, pt handshakeproto.Proto return HandshakeError{e: msg.ack.Error} } -func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { +func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (handshakeproto.ProtoType, error) { if ctx == nil { ctx = context.Background() } h := newHandshake() done := make(chan struct{}) + var ( + protoType handshakeproto.ProtoType + err error + ) go func() { defer close(done) protoType, err = incomingProtoHandshake(h, conn, pt) }() select { case <-done: - return + return protoType, err case <-ctx.Done(): _ = conn.Close() return 0, ctx.Err()