diff --git a/net/streampool/stream.go b/net/streampool/stream.go index 29d9e199..141d433c 100644 --- a/net/streampool/stream.go +++ b/net/streampool/stream.go @@ -25,6 +25,7 @@ func (sr *stream) write(msg drpc.Message) (err error) { var queueId string if qId, ok := msg.(MessageQueueId); ok { queueId = qId.MessageQueueId() + msg = qId.DrpcMessage() } return sr.queue.Add(sr.stream.Context(), queueId, msg) } diff --git a/net/streampool/streampool.go b/net/streampool/streampool.go index 1230ed88..442a5097 100644 --- a/net/streampool/streampool.go +++ b/net/streampool/streampool.go @@ -27,6 +27,7 @@ type PeerGetter func(ctx context.Context) (peers []peer.Peer, err error) type MessageQueueId interface { MessageQueueId() string + DrpcMessage() drpc.Message } // StreamPool keeps and read streams @@ -363,3 +364,21 @@ func removeStream(m map[string][]uint32, key string, streamId uint32) { m[key] = streamIds } } + +// WithQueueId wraps the message and adds queueId +func WithQueueId(msg drpc.Message, queueId string) drpc.Message { + return &messageWithQueueId{queueId: queueId, Message: msg} +} + +type messageWithQueueId struct { + drpc.Message + queueId string +} + +func (m messageWithQueueId) MessageQueueId() string { + return m.queueId +} + +func (m messageWithQueueId) DrpcMessage() drpc.Message { + return m.Message +} diff --git a/net/streampool/streampool_test.go b/net/streampool/streampool_test.go index 7f1a7fe4..8facd481 100644 --- a/net/streampool/streampool_test.go +++ b/net/streampool/streampool_test.go @@ -39,7 +39,7 @@ func TestStreamPool_AddStream(t *testing.T) { require.NoError(t, fx.AddStream(s2, "space2", "common")) require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space1"}, "space1")) - require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space2"}, "space2")) + require.NoError(t, fx.Broadcast(ctx, WithQueueId(&testservice.StreamMessage{ReqData: "space2"}, "q2"), "space2")) require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "common"}, "common")) var serverResults []string