diff --git a/pkg/acl/tree/objecttree_test.go b/pkg/acl/tree/objecttree_test.go index 463dca31..a27c9cc1 100644 --- a/pkg/acl/tree/objecttree_test.go +++ b/pkg/acl/tree/objecttree_test.go @@ -346,7 +346,7 @@ func TestObjectTree(t *testing.T) { }) }) - t.Run("changes after common snapshot simple", func(t *testing.T) { + t.Run("changes after common snapshot db complex", func(t *testing.T) { ctx := prepareTreeContext(t, aclList) changeCreator := ctx.changeCreator objTree := ctx.objTree @@ -355,29 +355,18 @@ func TestObjectTree(t *testing.T) { changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"), changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"), + changeCreator.createRaw("4", aclList.Head().Id, "0", false, "2"), + changeCreator.createRaw("5", aclList.Head().Id, "0", false, "1"), + // main difference from tree example + changeCreator.createRaw("6", aclList.Head().Id, "0", true, "3", "4", "5"), } _, err := objTree.AddRawChanges(context.Background(), rawChanges...) require.NoError(t, err, "adding changes should be without error") - require.Equal(t, "3", objTree.Root().Id) + require.Equal(t, "6", objTree.Root().Id) - t.Run("changes from db", func(t *testing.T) { - changes, err := objTree.ChangesAfterCommonSnapshot([]string{"0"}, []string{}) - require.NoError(t, err, "changes after common snapshot should be without error") - - changeIds := make(map[string]struct{}) - for _, ch := range changes { - changeIds[ch.Id] = struct{}{} - } - - for _, raw := range rawChanges { - _, ok := changeIds[raw.Id] - assert.Equal(t, true, ok) - } - }) - - t.Run("changes from db with empty path", func(t *testing.T) { - changes, err := objTree.ChangesAfterCommonSnapshot([]string{}, []string{}) + t.Run("all changes from db", func(t *testing.T) { + changes, err := objTree.ChangesAfterCommonSnapshot([]string{"3", "0"}, []string{}) require.NoError(t, err, "changes after common snapshot should be without error") changeIds := make(map[string]struct{}) @@ -392,6 +381,44 @@ func TestObjectTree(t *testing.T) { _, ok := changeIds["0"] assert.Equal(t, true, ok) }) + + t.Run("changes from tree db 1", func(t *testing.T) { + changes, err := objTree.ChangesAfterCommonSnapshot([]string{"3", "0"}, []string{"1"}) + require.NoError(t, err, "changes after common snapshot should be without error") + + changeIds := make(map[string]struct{}) + for _, ch := range changes { + changeIds[ch.Id] = struct{}{} + } + + for _, id := range []string{"2", "3", "4", "5", "6"} { + _, ok := changeIds[id] + assert.Equal(t, true, ok) + } + for _, id := range []string{"0", "1"} { + _, ok := changeIds[id] + assert.Equal(t, false, ok) + } + }) + + t.Run("changes from tree db 5", func(t *testing.T) { + changes, err := objTree.ChangesAfterCommonSnapshot([]string{"3", "0"}, []string{"5"}) + require.NoError(t, err, "changes after common snapshot should be without error") + + changeIds := make(map[string]struct{}) + for _, ch := range changes { + changeIds[ch.Id] = struct{}{} + } + + for _, id := range []string{"2", "3", "4", "6"} { + _, ok := changeIds[id] + assert.Equal(t, true, ok) + } + for _, id := range []string{"0", "1", "5"} { + _, ok := changeIds[id] + assert.Equal(t, false, ok) + } + }) }) t.Run("add new changes related to previous snapshot", func(t *testing.T) { diff --git a/pkg/acl/tree/rawloader.go b/pkg/acl/tree/rawloader.go index 0c1956a9..43854a29 100644 --- a/pkg/acl/tree/rawloader.go +++ b/pkg/acl/tree/rawloader.go @@ -19,6 +19,7 @@ type rawChangeLoader struct { type rawCacheEntry struct { change *Change rawChange *aclpb.RawChange + position int } func newRawChangeLoader(treeStorage storage.TreeStorage, changeBuilder ChangeBuilder) *rawChangeLoader { @@ -107,50 +108,56 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi r.cache = nil }() - // updating map - bufPosMap := make(map[string]int) - for _, breakpoint := range breakpoints { - bufPosMap[breakpoint] = -1 + existingBreakpoints := make([]string, 0, len(breakpoints)) + for _, b := range breakpoints { + entry, err := r.loadEntry(b) + if err != nil { + continue + } + entry.position = -1 + r.cache[b] = entry + existingBreakpoints = append(existingBreakpoints, b) } - bufPosMap[commonSnapshot] = -1 + r.cache[commonSnapshot] = rawCacheEntry{position: -1} dfs := func( commonSnapshot string, heads []string, startCounter int, shouldVisit func(counter int, mapExists bool) bool, - visit func(prevCounter int, entry rawCacheEntry) int) bool { + visit func(entry rawCacheEntry) rawCacheEntry) bool { // resetting stack r.idStack = r.idStack[:0] r.idStack = append(r.idStack, heads...) commonSnapshotVisited := false + var err error for len(r.idStack) > 0 { id := r.idStack[len(r.idStack)-1] r.idStack = r.idStack[:len(r.idStack)-1] - cnt, exists := bufPosMap[id] - if !shouldVisit(cnt, exists) { + entry, exists := r.cache[id] + if !shouldVisit(entry.position, exists) { continue } - - // TODO: add proper error handling, we must ignore errors on missing breakpoints though - entry, err := r.loadEntry(id) - if err != nil { - continue + if !exists { + entry, err = r.loadEntry(id) + if err != nil { + continue + } } // setting the counter when we visit - bufPosMap[id] = visit(cnt, entry) + r.cache[id] = visit(entry) for _, prev := range entry.change.PreviousIds { if prev == commonSnapshot { commonSnapshotVisited = true break } - cnt, exists = bufPosMap[prev] - if !shouldVisit(cnt, exists) { + entry, exists = r.cache[prev] + if !shouldVisit(entry.position, exists) { continue } r.idStack = append(r.idStack, prev) @@ -167,9 +174,10 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi func(counter int, mapExists bool) bool { return !mapExists }, - func(_ int, entry rawCacheEntry) int { + func(entry rawCacheEntry) rawCacheEntry { buffer = append(buffer, entry.rawChange) - return len(buffer) - 1 + entry.position = len(buffer) - 1 + return entry }) // checking if we stopped at breakpoints @@ -188,15 +196,16 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi } // marking all visited as nil - dfs(commonSnapshot, breakpoints, len(buffer), + dfs(commonSnapshot, existingBreakpoints, len(buffer), func(counter int, mapExists bool) bool { return !mapExists || counter < len(buffer) }, - func(discardedPosition int, entry rawCacheEntry) int { - if discardedPosition != -1 { - buffer[discardedPosition] = nil + func(entry rawCacheEntry) rawCacheEntry { + if entry.position != -1 { + buffer[entry.position] = nil } - return len(buffer) + 1 + entry.position = len(buffer) + 1 + return entry }) // discarding visited @@ -208,11 +217,6 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi } func (r *rawChangeLoader) loadEntry(id string) (entry rawCacheEntry, err error) { - var ok bool - if entry, ok = r.cache[id]; ok { - return - } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) defer cancel() @@ -229,7 +233,6 @@ func (r *rawChangeLoader) loadEntry(id string) (entry rawCacheEntry, err error) change: change, rawChange: rawChange, } - r.cache[id] = entry return } diff --git a/pkg/acl/tree/util.go b/pkg/acl/tree/util.go index 37a08d31..0e6cc7cd 100644 --- a/pkg/acl/tree/util.go +++ b/pkg/acl/tree/util.go @@ -21,6 +21,8 @@ OuterLoop: if ourPath[i] == theirPath[j] { i-- j-- + } else { + break } } return ourPath[i+1], nil