use net.Conn for a proto handshake

This commit is contained in:
Sergey Cherepanov 2023-05-31 19:24:23 +02:00
parent c43ac9eb84
commit 00c582e157
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C
2 changed files with 13 additions and 12 deletions

View File

@ -7,6 +7,7 @@ import (
"github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices"
"io"
"net"
"sync"
)
@ -73,7 +74,7 @@ func newHandshake() *handshake {
}
type handshake struct {
conn sec.SecureConn
conn net.Conn
remoteCred *handshakeproto.Credentials
remoteProto *handshakeproto.Proto
remoteAck *handshakeproto.Ack

View File

@ -3,15 +3,15 @@ package handshake
import (
"context"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices"
"net"
)
type ProtoChecker struct {
AllowedProtoTypes []handshakeproto.ProtoType
}
func OutgoingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt handshakeproto.ProtoType) (err error) {
func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) (err error) {
if ctx == nil {
ctx = context.Background()
}
@ -19,20 +19,20 @@ func OutgoingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt handshake
done := make(chan struct{})
go func() {
defer close(done)
err = outgoingProtoHandshake(h, sc, pt)
err = outgoingProtoHandshake(h, conn, pt)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
_ = conn.Close()
return ctx.Err()
}
}
func outgoingProtoHandshake(h *handshake, sc sec.SecureConn, pt handshakeproto.ProtoType) (err error) {
func outgoingProtoHandshake(h *handshake, conn net.Conn, pt handshakeproto.ProtoType) (err error) {
defer h.release()
h.conn = sc
h.conn = conn
localProto := &handshakeproto.Proto{
Proto: pt,
}
@ -54,7 +54,7 @@ func outgoingProtoHandshake(h *handshake, sc sec.SecureConn, pt handshakeproto.P
return HandshakeError{e: msg.ack.Error}
}
func IncomingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
if ctx == nil {
ctx = context.Background()
}
@ -62,20 +62,20 @@ func IncomingProtoHandshake(ctx context.Context, sc sec.SecureConn, pt ProtoChec
done := make(chan struct{})
go func() {
defer close(done)
protoType, err = incomingProtoHandshake(h, sc, pt)
protoType, err = incomingProtoHandshake(h, conn, pt)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
_ = conn.Close()
return 0, ctx.Err()
}
}
func incomingProtoHandshake(h *handshake, sc sec.SecureConn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
func incomingProtoHandshake(h *handshake, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
defer h.release()
h.conn = sc
h.conn = conn
msg, err := h.readMsg(msgTypeProto)
if err != nil {