Fix tree append logic, validtate only new changes and check unattached changes which were added at this round

This commit is contained in:
mcrakhman 2022-09-11 18:45:47 +02:00
parent 74ebcc5616
commit 639cc302d2
No known key found for this signature in database
GPG Key ID: DED12CFEF5B8396B
8 changed files with 225 additions and 138 deletions

View File

@ -28,13 +28,14 @@ type ChangeBuilder interface {
ConvertFromRaw(rawChange *aclpb.RawChange) (ch *Change, err error) ConvertFromRaw(rawChange *aclpb.RawChange) (ch *Change, err error)
ConvertFromRawAndVerify(rawChange *aclpb.RawChange) (ch *Change, err error) ConvertFromRawAndVerify(rawChange *aclpb.RawChange) (ch *Change, err error)
BuildContent(payload BuilderContent) (ch *Change, raw *aclpb.RawChange, err error) BuildContent(payload BuilderContent) (ch *Change, raw *aclpb.RawChange, err error)
BuildRaw(ch *Change) (*aclpb.RawChange, error)
} }
type changeBuilder struct { type changeBuilder struct {
keys *common.Keychain keys *common.Keychain
} }
func newChangeBuilder(keys *common.Keychain) *changeBuilder { func newChangeBuilder(keys *common.Keychain) ChangeBuilder {
return &changeBuilder{keys: keys} return &changeBuilder{keys: keys}
} }
@ -125,3 +126,18 @@ func (c *changeBuilder) BuildContent(payload BuilderContent) (ch *Change, raw *a
} }
return return
} }
func (c *changeBuilder) BuildRaw(ch *Change) (raw *aclpb.RawChange, err error) {
var marshalled []byte
marshalled, err = ch.Content.Marshal()
if err != nil {
return
}
raw = &aclpb.RawChange{
Payload: marshalled,
Signature: ch.Signature(),
Id: ch.Id,
}
return
}

View File

@ -7,8 +7,10 @@ import (
) )
type ObjectTreeValidator interface { type ObjectTreeValidator interface {
// ValidateTree should always be entered while holding a read lock on ACLList // ValidateFullTree should always be entered while holding a read lock on ACLList
ValidateTree(tree *Tree, aclList list.ACLList) error ValidateFullTree(tree *Tree, aclList list.ACLList) error
// ValidateNewChanges should always be entered while holding a read lock on ACLList
ValidateNewChanges(tree *Tree, aclList list.ACLList, newChanges []*Change) error
} }
type objectTreeValidator struct{} type objectTreeValidator struct{}
@ -17,42 +19,55 @@ func newTreeValidator() ObjectTreeValidator {
return &objectTreeValidator{} return &objectTreeValidator{}
} }
func (v *objectTreeValidator) ValidateTree(tree *Tree, aclList list.ACLList) (err error) { func (v *objectTreeValidator) ValidateFullTree(tree *Tree, aclList list.ACLList) (err error) {
tree.Iterate(tree.RootId(), func(c *Change) (isContinue bool) {
err = v.validateChange(tree, aclList, c)
return err == nil
})
return err
}
func (v *objectTreeValidator) ValidateNewChanges(tree *Tree, aclList list.ACLList, newChanges []*Change) (err error) {
for _, c := range newChanges {
err = v.validateChange(tree, aclList, c)
if err != nil {
return
}
}
return
}
func (v *objectTreeValidator) validateChange(tree *Tree, aclList list.ACLList, c *Change) (err error) {
var ( var (
perm list.UserPermissionPair perm list.UserPermissionPair
state = aclList.ACLState() state = aclList.ACLState()
) )
// checking if the user could write
perm, err = state.PermissionsAtRecord(c.Content.AclHeadId, c.Content.Identity)
if err != nil {
return
}
tree.Iterate(tree.RootId(), func(c *Change) (isContinue bool) { if perm.Permission != aclpb.ACLChange_Writer && perm.Permission != aclpb.ACLChange_Admin {
// checking if the user could write err = list.ErrInsufficientPermissions
perm, err = state.PermissionsAtRecord(c.Content.AclHeadId, c.Content.Identity) return
}
// checking if the change refers to later acl heads than its previous ids
for _, id := range c.PreviousIds {
prevChange := tree.attached[id]
if prevChange.Content.AclHeadId == c.Content.AclHeadId {
continue
}
var after bool
after, err = aclList.IsAfter(c.Content.AclHeadId, prevChange.Content.AclHeadId)
if err != nil { if err != nil {
return false return
} }
if !after {
if perm.Permission != aclpb.ACLChange_Writer && perm.Permission != aclpb.ACLChange_Admin { err = fmt.Errorf("current acl head id (%s) should be after each of the previous ones (%s)", c.Content.AclHeadId, prevChange.Content.AclHeadId)
err = list.ErrInsufficientPermissions return
return false
} }
}
// checking if the change refers to later acl heads than its previous ids return
for _, id := range c.PreviousIds {
prevChange := tree.attached[id]
if prevChange.Content.AclHeadId == c.Content.AclHeadId {
continue
}
var after bool
after, err = aclList.IsAfter(c.Content.AclHeadId, prevChange.Content.AclHeadId)
if err != nil {
return false
}
if !after {
err = fmt.Errorf("current acl head id (%s) should be after each of the previous ones (%s)", c.Content.AclHeadId, prevChange.Content.AclHeadId)
return false
}
}
return true
})
return err
} }

View File

@ -31,18 +31,12 @@ var (
type AddResultSummary int type AddResultSummary int
const (
AddResultSummaryNothing AddResultSummary = iota
AddResultSummaryAppend
AddResultSummaryRebuild
)
type AddResult struct { type AddResult struct {
OldHeads []string OldHeads []string
Heads []string Heads []string
Added []*aclpb.RawChange Added []*aclpb.RawChange
Summary AddResultSummary Mode Mode
} }
type ChangeIterateFunc = func(change *Change) bool type ChangeIterateFunc = func(change *Change) bool
@ -198,7 +192,9 @@ func (ot *objectTree) rebuildFromStorage(newChanges []*Change) (err error) {
// but obviously they are not roots, because of the way how we construct the tree // but obviously they are not roots, because of the way how we construct the tree
ot.tree.clearPossibleRoots() ot.tree.clearPossibleRoots()
return ot.validateTree() // it is a good question whether we need to validate everything
// because maybe we can trust the stuff that is already in the storage
return ot.validateTree(nil)
} }
func (ot *objectTree) ID() string { func (ot *objectTree) ID() string {
@ -324,7 +320,11 @@ func (ot *objectTree) addRawChanges(ctx context.Context, rawChanges ...*aclpb.Ra
// filtering changes, verifying and unmarshalling them // filtering changes, verifying and unmarshalling them
for idx, ch := range rawChanges { for idx, ch := range rawChanges {
if ot.HasChange(ch.Id) { // not unmarshalling the changes if they were already added either as unattached or attached
if _, exists := ot.tree.attached[ch.Id]; exists {
continue
}
if _, exists := ot.tree.unAttached[ch.Id]; exists {
continue continue
} }
@ -346,29 +346,54 @@ func (ot *objectTree) addRawChanges(ctx context.Context, rawChanges ...*aclpb.Ra
addResult = AddResult{ addResult = AddResult{
OldHeads: prevHeadsCopy, OldHeads: prevHeadsCopy,
Heads: prevHeadsCopy, Heads: prevHeadsCopy,
Summary: AddResultSummaryNothing, Mode: Nothing,
} }
return return
} }
// returns changes that we added to the tree // returns changes that we added to the tree as attached this round
getAddedChanges := func() []*aclpb.RawChange { // they can include not only the changes that were added now,
var added []*aclpb.RawChange // but also the changes that were previously in the tree
getAddedChanges := func(toConvert []*Change) (added []*aclpb.RawChange, err error) {
alreadyConverted := make(map[*Change]struct{})
// first we see if we have already unmarshalled those changes
for _, idx := range ot.notSeenIdxBuf { for _, idx := range ot.notSeenIdxBuf {
rawChange := rawChanges[idx] rawChange := rawChanges[idx]
if _, exists := ot.tree.attached[rawChange.Id]; exists { if ch, exists := ot.tree.attached[rawChange.Id]; exists {
if len(toConvert) != 0 {
alreadyConverted[ch] = struct{}{}
}
added = append(added, rawChange) added = append(added, rawChange)
} }
} }
return added // this will happen in case we called rebuild from storage
// or if all the changes that we added were contained in current add request
// (this what would happen in most cases)
if len(toConvert) == 0 || len(added) == len(toConvert) {
return
}
// but in some cases it may happen that the changes that were added this round
// were contained in unattached from previous requests
for _, ch := range toConvert {
// if we got some changes that we need to convert to raw
if _, exists := alreadyConverted[ch]; !exists {
var raw *aclpb.RawChange
raw, err = ot.changeBuilder.BuildRaw(ch)
if err != nil {
return
}
added = append(added, raw)
}
}
return
} }
rollback := func() { rollback := func(changes []*Change) {
for _, ch := range ot.tmpChangesBuf { for _, ch := range changes {
if _, exists := ot.tree.attached[ch.Id]; exists { if _, exists := ot.tree.attached[ch.Id]; exists {
delete(ot.tree.attached, ch.Id) delete(ot.tree.attached, ch.Id)
} else if _, exists := ot.tree.unAttached[ch.Id]; exists {
delete(ot.tree.unAttached, ch.Id)
} }
} }
} }
@ -396,43 +421,56 @@ func (ot *objectTree) addRawChanges(ctx context.Context, rawChanges ...*aclpb.Ra
ot.rebuildFromStorage(nil) ot.rebuildFromStorage(nil)
return return
} }
var added []*aclpb.RawChange
added, err = getAddedChanges(nil)
// we shouldn't get any error in this case
if err != nil {
panic(err)
}
addResult = AddResult{ addResult = AddResult{
OldHeads: prevHeadsCopy, OldHeads: prevHeadsCopy,
Heads: headsCopy(), Heads: headsCopy(),
Added: getAddedChanges(), Added: added,
Summary: AddResultSummaryRebuild, Mode: Rebuild,
} }
return return
} }
} }
// normal mode of operation, where we don't need to rebuild from database // normal mode of operation, where we don't need to rebuild from database
mode = ot.tree.Add(ot.tmpChangesBuf...) mode, treeChangesAdded := ot.tree.Add(ot.tmpChangesBuf...)
switch mode { switch mode {
case Nothing: case Nothing:
addResult = AddResult{ addResult = AddResult{
OldHeads: prevHeadsCopy, OldHeads: prevHeadsCopy,
Heads: prevHeadsCopy, Heads: prevHeadsCopy,
Summary: AddResultSummaryNothing, Mode: mode,
} }
return return
default: default:
// just rebuilding the state from start without reloading everything from tree storage // we need to validate only newly added changes
// as an optimization we could've started from current heads, but I didn't implement that err = ot.validateTree(treeChangesAdded)
err = ot.validateTree()
if err != nil { if err != nil {
rollback() rollback(treeChangesAdded)
err = ErrHasInvalidChanges err = ErrHasInvalidChanges
return return
} }
var added []*aclpb.RawChange
added, err = getAddedChanges(treeChangesAdded)
if err != nil {
// that means that some unattached changes were somehow corrupted in memory
// this shouldn't happen but if that happens, then rebuilding from storage
ot.rebuildFromStorage(nil)
return
}
addResult = AddResult{ addResult = AddResult{
OldHeads: prevHeadsCopy, OldHeads: prevHeadsCopy,
Heads: headsCopy(), Heads: headsCopy(),
Added: getAddedChanges(), Added: added,
Summary: AddResultSummaryAppend, Mode: mode,
} }
} }
return return
@ -478,8 +516,7 @@ func (ot *objectTree) IterateFrom(id string, convert ChangeConvertFunc, iterate
func (ot *objectTree) HasChange(s string) bool { func (ot *objectTree) HasChange(s string) bool {
_, attachedExists := ot.tree.attached[s] _, attachedExists := ot.tree.attached[s]
_, unattachedExists := ot.tree.unAttached[s] return attachedExists
return attachedExists || unattachedExists
} }
func (ot *objectTree) Heads() []string { func (ot *objectTree) Heads() []string {
@ -552,7 +589,7 @@ func (ot *objectTree) snapshotPathIsActual() bool {
return len(ot.snapshotPath) != 0 && ot.snapshotPath[0] == ot.tree.RootId() return len(ot.snapshotPath) != 0 && ot.snapshotPath[0] == ot.tree.RootId()
} }
func (ot *objectTree) validateTree() error { func (ot *objectTree) validateTree(newChanges []*Change) error {
ot.aclList.RLock() ot.aclList.RLock()
defer ot.aclList.RUnlock() defer ot.aclList.RUnlock()
state := ot.aclList.ACLState() state := ot.aclList.ACLState()
@ -563,8 +600,11 @@ func (ot *objectTree) validateTree() error {
ot.keys[key] = value ot.keys[key] = value
} }
} }
if len(newChanges) == 0 {
return ot.validator.ValidateFullTree(ot.tree, ot.aclList)
}
return ot.validator.ValidateTree(ot.tree, ot.aclList) return ot.validator.ValidateNewChanges(ot.tree, ot.aclList, newChanges)
} }
func (ot *objectTree) DebugDump() (string, error) { func (ot *objectTree) DebugDump() (string, error) {

View File

@ -7,7 +7,6 @@ import (
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/storage" "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/storage"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/testutils/acllistbuilder" "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/testutils/acllistbuilder"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/signingkey" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/signingkey"
"github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing" "testing"
@ -43,29 +42,33 @@ func (c *mockChangeCreator) createNewTreeStorage(treeId, aclListId, aclHeadId, f
return treeStorage return treeStorage
} }
type mockChangeBuilder struct{} type mockChangeBuilder struct {
originalBuilder ChangeBuilder
}
func (c *mockChangeBuilder) ConvertFromRaw(rawChange *aclpb.RawChange) (ch *Change, err error) { func (c *mockChangeBuilder) ConvertFromRaw(rawChange *aclpb.RawChange) (ch *Change, err error) {
unmarshalled := &aclpb.Change{} return c.originalBuilder.ConvertFromRaw(rawChange)
err = proto.Unmarshal(rawChange.Payload, unmarshalled)
if err != nil {
return nil, err
}
ch = NewChange(rawChange.Id, unmarshalled, rawChange.Signature)
return
} }
func (c *mockChangeBuilder) ConvertFromRawAndVerify(rawChange *aclpb.RawChange) (ch *Change, err error) { func (c *mockChangeBuilder) ConvertFromRawAndVerify(rawChange *aclpb.RawChange) (ch *Change, err error) {
return c.ConvertFromRaw(rawChange) return c.originalBuilder.ConvertFromRaw(rawChange)
} }
func (c *mockChangeBuilder) BuildContent(payload BuilderContent) (ch *Change, raw *aclpb.RawChange, err error) { func (c *mockChangeBuilder) BuildContent(payload BuilderContent) (ch *Change, raw *aclpb.RawChange, err error) {
panic("implement me") panic("implement me")
} }
func (c *mockChangeBuilder) BuildRaw(ch *Change) (raw *aclpb.RawChange, err error) {
return c.originalBuilder.BuildRaw(ch)
}
type mockChangeValidator struct{} type mockChangeValidator struct{}
func (m *mockChangeValidator) ValidateTree(tree *Tree, aclList list.ACLList) error { func (m *mockChangeValidator) ValidateNewChanges(tree *Tree, aclList list.ACLList, newChanges []*Change) error {
return nil
}
func (m *mockChangeValidator) ValidateFullTree(tree *Tree, aclList list.ACLList) error {
return nil return nil
} }
@ -90,7 +93,9 @@ func prepareACLList(t *testing.T) list.ACLList {
func prepareTreeContext(t *testing.T, aclList list.ACLList) testTreeContext { func prepareTreeContext(t *testing.T, aclList list.ACLList) testTreeContext {
changeCreator := &mockChangeCreator{} changeCreator := &mockChangeCreator{}
treeStorage := changeCreator.createNewTreeStorage("treeId", aclList.ID(), aclList.Head().Id, "0") treeStorage := changeCreator.createNewTreeStorage("treeId", aclList.ID(), aclList.Head().Id, "0")
changeBuilder := &mockChangeBuilder{} changeBuilder := &mockChangeBuilder{
originalBuilder: newChangeBuilder(nil),
}
deps := objectTreeDeps{ deps := objectTreeDeps{
changeBuilder: changeBuilder, changeBuilder: changeBuilder,
treeBuilder: newTreeBuilder(treeStorage, changeBuilder), treeBuilder: newTreeBuilder(treeStorage, changeBuilder),
@ -142,7 +147,7 @@ func TestObjectTree(t *testing.T) {
assert.Equal(t, []string{"0"}, res.OldHeads) assert.Equal(t, []string{"0"}, res.OldHeads)
assert.Equal(t, []string{"2"}, res.Heads) assert.Equal(t, []string{"2"}, res.Heads)
assert.Equal(t, len(rawChanges), len(res.Added)) assert.Equal(t, len(rawChanges), len(res.Added))
assert.Equal(t, AddResultSummaryAppend, res.Summary) assert.Equal(t, Append, res.Mode)
// check tree heads // check tree heads
assert.Equal(t, []string{"2"}, objTree.Heads()) assert.Equal(t, []string{"2"}, objTree.Heads())
@ -202,7 +207,7 @@ func TestObjectTree(t *testing.T) {
assert.Equal(t, []string{"0"}, res.OldHeads) assert.Equal(t, []string{"0"}, res.OldHeads)
assert.Equal(t, []string{"0"}, res.Heads) assert.Equal(t, []string{"0"}, res.Heads)
assert.Equal(t, 0, len(res.Added)) assert.Equal(t, 0, len(res.Added))
assert.Equal(t, AddResultSummaryNothing, res.Summary) assert.Equal(t, Nothing, res.Mode)
// check tree heads // check tree heads
assert.Equal(t, []string{"0"}, objTree.Heads()) assert.Equal(t, []string{"0"}, objTree.Heads())
@ -227,7 +232,7 @@ func TestObjectTree(t *testing.T) {
assert.Equal(t, []string{"0"}, res.OldHeads) assert.Equal(t, []string{"0"}, res.OldHeads)
assert.Equal(t, []string{"4"}, res.Heads) assert.Equal(t, []string{"4"}, res.Heads)
assert.Equal(t, len(rawChanges), len(res.Added)) assert.Equal(t, len(rawChanges), len(res.Added))
assert.Equal(t, AddResultSummaryAppend, res.Summary) assert.Equal(t, Append, res.Mode)
// check tree heads // check tree heads
assert.Equal(t, []string{"4"}, objTree.Heads()) assert.Equal(t, []string{"4"}, objTree.Heads())
@ -448,7 +453,7 @@ func TestObjectTree(t *testing.T) {
assert.Equal(t, []string{"3"}, res.OldHeads) assert.Equal(t, []string{"3"}, res.OldHeads)
assert.Equal(t, []string{"6"}, res.Heads) assert.Equal(t, []string{"6"}, res.Heads)
assert.Equal(t, len(rawChanges), len(res.Added)) assert.Equal(t, len(rawChanges), len(res.Added))
assert.Equal(t, AddResultSummaryRebuild, res.Summary) assert.Equal(t, Rebuild, res.Mode)
// check tree heads // check tree heads
assert.Equal(t, []string{"6"}, objTree.Heads()) assert.Equal(t, []string{"6"}, objTree.Heads())

View File

@ -37,17 +37,11 @@ func (r *rawChangeLoader) LoadFromTree(t *Tree, breakpoints []string) ([]*aclpb.
convert := func(chs []*Change) (rawChanges []*aclpb.RawChange, err error) { convert := func(chs []*Change) (rawChanges []*aclpb.RawChange, err error) {
for _, ch := range chs { for _, ch := range chs {
var marshalled []byte var raw *aclpb.RawChange
marshalled, err = ch.Content.Marshal() raw, err = r.changeBuilder.BuildRaw(ch)
if err != nil { if err != nil {
return return
} }
raw := &aclpb.RawChange{
Payload: marshalled,
Signature: ch.Signature(),
Id: ch.Id,
}
rawChanges = append(rawChanges, raw) rawChanges = append(rawChanges, raw)
} }
return return

View File

@ -16,11 +16,12 @@ const (
) )
type Tree struct { type Tree struct {
root *Change root *Change
headIds []string headIds []string
metaHeadIds []string lastIteratedHeadId string
attached map[string]*Change metaHeadIds []string
unAttached map[string]*Change attached map[string]*Change
unAttached map[string]*Change
// missed id -> list of dependency ids // missed id -> list of dependency ids
waitList map[string][]string waitList map[string][]string
invalidChanges map[string]struct{} invalidChanges map[string]struct{}
@ -29,6 +30,7 @@ type Tree struct {
// bufs // bufs
visitedBuf []*Change visitedBuf []*Change
stackBuf []*Change stackBuf []*Change
addedBuf []*Change
duplicateEvents int duplicateEvents int
} }
@ -44,7 +46,8 @@ func (t *Tree) Root() *Change {
return t.root return t.root
} }
func (t *Tree) AddFast(changes ...*Change) { func (t *Tree) AddFast(changes ...*Change) []*Change {
t.addedBuf = t.addedBuf[:0]
for _, c := range changes { for _, c := range changes {
// ignore existing // ignore existing
if _, ok := t.attached[c.Id]; ok { if _, ok := t.attached[c.Id]; ok {
@ -55,6 +58,7 @@ func (t *Tree) AddFast(changes ...*Change) {
t.add(c) t.add(c)
} }
t.updateHeads() t.updateHeads()
return t.addedBuf
} }
func (t *Tree) AddMergedHead(c *Change) error { func (t *Tree) AddMergedHead(c *Change) error {
@ -81,11 +85,12 @@ func (t *Tree) AddMergedHead(c *Change) error {
return nil return nil
} }
func (t *Tree) Add(changes ...*Change) (mode Mode) { func (t *Tree) Add(changes ...*Change) (mode Mode, added []*Change) {
t.addedBuf = t.addedBuf[:0]
var ( var (
beforeHeadIds = t.headIds // this is previous head id which should have been iterated last
attached bool lastIteratedHeadId = t.lastIteratedHeadId
empty = t.Len() == 0 empty = t.Len() == 0
) )
for _, c := range changes { for _, c := range changes {
// ignore existing // ignore existing
@ -94,40 +99,43 @@ func (t *Tree) Add(changes ...*Change) (mode Mode) {
} else if _, ok := t.unAttached[c.Id]; ok { } else if _, ok := t.unAttached[c.Id]; ok {
continue continue
} }
if t.add(c) { t.add(c)
attached = true
}
} }
if !attached { if len(t.addedBuf) == 0 {
return Nothing mode = Nothing
return
} }
t.updateHeads() t.updateHeads()
added = t.addedBuf
if empty { if empty {
return Rebuild mode = Rebuild
return
} }
// beforeHeadsIds is definitely not empty, because the tree is not empty // mode is Append for cases when we can safely start iterating from lastIteratedHeadId to build state
stack := make([]*Change, len(beforeHeadIds), len(beforeHeadIds)) // the idea here is that if all new changes have lastIteratedHeadId as previous,
for i, hid := range beforeHeadIds { // then according to topological sorting order they will be looked at later than lastIteratedHeadId
stack[i] = t.attached[hid] //
} // one important consideration is that if some unattached changes were added to the tree
// as a result of adding new changes, then each of these unattached changes
// mode is Append for cases when we can safely start iterating // will also have at least one of new changes as ancestor
// from old heads to append the state // and that means they will also be iterated later than lastIteratedHeadId
mode = Append mode = Append
t.dfsNext(stack, t.dfsNext([]*Change{t.attached[lastIteratedHeadId]},
func(_ *Change) (isContinue bool) { func(_ *Change) (isContinue bool) {
return true return true
}, },
func(_ []*Change) { func(_ []*Change) {
// checking if some new changes were not visited // checking if some new changes were not visited
for _, ch := range changes { for _, ch := range changes {
// if the change was not added, then skipping // if the change was not added to the tree, then skipping
if _, ok := t.attached[ch.Id]; !ok { if _, ok := t.attached[ch.Id]; !ok {
continue continue
} }
// if some new change was not visited, // if some new change was not visited,
// then we can't start from old heads, we need to start from root, so Rebuild // then we can't start from lastIteratedHeadId,
// we need to start from root, so Rebuild
if !ch.visited { if !ch.visited {
mode = Rebuild mode = Rebuild
break break
@ -135,7 +143,7 @@ func (t *Tree) Add(changes ...*Change) (mode Mode) {
} }
}) })
return mode return
} }
// RemoveInvalidChange removes all the changes that are descendants of id // RemoveInvalidChange removes all the changes that are descendants of id
@ -190,6 +198,7 @@ func (t *Tree) add(c *Change) (attached bool) {
if t.root == nil { // first element if t.root == nil { // first element
t.root = c t.root = c
t.lastIteratedHeadId = t.root.Id
t.attached = map[string]*Change{ t.attached = map[string]*Change{
c.Id: c, c.Id: c,
} }
@ -197,6 +206,7 @@ func (t *Tree) add(c *Change) (attached bool) {
t.waitList = make(map[string][]string) t.waitList = make(map[string][]string)
t.invalidChanges = make(map[string]struct{}) t.invalidChanges = make(map[string]struct{})
t.possibleRoots = make([]*Change, 0, 10) t.possibleRoots = make([]*Change, 0, 10)
t.addedBuf = append(t.addedBuf, c)
return true return true
} }
if len(c.PreviousIds) > 1 { if len(c.PreviousIds) > 1 {
@ -238,6 +248,7 @@ func (t *Tree) canAttach(c *Change) (attach bool) {
func (t *Tree) attach(c *Change, newEl bool) { func (t *Tree) attach(c *Change, newEl bool) {
t.attached[c.Id] = c t.attached[c.Id] = c
t.addedBuf = append(t.addedBuf, c)
if !newEl { if !newEl {
delete(t.unAttached, c.Id) delete(t.unAttached, c.Id)
} }
@ -371,16 +382,16 @@ func (t *Tree) dfsNext(stack []*Change, visit func(ch *Change) (isContinue bool)
func (t *Tree) updateHeads() { func (t *Tree) updateHeads() {
var newHeadIds []string var newHeadIds []string
t.dfsNext( t.iterate(t.root, func(c *Change) (isContinue bool) {
[]*Change{t.root}, if len(c.Next) == 0 {
func(ch *Change) (isContinue bool) { newHeadIds = append(newHeadIds, c.Id)
if len(ch.Next) == 0 { }
newHeadIds = append(newHeadIds, ch.Id) return true
} })
return true
},
nil)
t.headIds = newHeadIds t.headIds = newHeadIds
// the lastIteratedHeadId is the id of the head which was iterated last according to the order
t.lastIteratedHeadId = newHeadIds[len(newHeadIds)-1]
// TODO: check why do we need sorting here
sort.Strings(t.headIds) sort.Strings(t.headIds)
} }

View File

@ -29,19 +29,22 @@ func newSnapshot(id, snapshotId string, prevIds ...string) *Change {
func TestTree_Add(t *testing.T) { func TestTree_Add(t *testing.T) {
t.Run("add first el", func(t *testing.T) { t.Run("add first el", func(t *testing.T) {
tr := new(Tree) tr := new(Tree)
assert.Equal(t, Rebuild, tr.Add(newSnapshot("root", ""))) res, _ := tr.Add(newSnapshot("root", ""))
assert.Equal(t, Rebuild, res)
assert.Equal(t, tr.root.Id, "root") assert.Equal(t, tr.root.Id, "root")
assert.Equal(t, []string{"root"}, tr.Heads()) assert.Equal(t, []string{"root"}, tr.Heads())
}) })
t.Run("linear add", func(t *testing.T) { t.Run("linear add", func(t *testing.T) {
tr := new(Tree) tr := new(Tree)
assert.Equal(t, Rebuild, tr.Add( res, _ := tr.Add(
newSnapshot("root", ""), newSnapshot("root", ""),
newChange("one", "root", "root"), newChange("one", "root", "root"),
newChange("two", "root", "one"), newChange("two", "root", "one"),
)) )
assert.Equal(t, Rebuild, res)
assert.Equal(t, []string{"two"}, tr.Heads()) assert.Equal(t, []string{"two"}, tr.Heads())
assert.Equal(t, Append, tr.Add(newChange("three", "root", "two"))) res, _ = tr.Add(newChange("three", "root", "two"))
assert.Equal(t, Append, res)
el := tr.root el := tr.root
var ids []string var ids []string
for el != nil { for el != nil {
@ -57,17 +60,19 @@ func TestTree_Add(t *testing.T) {
}) })
t.Run("branch", func(t *testing.T) { t.Run("branch", func(t *testing.T) {
tr := new(Tree) tr := new(Tree)
assert.Equal(t, Rebuild, tr.Add( res, _ := tr.Add(
newSnapshot("root", ""), newSnapshot("root", ""),
newChange("1", "root", "root"), newChange("1", "root", "root"),
newChange("2", "root", "1"), newChange("2", "root", "1"),
)) )
assert.Equal(t, Rebuild, res)
assert.Equal(t, []string{"2"}, tr.Heads()) assert.Equal(t, []string{"2"}, tr.Heads())
assert.Equal(t, Rebuild, tr.Add( res, _ = tr.Add(
newChange("1.2", "root", "1.1"), newChange("1.2", "root", "1.1"),
newChange("1.3", "root", "1.2"), newChange("1.3", "root", "1.2"),
newChange("1.1", "root", "1"), newChange("1.1", "root", "1"),
)) )
assert.Equal(t, Rebuild, res)
assert.Len(t, tr.attached["1"].Next, 2) assert.Len(t, tr.attached["1"].Next, 2)
assert.Len(t, tr.unAttached, 0) assert.Len(t, tr.unAttached, 0)
assert.Len(t, tr.attached, 6) assert.Len(t, tr.attached, 6)
@ -75,7 +80,7 @@ func TestTree_Add(t *testing.T) {
}) })
t.Run("branch union", func(t *testing.T) { t.Run("branch union", func(t *testing.T) {
tr := new(Tree) tr := new(Tree)
assert.Equal(t, Rebuild, tr.Add( res, _ := tr.Add(
newSnapshot("root", ""), newSnapshot("root", ""),
newChange("1", "root", "root"), newChange("1", "root", "root"),
newChange("2", "root", "1"), newChange("2", "root", "1"),
@ -84,7 +89,8 @@ func TestTree_Add(t *testing.T) {
newChange("1.1", "root", "1"), newChange("1.1", "root", "1"),
newChange("3", "root", "2", "1.3"), newChange("3", "root", "2", "1.3"),
newChange("4", "root", "3"), newChange("4", "root", "3"),
)) )
assert.Equal(t, Rebuild, res)
assert.Len(t, tr.unAttached, 0) assert.Len(t, tr.unAttached, 0)
assert.Len(t, tr.attached, 8) assert.Len(t, tr.attached, 8)
assert.Equal(t, []string{"4"}, tr.Heads()) assert.Equal(t, []string{"4"}, tr.Heads())

View File

@ -51,7 +51,7 @@ func (t *Tree) makeRootAndRemove(start *Change) {
}, },
func(changes []*Change) { func(changes []*Change) {
for _, ch := range changes { for _, ch := range changes {
delete(t.unAttached, ch.Id) delete(t.attached, ch.Id)
} }
}, },
) )