diff --git a/common/commonspace/synctree/synctree.go b/common/commonspace/synctree/synctree.go index e8c69db3..bf991efe 100644 --- a/common/commonspace/synctree/synctree.go +++ b/common/commonspace/synctree/synctree.go @@ -247,11 +247,11 @@ func (s *syncTree) AddContent(ctx context.Context, content tree.SignableChangeCo return } -func (s *syncTree) AddRawChanges(ctx context.Context, changes ...*treechangeproto.RawTreeChangeWithId) (res tree.AddResult, err error) { +func (s *syncTree) AddRawChanges(ctx context.Context, changesPayload tree.RawChangesPayload) (res tree.AddResult, err error) { if err = s.checkAlive(); err != nil { return } - res, err = s.ObjectTree.AddRawChanges(ctx, changes...) + res, err = s.ObjectTree.AddRawChanges(ctx, changesPayload) if err != nil { return } diff --git a/common/commonspace/synctree/synctreehandler.go b/common/commonspace/synctree/synctreehandler.go index 8c119e2a..5939a16a 100644 --- a/common/commonspace/synctree/synctreehandler.go +++ b/common/commonspace/synctree/synctreehandler.go @@ -75,7 +75,10 @@ func (s *syncTreeHandler) handleHeadUpdate( return nil } - _, err = objTree.AddRawChanges(ctx, update.Changes...) + _, err = objTree.AddRawChanges(ctx, tree.RawChangesPayload{ + NewHeads: update.Heads, + RawChanges: update.Changes, + }) if err != nil { return err } @@ -128,7 +131,10 @@ func (s *syncTreeHandler) handleFullSyncRequest( defer objTree.Unlock() if len(request.Changes) != 0 && !s.alreadyHasHeads(objTree, request.Heads) { - _, err = objTree.AddRawChanges(ctx, request.Changes...) + _, err = objTree.AddRawChanges(ctx, tree.RawChangesPayload{ + NewHeads: request.Heads, + RawChanges: request.Changes, + }) if err != nil { return err } @@ -168,7 +174,10 @@ func (s *syncTreeHandler) handleFullSyncResponse( return nil } - _, err = objTree.AddRawChanges(ctx, response.Changes...) + _, err = objTree.AddRawChanges(ctx, tree.RawChangesPayload{ + NewHeads: response.Heads, + RawChanges: response.Changes, + }) return err }() log.With("error", err != nil). diff --git a/common/pkg/acl/tree/objecttree.go b/common/pkg/acl/tree/objecttree.go index bb8dfc6a..8c4e0907 100644 --- a/common/pkg/acl/tree/objecttree.go +++ b/common/pkg/acl/tree/objecttree.go @@ -33,6 +33,11 @@ type AddResult struct { Mode Mode } +type RawChangesPayload struct { + NewHeads []string + RawChanges []*treechangeproto.RawTreeChangeWithId +} + type ChangeIterateFunc = func(change *Change) bool type ChangeConvertFunc = func(decrypted []byte) (any, error) @@ -55,7 +60,7 @@ type ObjectTree interface { Storage() storage.TreeStorage AddContent(ctx context.Context, content SignableChangeContent) (AddResult, error) - AddRawChanges(ctx context.Context, changes ...*treechangeproto.RawTreeChangeWithId) (AddResult, error) + AddRawChanges(ctx context.Context, changes RawChangesPayload) (AddResult, error) Delete() error Close() error @@ -113,10 +118,10 @@ func defaultObjectTreeDeps( } } -func (ot *objectTree) rebuildFromStorage(newChanges []*Change) (err error) { +func (ot *objectTree) rebuildFromStorage(theirHeads []string, newChanges []*Change) (err error) { ot.treeBuilder.Reset() - ot.tree, err = ot.treeBuilder.Build(newChanges) + ot.tree, err = ot.treeBuilder.Build(theirHeads, newChanges) if err != nil { return } @@ -213,8 +218,8 @@ func (ot *objectTree) prepareBuilderContent(content SignableChangeContent) (cnt return } -func (ot *objectTree) AddRawChanges(ctx context.Context, rawChanges ...*treechangeproto.RawTreeChangeWithId) (addResult AddResult, err error) { - addResult, err = ot.addRawChanges(ctx, rawChanges...) +func (ot *objectTree) AddRawChanges(ctx context.Context, changesPayload RawChangesPayload) (addResult AddResult, err error) { + addResult, err = ot.addRawChanges(ctx, changesPayload) if err != nil { return } @@ -235,7 +240,7 @@ func (ot *objectTree) AddRawChanges(ctx context.Context, rawChanges ...*treechan return } -func (ot *objectTree) addRawChanges(ctx context.Context, rawChanges ...*treechangeproto.RawTreeChangeWithId) (addResult AddResult, err error) { +func (ot *objectTree) addRawChanges(ctx context.Context, changesPayload RawChangesPayload) (addResult AddResult, err error) { // resetting buffers ot.newChangesBuf = ot.newChangesBuf[:0] ot.notSeenIdxBuf = ot.notSeenIdxBuf[:0] @@ -252,7 +257,7 @@ func (ot *objectTree) addRawChanges(ctx context.Context, rawChanges ...*treechan prevHeadsCopy := headsCopy() // filtering changes, verifying and unmarshalling them - for idx, ch := range rawChanges { + for idx, ch := range changesPayload.RawChanges { // not unmarshalling the changes if they were already added either as unattached or attached if _, exists := ot.tree.attached[ch.Id]; exists { continue @@ -331,17 +336,17 @@ func (ot *objectTree) addRawChanges(ctx context.Context, rawChanges ...*treechan ot.newChangesBuf = discardFromSlice(ot.newChangesBuf, func(ch *Change) bool { return ch == nil }) if shouldRebuildFromStorage { - err = ot.rebuildFromStorage(ot.newChangesBuf) + err = ot.rebuildFromStorage(changesPayload.NewHeads, ot.newChangesBuf) if err != nil { // rebuilding without new changes - ot.rebuildFromStorage(nil) + ot.rebuildFromStorage(nil, nil) return } - addResult, err = ot.createAddResult(prevHeadsCopy, Rebuild, nil, rawChanges) + addResult, err = ot.createAddResult(prevHeadsCopy, Rebuild, nil, changesPayload.RawChanges) if err != nil { // that means that some unattached changes were somehow corrupted in memory // this shouldn't happen but if that happens, then rebuilding from storage - ot.rebuildFromStorage(nil) + ot.rebuildFromStorage(nil, nil) return } return @@ -366,11 +371,11 @@ func (ot *objectTree) addRawChanges(ctx context.Context, rawChanges ...*treechan err = ErrHasInvalidChanges return } - addResult, err = ot.createAddResult(prevHeadsCopy, mode, treeChangesAdded, rawChanges) + addResult, err = ot.createAddResult(prevHeadsCopy, mode, treeChangesAdded, changesPayload.RawChanges) if err != nil { // that means that some unattached changes were somehow corrupted in memory // this shouldn't happen but if that happens, then rebuilding from storage - ot.rebuildFromStorage(nil) + ot.rebuildFromStorage(nil, nil) return } return diff --git a/common/pkg/acl/tree/objecttreefactory.go b/common/pkg/acl/tree/objecttreefactory.go index 69cb2acf..e2f7b6db 100644 --- a/common/pkg/acl/tree/objecttreefactory.go +++ b/common/pkg/acl/tree/objecttreefactory.go @@ -107,7 +107,7 @@ func buildObjectTree(deps objectTreeDeps) (ObjectTree, error) { newSnapshotsBuf: make([]*Change, 0, 10), } - err := objTree.rebuildFromStorage(nil) + err := objTree.rebuildFromStorage(nil, nil) if err != nil { return nil, err } diff --git a/common/pkg/acl/tree/treebuilder.go b/common/pkg/acl/tree/treebuilder.go index 1250bfdc..e7997e86 100644 --- a/common/pkg/acl/tree/treebuilder.go +++ b/common/pkg/acl/tree/treebuilder.go @@ -40,27 +40,32 @@ func (tb *treeBuilder) Reset() { tb.tree = &Tree{} } -func (tb *treeBuilder) Build(newChanges []*Change) (*Tree, error) { - var headsAndNewChanges []string +func (tb *treeBuilder) Build(theirHeads []string, newChanges []*Change) (*Tree, error) { + var proposedHeads []string heads, err := tb.treeStorage.Heads() if err != nil { return nil, err } - - headsAndNewChanges = append(headsAndNewChanges, heads...) tb.cache = make(map[string]*Change) + proposedHeads = append(proposedHeads, heads...) + if len(theirHeads) > 0 { + proposedHeads = append(proposedHeads, theirHeads...) + } for _, ch := range newChanges { - headsAndNewChanges = append(headsAndNewChanges, ch.Id) + // we don't know what new heads are, so every change can be head + if len(theirHeads) == 0 { + proposedHeads = append(proposedHeads, ch.Id) + } tb.cache[ch.Id] = ch } - log.With(zap.Strings("heads", heads)).Debug("building tree") - breakpoint, err := tb.findBreakpoint(headsAndNewChanges) + log.With(zap.Strings("heads", proposedHeads)).Debug("building tree") + breakpoint, err := tb.findBreakpoint(proposedHeads) if err != nil { return nil, fmt.Errorf("findBreakpoint error: %v", err) } - if err = tb.buildTree(headsAndNewChanges, breakpoint); err != nil { + if err = tb.buildTree(proposedHeads, breakpoint); err != nil { return nil, fmt.Errorf("buildTree error: %v", err) }