Further refactoring of ACL logic
This commit is contained in:
parent
697fed7f84
commit
9278cd14d7
@ -89,7 +89,7 @@ func (st *ACLState) ApplyChange(changeId string, change *pb.ACLChange) error {
|
||||
|
||||
for _, ch := range change.GetAclData().GetAclContent() {
|
||||
if err := st.applyChange(changeId, ch); err != nil {
|
||||
log.Infof("error while applying changes: %v; ignore", err)
|
||||
//log.Infof("error while applying changes: %v; ignore", err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,7 +82,7 @@ func (sb *ACLStateBuilder) BuildBefore(beforeId string) (*ACLState, bool, error)
|
||||
if err == nil {
|
||||
startChange = c
|
||||
} else if err != ErrDocumentForbidden {
|
||||
log.Errorf("marking change %s as invalid: %v", c.Id, err)
|
||||
//log.Errorf("marking change %s as invalid: %v", c.Id, err)
|
||||
sb.tree.RemoveInvalidChange(c.Id)
|
||||
}
|
||||
}()
|
||||
|
||||
@ -5,10 +5,8 @@ import (
|
||||
"fmt"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/data/pb"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/data/threadmodels"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/slice"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/textileio/go-threads/core/thread"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -41,7 +39,7 @@ func (tb *ACLTreeBuilder) loadChange(id string) (ch *Change, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer cancel()
|
||||
|
||||
record, err := tb.thread.GetChange(ctx, id)
|
||||
change, err := tb.thread.GetChange(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -49,11 +47,11 @@ func (tb *ACLTreeBuilder) loadChange(id string) (ch *Change, err error) {
|
||||
aclChange := new(pb.ACLChange)
|
||||
|
||||
// TODO: think what should we do with such cases, because this can be used by attacker to break our tree
|
||||
if err = proto.Unmarshal(record.Signed.Payload, aclChange); err != nil {
|
||||
if err = proto.Unmarshal(change.Payload, aclChange); err != nil {
|
||||
return
|
||||
}
|
||||
var verified bool
|
||||
verified, err = tb.verify(aclChange.Identity, record.Signed.Payload, record.Signed.Signature)
|
||||
verified, err = tb.verify(aclChange.Identity, change.Payload, change.Signature)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -80,44 +78,10 @@ func (tb *ACLTreeBuilder) verify(identity string, payload, signature []byte) (is
|
||||
return identityKey.Verify(payload, signature)
|
||||
}
|
||||
|
||||
func (tb *ACLTreeBuilder) getLogs() (logs []threadmodels.ThreadLog, err error) {
|
||||
// TODO: Add beforeId building logic
|
||||
logs, err = tb.thread.GetLogs()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetLogs error: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("build tree: logs: %v", logs)
|
||||
if len(logs) == 0 || len(logs) == 1 && len(logs[0].Head) <= 1 {
|
||||
return nil, ErrEmpty
|
||||
}
|
||||
var nonEmptyLogs = logs[:0]
|
||||
for _, l := range logs {
|
||||
if len(l.Head) == 0 {
|
||||
continue
|
||||
}
|
||||
if ch, err := tb.loadChange(l.Head); err != nil {
|
||||
log.Errorf("loading head %s of the log %s failed: %v", l.Head, l.ID, err)
|
||||
} else {
|
||||
tb.logHeads[l.ID] = ch
|
||||
}
|
||||
nonEmptyLogs = append(nonEmptyLogs, l)
|
||||
}
|
||||
return nonEmptyLogs, nil
|
||||
}
|
||||
|
||||
func (tb *ACLTreeBuilder) Build() (*Tree, error) {
|
||||
logs, err := tb.getLogs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
heads := tb.thread.Heads()
|
||||
|
||||
heads, err := tb.getACLHeads(logs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get acl heads error: %v", err)
|
||||
}
|
||||
|
||||
if err = tb.buildTreeFromStart(heads); err != nil {
|
||||
if err := tb.buildTreeFromStart(heads); err != nil {
|
||||
return nil, fmt.Errorf("buildTree error: %v", err)
|
||||
}
|
||||
tb.cache = nil
|
||||
@ -169,48 +133,6 @@ func (tb *ACLTreeBuilder) dfsFromStart(stack []string) (buf []*Change, possibleR
|
||||
return buf, possibleRoots, nil
|
||||
}
|
||||
|
||||
func (tb *ACLTreeBuilder) getPrecedingACLHeads(head string) ([]string, error) {
|
||||
headChange, err := tb.loadChange(head)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if headChange.Content.GetAclData() != nil {
|
||||
return []string{head}, nil
|
||||
} else {
|
||||
return headChange.PreviousIds, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (tb *ACLTreeBuilder) getACLHeads(logs []threadmodels.ThreadLog) (aclTreeHeads []string, err error) {
|
||||
sort.Slice(logs, func(i, j int) bool {
|
||||
return logs[i].ID < logs[j].ID
|
||||
})
|
||||
|
||||
// get acl tree heads from log heads
|
||||
for _, l := range logs {
|
||||
if slice.FindPos(aclTreeHeads, l.Head) != -1 { // do not scan known heads
|
||||
continue
|
||||
}
|
||||
precedingHeads, err := tb.getPrecedingACLHeads(l.Head)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, head := range precedingHeads {
|
||||
if slice.FindPos(aclTreeHeads, l.Head) != -1 {
|
||||
continue
|
||||
}
|
||||
aclTreeHeads = append(aclTreeHeads, head)
|
||||
}
|
||||
}
|
||||
|
||||
if len(aclTreeHeads) == 0 {
|
||||
return nil, fmt.Errorf("no usable ACL heads in thread")
|
||||
}
|
||||
return aclTreeHeads, nil
|
||||
}
|
||||
|
||||
func (tb *ACLTreeBuilder) getRoot(possibleRoots []*Change) (*Change, error) {
|
||||
threadId, err := thread.Decode(tb.thread.ID())
|
||||
if err != nil {
|
||||
|
||||
@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/data/pb"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/textileio/go-threads/crypto/symmetric"
|
||||
)
|
||||
|
||||
@ -15,9 +14,8 @@ type Change struct {
|
||||
PreviousIds []string
|
||||
Id string
|
||||
SnapshotId string
|
||||
LogHeads map[string]string
|
||||
IsSnapshot bool
|
||||
DecryptedDocumentChange *pb.ACLChangeChangeData
|
||||
DecryptedDocumentChange []byte
|
||||
|
||||
Content *pb.ACLChange
|
||||
}
|
||||
@ -27,17 +25,12 @@ func (ch *Change) DecryptContents(key *symmetric.Key) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var changesData pb.ACLChangeChangeData
|
||||
decrypted, err := key.Decrypt(ch.Content.ChangesData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decrypt changes data: %w", err)
|
||||
}
|
||||
|
||||
err = proto.Unmarshal(decrypted, &changesData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to umarshall into ChangesData: %w", err)
|
||||
}
|
||||
ch.DecryptedDocumentChange = &changesData
|
||||
ch.DecryptedDocumentChange = decrypted
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -53,7 +46,6 @@ func NewChange(id string, ch *pb.ACLChange) (*Change, error) {
|
||||
Content: ch,
|
||||
SnapshotId: ch.SnapshotBaseId,
|
||||
IsSnapshot: ch.GetAclData().GetAclSnapshot() != nil,
|
||||
LogHeads: ch.GetLogHeads(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -65,6 +57,5 @@ func NewACLChange(id string, ch *pb.ACLChange) (*Change, error) {
|
||||
Content: ch,
|
||||
SnapshotId: ch.SnapshotBaseId,
|
||||
IsSnapshot: ch.GetAclData().GetAclSnapshot() != nil,
|
||||
LogHeads: ch.GetLogHeads(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -68,6 +68,10 @@ func (t *ThreadBuilder) GetKeychain() *Keychain {
|
||||
// at the same time this guy can add some random folks which are not in space
|
||||
// but we should compare this against space in the future
|
||||
|
||||
func (t *ThreadBuilder) Heads() []string {
|
||||
return t.heads
|
||||
}
|
||||
|
||||
func (t *ThreadBuilder) GetChange(ctx context.Context, recordID string) (*threadmodels.RawChange, error) {
|
||||
rec := t.allChanges[recordID]
|
||||
|
||||
|
||||
@ -7,6 +7,8 @@ import (
|
||||
|
||||
type Thread interface {
|
||||
ID() string
|
||||
Heads() []string
|
||||
// TODO: add ACL heads
|
||||
GetChange(ctx context.Context, recordID string) (*RawChange, error)
|
||||
PushChange(payload proto.Marshaler) (id string, err error)
|
||||
}
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
package threadmodels
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/lib/core/smartblock"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -11,7 +10,7 @@ func TestCreateACLThreadIDVerify(t *testing.T) {
|
||||
t.Fatalf("should not return error after generating key pair: %v", err)
|
||||
}
|
||||
|
||||
thread, err := CreateACLThreadID(pubKey, smartblock.SmartBlockTypeWorkspace)
|
||||
thread, err := CreateACLThreadID(pubKey, 1)
|
||||
if err != nil {
|
||||
t.Fatalf("should not return error after generating thread: %v", err)
|
||||
}
|
||||
|
||||
@ -6,16 +6,17 @@ import (
|
||||
"fmt"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/data/pb"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/data/threadmodels"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/lib/logging"
|
||||
//"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/lib/logging"
|
||||
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/slice"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/prometheus/common/log"
|
||||
"github.com/textileio/go-threads/core/thread"
|
||||
"sort"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
log = logging.Logger("anytype-data")
|
||||
//log = logging.Logger("anytype-data")
|
||||
ErrEmpty = errors.New("logs empty")
|
||||
)
|
||||
|
||||
@ -48,7 +49,7 @@ func (tb *TreeBuilder) loadChange(id string) (ch *Change, err error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer cancel()
|
||||
|
||||
record, err := tb.thread.GetChange(ctx, id)
|
||||
change, err := tb.thread.GetChange(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -56,11 +57,11 @@ func (tb *TreeBuilder) loadChange(id string) (ch *Change, err error) {
|
||||
aclChange := new(pb.ACLChange)
|
||||
|
||||
// TODO: think what should we do with such cases, because this can be used by attacker to break our tree
|
||||
if err = proto.Unmarshal(record.Signed.Payload, aclChange); err != nil {
|
||||
if err = proto.Unmarshal(change.Payload, aclChange); err != nil {
|
||||
return
|
||||
}
|
||||
var verified bool
|
||||
verified, err = tb.verify(aclChange.Identity, record.Signed.Payload, record.Signed.Signature)
|
||||
verified, err = tb.verify(aclChange.Identity, change.Payload, change.Signature)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -87,40 +88,8 @@ func (tb *TreeBuilder) verify(identity string, payload, signature []byte) (isVer
|
||||
return identityKey.Verify(payload, signature)
|
||||
}
|
||||
|
||||
func (tb *TreeBuilder) getLogs() (logs []threadmodels.ThreadLog, err error) {
|
||||
// TODO: Add beforeId building logic
|
||||
logs, err = tb.thread.GetLogs()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetLogs error: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("build tree: logs: %v", logs)
|
||||
if len(logs) == 0 || len(logs) == 1 && len(logs[0].Head) <= 1 {
|
||||
return nil, ErrEmpty
|
||||
}
|
||||
var nonEmptyLogs = logs[:0]
|
||||
for _, l := range logs {
|
||||
if len(l.Head) == 0 {
|
||||
continue
|
||||
}
|
||||
if ch, err := tb.loadChange(l.Head); err != nil {
|
||||
log.Errorf("loading head %s of the log %s failed: %v", l.Head, l.ID, err)
|
||||
} else {
|
||||
tb.logHeads[l.ID] = ch
|
||||
}
|
||||
nonEmptyLogs = append(nonEmptyLogs, l)
|
||||
}
|
||||
return nonEmptyLogs, nil
|
||||
}
|
||||
|
||||
func (tb *TreeBuilder) Build(fromStart bool) (*Tree, error) {
|
||||
logs, err := tb.getLogs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO: check if this should be changed if we are building from start
|
||||
heads, err := tb.getActualHeads(logs)
|
||||
heads, err := tb.getActualHeads()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get acl heads error: %v", err)
|
||||
}
|
||||
|
||||
120
util/slice/slice.go
Normal file
120
util/slice/slice.go
Normal file
@ -0,0 +1,120 @@
|
||||
package slice
|
||||
|
||||
import (
|
||||
"hash/fnv"
|
||||
"math/rand"
|
||||
"sort"
|
||||
)
|
||||
|
||||
func DifferenceRemovedAdded(a, b []string) (removed []string, added []string) {
|
||||
var amap = map[string]struct{}{}
|
||||
var bmap = map[string]struct{}{}
|
||||
|
||||
for _, item := range a {
|
||||
amap[item] = struct{}{}
|
||||
}
|
||||
|
||||
for _, item := range b {
|
||||
if _, exists := amap[item]; !exists {
|
||||
added = append(added, item)
|
||||
}
|
||||
bmap[item] = struct{}{}
|
||||
}
|
||||
|
||||
for _, item := range a {
|
||||
if _, exists := bmap[item]; !exists {
|
||||
removed = append(removed, item)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func FindPos(s []string, v string) int {
|
||||
for i, sv := range s {
|
||||
if sv == v {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
// Difference returns the elements in `a` that aren't in `b`.
|
||||
func Difference(a, b []string) []string {
|
||||
var diff = make([]string, 0, len(a))
|
||||
for _, a1 := range a {
|
||||
if FindPos(b, a1) == -1 {
|
||||
diff = append(diff, a1)
|
||||
}
|
||||
}
|
||||
return diff
|
||||
}
|
||||
|
||||
func Insert(s []string, pos int, v ...string) []string {
|
||||
if len(s) <= pos {
|
||||
return append(s, v...)
|
||||
}
|
||||
if pos == 0 {
|
||||
return append(v, s[pos:]...)
|
||||
}
|
||||
return append(s[:pos], append(v, s[pos:]...)...)
|
||||
}
|
||||
|
||||
// Remove reuses provided slice capacity. Provided s slice should not be used after without reassigning to the func return!
|
||||
func Remove(s []string, v string) []string {
|
||||
var n int
|
||||
for _, x := range s {
|
||||
if x != v {
|
||||
s[n] = x
|
||||
n++
|
||||
}
|
||||
}
|
||||
return s[:n]
|
||||
}
|
||||
|
||||
func Filter(vals []string, cond func(string) bool) []string {
|
||||
var result = make([]string, 0, len(vals))
|
||||
for i := range vals {
|
||||
if cond(vals[i]) {
|
||||
result = append(result, vals[i])
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func GetRandomString(s []string, seed string) string {
|
||||
rand.Seed(int64(hash(seed)))
|
||||
return s[rand.Intn(len(s))]
|
||||
}
|
||||
|
||||
func hash(s string) uint64 {
|
||||
h := fnv.New64a()
|
||||
h.Write([]byte(s))
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
func SortedEquals(s1, s2 []string) bool {
|
||||
if len(s1) != len(s2) {
|
||||
return false
|
||||
}
|
||||
for i := range s1 {
|
||||
if s1[i] != s2[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func UnsortedEquals(s1, s2 []string) bool {
|
||||
if len(s1) != len(s2) {
|
||||
return false
|
||||
}
|
||||
|
||||
s1Sorted := make([]string, len(s1))
|
||||
s2Sorted := make([]string, len(s2))
|
||||
copy(s1Sorted, s1)
|
||||
copy(s2Sorted, s2)
|
||||
sort.Strings(s1Sorted)
|
||||
sort.Strings(s2Sorted)
|
||||
|
||||
return SortedEquals(s1Sorted, s2Sorted)
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user