diff --git a/commonspace/headsync/diffsyncer_test.go b/commonspace/headsync/diffsyncer_test.go index e6558b61..bdbcab58 100644 --- a/commonspace/headsync/diffsyncer_test.go +++ b/commonspace/headsync/diffsyncer_test.go @@ -51,6 +51,10 @@ func (p pushSpaceRequestMatcher) String() string { type mockPeer struct{} +func (m mockPeer) Addr() string { + return "" +} + func (m mockPeer) TryClose(objectTTL time.Duration) (res bool, err error) { return true, m.Close() } diff --git a/metric/log_test.go b/metric/log_test.go new file mode 100644 index 00000000..0bad30a0 --- /dev/null +++ b/metric/log_test.go @@ -0,0 +1 @@ +package metric diff --git a/net/peer/context.go b/net/peer/context.go index d759ad01..b753e7d2 100644 --- a/net/peer/context.go +++ b/net/peer/context.go @@ -13,6 +13,7 @@ type contextKey uint const ( contextKeyPeerId contextKey = iota contextKeyIdentity + contextKeyPeerAddr ) var ( @@ -36,6 +37,19 @@ func CtxWithPeerId(ctx context.Context, peerId string) context.Context { return context.WithValue(ctx, contextKeyPeerId, peerId) } +// CtxPeerAddr returns peer address +func CtxPeerAddr(ctx context.Context) string { + if p, ok := ctx.Value(contextKeyPeerAddr).(string); ok { + return p + } + return "" +} + +// CtxWithPeerAddr sets peer address to the context +func CtxWithPeerAddr(ctx context.Context, addr string) context.Context { + return context.WithValue(ctx, contextKeyPeerAddr, addr) +} + // CtxIdentity returns identity from context func CtxIdentity(ctx context.Context) ([]byte, error) { if identity, ok := ctx.Value(contextKeyIdentity).([]byte); ok { diff --git a/net/peer/peer.go b/net/peer/peer.go index 92137930..95ac2b12 100644 --- a/net/peer/peer.go +++ b/net/peer/peer.go @@ -26,6 +26,7 @@ type Peer interface { Id() string LastUsage() time.Time UpdateLastUsage() + Addr() string TryClose(objectTTL time.Duration) (res bool, err error) drpc.Conn } @@ -86,6 +87,13 @@ func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) { return true, p.Close() } +func (p *peer) Addr() string { + if p.sc != nil { + return p.sc.RemoteAddr().String() + } + return "" +} + func (p *peer) Close() (err error) { log.Debug("peer close", zap.String("peerId", p.id)) return p.Conn.Close() diff --git a/net/pool/pool_test.go b/net/pool/pool_test.go index ce3876e0..262a59f6 100644 --- a/net/pool/pool_test.go +++ b/net/pool/pool_test.go @@ -184,6 +184,10 @@ type testPeer struct { closed chan struct{} } +func (t *testPeer) Addr() string { + return "" +} + func (t *testPeer) Id() string { return t.id } diff --git a/net/rpc/rpctest/pool.go b/net/rpc/rpctest/pool.go index 630cbb6a..3a9935df 100644 --- a/net/rpc/rpctest/pool.go +++ b/net/rpc/rpctest/pool.go @@ -103,6 +103,10 @@ type testPeer struct { drpc.Conn } +func (t testPeer) Addr() string { + return "" +} + func (t testPeer) TryClose(objectTTL time.Duration) (res bool, err error) { return true, t.Close() } diff --git a/net/rpc/server/baseserver.go b/net/rpc/server/baseserver.go index 85d23bf3..91fca952 100644 --- a/net/rpc/server/baseserver.go +++ b/net/rpc/server/baseserver.go @@ -2,6 +2,7 @@ package server import ( "context" + "github.com/anytypeio/any-sync/net/peer" "github.com/anytypeio/any-sync/net/secureservice" "github.com/libp2p/go-libp2p/core/sec" "github.com/zeebo/errs" @@ -98,8 +99,11 @@ func (s *BaseDrpcServer) serveConn(conn net.Conn) { l.Info("handshake error", zap.Error(err)) return } + if sc, ok := conn.(sec.SecureConn); ok { + ctx = peer.CtxWithPeerId(ctx, sc.RemotePeer().String()) + } } - + ctx = peer.CtxWithPeerAddr(ctx, conn.RemoteAddr().String()) l.Debug("connection opened") if err := s.drpcServer.ServeOne(ctx, conn); err != nil { if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) { diff --git a/net/streampool/stream.go b/net/streampool/stream.go index 065f322e..f2e092c4 100644 --- a/net/streampool/stream.go +++ b/net/streampool/stream.go @@ -10,6 +10,7 @@ import ( type stream struct { peerId string + peerCtx context.Context stream drpc.Stream pool *streamPool streamId uint32 @@ -36,7 +37,7 @@ func (sr *stream) readLoop() error { sr.l.Info("msg receive error", zap.Error(err)) return err } - ctx := streamCtx(context.Background(), sr.streamId, sr.peerId) + ctx := streamCtx(sr.peerCtx, sr.streamId, sr.peerId) ctx = logger.CtxWithFields(ctx, zap.String("peerId", sr.peerId)) if err := sr.pool.handler.HandleMessage(ctx, sr.peerId, msg); err != nil { sr.l.Info("msg handle error", zap.Error(err)) diff --git a/net/streampool/streampool.go b/net/streampool/streampool.go index 84108371..5295cea4 100644 --- a/net/streampool/streampool.go +++ b/net/streampool/streampool.go @@ -27,9 +27,9 @@ type PeerGetter func(ctx context.Context) (peers []peer.Peer, err error) // StreamPool keeps and read streams type StreamPool interface { // AddStream adds new outgoing stream into the pool - AddStream(peerId string, stream drpc.Stream, tags ...string) + AddStream(stream drpc.Stream, tags ...string) (err error) // ReadStream adds new incoming stream and synchronously read it - ReadStream(peerId string, stream drpc.Stream, tags ...string) (err error) + ReadStream(stream drpc.Stream, tags ...string) (err error) // Send sends a message to given peers. A stream will be opened if it is not cached before. Works async. Send(ctx context.Context, msg drpc.Message, target PeerGetter) (err error) // SendById sends a message to given peerIds. Works only if stream exists @@ -63,16 +63,23 @@ type openingProcess struct { err error } -func (s *streamPool) ReadStream(peerId string, drpcStream drpc.Stream, tags ...string) error { - st := s.addStream(peerId, drpcStream, tags...) +func (s *streamPool) ReadStream(drpcStream drpc.Stream, tags ...string) error { + st, err := s.addStream(drpcStream, tags...) + if err != nil { + return err + } return st.readLoop() } -func (s *streamPool) AddStream(peerId string, drpcStream drpc.Stream, tags ...string) { - st := s.addStream(peerId, drpcStream, tags...) +func (s *streamPool) AddStream(drpcStream drpc.Stream, tags ...string) error { + st, err := s.addStream(drpcStream, tags...) + if err != nil { + return err + } go func() { _ = st.readLoop() }() + return nil } func (s *streamPool) Streams(tags ...string) (streams []drpc.Stream) { @@ -86,13 +93,19 @@ func (s *streamPool) Streams(tags ...string) (streams []drpc.Stream) { return } -func (s *streamPool) addStream(peerId string, drpcStream drpc.Stream, tags ...string) *stream { +func (s *streamPool) addStream(drpcStream drpc.Stream, tags ...string) (*stream, error) { + ctx := drpcStream.Context() + peerId, err := peer.CtxPeerId(ctx) + if err != nil { + return nil, err + } s.mu.Lock() defer s.mu.Unlock() s.lastStreamId++ streamId := s.lastStreamId st := &stream{ peerId: peerId, + peerCtx: ctx, stream: drpcStream, pool: s, streamId: streamId, @@ -104,7 +117,7 @@ func (s *streamPool) addStream(peerId string, drpcStream drpc.Stream, tags ...st for _, tag := range tags { s.streamIdsByTag[tag] = append(s.streamIdsByTag[tag], streamId) } - return st + return st, nil } func (s *streamPool) Send(ctx context.Context, msg drpc.Message, peerGetter PeerGetter) (err error) { @@ -241,7 +254,10 @@ func (s *streamPool) openStream(ctx context.Context, p peer.Peer) *openingProces op.err = err return } - s.AddStream(p.Id(), st, tags...) + if err = s.AddStream(st, tags...); err != nil { + op.err = nil + return + } }() return op } diff --git a/net/streampool/streampool_test.go b/net/streampool/streampool_test.go index 2ab15d26..57b2698d 100644 --- a/net/streampool/streampool_test.go +++ b/net/streampool/streampool_test.go @@ -21,6 +21,7 @@ var ctx = context.Background() func newClientStream(t *testing.T, fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) { p, err := fx.tp.Dial(ctx, peerId) require.NoError(t, err) + ctx = peer.CtxWithPeerId(ctx, peerId) s, err := testservice.NewDRPCTestClient(p).TestStream(ctx) require.NoError(t, err) return s, p @@ -33,9 +34,9 @@ func TestStreamPool_AddStream(t *testing.T) { defer fx.Finish(t) s1, _ := newClientStream(t, fx, "p1") - fx.AddStream("p1", s1, "space1", "common") + require.NoError(t, fx.AddStream(s1, "space1", "common")) s2, _ := newClientStream(t, fx, "p2") - fx.AddStream("p2", s2, "space2", "common") + require.NoError(t, fx.AddStream(s2, "space2", "common")) require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space1"}, "space1")) require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space2"}, "space2")) @@ -64,7 +65,7 @@ func TestStreamPool_AddStream(t *testing.T) { s1, p1 := newClientStream(t, fx, "p1") defer s1.Close() - fx.AddStream("p1", s1, "space1", "common") + require.NoError(t, fx.AddStream(s1, "space1", "common")) require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "test"}, func(ctx context.Context) (peers []peer.Peer, err error) { return []peer.Peer{p1}, nil @@ -159,7 +160,7 @@ func TestStreamPool_SendById(t *testing.T) { s1, _ := newClientStream(t, fx, "p1") defer s1.Close() - fx.AddStream("p1", s1, "space1", "common") + require.NoError(t, fx.AddStream(s1, "space1", "common")) require.NoError(t, fx.SendById(ctx, &testservice.StreamMessage{ReqData: "test"}, "p1")) var msg *testservice.StreamMessage @@ -177,11 +178,11 @@ func TestStreamPool_Tags(t *testing.T) { s1, _ := newClientStream(t, fx, "p1") defer s1.Close() - fx.AddStream("p1", s1, "t1") + require.NoError(t, fx.AddStream(s1, "t1")) s2, _ := newClientStream(t, fx, "p2") defer s1.Close() - fx.AddStream("p2", s2, "t2") + require.NoError(t, fx.AddStream(s2, "t2")) err := fx.AddTagsCtx(streamCtx(ctx, 1, "p1"), "t3", "t3") require.NoError(t, err)