fix race in proto handshake

This commit is contained in:
Sergey Cherepanov 2023-06-28 10:11:21 +02:00
parent 5a02d1c338
commit 8770da4abf
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C

View File

@ -11,19 +11,20 @@ type ProtoChecker struct {
AllowedProtoTypes []handshakeproto.ProtoType AllowedProtoTypes []handshakeproto.ProtoType
} }
func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) (err error) { func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) error {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
h := newHandshake() h := newHandshake()
done := make(chan struct{}) done := make(chan struct{})
var err error
go func() { go func() {
defer close(done) defer close(done)
err = outgoingProtoHandshake(h, conn, pt) err = outgoingProtoHandshake(h, conn, pt)
}() }()
select { select {
case <-done: case <-done:
return return err
case <-ctx.Done(): case <-ctx.Done():
_ = conn.Close() _ = conn.Close()
return ctx.Err() return ctx.Err()
@ -54,19 +55,23 @@ func outgoingProtoHandshake(h *handshake, conn net.Conn, pt handshakeproto.Proto
return HandshakeError{e: msg.ack.Error} return HandshakeError{e: msg.ack.Error}
} }
func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (handshakeproto.ProtoType, error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
h := newHandshake() h := newHandshake()
done := make(chan struct{}) done := make(chan struct{})
var (
protoType handshakeproto.ProtoType
err error
)
go func() { go func() {
defer close(done) defer close(done)
protoType, err = incomingProtoHandshake(h, conn, pt) protoType, err = incomingProtoHandshake(h, conn, pt)
}() }()
select { select {
case <-done: case <-done:
return return protoType, err
case <-ctx.Done(): case <-ctx.Done():
_ = conn.Close() _ = conn.Close()
return 0, ctx.Err() return 0, ctx.Err()