refactor handshake err

This commit is contained in:
Sergey Cherepanov 2023-05-22 19:17:20 +02:00
parent cfca99cf19
commit 7a778a5c9a
No known key found for this signature in database
GPG Key ID: 87F8EDE8FBDF637C
4 changed files with 12 additions and 29 deletions

View File

@ -19,13 +19,13 @@ const (
) )
type HandshakeError struct { type HandshakeError struct {
err error Err error
e handshakeproto.Error e handshakeproto.Error
} }
func (he HandshakeError) Error() string { func (he HandshakeError) Error() string {
if he.err != nil { if he.Err != nil {
return he.err.Error() return he.Err.Error()
} }
return he.e.String() return he.e.String()
} }
@ -34,7 +34,7 @@ var (
ErrUnexpectedPayload = HandshakeError{e: handshakeproto.Error_UnexpectedPayload} ErrUnexpectedPayload = HandshakeError{e: handshakeproto.Error_UnexpectedPayload}
ErrDeadlineExceeded = HandshakeError{e: handshakeproto.Error_DeadlineExceeded} ErrDeadlineExceeded = HandshakeError{e: handshakeproto.Error_DeadlineExceeded}
ErrInvalidCredentials = HandshakeError{e: handshakeproto.Error_InvalidCredentials} ErrInvalidCredentials = HandshakeError{e: handshakeproto.Error_InvalidCredentials}
ErrPeerDeclinedCredentials = HandshakeError{err: errors.New("remote peer declined the credentials")} ErrPeerDeclinedCredentials = HandshakeError{Err: errors.New("remote peer declined the credentials")}
ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed} ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed}
ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected} ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected}

View File

@ -336,7 +336,7 @@ func TestIncomingHandshake(t *testing.T) {
require.Equal(t, handshakeproto.Error_IncompatibleVersion, msg.ack.Error) require.Equal(t, handshakeproto.Error_IncompatibleVersion, msg.ack.Error)
res := <-handshakeResCh res := <-handshakeResCh
require.EqualError(t, res.err, ErrIncompatibleVersion.Error()) assert.Equal(t, res.err, ErrIncompatibleVersion)
}) })
t.Run("write cred instead ack", func(t *testing.T) { t.Run("write cred instead ack", func(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)

View File

@ -16,19 +16,6 @@ import (
"net" "net"
) )
type HandshakeError struct {
remoteAddr string
err error
}
func (he HandshakeError) RemoteAddr() string {
return he.remoteAddr
}
func (he HandshakeError) Error() string {
return he.err.Error()
}
const CName = "common.net.secure" const CName = "common.net.secure"
var log = logger.NewNamed(CName) var log = logger.NewNamed(CName)
@ -91,18 +78,14 @@ func (s *secureService) Name() (name string) {
func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) { func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) {
sc, err = s.p2pTr.SecureInbound(ctx, conn, "") sc, err = s.p2pTr.SecureInbound(ctx, conn, "")
if err != nil { if err != nil {
return nil, nil, HandshakeError{ return nil, nil, handshake.HandshakeError{
remoteAddr: conn.RemoteAddr().String(), Err: err,
err: err,
} }
} }
identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker) identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker)
if err != nil { if err != nil {
return nil, nil, HandshakeError{ return nil, nil, err
remoteAddr: conn.RemoteAddr().String(),
err: err,
}
} }
cctx = context.Background() cctx = context.Background()
cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String()) cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String())
@ -113,7 +96,7 @@ func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx
func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) { func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) {
sc, err := s.p2pTr.SecureOutbound(ctx, conn, "") sc, err := s.p2pTr.SecureOutbound(ctx, conn, "")
if err != nil { if err != nil {
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()} return nil, handshake.HandshakeError{Err: err}
} }
peerId := sc.RemotePeer().String() peerId := sc.RemotePeer().String()
confTypes := s.nodeconf.NodeTypes(peerId) confTypes := s.nodeconf.NodeTypes(peerId)
@ -126,7 +109,7 @@ func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.
// ignore identity for outgoing connection because we don't need it at this moment // ignore identity for outgoing connection because we don't need it at this moment
_, err = handshake.OutgoingHandshake(ctx, sc, checker) _, err = handshake.OutgoingHandshake(ctx, sc, checker)
if err != nil { if err != nil {
return nil, HandshakeError{err: err, remoteAddr: conn.RemoteAddr().String()} return nil, err
} }
return sc, nil return sc, nil
} }

View File

@ -73,9 +73,9 @@ func TestHandshakeIncompatibleVersion(t *testing.T) {
fxC := newFixture(t, nc, nc.GetAccountService(1), 1) fxC := newFixture(t, nc, nc.GetAccountService(1), 1)
defer fxC.Finish(t) defer fxC.Finish(t)
_, err := fxC.SecureOutbound(ctx, cc) _, err := fxC.SecureOutbound(ctx, cc)
require.EqualError(t, err, handshake.ErrIncompatibleVersion.Error()) require.Equal(t, handshake.ErrIncompatibleVersion, err)
res := <-resCh res := <-resCh
require.EqualError(t, res.err, handshake.ErrIncompatibleVersion.Error()) require.Equal(t, handshake.ErrIncompatibleVersion, res.err)
} }
func newFixture(t *testing.T, nc *testnodeconf.Config, acc accountservice.Service, protoVersion uint32) *fixture { func newFixture(t *testing.T, nc *testnodeconf.Config, acc accountservice.Service, protoVersion uint32) *fixture {