Compare commits

..

No commits in common. "v0.0.5" and "main" have entirely different histories.
v0.0.5 ... main

301 changed files with 39347 additions and 13020 deletions

14
.github/dependabot.yml vendored Normal file
View File

@ -0,0 +1,14 @@
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "gomod"
directory: "/" # Location of package manifests
schedule:
interval: "weekly"
ignore:
- dependency-name: "github.com/anyproto/go-chash"

View File

@ -7,7 +7,7 @@ jobs:
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
env: env:
GOPRIVATE: github.com/anytypeio GOPRIVATE: github.com/anyproto
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- uses: actions/setup-go@v3 - uses: actions/setup-go@v3
@ -17,20 +17,20 @@ jobs:
- name: git config - name: git config
run: git config --global url.https://${{ secrets.ANYTYPE_PAT }}@github.com/.insteadOf https://github.com/ run: git config --global url.https://${{ secrets.ANYTYPE_PAT }}@github.com/.insteadOf https://github.com/
# cache {{ # # cache {{
- id: go-cache-paths # - id: go-cache-paths
run: | # run: |
echo "GOCACHE=$(go env GOCACHE)" >> $GITHUB_OUTPUT # echo "GOCACHE=$(go env GOCACHE)" >> $GITHUB_OUTPUT
echo "GOMODCACHE=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT # echo "GOMODCACHE=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- uses: actions/cache@v3 # - uses: actions/cache@v3
with: # with:
path: | # path: |
${{ steps.go-cache-paths.outputs.GOCACHE }} # ${{ steps.go-cache-paths.outputs.GOCACHE }}
${{ steps.go-cache-paths.outputs.GOMODCACHE }} # ${{ steps.go-cache-paths.outputs.GOMODCACHE }}
key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('**/go.sum') }} # key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('**/go.sum') }}
restore-keys: | # restore-keys: |
${{ runner.os }}-go-${{ matrix.go-version }}- # ${{ runner.os }}-go-${{ matrix.go-version }}-
# }} # # }}
- name: deps - name: deps
run: make deps run: make deps

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Any Association
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1,5 +1,5 @@
.PHONY: proto test deps .PHONY: proto test deps
export GOPRIVATE=github.com/anytypeio export GOPRIVATE=github.com/anyproto
export PATH:=deps:$(PATH) export PATH:=deps:$(PATH)
proto: proto:
@ -7,15 +7,20 @@ proto:
@$(eval P_ACL_RECORDS_PATH_PB := commonspace/object/acl/aclrecordproto) @$(eval P_ACL_RECORDS_PATH_PB := commonspace/object/acl/aclrecordproto)
@$(eval P_TREE_CHANGES_PATH_PB := commonspace/object/tree/treechangeproto) @$(eval P_TREE_CHANGES_PATH_PB := commonspace/object/tree/treechangeproto)
@$(eval P_ACL_RECORDS := M$(P_ACL_RECORDS_PATH_PB)/protos/aclrecord.proto=github.com/anytypeio/any-sync/$(P_ACL_RECORDS_PATH_PB)) @$(eval P_CRYPTO_PATH_PB := util/crypto/cryptoproto)
@$(eval P_TREE_CHANGES := M$(P_TREE_CHANGES_PATH_PB)/protos/treechange.proto=github.com/anytypeio/any-sync/$(P_TREE_CHANGES_PATH_PB)) @$(eval P_ACL_RECORDS := M$(P_ACL_RECORDS_PATH_PB)/protos/aclrecord.proto=github.com/anyproto/any-sync/$(P_ACL_RECORDS_PATH_PB))
@$(eval P_TREE_CHANGES := M$(P_TREE_CHANGES_PATH_PB)/protos/treechange.proto=github.com/anyproto/any-sync/$(P_TREE_CHANGES_PATH_PB))
protoc --gogofaster_out=:. $(P_ACL_RECORDS_PATH_PB)/protos/*.proto protoc --gogofaster_out=:. $(P_ACL_RECORDS_PATH_PB)/protos/*.proto
protoc --gogofaster_out=:. $(P_TREE_CHANGES_PATH_PB)/protos/*.proto protoc --gogofaster_out=:. $(P_TREE_CHANGES_PATH_PB)/protos/*.proto
protoc --gogofaster_out=:. $(P_CRYPTO_PATH_PB)/protos/*.proto
$(eval PKGMAP := $$(P_TREE_CHANGES),$$(P_ACL_RECORDS)) $(eval PKGMAP := $$(P_TREE_CHANGES),$$(P_ACL_RECORDS))
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. commonspace/spacesyncproto/protos/*.proto protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. commonspace/spacesyncproto/protos/*.proto
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. commonfile/fileproto/protos/*.proto protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. commonfile/fileproto/protos/*.proto
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. net/streampool/testservice/protos/*.proto protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. net/streampool/testservice/protos/*.proto
protoc --gogofaster_out=:. net/secureservice/handshake/handshakeproto/protos/*.proto
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. coordinator/coordinatorproto/protos/*.proto
protoc --gogofaster_out=:. --go-drpc_out=protolib=github.com/gogo/protobuf:. consensus/consensusproto/protos/*.proto
deps: deps:
go mod download go mod download

43
README.md Normal file
View File

@ -0,0 +1,43 @@
# Any-Sync
Any-Sync is an open-source protocol designed to create high-performance, local-first, peer-to-peer, end-to-end encrypted applications that facilitate seamless collaboration among multiple users and devices.
By utilizing this protocol, users can rest assured that they retain complete control over their data and digital experience. They are empowered to freely transition between various service providers, or even opt to self-host the applications.
This ensures utmost flexibility and autonomy for users in managing their personal information and digital interactions.
## Introduction
Most existing information management tools are implemented on centralized client-server architecture or designed for an offline-first single-user usage. Either way there are trade-offs for users: they can face restricted freedoms and privacy violations or compromise on the functionality of tools to avoid this.
We believe this goes against fundamental digital freedoms and that a new generation of software is needed that will respect these freedoms, while providing best in-class user experience.
Our goal with `any-sync` is to develop a protocol that will enable the deployment of this software.
Features:
- Conflict-free data replication across multiple devices and agents
- Built-in end-to-end encryption
- Cryptographically verifiable history of changes
- Adoption to frequent operations (high performance)
- Reliable and scalable infrastructure
- Simultaneous support of p2p and remote communication
## Protocol explanation
Plese read the [overview](https://tech.anytype.io/any-sync/overview) of protocol entities and design.
## Implementation
You can find the various parts of the protocol implemented in Go in the following repositories:
- [`any-sync-node`](https://github.com/anyproto/any-sync-node) — implementation of a sync node responsible for storing spaces and objects.
- [`any-sync-filenode`](https://github.com/anyproto/any-sync-filenode) — implementation of a file node responsible for storing files.
- [`any-sync-coordinator`](https://github.com/anyproto/any-sync-coordinator) — implementation of a coordinator node responsible for network configuration management.
## Contribution
Thank you for your desire to develop Anytype together.
Currently, we're not ready to accept PRs, but we will in the nearest future.
Follow us on [Github](https://github.com/anyproto) and join the [Contributors Community](https://github.com/orgs/anyproto/discussions).
---
Made by Any — a Swiss association 🇨🇭
Licensed under [MIT License](./LICENSE).

View File

@ -1,23 +1,22 @@
//go:generate mockgen -destination mock_accountservice/mock_accountservice.go github.com/anytypeio/any-sync/accountservice Service //go:generate mockgen -destination mock_accountservice/mock_accountservice.go github.com/anyproto/any-sync/accountservice Service
package accountservice package accountservice
import ( import (
"github.com/anytypeio/any-sync/app" "github.com/anyproto/any-sync/app"
"github.com/anytypeio/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
) )
const CName = "common.accountservice" const CName = "common.accountservice"
type Service interface { type Service interface {
app.Component app.Component
Account() *accountdata.AccountData Account() *accountdata.AccountKeys
} }
type Config struct { type Config struct {
PeerId string `yaml:"peerId"` PeerId string `yaml:"peerId"`
PeerKey string `yaml:"peerKey"` PeerKey string `yaml:"peerKey"`
SigningKey string `yaml:"signingKey"` SigningKey string `yaml:"signingKey"`
EncryptionKey string `yaml:"encryptionKey"`
} }
type ConfigGetter interface { type ConfigGetter interface {

View File

@ -1,12 +1,12 @@
package mock_accountservice package mock_accountservice
import ( import (
"github.com/anytypeio/any-sync/accountservice" "github.com/anyproto/any-sync/accountservice"
"github.com/anytypeio/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/golang/mock/gomock" "go.uber.org/mock/gomock"
) )
func NewAccountServiceWithAccount(ctrl *gomock.Controller, acc *accountdata.AccountData) *MockService { func NewAccountServiceWithAccount(ctrl *gomock.Controller, acc *accountdata.AccountKeys) *MockService {
mock := NewMockService(ctrl) mock := NewMockService(ctrl)
mock.EXPECT().Name().Return(accountservice.CName).AnyTimes() mock.EXPECT().Name().Return(accountservice.CName).AnyTimes()
mock.EXPECT().Init(gomock.Any()).AnyTimes() mock.EXPECT().Init(gomock.Any()).AnyTimes()

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anytypeio/any-sync/accountservice (interfaces: Service) // Source: github.com/anyproto/any-sync/accountservice (interfaces: Service)
// Package mock_accountservice is a generated GoMock package. // Package mock_accountservice is a generated GoMock package.
package mock_accountservice package mock_accountservice
@ -7,9 +7,9 @@ package mock_accountservice
import ( import (
reflect "reflect" reflect "reflect"
app "github.com/anytypeio/any-sync/app" app "github.com/anyproto/any-sync/app"
accountdata "github.com/anytypeio/any-sync/commonspace/object/accountdata" accountdata "github.com/anyproto/any-sync/commonspace/object/accountdata"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockService is a mock of Service interface. // MockService is a mock of Service interface.
@ -36,10 +36,10 @@ func (m *MockService) EXPECT() *MockServiceMockRecorder {
} }
// Account mocks base method. // Account mocks base method.
func (m *MockService) Account() *accountdata.AccountData { func (m *MockService) Account() *accountdata.AccountKeys {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Account") ret := m.ctrl.Call(m, "Account")
ret0, _ := ret[0].(*accountdata.AccountData) ret0, _ := ret[0].(*accountdata.AccountKeys)
return ret0 return ret0
} }

View File

@ -4,10 +4,11 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/anytypeio/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"go.uber.org/zap" "go.uber.org/zap"
"os" "os"
"runtime" "runtime"
"runtime/debug"
"strings" "strings"
"sync" "sync"
"time" "time"
@ -15,12 +16,15 @@ import (
var ( var (
// values of this vars will be defined while compilation // values of this vars will be defined while compilation
GitCommit, GitBranch, GitState, GitSummary, BuildDate string AppName, GitCommit, GitBranch, GitState, GitSummary, BuildDate string
name string name string
) )
var ( var (
log = logger.NewNamed("app") log = logger.NewNamed("app")
StopDeadline = time.Minute
StopWarningAfter = time.Second * 10
StartWarningAfter = time.Second * 10
) )
// Component is a minimal interface for a common app.Component // Component is a minimal interface for a common app.Component
@ -51,10 +55,14 @@ type ComponentStatable interface {
// App is the central part of the application // App is the central part of the application
// It contains and manages all components // It contains and manages all components
type App struct { type App struct {
parent *App
components []Component components []Component
mu sync.RWMutex mu sync.RWMutex
startStat StartStat startStat Stat
stopStat Stat
deviceState int deviceState int
versionName string
anySyncVersion string
} }
// Name returns app name // Name returns app name
@ -62,23 +70,47 @@ func (app *App) Name() string {
return name return name
} }
func (app *App) AppName() string {
return AppName
}
// Version return app version // Version return app version
func (app *App) Version() string { func (app *App) Version() string {
return GitSummary return GitSummary
} }
type StartStat struct { // SetVersionName sets the custom application version
func (app *App) SetVersionName(v string) {
app.versionName = v
}
// VersionName returns a string with the settled app version or auto-generated version if it didn't set
func (app *App) VersionName() string {
if app.versionName != "" {
return app.versionName
}
return AppName + ":" + GitSummary + "/any-sync:" + app.AnySyncVersion()
}
type Stat struct {
SpentMsPerComp map[string]int64 SpentMsPerComp map[string]int64
SpentMsTotal int64 SpentMsTotal int64
} }
// StartStat returns total time spent per comp // StartStat returns total time spent per comp for the last Start
func (app *App) StartStat() StartStat { func (app *App) StartStat() Stat {
app.mu.Lock() app.mu.Lock()
defer app.mu.Unlock() defer app.mu.Unlock()
return app.startStat return app.startStat
} }
// StopStat returns total time spent per comp for the last Close
func (app *App) StopStat() Stat {
app.mu.Lock()
defer app.mu.Unlock()
return app.stopStat
}
// VersionDescription return the full info about the build // VersionDescription return the full info about the build
func (app *App) VersionDescription() string { func (app *App) VersionDescription() string {
return VersionDescription() return VersionDescription()
@ -92,6 +124,16 @@ func VersionDescription() string {
return fmt.Sprintf("build on %s from %s at #%s(%s)", BuildDate, GitBranch, GitCommit, GitState) return fmt.Sprintf("build on %s from %s at #%s(%s)", BuildDate, GitBranch, GitCommit, GitState)
} }
// ChildApp creates a child container which has access to parent's components
// It doesn't call Start on any of the parent's components
func (app *App) ChildApp() *App {
return &App{
parent: app,
deviceState: app.deviceState,
anySyncVersion: app.AnySyncVersion(),
}
}
// Register adds service to registry // Register adds service to registry
// All components will be started in the order they were registered // All components will be started in the order they were registered
func (app *App) Register(s Component) *App { func (app *App) Register(s Component) *App {
@ -111,11 +153,15 @@ func (app *App) Register(s Component) *App {
func (app *App) Component(name string) Component { func (app *App) Component(name string) Component {
app.mu.RLock() app.mu.RLock()
defer app.mu.RUnlock() defer app.mu.RUnlock()
for _, s := range app.components { current := app
for current != nil {
for _, s := range current.components {
if s.Name() == name { if s.Name() == name {
return s return s
} }
} }
current = current.parent
}
return nil return nil
} }
@ -132,11 +178,15 @@ func (app *App) MustComponent(name string) Component {
func MustComponent[i any](app *App) i { func MustComponent[i any](app *App) i {
app.mu.RLock() app.mu.RLock()
defer app.mu.RUnlock() defer app.mu.RUnlock()
for _, s := range app.components { current := app
for current != nil {
for _, s := range current.components {
if v, ok := s.(i); ok { if v, ok := s.(i); ok {
return v return v
} }
} }
current = current.parent
}
empty := new(i) empty := new(i)
panic(fmt.Errorf("component with interface %T is not found", empty)) panic(fmt.Errorf("component with interface %T is not found", empty))
} }
@ -145,9 +195,13 @@ func MustComponent[i any](app *App) i {
func (app *App) ComponentNames() (names []string) { func (app *App) ComponentNames() (names []string) {
app.mu.RLock() app.mu.RLock()
defer app.mu.RUnlock() defer app.mu.RUnlock()
names = make([]string, len(app.components)) names = make([]string, 0, len(app.components))
for i, c := range app.components { current := app
names[i] = c.Name() for current != nil {
for _, c := range current.components {
names = append(names, c.Name())
}
current = current.parent
} }
return return
} }
@ -158,7 +212,17 @@ func (app *App) Start(ctx context.Context) (err error) {
app.mu.RLock() app.mu.RLock()
defer app.mu.RUnlock() defer app.mu.RUnlock()
app.startStat.SpentMsPerComp = make(map[string]int64) app.startStat.SpentMsPerComp = make(map[string]int64)
var currentComponentStarting string
done := make(chan struct{})
go func() {
select {
case <-done:
return
case <-time.After(StartWarningAfter):
l := statLogger(app.stopStat, log).With(zap.String("in_progress", currentComponentStarting))
l.Warn("components start in progress")
}
}()
closeServices := func(idx int) { closeServices := func(idx int) {
for i := idx; i >= 0; i-- { for i := idx; i >= 0; i-- {
if serviceClose, ok := app.components[i].(ComponentRunnable); ok { if serviceClose, ok := app.components[i].(ComponentRunnable); ok {
@ -172,7 +236,7 @@ func (app *App) Start(ctx context.Context) (err error) {
for i, s := range app.components { for i, s := range app.components {
if err = s.Init(app); err != nil { if err = s.Init(app); err != nil {
closeServices(i) closeServices(i)
return fmt.Errorf("can't init service '%s': %v", s.Name(), err) return fmt.Errorf("can't init service '%s': %w", s.Name(), err)
} }
} }
@ -181,17 +245,32 @@ func (app *App) Start(ctx context.Context) (err error) {
start := time.Now() start := time.Now()
if err = serviceRun.Run(ctx); err != nil { if err = serviceRun.Run(ctx); err != nil {
closeServices(i) closeServices(i)
return fmt.Errorf("can't run service '%s': %v", serviceRun.Name(), err) return fmt.Errorf("can't run service '%s': %w", serviceRun.Name(), err)
} }
spent := time.Since(start).Milliseconds() spent := time.Since(start).Milliseconds()
app.startStat.SpentMsTotal += spent app.startStat.SpentMsTotal += spent
app.startStat.SpentMsPerComp[s.Name()] = spent app.startStat.SpentMsPerComp[s.Name()] = spent
} }
} }
log.Debug("all components started")
close(done)
l := statLogger(app.stopStat, log)
if app.startStat.SpentMsTotal > StartWarningAfter.Milliseconds() {
l.Warn("all components started")
}
l.Debug("all components started")
return return
} }
// IterateComponents iterates over all registered components. It's safe for concurrent use.
func (app *App) IterateComponents(fn func(Component)) {
app.mu.RLock()
defer app.mu.RUnlock()
for _, s := range app.components {
fn(s)
}
}
func stackAllGoroutines() []byte { func stackAllGoroutines() []byte {
buf := make([]byte, 1024) buf := make([]byte, 1024)
for { for {
@ -203,18 +282,41 @@ func stackAllGoroutines() []byte {
} }
} }
func statLogger(stat Stat, ctxLogger logger.CtxLogger) logger.CtxLogger {
l := ctxLogger
for k, v := range stat.SpentMsPerComp {
l = l.With(zap.Int64(k, v))
}
l = l.With(zap.Int64("total", stat.SpentMsTotal))
return l
}
// Close stops the application // Close stops the application
// All components with ComponentRunnable implementation will be closed in the reversed order // All components with ComponentRunnable implementation will be closed in the reversed order
func (app *App) Close(ctx context.Context) error { func (app *App) Close(ctx context.Context) error {
log.Debug("close components...") log.Debug("close components...")
app.mu.RLock() app.mu.RLock()
defer app.mu.RUnlock() defer app.mu.RUnlock()
app.stopStat.SpentMsPerComp = make(map[string]int64)
var currentComponentStopping string
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
select { select {
case <-done: case <-done:
return return
case <-time.After(time.Minute): case <-time.After(StopWarningAfter):
statLogger(app.stopStat, log).
With(zap.String("in_progress", currentComponentStopping)).
Warn("components close in progress")
}
}()
go func() {
select {
case <-done:
return
case <-time.After(StopDeadline):
_, _ = os.Stderr.Write([]byte("app.Close timeout\n")) _, _ = os.Stderr.Write([]byte("app.Close timeout\n"))
_, _ = os.Stderr.Write(stackAllGoroutines()) _, _ = os.Stderr.Write(stackAllGoroutines())
panic("app.Close timeout") panic("app.Close timeout")
@ -224,16 +326,27 @@ func (app *App) Close(ctx context.Context) error {
var errs []string var errs []string
for i := len(app.components) - 1; i >= 0; i-- { for i := len(app.components) - 1; i >= 0; i-- {
if serviceClose, ok := app.components[i].(ComponentRunnable); ok { if serviceClose, ok := app.components[i].(ComponentRunnable); ok {
start := time.Now()
currentComponentStopping = app.components[i].Name()
if e := serviceClose.Close(ctx); e != nil { if e := serviceClose.Close(ctx); e != nil {
errs = append(errs, fmt.Sprintf("Component '%s' close error: %v", serviceClose.Name(), e)) errs = append(errs, fmt.Sprintf("Component '%s' close error: %v", serviceClose.Name(), e))
} }
spent := time.Since(start).Milliseconds()
app.stopStat.SpentMsTotal += spent
app.stopStat.SpentMsPerComp[app.components[i].Name()] = spent
} }
} }
close(done) close(done)
if len(errs) > 0 { if len(errs) > 0 {
return errors.New(strings.Join(errs, "\n")) return errors.New(strings.Join(errs, "\n"))
} }
log.Debug("all components have been closed")
l := statLogger(app.stopStat, log)
if app.stopStat.SpentMsTotal > StopWarningAfter.Milliseconds() {
l.Warn("all components have been closed")
}
l.Debug("all components have been closed")
return nil return nil
} }
@ -250,3 +363,20 @@ func (app *App) SetDeviceState(state int) {
} }
} }
} }
var onceVersion sync.Once
func (app *App) AnySyncVersion() string {
onceVersion.Do(func() {
info, ok := debug.ReadBuildInfo()
if ok {
for _, mod := range info.Deps {
if mod.Path == "github.com/anyproto/any-sync" {
app.anySyncVersion = mod.Version
break
}
}
}
})
return app.anySyncVersion
}

View File

@ -34,6 +34,40 @@ func TestAppServiceRegistry(t *testing.T) {
names := app.ComponentNames() names := app.ComponentNames()
assert.Equal(t, names, []string{"c1", "r1", "s1"}) assert.Equal(t, names, []string{"c1", "r1", "s1"})
}) })
t.Run("Child MustComponent", func(t *testing.T) {
app := app.ChildApp()
app.Register(newTestService(testTypeComponent, "x1", nil, nil))
for _, name := range []string{"c1", "r1", "s1", "x1"} {
assert.NotPanics(t, func() { app.MustComponent(name) }, name)
}
assert.Panics(t, func() { app.MustComponent("not-registered") })
})
t.Run("Child ComponentNames", func(t *testing.T) {
app := app.ChildApp()
app.Register(newTestService(testTypeComponent, "x1", nil, nil))
names := app.ComponentNames()
assert.Equal(t, names, []string{"x1", "c1", "r1", "s1"})
})
t.Run("Child override", func(t *testing.T) {
app := app.ChildApp()
app.Register(newTestService(testTypeRunnable, "s1", nil, nil))
_ = app.MustComponent("s1").(*testRunnable)
})
}
func TestApp_IterateComponents(t *testing.T) {
app := new(App)
app.Register(newTestService(testTypeRunnable, "c1", nil, nil))
app.Register(newTestService(testTypeRunnable, "r1", nil, nil))
app.Register(newTestService(testTypeComponent, "s1", nil, nil))
var got []string
app.IterateComponents(func(s Component) {
got = append(got, s.Name())
})
assert.ElementsMatch(t, []string{"c1", "r1", "s1"}, got)
} }
func TestAppStart(t *testing.T) { func TestAppStart(t *testing.T) {

View File

@ -1,7 +1,7 @@
// Package ldiff provides a container of elements with fixed id and changeable content. // Package ldiff provides a container of elements with fixed id and changeable content.
// Diff can calculate the difference with another diff container (you can make it remote) with minimum hops and traffic. // Diff can calculate the difference with another diff container (you can make it remote) with minimum hops and traffic.
// //
//go:generate mockgen -destination mock_ldiff/mock_ldiff.go github.com/anytypeio/any-sync/app/ldiff Diff,Remote //go:generate mockgen -destination mock_ldiff/mock_ldiff.go github.com/anyproto/any-sync/app/ldiff Diff,Remote
package ldiff package ldiff
import ( import (
@ -51,7 +51,7 @@ var hashersPool = &sync.Pool{
}, },
} }
var ErrElementNotFound = errors.New("element not found") var ErrElementNotFound = errors.New("ldiff: element not found")
// Element of data // Element of data
type Element struct { type Element struct {
@ -88,10 +88,14 @@ type Diff interface {
Diff(ctx context.Context, dl Remote) (newIds, changedIds, removedIds []string, err error) Diff(ctx context.Context, dl Remote) (newIds, changedIds, removedIds []string, err error)
// Elements retrieves all elements in the Diff // Elements retrieves all elements in the Diff
Elements() []Element Elements() []Element
// Element returns an element by id
Element(id string) (Element, error)
// Ids retrieves ids of all elements in the Diff // Ids retrieves ids of all elements in the Diff
Ids() []string Ids() []string
// Hash returns hash of all elements in the diff // Hash returns hash of all elements in the diff
Hash() string Hash() string
// Len returns count of elements in the diff
Len() int
} }
// Remote interface for using in the Diff // Remote interface for using in the Diff
@ -157,6 +161,12 @@ func (d *diff) Ids() (ids []string) {
return return
} }
func (d *diff) Len() int {
d.mu.RLock()
defer d.mu.RUnlock()
return d.sl.Len()
}
func (d *diff) Elements() (elements []Element) { func (d *diff) Elements() (elements []Element) {
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()
@ -172,6 +182,19 @@ func (d *diff) Elements() (elements []Element) {
return return
} }
func (d *diff) Element(id string) (Element, error) {
d.mu.RLock()
defer d.mu.RUnlock()
el := d.sl.Get(&element{Element: Element{Id: id}, hash: xxhash.Sum64([]byte(id))})
if el == nil {
return Element{}, ErrElementNotFound
}
if e, ok := el.Key().(*element); ok {
return e.Element, nil
}
return Element{}, ErrElementNotFound
}
func (d *diff) Hash() string { func (d *diff) Hash() string {
d.mu.RLock() d.mu.RLock()
defer d.mu.RUnlock() defer d.mu.RUnlock()

View File

@ -3,10 +3,11 @@ package ldiff
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"gopkg.in/mgo.v2/bson"
"math" "math"
"sort"
"testing" "testing"
) )
@ -43,7 +44,7 @@ func TestDiff_Diff(t *testing.T) {
d2 := New(16, 16) d2 := New(16, 16)
for i := 0; i < 1000; i++ { for i := 0; i < 1000; i++ {
id := fmt.Sprint(i) id := fmt.Sprint(i)
head := bson.NewObjectId().Hex() head := uuid.NewString()
d1.Set(Element{ d1.Set(Element{
Id: id, Id: id,
Head: head, Head: head,
@ -91,7 +92,7 @@ func TestDiff_Diff(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
d2.Set(Element{ d2.Set(Element{
Id: fmt.Sprint(i), Id: fmt.Sprint(i),
Head: bson.NewObjectId().Hex(), Head: uuid.NewString(),
}) })
} }
@ -107,7 +108,7 @@ func TestDiff_Diff(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
d2.Set(Element{ d2.Set(Element{
Id: fmt.Sprint(i), Id: fmt.Sprint(i),
Head: bson.NewObjectId().Hex(), Head: uuid.NewString(),
}) })
} }
var cancel func() var cancel func()
@ -122,7 +123,7 @@ func BenchmarkDiff_Ranges(b *testing.B) {
d := New(16, 16) d := New(16, 16)
for i := 0; i < 10000; i++ { for i := 0; i < 10000; i++ {
id := fmt.Sprint(i) id := fmt.Sprint(i)
head := bson.NewObjectId().Hex() head := uuid.NewString()
d.Set(Element{ d.Set(Element{
Id: id, Id: id,
Head: head, Head: head,
@ -148,3 +149,51 @@ func TestDiff_Hash(t *testing.T) {
assert.NotEmpty(t, h2) assert.NotEmpty(t, h2)
assert.NotEqual(t, h1, h2) assert.NotEqual(t, h1, h2)
} }
func TestDiff_Element(t *testing.T) {
d := New(16, 16)
for i := 0; i < 10; i++ {
d.Set(Element{Id: fmt.Sprint("id", i), Head: fmt.Sprint("head", i)})
}
_, err := d.Element("not found")
assert.Equal(t, ErrElementNotFound, err)
el, err := d.Element("id5")
require.NoError(t, err)
assert.Equal(t, "head5", el.Head)
d.Set(Element{"id5", "otherHead"})
el, err = d.Element("id5")
require.NoError(t, err)
assert.Equal(t, "otherHead", el.Head)
}
func TestDiff_Ids(t *testing.T) {
d := New(16, 16)
var ids []string
for i := 0; i < 10; i++ {
id := fmt.Sprint("id", i)
d.Set(Element{Id: id, Head: fmt.Sprint("head", i)})
ids = append(ids, id)
}
gotIds := d.Ids()
sort.Strings(gotIds)
assert.Equal(t, ids, gotIds)
assert.Equal(t, len(ids), d.Len())
}
func TestDiff_Elements(t *testing.T) {
d := New(16, 16)
var els []Element
for i := 0; i < 10; i++ {
id := fmt.Sprint("id", i)
el := Element{Id: id, Head: fmt.Sprint("head", i)}
d.Set(el)
els = append(els, el)
}
gotEls := d.Elements()
sort.Slice(gotEls, func(i, j int) bool {
return gotEls[i].Id < gotEls[j].Id
})
assert.Equal(t, els, gotEls)
}

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anytypeio/any-sync/app/ldiff (interfaces: Diff,Remote) // Source: github.com/anyproto/any-sync/app/ldiff (interfaces: Diff,Remote)
// Package mock_ldiff is a generated GoMock package. // Package mock_ldiff is a generated GoMock package.
package mock_ldiff package mock_ldiff
@ -8,8 +8,8 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
ldiff "github.com/anytypeio/any-sync/app/ldiff" ldiff "github.com/anyproto/any-sync/app/ldiff"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockDiff is a mock of Diff interface. // MockDiff is a mock of Diff interface.
@ -52,6 +52,21 @@ func (mr *MockDiffMockRecorder) Diff(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Diff", reflect.TypeOf((*MockDiff)(nil).Diff), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Diff", reflect.TypeOf((*MockDiff)(nil).Diff), arg0, arg1)
} }
// Element mocks base method.
func (m *MockDiff) Element(arg0 string) (ldiff.Element, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Element", arg0)
ret0, _ := ret[0].(ldiff.Element)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Element indicates an expected call of Element.
func (mr *MockDiffMockRecorder) Element(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Element", reflect.TypeOf((*MockDiff)(nil).Element), arg0)
}
// Elements mocks base method. // Elements mocks base method.
func (m *MockDiff) Elements() []ldiff.Element { func (m *MockDiff) Elements() []ldiff.Element {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -94,6 +109,20 @@ func (mr *MockDiffMockRecorder) Ids() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ids", reflect.TypeOf((*MockDiff)(nil).Ids)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ids", reflect.TypeOf((*MockDiff)(nil).Ids))
} }
// Len mocks base method.
func (m *MockDiff) Len() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Len")
ret0, _ := ret[0].(int)
return ret0
}
// Len indicates an expected call of Len.
func (mr *MockDiffMockRecorder) Len() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockDiff)(nil).Len))
}
// Ranges mocks base method. // Ranges mocks base method.
func (m *MockDiff) Ranges(arg0 context.Context, arg1 []ldiff.Range, arg2 []ldiff.RangeResult) ([]ldiff.RangeResult, error) { func (m *MockDiff) Ranges(arg0 context.Context, arg1 []ldiff.Range, arg2 []ldiff.RangeResult) ([]ldiff.RangeResult, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -1,33 +1,124 @@
package logger package logger
import "go.uber.org/zap" import (
"fmt"
"strings"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"github.com/anyproto/any-sync/util/slice"
)
type LogFormat int
const (
ColorizedOutput LogFormat = iota
PlaintextOutput
JSONOutput
)
type NamedLevel struct {
Name string `yaml:"name"`
Level string `yaml:"level"`
}
type Config struct { type Config struct {
Production bool `yaml:"production"` Production bool `yaml:"production"`
DefaultLevel string `yaml:"defaultLevel"` DefaultLevel string `yaml:"defaultLevel"`
NamedLevels map[string]string `yaml:"namedLevels"` Levels []NamedLevel `yaml:"levels"` // first match will be used
AddOutputPaths []string `yaml:"outputPaths"`
DisableStdErr bool `yaml:"disableStdErr"`
Format LogFormat `yaml:"format"`
ZapConfig *zap.Config `yaml:"-"` // optional, if set it will be used instead of other config options
} }
func (l Config) ApplyGlobal() { func (l Config) ApplyGlobal() {
var conf zap.Config var conf zap.Config
if l.ZapConfig != nil {
conf = *l.ZapConfig
} else {
if l.Production { if l.Production {
conf = zap.NewProductionConfig() conf = zap.NewProductionConfig()
} else { } else {
conf = zap.NewDevelopmentConfig() conf = zap.NewDevelopmentConfig()
} }
encConfig := conf.EncoderConfig
switch l.Format {
case PlaintextOutput:
encConfig.EncodeLevel = zapcore.CapitalLevelEncoder
conf.Encoding = "console"
case JSONOutput:
encConfig.MessageKey = "msg"
encConfig.TimeKey = "ts"
encConfig.LevelKey = "level"
encConfig.NameKey = "logger"
encConfig.CallerKey = "caller"
encConfig.EncodeTime = zapcore.ISO8601TimeEncoder
conf.Encoding = "json"
default:
// default is ColorizedOutput
conf.Encoding = "console"
encConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
}
conf.EncoderConfig = encConfig
if len(l.AddOutputPaths) > 0 {
conf.OutputPaths = append(conf.OutputPaths, l.AddOutputPaths...)
}
if l.DisableStdErr {
conf.OutputPaths = slice.Filter(conf.OutputPaths, func(path string) bool {
return path != "stderr"
})
}
if defaultLevel, err := zap.ParseAtomicLevel(l.DefaultLevel); err == nil { if defaultLevel, err := zap.ParseAtomicLevel(l.DefaultLevel); err == nil {
conf.Level = defaultLevel conf.Level = defaultLevel
} }
var levels = make(map[string]zap.AtomicLevel) }
for k, v := range l.NamedLevels { for _, v := range l.Levels {
if lev, err := zap.ParseAtomicLevel(v); err != nil { if lev, err := zap.ParseAtomicLevel(v.Level); err == nil {
levels[k] = lev // we need to have a minimum level of all named loggers for the main logger
if lev.Level() < conf.Level.Level() {
conf.Level.SetLevel(lev.Level())
} }
} }
defaultLogger, err := conf.Build() }
lg, err := conf.Build()
if err != nil { if err != nil {
Default().Fatal("can't build logger", zap.Error(err)) Default().Fatal("can't build logger", zap.Error(err))
} }
SetDefault(defaultLogger) SetDefault(lg)
SetNamedLevels(levels) SetNamedLevels(l.Levels)
}
// LevelsFromStr parses a string of the form "name1=DEBUG;prefix*=WARN;*=ERROR" into a slice of NamedLevel
// it may be useful to parse the log level from the OS env var
func LevelsFromStr(s string) (levels []NamedLevel) {
for _, kv := range strings.Split(s, ";") {
strings.TrimSpace(kv)
parts := strings.Split(kv, "=")
var key, value string
if len(parts) == 1 {
key = "*"
value = parts[0]
_, err := zap.ParseAtomicLevel(value)
if err != nil {
fmt.Printf("Can't parse log level %s: %s\n", parts[0], err.Error())
continue
}
levels = append(levels, NamedLevel{Name: key, Level: value})
} else if len(parts) == 2 {
key = parts[0]
value = parts[1]
}
_, err := zap.ParseAtomicLevel(value)
if err != nil {
fmt.Printf("Can't parse log level %s: %s\n", parts[0], err.Error())
continue
}
levels = append(levels, NamedLevel{Name: key, Level: value})
}
return levels
} }

View File

@ -32,6 +32,7 @@ func CtxGetFields(ctx context.Context) (fields []zap.Field) {
type CtxLogger struct { type CtxLogger struct {
*zap.Logger *zap.Logger
name string
} }
func (cl CtxLogger) DebugCtx(ctx context.Context, msg string, fields ...zap.Field) { func (cl CtxLogger) DebugCtx(ctx context.Context, msg string, fields ...zap.Field) {
@ -51,5 +52,9 @@ func (cl CtxLogger) ErrorCtx(ctx context.Context, msg string, fields ...zap.Fiel
} }
func (cl CtxLogger) With(fields ...zap.Field) CtxLogger { func (cl CtxLogger) With(fields ...zap.Field) CtxLogger {
return CtxLogger{cl.Logger.With(fields...)} return CtxLogger{cl.Logger.With(fields...), cl.name}
}
func (cl CtxLogger) Sugar() *zap.SugaredLogger {
return NewNamedSugared(cl.name)
} }

View File

@ -1,51 +1,137 @@
package logger package logger
import ( import (
"go.uber.org/zap"
"sync" "sync"
"github.com/gobwas/glob"
"go.uber.org/zap"
) )
var ( var (
mu sync.Mutex mu sync.Mutex
defaultLogger *zap.Logger logger *zap.Logger
levels = make(map[string]zap.AtomicLevel) loggerConfig zap.Config
loggers = make(map[string]CtxLogger) namedLevels []namedLevel
namedGlobs = make(map[string]glob.Glob)
namedLoggers = make(map[string]CtxLogger)
namedSugarLoggers = make(map[string]*zap.SugaredLogger)
) )
func init() { type namedLevel struct {
defaultLogger, _ = zap.NewDevelopment() name string
zap.NewProduction() level zap.AtomicLevel
} }
func init() {
loggerConfig = zap.NewDevelopmentConfig()
logger, _ = loggerConfig.Build()
}
// SetDefault replaces the default logger
// you need to call SetNamedLevels after in case you have named loggers,
// otherwise they will use the old logger
func SetDefault(l *zap.Logger) { func SetDefault(l *zap.Logger) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
*defaultLogger = *l *logger = *l
for name, l := range loggers {
*l.Logger = *defaultLogger.Named(name)
}
} }
func SetNamedLevels(l map[string]zap.AtomicLevel) { // SetNamedLevels sets the namedLevels for named loggers
// it also supports glob patterns for names, like "app*"
// can be racy in case there are existing named loggers
// so consider to call only once at the beginning
func SetNamedLevels(nls []NamedLevel) {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
levels = l namedLevels = namedLevels[:0]
var minLevel = logger.Level()
for _, nl := range nls {
l, err := zap.ParseAtomicLevel(nl.Level)
if err != nil {
continue
}
namedLevels = append(namedLevels, namedLevel{name: nl.Name, level: l})
g, err := glob.Compile(nl.Name)
if err == nil {
namedGlobs[nl.Name] = g
}
if l.Level() < minLevel {
minLevel = l.Level()
}
}
if minLevel < logger.Level() {
// recreate logger if the min level is lower than the current min one
loggerConfig.Level = zap.NewAtomicLevelAt(minLevel)
logger, _ = loggerConfig.Build()
}
for name, nl := range namedLoggers {
level := getLevel(name)
newCore := zap.New(logger.Core()).Named(name).WithOptions(
zap.IncreaseLevel(level),
)
*(nl.Logger) = *newCore
}
for name, nl := range namedSugarLoggers {
level := getLevel(name)
newCore := zap.New(logger.Core()).Named(name).WithOptions(
zap.IncreaseLevel(level),
).Sugar()
*(nl) = *newCore
}
} }
func Default() *zap.Logger { func Default() *zap.Logger {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
return defaultLogger return logger
}
// getLevel returns the level for the given name
// it return the first matching name or glob pattern whatever comes first
func getLevel(name string) zap.AtomicLevel {
for _, nl := range namedLevels {
if nl.name == name {
return nl.level
}
if g, ok := namedGlobs[nl.name]; ok && g.Match(name) {
return nl.level
}
}
return zap.NewAtomicLevelAt(logger.Level())
} }
func NewNamed(name string, fields ...zap.Field) CtxLogger { func NewNamed(name string, fields ...zap.Field) CtxLogger {
mu.Lock() mu.Lock()
defer mu.Unlock() defer mu.Unlock()
l := defaultLogger.Named(name)
if len(fields) > 0 { if l, nameExists := namedLoggers[name]; nameExists {
l = l.With(fields...) return l
} }
ctxL := CtxLogger{l}
loggers[name] = ctxL level := getLevel(name)
l := zap.New(logger.Core()).Named(name).WithOptions(zap.IncreaseLevel(level),
zap.Fields(fields...))
ctxL := CtxLogger{Logger: l, name: name}
namedLoggers[name] = ctxL
return ctxL return ctxL
} }
func NewNamedSugared(name string) *zap.SugaredLogger {
mu.Lock()
defer mu.Unlock()
if l, nameExists := namedSugarLoggers[name]; nameExists {
return l
}
level := getLevel(name)
l := zap.New(logger.Core()).Named(name).Sugar().WithOptions(zap.IncreaseLevel(level))
namedSugarLoggers[name] = l
return l
}

150
app/logger/log_test.go Normal file
View File

@ -0,0 +1,150 @@
package logger
import (
"reflect"
"testing"
"go.uber.org/zap"
)
func Test_getLevel1(t *testing.T) {
SetNamedLevels([]NamedLevel{
{Name: "app", Level: "debug"},
{Name: "app*", Level: "info"},
{Name: "app.sub", Level: "warn"},
{Name: "*", Level: "fatal"},
})
tests := []struct {
name string
want zap.AtomicLevel
}{
{
name: "app",
want: zap.NewAtomicLevelAt(zap.DebugLevel),
},
{
name: "app.aaa",
want: zap.NewAtomicLevelAt(zap.InfoLevel),
},
{
name: "app.sub",
want: zap.NewAtomicLevelAt(zap.InfoLevel),
},
{
name: "random",
want: zap.NewAtomicLevelAt(zap.FatalLevel),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getLevel(tt.name); !reflect.DeepEqual(got, tt.want) {
t.Errorf("getLevel() = %v, want %v", got, tt.want)
}
})
}
}
func Test_getLevel2(t *testing.T) {
SetNamedLevels([]NamedLevel{
{Name: "*", Level: "ERROR"},
{Name: "app", Level: "info"},
{Name: "app.sub", Level: "warn"},
{Name: "*", Level: "fatal"},
})
tests := []struct {
name string
want zap.AtomicLevel
}{
{
name: "app",
want: zap.NewAtomicLevelAt(zap.ErrorLevel),
},
{
name: "app.aaa",
want: zap.NewAtomicLevelAt(zap.ErrorLevel),
},
{
name: "app.sub",
want: zap.NewAtomicLevelAt(zap.ErrorLevel),
},
{
name: "random",
want: zap.NewAtomicLevelAt(zap.ErrorLevel),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getLevel(tt.name); !reflect.DeepEqual(got, tt.want) {
t.Errorf("getLevel() = %v, want %v", got, tt.want)
}
})
}
}
func Test_getLevel3(t *testing.T) {
SetNamedLevels([]NamedLevel{
{Name: "app", Level: "info"},
{Name: "*.sub", Level: "warn"},
{Name: "*", Level: "fatal"},
})
tests := []struct {
name string
want zap.AtomicLevel
}{
{
name: "app",
want: zap.NewAtomicLevelAt(zap.InfoLevel),
},
{
name: "app.sub",
want: zap.NewAtomicLevelAt(zap.WarnLevel),
},
{
name: "random",
want: zap.NewAtomicLevelAt(zap.FatalLevel),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getLevel(tt.name); !reflect.DeepEqual(got, tt.want) {
t.Errorf("getLevel() = %v, want %v", got, tt.want)
}
})
}
}
func Test_getLevel4(t *testing.T) {
SetNamedLevels([]NamedLevel{
{Name: "*", Level: "invalid"},
{Name: "app", Level: "info"},
{Name: "b", Level: "invalid"},
})
tests := []struct {
name string
want zap.AtomicLevel
}{
{
name: "app",
want: zap.NewAtomicLevelAt(zap.InfoLevel),
},
{
name: "app.sub",
want: zap.NewAtomicLevelAt(logger.Level()),
},
{
name: "b",
want: zap.NewAtomicLevelAt(logger.Level()),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getLevel(tt.name); !reflect.DeepEqual(got, tt.want) {
t.Errorf("getLevel() = %v, want %v", got, tt.want)
}
})
}
}

136
app/ocache/entry.go Normal file
View File

@ -0,0 +1,136 @@
package ocache
import (
"context"
"sync"
"time"
"go.uber.org/zap"
)
type entryState int
const (
entryStateLoading = iota
entryStateActive
entryStateClosing
entryStateClosed
)
type entry struct {
id string
state entryState
lastUsage time.Time
load chan struct{}
loadErr error
value Object
close chan struct{}
mx sync.Mutex
cancel context.CancelFunc
}
func newEntry(id string, value Object, state entryState) *entry {
return &entry{
id: id,
load: make(chan struct{}),
lastUsage: time.Now(),
state: state,
value: value,
}
}
func (e *entry) isActive() bool {
e.mx.Lock()
defer e.mx.Unlock()
return e.state == entryStateActive
}
func (e *entry) isClosing() bool {
e.mx.Lock()
defer e.mx.Unlock()
return e.state == entryStateClosed || e.state == entryStateClosing
}
func (e *entry) setCancel(cancel context.CancelFunc) {
e.mx.Lock()
defer e.mx.Unlock()
e.cancel = cancel
}
func (e *entry) cancelLoad() {
e.mx.Lock()
defer e.mx.Unlock()
if e.cancel != nil {
e.cancel()
}
}
func (e *entry) waitLoad(ctx context.Context, id string) (value Object, err error) {
select {
case <-ctx.Done():
log.DebugCtx(ctx, "ctx done while waiting on object load", zap.String("id", id))
return nil, ctx.Err()
case <-e.load:
return e.value, e.loadErr
}
}
func (e *entry) waitClose(ctx context.Context, id string) (res bool, err error) {
e.mx.Lock()
switch e.state {
case entryStateClosing:
waitCh := e.close
e.mx.Unlock()
select {
case <-ctx.Done():
log.DebugCtx(ctx, "ctx done while waiting on object close", zap.String("id", id))
return false, ctx.Err()
case <-waitCh:
return true, nil
}
case entryStateClosed:
e.mx.Unlock()
return true, nil
default:
e.mx.Unlock()
return false, nil
}
}
func (e *entry) setClosing(wait bool) (prevState, curState entryState) {
e.mx.Lock()
prevState = e.state
curState = e.state
if e.state == entryStateClosing {
waitCh := e.close
e.mx.Unlock()
if !wait {
return
}
<-waitCh
e.mx.Lock()
}
if e.state != entryStateClosed {
e.state = entryStateClosing
e.close = make(chan struct{})
}
curState = e.state
e.mx.Unlock()
return
}
func (e *entry) setActive(chClose bool) {
e.mx.Lock()
defer e.mx.Unlock()
if chClose {
close(e.close)
}
e.state = entryStateActive
}
func (e *entry) setClosed() {
e.mx.Lock()
defer e.mx.Unlock()
close(e.close)
e.state = entryStateClosed
}

View File

@ -1,14 +1,22 @@
package ocache package ocache
import "github.com/prometheus/client_golang/prometheus" import (
"github.com/prometheus/client_golang/prometheus"
"strings"
)
func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Option { func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Option {
if subsystem == "" {
subsystem = "cache"
}
if reg == nil { if reg == nil {
return nil return nil
} }
if subsystem == "" {
subsystem = "cache"
}
nameSplit := strings.Split(namespace, ".")
subSplit := strings.Split(subsystem, ".")
namespace = strings.Join(nameSplit, "_")
subsystem = strings.Join(subSplit, "_")
return func(cache *oCache) { return func(cache *oCache) {
cache.metrics = &metrics{ cache.metrics = &metrics{
hit: prometheus.NewCounter(prometheus.CounterOpts{ hit: prometheus.NewCounter(prometheus.CounterOpts{

View File

@ -0,0 +1,19 @@
package ocache
import (
"context"
"github.com/prometheus/client_golang/prometheus"
"github.com/stretchr/testify/require"
"strings"
"testing"
)
func TestWithPrometheus_MetricsConvertsDots(t *testing.T) {
opt := WithPrometheus(prometheus.NewRegistry(), "some.name", "some.system")
cache := New(func(ctx context.Context, id string) (value Object, err error) {
return &testObject{}, nil
}, opt).(*oCache)
_, err := cache.Get(context.Background(), "id")
require.NoError(t, err)
require.True(t, strings.Contains(cache.metrics.hit.Desc().String(), "some_name_some_system_hit"))
}

View File

@ -3,10 +3,11 @@ package ocache
import ( import (
"context" "context"
"errors" "errors"
"github.com/anytypeio/any-sync/app/logger"
"go.uber.org/zap"
"sync" "sync"
"time" "time"
"github.com/anyproto/any-sync/app/logger"
"go.uber.org/zap"
) )
var ( var (
@ -44,12 +45,6 @@ var WithGCPeriod = func(gcPeriod time.Duration) Option {
} }
} }
var WithRefCounter = func(enable bool) Option {
return func(cache *oCache) {
cache.refCounter = enable
}
}
func New(loadFunc LoadFunc, opts ...Option) OCache { func New(loadFunc LoadFunc, opts ...Option) OCache {
c := &oCache{ c := &oCache{
data: make(map[string]*entry), data: make(map[string]*entry),
@ -73,33 +68,7 @@ func New(loadFunc LoadFunc, opts ...Option) OCache {
type Object interface { type Object interface {
Close() (err error) Close() (err error)
} TryClose(objectTTL time.Duration) (res bool, err error)
type ObjectLocker interface {
Object
Locked() bool
}
type ObjectLastUsage interface {
LastUsage() time.Time
}
type entry struct {
id string
lastUsage time.Time
refCount uint32
isClosing bool
load chan struct{}
loadErr error
value Object
close chan struct{}
}
func (e *entry) locked() bool {
if locker, ok := e.value.(ObjectLocker); ok {
return locker.Locked()
}
return false
} }
type OCache interface { type OCache interface {
@ -116,12 +85,8 @@ type OCache interface {
// Add adds new object to cache // Add adds new object to cache
// Returns error when object exists // Returns error when object exists
Add(id string, value Object) (err error) Add(id string, value Object) (err error)
// Release decreases the refs counter
Release(id string) bool
// Reset sets refs counter to 0
Reset(id string) bool
// Remove closes and removes object // Remove closes and removes object
Remove(id string) (ok bool, err error) Remove(ctx context.Context, id string) (ok bool, err error)
// ForEach iterates over all loaded objects, breaks when callback returns false // ForEach iterates over all loaded objects, breaks when callback returns false
ForEach(f func(v Object) (isContinue bool)) ForEach(f func(v Object) (isContinue bool))
// GC frees not used and expired objects // GC frees not used and expired objects
@ -144,7 +109,6 @@ type oCache struct {
closeCh chan struct{} closeCh chan struct{}
log *zap.SugaredLogger log *zap.SugaredLogger
metrics *metrics metrics *metrics
refCounter bool
} }
func (c *oCache) Get(ctx context.Context, id string) (value Object, err error) { func (c *oCache) Get(ctx context.Context, id string) (value Object, err error) {
@ -160,68 +124,44 @@ Load:
return nil, ErrClosed return nil, ErrClosed
} }
if e, ok = c.data[id]; !ok { if e, ok = c.data[id]; !ok {
e = newEntry(id, nil, entryStateLoading)
load = true load = true
e = &entry{
id: id,
load: make(chan struct{}),
}
c.data[id] = e c.data[id] = e
} }
closing := e.isClosing e.lastUsage = time.Now()
if !e.isClosing {
e.lastUsage = c.timeNow()
if c.refCounter {
e.refCount++
}
}
c.mu.Unlock() c.mu.Unlock()
if closing { reload, err := e.waitClose(ctx, id)
<-e.close if err != nil {
return nil, err
}
if reload {
goto Load goto Load
} }
if load { if load {
go c.load(ctx, id, e) go c.load(ctx, id, e)
} }
if c.metrics != nil { c.metricsGet(!load)
if load { return e.waitLoad(ctx, id)
c.metrics.miss.Inc()
} else {
c.metrics.hit.Inc()
}
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-e.load:
}
return e.value, e.loadErr
} }
func (c *oCache) Pick(ctx context.Context, id string) (value Object, err error) { func (c *oCache) Pick(ctx context.Context, id string) (value Object, err error) {
c.mu.Lock() c.mu.Lock()
val, ok := c.data[id] val, ok := c.data[id]
if !ok || val.isClosing { if !ok || val.isClosing() {
c.mu.Unlock() c.mu.Unlock()
return nil, ErrNotExists return nil, ErrNotExists
} }
c.mu.Unlock() c.mu.Unlock()
c.metricsGet(true)
if c.metrics != nil { return val.waitLoad(ctx, id)
c.metrics.hit.Inc()
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-val.load:
return val.value, val.loadErr
}
} }
func (c *oCache) load(ctx context.Context, id string, e *entry) { func (c *oCache) load(ctx context.Context, id string, e *entry) {
defer close(e.load) defer close(e.load)
ctx, cancel := context.WithCancel(ctx)
e.setCancel(cancel)
value, err := c.loadFunc(ctx, id) value, err := c.loadFunc(ctx, id)
cancel()
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -230,63 +170,39 @@ func (c *oCache) load(ctx context.Context, id string, e *entry) {
delete(c.data, id) delete(c.data, id)
} else { } else {
e.value = value e.value = value
e.setActive(false)
} }
} }
func (c *oCache) Release(id string) bool { func (c *oCache) Remove(ctx context.Context, id string) (ok bool, err error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return false
}
if e, ok := c.data[id]; ok {
if c.refCounter && e.refCount > 0 {
e.refCount--
return true
}
}
return false
}
func (c *oCache) Reset(id string) bool {
c.mu.Lock()
defer c.mu.Unlock()
if c.closed {
return false
}
if e, ok := c.data[id]; ok {
e.refCount = 0
return true
}
return false
}
func (c *oCache) Remove(id string) (ok bool, err error) {
c.mu.Lock() c.mu.Lock()
if c.closed { if c.closed {
c.mu.Unlock() c.mu.Unlock()
err = ErrClosed err = ErrClosed
return return
} }
var e *entry e, ok := c.data[id]
e, ok = c.data[id] if !ok {
if !ok || e.isClosing {
c.mu.Unlock() c.mu.Unlock()
return return false, ErrNotExists
} }
e.isClosing = true
e.close = make(chan struct{})
c.mu.Unlock() c.mu.Unlock()
return c.remove(ctx, e)
}
<-e.load func (c *oCache) remove(ctx context.Context, e *entry) (ok bool, err error) {
if e.value != nil { if _, err = e.waitLoad(ctx, e.id); err != nil {
err = e.value.Close() return false, err
} }
_, curState := e.setClosing(true)
if curState == entryStateClosing {
ok = true
err = e.value.Close()
c.mu.Lock() c.mu.Lock()
close(e.close) e.setClosed()
delete(c.data, e.id) delete(c.data, e.id)
c.mu.Unlock() c.mu.Unlock()
}
return return
} }
@ -308,13 +224,7 @@ func (c *oCache) Add(id string, value Object) (err error) {
if _, ok := c.data[id]; ok { if _, ok := c.data[id]; ok {
return ErrExists return ErrExists
} }
e := &entry{ e := newEntry(id, value, entryStateActive)
id: id,
lastUsage: time.Now(),
refCount: 0,
load: make(chan struct{}),
value: value,
}
close(e.load) close(e.load)
c.data[id] = e c.data[id] = e
return return
@ -326,7 +236,7 @@ func (c *oCache) ForEach(f func(obj Object) (isContinue bool)) {
for _, v := range c.data { for _, v := range c.data {
select { select {
case <-v.load: case <-v.load:
if v.value != nil && !v.isClosing { if v.value != nil && !v.isClosing() {
objects = append(objects, v.value) objects = append(objects, v.value)
} }
default: default:
@ -362,40 +272,35 @@ func (c *oCache) GC() {
deadline := c.timeNow().Add(-c.ttl) deadline := c.timeNow().Add(-c.ttl)
var toClose []*entry var toClose []*entry
for _, e := range c.data { for _, e := range c.data {
if e.isClosing { if e.isActive() && e.lastUsage.Before(deadline) {
continue
}
lu := e.lastUsage
if lug, ok := e.value.(ObjectLastUsage); ok {
lu = lug.LastUsage()
}
if !e.locked() && e.refCount <= 0 && lu.Before(deadline) {
e.isClosing = true
e.close = make(chan struct{}) e.close = make(chan struct{})
toClose = append(toClose, e) toClose = append(toClose, e)
} }
} }
size := len(c.data) size := len(c.data)
c.mu.Unlock() c.mu.Unlock()
closedNum := 0
for _, e := range toClose { for _, e := range toClose {
<-e.load prevState, _ := e.setClosing(false)
if e.value != nil { if prevState == entryStateClosing || prevState == entryStateClosed {
if err := e.value.Close(); err != nil { continue
}
closed, err := e.value.TryClose(c.ttl)
if err != nil {
c.log.With("object_id", e.id).Warnf("GC: object close error: %v", err) c.log.With("object_id", e.id).Warnf("GC: object close error: %v", err)
} }
} if !closed {
} e.setActive(true)
c.log.Infof("GC: removed %d; cache size: %d", len(toClose), size) continue
if len(toClose) > 0 && c.metrics != nil { } else {
c.metrics.gc.Add(float64(len(toClose))) closedNum++
}
c.mu.Lock() c.mu.Lock()
for _, e := range toClose { e.setClosed()
close(e.close)
delete(c.data, e.id) delete(c.data, e.id)
}
c.mu.Unlock() c.mu.Unlock()
}
}
c.metricsClosed(closedNum, size)
} }
func (c *oCache) Len() int { func (c *oCache) Len() int {
@ -412,25 +317,35 @@ func (c *oCache) Close() (err error) {
} }
c.closed = true c.closed = true
close(c.closeCh) close(c.closeCh)
var toClose, alreadyClosing []*entry var toClose []*entry
for _, e := range c.data { for _, e := range c.data {
if e.isClosing { e.cancelLoad()
alreadyClosing = append(alreadyClosing, e)
} else {
toClose = append(toClose, e) toClose = append(toClose, e)
} }
}
c.mu.Unlock() c.mu.Unlock()
for _, e := range toClose { for _, e := range toClose {
<-e.load if _, err := c.remove(context.Background(), e); err != nil && err != ErrNotExists {
if e.value != nil { c.log.With("object_id", e.id).Warnf("cache close: object close error: %v", err)
if clErr := e.value.Close(); clErr != nil {
c.log.With("object_id", e.id).Warnf("cache close: object close error: %v", clErr)
} }
} }
}
for _, e := range alreadyClosing {
<-e.close
}
return nil return nil
} }
func (c *oCache) metricsGet(hit bool) {
if c.metrics == nil {
return
}
if hit {
c.metrics.hit.Inc()
} else {
c.metrics.miss.Inc()
}
}
func (c *oCache) metricsClosed(closedLen, size int) {
c.log.Infof("GC: removed %d; cache size: %d", closedLen, size)
if c.metrics == nil || closedLen == 0 {
return
}
c.metrics.gc.Add(float64(closedLen))
}

View File

@ -3,6 +3,8 @@ package ocache
import ( import (
"context" "context"
"errors" "errors"
"fmt"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time" "time"
@ -11,26 +13,48 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
var ctx = context.Background()
type testObject struct { type testObject struct {
name string name string
closeErr error closeErr error
closeCh chan struct{} closeCh chan struct{}
tryReturn bool
closeCalled bool
tryCloseCalled bool
} }
func NewTestObject(name string, closeCh chan struct{}) *testObject { func NewTestObject(name string, tryReturn bool, closeCh chan struct{}) *testObject {
return &testObject{ return &testObject{
name: name, name: name,
closeCh: closeCh, closeCh: closeCh,
tryReturn: tryReturn,
} }
} }
func (t *testObject) Close() (err error) { func (t *testObject) Close() (err error) {
if t.closeCalled || (t.tryCloseCalled && t.tryReturn) {
panic("close called twice")
}
t.closeCalled = true
if t.closeCh != nil { if t.closeCh != nil {
<-t.closeCh <-t.closeCh
} }
return t.closeErr return t.closeErr
} }
func (t *testObject) TryClose(objectTTL time.Duration) (res bool, err error) {
if t.closeCalled || (t.tryCloseCalled && t.tryReturn) {
panic("close called twice")
}
t.tryCloseCalled = true
if t.closeCh != nil {
<-t.closeCh
return t.tryReturn, t.closeErr
}
return t.tryReturn, nil
}
func TestOCache_Get(t *testing.T) { func TestOCache_Get(t *testing.T) {
t.Run("successful", func(t *testing.T) { t.Run("successful", func(t *testing.T) {
c := New(func(ctx context.Context, id string) (value Object, err error) { c := New(func(ctx context.Context, id string) (value Object, err error) {
@ -116,42 +140,37 @@ func TestOCache_Get(t *testing.T) {
} }
func TestOCache_GC(t *testing.T) { func TestOCache_GC(t *testing.T) {
t.Run("test without close wait", func(t *testing.T) { t.Run("test gc expired object", func(t *testing.T) {
c := New(func(ctx context.Context, id string) (value Object, err error) { c := New(func(ctx context.Context, id string) (value Object, err error) {
return &testObject{name: id}, nil return NewTestObject(id, true, nil), nil
}, WithTTL(time.Millisecond*10), WithRefCounter(true)) }, WithTTL(time.Millisecond*10))
val, err := c.Get(context.TODO(), "id") val, err := c.Get(context.TODO(), "id")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, val) require.NotNil(t, val)
assert.Equal(t, 1, c.Len()) assert.Equal(t, 1, c.Len())
c.GC() c.GC()
assert.Equal(t, 1, c.Len()) assert.Equal(t, 1, c.Len())
time.Sleep(time.Millisecond * 30) time.Sleep(time.Millisecond * 20)
c.GC()
assert.Equal(t, 1, c.Len())
assert.True(t, c.Release("id"))
c.GC() c.GC()
assert.Equal(t, 0, c.Len()) assert.Equal(t, 0, c.Len())
assert.False(t, c.Release("id"))
}) })
t.Run("test with close wait", func(t *testing.T) { t.Run("test gc tryClose true, close before get", func(t *testing.T) {
closeCh := make(chan struct{}) closeCh := make(chan struct{})
getCh := make(chan struct{}) getCh := make(chan struct{})
c := New(func(ctx context.Context, id string) (value Object, err error) { c := New(func(ctx context.Context, id string) (value Object, err error) {
return NewTestObject(id, closeCh), nil return NewTestObject(id, true, closeCh), nil
}, WithTTL(time.Millisecond*10), WithRefCounter(true)) }, WithTTL(time.Millisecond*10))
val, err := c.Get(context.TODO(), "id") val, err := c.Get(context.TODO(), "id")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, val) require.NotNil(t, val)
assert.Equal(t, 1, c.Len()) assert.Equal(t, 1, c.Len())
assert.True(t, c.Release("id"))
// making ttl pass // making ttl pass
time.Sleep(time.Millisecond * 40) time.Sleep(time.Millisecond * 20)
// first gc will be run after 20 secs, so calling it manually // first gc will be run after 20 secs, so calling it manually
go c.GC() go c.GC()
// waiting until all objects are marked as closing // waiting until all objects are marked as closing
time.Sleep(time.Millisecond * 40) time.Sleep(time.Millisecond * 20)
var events []string var events []string
go func() { go func() {
_, err := c.Get(context.TODO(), "id") _, err := c.Get(context.TODO(), "id")
@ -160,33 +179,114 @@ func TestOCache_GC(t *testing.T) {
events = append(events, "get") events = append(events, "get")
close(getCh) close(getCh)
}() }()
events = append(events, "close")
// sleeping to make sure that Get is called // sleeping to make sure that Get is called
time.Sleep(time.Millisecond * 40) time.Sleep(time.Millisecond * 20)
events = append(events, "close")
close(closeCh) close(closeCh)
<-getCh <-getCh
require.Equal(t, []string{"close", "get"}, events) require.Equal(t, []string{"close", "get"}, events)
}) })
t.Run("test gc tryClose false, many parallel get", func(t *testing.T) {
timesCalled := &atomic.Int32{}
obj := NewTestObject("id", false, nil)
c := New(func(ctx context.Context, id string) (value Object, err error) {
timesCalled.Add(1)
return obj, nil
}, WithTTL(0))
val, err := c.Get(context.TODO(), "id")
require.NoError(t, err)
require.NotNil(t, val)
assert.Equal(t, 1, c.Len())
begin := make(chan struct{})
wg := sync.WaitGroup{}
once := sync.Once{}
wg.Add(1)
go func() {
<-begin
c.GC()
wg.Done()
}()
for i := 0; i < 50; i++ {
wg.Add(1)
go func(i int) {
once.Do(func() {
close(begin)
})
if i%2 != 0 {
time.Sleep(time.Millisecond)
}
_, err := c.Get(context.TODO(), "id")
require.NoError(t, err)
wg.Done()
}(i)
}
require.NoError(t, err)
wg.Wait()
require.Equal(t, timesCalled.Load(), int32(1))
require.True(t, obj.tryCloseCalled)
})
t.Run("test gc tryClose different, many objects", func(t *testing.T) {
tryCloseIds := make(map[string]bool)
called := make(map[string]int)
max := 1000
getId := func(i int) string {
return fmt.Sprintf("id%d", i)
}
for i := 0; i < max; i++ {
if i%2 == 1 {
tryCloseIds[getId(i)] = true
} else {
tryCloseIds[getId(i)] = false
}
}
c := New(func(ctx context.Context, id string) (value Object, err error) {
called[id] = called[id] + 1
return NewTestObject(id, tryCloseIds[id], nil), nil
}, WithTTL(time.Millisecond*10))
for i := 0; i < max; i++ {
val, err := c.Get(context.TODO(), getId(i))
require.NoError(t, err)
require.NotNil(t, val)
}
assert.Equal(t, max, c.Len())
time.Sleep(time.Millisecond * 20)
c.GC()
for i := 0; i < max; i++ {
val, err := c.Get(context.TODO(), getId(i))
require.NoError(t, err)
require.NotNil(t, val)
}
for i := 0; i < max; i++ {
val, err := c.Get(context.TODO(), getId(i))
require.NoError(t, err)
require.NotNil(t, val)
require.Equal(t, called[getId(i)], i%2+1)
}
})
} }
func Test_OCache_Remove(t *testing.T) { func Test_OCache_Remove(t *testing.T) {
t.Run("remove simple", func(t *testing.T) {
closeCh := make(chan struct{}) closeCh := make(chan struct{})
getCh := make(chan struct{}) getCh := make(chan struct{})
c := New(func(ctx context.Context, id string) (value Object, err error) { c := New(func(ctx context.Context, id string) (value Object, err error) {
return NewTestObject(id, closeCh), nil return NewTestObject(id, false, closeCh), nil
}, WithTTL(time.Millisecond*10)) }, WithTTL(time.Millisecond*10))
val, err := c.Get(context.TODO(), "id") val, err := c.Get(context.TODO(), "id")
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, val) require.NotNil(t, val)
assert.Equal(t, 1, c.Len()) assert.Equal(t, 1, c.Len())
// removing the object, so we will wait on closing // removing the object, so we will wait on closing
go func() { go func() {
_, err := c.Remove("id") _, err := c.Remove(ctx, "id")
require.NoError(t, err) require.NoError(t, err)
}() }()
time.Sleep(time.Millisecond * 40) time.Sleep(time.Millisecond * 20)
var events []string var events []string
go func() { go func() {
@ -196,11 +296,215 @@ func Test_OCache_Remove(t *testing.T) {
events = append(events, "get") events = append(events, "get")
close(getCh) close(getCh)
}() }()
events = append(events, "close")
// sleeping to make sure that Get is called // sleeping to make sure that Get is called
time.Sleep(time.Millisecond * 40) time.Sleep(time.Millisecond * 20)
events = append(events, "close")
close(closeCh) close(closeCh)
<-getCh <-getCh
require.Equal(t, []string{"close", "get"}, events) require.Equal(t, []string{"close", "get"}, events)
})
t.Run("test remove while gc, tryClose false", func(t *testing.T) {
closeCh := make(chan struct{})
removeCh := make(chan struct{})
c := New(func(ctx context.Context, id string) (value Object, err error) {
return NewTestObject(id, false, closeCh), nil
}, WithTTL(time.Millisecond*10))
val, err := c.Get(context.TODO(), "id")
require.NoError(t, err)
require.NotNil(t, val)
assert.Equal(t, 1, c.Len())
time.Sleep(time.Millisecond * 20)
go c.GC()
time.Sleep(time.Millisecond * 20)
var events []string
go func() {
ok, err := c.Remove(ctx, "id")
require.NoError(t, err)
require.True(t, ok)
events = append(events, "remove")
close(removeCh)
}()
time.Sleep(time.Millisecond * 20)
events = append(events, "close")
close(closeCh)
<-removeCh
require.Equal(t, []string{"close", "remove"}, events)
})
t.Run("test remove while gc, tryClose true", func(t *testing.T) {
closeCh := make(chan struct{})
removeCh := make(chan struct{})
c := New(func(ctx context.Context, id string) (value Object, err error) {
return NewTestObject(id, true, closeCh), nil
}, WithTTL(time.Millisecond*10))
val, err := c.Get(context.TODO(), "id")
require.NoError(t, err)
require.NotNil(t, val)
assert.Equal(t, 1, c.Len())
time.Sleep(time.Millisecond * 20)
go c.GC()
time.Sleep(time.Millisecond * 20)
var events []string
go func() {
ok, err := c.Remove(ctx, "id")
require.NoError(t, err)
require.False(t, ok)
events = append(events, "remove")
close(removeCh)
}()
time.Sleep(time.Millisecond * 20)
events = append(events, "close")
close(closeCh)
<-removeCh
require.Equal(t, []string{"close", "remove"}, events)
})
t.Run("test gc while remove, tryClose true", func(t *testing.T) {
closeCh := make(chan struct{})
removeCh := make(chan struct{})
c := New(func(ctx context.Context, id string) (value Object, err error) {
return NewTestObject(id, true, closeCh), nil
}, WithTTL(time.Millisecond*10))
val, err := c.Get(context.TODO(), "id")
require.NoError(t, err)
require.NotNil(t, val)
assert.Equal(t, 1, c.Len())
go func() {
ok, err := c.Remove(ctx, "id")
require.NoError(t, err)
require.True(t, ok)
close(removeCh)
}()
time.Sleep(20 * time.Millisecond)
c.GC()
close(closeCh)
<-removeCh
})
}
func TestOCacheCancelWhenRemove(t *testing.T) {
c := New(func(ctx context.Context, id string) (value Object, err error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
}
}, WithTTL(time.Millisecond*10))
stopLoad := make(chan struct{})
var err error
go func() {
_, err = c.Get(context.TODO(), "id")
stopLoad <- struct{}{}
}()
time.Sleep(time.Millisecond * 10)
c.Close()
<-stopLoad
require.Equal(t, context.Canceled, err)
}
func TestOCacheFuzzy(t *testing.T) {
t.Run("test many objects gc, get and remove simultaneously, close after", func(t *testing.T) {
tryCloseIds := make(map[string]bool)
max := 2000
getId := func(i int) string {
return fmt.Sprintf("id%d", i)
}
for i := 0; i < max; i++ {
if i%2 == 1 {
tryCloseIds[getId(i)] = true
} else {
tryCloseIds[getId(i)] = false
}
}
c := New(func(ctx context.Context, id string) (value Object, err error) {
return NewTestObject(id, tryCloseIds[id], nil), nil
}, WithTTL(time.Nanosecond))
stopGC := make(chan struct{})
wg := sync.WaitGroup{}
go func() {
for {
select {
case <-stopGC:
return
default:
c.GC()
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
for i := 0; i < max; i++ {
val, err := c.Get(context.TODO(), getId(i))
require.NoError(t, err)
require.NotNil(t, val)
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < 10; j++ {
for i := 0; i < max; i++ {
c.Remove(ctx, getId(i))
}
}
}()
wg.Wait()
close(stopGC)
err := c.Close()
require.NoError(t, err)
require.Equal(t, 0, c.Len())
})
t.Run("test many objects gc, get, remove and close simultaneously", func(t *testing.T) {
tryCloseIds := make(map[string]bool)
max := 2000
getId := func(i int) string {
return fmt.Sprintf("id%d", i)
}
for i := 0; i < max; i++ {
if i%2 == 1 {
tryCloseIds[getId(i)] = true
} else {
tryCloseIds[getId(i)] = false
}
}
c := New(func(ctx context.Context, id string) (value Object, err error) {
return NewTestObject(id, tryCloseIds[id], nil), nil
}, WithTTL(time.Nanosecond))
go func() {
for {
c.GC()
}
}()
go func() {
for j := 0; j < 10; j++ {
for i := 0; i < max; i++ {
val, err := c.Get(context.TODO(), getId(i))
if err == ErrClosed {
return
}
require.NoError(t, err)
require.NotNil(t, val)
}
}
}()
go func() {
for j := 0; j < 10; j++ {
for i := 0; i < max; i++ {
c.Remove(ctx, getId(i))
}
}
}()
time.Sleep(time.Millisecond)
err := c.Close()
require.NoError(t, err)
require.Equal(t, 0, c.Len())
})
} }

View File

@ -2,8 +2,8 @@ package fileblockstore
import ( import (
"context" "context"
"github.com/anytypeio/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anytypeio/any-sync/commonfile/fileproto/fileprotoerr" "github.com/anyproto/any-sync/commonfile/fileproto/fileprotoerr"
blocks "github.com/ipfs/go-block-format" blocks "github.com/ipfs/go-block-format"
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"
) )
@ -21,6 +21,7 @@ type ctxKey uint
const ( const (
ctxKeySpaceId ctxKey = iota ctxKeySpaceId ctxKey = iota
ctxKeyFileId
) )
type BlockStore interface { type BlockStore interface {
@ -48,3 +49,12 @@ func CtxGetSpaceId(ctx context.Context) (spaceId string) {
spaceId, _ = ctx.Value(ctxKeySpaceId).(string) spaceId, _ = ctx.Value(ctxKeySpaceId).(string)
return return
} }
func CtxWithFileId(ctx context.Context, spaceId string) context.Context {
return context.WithValue(ctx, ctxKeyFileId, spaceId)
}
func CtxGetFileId(ctx context.Context) (spaceId string) {
spaceId, _ = ctx.Value(ctxKeyFileId).(string)
return
}

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT. // Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32 // protoc-gen-go-drpc version: v0.0.33
// source: commonfile/fileproto/protos/file.proto // source: commonfile/fileproto/protos/file.proto
package fileproto package fileproto
@ -40,10 +40,14 @@ func (drpcEncoding_File_commonfile_fileproto_protos_file_proto) JSONUnmarshal(bu
type DRPCFileClient interface { type DRPCFileClient interface {
DRPCConn() drpc.Conn DRPCConn() drpc.Conn
GetBlocks(ctx context.Context) (DRPCFile_GetBlocksClient, error) BlockGet(ctx context.Context, in *BlockGetRequest) (*BlockGetResponse, error)
PushBlock(ctx context.Context, in *PushBlockRequest) (*PushBlockResponse, error) BlockPush(ctx context.Context, in *BlockPushRequest) (*BlockPushResponse, error)
DeleteBlocks(ctx context.Context, in *DeleteBlocksRequest) (*DeleteBlocksResponse, error) BlocksCheck(ctx context.Context, in *BlocksCheckRequest) (*BlocksCheckResponse, error)
BlocksBind(ctx context.Context, in *BlocksBindRequest) (*BlocksBindResponse, error)
FilesDelete(ctx context.Context, in *FilesDeleteRequest) (*FilesDeleteResponse, error)
FilesInfo(ctx context.Context, in *FilesInfoRequest) (*FilesInfoResponse, error)
Check(ctx context.Context, in *CheckRequest) (*CheckResponse, error) Check(ctx context.Context, in *CheckRequest) (*CheckResponse, error)
SpaceInfo(ctx context.Context, in *SpaceInfoRequest) (*SpaceInfoResponse, error)
} }
type drpcFileClient struct { type drpcFileClient struct {
@ -56,53 +60,54 @@ func NewDRPCFileClient(cc drpc.Conn) DRPCFileClient {
func (c *drpcFileClient) DRPCConn() drpc.Conn { return c.cc } func (c *drpcFileClient) DRPCConn() drpc.Conn { return c.cc }
func (c *drpcFileClient) GetBlocks(ctx context.Context) (DRPCFile_GetBlocksClient, error) { func (c *drpcFileClient) BlockGet(ctx context.Context, in *BlockGetRequest) (*BlockGetResponse, error) {
stream, err := c.cc.NewStream(ctx, "/anyFile.File/GetBlocks", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}) out := new(BlockGetResponse)
if err != nil { err := c.cc.Invoke(ctx, "/filesync.File/BlockGet", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
return nil, err
}
x := &drpcFile_GetBlocksClient{stream}
return x, nil
}
type DRPCFile_GetBlocksClient interface {
drpc.Stream
Send(*GetBlockRequest) error
Recv() (*GetBlockResponse, error)
}
type drpcFile_GetBlocksClient struct {
drpc.Stream
}
func (x *drpcFile_GetBlocksClient) Send(m *GetBlockRequest) error {
return x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{})
}
func (x *drpcFile_GetBlocksClient) Recv() (*GetBlockResponse, error) {
m := new(GetBlockResponse)
if err := x.MsgRecv(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
return nil, err
}
return m, nil
}
func (x *drpcFile_GetBlocksClient) RecvMsg(m *GetBlockResponse) error {
return x.MsgRecv(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{})
}
func (c *drpcFileClient) PushBlock(ctx context.Context, in *PushBlockRequest) (*PushBlockResponse, error) {
out := new(PushBlockResponse)
err := c.cc.Invoke(ctx, "/anyFile.File/PushBlock", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return out, nil return out, nil
} }
func (c *drpcFileClient) DeleteBlocks(ctx context.Context, in *DeleteBlocksRequest) (*DeleteBlocksResponse, error) { func (c *drpcFileClient) BlockPush(ctx context.Context, in *BlockPushRequest) (*BlockPushResponse, error) {
out := new(DeleteBlocksResponse) out := new(BlockPushResponse)
err := c.cc.Invoke(ctx, "/anyFile.File/DeleteBlocks", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out) err := c.cc.Invoke(ctx, "/filesync.File/BlockPush", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcFileClient) BlocksCheck(ctx context.Context, in *BlocksCheckRequest) (*BlocksCheckResponse, error) {
out := new(BlocksCheckResponse)
err := c.cc.Invoke(ctx, "/filesync.File/BlocksCheck", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcFileClient) BlocksBind(ctx context.Context, in *BlocksBindRequest) (*BlocksBindResponse, error) {
out := new(BlocksBindResponse)
err := c.cc.Invoke(ctx, "/filesync.File/BlocksBind", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcFileClient) FilesDelete(ctx context.Context, in *FilesDeleteRequest) (*FilesDeleteResponse, error) {
out := new(FilesDeleteResponse)
err := c.cc.Invoke(ctx, "/filesync.File/FilesDelete", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcFileClient) FilesInfo(ctx context.Context, in *FilesInfoRequest) (*FilesInfoResponse, error) {
out := new(FilesInfoResponse)
err := c.cc.Invoke(ctx, "/filesync.File/FilesInfo", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -111,7 +116,16 @@ func (c *drpcFileClient) DeleteBlocks(ctx context.Context, in *DeleteBlocksReque
func (c *drpcFileClient) Check(ctx context.Context, in *CheckRequest) (*CheckResponse, error) { func (c *drpcFileClient) Check(ctx context.Context, in *CheckRequest) (*CheckResponse, error) {
out := new(CheckResponse) out := new(CheckResponse)
err := c.cc.Invoke(ctx, "/anyFile.File/Check", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out) err := c.cc.Invoke(ctx, "/filesync.File/Check", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
func (c *drpcFileClient) SpaceInfo(ctx context.Context, in *SpaceInfoRequest) (*SpaceInfoResponse, error) {
out := new(SpaceInfoResponse)
err := c.cc.Invoke(ctx, "/filesync.File/SpaceInfo", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -119,23 +133,39 @@ func (c *drpcFileClient) Check(ctx context.Context, in *CheckRequest) (*CheckRes
} }
type DRPCFileServer interface { type DRPCFileServer interface {
GetBlocks(DRPCFile_GetBlocksStream) error BlockGet(context.Context, *BlockGetRequest) (*BlockGetResponse, error)
PushBlock(context.Context, *PushBlockRequest) (*PushBlockResponse, error) BlockPush(context.Context, *BlockPushRequest) (*BlockPushResponse, error)
DeleteBlocks(context.Context, *DeleteBlocksRequest) (*DeleteBlocksResponse, error) BlocksCheck(context.Context, *BlocksCheckRequest) (*BlocksCheckResponse, error)
BlocksBind(context.Context, *BlocksBindRequest) (*BlocksBindResponse, error)
FilesDelete(context.Context, *FilesDeleteRequest) (*FilesDeleteResponse, error)
FilesInfo(context.Context, *FilesInfoRequest) (*FilesInfoResponse, error)
Check(context.Context, *CheckRequest) (*CheckResponse, error) Check(context.Context, *CheckRequest) (*CheckResponse, error)
SpaceInfo(context.Context, *SpaceInfoRequest) (*SpaceInfoResponse, error)
} }
type DRPCFileUnimplementedServer struct{} type DRPCFileUnimplementedServer struct{}
func (s *DRPCFileUnimplementedServer) GetBlocks(DRPCFile_GetBlocksStream) error { func (s *DRPCFileUnimplementedServer) BlockGet(context.Context, *BlockGetRequest) (*BlockGetResponse, error) {
return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCFileUnimplementedServer) PushBlock(context.Context, *PushBlockRequest) (*PushBlockResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
} }
func (s *DRPCFileUnimplementedServer) DeleteBlocks(context.Context, *DeleteBlocksRequest) (*DeleteBlocksResponse, error) { func (s *DRPCFileUnimplementedServer) BlockPush(context.Context, *BlockPushRequest) (*BlockPushResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCFileUnimplementedServer) BlocksCheck(context.Context, *BlocksCheckRequest) (*BlocksCheckResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCFileUnimplementedServer) BlocksBind(context.Context, *BlocksBindRequest) (*BlocksBindResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCFileUnimplementedServer) FilesDelete(context.Context, *FilesDeleteRequest) (*FilesDeleteResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
func (s *DRPCFileUnimplementedServer) FilesInfo(context.Context, *FilesInfoRequest) (*FilesInfoResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
} }
@ -143,40 +173,72 @@ func (s *DRPCFileUnimplementedServer) Check(context.Context, *CheckRequest) (*Ch
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
} }
func (s *DRPCFileUnimplementedServer) SpaceInfo(context.Context, *SpaceInfoRequest) (*SpaceInfoResponse, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCFileDescription struct{} type DRPCFileDescription struct{}
func (DRPCFileDescription) NumMethods() int { return 4 } func (DRPCFileDescription) NumMethods() int { return 8 }
func (DRPCFileDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) { func (DRPCFileDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n { switch n {
case 0: case 0:
return "/anyFile.File/GetBlocks", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, return "/filesync.File/BlockGet", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return nil, srv.(DRPCFileServer). return srv.(DRPCFileServer).
GetBlocks( BlockGet(
&drpcFile_GetBlocksStream{in1.(drpc.Stream)}, ctx,
in1.(*BlockGetRequest),
) )
}, DRPCFileServer.GetBlocks, true }, DRPCFileServer.BlockGet, true
case 1: case 1:
return "/anyFile.File/PushBlock", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, return "/filesync.File/BlockPush", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCFileServer). return srv.(DRPCFileServer).
PushBlock( BlockPush(
ctx, ctx,
in1.(*PushBlockRequest), in1.(*BlockPushRequest),
) )
}, DRPCFileServer.PushBlock, true }, DRPCFileServer.BlockPush, true
case 2: case 2:
return "/anyFile.File/DeleteBlocks", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, return "/filesync.File/BlocksCheck", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCFileServer). return srv.(DRPCFileServer).
DeleteBlocks( BlocksCheck(
ctx, ctx,
in1.(*DeleteBlocksRequest), in1.(*BlocksCheckRequest),
) )
}, DRPCFileServer.DeleteBlocks, true }, DRPCFileServer.BlocksCheck, true
case 3: case 3:
return "/anyFile.File/Check", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, return "/filesync.File/BlocksBind", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCFileServer).
BlocksBind(
ctx,
in1.(*BlocksBindRequest),
)
}, DRPCFileServer.BlocksBind, true
case 4:
return "/filesync.File/FilesDelete", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCFileServer).
FilesDelete(
ctx,
in1.(*FilesDeleteRequest),
)
}, DRPCFileServer.FilesDelete, true
case 5:
return "/filesync.File/FilesInfo", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCFileServer).
FilesInfo(
ctx,
in1.(*FilesInfoRequest),
)
}, DRPCFileServer.FilesInfo, true
case 6:
return "/filesync.File/Check", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCFileServer). return srv.(DRPCFileServer).
Check( Check(
@ -184,6 +246,15 @@ func (DRPCFileDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
in1.(*CheckRequest), in1.(*CheckRequest),
) )
}, DRPCFileServer.Check, true }, DRPCFileServer.Check, true
case 7:
return "/filesync.File/SpaceInfo", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCFileServer).
SpaceInfo(
ctx,
in1.(*SpaceInfoRequest),
)
}, DRPCFileServer.SpaceInfo, true
default: default:
return "", nil, nil, nil, false return "", nil, nil, nil, false
} }
@ -193,58 +264,96 @@ func DRPCRegisterFile(mux drpc.Mux, impl DRPCFileServer) error {
return mux.Register(impl, DRPCFileDescription{}) return mux.Register(impl, DRPCFileDescription{})
} }
type DRPCFile_GetBlocksStream interface { type DRPCFile_BlockGetStream interface {
drpc.Stream drpc.Stream
Send(*GetBlockResponse) error SendAndClose(*BlockGetResponse) error
Recv() (*GetBlockRequest, error)
} }
type drpcFile_GetBlocksStream struct { type drpcFile_BlockGetStream struct {
drpc.Stream drpc.Stream
} }
func (x *drpcFile_GetBlocksStream) Send(m *GetBlockResponse) error { func (x *drpcFile_BlockGetStream) SendAndClose(m *BlockGetResponse) error {
return x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{})
}
func (x *drpcFile_GetBlocksStream) Recv() (*GetBlockRequest, error) {
m := new(GetBlockRequest)
if err := x.MsgRecv(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
return nil, err
}
return m, nil
}
func (x *drpcFile_GetBlocksStream) RecvMsg(m *GetBlockRequest) error {
return x.MsgRecv(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{})
}
type DRPCFile_PushBlockStream interface {
drpc.Stream
SendAndClose(*PushBlockResponse) error
}
type drpcFile_PushBlockStream struct {
drpc.Stream
}
func (x *drpcFile_PushBlockStream) SendAndClose(m *PushBlockResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil { if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
return err return err
} }
return x.CloseSend() return x.CloseSend()
} }
type DRPCFile_DeleteBlocksStream interface { type DRPCFile_BlockPushStream interface {
drpc.Stream drpc.Stream
SendAndClose(*DeleteBlocksResponse) error SendAndClose(*BlockPushResponse) error
} }
type drpcFile_DeleteBlocksStream struct { type drpcFile_BlockPushStream struct {
drpc.Stream drpc.Stream
} }
func (x *drpcFile_DeleteBlocksStream) SendAndClose(m *DeleteBlocksResponse) error { func (x *drpcFile_BlockPushStream) SendAndClose(m *BlockPushResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCFile_BlocksCheckStream interface {
drpc.Stream
SendAndClose(*BlocksCheckResponse) error
}
type drpcFile_BlocksCheckStream struct {
drpc.Stream
}
func (x *drpcFile_BlocksCheckStream) SendAndClose(m *BlocksCheckResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCFile_BlocksBindStream interface {
drpc.Stream
SendAndClose(*BlocksBindResponse) error
}
type drpcFile_BlocksBindStream struct {
drpc.Stream
}
func (x *drpcFile_BlocksBindStream) SendAndClose(m *BlocksBindResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCFile_FilesDeleteStream interface {
drpc.Stream
SendAndClose(*FilesDeleteResponse) error
}
type drpcFile_FilesDeleteStream struct {
drpc.Stream
}
func (x *drpcFile_FilesDeleteStream) SendAndClose(m *FilesDeleteResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
return err
}
return x.CloseSend()
}
type DRPCFile_FilesInfoStream interface {
drpc.Stream
SendAndClose(*FilesInfoResponse) error
}
type drpcFile_FilesInfoStream struct {
drpc.Stream
}
func (x *drpcFile_FilesInfoStream) SendAndClose(m *FilesInfoResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil { if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
return err return err
} }
@ -266,3 +375,19 @@ func (x *drpcFile_CheckStream) SendAndClose(m *CheckResponse) error {
} }
return x.CloseSend() return x.CloseSend()
} }
type DRPCFile_SpaceInfoStream interface {
drpc.Stream
SendAndClose(*SpaceInfoResponse) error
}
type drpcFile_SpaceInfoStream struct {
drpc.Stream
}
func (x *drpcFile_SpaceInfoStream) SendAndClose(m *SpaceInfoResponse) error {
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
return err
}
return x.CloseSend()
}

View File

@ -2,12 +2,16 @@ package fileprotoerr
import ( import (
"fmt" "fmt"
"github.com/anytypeio/any-sync/commonfile/fileproto" "github.com/anyproto/any-sync/commonfile/fileproto"
"github.com/anytypeio/any-sync/net/rpc/rpcerr" "github.com/anyproto/any-sync/net/rpc/rpcerr"
) )
var ( var (
errGroup = rpcerr.ErrGroup(fileproto.ErrCodes_ErrorOffset) errGroup = rpcerr.ErrGroup(fileproto.ErrCodes_ErrorOffset)
ErrUnexpected = errGroup.Register(fmt.Errorf("unexpected fileproto error"), uint64(fileproto.ErrCodes_Unexpected)) ErrUnexpected = errGroup.Register(fmt.Errorf("unexpected fileproto error"), uint64(fileproto.ErrCodes_Unexpected))
ErrCIDNotFound = errGroup.Register(fmt.Errorf("CID not found"), uint64(fileproto.ErrCodes_CIDNotFound)) ErrCIDNotFound = errGroup.Register(fmt.Errorf("CID not found"), uint64(fileproto.ErrCodes_CIDNotFound))
ErrForbidden = errGroup.Register(fmt.Errorf("forbidden"), uint64(fileproto.ErrCodes_Forbidden))
ErrSpaceLimitExceeded = errGroup.Register(fmt.Errorf("space limit exceeded"), uint64(fileproto.ErrCodes_SpaceLimitExceeded))
ErrQuerySizeExceeded = errGroup.Register(fmt.Errorf("query size exceeded"), uint64(fileproto.ErrCodes_QuerySizeExceeded))
ErrWrongHash = errGroup.Register(fmt.Errorf("wrong block hash"), uint64(fileproto.ErrCodes_WrongHash))
) )

View File

@ -1,60 +1,123 @@
syntax = "proto3"; syntax = "proto3";
package anyFile; package filesync;
option go_package = "commonfile/fileproto"; option go_package = "commonfile/fileproto";
enum ErrCodes { enum ErrCodes {
Unexpected = 0; Unexpected = 0;
CIDNotFound = 1; CIDNotFound = 1;
Forbidden = 2;
SpaceLimitExceeded = 3;
QuerySizeExceeded = 4;
WrongHash = 5;
ErrorOffset = 200; ErrorOffset = 200;
} }
service File { service File {
// GetBlocks streams ipfs blocks from server to client // BlockGet gets one block from a server
rpc GetBlocks(stream GetBlockRequest) returns (stream GetBlockResponse); rpc BlockGet(BlockGetRequest) returns (BlockGetResponse);
// PushBlock pushes one block to server // BlockPush pushes one block to a server
rpc PushBlock(PushBlockRequest) returns (PushBlockResponse); rpc BlockPush(BlockPushRequest) returns (BlockPushResponse);
// DeleteBlock deletes block from space // BlocksCheck checks is CIDs exists
rpc DeleteBlocks(DeleteBlocksRequest) returns (DeleteBlocksResponse); rpc BlocksCheck(BlocksCheckRequest) returns (BlocksCheckResponse);
// Ping checks the connection // BlocksBind binds CIDs to space
rpc BlocksBind(BlocksBindRequest) returns (BlocksBindResponse);
// FilesDelete deletes files by id
rpc FilesDelete(FilesDeleteRequest) returns (FilesDeleteResponse);
// FilesInfo return info by given files id
rpc FilesInfo(FilesInfoRequest) returns (FilesInfoResponse);
// Check checks the connection and credentials
rpc Check(CheckRequest) returns (CheckResponse); rpc Check(CheckRequest) returns (CheckResponse);
// SpaceInfo returns usage, limit, etc about space
rpc SpaceInfo(SpaceInfoRequest) returns (SpaceInfoResponse);
} }
message GetBlockRequest { message BlockGetRequest {
string spaceId = 1; string spaceId = 1;
bytes cid = 2; bytes cid = 2;
} }
message GetBlockResponse { message BlockGetResponse {
bytes cid = 1; bytes cid = 1;
bytes data = 2; bytes data = 2;
CIDError code = 3;
} }
message PushBlockRequest { message BlockPushRequest {
string spaceId = 1; string spaceId = 1;
bytes cid = 2; string fileId = 2;
bytes data = 3; bytes cid = 3;
bytes data = 4;
} }
message PushBlockResponse {} message BlockPushResponse {}
message DeleteBlocksRequest {
message BlocksCheckRequest {
string spaceId = 1; string spaceId = 1;
repeated bytes cid = 2; repeated bytes cids = 2;
} }
message DeleteBlocksResponse {} message BlocksCheckResponse {
repeated BlockAvailability blocksAvailability = 1;
}
message BlockAvailability {
bytes cid = 1;
AvailabilityStatus status = 2;
}
enum AvailabilityStatus {
NotExists = 0;
Exists = 1;
ExistsInSpace = 2;
}
message BlocksBindRequest {
string spaceId = 1;
string fileId = 2;
repeated bytes cids = 3;
}
message BlocksBindResponse {}
message FilesDeleteRequest {
string spaceId = 1;
repeated string fileIds = 2;
}
message FilesDeleteResponse {}
message FilesInfoRequest {
string spaceId = 1;
repeated string fileIds = 2;
}
message FilesInfoResponse {
repeated FileInfo filesInfo = 1;
}
message FileInfo {
string fileId = 1;
uint64 usageBytes = 2;
uint32 cidsCount = 3;
}
message CheckRequest {} message CheckRequest {}
message CheckResponse { message CheckResponse {
repeated string spaceIds = 1; repeated string spaceIds = 1;
bool allowWrite = 2;
} }
message SpaceInfoRequest {
enum CIDError { string spaceId = 1;
CIDErrorOk = 0;
CIDErrorNotFound = 1;
CIDErrorUnexpected = 2;
} }
message SpaceInfoResponse {
uint64 limitBytes = 1;
uint64 usageBytes = 2;
uint64 cidsCount = 3;
uint64 filesCount = 4;
}

View File

@ -2,7 +2,7 @@ package fileservice
import ( import (
"context" "context"
"github.com/anytypeio/any-sync/commonfile/fileblockstore" "github.com/anyproto/any-sync/commonfile/fileblockstore"
blocks "github.com/ipfs/go-block-format" blocks "github.com/ipfs/go-block-format"
"github.com/ipfs/go-blockservice" "github.com/ipfs/go-blockservice"
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"

View File

@ -3,9 +3,9 @@ package fileservice
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/anytypeio/any-sync/app" "github.com/anyproto/any-sync/app"
"github.com/anytypeio/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anytypeio/any-sync/commonfile/fileblockstore" "github.com/anyproto/any-sync/commonfile/fileblockstore"
"github.com/ipfs/go-cid" "github.com/ipfs/go-cid"
chunker "github.com/ipfs/go-ipfs-chunker" chunker "github.com/ipfs/go-ipfs-chunker"
ipld "github.com/ipfs/go-ipld-format" ipld "github.com/ipfs/go-ipld-format"
@ -22,6 +22,10 @@ const CName = "common.commonfile.fileservice"
var log = logger.NewNamed(CName) var log = logger.NewNamed(CName)
const (
ChunkSize = 1 << 20
)
func New() FileService { func New() FileService {
return &fileService{} return &fileService{}
} }
@ -74,7 +78,7 @@ func (fs *fileService) AddFile(ctx context.Context, r io.Reader) (ipld.Node, err
Maxlinks: helpers.DefaultLinksPerBlock, Maxlinks: helpers.DefaultLinksPerBlock,
CidBuilder: &fs.prefix, CidBuilder: &fs.prefix,
} }
dbh, err := dbp.New(chunker.DefaultSplitter(r)) dbh, err := dbp.New(chunker.NewSizeSplitter(r, ChunkSize))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -1,53 +0,0 @@
package commonspace
import (
"context"
"github.com/anytypeio/any-sync/commonspace/object/syncobjectgetter"
"github.com/anytypeio/any-sync/commonspace/object/tree/objecttree"
"github.com/anytypeio/any-sync/commonspace/object/treegetter"
)
type commonGetter struct {
treegetter.TreeGetter
spaceId string
reservedObjects []syncobjectgetter.SyncObject
}
func newCommonGetter(spaceId string, getter treegetter.TreeGetter) *commonGetter {
return &commonGetter{
TreeGetter: getter,
spaceId: spaceId,
}
}
func (c *commonGetter) AddObject(object syncobjectgetter.SyncObject) {
c.reservedObjects = append(c.reservedObjects, object)
}
func (c *commonGetter) GetTree(ctx context.Context, spaceId, treeId string) (objecttree.ObjectTree, error) {
if obj := c.getReservedObject(treeId); obj != nil {
return obj.(objecttree.ObjectTree), nil
}
return c.TreeGetter.GetTree(ctx, spaceId, treeId)
}
func (c *commonGetter) getReservedObject(id string) syncobjectgetter.SyncObject {
for _, obj := range c.reservedObjects {
if obj != nil && obj.Id() == id {
return obj
}
}
return nil
}
func (c *commonGetter) GetObject(ctx context.Context, objectId string) (obj syncobjectgetter.SyncObject, err error) {
if obj := c.getReservedObject(objectId); obj != nil {
return obj, nil
}
t, err := c.TreeGetter.GetTree(ctx, c.spaceId, objectId)
if err != nil {
return
}
obj = t.(syncobjectgetter.SyncObject)
return
}

View File

@ -1,8 +1,8 @@
package commonspace package commonspace
import ( import (
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anytypeio/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
) )
type commonStorage struct { type commonStorage struct {

View File

@ -1,10 +0,0 @@
package commonspace
type ConfigGetter interface {
GetSpace() Config
}
type Config struct {
GCTTL int `yaml:"gcTTL"`
SyncPeriod int `yaml:"syncPeriod"`
}

View File

@ -0,0 +1,11 @@
package config
type ConfigGetter interface {
GetSpace() Config
}
type Config struct {
GCTTL int `yaml:"gcTTL"`
SyncPeriod int `yaml:"syncPeriod"`
KeepTreeDataInMemory bool `yaml:"keepTreeDataInMemory"`
}

View File

@ -0,0 +1,34 @@
//go:generate mockgen -destination mock_credentialprovider/mock_credentialprovider.go github.com/anyproto/any-sync/commonspace/credentialprovider CredentialProvider
package credentialprovider
import (
"context"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
)
const CName = "common.commonspace.credentialprovider"
func NewNoOp() CredentialProvider {
return &noOpProvider{}
}
type CredentialProvider interface {
app.Component
GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error)
}
type noOpProvider struct {
}
func (n noOpProvider) Init(a *app.App) (err error) {
return nil
}
func (n noOpProvider) Name() (name string) {
return CName
}
func (n noOpProvider) GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) {
return nil, nil
}

View File

@ -0,0 +1,80 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/credentialprovider (interfaces: CredentialProvider)
// Package mock_credentialprovider is a generated GoMock package.
package mock_credentialprovider
import (
context "context"
reflect "reflect"
app "github.com/anyproto/any-sync/app"
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
gomock "go.uber.org/mock/gomock"
)
// MockCredentialProvider is a mock of CredentialProvider interface.
type MockCredentialProvider struct {
ctrl *gomock.Controller
recorder *MockCredentialProviderMockRecorder
}
// MockCredentialProviderMockRecorder is the mock recorder for MockCredentialProvider.
type MockCredentialProviderMockRecorder struct {
mock *MockCredentialProvider
}
// NewMockCredentialProvider creates a new mock instance.
func NewMockCredentialProvider(ctrl *gomock.Controller) *MockCredentialProvider {
mock := &MockCredentialProvider{ctrl: ctrl}
mock.recorder = &MockCredentialProviderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockCredentialProvider) EXPECT() *MockCredentialProviderMockRecorder {
return m.recorder
}
// GetCredential mocks base method.
func (m *MockCredentialProvider) GetCredential(arg0 context.Context, arg1 *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetCredential", arg0, arg1)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetCredential indicates an expected call of GetCredential.
func (mr *MockCredentialProviderMockRecorder) GetCredential(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCredential", reflect.TypeOf((*MockCredentialProvider)(nil).GetCredential), arg0, arg1)
}
// Init mocks base method.
func (m *MockCredentialProvider) Init(arg0 *app.App) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Init indicates an expected call of Init.
func (mr *MockCredentialProviderMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockCredentialProvider)(nil).Init), arg0)
}
// Name mocks base method.
func (m *MockCredentialProvider) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockCredentialProviderMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockCredentialProvider)(nil).Name))
}

View File

@ -0,0 +1,278 @@
package commonspace
import (
"context"
"fmt"
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/settings"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/util/crypto"
"github.com/stretchr/testify/require"
"math/rand"
"testing"
"time"
)
func addIncorrectSnapshot(settingsObject settings.SettingsObject, acc *accountdata.AccountKeys, partialIds map[string]struct{}, newId string) (err error) {
factory := settingsstate.NewChangeFactory()
bytes, err := factory.CreateObjectDeleteChange(newId, &settingsstate.State{DeletedIds: partialIds}, true)
if err != nil {
return
}
ch, err := settingsObject.PrepareChange(objecttree.SignableChangeContent{
Data: bytes,
Key: acc.SignKey,
IsSnapshot: true,
IsEncrypted: false,
Timestamp: time.Now().Unix(),
})
if err != nil {
return
}
res, err := settingsObject.AddRawChanges(context.Background(), objecttree.RawChangesPayload{
NewHeads: []string{ch.Id},
RawChanges: []*treechangeproto.RawTreeChangeWithId{ch},
})
if err != nil {
return
}
if res.Mode != objecttree.Rebuild {
return fmt.Errorf("incorrect mode: %d", res.Mode)
}
return
}
func TestSpaceDeleteIds(t *testing.T) {
fx := newFixture(t)
acc := fx.account.Account()
rk := crypto.NewAES()
ctx := context.Background()
totalObjs := 1500
// creating space
sp, err := fx.spaceService.CreateSpace(ctx, SpaceCreatePayload{
SigningKey: acc.SignKey,
SpaceType: "type",
ReadKey: rk.Bytes(),
ReplicationKey: 10,
MasterKey: acc.PeerKey,
})
require.NoError(t, err)
require.NotNil(t, sp)
// initializing space
spc, err := fx.spaceService.NewSpace(ctx, sp)
require.NoError(t, err)
require.NotNil(t, spc)
// adding space to tree manager
fx.treeManager.space = spc
err = spc.Init(ctx)
require.NoError(t, err)
close(fx.treeManager.waitLoad)
var ids []string
for i := 0; i < totalObjs; i++ {
// creating a tree
bytes := make([]byte, 32)
rand.Read(bytes)
doc, err := spc.TreeBuilder().CreateTree(ctx, objecttree.ObjectTreeCreatePayload{
PrivKey: acc.SignKey,
ChangeType: "some",
SpaceId: spc.Id(),
IsEncrypted: false,
Seed: bytes,
Timestamp: time.Now().Unix(),
})
require.NoError(t, err)
tr, err := spc.TreeBuilder().PutTree(ctx, doc, nil)
require.NoError(t, err)
ids = append(ids, tr.Id())
tr.Close()
}
// deleting trees
for _, id := range ids {
err = spc.DeleteTree(ctx, id)
require.NoError(t, err)
}
time.Sleep(3 * time.Second)
spc.Close()
require.Equal(t, len(ids), len(fx.treeManager.deletedIds))
}
func createTree(t *testing.T, ctx context.Context, spc Space, acc *accountdata.AccountKeys) string {
bytes := make([]byte, 32)
rand.Read(bytes)
doc, err := spc.TreeBuilder().CreateTree(ctx, objecttree.ObjectTreeCreatePayload{
PrivKey: acc.SignKey,
ChangeType: "some",
SpaceId: spc.Id(),
IsEncrypted: false,
Seed: bytes,
Timestamp: time.Now().Unix(),
})
require.NoError(t, err)
tr, err := spc.TreeBuilder().PutTree(ctx, doc, nil)
require.NoError(t, err)
tr.Close()
return tr.Id()
}
func TestSpaceDeleteIdsIncorrectSnapshot(t *testing.T) {
fx := newFixture(t)
acc := fx.account.Account()
rk := crypto.NewAES()
ctx := context.Background()
totalObjs := 1500
partialObjs := 300
// creating space
sp, err := fx.spaceService.CreateSpace(ctx, SpaceCreatePayload{
SigningKey: acc.SignKey,
SpaceType: "type",
ReadKey: rk.Bytes(),
ReplicationKey: 10,
MasterKey: acc.PeerKey,
})
require.NoError(t, err)
require.NotNil(t, sp)
// initializing space
spc, err := fx.spaceService.NewSpace(ctx, sp)
require.NoError(t, err)
require.NotNil(t, spc)
// adding space to tree manager
fx.treeManager.space = spc
err = spc.Init(ctx)
close(fx.treeManager.waitLoad)
require.NoError(t, err)
settingsObject := spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
var ids []string
for i := 0; i < totalObjs; i++ {
id := createTree(t, ctx, spc, acc)
ids = append(ids, id)
}
// copying storage, so we will have all the trees locally
inmemory := spc.Storage().(*commonStorage).SpaceStorage.(*spacestorage.InMemorySpaceStorage)
storageCopy := inmemory.CopyStorage()
treesCopy := inmemory.AllTrees()
// deleting trees
for _, id := range ids {
err = spc.DeleteTree(ctx, id)
require.NoError(t, err)
}
mapIds := map[string]struct{}{}
for _, id := range ids[:partialObjs] {
mapIds[id] = struct{}{}
}
// adding snapshot that breaks the state
err = addIncorrectSnapshot(settingsObject, acc, mapIds, ids[partialObjs])
require.NoError(t, err)
// copying the contents of the settings tree
treesCopy[settingsObject.Id()] = settingsObject.Storage()
storageCopy.SetTrees(treesCopy)
spc.Close()
time.Sleep(100 * time.Millisecond)
// now we replace the storage, so the trees are back, but the settings object says that they are deleted
fx.storageProvider.(*spacestorage.InMemorySpaceStorageProvider).SetStorage(storageCopy)
spc, err = fx.spaceService.NewSpace(ctx, sp)
require.NoError(t, err)
require.NotNil(t, spc)
fx.treeManager.waitLoad = make(chan struct{})
fx.treeManager.space = spc
fx.treeManager.deletedIds = nil
err = spc.Init(ctx)
require.NoError(t, err)
close(fx.treeManager.waitLoad)
// waiting until everything is deleted
time.Sleep(3 * time.Second)
require.Equal(t, len(ids), len(fx.treeManager.deletedIds))
// checking that new snapshot will contain all the changes
settingsObject = spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
settings.DoSnapshot = func(treeLen int) bool {
return true
}
id := createTree(t, ctx, spc, acc)
err = spc.DeleteTree(ctx, id)
require.NoError(t, err)
delIds := settingsObject.Root().Model.(*spacesyncproto.SettingsData).Snapshot.DeletedIds
require.Equal(t, totalObjs+1, len(delIds))
}
func TestSpaceDeleteIdsMarkDeleted(t *testing.T) {
fx := newFixture(t)
acc := fx.account.Account()
rk := crypto.NewAES()
ctx := context.Background()
totalObjs := 1500
// creating space
sp, err := fx.spaceService.CreateSpace(ctx, SpaceCreatePayload{
SigningKey: acc.SignKey,
SpaceType: "type",
ReadKey: rk.Bytes(),
ReplicationKey: 10,
MasterKey: acc.PeerKey,
})
require.NoError(t, err)
require.NotNil(t, sp)
// initializing space
spc, err := fx.spaceService.NewSpace(ctx, sp)
require.NoError(t, err)
require.NotNil(t, spc)
// adding space to tree manager
fx.treeManager.space = spc
err = spc.Init(ctx)
require.NoError(t, err)
close(fx.treeManager.waitLoad)
settingsObject := spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
var ids []string
for i := 0; i < totalObjs; i++ {
id := createTree(t, ctx, spc, acc)
ids = append(ids, id)
}
// copying storage, so we will have the same contents, except for empty trees
inmemory := spc.Storage().(*commonStorage).SpaceStorage.(*spacestorage.InMemorySpaceStorage)
storageCopy := inmemory.CopyStorage()
// deleting trees, this will prepare the document to have all the deletion changes
for _, id := range ids {
err = spc.DeleteTree(ctx, id)
require.NoError(t, err)
}
treesMap := map[string]treestorage.TreeStorage{}
// copying the contents of the settings tree
treesMap[settingsObject.Id()] = settingsObject.Storage()
storageCopy.SetTrees(treesMap)
spc.Close()
time.Sleep(100 * time.Millisecond)
// now we replace the storage, so the trees are back, but the settings object says that they are deleted
fx.storageProvider.(*spacestorage.InMemorySpaceStorageProvider).SetStorage(storageCopy)
spc, err = fx.spaceService.NewSpace(ctx, sp)
require.NoError(t, err)
require.NotNil(t, spc)
fx.treeManager.space = spc
fx.treeManager.waitLoad = make(chan struct{})
fx.treeManager.deletedIds = nil
fx.treeManager.markedIds = nil
err = spc.Init(ctx)
require.NoError(t, err)
close(fx.treeManager.waitLoad)
// waiting until everything is deleted
time.Sleep(3 * time.Second)
require.Equal(t, len(ids), len(fx.treeManager.markedIds))
require.Zero(t, len(fx.treeManager.deletedIds))
}

View File

@ -1,59 +1,73 @@
//go:generate mockgen -destination mock_deletionstate/mock_deletionstate.go github.com/anytypeio/any-sync/commonspace/settings/deletionstate DeletionState //go:generate mockgen -destination mock_deletionstate/mock_deletionstate.go github.com/anyproto/any-sync/commonspace/deletionstate ObjectDeletionState
package deletionstate package deletionstate
import ( import (
"github.com/anytypeio/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/app"
"github.com/anytypeio/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"go.uber.org/zap"
"sync" "sync"
) )
var log = logger.NewNamed(CName)
const CName = "common.commonspace.deletionstate"
type StateUpdateObserver func(ids []string) type StateUpdateObserver func(ids []string)
type DeletionState interface { type ObjectDeletionState interface {
app.Component
AddObserver(observer StateUpdateObserver) AddObserver(observer StateUpdateObserver)
Add(ids []string) (err error) Add(ids map[string]struct{})
GetQueued() (ids []string) GetQueued() (ids []string)
Delete(id string) (err error) Delete(id string) (err error)
Exists(id string) bool Exists(id string) bool
FilterJoin(ids ...[]string) (filtered []string) Filter(ids []string) (filtered []string)
CreateDeleteChange(id string, isSnapshot bool) (res []byte, err error)
} }
type deletionState struct { type objectDeletionState struct {
sync.RWMutex sync.RWMutex
log logger.CtxLogger
queued map[string]struct{} queued map[string]struct{}
deleted map[string]struct{} deleted map[string]struct{}
stateUpdateObservers []StateUpdateObserver stateUpdateObservers []StateUpdateObserver
storage spacestorage.SpaceStorage storage spacestorage.SpaceStorage
} }
func NewDeletionState(storage spacestorage.SpaceStorage) DeletionState { func (st *objectDeletionState) Init(a *app.App) (err error) {
return &deletionState{ st.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
return nil
}
func (st *objectDeletionState) Name() (name string) {
return CName
}
func New() ObjectDeletionState {
return &objectDeletionState{
log: log,
queued: map[string]struct{}{}, queued: map[string]struct{}{},
deleted: map[string]struct{}{}, deleted: map[string]struct{}{},
storage: storage,
} }
} }
func (st *deletionState) AddObserver(observer StateUpdateObserver) { func (st *objectDeletionState) AddObserver(observer StateUpdateObserver) {
st.Lock() st.Lock()
defer st.Unlock() defer st.Unlock()
st.stateUpdateObservers = append(st.stateUpdateObservers, observer) st.stateUpdateObservers = append(st.stateUpdateObservers, observer)
} }
func (st *deletionState) Add(ids []string) (err error) { func (st *objectDeletionState) Add(ids map[string]struct{}) {
var added []string
st.Lock() st.Lock()
defer func() { defer func() {
st.Unlock() st.Unlock()
if err != nil {
return
}
for _, ob := range st.stateUpdateObservers { for _, ob := range st.stateUpdateObservers {
ob(ids) ob(added)
} }
}() }()
for _, id := range ids { for id := range ids {
if _, exists := st.deleted[id]; exists { if _, exists := st.deleted[id]; exists {
continue continue
} }
@ -62,9 +76,10 @@ func (st *deletionState) Add(ids []string) (err error) {
} }
var status string var status string
status, err = st.storage.TreeDeletedStatus(id) status, err := st.storage.TreeDeletedStatus(id)
if err != nil { if err != nil {
return st.log.Warn("failed to get deleted status", zap.String("treeId", id), zap.Error(err))
continue
} }
switch status { switch status {
@ -73,17 +88,18 @@ func (st *deletionState) Add(ids []string) (err error) {
case spacestorage.TreeDeletedStatusDeleted: case spacestorage.TreeDeletedStatusDeleted:
st.deleted[id] = struct{}{} st.deleted[id] = struct{}{}
default: default:
st.queued[id] = struct{}{} err := st.storage.SetTreeDeletedStatus(id, spacestorage.TreeDeletedStatusQueued)
err = st.storage.SetTreeDeletedStatus(id, spacestorage.TreeDeletedStatusQueued)
if err != nil { if err != nil {
return st.log.Warn("failed to set deleted status", zap.String("treeId", id), zap.Error(err))
continue
} }
st.queued[id] = struct{}{}
} }
added = append(added, id)
} }
return
} }
func (st *deletionState) GetQueued() (ids []string) { func (st *objectDeletionState) GetQueued() (ids []string) {
st.RLock() st.RLock()
defer st.RUnlock() defer st.RUnlock()
ids = make([]string, 0, len(st.queued)) ids = make([]string, 0, len(st.queued))
@ -93,7 +109,7 @@ func (st *deletionState) GetQueued() (ids []string) {
return return
} }
func (st *deletionState) Delete(id string) (err error) { func (st *objectDeletionState) Delete(id string) (err error) {
st.Lock() st.Lock()
defer st.Unlock() defer st.Unlock()
delete(st.queued, id) delete(st.queued, id)
@ -105,44 +121,24 @@ func (st *deletionState) Delete(id string) (err error) {
return return
} }
func (st *deletionState) Exists(id string) bool { func (st *objectDeletionState) Exists(id string) bool {
st.RLock() st.RLock()
defer st.RUnlock() defer st.RUnlock()
return st.exists(id) return st.exists(id)
} }
func (st *deletionState) FilterJoin(ids ...[]string) (filtered []string) { func (st *objectDeletionState) Filter(ids []string) (filtered []string) {
st.RLock() st.RLock()
defer st.RUnlock() defer st.RUnlock()
filter := func(ids []string) {
for _, id := range ids { for _, id := range ids {
if !st.exists(id) { if !st.exists(id) {
filtered = append(filtered, id) filtered = append(filtered, id)
} }
} }
}
for _, arr := range ids {
filter(arr)
}
return return
} }
func (st *deletionState) CreateDeleteChange(id string, isSnapshot bool) (res []byte, err error) { func (st *objectDeletionState) exists(id string) bool {
content := &spacesyncproto.SpaceSettingsContent_ObjectDelete{
ObjectDelete: &spacesyncproto.ObjectDelete{Id: id},
}
change := &spacesyncproto.SettingsData{
Content: []*spacesyncproto.SpaceSettingsContent{
{Value: content},
},
Snapshot: nil,
}
// TODO: add snapshot logic
res, err = change.Marshal()
return
}
func (st *deletionState) exists(id string) bool {
if _, exists := st.deleted[id]; exists { if _, exists := st.deleted[id]; exists {
return true return true
} }

View File

@ -1,23 +1,25 @@
package deletionstate package deletionstate
import ( import (
"github.com/anytypeio/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anytypeio/any-sync/commonspace/spacestorage/mock_spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"sort"
"testing" "testing"
) )
type fixture struct { type fixture struct {
ctrl *gomock.Controller ctrl *gomock.Controller
delState *deletionState delState *objectDeletionState
spaceStorage *mock_spacestorage.MockSpaceStorage spaceStorage *mock_spacestorage.MockSpaceStorage
} }
func newFixture(t *testing.T) *fixture { func newFixture(t *testing.T) *fixture {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
spaceStorage := mock_spacestorage.NewMockSpaceStorage(ctrl) spaceStorage := mock_spacestorage.NewMockSpaceStorage(ctrl)
delState := NewDeletionState(spaceStorage).(*deletionState) delState := New().(*objectDeletionState)
delState.storage = spaceStorage
return &fixture{ return &fixture{
ctrl: ctrl, ctrl: ctrl,
delState: delState, delState: delState,
@ -36,8 +38,7 @@ func TestDeletionState_Add(t *testing.T) {
id := "newId" id := "newId"
fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return("", nil) fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return("", nil)
fx.spaceStorage.EXPECT().SetTreeDeletedStatus(id, spacestorage.TreeDeletedStatusQueued).Return(nil) fx.spaceStorage.EXPECT().SetTreeDeletedStatus(id, spacestorage.TreeDeletedStatusQueued).Return(nil)
err := fx.delState.Add([]string{id}) fx.delState.Add(map[string]struct{}{id: {}})
require.NoError(t, err)
require.Contains(t, fx.delState.queued, id) require.Contains(t, fx.delState.queued, id)
}) })
@ -46,8 +47,7 @@ func TestDeletionState_Add(t *testing.T) {
defer fx.stop() defer fx.stop()
id := "newId" id := "newId"
fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return(spacestorage.TreeDeletedStatusQueued, nil) fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return(spacestorage.TreeDeletedStatusQueued, nil)
err := fx.delState.Add([]string{id}) fx.delState.Add(map[string]struct{}{id: {}})
require.NoError(t, err)
require.Contains(t, fx.delState.queued, id) require.Contains(t, fx.delState.queued, id)
}) })
@ -56,8 +56,7 @@ func TestDeletionState_Add(t *testing.T) {
defer fx.stop() defer fx.stop()
id := "newId" id := "newId"
fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return(spacestorage.TreeDeletedStatusDeleted, nil) fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return(spacestorage.TreeDeletedStatusDeleted, nil)
err := fx.delState.Add([]string{id}) fx.delState.Add(map[string]struct{}{id: {}})
require.NoError(t, err)
require.Contains(t, fx.delState.deleted, id) require.Contains(t, fx.delState.deleted, id)
}) })
} }
@ -70,6 +69,7 @@ func TestDeletionState_GetQueued(t *testing.T) {
fx.delState.queued["id2"] = struct{}{} fx.delState.queued["id2"] = struct{}{}
queued := fx.delState.GetQueued() queued := fx.delState.GetQueued()
sort.Strings(queued)
require.Equal(t, []string{"id1", "id2"}, queued) require.Equal(t, []string{"id1", "id2"}, queued)
} }
@ -80,8 +80,8 @@ func TestDeletionState_FilterJoin(t *testing.T) {
fx.delState.queued["id1"] = struct{}{} fx.delState.queued["id1"] = struct{}{}
fx.delState.queued["id2"] = struct{}{} fx.delState.queued["id2"] = struct{}{}
filtered := fx.delState.FilterJoin([]string{"id1"}, []string{"id3", "id2"}, []string{"id4"}) filtered := fx.delState.Filter([]string{"id3", "id2"})
require.Equal(t, []string{"id3", "id4"}, filtered) require.Equal(t, []string{"id3"}, filtered)
} }
func TestDeletionState_AddObserver(t *testing.T) { func TestDeletionState_AddObserver(t *testing.T) {
@ -96,8 +96,7 @@ func TestDeletionState_AddObserver(t *testing.T) {
id := "newId" id := "newId"
fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return("", nil) fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return("", nil)
fx.spaceStorage.EXPECT().SetTreeDeletedStatus(id, spacestorage.TreeDeletedStatusQueued).Return(nil) fx.spaceStorage.EXPECT().SetTreeDeletedStatus(id, spacestorage.TreeDeletedStatusQueued).Return(nil)
err := fx.delState.Add([]string{id}) fx.delState.Add(map[string]struct{}{id: {}})
require.NoError(t, err)
require.Contains(t, fx.delState.queued, id) require.Contains(t, fx.delState.queued, id)
require.Equal(t, []string{id}, queued) require.Equal(t, []string{id}, queued)
} }

View File

@ -0,0 +1,144 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/deletionstate (interfaces: ObjectDeletionState)
// Package mock_deletionstate is a generated GoMock package.
package mock_deletionstate
import (
reflect "reflect"
app "github.com/anyproto/any-sync/app"
deletionstate "github.com/anyproto/any-sync/commonspace/deletionstate"
gomock "go.uber.org/mock/gomock"
)
// MockObjectDeletionState is a mock of ObjectDeletionState interface.
type MockObjectDeletionState struct {
ctrl *gomock.Controller
recorder *MockObjectDeletionStateMockRecorder
}
// MockObjectDeletionStateMockRecorder is the mock recorder for MockObjectDeletionState.
type MockObjectDeletionStateMockRecorder struct {
mock *MockObjectDeletionState
}
// NewMockObjectDeletionState creates a new mock instance.
func NewMockObjectDeletionState(ctrl *gomock.Controller) *MockObjectDeletionState {
mock := &MockObjectDeletionState{ctrl: ctrl}
mock.recorder = &MockObjectDeletionStateMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockObjectDeletionState) EXPECT() *MockObjectDeletionStateMockRecorder {
return m.recorder
}
// Add mocks base method.
func (m *MockObjectDeletionState) Add(arg0 map[string]struct{}) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0)
}
// Add indicates an expected call of Add.
func (mr *MockObjectDeletionStateMockRecorder) Add(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockObjectDeletionState)(nil).Add), arg0)
}
// AddObserver mocks base method.
func (m *MockObjectDeletionState) AddObserver(arg0 deletionstate.StateUpdateObserver) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddObserver", arg0)
}
// AddObserver indicates an expected call of AddObserver.
func (mr *MockObjectDeletionStateMockRecorder) AddObserver(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddObserver", reflect.TypeOf((*MockObjectDeletionState)(nil).AddObserver), arg0)
}
// Delete mocks base method.
func (m *MockObjectDeletionState) Delete(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockObjectDeletionStateMockRecorder) Delete(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockObjectDeletionState)(nil).Delete), arg0)
}
// Exists mocks base method.
func (m *MockObjectDeletionState) Exists(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Exists", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// Exists indicates an expected call of Exists.
func (mr *MockObjectDeletionStateMockRecorder) Exists(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockObjectDeletionState)(nil).Exists), arg0)
}
// Filter mocks base method.
func (m *MockObjectDeletionState) Filter(arg0 []string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Filter", arg0)
ret0, _ := ret[0].([]string)
return ret0
}
// Filter indicates an expected call of Filter.
func (mr *MockObjectDeletionStateMockRecorder) Filter(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filter", reflect.TypeOf((*MockObjectDeletionState)(nil).Filter), arg0)
}
// GetQueued mocks base method.
func (m *MockObjectDeletionState) GetQueued() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetQueued")
ret0, _ := ret[0].([]string)
return ret0
}
// GetQueued indicates an expected call of GetQueued.
func (mr *MockObjectDeletionStateMockRecorder) GetQueued() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQueued", reflect.TypeOf((*MockObjectDeletionState)(nil).GetQueued))
}
// Init mocks base method.
func (m *MockObjectDeletionState) Init(arg0 *app.App) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Init indicates an expected call of Init.
func (mr *MockObjectDeletionStateMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockObjectDeletionState)(nil).Init), arg0)
}
// Name mocks base method.
func (m *MockObjectDeletionState) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockObjectDeletionStateMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockObjectDeletionState)(nil).Name))
}

View File

@ -3,46 +3,45 @@ package headsync
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/anytypeio/any-sync/app/ldiff"
"github.com/anytypeio/any-sync/app/logger"
"github.com/anytypeio/any-sync/commonspace/object/tree/synctree"
"github.com/anytypeio/any-sync/commonspace/object/treegetter"
"github.com/anytypeio/any-sync/commonspace/peermanager"
"github.com/anytypeio/any-sync/commonspace/settings/deletionstate"
"github.com/anytypeio/any-sync/commonspace/spacestorage"
"github.com/anytypeio/any-sync/commonspace/spacesyncproto"
"github.com/anytypeio/any-sync/commonspace/syncstatus"
"github.com/anytypeio/any-sync/net/peer"
"github.com/anytypeio/any-sync/net/rpc/rpcerr"
"go.uber.org/zap"
"time" "time"
"github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/rpc/rpcerr"
"github.com/anyproto/any-sync/util/slice"
"go.uber.org/zap"
) )
type DiffSyncer interface { type DiffSyncer interface {
Sync(ctx context.Context) error Sync(ctx context.Context) error
RemoveObjects(ids []string) RemoveObjects(ids []string)
UpdateHeads(id string, heads []string) UpdateHeads(id string, heads []string)
Init(deletionState deletionstate.DeletionState) Init()
Close() error
} }
func newDiffSyncer( func newDiffSyncer(hs *headSync) DiffSyncer {
spaceId string,
diff ldiff.Diff,
peerManager peermanager.PeerManager,
cache treegetter.TreeGetter,
storage spacestorage.SpaceStorage,
clientFactory spacesyncproto.ClientFactory,
syncStatus syncstatus.StatusUpdater,
log logger.CtxLogger) DiffSyncer {
return &diffSyncer{ return &diffSyncer{
diff: diff, diff: hs.diff,
spaceId: spaceId, spaceId: hs.spaceId,
cache: cache, treeManager: hs.treeManager,
storage: storage, storage: hs.storage,
peerManager: peerManager, peerManager: hs.peerManager,
clientFactory: clientFactory, clientFactory: spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient),
credentialProvider: hs.credentialProvider,
log: log, log: log,
syncStatus: syncStatus, syncStatus: hs.syncStatus,
deletionState: hs.deletionState,
syncAcl: hs.syncAcl,
} }
} }
@ -50,17 +49,20 @@ type diffSyncer struct {
spaceId string spaceId string
diff ldiff.Diff diff ldiff.Diff
peerManager peermanager.PeerManager peerManager peermanager.PeerManager
cache treegetter.TreeGetter treeManager treemanager.TreeManager
storage spacestorage.SpaceStorage storage spacestorage.SpaceStorage
clientFactory spacesyncproto.ClientFactory clientFactory spacesyncproto.ClientFactory
log logger.CtxLogger log logger.CtxLogger
deletionState deletionstate.DeletionState deletionState deletionstate.ObjectDeletionState
credentialProvider credentialprovider.CredentialProvider
syncStatus syncstatus.StatusUpdater syncStatus syncstatus.StatusUpdater
treeSyncer treemanager.TreeSyncer
syncAcl syncacl.SyncAcl
} }
func (d *diffSyncer) Init(deletionState deletionstate.DeletionState) { func (d *diffSyncer) Init() {
d.deletionState = deletionState
d.deletionState.AddObserver(d.RemoveObjects) d.deletionState.AddObserver(d.RemoveObjects)
d.treeSyncer = d.treeManager.NewTreeSyncer(d.spaceId, d.treeManager)
} }
func (d *diffSyncer) RemoveObjects(ids []string) { func (d *diffSyncer) RemoveObjects(ids []string) {
@ -86,6 +88,7 @@ func (d *diffSyncer) UpdateHeads(id string, heads []string) {
} }
func (d *diffSyncer) Sync(ctx context.Context) error { func (d *diffSyncer) Sync(ctx context.Context) error {
// TODO: split diffsyncer into components
st := time.Now() st := time.Now()
// diffing with responsible peers according to configuration // diffing with responsible peers according to configuration
peers, err := d.peerManager.GetResponsiblePeers(ctx) peers, err := d.peerManager.GetResponsiblePeers(ctx)
@ -108,66 +111,68 @@ func (d *diffSyncer) Sync(ctx context.Context) error {
func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error) { func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error) {
ctx = logger.CtxWithFields(ctx, zap.String("peerId", p.Id())) ctx = logger.CtxWithFields(ctx, zap.String("peerId", p.Id()))
conn, err := p.AcquireDrpcConn(ctx)
if err != nil {
return
}
defer p.ReleaseDrpcConn(conn)
var ( var (
cl = d.clientFactory.Client(p) cl = d.clientFactory.Client(conn)
rdiff = NewRemoteDiff(d.spaceId, cl) rdiff = NewRemoteDiff(d.spaceId, cl)
stateCounter = d.syncStatus.StateCounter() stateCounter = d.syncStatus.StateCounter()
syncAclId = d.syncAcl.Id()
) )
newIds, changedIds, removedIds, err := d.diff.Diff(ctx, rdiff) newIds, changedIds, removedIds, err := d.diff.Diff(ctx, rdiff)
err = rpcerr.Unwrap(err) err = rpcerr.Unwrap(err)
if err != nil && err != spacesyncproto.ErrSpaceMissing { if err != nil && err != spacesyncproto.ErrSpaceMissing {
if err == spacesyncproto.ErrSpaceIsDeleted {
d.log.Debug("got space deleted while syncing")
d.treeSyncer.SyncAll(ctx, p.Id(), []string{d.storage.SpaceSettingsId()}, nil)
}
d.syncStatus.SetNodesOnline(p.Id(), false) d.syncStatus.SetNodesOnline(p.Id(), false)
return fmt.Errorf("diff error: %v", err) return fmt.Errorf("diff error: %v", err)
} }
d.syncStatus.SetNodesOnline(p.Id(), true) d.syncStatus.SetNodesOnline(p.Id(), true)
if err == spacesyncproto.ErrSpaceMissing { if err == spacesyncproto.ErrSpaceMissing {
return d.sendPushSpaceRequest(ctx, cl) return d.sendPushSpaceRequest(ctx, p.Id(), cl)
} }
totalLen := len(newIds) + len(changedIds) + len(removedIds) totalLen := len(newIds) + len(changedIds) + len(removedIds)
// not syncing ids which were removed through settings document // not syncing ids which were removed through settings document
filteredIds := d.deletionState.FilterJoin(newIds, changedIds, removedIds) missingIds := d.deletionState.Filter(newIds)
existingIds := append(d.deletionState.Filter(removedIds), d.deletionState.Filter(changedIds)...)
d.syncStatus.RemoveAllExcept(p.Id(), existingIds, stateCounter)
d.syncStatus.RemoveAllExcept(p.Id(), filteredIds, stateCounter) prevExistingLen := len(existingIds)
existingIds = slice.DiscardFromSlice(existingIds, func(s string) bool {
return s == syncAclId
})
// if we removed acl head from the list
if len(existingIds) < prevExistingLen {
if syncErr := d.syncAcl.SyncWithPeer(ctx, p.Id()); syncErr != nil {
log.Warn("failed to send acl sync message to peer", zap.String("aclId", syncAclId))
}
}
d.pingTreesInCache(ctx, filteredIds) // treeSyncer should not get acl id, that's why we filter existing ids before
err = d.treeSyncer.SyncAll(ctx, p.Id(), existingIds, missingIds)
d.log.Info("sync done:", zap.Int("newIds", len(newIds)), if err != nil {
return err
}
d.log.Info("sync done:",
zap.Int("newIds", len(newIds)),
zap.Int("changedIds", len(changedIds)), zap.Int("changedIds", len(changedIds)),
zap.Int("removedIds", len(removedIds)), zap.Int("removedIds", len(removedIds)),
zap.Int("already deleted ids", totalLen-len(filteredIds)), zap.Int("already deleted ids", totalLen-prevExistingLen-len(missingIds)),
zap.String("peerId", p.Id()), zap.String("peerId", p.Id()),
) )
return return
} }
func (d *diffSyncer) pingTreesInCache(ctx context.Context, trees []string) { func (d *diffSyncer) sendPushSpaceRequest(ctx context.Context, peerId string, cl spacesyncproto.DRPCSpaceSyncClient) (err error) {
for _, tId := range trees {
tree, err := d.cache.GetTree(ctx, d.spaceId, tId)
if err != nil {
d.log.InfoCtx(ctx, "can't load tree", zap.Error(err))
continue
}
syncTree, ok := tree.(synctree.SyncTree)
if !ok {
d.log.InfoCtx(ctx, "not a sync tree", zap.String("objectId", tId))
continue
}
// the idea why we call it directly is that if we try to get it from cache
// it may be already there (i.e. loaded)
// and build func will not be called, thus we won't sync the tree
// therefore we just do it manually
if err = syncTree.Ping(ctx); err != nil {
d.log.WarnCtx(ctx, "synctree.Ping error", zap.Error(err), zap.String("treeId", tId))
} else {
d.log.DebugCtx(ctx, "success tree ping", zap.String("treeId", tId))
}
}
}
func (d *diffSyncer) sendPushSpaceRequest(ctx context.Context, cl spacesyncproto.DRPCSpaceSyncClient) (err error) {
aclStorage, err := d.storage.AclStorage() aclStorage, err := d.storage.AclStorage()
if err != nil { if err != nil {
return return
@ -192,6 +197,10 @@ func (d *diffSyncer) sendPushSpaceRequest(ctx context.Context, cl spacesyncproto
return return
} }
cred, err := d.credentialProvider.GetCredential(ctx, header)
if err != nil {
return
}
spacePayload := &spacesyncproto.SpacePayload{ spacePayload := &spacesyncproto.SpacePayload{
SpaceHeader: header, SpaceHeader: header,
AclPayload: root.Payload, AclPayload: root.Payload,
@ -201,6 +210,31 @@ func (d *diffSyncer) sendPushSpaceRequest(ctx context.Context, cl spacesyncproto
} }
_, err = cl.SpacePush(ctx, &spacesyncproto.SpacePushRequest{ _, err = cl.SpacePush(ctx, &spacesyncproto.SpacePushRequest{
Payload: spacePayload, Payload: spacePayload,
Credential: cred,
}) })
if err != nil {
return
}
if e := d.subscribe(ctx, peerId); e != nil {
d.log.WarnCtx(ctx, "error subscribing for space", zap.Error(e))
}
return return
} }
func (d *diffSyncer) subscribe(ctx context.Context, peerId string) (err error) {
var msg = &spacesyncproto.SpaceSubscription{
SpaceIds: []string{d.spaceId},
Action: spacesyncproto.SpaceSubscriptionAction_Subscribe,
}
payload, err := msg.Marshal()
if err != nil {
return
}
return d.peerManager.SendPeer(ctx, peerId, &spacesyncproto.ObjectSyncMessage{
Payload: payload,
})
}
func (d *diffSyncer) Close() error {
return d.treeSyncer.Close()
}

View File

@ -1,170 +1,200 @@
package headsync package headsync
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
"github.com/anytypeio/any-sync/app/ldiff"
"github.com/anytypeio/any-sync/app/ldiff/mock_ldiff"
"github.com/anytypeio/any-sync/app/logger"
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anytypeio/any-sync/commonspace/object/acl/liststorage/mock_liststorage"
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
mock_treestorage "github.com/anytypeio/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
"github.com/anytypeio/any-sync/commonspace/object/treegetter/mock_treegetter"
"github.com/anytypeio/any-sync/commonspace/peermanager/mock_peermanager"
"github.com/anytypeio/any-sync/commonspace/settings/deletionstate/mock_deletionstate"
"github.com/anytypeio/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/anytypeio/any-sync/commonspace/spacesyncproto"
"github.com/anytypeio/any-sync/commonspace/spacesyncproto/mock_spacesyncproto"
"github.com/anytypeio/any-sync/commonspace/syncstatus"
"github.com/anytypeio/any-sync/net/peer"
"github.com/golang/mock/gomock"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/require"
"storj.io/drpc"
"testing" "testing"
"time" "time"
"github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage/mock_liststorage"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/net/peer"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"storj.io/drpc"
) )
type pushSpaceRequestMatcher struct { type pushSpaceRequestMatcher struct {
spaceId string spaceId string
aclRootId string aclRootId string
settingsId string settingsId string
credential []byte
spaceHeader *spacesyncproto.RawSpaceHeaderWithId spaceHeader *spacesyncproto.RawSpaceHeaderWithId
} }
func newPushSpaceRequestMatcher(
spaceId string,
aclRootId string,
settingsId string,
credential []byte,
spaceHeader *spacesyncproto.RawSpaceHeaderWithId) *pushSpaceRequestMatcher {
return &pushSpaceRequestMatcher{
spaceId: spaceId,
aclRootId: aclRootId,
settingsId: settingsId,
credential: credential,
spaceHeader: spaceHeader,
}
}
func (p pushSpaceRequestMatcher) Matches(x interface{}) bool { func (p pushSpaceRequestMatcher) Matches(x interface{}) bool {
res, ok := x.(*spacesyncproto.SpacePushRequest) res, ok := x.(*spacesyncproto.SpacePushRequest)
if !ok { if !ok {
return false return false
} }
return res.Payload.AclPayloadId == p.aclRootId && res.Payload.SpaceHeader == p.spaceHeader && res.Payload.SpaceSettingsPayloadId == p.settingsId return res.Payload.AclPayloadId == p.aclRootId && res.Payload.SpaceHeader == p.spaceHeader && res.Payload.SpaceSettingsPayloadId == p.settingsId && bytes.Equal(p.credential, res.Credential)
} }
func (p pushSpaceRequestMatcher) String() string { func (p pushSpaceRequestMatcher) String() string {
return "" return ""
} }
type mockPeer struct{} type mockPeer struct {
}
func (m mockPeer) Id() string { func (m mockPeer) Id() string {
return "mockId" return "peerId"
} }
func (m mockPeer) LastUsage() time.Time { func (m mockPeer) Context() context.Context {
return time.Time{} return context.Background()
} }
func (m mockPeer) Secure() sec.SecureConn { func (m mockPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
return nil
}
func (m mockPeer) UpdateLastUsage() {
}
func (m mockPeer) Close() error {
return nil
}
func (m mockPeer) Closed() <-chan struct{} {
return make(chan struct{})
}
func (m mockPeer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error {
return nil
}
func (m mockPeer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
return nil, nil return nil, nil
} }
func newPushSpaceRequestMatcher( func (m mockPeer) ReleaseDrpcConn(conn drpc.Conn) {
spaceId string, return
aclRootId string,
settingsId string,
spaceHeader *spacesyncproto.RawSpaceHeaderWithId) *pushSpaceRequestMatcher {
return &pushSpaceRequestMatcher{
spaceId: spaceId,
aclRootId: aclRootId,
settingsId: settingsId,
spaceHeader: spaceHeader,
}
} }
func TestDiffSyncer_Sync(t *testing.T) { func (m mockPeer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error {
// setup return nil
ctx := context.Background() }
ctrl := gomock.NewController(t)
defer ctrl.Finish()
diffMock := mock_ldiff.NewMockDiff(ctrl) func (m mockPeer) IsClosed() bool {
peerManagerMock := mock_peermanager.NewMockPeerManager(ctrl) return false
cacheMock := mock_treegetter.NewMockTreeGetter(ctrl) }
stMock := mock_spacestorage.NewMockSpaceStorage(ctrl)
clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl) func (m mockPeer) TryClose(objectTTL time.Duration) (res bool, err error) {
factory := spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient { return false, err
return clientMock }
func (m mockPeer) Close() (err error) {
return nil
}
func (fx *headSyncFixture) initDiffSyncer(t *testing.T) {
fx.init(t)
fx.diffSyncer = newDiffSyncer(fx.headSync).(*diffSyncer)
fx.diffSyncer.clientFactory = spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient {
return fx.clientMock
}) })
delState := mock_deletionstate.NewMockDeletionState(ctrl) fx.deletionStateMock.EXPECT().AddObserver(gomock.Any())
spaceId := "spaceId" fx.treeManagerMock.EXPECT().NewTreeSyncer(fx.spaceState.SpaceId, fx.treeManagerMock).Return(fx.treeSyncerMock)
aclRootId := "aclRootId" fx.diffSyncer.Init()
l := logger.NewNamed(spaceId) }
diffSyncer := newDiffSyncer(spaceId, diffMock, peerManagerMock, cacheMock, stMock, factory, syncstatus.NewNoOpSyncStatus(), l)
delState.EXPECT().AddObserver(gomock.Any()) func TestDiffSyncer(t *testing.T) {
diffSyncer.Init(delState) ctx := context.Background()
t.Run("diff syncer sync", func(t *testing.T) { t.Run("diff syncer sync", func(t *testing.T) {
peerManagerMock.EXPECT(). fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
mPeer := mockPeer{}
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mockPeer{}}, nil) Return([]peer.Peer{mPeer}, nil)
diffMock.EXPECT(). fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return([]string{"new"}, []string{"changed"}, nil, nil) Return([]string{"new"}, []string{"changed"}, nil, nil)
delState.EXPECT().FilterJoin(gomock.Any()).Return([]string{"new", "changed"}) fx.deletionStateMock.EXPECT().Filter([]string{"new"}).Return([]string{"new"}).Times(1)
for _, arg := range []string{"new", "changed"} { fx.deletionStateMock.EXPECT().Filter([]string{"changed"}).Return([]string{"changed"}).Times(1)
cacheMock.EXPECT(). fx.deletionStateMock.EXPECT().Filter(nil).Return(nil).Times(1)
GetTree(gomock.Any(), spaceId, arg). fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"changed"}, []string{"new"}).Return(nil)
Return(nil, nil) require.NoError(t, fx.diffSyncer.Sync(ctx))
} })
require.NoError(t, diffSyncer.Sync(ctx))
t.Run("diff syncer sync, acl changed", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
mPeer := mockPeer{}
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mPeer}, nil)
fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return([]string{"new"}, []string{"changed"}, nil, nil)
fx.deletionStateMock.EXPECT().Filter([]string{"new"}).Return([]string{"new"}).Times(1)
fx.deletionStateMock.EXPECT().Filter([]string{"changed"}).Return([]string{"changed", "aclId"}).Times(1)
fx.deletionStateMock.EXPECT().Filter(nil).Return(nil).Times(1)
fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"changed"}, []string{"new"}).Return(nil)
fx.aclMock.EXPECT().SyncWithPeer(gomock.Any(), mPeer.Id()).Return(nil)
require.NoError(t, fx.diffSyncer.Sync(ctx))
}) })
t.Run("diff syncer sync conf error", func(t *testing.T) { t.Run("diff syncer sync conf error", func(t *testing.T) {
peerManagerMock.EXPECT(). fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
ctx := context.Background()
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return(nil, fmt.Errorf("some error")) Return(nil, fmt.Errorf("some error"))
require.Error(t, diffSyncer.Sync(ctx)) require.Error(t, fx.diffSyncer.Sync(ctx))
}) })
t.Run("deletion state remove objects", func(t *testing.T) { t.Run("deletion state remove objects", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
deletedId := "id" deletedId := "id"
delState.EXPECT().Exists(deletedId).Return(true) fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.deletionStateMock.EXPECT().Exists(deletedId).Return(true)
// this should not result in any mock being called // this should not result in any mock being called
diffSyncer.UpdateHeads(deletedId, []string{"someHead"}) fx.diffSyncer.UpdateHeads(deletedId, []string{"someHead"})
}) })
t.Run("update heads updates diff", func(t *testing.T) { t.Run("update heads updates diff", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
newId := "newId" newId := "newId"
newHeads := []string{"h1", "h2"} newHeads := []string{"h1", "h2"}
hash := "hash" hash := "hash"
diffMock.EXPECT().Set(ldiff.Element{ fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.diffMock.EXPECT().Set(ldiff.Element{
Id: newId, Id: newId,
Head: concatStrings(newHeads), Head: concatStrings(newHeads),
}) })
diffMock.EXPECT().Hash().Return(hash) fx.diffMock.EXPECT().Hash().Return(hash)
delState.EXPECT().Exists(newId).Return(false) fx.deletionStateMock.EXPECT().Exists(newId).Return(false)
stMock.EXPECT().WriteSpaceHash(hash) fx.storageMock.EXPECT().WriteSpaceHash(hash)
diffSyncer.UpdateHeads(newId, newHeads) fx.diffSyncer.UpdateHeads(newId, newHeads)
}) })
t.Run("diff syncer sync space missing", func(t *testing.T) { t.Run("diff syncer sync space missing", func(t *testing.T) {
aclStorageMock := mock_liststorage.NewMockListStorage(ctrl) fx := newHeadSyncFixture(t)
settingsStorage := mock_treestorage.NewMockTreeStorage(ctrl) fx.initDiffSyncer(t)
defer fx.stop()
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
aclStorageMock := mock_liststorage.NewMockListStorage(fx.ctrl)
settingsStorage := mock_treestorage.NewMockTreeStorage(fx.ctrl)
settingsId := "settingsId" settingsId := "settingsId"
aclRoot := &aclrecordproto.RawAclRecordWithId{ aclRootId := "aclRootId"
aclRoot := &consensusproto.RawRecordWithId{
Id: aclRootId, Id: aclRootId,
} }
settingsRoot := &treechangeproto.RawTreeChangeWithId{ settingsRoot := &treechangeproto.RawTreeChangeWithId{
@ -172,38 +202,65 @@ func TestDiffSyncer_Sync(t *testing.T) {
} }
spaceHeader := &spacesyncproto.RawSpaceHeaderWithId{} spaceHeader := &spacesyncproto.RawSpaceHeaderWithId{}
spaceSettingsId := "spaceSettingsId" spaceSettingsId := "spaceSettingsId"
credential := []byte("credential")
peerManagerMock.EXPECT(). fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mockPeer{}}, nil) Return([]peer.Peer{mockPeer{}}, nil)
diffMock.EXPECT(). fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return(nil, nil, nil, spacesyncproto.ErrSpaceMissing) Return(nil, nil, nil, spacesyncproto.ErrSpaceMissing)
stMock.EXPECT().AclStorage().Return(aclStorageMock, nil) fx.storageMock.EXPECT().AclStorage().Return(aclStorageMock, nil)
stMock.EXPECT().SpaceHeader().Return(spaceHeader, nil) fx.storageMock.EXPECT().SpaceHeader().Return(spaceHeader, nil)
stMock.EXPECT().SpaceSettingsId().Return(spaceSettingsId) fx.storageMock.EXPECT().SpaceSettingsId().Return(spaceSettingsId)
stMock.EXPECT().TreeStorage(spaceSettingsId).Return(settingsStorage, nil) fx.storageMock.EXPECT().TreeStorage(spaceSettingsId).Return(settingsStorage, nil)
settingsStorage.EXPECT().Root().Return(settingsRoot, nil) settingsStorage.EXPECT().Root().Return(settingsRoot, nil)
aclStorageMock.EXPECT(). aclStorageMock.EXPECT().
Root(). Root().
Return(aclRoot, nil) Return(aclRoot, nil)
clientMock.EXPECT(). fx.credentialProviderMock.EXPECT().
SpacePush(gomock.Any(), newPushSpaceRequestMatcher(spaceId, aclRootId, settingsId, spaceHeader)). GetCredential(gomock.Any(), spaceHeader).
Return(credential, nil)
fx.clientMock.EXPECT().
SpacePush(gomock.Any(), newPushSpaceRequestMatcher(fx.spaceState.SpaceId, aclRootId, settingsId, credential, spaceHeader)).
Return(nil, nil) Return(nil, nil)
fx.peerManagerMock.EXPECT().SendPeer(gomock.Any(), "peerId", gomock.Any())
require.NoError(t, diffSyncer.Sync(ctx)) require.NoError(t, fx.diffSyncer.Sync(ctx))
}) })
t.Run("diff syncer sync other error", func(t *testing.T) { t.Run("diff syncer sync unexpected", func(t *testing.T) {
peerManagerMock.EXPECT(). fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mockPeer{}}, nil) Return([]peer.Peer{mockPeer{}}, nil)
diffMock.EXPECT(). fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return(nil, nil, nil, spacesyncproto.ErrUnexpected) Return(nil, nil, nil, spacesyncproto.ErrUnexpected)
require.NoError(t, diffSyncer.Sync(ctx)) require.NoError(t, fx.diffSyncer.Sync(ctx))
})
t.Run("diff syncer sync space is deleted error", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
mPeer := mockPeer{}
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mPeer}, nil)
fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return(nil, nil, nil, spacesyncproto.ErrSpaceIsDeleted)
fx.storageMock.EXPECT().SpaceSettingsId().Return("settingsId")
fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"settingsId"}, nil).Return(nil)
require.NoError(t, fx.diffSyncer.Sync(ctx))
}) })
} }

View File

@ -1,96 +1,152 @@
//go:generate mockgen -destination mock_headsync/mock_headsync.go github.com/anytypeio/any-sync/commonspace/headsync DiffSyncer //go:generate mockgen -destination mock_headsync/mock_headsync.go github.com/anyproto/any-sync/commonspace/headsync DiffSyncer
package headsync package headsync
import ( import (
"context" "context"
"github.com/anytypeio/any-sync/app/ldiff" "sync/atomic"
"github.com/anytypeio/any-sync/app/logger"
"github.com/anytypeio/any-sync/commonspace/object/treegetter"
"github.com/anytypeio/any-sync/commonspace/peermanager"
"github.com/anytypeio/any-sync/commonspace/settings/deletionstate"
"github.com/anytypeio/any-sync/commonspace/spacestorage"
"github.com/anytypeio/any-sync/commonspace/spacesyncproto"
"github.com/anytypeio/any-sync/commonspace/syncstatus"
"github.com/anytypeio/any-sync/util/periodicsync"
"go.uber.org/zap"
"strings"
"time" "time"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/app/logger"
config2 "github.com/anyproto/any-sync/commonspace/config"
"github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/util/periodicsync"
"github.com/anyproto/any-sync/util/slice"
"go.uber.org/zap"
"golang.org/x/exp/slices"
) )
var log = logger.NewNamed(CName)
const CName = "common.commonspace.headsync"
type TreeHeads struct { type TreeHeads struct {
Id string Id string
Heads []string Heads []string
} }
type HeadSync interface { type HeadSync interface {
Init(objectIds []string, deletionState deletionstate.DeletionState) app.ComponentRunnable
ExternalIds() []string
DebugAllHeads() (res []TreeHeads)
AllIds() []string
UpdateHeads(id string, heads []string) UpdateHeads(id string, heads []string)
HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error)
RemoveObjects(ids []string) RemoveObjects(ids []string)
AllIds() []string
DebugAllHeads() (res []TreeHeads)
Close() (err error)
} }
type headSync struct { type headSync struct {
spaceId string spaceId string
spaceIsDeleted *atomic.Bool
syncPeriod int
periodicSync periodicsync.PeriodicSync periodicSync periodicsync.PeriodicSync
storage spacestorage.SpaceStorage storage spacestorage.SpaceStorage
diff ldiff.Diff diff ldiff.Diff
log logger.CtxLogger log logger.CtxLogger
syncer DiffSyncer syncer DiffSyncer
configuration nodeconf.NodeConf
syncPeriod int peerManager peermanager.PeerManager
treeManager treemanager.TreeManager
credentialProvider credentialprovider.CredentialProvider
syncStatus syncstatus.StatusService
deletionState deletionstate.ObjectDeletionState
syncAcl syncacl.SyncAcl
} }
func NewHeadSync( func New() HeadSync {
spaceId string, return &headSync{}
syncPeriod int, }
storage spacestorage.SpaceStorage,
peerManager peermanager.PeerManager,
cache treegetter.TreeGetter,
syncStatus syncstatus.StatusUpdater,
log logger.CtxLogger) HeadSync {
diff := ldiff.New(16, 16) var createDiffSyncer = newDiffSyncer
l := log.With(zap.String("spaceId", spaceId))
factory := spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient)
syncer := newDiffSyncer(spaceId, diff, peerManager, cache, storage, factory, syncStatus, l)
periodicSync := periodicsync.NewPeriodicSync(syncPeriod, time.Minute*10, syncer.Sync, l)
return &headSync{ func (h *headSync) Init(a *app.App) (err error) {
spaceId: spaceId, shared := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
storage: storage, cfg := a.MustComponent("config").(config2.ConfigGetter)
syncer: syncer, h.syncAcl = a.MustComponent(syncacl.CName).(syncacl.SyncAcl)
periodicSync: periodicSync, h.spaceId = shared.SpaceId
diff: diff, h.spaceIsDeleted = shared.SpaceIsDeleted
log: log, h.syncPeriod = cfg.GetSpace().SyncPeriod
syncPeriod: syncPeriod, h.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
h.log = log.With(zap.String("spaceId", h.spaceId))
h.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
h.diff = ldiff.New(16, 16)
h.peerManager = a.MustComponent(peermanager.CName).(peermanager.PeerManager)
h.credentialProvider = a.MustComponent(credentialprovider.CName).(credentialprovider.CredentialProvider)
h.syncStatus = a.MustComponent(syncstatus.CName).(syncstatus.StatusService)
h.treeManager = a.MustComponent(treemanager.CName).(treemanager.TreeManager)
h.deletionState = a.MustComponent(deletionstate.CName).(deletionstate.ObjectDeletionState)
h.syncer = createDiffSyncer(h)
sync := func(ctx context.Context) (err error) {
// for clients cancelling the sync process
if h.spaceIsDeleted.Load() && !h.configuration.IsResponsible(h.spaceId) {
return spacesyncproto.ErrSpaceIsDeleted
} }
return h.syncer.Sync(ctx)
}
h.periodicSync = periodicsync.NewPeriodicSync(h.syncPeriod, time.Minute, sync, h.log)
h.syncAcl.SetHeadUpdater(h)
// TODO: move to run?
h.syncer.Init()
return nil
} }
func (d *headSync) Init(objectIds []string, deletionState deletionstate.DeletionState) { func (h *headSync) Name() (name string) {
d.fillDiff(objectIds) return CName
d.syncer.Init(deletionState)
d.periodicSync.Run()
} }
func (d *headSync) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) { func (h *headSync) Run(ctx context.Context) (err error) {
return HandleRangeRequest(ctx, d.diff, req) initialIds, err := h.storage.StoredIds()
if err != nil {
return
}
h.fillDiff(initialIds)
h.periodicSync.Run()
return
} }
func (d *headSync) UpdateHeads(id string, heads []string) { func (h *headSync) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) {
d.syncer.UpdateHeads(id, heads) if h.spaceIsDeleted.Load() {
peerId, err := peer.CtxPeerId(ctx)
if err != nil {
return nil, err
}
// stop receiving all request for sync from clients
if !slices.Contains(h.configuration.NodeIds(h.spaceId), peerId) {
return nil, spacesyncproto.ErrSpaceIsDeleted
}
}
return HandleRangeRequest(ctx, h.diff, req)
} }
func (d *headSync) AllIds() []string { func (h *headSync) UpdateHeads(id string, heads []string) {
return d.diff.Ids() h.syncer.UpdateHeads(id, heads)
} }
func (d *headSync) DebugAllHeads() (res []TreeHeads) { func (h *headSync) AllIds() []string {
els := d.diff.Elements() return h.diff.Ids()
}
func (h *headSync) ExternalIds() []string {
settingsId := h.storage.SpaceSettingsId()
return slice.DiscardFromSlice(h.AllIds(), func(id string) bool {
return id == settingsId
})
}
func (h *headSync) DebugAllHeads() (res []TreeHeads) {
els := h.diff.Elements()
for _, el := range els { for _, el := range els {
idHead := TreeHeads{ idHead := TreeHeads{
Id: el.Id, Id: el.Id,
@ -101,19 +157,19 @@ func (d *headSync) DebugAllHeads() (res []TreeHeads) {
return return
} }
func (d *headSync) RemoveObjects(ids []string) { func (h *headSync) RemoveObjects(ids []string) {
d.syncer.RemoveObjects(ids) h.syncer.RemoveObjects(ids)
} }
func (d *headSync) Close() (err error) { func (h *headSync) Close(ctx context.Context) (err error) {
d.periodicSync.Close() h.periodicSync.Close()
return nil return h.syncer.Close()
} }
func (d *headSync) fillDiff(objectIds []string) { func (h *headSync) fillDiff(objectIds []string) {
var els = make([]ldiff.Element, 0, len(objectIds)) var els = make([]ldiff.Element, 0, len(objectIds))
for _, id := range objectIds { for _, id := range objectIds {
st, err := d.storage.TreeStorage(id) st, err := h.storage.TreeStorage(id)
if err != nil { if err != nil {
continue continue
} }
@ -126,32 +182,12 @@ func (d *headSync) fillDiff(objectIds []string) {
Head: concatStrings(heads), Head: concatStrings(heads),
}) })
} }
d.diff.Set(els...) els = append(els, ldiff.Element{
if err := d.storage.WriteSpaceHash(d.diff.Hash()); err != nil { Id: h.syncAcl.Id(),
d.log.Error("can't write space hash", zap.Error(err)) Head: h.syncAcl.Head().Id,
})
h.diff.Set(els...)
if err := h.storage.WriteSpaceHash(h.diff.Hash()); err != nil {
h.log.Error("can't write space hash", zap.Error(err))
} }
} }
func concatStrings(strs []string) string {
var (
b strings.Builder
totalLen int
)
for _, s := range strs {
totalLen += len(s)
}
b.Grow(totalLen)
for _, s := range strs {
b.WriteString(s)
}
return b.String()
}
func splitString(str string) (res []string) {
const cidLen = 59
for i := 0; i < len(str); i += cidLen {
res = append(res, str[i:i+cidLen])
}
return
}

View File

@ -1,70 +1,190 @@
package headsync package headsync
import ( import (
"github.com/anytypeio/any-sync/app/ldiff" "context"
"github.com/anytypeio/any-sync/app/ldiff/mock_ldiff" "github.com/anyproto/any-sync/app"
"github.com/anytypeio/any-sync/app/logger" "github.com/anyproto/any-sync/app/ldiff"
"github.com/anytypeio/any-sync/commonspace/headsync/mock_headsync" "github.com/anyproto/any-sync/app/ldiff/mock_ldiff"
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage/mock_treestorage" "github.com/anyproto/any-sync/commonspace/config"
"github.com/anytypeio/any-sync/commonspace/settings/deletionstate/mock_deletionstate" "github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anytypeio/any-sync/commonspace/spacestorage/mock_spacestorage" "github.com/anyproto/any-sync/commonspace/credentialprovider/mock_credentialprovider"
"github.com/anytypeio/any-sync/util/periodicsync/mock_periodicsync" "github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/golang/mock/gomock" "github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate"
"github.com/anyproto/any-sync/commonspace/headsync/mock_headsync"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl/mock_syncacl"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
"github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/peermanager/mock_peermanager"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto/mock_spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/nodeconf/mock_nodeconf"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"sync/atomic"
"testing" "testing"
) )
func TestDiffService(t *testing.T) { type mockConfig struct {
ctrl := gomock.NewController(t) }
defer ctrl.Finish()
spaceId := "spaceId" func (m mockConfig) Init(a *app.App) (err error) {
l := logger.NewNamed("sync") return nil
pSyncMock := mock_periodicsync.NewMockPeriodicSync(ctrl) }
storageMock := mock_spacestorage.NewMockSpaceStorage(ctrl)
treeStorageMock := mock_treestorage.NewMockTreeStorage(ctrl)
diffMock := mock_ldiff.NewMockDiff(ctrl)
syncer := mock_headsync.NewMockDiffSyncer(ctrl)
delState := mock_deletionstate.NewMockDeletionState(ctrl)
syncPeriod := 1
initId := "initId"
service := &headSync{ func (m mockConfig) Name() (name string) {
spaceId: spaceId, return "config"
storage: storageMock, }
periodicSync: pSyncMock,
syncer: syncer, func (m mockConfig) GetSpace() config.Config {
diff: diffMock, return config.Config{}
log: l, }
syncPeriod: syncPeriod,
type headSyncFixture struct {
spaceState *spacestate.SpaceState
ctrl *gomock.Controller
app *app.App
configurationMock *mock_nodeconf.MockService
storageMock *mock_spacestorage.MockSpaceStorage
peerManagerMock *mock_peermanager.MockPeerManager
credentialProviderMock *mock_credentialprovider.MockCredentialProvider
syncStatus syncstatus.StatusService
treeManagerMock *mock_treemanager.MockTreeManager
deletionStateMock *mock_deletionstate.MockObjectDeletionState
diffSyncerMock *mock_headsync.MockDiffSyncer
treeSyncerMock *mock_treemanager.MockTreeSyncer
diffMock *mock_ldiff.MockDiff
clientMock *mock_spacesyncproto.MockDRPCSpaceSyncClient
aclMock *mock_syncacl.MockSyncAcl
headSync *headSync
diffSyncer *diffSyncer
}
func newHeadSyncFixture(t *testing.T) *headSyncFixture {
spaceState := &spacestate.SpaceState{
SpaceId: "spaceId",
SpaceIsDeleted: &atomic.Bool{},
} }
ctrl := gomock.NewController(t)
configurationMock := mock_nodeconf.NewMockService(ctrl)
configurationMock.EXPECT().Name().AnyTimes().Return(nodeconf.CName)
storageMock := mock_spacestorage.NewMockSpaceStorage(ctrl)
storageMock.EXPECT().Name().AnyTimes().Return(spacestorage.CName)
peerManagerMock := mock_peermanager.NewMockPeerManager(ctrl)
peerManagerMock.EXPECT().Name().AnyTimes().Return(peermanager.CName)
credentialProviderMock := mock_credentialprovider.NewMockCredentialProvider(ctrl)
credentialProviderMock.EXPECT().Name().AnyTimes().Return(credentialprovider.CName)
syncStatus := syncstatus.NewNoOpSyncStatus()
treeManagerMock := mock_treemanager.NewMockTreeManager(ctrl)
treeManagerMock.EXPECT().Name().AnyTimes().Return(treemanager.CName)
deletionStateMock := mock_deletionstate.NewMockObjectDeletionState(ctrl)
deletionStateMock.EXPECT().Name().AnyTimes().Return(deletionstate.CName)
diffSyncerMock := mock_headsync.NewMockDiffSyncer(ctrl)
treeSyncerMock := mock_treemanager.NewMockTreeSyncer(ctrl)
diffMock := mock_ldiff.NewMockDiff(ctrl)
clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl)
aclMock := mock_syncacl.NewMockSyncAcl(ctrl)
aclMock.EXPECT().Name().AnyTimes().Return(syncacl.CName)
aclMock.EXPECT().SetHeadUpdater(gomock.Any()).AnyTimes()
hs := &headSync{}
a := &app.App{}
a.Register(spaceState).
Register(aclMock).
Register(mockConfig{}).
Register(configurationMock).
Register(storageMock).
Register(peerManagerMock).
Register(credentialProviderMock).
Register(syncStatus).
Register(treeManagerMock).
Register(deletionStateMock).
Register(hs)
return &headSyncFixture{
spaceState: spaceState,
ctrl: ctrl,
app: a,
configurationMock: configurationMock,
storageMock: storageMock,
peerManagerMock: peerManagerMock,
credentialProviderMock: credentialProviderMock,
syncStatus: syncStatus,
treeManagerMock: treeManagerMock,
deletionStateMock: deletionStateMock,
headSync: hs,
diffSyncerMock: diffSyncerMock,
treeSyncerMock: treeSyncerMock,
diffMock: diffMock,
clientMock: clientMock,
aclMock: aclMock,
}
}
t.Run("init", func(t *testing.T) { func (fx *headSyncFixture) init(t *testing.T) {
storageMock.EXPECT().TreeStorage(initId).Return(treeStorageMock, nil) createDiffSyncer = func(hs *headSync) DiffSyncer {
treeStorageMock.EXPECT().Heads().Return([]string{"h1", "h2"}, nil) return fx.diffSyncerMock
syncer.EXPECT().Init(delState) }
diffMock.EXPECT().Set(ldiff.Element{ fx.diffSyncerMock.EXPECT().Init()
Id: initId, err := fx.headSync.Init(fx.app)
require.NoError(t, err)
fx.headSync.diff = fx.diffMock
}
func (fx *headSyncFixture) stop() {
fx.ctrl.Finish()
}
func TestHeadSync(t *testing.T) {
ctx := context.Background()
t.Run("run close", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.init(t)
defer fx.stop()
ids := []string{"id1"}
treeMock := mock_treestorage.NewMockTreeStorage(fx.ctrl)
fx.storageMock.EXPECT().StoredIds().Return(ids, nil)
fx.storageMock.EXPECT().TreeStorage(ids[0]).Return(treeMock, nil)
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.aclMock.EXPECT().Head().AnyTimes().Return(&list.AclRecord{Id: "headId"})
treeMock.EXPECT().Heads().Return([]string{"h1", "h2"}, nil)
fx.diffMock.EXPECT().Set(ldiff.Element{
Id: "id1",
Head: "h1h2", Head: "h1h2",
}) })
hash := "123" fx.diffMock.EXPECT().Hash().Return("hash")
diffMock.EXPECT().Hash().Return(hash) fx.storageMock.EXPECT().WriteSpaceHash("hash").Return(nil)
storageMock.EXPECT().WriteSpaceHash(hash) fx.diffSyncerMock.EXPECT().Sync(gomock.Any()).Return(nil)
pSyncMock.EXPECT().Run() fx.diffSyncerMock.EXPECT().Close().Return(nil)
service.Init([]string{initId}, delState) err := fx.headSync.Run(ctx)
require.NoError(t, err)
err = fx.headSync.Close(ctx)
require.NoError(t, err)
}) })
t.Run("update heads", func(t *testing.T) { t.Run("update heads", func(t *testing.T) {
syncer.EXPECT().UpdateHeads(initId, []string{"h1", "h2"}) fx := newHeadSyncFixture(t)
service.UpdateHeads(initId, []string{"h1", "h2"}) fx.init(t)
defer fx.stop()
fx.diffSyncerMock.EXPECT().UpdateHeads("id1", []string{"h1"})
fx.headSync.UpdateHeads("id1", []string{"h1"})
}) })
t.Run("remove objects", func(t *testing.T) { t.Run("remove objects", func(t *testing.T) {
syncer.EXPECT().RemoveObjects([]string{"h1", "h2"}) fx := newHeadSyncFixture(t)
service.RemoveObjects([]string{"h1", "h2"}) fx.init(t)
}) defer fx.stop()
t.Run("close", func(t *testing.T) { fx.diffSyncerMock.EXPECT().RemoveObjects([]string{"id1"})
pSyncMock.EXPECT().Close() fx.headSync.RemoveObjects([]string{"id1"})
service.Close()
}) })
} }

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anytypeio/any-sync/commonspace/headsync (interfaces: DiffSyncer) // Source: github.com/anyproto/any-sync/commonspace/headsync (interfaces: DiffSyncer)
// Package mock_headsync is a generated GoMock package. // Package mock_headsync is a generated GoMock package.
package mock_headsync package mock_headsync
@ -8,8 +8,7 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
deletionstate "github.com/anytypeio/any-sync/commonspace/settings/deletionstate" gomock "go.uber.org/mock/gomock"
gomock "github.com/golang/mock/gomock"
) )
// MockDiffSyncer is a mock of DiffSyncer interface. // MockDiffSyncer is a mock of DiffSyncer interface.
@ -35,16 +34,30 @@ func (m *MockDiffSyncer) EXPECT() *MockDiffSyncerMockRecorder {
return m.recorder return m.recorder
} }
// Init mocks base method. // Close mocks base method.
func (m *MockDiffSyncer) Init(arg0 deletionstate.DeletionState) { func (m *MockDiffSyncer) Close() error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Init", arg0) ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockDiffSyncerMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDiffSyncer)(nil).Close))
}
// Init mocks base method.
func (m *MockDiffSyncer) Init() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Init")
} }
// Init indicates an expected call of Init. // Init indicates an expected call of Init.
func (mr *MockDiffSyncerMockRecorder) Init(arg0 interface{}) *gomock.Call { func (mr *MockDiffSyncerMockRecorder) Init() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockDiffSyncer)(nil).Init), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockDiffSyncer)(nil).Init))
} }
// RemoveObjects mocks base method. // RemoveObjects mocks base method.

View File

@ -2,8 +2,8 @@ package headsync
import ( import (
"context" "context"
"github.com/anytypeio/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/ldiff"
"github.com/anytypeio/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
) )
type Client interface { type Client interface {

View File

@ -3,8 +3,8 @@ package headsync
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/anytypeio/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/ldiff"
"github.com/anytypeio/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing" "testing"

View File

@ -0,0 +1,27 @@
package headsync
import "strings"
func concatStrings(strs []string) string {
var (
b strings.Builder
totalLen int
)
for _, s := range strs {
totalLen += len(s)
}
b.Grow(totalLen)
for _, s := range strs {
b.WriteString(s)
}
return b.String()
}
func splitString(str string) (res []string) {
const cidLen = 59
for i := 0; i < len(str); i += cidLen {
res = append(res, str[i:i+cidLen])
}
return
}

View File

@ -1,14 +1,36 @@
package accountdata package accountdata
import ( import (
"github.com/anytypeio/any-sync/util/keys/asymmetric/encryptionkey" "crypto/rand"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey" "github.com/anyproto/any-sync/util/crypto"
) )
type AccountData struct { // TODO: create a convenient constructor for this type AccountKeys struct {
Identity []byte // public key PeerKey crypto.PrivKey
PeerKey signingkey.PrivKey SignKey crypto.PrivKey
SignKey signingkey.PrivKey
EncKey encryptionkey.PrivKey
PeerId string PeerId string
} }
func New(peerKey, signKey crypto.PrivKey) *AccountKeys {
return &AccountKeys{
PeerKey: peerKey,
SignKey: signKey,
PeerId: peerKey.GetPublic().PeerId(),
}
}
func NewRandom() (*AccountKeys, error) {
peerKey, _, err := crypto.GenerateEd25519Key(rand.Reader)
if err != nil {
return nil, err
}
signKey, _, err := crypto.GenerateEd25519Key(rand.Reader)
if err != nil {
return nil, err
}
return &AccountKeys{
PeerKey: peerKey,
SignKey: signKey,
PeerId: peerKey.GetPublic().PeerId(),
}, nil
}

View File

@ -1,12 +0,0 @@
package aclrecordproto
import (
"github.com/anytypeio/any-sync/util/keys/symmetric"
)
func AclReadKeyDerive(signKey []byte, encKey []byte) (*symmetric.Key, error) {
concBuf := make([]byte, 0, len(signKey)+len(encKey))
concBuf = append(concBuf, signKey...)
concBuf = append(concBuf, encKey...)
return symmetric.DeriveFromBytes(concBuf)
}

File diff suppressed because it is too large Load Diff

View File

@ -2,117 +2,105 @@ syntax = "proto3";
package aclrecord; package aclrecord;
option go_package = "commonspace/object/acl/aclrecordproto"; option go_package = "commonspace/object/acl/aclrecordproto";
message RawAclRecord { // AclRoot is a root of access control list
bytes payload = 1;
bytes signature = 2;
bytes acceptorIdentity = 3;
bytes acceptorSignature = 4;
}
message RawAclRecordWithId {
bytes payload = 1;
string id = 2;
}
message AclRecord {
string prevId = 1;
bytes identity = 2;
bytes data = 3;
uint64 currentReadKeyHash = 4;
int64 timestamp = 5;
}
message AclRoot { message AclRoot {
bytes identity = 1; bytes identity = 1;
bytes encryptionKey = 2; bytes masterKey = 2;
string spaceId = 3; string spaceId = 3;
bytes encryptedReadKey = 4; bytes encryptedReadKey = 4;
string derivationScheme = 5; int64 timestamp = 5;
uint64 currentReadKeyHash = 6; bytes identitySignature = 6;
int64 timestamp = 7;
} }
message AclContentValue { // AclAccountInvite contains the public invite key, the private part of which is sent to the user directly
oneof value { message AclAccountInvite {
AclUserAdd userAdd = 1; bytes inviteKey = 1;
AclUserRemove userRemove = 2;
AclUserPermissionChange userPermissionChange = 3;
AclUserInvite userInvite = 4;
AclUserJoin userJoin = 5;
}
} }
message AclData { // AclAccountRequestJoin contains the reference to the invite record and the data of the person who wants to join, confirmed by the private invite key
repeated AclContentValue aclContent = 1; message AclAccountRequestJoin {
bytes inviteIdentity = 1;
string inviteRecordId = 2;
bytes inviteIdentitySignature = 3;
bytes metadata = 4;
} }
message AclState { // AclAccountRequestAccept contains the reference to join record and all read keys, encrypted with the identity of the requestor
repeated uint64 readKeyHashes = 1; message AclAccountRequestAccept {
repeated AclUserState userStates = 2;
map<string, AclUserInvite> invites = 3;
}
message AclUserState {
bytes identity = 1; bytes identity = 1;
bytes encryptionKey = 2; string requestRecordId = 2;
AclUserPermissions permissions = 3; repeated AclReadKeyWithRecord encryptedReadKeys = 3;
}
message AclUserAdd {
bytes identity = 1;
bytes encryptionKey = 2;
repeated bytes encryptedReadKeys = 3;
AclUserPermissions permissions = 4; AclUserPermissions permissions = 4;
} }
message AclUserInvite { // AclAccountRequestDecline contains the reference to join record
bytes acceptPublicKey = 1; message AclAccountRequestDecline {
uint64 encryptSymKeyHash = 2; string requestRecordId = 1;
repeated bytes encryptedReadKeys = 3;
AclUserPermissions permissions = 4;
} }
message AclUserJoin { // AclAccountInviteRevoke revokes the invite record
message AclAccountInviteRevoke {
string inviteRecordId = 1;
}
// AclReadKeys are a read key with record id
message AclReadKeyWithRecord {
string recordId = 1;
bytes encryptedReadKey = 2;
}
// AclEncryptedReadKeys are new key for specific identity
message AclEncryptedReadKey {
bytes identity = 1; bytes identity = 1;
bytes encryptionKey = 2; bytes encryptedReadKey = 2;
bytes acceptSignature = 3;
bytes acceptPubKey = 4;
repeated bytes encryptedReadKeys = 5;
} }
message AclUserRemove { // AclAccountPermissionChange changes permissions of specific account
bytes identity = 1; message AclAccountPermissionChange {
repeated AclReadKeyReplace readKeyReplaces = 2;
}
message AclReadKeyReplace {
bytes identity = 1;
bytes encryptionKey = 2;
bytes encryptedReadKey = 3;
}
message AclUserPermissionChange {
bytes identity = 1; bytes identity = 1;
AclUserPermissions permissions = 2; AclUserPermissions permissions = 2;
} }
enum AclUserPermissions { // AclReadKeyChange changes the key for a space
Admin = 0; message AclReadKeyChange {
Writer = 1; repeated AclEncryptedReadKey accountKeys = 1;
Reader = 2;
} }
message AclSyncMessage { // AclAccountRemove removes an account and changes read key for space
AclSyncContentValue content = 2; message AclAccountRemove {
repeated bytes identities = 1;
repeated AclEncryptedReadKey accountKeys = 2;
} }
// AclSyncContentValue provides different types for acl sync // AclAccountRequestRemove adds a request to remove an account
message AclSyncContentValue { message AclAccountRequestRemove {
}
// AclContentValue contains possible values for Acl
message AclContentValue {
oneof value { oneof value {
AclAddRecords addRecords = 1; AclAccountInvite invite = 1;
AclAccountInviteRevoke inviteRevoke = 2;
AclAccountRequestJoin requestJoin = 3;
AclAccountRequestAccept requestAccept = 4;
AclAccountPermissionChange permissionChange = 5;
AclAccountRemove accountRemove = 6;
AclReadKeyChange readKeyChange = 7;
AclAccountRequestDecline requestDecline = 8;
AclAccountRequestRemove accountRequestRemove = 9;
} }
} }
message AclAddRecords { // AclData contains different acl content
repeated RawAclRecordWithId records = 1; message AclData {
repeated AclContentValue aclContent = 1;
}
// AclUserPermissions contains different possible user roles
enum AclUserPermissions {
None = 0;
Owner = 1;
Admin = 2;
Writer = 3;
Reader = 4;
} }

View File

@ -1,166 +1,499 @@
package list package list
import ( import (
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anytypeio/any-sync/commonspace/object/keychain"
"github.com/anytypeio/any-sync/util/cidutil"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
"github.com/anytypeio/any-sync/util/keys/symmetric"
"github.com/gogo/protobuf/proto"
"time" "time"
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/util/cidutil"
"github.com/anyproto/any-sync/util/crypto"
"github.com/gogo/protobuf/proto"
) )
// remove interface type RootContent struct {
PrivKey crypto.PrivKey
MasterKey crypto.PrivKey
SpaceId string
EncryptedReadKey []byte
}
type RequestJoinPayload struct {
InviteRecordId string
InviteKey crypto.PrivKey
Metadata []byte
}
type RequestAcceptPayload struct {
RequestRecordId string
Permissions AclPermissions
}
type PermissionChangePayload struct {
Identity crypto.PubKey
Permissions AclPermissions
}
type AccountRemovePayload struct {
Identities []crypto.PubKey
ReadKey crypto.SymKey
}
type InviteResult struct {
InviteRec *consensusproto.RawRecord
InviteKey crypto.PrivKey
}
type AclRecordBuilder interface { type AclRecordBuilder interface {
ConvertFromRaw(rawIdRecord *aclrecordproto.RawAclRecordWithId) (rec *AclRecord, err error) UnmarshallWithId(rawIdRecord *consensusproto.RawRecordWithId) (rec *AclRecord, err error)
BuildUserJoin(acceptPrivKeyBytes []byte, encSymKeyBytes []byte, state *AclState) (rec *aclrecordproto.RawAclRecord, err error) Unmarshall(rawRecord *consensusproto.RawRecord) (rec *AclRecord, err error)
BuildRoot(content RootContent) (rec *consensusproto.RawRecordWithId, err error)
BuildInvite() (res InviteResult, err error)
BuildInviteRevoke(inviteRecordId string) (rawRecord *consensusproto.RawRecord, err error)
BuildRequestJoin(payload RequestJoinPayload) (rawRecord *consensusproto.RawRecord, err error)
BuildRequestAccept(payload RequestAcceptPayload) (rawRecord *consensusproto.RawRecord, err error)
BuildRequestDecline(requestRecordId string) (rawRecord *consensusproto.RawRecord, err error)
BuildRequestRemove() (rawRecord *consensusproto.RawRecord, err error)
BuildPermissionChange(payload PermissionChangePayload) (rawRecord *consensusproto.RawRecord, err error)
BuildReadKeyChange(newKey crypto.SymKey) (rawRecord *consensusproto.RawRecord, err error)
BuildAccountRemove(payload AccountRemovePayload) (rawRecord *consensusproto.RawRecord, err error)
} }
type aclRecordBuilder struct { type aclRecordBuilder struct {
id string id string
keychain *keychain.Keychain keyStorage crypto.KeyStorage
accountKeys *accountdata.AccountKeys
verifier AcceptorVerifier
state *AclState
} }
func newAclRecordBuilder(id string, keychain *keychain.Keychain) AclRecordBuilder { func NewAclRecordBuilder(id string, keyStorage crypto.KeyStorage, keys *accountdata.AccountKeys, verifier AcceptorVerifier) AclRecordBuilder {
return &aclRecordBuilder{ return &aclRecordBuilder{
id: id, id: id,
keychain: keychain, keyStorage: keyStorage,
accountKeys: keys,
verifier: verifier,
} }
} }
func (a *aclRecordBuilder) BuildUserJoin(acceptPrivKeyBytes []byte, encSymKeyBytes []byte, state *AclState) (rec *aclrecordproto.RawAclRecord, err error) { func (a *aclRecordBuilder) buildRecord(aclContent *aclrecordproto.AclContentValue) (rawRec *consensusproto.RawRecord, err error) {
acceptPrivKey, err := signingkey.NewSigningEd25519PrivKeyFromBytes(acceptPrivKeyBytes)
if err != nil {
return
}
acceptPubKeyBytes, err := acceptPrivKey.GetPublic().Raw()
if err != nil {
return
}
encSymKey, err := symmetric.FromBytes(encSymKeyBytes)
if err != nil {
return
}
invite, err := state.Invite(acceptPubKeyBytes)
if err != nil {
return
}
encPrivKey, signPrivKey := state.UserKeys()
var symKeys [][]byte
for _, rk := range invite.EncryptedReadKeys {
dec, err := encSymKey.Decrypt(rk)
if err != nil {
return nil, err
}
newEnc, err := encPrivKey.GetPublic().Encrypt(dec)
if err != nil {
return nil, err
}
symKeys = append(symKeys, newEnc)
}
idSignature, err := acceptPrivKey.Sign(state.Identity())
if err != nil {
return
}
encPubKeyBytes, err := encPrivKey.GetPublic().Raw()
if err != nil {
return
}
userJoin := &aclrecordproto.AclUserJoin{
Identity: state.Identity(),
EncryptionKey: encPubKeyBytes,
AcceptSignature: idSignature,
AcceptPubKey: acceptPubKeyBytes,
EncryptedReadKeys: symKeys,
}
aclData := &aclrecordproto.AclData{AclContent: []*aclrecordproto.AclContentValue{ aclData := &aclrecordproto.AclData{AclContent: []*aclrecordproto.AclContentValue{
{Value: &aclrecordproto.AclContentValue_UserJoin{UserJoin: userJoin}}, aclContent,
}} }}
marshalledJoin, err := aclData.Marshal() marshalledData, err := aclData.Marshal()
if err != nil { if err != nil {
return return
} }
aclRecord := &aclrecordproto.AclRecord{ protoKey, err := a.accountKeys.SignKey.GetPublic().Marshall()
PrevId: state.LastRecordId(),
Identity: state.Identity(),
Data: marshalledJoin,
CurrentReadKeyHash: state.CurrentReadKeyHash(),
Timestamp: time.Now().UnixNano(),
}
marshalledRecord, err := aclRecord.Marshal()
if err != nil { if err != nil {
return return
} }
recSignature, err := signPrivKey.Sign(marshalledRecord) rec := &consensusproto.Record{
PrevId: a.state.lastRecordId,
Identity: protoKey,
Data: marshalledData,
Timestamp: time.Now().Unix(),
}
marshalledRec, err := rec.Marshal()
if err != nil { if err != nil {
return return
} }
rec = &aclrecordproto.RawAclRecord{ signature, err := a.accountKeys.SignKey.Sign(marshalledRec)
Payload: marshalledRecord, if err != nil {
Signature: recSignature, return
}
rawRec = &consensusproto.RawRecord{
Payload: marshalledRec,
Signature: signature,
} }
return return
} }
func (a *aclRecordBuilder) ConvertFromRaw(rawIdRecord *aclrecordproto.RawAclRecordWithId) (rec *AclRecord, err error) { func (a *aclRecordBuilder) BuildInvite() (res InviteResult, err error) {
rawRec := &aclrecordproto.RawAclRecord{} if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
privKey, pubKey, err := crypto.GenerateRandomEd25519KeyPair()
if err != nil {
return
}
invitePubKey, err := pubKey.Marshall()
if err != nil {
return
}
inviteRec := &aclrecordproto.AclAccountInvite{InviteKey: invitePubKey}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_Invite{Invite: inviteRec}}
rawRec, err := a.buildRecord(content)
if err != nil {
return
}
res.InviteKey = privKey
res.InviteRec = rawRec
return
}
func (a *aclRecordBuilder) BuildInviteRevoke(inviteRecordId string) (rawRecord *consensusproto.RawRecord, err error) {
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
_, exists := a.state.inviteKeys[inviteRecordId]
if !exists {
err = ErrNoSuchInvite
return
}
revokeRec := &aclrecordproto.AclAccountInviteRevoke{InviteRecordId: inviteRecordId}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_InviteRevoke{InviteRevoke: revokeRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildRequestJoin(payload RequestJoinPayload) (rawRecord *consensusproto.RawRecord, err error) {
key, exists := a.state.inviteKeys[payload.InviteRecordId]
if !exists {
err = ErrNoSuchInvite
return
}
if !payload.InviteKey.GetPublic().Equals(key) {
err = ErrIncorrectInviteKey
}
rawIdentity, err := a.accountKeys.SignKey.GetPublic().Raw()
if err != nil {
return
}
signature, err := payload.InviteKey.Sign(rawIdentity)
if err != nil {
return
}
protoIdentity, err := a.accountKeys.SignKey.GetPublic().Marshall()
if err != nil {
return
}
joinRec := &aclrecordproto.AclAccountRequestJoin{
InviteIdentity: protoIdentity,
InviteRecordId: payload.InviteRecordId,
InviteIdentitySignature: signature,
Metadata: payload.Metadata,
}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_RequestJoin{RequestJoin: joinRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildRequestAccept(payload RequestAcceptPayload) (rawRecord *consensusproto.RawRecord, err error) {
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
request, exists := a.state.requestRecords[payload.RequestRecordId]
if !exists {
err = ErrNoSuchRequest
return
}
var encryptedReadKeys []*aclrecordproto.AclReadKeyWithRecord
for keyId, key := range a.state.userReadKeys {
rawKey, err := key.Raw()
if err != nil {
return nil, err
}
enc, err := request.RequestIdentity.Encrypt(rawKey)
if err != nil {
return nil, err
}
encryptedReadKeys = append(encryptedReadKeys, &aclrecordproto.AclReadKeyWithRecord{
RecordId: keyId,
EncryptedReadKey: enc,
})
}
if err != nil {
return
}
requestIdentityProto, err := request.RequestIdentity.Marshall()
if err != nil {
return
}
acceptRec := &aclrecordproto.AclAccountRequestAccept{
Identity: requestIdentityProto,
RequestRecordId: payload.RequestRecordId,
EncryptedReadKeys: encryptedReadKeys,
Permissions: aclrecordproto.AclUserPermissions(payload.Permissions),
}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_RequestAccept{RequestAccept: acceptRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildRequestDecline(requestRecordId string) (rawRecord *consensusproto.RawRecord, err error) {
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
_, exists := a.state.requestRecords[requestRecordId]
if !exists {
err = ErrNoSuchRequest
return
}
declineRec := &aclrecordproto.AclAccountRequestDecline{RequestRecordId: requestRecordId}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_RequestDecline{RequestDecline: declineRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildPermissionChange(payload PermissionChangePayload) (rawRecord *consensusproto.RawRecord, err error) {
permissions := a.state.Permissions(a.state.pubKey)
if !permissions.CanManageAccounts() || payload.Identity.Equals(a.state.pubKey) {
err = ErrInsufficientPermissions
return
}
if payload.Permissions.IsOwner() {
err = ErrIsOwner
return
}
protoIdentity, err := payload.Identity.Marshall()
if err != nil {
return
}
permissionRec := &aclrecordproto.AclAccountPermissionChange{
Identity: protoIdentity,
Permissions: aclrecordproto.AclUserPermissions(payload.Permissions),
}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_PermissionChange{PermissionChange: permissionRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildReadKeyChange(newKey crypto.SymKey) (rawRecord *consensusproto.RawRecord, err error) {
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
rawKey, err := newKey.Raw()
if err != nil {
return
}
if len(rawKey) != crypto.KeyBytes {
err = ErrIncorrectReadKey
return
}
var aclReadKeys []*aclrecordproto.AclEncryptedReadKey
for _, st := range a.state.userStates {
protoIdentity, err := st.PubKey.Marshall()
if err != nil {
return nil, err
}
enc, err := st.PubKey.Encrypt(rawKey)
if err != nil {
return nil, err
}
aclReadKeys = append(aclReadKeys, &aclrecordproto.AclEncryptedReadKey{
Identity: protoIdentity,
EncryptedReadKey: enc,
})
}
readRec := &aclrecordproto.AclReadKeyChange{AccountKeys: aclReadKeys}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_ReadKeyChange{ReadKeyChange: readRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildAccountRemove(payload AccountRemovePayload) (rawRecord *consensusproto.RawRecord, err error) {
deletedMap := map[string]struct{}{}
for _, key := range payload.Identities {
permissions := a.state.Permissions(key)
if permissions.IsOwner() {
return nil, ErrInsufficientPermissions
}
if permissions.NoPermissions() {
return nil, ErrNoSuchAccount
}
deletedMap[mapKeyFromPubKey(key)] = struct{}{}
}
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
rawKey, err := payload.ReadKey.Raw()
if err != nil {
return
}
if len(rawKey) != crypto.KeyBytes {
err = ErrIncorrectReadKey
return
}
var aclReadKeys []*aclrecordproto.AclEncryptedReadKey
for _, st := range a.state.userStates {
if _, exists := deletedMap[mapKeyFromPubKey(st.PubKey)]; exists {
continue
}
protoIdentity, err := st.PubKey.Marshall()
if err != nil {
return nil, err
}
enc, err := st.PubKey.Encrypt(rawKey)
if err != nil {
return nil, err
}
aclReadKeys = append(aclReadKeys, &aclrecordproto.AclEncryptedReadKey{
Identity: protoIdentity,
EncryptedReadKey: enc,
})
}
var marshalledIdentities [][]byte
for _, key := range payload.Identities {
protoIdentity, err := key.Marshall()
if err != nil {
return nil, err
}
marshalledIdentities = append(marshalledIdentities, protoIdentity)
}
removeRec := &aclrecordproto.AclAccountRemove{AccountKeys: aclReadKeys, Identities: marshalledIdentities}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_AccountRemove{AccountRemove: removeRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildRequestRemove() (rawRecord *consensusproto.RawRecord, err error) {
permissions := a.state.Permissions(a.state.pubKey)
if permissions.NoPermissions() {
err = ErrNoSuchAccount
return
}
if permissions.IsOwner() {
err = ErrIsOwner
return
}
removeRec := &aclrecordproto.AclAccountRequestRemove{}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_AccountRequestRemove{AccountRequestRemove: removeRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) Unmarshall(rawRecord *consensusproto.RawRecord) (rec *AclRecord, err error) {
aclRecord := &consensusproto.Record{}
err = proto.Unmarshal(rawRecord.Payload, aclRecord)
if err != nil {
return
}
pubKey, err := a.keyStorage.PubKeyFromProto(aclRecord.Identity)
if err != nil {
return
}
aclData := &aclrecordproto.AclData{}
err = proto.Unmarshal(aclRecord.Data, aclData)
if err != nil {
return
}
rec = &AclRecord{
PrevId: aclRecord.PrevId,
Timestamp: aclRecord.Timestamp,
Data: aclRecord.Data,
Signature: rawRecord.Signature,
Identity: pubKey,
Model: aclData,
}
res, err := pubKey.Verify(rawRecord.Payload, rawRecord.Signature)
if err != nil {
return
}
if !res {
err = ErrInvalidSignature
return
}
return
}
func (a *aclRecordBuilder) UnmarshallWithId(rawIdRecord *consensusproto.RawRecordWithId) (rec *AclRecord, err error) {
var (
rawRec = &consensusproto.RawRecord{}
pubKey crypto.PubKey
)
err = proto.Unmarshal(rawIdRecord.Payload, rawRec) err = proto.Unmarshal(rawIdRecord.Payload, rawRec)
if err != nil { if err != nil {
return return
} }
if rawIdRecord.Id == a.id { if rawIdRecord.Id == a.id {
aclRoot := &aclrecordproto.AclRoot{} aclRoot := &aclrecordproto.AclRoot{}
err = proto.Unmarshal(rawRec.Payload, aclRoot) err = proto.Unmarshal(rawRec.Payload, aclRoot)
if err != nil { if err != nil {
return return
} }
pubKey, err = a.keyStorage.PubKeyFromProto(aclRoot.Identity)
if err != nil {
return
}
rec = &AclRecord{ rec = &AclRecord{
Id: rawIdRecord.Id, Id: rawIdRecord.Id,
CurrentReadKeyHash: aclRoot.CurrentReadKeyHash,
Timestamp: aclRoot.Timestamp, Timestamp: aclRoot.Timestamp,
Signature: rawRec.Signature, Signature: rawRec.Signature,
Identity: aclRoot.Identity, Identity: pubKey,
Model: aclRoot, Model: aclRoot,
} }
} else { } else {
aclRecord := &aclrecordproto.AclRecord{} err = a.verifier.VerifyAcceptor(rawRec)
if err != nil {
return
}
aclRecord := &consensusproto.Record{}
err = proto.Unmarshal(rawRec.Payload, aclRecord) err = proto.Unmarshal(rawRec.Payload, aclRecord)
if err != nil { if err != nil {
return return
} }
pubKey, err = a.keyStorage.PubKeyFromProto(aclRecord.Identity)
rec = &AclRecord{
Id: rawIdRecord.Id,
PrevId: aclRecord.PrevId,
CurrentReadKeyHash: aclRecord.CurrentReadKeyHash,
Timestamp: aclRecord.Timestamp,
Data: aclRecord.Data,
Signature: rawRec.Signature,
Identity: aclRecord.Identity,
}
}
err = verifyRaw(a.keychain, rawRec, rawIdRecord, rec.Identity)
return
}
func verifyRaw(
keychain *keychain.Keychain,
rawRec *aclrecordproto.RawAclRecord,
recWithId *aclrecordproto.RawAclRecordWithId,
identity []byte) (err error) {
identityKey, err := keychain.GetOrAdd(string(identity))
if err != nil { if err != nil {
return return
} }
aclData := &aclrecordproto.AclData{}
err = proto.Unmarshal(aclRecord.Data, aclData)
if err != nil {
return
}
rec = &AclRecord{
Id: rawIdRecord.Id,
PrevId: aclRecord.PrevId,
Timestamp: aclRecord.Timestamp,
Data: aclRecord.Data,
Signature: rawRec.Signature,
Identity: pubKey,
Model: aclData,
}
}
err = verifyRaw(pubKey, rawRec, rawIdRecord)
return
}
func (a *aclRecordBuilder) BuildRoot(content RootContent) (rec *consensusproto.RawRecordWithId, err error) {
rawIdentity, err := content.PrivKey.GetPublic().Raw()
if err != nil {
return
}
identity, err := content.PrivKey.GetPublic().Marshall()
if err != nil {
return
}
masterKey, err := content.MasterKey.GetPublic().Marshall()
if err != nil {
return
}
identitySignature, err := content.MasterKey.Sign(rawIdentity)
if err != nil {
return
}
var timestamp int64
if content.EncryptedReadKey != nil {
timestamp = time.Now().Unix()
}
aclRoot := &aclrecordproto.AclRoot{
Identity: identity,
SpaceId: content.SpaceId,
EncryptedReadKey: content.EncryptedReadKey,
MasterKey: masterKey,
IdentitySignature: identitySignature,
Timestamp: timestamp,
}
return marshalAclRoot(aclRoot, content.PrivKey)
}
func verifyRaw(
pubKey crypto.PubKey,
rawRec *consensusproto.RawRecord,
recWithId *consensusproto.RawRecordWithId) (err error) {
// verifying signature // verifying signature
res, err := identityKey.Verify(rawRec.Payload, rawRec.Signature) res, err := pubKey.Verify(rawRec.Payload, rawRec.Signature)
if err != nil { if err != nil {
return return
} }
@ -175,3 +508,31 @@ func verifyRaw(
} }
return return
} }
func marshalAclRoot(aclRoot *aclrecordproto.AclRoot, key crypto.PrivKey) (rawWithId *consensusproto.RawRecordWithId, err error) {
marshalledRoot, err := aclRoot.Marshal()
if err != nil {
return
}
signature, err := key.Sign(marshalledRoot)
if err != nil {
return
}
raw := &consensusproto.RawRecord{
Payload: marshalledRoot,
Signature: signature,
}
marshalledRaw, err := raw.Marshal()
if err != nil {
return
}
aclHeadId, err := cidutil.NewCidFromBytes(marshalledRaw)
if err != nil {
return
}
rawWithId = &consensusproto.RawRecordWithId{
Payload: marshalledRaw,
Id: aclHeadId,
}
return
}

View File

@ -1,50 +0,0 @@
package list
import (
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
acllistbuilder2 "github.com/anytypeio/any-sync/commonspace/object/acl/testutils/acllistbuilder"
"github.com/anytypeio/any-sync/commonspace/object/keychain"
"github.com/anytypeio/any-sync/util/cidutil"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
"github.com/stretchr/testify/require"
"testing"
)
func TestAclRecordBuilder_BuildUserJoin(t *testing.T) {
st, err := acllistbuilder2.NewListStorageWithTestName("userjoinexample.yml")
require.NoError(t, err, "building storage should not result in error")
testKeychain := st.(*acllistbuilder2.AclListStorageBuilder).GetKeychain()
identity := testKeychain.GeneratedIdentities["D"]
signPrivKey := testKeychain.SigningKeysByYAMLName["D"]
encPrivKey := testKeychain.EncryptionKeysByYAMLName["D"]
acc := &accountdata.AccountData{
Identity: []byte(identity),
SignKey: signPrivKey,
EncKey: encPrivKey,
}
aclList, err := BuildAclListWithIdentity(acc, st)
require.NoError(t, err, "building acl list should be without error")
recordBuilder := newAclRecordBuilder(aclList.Id(), keychain.NewKeychain())
rk, err := testKeychain.GetKey("key.Read.EncKey").(*acllistbuilder2.SymKey).Key.Raw()
require.NoError(t, err)
privKey, err := testKeychain.GetKey("key.Sign.Onetime1").(signingkey.PrivKey).Raw()
require.NoError(t, err)
userJoin, err := recordBuilder.BuildUserJoin(privKey, rk, aclList.AclState())
require.NoError(t, err)
marshalledJoin, err := userJoin.Marshal()
require.NoError(t, err)
id, err := cidutil.NewCidFromBytes(marshalledJoin)
require.NoError(t, err)
rawRec := &aclrecordproto.RawAclRecordWithId{
Payload: marshalledJoin,
Id: id,
}
res, err := aclList.AddRawRecord(rawRec)
require.True(t, res)
require.NoError(t, err)
require.Equal(t, aclrecordproto.AclUserPermissions_Writer, aclList.AclState().UserStates()[identity].Permissions)
}

View File

@ -1,120 +1,142 @@
package list package list
import ( import (
"bytes"
"errors" "errors"
"fmt"
"github.com/anytypeio/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
aclrecordproto2 "github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anytypeio/any-sync/commonspace/object/keychain" "github.com/anyproto/any-sync/util/crypto"
"github.com/anytypeio/any-sync/util/keys"
"github.com/anytypeio/any-sync/util/keys/asymmetric/encryptionkey"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
"github.com/anytypeio/any-sync/util/keys/symmetric"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"go.uber.org/zap" "go.uber.org/zap"
"hash/fnv"
) )
var log = logger.NewNamed("acllist").Sugar() var log = logger.NewNamedSugared("common.commonspace.acllist")
var ( var (
ErrNoSuchUser = errors.New("no such user") ErrNoSuchAccount = errors.New("no such account")
ErrPendingRequest = errors.New("already exists pending request")
ErrUnexpectedContentType = errors.New("unexpected content type")
ErrIncorrectIdentity = errors.New("incorrect identity")
ErrIncorrectInviteKey = errors.New("incorrect invite key")
ErrFailedToDecrypt = errors.New("failed to decrypt key") ErrFailedToDecrypt = errors.New("failed to decrypt key")
ErrUserRemoved = errors.New("user was removed from the document")
ErrDocumentForbidden = errors.New("your user was forbidden access to the document")
ErrUserAlreadyExists = errors.New("user already exists")
ErrNoSuchRecord = errors.New("no such record") ErrNoSuchRecord = errors.New("no such record")
ErrNoSuchRequest = errors.New("no such request")
ErrNoSuchInvite = errors.New("no such invite") ErrNoSuchInvite = errors.New("no such invite")
ErrOldInvite = errors.New("invite is too old")
ErrInsufficientPermissions = errors.New("insufficient permissions") ErrInsufficientPermissions = errors.New("insufficient permissions")
ErrIsOwner = errors.New("can't be made by owner")
ErrIncorrectNumberOfAccounts = errors.New("incorrect number of accounts")
ErrDuplicateAccounts = errors.New("duplicate accounts")
ErrNoReadKey = errors.New("acl state doesn't have a read key") ErrNoReadKey = errors.New("acl state doesn't have a read key")
ErrIncorrectReadKey = errors.New("incorrect read key")
ErrInvalidSignature = errors.New("signature is invalid") ErrInvalidSignature = errors.New("signature is invalid")
ErrIncorrectRoot = errors.New("incorrect root") ErrIncorrectRoot = errors.New("incorrect root")
ErrIncorrectRecordSequence = errors.New("incorrect prev id of a record") ErrIncorrectRecordSequence = errors.New("incorrect prev id of a record")
) )
type UserPermissionPair struct { type UserPermissionPair struct {
Identity string Identity crypto.PubKey
Permission aclrecordproto2.AclUserPermissions Permission aclrecordproto.AclUserPermissions
} }
type AclState struct { type AclState struct {
id string id string
currentReadKeyHash uint64 currentReadKeyId string
userReadKeys map[uint64]*symmetric.Key // userReadKeys is a map recordId -> read key which tells us about every read key
userStates map[string]*aclrecordproto2.AclUserState userReadKeys map[string]crypto.SymKey
userInvites map[string]*aclrecordproto2.AclUserInvite // userStates is a map pubKey -> state which defines current user state
encryptionKey encryptionkey.PrivKey userStates map[string]AclUserState
signingKey signingkey.PrivKey // statesAtRecord is a map recordId -> state which define user state at particular record
// probably this can grow rather large at some point, so we can maybe optimise later to have:
// - map pubKey -> []recordIds (where recordIds is an array where such identity permissions were changed)
statesAtRecord map[string][]AclUserState
// inviteKeys is a map recordId -> invite
inviteKeys map[string]crypto.PubKey
// requestRecords is a map recordId -> RequestRecord
requestRecords map[string]RequestRecord
// pendingRequests is a map pubKey -> recordId
pendingRequests map[string]string
key crypto.PrivKey
pubKey crypto.PubKey
keyStore crypto.KeyStorage
totalReadKeys int totalReadKeys int
identity string
permissionsAtRecord map[string][]UserPermissionPair
lastRecordId string lastRecordId string
contentValidator ContentValidator
keychain *keychain.Keychain
} }
func newAclStateWithKeys( func newAclStateWithKeys(
id string, id string,
signingKey signingkey.PrivKey, key crypto.PrivKey) (*AclState, error) {
encryptionKey encryptionkey.PrivKey) (*AclState, error) { st := &AclState{
identity, err := signingKey.GetPublic().Raw()
if err != nil {
return nil, err
}
return &AclState{
id: id, id: id,
identity: string(identity), key: key,
signingKey: signingKey, pubKey: key.GetPublic(),
encryptionKey: encryptionKey, userReadKeys: make(map[string]crypto.SymKey),
userReadKeys: make(map[uint64]*symmetric.Key), userStates: make(map[string]AclUserState),
userStates: make(map[string]*aclrecordproto2.AclUserState), statesAtRecord: make(map[string][]AclUserState),
userInvites: make(map[string]*aclrecordproto2.AclUserInvite), inviteKeys: make(map[string]crypto.PubKey),
permissionsAtRecord: make(map[string][]UserPermissionPair), requestRecords: make(map[string]RequestRecord),
}, nil pendingRequests: make(map[string]string),
keyStore: crypto.NewKeyStorage(),
}
st.contentValidator = &contentValidator{
keyStore: st.keyStore,
aclState: st,
}
return st, nil
} }
func newAclState(id string) *AclState { func newAclState(id string) *AclState {
return &AclState{ st := &AclState{
id: id, id: id,
userReadKeys: make(map[uint64]*symmetric.Key), userReadKeys: make(map[string]crypto.SymKey),
userStates: make(map[string]*aclrecordproto2.AclUserState), userStates: make(map[string]AclUserState),
userInvites: make(map[string]*aclrecordproto2.AclUserInvite), statesAtRecord: make(map[string][]AclUserState),
permissionsAtRecord: make(map[string][]UserPermissionPair), inviteKeys: make(map[string]crypto.PubKey),
requestRecords: make(map[string]RequestRecord),
pendingRequests: make(map[string]string),
keyStore: crypto.NewKeyStorage(),
} }
st.contentValidator = &contentValidator{
keyStore: st.keyStore,
aclState: st,
}
return st
} }
func (st *AclState) CurrentReadKeyHash() uint64 { func (st *AclState) Validator() ContentValidator {
return st.currentReadKeyHash return st.contentValidator
} }
func (st *AclState) CurrentReadKey() (*symmetric.Key, error) { func (st *AclState) CurrentReadKeyId() string {
key, exists := st.userReadKeys[st.currentReadKeyHash] return st.currentReadKeyId
}
func (st *AclState) CurrentReadKey() (crypto.SymKey, error) {
key, exists := st.userReadKeys[st.CurrentReadKeyId()]
if !exists { if !exists {
return nil, ErrNoReadKey return nil, ErrNoReadKey
} }
return key, nil return key, nil
} }
func (st *AclState) UserReadKeys() map[uint64]*symmetric.Key { func (st *AclState) UserReadKeys() map[string]crypto.SymKey {
return st.userReadKeys return st.userReadKeys
} }
func (st *AclState) PermissionsAtRecord(id string, identity string) (UserPermissionPair, error) { func (st *AclState) StateAtRecord(id string, pubKey crypto.PubKey) (AclUserState, error) {
permissions, ok := st.permissionsAtRecord[id] userState, ok := st.statesAtRecord[id]
if !ok { if !ok {
log.Errorf("missing record at id %s", id) log.Errorf("missing record at id %s", id)
return UserPermissionPair{}, ErrNoSuchRecord return AclUserState{}, ErrNoSuchRecord
} }
for _, perm := range permissions { for _, perm := range userState {
if perm.Identity == identity { if perm.PubKey.Equals(pubKey) {
return perm, nil return perm, nil
} }
} }
return UserPermissionPair{}, ErrNoSuchUser return AclUserState{}, ErrNoSuchAccount
} }
func (st *AclState) applyRecord(record *AclRecord) (err error) { func (st *AclState) applyRecord(record *AclRecord) (err error) {
@ -127,338 +149,316 @@ func (st *AclState) applyRecord(record *AclRecord) (err error) {
err = ErrIncorrectRecordSequence err = ErrIncorrectRecordSequence
return return
} }
// if the record is root record
if record.Id == st.id { if record.Id == st.id {
root, ok := record.Model.(*aclrecordproto2.AclRoot) err = st.applyRoot(record)
if !ok {
return ErrIncorrectRoot
}
err = st.applyRoot(root)
if err != nil { if err != nil {
return return
} }
st.permissionsAtRecord[record.Id] = []UserPermissionPair{ st.statesAtRecord[record.Id] = []AclUserState{
{Identity: string(root.Identity), Permission: aclrecordproto2.AclUserPermissions_Admin}, st.userStates[mapKeyFromPubKey(record.Identity)],
} }
return return
} }
aclData := &aclrecordproto2.AclData{} // if the model is not cached
if record.Model == nil {
if record.Model != nil { aclData := &aclrecordproto.AclData{}
aclData = record.Model.(*aclrecordproto2.AclData)
} else {
err = proto.Unmarshal(record.Data, aclData) err = proto.Unmarshal(record.Data, aclData)
if err != nil { if err != nil {
return return
} }
record.Model = aclData record.Model = aclData
} }
// applying records contents
err = st.applyChangeData(aclData, record.CurrentReadKeyHash, record.Identity) err = st.applyChangeData(record)
if err != nil { if err != nil {
return return
} }
// getting all states for users at record and saving them
// getting all permissions for users at record var states []AclUserState
var permissions []UserPermissionPair
for _, state := range st.userStates { for _, state := range st.userStates {
permission := UserPermissionPair{ states = append(states, state)
Identity: string(state.Identity),
Permission: state.Permissions,
} }
permissions = append(permissions, permission) st.statesAtRecord[record.Id] = states
}
st.permissionsAtRecord[record.Id] = permissions
return return
} }
func (st *AclState) applyRoot(root *aclrecordproto2.AclRoot) (err error) { func (st *AclState) applyRoot(record *AclRecord) (err error) {
if st.signingKey != nil && st.encryptionKey != nil && st.identity == string(root.Identity) { if st.key != nil && st.pubKey.Equals(record.Identity) {
err = st.saveReadKeyFromRoot(root) err = st.saveReadKeyFromRoot(record)
if err != nil { if err != nil {
return return
} }
} }
// adding user to the list // adding user to the list
userState := &aclrecordproto2.AclUserState{ userState := AclUserState{
Identity: root.Identity, PubKey: record.Identity,
EncryptionKey: root.EncryptionKey, Permissions: AclPermissions(aclrecordproto.AclUserPermissions_Owner),
Permissions: aclrecordproto2.AclUserPermissions_Admin,
} }
st.currentReadKeyHash = root.CurrentReadKeyHash st.currentReadKeyId = record.Id
st.userStates[string(root.Identity)] = userState st.userStates[mapKeyFromPubKey(record.Identity)] = userState
st.totalReadKeys++ st.totalReadKeys++
return return
} }
func (st *AclState) saveReadKeyFromRoot(root *aclrecordproto2.AclRoot) (err error) { func (st *AclState) saveReadKeyFromRoot(record *AclRecord) (err error) {
var readKey *symmetric.Key var readKey crypto.SymKey
if len(root.GetDerivationScheme()) != 0 { root, ok := record.Model.(*aclrecordproto.AclRoot)
var encPrivKey []byte if !ok {
encPrivKey, err = st.encryptionKey.Raw() return ErrIncorrectRoot
if err != nil {
return
} }
var signPrivKey []byte if root.EncryptedReadKey == nil {
signPrivKey, err = st.signingKey.Raw() readKey, err = st.deriveKey()
if err != nil {
return
}
readKey, err = aclrecordproto2.AclReadKeyDerive(signPrivKey, encPrivKey)
if err != nil { if err != nil {
return return
} }
} else { } else {
readKey, _, err = st.decryptReadKeyAndHash(root.EncryptedReadKey) readKey, err = st.decryptReadKey(root.EncryptedReadKey)
if err != nil { if err != nil {
return return
} }
} }
st.userReadKeys[record.Id] = readKey
hasher := fnv.New64()
_, err = hasher.Write(readKey.Bytes())
if err != nil {
return
}
if hasher.Sum64() != root.CurrentReadKeyHash {
return ErrIncorrectRoot
}
st.userReadKeys[root.CurrentReadKeyHash] = readKey
return return
} }
func (st *AclState) applyChangeData(changeData *aclrecordproto2.AclData, hash uint64, identity []byte) (err error) { func (st *AclState) applyChangeData(record *AclRecord) (err error) {
defer func() { model := record.Model.(*aclrecordproto.AclData)
if err != nil { for _, ch := range model.GetAclContent() {
return if err = st.applyChangeContent(ch, record.Id, record.Identity); err != nil {
}
if hash != st.currentReadKeyHash {
st.totalReadKeys++
st.currentReadKeyHash = hash
}
}()
if !st.isUserJoin(changeData) {
// we check signature when we add this to the List, so no need to do it here
if _, exists := st.userStates[string(identity)]; !exists {
err = ErrNoSuchUser
return
}
if !st.HasPermission(identity, aclrecordproto2.AclUserPermissions_Admin) {
err = fmt.Errorf("user %s must have admin permissions", identity)
return
}
}
for _, ch := range changeData.GetAclContent() {
if err = st.applyChangeContent(ch); err != nil {
log.Info("error while applying changes: %v; ignore", zap.Error(err)) log.Info("error while applying changes: %v; ignore", zap.Error(err))
return err return err
} }
} }
return nil return nil
} }
func (st *AclState) applyChangeContent(ch *aclrecordproto2.AclContentValue) error { func (st *AclState) applyChangeContent(ch *aclrecordproto.AclContentValue, recordId string, authorIdentity crypto.PubKey) error {
switch { switch {
case ch.GetUserPermissionChange() != nil: case ch.GetPermissionChange() != nil:
return st.applyUserPermissionChange(ch.GetUserPermissionChange()) return st.applyPermissionChange(ch.GetPermissionChange(), recordId, authorIdentity)
case ch.GetUserAdd() != nil: case ch.GetInvite() != nil:
return st.applyUserAdd(ch.GetUserAdd()) return st.applyInvite(ch.GetInvite(), recordId, authorIdentity)
case ch.GetUserRemove() != nil: case ch.GetInviteRevoke() != nil:
return st.applyUserRemove(ch.GetUserRemove()) return st.applyInviteRevoke(ch.GetInviteRevoke(), recordId, authorIdentity)
case ch.GetUserInvite() != nil: case ch.GetRequestJoin() != nil:
return st.applyUserInvite(ch.GetUserInvite()) return st.applyRequestJoin(ch.GetRequestJoin(), recordId, authorIdentity)
case ch.GetUserJoin() != nil: case ch.GetRequestAccept() != nil:
return st.applyUserJoin(ch.GetUserJoin()) return st.applyRequestAccept(ch.GetRequestAccept(), recordId, authorIdentity)
case ch.GetRequestDecline() != nil:
return st.applyRequestDecline(ch.GetRequestDecline(), recordId, authorIdentity)
case ch.GetAccountRemove() != nil:
return st.applyAccountRemove(ch.GetAccountRemove(), recordId, authorIdentity)
case ch.GetReadKeyChange() != nil:
return st.applyReadKeyChange(ch.GetReadKeyChange(), recordId, authorIdentity)
case ch.GetAccountRequestRemove() != nil:
return st.applyRequestRemove(ch.GetAccountRequestRemove(), recordId, authorIdentity)
default: default:
return fmt.Errorf("unexpected change type: %v", ch) return ErrUnexpectedContentType
} }
} }
func (st *AclState) applyUserPermissionChange(ch *aclrecordproto2.AclUserPermissionChange) error { func (st *AclState) applyPermissionChange(ch *aclrecordproto.AclAccountPermissionChange, recordId string, authorIdentity crypto.PubKey) error {
chIdentity := string(ch.Identity) chIdentity, err := st.keyStore.PubKeyFromProto(ch.Identity)
state, exists := st.userStates[chIdentity] if err != nil {
if !exists { return err
return ErrNoSuchUser
} }
err = st.contentValidator.ValidatePermissionChange(ch, authorIdentity)
state.Permissions = ch.Permissions if err != nil {
return err
}
stringKey := mapKeyFromPubKey(chIdentity)
state, _ := st.userStates[stringKey]
state.Permissions = AclPermissions(ch.Permissions)
st.userStates[stringKey] = state
return nil return nil
} }
func (st *AclState) applyUserInvite(ch *aclrecordproto2.AclUserInvite) error { func (st *AclState) applyInvite(ch *aclrecordproto.AclAccountInvite, recordId string, authorIdentity crypto.PubKey) error {
st.userInvites[string(ch.AcceptPublicKey)] = ch inviteKey, err := st.keyStore.PubKeyFromProto(ch.InviteKey)
if err != nil {
return err
}
err = st.contentValidator.ValidateInvite(ch, authorIdentity)
if err != nil {
return err
}
st.inviteKeys[recordId] = inviteKey
return nil return nil
} }
func (st *AclState) applyUserJoin(ch *aclrecordproto2.AclUserJoin) error { func (st *AclState) applyInviteRevoke(ch *aclrecordproto.AclAccountInviteRevoke, recordId string, authorIdentity crypto.PubKey) error {
invite, exists := st.userInvites[string(ch.AcceptPubKey)] err := st.contentValidator.ValidateInviteRevoke(ch, authorIdentity)
if !exists {
return fmt.Errorf("no such invite with such public key %s", keys.EncodeBytesToString(ch.AcceptPubKey))
}
chIdentity := string(ch.Identity)
if _, exists = st.userStates[chIdentity]; exists {
return ErrUserAlreadyExists
}
// validating signature
signature := ch.GetAcceptSignature()
verificationKey, err := signingkey.NewSigningEd25519PubKeyFromBytes(invite.AcceptPublicKey)
if err != nil { if err != nil {
return fmt.Errorf("public key verifying invite accepts is given in incorrect format: %v", err) return err
} }
delete(st.inviteKeys, ch.InviteRecordId)
return nil
}
res, err := verificationKey.Verify(ch.Identity, signature) func (st *AclState) applyRequestJoin(ch *aclrecordproto.AclAccountRequestJoin, recordId string, authorIdentity crypto.PubKey) error {
err := st.contentValidator.ValidateRequestJoin(ch, authorIdentity)
if err != nil { if err != nil {
return fmt.Errorf("verification returned error: %w", err) return err
} }
if !res { st.pendingRequests[mapKeyFromPubKey(authorIdentity)] = recordId
return ErrInvalidSignature st.requestRecords[recordId] = RequestRecord{
RequestIdentity: authorIdentity,
RequestMetadata: ch.Metadata,
Type: RequestTypeJoin,
} }
return nil
}
// if ourselves -> we need to decrypt the read keys func (st *AclState) applyRequestAccept(ch *aclrecordproto.AclAccountRequestAccept, recordId string, authorIdentity crypto.PubKey) error {
if st.identity == chIdentity { err := st.contentValidator.ValidateRequestAccept(ch, authorIdentity)
if err != nil {
return err
}
acceptIdentity, err := st.keyStore.PubKeyFromProto(ch.Identity)
if err != nil {
return err
}
record, _ := st.requestRecords[ch.RequestRecordId]
st.userStates[mapKeyFromPubKey(acceptIdentity)] = AclUserState{
PubKey: acceptIdentity,
Permissions: AclPermissions(ch.Permissions),
RequestMetadata: record.RequestMetadata,
}
delete(st.pendingRequests, mapKeyFromPubKey(st.requestRecords[ch.RequestRecordId].RequestIdentity))
if !st.pubKey.Equals(acceptIdentity) {
return nil
}
for _, key := range ch.EncryptedReadKeys { for _, key := range ch.EncryptedReadKeys {
key, hash, err := st.decryptReadKeyAndHash(key) decrypted, err := st.key.Decrypt(key.EncryptedReadKey)
if err != nil { if err != nil {
return ErrFailedToDecrypt return err
} }
sym, err := crypto.UnmarshallAESKey(decrypted)
st.userReadKeys[hash] = key
}
}
// adding user to the list
userState := &aclrecordproto2.AclUserState{
Identity: ch.Identity,
EncryptionKey: ch.EncryptionKey,
Permissions: invite.Permissions,
}
st.userStates[chIdentity] = userState
return nil
}
func (st *AclState) applyUserAdd(ch *aclrecordproto2.AclUserAdd) error {
chIdentity := string(ch.Identity)
if _, exists := st.userStates[chIdentity]; exists {
return ErrUserAlreadyExists
}
st.userStates[chIdentity] = &aclrecordproto2.AclUserState{
Identity: ch.Identity,
EncryptionKey: ch.EncryptionKey,
Permissions: ch.Permissions,
}
if chIdentity == st.identity {
for _, key := range ch.EncryptedReadKeys {
key, hash, err := st.decryptReadKeyAndHash(key)
if err != nil { if err != nil {
return ErrFailedToDecrypt return err
}
st.userReadKeys[hash] = key
}
}
return nil
}
func (st *AclState) applyUserRemove(ch *aclrecordproto2.AclUserRemove) error {
chIdentity := string(ch.Identity)
if chIdentity == st.identity {
return ErrDocumentForbidden
}
if _, exists := st.userStates[chIdentity]; !exists {
return ErrNoSuchUser
}
delete(st.userStates, chIdentity)
for _, replace := range ch.ReadKeyReplaces {
repIdentity := string(replace.Identity)
// if this is our identity then we have to decrypt the key
if repIdentity == st.identity {
key, hash, err := st.decryptReadKeyAndHash(replace.EncryptedReadKey)
if err != nil {
return ErrFailedToDecrypt
}
st.userReadKeys[hash] = key
break
} }
st.userReadKeys[key.RecordId] = sym
} }
return nil return nil
} }
func (st *AclState) decryptReadKeyAndHash(msg []byte) (*symmetric.Key, uint64, error) { func (st *AclState) applyRequestDecline(ch *aclrecordproto.AclAccountRequestDecline, recordId string, authorIdentity crypto.PubKey) error {
decrypted, err := st.encryptionKey.Decrypt(msg) err := st.contentValidator.ValidateRequestDecline(ch, authorIdentity)
if err != nil { if err != nil {
return nil, 0, ErrFailedToDecrypt return err
} }
delete(st.pendingRequests, mapKeyFromPubKey(st.requestRecords[ch.RequestRecordId].RequestIdentity))
key, err := symmetric.FromBytes(decrypted) delete(st.requestRecords, ch.RequestRecordId)
if err != nil { return nil
return nil, 0, ErrFailedToDecrypt
}
hasher := fnv.New64()
hasher.Write(decrypted)
return key, hasher.Sum64(), nil
} }
func (st *AclState) HasPermission(identity []byte, permission aclrecordproto2.AclUserPermissions) bool { func (st *AclState) applyRequestRemove(ch *aclrecordproto.AclAccountRequestRemove, recordId string, authorIdentity crypto.PubKey) error {
state, exists := st.userStates[string(identity)] err := st.contentValidator.ValidateRequestRemove(ch, authorIdentity)
if err != nil {
return err
}
st.requestRecords[recordId] = RequestRecord{
RequestIdentity: authorIdentity,
Type: RequestTypeRemove,
}
st.pendingRequests[mapKeyFromPubKey(authorIdentity)] = recordId
return nil
}
func (st *AclState) applyAccountRemove(ch *aclrecordproto.AclAccountRemove, recordId string, authorIdentity crypto.PubKey) error {
err := st.contentValidator.ValidateAccountRemove(ch, authorIdentity)
if err != nil {
return err
}
for _, rawIdentity := range ch.Identities {
identity, err := st.keyStore.PubKeyFromProto(rawIdentity)
if err != nil {
return err
}
idKey := mapKeyFromPubKey(identity)
delete(st.userStates, idKey)
delete(st.pendingRequests, idKey)
}
return st.updateReadKey(ch.AccountKeys, recordId)
}
func (st *AclState) applyReadKeyChange(ch *aclrecordproto.AclReadKeyChange, recordId string, authorIdentity crypto.PubKey) error {
err := st.contentValidator.ValidateReadKeyChange(ch, authorIdentity)
if err != nil {
return err
}
return st.updateReadKey(ch.AccountKeys, recordId)
}
func (st *AclState) updateReadKey(keys []*aclrecordproto.AclEncryptedReadKey, recordId string) error {
for _, accKey := range keys {
identity, _ := st.keyStore.PubKeyFromProto(accKey.Identity)
if st.pubKey.Equals(identity) {
res, err := st.decryptReadKey(accKey.EncryptedReadKey)
if err != nil {
return err
}
st.userReadKeys[recordId] = res
}
}
st.currentReadKeyId = recordId
return nil
}
func (st *AclState) decryptReadKey(msg []byte) (crypto.SymKey, error) {
decrypted, err := st.key.Decrypt(msg)
if err != nil {
return nil, ErrFailedToDecrypt
}
key, err := crypto.UnmarshallAESKey(decrypted)
if err != nil {
return nil, ErrFailedToDecrypt
}
return key, nil
}
func (st *AclState) Permissions(identity crypto.PubKey) AclPermissions {
state, exists := st.userStates[mapKeyFromPubKey(identity)]
if !exists { if !exists {
return false return AclPermissions(aclrecordproto.AclUserPermissions_None)
} }
return state.Permissions
return state.Permissions == permission
} }
func (st *AclState) isUserJoin(data *aclrecordproto2.AclData) bool { func (st *AclState) JoinRecords() (records []RequestRecord) {
// if we have a UserJoin, then it should always be the first one applied for _, recId := range st.pendingRequests {
return data.GetAclContent() != nil && data.GetAclContent()[0].GetUserJoin() != nil rec := st.requestRecords[recId]
} if rec.Type == RequestTypeJoin {
records = append(records, rec)
func (st *AclState) isUserAdd(data *aclrecordproto2.AclData, identity []byte) bool {
// if we have a UserAdd, then it should always be the first one applied
userAdd := data.GetAclContent()[0].GetUserAdd()
return data.GetAclContent() != nil && userAdd != nil && bytes.Compare(userAdd.GetIdentity(), identity) == 0
}
func (st *AclState) UserStates() map[string]*aclrecordproto2.AclUserState {
return st.userStates
}
func (st *AclState) Invite(acceptPubKey []byte) (invite *aclrecordproto2.AclUserInvite, err error) {
invite, exists := st.userInvites[string(acceptPubKey)]
if !exists {
err = ErrNoSuchInvite
return
} }
if len(invite.EncryptedReadKeys) != st.totalReadKeys {
err = ErrOldInvite
} }
return return
} }
func (st *AclState) UserKeys() (encKey encryptionkey.PrivKey, signKey signingkey.PrivKey) { func (st *AclState) RemoveRecords() (records []RequestRecord) {
return st.encryptionKey, st.signingKey for _, recId := range st.pendingRequests {
} rec := st.requestRecords[recId]
if rec.Type == RequestTypeRemove {
func (st *AclState) Identity() []byte { records = append(records, rec)
return []byte(st.identity) }
}
return
} }
func (st *AclState) LastRecordId() string { func (st *AclState) LastRecordId() string {
return st.lastRecordId return st.lastRecordId
} }
func (st *AclState) deriveKey() (crypto.SymKey, error) {
keyBytes, err := st.key.Raw()
if err != nil {
return nil, err
}
return crypto.DeriveSymmetricKey(keyBytes, crypto.AnysyncSpacePath)
}
func mapKeyFromPubKey(pubKey crypto.PubKey) string {
return string(pubKey.Storage())
}

View File

@ -1,21 +1,18 @@
package list package list
import ( import (
"github.com/anytypeio/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anytypeio/any-sync/util/keys/asymmetric/encryptionkey" "github.com/anyproto/any-sync/util/crypto"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
) )
type aclStateBuilder struct { type aclStateBuilder struct {
signPrivKey signingkey.PrivKey privKey crypto.PrivKey
encPrivKey encryptionkey.PrivKey
id string id string
} }
func newAclStateBuilderWithIdentity(accountData *accountdata.AccountData) *aclStateBuilder { func newAclStateBuilderWithIdentity(keys *accountdata.AccountKeys) *aclStateBuilder {
return &aclStateBuilder{ return &aclStateBuilder{
signPrivKey: accountData.SignKey, privKey: keys.SignKey,
encPrivKey: accountData.EncKey,
} }
} }
@ -28,8 +25,8 @@ func (sb *aclStateBuilder) Init(id string) {
} }
func (sb *aclStateBuilder) Build(records []*AclRecord) (state *AclState, err error) { func (sb *aclStateBuilder) Build(records []*AclRecord) (state *AclState, err error) {
if sb.encPrivKey != nil && sb.signPrivKey != nil { if sb.privKey != nil {
state, err = newAclStateWithKeys(sb.id, sb.signPrivKey, sb.encPrivKey) state, err = newAclStateWithKeys(sb.id, sb.privKey)
if err != nil { if err != nil {
return return
} }

View File

@ -1,20 +1,25 @@
//go:generate mockgen -destination mock_list/mock_list.go github.com/anytypeio/any-sync/commonspace/object/acl/list AclList //go:generate mockgen -destination mock_list/mock_list.go github.com/anyproto/any-sync/commonspace/object/acl/list AclList
package list package list
import ( import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anytypeio/any-sync/commonspace/object/acl/liststorage"
"github.com/anytypeio/any-sync/commonspace/object/keychain"
"sync" "sync"
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/util/cidutil"
"github.com/anyproto/any-sync/util/crypto"
) )
type IterFunc = func(record *AclRecord) (IsContinue bool) type IterFunc = func(record *AclRecord) (IsContinue bool)
var ErrIncorrectCID = errors.New("incorrect CID") var (
ErrIncorrectCID = errors.New("incorrect CID")
ErrRecordAlreadyExists = errors.New("record already exists")
)
type RWLocker interface { type RWLocker interface {
sync.Locker sync.Locker
@ -22,48 +27,97 @@ type RWLocker interface {
RUnlock() RUnlock()
} }
type AcceptorVerifier interface {
VerifyAcceptor(rec *consensusproto.RawRecord) (err error)
}
type NoOpAcceptorVerifier struct {
}
func (n NoOpAcceptorVerifier) VerifyAcceptor(rec *consensusproto.RawRecord) (err error) {
return nil
}
type AclList interface { type AclList interface {
RWLocker RWLocker
Id() string Id() string
Root() *aclrecordproto.RawAclRecordWithId Root() *consensusproto.RawRecordWithId
Records() []*AclRecord Records() []*AclRecord
AclState() *AclState AclState() *AclState
IsAfter(first string, second string) (bool, error) IsAfter(first string, second string) (bool, error)
HasHead(head string) bool
Head() *AclRecord Head() *AclRecord
RecordsAfter(ctx context.Context, id string) (records []*consensusproto.RawRecordWithId, err error)
Get(id string) (*AclRecord, error) Get(id string) (*AclRecord, error)
GetIndex(idx int) (*AclRecord, error)
Iterate(iterFunc IterFunc) Iterate(iterFunc IterFunc)
IterateFrom(startId string, iterFunc IterFunc) IterateFrom(startId string, iterFunc IterFunc)
AddRawRecord(rawRec *aclrecordproto.RawAclRecordWithId) (added bool, err error) KeyStorage() crypto.KeyStorage
RecordBuilder() AclRecordBuilder
Close() (err error) ValidateRawRecord(record *consensusproto.RawRecord) (err error)
AddRawRecord(rawRec *consensusproto.RawRecordWithId) (err error)
AddRawRecords(rawRecords []*consensusproto.RawRecordWithId) (err error)
Close(ctx context.Context) (err error)
} }
type aclList struct { type aclList struct {
root *aclrecordproto.RawAclRecordWithId root *consensusproto.RawRecordWithId
records []*AclRecord records []*AclRecord
indexes map[string]int indexes map[string]int
id string id string
stateBuilder *aclStateBuilder stateBuilder *aclStateBuilder
recordBuilder AclRecordBuilder recordBuilder AclRecordBuilder
keyStorage crypto.KeyStorage
aclState *AclState aclState *AclState
keychain *keychain.Keychain
storage liststorage.ListStorage storage liststorage.ListStorage
sync.RWMutex sync.RWMutex
} }
func BuildAclListWithIdentity(acc *accountdata.AccountData, storage liststorage.ListStorage) (AclList, error) { type internalDeps struct {
builder := newAclStateBuilderWithIdentity(acc) storage liststorage.ListStorage
return build(storage.Id(), builder, newAclRecordBuilder(storage.Id(), keychain.NewKeychain()), storage) keyStorage crypto.KeyStorage
stateBuilder *aclStateBuilder
recordBuilder AclRecordBuilder
acceptorVerifier AcceptorVerifier
} }
func BuildAclList(storage liststorage.ListStorage) (AclList, error) { func BuildAclListWithIdentity(acc *accountdata.AccountKeys, storage liststorage.ListStorage, verifier AcceptorVerifier) (AclList, error) {
return build(storage.Id(), newAclStateBuilder(), newAclRecordBuilder(storage.Id(), keychain.NewKeychain()), storage) keyStorage := crypto.NewKeyStorage()
deps := internalDeps{
storage: storage,
keyStorage: keyStorage,
stateBuilder: newAclStateBuilderWithIdentity(acc),
recordBuilder: NewAclRecordBuilder(storage.Id(), keyStorage, acc, verifier),
acceptorVerifier: verifier,
}
return build(deps)
} }
func build(id string, stateBuilder *aclStateBuilder, recBuilder AclRecordBuilder, storage liststorage.ListStorage) (list AclList, err error) { func BuildAclList(storage liststorage.ListStorage, verifier AcceptorVerifier) (AclList, error) {
keyStorage := crypto.NewKeyStorage()
deps := internalDeps{
storage: storage,
keyStorage: keyStorage,
stateBuilder: newAclStateBuilder(),
recordBuilder: NewAclRecordBuilder(storage.Id(), keyStorage, nil, verifier),
acceptorVerifier: verifier,
}
return build(deps)
}
func build(deps internalDeps) (list AclList, err error) {
var (
storage = deps.storage
id = deps.storage.Id()
recBuilder = deps.recordBuilder
stateBuilder = deps.stateBuilder
)
head, err := storage.Head() head, err := storage.Head()
if err != nil { if err != nil {
return return
@ -74,7 +128,7 @@ func build(id string, stateBuilder *aclStateBuilder, recBuilder AclRecordBuilder
return return
} }
record, err := recBuilder.ConvertFromRaw(rawRecordWithId) record, err := recBuilder.UnmarshallWithId(rawRecordWithId)
if err != nil { if err != nil {
return return
} }
@ -86,7 +140,7 @@ func build(id string, stateBuilder *aclStateBuilder, recBuilder AclRecordBuilder
return return
} }
record, err = recBuilder.ConvertFromRaw(rawRecordWithId) record, err = recBuilder.UnmarshallWithId(rawRecordWithId)
if err != nil { if err != nil {
return return
} }
@ -116,6 +170,7 @@ func build(id string, stateBuilder *aclStateBuilder, recBuilder AclRecordBuilder
return return
} }
recBuilder.(*aclRecordBuilder).state = state
list = &aclList{ list = &aclList{
root: rootWithId, root: rootWithId,
records: records, records: records,
@ -129,15 +184,37 @@ func build(id string, stateBuilder *aclStateBuilder, recBuilder AclRecordBuilder
return return
} }
func (a *aclList) RecordBuilder() AclRecordBuilder {
return a.recordBuilder
}
func (a *aclList) Records() []*AclRecord { func (a *aclList) Records() []*AclRecord {
return a.records return a.records
} }
func (a *aclList) AddRawRecord(rawRec *aclrecordproto.RawAclRecordWithId) (added bool, err error) { func (a *aclList) ValidateRawRecord(rawRec *consensusproto.RawRecord) (err error) {
if _, ok := a.indexes[rawRec.Id]; ok { record, err := a.recordBuilder.Unmarshall(rawRec)
if err != nil {
return return
} }
record, err := a.recordBuilder.ConvertFromRaw(rawRec) return a.aclState.Validator().ValidateAclRecordContents(record)
}
func (a *aclList) AddRawRecords(rawRecords []*consensusproto.RawRecordWithId) (err error) {
for _, rec := range rawRecords {
err = a.AddRawRecord(rec)
if err != nil && err != ErrRecordAlreadyExists {
return
}
}
return
}
func (a *aclList) AddRawRecord(rawRec *consensusproto.RawRecordWithId) (err error) {
if _, ok := a.indexes[rawRec.Id]; ok {
return ErrRecordAlreadyExists
}
record, err := a.recordBuilder.UnmarshallWithId(rawRec)
if err != nil { if err != nil {
return return
} }
@ -152,15 +229,6 @@ func (a *aclList) AddRawRecord(rawRec *aclrecordproto.RawAclRecordWithId) (added
if err = a.storage.SetHead(rawRec.Id); err != nil { if err = a.storage.SetHead(rawRec.Id); err != nil {
return return
} }
return true, nil
}
func (a *aclList) IsValidNext(rawRec *aclrecordproto.RawAclRecordWithId) (err error) {
_, err = a.recordBuilder.ConvertFromRaw(rawRec)
if err != nil {
return
}
// TODO: change state and add "check" method for records
return return
} }
@ -168,7 +236,7 @@ func (a *aclList) Id() string {
return a.id return a.id
} }
func (a *aclList) Root() *aclrecordproto.RawAclRecordWithId { func (a *aclList) Root() *consensusproto.RawRecordWithId {
return a.root return a.root
} }
@ -176,6 +244,10 @@ func (a *aclList) AclState() *AclState {
return a.aclState return a.aclState
} }
func (a *aclList) KeyStorage() crypto.KeyStorage {
return a.keyStorage
}
func (a *aclList) IsAfter(first string, second string) (bool, error) { func (a *aclList) IsAfter(first string, second string) (bool, error) {
firstRec, okFirst := a.indexes[first] firstRec, okFirst := a.indexes[first]
secondRec, okSecond := a.indexes[second] secondRec, okSecond := a.indexes[second]
@ -189,14 +261,27 @@ func (a *aclList) Head() *AclRecord {
return a.records[len(a.records)-1] return a.records[len(a.records)-1]
} }
func (a *aclList) HasHead(head string) bool {
_, exists := a.indexes[head]
return exists
}
func (a *aclList) Get(id string) (*AclRecord, error) { func (a *aclList) Get(id string) (*AclRecord, error) {
recIdx, ok := a.indexes[id] recIdx, ok := a.indexes[id]
if !ok { if !ok {
return nil, fmt.Errorf("no such record") return nil, ErrNoSuchRecord
} }
return a.records[recIdx], nil return a.records[recIdx], nil
} }
func (a *aclList) GetIndex(idx int) (*AclRecord, error) {
// TODO: when we add snapshots we will have to monitor record num in snapshots
if idx < 0 || idx >= len(a.records) {
return nil, ErrNoSuchRecord
}
return a.records[idx], nil
}
func (a *aclList) Iterate(iterFunc IterFunc) { func (a *aclList) Iterate(iterFunc IterFunc) {
for _, rec := range a.records { for _, rec := range a.records {
if !iterFunc(rec) { if !iterFunc(rec) {
@ -205,6 +290,21 @@ func (a *aclList) Iterate(iterFunc IterFunc) {
} }
} }
func (a *aclList) RecordsAfter(ctx context.Context, id string) (records []*consensusproto.RawRecordWithId, err error) {
recIdx, ok := a.indexes[id]
if !ok {
return nil, ErrNoSuchRecord
}
for i := recIdx + 1; i < len(a.records); i++ {
rawRec, err := a.storage.GetRawRecord(ctx, a.records[i].Id)
if err != nil {
return nil, err
}
records = append(records, rawRec)
}
return
}
func (a *aclList) IterateFrom(startId string, iterFunc IterFunc) { func (a *aclList) IterateFrom(startId string, iterFunc IterFunc) {
recIdx, ok := a.indexes[startId] recIdx, ok := a.indexes[startId]
if !ok { if !ok {
@ -217,6 +317,21 @@ func (a *aclList) IterateFrom(startId string, iterFunc IterFunc) {
} }
} }
func (a *aclList) Close() (err error) { func (a *aclList) Close(ctx context.Context) (err error) {
return nil return nil
} }
func WrapAclRecord(rawRec *consensusproto.RawRecord) *consensusproto.RawRecordWithId {
payload, err := rawRec.Marshal()
if err != nil {
panic(err)
}
id, err := cidutil.NewCidFromBytes(payload)
if err != nil {
panic(err)
}
return &consensusproto.RawRecordWithId{
Payload: payload,
Id: id,
}
}

View File

@ -1,91 +1,293 @@
package list package list
import ( import (
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto" "fmt"
"github.com/anytypeio/any-sync/commonspace/object/acl/testutils/acllistbuilder"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing" "testing"
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/util/crypto"
"github.com/stretchr/testify/require"
) )
func TestAclList_AclState_UserInviteAndJoin(t *testing.T) { type aclFixture struct {
st, err := acllistbuilder.NewListStorageWithTestName("userjoinexample.yml") ownerKeys *accountdata.AccountKeys
require.NoError(t, err, "building storage should not result in error") accountKeys *accountdata.AccountKeys
ownerAcl *aclList
keychain := st.(*acllistbuilder.AclListStorageBuilder).GetKeychain() accountAcl *aclList
spaceId string
aclList, err := BuildAclList(st)
require.NoError(t, err, "building acl list should be without error")
idA := keychain.GetIdentity("A")
idB := keychain.GetIdentity("B")
idC := keychain.GetIdentity("C")
// checking final state
assert.Equal(t, aclrecordproto.AclUserPermissions_Admin, aclList.AclState().UserStates()[idA].Permissions)
assert.Equal(t, aclrecordproto.AclUserPermissions_Writer, aclList.AclState().UserStates()[idB].Permissions)
assert.Equal(t, aclrecordproto.AclUserPermissions_Reader, aclList.AclState().UserStates()[idC].Permissions)
assert.Equal(t, aclList.Head().CurrentReadKeyHash, aclList.AclState().CurrentReadKeyHash())
var records []*AclRecord
aclList.Iterate(func(record *AclRecord) (IsContinue bool) {
records = append(records, record)
return true
})
// checking permissions at specific records
assert.Equal(t, 3, len(records))
_, err = aclList.AclState().PermissionsAtRecord(records[1].Id, idB)
assert.Error(t, err, "B should have no permissions at record 1")
perm, err := aclList.AclState().PermissionsAtRecord(records[2].Id, idB)
assert.NoError(t, err, "should have no error with permissions of B in the record 2")
assert.Equal(t, UserPermissionPair{
Identity: idB,
Permission: aclrecordproto.AclUserPermissions_Writer,
}, perm)
} }
func TestAclList_AclState_UserJoinAndRemove(t *testing.T) { func newFixture(t *testing.T) *aclFixture {
st, err := acllistbuilder.NewListStorageWithTestName("userremoveexample.yml") ownerKeys, err := accountdata.NewRandom()
require.NoError(t, err, "building storage should not result in error") require.NoError(t, err)
accountKeys, err := accountdata.NewRandom()
keychain := st.(*acllistbuilder.AclListStorageBuilder).GetKeychain() require.NoError(t, err)
spaceId := "spaceId"
aclList, err := BuildAclList(st) ownerAcl, err := NewTestDerivedAcl(spaceId, ownerKeys)
require.NoError(t, err, "building acl list should be without error") require.NoError(t, err)
accountAcl, err := NewTestAclWithRoot(accountKeys, ownerAcl.Root())
idA := keychain.GetIdentity("A") require.NoError(t, err)
idB := keychain.GetIdentity("B") return &aclFixture{
idC := keychain.GetIdentity("C") ownerKeys: ownerKeys,
accountKeys: accountKeys,
// checking final state ownerAcl: ownerAcl.(*aclList),
assert.Equal(t, aclrecordproto.AclUserPermissions_Admin, aclList.AclState().UserStates()[idA].Permissions) accountAcl: accountAcl.(*aclList),
assert.Equal(t, aclrecordproto.AclUserPermissions_Reader, aclList.AclState().UserStates()[idC].Permissions) spaceId: spaceId,
assert.Equal(t, aclList.Head().CurrentReadKeyHash, aclList.AclState().CurrentReadKeyHash()) }
}
_, exists := aclList.AclState().UserStates()[idB]
assert.Equal(t, false, exists) func (fx *aclFixture) addRec(t *testing.T, rec *consensusproto.RawRecordWithId) {
err := fx.ownerAcl.AddRawRecord(rec)
var records []*AclRecord require.NoError(t, err)
aclList.Iterate(func(record *AclRecord) (IsContinue bool) { err = fx.accountAcl.AddRawRecord(rec)
records = append(records, record) require.NoError(t, err)
return true }
})
func (fx *aclFixture) inviteAccount(t *testing.T, perms AclPermissions) {
// checking permissions at specific records var (
assert.Equal(t, 4, len(records)) ownerAcl = fx.ownerAcl
ownerState = fx.ownerAcl.aclState
assert.NotEqual(t, records[2].CurrentReadKeyHash, aclList.AclState().CurrentReadKeyHash()) accountAcl = fx.accountAcl
accountState = fx.accountAcl.aclState
perm, err := aclList.AclState().PermissionsAtRecord(records[2].Id, idB) )
assert.NoError(t, err, "should have no error with permissions of B in the record 2") // building invite
assert.Equal(t, UserPermissionPair{ inv, err := ownerAcl.RecordBuilder().BuildInvite()
Identity: idB, require.NoError(t, err)
Permission: aclrecordproto.AclUserPermissions_Writer, inviteRec := WrapAclRecord(inv.InviteRec)
}, perm) fx.addRec(t, inviteRec)
_, err = aclList.AclState().PermissionsAtRecord(records[3].Id, idB) // building request join
assert.Error(t, err, "B should have no permissions at record 3, because user should be removed") requestJoin, err := accountAcl.RecordBuilder().BuildRequestJoin(RequestJoinPayload{
InviteRecordId: inviteRec.Id,
InviteKey: inv.InviteKey,
})
require.NoError(t, err)
requestJoinRec := WrapAclRecord(requestJoin)
fx.addRec(t, requestJoinRec)
// building request accept
requestAccept, err := ownerAcl.RecordBuilder().BuildRequestAccept(RequestAcceptPayload{
RequestRecordId: requestJoinRec.Id,
Permissions: perms,
})
require.NoError(t, err)
// validate
err = ownerAcl.ValidateRawRecord(requestAccept)
require.NoError(t, err)
requestAcceptRec := WrapAclRecord(requestAccept)
fx.addRec(t, requestAcceptRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).CanWrite())
require.Equal(t, 0, len(ownerState.pendingRequests))
require.Equal(t, 0, len(accountState.pendingRequests))
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, accountState.Permissions(accountState.pubKey).CanWrite())
_, err = ownerState.StateAtRecord(requestJoinRec.Id, accountState.pubKey)
require.Equal(t, ErrNoSuchAccount, err)
stateAtRec, err := ownerState.StateAtRecord(requestAcceptRec.Id, accountState.pubKey)
require.NoError(t, err)
require.True(t, stateAtRec.Permissions == perms)
}
func TestAclList_BuildRoot(t *testing.T) {
randomKeys, err := accountdata.NewRandom()
require.NoError(t, err)
randomAcl, err := NewTestDerivedAcl("spaceId", randomKeys)
require.NoError(t, err)
fmt.Println(randomAcl.Id())
}
func TestAclList_InvitePipeline(t *testing.T) {
fx := newFixture(t)
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Writer))
}
func TestAclList_InviteRevoke(t *testing.T) {
fx := newFixture(t)
var (
ownerState = fx.ownerAcl.aclState
accountState = fx.accountAcl.aclState
)
// building invite
inv, err := fx.ownerAcl.RecordBuilder().BuildInvite()
require.NoError(t, err)
inviteRec := WrapAclRecord(inv.InviteRec)
fx.addRec(t, inviteRec)
// building invite revoke
inviteRevoke, err := fx.ownerAcl.RecordBuilder().BuildInviteRevoke(ownerState.lastRecordId)
require.NoError(t, err)
inviteRevokeRec := WrapAclRecord(inviteRevoke)
fx.addRec(t, inviteRevokeRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
require.Empty(t, ownerState.inviteKeys)
require.Empty(t, accountState.inviteKeys)
}
func TestAclList_RequestDecline(t *testing.T) {
fx := newFixture(t)
var (
ownerAcl = fx.ownerAcl
ownerState = fx.ownerAcl.aclState
accountAcl = fx.accountAcl
accountState = fx.accountAcl.aclState
)
// building invite
inv, err := ownerAcl.RecordBuilder().BuildInvite()
require.NoError(t, err)
inviteRec := WrapAclRecord(inv.InviteRec)
fx.addRec(t, inviteRec)
// building request join
requestJoin, err := accountAcl.RecordBuilder().BuildRequestJoin(RequestJoinPayload{
InviteRecordId: inviteRec.Id,
InviteKey: inv.InviteKey,
})
require.NoError(t, err)
requestJoinRec := WrapAclRecord(requestJoin)
fx.addRec(t, requestJoinRec)
// building request decline
requestDecline, err := ownerAcl.RecordBuilder().BuildRequestDecline(ownerState.lastRecordId)
require.NoError(t, err)
requestDeclineRec := WrapAclRecord(requestDecline)
fx.addRec(t, requestDeclineRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
require.Empty(t, ownerState.pendingRequests)
require.Empty(t, accountState.pendingRequests)
}
func TestAclList_Remove(t *testing.T) {
fx := newFixture(t)
var (
ownerState = fx.ownerAcl.aclState
accountState = fx.accountAcl.aclState
)
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Writer))
newReadKey := crypto.NewAES()
remove, err := fx.ownerAcl.RecordBuilder().BuildAccountRemove(AccountRemovePayload{
Identities: []crypto.PubKey{fx.accountKeys.SignKey.GetPublic()},
ReadKey: newReadKey,
})
require.NoError(t, err)
removeRec := WrapAclRecord(remove)
fx.addRec(t, removeRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
require.True(t, ownerState.userReadKeys[removeRec.Id].Equals(newReadKey))
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
require.Equal(t, 0, len(ownerState.pendingRequests))
require.Equal(t, 0, len(accountState.pendingRequests))
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, accountState.Permissions(accountState.pubKey).NoPermissions())
require.Nil(t, accountState.userReadKeys[removeRec.Id])
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
}
func TestAclList_ReadKeyChange(t *testing.T) {
fx := newFixture(t)
var (
ownerState = fx.ownerAcl.aclState
accountState = fx.accountAcl.aclState
)
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Admin))
newReadKey := crypto.NewAES()
readKeyChange, err := fx.ownerAcl.RecordBuilder().BuildReadKeyChange(newReadKey)
require.NoError(t, err)
readKeyRec := WrapAclRecord(readKeyChange)
fx.addRec(t, readKeyRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).CanManageAccounts())
require.True(t, ownerState.userReadKeys[readKeyRec.Id].Equals(newReadKey))
require.True(t, accountState.userReadKeys[readKeyRec.Id].Equals(newReadKey))
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
readKey, err := ownerState.CurrentReadKey()
require.NoError(t, err)
require.True(t, newReadKey.Equals(readKey))
require.Equal(t, 0, len(ownerState.pendingRequests))
require.Equal(t, 0, len(accountState.pendingRequests))
}
func TestAclList_PermissionChange(t *testing.T) {
fx := newFixture(t)
var (
ownerState = fx.ownerAcl.aclState
accountState = fx.accountAcl.aclState
)
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Admin))
permissionChange, err := fx.ownerAcl.RecordBuilder().BuildPermissionChange(PermissionChangePayload{
Identity: fx.accountKeys.SignKey.GetPublic(),
Permissions: AclPermissions(aclrecordproto.AclUserPermissions_Writer),
})
require.NoError(t, err)
permissionChangeRec := WrapAclRecord(permissionChange)
fx.addRec(t, permissionChangeRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey) == AclPermissions(aclrecordproto.AclUserPermissions_Writer))
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, accountState.Permissions(accountState.pubKey) == AclPermissions(aclrecordproto.AclUserPermissions_Writer))
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
require.Equal(t, 0, len(ownerState.pendingRequests))
require.Equal(t, 0, len(accountState.pendingRequests))
}
func TestAclList_RequestRemove(t *testing.T) {
fx := newFixture(t)
var (
ownerState = fx.ownerAcl.aclState
accountState = fx.accountAcl.aclState
)
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Writer))
removeRequest, err := fx.accountAcl.RecordBuilder().BuildRequestRemove()
require.NoError(t, err)
removeRequestRec := WrapAclRecord(removeRequest)
fx.addRec(t, removeRequestRec)
recs := fx.accountAcl.AclState().RemoveRecords()
require.Len(t, recs, 1)
require.True(t, accountState.pubKey.Equals(recs[0].RequestIdentity))
newReadKey := crypto.NewAES()
remove, err := fx.ownerAcl.RecordBuilder().BuildAccountRemove(AccountRemovePayload{
Identities: []crypto.PubKey{recs[0].RequestIdentity},
ReadKey: newReadKey,
})
require.NoError(t, err)
removeRec := WrapAclRecord(remove)
fx.addRec(t, removeRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
require.True(t, ownerState.userReadKeys[removeRec.Id].Equals(newReadKey))
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
require.Equal(t, 0, len(ownerState.pendingRequests))
require.Equal(t, 0, len(accountState.pendingRequests))
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, accountState.Permissions(accountState.pubKey).NoPermissions())
require.Nil(t, accountState.userReadKeys[removeRec.Id])
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
} }

View File

@ -0,0 +1,41 @@
package list
import (
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/util/crypto"
)
func NewTestDerivedAcl(spaceId string, keys *accountdata.AccountKeys) (AclList, error) {
builder := NewAclRecordBuilder("", crypto.NewKeyStorage(), keys, NoOpAcceptorVerifier{})
masterKey, _, err := crypto.GenerateRandomEd25519KeyPair()
if err != nil {
return nil, err
}
root, err := builder.BuildRoot(RootContent{
PrivKey: keys.SignKey,
SpaceId: spaceId,
MasterKey: masterKey,
})
if err != nil {
return nil, err
}
st, err := liststorage.NewInMemoryAclListStorage(root.Id, []*consensusproto.RawRecordWithId{
root,
})
if err != nil {
return nil, err
}
return BuildAclListWithIdentity(keys, st, NoOpAcceptorVerifier{})
}
func NewTestAclWithRoot(keys *accountdata.AccountKeys, root *consensusproto.RawRecordWithId) (AclList, error) {
st, err := liststorage.NewInMemoryAclListStorage(root.Id, []*consensusproto.RawRecordWithId{
root,
})
if err != nil {
return nil, err
}
return BuildAclListWithIdentity(keys, st, NoOpAcceptorVerifier{})
}

View File

@ -1,15 +1,17 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anytypeio/any-sync/commonspace/object/acl/list (interfaces: AclList) // Source: github.com/anyproto/any-sync/commonspace/object/acl/list (interfaces: AclList)
// Package mock_list is a generated GoMock package. // Package mock_list is a generated GoMock package.
package mock_list package mock_list
import ( import (
context "context"
reflect "reflect" reflect "reflect"
aclrecordproto "github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto" list "github.com/anyproto/any-sync/commonspace/object/acl/list"
list "github.com/anytypeio/any-sync/commonspace/object/acl/list" consensusproto "github.com/anyproto/any-sync/consensus/consensusproto"
gomock "github.com/golang/mock/gomock" crypto "github.com/anyproto/any-sync/util/crypto"
gomock "go.uber.org/mock/gomock"
) )
// MockAclList is a mock of AclList interface. // MockAclList is a mock of AclList interface.
@ -50,12 +52,11 @@ func (mr *MockAclListMockRecorder) AclState() *gomock.Call {
} }
// AddRawRecord mocks base method. // AddRawRecord mocks base method.
func (m *MockAclList) AddRawRecord(arg0 *aclrecordproto.RawAclRecordWithId) (bool, error) { func (m *MockAclList) AddRawRecord(arg0 *consensusproto.RawRecordWithId) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddRawRecord", arg0) ret := m.ctrl.Call(m, "AddRawRecord", arg0)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(error)
ret1, _ := ret[1].(error) return ret0
return ret0, ret1
} }
// AddRawRecord indicates an expected call of AddRawRecord. // AddRawRecord indicates an expected call of AddRawRecord.
@ -64,18 +65,32 @@ func (mr *MockAclListMockRecorder) AddRawRecord(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockAclList)(nil).AddRawRecord), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockAclList)(nil).AddRawRecord), arg0)
} }
// Close mocks base method. // AddRawRecords mocks base method.
func (m *MockAclList) Close() error { func (m *MockAclList) AddRawRecords(arg0 []*consensusproto.RawRecordWithId) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close") ret := m.ctrl.Call(m, "AddRawRecords", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// AddRawRecords indicates an expected call of AddRawRecords.
func (mr *MockAclListMockRecorder) AddRawRecords(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecords", reflect.TypeOf((*MockAclList)(nil).AddRawRecords), arg0)
}
// Close mocks base method.
func (m *MockAclList) Close(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Close indicates an expected call of Close. // Close indicates an expected call of Close.
func (mr *MockAclListMockRecorder) Close() *gomock.Call { func (mr *MockAclListMockRecorder) Close(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAclList)(nil).Close)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAclList)(nil).Close), arg0)
} }
// Get mocks base method. // Get mocks base method.
@ -93,6 +108,35 @@ func (mr *MockAclListMockRecorder) Get(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAclList)(nil).Get), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAclList)(nil).Get), arg0)
} }
// GetIndex mocks base method.
func (m *MockAclList) GetIndex(arg0 int) (*list.AclRecord, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetIndex", arg0)
ret0, _ := ret[0].(*list.AclRecord)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetIndex indicates an expected call of GetIndex.
func (mr *MockAclListMockRecorder) GetIndex(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIndex", reflect.TypeOf((*MockAclList)(nil).GetIndex), arg0)
}
// HasHead mocks base method.
func (m *MockAclList) HasHead(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HasHead", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// HasHead indicates an expected call of HasHead.
func (mr *MockAclListMockRecorder) HasHead(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasHead", reflect.TypeOf((*MockAclList)(nil).HasHead), arg0)
}
// Head mocks base method. // Head mocks base method.
func (m *MockAclList) Head() *list.AclRecord { func (m *MockAclList) Head() *list.AclRecord {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -160,6 +204,20 @@ func (mr *MockAclListMockRecorder) IterateFrom(arg0, arg1 interface{}) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateFrom", reflect.TypeOf((*MockAclList)(nil).IterateFrom), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateFrom", reflect.TypeOf((*MockAclList)(nil).IterateFrom), arg0, arg1)
} }
// KeyStorage mocks base method.
func (m *MockAclList) KeyStorage() crypto.KeyStorage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeyStorage")
ret0, _ := ret[0].(crypto.KeyStorage)
return ret0
}
// KeyStorage indicates an expected call of KeyStorage.
func (mr *MockAclListMockRecorder) KeyStorage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyStorage", reflect.TypeOf((*MockAclList)(nil).KeyStorage))
}
// Lock mocks base method. // Lock mocks base method.
func (m *MockAclList) Lock() { func (m *MockAclList) Lock() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -196,6 +254,20 @@ func (mr *MockAclListMockRecorder) RUnlock() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RUnlock", reflect.TypeOf((*MockAclList)(nil).RUnlock)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RUnlock", reflect.TypeOf((*MockAclList)(nil).RUnlock))
} }
// RecordBuilder mocks base method.
func (m *MockAclList) RecordBuilder() list.AclRecordBuilder {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RecordBuilder")
ret0, _ := ret[0].(list.AclRecordBuilder)
return ret0
}
// RecordBuilder indicates an expected call of RecordBuilder.
func (mr *MockAclListMockRecorder) RecordBuilder() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordBuilder", reflect.TypeOf((*MockAclList)(nil).RecordBuilder))
}
// Records mocks base method. // Records mocks base method.
func (m *MockAclList) Records() []*list.AclRecord { func (m *MockAclList) Records() []*list.AclRecord {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -210,11 +282,26 @@ func (mr *MockAclListMockRecorder) Records() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Records", reflect.TypeOf((*MockAclList)(nil).Records)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Records", reflect.TypeOf((*MockAclList)(nil).Records))
} }
// RecordsAfter mocks base method.
func (m *MockAclList) RecordsAfter(arg0 context.Context, arg1 string) ([]*consensusproto.RawRecordWithId, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RecordsAfter", arg0, arg1)
ret0, _ := ret[0].([]*consensusproto.RawRecordWithId)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RecordsAfter indicates an expected call of RecordsAfter.
func (mr *MockAclListMockRecorder) RecordsAfter(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordsAfter", reflect.TypeOf((*MockAclList)(nil).RecordsAfter), arg0, arg1)
}
// Root mocks base method. // Root mocks base method.
func (m *MockAclList) Root() *aclrecordproto.RawAclRecordWithId { func (m *MockAclList) Root() *consensusproto.RawRecordWithId {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Root") ret := m.ctrl.Call(m, "Root")
ret0, _ := ret[0].(*aclrecordproto.RawAclRecordWithId) ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
return ret0 return ret0
} }
@ -235,3 +322,17 @@ func (mr *MockAclListMockRecorder) Unlock() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockAclList)(nil).Unlock)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockAclList)(nil).Unlock))
} }
// ValidateRawRecord mocks base method.
func (m *MockAclList) ValidateRawRecord(arg0 *consensusproto.RawRecord) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ValidateRawRecord", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// ValidateRawRecord indicates an expected call of ValidateRawRecord.
func (mr *MockAclListMockRecorder) ValidateRawRecord(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateRawRecord", reflect.TypeOf((*MockAclList)(nil).ValidateRawRecord), arg0)
}

View File

@ -0,0 +1,69 @@
package list
import (
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/util/crypto"
)
type AclRecord struct {
Id string
PrevId string
Timestamp int64
Data []byte
Identity crypto.PubKey
Model interface{}
Signature []byte
}
type RequestRecord struct {
RequestIdentity crypto.PubKey
RequestMetadata []byte
Type RequestType
}
type AclUserState struct {
PubKey crypto.PubKey
Permissions AclPermissions
RequestMetadata []byte
}
type RequestType int
const (
RequestTypeRemove RequestType = iota
RequestTypeJoin
)
type AclPermissions aclrecordproto.AclUserPermissions
func (p AclPermissions) NoPermissions() bool {
return aclrecordproto.AclUserPermissions(p) == aclrecordproto.AclUserPermissions_None
}
func (p AclPermissions) IsOwner() bool {
return aclrecordproto.AclUserPermissions(p) == aclrecordproto.AclUserPermissions_Owner
}
func (p AclPermissions) CanWrite() bool {
switch aclrecordproto.AclUserPermissions(p) {
case aclrecordproto.AclUserPermissions_Admin:
return true
case aclrecordproto.AclUserPermissions_Writer:
return true
case aclrecordproto.AclUserPermissions_Owner:
return true
default:
return false
}
}
func (p AclPermissions) CanManageAccounts() bool {
switch aclrecordproto.AclUserPermissions(p) {
case aclrecordproto.AclUserPermissions_Admin:
return true
case aclrecordproto.AclUserPermissions_Owner:
return true
default:
return false
}
}

View File

@ -1,12 +0,0 @@
package list
type AclRecord struct {
Id string
PrevId string
CurrentReadKeyHash uint64
Timestamp int64
Data []byte
Identity []byte
Model interface{}
Signature []byte
}

View File

@ -0,0 +1,218 @@
package list
import (
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/util/crypto"
)
type ContentValidator interface {
ValidateAclRecordContents(ch *AclRecord) (err error)
ValidatePermissionChange(ch *aclrecordproto.AclAccountPermissionChange, authorIdentity crypto.PubKey) (err error)
ValidateInvite(ch *aclrecordproto.AclAccountInvite, authorIdentity crypto.PubKey) (err error)
ValidateInviteRevoke(ch *aclrecordproto.AclAccountInviteRevoke, authorIdentity crypto.PubKey) (err error)
ValidateRequestJoin(ch *aclrecordproto.AclAccountRequestJoin, authorIdentity crypto.PubKey) (err error)
ValidateRequestAccept(ch *aclrecordproto.AclAccountRequestAccept, authorIdentity crypto.PubKey) (err error)
ValidateRequestDecline(ch *aclrecordproto.AclAccountRequestDecline, authorIdentity crypto.PubKey) (err error)
ValidateAccountRemove(ch *aclrecordproto.AclAccountRemove, authorIdentity crypto.PubKey) (err error)
ValidateRequestRemove(ch *aclrecordproto.AclAccountRequestRemove, authorIdentity crypto.PubKey) (err error)
ValidateReadKeyChange(ch *aclrecordproto.AclReadKeyChange, authorIdentity crypto.PubKey) (err error)
}
type contentValidator struct {
keyStore crypto.KeyStorage
aclState *AclState
}
func (c *contentValidator) ValidateAclRecordContents(ch *AclRecord) (err error) {
if ch.PrevId != c.aclState.lastRecordId {
return ErrIncorrectRecordSequence
}
aclData := ch.Model.(*aclrecordproto.AclData)
for _, content := range aclData.AclContent {
err = c.validateAclRecordContent(content, ch.Identity)
if err != nil {
return
}
}
return
}
func (c *contentValidator) validateAclRecordContent(ch *aclrecordproto.AclContentValue, authorIdentity crypto.PubKey) (err error) {
switch {
case ch.GetPermissionChange() != nil:
return c.ValidatePermissionChange(ch.GetPermissionChange(), authorIdentity)
case ch.GetInvite() != nil:
return c.ValidateInvite(ch.GetInvite(), authorIdentity)
case ch.GetInviteRevoke() != nil:
return c.ValidateInviteRevoke(ch.GetInviteRevoke(), authorIdentity)
case ch.GetRequestJoin() != nil:
return c.ValidateRequestJoin(ch.GetRequestJoin(), authorIdentity)
case ch.GetRequestAccept() != nil:
return c.ValidateRequestAccept(ch.GetRequestAccept(), authorIdentity)
case ch.GetRequestDecline() != nil:
return c.ValidateRequestDecline(ch.GetRequestDecline(), authorIdentity)
case ch.GetAccountRemove() != nil:
return c.ValidateAccountRemove(ch.GetAccountRemove(), authorIdentity)
case ch.GetAccountRequestRemove() != nil:
return c.ValidateRequestRemove(ch.GetAccountRequestRemove(), authorIdentity)
case ch.GetReadKeyChange() != nil:
return c.ValidateReadKeyChange(ch.GetReadKeyChange(), authorIdentity)
default:
return ErrUnexpectedContentType
}
}
func (c *contentValidator) ValidatePermissionChange(ch *aclrecordproto.AclAccountPermissionChange, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
chIdentity, err := c.keyStore.PubKeyFromProto(ch.Identity)
if err != nil {
return err
}
_, exists := c.aclState.userStates[mapKeyFromPubKey(chIdentity)]
if !exists {
return ErrNoSuchAccount
}
return
}
func (c *contentValidator) ValidateInvite(ch *aclrecordproto.AclAccountInvite, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
_, err = c.keyStore.PubKeyFromProto(ch.InviteKey)
return
}
func (c *contentValidator) ValidateInviteRevoke(ch *aclrecordproto.AclAccountInviteRevoke, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
_, exists := c.aclState.inviteKeys[ch.InviteRecordId]
if !exists {
return ErrNoSuchInvite
}
return
}
func (c *contentValidator) ValidateRequestJoin(ch *aclrecordproto.AclAccountRequestJoin, authorIdentity crypto.PubKey) (err error) {
inviteKey, exists := c.aclState.inviteKeys[ch.InviteRecordId]
if !exists {
return ErrNoSuchInvite
}
inviteIdentity, err := c.keyStore.PubKeyFromProto(ch.InviteIdentity)
if err != nil {
return
}
if _, exists := c.aclState.pendingRequests[mapKeyFromPubKey(inviteIdentity)]; exists {
return ErrPendingRequest
}
if !authorIdentity.Equals(inviteIdentity) {
return ErrIncorrectIdentity
}
rawInviteIdentity, err := inviteIdentity.Raw()
if err != nil {
return err
}
ok, err := inviteKey.Verify(rawInviteIdentity, ch.InviteIdentitySignature)
if err != nil {
return ErrInvalidSignature
}
if !ok {
return ErrInvalidSignature
}
return
}
func (c *contentValidator) ValidateRequestAccept(ch *aclrecordproto.AclAccountRequestAccept, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
record, exists := c.aclState.requestRecords[ch.RequestRecordId]
if !exists {
return ErrNoSuchRequest
}
acceptIdentity, err := c.keyStore.PubKeyFromProto(ch.Identity)
if err != nil {
return
}
if !acceptIdentity.Equals(record.RequestIdentity) {
return ErrIncorrectIdentity
}
if ch.Permissions == aclrecordproto.AclUserPermissions_Owner {
return ErrInsufficientPermissions
}
return
}
func (c *contentValidator) ValidateRequestDecline(ch *aclrecordproto.AclAccountRequestDecline, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
_, exists := c.aclState.requestRecords[ch.RequestRecordId]
if !exists {
return ErrNoSuchRequest
}
return
}
func (c *contentValidator) ValidateAccountRemove(ch *aclrecordproto.AclAccountRemove, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
seenIdentities := map[string]struct{}{}
for _, rawIdentity := range ch.Identities {
identity, err := c.keyStore.PubKeyFromProto(rawIdentity)
if err != nil {
return err
}
if identity.Equals(authorIdentity) {
return ErrInsufficientPermissions
}
permissions := c.aclState.Permissions(identity)
if permissions.NoPermissions() {
return ErrNoSuchAccount
}
if permissions.IsOwner() {
return ErrInsufficientPermissions
}
idKey := mapKeyFromPubKey(identity)
if _, exists := seenIdentities[idKey]; exists {
return ErrDuplicateAccounts
}
seenIdentities[mapKeyFromPubKey(identity)] = struct{}{}
}
return c.validateAccountReadKeys(ch.AccountKeys, len(c.aclState.userStates)-len(ch.Identities))
}
func (c *contentValidator) ValidateRequestRemove(ch *aclrecordproto.AclAccountRequestRemove, authorIdentity crypto.PubKey) (err error) {
if c.aclState.Permissions(authorIdentity).NoPermissions() {
return ErrInsufficientPermissions
}
if _, exists := c.aclState.pendingRequests[mapKeyFromPubKey(authorIdentity)]; exists {
return ErrPendingRequest
}
return
}
func (c *contentValidator) ValidateReadKeyChange(ch *aclrecordproto.AclReadKeyChange, authorIdentity crypto.PubKey) (err error) {
return c.validateAccountReadKeys(ch.AccountKeys, len(c.aclState.userStates))
}
func (c *contentValidator) validateAccountReadKeys(accountKeys []*aclrecordproto.AclEncryptedReadKey, usersNum int) (err error) {
if len(accountKeys) != usersNum {
return ErrIncorrectNumberOfAccounts
}
for _, encKeys := range accountKeys {
identity, err := c.keyStore.PubKeyFromProto(encKeys.Identity)
if err != nil {
return err
}
_, exists := c.aclState.userStates[mapKeyFromPubKey(identity)]
if !exists {
return ErrNoSuchAccount
}
}
return
}

View File

@ -3,24 +3,26 @@ package liststorage
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"sync" "sync"
) )
type inMemoryAclListStorage struct { type inMemoryAclListStorage struct {
id string id string
root *aclrecordproto.RawAclRecordWithId root *consensusproto.RawRecordWithId
head string head string
records map[string]*aclrecordproto.RawAclRecordWithId records map[string]*consensusproto.RawRecordWithId
sync.RWMutex sync.RWMutex
} }
func NewInMemoryAclListStorage( func NewInMemoryAclListStorage(
id string, id string,
records []*aclrecordproto.RawAclRecordWithId) (ListStorage, error) { records []*consensusproto.RawRecordWithId) (ListStorage, error) {
allRecords := make(map[string]*aclrecordproto.RawAclRecordWithId) allRecords := make(map[string]*consensusproto.RawRecordWithId)
for _, ch := range records { for _, ch := range records {
allRecords[ch.Id] = ch allRecords[ch.Id] = ch
} }
@ -41,7 +43,7 @@ func (t *inMemoryAclListStorage) Id() string {
return t.id return t.id
} }
func (t *inMemoryAclListStorage) Root() (*aclrecordproto.RawAclRecordWithId, error) { func (t *inMemoryAclListStorage) Root() (*consensusproto.RawRecordWithId, error) {
t.RLock() t.RLock()
defer t.RUnlock() defer t.RUnlock()
return t.root, nil return t.root, nil
@ -60,7 +62,7 @@ func (t *inMemoryAclListStorage) SetHead(head string) error {
return nil return nil
} }
func (t *inMemoryAclListStorage) AddRawRecord(ctx context.Context, record *aclrecordproto.RawAclRecordWithId) error { func (t *inMemoryAclListStorage) AddRawRecord(ctx context.Context, record *consensusproto.RawRecordWithId) error {
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
// TODO: better to do deep copy // TODO: better to do deep copy
@ -68,7 +70,7 @@ func (t *inMemoryAclListStorage) AddRawRecord(ctx context.Context, record *aclre
return nil return nil
} }
func (t *inMemoryAclListStorage) GetRawRecord(ctx context.Context, recordId string) (*aclrecordproto.RawAclRecordWithId, error) { func (t *inMemoryAclListStorage) GetRawRecord(ctx context.Context, recordId string) (*consensusproto.RawRecordWithId, error) {
t.RLock() t.RLock()
defer t.RUnlock() defer t.RUnlock()
if res, exists := t.records[recordId]; exists { if res, exists := t.records[recordId]; exists {

View File

@ -1,10 +1,11 @@
//go:generate mockgen -destination mock_liststorage/mock_liststorage.go github.com/anytypeio/any-sync/commonspace/object/acl/liststorage ListStorage //go:generate mockgen -destination mock_liststorage/mock_liststorage.go github.com/anyproto/any-sync/commonspace/object/acl/liststorage ListStorage
package liststorage package liststorage
import ( import (
"context" "context"
"errors" "errors"
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
) )
var ( var (
@ -13,12 +14,16 @@ var (
ErrUnknownRecord = errors.New("record doesn't exist") ErrUnknownRecord = errors.New("record doesn't exist")
) )
type Exporter interface {
ListStorage(root *consensusproto.RawRecordWithId) (ListStorage, error)
}
type ListStorage interface { type ListStorage interface {
Id() string Id() string
Root() (*aclrecordproto.RawAclRecordWithId, error) Root() (*consensusproto.RawRecordWithId, error)
Head() (string, error) Head() (string, error)
SetHead(headId string) error SetHead(headId string) error
GetRawRecord(ctx context.Context, id string) (*aclrecordproto.RawAclRecordWithId, error) GetRawRecord(ctx context.Context, id string) (*consensusproto.RawRecordWithId, error)
AddRawRecord(ctx context.Context, rec *aclrecordproto.RawAclRecordWithId) error AddRawRecord(ctx context.Context, rec *consensusproto.RawRecordWithId) error
} }

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anytypeio/any-sync/commonspace/object/acl/liststorage (interfaces: ListStorage) // Source: github.com/anyproto/any-sync/commonspace/object/acl/liststorage (interfaces: ListStorage)
// Package mock_liststorage is a generated GoMock package. // Package mock_liststorage is a generated GoMock package.
package mock_liststorage package mock_liststorage
@ -8,8 +8,8 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
aclrecordproto "github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto" consensusproto "github.com/anyproto/any-sync/consensus/consensusproto"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockListStorage is a mock of ListStorage interface. // MockListStorage is a mock of ListStorage interface.
@ -36,7 +36,7 @@ func (m *MockListStorage) EXPECT() *MockListStorageMockRecorder {
} }
// AddRawRecord mocks base method. // AddRawRecord mocks base method.
func (m *MockListStorage) AddRawRecord(arg0 context.Context, arg1 *aclrecordproto.RawAclRecordWithId) error { func (m *MockListStorage) AddRawRecord(arg0 context.Context, arg1 *consensusproto.RawRecordWithId) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddRawRecord", arg0, arg1) ret := m.ctrl.Call(m, "AddRawRecord", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@ -50,10 +50,10 @@ func (mr *MockListStorageMockRecorder) AddRawRecord(arg0, arg1 interface{}) *gom
} }
// GetRawRecord mocks base method. // GetRawRecord mocks base method.
func (m *MockListStorage) GetRawRecord(arg0 context.Context, arg1 string) (*aclrecordproto.RawAclRecordWithId, error) { func (m *MockListStorage) GetRawRecord(arg0 context.Context, arg1 string) (*consensusproto.RawRecordWithId, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRawRecord", arg0, arg1) ret := m.ctrl.Call(m, "GetRawRecord", arg0, arg1)
ret0, _ := ret[0].(*aclrecordproto.RawAclRecordWithId) ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@ -94,10 +94,10 @@ func (mr *MockListStorageMockRecorder) Id() *gomock.Call {
} }
// Root mocks base method. // Root mocks base method.
func (m *MockListStorage) Root() (*aclrecordproto.RawAclRecordWithId, error) { func (m *MockListStorage) Root() (*consensusproto.RawRecordWithId, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Root") ret := m.ctrl.Call(m, "Root")
ret0, _ := ret[0].(*aclrecordproto.RawAclRecordWithId) ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }

View File

@ -0,0 +1,120 @@
//go:generate mockgen -destination mock_syncacl/mock_syncacl.go github.com/anyproto/any-sync/commonspace/object/acl/syncacl SyncAcl,SyncClient,RequestFactory,AclSyncProtocol
package syncacl
import (
"context"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/consensus/consensusproto"
"go.uber.org/zap"
)
type AclSyncProtocol interface {
HeadUpdate(ctx context.Context, senderId string, update *consensusproto.LogHeadUpdate) (request *consensusproto.LogSyncMessage, err error)
FullSyncRequest(ctx context.Context, senderId string, request *consensusproto.LogFullSyncRequest) (response *consensusproto.LogSyncMessage, err error)
FullSyncResponse(ctx context.Context, senderId string, response *consensusproto.LogFullSyncResponse) (err error)
}
type aclSyncProtocol struct {
log logger.CtxLogger
spaceId string
aclList list.AclList
reqFactory RequestFactory
}
func (a *aclSyncProtocol) HeadUpdate(ctx context.Context, senderId string, update *consensusproto.LogHeadUpdate) (request *consensusproto.LogSyncMessage, err error) {
isEmptyUpdate := len(update.Records) == 0
log := a.log.With(
zap.String("senderId", senderId),
zap.String("update head", update.Head),
zap.Int("len(update records)", len(update.Records)))
log.DebugCtx(ctx, "received acl head update message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "acl head update finished with error", zap.Error(err))
} else if request != nil {
cnt := request.Content.GetFullSyncRequest()
log.DebugCtx(ctx, "returning acl full sync request", zap.String("request head", cnt.Head))
} else {
if !isEmptyUpdate {
log.DebugCtx(ctx, "acl head update finished correctly")
}
}
}()
if isEmptyUpdate {
headEquals := a.aclList.Head().Id == update.Head
log.DebugCtx(ctx, "is empty acl head update", zap.Bool("headEquals", headEquals))
if headEquals {
return
}
return a.reqFactory.CreateFullSyncRequest(a.aclList, update.Head)
}
if a.aclList.HasHead(update.Head) {
return
}
err = a.aclList.AddRawRecords(update.Records)
if err == list.ErrIncorrectRecordSequence {
return a.reqFactory.CreateFullSyncRequest(a.aclList, update.Head)
}
return
}
func (a *aclSyncProtocol) FullSyncRequest(ctx context.Context, senderId string, request *consensusproto.LogFullSyncRequest) (response *consensusproto.LogSyncMessage, err error) {
log := a.log.With(
zap.String("senderId", senderId),
zap.String("request head", request.Head),
zap.Int("len(request records)", len(request.Records)))
log.DebugCtx(ctx, "received acl full sync request message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "acl full sync request finished with error", zap.Error(err))
} else if response != nil {
cnt := response.Content.GetFullSyncResponse()
log.DebugCtx(ctx, "acl full sync response sent", zap.String("response head", cnt.Head), zap.Int("len(response records)", len(cnt.Records)))
}
}()
if !a.aclList.HasHead(request.Head) {
if len(request.Records) > 0 {
// in this case we can try to add some records
err = a.aclList.AddRawRecords(request.Records)
if err != nil {
return
}
} else {
// here it is impossible for us to do anything, we can't return records after head as defined in request, because we don't have it
return nil, list.ErrIncorrectRecordSequence
}
}
return a.reqFactory.CreateFullSyncResponse(a.aclList, request.Head)
}
func (a *aclSyncProtocol) FullSyncResponse(ctx context.Context, senderId string, response *consensusproto.LogFullSyncResponse) (err error) {
log := a.log.With(
zap.String("senderId", senderId),
zap.String("response head", response.Head),
zap.Int("len(response records)", len(response.Records)))
log.DebugCtx(ctx, "received acl full sync response message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "acl full sync response failed", zap.Error(err))
} else {
log.DebugCtx(ctx, "acl full sync response succeeded")
}
}()
if a.aclList.HasHead(response.Head) {
return
}
return a.aclList.AddRawRecords(response.Records)
}
func newAclSyncProtocol(spaceId string, aclList list.AclList, reqFactory RequestFactory) *aclSyncProtocol {
return &aclSyncProtocol{
log: log.With(zap.String("spaceId", spaceId), zap.String("aclId", aclList.Id())),
spaceId: spaceId,
aclList: aclList,
reqFactory: reqFactory,
}
}

View File

@ -0,0 +1,213 @@
package syncacl
import (
"context"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/acl/list/mock_list"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl/mock_syncacl"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"testing"
)
type aclSyncProtocolFixture struct {
log logger.CtxLogger
spaceId string
senderId string
aclId string
aclMock *mock_list.MockAclList
reqFactory *mock_syncacl.MockRequestFactory
ctrl *gomock.Controller
syncProtocol AclSyncProtocol
}
func newSyncProtocolFixture(t *testing.T) *aclSyncProtocolFixture {
ctrl := gomock.NewController(t)
aclList := mock_list.NewMockAclList(ctrl)
spaceId := "spaceId"
reqFactory := mock_syncacl.NewMockRequestFactory(ctrl)
aclList.EXPECT().Id().Return("aclId")
syncProtocol := newAclSyncProtocol(spaceId, aclList, reqFactory)
return &aclSyncProtocolFixture{
log: log,
spaceId: spaceId,
senderId: "senderId",
aclId: "aclId",
aclMock: aclList,
reqFactory: reqFactory,
ctrl: ctrl,
syncProtocol: syncProtocol,
}
}
func (fx *aclSyncProtocolFixture) stop() {
fx.ctrl.Finish()
}
func TestHeadUpdate(t *testing.T) {
ctx := context.Background()
fullRequest := &consensusproto.LogSyncMessage{
Content: &consensusproto.LogSyncContentValue{
Value: &consensusproto.LogSyncContentValue_FullSyncRequest{
FullSyncRequest: &consensusproto.LogFullSyncRequest{},
},
},
}
t.Run("head update non empty all heads added", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(false)
fx.aclMock.EXPECT().AddRawRecords(headUpdate.Records).Return(nil)
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.Nil(t, req)
require.NoError(t, err)
})
t.Run("head update results in full request", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(false)
fx.aclMock.EXPECT().AddRawRecords(headUpdate.Records).Return(list.ErrIncorrectRecordSequence)
fx.reqFactory.EXPECT().CreateFullSyncRequest(fx.aclMock, headUpdate.Head).Return(fullRequest, nil)
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.Equal(t, fullRequest, req)
require.NoError(t, err)
})
t.Run("head update old heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(true)
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.Nil(t, req)
require.NoError(t, err)
})
t.Run("head update empty equals", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().Head().Return(&list.AclRecord{Id: "h1"})
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.Nil(t, req)
require.NoError(t, err)
})
t.Run("head update empty results in full request", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().Head().Return(&list.AclRecord{Id: "h2"})
fx.reqFactory.EXPECT().CreateFullSyncRequest(fx.aclMock, headUpdate.Head).Return(fullRequest, nil)
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.Equal(t, fullRequest, req)
require.NoError(t, err)
})
}
func TestFullSyncRequest(t *testing.T) {
ctx := context.Background()
fullResponse := &consensusproto.LogSyncMessage{
Content: &consensusproto.LogSyncContentValue{
Value: &consensusproto.LogSyncContentValue_FullSyncResponse{
FullSyncResponse: &consensusproto.LogFullSyncResponse{},
},
},
}
t.Run("full sync request non empty all heads added", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullRequest := &consensusproto.LogFullSyncRequest{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(false)
fx.aclMock.EXPECT().AddRawRecords(fullRequest.Records).Return(nil)
fx.reqFactory.EXPECT().CreateFullSyncResponse(fx.aclMock, fullRequest.Head).Return(fullResponse, nil)
resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest)
require.Equal(t, fullResponse, resp)
require.NoError(t, err)
})
t.Run("full sync request non empty head exists", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullRequest := &consensusproto.LogFullSyncRequest{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(true)
fx.reqFactory.EXPECT().CreateFullSyncResponse(fx.aclMock, fullRequest.Head).Return(fullResponse, nil)
resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest)
require.Equal(t, fullResponse, resp)
require.NoError(t, err)
})
t.Run("full sync request empty head not exists", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
fullRequest := &consensusproto.LogFullSyncRequest{
Head: "h1",
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(false)
resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest)
require.Nil(t, resp)
require.Error(t, list.ErrIncorrectRecordSequence, err)
})
}
func TestFullSyncResponse(t *testing.T) {
ctx := context.Background()
t.Run("full sync response no heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullResponse := &consensusproto.LogFullSyncResponse{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(false)
fx.aclMock.EXPECT().AddRawRecords(fullResponse.Records).Return(nil)
err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullResponse)
require.NoError(t, err)
})
t.Run("full sync response has heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullResponse := &consensusproto.LogFullSyncResponse{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(true)
err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullResponse)
require.NoError(t, err)
})
}

View File

@ -0,0 +1,5 @@
package headupdater
type HeadUpdater interface {
UpdateHeads(id string, heads []string)
}

View File

@ -0,0 +1,694 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/object/acl/syncacl (interfaces: SyncAcl,SyncClient,RequestFactory,AclSyncProtocol)
// Package mock_syncacl is a generated GoMock package.
package mock_syncacl
import (
context "context"
reflect "reflect"
app "github.com/anyproto/any-sync/app"
list "github.com/anyproto/any-sync/commonspace/object/acl/list"
headupdater "github.com/anyproto/any-sync/commonspace/object/acl/syncacl/headupdater"
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
consensusproto "github.com/anyproto/any-sync/consensus/consensusproto"
crypto "github.com/anyproto/any-sync/util/crypto"
gomock "go.uber.org/mock/gomock"
)
// MockSyncAcl is a mock of SyncAcl interface.
type MockSyncAcl struct {
ctrl *gomock.Controller
recorder *MockSyncAclMockRecorder
}
// MockSyncAclMockRecorder is the mock recorder for MockSyncAcl.
type MockSyncAclMockRecorder struct {
mock *MockSyncAcl
}
// NewMockSyncAcl creates a new mock instance.
func NewMockSyncAcl(ctrl *gomock.Controller) *MockSyncAcl {
mock := &MockSyncAcl{ctrl: ctrl}
mock.recorder = &MockSyncAclMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSyncAcl) EXPECT() *MockSyncAclMockRecorder {
return m.recorder
}
// AclState mocks base method.
func (m *MockSyncAcl) AclState() *list.AclState {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AclState")
ret0, _ := ret[0].(*list.AclState)
return ret0
}
// AclState indicates an expected call of AclState.
func (mr *MockSyncAclMockRecorder) AclState() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AclState", reflect.TypeOf((*MockSyncAcl)(nil).AclState))
}
// AddRawRecord mocks base method.
func (m *MockSyncAcl) AddRawRecord(arg0 *consensusproto.RawRecordWithId) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddRawRecord", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// AddRawRecord indicates an expected call of AddRawRecord.
func (mr *MockSyncAclMockRecorder) AddRawRecord(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockSyncAcl)(nil).AddRawRecord), arg0)
}
// AddRawRecords mocks base method.
func (m *MockSyncAcl) AddRawRecords(arg0 []*consensusproto.RawRecordWithId) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddRawRecords", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// AddRawRecords indicates an expected call of AddRawRecords.
func (mr *MockSyncAclMockRecorder) AddRawRecords(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecords", reflect.TypeOf((*MockSyncAcl)(nil).AddRawRecords), arg0)
}
// Close mocks base method.
func (m *MockSyncAcl) Close(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockSyncAclMockRecorder) Close(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSyncAcl)(nil).Close), arg0)
}
// Get mocks base method.
func (m *MockSyncAcl) Get(arg0 string) (*list.AclRecord, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0)
ret0, _ := ret[0].(*list.AclRecord)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockSyncAclMockRecorder) Get(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSyncAcl)(nil).Get), arg0)
}
// GetIndex mocks base method.
func (m *MockSyncAcl) GetIndex(arg0 int) (*list.AclRecord, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetIndex", arg0)
ret0, _ := ret[0].(*list.AclRecord)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetIndex indicates an expected call of GetIndex.
func (mr *MockSyncAclMockRecorder) GetIndex(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIndex", reflect.TypeOf((*MockSyncAcl)(nil).GetIndex), arg0)
}
// HandleMessage mocks base method.
func (m *MockSyncAcl) HandleMessage(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// HandleMessage indicates an expected call of HandleMessage.
func (mr *MockSyncAclMockRecorder) HandleMessage(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockSyncAcl)(nil).HandleMessage), arg0, arg1, arg2)
}
// HandleRequest mocks base method.
func (m *MockSyncAcl) HandleRequest(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandleRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// HandleRequest indicates an expected call of HandleRequest.
func (mr *MockSyncAclMockRecorder) HandleRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleRequest", reflect.TypeOf((*MockSyncAcl)(nil).HandleRequest), arg0, arg1, arg2)
}
// HasHead mocks base method.
func (m *MockSyncAcl) HasHead(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HasHead", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// HasHead indicates an expected call of HasHead.
func (mr *MockSyncAclMockRecorder) HasHead(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasHead", reflect.TypeOf((*MockSyncAcl)(nil).HasHead), arg0)
}
// Head mocks base method.
func (m *MockSyncAcl) Head() *list.AclRecord {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Head")
ret0, _ := ret[0].(*list.AclRecord)
return ret0
}
// Head indicates an expected call of Head.
func (mr *MockSyncAclMockRecorder) Head() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Head", reflect.TypeOf((*MockSyncAcl)(nil).Head))
}
// Id mocks base method.
func (m *MockSyncAcl) Id() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Id")
ret0, _ := ret[0].(string)
return ret0
}
// Id indicates an expected call of Id.
func (mr *MockSyncAclMockRecorder) Id() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Id", reflect.TypeOf((*MockSyncAcl)(nil).Id))
}
// Init mocks base method.
func (m *MockSyncAcl) Init(arg0 *app.App) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Init indicates an expected call of Init.
func (mr *MockSyncAclMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockSyncAcl)(nil).Init), arg0)
}
// IsAfter mocks base method.
func (m *MockSyncAcl) IsAfter(arg0, arg1 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsAfter", arg0, arg1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsAfter indicates an expected call of IsAfter.
func (mr *MockSyncAclMockRecorder) IsAfter(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsAfter", reflect.TypeOf((*MockSyncAcl)(nil).IsAfter), arg0, arg1)
}
// Iterate mocks base method.
func (m *MockSyncAcl) Iterate(arg0 func(*list.AclRecord) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Iterate", arg0)
}
// Iterate indicates an expected call of Iterate.
func (mr *MockSyncAclMockRecorder) Iterate(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Iterate", reflect.TypeOf((*MockSyncAcl)(nil).Iterate), arg0)
}
// IterateFrom mocks base method.
func (m *MockSyncAcl) IterateFrom(arg0 string, arg1 func(*list.AclRecord) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "IterateFrom", arg0, arg1)
}
// IterateFrom indicates an expected call of IterateFrom.
func (mr *MockSyncAclMockRecorder) IterateFrom(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateFrom", reflect.TypeOf((*MockSyncAcl)(nil).IterateFrom), arg0, arg1)
}
// KeyStorage mocks base method.
func (m *MockSyncAcl) KeyStorage() crypto.KeyStorage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeyStorage")
ret0, _ := ret[0].(crypto.KeyStorage)
return ret0
}
// KeyStorage indicates an expected call of KeyStorage.
func (mr *MockSyncAclMockRecorder) KeyStorage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyStorage", reflect.TypeOf((*MockSyncAcl)(nil).KeyStorage))
}
// Lock mocks base method.
func (m *MockSyncAcl) Lock() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Lock")
}
// Lock indicates an expected call of Lock.
func (mr *MockSyncAclMockRecorder) Lock() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lock", reflect.TypeOf((*MockSyncAcl)(nil).Lock))
}
// Name mocks base method.
func (m *MockSyncAcl) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockSyncAclMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockSyncAcl)(nil).Name))
}
// RLock mocks base method.
func (m *MockSyncAcl) RLock() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RLock")
}
// RLock indicates an expected call of RLock.
func (mr *MockSyncAclMockRecorder) RLock() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RLock", reflect.TypeOf((*MockSyncAcl)(nil).RLock))
}
// RUnlock mocks base method.
func (m *MockSyncAcl) RUnlock() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RUnlock")
}
// RUnlock indicates an expected call of RUnlock.
func (mr *MockSyncAclMockRecorder) RUnlock() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RUnlock", reflect.TypeOf((*MockSyncAcl)(nil).RUnlock))
}
// RecordBuilder mocks base method.
func (m *MockSyncAcl) RecordBuilder() list.AclRecordBuilder {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RecordBuilder")
ret0, _ := ret[0].(list.AclRecordBuilder)
return ret0
}
// RecordBuilder indicates an expected call of RecordBuilder.
func (mr *MockSyncAclMockRecorder) RecordBuilder() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordBuilder", reflect.TypeOf((*MockSyncAcl)(nil).RecordBuilder))
}
// Records mocks base method.
func (m *MockSyncAcl) Records() []*list.AclRecord {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Records")
ret0, _ := ret[0].([]*list.AclRecord)
return ret0
}
// Records indicates an expected call of Records.
func (mr *MockSyncAclMockRecorder) Records() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Records", reflect.TypeOf((*MockSyncAcl)(nil).Records))
}
// RecordsAfter mocks base method.
func (m *MockSyncAcl) RecordsAfter(arg0 context.Context, arg1 string) ([]*consensusproto.RawRecordWithId, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RecordsAfter", arg0, arg1)
ret0, _ := ret[0].([]*consensusproto.RawRecordWithId)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RecordsAfter indicates an expected call of RecordsAfter.
func (mr *MockSyncAclMockRecorder) RecordsAfter(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordsAfter", reflect.TypeOf((*MockSyncAcl)(nil).RecordsAfter), arg0, arg1)
}
// Root mocks base method.
func (m *MockSyncAcl) Root() *consensusproto.RawRecordWithId {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Root")
ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
return ret0
}
// Root indicates an expected call of Root.
func (mr *MockSyncAclMockRecorder) Root() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Root", reflect.TypeOf((*MockSyncAcl)(nil).Root))
}
// Run mocks base method.
func (m *MockSyncAcl) Run(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Run indicates an expected call of Run.
func (mr *MockSyncAclMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSyncAcl)(nil).Run), arg0)
}
// SetHeadUpdater mocks base method.
func (m *MockSyncAcl) SetHeadUpdater(arg0 headupdater.HeadUpdater) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetHeadUpdater", arg0)
}
// SetHeadUpdater indicates an expected call of SetHeadUpdater.
func (mr *MockSyncAclMockRecorder) SetHeadUpdater(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeadUpdater", reflect.TypeOf((*MockSyncAcl)(nil).SetHeadUpdater), arg0)
}
// SyncWithPeer mocks base method.
func (m *MockSyncAcl) SyncWithPeer(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SyncWithPeer", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SyncWithPeer indicates an expected call of SyncWithPeer.
func (mr *MockSyncAclMockRecorder) SyncWithPeer(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncWithPeer", reflect.TypeOf((*MockSyncAcl)(nil).SyncWithPeer), arg0, arg1)
}
// Unlock mocks base method.
func (m *MockSyncAcl) Unlock() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Unlock")
}
// Unlock indicates an expected call of Unlock.
func (mr *MockSyncAclMockRecorder) Unlock() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockSyncAcl)(nil).Unlock))
}
// ValidateRawRecord mocks base method.
func (m *MockSyncAcl) ValidateRawRecord(arg0 *consensusproto.RawRecord) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ValidateRawRecord", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// ValidateRawRecord indicates an expected call of ValidateRawRecord.
func (mr *MockSyncAclMockRecorder) ValidateRawRecord(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateRawRecord", reflect.TypeOf((*MockSyncAcl)(nil).ValidateRawRecord), arg0)
}
// MockSyncClient is a mock of SyncClient interface.
type MockSyncClient struct {
ctrl *gomock.Controller
recorder *MockSyncClientMockRecorder
}
// MockSyncClientMockRecorder is the mock recorder for MockSyncClient.
type MockSyncClientMockRecorder struct {
mock *MockSyncClient
}
// NewMockSyncClient creates a new mock instance.
func NewMockSyncClient(ctrl *gomock.Controller) *MockSyncClient {
mock := &MockSyncClient{ctrl: ctrl}
mock.recorder = &MockSyncClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSyncClient) EXPECT() *MockSyncClientMockRecorder {
return m.recorder
}
// Broadcast mocks base method.
func (m *MockSyncClient) Broadcast(arg0 *consensusproto.LogSyncMessage) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Broadcast", arg0)
}
// Broadcast indicates an expected call of Broadcast.
func (mr *MockSyncClientMockRecorder) Broadcast(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Broadcast", reflect.TypeOf((*MockSyncClient)(nil).Broadcast), arg0)
}
// CreateFullSyncRequest mocks base method.
func (m *MockSyncClient) CreateFullSyncRequest(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
func (mr *MockSyncClientMockRecorder) CreateFullSyncRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncRequest), arg0, arg1)
}
// CreateFullSyncResponse mocks base method.
func (m *MockSyncClient) CreateFullSyncResponse(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
func (mr *MockSyncClientMockRecorder) CreateFullSyncResponse(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncResponse), arg0, arg1)
}
// CreateHeadUpdate mocks base method.
func (m *MockSyncClient) CreateHeadUpdate(arg0 list.AclList, arg1 []*consensusproto.RawRecordWithId) *consensusproto.LogSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
return ret0
}
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate.
func (mr *MockSyncClientMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockSyncClient)(nil).CreateHeadUpdate), arg0, arg1)
}
// QueueRequest mocks base method.
func (m *MockSyncClient) QueueRequest(arg0 string, arg1 *consensusproto.LogSyncMessage) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "QueueRequest", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// QueueRequest indicates an expected call of QueueRequest.
func (mr *MockSyncClientMockRecorder) QueueRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueRequest", reflect.TypeOf((*MockSyncClient)(nil).QueueRequest), arg0, arg1)
}
// SendRequest mocks base method.
func (m *MockSyncClient) SendRequest(arg0 context.Context, arg1 string, arg2 *consensusproto.LogSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SendRequest indicates an expected call of SendRequest.
func (mr *MockSyncClientMockRecorder) SendRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRequest", reflect.TypeOf((*MockSyncClient)(nil).SendRequest), arg0, arg1, arg2)
}
// SendUpdate mocks base method.
func (m *MockSyncClient) SendUpdate(arg0 string, arg1 *consensusproto.LogSyncMessage) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendUpdate", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SendUpdate indicates an expected call of SendUpdate.
func (mr *MockSyncClientMockRecorder) SendUpdate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendUpdate", reflect.TypeOf((*MockSyncClient)(nil).SendUpdate), arg0, arg1)
}
// MockRequestFactory is a mock of RequestFactory interface.
type MockRequestFactory struct {
ctrl *gomock.Controller
recorder *MockRequestFactoryMockRecorder
}
// MockRequestFactoryMockRecorder is the mock recorder for MockRequestFactory.
type MockRequestFactoryMockRecorder struct {
mock *MockRequestFactory
}
// NewMockRequestFactory creates a new mock instance.
func NewMockRequestFactory(ctrl *gomock.Controller) *MockRequestFactory {
mock := &MockRequestFactory{ctrl: ctrl}
mock.recorder = &MockRequestFactoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRequestFactory) EXPECT() *MockRequestFactoryMockRecorder {
return m.recorder
}
// CreateFullSyncRequest mocks base method.
func (m *MockRequestFactory) CreateFullSyncRequest(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
func (mr *MockRequestFactoryMockRecorder) CreateFullSyncRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncRequest), arg0, arg1)
}
// CreateFullSyncResponse mocks base method.
func (m *MockRequestFactory) CreateFullSyncResponse(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
func (mr *MockRequestFactoryMockRecorder) CreateFullSyncResponse(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncResponse), arg0, arg1)
}
// CreateHeadUpdate mocks base method.
func (m *MockRequestFactory) CreateHeadUpdate(arg0 list.AclList, arg1 []*consensusproto.RawRecordWithId) *consensusproto.LogSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
return ret0
}
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate.
func (mr *MockRequestFactoryMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockRequestFactory)(nil).CreateHeadUpdate), arg0, arg1)
}
// MockAclSyncProtocol is a mock of AclSyncProtocol interface.
type MockAclSyncProtocol struct {
ctrl *gomock.Controller
recorder *MockAclSyncProtocolMockRecorder
}
// MockAclSyncProtocolMockRecorder is the mock recorder for MockAclSyncProtocol.
type MockAclSyncProtocolMockRecorder struct {
mock *MockAclSyncProtocol
}
// NewMockAclSyncProtocol creates a new mock instance.
func NewMockAclSyncProtocol(ctrl *gomock.Controller) *MockAclSyncProtocol {
mock := &MockAclSyncProtocol{ctrl: ctrl}
mock.recorder = &MockAclSyncProtocolMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockAclSyncProtocol) EXPECT() *MockAclSyncProtocolMockRecorder {
return m.recorder
}
// FullSyncRequest mocks base method.
func (m *MockAclSyncProtocol) FullSyncRequest(arg0 context.Context, arg1 string, arg2 *consensusproto.LogFullSyncRequest) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FullSyncRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FullSyncRequest indicates an expected call of FullSyncRequest.
func (mr *MockAclSyncProtocolMockRecorder) FullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncRequest", reflect.TypeOf((*MockAclSyncProtocol)(nil).FullSyncRequest), arg0, arg1, arg2)
}
// FullSyncResponse mocks base method.
func (m *MockAclSyncProtocol) FullSyncResponse(arg0 context.Context, arg1 string, arg2 *consensusproto.LogFullSyncResponse) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FullSyncResponse", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// FullSyncResponse indicates an expected call of FullSyncResponse.
func (mr *MockAclSyncProtocolMockRecorder) FullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncResponse", reflect.TypeOf((*MockAclSyncProtocol)(nil).FullSyncResponse), arg0, arg1, arg2)
}
// HeadUpdate mocks base method.
func (m *MockAclSyncProtocol) HeadUpdate(arg0 context.Context, arg1 string, arg2 *consensusproto.LogHeadUpdate) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HeadUpdate", arg0, arg1, arg2)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// HeadUpdate indicates an expected call of HeadUpdate.
func (mr *MockAclSyncProtocolMockRecorder) HeadUpdate(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadUpdate", reflect.TypeOf((*MockAclSyncProtocol)(nil).HeadUpdate), arg0, arg1, arg2)
}

View File

@ -0,0 +1,54 @@
package syncacl
import (
"context"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/consensus/consensusproto"
)
type RequestFactory interface {
CreateHeadUpdate(l list.AclList, added []*consensusproto.RawRecordWithId) (msg *consensusproto.LogSyncMessage)
CreateFullSyncRequest(l list.AclList, theirHead string) (req *consensusproto.LogSyncMessage, err error)
CreateFullSyncResponse(l list.AclList, theirHead string) (*consensusproto.LogSyncMessage, error)
}
type requestFactory struct{}
func NewRequestFactory() RequestFactory {
return &requestFactory{}
}
func (r *requestFactory) CreateHeadUpdate(l list.AclList, added []*consensusproto.RawRecordWithId) (msg *consensusproto.LogSyncMessage) {
return consensusproto.WrapHeadUpdate(&consensusproto.LogHeadUpdate{
Head: l.Head().Id,
Records: added,
}, l.Root())
}
func (r *requestFactory) CreateFullSyncRequest(l list.AclList, theirHead string) (req *consensusproto.LogSyncMessage, err error) {
if !l.HasHead(theirHead) {
return consensusproto.WrapFullRequest(&consensusproto.LogFullSyncRequest{
Head: l.Head().Id,
}, l.Root()), nil
}
records, err := l.RecordsAfter(context.Background(), theirHead)
if err != nil {
return
}
return consensusproto.WrapFullRequest(&consensusproto.LogFullSyncRequest{
Head: l.Head().Id,
Records: records,
}, l.Root()), nil
}
func (r *requestFactory) CreateFullSyncResponse(l list.AclList, theirHead string) (resp *consensusproto.LogSyncMessage, err error) {
records, err := l.RecordsAfter(context.Background(), theirHead)
if err != nil {
return
}
return consensusproto.WrapFullResponse(&consensusproto.LogFullSyncResponse{
Head: l.Head().Id,
Records: records,
}, l.Root()), nil
}

View File

@ -1,21 +1,130 @@
package syncacl package syncacl
import ( import (
"github.com/anytypeio/any-sync/commonspace/object/acl/list" "context"
"github.com/anytypeio/any-sync/commonspace/objectsync" "errors"
"github.com/anytypeio/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/object/acl/syncacl/headupdater"
"github.com/anyproto/any-sync/commonspace/object/syncobjectgetter"
"github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/requestmanager"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/consensus/consensusproto"
) )
type SyncAcl struct { const CName = "common.acl.syncacl"
var (
log = logger.NewNamed(CName)
ErrSyncAclClosed = errors.New("sync acl is closed")
)
type SyncAcl interface {
app.ComponentRunnable
list.AclList list.AclList
synchandler.SyncHandler syncobjectgetter.SyncObject
messagePool objectsync.MessagePool SetHeadUpdater(updater headupdater.HeadUpdater)
SyncWithPeer(ctx context.Context, peerId string) (err error)
} }
func NewSyncAcl(aclList list.AclList, messagePool objectsync.MessagePool) *SyncAcl { func New() SyncAcl {
return &SyncAcl{ return &syncAcl{}
AclList: aclList, }
SyncHandler: nil,
messagePool: messagePool, type syncAcl struct {
} list.AclList
syncClient SyncClient
syncHandler synchandler.SyncHandler
headUpdater headupdater.HeadUpdater
isClosed bool
}
func (s *syncAcl) Run(ctx context.Context) (err error) {
return
}
func (s *syncAcl) HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) {
return s.syncHandler.HandleRequest(ctx, senderId, request)
}
func (s *syncAcl) SetHeadUpdater(updater headupdater.HeadUpdater) {
s.headUpdater = updater
}
func (s *syncAcl) HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) {
return s.syncHandler.HandleMessage(ctx, senderId, request)
}
func (s *syncAcl) Init(a *app.App) (err error) {
storage := a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
aclStorage, err := storage.AclStorage()
if err != nil {
return err
}
acc := a.MustComponent(accountservice.CName).(accountservice.Service)
s.AclList, err = list.BuildAclListWithIdentity(acc.Account(), aclStorage, list.NoOpAcceptorVerifier{})
if err != nil {
return
}
spaceId := storage.Id()
requestManager := a.MustComponent(requestmanager.CName).(requestmanager.RequestManager)
peerManager := a.MustComponent(peermanager.CName).(peermanager.PeerManager)
syncStatus := a.MustComponent(syncstatus.CName).(syncstatus.StatusService)
s.syncClient = NewSyncClient(spaceId, requestManager, peerManager)
s.syncHandler = newSyncAclHandler(storage.Id(), s, s.syncClient, syncStatus)
return err
}
func (s *syncAcl) AddRawRecord(rawRec *consensusproto.RawRecordWithId) (err error) {
if s.isClosed {
return ErrSyncAclClosed
}
err = s.AclList.AddRawRecord(rawRec)
if err != nil {
return
}
headUpdate := s.syncClient.CreateHeadUpdate(s, []*consensusproto.RawRecordWithId{rawRec})
s.headUpdater.UpdateHeads(s.Id(), []string{rawRec.Id})
s.syncClient.Broadcast(headUpdate)
return
}
func (s *syncAcl) AddRawRecords(rawRecords []*consensusproto.RawRecordWithId) (err error) {
if s.isClosed {
return ErrSyncAclClosed
}
err = s.AclList.AddRawRecords(rawRecords)
if err != nil {
return
}
headUpdate := s.syncClient.CreateHeadUpdate(s, rawRecords)
s.headUpdater.UpdateHeads(s.Id(), []string{rawRecords[len(rawRecords)-1].Id})
s.syncClient.Broadcast(headUpdate)
return
}
func (s *syncAcl) SyncWithPeer(ctx context.Context, peerId string) (err error) {
s.Lock()
defer s.Unlock()
headUpdate := s.syncClient.CreateHeadUpdate(s, nil)
return s.syncClient.SendUpdate(peerId, headUpdate)
}
func (s *syncAcl) Close(ctx context.Context) (err error) {
s.Lock()
defer s.Unlock()
s.isClosed = true
return
}
func (s *syncAcl) Name() (name string) {
return CName
} }

View File

@ -2,30 +2,81 @@ package syncacl
import ( import (
"context" "context"
"fmt" "errors"
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anytypeio/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anytypeio/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/gogo/protobuf/proto"
)
var (
ErrMessageIsRequest = errors.New("message is request")
ErrMessageIsNotRequest = errors.New("message is not request")
) )
type syncAclHandler struct { type syncAclHandler struct {
acl list.AclList aclList list.AclList
syncClient SyncClient
syncProtocol AclSyncProtocol
syncStatus syncstatus.StatusUpdater
spaceId string
} }
func (s *syncAclHandler) HandleMessage(ctx context.Context, senderId string, req *spacesyncproto.ObjectSyncMessage) (err error) { func newSyncAclHandler(spaceId string, aclList list.AclList, syncClient SyncClient, syncStatus syncstatus.StatusUpdater) synchandler.SyncHandler {
aclMsg := &aclrecordproto.AclSyncMessage{} return &syncAclHandler{
if err = aclMsg.Unmarshal(req.Payload); err != nil { aclList: aclList,
syncClient: syncClient,
syncProtocol: newAclSyncProtocol(spaceId, aclList, syncClient),
syncStatus: syncStatus,
spaceId: spaceId,
}
}
func (s *syncAclHandler) HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) {
unmarshalled := &consensusproto.LogSyncMessage{}
err = proto.Unmarshal(message.Payload, unmarshalled)
if err != nil {
return return
} }
content := aclMsg.GetContent() content := unmarshalled.GetContent()
head := consensusproto.GetHead(unmarshalled)
s.syncStatus.HeadsReceive(senderId, s.aclList.Id(), []string{head})
s.aclList.Lock()
defer s.aclList.Unlock()
switch { switch {
case content.GetAddRecords() != nil: case content.GetHeadUpdate() != nil:
return s.handleAddRecords(ctx, senderId, content.GetAddRecords()) var syncReq *consensusproto.LogSyncMessage
default: syncReq, err = s.syncProtocol.HeadUpdate(ctx, senderId, content.GetHeadUpdate())
return fmt.Errorf("unexpected aclSync message: %T", content.Value) if err != nil || syncReq == nil {
return
}
return s.syncClient.QueueRequest(senderId, syncReq)
case content.GetFullSyncRequest() != nil:
return ErrMessageIsRequest
case content.GetFullSyncResponse() != nil:
return s.syncProtocol.FullSyncResponse(ctx, senderId, content.GetFullSyncResponse())
} }
}
func (s *syncAclHandler) handleAddRecords(ctx context.Context, senderId string, addRecord *aclrecordproto.AclAddRecords) (err error) {
return return
} }
func (s *syncAclHandler) HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) {
unmarshalled := &consensusproto.LogSyncMessage{}
err = proto.Unmarshal(request.Payload, unmarshalled)
if err != nil {
return
}
fullSyncRequest := unmarshalled.GetContent().GetFullSyncRequest()
if fullSyncRequest == nil {
return nil, ErrMessageIsNotRequest
}
s.aclList.Lock()
defer s.aclList.Unlock()
aclResp, err := s.syncProtocol.FullSyncRequest(ctx, senderId, fullSyncRequest)
if err != nil {
return
}
return spacesyncproto.MarshallSyncMessage(aclResp, s.spaceId, s.aclList.Id())
}

View File

@ -0,0 +1,233 @@
package syncacl
import (
"context"
"fmt"
"github.com/anyproto/any-sync/commonspace/object/acl/list/mock_list"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl/mock_syncacl"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"sync"
"testing"
)
type testAclMock struct {
*mock_list.MockAclList
m sync.RWMutex
}
func newTestAclMock(mockAcl *mock_list.MockAclList) *testAclMock {
return &testAclMock{
MockAclList: mockAcl,
}
}
func (t *testAclMock) Lock() {
t.m.Lock()
}
func (t *testAclMock) RLock() {
t.m.RLock()
}
func (t *testAclMock) Unlock() {
t.m.Unlock()
}
func (t *testAclMock) RUnlock() {
t.m.RUnlock()
}
func (t *testAclMock) TryLock() bool {
return t.m.TryLock()
}
func (t *testAclMock) TryRLock() bool {
return t.m.TryRLock()
}
type syncHandlerFixture struct {
ctrl *gomock.Controller
syncClientMock *mock_syncacl.MockSyncClient
aclMock *testAclMock
syncProtocolMock *mock_syncacl.MockAclSyncProtocol
spaceId string
senderId string
aclId string
syncHandler *syncAclHandler
}
func newSyncHandlerFixture(t *testing.T) *syncHandlerFixture {
ctrl := gomock.NewController(t)
aclMock := newTestAclMock(mock_list.NewMockAclList(ctrl))
syncClientMock := mock_syncacl.NewMockSyncClient(ctrl)
syncProtocolMock := mock_syncacl.NewMockAclSyncProtocol(ctrl)
spaceId := "spaceId"
syncHandler := &syncAclHandler{
aclList: aclMock,
syncClient: syncClientMock,
syncProtocol: syncProtocolMock,
syncStatus: syncstatus.NewNoOpSyncStatus(),
spaceId: spaceId,
}
return &syncHandlerFixture{
ctrl: ctrl,
syncClientMock: syncClientMock,
aclMock: aclMock,
syncProtocolMock: syncProtocolMock,
spaceId: spaceId,
senderId: "senderId",
aclId: "aclId",
syncHandler: syncHandler,
}
}
func (fx *syncHandlerFixture) stop() {
fx.ctrl.Finish()
}
func TestSyncAclHandler_HandleMessage(t *testing.T) {
ctx := context.Background()
t.Run("handle head update, request returned", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
syncReq := &consensusproto.LogSyncMessage{}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(syncReq, nil)
fx.syncClientMock.EXPECT().QueueRequest(fx.senderId, syncReq).Return(nil)
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
})
t.Run("handle head update, no request", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(nil, nil)
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
})
t.Run("handle head update, returned error", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
expectedErr := fmt.Errorf("some error")
fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(nil, expectedErr)
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.Error(t, expectedErr, err)
})
t.Run("handle full sync request is forbidden", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullRequest := &consensusproto.LogFullSyncRequest{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapFullRequest(fullRequest, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.Error(t, ErrMessageIsRequest, err)
})
t.Run("handle full sync response, no error", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullResponse := &consensusproto.LogFullSyncResponse{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapFullResponse(fullResponse, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.syncProtocolMock.EXPECT().FullSyncResponse(ctx, fx.senderId, gomock.Any()).Return(nil)
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
})
}
func TestSyncAclHandler_HandleRequest(t *testing.T) {
ctx := context.Background()
t.Run("handle full sync request, no error", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullRequest := &consensusproto.LogFullSyncRequest{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapFullRequest(fullRequest, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fullResp := &consensusproto.LogSyncMessage{
Content: &consensusproto.LogSyncContentValue{
Value: &consensusproto.LogSyncContentValue_FullSyncResponse{
FullSyncResponse: &consensusproto.LogFullSyncResponse{
Head: "returnedHead",
},
},
},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.syncProtocolMock.EXPECT().FullSyncRequest(ctx, fx.senderId, gomock.Any()).Return(fullResp, nil)
res, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
unmarshalled := &consensusproto.LogSyncMessage{}
err = proto.Unmarshal(res.Payload, unmarshalled)
if err != nil {
return
}
require.Equal(t, "returnedHead", consensusproto.GetHead(unmarshalled))
})
t.Run("handle other message returns error", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
_, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg)
require.Error(t, ErrMessageIsNotRequest, err)
})
}

View File

@ -0,0 +1,70 @@
package syncacl
import (
"context"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/requestmanager"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"go.uber.org/zap"
)
type SyncClient interface {
RequestFactory
Broadcast(msg *consensusproto.LogSyncMessage)
SendUpdate(peerId string, msg *consensusproto.LogSyncMessage) (err error)
QueueRequest(peerId string, msg *consensusproto.LogSyncMessage) (err error)
SendRequest(ctx context.Context, peerId string, msg *consensusproto.LogSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
}
type syncClient struct {
RequestFactory
spaceId string
requestManager requestmanager.RequestManager
peerManager peermanager.PeerManager
}
func NewSyncClient(spaceId string, requestManager requestmanager.RequestManager, peerManager peermanager.PeerManager) SyncClient {
return &syncClient{
RequestFactory: &requestFactory{},
spaceId: spaceId,
requestManager: requestManager,
peerManager: peerManager,
}
}
func (s *syncClient) Broadcast(msg *consensusproto.LogSyncMessage) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
if err != nil {
return
}
err = s.peerManager.Broadcast(context.Background(), objMsg)
if err != nil {
log.Debug("broadcast error", zap.Error(err))
}
}
func (s *syncClient) SendUpdate(peerId string, msg *consensusproto.LogSyncMessage) (err error) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
if err != nil {
return
}
return s.peerManager.SendPeer(context.Background(), peerId, objMsg)
}
func (s *syncClient) SendRequest(ctx context.Context, peerId string, msg *consensusproto.LogSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
if err != nil {
return
}
return s.requestManager.SendRequest(ctx, peerId, objMsg)
}
func (s *syncClient) QueueRequest(peerId string, msg *consensusproto.LogSyncMessage) (err error) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
if err != nil {
return
}
return s.requestManager.QueueRequest(peerId, objMsg)
}

View File

@ -1,194 +0,0 @@
package acllistbuilder
import (
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anytypeio/any-sync/util/keys"
"github.com/anytypeio/any-sync/util/keys/asymmetric/encryptionkey"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
"github.com/anytypeio/any-sync/util/keys/symmetric"
"hash/fnv"
"strings"
)
type SymKey struct {
Hash uint64
Key *symmetric.Key
}
type YAMLKeychain struct {
SigningKeysByYAMLName map[string]signingkey.PrivKey
SigningKeysByRealIdentity map[string]signingkey.PrivKey
EncryptionKeysByYAMLName map[string]encryptionkey.PrivKey
ReadKeysByYAMLName map[string]*SymKey
ReadKeysByHash map[uint64]*SymKey
GeneratedIdentities map[string]string
}
func NewKeychain() *YAMLKeychain {
return &YAMLKeychain{
SigningKeysByYAMLName: map[string]signingkey.PrivKey{},
SigningKeysByRealIdentity: map[string]signingkey.PrivKey{},
EncryptionKeysByYAMLName: map[string]encryptionkey.PrivKey{},
GeneratedIdentities: map[string]string{},
ReadKeysByYAMLName: map[string]*SymKey{},
ReadKeysByHash: map[uint64]*SymKey{},
}
}
func (k *YAMLKeychain) ParseKeys(keys *Keys) {
for _, encKey := range keys.Enc {
k.AddEncryptionKey(encKey)
}
for _, signKey := range keys.Sign {
k.AddSigningKey(signKey)
}
for _, readKey := range keys.Read {
k.AddReadKey(readKey)
}
}
func (k *YAMLKeychain) AddEncryptionKey(key *Key) {
if _, exists := k.EncryptionKeysByYAMLName[key.Name]; exists {
return
}
var (
newPrivKey encryptionkey.PrivKey
err error
)
if key.Value == "generated" {
newPrivKey, _, err = encryptionkey.GenerateRandomRSAKeyPair(2048)
if err != nil {
panic(err)
}
} else {
newPrivKey, err = keys.DecodeKeyFromString(key.Value, encryptionkey.NewEncryptionRsaPrivKeyFromBytes, nil)
if err != nil {
panic(err)
}
}
k.EncryptionKeysByYAMLName[key.Name] = newPrivKey
}
func (k *YAMLKeychain) AddSigningKey(key *Key) {
if _, exists := k.SigningKeysByYAMLName[key.Name]; exists {
return
}
var (
newPrivKey signingkey.PrivKey
pubKey signingkey.PubKey
err error
)
if key.Value == "generated" {
newPrivKey, pubKey, err = signingkey.GenerateRandomEd25519KeyPair()
if err != nil {
panic(err)
}
} else {
newPrivKey, err = keys.DecodeKeyFromString(key.Value, signingkey.NewSigningEd25519PrivKeyFromBytes, nil)
if err != nil {
panic(err)
}
pubKey = newPrivKey.GetPublic()
}
k.SigningKeysByYAMLName[key.Name] = newPrivKey
rawPubKey, err := pubKey.Raw()
if err != nil {
panic(err)
}
encoded := string(rawPubKey)
k.SigningKeysByRealIdentity[encoded] = newPrivKey
k.GeneratedIdentities[key.Name] = encoded
}
func (k *YAMLKeychain) AddReadKey(key *Key) {
if _, exists := k.ReadKeysByYAMLName[key.Name]; exists {
return
}
var (
rkey *symmetric.Key
err error
)
if key.Value == "generated" {
rkey, err = symmetric.NewRandom()
if err != nil {
panic("should be able to generate symmetric key")
}
} else if key.Value == "derived" {
signKey, _ := k.SigningKeysByYAMLName[key.Name].Raw()
encKey, _ := k.EncryptionKeysByYAMLName[key.Name].Raw()
rkey, err = aclrecordproto.AclReadKeyDerive(signKey, encKey)
if err != nil {
panic("should be able to derive symmetric key")
}
} else {
rkey, err = symmetric.FromString(key.Value)
if err != nil {
panic("should be able to parse symmetric key")
}
}
hasher := fnv.New64()
hasher.Write(rkey.Bytes())
k.ReadKeysByYAMLName[key.Name] = &SymKey{
Hash: hasher.Sum64(),
Key: rkey,
}
k.ReadKeysByHash[hasher.Sum64()] = &SymKey{
Hash: hasher.Sum64(),
Key: rkey,
}
}
func (k *YAMLKeychain) AddKey(key *Key) {
parts := strings.Split(key.Name, ".")
if len(parts) != 3 {
panic("cannot parse a key")
}
switch parts[1] {
case "Signature":
k.AddSigningKey(key)
case "Enc":
k.AddEncryptionKey(key)
case "Read":
k.AddReadKey(key)
default:
panic("incorrect format")
}
}
func (k *YAMLKeychain) GetKey(key string) interface{} {
parts := strings.Split(key, ".")
if len(parts) != 3 {
panic("cannot parse a key")
}
name := parts[2]
switch parts[1] {
case "Sign":
if key, exists := k.SigningKeysByYAMLName[name]; exists {
return key
}
case "Enc":
if key, exists := k.EncryptionKeysByYAMLName[name]; exists {
return key
}
case "Read":
if key, exists := k.ReadKeysByYAMLName[name]; exists {
return key
}
default:
panic("incorrect format")
}
return nil
}
func (k *YAMLKeychain) GetIdentity(name string) string {
return k.GeneratedIdentities[name]
}

View File

@ -1,295 +0,0 @@
package acllistbuilder
import (
"context"
"fmt"
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anytypeio/any-sync/commonspace/object/acl/liststorage"
"github.com/anytypeio/any-sync/commonspace/object/acl/testutils/yamltests"
"github.com/anytypeio/any-sync/util/cidutil"
"github.com/anytypeio/any-sync/util/keys/asymmetric/encryptionkey"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
"github.com/anytypeio/any-sync/util/keys/symmetric"
"gopkg.in/yaml.v3"
"io/ioutil"
"path"
"time"
"github.com/gogo/protobuf/proto"
)
type AclListStorageBuilder struct {
liststorage.ListStorage
keychain *YAMLKeychain
}
func NewAclListStorageBuilder(keychain *YAMLKeychain) *AclListStorageBuilder {
return &AclListStorageBuilder{
keychain: keychain,
}
}
func NewListStorageWithTestName(name string) (liststorage.ListStorage, error) {
filePath := path.Join(yamltests.Path(), name)
return NewAclListStorageBuilderFromFile(filePath)
}
func NewAclListStorageBuilderFromFile(file string) (*AclListStorageBuilder, error) {
content, err := ioutil.ReadFile(file)
if err != nil {
return nil, err
}
ymlTree := YMLList{}
err = yaml.Unmarshal(content, &ymlTree)
if err != nil {
return nil, err
}
tb := NewAclListStorageBuilder(NewKeychain())
tb.Parse(&ymlTree)
return tb, nil
}
func (t *AclListStorageBuilder) createRaw(rec proto.Marshaler, identity []byte) *aclrecordproto.RawAclRecordWithId {
protoMarshalled, err := rec.Marshal()
if err != nil {
panic("should be able to marshal final acl message!")
}
signature, err := t.keychain.SigningKeysByRealIdentity[string(identity)].Sign(protoMarshalled)
if err != nil {
panic("should be able to sign final acl message!")
}
rawRec := &aclrecordproto.RawAclRecord{
Payload: protoMarshalled,
Signature: signature,
}
rawMarshalled, err := proto.Marshal(rawRec)
if err != nil {
panic(err)
}
id, _ := cidutil.NewCidFromBytes(rawMarshalled)
return &aclrecordproto.RawAclRecordWithId{
Payload: rawMarshalled,
Id: id,
}
}
func (t *AclListStorageBuilder) GetKeychain() *YAMLKeychain {
return t.keychain
}
func (t *AclListStorageBuilder) Parse(l *YMLList) {
// Just to clarify - we are generating new identities for the ones that
// are specified in the yml file, because our identities should be Ed25519
// the same thing is happening for the encryption keys
t.keychain.ParseKeys(&l.Keys)
rawRoot := t.parseRoot(l.Root)
var err error
t.ListStorage, err = liststorage.NewInMemoryAclListStorage(rawRoot.Id, []*aclrecordproto.RawAclRecordWithId{rawRoot})
if err != nil {
panic(err)
}
prevId := rawRoot.Id
for _, rec := range l.Records {
newRecord := t.parseRecord(rec, prevId)
rawRecord := t.createRaw(newRecord, newRecord.Identity)
err = t.AddRawRecord(context.Background(), rawRecord)
if err != nil {
panic(err)
}
prevId = rawRecord.Id
}
t.SetHead(prevId)
}
func (t *AclListStorageBuilder) parseRecord(rec *Record, prevId string) *aclrecordproto.AclRecord {
k := t.keychain.GetKey(rec.ReadKey).(*SymKey)
var aclChangeContents []*aclrecordproto.AclContentValue
for _, ch := range rec.AclChanges {
aclChangeContent := t.parseAclChange(ch)
aclChangeContents = append(aclChangeContents, aclChangeContent)
}
data := &aclrecordproto.AclData{
AclContent: aclChangeContents,
}
bytes, _ := data.Marshal()
return &aclrecordproto.AclRecord{
PrevId: prevId,
Identity: []byte(t.keychain.GetIdentity(rec.Identity)),
Data: bytes,
CurrentReadKeyHash: k.Hash,
Timestamp: time.Now().UnixNano(),
}
}
func (t *AclListStorageBuilder) parseAclChange(ch *AclChange) (convCh *aclrecordproto.AclContentValue) {
switch {
case ch.UserAdd != nil:
add := ch.UserAdd
encKey := t.keychain.GetKey(add.EncryptionKey).(encryptionkey.PrivKey)
rawKey, _ := encKey.GetPublic().Raw()
convCh = &aclrecordproto.AclContentValue{
Value: &aclrecordproto.AclContentValue_UserAdd{
UserAdd: &aclrecordproto.AclUserAdd{
Identity: []byte(t.keychain.GetIdentity(add.Identity)),
EncryptionKey: rawKey,
EncryptedReadKeys: t.encryptReadKeysWithPubKey(add.EncryptedReadKeys, encKey),
Permissions: t.convertPermission(add.Permission),
},
},
}
case ch.UserJoin != nil:
join := ch.UserJoin
encKey := t.keychain.GetKey(join.EncryptionKey).(encryptionkey.PrivKey)
rawKey, _ := encKey.GetPublic().Raw()
idKey, _ := t.keychain.SigningKeysByYAMLName[join.Identity].GetPublic().Raw()
signKey := t.keychain.GetKey(join.AcceptKey).(signingkey.PrivKey)
signature, err := signKey.Sign(idKey)
if err != nil {
panic(err)
}
acceptPubKey, _ := signKey.GetPublic().Raw()
convCh = &aclrecordproto.AclContentValue{
Value: &aclrecordproto.AclContentValue_UserJoin{
UserJoin: &aclrecordproto.AclUserJoin{
Identity: []byte(t.keychain.GetIdentity(join.Identity)),
EncryptionKey: rawKey,
AcceptSignature: signature,
AcceptPubKey: acceptPubKey,
EncryptedReadKeys: t.encryptReadKeysWithPubKey(join.EncryptedReadKeys, encKey),
},
},
}
case ch.UserInvite != nil:
invite := ch.UserInvite
rawAcceptKey, _ := t.keychain.GetKey(invite.AcceptKey).(signingkey.PrivKey).GetPublic().Raw()
hash := t.keychain.GetKey(invite.EncryptionKey).(*SymKey).Hash
encKey := t.keychain.ReadKeysByHash[hash]
convCh = &aclrecordproto.AclContentValue{
Value: &aclrecordproto.AclContentValue_UserInvite{
UserInvite: &aclrecordproto.AclUserInvite{
AcceptPublicKey: rawAcceptKey,
EncryptSymKeyHash: hash,
EncryptedReadKeys: t.encryptReadKeysWithSymKey(invite.EncryptedReadKeys, encKey.Key),
Permissions: t.convertPermission(invite.Permissions),
},
},
}
case ch.UserPermissionChange != nil:
permissionChange := ch.UserPermissionChange
convCh = &aclrecordproto.AclContentValue{
Value: &aclrecordproto.AclContentValue_UserPermissionChange{
UserPermissionChange: &aclrecordproto.AclUserPermissionChange{
Identity: []byte(t.keychain.GetIdentity(permissionChange.Identity)),
Permissions: t.convertPermission(permissionChange.Permission),
},
},
}
case ch.UserRemove != nil:
remove := ch.UserRemove
newReadKey := t.keychain.GetKey(remove.NewReadKey).(*SymKey)
var replaces []*aclrecordproto.AclReadKeyReplace
for _, id := range remove.IdentitiesLeft {
encKey := t.keychain.EncryptionKeysByYAMLName[id]
rawEncKey, _ := encKey.GetPublic().Raw()
encReadKey, err := encKey.GetPublic().Encrypt(newReadKey.Key.Bytes())
if err != nil {
panic(err)
}
replaces = append(replaces, &aclrecordproto.AclReadKeyReplace{
Identity: []byte(t.keychain.GetIdentity(id)),
EncryptionKey: rawEncKey,
EncryptedReadKey: encReadKey,
})
}
convCh = &aclrecordproto.AclContentValue{
Value: &aclrecordproto.AclContentValue_UserRemove{
UserRemove: &aclrecordproto.AclUserRemove{
Identity: []byte(t.keychain.GetIdentity(remove.RemovedIdentity)),
ReadKeyReplaces: replaces,
},
},
}
}
if convCh == nil {
panic("cannot have empty acl change")
}
return convCh
}
func (t *AclListStorageBuilder) encryptReadKeysWithPubKey(keys []string, encKey encryptionkey.PrivKey) (enc [][]byte) {
for _, k := range keys {
realKey := t.keychain.GetKey(k).(*SymKey).Key.Bytes()
res, err := encKey.GetPublic().Encrypt(realKey)
if err != nil {
panic(err)
}
enc = append(enc, res)
}
return
}
func (t *AclListStorageBuilder) encryptReadKeysWithSymKey(keys []string, key *symmetric.Key) (enc [][]byte) {
for _, k := range keys {
realKey := t.keychain.GetKey(k).(*SymKey).Key.Bytes()
res, err := key.Encrypt(realKey)
if err != nil {
panic(err)
}
enc = append(enc, res)
}
return
}
func (t *AclListStorageBuilder) convertPermission(perm string) aclrecordproto.AclUserPermissions {
switch perm {
case "admin":
return aclrecordproto.AclUserPermissions_Admin
case "writer":
return aclrecordproto.AclUserPermissions_Writer
case "reader":
return aclrecordproto.AclUserPermissions_Reader
default:
panic(fmt.Sprintf("incorrect permission: %s", perm))
}
}
func (t *AclListStorageBuilder) traverseFromHead(f func(rec *aclrecordproto.AclRecord, id string) error) (err error) {
panic("this was removed, add if needed")
}
func (t *AclListStorageBuilder) parseRoot(root *Root) (rawRoot *aclrecordproto.RawAclRecordWithId) {
rawSignKey, _ := t.keychain.SigningKeysByYAMLName[root.Identity].GetPublic().Raw()
rawEncKey, _ := t.keychain.EncryptionKeysByYAMLName[root.Identity].GetPublic().Raw()
readKey := t.keychain.ReadKeysByYAMLName[root.Identity]
aclRoot := &aclrecordproto.AclRoot{
Identity: rawSignKey,
EncryptionKey: rawEncKey,
SpaceId: root.SpaceId,
EncryptedReadKey: nil,
DerivationScheme: "scheme",
CurrentReadKeyHash: readKey.Hash,
}
return t.createRaw(aclRoot, rawSignKey)
}

View File

@ -1,11 +0,0 @@
//go:build ((!linux && !darwin) || android || ios || nographviz) && !amd64
// +build !linux,!darwin android ios nographviz
// +build !amd64
package acllistbuilder
import "fmt"
func (t *AclListStorageBuilder) Graph() (string, error) {
return "", fmt.Errorf("building graphs is not supported")
}

View File

@ -1,120 +0,0 @@
//go:build (linux || darwin) && !android && !ios && !nographviz && (amd64 || arm64)
// +build linux darwin
// +build !android
// +build !ios
// +build !nographviz
// +build amd64 arm64
package acllistbuilder
import (
"fmt"
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/gogo/protobuf/proto"
"strings"
"unicode"
"github.com/awalterschulze/gographviz"
)
// To quickly look at visualized string you can use https://dreampuf.github.io/GraphvizOnline
type EdgeParameters struct {
style string
color string
label string
}
func (t *AclListStorageBuilder) Graph() (string, error) {
// TODO: check updates on https://github.com/goccy/go-graphviz/issues/52 or make a fix yourself to use better library here
graph := gographviz.NewGraph()
graph.SetName("G")
graph.SetDir(true)
var nodes = make(map[string]struct{})
var addNodes = func(r *aclrecordproto.AclRecord, id string) error {
style := "solid"
var chSymbs []string
aclData := &aclrecordproto.AclData{}
err := proto.Unmarshal(r.GetData(), aclData)
if err != nil {
return err
}
for _, chc := range aclData.AclContent {
tp := fmt.Sprintf("%T", chc.Value)
tp = strings.Replace(tp, "AclChangeAclContentValueValueOf", "", 1)
res := ""
for _, ts := range tp {
if unicode.IsUpper(ts) {
res += string(ts)
}
}
chSymbs = append(chSymbs, res)
}
shortId := id
label := fmt.Sprintf("Id: %s\nChanges: %s\n",
shortId,
strings.Join(chSymbs, ","),
)
e := graph.AddNode("G", "\""+id+"\"", map[string]string{
"label": "\"" + label + "\"",
"style": "\"" + style + "\"",
})
if e != nil {
return e
}
nodes[id] = struct{}{}
return nil
}
var createEdge = func(firstId, secondId string, params EdgeParameters) error {
_, exists := nodes[firstId]
if !exists {
return fmt.Errorf("no such node")
}
_, exists = nodes[secondId]
if !exists {
return fmt.Errorf("no previous node")
}
err := graph.AddEdge("\""+firstId+"\"", "\""+secondId+"\"", true, map[string]string{
"color": params.color,
"style": params.style,
})
if err != nil {
return err
}
return nil
}
var addLinks = func(r *aclrecordproto.AclRecord, id string) error {
if r.PrevId == "" {
return nil
}
err := createEdge(id, r.PrevId, EdgeParameters{
style: "dashed",
color: "red",
})
if err != nil {
return err
}
return nil
}
err := t.traverseFromHead(addNodes)
if err != nil {
return "", err
}
err = t.traverseFromHead(addLinks)
if err != nil {
return "", err
}
return graph.String(), nil
}

View File

@ -1,70 +0,0 @@
package acllistbuilder
type Key struct {
Name string `yaml:"name"`
Value string `yaml:"value"`
}
type Keys struct {
Derived string `yaml:"Derived"`
Enc []*Key `yaml:"Enc"`
Sign []*Key `yaml:"Sign"`
Read []*Key `yaml:"Read"`
}
type AclChange struct {
UserAdd *struct {
Identity string `yaml:"identity"`
EncryptionKey string `yaml:"encryptionKey"`
EncryptedReadKeys []string `yaml:"encryptedReadKeys"`
Permission string `yaml:"permission"`
} `yaml:"userAdd"`
UserJoin *struct {
Identity string `yaml:"identity"`
EncryptionKey string `yaml:"encryptionKey"`
AcceptKey string `yaml:"acceptKey"`
EncryptedReadKeys []string `yaml:"encryptedReadKeys"`
} `yaml:"userJoin"`
UserInvite *struct {
AcceptKey string `yaml:"acceptKey"`
EncryptionKey string `yaml:"encryptionKey"`
EncryptedReadKeys []string `yaml:"encryptedReadKeys"`
Permissions string `yaml:"permissions"`
} `yaml:"userInvite"`
UserRemove *struct {
RemovedIdentity string `yaml:"removedIdentity"`
NewReadKey string `yaml:"newReadKey"`
IdentitiesLeft []string `yaml:"identitiesLeft"`
} `yaml:"userRemove"`
UserPermissionChange *struct {
Identity string `yaml:"identity"`
Permission string `yaml:"permission"`
}
}
type Record struct {
Identity string `yaml:"identity"`
AclChanges []*AclChange `yaml:"aclChanges"`
ReadKey string `yaml:"readKey"`
}
type Header struct {
FirstChangeId string `yaml:"firstChangeId"`
IsWorkspace bool `yaml:"isWorkspace"`
}
type Root struct {
Identity string `yaml:"identity"`
SpaceId string `yaml:"spaceId"`
}
type YMLList struct {
Root *Root
Records []*Record `yaml:"records"`
Keys Keys `yaml:"keys"`
}

View File

@ -1,15 +0,0 @@
package yamltests
import (
"path/filepath"
"runtime"
)
var (
_, b, _, _ = runtime.Caller(0)
basepath = filepath.Dir(b)
)
func Path() string {
return basepath
}

View File

@ -1,53 +0,0 @@
root:
identity: A
spaceId: space
records:
- identity: A
aclChanges:
- userInvite:
acceptKey: key.Sign.Onetime1
encryptionKey: key.Read.EncKey
encryptedReadKeys: [key.Read.A]
permissions: writer
- userAdd:
identity: C
permission: reader
encryptionKey: key.Enc.C
encryptedReadKeys: [key.Read.A]
readKey: key.Read.A
- identity: B
aclChanges:
- userJoin:
identity: B
encryptionKey: key.Enc.B
acceptKey: key.Sign.Onetime1
encryptedReadKeys: [key.Read.A]
readKey: key.Read.A
keys:
Enc:
- name: A
value: generated
- name: B
value: generated
- name: C
value: generated
- name: D
value: generated
- name: Onetime1
value: generated
Sign:
- name: A
value: generated
- name: B
value: generated
- name: C
value: generated
- name: D
value: generated
- name: Onetime1
value: generated
Read:
- name: A
value: derived
- name: EncKey
value: generated

View File

@ -1,58 +0,0 @@
root:
identity: A
spaceId: space
records:
- identity: A
aclChanges:
- userInvite:
acceptKey: key.Sign.Onetime1
encryptionKey: key.Read.EncKey
encryptedReadKeys: [key.Read.A]
permissions: writer
- userAdd:
identity: C
permission: reader
encryptionKey: key.Enc.C
encryptedReadKeys: [key.Read.A]
readKey: key.Read.A
- identity: B
aclChanges:
- userJoin:
identity: B
encryptionKey: key.Enc.B
acceptKey: key.Sign.Onetime1
encryptedReadKeys: [key.Read.A]
readKey: key.Read.A
- identity: A
aclChanges:
- userRemove:
removedIdentity: B
newReadKey: key.Read.2
identitiesLeft: [A, C]
readKey: key.Read.2
keys:
Enc:
- name: A
value: generated
- name: B
value: generated
- name: C
value: generated
- name: Onetime1
value: generated
Sign:
- name: A
value: generated
- name: B
value: generated
- name: C
value: generated
- name: Onetime1
value: generated
Read:
- name: A
value: derived
- name: 2
value: generated
- name: EncKey
value: generated

View File

@ -1,28 +0,0 @@
package keychain
import (
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
)
type Keychain struct {
keys map[string]signingkey.PubKey
}
func NewKeychain() *Keychain {
return &Keychain{
keys: make(map[string]signingkey.PubKey),
}
}
func (k *Keychain) GetOrAdd(identity string) (signingkey.PubKey, error) {
if key, exists := k.keys[identity]; exists {
return key, nil
}
res, err := signingkey.NewSigningEd25519PubKeyFromBytes([]byte(identity))
if err != nil {
return nil, err
}
k.keys[identity] = res.(signingkey.PubKey)
return res.(signingkey.PubKey), nil
}

View File

@ -2,7 +2,7 @@ package syncobjectgetter
import ( import (
"context" "context"
"github.com/anytypeio/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
) )
type SyncObject interface { type SyncObject interface {

View File

@ -0,0 +1,80 @@
package exporter
import (
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/util/crypto"
)
type DataConverter interface {
Unmarshall(decrypted []byte) (any, error)
Marshall(model any) ([]byte, error)
}
type TreeExporterParams struct {
ListStorageExporter liststorage.Exporter
TreeStorageExporter treestorage.Exporter
DataConverter DataConverter
}
type TreeExporter interface {
ExportUnencrypted(tree objecttree.ReadableObjectTree) (err error)
}
type treeExporter struct {
listExporter liststorage.Exporter
treeExporter treestorage.Exporter
converter DataConverter
}
func NewTreeExporter(params TreeExporterParams) TreeExporter {
return &treeExporter{
listExporter: params.ListStorageExporter,
treeExporter: params.TreeStorageExporter,
converter: params.DataConverter,
}
}
func (t *treeExporter) ExportUnencrypted(tree objecttree.ReadableObjectTree) (err error) {
lst := tree.AclList()
// this exports root which should be enough before we implement acls
_, err = t.listExporter.ListStorage(lst.Root())
if err != nil {
return
}
treeStorage, err := t.treeExporter.TreeStorage(tree.Header())
if err != nil {
return
}
changeBuilder := objecttree.NewChangeBuilder(crypto.NewKeyStorage(), tree.Header())
putStorage := func(change *objecttree.Change) (err error) {
var raw *treechangeproto.RawTreeChangeWithId
raw, err = changeBuilder.Marshall(change)
if err != nil {
return
}
return treeStorage.AddRawChange(raw)
}
err = tree.IterateRoot(t.converter.Unmarshall, func(change *objecttree.Change) bool {
if change.Id == tree.Id() {
err = putStorage(change)
return err == nil
}
var data []byte
data, err = t.converter.Marshall(change.Model)
if err != nil {
return false
}
// that means that change is unencrypted
change.ReadKeyId = ""
change.Data = data
err = putStorage(change)
return err == nil
})
if err != nil {
return
}
return treeStorage.SetHeads(tree.Heads())
}

View File

@ -0,0 +1,28 @@
package exporter
import (
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
)
type TreeImportParams struct {
ListStorage liststorage.ListStorage
TreeStorage treestorage.TreeStorage
BeforeId string
IncludeBeforeId bool
}
func ImportHistoryTree(params TreeImportParams) (tree objecttree.ReadableObjectTree, err error) {
aclList, err := list.BuildAclList(params.ListStorage, list.NoOpAcceptorVerifier{})
if err != nil {
return
}
return objecttree.BuildNonVerifiableHistoryTree(objecttree.HistoryTreeParams{
TreeStorage: params.TreeStorage,
AclList: aclList,
BeforeId: params.BeforeId,
IncludeBeforeId: params.IncludeBeforeId,
})
}

View File

@ -2,7 +2,9 @@ package objecttree
import ( import (
"errors" "errors"
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/util/crypto"
"github.com/gogo/protobuf/proto"
) )
var ( var (
@ -17,46 +19,51 @@ type Change struct {
AclHeadId string AclHeadId string
Id string Id string
SnapshotId string SnapshotId string
IsSnapshot bool
Timestamp int64 Timestamp int64
ReadKeyHash uint64 ReadKeyId string
Identity string Identity crypto.PubKey
Data []byte Data []byte
Model interface{} Model interface{}
Signature []byte
// iterator helpers // iterator helpers
visited bool visited bool
branchesFinished bool branchesFinished bool
IsSnapshot bool
Signature []byte
} }
func NewChange(id string, ch *treechangeproto.TreeChange, signature []byte) *Change { func NewChange(id string, identity crypto.PubKey, ch *treechangeproto.TreeChange, signature []byte) *Change {
return &Change{ return &Change{
Next: nil, Next: nil,
PreviousIds: ch.TreeHeadIds, PreviousIds: ch.TreeHeadIds,
AclHeadId: ch.AclHeadId, AclHeadId: ch.AclHeadId,
Timestamp: ch.Timestamp, Timestamp: ch.Timestamp,
ReadKeyHash: ch.CurrentReadKeyHash, ReadKeyId: ch.ReadKeyId,
Id: id, Id: id,
Data: ch.ChangesData, Data: ch.ChangesData,
SnapshotId: ch.SnapshotBaseId, SnapshotId: ch.SnapshotBaseId,
IsSnapshot: ch.IsSnapshot, IsSnapshot: ch.IsSnapshot,
Identity: string(ch.Identity), Identity: identity,
Signature: signature, Signature: signature,
} }
} }
func NewChangeFromRoot(id string, ch *treechangeproto.RootChange, signature []byte) *Change { func NewChangeFromRoot(id string, identity crypto.PubKey, ch *treechangeproto.RootChange, signature []byte) *Change {
changeInfo := &treechangeproto.TreeChangeInfo{
ChangeType: ch.ChangeType,
ChangePayload: ch.ChangePayload,
}
data, _ := proto.Marshal(changeInfo)
return &Change{ return &Change{
Next: nil, Next: nil,
AclHeadId: ch.AclHeadId, AclHeadId: ch.AclHeadId,
Id: id, Id: id,
IsSnapshot: true, IsSnapshot: true,
Timestamp: ch.Timestamp, Timestamp: ch.Timestamp,
Identity: string(ch.Identity), Identity: identity,
Signature: signature, Signature: signature,
Data: []byte(ch.ChangeType), Data: data,
Model: changeInfo,
} }
} }

View File

@ -2,13 +2,10 @@ package objecttree
import ( import (
"errors" "errors"
"github.com/anytypeio/any-sync/commonspace/object/keychain" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/util/cidutil"
"github.com/anytypeio/any-sync/util/cidutil" "github.com/anyproto/any-sync/util/crypto"
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
"github.com/anytypeio/any-sync/util/keys/symmetric"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"time"
) )
var ErrEmptyChange = errors.New("change payload should not be empty") var ErrEmptyChange = errors.New("change payload should not be empty")
@ -17,42 +14,64 @@ type BuilderContent struct {
TreeHeadIds []string TreeHeadIds []string
AclHeadId string AclHeadId string
SnapshotBaseId string SnapshotBaseId string
CurrentReadKeyHash uint64 ReadKeyId string
Identity []byte
IsSnapshot bool IsSnapshot bool
SigningKey signingkey.PrivKey PrivKey crypto.PrivKey
ReadKey *symmetric.Key ReadKey crypto.SymKey
Content []byte Content []byte
Timestamp int64
} }
type InitialContent struct { type InitialContent struct {
AclHeadId string AclHeadId string
Identity []byte PrivKey crypto.PrivKey
SigningKey signingkey.PrivKey
SpaceId string SpaceId string
Seed []byte Seed []byte
ChangeType string ChangeType string
ChangePayload []byte
Timestamp int64 Timestamp int64
} }
type ChangeBuilder interface { type nonVerifiableChangeBuilder struct {
ConvertFromRaw(rawIdChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error) ChangeBuilder
BuildContent(payload BuilderContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error)
BuildInitialContent(payload InitialContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error)
BuildRaw(ch *Change) (*treechangeproto.RawTreeChangeWithId, error)
SetRootRawChange(rawIdChange *treechangeproto.RawTreeChangeWithId)
} }
func (c *nonVerifiableChangeBuilder) BuildRoot(payload InitialContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error) {
return c.ChangeBuilder.BuildRoot(payload)
}
func (c *nonVerifiableChangeBuilder) Unmarshall(rawChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error) {
return c.ChangeBuilder.Unmarshall(rawChange, false)
}
func (c *nonVerifiableChangeBuilder) Build(payload BuilderContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error) {
return c.ChangeBuilder.Build(payload)
}
func (c *nonVerifiableChangeBuilder) Marshall(ch *Change) (raw *treechangeproto.RawTreeChangeWithId, err error) {
return c.ChangeBuilder.Marshall(ch)
}
type ChangeBuilder interface {
Unmarshall(rawIdChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error)
Build(payload BuilderContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error)
BuildRoot(payload InitialContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error)
Marshall(ch *Change) (*treechangeproto.RawTreeChangeWithId, error)
}
type newChangeFunc = func(id string, identity crypto.PubKey, ch *treechangeproto.TreeChange, signature []byte) *Change
type changeBuilder struct { type changeBuilder struct {
rootChange *treechangeproto.RawTreeChangeWithId rootChange *treechangeproto.RawTreeChangeWithId
keys *keychain.Keychain keys crypto.KeyStorage
newChange newChangeFunc
} }
func NewChangeBuilder(keys *keychain.Keychain, rootChange *treechangeproto.RawTreeChangeWithId) ChangeBuilder { func NewChangeBuilder(keys crypto.KeyStorage, rootChange *treechangeproto.RawTreeChangeWithId) ChangeBuilder {
return &changeBuilder{keys: keys, rootChange: rootChange} return &changeBuilder{keys: keys, rootChange: rootChange, newChange: NewChange}
} }
func (c *changeBuilder) ConvertFromRaw(rawIdChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error) { func (c *changeBuilder) Unmarshall(rawIdChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error) {
if rawIdChange.GetRawChange() == nil { if rawIdChange.GetRawChange() == nil {
err = ErrEmptyChange err = ErrEmptyChange
return return
@ -77,15 +96,9 @@ func (c *changeBuilder) ConvertFromRaw(rawIdChange *treechangeproto.RawTreeChang
} }
if verify { if verify {
var identityKey signingkey.PubKey
identityKey, err = c.keys.GetOrAdd(ch.Identity)
if err != nil {
return
}
// verifying signature // verifying signature
var res bool var res bool
res, err = identityKey.Verify(raw.Payload, raw.Signature) res, err = ch.Identity.Verify(raw.Payload, raw.Signature)
if err != nil { if err != nil {
return return
} }
@ -101,43 +114,41 @@ func (c *changeBuilder) SetRootRawChange(rawIdChange *treechangeproto.RawTreeCha
c.rootChange = rawIdChange c.rootChange = rawIdChange
} }
func (c *changeBuilder) BuildInitialContent(payload InitialContent) (ch *Change, rawIdChange *treechangeproto.RawTreeChangeWithId, err error) { func (c *changeBuilder) BuildRoot(payload InitialContent) (ch *Change, rawIdChange *treechangeproto.RawTreeChangeWithId, err error) {
identity, err := payload.PrivKey.GetPublic().Marshall()
if err != nil {
return
}
change := &treechangeproto.RootChange{ change := &treechangeproto.RootChange{
AclHeadId: payload.AclHeadId, AclHeadId: payload.AclHeadId,
Timestamp: payload.Timestamp, Timestamp: payload.Timestamp,
Identity: payload.Identity, Identity: identity,
ChangeType: payload.ChangeType, ChangeType: payload.ChangeType,
ChangePayload: payload.ChangePayload,
SpaceId: payload.SpaceId, SpaceId: payload.SpaceId,
Seed: payload.Seed, Seed: payload.Seed,
} }
marshalledChange, err := proto.Marshal(change) marshalledChange, err := proto.Marshal(change)
if err != nil { if err != nil {
return return
} }
signature, err := payload.PrivKey.Sign(marshalledChange)
signature, err := payload.SigningKey.Sign(marshalledChange)
if err != nil { if err != nil {
return return
} }
raw := &treechangeproto.RawTreeChange{ raw := &treechangeproto.RawTreeChange{
Payload: marshalledChange, Payload: marshalledChange,
Signature: signature, Signature: signature,
} }
marshalledRawChange, err := proto.Marshal(raw) marshalledRawChange, err := proto.Marshal(raw)
if err != nil { if err != nil {
return return
} }
id, err := cidutil.NewCidFromBytes(marshalledRawChange) id, err := cidutil.NewCidFromBytes(marshalledRawChange)
if err != nil { if err != nil {
return return
} }
ch = NewChangeFromRoot(id, payload.PrivKey.GetPublic(), change, signature)
ch = NewChangeFromRoot(id, change, signature)
rawIdChange = &treechangeproto.RawTreeChangeWithId{ rawIdChange = &treechangeproto.RawTreeChangeWithId{
RawChange: marshalledRawChange, RawChange: marshalledRawChange,
Id: id, Id: id,
@ -145,14 +156,18 @@ func (c *changeBuilder) BuildInitialContent(payload InitialContent) (ch *Change,
return return
} }
func (c *changeBuilder) BuildContent(payload BuilderContent) (ch *Change, rawIdChange *treechangeproto.RawTreeChangeWithId, err error) { func (c *changeBuilder) Build(payload BuilderContent) (ch *Change, rawIdChange *treechangeproto.RawTreeChangeWithId, err error) {
identity, err := payload.PrivKey.GetPublic().Marshall()
if err != nil {
return
}
change := &treechangeproto.TreeChange{ change := &treechangeproto.TreeChange{
TreeHeadIds: payload.TreeHeadIds, TreeHeadIds: payload.TreeHeadIds,
AclHeadId: payload.AclHeadId, AclHeadId: payload.AclHeadId,
SnapshotBaseId: payload.SnapshotBaseId, SnapshotBaseId: payload.SnapshotBaseId,
CurrentReadKeyHash: payload.CurrentReadKeyHash, ReadKeyId: payload.ReadKeyId,
Timestamp: time.Now().UnixNano(), Timestamp: payload.Timestamp,
Identity: payload.Identity, Identity: identity,
IsSnapshot: payload.IsSnapshot, IsSnapshot: payload.IsSnapshot,
} }
if payload.ReadKey != nil { if payload.ReadKey != nil {
@ -165,34 +180,27 @@ func (c *changeBuilder) BuildContent(payload BuilderContent) (ch *Change, rawIdC
} else { } else {
change.ChangesData = payload.Content change.ChangesData = payload.Content
} }
marshalledChange, err := proto.Marshal(change) marshalledChange, err := proto.Marshal(change)
if err != nil { if err != nil {
return return
} }
signature, err := payload.PrivKey.Sign(marshalledChange)
signature, err := payload.SigningKey.Sign(marshalledChange)
if err != nil { if err != nil {
return return
} }
raw := &treechangeproto.RawTreeChange{ raw := &treechangeproto.RawTreeChange{
Payload: marshalledChange, Payload: marshalledChange,
Signature: signature, Signature: signature,
} }
marshalledRawChange, err := proto.Marshal(raw) marshalledRawChange, err := proto.Marshal(raw)
if err != nil { if err != nil {
return return
} }
id, err := cidutil.NewCidFromBytes(marshalledRawChange) id, err := cidutil.NewCidFromBytes(marshalledRawChange)
if err != nil { if err != nil {
return return
} }
ch = c.newChange(id, payload.PrivKey.GetPublic(), change, signature)
ch = NewChange(id, change, signature)
rawIdChange = &treechangeproto.RawTreeChangeWithId{ rawIdChange = &treechangeproto.RawTreeChangeWithId{
RawChange: marshalledRawChange, RawChange: marshalledRawChange,
Id: id, Id: id,
@ -200,18 +208,22 @@ func (c *changeBuilder) BuildContent(payload BuilderContent) (ch *Change, rawIdC
return return
} }
func (c *changeBuilder) BuildRaw(ch *Change) (raw *treechangeproto.RawTreeChangeWithId, err error) { func (c *changeBuilder) Marshall(ch *Change) (raw *treechangeproto.RawTreeChangeWithId, err error) {
if ch.Id == c.rootChange.Id { if c.isRoot(ch.Id) {
return c.rootChange, nil return c.rootChange, nil
} }
identity, err := ch.Identity.Marshall()
if err != nil {
return
}
treeChange := &treechangeproto.TreeChange{ treeChange := &treechangeproto.TreeChange{
TreeHeadIds: ch.PreviousIds, TreeHeadIds: ch.PreviousIds,
AclHeadId: ch.AclHeadId, AclHeadId: ch.AclHeadId,
SnapshotBaseId: ch.SnapshotId, SnapshotBaseId: ch.SnapshotId,
ChangesData: ch.Data, ChangesData: ch.Data,
CurrentReadKeyHash: ch.ReadKeyHash, ReadKeyId: ch.ReadKeyId,
Timestamp: ch.Timestamp, Timestamp: ch.Timestamp,
Identity: []byte(ch.Identity), Identity: identity,
IsSnapshot: ch.IsSnapshot, IsSnapshot: ch.IsSnapshot,
} }
var marshalled []byte var marshalled []byte
@ -236,22 +248,36 @@ func (c *changeBuilder) BuildRaw(ch *Change) (raw *treechangeproto.RawTreeChange
} }
func (c *changeBuilder) unmarshallRawChange(raw *treechangeproto.RawTreeChange, id string) (ch *Change, err error) { func (c *changeBuilder) unmarshallRawChange(raw *treechangeproto.RawTreeChange, id string) (ch *Change, err error) {
if c.rootChange.Id == id { var key crypto.PubKey
if c.isRoot(id) {
unmarshalled := &treechangeproto.RootChange{} unmarshalled := &treechangeproto.RootChange{}
err = proto.Unmarshal(raw.Payload, unmarshalled) err = proto.Unmarshal(raw.Payload, unmarshalled)
if err != nil { if err != nil {
return return
} }
ch = NewChangeFromRoot(id, unmarshalled, raw.Signature) key, err = c.keys.PubKeyFromProto(unmarshalled.Identity)
if err != nil {
return
}
ch = NewChangeFromRoot(id, key, unmarshalled, raw.Signature)
return return
} }
unmarshalled := &treechangeproto.TreeChange{} unmarshalled := &treechangeproto.TreeChange{}
err = proto.Unmarshal(raw.Payload, unmarshalled) err = proto.Unmarshal(raw.Payload, unmarshalled)
if err != nil { if err != nil {
return return
} }
key, err = c.keys.PubKeyFromProto(unmarshalled.Identity)
ch = NewChange(id, unmarshalled, raw.Signature) if err != nil {
return
}
ch = c.newChange(id, key, unmarshalled, raw.Signature)
return return
} }
func (c *changeBuilder) isRoot(id string) bool {
if c.rootChange != nil {
return c.rootChange.Id == id
}
return false
}

View File

@ -14,39 +14,47 @@ type historyTree struct {
*objectTree *objectTree
} }
func (h *historyTree) rebuildFromStorage(beforeId string, include bool) (err error) { func (h *historyTree) rebuildFromStorage(params HistoryTreeParams) (err error) {
ot := h.objectTree err = h.rebuild(params)
ot.treeBuilder.Reset() if err != nil {
if beforeId == ot.Id() && !include { return
}
h.aclList.RLock()
defer h.aclList.RUnlock()
state := h.aclList.AclState()
return h.readKeysFromAclState(state)
}
func (h *historyTree) rebuild(params HistoryTreeParams) (err error) {
var (
beforeId = params.BeforeId
include = params.IncludeBeforeId
full = params.BuildFullTree
)
h.treeBuilder.Reset()
if full {
h.tree, err = h.treeBuilder.BuildFull()
return
}
if beforeId == h.Id() && !include {
return ErrLoadBeforeRoot return ErrLoadBeforeRoot
} }
heads := []string{beforeId} heads := []string{beforeId}
if beforeId == "" { if beforeId == "" {
heads, err = ot.treeStorage.Heads() heads, err = h.treeStorage.Heads()
if err != nil { if err != nil {
return return
} }
} else if !include { } else if !include {
beforeChange, err := ot.treeBuilder.loadChange(beforeId) beforeChange, err := h.treeBuilder.loadChange(beforeId)
if err != nil { if err != nil {
return err return err
} }
heads = beforeChange.PreviousIds heads = beforeChange.PreviousIds
} }
ot.tree, err = ot.treeBuilder.build(heads, nil, nil) h.tree, err = h.treeBuilder.build(heads, nil, nil)
if err != nil {
return
}
ot.aclList.RLock()
defer ot.aclList.RUnlock()
state := ot.aclList.AclState()
if len(ot.keys) != len(state.UserReadKeys()) {
for key, value := range state.UserReadKeys() {
ot.keys[key] = value
}
}
return return
} }

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anytypeio/any-sync/commonspace/object/tree/objecttree (interfaces: ObjectTree) // Source: github.com/anyproto/any-sync/commonspace/object/tree/objecttree (interfaces: ObjectTree)
// Package mock_objecttree is a generated GoMock package. // Package mock_objecttree is a generated GoMock package.
package mock_objecttree package mock_objecttree
@ -7,11 +7,13 @@ package mock_objecttree
import ( import (
context "context" context "context"
reflect "reflect" reflect "reflect"
time "time"
objecttree "github.com/anytypeio/any-sync/commonspace/object/tree/objecttree" list "github.com/anyproto/any-sync/commonspace/object/acl/list"
treechangeproto "github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto" objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
treestorage "github.com/anytypeio/any-sync/commonspace/object/tree/treestorage" treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
gomock "github.com/golang/mock/gomock" treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
gomock "go.uber.org/mock/gomock"
) )
// MockObjectTree is a mock of ObjectTree interface. // MockObjectTree is a mock of ObjectTree interface.
@ -37,6 +39,20 @@ func (m *MockObjectTree) EXPECT() *MockObjectTreeMockRecorder {
return m.recorder return m.recorder
} }
// AclList mocks base method.
func (m *MockObjectTree) AclList() list.AclList {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AclList")
ret0, _ := ret[0].(list.AclList)
return ret0
}
// AclList indicates an expected call of AclList.
func (mr *MockObjectTreeMockRecorder) AclList() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AclList", reflect.TypeOf((*MockObjectTree)(nil).AclList))
}
// AddContent mocks base method. // AddContent mocks base method.
func (m *MockObjectTree) AddContent(arg0 context.Context, arg1 objecttree.SignableChangeContent) (objecttree.AddResult, error) { func (m *MockObjectTree) AddContent(arg0 context.Context, arg1 objecttree.SignableChangeContent) (objecttree.AddResult, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -67,6 +83,20 @@ func (mr *MockObjectTreeMockRecorder) AddRawChanges(arg0, arg1 interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawChanges", reflect.TypeOf((*MockObjectTree)(nil).AddRawChanges), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawChanges", reflect.TypeOf((*MockObjectTree)(nil).AddRawChanges), arg0, arg1)
} }
// ChangeInfo mocks base method.
func (m *MockObjectTree) ChangeInfo() *treechangeproto.TreeChangeInfo {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ChangeInfo")
ret0, _ := ret[0].(*treechangeproto.TreeChangeInfo)
return ret0
}
// ChangeInfo indicates an expected call of ChangeInfo.
func (mr *MockObjectTreeMockRecorder) ChangeInfo() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeInfo", reflect.TypeOf((*MockObjectTree)(nil).ChangeInfo))
}
// ChangesAfterCommonSnapshot mocks base method. // ChangesAfterCommonSnapshot mocks base method.
func (m *MockObjectTree) ChangesAfterCommonSnapshot(arg0, arg1 []string) ([]*treechangeproto.RawTreeChangeWithId, error) { func (m *MockObjectTree) ChangesAfterCommonSnapshot(arg0, arg1 []string) ([]*treechangeproto.RawTreeChangeWithId, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -96,19 +126,19 @@ func (mr *MockObjectTreeMockRecorder) Close() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockObjectTree)(nil).Close)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockObjectTree)(nil).Close))
} }
// DebugDump mocks base method. // Debug mocks base method.
func (m *MockObjectTree) DebugDump(arg0 objecttree.DescriptionParser) (string, error) { func (m *MockObjectTree) Debug(arg0 objecttree.DescriptionParser) (objecttree.DebugInfo, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DebugDump", arg0) ret := m.ctrl.Call(m, "Debug", arg0)
ret0, _ := ret[0].(string) ret0, _ := ret[0].(objecttree.DebugInfo)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// DebugDump indicates an expected call of DebugDump. // Debug indicates an expected call of Debug.
func (mr *MockObjectTreeMockRecorder) DebugDump(arg0 interface{}) *gomock.Call { func (mr *MockObjectTreeMockRecorder) Debug(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugDump", reflect.TypeOf((*MockObjectTree)(nil).DebugDump), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockObjectTree)(nil).Debug), arg0)
} }
// Delete mocks base method. // Delete mocks base method.
@ -228,6 +258,20 @@ func (mr *MockObjectTreeMockRecorder) IterateRoot(arg0, arg1 interface{}) *gomoc
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateRoot", reflect.TypeOf((*MockObjectTree)(nil).IterateRoot), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateRoot", reflect.TypeOf((*MockObjectTree)(nil).IterateRoot), arg0, arg1)
} }
// Len mocks base method.
func (m *MockObjectTree) Len() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Len")
ret0, _ := ret[0].(int)
return ret0
}
// Len indicates an expected call of Len.
func (mr *MockObjectTreeMockRecorder) Len() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockObjectTree)(nil).Len))
}
// Lock mocks base method. // Lock mocks base method.
func (m *MockObjectTree) Lock() { func (m *MockObjectTree) Lock() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -240,6 +284,21 @@ func (mr *MockObjectTreeMockRecorder) Lock() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lock", reflect.TypeOf((*MockObjectTree)(nil).Lock)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lock", reflect.TypeOf((*MockObjectTree)(nil).Lock))
} }
// PrepareChange mocks base method.
func (m *MockObjectTree) PrepareChange(arg0 objecttree.SignableChangeContent) (*treechangeproto.RawTreeChangeWithId, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PrepareChange", arg0)
ret0, _ := ret[0].(*treechangeproto.RawTreeChangeWithId)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PrepareChange indicates an expected call of PrepareChange.
func (mr *MockObjectTreeMockRecorder) PrepareChange(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrepareChange", reflect.TypeOf((*MockObjectTree)(nil).PrepareChange), arg0)
}
// RLock mocks base method. // RLock mocks base method.
func (m *MockObjectTree) RLock() { func (m *MockObjectTree) RLock() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -306,6 +365,49 @@ func (mr *MockObjectTreeMockRecorder) Storage() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Storage", reflect.TypeOf((*MockObjectTree)(nil).Storage)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Storage", reflect.TypeOf((*MockObjectTree)(nil).Storage))
} }
// TryClose mocks base method.
func (m *MockObjectTree) TryClose(arg0 time.Duration) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryClose", arg0)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// TryClose indicates an expected call of TryClose.
func (mr *MockObjectTreeMockRecorder) TryClose(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryClose", reflect.TypeOf((*MockObjectTree)(nil).TryClose), arg0)
}
// TryLock mocks base method.
func (m *MockObjectTree) TryLock() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryLock")
ret0, _ := ret[0].(bool)
return ret0
}
// TryLock indicates an expected call of TryLock.
func (mr *MockObjectTreeMockRecorder) TryLock() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryLock", reflect.TypeOf((*MockObjectTree)(nil).TryLock))
}
// TryRLock mocks base method.
func (m *MockObjectTree) TryRLock() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryRLock")
ret0, _ := ret[0].(bool)
return ret0
}
// TryRLock indicates an expected call of TryRLock.
func (mr *MockObjectTreeMockRecorder) TryRLock() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryRLock", reflect.TypeOf((*MockObjectTree)(nil).TryRLock))
}
// Unlock mocks base method. // Unlock mocks base method.
func (m *MockObjectTree) Unlock() { func (m *MockObjectTree) Unlock() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -331,3 +433,18 @@ func (mr *MockObjectTreeMockRecorder) UnmarshalledHeader() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnmarshalledHeader", reflect.TypeOf((*MockObjectTree)(nil).UnmarshalledHeader)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnmarshalledHeader", reflect.TypeOf((*MockObjectTree)(nil).UnmarshalledHeader))
} }
// UnpackChange mocks base method.
func (m *MockObjectTree) UnpackChange(arg0 *treechangeproto.RawTreeChangeWithId) ([]byte, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UnpackChange", arg0)
ret0, _ := ret[0].([]byte)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UnpackChange indicates an expected call of UnpackChange.
func (mr *MockObjectTreeMockRecorder) UnpackChange(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackChange", reflect.TypeOf((*MockObjectTree)(nil).UnpackChange), arg0)
}

View File

@ -1,29 +1,33 @@
//go:generate mockgen -destination mock_objecttree/mock_objecttree.go github.com/anytypeio/any-sync/commonspace/object/tree/objecttree ObjectTree //go:generate mockgen -destination mock_objecttree/mock_objecttree.go github.com/anyproto/any-sync/commonspace/object/tree/objecttree ObjectTree
package objecttree package objecttree
import ( import (
"context" "context"
"errors" "errors"
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
list "github.com/anytypeio/any-sync/commonspace/object/acl/list"
"github.com/anytypeio/any-sync/commonspace/object/keychain"
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage"
"github.com/anytypeio/any-sync/util/keys/symmetric"
"github.com/anytypeio/any-sync/util/slice"
"sync" "sync"
"time"
"github.com/anyproto/any-sync/util/crypto"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/util/slice"
) )
type RWLocker interface { type RWLocker interface {
sync.Locker sync.Locker
RLock() RLock()
RUnlock() RUnlock()
TryRLock() bool
TryLock() bool
} }
var ( var (
ErrHasInvalidChanges = errors.New("the change is invalid") ErrHasInvalidChanges = errors.New("the change is invalid")
ErrNoCommonSnapshot = errors.New("trees doesn't have a common snapshot") ErrNoCommonSnapshot = errors.New("trees doesn't have a common snapshot")
ErrNoChangeInTree = errors.New("no such change in tree") ErrNoChangeInTree = errors.New("no such change in tree")
ErrMissingKey = errors.New("missing current read key")
) )
type AddResultSummary int type AddResultSummary int
@ -50,13 +54,17 @@ type ReadableObjectTree interface {
Id() string Id() string
Header() *treechangeproto.RawTreeChangeWithId Header() *treechangeproto.RawTreeChangeWithId
UnmarshalledHeader() *Change UnmarshalledHeader() *Change
ChangeInfo() *treechangeproto.TreeChangeInfo
Heads() []string Heads() []string
Root() *Change Root() *Change
Len() int
AclList() list.AclList
HasChanges(...string) bool HasChanges(...string) bool
GetChange(string) (*Change, error) GetChange(string) (*Change, error)
DebugDump(parser DescriptionParser) (string, error) Debug(parser DescriptionParser) (DebugInfo, error)
IterateRoot(convert ChangeConvertFunc, iterate ChangeIterateFunc) error IterateRoot(convert ChangeConvertFunc, iterate ChangeIterateFunc) error
IterateFrom(id string, convert ChangeConvertFunc, iterate ChangeIterateFunc) error IterateFrom(id string, convert ChangeConvertFunc, iterate ChangeIterateFunc) error
} }
@ -72,8 +80,12 @@ type ObjectTree interface {
AddContent(ctx context.Context, content SignableChangeContent) (AddResult, error) AddContent(ctx context.Context, content SignableChangeContent) (AddResult, error)
AddRawChanges(ctx context.Context, changes RawChangesPayload) (AddResult, error) AddRawChanges(ctx context.Context, changes RawChangesPayload) (AddResult, error)
UnpackChange(raw *treechangeproto.RawTreeChangeWithId) (data []byte, err error)
PrepareChange(content SignableChangeContent) (res *treechangeproto.RawTreeChangeWithId, err error)
Delete() error Delete() error
Close() error Close() error
TryClose(objectTTL time.Duration) (bool, error)
} }
type objectTree struct { type objectTree struct {
@ -89,7 +101,8 @@ type objectTree struct {
root *Change root *Change
tree *Tree tree *Tree
keys map[uint64]*symmetric.Key keys map[string]crypto.SymKey
currentReadKey crypto.SymKey
// buffers // buffers
difSnapshotBuf []*treechangeproto.RawTreeChangeWithId difSnapshotBuf []*treechangeproto.RawTreeChangeWithId
@ -102,41 +115,26 @@ type objectTree struct {
sync.RWMutex sync.RWMutex
} }
type objectTreeDeps struct {
changeBuilder ChangeBuilder
treeBuilder *treeBuilder
treeStorage treestorage.TreeStorage
validator ObjectTreeValidator
rawChangeLoader *rawChangeLoader
aclList list.AclList
}
func defaultObjectTreeDeps(
rootChange *treechangeproto.RawTreeChangeWithId,
treeStorage treestorage.TreeStorage,
aclList list.AclList) objectTreeDeps {
keychain := keychain.NewKeychain()
changeBuilder := NewChangeBuilder(keychain, rootChange)
treeBuilder := newTreeBuilder(treeStorage, changeBuilder)
return objectTreeDeps{
changeBuilder: changeBuilder,
treeBuilder: treeBuilder,
treeStorage: treeStorage,
validator: newTreeValidator(),
rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder),
aclList: aclList,
}
}
func (ot *objectTree) rebuildFromStorage(theirHeads []string, newChanges []*Change) (err error) { func (ot *objectTree) rebuildFromStorage(theirHeads []string, newChanges []*Change) (err error) {
oldTree := ot.tree
ot.treeBuilder.Reset() ot.treeBuilder.Reset()
ot.tree, err = ot.treeBuilder.Build(theirHeads, newChanges) ot.tree, err = ot.treeBuilder.Build(theirHeads, newChanges)
if err != nil { if err != nil {
return return
} }
// in case there are new heads
if theirHeads != nil && oldTree != nil {
// checking that old root is still in tree
rootCh, rootExists := ot.tree.attached[oldTree.RootId()]
// checking the case where theirHeads were actually below prevHeads
// so if we did load some extra data in the tree, let's reduce it to old root
if slice.UnsortedEquals(oldTree.headIds, ot.tree.headIds) && rootExists && ot.tree.RootId() != oldTree.RootId() {
ot.tree.makeRootAndRemove(rootCh)
}
}
// during building the tree we may have marked some changes as possible roots, // during building the tree we may have marked some changes as possible roots,
// but obviously they are not roots, because of the way how we construct the tree // but obviously they are not roots, because of the way how we construct the tree
ot.tree.clearPossibleRoots() ot.tree.clearPossibleRoots()
@ -150,6 +148,14 @@ func (ot *objectTree) Id() string {
return ot.id return ot.id
} }
func (ot *objectTree) Len() int {
return ot.tree.Len()
}
func (ot *objectTree) AclList() list.AclList {
return ot.aclList
}
func (ot *objectTree) Header() *treechangeproto.RawTreeChangeWithId { func (ot *objectTree) Header() *treechangeproto.RawTreeChangeWithId {
return ot.rawRoot return ot.rawRoot
} }
@ -158,6 +164,10 @@ func (ot *objectTree) UnmarshalledHeader() *Change {
return ot.root return ot.root
} }
func (ot *objectTree) ChangeInfo() *treechangeproto.TreeChangeInfo {
return ot.root.Model.(*treechangeproto.TreeChangeInfo)
}
func (ot *objectTree) Storage() treestorage.TreeStorage { func (ot *objectTree) Storage() treestorage.TreeStorage {
return ot.treeStorage return ot.treeStorage
} }
@ -179,7 +189,7 @@ func (ot *objectTree) AddContent(ctx context.Context, content SignableChangeCont
oldHeads := make([]string, 0, len(ot.tree.Heads())) oldHeads := make([]string, 0, len(ot.tree.Heads()))
oldHeads = append(oldHeads, ot.tree.Heads()...) oldHeads = append(oldHeads, ot.tree.Heads()...)
objChange, rawChange, err := ot.changeBuilder.BuildContent(payload) objChange, rawChange, err := ot.changeBuilder.Build(payload)
if content.IsSnapshot { if content.IsSnapshot {
// clearing tree, because we already saved everything in the last snapshot // clearing tree, because we already saved everything in the last snapshot
ot.tree = &Tree{} ot.tree = &Tree{}
@ -189,7 +199,7 @@ func (ot *objectTree) AddContent(ctx context.Context, content SignableChangeCont
panic(err) panic(err)
} }
err = ot.treeStorage.TransactionAdd([]*treechangeproto.RawTreeChangeWithId{rawChange}, []string{objChange.Id}) err = ot.treeStorage.AddRawChangesSetHeads([]*treechangeproto.RawTreeChangeWithId{rawChange}, []string{objChange.Id})
if err != nil { if err != nil {
return return
} }
@ -210,39 +220,61 @@ func (ot *objectTree) AddContent(ctx context.Context, content SignableChangeCont
return return
} }
func (ot *objectTree) UnpackChange(raw *treechangeproto.RawTreeChangeWithId) (data []byte, err error) {
unmarshalled, err := ot.changeBuilder.Unmarshall(raw, true)
if err != nil {
return
}
data = unmarshalled.Data
return
}
func (ot *objectTree) PrepareChange(content SignableChangeContent) (res *treechangeproto.RawTreeChangeWithId, err error) {
payload, err := ot.prepareBuilderContent(content)
if err != nil {
return
}
_, res, err = ot.changeBuilder.Build(payload)
return
}
func (ot *objectTree) prepareBuilderContent(content SignableChangeContent) (cnt BuilderContent, err error) { func (ot *objectTree) prepareBuilderContent(content SignableChangeContent) (cnt BuilderContent, err error) {
ot.aclList.RLock() ot.aclList.RLock()
defer ot.aclList.RUnlock() defer ot.aclList.RUnlock()
var ( var (
state = ot.aclList.AclState() // special method for own keys state = ot.aclList.AclState() // special method for own keys
readKey *symmetric.Key readKey crypto.SymKey
readKeyHash uint64 pubKey = content.Key.GetPublic()
readKeyId string
) )
canWrite := state.HasPermission(content.Identity, aclrecordproto.AclUserPermissions_Writer) || if !state.Permissions(pubKey).CanWrite() {
state.HasPermission(content.Identity, aclrecordproto.AclUserPermissions_Admin)
if !canWrite {
err = list.ErrInsufficientPermissions err = list.ErrInsufficientPermissions
return return
} }
if content.IsEncrypted { if content.IsEncrypted {
readKeyHash = state.CurrentReadKeyHash() readKeyId = state.CurrentReadKeyId()
readKey, err = state.CurrentReadKey() if ot.currentReadKey == nil {
if err != nil { err = ErrMissingKey
return return
} }
readKey = ot.currentReadKey
}
timestamp := content.Timestamp
if timestamp <= 0 {
timestamp = time.Now().Unix()
} }
cnt = BuilderContent{ cnt = BuilderContent{
TreeHeadIds: ot.tree.Heads(), TreeHeadIds: ot.tree.Heads(),
AclHeadId: ot.aclList.Head().Id, AclHeadId: ot.aclList.Head().Id,
SnapshotBaseId: ot.tree.RootId(), SnapshotBaseId: ot.tree.RootId(),
CurrentReadKeyHash: readKeyHash, ReadKeyId: readKeyId,
Identity: content.Identity,
IsSnapshot: content.IsSnapshot, IsSnapshot: content.IsSnapshot,
SigningKey: content.Key, PrivKey: content.Key,
ReadKey: readKey, ReadKey: readKey,
Content: content.Data, Content: content.Data,
Timestamp: timestamp,
} }
return return
} }
@ -262,7 +294,11 @@ func (ot *objectTree) AddRawChanges(ctx context.Context, changesPayload RawChang
addResult.Mode = Rebuild addResult.Mode = Rebuild
} }
err = ot.treeStorage.TransactionAdd(addResult.Added, addResult.Heads) err = ot.treeStorage.AddRawChangesSetHeads(addResult.Added, addResult.Heads)
if err != nil {
// rolling back all changes made to inmemory state
ot.rebuildFromStorage(nil, nil)
}
return return
} }
@ -293,7 +329,7 @@ func (ot *objectTree) addRawChanges(ctx context.Context, changesPayload RawChang
if unAttached, exists := ot.tree.unAttached[ch.Id]; exists { if unAttached, exists := ot.tree.unAttached[ch.Id]; exists {
change = unAttached change = unAttached
} else { } else {
change, err = ot.changeBuilder.ConvertFromRaw(ch, true) change, err = ot.changeBuilder.Unmarshall(ch, true)
if err != nil { if err != nil {
return return
} }
@ -325,7 +361,7 @@ func (ot *objectTree) addRawChanges(ctx context.Context, changesPayload RawChang
} }
// checks if we need to go to database // checks if we need to go to database
isOldSnapshot := func(ch *Change) bool { snapshotNotInTree := func(ch *Change) bool {
if ch.SnapshotId == ot.tree.RootId() { if ch.SnapshotId == ot.tree.RootId() {
return false return false
} }
@ -340,26 +376,12 @@ func (ot *objectTree) addRawChanges(ctx context.Context, changesPayload RawChang
shouldRebuildFromStorage := false shouldRebuildFromStorage := false
// checking if we have some changes with different snapshot and then rebuilding // checking if we have some changes with different snapshot and then rebuilding
for idx, ch := range ot.newChangesBuf { for _, ch := range ot.newChangesBuf {
if isOldSnapshot(ch) { if snapshotNotInTree(ch) {
var exists bool
// checking if it exists in the storage, if yes, then at some point it was added to the tree
// thus we don't need to look at this change
exists, err = ot.treeStorage.HasChange(ctx, ch.Id)
if err != nil {
return
}
if exists {
// marking as nil to delete after
ot.newChangesBuf[idx] = nil
continue
}
// we haven't seen the change, and it refers to old snapshot, so we should rebuild
shouldRebuildFromStorage = true shouldRebuildFromStorage = true
break
} }
} }
// discarding all previously seen changes
ot.newChangesBuf = slice.DiscardFromSlice(ot.newChangesBuf, func(ch *Change) bool { return ch == nil })
if shouldRebuildFromStorage { if shouldRebuildFromStorage {
err = ot.rebuildFromStorage(changesPayload.NewHeads, ot.newChangesBuf) err = ot.rebuildFromStorage(changesPayload.NewHeads, ot.newChangesBuf)
@ -444,7 +466,7 @@ func (ot *objectTree) createAddResult(oldHeads []string, mode Mode, treeChangesA
// if we got some changes that we need to convert to raw // if we got some changes that we need to convert to raw
if _, exists := alreadyConverted[ch]; !exists { if _, exists := alreadyConverted[ch]; !exists {
var raw *treechangeproto.RawTreeChangeWithId var raw *treechangeproto.RawTreeChangeWithId
raw, err = ot.changeBuilder.BuildRaw(ch) raw, err = ot.changeBuilder.Marshall(ch)
if err != nil { if err != nil {
return return
} }
@ -456,6 +478,11 @@ func (ot *objectTree) createAddResult(oldHeads []string, mode Mode, treeChangesA
var added []*treechangeproto.RawTreeChangeWithId var added []*treechangeproto.RawTreeChangeWithId
added, err = getAddedChanges(treeChangesAdded) added, err = getAddedChanges(treeChangesAdded)
if !ot.treeBuilder.keepInMemoryData {
for _, ch := range treeChangesAdded {
ch.Data = nil
}
}
if err != nil { if err != nil {
return return
} }
@ -479,11 +506,11 @@ func (ot *objectTree) IterateFrom(id string, convert ChangeConvertFunc, iterate
} }
decrypt := func(c *Change) (decrypted []byte, err error) { decrypt := func(c *Change) (decrypted []byte, err error) {
// the change is not encrypted // the change is not encrypted
if c.ReadKeyHash == 0 { if c.ReadKeyId == "" {
decrypted = c.Data decrypted = c.Data
return return
} }
readKey, exists := ot.keys[c.ReadKeyHash] readKey, exists := ot.keys[c.ReadKeyId]
if !exists { if !exists {
err = list.ErrNoReadKey err = list.ErrNoReadKey
return return
@ -522,22 +549,8 @@ func (ot *objectTree) IterateFrom(id string, convert ChangeConvertFunc, iterate
} }
func (ot *objectTree) HasChanges(chs ...string) bool { func (ot *objectTree) HasChanges(chs ...string) bool {
hasChange := func(s string) bool {
_, attachedExists := ot.tree.attached[s]
if attachedExists {
return attachedExists
}
has, err := ot.treeStorage.HasChange(context.Background(), s)
if err != nil {
return false
}
return has
}
for _, ch := range chs { for _, ch := range chs {
if !hasChange(ch) { if _, attachedExists := ot.tree.attached[ch]; !attachedExists {
return false return false
} }
} }
@ -553,6 +566,10 @@ func (ot *objectTree) Root() *Change {
return ot.tree.Root() return ot.tree.Root()
} }
func (ot *objectTree) TryClose(objectTTL time.Duration) (bool, error) {
return true, ot.Close()
}
func (ot *objectTree) Close() error { func (ot *objectTree) Close() error {
return nil return nil
} }
@ -600,19 +617,7 @@ func (ot *objectTree) ChangesAfterCommonSnapshot(theirPath, theirHeads []string)
} }
} }
if commonSnapshot == ot.tree.RootId() { return ot.rawChangeLoader.Load(commonSnapshot, ot.tree, theirHeads)
return ot.getChangesFromTree(theirHeads)
} else {
return ot.getChangesFromDB(commonSnapshot, theirHeads)
}
}
func (ot *objectTree) getChangesFromTree(theirHeads []string) (rawChanges []*treechangeproto.RawTreeChangeWithId, err error) {
return ot.rawChangeLoader.LoadFromTree(ot.tree, theirHeads)
}
func (ot *objectTree) getChangesFromDB(commonSnapshot string, theirHeads []string) (rawChanges []*treechangeproto.RawTreeChangeWithId, err error) {
return ot.rawChangeLoader.LoadFromStorage(commonSnapshot, ot.tree.headIds, theirHeads)
} }
func (ot *objectTree) snapshotPathIsActual() bool { func (ot *objectTree) snapshotPathIsActual() bool {
@ -624,11 +629,9 @@ func (ot *objectTree) validateTree(newChanges []*Change) error {
defer ot.aclList.RUnlock() defer ot.aclList.RUnlock()
state := ot.aclList.AclState() state := ot.aclList.AclState()
// just not to take lock many times, updating the key map from aclList err := ot.readKeysFromAclState(state)
if len(ot.keys) != len(state.UserReadKeys()) { if err != nil {
for key, value := range state.UserReadKeys() { return err
ot.keys[key] = value
}
} }
if len(newChanges) == 0 { if len(newChanges) == 0 {
return ot.validator.ValidateFullTree(ot.tree, ot.aclList) return ot.validator.ValidateFullTree(ot.tree, ot.aclList)
@ -637,6 +640,26 @@ func (ot *objectTree) validateTree(newChanges []*Change) error {
return ot.validator.ValidateNewChanges(ot.tree, ot.aclList, newChanges) return ot.validator.ValidateNewChanges(ot.tree, ot.aclList, newChanges)
} }
func (ot *objectTree) DebugDump(parser DescriptionParser) (string, error) { func (ot *objectTree) readKeysFromAclState(state *list.AclState) (err error) {
return ot.tree.Graph(parser) // just not to take lock many times, updating the key map from aclList
if len(ot.keys) == len(state.UserReadKeys()) {
return nil
}
for key, value := range state.UserReadKeys() {
treeKey, err := deriveTreeKey(value, ot.id)
if err != nil {
return err
}
ot.keys[key] = treeKey
}
curKey, err := state.CurrentReadKey()
if err != nil {
return err
}
ot.currentReadKey, err = deriveTreeKey(curKey, ot.id)
return err
}
func (ot *objectTree) Debug(parser DescriptionParser) (DebugInfo, error) {
return objectTreeDebug{}.debugInfo(ot, parser)
} }

View File

@ -2,151 +2,74 @@ package objecttree
import ( import (
"context" "context"
"github.com/anytypeio/any-sync/commonspace/object/acl/list" "fmt"
"github.com/anytypeio/any-sync/commonspace/object/acl/testutils/acllistbuilder" "testing"
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto" "time"
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing"
) )
type mockChangeCreator struct{}
func (c *mockChangeCreator) createRoot(id, aclId string) *treechangeproto.RawTreeChangeWithId {
aclChange := &treechangeproto.RootChange{
AclHeadId: aclId,
}
res, _ := aclChange.Marshal()
raw := &treechangeproto.RawTreeChange{
Payload: res,
Signature: nil,
}
rawMarshalled, _ := raw.Marshal()
return &treechangeproto.RawTreeChangeWithId{
RawChange: rawMarshalled,
Id: id,
}
}
func (c *mockChangeCreator) createRaw(id, aclId, snapshotId string, isSnapshot bool, prevIds ...string) *treechangeproto.RawTreeChangeWithId {
aclChange := &treechangeproto.TreeChange{
TreeHeadIds: prevIds,
AclHeadId: aclId,
SnapshotBaseId: snapshotId,
ChangesData: nil,
IsSnapshot: isSnapshot,
}
res, _ := aclChange.Marshal()
raw := &treechangeproto.RawTreeChange{
Payload: res,
Signature: nil,
}
rawMarshalled, _ := raw.Marshal()
return &treechangeproto.RawTreeChangeWithId{
RawChange: rawMarshalled,
Id: id,
}
}
func (c *mockChangeCreator) createNewTreeStorage(treeId, aclHeadId string) treestorage.TreeStorage {
root := c.createRoot(treeId, aclHeadId)
treeStorage, _ := treestorage.NewInMemoryTreeStorage(root, []string{root.Id}, []*treechangeproto.RawTreeChangeWithId{root})
return treeStorage
}
type mockChangeBuilder struct {
originalBuilder ChangeBuilder
}
func (c *mockChangeBuilder) BuildInitialContent(payload InitialContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error) {
panic("implement me")
}
func (c *mockChangeBuilder) SetRootRawChange(rawIdChange *treechangeproto.RawTreeChangeWithId) {
c.originalBuilder.SetRootRawChange(rawIdChange)
}
func (c *mockChangeBuilder) ConvertFromRaw(rawChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error) {
return c.originalBuilder.ConvertFromRaw(rawChange, false)
}
func (c *mockChangeBuilder) BuildContent(payload BuilderContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error) {
panic("implement me")
}
func (c *mockChangeBuilder) BuildRaw(ch *Change) (raw *treechangeproto.RawTreeChangeWithId, err error) {
return c.originalBuilder.BuildRaw(ch)
}
type mockChangeValidator struct{}
func (m *mockChangeValidator) ValidateNewChanges(tree *Tree, aclList list.AclList, newChanges []*Change) error {
return nil
}
func (m *mockChangeValidator) ValidateFullTree(tree *Tree, aclList list.AclList) error {
return nil
}
type testTreeContext struct { type testTreeContext struct {
aclList list.AclList aclList list.AclList
treeStorage treestorage.TreeStorage treeStorage treestorage.TreeStorage
changeBuilder *mockChangeBuilder changeCreator *MockChangeCreator
changeCreator *mockChangeCreator
objTree ObjectTree objTree ObjectTree
} }
func prepareAclList(t *testing.T) list.AclList { func prepareAclList(t *testing.T) (list.AclList, *accountdata.AccountKeys) {
st, err := acllistbuilder.NewListStorageWithTestName("userjoinexample.yml") randKeys, err := accountdata.NewRandom()
require.NoError(t, err, "building storage should not result in error") require.NoError(t, err)
aclList, err := list.NewTestDerivedAcl("spaceId", randKeys)
aclList, err := list.BuildAclList(st)
require.NoError(t, err, "building acl list should be without error") require.NoError(t, err, "building acl list should be without error")
return aclList return aclList, randKeys
} }
func prepareTreeDeps(aclList list.AclList) (*mockChangeCreator, objectTreeDeps) { func prepareHistoryTreeDeps(aclList list.AclList) (*MockChangeCreator, objectTreeDeps) {
changeCreator := &mockChangeCreator{} changeCreator := NewMockChangeCreator()
treeStorage := changeCreator.createNewTreeStorage("0", aclList.Head().Id) treeStorage := changeCreator.CreateNewTreeStorage("0", aclList.Head().Id)
root, _ := treeStorage.Root() root, _ := treeStorage.Root()
changeBuilder := &mockChangeBuilder{ changeBuilder := &nonVerifiableChangeBuilder{
originalBuilder: NewChangeBuilder(nil, root), ChangeBuilder: NewChangeBuilder(newMockKeyStorage(), root),
} }
deps := objectTreeDeps{ deps := objectTreeDeps{
changeBuilder: changeBuilder, changeBuilder: changeBuilder,
treeBuilder: newTreeBuilder(treeStorage, changeBuilder), treeBuilder: newTreeBuilder(true, treeStorage, changeBuilder),
treeStorage: treeStorage, treeStorage: treeStorage,
rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder), rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder),
validator: &mockChangeValidator{}, validator: &noOpTreeValidator{},
aclList: aclList, aclList: aclList,
} }
return changeCreator, deps return changeCreator, deps
} }
func prepareTreeContext(t *testing.T, aclList list.AclList) testTreeContext { func prepareTreeContext(t *testing.T, aclList list.AclList) testTreeContext {
changeCreator := &mockChangeCreator{} return prepareContext(t, aclList, BuildTestableTree, nil)
treeStorage := changeCreator.createNewTreeStorage("0", aclList.Head().Id) }
root, _ := treeStorage.Root()
changeBuilder := &mockChangeBuilder{
originalBuilder: NewChangeBuilder(nil, root),
}
deps := objectTreeDeps{
changeBuilder: changeBuilder,
treeBuilder: newTreeBuilder(treeStorage, changeBuilder),
treeStorage: treeStorage,
rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder),
validator: &mockChangeValidator{},
aclList: aclList,
}
// check build func prepareEmptyDataTreeContext(t *testing.T, aclList list.AclList, additionalChanges func(changeCreator *MockChangeCreator) RawChangesPayload) testTreeContext {
objTree, err := buildObjectTree(deps) return prepareContext(t, aclList, BuildEmptyDataTestableTree, additionalChanges)
}
func prepareContext(
t *testing.T,
aclList list.AclList,
objTreeBuilder BuildObjectTreeFunc,
additionalChanges func(changeCreator *MockChangeCreator) RawChangesPayload) testTreeContext {
changeCreator := NewMockChangeCreator()
treeStorage := changeCreator.CreateNewTreeStorage("0", aclList.Head().Id)
if additionalChanges != nil {
payload := additionalChanges(changeCreator)
err := treeStorage.AddRawChangesSetHeads(payload.RawChanges, payload.NewHeads)
require.NoError(t, err)
}
objTree, err := objTreeBuilder(treeStorage, aclList)
require.NoError(t, err, "building tree should be without error") require.NoError(t, err, "building tree should be without error")
// check tree iterate // check tree iterate
@ -156,18 +79,71 @@ func prepareTreeContext(t *testing.T, aclList list.AclList) testTreeContext {
return true return true
}) })
require.NoError(t, err, "iterate should be without error") require.NoError(t, err, "iterate should be without error")
if additionalChanges == nil {
assert.Equal(t, []string{"0"}, iterChangesId) assert.Equal(t, []string{"0"}, iterChangesId)
}
return testTreeContext{ return testTreeContext{
aclList: aclList, aclList: aclList,
treeStorage: treeStorage, treeStorage: treeStorage,
changeBuilder: changeBuilder,
changeCreator: changeCreator, changeCreator: changeCreator,
objTree: objTree, objTree: objTree,
} }
} }
func TestObjectTree(t *testing.T) { func TestObjectTree(t *testing.T) {
aclList := prepareAclList(t) aclList, keys := prepareAclList(t)
ctx := context.Background()
t.Run("add content", func(t *testing.T) {
root, err := CreateObjectTreeRoot(ObjectTreeCreatePayload{
PrivKey: keys.SignKey,
ChangeType: "changeType",
ChangePayload: nil,
SpaceId: "spaceId",
IsEncrypted: true,
}, aclList)
require.NoError(t, err)
store, _ := treestorage.NewInMemoryTreeStorage(root, []string{root.Id}, []*treechangeproto.RawTreeChangeWithId{root})
oTree, err := BuildObjectTree(store, aclList)
require.NoError(t, err)
t.Run("0 timestamp is changed to current", func(t *testing.T) {
start := time.Now()
res, err := oTree.AddContent(ctx, SignableChangeContent{
Data: []byte("some"),
Key: keys.SignKey,
IsSnapshot: false,
IsEncrypted: true,
Timestamp: 0,
})
end := time.Now()
require.NoError(t, err)
require.Len(t, oTree.Heads(), 1)
require.Equal(t, res.Added[0].Id, oTree.Heads()[0])
ch, err := oTree.(*objectTree).changeBuilder.Unmarshall(res.Added[0], true)
require.NoError(t, err)
require.GreaterOrEqual(t, start.Unix(), ch.Timestamp)
require.LessOrEqual(t, end.Unix(), ch.Timestamp)
require.Equal(t, res.Added[0].Id, oTree.(*objectTree).tree.lastIteratedHeadId)
})
t.Run("timestamp is set correctly", func(t *testing.T) {
someTs := time.Now().Add(time.Hour).Unix()
res, err := oTree.AddContent(ctx, SignableChangeContent{
Data: []byte("some"),
Key: keys.SignKey,
IsSnapshot: false,
IsEncrypted: true,
Timestamp: someTs,
})
require.NoError(t, err)
require.Len(t, oTree.Heads(), 1)
require.Equal(t, res.Added[0].Id, oTree.Heads()[0])
ch, err := oTree.(*objectTree).changeBuilder.Unmarshall(res.Added[0], true)
require.NoError(t, err)
require.Equal(t, ch.Timestamp, someTs)
require.Equal(t, res.Added[0].Id, oTree.(*objectTree).tree.lastIteratedHeadId)
})
})
t.Run("add simple", func(t *testing.T) { t.Run("add simple", func(t *testing.T) {
ctx := prepareTreeContext(t, aclList) ctx := prepareTreeContext(t, aclList)
@ -176,8 +152,8 @@ func TestObjectTree(t *testing.T) {
objTree := ctx.objTree objTree := ctx.objTree
rawChanges := []*treechangeproto.RawTreeChangeWithId{ rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"), changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
} }
payload := RawChangesPayload{ payload := RawChangesPayload{
NewHeads: []string{rawChanges[len(rawChanges)-1].Id}, NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
@ -221,7 +197,7 @@ func TestObjectTree(t *testing.T) {
objTree := ctx.objTree objTree := ctx.objTree
rawChanges := []*treechangeproto.RawTreeChangeWithId{ rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.createRaw("0", aclList.Head().Id, "", true, ""), changeCreator.CreateRaw("0", aclList.Head().Id, "", true, ""),
} }
payload := RawChangesPayload{ payload := RawChangesPayload{
NewHeads: []string{rawChanges[len(rawChanges)-1].Id}, NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
@ -245,7 +221,33 @@ func TestObjectTree(t *testing.T) {
objTree := ctx.objTree objTree := ctx.objTree
rawChanges := []*treechangeproto.RawTreeChangeWithId{ rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
}
payload := RawChangesPayload{
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
RawChanges: rawChanges,
}
res, err := objTree.AddRawChanges(context.Background(), payload)
require.NoError(t, err, "adding changes should be without error")
// check result
assert.Equal(t, []string{"0"}, res.OldHeads)
assert.Equal(t, []string{"0"}, res.Heads)
assert.Equal(t, 0, len(res.Added))
assert.Equal(t, Nothing, res.Mode)
// check tree heads
assert.Equal(t, []string{"0"}, objTree.Heads())
})
t.Run("add not connected changes", func(t *testing.T) {
ctx := prepareTreeContext(t, aclList)
changeCreator := ctx.changeCreator
objTree := ctx.objTree
// this change could in theory replace current snapshot, we should prevent that
rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.CreateRaw("2", aclList.Head().Id, "0", true, "1"),
} }
payload := RawChangesPayload{ payload := RawChangesPayload{
NewHeads: []string{rawChanges[len(rawChanges)-1].Id}, NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
@ -271,10 +273,10 @@ func TestObjectTree(t *testing.T) {
objTree := ctx.objTree objTree := ctx.objTree
rawChanges := []*treechangeproto.RawTreeChangeWithId{ rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"), changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"), changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
changeCreator.createRaw("4", aclList.Head().Id, "3", false, "3"), changeCreator.CreateRaw("4", aclList.Head().Id, "3", false, "3"),
} }
payload := RawChangesPayload{ payload := RawChangesPayload{
NewHeads: []string{rawChanges[len(rawChanges)-1].Id}, NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
@ -321,9 +323,9 @@ func TestObjectTree(t *testing.T) {
objTree := ctx.objTree objTree := ctx.objTree
rawChanges := []*treechangeproto.RawTreeChangeWithId{ rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"), changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"), changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
} }
payload := RawChangesPayload{ payload := RawChangesPayload{
NewHeads: []string{rawChanges[len(rawChanges)-1].Id}, NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
@ -339,18 +341,195 @@ func TestObjectTree(t *testing.T) {
assert.Equal(t, true, objTree.(*objectTree).snapshotPathIsActual()) assert.Equal(t, true, objTree.(*objectTree).snapshotPathIsActual())
}) })
t.Run("test empty data tree", func(t *testing.T) {
t.Run("empty tree add", func(t *testing.T) {
ctx := prepareEmptyDataTreeContext(t, aclList, nil)
changeCreator := ctx.changeCreator
objTree := ctx.objTree
rawChangesFirst := []*treechangeproto.RawTreeChangeWithId{
changeCreator.CreateRawWithData("1", aclList.Head().Id, "0", false, []byte("1"), "0"),
changeCreator.CreateRawWithData("2", aclList.Head().Id, "0", false, []byte("2"), "1"),
changeCreator.CreateRawWithData("3", aclList.Head().Id, "0", false, []byte("3"), "2"),
}
rawChangesSecond := []*treechangeproto.RawTreeChangeWithId{
changeCreator.CreateRawWithData("4", aclList.Head().Id, "0", false, []byte("4"), "2"),
changeCreator.CreateRawWithData("5", aclList.Head().Id, "0", false, []byte("5"), "1"),
changeCreator.CreateRawWithData("6", aclList.Head().Id, "0", false, []byte("6"), "3", "4", "5"),
}
// making them to be saved in unattached
_, err := objTree.AddRawChanges(context.Background(), RawChangesPayload{
NewHeads: []string{"6"},
RawChanges: rawChangesSecond,
})
require.NoError(t, err, "adding changes should be without error")
// attaching them
res, err := objTree.AddRawChanges(context.Background(), RawChangesPayload{
NewHeads: []string{"3"},
RawChanges: rawChangesFirst,
})
require.NoError(t, err, "adding changes should be without error")
require.Equal(t, "0", objTree.Root().Id)
require.Equal(t, []string{"6"}, objTree.Heads())
require.Equal(t, 6, len(res.Added))
// checking that added changes still have data
for _, ch := range res.Added {
unmarshallRaw := &treechangeproto.RawTreeChange{}
proto.Unmarshal(ch.RawChange, unmarshallRaw)
treeCh := &treechangeproto.TreeChange{}
proto.Unmarshal(unmarshallRaw.Payload, treeCh)
require.Equal(t, ch.Id, string(treeCh.ChangesData))
}
// checking that the tree doesn't have data in memory
err = objTree.IterateRoot(nil, func(change *Change) bool {
if change.Id == "0" {
return true
}
require.Nil(t, change.Data)
return true
})
})
t.Run("empty tree load", func(t *testing.T) {
ctx := prepareEmptyDataTreeContext(t, aclList, func(changeCreator *MockChangeCreator) RawChangesPayload {
rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.CreateRawWithData("1", aclList.Head().Id, "0", false, []byte("1"), "0"),
changeCreator.CreateRawWithData("2", aclList.Head().Id, "0", false, []byte("2"), "1"),
changeCreator.CreateRawWithData("3", aclList.Head().Id, "0", false, []byte("3"), "2"),
changeCreator.CreateRawWithData("4", aclList.Head().Id, "0", false, []byte("4"), "2"),
changeCreator.CreateRawWithData("5", aclList.Head().Id, "0", false, []byte("5"), "1"),
changeCreator.CreateRawWithData("6", aclList.Head().Id, "0", false, []byte("6"), "3", "4", "5"),
}
return RawChangesPayload{NewHeads: []string{"6"}, RawChanges: rawChanges}
})
ctx.objTree.IterateRoot(nil, func(change *Change) bool {
if change.Id == "0" {
return true
}
require.Nil(t, change.Data)
return true
})
rawChanges, err := ctx.objTree.ChangesAfterCommonSnapshot([]string{"0"}, []string{"6"})
require.NoError(t, err)
for _, ch := range rawChanges {
unmarshallRaw := &treechangeproto.RawTreeChange{}
proto.Unmarshal(ch.RawChange, unmarshallRaw)
treeCh := &treechangeproto.TreeChange{}
proto.Unmarshal(unmarshallRaw.Payload, treeCh)
require.Equal(t, ch.Id, string(treeCh.ChangesData))
}
})
})
t.Run("rollback when add to storage returns error", func(t *testing.T) {
ctx := prepareTreeContext(t, aclList)
changeCreator := ctx.changeCreator
objTree := ctx.objTree
store := ctx.treeStorage.(*treestorage.InMemoryTreeStorage)
addErr := fmt.Errorf("error saving")
store.SetReturnErrorOnAdd(addErr)
rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.CreateRaw("1", aclList.Head().Id, "0", true, "0"),
}
payload := RawChangesPayload{
NewHeads: []string{"1"},
RawChanges: rawChanges,
}
_, err := objTree.AddRawChanges(context.Background(), payload)
require.Error(t, err, addErr)
require.Equal(t, "0", objTree.Root().Id)
})
t.Run("their heads before common snapshot", func(t *testing.T) {
// checking that adding old changes did not affect the tree
ctx := prepareTreeContext(t, aclList)
changeCreator := ctx.changeCreator
objTree := ctx.objTree
rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.CreateRaw("1", aclList.Head().Id, "0", true, "0"),
changeCreator.CreateRaw("2", aclList.Head().Id, "1", false, "1"),
changeCreator.CreateRaw("3", aclList.Head().Id, "1", true, "2"),
changeCreator.CreateRaw("4", aclList.Head().Id, "1", false, "2"),
changeCreator.CreateRaw("5", aclList.Head().Id, "1", false, "1"),
changeCreator.CreateRaw("6", aclList.Head().Id, "1", true, "3", "4", "5"),
}
payload := RawChangesPayload{
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
RawChanges: rawChanges,
}
_, err := objTree.AddRawChanges(context.Background(), payload)
require.NoError(t, err, "adding changes should be without error")
require.Equal(t, "6", objTree.Root().Id)
rawChangesPrevious := []*treechangeproto.RawTreeChangeWithId{
changeCreator.CreateRaw("1", aclList.Head().Id, "0", true, "0"),
}
payload = RawChangesPayload{
NewHeads: []string{"1"},
RawChanges: rawChangesPrevious,
}
_, err = objTree.AddRawChanges(context.Background(), payload)
require.NoError(t, err, "adding changes should be without error")
require.Equal(t, "6", objTree.Root().Id)
})
t.Run("stored changes will not break the pipeline if heads were not updated", func(t *testing.T) {
ctx := prepareTreeContext(t, aclList)
changeCreator := ctx.changeCreator
objTree := ctx.objTree
store := ctx.treeStorage.(*treestorage.InMemoryTreeStorage)
rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.CreateRaw("1", aclList.Head().Id, "0", true, "0"),
}
payload := RawChangesPayload{
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
RawChanges: rawChanges,
}
_, err := objTree.AddRawChanges(context.Background(), payload)
require.NoError(t, err, "adding changes should be without error")
require.Equal(t, "1", objTree.Root().Id)
// creating changes to save in the storage
// to imitate the condition where all changes are in the storage
// but the head was not updated
storageChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.CreateRaw("2", aclList.Head().Id, "1", false, "1"),
changeCreator.CreateRaw("3", aclList.Head().Id, "1", true, "2"),
changeCreator.CreateRaw("4", aclList.Head().Id, "1", false, "2"),
changeCreator.CreateRaw("5", aclList.Head().Id, "1", false, "1"),
changeCreator.CreateRaw("6", aclList.Head().Id, "1", true, "3", "4", "5"),
}
store.AddRawChangesSetHeads(storageChanges, []string{"1"})
// updating with subset of those changes to see that everything will still work
payload = RawChangesPayload{
NewHeads: []string{"6"},
RawChanges: storageChanges,
}
_, err = objTree.AddRawChanges(context.Background(), payload)
require.NoError(t, err, "adding changes should be without error")
require.Equal(t, "6", objTree.Root().Id)
})
t.Run("changes from tree after common snapshot complex", func(t *testing.T) { t.Run("changes from tree after common snapshot complex", func(t *testing.T) {
ctx := prepareTreeContext(t, aclList) ctx := prepareTreeContext(t, aclList)
changeCreator := ctx.changeCreator changeCreator := ctx.changeCreator
objTree := ctx.objTree objTree := ctx.objTree
rawChanges := []*treechangeproto.RawTreeChangeWithId{ rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"), changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"), changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
changeCreator.createRaw("4", aclList.Head().Id, "0", false, "2"), changeCreator.CreateRaw("4", aclList.Head().Id, "0", false, "2"),
changeCreator.createRaw("5", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("5", aclList.Head().Id, "0", false, "1"),
changeCreator.createRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"), changeCreator.CreateRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"),
} }
payload := RawChangesPayload{ payload := RawChangesPayload{
@ -424,13 +603,13 @@ func TestObjectTree(t *testing.T) {
objTree := ctx.objTree objTree := ctx.objTree
rawChanges := []*treechangeproto.RawTreeChangeWithId{ rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"), changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"), changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
changeCreator.createRaw("4", aclList.Head().Id, "0", false, "2"), changeCreator.CreateRaw("4", aclList.Head().Id, "0", false, "2"),
changeCreator.createRaw("5", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("5", aclList.Head().Id, "0", false, "1"),
// main difference from tree example // main difference from tree example
changeCreator.createRaw("6", aclList.Head().Id, "0", true, "3", "4", "5"), changeCreator.CreateRaw("6", aclList.Head().Id, "0", true, "3", "4", "5"),
} }
payload := RawChangesPayload{ payload := RawChangesPayload{
@ -505,9 +684,9 @@ func TestObjectTree(t *testing.T) {
objTree := ctx.objTree objTree := ctx.objTree
rawChanges := []*treechangeproto.RawTreeChangeWithId{ rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"), changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"), changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
} }
payload := RawChangesPayload{ payload := RawChangesPayload{
NewHeads: []string{rawChanges[len(rawChanges)-1].Id}, NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
@ -519,9 +698,9 @@ func TestObjectTree(t *testing.T) {
require.Equal(t, "3", objTree.Root().Id) require.Equal(t, "3", objTree.Root().Id)
rawChanges = []*treechangeproto.RawTreeChangeWithId{ rawChanges = []*treechangeproto.RawTreeChangeWithId{
changeCreator.createRaw("4", aclList.Head().Id, "0", false, "2"), changeCreator.CreateRaw("4", aclList.Head().Id, "0", false, "2"),
changeCreator.createRaw("5", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("5", aclList.Head().Id, "0", false, "1"),
changeCreator.createRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"), changeCreator.CreateRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"),
} }
payload = RawChangesPayload{ payload = RawChangesPayload{
NewHeads: []string{rawChanges[len(rawChanges)-1].Id}, NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
@ -562,17 +741,17 @@ func TestObjectTree(t *testing.T) {
}) })
t.Run("test history tree not include", func(t *testing.T) { t.Run("test history tree not include", func(t *testing.T) {
changeCreator, deps := prepareTreeDeps(aclList) changeCreator, deps := prepareHistoryTreeDeps(aclList)
rawChanges := []*treechangeproto.RawTreeChangeWithId{ rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"), changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"), changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
changeCreator.createRaw("4", aclList.Head().Id, "0", false, "2"), changeCreator.CreateRaw("4", aclList.Head().Id, "0", false, "2"),
changeCreator.createRaw("5", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("5", aclList.Head().Id, "0", false, "1"),
changeCreator.createRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"), changeCreator.CreateRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"),
} }
deps.treeStorage.TransactionAdd(rawChanges, []string{"6"}) deps.treeStorage.AddRawChangesSetHeads(rawChanges, []string{"6"})
hTree, err := buildHistoryTree(deps, HistoryTreeParams{ hTree, err := buildHistoryTree(deps, HistoryTreeParams{
BeforeId: "6", BeforeId: "6",
IncludeBeforeId: false, IncludeBeforeId: false,
@ -592,18 +771,49 @@ func TestObjectTree(t *testing.T) {
assert.Equal(t, "0", hTree.Root().Id) assert.Equal(t, "0", hTree.Root().Id)
}) })
t.Run("test history tree build full", func(t *testing.T) {
changeCreator, deps := prepareHistoryTreeDeps(aclList)
// sequence of snapshots: 5->1->0
rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.CreateRaw("1", aclList.Head().Id, "0", true, "0"),
changeCreator.CreateRaw("2", aclList.Head().Id, "1", false, "1"),
changeCreator.CreateRaw("3", aclList.Head().Id, "1", true, "2"),
changeCreator.CreateRaw("4", aclList.Head().Id, "1", false, "2"),
changeCreator.CreateRaw("5", aclList.Head().Id, "1", true, "3", "4"),
changeCreator.CreateRaw("6", aclList.Head().Id, "5", false, "5"),
}
deps.treeStorage.AddRawChangesSetHeads(rawChanges, []string{"6"})
hTree, err := buildHistoryTree(deps, HistoryTreeParams{
BuildFullTree: true,
})
require.NoError(t, err)
// check tree heads
assert.Equal(t, []string{"6"}, hTree.Heads())
// check tree iterate
var iterChangesId []string
err = hTree.IterateFrom(hTree.Root().Id, nil, func(change *Change) bool {
iterChangesId = append(iterChangesId, change.Id)
return true
})
require.NoError(t, err, "iterate should be without error")
assert.Equal(t, []string{"0", "1", "2", "3", "4", "5", "6"}, iterChangesId)
assert.Equal(t, "0", hTree.Root().Id)
})
t.Run("test history tree include", func(t *testing.T) { t.Run("test history tree include", func(t *testing.T) {
changeCreator, deps := prepareTreeDeps(aclList) changeCreator, deps := prepareHistoryTreeDeps(aclList)
rawChanges := []*treechangeproto.RawTreeChangeWithId{ rawChanges := []*treechangeproto.RawTreeChangeWithId{
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"), changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"), changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
changeCreator.createRaw("4", aclList.Head().Id, "0", false, "2"), changeCreator.CreateRaw("4", aclList.Head().Id, "0", false, "2"),
changeCreator.createRaw("5", aclList.Head().Id, "0", false, "1"), changeCreator.CreateRaw("5", aclList.Head().Id, "0", false, "1"),
changeCreator.createRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"), changeCreator.CreateRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"),
} }
deps.treeStorage.TransactionAdd(rawChanges, []string{"6"}) deps.treeStorage.AddRawChangesSetHeads(rawChanges, []string{"6"})
hTree, err := buildHistoryTree(deps, HistoryTreeParams{ hTree, err := buildHistoryTree(deps, HistoryTreeParams{
BeforeId: "6", BeforeId: "6",
IncludeBeforeId: true, IncludeBeforeId: true,
@ -624,7 +834,7 @@ func TestObjectTree(t *testing.T) {
}) })
t.Run("test history tree root", func(t *testing.T) { t.Run("test history tree root", func(t *testing.T) {
_, deps := prepareTreeDeps(aclList) _, deps := prepareHistoryTreeDeps(aclList)
hTree, err := buildHistoryTree(deps, HistoryTreeParams{ hTree, err := buildHistoryTree(deps, HistoryTreeParams{
BeforeId: "0", BeforeId: "0",
IncludeBeforeId: true, IncludeBeforeId: true,
@ -643,4 +853,40 @@ func TestObjectTree(t *testing.T) {
assert.Equal(t, []string{"0"}, iterChangesId) assert.Equal(t, []string{"0"}, iterChangesId)
assert.Equal(t, "0", hTree.Root().Id) assert.Equal(t, "0", hTree.Root().Id)
}) })
t.Run("validate correct tree", func(t *testing.T) {
ctx := prepareTreeContext(t, aclList)
changeCreator := ctx.changeCreator
rawChanges := []*treechangeproto.RawTreeChangeWithId{
ctx.objTree.Header(),
changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
}
defaultObjectTreeDeps = nonVerifiableTreeDeps
err := ValidateRawTree(treestorage.TreeStorageCreatePayload{
RootRawChange: ctx.objTree.Header(),
Heads: []string{"3"},
Changes: rawChanges,
}, ctx.aclList)
require.NoError(t, err)
})
t.Run("fail to validate not connected tree", func(t *testing.T) {
ctx := prepareTreeContext(t, aclList)
changeCreator := ctx.changeCreator
rawChanges := []*treechangeproto.RawTreeChangeWithId{
ctx.objTree.Header(),
changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
}
defaultObjectTreeDeps = nonVerifiableTreeDeps
err := ValidateRawTree(treestorage.TreeStorageCreatePayload{
RootRawChange: ctx.objTree.Header(),
Heads: []string{"3"},
Changes: rawChanges,
}, ctx.aclList)
require.Equal(t, ErrHasInvalidChanges, err)
})
} }

View File

@ -0,0 +1,25 @@
package objecttree
type objectTreeDebug struct {
}
type DebugInfo struct {
TreeLen int
TreeString string
Graphviz string
Heads []string
SnapshotPath []string
}
func (o objectTreeDebug) debugInfo(ot *objectTree, parser DescriptionParser) (di DebugInfo, err error) {
di = DebugInfo{}
di.Graphviz, err = ot.tree.Graph(parser)
if err != nil {
return
}
di.TreeString = ot.tree.String()
di.TreeLen = ot.tree.Len()
di.Heads = ot.Heads()
di.SnapshotPath = ot.SnapshotPath()
return
}

View File

@ -0,0 +1,250 @@
package objecttree
import (
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/util/crypto"
)
type ObjectTreeCreatePayload struct {
PrivKey crypto.PrivKey
ChangeType string
ChangePayload []byte
SpaceId string
IsEncrypted bool
Seed []byte
Timestamp int64
}
type HistoryTreeParams struct {
TreeStorage treestorage.TreeStorage
AclList list.AclList
BeforeId string
IncludeBeforeId bool
BuildFullTree bool
}
type objectTreeDeps struct {
changeBuilder ChangeBuilder
treeBuilder *treeBuilder
treeStorage treestorage.TreeStorage
validator ObjectTreeValidator
rawChangeLoader *rawChangeLoader
aclList list.AclList
}
type BuildObjectTreeFunc = func(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error)
var defaultObjectTreeDeps = verifiableTreeDeps
func verifiableTreeDeps(
rootChange *treechangeproto.RawTreeChangeWithId,
treeStorage treestorage.TreeStorage,
aclList list.AclList) objectTreeDeps {
changeBuilder := NewChangeBuilder(crypto.NewKeyStorage(), rootChange)
treeBuilder := newTreeBuilder(true, treeStorage, changeBuilder)
return objectTreeDeps{
changeBuilder: changeBuilder,
treeBuilder: treeBuilder,
treeStorage: treeStorage,
validator: newTreeValidator(),
rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder),
aclList: aclList,
}
}
func emptyDataTreeDeps(
rootChange *treechangeproto.RawTreeChangeWithId,
treeStorage treestorage.TreeStorage,
aclList list.AclList) objectTreeDeps {
changeBuilder := NewChangeBuilder(crypto.NewKeyStorage(), rootChange)
treeBuilder := newTreeBuilder(false, treeStorage, changeBuilder)
return objectTreeDeps{
changeBuilder: changeBuilder,
treeBuilder: treeBuilder,
treeStorage: treeStorage,
validator: newTreeValidator(),
rawChangeLoader: newStorageLoader(treeStorage, changeBuilder),
aclList: aclList,
}
}
func nonVerifiableTreeDeps(
rootChange *treechangeproto.RawTreeChangeWithId,
treeStorage treestorage.TreeStorage,
aclList list.AclList) objectTreeDeps {
changeBuilder := &nonVerifiableChangeBuilder{NewChangeBuilder(newMockKeyStorage(), rootChange)}
treeBuilder := newTreeBuilder(true, treeStorage, changeBuilder)
return objectTreeDeps{
changeBuilder: changeBuilder,
treeBuilder: treeBuilder,
treeStorage: treeStorage,
validator: &noOpTreeValidator{},
rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder),
aclList: aclList,
}
}
func BuildEmptyDataObjectTree(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error) {
rootChange, err := treeStorage.Root()
if err != nil {
return nil, err
}
deps := emptyDataTreeDeps(rootChange, treeStorage, aclList)
return buildObjectTree(deps)
}
func BuildTestableTree(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error) {
root, _ := treeStorage.Root()
changeBuilder := &nonVerifiableChangeBuilder{
ChangeBuilder: NewChangeBuilder(newMockKeyStorage(), root),
}
deps := objectTreeDeps{
changeBuilder: changeBuilder,
treeBuilder: newTreeBuilder(true, treeStorage, changeBuilder),
treeStorage: treeStorage,
rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder),
validator: &noOpTreeValidator{},
aclList: aclList,
}
return buildObjectTree(deps)
}
func BuildEmptyDataTestableTree(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error) {
root, _ := treeStorage.Root()
changeBuilder := &nonVerifiableChangeBuilder{
ChangeBuilder: NewChangeBuilder(newMockKeyStorage(), root),
}
deps := objectTreeDeps{
changeBuilder: changeBuilder,
treeBuilder: newTreeBuilder(false, treeStorage, changeBuilder),
treeStorage: treeStorage,
rawChangeLoader: newStorageLoader(treeStorage, changeBuilder),
validator: &noOpTreeValidator{},
aclList: aclList,
}
return buildObjectTree(deps)
}
func BuildObjectTree(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error) {
rootChange, err := treeStorage.Root()
if err != nil {
return nil, err
}
deps := defaultObjectTreeDeps(rootChange, treeStorage, aclList)
return buildObjectTree(deps)
}
func BuildNonVerifiableHistoryTree(params HistoryTreeParams) (HistoryTree, error) {
rootChange, err := params.TreeStorage.Root()
if err != nil {
return nil, err
}
deps := nonVerifiableTreeDeps(rootChange, params.TreeStorage, params.AclList)
return buildHistoryTree(deps, params)
}
func BuildHistoryTree(params HistoryTreeParams) (HistoryTree, error) {
rootChange, err := params.TreeStorage.Root()
if err != nil {
return nil, err
}
deps := defaultObjectTreeDeps(rootChange, params.TreeStorage, params.AclList)
return buildHistoryTree(deps, params)
}
func CreateObjectTreeRoot(payload ObjectTreeCreatePayload, aclList list.AclList) (root *treechangeproto.RawTreeChangeWithId, err error) {
aclList.RLock()
aclHeadId := aclList.Head().Id
aclList.RUnlock()
if err != nil {
return
}
cnt := InitialContent{
AclHeadId: aclHeadId,
PrivKey: payload.PrivKey,
SpaceId: payload.SpaceId,
ChangeType: payload.ChangeType,
ChangePayload: payload.ChangePayload,
Timestamp: payload.Timestamp,
Seed: payload.Seed,
}
_, root, err = NewChangeBuilder(crypto.NewKeyStorage(), nil).BuildRoot(cnt)
return
}
func buildObjectTree(deps objectTreeDeps) (ObjectTree, error) {
objTree := &objectTree{
id: deps.treeStorage.Id(),
treeStorage: deps.treeStorage,
treeBuilder: deps.treeBuilder,
validator: deps.validator,
aclList: deps.aclList,
changeBuilder: deps.changeBuilder,
rawChangeLoader: deps.rawChangeLoader,
keys: make(map[string]crypto.SymKey),
newChangesBuf: make([]*Change, 0, 10),
difSnapshotBuf: make([]*treechangeproto.RawTreeChangeWithId, 0, 10),
notSeenIdxBuf: make([]int, 0, 10),
newSnapshotsBuf: make([]*Change, 0, 10),
}
err := objTree.rebuildFromStorage(nil, nil)
if err != nil {
return nil, err
}
objTree.rawRoot, err = objTree.treeStorage.Root()
if err != nil {
return nil, err
}
// verifying root
header, err := objTree.changeBuilder.Unmarshall(objTree.rawRoot, true)
if err != nil {
return nil, err
}
objTree.root = header
return objTree, nil
}
func buildHistoryTree(deps objectTreeDeps, params HistoryTreeParams) (ht HistoryTree, err error) {
objTree := &objectTree{
id: deps.treeStorage.Id(),
treeStorage: deps.treeStorage,
treeBuilder: deps.treeBuilder,
validator: deps.validator,
aclList: deps.aclList,
changeBuilder: deps.changeBuilder,
rawChangeLoader: deps.rawChangeLoader,
keys: make(map[string]crypto.SymKey),
newChangesBuf: make([]*Change, 0, 10),
difSnapshotBuf: make([]*treechangeproto.RawTreeChangeWithId, 0, 10),
notSeenIdxBuf: make([]int, 0, 10),
newSnapshotsBuf: make([]*Change, 0, 10),
}
hTree := &historyTree{objectTree: objTree}
err = hTree.rebuildFromStorage(params)
if err != nil {
return nil, err
}
objTree.id = objTree.treeStorage.Id()
objTree.rawRoot, err = objTree.treeStorage.Root()
if err != nil {
return nil, err
}
header, err := objTree.changeBuilder.Unmarshall(objTree.rawRoot, false)
if err != nil {
return nil, err
}
objTree.root = header
return hTree, nil
}

View File

@ -1,9 +1,12 @@
package objecttree package objecttree
import ( import (
"context"
"fmt" "fmt"
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anytypeio/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/util/slice"
) )
type ObjectTreeValidator interface { type ObjectTreeValidator interface {
@ -13,6 +16,16 @@ type ObjectTreeValidator interface {
ValidateNewChanges(tree *Tree, aclList list.AclList, newChanges []*Change) error ValidateNewChanges(tree *Tree, aclList list.AclList, newChanges []*Change) error
} }
type noOpTreeValidator struct{}
func (n *noOpTreeValidator) ValidateFullTree(tree *Tree, aclList list.AclList) error {
return nil
}
func (n *noOpTreeValidator) ValidateNewChanges(tree *Tree, aclList list.AclList, newChanges []*Change) error {
return nil
}
type objectTreeValidator struct{} type objectTreeValidator struct{}
func newTreeValidator() ObjectTreeValidator { func newTreeValidator() ObjectTreeValidator {
@ -39,20 +52,18 @@ func (v *objectTreeValidator) ValidateNewChanges(tree *Tree, aclList list.AclLis
func (v *objectTreeValidator) validateChange(tree *Tree, aclList list.AclList, c *Change) (err error) { func (v *objectTreeValidator) validateChange(tree *Tree, aclList list.AclList, c *Change) (err error) {
var ( var (
perm list.UserPermissionPair userState list.AclUserState
state = aclList.AclState() state = aclList.AclState()
) )
// checking if the user could write // checking if the user could write
perm, err = state.PermissionsAtRecord(c.AclHeadId, c.Identity) userState, err = state.StateAtRecord(c.AclHeadId, c.Identity)
if err != nil { if err != nil {
return return
} }
if !userState.Permissions.CanWrite() {
if perm.Permission != aclrecordproto.AclUserPermissions_Writer && perm.Permission != aclrecordproto.AclUserPermissions_Admin {
err = list.ErrInsufficientPermissions err = list.ErrInsufficientPermissions
return return
} }
if c.Id == tree.RootId() { if c.Id == tree.RootId() {
return return
} }
@ -75,3 +86,25 @@ func (v *objectTreeValidator) validateChange(tree *Tree, aclList list.AclList, c
} }
return return
} }
func ValidateRawTree(payload treestorage.TreeStorageCreatePayload, aclList list.AclList) (err error) {
treeStorage, err := treestorage.NewInMemoryTreeStorage(payload.RootRawChange, []string{payload.RootRawChange.Id}, nil)
if err != nil {
return
}
tree, err := BuildObjectTree(treeStorage, aclList)
if err != nil {
return
}
res, err := tree.AddRawChanges(context.Background(), RawChangesPayload{
NewHeads: payload.Heads,
RawChanges: payload.Changes,
})
if err != nil {
return
}
if !slice.UnsortedEquals(res.Heads, payload.Heads) {
return ErrHasInvalidChanges
}
return
}

View File

@ -2,15 +2,17 @@ package objecttree
import ( import (
"context" "context"
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage"
"github.com/anytypeio/any-sync/util/slice"
"time" "time"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/util/slice"
) )
type rawChangeLoader struct { type rawChangeLoader struct {
treeStorage treestorage.TreeStorage treeStorage treestorage.TreeStorage
changeBuilder ChangeBuilder changeBuilder ChangeBuilder
alwaysFromStorage bool
// buffers // buffers
idStack []string idStack []string
@ -21,6 +23,13 @@ type rawCacheEntry struct {
change *Change change *Change
rawChange *treechangeproto.RawTreeChangeWithId rawChange *treechangeproto.RawTreeChangeWithId
position int position int
removed bool
}
func newStorageLoader(treeStorage treestorage.TreeStorage, changeBuilder ChangeBuilder) *rawChangeLoader {
loader := newRawChangeLoader(treeStorage, changeBuilder)
loader.alwaysFromStorage = true
return loader
} }
func newRawChangeLoader(treeStorage treestorage.TreeStorage, changeBuilder ChangeBuilder) *rawChangeLoader { func newRawChangeLoader(treeStorage treestorage.TreeStorage, changeBuilder ChangeBuilder) *rawChangeLoader {
@ -30,7 +39,15 @@ func newRawChangeLoader(treeStorage treestorage.TreeStorage, changeBuilder Chang
} }
} }
func (r *rawChangeLoader) LoadFromTree(t *Tree, breakpoints []string) ([]*treechangeproto.RawTreeChangeWithId, error) { func (r *rawChangeLoader) Load(commonSnapshot string, t *Tree, breakpoints []string) ([]*treechangeproto.RawTreeChangeWithId, error) {
if commonSnapshot == t.root.Id && !r.alwaysFromStorage {
return r.loadFromTree(t, breakpoints)
} else {
return r.loadFromStorage(commonSnapshot, t.Heads(), breakpoints)
}
}
func (r *rawChangeLoader) loadFromTree(t *Tree, breakpoints []string) ([]*treechangeproto.RawTreeChangeWithId, error) {
var stack []*Change var stack []*Change
for _, h := range t.headIds { for _, h := range t.headIds {
stack = append(stack, t.attached[h]) stack = append(stack, t.attached[h])
@ -39,7 +56,7 @@ func (r *rawChangeLoader) LoadFromTree(t *Tree, breakpoints []string) ([]*treech
convert := func(chs []*Change) (rawChanges []*treechangeproto.RawTreeChangeWithId, err error) { convert := func(chs []*Change) (rawChanges []*treechangeproto.RawTreeChangeWithId, err error) {
for _, ch := range chs { for _, ch := range chs {
var raw *treechangeproto.RawTreeChangeWithId var raw *treechangeproto.RawTreeChangeWithId
raw, err = r.changeBuilder.BuildRaw(ch) raw, err = r.changeBuilder.Marshall(ch)
if err != nil { if err != nil {
return return
} }
@ -98,7 +115,7 @@ func (r *rawChangeLoader) LoadFromTree(t *Tree, breakpoints []string) ([]*treech
return convert(results) return convert(results)
} }
func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoints []string) ([]*treechangeproto.RawTreeChangeWithId, error) { func (r *rawChangeLoader) loadFromStorage(commonSnapshot string, heads, breakpoints []string) ([]*treechangeproto.RawTreeChangeWithId, error) {
// resetting cache // resetting cache
r.cache = make(map[string]rawCacheEntry) r.cache = make(map[string]rawCacheEntry)
defer func() { defer func() {
@ -111,7 +128,6 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
if err != nil { if err != nil {
continue continue
} }
entry.position = -1
r.cache[b] = entry r.cache[b] = entry
existingBreakpoints = append(existingBreakpoints, b) existingBreakpoints = append(existingBreakpoints, b)
} }
@ -120,8 +136,7 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
dfs := func( dfs := func(
commonSnapshot string, commonSnapshot string,
heads []string, heads []string,
startCounter int, shouldVisit func(entry rawCacheEntry, mapExists bool) bool,
shouldVisit func(counter int, mapExists bool) bool,
visit func(entry rawCacheEntry) rawCacheEntry) bool { visit func(entry rawCacheEntry) rawCacheEntry) bool {
// resetting stack // resetting stack
@ -135,7 +150,7 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
r.idStack = r.idStack[:len(r.idStack)-1] r.idStack = r.idStack[:len(r.idStack)-1]
entry, exists := r.cache[id] entry, exists := r.cache[id]
if !shouldVisit(entry.position, exists) { if !shouldVisit(entry, exists) {
continue continue
} }
if id == commonSnapshot { if id == commonSnapshot {
@ -144,7 +159,6 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
} }
if !exists { if !exists {
entry, err = r.loadEntry(id) entry, err = r.loadEntry(id)
entry.position = -1
if err != nil { if err != nil {
continue continue
} }
@ -159,7 +173,7 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
break break
} }
prevEntry, exists := r.cache[prev] prevEntry, exists := r.cache[prev]
if !shouldVisit(prevEntry.position, exists) { if !shouldVisit(prevEntry, exists) {
continue continue
} }
r.idStack = append(r.idStack, prev) r.idStack = append(r.idStack, prev)
@ -172,8 +186,8 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
r.idStack = append(r.idStack, heads...) r.idStack = append(r.idStack, heads...)
var buffer []*treechangeproto.RawTreeChangeWithId var buffer []*treechangeproto.RawTreeChangeWithId
rootVisited := dfs(commonSnapshot, heads, 0, rootVisited := dfs(commonSnapshot, heads,
func(counter int, mapExists bool) bool { func(_ rawCacheEntry, mapExists bool) bool {
return !mapExists return !mapExists
}, },
func(entry rawCacheEntry) rawCacheEntry { func(entry rawCacheEntry) rawCacheEntry {
@ -198,11 +212,13 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
} }
// marking all visited as nil // marking all visited as nil
dfs(commonSnapshot, existingBreakpoints, len(buffer), dfs(commonSnapshot, existingBreakpoints,
func(counter int, mapExists bool) bool { func(entry rawCacheEntry, mapExists bool) bool {
return !mapExists || counter < len(buffer) // only going through already loaded changes
return mapExists && !entry.removed
}, },
func(entry rawCacheEntry) rawCacheEntry { func(entry rawCacheEntry) rawCacheEntry {
entry.removed = true
if entry.position != -1 { if entry.position != -1 {
buffer[entry.position] = nil buffer[entry.position] = nil
} }
@ -226,13 +242,14 @@ func (r *rawChangeLoader) loadEntry(id string) (entry rawCacheEntry, err error)
return return
} }
change, err := r.changeBuilder.ConvertFromRaw(rawChange, false) change, err := r.changeBuilder.Unmarshall(rawChange, false)
if err != nil { if err != nil {
return return
} }
entry = rawCacheEntry{ entry = rawCacheEntry{
change: change, change: change,
rawChange: rawChange, rawChange: rawChange,
position: -1,
} }
return return
} }

View File

@ -1,16 +0,0 @@
package objecttree
import (
"github.com/anytypeio/any-sync/commonspace/object/acl/list"
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage"
)
func ValidateRawTree(payload treestorage.TreeStorageCreatePayload, aclList list.AclList) (err error) {
treeStorage, err := treestorage.NewInMemoryTreeStorage(payload.RootRawChange, payload.Heads, payload.Changes)
if err != nil {
return
}
_, err = BuildObjectTree(treeStorage, aclList)
return
}

View File

@ -1,13 +1,19 @@
package objecttree package objecttree
import ( import (
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey" "github.com/anyproto/any-sync/util/crypto"
) )
// SignableChangeContent is a payload to be passed when we are creating change
type SignableChangeContent struct { type SignableChangeContent struct {
// Data is a data provided by the client
Data []byte Data []byte
Key signingkey.PrivKey // Key is the key which will be used to sign the change
Identity []byte Key crypto.PrivKey
// IsSnapshot tells if the change has snapshot of all previous data
IsSnapshot bool IsSnapshot bool
// IsEncrypted tells if we encrypt the data with the relevant symmetric key
IsEncrypted bool IsEncrypted bool
// Timestamp is a timestamp of change, if it is <= 0, then we use current timestamp
Timestamp int64
} }

View File

@ -0,0 +1,121 @@
package objecttree
import (
"fmt"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/util/crypto"
libcrypto "github.com/libp2p/go-libp2p/core/crypto"
)
type mockPubKey struct {
}
const mockKeyValue = "mockKey"
func (m mockPubKey) Equals(key crypto.Key) bool {
return true
}
func (m mockPubKey) Raw() ([]byte, error) {
return []byte(mockKeyValue), nil
}
func (m mockPubKey) Encrypt(message []byte) ([]byte, error) {
return message, nil
}
func (m mockPubKey) Verify(data []byte, sig []byte) (bool, error) {
return true, nil
}
func (m mockPubKey) Marshall() ([]byte, error) {
return []byte(mockKeyValue), nil
}
func (m mockPubKey) Storage() []byte {
return []byte(mockKeyValue)
}
func (m mockPubKey) Account() string {
return mockKeyValue
}
func (m mockPubKey) Network() string {
return mockKeyValue
}
func (m mockPubKey) PeerId() string {
return mockKeyValue
}
func (m mockPubKey) LibP2P() (libcrypto.PubKey, error) {
return nil, fmt.Errorf("can't be converted in libp2p")
}
type mockKeyStorage struct {
}
func newMockKeyStorage() mockKeyStorage {
return mockKeyStorage{}
}
func (m mockKeyStorage) PubKeyFromProto(protoBytes []byte) (crypto.PubKey, error) {
return mockPubKey{}, nil
}
type MockChangeCreator struct{}
func NewMockChangeCreator() *MockChangeCreator {
return &MockChangeCreator{}
}
func (c *MockChangeCreator) CreateRoot(id, aclId string) *treechangeproto.RawTreeChangeWithId {
aclChange := &treechangeproto.RootChange{
AclHeadId: aclId,
}
res, _ := aclChange.Marshal()
raw := &treechangeproto.RawTreeChange{
Payload: res,
Signature: nil,
}
rawMarshalled, _ := raw.Marshal()
return &treechangeproto.RawTreeChangeWithId{
RawChange: rawMarshalled,
Id: id,
}
}
func (c *MockChangeCreator) CreateRaw(id, aclId, snapshotId string, isSnapshot bool, prevIds ...string) *treechangeproto.RawTreeChangeWithId {
return c.CreateRawWithData(id, aclId, snapshotId, isSnapshot, nil, prevIds...)
}
func (c *MockChangeCreator) CreateRawWithData(id, aclId, snapshotId string, isSnapshot bool, data []byte, prevIds ...string) *treechangeproto.RawTreeChangeWithId {
aclChange := &treechangeproto.TreeChange{
TreeHeadIds: prevIds,
AclHeadId: aclId,
SnapshotBaseId: snapshotId,
ChangesData: data,
IsSnapshot: isSnapshot,
}
res, _ := aclChange.Marshal()
raw := &treechangeproto.RawTreeChange{
Payload: res,
Signature: nil,
}
rawMarshalled, _ := raw.Marshal()
return &treechangeproto.RawTreeChangeWithId{
RawChange: rawMarshalled,
Id: id,
}
}
func (c *MockChangeCreator) CreateNewTreeStorage(treeId, aclHeadId string) treestorage.TreeStorage {
root := c.CreateRoot(treeId, aclHeadId)
treeStorage, _ := treestorage.NewInMemoryTreeStorage(root, []string{root.Id}, []*treechangeproto.RawTreeChangeWithId{root})
return treeStorage
}

View File

@ -82,6 +82,7 @@ func (t *Tree) AddMergedHead(c *Change) error {
} }
} }
t.headIds = []string{c.Id} t.headIds = []string{c.Id}
t.lastIteratedHeadId = c.Id
return nil return nil
} }

Some files were not shown because too many files have changed in this diff Show More