diff --git a/common/pkg/acl/storage/inmemory.go b/common/pkg/acl/storage/inmemory.go index 5c39e94d..d3d42b1e 100644 --- a/common/pkg/acl/storage/inmemory.go +++ b/common/pkg/acl/storage/inmemory.go @@ -9,8 +9,10 @@ import ( ) type inMemoryACLListStorage struct { - records []*aclrecordproto.RawACLRecordWithId id string + root *aclrecordproto.RawACLRecordWithId + head string + records map[string]*aclrecordproto.RawACLRecordWithId sync.RWMutex } @@ -18,48 +20,63 @@ type inMemoryACLListStorage struct { func NewInMemoryACLListStorage( id string, records []*aclrecordproto.RawACLRecordWithId) (ListStorage, error) { + + allRecords := make(map[string]*aclrecordproto.RawACLRecordWithId) + for _, ch := range records { + allRecords[ch.Id] = ch + } + root := records[0] + head := records[len(records)-1] + return &inMemoryACLListStorage{ - id: id, - records: records, + id: root.Id, + root: root, + head: head.Id, + records: allRecords, RWMutex: sync.RWMutex{}, }, nil } -func (i *inMemoryACLListStorage) Root() (*aclrecordproto.RawACLRecordWithId, error) { - i.RLock() - defer i.RUnlock() - return i.records[0], nil +func (t *inMemoryACLListStorage) ID() string { + t.RLock() + defer t.RUnlock() + return t.id } -func (i *inMemoryACLListStorage) SetHead(headId string) error { - panic("implement me") +func (t *inMemoryACLListStorage) Root() (*aclrecordproto.RawACLRecordWithId, error) { + t.RLock() + defer t.RUnlock() + return t.root, nil } -func (i *inMemoryACLListStorage) Head() (string, error) { - i.RLock() - defer i.RUnlock() - return i.records[len(i.records)-1].Id, nil +func (t *inMemoryACLListStorage) Head() (string, error) { + t.RLock() + defer t.RUnlock() + return t.head, nil } -func (i *inMemoryACLListStorage) GetRawRecord(ctx context.Context, id string) (*aclrecordproto.RawACLRecordWithId, error) { - i.RLock() - defer i.RUnlock() - for _, rec := range i.records { - if rec.Id == id { - return rec, nil - } +func (t *inMemoryACLListStorage) SetHead(head string) error { + t.Lock() + defer t.Unlock() + t.head = head + return nil +} + +func (t *inMemoryACLListStorage) AddRawRecord(ctx context.Context, record *aclrecordproto.RawACLRecordWithId) error { + t.Lock() + defer t.Unlock() + // TODO: better to do deep copy + t.records[record.Id] = record + return nil +} + +func (t *inMemoryACLListStorage) GetRawRecord(ctx context.Context, recordId string) (*aclrecordproto.RawACLRecordWithId, error) { + t.RLock() + defer t.RUnlock() + if res, exists := t.records[recordId]; exists { + return res, nil } - return nil, fmt.Errorf("no such record") -} - -func (i *inMemoryACLListStorage) AddRawRecord(ctx context.Context, rec *aclrecordproto.RawACLRecordWithId) error { - panic("implement me") -} - -func (i *inMemoryACLListStorage) ID() string { - i.RLock() - defer i.RUnlock() - return i.id + return nil, fmt.Errorf("could not get record with id: %s", recordId) } type inMemoryTreeStorage struct { diff --git a/common/pkg/acl/testutils/acllistbuilder/liststoragebuilder.go b/common/pkg/acl/testutils/acllistbuilder/liststoragebuilder.go index dad7271b..25542e93 100644 --- a/common/pkg/acl/testutils/acllistbuilder/liststoragebuilder.go +++ b/common/pkg/acl/testutils/acllistbuilder/liststoragebuilder.go @@ -3,7 +3,7 @@ package acllistbuilder import ( "context" "fmt" - aclrecordproto2 "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto" + aclrecordproto "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/aclrecordproto" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/storage" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/pkg/acl/testutils/yamltests" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/util/cid" @@ -20,20 +20,12 @@ import ( ) type ACLListStorageBuilder struct { - aclList string - records []*aclrecordproto2.ACLRecord - rawRecords []*aclrecordproto2.RawACLRecordWithId - indexes map[string]int - keychain *YAMLKeychain - rawRoot *aclrecordproto2.RawACLRecordWithId - root *aclrecordproto2.ACLRoot - id string + storage.ListStorage + keychain *YAMLKeychain } func NewACLListStorageBuilder(keychain *YAMLKeychain) *ACLListStorageBuilder { return &ACLListStorageBuilder{ - records: make([]*aclrecordproto2.ACLRecord, 0), - indexes: make(map[string]int), keychain: keychain, } } @@ -61,7 +53,7 @@ func NewACLListStorageBuilderFromFile(file string) (*ACLListStorageBuilder, erro return tb, nil } -func (t *ACLListStorageBuilder) createRaw(rec proto.Marshaler, identity []byte) *aclrecordproto2.RawACLRecordWithId { +func (t *ACLListStorageBuilder) createRaw(rec proto.Marshaler, identity []byte) *aclrecordproto.RawACLRecordWithId { protoMarshalled, err := rec.Marshal() if err != nil { panic("should be able to marshal final acl message!") @@ -72,7 +64,7 @@ func (t *ACLListStorageBuilder) createRaw(rec proto.Marshaler, identity []byte) panic("should be able to sign final acl message!") } - rawRec := &aclrecordproto2.RawACLRecord{ + rawRec := &aclrecordproto.RawACLRecord{ Payload: protoMarshalled, Signature: signature, } @@ -84,85 +76,53 @@ func (t *ACLListStorageBuilder) createRaw(rec proto.Marshaler, identity []byte) id, _ := cid.NewCIDFromBytes(rawMarshalled) - return &aclrecordproto2.RawACLRecordWithId{ + return &aclrecordproto.RawACLRecordWithId{ Payload: rawMarshalled, Id: id, } } -func (t *ACLListStorageBuilder) Head() (string, error) { - l := len(t.records) - if l > 0 { - return t.rawRecords[l-1].Id, nil - } - return t.rawRoot.Id, nil -} - -func (t *ACLListStorageBuilder) SetHead(headId string) error { - panic("SetHead is not implemented") -} - -func (t *ACLListStorageBuilder) Root() (*aclrecordproto2.RawACLRecordWithId, error) { - return t.rawRoot, nil -} - -func (t *ACLListStorageBuilder) GetRawRecord(ctx context.Context, id string) (*aclrecordproto2.RawACLRecordWithId, error) { - recIdx, ok := t.indexes[id] - if !ok { - if id == t.rawRoot.Id { - return t.rawRoot, nil - } - return nil, fmt.Errorf("no such record") - } - return t.rawRecords[recIdx], nil -} - -func (t *ACLListStorageBuilder) AddRawRecord(ctx context.Context, rec *aclrecordproto2.RawACLRecordWithId) error { - panic("implement me") -} - -func (t *ACLListStorageBuilder) ID() string { - return t.id -} - -func (t *ACLListStorageBuilder) GetRawRecords() []*aclrecordproto2.RawACLRecordWithId { - return t.rawRecords -} - func (t *ACLListStorageBuilder) GetKeychain() *YAMLKeychain { return t.keychain } -func (t *ACLListStorageBuilder) Parse(tree *YMLList) { +func (t *ACLListStorageBuilder) Parse(l *YMLList) { // Just to clarify - we are generating new identities for the ones that // are specified in the yml file, because our identities should be Ed25519 // the same thing is happening for the encryption keys - t.keychain.ParseKeys(&tree.Keys) - t.parseRoot(tree.Root) - prevId := t.id - for idx, rec := range tree.Records { + t.keychain.ParseKeys(&l.Keys) + rawRoot := t.parseRoot(l.Root) + var err error + t.ListStorage, err = storage.NewInMemoryACLListStorage(rawRoot.Id, []*aclrecordproto.RawACLRecordWithId{rawRoot}) + if err != nil { + panic(err) + } + prevId := rawRoot.Id + for _, rec := range l.Records { newRecord := t.parseRecord(rec, prevId) rawRecord := t.createRaw(newRecord, newRecord.Identity) - t.records = append(t.records, newRecord) - t.rawRecords = append(t.rawRecords, rawRecord) - t.indexes[rawRecord.Id] = idx + err = t.AddRawRecord(context.Background(), rawRecord) + if err != nil { + panic(err) + } prevId = rawRecord.Id } + t.SetHead(prevId) } -func (t *ACLListStorageBuilder) parseRecord(rec *Record, prevId string) *aclrecordproto2.ACLRecord { +func (t *ACLListStorageBuilder) parseRecord(rec *Record, prevId string) *aclrecordproto.ACLRecord { k := t.keychain.GetKey(rec.ReadKey).(*SymKey) - var aclChangeContents []*aclrecordproto2.ACLContentValue + var aclChangeContents []*aclrecordproto.ACLContentValue for _, ch := range rec.AclChanges { aclChangeContent := t.parseACLChange(ch) aclChangeContents = append(aclChangeContents, aclChangeContent) } - data := &aclrecordproto2.ACLData{ + data := &aclrecordproto.ACLData{ AclContent: aclChangeContents, } bytes, _ := data.Marshal() - return &aclrecordproto2.ACLRecord{ + return &aclrecordproto.ACLRecord{ PrevId: prevId, Identity: []byte(t.keychain.GetIdentity(rec.Identity)), Data: bytes, @@ -171,7 +131,7 @@ func (t *ACLListStorageBuilder) parseRecord(rec *Record, prevId string) *aclreco } } -func (t *ACLListStorageBuilder) parseACLChange(ch *ACLChange) (convCh *aclrecordproto2.ACLContentValue) { +func (t *ACLListStorageBuilder) parseACLChange(ch *ACLChange) (convCh *aclrecordproto.ACLContentValue) { switch { case ch.UserAdd != nil: add := ch.UserAdd @@ -179,9 +139,9 @@ func (t *ACLListStorageBuilder) parseACLChange(ch *ACLChange) (convCh *aclrecord encKey := t.keychain.GetKey(add.EncryptionKey).(encryptionkey.PrivKey) rawKey, _ := encKey.GetPublic().Raw() - convCh = &aclrecordproto2.ACLContentValue{ - Value: &aclrecordproto2.ACLContentValue_UserAdd{ - UserAdd: &aclrecordproto2.ACLUserAdd{ + convCh = &aclrecordproto.ACLContentValue{ + Value: &aclrecordproto.ACLContentValue_UserAdd{ + UserAdd: &aclrecordproto.ACLUserAdd{ Identity: []byte(t.keychain.GetIdentity(add.Identity)), EncryptionKey: rawKey, EncryptedReadKeys: t.encryptReadKeysWithPubKey(add.EncryptedReadKeys, encKey), @@ -203,9 +163,9 @@ func (t *ACLListStorageBuilder) parseACLChange(ch *ACLChange) (convCh *aclrecord } acceptPubKey, _ := signKey.GetPublic().Raw() - convCh = &aclrecordproto2.ACLContentValue{ - Value: &aclrecordproto2.ACLContentValue_UserJoin{ - UserJoin: &aclrecordproto2.ACLUserJoin{ + convCh = &aclrecordproto.ACLContentValue{ + Value: &aclrecordproto.ACLContentValue_UserJoin{ + UserJoin: &aclrecordproto.ACLUserJoin{ Identity: []byte(t.keychain.GetIdentity(join.Identity)), EncryptionKey: rawKey, AcceptSignature: signature, @@ -220,9 +180,9 @@ func (t *ACLListStorageBuilder) parseACLChange(ch *ACLChange) (convCh *aclrecord hash := t.keychain.GetKey(invite.EncryptionKey).(*SymKey).Hash encKey := t.keychain.ReadKeysByHash[hash] - convCh = &aclrecordproto2.ACLContentValue{ - Value: &aclrecordproto2.ACLContentValue_UserInvite{ - UserInvite: &aclrecordproto2.ACLUserInvite{ + convCh = &aclrecordproto.ACLContentValue{ + Value: &aclrecordproto.ACLContentValue_UserInvite{ + UserInvite: &aclrecordproto.ACLUserInvite{ AcceptPublicKey: rawAcceptKey, EncryptSymKeyHash: hash, EncryptedReadKeys: t.encryptReadKeysWithSymKey(invite.EncryptedReadKeys, encKey.Key), @@ -233,9 +193,9 @@ func (t *ACLListStorageBuilder) parseACLChange(ch *ACLChange) (convCh *aclrecord case ch.UserPermissionChange != nil: permissionChange := ch.UserPermissionChange - convCh = &aclrecordproto2.ACLContentValue{ - Value: &aclrecordproto2.ACLContentValue_UserPermissionChange{ - UserPermissionChange: &aclrecordproto2.ACLUserPermissionChange{ + convCh = &aclrecordproto.ACLContentValue{ + Value: &aclrecordproto.ACLContentValue_UserPermissionChange{ + UserPermissionChange: &aclrecordproto.ACLUserPermissionChange{ Identity: []byte(t.keychain.GetIdentity(permissionChange.Identity)), Permissions: t.convertPermission(permissionChange.Permission), }, @@ -246,7 +206,7 @@ func (t *ACLListStorageBuilder) parseACLChange(ch *ACLChange) (convCh *aclrecord newReadKey := t.keychain.GetKey(remove.NewReadKey).(*SymKey) - var replaces []*aclrecordproto2.ACLReadKeyReplace + var replaces []*aclrecordproto.ACLReadKeyReplace for _, id := range remove.IdentitiesLeft { encKey := t.keychain.EncryptionKeysByYAMLIdentity[id] rawEncKey, _ := encKey.GetPublic().Raw() @@ -254,16 +214,16 @@ func (t *ACLListStorageBuilder) parseACLChange(ch *ACLChange) (convCh *aclrecord if err != nil { panic(err) } - replaces = append(replaces, &aclrecordproto2.ACLReadKeyReplace{ + replaces = append(replaces, &aclrecordproto.ACLReadKeyReplace{ Identity: []byte(t.keychain.GetIdentity(id)), EncryptionKey: rawEncKey, EncryptedReadKey: encReadKey, }) } - convCh = &aclrecordproto2.ACLContentValue{ - Value: &aclrecordproto2.ACLContentValue_UserRemove{ - UserRemove: &aclrecordproto2.ACLUserRemove{ + convCh = &aclrecordproto.ACLContentValue{ + Value: &aclrecordproto.ACLContentValue_UserRemove{ + UserRemove: &aclrecordproto.ACLUserRemove{ Identity: []byte(t.keychain.GetIdentity(remove.RemovedIdentity)), ReadKeyReplaces: replaces, }, @@ -303,36 +263,30 @@ func (t *ACLListStorageBuilder) encryptReadKeysWithSymKey(keys []string, key *sy return } -func (t *ACLListStorageBuilder) convertPermission(perm string) aclrecordproto2.ACLUserPermissions { +func (t *ACLListStorageBuilder) convertPermission(perm string) aclrecordproto.ACLUserPermissions { switch perm { case "admin": - return aclrecordproto2.ACLUserPermissions_Admin + return aclrecordproto.ACLUserPermissions_Admin case "writer": - return aclrecordproto2.ACLUserPermissions_Writer + return aclrecordproto.ACLUserPermissions_Writer case "reader": - return aclrecordproto2.ACLUserPermissions_Reader + return aclrecordproto.ACLUserPermissions_Reader default: panic(fmt.Sprintf("incorrect permission: %s", perm)) } } -func (t *ACLListStorageBuilder) traverseFromHead(f func(rec *aclrecordproto2.ACLRecord, id string) error) (err error) { - for i := len(t.records) - 1; i >= 0; i-- { - err = f(t.records[i], t.rawRecords[i].Id) - if err != nil { - return err - } - } - return nil +func (t *ACLListStorageBuilder) traverseFromHead(f func(rec *aclrecordproto.ACLRecord, id string) error) (err error) { + panic("this was removed, add if needed") } -func (t *ACLListStorageBuilder) parseRoot(root *Root) { +func (t *ACLListStorageBuilder) parseRoot(root *Root) (rawRoot *aclrecordproto.RawACLRecordWithId) { rawSignKey, _ := t.keychain.SigningKeysByYAMLIdentity[root.Identity].GetPublic().Raw() rawEncKey, _ := t.keychain.EncryptionKeysByYAMLIdentity[root.Identity].GetPublic().Raw() - readKey, _ := aclrecordproto2.ACLReadKeyDerive(rawSignKey, rawEncKey) + readKey, _ := aclrecordproto.ACLReadKeyDerive(rawSignKey, rawEncKey) hasher := fnv.New64() hasher.Write(readKey.Bytes()) - t.root = &aclrecordproto2.ACLRoot{ + aclRoot := &aclrecordproto.ACLRoot{ Identity: rawSignKey, EncryptionKey: rawEncKey, SpaceId: root.SpaceId, @@ -340,6 +294,5 @@ func (t *ACLListStorageBuilder) parseRoot(root *Root) { DerivationScheme: "scheme", CurrentReadKeyHash: hasher.Sum64(), } - t.rawRoot = t.createRaw(t.root, rawSignKey) - t.id = t.rawRoot.Id + return t.createRaw(aclRoot, rawSignKey) }