diff --git a/net/secureservice/handshake/proto.go b/net/secureservice/handshake/proto.go index 45e95ab5..6cd82110 100644 --- a/net/secureservice/handshake/proto.go +++ b/net/secureservice/handshake/proto.go @@ -11,19 +11,20 @@ type ProtoChecker struct { AllowedProtoTypes []handshakeproto.ProtoType } -func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) (err error) { +func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) error { if ctx == nil { ctx = context.Background() } h := newHandshake() done := make(chan struct{}) + var err error go func() { defer close(done) err = outgoingProtoHandshake(h, conn, pt) }() select { case <-done: - return + return err case <-ctx.Done(): _ = conn.Close() return ctx.Err() @@ -54,19 +55,23 @@ func outgoingProtoHandshake(h *handshake, conn net.Conn, pt handshakeproto.Proto return HandshakeError{e: msg.ack.Error} } -func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { +func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (handshakeproto.ProtoType, error) { if ctx == nil { ctx = context.Background() } h := newHandshake() done := make(chan struct{}) + var ( + protoType handshakeproto.ProtoType + err error + ) go func() { defer close(done) protoType, err = incomingProtoHandshake(h, conn, pt) }() select { case <-done: - return + return protoType, err case <-ctx.Done(): _ = conn.Close() return 0, ctx.Err()