use net.Conn for a proto handshake

This commit is contained in:
Sergey Cherepanov 2023-05-31 19:24:23 +02:00
parent c43ac9eb84
commit 00c582e157
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C
2 changed files with 13 additions and 12 deletions

View File

@ -7,6 +7,7 @@ import (
"github.com/libp2p/go-libp2p/core/sec" "github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"io" "io"
"net"
"sync" "sync"
) )
@ -73,7 +74,7 @@ func newHandshake() *handshake {
} }
type handshake struct { type handshake struct {
conn sec.SecureConn conn net.Conn
remoteCred *handshakeproto.Credentials remoteCred *handshakeproto.Credentials
remoteProto *handshakeproto.Proto remoteProto *handshakeproto.Proto
remoteAck *handshakeproto.Ack remoteAck *handshakeproto.Ack

View File

@ -3,15 +3,15 @@ package handshake
import ( import (
"context" "context"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"net"
) )
type ProtoChecker struct { type ProtoChecker struct {
AllowedProtoTypes []handshakeproto.ProtoType AllowedProtoTypes []handshakeproto.ProtoType
} }
func OutgoingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt handshakeproto.ProtoType) (err error) { func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) (err error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
@ -19,20 +19,20 @@ func OutgoingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt handshake
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
err = outgoingProtoHandshake(h, sc, pt) err = outgoingProtoHandshake(h, conn, pt)
}() }()
select { select {
case <-done: case <-done:
return return
case <-ctx.Done(): case <-ctx.Done():
_ = sc.Close() _ = conn.Close()
return ctx.Err() return ctx.Err()
} }
} }
func outgoingProtoHandshake(h *handshake, sc sec.SecureConn, pt handshakeproto.ProtoType) (err error) { func outgoingProtoHandshake(h *handshake, conn net.Conn, pt handshakeproto.ProtoType) (err error) {
defer h.release() defer h.release()
h.conn = sc h.conn = conn
localProto := &handshakeproto.Proto{ localProto := &handshakeproto.Proto{
Proto: pt, Proto: pt,
} }
@ -54,7 +54,7 @@ func outgoingProtoHandshake(h *handshake, sc sec.SecureConn, pt handshakeproto.P
return HandshakeError{e: msg.ack.Error} return HandshakeError{e: msg.ack.Error}
} }
func IncomingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
@ -62,20 +62,20 @@ func IncomingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt ProtoChec
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
defer close(done) defer close(done)
protoType, err = incomingProtoHandshake(h, sc, pt) protoType, err = incomingProtoHandshake(h, conn, pt)
}() }()
select { select {
case <-done: case <-done:
return return
case <-ctx.Done(): case <-ctx.Done():
_ = sc.Close() _ = conn.Close()
return 0, ctx.Err() return 0, ctx.Err()
} }
} }
func incomingProtoHandshake(h *handshake, sc sec.SecureConn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { func incomingProtoHandshake(h *handshake, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
defer h.release() defer h.release()
h.conn = sc h.conn = conn
msg, err := h.readMsg(msgTypeProto) msg, err := h.readMsg(msgTypeProto)
if err != nil { if err != nil {