diff --git a/common/commonspace/syncservice/requestfactory.go b/common/commonspace/syncservice/requestfactory.go new file mode 100644 index 00000000..ab19bc6a --- /dev/null +++ b/common/commonspace/syncservice/requestfactory.go @@ -0,0 +1,59 @@ +package syncservice + +import ( + "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto" + "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/tree" + "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/treechangeproto" + "github.com/anytypeio/go-anytype-infrastructure-experiments/util/slice" +) + +type RequestFactory interface { + FullSyncRequest(t tree.ObjectTree, theirHeads, theirSnapshotPath []string, trackingId string) (req *spacesyncproto.ObjectSyncMessage, err error) + FullSyncResponse(t tree.ObjectTree, theirHeads, theirSnapshotPath []string, trackingId string) (*spacesyncproto.ObjectSyncMessage, error) +} + +func newRequestFactory() RequestFactory { + return &requestFactory{} +} + +type requestFactory struct{} + +func (r *requestFactory) FullSyncRequest(t tree.ObjectTree, theirHeads, theirSnapshotPath []string, trackingId string) (msg *spacesyncproto.ObjectSyncMessage, err error) { + req := &spacesyncproto.ObjectFullSyncRequest{} + if t == nil { + msg = spacesyncproto.WrapFullRequest(req, t.Header(), t.ID(), trackingId) + return + } + + req.Heads = t.Heads() + req.SnapshotPath = t.SnapshotPath() + + var changesAfterSnapshot []*treechangeproto.RawTreeChangeWithId + changesAfterSnapshot, err = t.ChangesAfterCommonSnapshot(theirSnapshotPath, theirHeads) + if err != nil { + return + } + + req.Changes = changesAfterSnapshot + msg = spacesyncproto.WrapFullRequest(req, t.Header(), t.ID(), trackingId) + return +} + +func (r *requestFactory) FullSyncResponse(t tree.ObjectTree, theirHeads, theirSnapshotPath []string, trackingId string) (msg *spacesyncproto.ObjectSyncMessage, err error) { + resp := &spacesyncproto.ObjectFullSyncResponse{ + Heads: t.Heads(), + SnapshotPath: t.SnapshotPath(), + } + if slice.UnsortedEquals(theirHeads, t.Heads()) { + msg = spacesyncproto.WrapFullResponse(resp, t.Header(), t.ID(), trackingId) + return + } + + ourChanges, err := t.ChangesAfterCommonSnapshot(theirSnapshotPath, theirHeads) + if err != nil { + return + } + resp.Changes = ourChanges + msg = spacesyncproto.WrapFullResponse(resp, t.Header(), t.ID(), trackingId) + return +} diff --git a/common/commonspace/syncservice/synchandler.go b/common/commonspace/syncservice/synchandler.go index 0123cbff..16ce2177 100644 --- a/common/commonspace/syncservice/synchandler.go +++ b/common/commonspace/syncservice/synchandler.go @@ -6,7 +6,6 @@ import ( "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/cache" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto" "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/tree" - "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/treechangeproto" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/slice" ) @@ -14,17 +13,19 @@ type syncHandler struct { spaceId string treeCache cache.TreeCache syncClient SyncClient + factory RequestFactory } type SyncHandler interface { HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) } -func newSyncHandler(spaceId string, treeCache cache.TreeCache, syncClient SyncClient) *syncHandler { +func newSyncHandler(spaceId string, treeCache cache.TreeCache, syncClient SyncClient, factory RequestFactory) *syncHandler { return &syncHandler{ spaceId: spaceId, treeCache: treeCache, syncClient: syncClient, + factory: factory, } } @@ -47,10 +48,7 @@ func (s *syncHandler) handleHeadUpdate( update *spacesyncproto.ObjectHeadUpdate, msg *spacesyncproto.ObjectSyncMessage) (err error) { - var ( - fullRequest *spacesyncproto.ObjectFullSyncRequest - result tree.AddResult - ) + var fullRequest *spacesyncproto.ObjectSyncMessage res, err := s.treeCache.GetTree(ctx, s.spaceId, msg.TreeId) if err != nil { return @@ -62,21 +60,20 @@ func (s *syncHandler) handleHeadUpdate( defer res.Release() defer objTree.Unlock() - if slice.UnsortedEquals(update.Heads, objTree.Heads()) { + if s.alreadyHaveHeads(objTree, update.Heads) { return nil } - result, err = objTree.AddRawChanges(ctx, update.Changes...) + _, err = objTree.AddRawChanges(ctx, update.Changes...) if err != nil { return err } - // if after the heads are equal, or we have them locally - if slice.UnsortedEquals(update.Heads, result.Heads) || objTree.HasChanges(update.Heads...) { + if s.alreadyHaveHeads(objTree, update.Heads) { return nil } - fullRequest, err = s.prepareFullSyncRequest(objTree, update) + fullRequest, err = s.factory.FullSyncRequest(objTree, update.Heads, update.SnapshotPath, msg.TrackingId) if err != nil { return err } @@ -84,8 +81,7 @@ func (s *syncHandler) handleHeadUpdate( }() if fullRequest != nil { - return s.syncClient.SendAsync(senderId, - spacesyncproto.WrapFullRequest(fullRequest, msg.RootChange, msg.TreeId, msg.TrackingId)) + return s.syncClient.SendAsync(senderId, fullRequest) } return } @@ -96,7 +92,7 @@ func (s *syncHandler) handleFullSyncRequest( request *spacesyncproto.ObjectFullSyncRequest, msg *spacesyncproto.ObjectSyncMessage) (err error) { var ( - fullResponse *spacesyncproto.ObjectFullSyncResponse + fullResponse *spacesyncproto.ObjectSyncMessage header = msg.RootChange ) defer func() { @@ -120,20 +116,21 @@ func (s *syncHandler) handleFullSyncRequest( header = objTree.Header() } - _, err = objTree.AddRawChanges(ctx, request.Changes...) - if err != nil { - return err + if !s.alreadyHaveHeads(objTree, request.Heads) { + _, err = objTree.AddRawChanges(ctx, request.Changes...) + if err != nil { + return err + } } - fullResponse, err = s.prepareFullSyncResponse(request.SnapshotPath, request.Heads, objTree) + fullResponse, err = s.factory.FullSyncResponse(objTree, request.Heads, request.SnapshotPath, msg.TrackingId) return err }() if err != nil { return } - return s.syncClient.SendAsync(senderId, - spacesyncproto.WrapFullResponse(fullResponse, header, msg.TreeId, msg.TrackingId)) + return s.syncClient.SendAsync(senderId, fullResponse) } func (s *syncHandler) handleFullSyncResponse( @@ -152,8 +149,7 @@ func (s *syncHandler) handleFullSyncResponse( defer res.Release() defer objTree.Unlock() - // if we already have the heads for whatever reason - if slice.UnsortedEquals(response.Heads, objTree.Heads()) { + if s.alreadyHaveHeads(objTree, response.Heads) { return nil } @@ -164,39 +160,6 @@ func (s *syncHandler) handleFullSyncResponse( return } -func (s *syncHandler) prepareFullSyncRequest( - t tree.ObjectTree, - update *spacesyncproto.ObjectHeadUpdate) (req *spacesyncproto.ObjectFullSyncRequest, err error) { - req = &spacesyncproto.ObjectFullSyncRequest{ - Heads: t.Heads(), - SnapshotPath: t.SnapshotPath(), - } - if len(update.Changes) != 0 { - var changesAfterSnapshot []*treechangeproto.RawTreeChangeWithId - changesAfterSnapshot, err = t.ChangesAfterCommonSnapshot(update.SnapshotPath, update.Heads) - if err != nil { - return - } - req.Changes = changesAfterSnapshot - } - return &spacesyncproto.ObjectFullSyncRequest{ - Heads: t.Heads(), - SnapshotPath: t.SnapshotPath(), - }, nil -} - -func (s *syncHandler) prepareFullSyncResponse( - theirPath, - theirHeads []string, - t tree.ObjectTree) (*spacesyncproto.ObjectFullSyncResponse, error) { - ourChanges, err := t.ChangesAfterCommonSnapshot(theirPath, theirHeads) - if err != nil { - return nil, err - } - - return &spacesyncproto.ObjectFullSyncResponse{ - Heads: t.Heads(), - Changes: ourChanges, - SnapshotPath: t.SnapshotPath(), - }, nil +func (s *syncHandler) alreadyHaveHeads(t tree.ObjectTree, heads []string) bool { + return slice.UnsortedEquals(t.Heads(), heads) || t.HasChanges(heads...) } diff --git a/common/commonspace/syncservice/syncservice.go b/common/commonspace/syncservice/syncservice.go index 64ac67a6..c0a461e1 100644 --- a/common/commonspace/syncservice/syncservice.go +++ b/common/commonspace/syncservice/syncservice.go @@ -50,7 +50,7 @@ func NewSyncService(spaceId string, headNotifiable HeadNotifiable, cache cache.T streamPool := newStreamPool(func(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) { return syncHandler.HandleMessage(ctx, senderId, message) }) - syncHandler = newSyncHandler(spaceId, cache, streamPool) + syncHandler = newSyncHandler(spaceId, cache, streamPool, newRequestFactory()) return newSyncService( spaceId, headNotifiable,