Refactor ACLState a little bit

This commit is contained in:
mcrakhman 2022-09-08 23:20:11 +02:00 committed by Mikhail Iudin
parent bf4e4de38c
commit e384d549a2
No known key found for this signature in database
GPG Key ID: FAAAA8BAABDFF1C0
10 changed files with 115 additions and 185 deletions

View File

@ -1,23 +1,23 @@
package tree
package common
import (
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/signingkey"
)
type keychain struct {
type Keychain struct {
decoder keys.Decoder
keys map[string]signingkey.PubKey
}
func newKeychain() *keychain {
return &keychain{
func NewKeychain() *Keychain {
return &Keychain{
decoder: signingkey.NewEDPubKeyDecoder(),
keys: make(map[string]signingkey.PubKey),
}
}
func (k *keychain) getOrAdd(identity string) (signingkey.PubKey, error) {
func (k *Keychain) GetOrAdd(identity string) (signingkey.PubKey, error) {
if key, exists := k.keys[identity]; exists {
return key, nil
}

View File

@ -1,11 +1,11 @@
package list
import (
"bytes"
"errors"
"fmt"
"github.com/anytypeio/go-anytype-infrastructure-experiments/app/logger"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/common"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/encryptionkey"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/signingkey"
@ -25,6 +25,7 @@ var ErrUserAlreadyExists = errors.New("user already exists")
var ErrNoSuchRecord = errors.New("no such record")
var ErrInsufficientPermissions = errors.New("insufficient permissions")
var ErrNoReadKey = errors.New("acl state doesn't have a read key")
var ErrInvalidSignature = errors.New("signature is invalid")
type UserPermissionPair struct {
Identity string
@ -32,14 +33,18 @@ type UserPermissionPair struct {
}
type ACLState struct {
currentReadKeyHash uint64
userReadKeys map[uint64]*symmetric.Key
userStates map[string]*aclpb.ACLChangeUserState
userInvites map[string]*aclpb.ACLChangeUserInvite
currentReadKeyHash uint64
userReadKeys map[uint64]*symmetric.Key
userStates map[string]*aclpb.ACLChangeUserState
userInvites map[string]*aclpb.ACLChangeUserInvite
signingPubKeyDecoder keys.Decoder
encryptionKey encryptionkey.PrivKey
identity string
permissionsAtRecord map[string][]UserPermissionPair
identity string
permissionsAtRecord map[string][]UserPermissionPair
keychain *common.Keychain
}
func newACLStateWithIdentity(
@ -54,6 +59,7 @@ func newACLStateWithIdentity(
userInvites: make(map[string]*aclpb.ACLChangeUserInvite),
signingPubKeyDecoder: decoder,
permissionsAtRecord: make(map[string][]UserPermissionPair),
keychain: common.NewKeychain(),
}
}
@ -64,6 +70,7 @@ func newACLState(decoder keys.Decoder) *ACLState {
userStates: make(map[string]*aclpb.ACLChangeUserState),
userInvites: make(map[string]*aclpb.ACLChangeUserInvite),
permissionsAtRecord: make(map[string][]UserPermissionPair),
keychain: common.NewKeychain(),
}
}
@ -99,7 +106,6 @@ func (st *ACLState) PermissionsAtRecord(id string, identity string) (UserPermiss
}
func (st *ACLState) applyRecord(record *aclpb.Record) (err error) {
// TODO: this should be probably changed
aclData := &aclpb.ACLChangeACLData{}
err = proto.Unmarshal(record.Data, aclData)
@ -107,35 +113,37 @@ func (st *ACLState) applyRecord(record *aclpb.Record) (err error) {
return
}
defer func() {
if err != nil {
return
}
st.currentReadKeyHash = record.CurrentReadKeyHash
}()
err = st.applyChangeData(aclData, record.CurrentReadKeyHash, record.Identity)
if err != nil {
return
}
return st.applyChangeData(aclData, record.CurrentReadKeyHash, record.Identity)
st.currentReadKeyHash = record.CurrentReadKeyHash
return
}
func (st *ACLState) applyChangeAndUpdate(recordWrapper *Record) (err error) {
change := recordWrapper.Content
aclData := &aclpb.ACLChangeACLData{}
var (
change = recordWrapper.Content
aclData = &aclpb.ACLChangeACLData{}
)
if recordWrapper.ParsedModel != nil {
aclData = recordWrapper.ParsedModel.(*aclpb.ACLChangeACLData)
if recordWrapper.Model != nil {
aclData = recordWrapper.Model.(*aclpb.ACLChangeACLData)
} else {
err = proto.Unmarshal(change.Data, aclData)
if err != nil {
return
}
recordWrapper.ParsedModel = aclData
recordWrapper.Model = aclData
}
err = st.applyChangeData(aclData, recordWrapper.Content.CurrentReadKeyHash, recordWrapper.Content.Identity)
if err != nil {
return err
return
}
// getting all permissions for users at record
var permissions []UserPermissionPair
for _, state := range st.userStates {
permission := UserPermissionPair{
@ -144,8 +152,8 @@ func (st *ACLState) applyChangeAndUpdate(recordWrapper *Record) (err error) {
}
permissions = append(permissions, permission)
}
st.permissionsAtRecord[recordWrapper.Id] = permissions
log.Infof("adding permissions at record %s", recordWrapper.Id)
return nil
}
@ -243,7 +251,7 @@ func (st *ACLState) applyUserJoin(ch *aclpb.ACLChangeUserJoin) error {
return fmt.Errorf("verification returned error: %w", err)
}
if !res {
return fmt.Errorf("signature is invalid")
return ErrInvalidSignature
}
// if ourselves -> we need to decrypt the read keys
@ -374,87 +382,6 @@ func (st *ACLState) isUserAdd(data *aclpb.ACLChangeACLData, identity string) boo
return data.GetAclContent() != nil && userAdd != nil && userAdd.GetIdentity() == identity
}
func (st *ACLState) getPermissionDecreasedUsers(ch *aclpb.ACLChange) (identities []*aclpb.ACLChangeUserPermissionChange) {
// this should be called after general checks are completed
if ch.GetAclData().GetAclContent() == nil {
return nil
}
contents := ch.GetAclData().GetAclContent()
for _, c := range contents {
if c.GetUserPermissionChange() != nil {
content := c.GetUserPermissionChange()
currentState := st.userStates[content.Identity]
// the comparison works in different direction :-)
if content.Permissions > currentState.Permissions {
identities = append(identities, &aclpb.ACLChangeUserPermissionChange{
Identity: content.Identity,
Permissions: content.Permissions,
})
}
}
if c.GetUserRemove() != nil {
content := c.GetUserRemove()
identities = append(identities, &aclpb.ACLChangeUserPermissionChange{
Identity: content.Identity,
Permissions: aclpb.ACLChange_Removed,
})
}
}
return identities
}
func (st *ACLState) equal(other *ACLState) bool {
if st == nil && other == nil {
return true
}
if st == nil || other == nil {
return false
}
if st.currentReadKeyHash != other.currentReadKeyHash {
return false
}
if st.identity != other.identity {
return false
}
if len(st.userStates) != len(other.userStates) {
return false
}
for _, st := range st.userStates {
otherSt, exists := other.userStates[st.Identity]
if !exists {
return false
}
if st.Permissions != otherSt.Permissions {
return false
}
if bytes.Compare(st.EncryptionKey, otherSt.EncryptionKey) != 0 {
return false
}
}
if len(st.userInvites) != len(other.userInvites) {
return false
}
// TODO: add detailed user invites comparison + compare other stuff
return true
}
func (st *ACLState) GetUserStates() map[string]*aclpb.ACLChangeUserState {
// TODO: we should provide better API that would not allow to change this map from the outside
return st.userStates
}
func (st *ACLState) isNodeIdentity() bool {
return st.identity == ""
}

View File

@ -124,7 +124,7 @@ func (c *aclChangeBuilder) BuildAndApply() (*Record, []byte, error) {
return nil, nil, err
}
ch := NewRecord(id, aclRecord)
ch.ParsedModel = c.aclData
ch.Model = c.aclData
ch.Sign = signature
return ch, fullMarshalledChange, nil

View File

@ -2,17 +2,21 @@ package list
import (
"context"
"errors"
"fmt"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/account"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/common"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/storage"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/cid"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys"
"go.uber.org/zap"
"sync"
)
type IterFunc = func(record *Record) (IsContinue bool)
var ErrIncorrectCID = errors.New("incorrect CID")
type RWLocker interface {
sync.Locker
RLock()
@ -41,6 +45,7 @@ type aclList struct {
builder *aclStateBuilder
aclState *ACLState
keychain *common.Keychain
sync.RWMutex
}
@ -54,36 +59,46 @@ func BuildACLList(decoder keys.Decoder, storage storage.ListStorage) (ACLList, e
return buildWithACLStateBuilder(newACLStateBuilder(decoder), storage)
}
func buildWithACLStateBuilder(builder *aclStateBuilder, storage storage.ListStorage) (ACLList, error) {
func buildWithACLStateBuilder(builder *aclStateBuilder, storage storage.ListStorage) (list ACLList, err error) {
header, err := storage.Header()
if err != nil {
return nil, err
return
}
id, err := storage.ID()
if err != nil {
return nil, err
return
}
rawRecord, err := storage.Head()
if err != nil {
return nil, err
return
}
keychain := common.NewKeychain()
record, err := NewFromRawRecord(rawRecord)
if err != nil {
return nil, err
return
}
err = verifyRecord(keychain, rawRecord, record)
if err != nil {
return
}
records := []*Record{record}
for record.Content.PrevId != "" {
rawRecord, err = storage.GetRawRecord(context.Background(), record.Content.PrevId)
if err != nil {
return nil, err
return
}
record, err = NewFromRawRecord(rawRecord)
if err != nil {
return nil, err
return
}
err = verifyRecord(keychain, rawRecord, record)
if err != nil {
return
}
records = append(records, record)
}
@ -99,14 +114,12 @@ func buildWithACLStateBuilder(builder *aclStateBuilder, storage storage.ListStor
indexes[records[len(records)/2].Id] = len(records) / 2
}
log.With(zap.String("head id", records[len(records)-1].Id), zap.String("list id", id)).
Info("building acl tree")
state, err := builder.Build(records)
if err != nil {
return nil, err
return
}
return &aclList{
list = &aclList{
header: header,
records: records,
indexes: indexes,
@ -114,7 +127,8 @@ func buildWithACLStateBuilder(builder *aclStateBuilder, storage storage.ListStor
aclState: state,
id: id,
RWMutex: sync.RWMutex{},
}, nil
}
return
}
func (a *aclList) Records() []*Record {
@ -177,3 +191,26 @@ func (a *aclList) IterateFrom(startId string, iterFunc IterFunc) {
func (a *aclList) Close() (err error) {
return nil
}
func verifyRecord(keychain *common.Keychain, rawRecord *aclpb.RawRecord, record *Record) (err error) {
identityKey, err := keychain.GetOrAdd(record.Content.Identity)
if err != nil {
return
}
// verifying signature
res, err := identityKey.Verify(rawRecord.Payload, rawRecord.Signature)
if err != nil {
return
}
if !res {
err = ErrInvalidSignature
return
}
// verifying ID
if !cid.VerifyCID(rawRecord.Payload, rawRecord.Id) {
err = ErrIncorrectCID
}
return
}

View File

@ -6,10 +6,10 @@ import (
)
type Record struct {
Id string
Content *aclpb.Record
ParsedModel interface{}
Sign []byte
Id string
Content *aclpb.Record
Model interface{}
Sign []byte
}
func NewRecord(id string, aclRecord *aclpb.Record) *Record {

View File

@ -56,42 +56,6 @@ func (ch *Change) DecryptContents(key *symmetric.Key) error {
return nil
}
func NewChangeFromRaw(rawChange *aclpb.RawChange) (*Change, error) {
unmarshalled := &aclpb.Change{}
err := proto.Unmarshal(rawChange.Payload, unmarshalled)
if err != nil {
return nil, err
}
ch := NewChange(rawChange.Id, unmarshalled, rawChange.Signature)
return ch, nil
}
func newVerifiedChangeFromRaw(
rawChange *aclpb.RawChange,
kch *keychain) (*Change, error) {
unmarshalled := &aclpb.Change{}
ch, err := NewChangeFromRaw(rawChange)
if err != nil {
return nil, err
}
identityKey, err := kch.getOrAdd(unmarshalled.Identity)
if err != nil {
return nil, err
}
res, err := identityKey.Verify(rawChange.Payload, rawChange.Signature)
if err != nil {
return nil, err
}
if !res {
return nil, ErrIncorrectSignature
}
return ch, nil
}
func NewChange(id string, ch *aclpb.Change, signature []byte) *Change {
return &Change{
Next: nil,

View File

@ -2,6 +2,7 @@ package tree
import (
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/common"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/cid"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/signingkey"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/symmetric"
@ -30,10 +31,10 @@ type ChangeBuilder interface {
}
type changeBuilder struct {
keys *keychain
keys *common.Keychain
}
func newChangeBuilder(keys *keychain) *changeBuilder {
func newChangeBuilder(keys *common.Keychain) *changeBuilder {
return &changeBuilder{keys: keys}
}
@ -55,7 +56,7 @@ func (c *changeBuilder) ConvertFromRawAndVerify(rawChange *aclpb.RawChange) (ch
return nil, err
}
identityKey, err := c.keys.getOrAdd(unmarshalled.Identity)
identityKey, err := c.keys.GetOrAdd(unmarshalled.Identity)
if err != nil {
return
}

View File

@ -4,6 +4,7 @@ import (
"context"
"errors"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/common"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/list"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/storage"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/symmetric"
@ -112,7 +113,7 @@ func defaultObjectTreeDeps(
listener ObjectTreeUpdateListener,
aclList list.ACLList) objectTreeDeps {
keychain := newKeychain()
keychain := common.NewKeychain()
changeBuilder := newChangeBuilder(keychain)
treeBuilder := newTreeBuilder(treeStorage, changeBuilder)
return objectTreeDeps{

View File

@ -235,21 +235,3 @@ func (r *rawChangeLoader) loadEntry(id string) (entry rawCacheEntry, err error)
}
return
}
func discardFromSlice[T any](elements []T, isDiscarded func(T) bool) []T {
var (
finishedIdx = 0
currentIdx = 0
)
for currentIdx < len(elements) {
if !isDiscarded(elements[currentIdx]) {
if finishedIdx != currentIdx {
elements[finishedIdx] = elements[currentIdx]
}
finishedIdx++
}
currentIdx++
}
elements = elements[:finishedIdx]
return elements
}

View File

@ -27,3 +27,21 @@ OuterLoop:
}
return ourPath[i+1], nil
}
func discardFromSlice[T any](elements []T, isDiscarded func(T) bool) []T {
var (
finishedIdx = 0
currentIdx = 0
)
for currentIdx < len(elements) {
if !isDiscarded(elements[currentIdx]) {
if finishedIdx != currentIdx {
elements[finishedIdx] = elements[currentIdx]
}
finishedIdx++
}
currentIdx++
}
elements = elements[:finishedIdx]
return elements
}