diff --git a/common/commonspace/syncservice/mock_syncservice/mock_syncservice.go b/common/commonspace/syncservice/mock_syncservice/mock_syncservice.go new file mode 100644 index 00000000..8438ab64 --- /dev/null +++ b/common/commonspace/syncservice/mock_syncservice/mock_syncservice.go @@ -0,0 +1,78 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice (interfaces: SyncClient) + +// Package mock_syncservice is a generated GoMock package. +package mock_syncservice + +import ( + reflect "reflect" + + spacesyncproto "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto" + gomock "github.com/golang/mock/gomock" +) + +// MockSyncClient is a mock of SyncClient interface. +type MockSyncClient struct { + ctrl *gomock.Controller + recorder *MockSyncClientMockRecorder +} + +// MockSyncClientMockRecorder is the mock recorder for MockSyncClient. +type MockSyncClientMockRecorder struct { + mock *MockSyncClient +} + +// NewMockSyncClient creates a new mock instance. +func NewMockSyncClient(ctrl *gomock.Controller) *MockSyncClient { + mock := &MockSyncClient{ctrl: ctrl} + mock.recorder = &MockSyncClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSyncClient) EXPECT() *MockSyncClientMockRecorder { + return m.recorder +} + +// BroadcastAsync mocks base method. +func (m *MockSyncClient) BroadcastAsync(arg0 *spacesyncproto.ObjectSyncMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BroadcastAsync", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// BroadcastAsync indicates an expected call of BroadcastAsync. +func (mr *MockSyncClientMockRecorder) BroadcastAsync(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BroadcastAsync", reflect.TypeOf((*MockSyncClient)(nil).BroadcastAsync), arg0) +} + +// SendAsync mocks base method. +func (m *MockSyncClient) SendAsync(arg0 string, arg1 *spacesyncproto.ObjectSyncMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendAsync", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendAsync indicates an expected call of SendAsync. +func (mr *MockSyncClientMockRecorder) SendAsync(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendAsync", reflect.TypeOf((*MockSyncClient)(nil).SendAsync), arg0, arg1) +} + +// SendSync mocks base method. +func (m *MockSyncClient) SendSync(arg0 string, arg1 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendSync", arg0, arg1) + ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SendSync indicates an expected call of SendSync. +func (mr *MockSyncClientMockRecorder) SendSync(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSync", reflect.TypeOf((*MockSyncClient)(nil).SendSync), arg0, arg1) +} diff --git a/common/commonspace/syncservice/synchandler.go b/common/commonspace/syncservice/synchandler.go index 3a0da9fd..0123cbff 100644 --- a/common/commonspace/syncservice/synchandler.go +++ b/common/commonspace/syncservice/synchandler.go @@ -1,3 +1,4 @@ +//go:generate mockgen -destination mock_syncservice/mock_syncservice.go github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice SyncClient package syncservice import ( @@ -31,16 +32,16 @@ func (s *syncHandler) HandleMessage(ctx context.Context, senderId string, msg *s content := msg.GetContent() switch { case content.GetFullSyncRequest() != nil: - return s.HandleFullSyncRequest(ctx, senderId, content.GetFullSyncRequest(), msg) + return s.handleFullSyncRequest(ctx, senderId, content.GetFullSyncRequest(), msg) case content.GetFullSyncResponse() != nil: - return s.HandleFullSyncResponse(ctx, senderId, content.GetFullSyncResponse(), msg) + return s.handleFullSyncResponse(ctx, senderId, content.GetFullSyncResponse(), msg) case content.GetHeadUpdate() != nil: - return s.HandleHeadUpdate(ctx, senderId, content.GetHeadUpdate(), msg) + return s.handleHeadUpdate(ctx, senderId, content.GetHeadUpdate(), msg) } return nil } -func (s *syncHandler) HandleHeadUpdate( +func (s *syncHandler) handleHeadUpdate( ctx context.Context, senderId string, update *spacesyncproto.ObjectHeadUpdate, @@ -70,12 +71,14 @@ func (s *syncHandler) HandleHeadUpdate( return err } - // if we couldn't add all the changes - if len(update.Changes) != len(result.Added) { - fullRequest, err = s.prepareFullSyncRequest(objTree, update) - if err != nil { - return err - } + // if after the heads are equal, or we have them locally + if slice.UnsortedEquals(update.Heads, result.Heads) || objTree.HasChanges(update.Heads...) { + return nil + } + + fullRequest, err = s.prepareFullSyncRequest(objTree, update) + if err != nil { + return err } return nil }() @@ -87,7 +90,7 @@ func (s *syncHandler) HandleHeadUpdate( return } -func (s *syncHandler) HandleFullSyncRequest( +func (s *syncHandler) handleFullSyncRequest( ctx context.Context, senderId string, request *spacesyncproto.ObjectFullSyncRequest, @@ -133,7 +136,7 @@ func (s *syncHandler) HandleFullSyncRequest( spacesyncproto.WrapFullResponse(fullResponse, header, msg.TreeId, msg.TrackingId)) } -func (s *syncHandler) HandleFullSyncResponse( +func (s *syncHandler) handleFullSyncResponse( ctx context.Context, senderId string, response *spacesyncproto.ObjectFullSyncResponse, diff --git a/common/commonspace/syncservice/synchandler_test.go b/common/commonspace/syncservice/synchandler_test.go new file mode 100644 index 00000000..dc145017 --- /dev/null +++ b/common/commonspace/syncservice/synchandler_test.go @@ -0,0 +1,58 @@ +package syncservice + +import ( + "context" + "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/cache" + "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/cache/mock_cache" + "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto" + "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/syncservice/mock_syncservice" + "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/tree" + mock_tree "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/tree/mock_objecttree" + "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/treechangeproto" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "testing" +) + +type treeContainer struct { + objTree tree.ObjectTree +} + +func (t treeContainer) Tree() tree.ObjectTree { + return t.objTree +} + +func TestSyncHandler_HandleMessage(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ctx := context.Background() + spaceId := "spaceId" + cacheMock := mock_cache.NewMockTreeCache(ctrl) + syncClientMock := mock_syncservice.NewMockSyncClient(ctrl) + objectTreeMock := mock_tree.NewMockObjectTree(ctrl) + + syncHandler := newSyncHandler(spaceId, cacheMock, syncClientMock) + treeId := "treeId" + senderId := "senderId" + chWithId := &treechangeproto.RawTreeChangeWithId{} + headUpdate := &spacesyncproto.ObjectHeadUpdate{ + Heads: []string{"h1"}, + Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, + SnapshotPath: []string{"h1"}, + } + msg := spacesyncproto.WrapHeadUpdate(headUpdate, chWithId, treeId, "") + cacheMock.EXPECT(). + GetTree(gomock.Any(), spaceId, treeId). + Return(cache.TreeResult{ + Release: func() {}, + TreeContainer: treeContainer{objectTreeMock}, + }, nil) + objectTreeMock.EXPECT().Lock() + objectTreeMock.EXPECT().Heads().Return([]string{"h2"}) + objectTreeMock.EXPECT().AddRawChanges(gomock.Any(), gomock.Eq([]*treechangeproto.RawTreeChangeWithId{chWithId})). + Return(tree.AddResult{}, nil) + objectTreeMock.EXPECT().Unlock() + err := syncHandler.HandleMessage(ctx, senderId, msg) + require.NoError(t, err) +} diff --git a/pkg/acl/storage/inmemory.go b/pkg/acl/storage/inmemory.go index 6ddbea57..487caa04 100644 --- a/pkg/acl/storage/inmemory.go +++ b/pkg/acl/storage/inmemory.go @@ -87,6 +87,11 @@ func NewInMemoryTreeStorage( }, nil } +func (t *inMemoryTreeStorage) HasChange(ctx context.Context, id string) (bool, error) { + _, exists := t.changes[id] + return exists, nil +} + func (t *inMemoryTreeStorage) ID() (string, error) { t.RLock() defer t.RUnlock() diff --git a/pkg/acl/storage/mock_storage/mock_storage.go b/pkg/acl/storage/mock_storage/mock_storage.go index 5fff6944..f20bf468 100644 --- a/pkg/acl/storage/mock_storage/mock_storage.go +++ b/pkg/acl/storage/mock_storage/mock_storage.go @@ -162,6 +162,21 @@ func (mr *MockTreeStorageMockRecorder) GetRawChange(arg0, arg1 interface{}) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRawChange", reflect.TypeOf((*MockTreeStorage)(nil).GetRawChange), arg0, arg1) } +// HasChange mocks base method. +func (m *MockTreeStorage) HasChange(arg0 context.Context, arg1 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasChange", arg0, arg1) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HasChange indicates an expected call of HasChange. +func (mr *MockTreeStorageMockRecorder) HasChange(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasChange", reflect.TypeOf((*MockTreeStorage)(nil).HasChange), arg0, arg1) +} + // Heads mocks base method. func (m *MockTreeStorage) Heads() ([]string, error) { m.ctrl.T.Helper() diff --git a/pkg/acl/storage/treestorage.go b/pkg/acl/storage/treestorage.go index a9c98f1d..25b91e46 100644 --- a/pkg/acl/storage/treestorage.go +++ b/pkg/acl/storage/treestorage.go @@ -13,6 +13,7 @@ type TreeStorage interface { AddRawChange(change *treechangeproto.RawTreeChangeWithId) error GetRawChange(ctx context.Context, id string) (*treechangeproto.RawTreeChangeWithId, error) + HasChange(ctx context.Context, id string) (bool, error) } type TreeStorageCreatorFunc = func(payload TreeStorageCreatePayload) (TreeStorage, error) diff --git a/pkg/acl/tree/mock_objecttree/mock_objecttree.go b/pkg/acl/tree/mock_objecttree/mock_objecttree.go index a4eac719..70afa9d4 100644 --- a/pkg/acl/tree/mock_objecttree/mock_objecttree.go +++ b/pkg/acl/tree/mock_objecttree/mock_objecttree.go @@ -116,18 +116,22 @@ func (mr *MockObjectTreeMockRecorder) DebugDump() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugDump", reflect.TypeOf((*MockObjectTree)(nil).DebugDump)) } -// HasChange mocks base method. -func (m *MockObjectTree) HasChange(arg0 string) bool { +// HasChanges mocks base method. +func (m *MockObjectTree) HasChanges(arg0 ...string) bool { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "HasChange", arg0) + varargs := []interface{}{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "HasChanges", varargs...) ret0, _ := ret[0].(bool) return ret0 } -// HasChange indicates an expected call of HasChange. -func (mr *MockObjectTreeMockRecorder) HasChange(arg0 interface{}) *gomock.Call { +// HasChanges indicates an expected call of HasChanges. +func (mr *MockObjectTreeMockRecorder) HasChanges(arg0 ...interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasChange", reflect.TypeOf((*MockObjectTree)(nil).HasChange), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasChanges", reflect.TypeOf((*MockObjectTree)(nil).HasChanges), arg0...) } // Header mocks base method. diff --git a/pkg/acl/tree/objecttree.go b/pkg/acl/tree/objecttree.go index 47305d23..9e6323aa 100644 --- a/pkg/acl/tree/objecttree.go +++ b/pkg/acl/tree/objecttree.go @@ -43,7 +43,7 @@ type ObjectTree interface { Header() *treechangeproto.RawTreeChangeWithId Heads() []string Root() *Change - HasChange(string) bool + HasChanges(...string) bool DebugDump() (string, error) Iterate(convert ChangeConvertFunc, iterate ChangeIterateFunc) error @@ -76,7 +76,7 @@ type objectTree struct { // buffers difSnapshotBuf []*treechangeproto.RawTreeChangeWithId - tmpChangesBuf []*Change + newChangesBuf []*Change newSnapshotsBuf []*Change notSeenIdxBuf []int @@ -227,7 +227,7 @@ func (ot *objectTree) AddRawChanges(ctx context.Context, rawChanges ...*treechan func (ot *objectTree) addRawChanges(ctx context.Context, rawChanges ...*treechangeproto.RawTreeChangeWithId) (addResult AddResult, err error) { // resetting buffers - ot.tmpChangesBuf = ot.tmpChangesBuf[:0] + ot.newChangesBuf = ot.newChangesBuf[:0] ot.notSeenIdxBuf = ot.notSeenIdxBuf[:0] ot.difSnapshotBuf = ot.difSnapshotBuf[:0] ot.newSnapshotsBuf = ot.newSnapshotsBuf[:0] @@ -247,20 +247,21 @@ func (ot *objectTree) addRawChanges(ctx context.Context, rawChanges ...*treechan if _, exists := ot.tree.attached[ch.Id]; exists { continue } - if _, exists := ot.tree.unAttached[ch.Id]; exists { - continue - } var change *Change - change, err = ot.changeBuilder.ConvertFromRaw(ch, true) - if err != nil { - return + if unAttached, exists := ot.tree.unAttached[ch.Id]; exists { + change = unAttached + } else { + change, err = ot.changeBuilder.ConvertFromRaw(ch, true) + if err != nil { + return + } } if change.IsSnapshot { ot.newSnapshotsBuf = append(ot.newSnapshotsBuf, change) } - ot.tmpChangesBuf = append(ot.tmpChangesBuf, change) + ot.newChangesBuf = append(ot.newChangesBuf, change) ot.notSeenIdxBuf = append(ot.notSeenIdxBuf, idx) } @@ -274,6 +275,106 @@ func (ot *objectTree) addRawChanges(ctx context.Context, rawChanges ...*treechan return } + rollback := func(changes []*Change) { + for _, ch := range changes { + if _, exists := ot.tree.attached[ch.Id]; exists { + delete(ot.tree.attached, ch.Id) + } + } + } + + // checks if we need to go to database + isOldSnapshot := func(ch *Change) bool { + if ch.SnapshotId == ot.tree.RootId() { + return false + } + for _, sn := range ot.newSnapshotsBuf { + // if change refers to newly received snapshot + if ch.SnapshotId == sn.Id { + return false + } + } + return true + } + + shouldRebuildFromStorage := false + // checking if we have some changes with different snapshot and then rebuilding + for idx, ch := range ot.newChangesBuf { + if isOldSnapshot(ch) { + var exists bool + // checking if it exists in the storage, if yes, then at some point it was added to the tree + // thus we don't need to look at this change + exists, err = ot.treeStorage.HasChange(ctx, ch.Id) + if err != nil { + return + } + if exists { + // marking as nil to delete after + ot.newChangesBuf[idx] = nil + continue + } + // we haven't seen the change, and it refers to old snapshot, so we should rebuild + shouldRebuildFromStorage = true + } + } + // discarding all previously seen changes + ot.newChangesBuf = discardFromSlice(ot.newChangesBuf, func(ch *Change) bool { return ch == nil }) + + if shouldRebuildFromStorage { + err = ot.rebuildFromStorage(ot.newChangesBuf) + if err != nil { + // rebuilding without new changes + ot.rebuildFromStorage(nil) + return + } + addResult, err = ot.createAddResult(prevHeadsCopy, Rebuild, nil, rawChanges) + if err != nil { + // that means that some unattached changes were somehow corrupted in memory + // this shouldn't happen but if that happens, then rebuilding from storage + ot.rebuildFromStorage(nil) + return + } + return + } + + // normal mode of operation, where we don't need to rebuild from database + mode, treeChangesAdded := ot.tree.Add(ot.newChangesBuf...) + switch mode { + case Nothing: + addResult = AddResult{ + OldHeads: prevHeadsCopy, + Heads: prevHeadsCopy, + Mode: mode, + } + return + + default: + // we need to validate only newly added changes + err = ot.validateTree(treeChangesAdded) + if err != nil { + rollback(treeChangesAdded) + err = ErrHasInvalidChanges + return + } + addResult, err = ot.createAddResult(prevHeadsCopy, mode, treeChangesAdded, rawChanges) + if err != nil { + // that means that some unattached changes were somehow corrupted in memory + // this shouldn't happen but if that happens, then rebuilding from storage + ot.rebuildFromStorage(nil) + return + } + return + } + return +} + +func (ot *objectTree) createAddResult(oldHeads []string, mode Mode, treeChangesAdded []*Change, rawChanges []*treechangeproto.RawTreeChangeWithId) (addResult AddResult, err error) { + headsCopy := func() []string { + newHeads := make([]string, 0, len(ot.tree.Heads())) + newHeads = append(newHeads, ot.tree.Heads()...) + return newHeads + } + // returns changes that we added to the tree as attached this round // they can include not only the changes that were added now, // but also the changes that were previously in the tree @@ -313,88 +414,16 @@ func (ot *objectTree) addRawChanges(ctx context.Context, rawChanges ...*treechan return } - rollback := func(changes []*Change) { - for _, ch := range changes { - if _, exists := ot.tree.attached[ch.Id]; exists { - delete(ot.tree.attached, ch.Id) - } - } - } - - // checks if we need to go to database - isOldSnapshot := func(ch *Change) bool { - if ch.SnapshotId == ot.tree.RootId() { - return false - } - for _, sn := range ot.newSnapshotsBuf { - // if change refers to newly received snapshot - if ch.SnapshotId == sn.Id { - return false - } - } - return true - } - - // checking if we have some changes with different snapshot and then rebuilding - for _, ch := range ot.tmpChangesBuf { - if isOldSnapshot(ch) { - err = ot.rebuildFromStorage(ot.tmpChangesBuf) - if err != nil { - // rebuilding without new changes - ot.rebuildFromStorage(nil) - return - } - var added []*treechangeproto.RawTreeChangeWithId - added, err = getAddedChanges(nil) - // we shouldn't get any error in this case - if err != nil { - panic(err) - } - - addResult = AddResult{ - OldHeads: prevHeadsCopy, - Heads: headsCopy(), - Added: added, - Mode: Rebuild, - } - return - } - } - - // normal mode of operation, where we don't need to rebuild from database - mode, treeChangesAdded := ot.tree.Add(ot.tmpChangesBuf...) - switch mode { - case Nothing: - addResult = AddResult{ - OldHeads: prevHeadsCopy, - Heads: prevHeadsCopy, - Mode: mode, - } + var added []*treechangeproto.RawTreeChangeWithId + added, err = getAddedChanges(treeChangesAdded) + if err != nil { return - - default: - // we need to validate only newly added changes - err = ot.validateTree(treeChangesAdded) - if err != nil { - rollback(treeChangesAdded) - err = ErrHasInvalidChanges - return - } - var added []*treechangeproto.RawTreeChangeWithId - added, err = getAddedChanges(treeChangesAdded) - if err != nil { - // that means that some unattached changes were somehow corrupted in memory - // this shouldn't happen but if that happens, then rebuilding from storage - ot.rebuildFromStorage(nil) - return - } - - addResult = AddResult{ - OldHeads: prevHeadsCopy, - Heads: headsCopy(), - Added: added, - Mode: mode, - } + } + addResult = AddResult{ + OldHeads: oldHeads, + Heads: headsCopy(), + Added: added, + Mode: mode, } return } @@ -441,9 +470,28 @@ func (ot *objectTree) IterateFrom(id string, convert ChangeConvertFunc, iterate return } -func (ot *objectTree) HasChange(s string) bool { - _, attachedExists := ot.tree.attached[s] - return attachedExists +func (ot *objectTree) HasChanges(chs ...string) bool { + hasChange := func(s string) bool { + _, attachedExists := ot.tree.attached[s] + if attachedExists { + return attachedExists + } + + has, err := ot.treeStorage.HasChange(context.Background(), s) + if err != nil { + return false + } + + return has + } + + for _, ch := range chs { + if !hasChange(ch) { + return false + } + } + + return true } func (ot *objectTree) Heads() []string { diff --git a/pkg/acl/tree/objecttreefactory.go b/pkg/acl/tree/objecttreefactory.go index 61631412..1db77908 100644 --- a/pkg/acl/tree/objecttreefactory.go +++ b/pkg/acl/tree/objecttreefactory.go @@ -104,7 +104,7 @@ func buildObjectTree(deps objectTreeDeps) (ObjectTree, error) { rawChangeLoader: deps.rawChangeLoader, tree: nil, keys: make(map[uint64]*symmetric.Key), - tmpChangesBuf: make([]*Change, 0, 10), + newChangesBuf: make([]*Change, 0, 10), difSnapshotBuf: make([]*treechangeproto.RawTreeChangeWithId, 0, 10), notSeenIdxBuf: make([]int, 0, 10), newSnapshotsBuf: make([]*Change, 0, 10),