diff --git a/commonspace/object/tree/synctree/queuedclient.go b/commonspace/object/tree/synctree/queuedclient.go deleted file mode 100644 index 49ea3922..00000000 --- a/commonspace/object/tree/synctree/queuedclient.go +++ /dev/null @@ -1,37 +0,0 @@ -package synctree - -import ( - "context" - "github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto" - "github.com/anytypeio/any-sync/commonspace/objectsync" -) - -type queuedClient struct { - SyncClient - queue objectsync.ActionQueue -} - -func newQueuedClient(client SyncClient, queue objectsync.ActionQueue) SyncClient { - return &queuedClient{ - SyncClient: client, - queue: queue, - } -} - -func (q *queuedClient) Broadcast(ctx context.Context, message *treechangeproto.TreeSyncMessage) (err error) { - return q.queue.Send(func() error { - return q.SyncClient.Broadcast(ctx, message) - }) -} - -func (q *queuedClient) SendWithReply(ctx context.Context, peerId string, message *treechangeproto.TreeSyncMessage, replyId string) (err error) { - return q.queue.Send(func() error { - return q.SyncClient.SendWithReply(ctx, peerId, message, replyId) - }) -} - -func (q *queuedClient) BroadcastAsyncOrSendResponsible(ctx context.Context, message *treechangeproto.TreeSyncMessage) (err error) { - return q.queue.Send(func() error { - return q.SyncClient.BroadcastAsyncOrSendResponsible(ctx, message) - }) -} diff --git a/commonspace/object/tree/synctree/synctree.go b/commonspace/object/tree/synctree/synctree.go index 2b7e27f6..b517397f 100644 --- a/commonspace/object/tree/synctree/synctree.go +++ b/commonspace/object/tree/synctree/synctree.go @@ -52,7 +52,7 @@ type syncTree struct { var log = logger.NewNamed("commonspace.synctree").Sugar() var buildObjectTree = objecttree.BuildObjectTree -var createSyncClient = newWrappedSyncClient +var createSyncClient = newSyncClient type BuildDeps struct { SpaceId string @@ -68,15 +68,6 @@ type BuildDeps struct { WaitTreeRemoteSync bool } -func newWrappedSyncClient( - spaceId string, - factory RequestFactory, - objectSync objectsync.ObjectSync, - configuration nodeconf.Configuration) SyncClient { - syncClient := newSyncClient(spaceId, objectSync.MessagePool(), factory, configuration) - return newQueuedClient(syncClient, objectSync.ActionQueue()) -} - func BuildSyncTreeOrGetRemote(ctx context.Context, id string, deps BuildDeps) (t SyncTree, err error) { getTreeRemote := func() (msg *treechangeproto.TreeSyncMessage, err error) { peerId, err := peer.CtxPeerId(ctx) @@ -182,8 +173,8 @@ func buildSyncTree(ctx context.Context, isFirstBuild bool, deps BuildDeps) (t Sy } syncClient := createSyncClient( deps.SpaceId, + deps.ObjectSync.MessagePool(), sharedFactory, - deps.ObjectSync, deps.Configuration) syncTree := &syncTree{ ObjectTree: objTree, diff --git a/commonspace/object/tree/synctree/synctree_test.go b/commonspace/object/tree/synctree/synctree_test.go index 09d27d25..8854a702 100644 --- a/commonspace/object/tree/synctree/synctree_test.go +++ b/commonspace/object/tree/synctree/synctree_test.go @@ -73,7 +73,7 @@ func Test_BuildSyncTree(t *testing.T) { updateListenerMock.EXPECT().Update(tr) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) - syncClientMock.EXPECT().BroadcastAsync(gomock.Eq(headUpdate)).Return(nil) + syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)).Return(nil) res, err := tr.AddRawChanges(ctx, payload) require.NoError(t, err) require.Equal(t, expectedRes, res) @@ -95,7 +95,7 @@ func Test_BuildSyncTree(t *testing.T) { updateListenerMock.EXPECT().Rebuild(tr) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) - syncClientMock.EXPECT().BroadcastAsync(gomock.Eq(headUpdate)).Return(nil) + syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)).Return(nil) res, err := tr.AddRawChanges(ctx, payload) require.NoError(t, err) require.Equal(t, expectedRes, res) @@ -133,7 +133,7 @@ func Test_BuildSyncTree(t *testing.T) { Return(expectedRes, nil) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) - syncClientMock.EXPECT().BroadcastAsync(gomock.Eq(headUpdate)).Return(nil) + syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)).Return(nil) res, err := tr.AddContent(ctx, content) require.NoError(t, err) require.Equal(t, expectedRes, res) diff --git a/commonspace/object/tree/synctree/synctreehandler_test.go b/commonspace/object/tree/synctree/synctreehandler_test.go index 061adc0a..9b2e1120 100644 --- a/commonspace/object/tree/synctree/synctreehandler_test.go +++ b/commonspace/object/tree/synctree/synctreehandler_test.go @@ -38,7 +38,7 @@ type syncHandlerFixture struct { ctrl *gomock.Controller syncClientMock *mock_synctree.MockSyncClient objectTreeMock *testObjTreeMock - receiveQueueMock *mock_synctree.MockReceiveQueue + receiveQueueMock ReceiveQueue syncHandler *syncTreeHandler } @@ -47,19 +47,19 @@ func newSyncHandlerFixture(t *testing.T) *syncHandlerFixture { ctrl := gomock.NewController(t) syncClientMock := mock_synctree.NewMockSyncClient(ctrl) objectTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl)) - receiveQueueMock := mock_synctree.NewMockReceiveQueue(ctrl) + receiveQueue := newReceiveQueue(5) syncHandler := &syncTreeHandler{ objTree: objectTreeMock, syncClient: syncClientMock, - queue: receiveQueueMock, + queue: receiveQueue, syncStatus: syncstatus.NewNoOpSyncStatus(), } return &syncHandlerFixture{ ctrl: ctrl, syncClientMock: syncClientMock, objectTreeMock: objectTreeMock, - receiveQueueMock: receiveQueueMock, + receiveQueueMock: receiveQueue, syncHandler: syncHandler, } } @@ -84,10 +84,7 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) { SnapshotPath: []string{"h1"}, } treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) - objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "") - - fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false) - fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil) + objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "") fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).Times(2) @@ -101,7 +98,6 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) { fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2", "h1"}) fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(true) - fx.receiveQueueMock.EXPECT().ClearQueue(senderId) err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) require.NoError(t, err) }) @@ -118,10 +114,8 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) { SnapshotPath: []string{"h1"}, } treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) - objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "") + objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "") fullRequest := &treechangeproto.TreeSyncMessage{} - fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false) - fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil) fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes() @@ -136,9 +130,8 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) { fx.syncClientMock.EXPECT(). CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). Return(fullRequest, nil) - fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq("")) + fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq("")) - fx.receiveQueueMock.EXPECT().ClearQueue(senderId) err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) require.NoError(t, err) }) @@ -155,14 +148,11 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) { SnapshotPath: []string{"h1"}, } treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) - objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "") - fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false) - fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil) + objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "") fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes() - fx.receiveQueueMock.EXPECT().ClearQueue(senderId) err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) require.NoError(t, err) }) @@ -179,19 +169,16 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) { SnapshotPath: []string{"h1"}, } treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) - objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "") + objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "") fullRequest := &treechangeproto.TreeSyncMessage{} - fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false) - fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil) fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes() fx.syncClientMock.EXPECT(). CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). Return(fullRequest, nil) - fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq("")) + fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullRequest), gomock.Eq("")) - fx.receiveQueueMock.EXPECT().ClearQueue(senderId) err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) require.NoError(t, err) }) @@ -208,14 +195,11 @@ func TestSyncHandler_HandleHeadUpdate(t *testing.T) { SnapshotPath: []string{"h1"}, } treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) - objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "") - fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false) - fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil) + objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "") fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes() - fx.receiveQueueMock.EXPECT().ClearQueue(senderId) err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) require.NoError(t, err) }) @@ -237,10 +221,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) { SnapshotPath: []string{"h1"}, } treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId) - objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "") + objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "") fullResponse := &treechangeproto.TreeSyncMessage{} - fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false) - fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil) fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.objectTreeMock.EXPECT().Header().Return(nil) @@ -255,9 +237,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) { fx.syncClientMock.EXPECT(). CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). Return(fullResponse, nil) - fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq("")) + fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq("")) - fx.receiveQueueMock.EXPECT().ClearQueue(senderId) err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) require.NoError(t, err) }) @@ -274,10 +255,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) { SnapshotPath: []string{"h1"}, } treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId) - objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "") + objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "") fullResponse := &treechangeproto.TreeSyncMessage{} - fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false) - fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil) fx.objectTreeMock.EXPECT(). Id().AnyTimes().Return(treeId) @@ -288,9 +267,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) { fx.syncClientMock.EXPECT(). CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). Return(fullResponse, nil) - fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq("")) + fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq("")) - fx.receiveQueueMock.EXPECT().ClearQueue(senderId) err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) require.NoError(t, err) }) @@ -307,10 +285,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) { SnapshotPath: []string{"h1"}, } treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId) - objectMsg, _ := marshallTreeMessage(treeMsg, treeId, replyId) + objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, replyId) fullResponse := &treechangeproto.TreeSyncMessage{} - fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), replyId).Return(false) - fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, replyId, nil) fx.objectTreeMock.EXPECT(). Id().AnyTimes().Return(treeId) @@ -318,9 +294,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) { fx.syncClientMock.EXPECT(). CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). Return(fullResponse, nil) - fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(replyId)) + fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(fullResponse), gomock.Eq(replyId)) - fx.receiveQueueMock.EXPECT().ClearQueue(senderId) err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) require.NoError(t, err) }) @@ -337,9 +312,7 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) { SnapshotPath: []string{"h1"}, } treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId) - objectMsg, _ := marshallTreeMessage(treeMsg, treeId, "") - fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), "").Return(false) - fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, "", nil) + objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, "") fx.objectTreeMock.EXPECT(). Id().AnyTimes().Return(treeId) @@ -356,9 +329,8 @@ func TestSyncHandler_HandleFullSyncRequest(t *testing.T) { RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId}, })). Return(objecttree.AddResult{}, fmt.Errorf("")) - fx.syncClientMock.EXPECT().SendAsync(gomock.Eq(senderId), gomock.Any(), gomock.Eq("")) + fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Any(), gomock.Eq("")) - fx.receiveQueueMock.EXPECT().ClearQueue(senderId) err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) require.Error(t, err) }) @@ -381,9 +353,7 @@ func TestSyncHandler_HandleFullSyncResponse(t *testing.T) { SnapshotPath: []string{"h1"}, } treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId) - objectMsg, _ := marshallTreeMessage(treeMsg, treeId, replyId) - fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), replyId).Return(false) - fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, replyId, nil) + objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, replyId) fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.objectTreeMock.EXPECT(). @@ -399,7 +369,6 @@ func TestSyncHandler_HandleFullSyncResponse(t *testing.T) { })). Return(objecttree.AddResult{}, nil) - fx.receiveQueueMock.EXPECT().ClearQueue(senderId) err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) require.NoError(t, err) }) @@ -417,16 +386,13 @@ func TestSyncHandler_HandleFullSyncResponse(t *testing.T) { SnapshotPath: []string{"h1"}, } treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId) - objectMsg, _ := marshallTreeMessage(treeMsg, treeId, replyId) - fx.receiveQueueMock.EXPECT().AddMessage(senderId, gomock.Eq(treeMsg), replyId).Return(false) - fx.receiveQueueMock.EXPECT().GetMessage(senderId).Return(treeMsg, replyId, nil) + objectMsg, _ := marshallTreeMessage(treeMsg, "spaceId", treeId, replyId) fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.objectTreeMock.EXPECT(). Heads(). Return([]string{"h1"}).AnyTimes() - fx.receiveQueueMock.EXPECT().ClearQueue(senderId) err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) require.NoError(t, err) }) diff --git a/commonspace/objectsync/actionqueue.go b/commonspace/objectsync/actionqueue.go deleted file mode 100644 index 65d5e3af..00000000 --- a/commonspace/objectsync/actionqueue.go +++ /dev/null @@ -1,78 +0,0 @@ -package objectsync - -import ( - "context" - "github.com/cheggaaa/mb/v3" - "go.uber.org/zap" -) - -type ActionFunc func() error - -type ActionQueue interface { - Send(action ActionFunc) (err error) - Run() - Close() -} - -type actionQueue struct { - batcher *mb.MB[ActionFunc] - maxReaders int - maxQueueLen int - readers chan struct{} -} - -func NewDefaultActionQueue() ActionQueue { - return NewActionQueue(10, 200) -} - -func NewActionQueue(maxReaders int, maxQueueLen int) ActionQueue { - return &actionQueue{ - batcher: mb.New[ActionFunc](maxQueueLen), - maxReaders: maxReaders, - maxQueueLen: maxQueueLen, - } -} - -func (q *actionQueue) Send(action ActionFunc) (err error) { - log.Debug("adding action to batcher") - err = q.batcher.TryAdd(action) - if err == nil { - return - } - log.With(zap.Error(err)).Debug("queue returned error") - actions := q.batcher.GetAll() - actions = append(actions[len(actions)/2:], action) - return q.batcher.Add(context.Background(), actions...) -} - -func (q *actionQueue) Run() { - log.Debug("running the queue") - q.readers = make(chan struct{}, q.maxReaders) - for i := 0; i < q.maxReaders; i++ { - go q.startReading() - } -} - -func (q *actionQueue) startReading() { - defer func() { - q.readers <- struct{}{} - }() - for { - action, err := q.batcher.WaitOne(context.Background()) - if err != nil { - return - } - err = action() - if err != nil { - log.With(zap.Error(err)).Debug("action errored out") - } - } -} - -func (q *actionQueue) Close() { - log.Debug("closing the queue") - q.batcher.Close() - for i := 0; i < q.maxReaders; i++ { - <-q.readers - } -} diff --git a/commonspace/objectsync/actionqueue_test.go b/commonspace/objectsync/actionqueue_test.go deleted file mode 100644 index eef4a952..00000000 --- a/commonspace/objectsync/actionqueue_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package objectsync - -import ( - "fmt" - "github.com/stretchr/testify/require" - "sync/atomic" - "testing" -) - -func TestActionQueue_Send(t *testing.T) { - maxReaders := 41 - maxLen := 93 - - queue := NewActionQueue(maxReaders, maxLen).(*actionQueue) - counter := atomic.Int32{} - expectedCounter := int32(maxReaders + (maxLen+1)/2 + 1) - blocker := make(chan struct{}, expectedCounter) - waiter := make(chan struct{}, expectedCounter) - increase := func() error { - counter.Add(1) - waiter <- struct{}{} - <-blocker - return nil - } - - queue.Run() - // sending maxReaders messages, so the goroutines will block on `blocker` channel - for i := 0; i < maxReaders; i++ { - queue.Send(increase) - } - // waiting until they all make progress - for i := 0; i < maxReaders; i++ { - <-waiter - } - fmt.Println(counter.Load()) - // check that queue is empty - require.Equal(t, queue.batcher.Len(), 0) - // making queue to overflow while readers are blocked - for i := 0; i < maxLen+1; i++ { - queue.Send(increase) - } - // check that queue was halved after overflow - require.Equal(t, (maxLen+1)/2+1, queue.batcher.Len()) - // unblocking maxReaders waiting + then we should also unblock the new readers to do a bit more readings - for i := 0; i < int(expectedCounter); i++ { - blocker <- struct{}{} - } - // waiting for all readers to finish adding - for i := 0; i < int(expectedCounter)-maxReaders; i++ { - <-waiter - } - queue.Close() - require.Equal(t, expectedCounter, counter.Load()) -} diff --git a/commonspace/objectsync/mock_objectsync/mock_objectsync.go b/commonspace/objectsync/mock_objectsync/mock_objectsync.go deleted file mode 100644 index f87f06f3..00000000 --- a/commonspace/objectsync/mock_objectsync/mock_objectsync.go +++ /dev/null @@ -1,73 +0,0 @@ -// Code generated by MockGen. DO NOT EDIT. -// Source: github.com/anytypeio/any-sync/commonspace/objectsync (interfaces: ActionQueue) - -// Package mock_objectsync is a generated GoMock package. -package mock_objectsync - -import ( - reflect "reflect" - - objectsync "github.com/anytypeio/any-sync/commonspace/objectsync" - gomock "github.com/golang/mock/gomock" -) - -// MockActionQueue is a mock of ActionQueue interface. -type MockActionQueue struct { - ctrl *gomock.Controller - recorder *MockActionQueueMockRecorder -} - -// MockActionQueueMockRecorder is the mock recorder for MockActionQueue. -type MockActionQueueMockRecorder struct { - mock *MockActionQueue -} - -// NewMockActionQueue creates a new mock instance. -func NewMockActionQueue(ctrl *gomock.Controller) *MockActionQueue { - mock := &MockActionQueue{ctrl: ctrl} - mock.recorder = &MockActionQueueMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockActionQueue) EXPECT() *MockActionQueueMockRecorder { - return m.recorder -} - -// Close mocks base method. -func (m *MockActionQueue) Close() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Close") -} - -// Close indicates an expected call of Close. -func (mr *MockActionQueueMockRecorder) Close() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockActionQueue)(nil).Close)) -} - -// Run mocks base method. -func (m *MockActionQueue) Run() { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Run") -} - -// Run indicates an expected call of Run. -func (mr *MockActionQueueMockRecorder) Run() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockActionQueue)(nil).Run)) -} - -// Send mocks base method. -func (m *MockActionQueue) Send(arg0 objectsync.ActionFunc) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Send", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Send indicates an expected call of Send. -func (mr *MockActionQueueMockRecorder) Send(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Send", reflect.TypeOf((*MockActionQueue)(nil).Send), arg0) -} diff --git a/commonspace/objectsync/msgpool.go b/commonspace/objectsync/msgpool.go index fc629e5e..0389b508 100644 --- a/commonspace/objectsync/msgpool.go +++ b/commonspace/objectsync/msgpool.go @@ -9,6 +9,7 @@ import ( "strings" "sync" "sync/atomic" + "time" ) type StreamManager interface { @@ -37,7 +38,6 @@ type messagePool struct { waiters map[string]responseWaiter waitersMx sync.Mutex counter atomic.Uint64 - queue ActionQueue } func newMessagePool(streamManager StreamManager, messageHandler MessageHandler) MessagePool { @@ -45,15 +45,17 @@ func newMessagePool(streamManager StreamManager, messageHandler MessageHandler) StreamManager: streamManager, messageHandler: messageHandler, waiters: make(map[string]responseWaiter), - queue: NewDefaultActionQueue(), } return s } func (s *messagePool) SendSync(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, time.Second*10) + defer cancel() newCounter := s.counter.Add(1) msg.ReplyId = genReplyKey(peerId, msg.ObjectId, newCounter) - + log.Info("mpool sendSync", zap.String("replyId", msg.ReplyId)) s.waitersMx.Lock() waiter := responseWaiter{ ch: make(chan *spacesyncproto.ObjectSyncMessage, 1), @@ -81,19 +83,14 @@ func (s *messagePool) SendSync(ctx context.Context, peerId string, msg *spacesyn func (s *messagePool) HandleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { if msg.ReplyId != "" { + log.Info("mpool receive reply", zap.String("replyId", msg.ReplyId)) // we got reply, send it to waiter if s.stopWaiter(msg) { return } log.With(zap.String("replyId", msg.ReplyId)).Debug("reply id does not exist") - return } - return s.queue.Send(func() error { - if e := s.messageHandler(ctx, senderId, msg); e != nil { - log.Info("handle message error", zap.Error(e)) - } - return nil - }) + return s.messageHandler(ctx, senderId, msg) } func (s *messagePool) stopWaiter(msg *spacesyncproto.ObjectSyncMessage) bool { diff --git a/commonspace/objectsync/objectsync.go b/commonspace/objectsync/objectsync.go index d9c9cc27..5a97cd1b 100644 --- a/commonspace/objectsync/objectsync.go +++ b/commonspace/objectsync/objectsync.go @@ -1,4 +1,3 @@ -//go:generate mockgen -destination mock_objectsync/mock_objectsync.go github.com/anytypeio/any-sync/commonspace/objectsync ActionQueue package objectsync import ( @@ -18,58 +17,57 @@ type ObjectSync interface { ocache.ObjectLastUsage synchandler.SyncHandler MessagePool() MessagePool - ActionQueue() ActionQueue - Init(getter syncobjectgetter.SyncObjectGetter) + Init() Close() (err error) } type objectSync struct { spaceId string - streamPool MessagePool + messagePool MessagePool objectGetter syncobjectgetter.SyncObjectGetter - actionQueue ActionQueue syncCtx context.Context cancelSync context.CancelFunc } -func NewObjectSync(streamManager StreamManager, spaceId string) (objectSync ObjectSync) { - msgPool := newMessagePool(streamManager, func(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) { - return objectSync.HandleMessage(ctx, senderId, message) - }) +func NewObjectSync( + spaceId string, + streamManager StreamManager, + objectGetter syncobjectgetter.SyncObjectGetter) ObjectSync { syncCtx, cancel := context.WithCancel(context.Background()) - objectSync = newObjectSync( + os := newObjectSync( spaceId, - msgPool, + objectGetter, syncCtx, cancel) - return + msgPool := newMessagePool(streamManager, os.handleMessage) + os.messagePool = msgPool + return os } func newObjectSync( spaceId string, - streamPool MessagePool, + objectGetter syncobjectgetter.SyncObjectGetter, syncCtx context.Context, cancel context.CancelFunc, ) *objectSync { return &objectSync{ - streamPool: streamPool, - spaceId: spaceId, - syncCtx: syncCtx, - cancelSync: cancel, - actionQueue: NewDefaultActionQueue(), + objectGetter: objectGetter, + spaceId: spaceId, + syncCtx: syncCtx, + cancelSync: cancel, + //actionQueue: NewDefaultActionQueue(), } } -func (s *objectSync) Init(objectGetter syncobjectgetter.SyncObjectGetter) { - s.objectGetter = objectGetter - s.actionQueue.Run() +func (s *objectSync) Init() { + //s.actionQueue.Run() } func (s *objectSync) Close() (err error) { - s.actionQueue.Close() + //s.actionQueue.Close() s.cancelSync() return } @@ -80,7 +78,11 @@ func (s *objectSync) LastUsage() time.Time { } func (s *objectSync) HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) { - log.With(zap.String("peerId", senderId), zap.String("objectId", message.ObjectId)).Debug("handling message") + return s.messagePool.HandleMessage(ctx, senderId, message) +} + +func (s *objectSync) handleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) { + log.With(zap.String("peerId", senderId), zap.String("objectId", message.ObjectId), zap.String("replyId", message.ReplyId)).Debug("handling message") obj, err := s.objectGetter.GetObject(ctx, message.ObjectId) if err != nil { return @@ -89,9 +91,5 @@ func (s *objectSync) HandleMessage(ctx context.Context, senderId string, message } func (s *objectSync) MessagePool() MessagePool { - return s.streamPool -} - -func (s *objectSync) ActionQueue() ActionQueue { - return s.actionQueue + return s.messagePool } diff --git a/net/rpc/server/baseserver.go b/net/rpc/server/baseserver.go index 0f0138a2..4d8e023a 100644 --- a/net/rpc/server/baseserver.go +++ b/net/rpc/server/baseserver.go @@ -78,8 +78,8 @@ func (s *BaseDrpcServer) serve(ctx context.Context, lis secureservice.ContextLis } continue } - if _, ok := err.(secureservice.HandshakeError); ok { - l.Warn("listener handshake error", zap.Error(err)) + if herr, ok := err.(secureservice.HandshakeError); ok { + l.Warn("listener handshake error", zap.Error(herr), zap.String("remoteAddr", herr.RemoteAddr())) continue } l.Error("listener accept error", zap.Error(err)) diff --git a/net/secureservice/secureservice.go b/net/secureservice/secureservice.go index 3c9bc068..f55ad8da 100644 --- a/net/secureservice/secureservice.go +++ b/net/secureservice/secureservice.go @@ -12,7 +12,18 @@ import ( "net" ) -type HandshakeError error +type HandshakeError struct { + remoteAddr string + err error +} + +func (he HandshakeError) RemoteAddr() string { + return he.remoteAddr +} + +func (he HandshakeError) Error() string { + return he.err.Error() +} const CName = "common.net.secure" diff --git a/net/secureservice/tlslistener.go b/net/secureservice/tlslistener.go index ccf9da2d..867ced26 100644 --- a/net/secureservice/tlslistener.go +++ b/net/secureservice/tlslistener.go @@ -49,7 +49,10 @@ func (p *tlsListener) Accept(ctx context.Context) (context.Context, net.Conn, er func (p *tlsListener) upgradeConn(ctx context.Context, conn net.Conn) (context.Context, net.Conn, error) { secure, err := p.tr.SecureInbound(ctx, conn, "") if err != nil { - return nil, nil, HandshakeError(err) + return nil, nil, HandshakeError{ + remoteAddr: conn.RemoteAddr().String(), + err: err, + } } ctx = peer.CtxWithPeerId(ctx, secure.RemotePeer().String()) return ctx, secure, nil diff --git a/net/streampool/stream.go b/net/streampool/stream.go index 41e9307a..86d957f5 100644 --- a/net/streampool/stream.go +++ b/net/streampool/stream.go @@ -1,7 +1,6 @@ package streampool import ( - "fmt" "go.uber.org/zap" "storj.io/drpc" "sync/atomic" @@ -18,11 +17,7 @@ type stream struct { } func (sr *stream) write(msg drpc.Message) (err error) { - defer func() { - sr.l.Debug("write", zap.String("msg", msg.(fmt.Stringer).String()), zap.Error(err)) - }() if err = sr.stream.MsgSend(msg, EncodingProto); err != nil { - sr.l.Info("stream write error", zap.Error(err)) sr.streamClose() } return err @@ -38,8 +33,7 @@ func (sr *stream) readLoop() error { sr.l.Info("msg receive error", zap.Error(err)) return err } - sr.l.Debug("read msg", zap.String("msg", msg.(fmt.Stringer).String())) - if err := sr.pool.handler.HandleMessage(sr.stream.Context(), sr.peerId, msg); err != nil { + if err := sr.pool.HandleMessage(sr.stream.Context(), sr.peerId, msg); err != nil { sr.l.Info("msg handle error", zap.Error(err)) return err } diff --git a/net/streampool/streampool.go b/net/streampool/streampool.go index 5957e056..f1aa486b 100644 --- a/net/streampool/streampool.go +++ b/net/streampool/streampool.go @@ -1,9 +1,9 @@ package streampool import ( - "fmt" "github.com/anytypeio/any-sync/net/peer" "github.com/anytypeio/any-sync/net/pool" + "github.com/cheggaaa/mb/v3" "go.uber.org/zap" "golang.org/x/exp/slices" "golang.org/x/net/context" @@ -42,12 +42,30 @@ type streamPool struct { streamIdsByPeer map[string][]uint32 streamIdsByTag map[string][]uint32 streams map[uint32]*stream - opening map[string]chan struct{} + opening map[string]*openingProcess exec *sendPool + handleQueue *mb.MB[handleMessage] mu sync.RWMutex lastStreamId uint32 } +type openingProcess struct { + ch chan struct{} + err error +} +type handleMessage struct { + ctx context.Context + msg drpc.Message + peerId string +} + +func (s *streamPool) init() { + // TODO: to config + for i := 0; i < 10; i++ { + go s.handleMessageLoop() + } +} + func (s *streamPool) ReadStream(peerId string, drpcStream drpc.Stream, tags ...string) error { st := s.addStream(peerId, drpcStream, tags...) return st.readLoop() @@ -78,7 +96,6 @@ func (s *streamPool) addStream(peerId string, drpcStream drpc.Stream, tags ...st for _, tag := range tags { s.streamIdsByTag[tag] = append(s.streamIdsByTag[tag], streamId) } - st.l.Debug("stream added", zap.Strings("tags", st.tags)) return st } @@ -87,7 +104,7 @@ func (s *streamPool) Send(ctx context.Context, msg drpc.Message, peers ...peer.P for _, p := range peers { funcs = append(funcs, func() { if e := s.sendOne(ctx, p, msg); e != nil { - log.Info("send peer error", zap.Error(e)) + log.Info("send peer error", zap.Error(e), zap.String("peerId", p.Id())) } }) } @@ -103,12 +120,11 @@ func (s *streamPool) SendById(ctx context.Context, msg drpc.Message, peerIds ... } } s.mu.Unlock() - log.Debug("sendById", zap.String("msg", msg.(fmt.Stringer).String()), zap.Int("streams", len(streams))) var funcs []func() for _, st := range streams { funcs = append(funcs, func() { if e := st.write(msg); e != nil { - log.Debug("sendById write error", zap.Error(e)) + st.l.Debug("sendById write error", zap.Error(e)) } }) } @@ -126,7 +142,7 @@ func (s *streamPool) sendOne(ctx context.Context, p peer.Peer, msg drpc.Message) } for _, st := range streams { if err = st.write(msg); err != nil { - log.Info("stream write error", zap.Error(err)) + st.l.Info("sendOne write error", zap.Error(err)) // continue with next stream continue } else { @@ -144,18 +160,21 @@ func (s *streamPool) getStreams(ctx context.Context, p peer.Peer) (streams []*st for _, streamId := range streamIds { streams = append(streams, s.streams[streamId]) } - var openingCh chan struct{} + var op *openingProcess // no cached streams found if len(streams) == 0 { // start opening process - openingCh = s.openStream(ctx, p) + op = s.openStream(ctx, p) } s.mu.Unlock() // not empty openingCh means we should wait for the stream opening and try again - if openingCh != nil { + if op != nil { select { - case <-openingCh: + case <-op.ch: + if op.err != nil { + return nil, op.err + } return s.getStreams(ctx, p) case <-ctx.Done(): return nil, ctx.Err() @@ -164,30 +183,32 @@ func (s *streamPool) getStreams(ctx context.Context, p peer.Peer) (streams []*st return streams, nil } -func (s *streamPool) openStream(ctx context.Context, p peer.Peer) chan struct{} { - if ch, ok := s.opening[p.Id()]; ok { +func (s *streamPool) openStream(ctx context.Context, p peer.Peer) *openingProcess { + if op, ok := s.opening[p.Id()]; ok { // already have an opening process for this stream - return channel - return ch + return op } - ch := make(chan struct{}) - s.opening[p.Id()] = ch + op := &openingProcess{ + ch: make(chan struct{}), + } + s.opening[p.Id()] = op go func() { // start stream opening in separate goroutine to avoid lock whole pool defer func() { s.mu.Lock() defer s.mu.Unlock() - close(ch) + close(op.ch) delete(s.opening, p.Id()) }() // open new stream and add to pool st, tags, err := s.handler.OpenStream(ctx, p) if err != nil { - log.Warn("stream open error", zap.Error(err)) + op.err = err return } s.AddStream(p.Id(), st, tags...) }() - return ch + return op } func (s *streamPool) Broadcast(ctx context.Context, msg drpc.Message, tags ...string) (err error) { @@ -244,6 +265,28 @@ func (s *streamPool) removeStream(streamId uint32) { st.l.Debug("stream removed", zap.Strings("tags", st.tags)) } +func (s *streamPool) HandleMessage(ctx context.Context, peerId string, msg drpc.Message) (err error) { + return s.handleQueue.Add(ctx, handleMessage{ + ctx: ctx, + msg: msg, + peerId: peerId, + }) +} + +func (s *streamPool) handleMessageLoop() { + for { + hm, err := s.handleQueue.WaitOne(context.Background()) + if err != nil { + return + } + go func() { + if err = s.handler.HandleMessage(hm.ctx, hm.peerId, hm.msg); err != nil { + log.Warn("handle message error", zap.Error(err)) + } + }() + } +} + func (s *streamPool) Close() (err error) { return s.exec.Close() } diff --git a/net/streampool/streampool_test.go b/net/streampool/streampool_test.go index 7722173e..f6e478c1 100644 --- a/net/streampool/streampool_test.go +++ b/net/streampool/streampool_test.go @@ -18,22 +18,23 @@ import ( var ctx = context.Background() +func newClientStream(t *testing.T, fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) { + p, err := fx.tp.Dial(ctx, peerId) + require.NoError(t, err) + s, err := testservice.NewDRPCTestClient(p).TestStream(ctx) + require.NoError(t, err) + return s, p +} + func TestStreamPool_AddStream(t *testing.T) { - newClientStream := func(fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) { - p, err := fx.tp.Dial(ctx, peerId) - require.NoError(t, err) - s, err := testservice.NewDRPCTestClient(p).TestStream(ctx) - require.NoError(t, err) - return s, p - } t.Run("broadcast incoming", func(t *testing.T) { fx := newFixture(t) defer fx.Finish(t) - s1, _ := newClientStream(fx, "p1") + s1, _ := newClientStream(t, fx, "p1") fx.AddStream("p1", s1, "space1", "common") - s2, _ := newClientStream(fx, "p2") + s2, _ := newClientStream(t, fx, "p2") fx.AddStream("p2", s2, "space2", "common") require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space1"}, "space1")) @@ -61,7 +62,7 @@ func TestStreamPool_AddStream(t *testing.T) { fx := newFixture(t) defer fx.Finish(t) - s1, p1 := newClientStream(fx, "p1") + s1, p1 := newClientStream(t, fx, "p1") defer s1.Close() fx.AddStream("p1", s1, "space1", "common") @@ -122,6 +123,46 @@ func TestStreamPool_Send(t *testing.T) { // make sure that we have only one stream assert.Equal(t, int32(1), fx.tsh.streamsCount.Load()) }) + t.Run("parallel open stream error", func(t *testing.T) { + fx := newFixture(t) + defer fx.Finish(t) + + p, err := fx.tp.Dial(ctx, "p1") + require.NoError(t, err) + _ = p.Close() + + fx.th.streamOpenDelay = time.Second / 3 + + var numMsgs = 5 + + var wg sync.WaitGroup + for i := 0; i < numMsgs; i++ { + wg.Add(1) + go func() { + defer wg.Done() + assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, p, &testservice.StreamMessage{ReqData: "should open stream"})) + }() + } + wg.Wait() + }) +} + +func TestStreamPool_SendById(t *testing.T) { + fx := newFixture(t) + defer fx.Finish(t) + + s1, _ := newClientStream(t, fx, "p1") + defer s1.Close() + fx.AddStream("p1", s1, "space1", "common") + + require.NoError(t, fx.SendById(ctx, &testservice.StreamMessage{ReqData: "test"}, "p1")) + var msg *testservice.StreamMessage + select { + case msg = <-fx.tsh.receiveCh: + case <-time.After(time.Second): + require.NoError(t, fmt.Errorf("timeout")) + } + assert.Equal(t, "test", msg.ReqData) } func newFixture(t *testing.T) *fixture { diff --git a/net/streampool/streampoolservice.go b/net/streampool/streampoolservice.go index 198e5687..c7cc0c8d 100644 --- a/net/streampool/streampoolservice.go +++ b/net/streampool/streampoolservice.go @@ -3,6 +3,7 @@ package streampool import ( "github.com/anytypeio/any-sync/app" "github.com/anytypeio/any-sync/app/logger" + "github.com/cheggaaa/mb/v3" ) const CName = "common.net.streampool" @@ -22,15 +23,17 @@ type service struct { } func (s *service) NewStreamPool(h StreamHandler) StreamPool { - return &streamPool{ + sp := &streamPool{ handler: h, streamIdsByPeer: map[string][]uint32{}, streamIdsByTag: map[string][]uint32{}, streams: map[uint32]*stream{}, - opening: map[string]chan struct{}{}, + opening: map[string]*openingProcess{}, exec: newStreamSender(10, 100), - lastStreamId: 0, + handleQueue: mb.New[handleMessage](100), } + sp.init() + return sp } func (s *service) Init(a *app.App) (err error) {