Refactor ACLState a little bit

This commit is contained in:
mcrakhman 2022-09-08 23:20:11 +02:00 committed by Mikhail Iudin
parent bf4e4de38c
commit e384d549a2
No known key found for this signature in database
GPG Key ID: FAAAA8BAABDFF1C0
10 changed files with 115 additions and 185 deletions

View File

@ -1,23 +1,23 @@
package tree package common
import ( import (
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/signingkey" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/signingkey"
) )
type keychain struct { type Keychain struct {
decoder keys.Decoder decoder keys.Decoder
keys map[string]signingkey.PubKey keys map[string]signingkey.PubKey
} }
func newKeychain() *keychain { func NewKeychain() *Keychain {
return &keychain{ return &Keychain{
decoder: signingkey.NewEDPubKeyDecoder(), decoder: signingkey.NewEDPubKeyDecoder(),
keys: make(map[string]signingkey.PubKey), keys: make(map[string]signingkey.PubKey),
} }
} }
func (k *keychain) getOrAdd(identity string) (signingkey.PubKey, error) { func (k *Keychain) GetOrAdd(identity string) (signingkey.PubKey, error) {
if key, exists := k.keys[identity]; exists { if key, exists := k.keys[identity]; exists {
return key, nil return key, nil
} }

View File

@ -1,11 +1,11 @@
package list package list
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"github.com/anytypeio/go-anytype-infrastructure-experiments/app/logger" "github.com/anytypeio/go-anytype-infrastructure-experiments/app/logger"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb" "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/common"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/encryptionkey" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/encryptionkey"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/signingkey" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/signingkey"
@ -25,6 +25,7 @@ var ErrUserAlreadyExists = errors.New("user already exists")
var ErrNoSuchRecord = errors.New("no such record") var ErrNoSuchRecord = errors.New("no such record")
var ErrInsufficientPermissions = errors.New("insufficient permissions") var ErrInsufficientPermissions = errors.New("insufficient permissions")
var ErrNoReadKey = errors.New("acl state doesn't have a read key") var ErrNoReadKey = errors.New("acl state doesn't have a read key")
var ErrInvalidSignature = errors.New("signature is invalid")
type UserPermissionPair struct { type UserPermissionPair struct {
Identity string Identity string
@ -36,10 +37,14 @@ type ACLState struct {
userReadKeys map[uint64]*symmetric.Key userReadKeys map[uint64]*symmetric.Key
userStates map[string]*aclpb.ACLChangeUserState userStates map[string]*aclpb.ACLChangeUserState
userInvites map[string]*aclpb.ACLChangeUserInvite userInvites map[string]*aclpb.ACLChangeUserInvite
signingPubKeyDecoder keys.Decoder signingPubKeyDecoder keys.Decoder
encryptionKey encryptionkey.PrivKey encryptionKey encryptionkey.PrivKey
identity string identity string
permissionsAtRecord map[string][]UserPermissionPair permissionsAtRecord map[string][]UserPermissionPair
keychain *common.Keychain
} }
func newACLStateWithIdentity( func newACLStateWithIdentity(
@ -54,6 +59,7 @@ func newACLStateWithIdentity(
userInvites: make(map[string]*aclpb.ACLChangeUserInvite), userInvites: make(map[string]*aclpb.ACLChangeUserInvite),
signingPubKeyDecoder: decoder, signingPubKeyDecoder: decoder,
permissionsAtRecord: make(map[string][]UserPermissionPair), permissionsAtRecord: make(map[string][]UserPermissionPair),
keychain: common.NewKeychain(),
} }
} }
@ -64,6 +70,7 @@ func newACLState(decoder keys.Decoder) *ACLState {
userStates: make(map[string]*aclpb.ACLChangeUserState), userStates: make(map[string]*aclpb.ACLChangeUserState),
userInvites: make(map[string]*aclpb.ACLChangeUserInvite), userInvites: make(map[string]*aclpb.ACLChangeUserInvite),
permissionsAtRecord: make(map[string][]UserPermissionPair), permissionsAtRecord: make(map[string][]UserPermissionPair),
keychain: common.NewKeychain(),
} }
} }
@ -99,7 +106,6 @@ func (st *ACLState) PermissionsAtRecord(id string, identity string) (UserPermiss
} }
func (st *ACLState) applyRecord(record *aclpb.Record) (err error) { func (st *ACLState) applyRecord(record *aclpb.Record) (err error) {
// TODO: this should be probably changed
aclData := &aclpb.ACLChangeACLData{} aclData := &aclpb.ACLChangeACLData{}
err = proto.Unmarshal(record.Data, aclData) err = proto.Unmarshal(record.Data, aclData)
@ -107,35 +113,37 @@ func (st *ACLState) applyRecord(record *aclpb.Record) (err error) {
return return
} }
defer func() { err = st.applyChangeData(aclData, record.CurrentReadKeyHash, record.Identity)
if err != nil { if err != nil {
return return
} }
st.currentReadKeyHash = record.CurrentReadKeyHash
}()
return st.applyChangeData(aclData, record.CurrentReadKeyHash, record.Identity) st.currentReadKeyHash = record.CurrentReadKeyHash
return
} }
func (st *ACLState) applyChangeAndUpdate(recordWrapper *Record) (err error) { func (st *ACLState) applyChangeAndUpdate(recordWrapper *Record) (err error) {
change := recordWrapper.Content var (
aclData := &aclpb.ACLChangeACLData{} change = recordWrapper.Content
aclData = &aclpb.ACLChangeACLData{}
)
if recordWrapper.ParsedModel != nil { if recordWrapper.Model != nil {
aclData = recordWrapper.ParsedModel.(*aclpb.ACLChangeACLData) aclData = recordWrapper.Model.(*aclpb.ACLChangeACLData)
} else { } else {
err = proto.Unmarshal(change.Data, aclData) err = proto.Unmarshal(change.Data, aclData)
if err != nil { if err != nil {
return return
} }
recordWrapper.ParsedModel = aclData recordWrapper.Model = aclData
} }
err = st.applyChangeData(aclData, recordWrapper.Content.CurrentReadKeyHash, recordWrapper.Content.Identity) err = st.applyChangeData(aclData, recordWrapper.Content.CurrentReadKeyHash, recordWrapper.Content.Identity)
if err != nil { if err != nil {
return err return
} }
// getting all permissions for users at record
var permissions []UserPermissionPair var permissions []UserPermissionPair
for _, state := range st.userStates { for _, state := range st.userStates {
permission := UserPermissionPair{ permission := UserPermissionPair{
@ -144,8 +152,8 @@ func (st *ACLState) applyChangeAndUpdate(recordWrapper *Record) (err error) {
} }
permissions = append(permissions, permission) permissions = append(permissions, permission)
} }
st.permissionsAtRecord[recordWrapper.Id] = permissions st.permissionsAtRecord[recordWrapper.Id] = permissions
log.Infof("adding permissions at record %s", recordWrapper.Id)
return nil return nil
} }
@ -243,7 +251,7 @@ func (st *ACLState) applyUserJoin(ch *aclpb.ACLChangeUserJoin) error {
return fmt.Errorf("verification returned error: %w", err) return fmt.Errorf("verification returned error: %w", err)
} }
if !res { if !res {
return fmt.Errorf("signature is invalid") return ErrInvalidSignature
} }
// if ourselves -> we need to decrypt the read keys // if ourselves -> we need to decrypt the read keys
@ -374,87 +382,6 @@ func (st *ACLState) isUserAdd(data *aclpb.ACLChangeACLData, identity string) boo
return data.GetAclContent() != nil && userAdd != nil && userAdd.GetIdentity() == identity return data.GetAclContent() != nil && userAdd != nil && userAdd.GetIdentity() == identity
} }
func (st *ACLState) getPermissionDecreasedUsers(ch *aclpb.ACLChange) (identities []*aclpb.ACLChangeUserPermissionChange) {
// this should be called after general checks are completed
if ch.GetAclData().GetAclContent() == nil {
return nil
}
contents := ch.GetAclData().GetAclContent()
for _, c := range contents {
if c.GetUserPermissionChange() != nil {
content := c.GetUserPermissionChange()
currentState := st.userStates[content.Identity]
// the comparison works in different direction :-)
if content.Permissions > currentState.Permissions {
identities = append(identities, &aclpb.ACLChangeUserPermissionChange{
Identity: content.Identity,
Permissions: content.Permissions,
})
}
}
if c.GetUserRemove() != nil {
content := c.GetUserRemove()
identities = append(identities, &aclpb.ACLChangeUserPermissionChange{
Identity: content.Identity,
Permissions: aclpb.ACLChange_Removed,
})
}
}
return identities
}
func (st *ACLState) equal(other *ACLState) bool {
if st == nil && other == nil {
return true
}
if st == nil || other == nil {
return false
}
if st.currentReadKeyHash != other.currentReadKeyHash {
return false
}
if st.identity != other.identity {
return false
}
if len(st.userStates) != len(other.userStates) {
return false
}
for _, st := range st.userStates {
otherSt, exists := other.userStates[st.Identity]
if !exists {
return false
}
if st.Permissions != otherSt.Permissions {
return false
}
if bytes.Compare(st.EncryptionKey, otherSt.EncryptionKey) != 0 {
return false
}
}
if len(st.userInvites) != len(other.userInvites) {
return false
}
// TODO: add detailed user invites comparison + compare other stuff
return true
}
func (st *ACLState) GetUserStates() map[string]*aclpb.ACLChangeUserState { func (st *ACLState) GetUserStates() map[string]*aclpb.ACLChangeUserState {
// TODO: we should provide better API that would not allow to change this map from the outside
return st.userStates return st.userStates
} }
func (st *ACLState) isNodeIdentity() bool {
return st.identity == ""
}

View File

@ -124,7 +124,7 @@ func (c *aclChangeBuilder) BuildAndApply() (*Record, []byte, error) {
return nil, nil, err return nil, nil, err
} }
ch := NewRecord(id, aclRecord) ch := NewRecord(id, aclRecord)
ch.ParsedModel = c.aclData ch.Model = c.aclData
ch.Sign = signature ch.Sign = signature
return ch, fullMarshalledChange, nil return ch, fullMarshalledChange, nil

View File

@ -2,17 +2,21 @@ package list
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/account" "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/account"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb" "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/common"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/storage" "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/storage"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/cid"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys"
"go.uber.org/zap"
"sync" "sync"
) )
type IterFunc = func(record *Record) (IsContinue bool) type IterFunc = func(record *Record) (IsContinue bool)
var ErrIncorrectCID = errors.New("incorrect CID")
type RWLocker interface { type RWLocker interface {
sync.Locker sync.Locker
RLock() RLock()
@ -41,6 +45,7 @@ type aclList struct {
builder *aclStateBuilder builder *aclStateBuilder
aclState *ACLState aclState *ACLState
keychain *common.Keychain
sync.RWMutex sync.RWMutex
} }
@ -54,36 +59,46 @@ func BuildACLList(decoder keys.Decoder, storage storage.ListStorage) (ACLList, e
return buildWithACLStateBuilder(newACLStateBuilder(decoder), storage) return buildWithACLStateBuilder(newACLStateBuilder(decoder), storage)
} }
func buildWithACLStateBuilder(builder *aclStateBuilder, storage storage.ListStorage) (ACLList, error) { func buildWithACLStateBuilder(builder *aclStateBuilder, storage storage.ListStorage) (list ACLList, err error) {
header, err := storage.Header() header, err := storage.Header()
if err != nil { if err != nil {
return nil, err return
} }
id, err := storage.ID() id, err := storage.ID()
if err != nil { if err != nil {
return nil, err return
} }
rawRecord, err := storage.Head() rawRecord, err := storage.Head()
if err != nil { if err != nil {
return nil, err return
} }
keychain := common.NewKeychain()
record, err := NewFromRawRecord(rawRecord) record, err := NewFromRawRecord(rawRecord)
if err != nil { if err != nil {
return nil, err return
}
err = verifyRecord(keychain, rawRecord, record)
if err != nil {
return
} }
records := []*Record{record} records := []*Record{record}
for record.Content.PrevId != "" { for record.Content.PrevId != "" {
rawRecord, err = storage.GetRawRecord(context.Background(), record.Content.PrevId) rawRecord, err = storage.GetRawRecord(context.Background(), record.Content.PrevId)
if err != nil { if err != nil {
return nil, err return
} }
record, err = NewFromRawRecord(rawRecord) record, err = NewFromRawRecord(rawRecord)
if err != nil { if err != nil {
return nil, err return
}
err = verifyRecord(keychain, rawRecord, record)
if err != nil {
return
} }
records = append(records, record) records = append(records, record)
} }
@ -99,14 +114,12 @@ func buildWithACLStateBuilder(builder *aclStateBuilder, storage storage.ListStor
indexes[records[len(records)/2].Id] = len(records) / 2 indexes[records[len(records)/2].Id] = len(records) / 2
} }
log.With(zap.String("head id", records[len(records)-1].Id), zap.String("list id", id)).
Info("building acl tree")
state, err := builder.Build(records) state, err := builder.Build(records)
if err != nil { if err != nil {
return nil, err return
} }
return &aclList{ list = &aclList{
header: header, header: header,
records: records, records: records,
indexes: indexes, indexes: indexes,
@ -114,7 +127,8 @@ func buildWithACLStateBuilder(builder *aclStateBuilder, storage storage.ListStor
aclState: state, aclState: state,
id: id, id: id,
RWMutex: sync.RWMutex{}, RWMutex: sync.RWMutex{},
}, nil }
return
} }
func (a *aclList) Records() []*Record { func (a *aclList) Records() []*Record {
@ -177,3 +191,26 @@ func (a *aclList) IterateFrom(startId string, iterFunc IterFunc) {
func (a *aclList) Close() (err error) { func (a *aclList) Close() (err error) {
return nil return nil
} }
func verifyRecord(keychain *common.Keychain, rawRecord *aclpb.RawRecord, record *Record) (err error) {
identityKey, err := keychain.GetOrAdd(record.Content.Identity)
if err != nil {
return
}
// verifying signature
res, err := identityKey.Verify(rawRecord.Payload, rawRecord.Signature)
if err != nil {
return
}
if !res {
err = ErrInvalidSignature
return
}
// verifying ID
if !cid.VerifyCID(rawRecord.Payload, rawRecord.Id) {
err = ErrIncorrectCID
}
return
}

View File

@ -8,7 +8,7 @@ import (
type Record struct { type Record struct {
Id string Id string
Content *aclpb.Record Content *aclpb.Record
ParsedModel interface{} Model interface{}
Sign []byte Sign []byte
} }

View File

@ -56,42 +56,6 @@ func (ch *Change) DecryptContents(key *symmetric.Key) error {
return nil return nil
} }
func NewChangeFromRaw(rawChange *aclpb.RawChange) (*Change, error) {
unmarshalled := &aclpb.Change{}
err := proto.Unmarshal(rawChange.Payload, unmarshalled)
if err != nil {
return nil, err
}
ch := NewChange(rawChange.Id, unmarshalled, rawChange.Signature)
return ch, nil
}
func newVerifiedChangeFromRaw(
rawChange *aclpb.RawChange,
kch *keychain) (*Change, error) {
unmarshalled := &aclpb.Change{}
ch, err := NewChangeFromRaw(rawChange)
if err != nil {
return nil, err
}
identityKey, err := kch.getOrAdd(unmarshalled.Identity)
if err != nil {
return nil, err
}
res, err := identityKey.Verify(rawChange.Payload, rawChange.Signature)
if err != nil {
return nil, err
}
if !res {
return nil, ErrIncorrectSignature
}
return ch, nil
}
func NewChange(id string, ch *aclpb.Change, signature []byte) *Change { func NewChange(id string, ch *aclpb.Change, signature []byte) *Change {
return &Change{ return &Change{
Next: nil, Next: nil,

View File

@ -2,6 +2,7 @@ package tree
import ( import (
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb" "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/common"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/cid" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/cid"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/signingkey" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/asymmetric/signingkey"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/symmetric" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/symmetric"
@ -30,10 +31,10 @@ type ChangeBuilder interface {
} }
type changeBuilder struct { type changeBuilder struct {
keys *keychain keys *common.Keychain
} }
func newChangeBuilder(keys *keychain) *changeBuilder { func newChangeBuilder(keys *common.Keychain) *changeBuilder {
return &changeBuilder{keys: keys} return &changeBuilder{keys: keys}
} }
@ -55,7 +56,7 @@ func (c *changeBuilder) ConvertFromRawAndVerify(rawChange *aclpb.RawChange) (ch
return nil, err return nil, err
} }
identityKey, err := c.keys.getOrAdd(unmarshalled.Identity) identityKey, err := c.keys.GetOrAdd(unmarshalled.Identity)
if err != nil { if err != nil {
return return
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"errors" "errors"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb" "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/aclchanges/aclpb"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/common"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/list" "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/list"
"github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/storage" "github.com/anytypeio/go-anytype-infrastructure-experiments/pkg/acl/storage"
"github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/symmetric" "github.com/anytypeio/go-anytype-infrastructure-experiments/util/keys/symmetric"
@ -112,7 +113,7 @@ func defaultObjectTreeDeps(
listener ObjectTreeUpdateListener, listener ObjectTreeUpdateListener,
aclList list.ACLList) objectTreeDeps { aclList list.ACLList) objectTreeDeps {
keychain := newKeychain() keychain := common.NewKeychain()
changeBuilder := newChangeBuilder(keychain) changeBuilder := newChangeBuilder(keychain)
treeBuilder := newTreeBuilder(treeStorage, changeBuilder) treeBuilder := newTreeBuilder(treeStorage, changeBuilder)
return objectTreeDeps{ return objectTreeDeps{

View File

@ -235,21 +235,3 @@ func (r *rawChangeLoader) loadEntry(id string) (entry rawCacheEntry, err error)
} }
return return
} }
func discardFromSlice[T any](elements []T, isDiscarded func(T) bool) []T {
var (
finishedIdx = 0
currentIdx = 0
)
for currentIdx < len(elements) {
if !isDiscarded(elements[currentIdx]) {
if finishedIdx != currentIdx {
elements[finishedIdx] = elements[currentIdx]
}
finishedIdx++
}
currentIdx++
}
elements = elements[:finishedIdx]
return elements
}

View File

@ -27,3 +27,21 @@ OuterLoop:
} }
return ourPath[i+1], nil return ourPath[i+1], nil
} }
func discardFromSlice[T any](elements []T, isDiscarded func(T) bool) []T {
var (
finishedIdx = 0
currentIdx = 0
)
for currentIdx < len(elements) {
if !isDiscarded(elements[currentIdx]) {
if finishedIdx != currentIdx {
elements[finishedIdx] = elements[currentIdx]
}
finishedIdx++
}
currentIdx++
}
elements = elements[:finishedIdx]
return elements
}