diff --git a/net/secureservice/handshake/handshake.go b/net/secureservice/handshake/handshake.go index 39e6656f..abbafeb5 100644 --- a/net/secureservice/handshake/handshake.go +++ b/net/secureservice/handshake/handshake.go @@ -7,6 +7,7 @@ import ( "github.com/libp2p/go-libp2p/core/sec" "golang.org/x/exp/slices" "io" + "net" "sync" ) @@ -73,7 +74,7 @@ func newHandshake() *handshake { } type handshake struct { - conn sec.SecureConn + conn net.Conn remoteCred *handshakeproto.Credentials remoteProto *handshakeproto.Proto remoteAck *handshakeproto.Ack diff --git a/net/secureservice/handshake/proto.go b/net/secureservice/handshake/proto.go index 1e133069..45e95ab5 100644 --- a/net/secureservice/handshake/proto.go +++ b/net/secureservice/handshake/proto.go @@ -3,15 +3,15 @@ package handshake import ( "context" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" - "github.com/libp2p/go-libp2p/core/sec" "golang.org/x/exp/slices" + "net" ) type ProtoChecker struct { AllowedProtoTypes []handshakeproto.ProtoType } -func OutgoingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt handshakeproto.ProtoType) (err error) { +func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) (err error) { if ctx == nil { ctx = context.Background() } @@ -19,20 +19,20 @@ func OutgoingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt handshake done := make(chan struct{}) go func() { defer close(done) - err = outgoingProtoHandshake(h, sc, pt) + err = outgoingProtoHandshake(h, conn, pt) }() select { case <-done: return case <-ctx.Done(): - _ = sc.Close() + _ = conn.Close() return ctx.Err() } } -func outgoingProtoHandshake(h *handshake, sc sec.SecureConn, pt handshakeproto.ProtoType) (err error) { +func outgoingProtoHandshake(h *handshake, conn net.Conn, pt handshakeproto.ProtoType) (err error) { defer h.release() - h.conn = sc + h.conn = conn localProto := &handshakeproto.Proto{ Proto: pt, } @@ -54,7 +54,7 @@ func outgoingProtoHandshake(h *handshake, sc sec.SecureConn, pt handshakeproto.P return HandshakeError{e: msg.ack.Error} } -func IncomingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { +func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { if ctx == nil { ctx = context.Background() } @@ -62,20 +62,20 @@ func IncomingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt ProtoChec done := make(chan struct{}) go func() { defer close(done) - protoType, err = incomingProtoHandshake(h, sc, pt) + protoType, err = incomingProtoHandshake(h, conn, pt) }() select { case <-done: return case <-ctx.Done(): - _ = sc.Close() + _ = conn.Close() return 0, ctx.Err() } } -func incomingProtoHandshake(h *handshake, sc sec.SecureConn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { +func incomingProtoHandshake(h *handshake, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { defer h.release() - h.conn = sc + h.conn = conn msg, err := h.readMsg(msgTypeProto) if err != nil {