Compare commits
No commits in common. "v0.0.4" and "main" have entirely different histories.
14
.github/dependabot.yml
vendored
Normal file
14
.github/dependabot.yml
vendored
Normal file
@ -0,0 +1,14 @@
|
||||
# To get started with Dependabot version updates, you'll need to specify which
|
||||
# package ecosystems to update and where the package manifests are located.
|
||||
# Please see the documentation for all configuration options:
|
||||
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
|
||||
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "gomod"
|
||||
directory: "/" # Location of package manifests
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
ignore:
|
||||
- dependency-name: "github.com/anyproto/go-chash"
|
||||
|
||||
30
.github/workflows/coverage.yml
vendored
30
.github/workflows/coverage.yml
vendored
@ -7,7 +7,7 @@ jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GOPRIVATE: github.com/anytypeio
|
||||
GOPRIVATE: github.com/anyproto
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/setup-go@v3
|
||||
@ -17,20 +17,20 @@ jobs:
|
||||
- name: git config
|
||||
run: git config --global url.https://${{ secrets.ANYTYPE_PAT }}@github.com/.insteadOf https://github.com/
|
||||
|
||||
# cache {{
|
||||
- id: go-cache-paths
|
||||
run: |
|
||||
echo "GOCACHE=$(go env GOCACHE)" >> $GITHUB_OUTPUT
|
||||
echo "GOMODCACHE=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
- uses: actions/cache@v3
|
||||
with:
|
||||
path: |
|
||||
${{ steps.go-cache-paths.outputs.GOCACHE }}
|
||||
${{ steps.go-cache-paths.outputs.GOMODCACHE }}
|
||||
key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('**/go.sum') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-go-${{ matrix.go-version }}-
|
||||
# }}
|
||||
# # cache {{
|
||||
# - id: go-cache-paths
|
||||
# run: |
|
||||
# echo "GOCACHE=$(go env GOCACHE)" >> $GITHUB_OUTPUT
|
||||
# echo "GOMODCACHE=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||
# - uses: actions/cache@v3
|
||||
# with:
|
||||
# path: |
|
||||
# ${{ steps.go-cache-paths.outputs.GOCACHE }}
|
||||
# ${{ steps.go-cache-paths.outputs.GOMODCACHE }}
|
||||
# key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('**/go.sum') }}
|
||||
# restore-keys: |
|
||||
# ${{ runner.os }}-go-${{ matrix.go-version }}-
|
||||
# # }}
|
||||
|
||||
- name: deps
|
||||
run: make deps
|
||||
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Any Association
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
|
||||
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
|
||||
OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
12
Makefile
12
Makefile
@ -1,5 +1,5 @@
|
||||
.PHONY: proto test deps
|
||||
export GOPRIVATE=github.com/anytypeio
|
||||
export GOPRIVATE=github.com/anyproto
|
||||
export PATH:=deps:$(PATH)
|
||||
|
||||
proto:
|
||||
@ -7,14 +7,20 @@ proto:
|
||||
|
||||
@$(eval P_ACL_RECORDS_PATH_PB := commonspace/object/acl/aclrecordproto)
|
||||
@$(eval P_TREE_CHANGES_PATH_PB := commonspace/object/tree/treechangeproto)
|
||||
@$(eval P_ACL_RECORDS := M$(P_ACL_RECORDS_PATH_PB)/protos/aclrecord.proto=github.com/anytypeio/any-sync/$(P_ACL_RECORDS_PATH_PB))
|
||||
@$(eval P_TREE_CHANGES := M$(P_TREE_CHANGES_PATH_PB)/protos/treechange.proto=github.com/anytypeio/any-sync/$(P_TREE_CHANGES_PATH_PB))
|
||||
@$(eval P_CRYPTO_PATH_PB := util/crypto/cryptoproto)
|
||||
@$(eval P_ACL_RECORDS := M$(P_ACL_RECORDS_PATH_PB)/protos/aclrecord.proto=github.com/anyproto/any-sync/$(P_ACL_RECORDS_PATH_PB))
|
||||
@$(eval P_TREE_CHANGES := M$(P_TREE_CHANGES_PATH_PB)/protos/treechange.proto=github.com/anyproto/any-sync/$(P_TREE_CHANGES_PATH_PB))
|
||||
|
||||
protoc --gogofaster_out=:. $(P_ACL_RECORDS_PATH_PB)/protos/*.proto
|
||||
protoc --gogofaster_out=:. $(P_TREE_CHANGES_PATH_PB)/protos/*.proto
|
||||
protoc --gogofaster_out=:. $(P_CRYPTO_PATH_PB)/protos/*.proto
|
||||
$(eval PKGMAP := $$(P_TREE_CHANGES),$$(P_ACL_RECORDS))
|
||||
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. commonspace/spacesyncproto/protos/*.proto
|
||||
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. commonfile/fileproto/protos/*.proto
|
||||
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. net/streampool/testservice/protos/*.proto
|
||||
protoc --gogofaster_out=:. net/secureservice/handshake/handshakeproto/protos/*.proto
|
||||
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. coordinator/coordinatorproto/protos/*.proto
|
||||
protoc --gogofaster_out=:. --go-drpc_out=protolib=github.com/gogo/protobuf:. consensus/consensusproto/protos/*.proto
|
||||
|
||||
deps:
|
||||
go mod download
|
||||
|
||||
43
README.md
Normal file
43
README.md
Normal file
@ -0,0 +1,43 @@
|
||||
# Any-Sync
|
||||
Any-Sync is an open-source protocol designed to create high-performance, local-first, peer-to-peer, end-to-end encrypted applications that facilitate seamless collaboration among multiple users and devices.
|
||||
|
||||
By utilizing this protocol, users can rest assured that they retain complete control over their data and digital experience. They are empowered to freely transition between various service providers, or even opt to self-host the applications.
|
||||
|
||||
This ensures utmost flexibility and autonomy for users in managing their personal information and digital interactions.
|
||||
|
||||
## Introduction
|
||||
Most existing information management tools are implemented on centralized client-server architecture or designed for an offline-first single-user usage. Either way there are trade-offs for users: they can face restricted freedoms and privacy violations or compromise on the functionality of tools to avoid this.
|
||||
|
||||
We believe this goes against fundamental digital freedoms and that a new generation of software is needed that will respect these freedoms, while providing best in-class user experience.
|
||||
|
||||
Our goal with `any-sync` is to develop a protocol that will enable the deployment of this software.
|
||||
|
||||
Features:
|
||||
- Conflict-free data replication across multiple devices and agents
|
||||
- Built-in end-to-end encryption
|
||||
- Cryptographically verifiable history of changes
|
||||
- Adoption to frequent operations (high performance)
|
||||
- Reliable and scalable infrastructure
|
||||
- Simultaneous support of p2p and remote communication
|
||||
|
||||
## Protocol explanation
|
||||
Plese read the [overview](https://tech.anytype.io/any-sync/overview) of protocol entities and design.
|
||||
|
||||
## Implementation
|
||||
|
||||
You can find the various parts of the protocol implemented in Go in the following repositories:
|
||||
- [`any-sync-node`](https://github.com/anyproto/any-sync-node) — implementation of a sync node responsible for storing spaces and objects.
|
||||
- [`any-sync-filenode`](https://github.com/anyproto/any-sync-filenode) — implementation of a file node responsible for storing files.
|
||||
- [`any-sync-coordinator`](https://github.com/anyproto/any-sync-coordinator) — implementation of a coordinator node responsible for network configuration management.
|
||||
|
||||
## Contribution
|
||||
Thank you for your desire to develop Anytype together.
|
||||
|
||||
Currently, we're not ready to accept PRs, but we will in the nearest future.
|
||||
|
||||
Follow us on [Github](https://github.com/anyproto) and join the [Contributors Community](https://github.com/orgs/anyproto/discussions).
|
||||
|
||||
---
|
||||
Made by Any — a Swiss association 🇨🇭
|
||||
|
||||
Licensed under [MIT License](./LICENSE).
|
||||
@ -1,23 +1,22 @@
|
||||
//go:generate mockgen -destination mock_accountservice/mock_accountservice.go github.com/anytypeio/any-sync/accountservice Service
|
||||
//go:generate mockgen -destination mock_accountservice/mock_accountservice.go github.com/anyproto/any-sync/accountservice Service
|
||||
package accountservice
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/app"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/commonspace/object/accountdata"
|
||||
)
|
||||
|
||||
const CName = "common.accountservice"
|
||||
|
||||
type Service interface {
|
||||
app.Component
|
||||
Account() *accountdata.AccountData
|
||||
Account() *accountdata.AccountKeys
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
PeerId string `yaml:"peerId"`
|
||||
PeerKey string `yaml:"peerKey"`
|
||||
SigningKey string `yaml:"signingKey"`
|
||||
EncryptionKey string `yaml:"encryptionKey"`
|
||||
}
|
||||
|
||||
type ConfigGetter interface {
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
package mock_accountservice
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/accountservice"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/anyproto/any-sync/accountservice"
|
||||
"github.com/anyproto/any-sync/commonspace/object/accountdata"
|
||||
"go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
func NewAccountServiceWithAccount(ctrl *gomock.Controller, acc *accountdata.AccountData) *MockService {
|
||||
func NewAccountServiceWithAccount(ctrl *gomock.Controller, acc *accountdata.AccountKeys) *MockService {
|
||||
mock := NewMockService(ctrl)
|
||||
mock.EXPECT().Name().Return(accountservice.CName).AnyTimes()
|
||||
mock.EXPECT().Init(gomock.Any()).AnyTimes()
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anytypeio/any-sync/accountservice (interfaces: Service)
|
||||
// Source: github.com/anyproto/any-sync/accountservice (interfaces: Service)
|
||||
|
||||
// Package mock_accountservice is a generated GoMock package.
|
||||
package mock_accountservice
|
||||
@ -7,9 +7,9 @@ package mock_accountservice
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
app "github.com/anytypeio/any-sync/app"
|
||||
accountdata "github.com/anytypeio/any-sync/commonspace/object/accountdata"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
app "github.com/anyproto/any-sync/app"
|
||||
accountdata "github.com/anyproto/any-sync/commonspace/object/accountdata"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockService is a mock of Service interface.
|
||||
@ -36,10 +36,10 @@ func (m *MockService) EXPECT() *MockServiceMockRecorder {
|
||||
}
|
||||
|
||||
// Account mocks base method.
|
||||
func (m *MockService) Account() *accountdata.AccountData {
|
||||
func (m *MockService) Account() *accountdata.AccountKeys {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Account")
|
||||
ret0, _ := ret[0].(*accountdata.AccountData)
|
||||
ret0, _ := ret[0].(*accountdata.AccountKeys)
|
||||
return ret0
|
||||
}
|
||||
|
||||
|
||||
164
app/app.go
164
app/app.go
@ -4,10 +4,11 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"go.uber.org/zap"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@ -15,12 +16,15 @@ import (
|
||||
|
||||
var (
|
||||
// values of this vars will be defined while compilation
|
||||
GitCommit, GitBranch, GitState, GitSummary, BuildDate string
|
||||
AppName, GitCommit, GitBranch, GitState, GitSummary, BuildDate string
|
||||
name string
|
||||
)
|
||||
|
||||
var (
|
||||
log = logger.NewNamed("app")
|
||||
StopDeadline = time.Minute
|
||||
StopWarningAfter = time.Second * 10
|
||||
StartWarningAfter = time.Second * 10
|
||||
)
|
||||
|
||||
// Component is a minimal interface for a common app.Component
|
||||
@ -51,10 +55,14 @@ type ComponentStatable interface {
|
||||
// App is the central part of the application
|
||||
// It contains and manages all components
|
||||
type App struct {
|
||||
parent *App
|
||||
components []Component
|
||||
mu sync.RWMutex
|
||||
startStat StartStat
|
||||
startStat Stat
|
||||
stopStat Stat
|
||||
deviceState int
|
||||
versionName string
|
||||
anySyncVersion string
|
||||
}
|
||||
|
||||
// Name returns app name
|
||||
@ -62,23 +70,47 @@ func (app *App) Name() string {
|
||||
return name
|
||||
}
|
||||
|
||||
func (app *App) AppName() string {
|
||||
return AppName
|
||||
}
|
||||
|
||||
// Version return app version
|
||||
func (app *App) Version() string {
|
||||
return GitSummary
|
||||
}
|
||||
|
||||
type StartStat struct {
|
||||
// SetVersionName sets the custom application version
|
||||
func (app *App) SetVersionName(v string) {
|
||||
app.versionName = v
|
||||
}
|
||||
|
||||
// VersionName returns a string with the settled app version or auto-generated version if it didn't set
|
||||
func (app *App) VersionName() string {
|
||||
if app.versionName != "" {
|
||||
return app.versionName
|
||||
}
|
||||
return AppName + ":" + GitSummary + "/any-sync:" + app.AnySyncVersion()
|
||||
}
|
||||
|
||||
type Stat struct {
|
||||
SpentMsPerComp map[string]int64
|
||||
SpentMsTotal int64
|
||||
}
|
||||
|
||||
// StartStat returns total time spent per comp
|
||||
func (app *App) StartStat() StartStat {
|
||||
// StartStat returns total time spent per comp for the last Start
|
||||
func (app *App) StartStat() Stat {
|
||||
app.mu.Lock()
|
||||
defer app.mu.Unlock()
|
||||
return app.startStat
|
||||
}
|
||||
|
||||
// StopStat returns total time spent per comp for the last Close
|
||||
func (app *App) StopStat() Stat {
|
||||
app.mu.Lock()
|
||||
defer app.mu.Unlock()
|
||||
return app.stopStat
|
||||
}
|
||||
|
||||
// VersionDescription return the full info about the build
|
||||
func (app *App) VersionDescription() string {
|
||||
return VersionDescription()
|
||||
@ -92,6 +124,16 @@ func VersionDescription() string {
|
||||
return fmt.Sprintf("build on %s from %s at #%s(%s)", BuildDate, GitBranch, GitCommit, GitState)
|
||||
}
|
||||
|
||||
// ChildApp creates a child container which has access to parent's components
|
||||
// It doesn't call Start on any of the parent's components
|
||||
func (app *App) ChildApp() *App {
|
||||
return &App{
|
||||
parent: app,
|
||||
deviceState: app.deviceState,
|
||||
anySyncVersion: app.AnySyncVersion(),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds service to registry
|
||||
// All components will be started in the order they were registered
|
||||
func (app *App) Register(s Component) *App {
|
||||
@ -111,11 +153,15 @@ func (app *App) Register(s Component) *App {
|
||||
func (app *App) Component(name string) Component {
|
||||
app.mu.RLock()
|
||||
defer app.mu.RUnlock()
|
||||
for _, s := range app.components {
|
||||
current := app
|
||||
for current != nil {
|
||||
for _, s := range current.components {
|
||||
if s.Name() == name {
|
||||
return s
|
||||
}
|
||||
}
|
||||
current = current.parent
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -132,11 +178,15 @@ func (app *App) MustComponent(name string) Component {
|
||||
func MustComponent[i any](app *App) i {
|
||||
app.mu.RLock()
|
||||
defer app.mu.RUnlock()
|
||||
for _, s := range app.components {
|
||||
current := app
|
||||
for current != nil {
|
||||
for _, s := range current.components {
|
||||
if v, ok := s.(i); ok {
|
||||
return v
|
||||
}
|
||||
}
|
||||
current = current.parent
|
||||
}
|
||||
empty := new(i)
|
||||
panic(fmt.Errorf("component with interface %T is not found", empty))
|
||||
}
|
||||
@ -145,9 +195,13 @@ func MustComponent[i any](app *App) i {
|
||||
func (app *App) ComponentNames() (names []string) {
|
||||
app.mu.RLock()
|
||||
defer app.mu.RUnlock()
|
||||
names = make([]string, len(app.components))
|
||||
for i, c := range app.components {
|
||||
names[i] = c.Name()
|
||||
names = make([]string, 0, len(app.components))
|
||||
current := app
|
||||
for current != nil {
|
||||
for _, c := range current.components {
|
||||
names = append(names, c.Name())
|
||||
}
|
||||
current = current.parent
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -158,7 +212,17 @@ func (app *App) Start(ctx context.Context) (err error) {
|
||||
app.mu.RLock()
|
||||
defer app.mu.RUnlock()
|
||||
app.startStat.SpentMsPerComp = make(map[string]int64)
|
||||
|
||||
var currentComponentStarting string
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-time.After(StartWarningAfter):
|
||||
l := statLogger(app.stopStat, log).With(zap.String("in_progress", currentComponentStarting))
|
||||
l.Warn("components start in progress")
|
||||
}
|
||||
}()
|
||||
closeServices := func(idx int) {
|
||||
for i := idx; i >= 0; i-- {
|
||||
if serviceClose, ok := app.components[i].(ComponentRunnable); ok {
|
||||
@ -172,7 +236,7 @@ func (app *App) Start(ctx context.Context) (err error) {
|
||||
for i, s := range app.components {
|
||||
if err = s.Init(app); err != nil {
|
||||
closeServices(i)
|
||||
return fmt.Errorf("can't init service '%s': %v", s.Name(), err)
|
||||
return fmt.Errorf("can't init service '%s': %w", s.Name(), err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -181,17 +245,32 @@ func (app *App) Start(ctx context.Context) (err error) {
|
||||
start := time.Now()
|
||||
if err = serviceRun.Run(ctx); err != nil {
|
||||
closeServices(i)
|
||||
return fmt.Errorf("can't run service '%s': %v", serviceRun.Name(), err)
|
||||
return fmt.Errorf("can't run service '%s': %w", serviceRun.Name(), err)
|
||||
}
|
||||
spent := time.Since(start).Milliseconds()
|
||||
app.startStat.SpentMsTotal += spent
|
||||
app.startStat.SpentMsPerComp[s.Name()] = spent
|
||||
}
|
||||
}
|
||||
log.Debug("all components started")
|
||||
|
||||
close(done)
|
||||
l := statLogger(app.stopStat, log)
|
||||
if app.startStat.SpentMsTotal > StartWarningAfter.Milliseconds() {
|
||||
l.Warn("all components started")
|
||||
}
|
||||
l.Debug("all components started")
|
||||
return
|
||||
}
|
||||
|
||||
// IterateComponents iterates over all registered components. It's safe for concurrent use.
|
||||
func (app *App) IterateComponents(fn func(Component)) {
|
||||
app.mu.RLock()
|
||||
defer app.mu.RUnlock()
|
||||
for _, s := range app.components {
|
||||
fn(s)
|
||||
}
|
||||
}
|
||||
|
||||
func stackAllGoroutines() []byte {
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
@ -203,18 +282,41 @@ func stackAllGoroutines() []byte {
|
||||
}
|
||||
}
|
||||
|
||||
func statLogger(stat Stat, ctxLogger logger.CtxLogger) logger.CtxLogger {
|
||||
l := ctxLogger
|
||||
for k, v := range stat.SpentMsPerComp {
|
||||
l = l.With(zap.Int64(k, v))
|
||||
}
|
||||
l = l.With(zap.Int64("total", stat.SpentMsTotal))
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
// Close stops the application
|
||||
// All components with ComponentRunnable implementation will be closed in the reversed order
|
||||
func (app *App) Close(ctx context.Context) error {
|
||||
log.Debug("close components...")
|
||||
app.mu.RLock()
|
||||
defer app.mu.RUnlock()
|
||||
app.stopStat.SpentMsPerComp = make(map[string]int64)
|
||||
var currentComponentStopping string
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-time.After(time.Minute):
|
||||
case <-time.After(StopWarningAfter):
|
||||
statLogger(app.stopStat, log).
|
||||
With(zap.String("in_progress", currentComponentStopping)).
|
||||
Warn("components close in progress")
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case <-time.After(StopDeadline):
|
||||
_, _ = os.Stderr.Write([]byte("app.Close timeout\n"))
|
||||
_, _ = os.Stderr.Write(stackAllGoroutines())
|
||||
panic("app.Close timeout")
|
||||
@ -224,16 +326,27 @@ func (app *App) Close(ctx context.Context) error {
|
||||
var errs []string
|
||||
for i := len(app.components) - 1; i >= 0; i-- {
|
||||
if serviceClose, ok := app.components[i].(ComponentRunnable); ok {
|
||||
start := time.Now()
|
||||
currentComponentStopping = app.components[i].Name()
|
||||
if e := serviceClose.Close(ctx); e != nil {
|
||||
errs = append(errs, fmt.Sprintf("Component '%s' close error: %v", serviceClose.Name(), e))
|
||||
}
|
||||
spent := time.Since(start).Milliseconds()
|
||||
app.stopStat.SpentMsTotal += spent
|
||||
app.stopStat.SpentMsPerComp[app.components[i].Name()] = spent
|
||||
}
|
||||
}
|
||||
close(done)
|
||||
if len(errs) > 0 {
|
||||
return errors.New(strings.Join(errs, "\n"))
|
||||
}
|
||||
log.Debug("all components have been closed")
|
||||
|
||||
l := statLogger(app.stopStat, log)
|
||||
if app.stopStat.SpentMsTotal > StopWarningAfter.Milliseconds() {
|
||||
l.Warn("all components have been closed")
|
||||
}
|
||||
|
||||
l.Debug("all components have been closed")
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -250,3 +363,20 @@ func (app *App) SetDeviceState(state int) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var onceVersion sync.Once
|
||||
|
||||
func (app *App) AnySyncVersion() string {
|
||||
onceVersion.Do(func() {
|
||||
info, ok := debug.ReadBuildInfo()
|
||||
if ok {
|
||||
for _, mod := range info.Deps {
|
||||
if mod.Path == "github.com/anyproto/any-sync" {
|
||||
app.anySyncVersion = mod.Version
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
return app.anySyncVersion
|
||||
}
|
||||
|
||||
@ -34,6 +34,40 @@ func TestAppServiceRegistry(t *testing.T) {
|
||||
names := app.ComponentNames()
|
||||
assert.Equal(t, names, []string{"c1", "r1", "s1"})
|
||||
})
|
||||
t.Run("Child MustComponent", func(t *testing.T) {
|
||||
app := app.ChildApp()
|
||||
app.Register(newTestService(testTypeComponent, "x1", nil, nil))
|
||||
for _, name := range []string{"c1", "r1", "s1", "x1"} {
|
||||
assert.NotPanics(t, func() { app.MustComponent(name) }, name)
|
||||
}
|
||||
assert.Panics(t, func() { app.MustComponent("not-registered") })
|
||||
})
|
||||
t.Run("Child ComponentNames", func(t *testing.T) {
|
||||
app := app.ChildApp()
|
||||
app.Register(newTestService(testTypeComponent, "x1", nil, nil))
|
||||
names := app.ComponentNames()
|
||||
assert.Equal(t, names, []string{"x1", "c1", "r1", "s1"})
|
||||
})
|
||||
t.Run("Child override", func(t *testing.T) {
|
||||
app := app.ChildApp()
|
||||
app.Register(newTestService(testTypeRunnable, "s1", nil, nil))
|
||||
_ = app.MustComponent("s1").(*testRunnable)
|
||||
})
|
||||
}
|
||||
|
||||
func TestApp_IterateComponents(t *testing.T) {
|
||||
app := new(App)
|
||||
|
||||
app.Register(newTestService(testTypeRunnable, "c1", nil, nil))
|
||||
app.Register(newTestService(testTypeRunnable, "r1", nil, nil))
|
||||
app.Register(newTestService(testTypeComponent, "s1", nil, nil))
|
||||
|
||||
var got []string
|
||||
app.IterateComponents(func(s Component) {
|
||||
got = append(got, s.Name())
|
||||
})
|
||||
|
||||
assert.ElementsMatch(t, []string{"c1", "r1", "s1"}, got)
|
||||
}
|
||||
|
||||
func TestAppStart(t *testing.T) {
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
// Package ldiff provides a container of elements with fixed id and changeable content.
|
||||
// Diff can calculate the difference with another diff container (you can make it remote) with minimum hops and traffic.
|
||||
//
|
||||
//go:generate mockgen -destination mock_ldiff/mock_ldiff.go github.com/anytypeio/any-sync/app/ldiff Diff,Remote
|
||||
//go:generate mockgen -destination mock_ldiff/mock_ldiff.go github.com/anyproto/any-sync/app/ldiff Diff,Remote
|
||||
package ldiff
|
||||
|
||||
import (
|
||||
@ -51,7 +51,7 @@ var hashersPool = &sync.Pool{
|
||||
},
|
||||
}
|
||||
|
||||
var ErrElementNotFound = errors.New("element not found")
|
||||
var ErrElementNotFound = errors.New("ldiff: element not found")
|
||||
|
||||
// Element of data
|
||||
type Element struct {
|
||||
@ -88,10 +88,14 @@ type Diff interface {
|
||||
Diff(ctx context.Context, dl Remote) (newIds, changedIds, removedIds []string, err error)
|
||||
// Elements retrieves all elements in the Diff
|
||||
Elements() []Element
|
||||
// Element returns an element by id
|
||||
Element(id string) (Element, error)
|
||||
// Ids retrieves ids of all elements in the Diff
|
||||
Ids() []string
|
||||
// Hash returns hash of all elements in the diff
|
||||
Hash() string
|
||||
// Len returns count of elements in the diff
|
||||
Len() int
|
||||
}
|
||||
|
||||
// Remote interface for using in the Diff
|
||||
@ -157,6 +161,12 @@ func (d *diff) Ids() (ids []string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (d *diff) Len() int {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
return d.sl.Len()
|
||||
}
|
||||
|
||||
func (d *diff) Elements() (elements []Element) {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
@ -172,6 +182,19 @@ func (d *diff) Elements() (elements []Element) {
|
||||
return
|
||||
}
|
||||
|
||||
func (d *diff) Element(id string) (Element, error) {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
el := d.sl.Get(&element{Element: Element{Id: id}, hash: xxhash.Sum64([]byte(id))})
|
||||
if el == nil {
|
||||
return Element{}, ErrElementNotFound
|
||||
}
|
||||
if e, ok := el.Key().(*element); ok {
|
||||
return e.Element, nil
|
||||
}
|
||||
return Element{}, ErrElementNotFound
|
||||
}
|
||||
|
||||
func (d *diff) Hash() string {
|
||||
d.mu.RLock()
|
||||
defer d.mu.RUnlock()
|
||||
|
||||
@ -3,10 +3,11 @@ package ldiff
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
"math"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@ -43,7 +44,7 @@ func TestDiff_Diff(t *testing.T) {
|
||||
d2 := New(16, 16)
|
||||
for i := 0; i < 1000; i++ {
|
||||
id := fmt.Sprint(i)
|
||||
head := bson.NewObjectId().Hex()
|
||||
head := uuid.NewString()
|
||||
d1.Set(Element{
|
||||
Id: id,
|
||||
Head: head,
|
||||
@ -91,7 +92,7 @@ func TestDiff_Diff(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
d2.Set(Element{
|
||||
Id: fmt.Sprint(i),
|
||||
Head: bson.NewObjectId().Hex(),
|
||||
Head: uuid.NewString(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -107,7 +108,7 @@ func TestDiff_Diff(t *testing.T) {
|
||||
for i := 0; i < 10; i++ {
|
||||
d2.Set(Element{
|
||||
Id: fmt.Sprint(i),
|
||||
Head: bson.NewObjectId().Hex(),
|
||||
Head: uuid.NewString(),
|
||||
})
|
||||
}
|
||||
var cancel func()
|
||||
@ -122,7 +123,7 @@ func BenchmarkDiff_Ranges(b *testing.B) {
|
||||
d := New(16, 16)
|
||||
for i := 0; i < 10000; i++ {
|
||||
id := fmt.Sprint(i)
|
||||
head := bson.NewObjectId().Hex()
|
||||
head := uuid.NewString()
|
||||
d.Set(Element{
|
||||
Id: id,
|
||||
Head: head,
|
||||
@ -148,3 +149,51 @@ func TestDiff_Hash(t *testing.T) {
|
||||
assert.NotEmpty(t, h2)
|
||||
assert.NotEqual(t, h1, h2)
|
||||
}
|
||||
|
||||
func TestDiff_Element(t *testing.T) {
|
||||
d := New(16, 16)
|
||||
for i := 0; i < 10; i++ {
|
||||
d.Set(Element{Id: fmt.Sprint("id", i), Head: fmt.Sprint("head", i)})
|
||||
}
|
||||
_, err := d.Element("not found")
|
||||
assert.Equal(t, ErrElementNotFound, err)
|
||||
|
||||
el, err := d.Element("id5")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "head5", el.Head)
|
||||
|
||||
d.Set(Element{"id5", "otherHead"})
|
||||
el, err = d.Element("id5")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "otherHead", el.Head)
|
||||
}
|
||||
|
||||
func TestDiff_Ids(t *testing.T) {
|
||||
d := New(16, 16)
|
||||
var ids []string
|
||||
for i := 0; i < 10; i++ {
|
||||
id := fmt.Sprint("id", i)
|
||||
d.Set(Element{Id: id, Head: fmt.Sprint("head", i)})
|
||||
ids = append(ids, id)
|
||||
}
|
||||
gotIds := d.Ids()
|
||||
sort.Strings(gotIds)
|
||||
assert.Equal(t, ids, gotIds)
|
||||
assert.Equal(t, len(ids), d.Len())
|
||||
}
|
||||
|
||||
func TestDiff_Elements(t *testing.T) {
|
||||
d := New(16, 16)
|
||||
var els []Element
|
||||
for i := 0; i < 10; i++ {
|
||||
id := fmt.Sprint("id", i)
|
||||
el := Element{Id: id, Head: fmt.Sprint("head", i)}
|
||||
d.Set(el)
|
||||
els = append(els, el)
|
||||
}
|
||||
gotEls := d.Elements()
|
||||
sort.Slice(gotEls, func(i, j int) bool {
|
||||
return gotEls[i].Id < gotEls[j].Id
|
||||
})
|
||||
assert.Equal(t, els, gotEls)
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anytypeio/any-sync/app/ldiff (interfaces: Diff,Remote)
|
||||
// Source: github.com/anyproto/any-sync/app/ldiff (interfaces: Diff,Remote)
|
||||
|
||||
// Package mock_ldiff is a generated GoMock package.
|
||||
package mock_ldiff
|
||||
@ -8,8 +8,8 @@ import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
ldiff "github.com/anytypeio/any-sync/app/ldiff"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
ldiff "github.com/anyproto/any-sync/app/ldiff"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockDiff is a mock of Diff interface.
|
||||
@ -52,6 +52,21 @@ func (mr *MockDiffMockRecorder) Diff(arg0, arg1 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Diff", reflect.TypeOf((*MockDiff)(nil).Diff), arg0, arg1)
|
||||
}
|
||||
|
||||
// Element mocks base method.
|
||||
func (m *MockDiff) Element(arg0 string) (ldiff.Element, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Element", arg0)
|
||||
ret0, _ := ret[0].(ldiff.Element)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Element indicates an expected call of Element.
|
||||
func (mr *MockDiffMockRecorder) Element(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Element", reflect.TypeOf((*MockDiff)(nil).Element), arg0)
|
||||
}
|
||||
|
||||
// Elements mocks base method.
|
||||
func (m *MockDiff) Elements() []ldiff.Element {
|
||||
m.ctrl.T.Helper()
|
||||
@ -94,6 +109,20 @@ func (mr *MockDiffMockRecorder) Ids() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Ids", reflect.TypeOf((*MockDiff)(nil).Ids))
|
||||
}
|
||||
|
||||
// Len mocks base method.
|
||||
func (m *MockDiff) Len() int {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Len")
|
||||
ret0, _ := ret[0].(int)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Len indicates an expected call of Len.
|
||||
func (mr *MockDiffMockRecorder) Len() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockDiff)(nil).Len))
|
||||
}
|
||||
|
||||
// Ranges mocks base method.
|
||||
func (m *MockDiff) Ranges(arg0 context.Context, arg1 []ldiff.Range, arg2 []ldiff.RangeResult) ([]ldiff.RangeResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@ -1,33 +1,124 @@
|
||||
package logger
|
||||
|
||||
import "go.uber.org/zap"
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
|
||||
"github.com/anyproto/any-sync/util/slice"
|
||||
)
|
||||
|
||||
type LogFormat int
|
||||
|
||||
const (
|
||||
ColorizedOutput LogFormat = iota
|
||||
PlaintextOutput
|
||||
JSONOutput
|
||||
)
|
||||
|
||||
type NamedLevel struct {
|
||||
Name string `yaml:"name"`
|
||||
Level string `yaml:"level"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Production bool `yaml:"production"`
|
||||
DefaultLevel string `yaml:"defaultLevel"`
|
||||
NamedLevels map[string]string `yaml:"namedLevels"`
|
||||
Levels []NamedLevel `yaml:"levels"` // first match will be used
|
||||
AddOutputPaths []string `yaml:"outputPaths"`
|
||||
DisableStdErr bool `yaml:"disableStdErr"`
|
||||
Format LogFormat `yaml:"format"`
|
||||
ZapConfig *zap.Config `yaml:"-"` // optional, if set it will be used instead of other config options
|
||||
}
|
||||
|
||||
func (l Config) ApplyGlobal() {
|
||||
var conf zap.Config
|
||||
if l.ZapConfig != nil {
|
||||
conf = *l.ZapConfig
|
||||
} else {
|
||||
if l.Production {
|
||||
conf = zap.NewProductionConfig()
|
||||
} else {
|
||||
conf = zap.NewDevelopmentConfig()
|
||||
}
|
||||
encConfig := conf.EncoderConfig
|
||||
switch l.Format {
|
||||
case PlaintextOutput:
|
||||
encConfig.EncodeLevel = zapcore.CapitalLevelEncoder
|
||||
conf.Encoding = "console"
|
||||
case JSONOutput:
|
||||
encConfig.MessageKey = "msg"
|
||||
encConfig.TimeKey = "ts"
|
||||
encConfig.LevelKey = "level"
|
||||
encConfig.NameKey = "logger"
|
||||
encConfig.CallerKey = "caller"
|
||||
encConfig.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
conf.Encoding = "json"
|
||||
default:
|
||||
// default is ColorizedOutput
|
||||
conf.Encoding = "console"
|
||||
encConfig.EncodeLevel = zapcore.CapitalColorLevelEncoder
|
||||
}
|
||||
|
||||
conf.EncoderConfig = encConfig
|
||||
if len(l.AddOutputPaths) > 0 {
|
||||
conf.OutputPaths = append(conf.OutputPaths, l.AddOutputPaths...)
|
||||
}
|
||||
if l.DisableStdErr {
|
||||
conf.OutputPaths = slice.Filter(conf.OutputPaths, func(path string) bool {
|
||||
return path != "stderr"
|
||||
})
|
||||
}
|
||||
|
||||
if defaultLevel, err := zap.ParseAtomicLevel(l.DefaultLevel); err == nil {
|
||||
conf.Level = defaultLevel
|
||||
}
|
||||
var levels = make(map[string]zap.AtomicLevel)
|
||||
for k, v := range l.NamedLevels {
|
||||
if lev, err := zap.ParseAtomicLevel(v); err != nil {
|
||||
levels[k] = lev
|
||||
}
|
||||
for _, v := range l.Levels {
|
||||
if lev, err := zap.ParseAtomicLevel(v.Level); err == nil {
|
||||
// we need to have a minimum level of all named loggers for the main logger
|
||||
if lev.Level() < conf.Level.Level() {
|
||||
conf.Level.SetLevel(lev.Level())
|
||||
}
|
||||
}
|
||||
defaultLogger, err := conf.Build()
|
||||
}
|
||||
|
||||
lg, err := conf.Build()
|
||||
if err != nil {
|
||||
Default().Fatal("can't build logger", zap.Error(err))
|
||||
}
|
||||
SetDefault(defaultLogger)
|
||||
SetNamedLevels(levels)
|
||||
SetDefault(lg)
|
||||
SetNamedLevels(l.Levels)
|
||||
}
|
||||
|
||||
// LevelsFromStr parses a string of the form "name1=DEBUG;prefix*=WARN;*=ERROR" into a slice of NamedLevel
|
||||
// it may be useful to parse the log level from the OS env var
|
||||
func LevelsFromStr(s string) (levels []NamedLevel) {
|
||||
for _, kv := range strings.Split(s, ";") {
|
||||
strings.TrimSpace(kv)
|
||||
parts := strings.Split(kv, "=")
|
||||
var key, value string
|
||||
if len(parts) == 1 {
|
||||
key = "*"
|
||||
value = parts[0]
|
||||
_, err := zap.ParseAtomicLevel(value)
|
||||
if err != nil {
|
||||
fmt.Printf("Can't parse log level %s: %s\n", parts[0], err.Error())
|
||||
continue
|
||||
}
|
||||
levels = append(levels, NamedLevel{Name: key, Level: value})
|
||||
} else if len(parts) == 2 {
|
||||
key = parts[0]
|
||||
value = parts[1]
|
||||
}
|
||||
_, err := zap.ParseAtomicLevel(value)
|
||||
if err != nil {
|
||||
fmt.Printf("Can't parse log level %s: %s\n", parts[0], err.Error())
|
||||
continue
|
||||
}
|
||||
levels = append(levels, NamedLevel{Name: key, Level: value})
|
||||
}
|
||||
return levels
|
||||
}
|
||||
|
||||
60
app/logger/ctxfiled.go
Normal file
60
app/logger/ctxfiled.go
Normal file
@ -0,0 +1,60 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type ctxKey uint
|
||||
|
||||
const (
|
||||
ctxKeyFields ctxKey = iota
|
||||
)
|
||||
|
||||
func WithCtx(ctx context.Context, l *zap.Logger) *zap.Logger {
|
||||
return l.With(CtxGetFields(ctx)...)
|
||||
}
|
||||
|
||||
func CtxWithFields(ctx context.Context, fields ...zap.Field) context.Context {
|
||||
existingFields := CtxGetFields(ctx)
|
||||
if existingFields != nil {
|
||||
existingFields = append(existingFields, fields...)
|
||||
}
|
||||
return context.WithValue(ctx, ctxKeyFields, fields)
|
||||
}
|
||||
|
||||
func CtxGetFields(ctx context.Context) (fields []zap.Field) {
|
||||
if v := ctx.Value(ctxKeyFields); v != nil {
|
||||
return v.([]zap.Field)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
type CtxLogger struct {
|
||||
*zap.Logger
|
||||
name string
|
||||
}
|
||||
|
||||
func (cl CtxLogger) DebugCtx(ctx context.Context, msg string, fields ...zap.Field) {
|
||||
cl.Logger.Debug(msg, append(CtxGetFields(ctx), fields...)...)
|
||||
}
|
||||
|
||||
func (cl CtxLogger) InfoCtx(ctx context.Context, msg string, fields ...zap.Field) {
|
||||
cl.Logger.Info(msg, append(CtxGetFields(ctx), fields...)...)
|
||||
}
|
||||
|
||||
func (cl CtxLogger) WarnCtx(ctx context.Context, msg string, fields ...zap.Field) {
|
||||
cl.Logger.Warn(msg, append(CtxGetFields(ctx), fields...)...)
|
||||
}
|
||||
|
||||
func (cl CtxLogger) ErrorCtx(ctx context.Context, msg string, fields ...zap.Field) {
|
||||
cl.Logger.Error(msg, append(CtxGetFields(ctx), fields...)...)
|
||||
}
|
||||
|
||||
func (cl CtxLogger) With(fields ...zap.Field) CtxLogger {
|
||||
return CtxLogger{cl.Logger.With(fields...), cl.name}
|
||||
}
|
||||
|
||||
func (cl CtxLogger) Sugar() *zap.SugaredLogger {
|
||||
return NewNamedSugared(cl.name)
|
||||
}
|
||||
@ -1,50 +1,137 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"sync"
|
||||
|
||||
"github.com/gobwas/glob"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
mu sync.Mutex
|
||||
defaultLogger *zap.Logger
|
||||
levels = make(map[string]zap.AtomicLevel)
|
||||
loggers = make(map[string]*zap.Logger)
|
||||
logger *zap.Logger
|
||||
loggerConfig zap.Config
|
||||
namedLevels []namedLevel
|
||||
namedGlobs = make(map[string]glob.Glob)
|
||||
namedLoggers = make(map[string]CtxLogger)
|
||||
namedSugarLoggers = make(map[string]*zap.SugaredLogger)
|
||||
)
|
||||
|
||||
func init() {
|
||||
defaultLogger, _ = zap.NewDevelopment()
|
||||
zap.NewProduction()
|
||||
type namedLevel struct {
|
||||
name string
|
||||
level zap.AtomicLevel
|
||||
}
|
||||
|
||||
func init() {
|
||||
loggerConfig = zap.NewDevelopmentConfig()
|
||||
logger, _ = loggerConfig.Build()
|
||||
}
|
||||
|
||||
// SetDefault replaces the default logger
|
||||
// you need to call SetNamedLevels after in case you have named loggers,
|
||||
// otherwise they will use the old logger
|
||||
func SetDefault(l *zap.Logger) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
*defaultLogger = *l
|
||||
for name, l := range loggers {
|
||||
*l = *defaultLogger.Named(name)
|
||||
}
|
||||
*logger = *l
|
||||
}
|
||||
|
||||
func SetNamedLevels(l map[string]zap.AtomicLevel) {
|
||||
// SetNamedLevels sets the namedLevels for named loggers
|
||||
// it also supports glob patterns for names, like "app*"
|
||||
// can be racy in case there are existing named loggers
|
||||
// so consider to call only once at the beginning
|
||||
func SetNamedLevels(nls []NamedLevel) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
levels = l
|
||||
namedLevels = namedLevels[:0]
|
||||
|
||||
var minLevel = logger.Level()
|
||||
for _, nl := range nls {
|
||||
l, err := zap.ParseAtomicLevel(nl.Level)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
namedLevels = append(namedLevels, namedLevel{name: nl.Name, level: l})
|
||||
g, err := glob.Compile(nl.Name)
|
||||
if err == nil {
|
||||
namedGlobs[nl.Name] = g
|
||||
}
|
||||
|
||||
if l.Level() < minLevel {
|
||||
minLevel = l.Level()
|
||||
}
|
||||
}
|
||||
|
||||
if minLevel < logger.Level() {
|
||||
// recreate logger if the min level is lower than the current min one
|
||||
loggerConfig.Level = zap.NewAtomicLevelAt(minLevel)
|
||||
logger, _ = loggerConfig.Build()
|
||||
}
|
||||
|
||||
for name, nl := range namedLoggers {
|
||||
level := getLevel(name)
|
||||
newCore := zap.New(logger.Core()).Named(name).WithOptions(
|
||||
zap.IncreaseLevel(level),
|
||||
)
|
||||
*(nl.Logger) = *newCore
|
||||
}
|
||||
|
||||
for name, nl := range namedSugarLoggers {
|
||||
level := getLevel(name)
|
||||
newCore := zap.New(logger.Core()).Named(name).WithOptions(
|
||||
zap.IncreaseLevel(level),
|
||||
).Sugar()
|
||||
*(nl) = *newCore
|
||||
}
|
||||
}
|
||||
|
||||
func Default() *zap.Logger {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
return defaultLogger
|
||||
return logger
|
||||
}
|
||||
|
||||
func NewNamed(name string, fields ...zap.Field) *zap.Logger {
|
||||
// getLevel returns the level for the given name
|
||||
// it return the first matching name or glob pattern whatever comes first
|
||||
func getLevel(name string) zap.AtomicLevel {
|
||||
for _, nl := range namedLevels {
|
||||
if nl.name == name {
|
||||
return nl.level
|
||||
}
|
||||
if g, ok := namedGlobs[nl.name]; ok && g.Match(name) {
|
||||
return nl.level
|
||||
}
|
||||
}
|
||||
return zap.NewAtomicLevelAt(logger.Level())
|
||||
}
|
||||
|
||||
func NewNamed(name string, fields ...zap.Field) CtxLogger {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
l := defaultLogger.Named(name)
|
||||
if len(fields) > 0 {
|
||||
l = l.With(fields...)
|
||||
|
||||
if l, nameExists := namedLoggers[name]; nameExists {
|
||||
return l
|
||||
}
|
||||
loggers[name] = l
|
||||
|
||||
level := getLevel(name)
|
||||
l := zap.New(logger.Core()).Named(name).WithOptions(zap.IncreaseLevel(level),
|
||||
zap.Fields(fields...))
|
||||
|
||||
ctxL := CtxLogger{Logger: l, name: name}
|
||||
namedLoggers[name] = ctxL
|
||||
return ctxL
|
||||
}
|
||||
|
||||
func NewNamedSugared(name string) *zap.SugaredLogger {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if l, nameExists := namedSugarLoggers[name]; nameExists {
|
||||
return l
|
||||
}
|
||||
|
||||
level := getLevel(name)
|
||||
l := zap.New(logger.Core()).Named(name).Sugar().WithOptions(zap.IncreaseLevel(level))
|
||||
namedSugarLoggers[name] = l
|
||||
return l
|
||||
}
|
||||
|
||||
150
app/logger/log_test.go
Normal file
150
app/logger/log_test.go
Normal file
@ -0,0 +1,150 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func Test_getLevel1(t *testing.T) {
|
||||
SetNamedLevels([]NamedLevel{
|
||||
{Name: "app", Level: "debug"},
|
||||
{Name: "app*", Level: "info"},
|
||||
{Name: "app.sub", Level: "warn"},
|
||||
{Name: "*", Level: "fatal"},
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
want zap.AtomicLevel
|
||||
}{
|
||||
{
|
||||
name: "app",
|
||||
want: zap.NewAtomicLevelAt(zap.DebugLevel),
|
||||
},
|
||||
{
|
||||
name: "app.aaa",
|
||||
want: zap.NewAtomicLevelAt(zap.InfoLevel),
|
||||
},
|
||||
{
|
||||
name: "app.sub",
|
||||
want: zap.NewAtomicLevelAt(zap.InfoLevel),
|
||||
},
|
||||
{
|
||||
name: "random",
|
||||
want: zap.NewAtomicLevelAt(zap.FatalLevel),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getLevel(tt.name); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("getLevel() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getLevel2(t *testing.T) {
|
||||
SetNamedLevels([]NamedLevel{
|
||||
{Name: "*", Level: "ERROR"},
|
||||
{Name: "app", Level: "info"},
|
||||
{Name: "app.sub", Level: "warn"},
|
||||
{Name: "*", Level: "fatal"},
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
want zap.AtomicLevel
|
||||
}{
|
||||
{
|
||||
name: "app",
|
||||
want: zap.NewAtomicLevelAt(zap.ErrorLevel),
|
||||
},
|
||||
{
|
||||
name: "app.aaa",
|
||||
want: zap.NewAtomicLevelAt(zap.ErrorLevel),
|
||||
},
|
||||
{
|
||||
name: "app.sub",
|
||||
want: zap.NewAtomicLevelAt(zap.ErrorLevel),
|
||||
},
|
||||
{
|
||||
name: "random",
|
||||
want: zap.NewAtomicLevelAt(zap.ErrorLevel),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getLevel(tt.name); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("getLevel() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getLevel3(t *testing.T) {
|
||||
SetNamedLevels([]NamedLevel{
|
||||
{Name: "app", Level: "info"},
|
||||
{Name: "*.sub", Level: "warn"},
|
||||
{Name: "*", Level: "fatal"},
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
want zap.AtomicLevel
|
||||
}{
|
||||
{
|
||||
name: "app",
|
||||
want: zap.NewAtomicLevelAt(zap.InfoLevel),
|
||||
},
|
||||
{
|
||||
name: "app.sub",
|
||||
want: zap.NewAtomicLevelAt(zap.WarnLevel),
|
||||
},
|
||||
{
|
||||
name: "random",
|
||||
want: zap.NewAtomicLevelAt(zap.FatalLevel),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getLevel(tt.name); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("getLevel() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getLevel4(t *testing.T) {
|
||||
SetNamedLevels([]NamedLevel{
|
||||
{Name: "*", Level: "invalid"},
|
||||
{Name: "app", Level: "info"},
|
||||
{Name: "b", Level: "invalid"},
|
||||
})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
want zap.AtomicLevel
|
||||
}{
|
||||
{
|
||||
name: "app",
|
||||
want: zap.NewAtomicLevelAt(zap.InfoLevel),
|
||||
},
|
||||
{
|
||||
name: "app.sub",
|
||||
want: zap.NewAtomicLevelAt(logger.Level()),
|
||||
},
|
||||
{
|
||||
name: "b",
|
||||
want: zap.NewAtomicLevelAt(logger.Level()),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getLevel(tt.name); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("getLevel() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
136
app/ocache/entry.go
Normal file
136
app/ocache/entry.go
Normal file
@ -0,0 +1,136 @@
|
||||
package ocache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type entryState int
|
||||
|
||||
const (
|
||||
entryStateLoading = iota
|
||||
entryStateActive
|
||||
entryStateClosing
|
||||
entryStateClosed
|
||||
)
|
||||
|
||||
type entry struct {
|
||||
id string
|
||||
state entryState
|
||||
lastUsage time.Time
|
||||
load chan struct{}
|
||||
loadErr error
|
||||
value Object
|
||||
close chan struct{}
|
||||
mx sync.Mutex
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func newEntry(id string, value Object, state entryState) *entry {
|
||||
return &entry{
|
||||
id: id,
|
||||
load: make(chan struct{}),
|
||||
lastUsage: time.Now(),
|
||||
state: state,
|
||||
value: value,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *entry) isActive() bool {
|
||||
e.mx.Lock()
|
||||
defer e.mx.Unlock()
|
||||
return e.state == entryStateActive
|
||||
}
|
||||
|
||||
func (e *entry) isClosing() bool {
|
||||
e.mx.Lock()
|
||||
defer e.mx.Unlock()
|
||||
return e.state == entryStateClosed || e.state == entryStateClosing
|
||||
}
|
||||
|
||||
func (e *entry) setCancel(cancel context.CancelFunc) {
|
||||
e.mx.Lock()
|
||||
defer e.mx.Unlock()
|
||||
e.cancel = cancel
|
||||
}
|
||||
|
||||
func (e *entry) cancelLoad() {
|
||||
e.mx.Lock()
|
||||
defer e.mx.Unlock()
|
||||
if e.cancel != nil {
|
||||
e.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
func (e *entry) waitLoad(ctx context.Context, id string) (value Object, err error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.DebugCtx(ctx, "ctx done while waiting on object load", zap.String("id", id))
|
||||
return nil, ctx.Err()
|
||||
case <-e.load:
|
||||
return e.value, e.loadErr
|
||||
}
|
||||
}
|
||||
|
||||
func (e *entry) waitClose(ctx context.Context, id string) (res bool, err error) {
|
||||
e.mx.Lock()
|
||||
switch e.state {
|
||||
case entryStateClosing:
|
||||
waitCh := e.close
|
||||
e.mx.Unlock()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.DebugCtx(ctx, "ctx done while waiting on object close", zap.String("id", id))
|
||||
return false, ctx.Err()
|
||||
case <-waitCh:
|
||||
return true, nil
|
||||
}
|
||||
case entryStateClosed:
|
||||
e.mx.Unlock()
|
||||
return true, nil
|
||||
default:
|
||||
e.mx.Unlock()
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (e *entry) setClosing(wait bool) (prevState, curState entryState) {
|
||||
e.mx.Lock()
|
||||
prevState = e.state
|
||||
curState = e.state
|
||||
if e.state == entryStateClosing {
|
||||
waitCh := e.close
|
||||
e.mx.Unlock()
|
||||
if !wait {
|
||||
return
|
||||
}
|
||||
<-waitCh
|
||||
e.mx.Lock()
|
||||
}
|
||||
if e.state != entryStateClosed {
|
||||
e.state = entryStateClosing
|
||||
e.close = make(chan struct{})
|
||||
}
|
||||
curState = e.state
|
||||
e.mx.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (e *entry) setActive(chClose bool) {
|
||||
e.mx.Lock()
|
||||
defer e.mx.Unlock()
|
||||
if chClose {
|
||||
close(e.close)
|
||||
}
|
||||
e.state = entryStateActive
|
||||
}
|
||||
|
||||
func (e *entry) setClosed() {
|
||||
e.mx.Lock()
|
||||
defer e.mx.Unlock()
|
||||
close(e.close)
|
||||
e.state = entryStateClosed
|
||||
}
|
||||
@ -1,14 +1,22 @@
|
||||
package ocache
|
||||
|
||||
import "github.com/prometheus/client_golang/prometheus"
|
||||
import (
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Option {
|
||||
if subsystem == "" {
|
||||
subsystem = "cache"
|
||||
}
|
||||
if reg == nil {
|
||||
return nil
|
||||
}
|
||||
if subsystem == "" {
|
||||
subsystem = "cache"
|
||||
}
|
||||
nameSplit := strings.Split(namespace, ".")
|
||||
subSplit := strings.Split(subsystem, ".")
|
||||
namespace = strings.Join(nameSplit, "_")
|
||||
subsystem = strings.Join(subSplit, "_")
|
||||
|
||||
return func(cache *oCache) {
|
||||
cache.metrics = &metrics{
|
||||
hit: prometheus.NewCounter(prometheus.CounterOpts{
|
||||
|
||||
19
app/ocache/metrics_test.go
Normal file
19
app/ocache/metrics_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package ocache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/require"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWithPrometheus_MetricsConvertsDots(t *testing.T) {
|
||||
opt := WithPrometheus(prometheus.NewRegistry(), "some.name", "some.system")
|
||||
cache := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
return &testObject{}, nil
|
||||
}, opt).(*oCache)
|
||||
_, err := cache.Get(context.Background(), "id")
|
||||
require.NoError(t, err)
|
||||
require.True(t, strings.Contains(cache.metrics.hit.Desc().String(), "some_name_some_system_hit"))
|
||||
}
|
||||
@ -3,10 +3,11 @@ package ocache
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/anytypeio/any-sync/app/logger"
|
||||
"go.uber.org/zap"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -44,12 +45,6 @@ var WithGCPeriod = func(gcPeriod time.Duration) Option {
|
||||
}
|
||||
}
|
||||
|
||||
var WithRefCounter = func(enable bool) Option {
|
||||
return func(cache *oCache) {
|
||||
cache.refCounter = enable
|
||||
}
|
||||
}
|
||||
|
||||
func New(loadFunc LoadFunc, opts ...Option) OCache {
|
||||
c := &oCache{
|
||||
data: make(map[string]*entry),
|
||||
@ -73,33 +68,7 @@ func New(loadFunc LoadFunc, opts ...Option) OCache {
|
||||
|
||||
type Object interface {
|
||||
Close() (err error)
|
||||
}
|
||||
|
||||
type ObjectLocker interface {
|
||||
Object
|
||||
Locked() bool
|
||||
}
|
||||
|
||||
type ObjectLastUsage interface {
|
||||
LastUsage() time.Time
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
id string
|
||||
lastUsage time.Time
|
||||
refCount uint32
|
||||
isClosing bool
|
||||
load chan struct{}
|
||||
loadErr error
|
||||
value Object
|
||||
close chan struct{}
|
||||
}
|
||||
|
||||
func (e *entry) locked() bool {
|
||||
if locker, ok := e.value.(ObjectLocker); ok {
|
||||
return locker.Locked()
|
||||
}
|
||||
return false
|
||||
TryClose(objectTTL time.Duration) (res bool, err error)
|
||||
}
|
||||
|
||||
type OCache interface {
|
||||
@ -116,12 +85,8 @@ type OCache interface {
|
||||
// Add adds new object to cache
|
||||
// Returns error when object exists
|
||||
Add(id string, value Object) (err error)
|
||||
// Release decreases the refs counter
|
||||
Release(id string) bool
|
||||
// Reset sets refs counter to 0
|
||||
Reset(id string) bool
|
||||
// Remove closes and removes object
|
||||
Remove(id string) (ok bool, err error)
|
||||
Remove(ctx context.Context, id string) (ok bool, err error)
|
||||
// ForEach iterates over all loaded objects, breaks when callback returns false
|
||||
ForEach(f func(v Object) (isContinue bool))
|
||||
// GC frees not used and expired objects
|
||||
@ -144,7 +109,6 @@ type oCache struct {
|
||||
closeCh chan struct{}
|
||||
log *zap.SugaredLogger
|
||||
metrics *metrics
|
||||
refCounter bool
|
||||
}
|
||||
|
||||
func (c *oCache) Get(ctx context.Context, id string) (value Object, err error) {
|
||||
@ -160,68 +124,44 @@ Load:
|
||||
return nil, ErrClosed
|
||||
}
|
||||
if e, ok = c.data[id]; !ok {
|
||||
e = newEntry(id, nil, entryStateLoading)
|
||||
load = true
|
||||
e = &entry{
|
||||
id: id,
|
||||
load: make(chan struct{}),
|
||||
}
|
||||
c.data[id] = e
|
||||
}
|
||||
closing := e.isClosing
|
||||
if !e.isClosing {
|
||||
e.lastUsage = c.timeNow()
|
||||
if c.refCounter {
|
||||
e.refCount++
|
||||
}
|
||||
}
|
||||
e.lastUsage = time.Now()
|
||||
c.mu.Unlock()
|
||||
if closing {
|
||||
<-e.close
|
||||
reload, err := e.waitClose(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if reload {
|
||||
goto Load
|
||||
}
|
||||
|
||||
if load {
|
||||
go c.load(ctx, id, e)
|
||||
}
|
||||
if c.metrics != nil {
|
||||
if load {
|
||||
c.metrics.miss.Inc()
|
||||
} else {
|
||||
c.metrics.hit.Inc()
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-e.load:
|
||||
}
|
||||
return e.value, e.loadErr
|
||||
c.metricsGet(!load)
|
||||
return e.waitLoad(ctx, id)
|
||||
}
|
||||
|
||||
func (c *oCache) Pick(ctx context.Context, id string) (value Object, err error) {
|
||||
c.mu.Lock()
|
||||
val, ok := c.data[id]
|
||||
if !ok || val.isClosing {
|
||||
if !ok || val.isClosing() {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrNotExists
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if c.metrics != nil {
|
||||
c.metrics.hit.Inc()
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-val.load:
|
||||
return val.value, val.loadErr
|
||||
}
|
||||
c.metricsGet(true)
|
||||
return val.waitLoad(ctx, id)
|
||||
}
|
||||
|
||||
func (c *oCache) load(ctx context.Context, id string, e *entry) {
|
||||
defer close(e.load)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
e.setCancel(cancel)
|
||||
value, err := c.loadFunc(ctx, id)
|
||||
cancel()
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
@ -230,63 +170,39 @@ func (c *oCache) load(ctx context.Context, id string, e *entry) {
|
||||
delete(c.data, id)
|
||||
} else {
|
||||
e.value = value
|
||||
e.setActive(false)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *oCache) Release(id string) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return false
|
||||
}
|
||||
if e, ok := c.data[id]; ok {
|
||||
if c.refCounter && e.refCount > 0 {
|
||||
e.refCount--
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *oCache) Reset(id string) bool {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
if c.closed {
|
||||
return false
|
||||
}
|
||||
if e, ok := c.data[id]; ok {
|
||||
e.refCount = 0
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (c *oCache) Remove(id string) (ok bool, err error) {
|
||||
func (c *oCache) Remove(ctx context.Context, id string) (ok bool, err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
err = ErrClosed
|
||||
return
|
||||
}
|
||||
var e *entry
|
||||
e, ok = c.data[id]
|
||||
if !ok || e.isClosing {
|
||||
e, ok := c.data[id]
|
||||
if !ok {
|
||||
c.mu.Unlock()
|
||||
return
|
||||
return false, ErrNotExists
|
||||
}
|
||||
e.isClosing = true
|
||||
e.close = make(chan struct{})
|
||||
c.mu.Unlock()
|
||||
return c.remove(ctx, e)
|
||||
}
|
||||
|
||||
<-e.load
|
||||
if e.value != nil {
|
||||
err = e.value.Close()
|
||||
func (c *oCache) remove(ctx context.Context, e *entry) (ok bool, err error) {
|
||||
if _, err = e.waitLoad(ctx, e.id); err != nil {
|
||||
return false, err
|
||||
}
|
||||
_, curState := e.setClosing(true)
|
||||
if curState == entryStateClosing {
|
||||
ok = true
|
||||
err = e.value.Close()
|
||||
c.mu.Lock()
|
||||
close(e.close)
|
||||
e.setClosed()
|
||||
delete(c.data, e.id)
|
||||
c.mu.Unlock()
|
||||
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -308,13 +224,7 @@ func (c *oCache) Add(id string, value Object) (err error) {
|
||||
if _, ok := c.data[id]; ok {
|
||||
return ErrExists
|
||||
}
|
||||
e := &entry{
|
||||
id: id,
|
||||
lastUsage: time.Now(),
|
||||
refCount: 0,
|
||||
load: make(chan struct{}),
|
||||
value: value,
|
||||
}
|
||||
e := newEntry(id, value, entryStateActive)
|
||||
close(e.load)
|
||||
c.data[id] = e
|
||||
return
|
||||
@ -326,7 +236,7 @@ func (c *oCache) ForEach(f func(obj Object) (isContinue bool)) {
|
||||
for _, v := range c.data {
|
||||
select {
|
||||
case <-v.load:
|
||||
if v.value != nil && !v.isClosing {
|
||||
if v.value != nil && !v.isClosing() {
|
||||
objects = append(objects, v.value)
|
||||
}
|
||||
default:
|
||||
@ -362,40 +272,35 @@ func (c *oCache) GC() {
|
||||
deadline := c.timeNow().Add(-c.ttl)
|
||||
var toClose []*entry
|
||||
for _, e := range c.data {
|
||||
if e.isClosing {
|
||||
continue
|
||||
}
|
||||
lu := e.lastUsage
|
||||
if lug, ok := e.value.(ObjectLastUsage); ok {
|
||||
lu = lug.LastUsage()
|
||||
}
|
||||
if !e.locked() && e.refCount <= 0 && lu.Before(deadline) {
|
||||
e.isClosing = true
|
||||
if e.isActive() && e.lastUsage.Before(deadline) {
|
||||
e.close = make(chan struct{})
|
||||
toClose = append(toClose, e)
|
||||
}
|
||||
}
|
||||
size := len(c.data)
|
||||
c.mu.Unlock()
|
||||
|
||||
closedNum := 0
|
||||
for _, e := range toClose {
|
||||
<-e.load
|
||||
if e.value != nil {
|
||||
if err := e.value.Close(); err != nil {
|
||||
prevState, _ := e.setClosing(false)
|
||||
if prevState == entryStateClosing || prevState == entryStateClosed {
|
||||
continue
|
||||
}
|
||||
closed, err := e.value.TryClose(c.ttl)
|
||||
if err != nil {
|
||||
c.log.With("object_id", e.id).Warnf("GC: object close error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
c.log.Infof("GC: removed %d; cache size: %d", len(toClose), size)
|
||||
if len(toClose) > 0 && c.metrics != nil {
|
||||
c.metrics.gc.Add(float64(len(toClose)))
|
||||
}
|
||||
if !closed {
|
||||
e.setActive(true)
|
||||
continue
|
||||
} else {
|
||||
closedNum++
|
||||
c.mu.Lock()
|
||||
for _, e := range toClose {
|
||||
close(e.close)
|
||||
e.setClosed()
|
||||
delete(c.data, e.id)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
c.metricsClosed(closedNum, size)
|
||||
}
|
||||
|
||||
func (c *oCache) Len() int {
|
||||
@ -412,25 +317,35 @@ func (c *oCache) Close() (err error) {
|
||||
}
|
||||
c.closed = true
|
||||
close(c.closeCh)
|
||||
var toClose, alreadyClosing []*entry
|
||||
var toClose []*entry
|
||||
for _, e := range c.data {
|
||||
if e.isClosing {
|
||||
alreadyClosing = append(alreadyClosing, e)
|
||||
} else {
|
||||
e.cancelLoad()
|
||||
toClose = append(toClose, e)
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
for _, e := range toClose {
|
||||
<-e.load
|
||||
if e.value != nil {
|
||||
if clErr := e.value.Close(); clErr != nil {
|
||||
c.log.With("object_id", e.id).Warnf("cache close: object close error: %v", clErr)
|
||||
if _, err := c.remove(context.Background(), e); err != nil && err != ErrNotExists {
|
||||
c.log.With("object_id", e.id).Warnf("cache close: object close error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, e := range alreadyClosing {
|
||||
<-e.close
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *oCache) metricsGet(hit bool) {
|
||||
if c.metrics == nil {
|
||||
return
|
||||
}
|
||||
if hit {
|
||||
c.metrics.hit.Inc()
|
||||
} else {
|
||||
c.metrics.miss.Inc()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *oCache) metricsClosed(closedLen, size int) {
|
||||
c.log.Infof("GC: removed %d; cache size: %d", closedLen, size)
|
||||
if c.metrics == nil || closedLen == 0 {
|
||||
return
|
||||
}
|
||||
c.metrics.gc.Add(float64(closedLen))
|
||||
}
|
||||
|
||||
@ -3,6 +3,8 @@ package ocache
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@ -11,26 +13,48 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var ctx = context.Background()
|
||||
|
||||
type testObject struct {
|
||||
name string
|
||||
closeErr error
|
||||
closeCh chan struct{}
|
||||
tryReturn bool
|
||||
closeCalled bool
|
||||
tryCloseCalled bool
|
||||
}
|
||||
|
||||
func NewTestObject(name string, closeCh chan struct{}) *testObject {
|
||||
func NewTestObject(name string, tryReturn bool, closeCh chan struct{}) *testObject {
|
||||
return &testObject{
|
||||
name: name,
|
||||
closeCh: closeCh,
|
||||
tryReturn: tryReturn,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *testObject) Close() (err error) {
|
||||
if t.closeCalled || (t.tryCloseCalled && t.tryReturn) {
|
||||
panic("close called twice")
|
||||
}
|
||||
t.closeCalled = true
|
||||
if t.closeCh != nil {
|
||||
<-t.closeCh
|
||||
}
|
||||
return t.closeErr
|
||||
}
|
||||
|
||||
func (t *testObject) TryClose(objectTTL time.Duration) (res bool, err error) {
|
||||
if t.closeCalled || (t.tryCloseCalled && t.tryReturn) {
|
||||
panic("close called twice")
|
||||
}
|
||||
t.tryCloseCalled = true
|
||||
if t.closeCh != nil {
|
||||
<-t.closeCh
|
||||
return t.tryReturn, t.closeErr
|
||||
}
|
||||
return t.tryReturn, nil
|
||||
}
|
||||
|
||||
func TestOCache_Get(t *testing.T) {
|
||||
t.Run("successful", func(t *testing.T) {
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
@ -116,42 +140,37 @@ func TestOCache_Get(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestOCache_GC(t *testing.T) {
|
||||
t.Run("test without close wait", func(t *testing.T) {
|
||||
t.Run("test gc expired object", func(t *testing.T) {
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
return &testObject{name: id}, nil
|
||||
}, WithTTL(time.Millisecond*10), WithRefCounter(true))
|
||||
return NewTestObject(id, true, nil), nil
|
||||
}, WithTTL(time.Millisecond*10))
|
||||
val, err := c.Get(context.TODO(), "id")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
assert.Equal(t, 1, c.Len())
|
||||
c.GC()
|
||||
assert.Equal(t, 1, c.Len())
|
||||
time.Sleep(time.Millisecond * 30)
|
||||
c.GC()
|
||||
assert.Equal(t, 1, c.Len())
|
||||
assert.True(t, c.Release("id"))
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
c.GC()
|
||||
assert.Equal(t, 0, c.Len())
|
||||
assert.False(t, c.Release("id"))
|
||||
})
|
||||
t.Run("test with close wait", func(t *testing.T) {
|
||||
t.Run("test gc tryClose true, close before get", func(t *testing.T) {
|
||||
closeCh := make(chan struct{})
|
||||
getCh := make(chan struct{})
|
||||
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
return NewTestObject(id, closeCh), nil
|
||||
}, WithTTL(time.Millisecond*10), WithRefCounter(true))
|
||||
return NewTestObject(id, true, closeCh), nil
|
||||
}, WithTTL(time.Millisecond*10))
|
||||
val, err := c.Get(context.TODO(), "id")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
assert.Equal(t, 1, c.Len())
|
||||
assert.True(t, c.Release("id"))
|
||||
// making ttl pass
|
||||
time.Sleep(time.Millisecond * 40)
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
// first gc will be run after 20 secs, so calling it manually
|
||||
go c.GC()
|
||||
// waiting until all objects are marked as closing
|
||||
time.Sleep(time.Millisecond * 40)
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
var events []string
|
||||
go func() {
|
||||
_, err := c.Get(context.TODO(), "id")
|
||||
@ -160,33 +179,114 @@ func TestOCache_GC(t *testing.T) {
|
||||
events = append(events, "get")
|
||||
close(getCh)
|
||||
}()
|
||||
events = append(events, "close")
|
||||
// sleeping to make sure that Get is called
|
||||
time.Sleep(time.Millisecond * 40)
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
events = append(events, "close")
|
||||
close(closeCh)
|
||||
|
||||
<-getCh
|
||||
require.Equal(t, []string{"close", "get"}, events)
|
||||
})
|
||||
t.Run("test gc tryClose false, many parallel get", func(t *testing.T) {
|
||||
timesCalled := &atomic.Int32{}
|
||||
obj := NewTestObject("id", false, nil)
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
timesCalled.Add(1)
|
||||
return obj, nil
|
||||
}, WithTTL(0))
|
||||
|
||||
val, err := c.Get(context.TODO(), "id")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
assert.Equal(t, 1, c.Len())
|
||||
begin := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
once := sync.Once{}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
<-begin
|
||||
c.GC()
|
||||
wg.Done()
|
||||
}()
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
once.Do(func() {
|
||||
close(begin)
|
||||
})
|
||||
if i%2 != 0 {
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
_, err := c.Get(context.TODO(), "id")
|
||||
require.NoError(t, err)
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
wg.Wait()
|
||||
require.Equal(t, timesCalled.Load(), int32(1))
|
||||
require.True(t, obj.tryCloseCalled)
|
||||
})
|
||||
t.Run("test gc tryClose different, many objects", func(t *testing.T) {
|
||||
tryCloseIds := make(map[string]bool)
|
||||
called := make(map[string]int)
|
||||
max := 1000
|
||||
getId := func(i int) string {
|
||||
return fmt.Sprintf("id%d", i)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
if i%2 == 1 {
|
||||
tryCloseIds[getId(i)] = true
|
||||
} else {
|
||||
tryCloseIds[getId(i)] = false
|
||||
}
|
||||
}
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
called[id] = called[id] + 1
|
||||
return NewTestObject(id, tryCloseIds[id], nil), nil
|
||||
}, WithTTL(time.Millisecond*10))
|
||||
|
||||
for i := 0; i < max; i++ {
|
||||
val, err := c.Get(context.TODO(), getId(i))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
}
|
||||
assert.Equal(t, max, c.Len())
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
c.GC()
|
||||
for i := 0; i < max; i++ {
|
||||
val, err := c.Get(context.TODO(), getId(i))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
val, err := c.Get(context.TODO(), getId(i))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
require.Equal(t, called[getId(i)], i%2+1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_OCache_Remove(t *testing.T) {
|
||||
t.Run("remove simple", func(t *testing.T) {
|
||||
closeCh := make(chan struct{})
|
||||
getCh := make(chan struct{})
|
||||
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
return NewTestObject(id, closeCh), nil
|
||||
return NewTestObject(id, false, closeCh), nil
|
||||
}, WithTTL(time.Millisecond*10))
|
||||
|
||||
val, err := c.Get(context.TODO(), "id")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
assert.Equal(t, 1, c.Len())
|
||||
// removing the object, so we will wait on closing
|
||||
go func() {
|
||||
_, err := c.Remove("id")
|
||||
_, err := c.Remove(ctx, "id")
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 40)
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
var events []string
|
||||
go func() {
|
||||
@ -196,11 +296,215 @@ func Test_OCache_Remove(t *testing.T) {
|
||||
events = append(events, "get")
|
||||
close(getCh)
|
||||
}()
|
||||
events = append(events, "close")
|
||||
// sleeping to make sure that Get is called
|
||||
time.Sleep(time.Millisecond * 40)
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
events = append(events, "close")
|
||||
close(closeCh)
|
||||
|
||||
<-getCh
|
||||
require.Equal(t, []string{"close", "get"}, events)
|
||||
})
|
||||
t.Run("test remove while gc, tryClose false", func(t *testing.T) {
|
||||
closeCh := make(chan struct{})
|
||||
removeCh := make(chan struct{})
|
||||
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
return NewTestObject(id, false, closeCh), nil
|
||||
}, WithTTL(time.Millisecond*10))
|
||||
val, err := c.Get(context.TODO(), "id")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
assert.Equal(t, 1, c.Len())
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
go c.GC()
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
var events []string
|
||||
go func() {
|
||||
ok, err := c.Remove(ctx, "id")
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
events = append(events, "remove")
|
||||
close(removeCh)
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
events = append(events, "close")
|
||||
close(closeCh)
|
||||
|
||||
<-removeCh
|
||||
require.Equal(t, []string{"close", "remove"}, events)
|
||||
})
|
||||
t.Run("test remove while gc, tryClose true", func(t *testing.T) {
|
||||
closeCh := make(chan struct{})
|
||||
removeCh := make(chan struct{})
|
||||
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
return NewTestObject(id, true, closeCh), nil
|
||||
}, WithTTL(time.Millisecond*10))
|
||||
val, err := c.Get(context.TODO(), "id")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
assert.Equal(t, 1, c.Len())
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
go c.GC()
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
var events []string
|
||||
go func() {
|
||||
ok, err := c.Remove(ctx, "id")
|
||||
require.NoError(t, err)
|
||||
require.False(t, ok)
|
||||
events = append(events, "remove")
|
||||
close(removeCh)
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
events = append(events, "close")
|
||||
close(closeCh)
|
||||
|
||||
<-removeCh
|
||||
require.Equal(t, []string{"close", "remove"}, events)
|
||||
})
|
||||
t.Run("test gc while remove, tryClose true", func(t *testing.T) {
|
||||
closeCh := make(chan struct{})
|
||||
removeCh := make(chan struct{})
|
||||
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
return NewTestObject(id, true, closeCh), nil
|
||||
}, WithTTL(time.Millisecond*10))
|
||||
val, err := c.Get(context.TODO(), "id")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
assert.Equal(t, 1, c.Len())
|
||||
go func() {
|
||||
ok, err := c.Remove(ctx, "id")
|
||||
require.NoError(t, err)
|
||||
require.True(t, ok)
|
||||
close(removeCh)
|
||||
}()
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
c.GC()
|
||||
close(closeCh)
|
||||
<-removeCh
|
||||
})
|
||||
}
|
||||
|
||||
func TestOCacheCancelWhenRemove(t *testing.T) {
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}, WithTTL(time.Millisecond*10))
|
||||
stopLoad := make(chan struct{})
|
||||
var err error
|
||||
go func() {
|
||||
_, err = c.Get(context.TODO(), "id")
|
||||
stopLoad <- struct{}{}
|
||||
}()
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
c.Close()
|
||||
<-stopLoad
|
||||
require.Equal(t, context.Canceled, err)
|
||||
}
|
||||
|
||||
func TestOCacheFuzzy(t *testing.T) {
|
||||
t.Run("test many objects gc, get and remove simultaneously, close after", func(t *testing.T) {
|
||||
tryCloseIds := make(map[string]bool)
|
||||
max := 2000
|
||||
getId := func(i int) string {
|
||||
return fmt.Sprintf("id%d", i)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
if i%2 == 1 {
|
||||
tryCloseIds[getId(i)] = true
|
||||
} else {
|
||||
tryCloseIds[getId(i)] = false
|
||||
}
|
||||
}
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
return NewTestObject(id, tryCloseIds[id], nil), nil
|
||||
}, WithTTL(time.Nanosecond))
|
||||
|
||||
stopGC := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-stopGC:
|
||||
return
|
||||
default:
|
||||
c.GC()
|
||||
}
|
||||
}
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
for i := 0; i < max; i++ {
|
||||
val, err := c.Get(context.TODO(), getId(i))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
}
|
||||
}
|
||||
}()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 10; j++ {
|
||||
for i := 0; i < max; i++ {
|
||||
c.Remove(ctx, getId(i))
|
||||
}
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
close(stopGC)
|
||||
err := c.Close()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, c.Len())
|
||||
})
|
||||
t.Run("test many objects gc, get, remove and close simultaneously", func(t *testing.T) {
|
||||
tryCloseIds := make(map[string]bool)
|
||||
max := 2000
|
||||
getId := func(i int) string {
|
||||
return fmt.Sprintf("id%d", i)
|
||||
}
|
||||
for i := 0; i < max; i++ {
|
||||
if i%2 == 1 {
|
||||
tryCloseIds[getId(i)] = true
|
||||
} else {
|
||||
tryCloseIds[getId(i)] = false
|
||||
}
|
||||
}
|
||||
c := New(func(ctx context.Context, id string) (value Object, err error) {
|
||||
return NewTestObject(id, tryCloseIds[id], nil), nil
|
||||
}, WithTTL(time.Nanosecond))
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c.GC()
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for j := 0; j < 10; j++ {
|
||||
for i := 0; i < max; i++ {
|
||||
val, err := c.Get(context.TODO(), getId(i))
|
||||
if err == ErrClosed {
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, val)
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
for j := 0; j < 10; j++ {
|
||||
for i := 0; i < max; i++ {
|
||||
c.Remove(ctx, getId(i))
|
||||
}
|
||||
}
|
||||
}()
|
||||
time.Sleep(time.Millisecond)
|
||||
err := c.Close()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, c.Len())
|
||||
})
|
||||
}
|
||||
|
||||
@ -2,8 +2,8 @@ package fileblockstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/app/logger"
|
||||
"github.com/anytypeio/any-sync/commonfile/fileproto/fileprotoerr"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/commonfile/fileproto/fileprotoerr"
|
||||
blocks "github.com/ipfs/go-block-format"
|
||||
"github.com/ipfs/go-cid"
|
||||
)
|
||||
@ -21,6 +21,7 @@ type ctxKey uint
|
||||
|
||||
const (
|
||||
ctxKeySpaceId ctxKey = iota
|
||||
ctxKeyFileId
|
||||
)
|
||||
|
||||
type BlockStore interface {
|
||||
@ -48,3 +49,12 @@ func CtxGetSpaceId(ctx context.Context) (spaceId string) {
|
||||
spaceId, _ = ctx.Value(ctxKeySpaceId).(string)
|
||||
return
|
||||
}
|
||||
|
||||
func CtxWithFileId(ctx context.Context, spaceId string) context.Context {
|
||||
return context.WithValue(ctx, ctxKeyFileId, spaceId)
|
||||
}
|
||||
|
||||
func CtxGetFileId(ctx context.Context) (spaceId string) {
|
||||
spaceId, _ = ctx.Value(ctxKeyFileId).(string)
|
||||
return
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,5 +1,5 @@
|
||||
// Code generated by protoc-gen-go-drpc. DO NOT EDIT.
|
||||
// protoc-gen-go-drpc version: v0.0.32
|
||||
// protoc-gen-go-drpc version: v0.0.33
|
||||
// source: commonfile/fileproto/protos/file.proto
|
||||
|
||||
package fileproto
|
||||
@ -40,10 +40,14 @@ func (drpcEncoding_File_commonfile_fileproto_protos_file_proto) JSONUnmarshal(bu
|
||||
type DRPCFileClient interface {
|
||||
DRPCConn() drpc.Conn
|
||||
|
||||
GetBlocks(ctx context.Context) (DRPCFile_GetBlocksClient, error)
|
||||
PushBlock(ctx context.Context, in *PushBlockRequest) (*PushBlockResponse, error)
|
||||
DeleteBlocks(ctx context.Context, in *DeleteBlocksRequest) (*DeleteBlocksResponse, error)
|
||||
BlockGet(ctx context.Context, in *BlockGetRequest) (*BlockGetResponse, error)
|
||||
BlockPush(ctx context.Context, in *BlockPushRequest) (*BlockPushResponse, error)
|
||||
BlocksCheck(ctx context.Context, in *BlocksCheckRequest) (*BlocksCheckResponse, error)
|
||||
BlocksBind(ctx context.Context, in *BlocksBindRequest) (*BlocksBindResponse, error)
|
||||
FilesDelete(ctx context.Context, in *FilesDeleteRequest) (*FilesDeleteResponse, error)
|
||||
FilesInfo(ctx context.Context, in *FilesInfoRequest) (*FilesInfoResponse, error)
|
||||
Check(ctx context.Context, in *CheckRequest) (*CheckResponse, error)
|
||||
SpaceInfo(ctx context.Context, in *SpaceInfoRequest) (*SpaceInfoResponse, error)
|
||||
}
|
||||
|
||||
type drpcFileClient struct {
|
||||
@ -56,53 +60,54 @@ func NewDRPCFileClient(cc drpc.Conn) DRPCFileClient {
|
||||
|
||||
func (c *drpcFileClient) DRPCConn() drpc.Conn { return c.cc }
|
||||
|
||||
func (c *drpcFileClient) GetBlocks(ctx context.Context) (DRPCFile_GetBlocksClient, error) {
|
||||
stream, err := c.cc.NewStream(ctx, "/anyFile.File/GetBlocks", drpcEncoding_File_commonfile_fileproto_protos_file_proto{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
x := &drpcFile_GetBlocksClient{stream}
|
||||
return x, nil
|
||||
}
|
||||
|
||||
type DRPCFile_GetBlocksClient interface {
|
||||
drpc.Stream
|
||||
Send(*GetBlockRequest) error
|
||||
Recv() (*GetBlockResponse, error)
|
||||
}
|
||||
|
||||
type drpcFile_GetBlocksClient struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcFile_GetBlocksClient) Send(m *GetBlockRequest) error {
|
||||
return x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{})
|
||||
}
|
||||
|
||||
func (x *drpcFile_GetBlocksClient) Recv() (*GetBlockResponse, error) {
|
||||
m := new(GetBlockResponse)
|
||||
if err := x.MsgRecv(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (x *drpcFile_GetBlocksClient) RecvMsg(m *GetBlockResponse) error {
|
||||
return x.MsgRecv(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{})
|
||||
}
|
||||
|
||||
func (c *drpcFileClient) PushBlock(ctx context.Context, in *PushBlockRequest) (*PushBlockResponse, error) {
|
||||
out := new(PushBlockResponse)
|
||||
err := c.cc.Invoke(ctx, "/anyFile.File/PushBlock", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
|
||||
func (c *drpcFileClient) BlockGet(ctx context.Context, in *BlockGetRequest) (*BlockGetResponse, error) {
|
||||
out := new(BlockGetResponse)
|
||||
err := c.cc.Invoke(ctx, "/filesync.File/BlockGet", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcFileClient) DeleteBlocks(ctx context.Context, in *DeleteBlocksRequest) (*DeleteBlocksResponse, error) {
|
||||
out := new(DeleteBlocksResponse)
|
||||
err := c.cc.Invoke(ctx, "/anyFile.File/DeleteBlocks", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
|
||||
func (c *drpcFileClient) BlockPush(ctx context.Context, in *BlockPushRequest) (*BlockPushResponse, error) {
|
||||
out := new(BlockPushResponse)
|
||||
err := c.cc.Invoke(ctx, "/filesync.File/BlockPush", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcFileClient) BlocksCheck(ctx context.Context, in *BlocksCheckRequest) (*BlocksCheckResponse, error) {
|
||||
out := new(BlocksCheckResponse)
|
||||
err := c.cc.Invoke(ctx, "/filesync.File/BlocksCheck", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcFileClient) BlocksBind(ctx context.Context, in *BlocksBindRequest) (*BlocksBindResponse, error) {
|
||||
out := new(BlocksBindResponse)
|
||||
err := c.cc.Invoke(ctx, "/filesync.File/BlocksBind", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcFileClient) FilesDelete(ctx context.Context, in *FilesDeleteRequest) (*FilesDeleteResponse, error) {
|
||||
out := new(FilesDeleteResponse)
|
||||
err := c.cc.Invoke(ctx, "/filesync.File/FilesDelete", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcFileClient) FilesInfo(ctx context.Context, in *FilesInfoRequest) (*FilesInfoResponse, error) {
|
||||
out := new(FilesInfoResponse)
|
||||
err := c.cc.Invoke(ctx, "/filesync.File/FilesInfo", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -111,7 +116,16 @@ func (c *drpcFileClient) DeleteBlocks(ctx context.Context, in *DeleteBlocksReque
|
||||
|
||||
func (c *drpcFileClient) Check(ctx context.Context, in *CheckRequest) (*CheckResponse, error) {
|
||||
out := new(CheckResponse)
|
||||
err := c.cc.Invoke(ctx, "/anyFile.File/Check", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
|
||||
err := c.cc.Invoke(ctx, "/filesync.File/Check", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *drpcFileClient) SpaceInfo(ctx context.Context, in *SpaceInfoRequest) (*SpaceInfoResponse, error) {
|
||||
out := new(SpaceInfoResponse)
|
||||
err := c.cc.Invoke(ctx, "/filesync.File/SpaceInfo", drpcEncoding_File_commonfile_fileproto_protos_file_proto{}, in, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -119,23 +133,39 @@ func (c *drpcFileClient) Check(ctx context.Context, in *CheckRequest) (*CheckRes
|
||||
}
|
||||
|
||||
type DRPCFileServer interface {
|
||||
GetBlocks(DRPCFile_GetBlocksStream) error
|
||||
PushBlock(context.Context, *PushBlockRequest) (*PushBlockResponse, error)
|
||||
DeleteBlocks(context.Context, *DeleteBlocksRequest) (*DeleteBlocksResponse, error)
|
||||
BlockGet(context.Context, *BlockGetRequest) (*BlockGetResponse, error)
|
||||
BlockPush(context.Context, *BlockPushRequest) (*BlockPushResponse, error)
|
||||
BlocksCheck(context.Context, *BlocksCheckRequest) (*BlocksCheckResponse, error)
|
||||
BlocksBind(context.Context, *BlocksBindRequest) (*BlocksBindResponse, error)
|
||||
FilesDelete(context.Context, *FilesDeleteRequest) (*FilesDeleteResponse, error)
|
||||
FilesInfo(context.Context, *FilesInfoRequest) (*FilesInfoResponse, error)
|
||||
Check(context.Context, *CheckRequest) (*CheckResponse, error)
|
||||
SpaceInfo(context.Context, *SpaceInfoRequest) (*SpaceInfoResponse, error)
|
||||
}
|
||||
|
||||
type DRPCFileUnimplementedServer struct{}
|
||||
|
||||
func (s *DRPCFileUnimplementedServer) GetBlocks(DRPCFile_GetBlocksStream) error {
|
||||
return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCFileUnimplementedServer) PushBlock(context.Context, *PushBlockRequest) (*PushBlockResponse, error) {
|
||||
func (s *DRPCFileUnimplementedServer) BlockGet(context.Context, *BlockGetRequest) (*BlockGetResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCFileUnimplementedServer) DeleteBlocks(context.Context, *DeleteBlocksRequest) (*DeleteBlocksResponse, error) {
|
||||
func (s *DRPCFileUnimplementedServer) BlockPush(context.Context, *BlockPushRequest) (*BlockPushResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCFileUnimplementedServer) BlocksCheck(context.Context, *BlocksCheckRequest) (*BlocksCheckResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCFileUnimplementedServer) BlocksBind(context.Context, *BlocksBindRequest) (*BlocksBindResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCFileUnimplementedServer) FilesDelete(context.Context, *FilesDeleteRequest) (*FilesDeleteResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCFileUnimplementedServer) FilesInfo(context.Context, *FilesInfoRequest) (*FilesInfoResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
@ -143,40 +173,72 @@ func (s *DRPCFileUnimplementedServer) Check(context.Context, *CheckRequest) (*Ch
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
func (s *DRPCFileUnimplementedServer) SpaceInfo(context.Context, *SpaceInfoRequest) (*SpaceInfoResponse, error) {
|
||||
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
|
||||
}
|
||||
|
||||
type DRPCFileDescription struct{}
|
||||
|
||||
func (DRPCFileDescription) NumMethods() int { return 4 }
|
||||
func (DRPCFileDescription) NumMethods() int { return 8 }
|
||||
|
||||
func (DRPCFileDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
|
||||
switch n {
|
||||
case 0:
|
||||
return "/anyFile.File/GetBlocks", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
return "/filesync.File/BlockGet", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return nil, srv.(DRPCFileServer).
|
||||
GetBlocks(
|
||||
&drpcFile_GetBlocksStream{in1.(drpc.Stream)},
|
||||
return srv.(DRPCFileServer).
|
||||
BlockGet(
|
||||
ctx,
|
||||
in1.(*BlockGetRequest),
|
||||
)
|
||||
}, DRPCFileServer.GetBlocks, true
|
||||
}, DRPCFileServer.BlockGet, true
|
||||
case 1:
|
||||
return "/anyFile.File/PushBlock", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
return "/filesync.File/BlockPush", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCFileServer).
|
||||
PushBlock(
|
||||
BlockPush(
|
||||
ctx,
|
||||
in1.(*PushBlockRequest),
|
||||
in1.(*BlockPushRequest),
|
||||
)
|
||||
}, DRPCFileServer.PushBlock, true
|
||||
}, DRPCFileServer.BlockPush, true
|
||||
case 2:
|
||||
return "/anyFile.File/DeleteBlocks", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
return "/filesync.File/BlocksCheck", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCFileServer).
|
||||
DeleteBlocks(
|
||||
BlocksCheck(
|
||||
ctx,
|
||||
in1.(*DeleteBlocksRequest),
|
||||
in1.(*BlocksCheckRequest),
|
||||
)
|
||||
}, DRPCFileServer.DeleteBlocks, true
|
||||
}, DRPCFileServer.BlocksCheck, true
|
||||
case 3:
|
||||
return "/anyFile.File/Check", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
return "/filesync.File/BlocksBind", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCFileServer).
|
||||
BlocksBind(
|
||||
ctx,
|
||||
in1.(*BlocksBindRequest),
|
||||
)
|
||||
}, DRPCFileServer.BlocksBind, true
|
||||
case 4:
|
||||
return "/filesync.File/FilesDelete", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCFileServer).
|
||||
FilesDelete(
|
||||
ctx,
|
||||
in1.(*FilesDeleteRequest),
|
||||
)
|
||||
}, DRPCFileServer.FilesDelete, true
|
||||
case 5:
|
||||
return "/filesync.File/FilesInfo", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCFileServer).
|
||||
FilesInfo(
|
||||
ctx,
|
||||
in1.(*FilesInfoRequest),
|
||||
)
|
||||
}, DRPCFileServer.FilesInfo, true
|
||||
case 6:
|
||||
return "/filesync.File/Check", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCFileServer).
|
||||
Check(
|
||||
@ -184,6 +246,15 @@ func (DRPCFileDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver,
|
||||
in1.(*CheckRequest),
|
||||
)
|
||||
}, DRPCFileServer.Check, true
|
||||
case 7:
|
||||
return "/filesync.File/SpaceInfo", drpcEncoding_File_commonfile_fileproto_protos_file_proto{},
|
||||
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
|
||||
return srv.(DRPCFileServer).
|
||||
SpaceInfo(
|
||||
ctx,
|
||||
in1.(*SpaceInfoRequest),
|
||||
)
|
||||
}, DRPCFileServer.SpaceInfo, true
|
||||
default:
|
||||
return "", nil, nil, nil, false
|
||||
}
|
||||
@ -193,58 +264,96 @@ func DRPCRegisterFile(mux drpc.Mux, impl DRPCFileServer) error {
|
||||
return mux.Register(impl, DRPCFileDescription{})
|
||||
}
|
||||
|
||||
type DRPCFile_GetBlocksStream interface {
|
||||
type DRPCFile_BlockGetStream interface {
|
||||
drpc.Stream
|
||||
Send(*GetBlockResponse) error
|
||||
Recv() (*GetBlockRequest, error)
|
||||
SendAndClose(*BlockGetResponse) error
|
||||
}
|
||||
|
||||
type drpcFile_GetBlocksStream struct {
|
||||
type drpcFile_BlockGetStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcFile_GetBlocksStream) Send(m *GetBlockResponse) error {
|
||||
return x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{})
|
||||
}
|
||||
|
||||
func (x *drpcFile_GetBlocksStream) Recv() (*GetBlockRequest, error) {
|
||||
m := new(GetBlockRequest)
|
||||
if err := x.MsgRecv(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (x *drpcFile_GetBlocksStream) RecvMsg(m *GetBlockRequest) error {
|
||||
return x.MsgRecv(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{})
|
||||
}
|
||||
|
||||
type DRPCFile_PushBlockStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*PushBlockResponse) error
|
||||
}
|
||||
|
||||
type drpcFile_PushBlockStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcFile_PushBlockStream) SendAndClose(m *PushBlockResponse) error {
|
||||
func (x *drpcFile_BlockGetStream) SendAndClose(m *BlockGetResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCFile_DeleteBlocksStream interface {
|
||||
type DRPCFile_BlockPushStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*DeleteBlocksResponse) error
|
||||
SendAndClose(*BlockPushResponse) error
|
||||
}
|
||||
|
||||
type drpcFile_DeleteBlocksStream struct {
|
||||
type drpcFile_BlockPushStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcFile_DeleteBlocksStream) SendAndClose(m *DeleteBlocksResponse) error {
|
||||
func (x *drpcFile_BlockPushStream) SendAndClose(m *BlockPushResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCFile_BlocksCheckStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*BlocksCheckResponse) error
|
||||
}
|
||||
|
||||
type drpcFile_BlocksCheckStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcFile_BlocksCheckStream) SendAndClose(m *BlocksCheckResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCFile_BlocksBindStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*BlocksBindResponse) error
|
||||
}
|
||||
|
||||
type drpcFile_BlocksBindStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcFile_BlocksBindStream) SendAndClose(m *BlocksBindResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCFile_FilesDeleteStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*FilesDeleteResponse) error
|
||||
}
|
||||
|
||||
type drpcFile_FilesDeleteStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcFile_FilesDeleteStream) SendAndClose(m *FilesDeleteResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCFile_FilesInfoStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*FilesInfoResponse) error
|
||||
}
|
||||
|
||||
type drpcFile_FilesInfoStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcFile_FilesInfoStream) SendAndClose(m *FilesInfoResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -266,3 +375,19 @@ func (x *drpcFile_CheckStream) SendAndClose(m *CheckResponse) error {
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
type DRPCFile_SpaceInfoStream interface {
|
||||
drpc.Stream
|
||||
SendAndClose(*SpaceInfoResponse) error
|
||||
}
|
||||
|
||||
type drpcFile_SpaceInfoStream struct {
|
||||
drpc.Stream
|
||||
}
|
||||
|
||||
func (x *drpcFile_SpaceInfoStream) SendAndClose(m *SpaceInfoResponse) error {
|
||||
if err := x.MsgSend(m, drpcEncoding_File_commonfile_fileproto_protos_file_proto{}); err != nil {
|
||||
return err
|
||||
}
|
||||
return x.CloseSend()
|
||||
}
|
||||
|
||||
@ -2,12 +2,16 @@ package fileprotoerr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/commonfile/fileproto"
|
||||
"github.com/anytypeio/any-sync/net/rpc/rpcerr"
|
||||
"github.com/anyproto/any-sync/commonfile/fileproto"
|
||||
"github.com/anyproto/any-sync/net/rpc/rpcerr"
|
||||
)
|
||||
|
||||
var (
|
||||
errGroup = rpcerr.ErrGroup(fileproto.ErrCodes_ErrorOffset)
|
||||
ErrUnexpected = errGroup.Register(fmt.Errorf("unexpected fileproto error"), uint64(fileproto.ErrCodes_Unexpected))
|
||||
ErrCIDNotFound = errGroup.Register(fmt.Errorf("CID not found"), uint64(fileproto.ErrCodes_CIDNotFound))
|
||||
ErrForbidden = errGroup.Register(fmt.Errorf("forbidden"), uint64(fileproto.ErrCodes_Forbidden))
|
||||
ErrSpaceLimitExceeded = errGroup.Register(fmt.Errorf("space limit exceeded"), uint64(fileproto.ErrCodes_SpaceLimitExceeded))
|
||||
ErrQuerySizeExceeded = errGroup.Register(fmt.Errorf("query size exceeded"), uint64(fileproto.ErrCodes_QuerySizeExceeded))
|
||||
ErrWrongHash = errGroup.Register(fmt.Errorf("wrong block hash"), uint64(fileproto.ErrCodes_WrongHash))
|
||||
)
|
||||
|
||||
@ -1,60 +1,123 @@
|
||||
syntax = "proto3";
|
||||
package anyFile;
|
||||
package filesync;
|
||||
|
||||
option go_package = "commonfile/fileproto";
|
||||
|
||||
enum ErrCodes {
|
||||
Unexpected = 0;
|
||||
CIDNotFound = 1;
|
||||
Forbidden = 2;
|
||||
SpaceLimitExceeded = 3;
|
||||
QuerySizeExceeded = 4;
|
||||
WrongHash = 5;
|
||||
ErrorOffset = 200;
|
||||
}
|
||||
|
||||
service File {
|
||||
// GetBlocks streams ipfs blocks from server to client
|
||||
rpc GetBlocks(stream GetBlockRequest) returns (stream GetBlockResponse);
|
||||
// PushBlock pushes one block to server
|
||||
rpc PushBlock(PushBlockRequest) returns (PushBlockResponse);
|
||||
// DeleteBlock deletes block from space
|
||||
rpc DeleteBlocks(DeleteBlocksRequest) returns (DeleteBlocksResponse);
|
||||
// Ping checks the connection
|
||||
// BlockGet gets one block from a server
|
||||
rpc BlockGet(BlockGetRequest) returns (BlockGetResponse);
|
||||
// BlockPush pushes one block to a server
|
||||
rpc BlockPush(BlockPushRequest) returns (BlockPushResponse);
|
||||
// BlocksCheck checks is CIDs exists
|
||||
rpc BlocksCheck(BlocksCheckRequest) returns (BlocksCheckResponse);
|
||||
// BlocksBind binds CIDs to space
|
||||
rpc BlocksBind(BlocksBindRequest) returns (BlocksBindResponse);
|
||||
// FilesDelete deletes files by id
|
||||
rpc FilesDelete(FilesDeleteRequest) returns (FilesDeleteResponse);
|
||||
// FilesInfo return info by given files id
|
||||
rpc FilesInfo(FilesInfoRequest) returns (FilesInfoResponse);
|
||||
// Check checks the connection and credentials
|
||||
rpc Check(CheckRequest) returns (CheckResponse);
|
||||
// SpaceInfo returns usage, limit, etc about space
|
||||
rpc SpaceInfo(SpaceInfoRequest) returns (SpaceInfoResponse);
|
||||
}
|
||||
|
||||
message GetBlockRequest {
|
||||
message BlockGetRequest {
|
||||
string spaceId = 1;
|
||||
bytes cid = 2;
|
||||
}
|
||||
|
||||
message GetBlockResponse {
|
||||
message BlockGetResponse {
|
||||
bytes cid = 1;
|
||||
bytes data = 2;
|
||||
CIDError code = 3;
|
||||
}
|
||||
|
||||
message PushBlockRequest {
|
||||
message BlockPushRequest {
|
||||
string spaceId = 1;
|
||||
bytes cid = 2;
|
||||
bytes data = 3;
|
||||
string fileId = 2;
|
||||
bytes cid = 3;
|
||||
bytes data = 4;
|
||||
}
|
||||
|
||||
message PushBlockResponse {}
|
||||
message BlockPushResponse {}
|
||||
|
||||
message DeleteBlocksRequest {
|
||||
|
||||
message BlocksCheckRequest {
|
||||
string spaceId = 1;
|
||||
repeated bytes cid = 2;
|
||||
repeated bytes cids = 2;
|
||||
}
|
||||
|
||||
message DeleteBlocksResponse {}
|
||||
message BlocksCheckResponse {
|
||||
repeated BlockAvailability blocksAvailability = 1;
|
||||
}
|
||||
|
||||
message BlockAvailability {
|
||||
bytes cid = 1;
|
||||
AvailabilityStatus status = 2;
|
||||
}
|
||||
|
||||
enum AvailabilityStatus {
|
||||
NotExists = 0;
|
||||
Exists = 1;
|
||||
ExistsInSpace = 2;
|
||||
}
|
||||
|
||||
message BlocksBindRequest {
|
||||
string spaceId = 1;
|
||||
string fileId = 2;
|
||||
repeated bytes cids = 3;
|
||||
}
|
||||
|
||||
message BlocksBindResponse {}
|
||||
|
||||
|
||||
message FilesDeleteRequest {
|
||||
string spaceId = 1;
|
||||
repeated string fileIds = 2;
|
||||
}
|
||||
|
||||
message FilesDeleteResponse {}
|
||||
|
||||
message FilesInfoRequest {
|
||||
string spaceId = 1;
|
||||
repeated string fileIds = 2;
|
||||
}
|
||||
|
||||
message FilesInfoResponse {
|
||||
repeated FileInfo filesInfo = 1;
|
||||
}
|
||||
|
||||
message FileInfo {
|
||||
string fileId = 1;
|
||||
uint64 usageBytes = 2;
|
||||
uint32 cidsCount = 3;
|
||||
}
|
||||
|
||||
message CheckRequest {}
|
||||
|
||||
message CheckResponse {
|
||||
repeated string spaceIds = 1;
|
||||
bool allowWrite = 2;
|
||||
}
|
||||
|
||||
|
||||
enum CIDError {
|
||||
CIDErrorOk = 0;
|
||||
CIDErrorNotFound = 1;
|
||||
CIDErrorUnexpected = 2;
|
||||
message SpaceInfoRequest {
|
||||
string spaceId = 1;
|
||||
}
|
||||
|
||||
message SpaceInfoResponse {
|
||||
uint64 limitBytes = 1;
|
||||
uint64 usageBytes = 2;
|
||||
uint64 cidsCount = 3;
|
||||
uint64 filesCount = 4;
|
||||
}
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@ package fileservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/commonfile/fileblockstore"
|
||||
"github.com/anyproto/any-sync/commonfile/fileblockstore"
|
||||
blocks "github.com/ipfs/go-block-format"
|
||||
"github.com/ipfs/go-blockservice"
|
||||
"github.com/ipfs/go-cid"
|
||||
|
||||
@ -3,9 +3,9 @@ package fileservice
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/app"
|
||||
"github.com/anytypeio/any-sync/app/logger"
|
||||
"github.com/anytypeio/any-sync/commonfile/fileblockstore"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/commonfile/fileblockstore"
|
||||
"github.com/ipfs/go-cid"
|
||||
chunker "github.com/ipfs/go-ipfs-chunker"
|
||||
ipld "github.com/ipfs/go-ipld-format"
|
||||
@ -22,6 +22,10 @@ const CName = "common.commonfile.fileservice"
|
||||
|
||||
var log = logger.NewNamed(CName)
|
||||
|
||||
const (
|
||||
ChunkSize = 1 << 20
|
||||
)
|
||||
|
||||
func New() FileService {
|
||||
return &fileService{}
|
||||
}
|
||||
@ -74,7 +78,7 @@ func (fs *fileService) AddFile(ctx context.Context, r io.Reader) (ipld.Node, err
|
||||
Maxlinks: helpers.DefaultLinksPerBlock,
|
||||
CidBuilder: &fs.prefix,
|
||||
}
|
||||
dbh, err := dbp.New(chunker.DefaultSplitter(r))
|
||||
dbh, err := dbp.New(chunker.NewSizeSplitter(r, ChunkSize))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -1,42 +0,0 @@
|
||||
package commonspace
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/syncacl"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/syncobjectgetter"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/treegetter"
|
||||
"github.com/anytypeio/any-sync/commonspace/settings"
|
||||
)
|
||||
|
||||
type commonSpaceGetter struct {
|
||||
spaceId string
|
||||
aclList *syncacl.SyncAcl
|
||||
treeGetter treegetter.TreeGetter
|
||||
settings settings.SettingsObject
|
||||
}
|
||||
|
||||
func newCommonSpaceGetter(spaceId string, aclList *syncacl.SyncAcl, treeGetter treegetter.TreeGetter, settings settings.SettingsObject) syncobjectgetter.SyncObjectGetter {
|
||||
return &commonSpaceGetter{
|
||||
spaceId: spaceId,
|
||||
aclList: aclList,
|
||||
treeGetter: treeGetter,
|
||||
settings: settings,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *commonSpaceGetter) GetObject(ctx context.Context, objectId string) (obj syncobjectgetter.SyncObject, err error) {
|
||||
if c.aclList.Id() == objectId {
|
||||
obj = c.aclList
|
||||
return
|
||||
}
|
||||
if c.settings.Id() == objectId {
|
||||
obj = c.settings.(syncobjectgetter.SyncObject)
|
||||
return
|
||||
}
|
||||
t, err := c.treeGetter.GetTree(ctx, c.spaceId, objectId)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
obj = t.(syncobjectgetter.SyncObject)
|
||||
return
|
||||
}
|
||||
@ -1,8 +1,8 @@
|
||||
package commonspace
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestorage"
|
||||
)
|
||||
|
||||
type commonStorage struct {
|
||||
|
||||
@ -1,85 +0,0 @@
|
||||
//go:generate mockgen -destination mock_confconnector/mock_confconnector.go github.com/anytypeio/any-sync/commonspace/confconnector ConfConnector
|
||||
package confconnector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/net/peer"
|
||||
"github.com/anytypeio/any-sync/net/pool"
|
||||
"github.com/anytypeio/any-sync/nodeconf"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type ConfConnector interface {
|
||||
Configuration() nodeconf.Configuration
|
||||
Pool() pool.Pool
|
||||
GetResponsiblePeers(ctx context.Context, spaceId string) ([]peer.Peer, error)
|
||||
DialInactiveResponsiblePeers(ctx context.Context, spaceId string, activeNodeIds []string) ([]peer.Peer, error)
|
||||
}
|
||||
|
||||
type confConnector struct {
|
||||
conf nodeconf.Configuration
|
||||
pool pool.Pool
|
||||
}
|
||||
|
||||
func NewConfConnector(conf nodeconf.Configuration, pool pool.Pool) ConfConnector {
|
||||
return &confConnector{conf: conf, pool: pool}
|
||||
}
|
||||
|
||||
func (s *confConnector) Configuration() nodeconf.Configuration {
|
||||
return s.conf
|
||||
}
|
||||
|
||||
func (s *confConnector) Pool() pool.Pool {
|
||||
return s.pool
|
||||
}
|
||||
|
||||
func (s *confConnector) GetResponsiblePeers(ctx context.Context, spaceId string) ([]peer.Peer, error) {
|
||||
return s.connectOneOrMany(ctx, spaceId, nil, s.pool.Get)
|
||||
}
|
||||
|
||||
func (s *confConnector) DialInactiveResponsiblePeers(ctx context.Context, spaceId string, activeNodeIds []string) ([]peer.Peer, error) {
|
||||
return s.connectOneOrMany(ctx, spaceId, activeNodeIds, s.pool.Dial)
|
||||
}
|
||||
|
||||
func (s *confConnector) connectOneOrMany(
|
||||
ctx context.Context,
|
||||
spaceId string,
|
||||
activeNodeIds []string,
|
||||
connectOne func(context.Context, string) (peer.Peer, error)) (peers []peer.Peer, err error) {
|
||||
var (
|
||||
inactiveNodeIds []string
|
||||
allNodes = s.conf.NodeIds(spaceId)
|
||||
)
|
||||
for _, id := range allNodes {
|
||||
if !slices.Contains(activeNodeIds, id) {
|
||||
inactiveNodeIds = append(inactiveNodeIds, id)
|
||||
}
|
||||
}
|
||||
|
||||
if s.conf.IsResponsible(spaceId) {
|
||||
for _, id := range inactiveNodeIds {
|
||||
var p peer.Peer
|
||||
p, err = connectOne(ctx, id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
peers = append(peers, p)
|
||||
}
|
||||
} else if len(activeNodeIds) == 0 {
|
||||
// that means that all connected ids
|
||||
var p peer.Peer
|
||||
p, err = s.pool.GetOneOf(ctx, allNodes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// if we are dialling someone, we want to dial to the same peer which we cached
|
||||
// thus communication through streams and through diff will go to the same node
|
||||
p, err = connectOne(ctx, p.Id())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
peers = []peer.Peer{p}
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -1,96 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anytypeio/any-sync/commonspace/confconnector (interfaces: ConfConnector)
|
||||
|
||||
// Package mock_confconnector is a generated GoMock package.
|
||||
package mock_confconnector
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
peer "github.com/anytypeio/any-sync/net/peer"
|
||||
pool "github.com/anytypeio/any-sync/net/pool"
|
||||
nodeconf "github.com/anytypeio/any-sync/nodeconf"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockConfConnector is a mock of ConfConnector interface.
|
||||
type MockConfConnector struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockConfConnectorMockRecorder
|
||||
}
|
||||
|
||||
// MockConfConnectorMockRecorder is the mock recorder for MockConfConnector.
|
||||
type MockConfConnectorMockRecorder struct {
|
||||
mock *MockConfConnector
|
||||
}
|
||||
|
||||
// NewMockConfConnector creates a new mock instance.
|
||||
func NewMockConfConnector(ctrl *gomock.Controller) *MockConfConnector {
|
||||
mock := &MockConfConnector{ctrl: ctrl}
|
||||
mock.recorder = &MockConfConnectorMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockConfConnector) EXPECT() *MockConfConnectorMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Configuration mocks base method.
|
||||
func (m *MockConfConnector) Configuration() nodeconf.Configuration {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Configuration")
|
||||
ret0, _ := ret[0].(nodeconf.Configuration)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Configuration indicates an expected call of Configuration.
|
||||
func (mr *MockConfConnectorMockRecorder) Configuration() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Configuration", reflect.TypeOf((*MockConfConnector)(nil).Configuration))
|
||||
}
|
||||
|
||||
// DialInactiveResponsiblePeers mocks base method.
|
||||
func (m *MockConfConnector) DialInactiveResponsiblePeers(arg0 context.Context, arg1 string, arg2 []string) ([]peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DialInactiveResponsiblePeers", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].([]peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DialInactiveResponsiblePeers indicates an expected call of DialInactiveResponsiblePeers.
|
||||
func (mr *MockConfConnectorMockRecorder) DialInactiveResponsiblePeers(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DialInactiveResponsiblePeers", reflect.TypeOf((*MockConfConnector)(nil).DialInactiveResponsiblePeers), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// GetResponsiblePeers mocks base method.
|
||||
func (m *MockConfConnector) GetResponsiblePeers(arg0 context.Context, arg1 string) ([]peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetResponsiblePeers", arg0, arg1)
|
||||
ret0, _ := ret[0].([]peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetResponsiblePeers indicates an expected call of GetResponsiblePeers.
|
||||
func (mr *MockConfConnectorMockRecorder) GetResponsiblePeers(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponsiblePeers", reflect.TypeOf((*MockConfConnector)(nil).GetResponsiblePeers), arg0, arg1)
|
||||
}
|
||||
|
||||
// Pool mocks base method.
|
||||
func (m *MockConfConnector) Pool() pool.Pool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Pool")
|
||||
ret0, _ := ret[0].(pool.Pool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Pool indicates an expected call of Pool.
|
||||
func (mr *MockConfConnectorMockRecorder) Pool() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Pool", reflect.TypeOf((*MockConfConnector)(nil).Pool))
|
||||
}
|
||||
@ -1,10 +0,0 @@
|
||||
package commonspace
|
||||
|
||||
type ConfigGetter interface {
|
||||
GetSpace() Config
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
GCTTL int `yaml:"gcTTL"`
|
||||
SyncPeriod int `yaml:"syncPeriod"`
|
||||
}
|
||||
11
commonspace/config/config.go
Normal file
11
commonspace/config/config.go
Normal file
@ -0,0 +1,11 @@
|
||||
package config
|
||||
|
||||
type ConfigGetter interface {
|
||||
GetSpace() Config
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
GCTTL int `yaml:"gcTTL"`
|
||||
SyncPeriod int `yaml:"syncPeriod"`
|
||||
KeepTreeDataInMemory bool `yaml:"keepTreeDataInMemory"`
|
||||
}
|
||||
34
commonspace/credentialprovider/credentialprovider.go
Normal file
34
commonspace/credentialprovider/credentialprovider.go
Normal file
@ -0,0 +1,34 @@
|
||||
//go:generate mockgen -destination mock_credentialprovider/mock_credentialprovider.go github.com/anyproto/any-sync/commonspace/credentialprovider CredentialProvider
|
||||
package credentialprovider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
)
|
||||
|
||||
const CName = "common.commonspace.credentialprovider"
|
||||
|
||||
func NewNoOp() CredentialProvider {
|
||||
return &noOpProvider{}
|
||||
}
|
||||
|
||||
type CredentialProvider interface {
|
||||
app.Component
|
||||
GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error)
|
||||
}
|
||||
|
||||
type noOpProvider struct {
|
||||
}
|
||||
|
||||
func (n noOpProvider) Init(a *app.App) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n noOpProvider) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
func (n noOpProvider) GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) {
|
||||
return nil, nil
|
||||
}
|
||||
@ -0,0 +1,80 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anyproto/any-sync/commonspace/credentialprovider (interfaces: CredentialProvider)
|
||||
|
||||
// Package mock_credentialprovider is a generated GoMock package.
|
||||
package mock_credentialprovider
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
app "github.com/anyproto/any-sync/app"
|
||||
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockCredentialProvider is a mock of CredentialProvider interface.
|
||||
type MockCredentialProvider struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockCredentialProviderMockRecorder
|
||||
}
|
||||
|
||||
// MockCredentialProviderMockRecorder is the mock recorder for MockCredentialProvider.
|
||||
type MockCredentialProviderMockRecorder struct {
|
||||
mock *MockCredentialProvider
|
||||
}
|
||||
|
||||
// NewMockCredentialProvider creates a new mock instance.
|
||||
func NewMockCredentialProvider(ctrl *gomock.Controller) *MockCredentialProvider {
|
||||
mock := &MockCredentialProvider{ctrl: ctrl}
|
||||
mock.recorder = &MockCredentialProviderMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockCredentialProvider) EXPECT() *MockCredentialProviderMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// GetCredential mocks base method.
|
||||
func (m *MockCredentialProvider) GetCredential(arg0 context.Context, arg1 *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetCredential", arg0, arg1)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetCredential indicates an expected call of GetCredential.
|
||||
func (mr *MockCredentialProviderMockRecorder) GetCredential(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCredential", reflect.TypeOf((*MockCredentialProvider)(nil).GetCredential), arg0, arg1)
|
||||
}
|
||||
|
||||
// Init mocks base method.
|
||||
func (m *MockCredentialProvider) Init(arg0 *app.App) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Init", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Init indicates an expected call of Init.
|
||||
func (mr *MockCredentialProviderMockRecorder) Init(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockCredentialProvider)(nil).Init), arg0)
|
||||
}
|
||||
|
||||
// Name mocks base method.
|
||||
func (m *MockCredentialProvider) Name() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Name")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Name indicates an expected call of Name.
|
||||
func (mr *MockCredentialProviderMockRecorder) Name() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockCredentialProvider)(nil).Name))
|
||||
}
|
||||
278
commonspace/deletion_test.go
Normal file
278
commonspace/deletion_test.go
Normal file
@ -0,0 +1,278 @@
|
||||
package commonspace
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/anyproto/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/settings"
|
||||
"github.com/anyproto/any-sync/commonspace/settings/settingsstate"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
"github.com/stretchr/testify/require"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func addIncorrectSnapshot(settingsObject settings.SettingsObject, acc *accountdata.AccountKeys, partialIds map[string]struct{}, newId string) (err error) {
|
||||
factory := settingsstate.NewChangeFactory()
|
||||
bytes, err := factory.CreateObjectDeleteChange(newId, &settingsstate.State{DeletedIds: partialIds}, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ch, err := settingsObject.PrepareChange(objecttree.SignableChangeContent{
|
||||
Data: bytes,
|
||||
Key: acc.SignKey,
|
||||
IsSnapshot: true,
|
||||
IsEncrypted: false,
|
||||
Timestamp: time.Now().Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res, err := settingsObject.AddRawChanges(context.Background(), objecttree.RawChangesPayload{
|
||||
NewHeads: []string{ch.Id},
|
||||
RawChanges: []*treechangeproto.RawTreeChangeWithId{ch},
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if res.Mode != objecttree.Rebuild {
|
||||
return fmt.Errorf("incorrect mode: %d", res.Mode)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestSpaceDeleteIds(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
acc := fx.account.Account()
|
||||
rk := crypto.NewAES()
|
||||
ctx := context.Background()
|
||||
totalObjs := 1500
|
||||
|
||||
// creating space
|
||||
sp, err := fx.spaceService.CreateSpace(ctx, SpaceCreatePayload{
|
||||
SigningKey: acc.SignKey,
|
||||
SpaceType: "type",
|
||||
ReadKey: rk.Bytes(),
|
||||
ReplicationKey: 10,
|
||||
MasterKey: acc.PeerKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sp)
|
||||
|
||||
// initializing space
|
||||
spc, err := fx.spaceService.NewSpace(ctx, sp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, spc)
|
||||
// adding space to tree manager
|
||||
fx.treeManager.space = spc
|
||||
err = spc.Init(ctx)
|
||||
require.NoError(t, err)
|
||||
close(fx.treeManager.waitLoad)
|
||||
|
||||
var ids []string
|
||||
for i := 0; i < totalObjs; i++ {
|
||||
// creating a tree
|
||||
bytes := make([]byte, 32)
|
||||
rand.Read(bytes)
|
||||
doc, err := spc.TreeBuilder().CreateTree(ctx, objecttree.ObjectTreeCreatePayload{
|
||||
PrivKey: acc.SignKey,
|
||||
ChangeType: "some",
|
||||
SpaceId: spc.Id(),
|
||||
IsEncrypted: false,
|
||||
Seed: bytes,
|
||||
Timestamp: time.Now().Unix(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
tr, err := spc.TreeBuilder().PutTree(ctx, doc, nil)
|
||||
require.NoError(t, err)
|
||||
ids = append(ids, tr.Id())
|
||||
tr.Close()
|
||||
}
|
||||
// deleting trees
|
||||
for _, id := range ids {
|
||||
err = spc.DeleteTree(ctx, id)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
time.Sleep(3 * time.Second)
|
||||
spc.Close()
|
||||
require.Equal(t, len(ids), len(fx.treeManager.deletedIds))
|
||||
}
|
||||
|
||||
func createTree(t *testing.T, ctx context.Context, spc Space, acc *accountdata.AccountKeys) string {
|
||||
bytes := make([]byte, 32)
|
||||
rand.Read(bytes)
|
||||
doc, err := spc.TreeBuilder().CreateTree(ctx, objecttree.ObjectTreeCreatePayload{
|
||||
PrivKey: acc.SignKey,
|
||||
ChangeType: "some",
|
||||
SpaceId: spc.Id(),
|
||||
IsEncrypted: false,
|
||||
Seed: bytes,
|
||||
Timestamp: time.Now().Unix(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
tr, err := spc.TreeBuilder().PutTree(ctx, doc, nil)
|
||||
require.NoError(t, err)
|
||||
tr.Close()
|
||||
return tr.Id()
|
||||
}
|
||||
|
||||
func TestSpaceDeleteIdsIncorrectSnapshot(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
acc := fx.account.Account()
|
||||
rk := crypto.NewAES()
|
||||
ctx := context.Background()
|
||||
totalObjs := 1500
|
||||
partialObjs := 300
|
||||
|
||||
// creating space
|
||||
sp, err := fx.spaceService.CreateSpace(ctx, SpaceCreatePayload{
|
||||
SigningKey: acc.SignKey,
|
||||
SpaceType: "type",
|
||||
ReadKey: rk.Bytes(),
|
||||
ReplicationKey: 10,
|
||||
MasterKey: acc.PeerKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sp)
|
||||
|
||||
// initializing space
|
||||
spc, err := fx.spaceService.NewSpace(ctx, sp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, spc)
|
||||
// adding space to tree manager
|
||||
fx.treeManager.space = spc
|
||||
err = spc.Init(ctx)
|
||||
close(fx.treeManager.waitLoad)
|
||||
require.NoError(t, err)
|
||||
|
||||
settingsObject := spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
|
||||
var ids []string
|
||||
for i := 0; i < totalObjs; i++ {
|
||||
id := createTree(t, ctx, spc, acc)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
// copying storage, so we will have all the trees locally
|
||||
inmemory := spc.Storage().(*commonStorage).SpaceStorage.(*spacestorage.InMemorySpaceStorage)
|
||||
storageCopy := inmemory.CopyStorage()
|
||||
treesCopy := inmemory.AllTrees()
|
||||
|
||||
// deleting trees
|
||||
for _, id := range ids {
|
||||
err = spc.DeleteTree(ctx, id)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
mapIds := map[string]struct{}{}
|
||||
for _, id := range ids[:partialObjs] {
|
||||
mapIds[id] = struct{}{}
|
||||
}
|
||||
// adding snapshot that breaks the state
|
||||
err = addIncorrectSnapshot(settingsObject, acc, mapIds, ids[partialObjs])
|
||||
require.NoError(t, err)
|
||||
// copying the contents of the settings tree
|
||||
treesCopy[settingsObject.Id()] = settingsObject.Storage()
|
||||
storageCopy.SetTrees(treesCopy)
|
||||
spc.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// now we replace the storage, so the trees are back, but the settings object says that they are deleted
|
||||
fx.storageProvider.(*spacestorage.InMemorySpaceStorageProvider).SetStorage(storageCopy)
|
||||
|
||||
spc, err = fx.spaceService.NewSpace(ctx, sp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, spc)
|
||||
fx.treeManager.waitLoad = make(chan struct{})
|
||||
fx.treeManager.space = spc
|
||||
fx.treeManager.deletedIds = nil
|
||||
err = spc.Init(ctx)
|
||||
require.NoError(t, err)
|
||||
close(fx.treeManager.waitLoad)
|
||||
|
||||
// waiting until everything is deleted
|
||||
time.Sleep(3 * time.Second)
|
||||
require.Equal(t, len(ids), len(fx.treeManager.deletedIds))
|
||||
|
||||
// checking that new snapshot will contain all the changes
|
||||
settingsObject = spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
|
||||
settings.DoSnapshot = func(treeLen int) bool {
|
||||
return true
|
||||
}
|
||||
id := createTree(t, ctx, spc, acc)
|
||||
err = spc.DeleteTree(ctx, id)
|
||||
require.NoError(t, err)
|
||||
delIds := settingsObject.Root().Model.(*spacesyncproto.SettingsData).Snapshot.DeletedIds
|
||||
require.Equal(t, totalObjs+1, len(delIds))
|
||||
}
|
||||
|
||||
func TestSpaceDeleteIdsMarkDeleted(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
acc := fx.account.Account()
|
||||
rk := crypto.NewAES()
|
||||
ctx := context.Background()
|
||||
totalObjs := 1500
|
||||
|
||||
// creating space
|
||||
sp, err := fx.spaceService.CreateSpace(ctx, SpaceCreatePayload{
|
||||
SigningKey: acc.SignKey,
|
||||
SpaceType: "type",
|
||||
ReadKey: rk.Bytes(),
|
||||
ReplicationKey: 10,
|
||||
MasterKey: acc.PeerKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, sp)
|
||||
|
||||
// initializing space
|
||||
spc, err := fx.spaceService.NewSpace(ctx, sp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, spc)
|
||||
// adding space to tree manager
|
||||
fx.treeManager.space = spc
|
||||
err = spc.Init(ctx)
|
||||
require.NoError(t, err)
|
||||
close(fx.treeManager.waitLoad)
|
||||
|
||||
settingsObject := spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
|
||||
var ids []string
|
||||
for i := 0; i < totalObjs; i++ {
|
||||
id := createTree(t, ctx, spc, acc)
|
||||
ids = append(ids, id)
|
||||
}
|
||||
// copying storage, so we will have the same contents, except for empty trees
|
||||
inmemory := spc.Storage().(*commonStorage).SpaceStorage.(*spacestorage.InMemorySpaceStorage)
|
||||
storageCopy := inmemory.CopyStorage()
|
||||
|
||||
// deleting trees, this will prepare the document to have all the deletion changes
|
||||
for _, id := range ids {
|
||||
err = spc.DeleteTree(ctx, id)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
treesMap := map[string]treestorage.TreeStorage{}
|
||||
// copying the contents of the settings tree
|
||||
treesMap[settingsObject.Id()] = settingsObject.Storage()
|
||||
storageCopy.SetTrees(treesMap)
|
||||
spc.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
// now we replace the storage, so the trees are back, but the settings object says that they are deleted
|
||||
fx.storageProvider.(*spacestorage.InMemorySpaceStorageProvider).SetStorage(storageCopy)
|
||||
|
||||
spc, err = fx.spaceService.NewSpace(ctx, sp)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, spc)
|
||||
fx.treeManager.space = spc
|
||||
fx.treeManager.waitLoad = make(chan struct{})
|
||||
fx.treeManager.deletedIds = nil
|
||||
fx.treeManager.markedIds = nil
|
||||
err = spc.Init(ctx)
|
||||
require.NoError(t, err)
|
||||
close(fx.treeManager.waitLoad)
|
||||
|
||||
// waiting until everything is deleted
|
||||
time.Sleep(3 * time.Second)
|
||||
require.Equal(t, len(ids), len(fx.treeManager.markedIds))
|
||||
require.Zero(t, len(fx.treeManager.deletedIds))
|
||||
}
|
||||
@ -1,59 +1,73 @@
|
||||
//go:generate mockgen -destination mock_deletionstate/mock_deletionstate.go github.com/anytypeio/any-sync/commonspace/settings/deletionstate DeletionState
|
||||
//go:generate mockgen -destination mock_deletionstate/mock_deletionstate.go github.com/anyproto/any-sync/commonspace/deletionstate ObjectDeletionState
|
||||
package deletionstate
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/commonspace/spacestorage"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestorage"
|
||||
"go.uber.org/zap"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var log = logger.NewNamed(CName)
|
||||
|
||||
const CName = "common.commonspace.deletionstate"
|
||||
|
||||
type StateUpdateObserver func(ids []string)
|
||||
|
||||
type DeletionState interface {
|
||||
type ObjectDeletionState interface {
|
||||
app.Component
|
||||
AddObserver(observer StateUpdateObserver)
|
||||
Add(ids []string) (err error)
|
||||
Add(ids map[string]struct{})
|
||||
GetQueued() (ids []string)
|
||||
Delete(id string) (err error)
|
||||
Exists(id string) bool
|
||||
FilterJoin(ids ...[]string) (filtered []string)
|
||||
CreateDeleteChange(id string, isSnapshot bool) (res []byte, err error)
|
||||
Filter(ids []string) (filtered []string)
|
||||
}
|
||||
|
||||
type deletionState struct {
|
||||
type objectDeletionState struct {
|
||||
sync.RWMutex
|
||||
log logger.CtxLogger
|
||||
queued map[string]struct{}
|
||||
deleted map[string]struct{}
|
||||
stateUpdateObservers []StateUpdateObserver
|
||||
storage spacestorage.SpaceStorage
|
||||
}
|
||||
|
||||
func NewDeletionState(storage spacestorage.SpaceStorage) DeletionState {
|
||||
return &deletionState{
|
||||
func (st *objectDeletionState) Init(a *app.App) (err error) {
|
||||
st.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *objectDeletionState) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
func New() ObjectDeletionState {
|
||||
return &objectDeletionState{
|
||||
log: log,
|
||||
queued: map[string]struct{}{},
|
||||
deleted: map[string]struct{}{},
|
||||
storage: storage,
|
||||
}
|
||||
}
|
||||
|
||||
func (st *deletionState) AddObserver(observer StateUpdateObserver) {
|
||||
func (st *objectDeletionState) AddObserver(observer StateUpdateObserver) {
|
||||
st.Lock()
|
||||
defer st.Unlock()
|
||||
st.stateUpdateObservers = append(st.stateUpdateObservers, observer)
|
||||
}
|
||||
|
||||
func (st *deletionState) Add(ids []string) (err error) {
|
||||
func (st *objectDeletionState) Add(ids map[string]struct{}) {
|
||||
var added []string
|
||||
st.Lock()
|
||||
defer func() {
|
||||
st.Unlock()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, ob := range st.stateUpdateObservers {
|
||||
ob(ids)
|
||||
ob(added)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, id := range ids {
|
||||
for id := range ids {
|
||||
if _, exists := st.deleted[id]; exists {
|
||||
continue
|
||||
}
|
||||
@ -62,9 +76,10 @@ func (st *deletionState) Add(ids []string) (err error) {
|
||||
}
|
||||
|
||||
var status string
|
||||
status, err = st.storage.TreeDeletedStatus(id)
|
||||
status, err := st.storage.TreeDeletedStatus(id)
|
||||
if err != nil {
|
||||
return
|
||||
st.log.Warn("failed to get deleted status", zap.String("treeId", id), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
|
||||
switch status {
|
||||
@ -73,17 +88,18 @@ func (st *deletionState) Add(ids []string) (err error) {
|
||||
case spacestorage.TreeDeletedStatusDeleted:
|
||||
st.deleted[id] = struct{}{}
|
||||
default:
|
||||
st.queued[id] = struct{}{}
|
||||
err = st.storage.SetTreeDeletedStatus(id, spacestorage.TreeDeletedStatusQueued)
|
||||
err := st.storage.SetTreeDeletedStatus(id, spacestorage.TreeDeletedStatusQueued)
|
||||
if err != nil {
|
||||
return
|
||||
st.log.Warn("failed to set deleted status", zap.String("treeId", id), zap.Error(err))
|
||||
continue
|
||||
}
|
||||
st.queued[id] = struct{}{}
|
||||
}
|
||||
added = append(added, id)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (st *deletionState) GetQueued() (ids []string) {
|
||||
func (st *objectDeletionState) GetQueued() (ids []string) {
|
||||
st.RLock()
|
||||
defer st.RUnlock()
|
||||
ids = make([]string, 0, len(st.queued))
|
||||
@ -93,7 +109,7 @@ func (st *deletionState) GetQueued() (ids []string) {
|
||||
return
|
||||
}
|
||||
|
||||
func (st *deletionState) Delete(id string) (err error) {
|
||||
func (st *objectDeletionState) Delete(id string) (err error) {
|
||||
st.Lock()
|
||||
defer st.Unlock()
|
||||
delete(st.queued, id)
|
||||
@ -105,44 +121,24 @@ func (st *deletionState) Delete(id string) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (st *deletionState) Exists(id string) bool {
|
||||
func (st *objectDeletionState) Exists(id string) bool {
|
||||
st.RLock()
|
||||
defer st.RUnlock()
|
||||
return st.exists(id)
|
||||
}
|
||||
|
||||
func (st *deletionState) FilterJoin(ids ...[]string) (filtered []string) {
|
||||
func (st *objectDeletionState) Filter(ids []string) (filtered []string) {
|
||||
st.RLock()
|
||||
defer st.RUnlock()
|
||||
filter := func(ids []string) {
|
||||
for _, id := range ids {
|
||||
if !st.exists(id) {
|
||||
filtered = append(filtered, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, arr := range ids {
|
||||
filter(arr)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (st *deletionState) CreateDeleteChange(id string, isSnapshot bool) (res []byte, err error) {
|
||||
content := &spacesyncproto.SpaceSettingsContent_ObjectDelete{
|
||||
ObjectDelete: &spacesyncproto.ObjectDelete{Id: id},
|
||||
}
|
||||
change := &spacesyncproto.SettingsData{
|
||||
Content: []*spacesyncproto.SpaceSettingsContent{
|
||||
{Value: content},
|
||||
},
|
||||
Snapshot: nil,
|
||||
}
|
||||
// TODO: add snapshot logic
|
||||
res, err = change.Marshal()
|
||||
return
|
||||
}
|
||||
|
||||
func (st *deletionState) exists(id string) bool {
|
||||
func (st *objectDeletionState) exists(id string) bool {
|
||||
if _, exists := st.deleted[id]; exists {
|
||||
return true
|
||||
}
|
||||
@ -1,23 +1,25 @@
|
||||
package deletionstate
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/commonspace/spacestorage"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacestorage/mock_spacestorage"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fixture struct {
|
||||
ctrl *gomock.Controller
|
||||
delState *deletionState
|
||||
delState *objectDeletionState
|
||||
spaceStorage *mock_spacestorage.MockSpaceStorage
|
||||
}
|
||||
|
||||
func newFixture(t *testing.T) *fixture {
|
||||
ctrl := gomock.NewController(t)
|
||||
spaceStorage := mock_spacestorage.NewMockSpaceStorage(ctrl)
|
||||
delState := NewDeletionState(spaceStorage).(*deletionState)
|
||||
delState := New().(*objectDeletionState)
|
||||
delState.storage = spaceStorage
|
||||
return &fixture{
|
||||
ctrl: ctrl,
|
||||
delState: delState,
|
||||
@ -36,8 +38,7 @@ func TestDeletionState_Add(t *testing.T) {
|
||||
id := "newId"
|
||||
fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return("", nil)
|
||||
fx.spaceStorage.EXPECT().SetTreeDeletedStatus(id, spacestorage.TreeDeletedStatusQueued).Return(nil)
|
||||
err := fx.delState.Add([]string{id})
|
||||
require.NoError(t, err)
|
||||
fx.delState.Add(map[string]struct{}{id: {}})
|
||||
require.Contains(t, fx.delState.queued, id)
|
||||
})
|
||||
|
||||
@ -46,8 +47,7 @@ func TestDeletionState_Add(t *testing.T) {
|
||||
defer fx.stop()
|
||||
id := "newId"
|
||||
fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return(spacestorage.TreeDeletedStatusQueued, nil)
|
||||
err := fx.delState.Add([]string{id})
|
||||
require.NoError(t, err)
|
||||
fx.delState.Add(map[string]struct{}{id: {}})
|
||||
require.Contains(t, fx.delState.queued, id)
|
||||
})
|
||||
|
||||
@ -56,8 +56,7 @@ func TestDeletionState_Add(t *testing.T) {
|
||||
defer fx.stop()
|
||||
id := "newId"
|
||||
fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return(spacestorage.TreeDeletedStatusDeleted, nil)
|
||||
err := fx.delState.Add([]string{id})
|
||||
require.NoError(t, err)
|
||||
fx.delState.Add(map[string]struct{}{id: {}})
|
||||
require.Contains(t, fx.delState.deleted, id)
|
||||
})
|
||||
}
|
||||
@ -70,6 +69,7 @@ func TestDeletionState_GetQueued(t *testing.T) {
|
||||
fx.delState.queued["id2"] = struct{}{}
|
||||
|
||||
queued := fx.delState.GetQueued()
|
||||
sort.Strings(queued)
|
||||
require.Equal(t, []string{"id1", "id2"}, queued)
|
||||
}
|
||||
|
||||
@ -80,8 +80,8 @@ func TestDeletionState_FilterJoin(t *testing.T) {
|
||||
fx.delState.queued["id1"] = struct{}{}
|
||||
fx.delState.queued["id2"] = struct{}{}
|
||||
|
||||
filtered := fx.delState.FilterJoin([]string{"id1"}, []string{"id3", "id2"}, []string{"id4"})
|
||||
require.Equal(t, []string{"id3", "id4"}, filtered)
|
||||
filtered := fx.delState.Filter([]string{"id3", "id2"})
|
||||
require.Equal(t, []string{"id3"}, filtered)
|
||||
}
|
||||
|
||||
func TestDeletionState_AddObserver(t *testing.T) {
|
||||
@ -96,8 +96,7 @@ func TestDeletionState_AddObserver(t *testing.T) {
|
||||
id := "newId"
|
||||
fx.spaceStorage.EXPECT().TreeDeletedStatus(id).Return("", nil)
|
||||
fx.spaceStorage.EXPECT().SetTreeDeletedStatus(id, spacestorage.TreeDeletedStatusQueued).Return(nil)
|
||||
err := fx.delState.Add([]string{id})
|
||||
require.NoError(t, err)
|
||||
fx.delState.Add(map[string]struct{}{id: {}})
|
||||
require.Contains(t, fx.delState.queued, id)
|
||||
require.Equal(t, []string{id}, queued)
|
||||
}
|
||||
@ -0,0 +1,144 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anyproto/any-sync/commonspace/deletionstate (interfaces: ObjectDeletionState)
|
||||
|
||||
// Package mock_deletionstate is a generated GoMock package.
|
||||
package mock_deletionstate
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
app "github.com/anyproto/any-sync/app"
|
||||
deletionstate "github.com/anyproto/any-sync/commonspace/deletionstate"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockObjectDeletionState is a mock of ObjectDeletionState interface.
|
||||
type MockObjectDeletionState struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockObjectDeletionStateMockRecorder
|
||||
}
|
||||
|
||||
// MockObjectDeletionStateMockRecorder is the mock recorder for MockObjectDeletionState.
|
||||
type MockObjectDeletionStateMockRecorder struct {
|
||||
mock *MockObjectDeletionState
|
||||
}
|
||||
|
||||
// NewMockObjectDeletionState creates a new mock instance.
|
||||
func NewMockObjectDeletionState(ctrl *gomock.Controller) *MockObjectDeletionState {
|
||||
mock := &MockObjectDeletionState{ctrl: ctrl}
|
||||
mock.recorder = &MockObjectDeletionStateMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockObjectDeletionState) EXPECT() *MockObjectDeletionStateMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Add mocks base method.
|
||||
func (m *MockObjectDeletionState) Add(arg0 map[string]struct{}) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Add", arg0)
|
||||
}
|
||||
|
||||
// Add indicates an expected call of Add.
|
||||
func (mr *MockObjectDeletionStateMockRecorder) Add(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockObjectDeletionState)(nil).Add), arg0)
|
||||
}
|
||||
|
||||
// AddObserver mocks base method.
|
||||
func (m *MockObjectDeletionState) AddObserver(arg0 deletionstate.StateUpdateObserver) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "AddObserver", arg0)
|
||||
}
|
||||
|
||||
// AddObserver indicates an expected call of AddObserver.
|
||||
func (mr *MockObjectDeletionStateMockRecorder) AddObserver(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddObserver", reflect.TypeOf((*MockObjectDeletionState)(nil).AddObserver), arg0)
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
func (m *MockObjectDeletionState) Delete(arg0 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Delete", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Delete indicates an expected call of Delete.
|
||||
func (mr *MockObjectDeletionStateMockRecorder) Delete(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockObjectDeletionState)(nil).Delete), arg0)
|
||||
}
|
||||
|
||||
// Exists mocks base method.
|
||||
func (m *MockObjectDeletionState) Exists(arg0 string) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Exists", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Exists indicates an expected call of Exists.
|
||||
func (mr *MockObjectDeletionStateMockRecorder) Exists(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockObjectDeletionState)(nil).Exists), arg0)
|
||||
}
|
||||
|
||||
// Filter mocks base method.
|
||||
func (m *MockObjectDeletionState) Filter(arg0 []string) []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Filter", arg0)
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Filter indicates an expected call of Filter.
|
||||
func (mr *MockObjectDeletionStateMockRecorder) Filter(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filter", reflect.TypeOf((*MockObjectDeletionState)(nil).Filter), arg0)
|
||||
}
|
||||
|
||||
// GetQueued mocks base method.
|
||||
func (m *MockObjectDeletionState) GetQueued() []string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetQueued")
|
||||
ret0, _ := ret[0].([]string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetQueued indicates an expected call of GetQueued.
|
||||
func (mr *MockObjectDeletionStateMockRecorder) GetQueued() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQueued", reflect.TypeOf((*MockObjectDeletionState)(nil).GetQueued))
|
||||
}
|
||||
|
||||
// Init mocks base method.
|
||||
func (m *MockObjectDeletionState) Init(arg0 *app.App) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Init", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Init indicates an expected call of Init.
|
||||
func (mr *MockObjectDeletionStateMockRecorder) Init(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockObjectDeletionState)(nil).Init), arg0)
|
||||
}
|
||||
|
||||
// Name mocks base method.
|
||||
func (m *MockObjectDeletionState) Name() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Name")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Name indicates an expected call of Name.
|
||||
func (mr *MockObjectDeletionStateMockRecorder) Name() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockObjectDeletionState)(nil).Name))
|
||||
}
|
||||
@ -2,63 +2,67 @@ package headsync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/app/ldiff"
|
||||
"github.com/anytypeio/any-sync/commonspace/confconnector"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/synctree"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/treegetter"
|
||||
"github.com/anytypeio/any-sync/commonspace/settings/deletionstate"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacestorage"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/syncstatus"
|
||||
"github.com/anytypeio/any-sync/net/peer"
|
||||
"github.com/anytypeio/any-sync/net/rpc/rpcerr"
|
||||
"go.uber.org/zap"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/anyproto/any-sync/app/ldiff"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/commonspace/credentialprovider"
|
||||
"github.com/anyproto/any-sync/commonspace/deletionstate"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
|
||||
"github.com/anyproto/any-sync/commonspace/object/treemanager"
|
||||
"github.com/anyproto/any-sync/commonspace/peermanager"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anyproto/any-sync/commonspace/syncstatus"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/net/rpc/rpcerr"
|
||||
"github.com/anyproto/any-sync/util/slice"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type DiffSyncer interface {
|
||||
Sync(ctx context.Context) error
|
||||
RemoveObjects(ids []string)
|
||||
UpdateHeads(id string, heads []string)
|
||||
Init(deletionState deletionstate.DeletionState)
|
||||
Init()
|
||||
Close() error
|
||||
}
|
||||
|
||||
func newDiffSyncer(
|
||||
spaceId string,
|
||||
diff ldiff.Diff,
|
||||
confConnector confconnector.ConfConnector,
|
||||
cache treegetter.TreeGetter,
|
||||
storage spacestorage.SpaceStorage,
|
||||
clientFactory spacesyncproto.ClientFactory,
|
||||
syncStatus syncstatus.StatusUpdater,
|
||||
log *zap.Logger) DiffSyncer {
|
||||
func newDiffSyncer(hs *headSync) DiffSyncer {
|
||||
return &diffSyncer{
|
||||
diff: diff,
|
||||
spaceId: spaceId,
|
||||
cache: cache,
|
||||
storage: storage,
|
||||
confConnector: confConnector,
|
||||
clientFactory: clientFactory,
|
||||
diff: hs.diff,
|
||||
spaceId: hs.spaceId,
|
||||
treeManager: hs.treeManager,
|
||||
storage: hs.storage,
|
||||
peerManager: hs.peerManager,
|
||||
clientFactory: spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient),
|
||||
credentialProvider: hs.credentialProvider,
|
||||
log: log,
|
||||
syncStatus: syncStatus,
|
||||
syncStatus: hs.syncStatus,
|
||||
deletionState: hs.deletionState,
|
||||
syncAcl: hs.syncAcl,
|
||||
}
|
||||
}
|
||||
|
||||
type diffSyncer struct {
|
||||
spaceId string
|
||||
diff ldiff.Diff
|
||||
confConnector confconnector.ConfConnector
|
||||
cache treegetter.TreeGetter
|
||||
peerManager peermanager.PeerManager
|
||||
treeManager treemanager.TreeManager
|
||||
storage spacestorage.SpaceStorage
|
||||
clientFactory spacesyncproto.ClientFactory
|
||||
log *zap.Logger
|
||||
deletionState deletionstate.DeletionState
|
||||
log logger.CtxLogger
|
||||
deletionState deletionstate.ObjectDeletionState
|
||||
credentialProvider credentialprovider.CredentialProvider
|
||||
syncStatus syncstatus.StatusUpdater
|
||||
treeSyncer treemanager.TreeSyncer
|
||||
syncAcl syncacl.SyncAcl
|
||||
}
|
||||
|
||||
func (d *diffSyncer) Init(deletionState deletionstate.DeletionState) {
|
||||
d.deletionState = deletionState
|
||||
func (d *diffSyncer) Init() {
|
||||
d.deletionState.AddObserver(d.RemoveObjects)
|
||||
d.treeSyncer = d.treeManager.NewTreeSyncer(d.spaceId, d.treeManager)
|
||||
}
|
||||
|
||||
func (d *diffSyncer) RemoveObjects(ids []string) {
|
||||
@ -84,75 +88,91 @@ func (d *diffSyncer) UpdateHeads(id string, heads []string) {
|
||||
}
|
||||
|
||||
func (d *diffSyncer) Sync(ctx context.Context) error {
|
||||
// TODO: split diffsyncer into components
|
||||
st := time.Now()
|
||||
// diffing with responsible peers according to configuration
|
||||
peers, err := d.confConnector.GetResponsiblePeers(ctx, d.spaceId)
|
||||
peers, err := d.peerManager.GetResponsiblePeers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
var peerIds = make([]string, 0, len(peers))
|
||||
for _, p := range peers {
|
||||
if err := d.syncWithPeer(ctx, p); err != nil {
|
||||
d.log.Error("can't sync with peer", zap.String("peer", p.Id()), zap.Error(err))
|
||||
peerIds = append(peerIds, p.Id())
|
||||
}
|
||||
d.log.DebugCtx(ctx, "start diffsync", zap.Strings("peerIds", peerIds))
|
||||
for _, p := range peers {
|
||||
if err = d.syncWithPeer(peer.CtxWithPeerId(ctx, p.Id()), p); err != nil {
|
||||
d.log.ErrorCtx(ctx, "can't sync with peer", zap.String("peer", p.Id()), zap.Error(err))
|
||||
}
|
||||
}
|
||||
d.log.Info("synced", zap.String("spaceId", d.spaceId), zap.Duration("dur", time.Since(st)))
|
||||
d.log.InfoCtx(ctx, "diff done", zap.String("spaceId", d.spaceId), zap.Duration("dur", time.Since(st)))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error) {
|
||||
ctx = logger.CtxWithFields(ctx, zap.String("peerId", p.Id()))
|
||||
conn, err := p.AcquireDrpcConn(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer p.ReleaseDrpcConn(conn)
|
||||
|
||||
var (
|
||||
cl = d.clientFactory.Client(p)
|
||||
cl = d.clientFactory.Client(conn)
|
||||
rdiff = NewRemoteDiff(d.spaceId, cl)
|
||||
stateCounter = d.syncStatus.StateCounter()
|
||||
syncAclId = d.syncAcl.Id()
|
||||
)
|
||||
|
||||
newIds, changedIds, removedIds, err := d.diff.Diff(ctx, rdiff)
|
||||
err = rpcerr.Unwrap(err)
|
||||
if err != nil && err != spacesyncproto.ErrSpaceMissing {
|
||||
if err == spacesyncproto.ErrSpaceIsDeleted {
|
||||
d.log.Debug("got space deleted while syncing")
|
||||
d.treeSyncer.SyncAll(ctx, p.Id(), []string{d.storage.SpaceSettingsId()}, nil)
|
||||
}
|
||||
d.syncStatus.SetNodesOnline(p.Id(), false)
|
||||
return err
|
||||
return fmt.Errorf("diff error: %v", err)
|
||||
}
|
||||
d.syncStatus.SetNodesOnline(p.Id(), true)
|
||||
|
||||
if err == spacesyncproto.ErrSpaceMissing {
|
||||
return d.sendPushSpaceRequest(ctx, cl)
|
||||
return d.sendPushSpaceRequest(ctx, p.Id(), cl)
|
||||
}
|
||||
|
||||
totalLen := len(newIds) + len(changedIds) + len(removedIds)
|
||||
// not syncing ids which were removed through settings document
|
||||
filteredIds := d.deletionState.FilterJoin(newIds, changedIds, removedIds)
|
||||
missingIds := d.deletionState.Filter(newIds)
|
||||
existingIds := append(d.deletionState.Filter(removedIds), d.deletionState.Filter(changedIds)...)
|
||||
d.syncStatus.RemoveAllExcept(p.Id(), existingIds, stateCounter)
|
||||
|
||||
d.syncStatus.RemoveAllExcept(p.Id(), filteredIds, stateCounter)
|
||||
prevExistingLen := len(existingIds)
|
||||
existingIds = slice.DiscardFromSlice(existingIds, func(s string) bool {
|
||||
return s == syncAclId
|
||||
})
|
||||
// if we removed acl head from the list
|
||||
if len(existingIds) < prevExistingLen {
|
||||
if syncErr := d.syncAcl.SyncWithPeer(ctx, p.Id()); syncErr != nil {
|
||||
log.Warn("failed to send acl sync message to peer", zap.String("aclId", syncAclId))
|
||||
}
|
||||
}
|
||||
|
||||
ctx = peer.CtxWithPeerId(ctx, p.Id())
|
||||
d.pingTreesInCache(ctx, filteredIds)
|
||||
|
||||
d.log.Info("sync done:", zap.Int("newIds", len(newIds)),
|
||||
// treeSyncer should not get acl id, that's why we filter existing ids before
|
||||
err = d.treeSyncer.SyncAll(ctx, p.Id(), existingIds, missingIds)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.log.Info("sync done:",
|
||||
zap.Int("newIds", len(newIds)),
|
||||
zap.Int("changedIds", len(changedIds)),
|
||||
zap.Int("removedIds", len(removedIds)),
|
||||
zap.Int("already deleted ids", totalLen-len(filteredIds)))
|
||||
zap.Int("already deleted ids", totalLen-prevExistingLen-len(missingIds)),
|
||||
zap.String("peerId", p.Id()),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
func (d *diffSyncer) pingTreesInCache(ctx context.Context, trees []string) {
|
||||
for _, tId := range trees {
|
||||
tree, err := d.cache.GetTree(ctx, d.spaceId, tId)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
syncTree, ok := tree.(synctree.SyncTree)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
// the idea why we call it directly is that if we try to get it from cache
|
||||
// it may be already there (i.e. loaded)
|
||||
// and build func will not be called, thus we won't sync the tree
|
||||
// therefore we just do it manually
|
||||
syncTree.Ping()
|
||||
}
|
||||
}
|
||||
|
||||
func (d *diffSyncer) sendPushSpaceRequest(ctx context.Context, cl spacesyncproto.DRPCSpaceSyncClient) (err error) {
|
||||
func (d *diffSyncer) sendPushSpaceRequest(ctx context.Context, peerId string, cl spacesyncproto.DRPCSpaceSyncClient) (err error) {
|
||||
aclStorage, err := d.storage.AclStorage()
|
||||
if err != nil {
|
||||
return
|
||||
@ -177,6 +197,10 @@ func (d *diffSyncer) sendPushSpaceRequest(ctx context.Context, cl spacesyncproto
|
||||
return
|
||||
}
|
||||
|
||||
cred, err := d.credentialProvider.GetCredential(ctx, header)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
spacePayload := &spacesyncproto.SpacePayload{
|
||||
SpaceHeader: header,
|
||||
AclPayload: root.Payload,
|
||||
@ -186,6 +210,31 @@ func (d *diffSyncer) sendPushSpaceRequest(ctx context.Context, cl spacesyncproto
|
||||
}
|
||||
_, err = cl.SpacePush(ctx, &spacesyncproto.SpacePushRequest{
|
||||
Payload: spacePayload,
|
||||
Credential: cred,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if e := d.subscribe(ctx, peerId); e != nil {
|
||||
d.log.WarnCtx(ctx, "error subscribing for space", zap.Error(e))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (d *diffSyncer) subscribe(ctx context.Context, peerId string) (err error) {
|
||||
var msg = &spacesyncproto.SpaceSubscription{
|
||||
SpaceIds: []string{d.spaceId},
|
||||
Action: spacesyncproto.SpaceSubscriptionAction_Subscribe,
|
||||
}
|
||||
payload, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return d.peerManager.SendPeer(ctx, peerId, &spacesyncproto.ObjectSyncMessage{
|
||||
Payload: payload,
|
||||
})
|
||||
}
|
||||
|
||||
func (d *diffSyncer) Close() error {
|
||||
return d.treeSyncer.Close()
|
||||
}
|
||||
|
||||
@ -1,170 +1,200 @@
|
||||
package headsync
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/app/ldiff"
|
||||
"github.com/anytypeio/any-sync/app/ldiff/mock_ldiff"
|
||||
"github.com/anytypeio/any-sync/app/logger"
|
||||
"github.com/anytypeio/any-sync/commonspace/confconnector/mock_confconnector"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/liststorage/mock_liststorage"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
|
||||
mock_treestorage "github.com/anytypeio/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/treegetter/mock_treegetter"
|
||||
"github.com/anytypeio/any-sync/commonspace/settings/deletionstate/mock_deletionstate"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacestorage/mock_spacestorage"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacesyncproto/mock_spacesyncproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/syncstatus"
|
||||
"github.com/anytypeio/any-sync/net/peer"
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/libp2p/go-libp2p/core/sec"
|
||||
"github.com/stretchr/testify/require"
|
||||
"storj.io/drpc"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/anyproto/any-sync/app/ldiff"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage/mock_liststorage"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"storj.io/drpc"
|
||||
)
|
||||
|
||||
type pushSpaceRequestMatcher struct {
|
||||
spaceId string
|
||||
aclRootId string
|
||||
settingsId string
|
||||
credential []byte
|
||||
spaceHeader *spacesyncproto.RawSpaceHeaderWithId
|
||||
}
|
||||
|
||||
func newPushSpaceRequestMatcher(
|
||||
spaceId string,
|
||||
aclRootId string,
|
||||
settingsId string,
|
||||
credential []byte,
|
||||
spaceHeader *spacesyncproto.RawSpaceHeaderWithId) *pushSpaceRequestMatcher {
|
||||
return &pushSpaceRequestMatcher{
|
||||
spaceId: spaceId,
|
||||
aclRootId: aclRootId,
|
||||
settingsId: settingsId,
|
||||
credential: credential,
|
||||
spaceHeader: spaceHeader,
|
||||
}
|
||||
}
|
||||
|
||||
func (p pushSpaceRequestMatcher) Matches(x interface{}) bool {
|
||||
res, ok := x.(*spacesyncproto.SpacePushRequest)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
return res.Payload.AclPayloadId == p.aclRootId && res.Payload.SpaceHeader == p.spaceHeader && res.Payload.SpaceSettingsPayloadId == p.settingsId
|
||||
return res.Payload.AclPayloadId == p.aclRootId && res.Payload.SpaceHeader == p.spaceHeader && res.Payload.SpaceSettingsPayloadId == p.settingsId && bytes.Equal(p.credential, res.Credential)
|
||||
}
|
||||
|
||||
func (p pushSpaceRequestMatcher) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
type mockPeer struct{}
|
||||
type mockPeer struct {
|
||||
}
|
||||
|
||||
func (m mockPeer) Id() string {
|
||||
return "mockId"
|
||||
return "peerId"
|
||||
}
|
||||
|
||||
func (m mockPeer) LastUsage() time.Time {
|
||||
return time.Time{}
|
||||
func (m mockPeer) Context() context.Context {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
func (m mockPeer) Secure() sec.SecureConn {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockPeer) UpdateLastUsage() {
|
||||
}
|
||||
|
||||
func (m mockPeer) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockPeer) Closed() <-chan struct{} {
|
||||
return make(chan struct{})
|
||||
}
|
||||
|
||||
func (m mockPeer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockPeer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
|
||||
func (m mockPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func newPushSpaceRequestMatcher(
|
||||
spaceId string,
|
||||
aclRootId string,
|
||||
settingsId string,
|
||||
spaceHeader *spacesyncproto.RawSpaceHeaderWithId) *pushSpaceRequestMatcher {
|
||||
return &pushSpaceRequestMatcher{
|
||||
spaceId: spaceId,
|
||||
aclRootId: aclRootId,
|
||||
settingsId: settingsId,
|
||||
spaceHeader: spaceHeader,
|
||||
}
|
||||
func (m mockPeer) ReleaseDrpcConn(conn drpc.Conn) {
|
||||
return
|
||||
}
|
||||
|
||||
func TestDiffSyncer_Sync(t *testing.T) {
|
||||
// setup
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
func (m mockPeer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
diffMock := mock_ldiff.NewMockDiff(ctrl)
|
||||
connectorMock := mock_confconnector.NewMockConfConnector(ctrl)
|
||||
cacheMock := mock_treegetter.NewMockTreeGetter(ctrl)
|
||||
stMock := mock_spacestorage.NewMockSpaceStorage(ctrl)
|
||||
clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl)
|
||||
factory := spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient {
|
||||
return clientMock
|
||||
func (m mockPeer) IsClosed() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m mockPeer) TryClose(objectTTL time.Duration) (res bool, err error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
func (m mockPeer) Close() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (fx *headSyncFixture) initDiffSyncer(t *testing.T) {
|
||||
fx.init(t)
|
||||
fx.diffSyncer = newDiffSyncer(fx.headSync).(*diffSyncer)
|
||||
fx.diffSyncer.clientFactory = spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient {
|
||||
return fx.clientMock
|
||||
})
|
||||
delState := mock_deletionstate.NewMockDeletionState(ctrl)
|
||||
spaceId := "spaceId"
|
||||
aclRootId := "aclRootId"
|
||||
l := logger.NewNamed(spaceId)
|
||||
diffSyncer := newDiffSyncer(spaceId, diffMock, connectorMock, cacheMock, stMock, factory, syncstatus.NewNoOpSyncStatus(), l)
|
||||
delState.EXPECT().AddObserver(gomock.Any())
|
||||
diffSyncer.Init(delState)
|
||||
fx.deletionStateMock.EXPECT().AddObserver(gomock.Any())
|
||||
fx.treeManagerMock.EXPECT().NewTreeSyncer(fx.spaceState.SpaceId, fx.treeManagerMock).Return(fx.treeSyncerMock)
|
||||
fx.diffSyncer.Init()
|
||||
}
|
||||
|
||||
func TestDiffSyncer(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("diff syncer sync", func(t *testing.T) {
|
||||
connectorMock.EXPECT().
|
||||
GetResponsiblePeers(gomock.Any(), spaceId).
|
||||
Return([]peer.Peer{mockPeer{}}, nil)
|
||||
diffMock.EXPECT().
|
||||
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))).
|
||||
fx := newHeadSyncFixture(t)
|
||||
fx.initDiffSyncer(t)
|
||||
defer fx.stop()
|
||||
mPeer := mockPeer{}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
|
||||
fx.peerManagerMock.EXPECT().
|
||||
GetResponsiblePeers(gomock.Any()).
|
||||
Return([]peer.Peer{mPeer}, nil)
|
||||
fx.diffMock.EXPECT().
|
||||
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
|
||||
Return([]string{"new"}, []string{"changed"}, nil, nil)
|
||||
delState.EXPECT().FilterJoin(gomock.Any()).Return([]string{"new", "changed"})
|
||||
for _, arg := range []string{"new", "changed"} {
|
||||
cacheMock.EXPECT().
|
||||
GetTree(gomock.Any(), spaceId, arg).
|
||||
Return(nil, nil)
|
||||
}
|
||||
require.NoError(t, diffSyncer.Sync(ctx))
|
||||
fx.deletionStateMock.EXPECT().Filter([]string{"new"}).Return([]string{"new"}).Times(1)
|
||||
fx.deletionStateMock.EXPECT().Filter([]string{"changed"}).Return([]string{"changed"}).Times(1)
|
||||
fx.deletionStateMock.EXPECT().Filter(nil).Return(nil).Times(1)
|
||||
fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"changed"}, []string{"new"}).Return(nil)
|
||||
require.NoError(t, fx.diffSyncer.Sync(ctx))
|
||||
})
|
||||
|
||||
t.Run("diff syncer sync, acl changed", func(t *testing.T) {
|
||||
fx := newHeadSyncFixture(t)
|
||||
fx.initDiffSyncer(t)
|
||||
defer fx.stop()
|
||||
mPeer := mockPeer{}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
|
||||
fx.peerManagerMock.EXPECT().
|
||||
GetResponsiblePeers(gomock.Any()).
|
||||
Return([]peer.Peer{mPeer}, nil)
|
||||
fx.diffMock.EXPECT().
|
||||
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
|
||||
Return([]string{"new"}, []string{"changed"}, nil, nil)
|
||||
fx.deletionStateMock.EXPECT().Filter([]string{"new"}).Return([]string{"new"}).Times(1)
|
||||
fx.deletionStateMock.EXPECT().Filter([]string{"changed"}).Return([]string{"changed", "aclId"}).Times(1)
|
||||
fx.deletionStateMock.EXPECT().Filter(nil).Return(nil).Times(1)
|
||||
fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"changed"}, []string{"new"}).Return(nil)
|
||||
fx.aclMock.EXPECT().SyncWithPeer(gomock.Any(), mPeer.Id()).Return(nil)
|
||||
require.NoError(t, fx.diffSyncer.Sync(ctx))
|
||||
})
|
||||
|
||||
t.Run("diff syncer sync conf error", func(t *testing.T) {
|
||||
connectorMock.EXPECT().
|
||||
GetResponsiblePeers(gomock.Any(), spaceId).
|
||||
fx := newHeadSyncFixture(t)
|
||||
fx.initDiffSyncer(t)
|
||||
defer fx.stop()
|
||||
ctx := context.Background()
|
||||
fx.peerManagerMock.EXPECT().
|
||||
GetResponsiblePeers(gomock.Any()).
|
||||
Return(nil, fmt.Errorf("some error"))
|
||||
|
||||
require.Error(t, diffSyncer.Sync(ctx))
|
||||
require.Error(t, fx.diffSyncer.Sync(ctx))
|
||||
})
|
||||
|
||||
t.Run("deletion state remove objects", func(t *testing.T) {
|
||||
fx := newHeadSyncFixture(t)
|
||||
fx.initDiffSyncer(t)
|
||||
defer fx.stop()
|
||||
deletedId := "id"
|
||||
delState.EXPECT().Exists(deletedId).Return(true)
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
|
||||
fx.deletionStateMock.EXPECT().Exists(deletedId).Return(true)
|
||||
|
||||
// this should not result in any mock being called
|
||||
diffSyncer.UpdateHeads(deletedId, []string{"someHead"})
|
||||
fx.diffSyncer.UpdateHeads(deletedId, []string{"someHead"})
|
||||
})
|
||||
|
||||
t.Run("update heads updates diff", func(t *testing.T) {
|
||||
fx := newHeadSyncFixture(t)
|
||||
fx.initDiffSyncer(t)
|
||||
defer fx.stop()
|
||||
newId := "newId"
|
||||
newHeads := []string{"h1", "h2"}
|
||||
hash := "hash"
|
||||
diffMock.EXPECT().Set(ldiff.Element{
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
|
||||
fx.diffMock.EXPECT().Set(ldiff.Element{
|
||||
Id: newId,
|
||||
Head: concatStrings(newHeads),
|
||||
})
|
||||
diffMock.EXPECT().Hash().Return(hash)
|
||||
delState.EXPECT().Exists(newId).Return(false)
|
||||
stMock.EXPECT().WriteSpaceHash(hash)
|
||||
diffSyncer.UpdateHeads(newId, newHeads)
|
||||
fx.diffMock.EXPECT().Hash().Return(hash)
|
||||
fx.deletionStateMock.EXPECT().Exists(newId).Return(false)
|
||||
fx.storageMock.EXPECT().WriteSpaceHash(hash)
|
||||
fx.diffSyncer.UpdateHeads(newId, newHeads)
|
||||
})
|
||||
|
||||
t.Run("diff syncer sync space missing", func(t *testing.T) {
|
||||
aclStorageMock := mock_liststorage.NewMockListStorage(ctrl)
|
||||
settingsStorage := mock_treestorage.NewMockTreeStorage(ctrl)
|
||||
fx := newHeadSyncFixture(t)
|
||||
fx.initDiffSyncer(t)
|
||||
defer fx.stop()
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
|
||||
aclStorageMock := mock_liststorage.NewMockListStorage(fx.ctrl)
|
||||
settingsStorage := mock_treestorage.NewMockTreeStorage(fx.ctrl)
|
||||
settingsId := "settingsId"
|
||||
aclRoot := &aclrecordproto.RawAclRecordWithId{
|
||||
aclRootId := "aclRootId"
|
||||
aclRoot := &consensusproto.RawRecordWithId{
|
||||
Id: aclRootId,
|
||||
}
|
||||
settingsRoot := &treechangeproto.RawTreeChangeWithId{
|
||||
@ -172,38 +202,65 @@ func TestDiffSyncer_Sync(t *testing.T) {
|
||||
}
|
||||
spaceHeader := &spacesyncproto.RawSpaceHeaderWithId{}
|
||||
spaceSettingsId := "spaceSettingsId"
|
||||
credential := []byte("credential")
|
||||
|
||||
connectorMock.EXPECT().
|
||||
GetResponsiblePeers(gomock.Any(), spaceId).
|
||||
fx.peerManagerMock.EXPECT().
|
||||
GetResponsiblePeers(gomock.Any()).
|
||||
Return([]peer.Peer{mockPeer{}}, nil)
|
||||
diffMock.EXPECT().
|
||||
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))).
|
||||
fx.diffMock.EXPECT().
|
||||
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
|
||||
Return(nil, nil, nil, spacesyncproto.ErrSpaceMissing)
|
||||
|
||||
stMock.EXPECT().AclStorage().Return(aclStorageMock, nil)
|
||||
stMock.EXPECT().SpaceHeader().Return(spaceHeader, nil)
|
||||
stMock.EXPECT().SpaceSettingsId().Return(spaceSettingsId)
|
||||
stMock.EXPECT().TreeStorage(spaceSettingsId).Return(settingsStorage, nil)
|
||||
fx.storageMock.EXPECT().AclStorage().Return(aclStorageMock, nil)
|
||||
fx.storageMock.EXPECT().SpaceHeader().Return(spaceHeader, nil)
|
||||
fx.storageMock.EXPECT().SpaceSettingsId().Return(spaceSettingsId)
|
||||
fx.storageMock.EXPECT().TreeStorage(spaceSettingsId).Return(settingsStorage, nil)
|
||||
|
||||
settingsStorage.EXPECT().Root().Return(settingsRoot, nil)
|
||||
aclStorageMock.EXPECT().
|
||||
Root().
|
||||
Return(aclRoot, nil)
|
||||
clientMock.EXPECT().
|
||||
SpacePush(gomock.Any(), newPushSpaceRequestMatcher(spaceId, aclRootId, settingsId, spaceHeader)).
|
||||
fx.credentialProviderMock.EXPECT().
|
||||
GetCredential(gomock.Any(), spaceHeader).
|
||||
Return(credential, nil)
|
||||
fx.clientMock.EXPECT().
|
||||
SpacePush(gomock.Any(), newPushSpaceRequestMatcher(fx.spaceState.SpaceId, aclRootId, settingsId, credential, spaceHeader)).
|
||||
Return(nil, nil)
|
||||
fx.peerManagerMock.EXPECT().SendPeer(gomock.Any(), "peerId", gomock.Any())
|
||||
|
||||
require.NoError(t, diffSyncer.Sync(ctx))
|
||||
require.NoError(t, fx.diffSyncer.Sync(ctx))
|
||||
})
|
||||
|
||||
t.Run("diff syncer sync other error", func(t *testing.T) {
|
||||
connectorMock.EXPECT().
|
||||
GetResponsiblePeers(gomock.Any(), spaceId).
|
||||
t.Run("diff syncer sync unexpected", func(t *testing.T) {
|
||||
fx := newHeadSyncFixture(t)
|
||||
fx.initDiffSyncer(t)
|
||||
defer fx.stop()
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
|
||||
fx.peerManagerMock.EXPECT().
|
||||
GetResponsiblePeers(gomock.Any()).
|
||||
Return([]peer.Peer{mockPeer{}}, nil)
|
||||
diffMock.EXPECT().
|
||||
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))).
|
||||
fx.diffMock.EXPECT().
|
||||
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
|
||||
Return(nil, nil, nil, spacesyncproto.ErrUnexpected)
|
||||
|
||||
require.NoError(t, diffSyncer.Sync(ctx))
|
||||
require.NoError(t, fx.diffSyncer.Sync(ctx))
|
||||
})
|
||||
|
||||
t.Run("diff syncer sync space is deleted error", func(t *testing.T) {
|
||||
fx := newHeadSyncFixture(t)
|
||||
fx.initDiffSyncer(t)
|
||||
defer fx.stop()
|
||||
mPeer := mockPeer{}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
|
||||
fx.peerManagerMock.EXPECT().
|
||||
GetResponsiblePeers(gomock.Any()).
|
||||
Return([]peer.Peer{mPeer}, nil)
|
||||
fx.diffMock.EXPECT().
|
||||
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
|
||||
Return(nil, nil, nil, spacesyncproto.ErrSpaceIsDeleted)
|
||||
fx.storageMock.EXPECT().SpaceSettingsId().Return("settingsId")
|
||||
fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"settingsId"}, nil).Return(nil)
|
||||
|
||||
require.NoError(t, fx.diffSyncer.Sync(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
@ -1,95 +1,152 @@
|
||||
//go:generate mockgen -destination mock_headsync/mock_headsync.go github.com/anytypeio/any-sync/commonspace/headsync DiffSyncer
|
||||
//go:generate mockgen -destination mock_headsync/mock_headsync.go github.com/anyproto/any-sync/commonspace/headsync DiffSyncer
|
||||
package headsync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/app/ldiff"
|
||||
"github.com/anytypeio/any-sync/commonspace/confconnector"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/treegetter"
|
||||
"github.com/anytypeio/any-sync/commonspace/settings/deletionstate"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacestorage"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/syncstatus"
|
||||
"github.com/anytypeio/any-sync/util/periodicsync"
|
||||
"go.uber.org/zap"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/app/ldiff"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
config2 "github.com/anyproto/any-sync/commonspace/config"
|
||||
"github.com/anyproto/any-sync/commonspace/credentialprovider"
|
||||
"github.com/anyproto/any-sync/commonspace/deletionstate"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
|
||||
"github.com/anyproto/any-sync/commonspace/object/treemanager"
|
||||
"github.com/anyproto/any-sync/commonspace/peermanager"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestate"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anyproto/any-sync/commonspace/syncstatus"
|
||||
"github.com/anyproto/any-sync/net/peer"
|
||||
"github.com/anyproto/any-sync/nodeconf"
|
||||
"github.com/anyproto/any-sync/util/periodicsync"
|
||||
"github.com/anyproto/any-sync/util/slice"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
var log = logger.NewNamed(CName)
|
||||
|
||||
const CName = "common.commonspace.headsync"
|
||||
|
||||
type TreeHeads struct {
|
||||
Id string
|
||||
Heads []string
|
||||
}
|
||||
|
||||
type HeadSync interface {
|
||||
Init(objectIds []string, deletionState deletionstate.DeletionState)
|
||||
|
||||
app.ComponentRunnable
|
||||
ExternalIds() []string
|
||||
DebugAllHeads() (res []TreeHeads)
|
||||
AllIds() []string
|
||||
UpdateHeads(id string, heads []string)
|
||||
HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error)
|
||||
RemoveObjects(ids []string)
|
||||
AllIds() []string
|
||||
DebugAllHeads() (res []TreeHeads)
|
||||
|
||||
Close() (err error)
|
||||
}
|
||||
|
||||
type headSync struct {
|
||||
spaceId string
|
||||
spaceIsDeleted *atomic.Bool
|
||||
syncPeriod int
|
||||
|
||||
periodicSync periodicsync.PeriodicSync
|
||||
storage spacestorage.SpaceStorage
|
||||
diff ldiff.Diff
|
||||
log *zap.Logger
|
||||
log logger.CtxLogger
|
||||
syncer DiffSyncer
|
||||
|
||||
syncPeriod int
|
||||
configuration nodeconf.NodeConf
|
||||
peerManager peermanager.PeerManager
|
||||
treeManager treemanager.TreeManager
|
||||
credentialProvider credentialprovider.CredentialProvider
|
||||
syncStatus syncstatus.StatusService
|
||||
deletionState deletionstate.ObjectDeletionState
|
||||
syncAcl syncacl.SyncAcl
|
||||
}
|
||||
|
||||
func NewHeadSync(
|
||||
spaceId string,
|
||||
syncPeriod int,
|
||||
storage spacestorage.SpaceStorage,
|
||||
confConnector confconnector.ConfConnector,
|
||||
cache treegetter.TreeGetter,
|
||||
syncStatus syncstatus.StatusUpdater,
|
||||
log *zap.Logger) HeadSync {
|
||||
func New() HeadSync {
|
||||
return &headSync{}
|
||||
}
|
||||
|
||||
diff := ldiff.New(16, 16)
|
||||
l := log.With(zap.String("spaceId", spaceId))
|
||||
factory := spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient)
|
||||
syncer := newDiffSyncer(spaceId, diff, confConnector, cache, storage, factory, syncStatus, l)
|
||||
periodicSync := periodicsync.NewPeriodicSync(syncPeriod, time.Minute, syncer.Sync, l)
|
||||
var createDiffSyncer = newDiffSyncer
|
||||
|
||||
return &headSync{
|
||||
spaceId: spaceId,
|
||||
storage: storage,
|
||||
syncer: syncer,
|
||||
periodicSync: periodicSync,
|
||||
diff: diff,
|
||||
log: log,
|
||||
syncPeriod: syncPeriod,
|
||||
func (h *headSync) Init(a *app.App) (err error) {
|
||||
shared := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
|
||||
cfg := a.MustComponent("config").(config2.ConfigGetter)
|
||||
h.syncAcl = a.MustComponent(syncacl.CName).(syncacl.SyncAcl)
|
||||
h.spaceId = shared.SpaceId
|
||||
h.spaceIsDeleted = shared.SpaceIsDeleted
|
||||
h.syncPeriod = cfg.GetSpace().SyncPeriod
|
||||
h.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
|
||||
h.log = log.With(zap.String("spaceId", h.spaceId))
|
||||
h.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
|
||||
h.diff = ldiff.New(16, 16)
|
||||
h.peerManager = a.MustComponent(peermanager.CName).(peermanager.PeerManager)
|
||||
h.credentialProvider = a.MustComponent(credentialprovider.CName).(credentialprovider.CredentialProvider)
|
||||
h.syncStatus = a.MustComponent(syncstatus.CName).(syncstatus.StatusService)
|
||||
h.treeManager = a.MustComponent(treemanager.CName).(treemanager.TreeManager)
|
||||
h.deletionState = a.MustComponent(deletionstate.CName).(deletionstate.ObjectDeletionState)
|
||||
h.syncer = createDiffSyncer(h)
|
||||
sync := func(ctx context.Context) (err error) {
|
||||
// for clients cancelling the sync process
|
||||
if h.spaceIsDeleted.Load() && !h.configuration.IsResponsible(h.spaceId) {
|
||||
return spacesyncproto.ErrSpaceIsDeleted
|
||||
}
|
||||
return h.syncer.Sync(ctx)
|
||||
}
|
||||
h.periodicSync = periodicsync.NewPeriodicSync(h.syncPeriod, time.Minute, sync, h.log)
|
||||
h.syncAcl.SetHeadUpdater(h)
|
||||
// TODO: move to run?
|
||||
h.syncer.Init()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *headSync) Init(objectIds []string, deletionState deletionstate.DeletionState) {
|
||||
d.fillDiff(objectIds)
|
||||
d.syncer.Init(deletionState)
|
||||
d.periodicSync.Run()
|
||||
func (h *headSync) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
func (d *headSync) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) {
|
||||
return HandleRangeRequest(ctx, d.diff, req)
|
||||
func (h *headSync) Run(ctx context.Context) (err error) {
|
||||
initialIds, err := h.storage.StoredIds()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
h.fillDiff(initialIds)
|
||||
h.periodicSync.Run()
|
||||
return
|
||||
}
|
||||
|
||||
func (d *headSync) UpdateHeads(id string, heads []string) {
|
||||
d.syncer.UpdateHeads(id, heads)
|
||||
func (h *headSync) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) {
|
||||
if h.spaceIsDeleted.Load() {
|
||||
peerId, err := peer.CtxPeerId(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// stop receiving all request for sync from clients
|
||||
if !slices.Contains(h.configuration.NodeIds(h.spaceId), peerId) {
|
||||
return nil, spacesyncproto.ErrSpaceIsDeleted
|
||||
}
|
||||
}
|
||||
return HandleRangeRequest(ctx, h.diff, req)
|
||||
}
|
||||
|
||||
func (d *headSync) AllIds() []string {
|
||||
return d.diff.Ids()
|
||||
func (h *headSync) UpdateHeads(id string, heads []string) {
|
||||
h.syncer.UpdateHeads(id, heads)
|
||||
}
|
||||
|
||||
func (d *headSync) DebugAllHeads() (res []TreeHeads) {
|
||||
els := d.diff.Elements()
|
||||
func (h *headSync) AllIds() []string {
|
||||
return h.diff.Ids()
|
||||
}
|
||||
|
||||
func (h *headSync) ExternalIds() []string {
|
||||
settingsId := h.storage.SpaceSettingsId()
|
||||
return slice.DiscardFromSlice(h.AllIds(), func(id string) bool {
|
||||
return id == settingsId
|
||||
})
|
||||
}
|
||||
|
||||
func (h *headSync) DebugAllHeads() (res []TreeHeads) {
|
||||
els := h.diff.Elements()
|
||||
for _, el := range els {
|
||||
idHead := TreeHeads{
|
||||
Id: el.Id,
|
||||
@ -100,19 +157,19 @@ func (d *headSync) DebugAllHeads() (res []TreeHeads) {
|
||||
return
|
||||
}
|
||||
|
||||
func (d *headSync) RemoveObjects(ids []string) {
|
||||
d.syncer.RemoveObjects(ids)
|
||||
func (h *headSync) RemoveObjects(ids []string) {
|
||||
h.syncer.RemoveObjects(ids)
|
||||
}
|
||||
|
||||
func (d *headSync) Close() (err error) {
|
||||
d.periodicSync.Close()
|
||||
return nil
|
||||
func (h *headSync) Close(ctx context.Context) (err error) {
|
||||
h.periodicSync.Close()
|
||||
return h.syncer.Close()
|
||||
}
|
||||
|
||||
func (d *headSync) fillDiff(objectIds []string) {
|
||||
func (h *headSync) fillDiff(objectIds []string) {
|
||||
var els = make([]ldiff.Element, 0, len(objectIds))
|
||||
for _, id := range objectIds {
|
||||
st, err := d.storage.TreeStorage(id)
|
||||
st, err := h.storage.TreeStorage(id)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@ -125,32 +182,12 @@ func (d *headSync) fillDiff(objectIds []string) {
|
||||
Head: concatStrings(heads),
|
||||
})
|
||||
}
|
||||
d.diff.Set(els...)
|
||||
if err := d.storage.WriteSpaceHash(d.diff.Hash()); err != nil {
|
||||
d.log.Error("can't write space hash", zap.Error(err))
|
||||
els = append(els, ldiff.Element{
|
||||
Id: h.syncAcl.Id(),
|
||||
Head: h.syncAcl.Head().Id,
|
||||
})
|
||||
h.diff.Set(els...)
|
||||
if err := h.storage.WriteSpaceHash(h.diff.Hash()); err != nil {
|
||||
h.log.Error("can't write space hash", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func concatStrings(strs []string) string {
|
||||
var (
|
||||
b strings.Builder
|
||||
totalLen int
|
||||
)
|
||||
for _, s := range strs {
|
||||
totalLen += len(s)
|
||||
}
|
||||
|
||||
b.Grow(totalLen)
|
||||
for _, s := range strs {
|
||||
b.WriteString(s)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func splitString(str string) (res []string) {
|
||||
const cidLen = 59
|
||||
for i := 0; i < len(str); i += cidLen {
|
||||
res = append(res, str[i:i+cidLen])
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -1,70 +1,190 @@
|
||||
package headsync
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/app/ldiff"
|
||||
"github.com/anytypeio/any-sync/app/ldiff/mock_ldiff"
|
||||
"github.com/anytypeio/any-sync/app/logger"
|
||||
"github.com/anytypeio/any-sync/commonspace/headsync/mock_headsync"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
|
||||
"github.com/anytypeio/any-sync/commonspace/settings/deletionstate/mock_deletionstate"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacestorage/mock_spacestorage"
|
||||
"github.com/anytypeio/any-sync/util/periodicsync/mock_periodicsync"
|
||||
"github.com/golang/mock/gomock"
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/app/ldiff"
|
||||
"github.com/anyproto/any-sync/app/ldiff/mock_ldiff"
|
||||
"github.com/anyproto/any-sync/commonspace/config"
|
||||
"github.com/anyproto/any-sync/commonspace/credentialprovider"
|
||||
"github.com/anyproto/any-sync/commonspace/credentialprovider/mock_credentialprovider"
|
||||
"github.com/anyproto/any-sync/commonspace/deletionstate"
|
||||
"github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate"
|
||||
"github.com/anyproto/any-sync/commonspace/headsync/mock_headsync"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl/mock_syncacl"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/object/treemanager"
|
||||
"github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager"
|
||||
"github.com/anyproto/any-sync/commonspace/peermanager"
|
||||
"github.com/anyproto/any-sync/commonspace/peermanager/mock_peermanager"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestate"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto/mock_spacesyncproto"
|
||||
"github.com/anyproto/any-sync/commonspace/syncstatus"
|
||||
"github.com/anyproto/any-sync/nodeconf"
|
||||
"github.com/anyproto/any-sync/nodeconf/mock_nodeconf"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDiffService(t *testing.T) {
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
type mockConfig struct {
|
||||
}
|
||||
|
||||
spaceId := "spaceId"
|
||||
l := logger.NewNamed("sync")
|
||||
pSyncMock := mock_periodicsync.NewMockPeriodicSync(ctrl)
|
||||
storageMock := mock_spacestorage.NewMockSpaceStorage(ctrl)
|
||||
treeStorageMock := mock_treestorage.NewMockTreeStorage(ctrl)
|
||||
diffMock := mock_ldiff.NewMockDiff(ctrl)
|
||||
syncer := mock_headsync.NewMockDiffSyncer(ctrl)
|
||||
delState := mock_deletionstate.NewMockDeletionState(ctrl)
|
||||
syncPeriod := 1
|
||||
initId := "initId"
|
||||
func (m mockConfig) Init(a *app.App) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
service := &headSync{
|
||||
spaceId: spaceId,
|
||||
storage: storageMock,
|
||||
periodicSync: pSyncMock,
|
||||
syncer: syncer,
|
||||
diff: diffMock,
|
||||
log: l,
|
||||
syncPeriod: syncPeriod,
|
||||
func (m mockConfig) Name() (name string) {
|
||||
return "config"
|
||||
}
|
||||
|
||||
func (m mockConfig) GetSpace() config.Config {
|
||||
return config.Config{}
|
||||
}
|
||||
|
||||
type headSyncFixture struct {
|
||||
spaceState *spacestate.SpaceState
|
||||
ctrl *gomock.Controller
|
||||
app *app.App
|
||||
|
||||
configurationMock *mock_nodeconf.MockService
|
||||
storageMock *mock_spacestorage.MockSpaceStorage
|
||||
peerManagerMock *mock_peermanager.MockPeerManager
|
||||
credentialProviderMock *mock_credentialprovider.MockCredentialProvider
|
||||
syncStatus syncstatus.StatusService
|
||||
treeManagerMock *mock_treemanager.MockTreeManager
|
||||
deletionStateMock *mock_deletionstate.MockObjectDeletionState
|
||||
diffSyncerMock *mock_headsync.MockDiffSyncer
|
||||
treeSyncerMock *mock_treemanager.MockTreeSyncer
|
||||
diffMock *mock_ldiff.MockDiff
|
||||
clientMock *mock_spacesyncproto.MockDRPCSpaceSyncClient
|
||||
aclMock *mock_syncacl.MockSyncAcl
|
||||
headSync *headSync
|
||||
diffSyncer *diffSyncer
|
||||
}
|
||||
|
||||
func newHeadSyncFixture(t *testing.T) *headSyncFixture {
|
||||
spaceState := &spacestate.SpaceState{
|
||||
SpaceId: "spaceId",
|
||||
SpaceIsDeleted: &atomic.Bool{},
|
||||
}
|
||||
ctrl := gomock.NewController(t)
|
||||
configurationMock := mock_nodeconf.NewMockService(ctrl)
|
||||
configurationMock.EXPECT().Name().AnyTimes().Return(nodeconf.CName)
|
||||
storageMock := mock_spacestorage.NewMockSpaceStorage(ctrl)
|
||||
storageMock.EXPECT().Name().AnyTimes().Return(spacestorage.CName)
|
||||
peerManagerMock := mock_peermanager.NewMockPeerManager(ctrl)
|
||||
peerManagerMock.EXPECT().Name().AnyTimes().Return(peermanager.CName)
|
||||
credentialProviderMock := mock_credentialprovider.NewMockCredentialProvider(ctrl)
|
||||
credentialProviderMock.EXPECT().Name().AnyTimes().Return(credentialprovider.CName)
|
||||
syncStatus := syncstatus.NewNoOpSyncStatus()
|
||||
treeManagerMock := mock_treemanager.NewMockTreeManager(ctrl)
|
||||
treeManagerMock.EXPECT().Name().AnyTimes().Return(treemanager.CName)
|
||||
deletionStateMock := mock_deletionstate.NewMockObjectDeletionState(ctrl)
|
||||
deletionStateMock.EXPECT().Name().AnyTimes().Return(deletionstate.CName)
|
||||
diffSyncerMock := mock_headsync.NewMockDiffSyncer(ctrl)
|
||||
treeSyncerMock := mock_treemanager.NewMockTreeSyncer(ctrl)
|
||||
diffMock := mock_ldiff.NewMockDiff(ctrl)
|
||||
clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl)
|
||||
aclMock := mock_syncacl.NewMockSyncAcl(ctrl)
|
||||
aclMock.EXPECT().Name().AnyTimes().Return(syncacl.CName)
|
||||
aclMock.EXPECT().SetHeadUpdater(gomock.Any()).AnyTimes()
|
||||
hs := &headSync{}
|
||||
a := &app.App{}
|
||||
a.Register(spaceState).
|
||||
Register(aclMock).
|
||||
Register(mockConfig{}).
|
||||
Register(configurationMock).
|
||||
Register(storageMock).
|
||||
Register(peerManagerMock).
|
||||
Register(credentialProviderMock).
|
||||
Register(syncStatus).
|
||||
Register(treeManagerMock).
|
||||
Register(deletionStateMock).
|
||||
Register(hs)
|
||||
return &headSyncFixture{
|
||||
spaceState: spaceState,
|
||||
ctrl: ctrl,
|
||||
app: a,
|
||||
configurationMock: configurationMock,
|
||||
storageMock: storageMock,
|
||||
peerManagerMock: peerManagerMock,
|
||||
credentialProviderMock: credentialProviderMock,
|
||||
syncStatus: syncStatus,
|
||||
treeManagerMock: treeManagerMock,
|
||||
deletionStateMock: deletionStateMock,
|
||||
headSync: hs,
|
||||
diffSyncerMock: diffSyncerMock,
|
||||
treeSyncerMock: treeSyncerMock,
|
||||
diffMock: diffMock,
|
||||
clientMock: clientMock,
|
||||
aclMock: aclMock,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("init", func(t *testing.T) {
|
||||
storageMock.EXPECT().TreeStorage(initId).Return(treeStorageMock, nil)
|
||||
treeStorageMock.EXPECT().Heads().Return([]string{"h1", "h2"}, nil)
|
||||
syncer.EXPECT().Init(delState)
|
||||
diffMock.EXPECT().Set(ldiff.Element{
|
||||
Id: initId,
|
||||
func (fx *headSyncFixture) init(t *testing.T) {
|
||||
createDiffSyncer = func(hs *headSync) DiffSyncer {
|
||||
return fx.diffSyncerMock
|
||||
}
|
||||
fx.diffSyncerMock.EXPECT().Init()
|
||||
err := fx.headSync.Init(fx.app)
|
||||
require.NoError(t, err)
|
||||
fx.headSync.diff = fx.diffMock
|
||||
}
|
||||
|
||||
func (fx *headSyncFixture) stop() {
|
||||
fx.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestHeadSync(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("run close", func(t *testing.T) {
|
||||
fx := newHeadSyncFixture(t)
|
||||
fx.init(t)
|
||||
defer fx.stop()
|
||||
|
||||
ids := []string{"id1"}
|
||||
treeMock := mock_treestorage.NewMockTreeStorage(fx.ctrl)
|
||||
fx.storageMock.EXPECT().StoredIds().Return(ids, nil)
|
||||
fx.storageMock.EXPECT().TreeStorage(ids[0]).Return(treeMock, nil)
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
|
||||
fx.aclMock.EXPECT().Head().AnyTimes().Return(&list.AclRecord{Id: "headId"})
|
||||
treeMock.EXPECT().Heads().Return([]string{"h1", "h2"}, nil)
|
||||
fx.diffMock.EXPECT().Set(ldiff.Element{
|
||||
Id: "id1",
|
||||
Head: "h1h2",
|
||||
})
|
||||
hash := "123"
|
||||
diffMock.EXPECT().Hash().Return(hash)
|
||||
storageMock.EXPECT().WriteSpaceHash(hash)
|
||||
pSyncMock.EXPECT().Run()
|
||||
service.Init([]string{initId}, delState)
|
||||
fx.diffMock.EXPECT().Hash().Return("hash")
|
||||
fx.storageMock.EXPECT().WriteSpaceHash("hash").Return(nil)
|
||||
fx.diffSyncerMock.EXPECT().Sync(gomock.Any()).Return(nil)
|
||||
fx.diffSyncerMock.EXPECT().Close().Return(nil)
|
||||
err := fx.headSync.Run(ctx)
|
||||
require.NoError(t, err)
|
||||
err = fx.headSync.Close(ctx)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("update heads", func(t *testing.T) {
|
||||
syncer.EXPECT().UpdateHeads(initId, []string{"h1", "h2"})
|
||||
service.UpdateHeads(initId, []string{"h1", "h2"})
|
||||
fx := newHeadSyncFixture(t)
|
||||
fx.init(t)
|
||||
defer fx.stop()
|
||||
|
||||
fx.diffSyncerMock.EXPECT().UpdateHeads("id1", []string{"h1"})
|
||||
fx.headSync.UpdateHeads("id1", []string{"h1"})
|
||||
})
|
||||
|
||||
t.Run("remove objects", func(t *testing.T) {
|
||||
syncer.EXPECT().RemoveObjects([]string{"h1", "h2"})
|
||||
service.RemoveObjects([]string{"h1", "h2"})
|
||||
})
|
||||
fx := newHeadSyncFixture(t)
|
||||
fx.init(t)
|
||||
defer fx.stop()
|
||||
|
||||
t.Run("close", func(t *testing.T) {
|
||||
pSyncMock.EXPECT().Close()
|
||||
service.Close()
|
||||
fx.diffSyncerMock.EXPECT().RemoveObjects([]string{"id1"})
|
||||
fx.headSync.RemoveObjects([]string{"id1"})
|
||||
})
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anytypeio/any-sync/commonspace/headsync (interfaces: DiffSyncer)
|
||||
// Source: github.com/anyproto/any-sync/commonspace/headsync (interfaces: DiffSyncer)
|
||||
|
||||
// Package mock_headsync is a generated GoMock package.
|
||||
package mock_headsync
|
||||
@ -8,8 +8,7 @@ import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
deletionstate "github.com/anytypeio/any-sync/commonspace/settings/deletionstate"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockDiffSyncer is a mock of DiffSyncer interface.
|
||||
@ -35,16 +34,30 @@ func (m *MockDiffSyncer) EXPECT() *MockDiffSyncerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Init mocks base method.
|
||||
func (m *MockDiffSyncer) Init(arg0 deletionstate.DeletionState) {
|
||||
// Close mocks base method.
|
||||
func (m *MockDiffSyncer) Close() error {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Init", arg0)
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockDiffSyncerMockRecorder) Close() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDiffSyncer)(nil).Close))
|
||||
}
|
||||
|
||||
// Init mocks base method.
|
||||
func (m *MockDiffSyncer) Init() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Init")
|
||||
}
|
||||
|
||||
// Init indicates an expected call of Init.
|
||||
func (mr *MockDiffSyncerMockRecorder) Init(arg0 interface{}) *gomock.Call {
|
||||
func (mr *MockDiffSyncerMockRecorder) Init() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockDiffSyncer)(nil).Init), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockDiffSyncer)(nil).Init))
|
||||
}
|
||||
|
||||
// RemoveObjects mocks base method.
|
||||
|
||||
@ -2,8 +2,8 @@ package headsync
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/app/ldiff"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anyproto/any-sync/app/ldiff"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
)
|
||||
|
||||
type Client interface {
|
||||
|
||||
@ -3,8 +3,8 @@ package headsync
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/app/ldiff"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anyproto/any-sync/app/ldiff"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
|
||||
27
commonspace/headsync/util.go
Normal file
27
commonspace/headsync/util.go
Normal file
@ -0,0 +1,27 @@
|
||||
package headsync
|
||||
|
||||
import "strings"
|
||||
|
||||
func concatStrings(strs []string) string {
|
||||
var (
|
||||
b strings.Builder
|
||||
totalLen int
|
||||
)
|
||||
for _, s := range strs {
|
||||
totalLen += len(s)
|
||||
}
|
||||
|
||||
b.Grow(totalLen)
|
||||
for _, s := range strs {
|
||||
b.WriteString(s)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func splitString(str string) (res []string) {
|
||||
const cidLen = 59
|
||||
for i := 0; i < len(str); i += cidLen {
|
||||
res = append(res, str[i:i+cidLen])
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -1,14 +1,36 @@
|
||||
package accountdata
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/encryptionkey"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
"crypto/rand"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
)
|
||||
|
||||
type AccountData struct { // TODO: create a convenient constructor for this
|
||||
Identity []byte // public key
|
||||
PeerKey signingkey.PrivKey
|
||||
SignKey signingkey.PrivKey
|
||||
EncKey encryptionkey.PrivKey
|
||||
type AccountKeys struct {
|
||||
PeerKey crypto.PrivKey
|
||||
SignKey crypto.PrivKey
|
||||
PeerId string
|
||||
}
|
||||
|
||||
func New(peerKey, signKey crypto.PrivKey) *AccountKeys {
|
||||
return &AccountKeys{
|
||||
PeerKey: peerKey,
|
||||
SignKey: signKey,
|
||||
PeerId: peerKey.GetPublic().PeerId(),
|
||||
}
|
||||
}
|
||||
|
||||
func NewRandom() (*AccountKeys, error) {
|
||||
peerKey, _, err := crypto.GenerateEd25519Key(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
signKey, _, err := crypto.GenerateEd25519Key(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AccountKeys{
|
||||
PeerKey: peerKey,
|
||||
SignKey: signKey,
|
||||
PeerId: peerKey.GetPublic().PeerId(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -1,12 +0,0 @@
|
||||
package aclrecordproto
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/util/keys/symmetric"
|
||||
)
|
||||
|
||||
func AclReadKeyDerive(signKey []byte, encKey []byte) (*symmetric.Key, error) {
|
||||
concBuf := make([]byte, 0, len(signKey)+len(encKey))
|
||||
concBuf = append(concBuf, signKey...)
|
||||
concBuf = append(concBuf, encKey...)
|
||||
return symmetric.DeriveFromBytes(concBuf)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@ -2,117 +2,105 @@ syntax = "proto3";
|
||||
package aclrecord;
|
||||
option go_package = "commonspace/object/acl/aclrecordproto";
|
||||
|
||||
message RawAclRecord {
|
||||
bytes payload = 1;
|
||||
bytes signature = 2;
|
||||
bytes acceptorIdentity = 3;
|
||||
bytes acceptorSignature = 4;
|
||||
}
|
||||
|
||||
message RawAclRecordWithId {
|
||||
bytes payload = 1;
|
||||
string id = 2;
|
||||
}
|
||||
|
||||
message AclRecord {
|
||||
string prevId = 1;
|
||||
bytes identity = 2;
|
||||
bytes data = 3;
|
||||
uint64 currentReadKeyHash = 4;
|
||||
int64 timestamp = 5;
|
||||
}
|
||||
|
||||
// AclRoot is a root of access control list
|
||||
message AclRoot {
|
||||
bytes identity = 1;
|
||||
bytes encryptionKey = 2;
|
||||
bytes masterKey = 2;
|
||||
string spaceId = 3;
|
||||
bytes encryptedReadKey = 4;
|
||||
string derivationScheme = 5;
|
||||
uint64 currentReadKeyHash = 6;
|
||||
int64 timestamp = 7;
|
||||
int64 timestamp = 5;
|
||||
bytes identitySignature = 6;
|
||||
}
|
||||
|
||||
message AclContentValue {
|
||||
oneof value {
|
||||
AclUserAdd userAdd = 1;
|
||||
AclUserRemove userRemove = 2;
|
||||
AclUserPermissionChange userPermissionChange = 3;
|
||||
AclUserInvite userInvite = 4;
|
||||
AclUserJoin userJoin = 5;
|
||||
}
|
||||
// AclAccountInvite contains the public invite key, the private part of which is sent to the user directly
|
||||
message AclAccountInvite {
|
||||
bytes inviteKey = 1;
|
||||
}
|
||||
|
||||
message AclData {
|
||||
repeated AclContentValue aclContent = 1;
|
||||
// AclAccountRequestJoin contains the reference to the invite record and the data of the person who wants to join, confirmed by the private invite key
|
||||
message AclAccountRequestJoin {
|
||||
bytes inviteIdentity = 1;
|
||||
string inviteRecordId = 2;
|
||||
bytes inviteIdentitySignature = 3;
|
||||
bytes metadata = 4;
|
||||
}
|
||||
|
||||
message AclState {
|
||||
repeated uint64 readKeyHashes = 1;
|
||||
repeated AclUserState userStates = 2;
|
||||
map<string, AclUserInvite> invites = 3;
|
||||
}
|
||||
|
||||
message AclUserState {
|
||||
// AclAccountRequestAccept contains the reference to join record and all read keys, encrypted with the identity of the requestor
|
||||
message AclAccountRequestAccept {
|
||||
bytes identity = 1;
|
||||
bytes encryptionKey = 2;
|
||||
AclUserPermissions permissions = 3;
|
||||
}
|
||||
|
||||
message AclUserAdd {
|
||||
bytes identity = 1;
|
||||
bytes encryptionKey = 2;
|
||||
repeated bytes encryptedReadKeys = 3;
|
||||
string requestRecordId = 2;
|
||||
repeated AclReadKeyWithRecord encryptedReadKeys = 3;
|
||||
AclUserPermissions permissions = 4;
|
||||
}
|
||||
|
||||
message AclUserInvite {
|
||||
bytes acceptPublicKey = 1;
|
||||
uint64 encryptSymKeyHash = 2;
|
||||
repeated bytes encryptedReadKeys = 3;
|
||||
AclUserPermissions permissions = 4;
|
||||
// AclAccountRequestDecline contains the reference to join record
|
||||
message AclAccountRequestDecline {
|
||||
string requestRecordId = 1;
|
||||
}
|
||||
|
||||
message AclUserJoin {
|
||||
// AclAccountInviteRevoke revokes the invite record
|
||||
message AclAccountInviteRevoke {
|
||||
string inviteRecordId = 1;
|
||||
}
|
||||
|
||||
// AclReadKeys are a read key with record id
|
||||
message AclReadKeyWithRecord {
|
||||
string recordId = 1;
|
||||
bytes encryptedReadKey = 2;
|
||||
}
|
||||
|
||||
// AclEncryptedReadKeys are new key for specific identity
|
||||
message AclEncryptedReadKey {
|
||||
bytes identity = 1;
|
||||
bytes encryptionKey = 2;
|
||||
bytes acceptSignature = 3;
|
||||
bytes acceptPubKey = 4;
|
||||
repeated bytes encryptedReadKeys = 5;
|
||||
bytes encryptedReadKey = 2;
|
||||
}
|
||||
|
||||
message AclUserRemove {
|
||||
bytes identity = 1;
|
||||
repeated AclReadKeyReplace readKeyReplaces = 2;
|
||||
}
|
||||
|
||||
message AclReadKeyReplace {
|
||||
bytes identity = 1;
|
||||
bytes encryptionKey = 2;
|
||||
bytes encryptedReadKey = 3;
|
||||
}
|
||||
|
||||
message AclUserPermissionChange {
|
||||
// AclAccountPermissionChange changes permissions of specific account
|
||||
message AclAccountPermissionChange {
|
||||
bytes identity = 1;
|
||||
AclUserPermissions permissions = 2;
|
||||
}
|
||||
|
||||
enum AclUserPermissions {
|
||||
Admin = 0;
|
||||
Writer = 1;
|
||||
Reader = 2;
|
||||
// AclReadKeyChange changes the key for a space
|
||||
message AclReadKeyChange {
|
||||
repeated AclEncryptedReadKey accountKeys = 1;
|
||||
}
|
||||
|
||||
message AclSyncMessage {
|
||||
AclSyncContentValue content = 2;
|
||||
// AclAccountRemove removes an account and changes read key for space
|
||||
message AclAccountRemove {
|
||||
repeated bytes identities = 1;
|
||||
repeated AclEncryptedReadKey accountKeys = 2;
|
||||
}
|
||||
|
||||
// AclSyncContentValue provides different types for acl sync
|
||||
message AclSyncContentValue {
|
||||
// AclAccountRequestRemove adds a request to remove an account
|
||||
message AclAccountRequestRemove {
|
||||
}
|
||||
|
||||
// AclContentValue contains possible values for Acl
|
||||
message AclContentValue {
|
||||
oneof value {
|
||||
AclAddRecords addRecords = 1;
|
||||
AclAccountInvite invite = 1;
|
||||
AclAccountInviteRevoke inviteRevoke = 2;
|
||||
AclAccountRequestJoin requestJoin = 3;
|
||||
AclAccountRequestAccept requestAccept = 4;
|
||||
AclAccountPermissionChange permissionChange = 5;
|
||||
AclAccountRemove accountRemove = 6;
|
||||
AclReadKeyChange readKeyChange = 7;
|
||||
AclAccountRequestDecline requestDecline = 8;
|
||||
AclAccountRequestRemove accountRequestRemove = 9;
|
||||
}
|
||||
}
|
||||
|
||||
message AclAddRecords {
|
||||
repeated RawAclRecordWithId records = 1;
|
||||
// AclData contains different acl content
|
||||
message AclData {
|
||||
repeated AclContentValue aclContent = 1;
|
||||
}
|
||||
|
||||
// AclUserPermissions contains different possible user roles
|
||||
enum AclUserPermissions {
|
||||
None = 0;
|
||||
Owner = 1;
|
||||
Admin = 2;
|
||||
Writer = 3;
|
||||
Reader = 4;
|
||||
}
|
||||
@ -1,166 +1,499 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/keychain"
|
||||
"github.com/anytypeio/any-sync/util/cidutil"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
"github.com/anytypeio/any-sync/util/keys/symmetric"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"time"
|
||||
|
||||
"github.com/anyproto/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
"github.com/anyproto/any-sync/util/cidutil"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
)
|
||||
|
||||
// remove interface
|
||||
type RootContent struct {
|
||||
PrivKey crypto.PrivKey
|
||||
MasterKey crypto.PrivKey
|
||||
SpaceId string
|
||||
EncryptedReadKey []byte
|
||||
}
|
||||
|
||||
type RequestJoinPayload struct {
|
||||
InviteRecordId string
|
||||
InviteKey crypto.PrivKey
|
||||
Metadata []byte
|
||||
}
|
||||
|
||||
type RequestAcceptPayload struct {
|
||||
RequestRecordId string
|
||||
Permissions AclPermissions
|
||||
}
|
||||
|
||||
type PermissionChangePayload struct {
|
||||
Identity crypto.PubKey
|
||||
Permissions AclPermissions
|
||||
}
|
||||
|
||||
type AccountRemovePayload struct {
|
||||
Identities []crypto.PubKey
|
||||
ReadKey crypto.SymKey
|
||||
}
|
||||
|
||||
type InviteResult struct {
|
||||
InviteRec *consensusproto.RawRecord
|
||||
InviteKey crypto.PrivKey
|
||||
}
|
||||
|
||||
type AclRecordBuilder interface {
|
||||
ConvertFromRaw(rawIdRecord *aclrecordproto.RawAclRecordWithId) (rec *AclRecord, err error)
|
||||
BuildUserJoin(acceptPrivKeyBytes []byte, encSymKeyBytes []byte, state *AclState) (rec *aclrecordproto.RawAclRecord, err error)
|
||||
UnmarshallWithId(rawIdRecord *consensusproto.RawRecordWithId) (rec *AclRecord, err error)
|
||||
Unmarshall(rawRecord *consensusproto.RawRecord) (rec *AclRecord, err error)
|
||||
|
||||
BuildRoot(content RootContent) (rec *consensusproto.RawRecordWithId, err error)
|
||||
BuildInvite() (res InviteResult, err error)
|
||||
BuildInviteRevoke(inviteRecordId string) (rawRecord *consensusproto.RawRecord, err error)
|
||||
BuildRequestJoin(payload RequestJoinPayload) (rawRecord *consensusproto.RawRecord, err error)
|
||||
BuildRequestAccept(payload RequestAcceptPayload) (rawRecord *consensusproto.RawRecord, err error)
|
||||
BuildRequestDecline(requestRecordId string) (rawRecord *consensusproto.RawRecord, err error)
|
||||
BuildRequestRemove() (rawRecord *consensusproto.RawRecord, err error)
|
||||
BuildPermissionChange(payload PermissionChangePayload) (rawRecord *consensusproto.RawRecord, err error)
|
||||
BuildReadKeyChange(newKey crypto.SymKey) (rawRecord *consensusproto.RawRecord, err error)
|
||||
BuildAccountRemove(payload AccountRemovePayload) (rawRecord *consensusproto.RawRecord, err error)
|
||||
}
|
||||
|
||||
type aclRecordBuilder struct {
|
||||
id string
|
||||
keychain *keychain.Keychain
|
||||
keyStorage crypto.KeyStorage
|
||||
accountKeys *accountdata.AccountKeys
|
||||
verifier AcceptorVerifier
|
||||
state *AclState
|
||||
}
|
||||
|
||||
func newAclRecordBuilder(id string, keychain *keychain.Keychain) AclRecordBuilder {
|
||||
func NewAclRecordBuilder(id string, keyStorage crypto.KeyStorage, keys *accountdata.AccountKeys, verifier AcceptorVerifier) AclRecordBuilder {
|
||||
return &aclRecordBuilder{
|
||||
id: id,
|
||||
keychain: keychain,
|
||||
keyStorage: keyStorage,
|
||||
accountKeys: keys,
|
||||
verifier: verifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) BuildUserJoin(acceptPrivKeyBytes []byte, encSymKeyBytes []byte, state *AclState) (rec *aclrecordproto.RawAclRecord, err error) {
|
||||
acceptPrivKey, err := signingkey.NewSigningEd25519PrivKeyFromBytes(acceptPrivKeyBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
acceptPubKeyBytes, err := acceptPrivKey.GetPublic().Raw()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
encSymKey, err := symmetric.FromBytes(encSymKeyBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
invite, err := state.Invite(acceptPubKeyBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
encPrivKey, signPrivKey := state.UserKeys()
|
||||
var symKeys [][]byte
|
||||
for _, rk := range invite.EncryptedReadKeys {
|
||||
dec, err := encSymKey.Decrypt(rk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newEnc, err := encPrivKey.GetPublic().Encrypt(dec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
symKeys = append(symKeys, newEnc)
|
||||
}
|
||||
idSignature, err := acceptPrivKey.Sign(state.Identity())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
encPubKeyBytes, err := encPrivKey.GetPublic().Raw()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
userJoin := &aclrecordproto.AclUserJoin{
|
||||
Identity: state.Identity(),
|
||||
EncryptionKey: encPubKeyBytes,
|
||||
AcceptSignature: idSignature,
|
||||
AcceptPubKey: acceptPubKeyBytes,
|
||||
EncryptedReadKeys: symKeys,
|
||||
}
|
||||
func (a *aclRecordBuilder) buildRecord(aclContent *aclrecordproto.AclContentValue) (rawRec *consensusproto.RawRecord, err error) {
|
||||
aclData := &aclrecordproto.AclData{AclContent: []*aclrecordproto.AclContentValue{
|
||||
{Value: &aclrecordproto.AclContentValue_UserJoin{UserJoin: userJoin}},
|
||||
aclContent,
|
||||
}}
|
||||
marshalledJoin, err := aclData.Marshal()
|
||||
marshalledData, err := aclData.Marshal()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
aclRecord := &aclrecordproto.AclRecord{
|
||||
PrevId: state.LastRecordId(),
|
||||
Identity: state.Identity(),
|
||||
Data: marshalledJoin,
|
||||
CurrentReadKeyHash: state.CurrentReadKeyHash(),
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
}
|
||||
marshalledRecord, err := aclRecord.Marshal()
|
||||
protoKey, err := a.accountKeys.SignKey.GetPublic().Marshall()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
recSignature, err := signPrivKey.Sign(marshalledRecord)
|
||||
rec := &consensusproto.Record{
|
||||
PrevId: a.state.lastRecordId,
|
||||
Identity: protoKey,
|
||||
Data: marshalledData,
|
||||
Timestamp: time.Now().Unix(),
|
||||
}
|
||||
marshalledRec, err := rec.Marshal()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rec = &aclrecordproto.RawAclRecord{
|
||||
Payload: marshalledRecord,
|
||||
Signature: recSignature,
|
||||
signature, err := a.accountKeys.SignKey.Sign(marshalledRec)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rawRec = &consensusproto.RawRecord{
|
||||
Payload: marshalledRec,
|
||||
Signature: signature,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) ConvertFromRaw(rawIdRecord *aclrecordproto.RawAclRecordWithId) (rec *AclRecord, err error) {
|
||||
rawRec := &aclrecordproto.RawAclRecord{}
|
||||
func (a *aclRecordBuilder) BuildInvite() (res InviteResult, err error) {
|
||||
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
|
||||
err = ErrInsufficientPermissions
|
||||
return
|
||||
}
|
||||
privKey, pubKey, err := crypto.GenerateRandomEd25519KeyPair()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
invitePubKey, err := pubKey.Marshall()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
inviteRec := &aclrecordproto.AclAccountInvite{InviteKey: invitePubKey}
|
||||
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_Invite{Invite: inviteRec}}
|
||||
rawRec, err := a.buildRecord(content)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res.InviteKey = privKey
|
||||
res.InviteRec = rawRec
|
||||
return
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) BuildInviteRevoke(inviteRecordId string) (rawRecord *consensusproto.RawRecord, err error) {
|
||||
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
|
||||
err = ErrInsufficientPermissions
|
||||
return
|
||||
}
|
||||
_, exists := a.state.inviteKeys[inviteRecordId]
|
||||
if !exists {
|
||||
err = ErrNoSuchInvite
|
||||
return
|
||||
}
|
||||
revokeRec := &aclrecordproto.AclAccountInviteRevoke{InviteRecordId: inviteRecordId}
|
||||
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_InviteRevoke{InviteRevoke: revokeRec}}
|
||||
return a.buildRecord(content)
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) BuildRequestJoin(payload RequestJoinPayload) (rawRecord *consensusproto.RawRecord, err error) {
|
||||
key, exists := a.state.inviteKeys[payload.InviteRecordId]
|
||||
if !exists {
|
||||
err = ErrNoSuchInvite
|
||||
return
|
||||
}
|
||||
if !payload.InviteKey.GetPublic().Equals(key) {
|
||||
err = ErrIncorrectInviteKey
|
||||
}
|
||||
rawIdentity, err := a.accountKeys.SignKey.GetPublic().Raw()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
signature, err := payload.InviteKey.Sign(rawIdentity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
protoIdentity, err := a.accountKeys.SignKey.GetPublic().Marshall()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
joinRec := &aclrecordproto.AclAccountRequestJoin{
|
||||
InviteIdentity: protoIdentity,
|
||||
InviteRecordId: payload.InviteRecordId,
|
||||
InviteIdentitySignature: signature,
|
||||
Metadata: payload.Metadata,
|
||||
}
|
||||
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_RequestJoin{RequestJoin: joinRec}}
|
||||
return a.buildRecord(content)
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) BuildRequestAccept(payload RequestAcceptPayload) (rawRecord *consensusproto.RawRecord, err error) {
|
||||
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
|
||||
err = ErrInsufficientPermissions
|
||||
return
|
||||
}
|
||||
request, exists := a.state.requestRecords[payload.RequestRecordId]
|
||||
if !exists {
|
||||
err = ErrNoSuchRequest
|
||||
return
|
||||
}
|
||||
var encryptedReadKeys []*aclrecordproto.AclReadKeyWithRecord
|
||||
for keyId, key := range a.state.userReadKeys {
|
||||
rawKey, err := key.Raw()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enc, err := request.RequestIdentity.Encrypt(rawKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
encryptedReadKeys = append(encryptedReadKeys, &aclrecordproto.AclReadKeyWithRecord{
|
||||
RecordId: keyId,
|
||||
EncryptedReadKey: enc,
|
||||
})
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
requestIdentityProto, err := request.RequestIdentity.Marshall()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
acceptRec := &aclrecordproto.AclAccountRequestAccept{
|
||||
Identity: requestIdentityProto,
|
||||
RequestRecordId: payload.RequestRecordId,
|
||||
EncryptedReadKeys: encryptedReadKeys,
|
||||
Permissions: aclrecordproto.AclUserPermissions(payload.Permissions),
|
||||
}
|
||||
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_RequestAccept{RequestAccept: acceptRec}}
|
||||
return a.buildRecord(content)
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) BuildRequestDecline(requestRecordId string) (rawRecord *consensusproto.RawRecord, err error) {
|
||||
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
|
||||
err = ErrInsufficientPermissions
|
||||
return
|
||||
}
|
||||
_, exists := a.state.requestRecords[requestRecordId]
|
||||
if !exists {
|
||||
err = ErrNoSuchRequest
|
||||
return
|
||||
}
|
||||
declineRec := &aclrecordproto.AclAccountRequestDecline{RequestRecordId: requestRecordId}
|
||||
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_RequestDecline{RequestDecline: declineRec}}
|
||||
return a.buildRecord(content)
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) BuildPermissionChange(payload PermissionChangePayload) (rawRecord *consensusproto.RawRecord, err error) {
|
||||
permissions := a.state.Permissions(a.state.pubKey)
|
||||
if !permissions.CanManageAccounts() || payload.Identity.Equals(a.state.pubKey) {
|
||||
err = ErrInsufficientPermissions
|
||||
return
|
||||
}
|
||||
if payload.Permissions.IsOwner() {
|
||||
err = ErrIsOwner
|
||||
return
|
||||
}
|
||||
protoIdentity, err := payload.Identity.Marshall()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
permissionRec := &aclrecordproto.AclAccountPermissionChange{
|
||||
Identity: protoIdentity,
|
||||
Permissions: aclrecordproto.AclUserPermissions(payload.Permissions),
|
||||
}
|
||||
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_PermissionChange{PermissionChange: permissionRec}}
|
||||
return a.buildRecord(content)
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) BuildReadKeyChange(newKey crypto.SymKey) (rawRecord *consensusproto.RawRecord, err error) {
|
||||
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
|
||||
err = ErrInsufficientPermissions
|
||||
return
|
||||
}
|
||||
rawKey, err := newKey.Raw()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(rawKey) != crypto.KeyBytes {
|
||||
err = ErrIncorrectReadKey
|
||||
return
|
||||
}
|
||||
var aclReadKeys []*aclrecordproto.AclEncryptedReadKey
|
||||
for _, st := range a.state.userStates {
|
||||
protoIdentity, err := st.PubKey.Marshall()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enc, err := st.PubKey.Encrypt(rawKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aclReadKeys = append(aclReadKeys, &aclrecordproto.AclEncryptedReadKey{
|
||||
Identity: protoIdentity,
|
||||
EncryptedReadKey: enc,
|
||||
})
|
||||
}
|
||||
readRec := &aclrecordproto.AclReadKeyChange{AccountKeys: aclReadKeys}
|
||||
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_ReadKeyChange{ReadKeyChange: readRec}}
|
||||
return a.buildRecord(content)
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) BuildAccountRemove(payload AccountRemovePayload) (rawRecord *consensusproto.RawRecord, err error) {
|
||||
deletedMap := map[string]struct{}{}
|
||||
for _, key := range payload.Identities {
|
||||
permissions := a.state.Permissions(key)
|
||||
if permissions.IsOwner() {
|
||||
return nil, ErrInsufficientPermissions
|
||||
}
|
||||
if permissions.NoPermissions() {
|
||||
return nil, ErrNoSuchAccount
|
||||
}
|
||||
deletedMap[mapKeyFromPubKey(key)] = struct{}{}
|
||||
}
|
||||
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
|
||||
err = ErrInsufficientPermissions
|
||||
return
|
||||
}
|
||||
rawKey, err := payload.ReadKey.Raw()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(rawKey) != crypto.KeyBytes {
|
||||
err = ErrIncorrectReadKey
|
||||
return
|
||||
}
|
||||
var aclReadKeys []*aclrecordproto.AclEncryptedReadKey
|
||||
for _, st := range a.state.userStates {
|
||||
if _, exists := deletedMap[mapKeyFromPubKey(st.PubKey)]; exists {
|
||||
continue
|
||||
}
|
||||
protoIdentity, err := st.PubKey.Marshall()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
enc, err := st.PubKey.Encrypt(rawKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aclReadKeys = append(aclReadKeys, &aclrecordproto.AclEncryptedReadKey{
|
||||
Identity: protoIdentity,
|
||||
EncryptedReadKey: enc,
|
||||
})
|
||||
}
|
||||
var marshalledIdentities [][]byte
|
||||
for _, key := range payload.Identities {
|
||||
protoIdentity, err := key.Marshall()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
marshalledIdentities = append(marshalledIdentities, protoIdentity)
|
||||
}
|
||||
removeRec := &aclrecordproto.AclAccountRemove{AccountKeys: aclReadKeys, Identities: marshalledIdentities}
|
||||
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_AccountRemove{AccountRemove: removeRec}}
|
||||
return a.buildRecord(content)
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) BuildRequestRemove() (rawRecord *consensusproto.RawRecord, err error) {
|
||||
permissions := a.state.Permissions(a.state.pubKey)
|
||||
if permissions.NoPermissions() {
|
||||
err = ErrNoSuchAccount
|
||||
return
|
||||
}
|
||||
if permissions.IsOwner() {
|
||||
err = ErrIsOwner
|
||||
return
|
||||
}
|
||||
removeRec := &aclrecordproto.AclAccountRequestRemove{}
|
||||
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_AccountRequestRemove{AccountRequestRemove: removeRec}}
|
||||
return a.buildRecord(content)
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) Unmarshall(rawRecord *consensusproto.RawRecord) (rec *AclRecord, err error) {
|
||||
aclRecord := &consensusproto.Record{}
|
||||
err = proto.Unmarshal(rawRecord.Payload, aclRecord)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
pubKey, err := a.keyStorage.PubKeyFromProto(aclRecord.Identity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
aclData := &aclrecordproto.AclData{}
|
||||
err = proto.Unmarshal(aclRecord.Data, aclData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rec = &AclRecord{
|
||||
PrevId: aclRecord.PrevId,
|
||||
Timestamp: aclRecord.Timestamp,
|
||||
Data: aclRecord.Data,
|
||||
Signature: rawRecord.Signature,
|
||||
Identity: pubKey,
|
||||
Model: aclData,
|
||||
}
|
||||
res, err := pubKey.Verify(rawRecord.Payload, rawRecord.Signature)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !res {
|
||||
err = ErrInvalidSignature
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) UnmarshallWithId(rawIdRecord *consensusproto.RawRecordWithId) (rec *AclRecord, err error) {
|
||||
var (
|
||||
rawRec = &consensusproto.RawRecord{}
|
||||
pubKey crypto.PubKey
|
||||
)
|
||||
err = proto.Unmarshal(rawIdRecord.Payload, rawRec)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if rawIdRecord.Id == a.id {
|
||||
aclRoot := &aclrecordproto.AclRoot{}
|
||||
err = proto.Unmarshal(rawRec.Payload, aclRoot)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
pubKey, err = a.keyStorage.PubKeyFromProto(aclRoot.Identity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rec = &AclRecord{
|
||||
Id: rawIdRecord.Id,
|
||||
CurrentReadKeyHash: aclRoot.CurrentReadKeyHash,
|
||||
Timestamp: aclRoot.Timestamp,
|
||||
Signature: rawRec.Signature,
|
||||
Identity: aclRoot.Identity,
|
||||
Identity: pubKey,
|
||||
Model: aclRoot,
|
||||
}
|
||||
} else {
|
||||
aclRecord := &aclrecordproto.AclRecord{}
|
||||
err = a.verifier.VerifyAcceptor(rawRec)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
aclRecord := &consensusproto.Record{}
|
||||
err = proto.Unmarshal(rawRec.Payload, aclRecord)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
rec = &AclRecord{
|
||||
Id: rawIdRecord.Id,
|
||||
PrevId: aclRecord.PrevId,
|
||||
CurrentReadKeyHash: aclRecord.CurrentReadKeyHash,
|
||||
Timestamp: aclRecord.Timestamp,
|
||||
Data: aclRecord.Data,
|
||||
Signature: rawRec.Signature,
|
||||
Identity: aclRecord.Identity,
|
||||
}
|
||||
}
|
||||
|
||||
err = verifyRaw(a.keychain, rawRec, rawIdRecord, rec.Identity)
|
||||
return
|
||||
}
|
||||
|
||||
func verifyRaw(
|
||||
keychain *keychain.Keychain,
|
||||
rawRec *aclrecordproto.RawAclRecord,
|
||||
recWithId *aclrecordproto.RawAclRecordWithId,
|
||||
identity []byte) (err error) {
|
||||
identityKey, err := keychain.GetOrAdd(string(identity))
|
||||
pubKey, err = a.keyStorage.PubKeyFromProto(aclRecord.Identity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
aclData := &aclrecordproto.AclData{}
|
||||
err = proto.Unmarshal(aclRecord.Data, aclData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rec = &AclRecord{
|
||||
Id: rawIdRecord.Id,
|
||||
PrevId: aclRecord.PrevId,
|
||||
Timestamp: aclRecord.Timestamp,
|
||||
Data: aclRecord.Data,
|
||||
Signature: rawRec.Signature,
|
||||
Identity: pubKey,
|
||||
Model: aclData,
|
||||
}
|
||||
}
|
||||
|
||||
err = verifyRaw(pubKey, rawRec, rawIdRecord)
|
||||
return
|
||||
}
|
||||
|
||||
func (a *aclRecordBuilder) BuildRoot(content RootContent) (rec *consensusproto.RawRecordWithId, err error) {
|
||||
rawIdentity, err := content.PrivKey.GetPublic().Raw()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
identity, err := content.PrivKey.GetPublic().Marshall()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
masterKey, err := content.MasterKey.GetPublic().Marshall()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
identitySignature, err := content.MasterKey.Sign(rawIdentity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var timestamp int64
|
||||
if content.EncryptedReadKey != nil {
|
||||
timestamp = time.Now().Unix()
|
||||
}
|
||||
aclRoot := &aclrecordproto.AclRoot{
|
||||
Identity: identity,
|
||||
SpaceId: content.SpaceId,
|
||||
EncryptedReadKey: content.EncryptedReadKey,
|
||||
MasterKey: masterKey,
|
||||
IdentitySignature: identitySignature,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
return marshalAclRoot(aclRoot, content.PrivKey)
|
||||
}
|
||||
|
||||
func verifyRaw(
|
||||
pubKey crypto.PubKey,
|
||||
rawRec *consensusproto.RawRecord,
|
||||
recWithId *consensusproto.RawRecordWithId) (err error) {
|
||||
// verifying signature
|
||||
res, err := identityKey.Verify(rawRec.Payload, rawRec.Signature)
|
||||
res, err := pubKey.Verify(rawRec.Payload, rawRec.Signature)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -175,3 +508,31 @@ func verifyRaw(
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func marshalAclRoot(aclRoot *aclrecordproto.AclRoot, key crypto.PrivKey) (rawWithId *consensusproto.RawRecordWithId, err error) {
|
||||
marshalledRoot, err := aclRoot.Marshal()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
signature, err := key.Sign(marshalledRoot)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
raw := &consensusproto.RawRecord{
|
||||
Payload: marshalledRoot,
|
||||
Signature: signature,
|
||||
}
|
||||
marshalledRaw, err := raw.Marshal()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
aclHeadId, err := cidutil.NewCidFromBytes(marshalledRaw)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
rawWithId = &consensusproto.RawRecordWithId{
|
||||
Payload: marshalledRaw,
|
||||
Id: aclHeadId,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -1,50 +0,0 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
acllistbuilder2 "github.com/anytypeio/any-sync/commonspace/object/acl/testutils/acllistbuilder"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/keychain"
|
||||
"github.com/anytypeio/any-sync/util/cidutil"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAclRecordBuilder_BuildUserJoin(t *testing.T) {
|
||||
st, err := acllistbuilder2.NewListStorageWithTestName("userjoinexample.yml")
|
||||
require.NoError(t, err, "building storage should not result in error")
|
||||
|
||||
testKeychain := st.(*acllistbuilder2.AclListStorageBuilder).GetKeychain()
|
||||
identity := testKeychain.GeneratedIdentities["D"]
|
||||
signPrivKey := testKeychain.SigningKeysByYAMLName["D"]
|
||||
encPrivKey := testKeychain.EncryptionKeysByYAMLName["D"]
|
||||
acc := &accountdata.AccountData{
|
||||
Identity: []byte(identity),
|
||||
SignKey: signPrivKey,
|
||||
EncKey: encPrivKey,
|
||||
}
|
||||
|
||||
aclList, err := BuildAclListWithIdentity(acc, st)
|
||||
require.NoError(t, err, "building acl list should be without error")
|
||||
recordBuilder := newAclRecordBuilder(aclList.Id(), keychain.NewKeychain())
|
||||
rk, err := testKeychain.GetKey("key.Read.EncKey").(*acllistbuilder2.SymKey).Key.Raw()
|
||||
require.NoError(t, err)
|
||||
privKey, err := testKeychain.GetKey("key.Sign.Onetime1").(signingkey.PrivKey).Raw()
|
||||
require.NoError(t, err)
|
||||
|
||||
userJoin, err := recordBuilder.BuildUserJoin(privKey, rk, aclList.AclState())
|
||||
require.NoError(t, err)
|
||||
marshalledJoin, err := userJoin.Marshal()
|
||||
require.NoError(t, err)
|
||||
id, err := cidutil.NewCidFromBytes(marshalledJoin)
|
||||
require.NoError(t, err)
|
||||
rawRec := &aclrecordproto.RawAclRecordWithId{
|
||||
Payload: marshalledJoin,
|
||||
Id: id,
|
||||
}
|
||||
res, err := aclList.AddRawRecord(rawRec)
|
||||
require.True(t, res)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, aclrecordproto.AclUserPermissions_Writer, aclList.AclState().UserStates()[identity].Permissions)
|
||||
}
|
||||
@ -1,120 +1,142 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/app/logger"
|
||||
aclrecordproto2 "github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/keychain"
|
||||
"github.com/anytypeio/any-sync/util/keys"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/encryptionkey"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
"github.com/anytypeio/any-sync/util/keys/symmetric"
|
||||
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"go.uber.org/zap"
|
||||
"hash/fnv"
|
||||
)
|
||||
|
||||
var log = logger.NewNamed("acllist").Sugar()
|
||||
var log = logger.NewNamedSugared("common.commonspace.acllist")
|
||||
|
||||
var (
|
||||
ErrNoSuchUser = errors.New("no such user")
|
||||
ErrNoSuchAccount = errors.New("no such account")
|
||||
ErrPendingRequest = errors.New("already exists pending request")
|
||||
ErrUnexpectedContentType = errors.New("unexpected content type")
|
||||
ErrIncorrectIdentity = errors.New("incorrect identity")
|
||||
ErrIncorrectInviteKey = errors.New("incorrect invite key")
|
||||
ErrFailedToDecrypt = errors.New("failed to decrypt key")
|
||||
ErrUserRemoved = errors.New("user was removed from the document")
|
||||
ErrDocumentForbidden = errors.New("your user was forbidden access to the document")
|
||||
ErrUserAlreadyExists = errors.New("user already exists")
|
||||
ErrNoSuchRecord = errors.New("no such record")
|
||||
ErrNoSuchRequest = errors.New("no such request")
|
||||
ErrNoSuchInvite = errors.New("no such invite")
|
||||
ErrOldInvite = errors.New("invite is too old")
|
||||
ErrInsufficientPermissions = errors.New("insufficient permissions")
|
||||
ErrIsOwner = errors.New("can't be made by owner")
|
||||
ErrIncorrectNumberOfAccounts = errors.New("incorrect number of accounts")
|
||||
ErrDuplicateAccounts = errors.New("duplicate accounts")
|
||||
ErrNoReadKey = errors.New("acl state doesn't have a read key")
|
||||
ErrIncorrectReadKey = errors.New("incorrect read key")
|
||||
ErrInvalidSignature = errors.New("signature is invalid")
|
||||
ErrIncorrectRoot = errors.New("incorrect root")
|
||||
ErrIncorrectRecordSequence = errors.New("incorrect prev id of a record")
|
||||
)
|
||||
|
||||
type UserPermissionPair struct {
|
||||
Identity string
|
||||
Permission aclrecordproto2.AclUserPermissions
|
||||
Identity crypto.PubKey
|
||||
Permission aclrecordproto.AclUserPermissions
|
||||
}
|
||||
|
||||
type AclState struct {
|
||||
id string
|
||||
currentReadKeyHash uint64
|
||||
userReadKeys map[uint64]*symmetric.Key
|
||||
userStates map[string]*aclrecordproto2.AclUserState
|
||||
userInvites map[string]*aclrecordproto2.AclUserInvite
|
||||
encryptionKey encryptionkey.PrivKey
|
||||
signingKey signingkey.PrivKey
|
||||
currentReadKeyId string
|
||||
// userReadKeys is a map recordId -> read key which tells us about every read key
|
||||
userReadKeys map[string]crypto.SymKey
|
||||
// userStates is a map pubKey -> state which defines current user state
|
||||
userStates map[string]AclUserState
|
||||
// statesAtRecord is a map recordId -> state which define user state at particular record
|
||||
// probably this can grow rather large at some point, so we can maybe optimise later to have:
|
||||
// - map pubKey -> []recordIds (where recordIds is an array where such identity permissions were changed)
|
||||
statesAtRecord map[string][]AclUserState
|
||||
// inviteKeys is a map recordId -> invite
|
||||
inviteKeys map[string]crypto.PubKey
|
||||
// requestRecords is a map recordId -> RequestRecord
|
||||
requestRecords map[string]RequestRecord
|
||||
// pendingRequests is a map pubKey -> recordId
|
||||
pendingRequests map[string]string
|
||||
key crypto.PrivKey
|
||||
pubKey crypto.PubKey
|
||||
keyStore crypto.KeyStorage
|
||||
totalReadKeys int
|
||||
|
||||
identity string
|
||||
permissionsAtRecord map[string][]UserPermissionPair
|
||||
lastRecordId string
|
||||
|
||||
keychain *keychain.Keychain
|
||||
contentValidator ContentValidator
|
||||
}
|
||||
|
||||
func newAclStateWithKeys(
|
||||
id string,
|
||||
signingKey signingkey.PrivKey,
|
||||
encryptionKey encryptionkey.PrivKey) (*AclState, error) {
|
||||
identity, err := signingKey.GetPublic().Raw()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AclState{
|
||||
key crypto.PrivKey) (*AclState, error) {
|
||||
st := &AclState{
|
||||
id: id,
|
||||
identity: string(identity),
|
||||
signingKey: signingKey,
|
||||
encryptionKey: encryptionKey,
|
||||
userReadKeys: make(map[uint64]*symmetric.Key),
|
||||
userStates: make(map[string]*aclrecordproto2.AclUserState),
|
||||
userInvites: make(map[string]*aclrecordproto2.AclUserInvite),
|
||||
permissionsAtRecord: make(map[string][]UserPermissionPair),
|
||||
}, nil
|
||||
key: key,
|
||||
pubKey: key.GetPublic(),
|
||||
userReadKeys: make(map[string]crypto.SymKey),
|
||||
userStates: make(map[string]AclUserState),
|
||||
statesAtRecord: make(map[string][]AclUserState),
|
||||
inviteKeys: make(map[string]crypto.PubKey),
|
||||
requestRecords: make(map[string]RequestRecord),
|
||||
pendingRequests: make(map[string]string),
|
||||
keyStore: crypto.NewKeyStorage(),
|
||||
}
|
||||
st.contentValidator = &contentValidator{
|
||||
keyStore: st.keyStore,
|
||||
aclState: st,
|
||||
}
|
||||
return st, nil
|
||||
}
|
||||
|
||||
func newAclState(id string) *AclState {
|
||||
return &AclState{
|
||||
st := &AclState{
|
||||
id: id,
|
||||
userReadKeys: make(map[uint64]*symmetric.Key),
|
||||
userStates: make(map[string]*aclrecordproto2.AclUserState),
|
||||
userInvites: make(map[string]*aclrecordproto2.AclUserInvite),
|
||||
permissionsAtRecord: make(map[string][]UserPermissionPair),
|
||||
userReadKeys: make(map[string]crypto.SymKey),
|
||||
userStates: make(map[string]AclUserState),
|
||||
statesAtRecord: make(map[string][]AclUserState),
|
||||
inviteKeys: make(map[string]crypto.PubKey),
|
||||
requestRecords: make(map[string]RequestRecord),
|
||||
pendingRequests: make(map[string]string),
|
||||
keyStore: crypto.NewKeyStorage(),
|
||||
}
|
||||
st.contentValidator = &contentValidator{
|
||||
keyStore: st.keyStore,
|
||||
aclState: st,
|
||||
}
|
||||
return st
|
||||
}
|
||||
|
||||
func (st *AclState) CurrentReadKeyHash() uint64 {
|
||||
return st.currentReadKeyHash
|
||||
func (st *AclState) Validator() ContentValidator {
|
||||
return st.contentValidator
|
||||
}
|
||||
|
||||
func (st *AclState) CurrentReadKey() (*symmetric.Key, error) {
|
||||
key, exists := st.userReadKeys[st.currentReadKeyHash]
|
||||
func (st *AclState) CurrentReadKeyId() string {
|
||||
return st.currentReadKeyId
|
||||
}
|
||||
|
||||
func (st *AclState) CurrentReadKey() (crypto.SymKey, error) {
|
||||
key, exists := st.userReadKeys[st.CurrentReadKeyId()]
|
||||
if !exists {
|
||||
return nil, ErrNoReadKey
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (st *AclState) UserReadKeys() map[uint64]*symmetric.Key {
|
||||
func (st *AclState) UserReadKeys() map[string]crypto.SymKey {
|
||||
return st.userReadKeys
|
||||
}
|
||||
|
||||
func (st *AclState) PermissionsAtRecord(id string, identity string) (UserPermissionPair, error) {
|
||||
permissions, ok := st.permissionsAtRecord[id]
|
||||
func (st *AclState) StateAtRecord(id string, pubKey crypto.PubKey) (AclUserState, error) {
|
||||
userState, ok := st.statesAtRecord[id]
|
||||
if !ok {
|
||||
log.Errorf("missing record at id %s", id)
|
||||
return UserPermissionPair{}, ErrNoSuchRecord
|
||||
return AclUserState{}, ErrNoSuchRecord
|
||||
}
|
||||
|
||||
for _, perm := range permissions {
|
||||
if perm.Identity == identity {
|
||||
for _, perm := range userState {
|
||||
if perm.PubKey.Equals(pubKey) {
|
||||
return perm, nil
|
||||
}
|
||||
}
|
||||
return UserPermissionPair{}, ErrNoSuchUser
|
||||
return AclUserState{}, ErrNoSuchAccount
|
||||
}
|
||||
|
||||
func (st *AclState) applyRecord(record *AclRecord) (err error) {
|
||||
@ -127,338 +149,316 @@ func (st *AclState) applyRecord(record *AclRecord) (err error) {
|
||||
err = ErrIncorrectRecordSequence
|
||||
return
|
||||
}
|
||||
// if the record is root record
|
||||
if record.Id == st.id {
|
||||
root, ok := record.Model.(*aclrecordproto2.AclRoot)
|
||||
if !ok {
|
||||
return ErrIncorrectRoot
|
||||
}
|
||||
err = st.applyRoot(root)
|
||||
err = st.applyRoot(record)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
st.permissionsAtRecord[record.Id] = []UserPermissionPair{
|
||||
{Identity: string(root.Identity), Permission: aclrecordproto2.AclUserPermissions_Admin},
|
||||
st.statesAtRecord[record.Id] = []AclUserState{
|
||||
st.userStates[mapKeyFromPubKey(record.Identity)],
|
||||
}
|
||||
return
|
||||
}
|
||||
aclData := &aclrecordproto2.AclData{}
|
||||
|
||||
if record.Model != nil {
|
||||
aclData = record.Model.(*aclrecordproto2.AclData)
|
||||
} else {
|
||||
// if the model is not cached
|
||||
if record.Model == nil {
|
||||
aclData := &aclrecordproto.AclData{}
|
||||
err = proto.Unmarshal(record.Data, aclData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
record.Model = aclData
|
||||
}
|
||||
|
||||
err = st.applyChangeData(aclData, record.CurrentReadKeyHash, record.Identity)
|
||||
// applying records contents
|
||||
err = st.applyChangeData(record)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// getting all permissions for users at record
|
||||
var permissions []UserPermissionPair
|
||||
// getting all states for users at record and saving them
|
||||
var states []AclUserState
|
||||
for _, state := range st.userStates {
|
||||
permission := UserPermissionPair{
|
||||
Identity: string(state.Identity),
|
||||
Permission: state.Permissions,
|
||||
states = append(states, state)
|
||||
}
|
||||
permissions = append(permissions, permission)
|
||||
}
|
||||
|
||||
st.permissionsAtRecord[record.Id] = permissions
|
||||
st.statesAtRecord[record.Id] = states
|
||||
return
|
||||
}
|
||||
|
||||
func (st *AclState) applyRoot(root *aclrecordproto2.AclRoot) (err error) {
|
||||
if st.signingKey != nil && st.encryptionKey != nil && st.identity == string(root.Identity) {
|
||||
err = st.saveReadKeyFromRoot(root)
|
||||
func (st *AclState) applyRoot(record *AclRecord) (err error) {
|
||||
if st.key != nil && st.pubKey.Equals(record.Identity) {
|
||||
err = st.saveReadKeyFromRoot(record)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// adding user to the list
|
||||
userState := &aclrecordproto2.AclUserState{
|
||||
Identity: root.Identity,
|
||||
EncryptionKey: root.EncryptionKey,
|
||||
Permissions: aclrecordproto2.AclUserPermissions_Admin,
|
||||
userState := AclUserState{
|
||||
PubKey: record.Identity,
|
||||
Permissions: AclPermissions(aclrecordproto.AclUserPermissions_Owner),
|
||||
}
|
||||
st.currentReadKeyHash = root.CurrentReadKeyHash
|
||||
st.userStates[string(root.Identity)] = userState
|
||||
st.currentReadKeyId = record.Id
|
||||
st.userStates[mapKeyFromPubKey(record.Identity)] = userState
|
||||
st.totalReadKeys++
|
||||
return
|
||||
}
|
||||
|
||||
func (st *AclState) saveReadKeyFromRoot(root *aclrecordproto2.AclRoot) (err error) {
|
||||
var readKey *symmetric.Key
|
||||
if len(root.GetDerivationScheme()) != 0 {
|
||||
var encPrivKey []byte
|
||||
encPrivKey, err = st.encryptionKey.Raw()
|
||||
if err != nil {
|
||||
return
|
||||
func (st *AclState) saveReadKeyFromRoot(record *AclRecord) (err error) {
|
||||
var readKey crypto.SymKey
|
||||
root, ok := record.Model.(*aclrecordproto.AclRoot)
|
||||
if !ok {
|
||||
return ErrIncorrectRoot
|
||||
}
|
||||
var signPrivKey []byte
|
||||
signPrivKey, err = st.signingKey.Raw()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
readKey, err = aclrecordproto2.AclReadKeyDerive(signPrivKey, encPrivKey)
|
||||
if root.EncryptedReadKey == nil {
|
||||
readKey, err = st.deriveKey()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
readKey, _, err = st.decryptReadKeyAndHash(root.EncryptedReadKey)
|
||||
readKey, err = st.decryptReadKey(root.EncryptedReadKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
hasher := fnv.New64()
|
||||
_, err = hasher.Write(readKey.Bytes())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if hasher.Sum64() != root.CurrentReadKeyHash {
|
||||
return ErrIncorrectRoot
|
||||
}
|
||||
st.userReadKeys[root.CurrentReadKeyHash] = readKey
|
||||
|
||||
st.userReadKeys[record.Id] = readKey
|
||||
return
|
||||
}
|
||||
|
||||
func (st *AclState) applyChangeData(changeData *aclrecordproto2.AclData, hash uint64, identity []byte) (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if hash != st.currentReadKeyHash {
|
||||
st.totalReadKeys++
|
||||
st.currentReadKeyHash = hash
|
||||
}
|
||||
}()
|
||||
|
||||
if !st.isUserJoin(changeData) {
|
||||
// we check signature when we add this to the List, so no need to do it here
|
||||
if _, exists := st.userStates[string(identity)]; !exists {
|
||||
err = ErrNoSuchUser
|
||||
return
|
||||
}
|
||||
|
||||
if !st.HasPermission(identity, aclrecordproto2.AclUserPermissions_Admin) {
|
||||
err = fmt.Errorf("user %s must have admin permissions", identity)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
for _, ch := range changeData.GetAclContent() {
|
||||
if err = st.applyChangeContent(ch); err != nil {
|
||||
func (st *AclState) applyChangeData(record *AclRecord) (err error) {
|
||||
model := record.Model.(*aclrecordproto.AclData)
|
||||
for _, ch := range model.GetAclContent() {
|
||||
if err = st.applyChangeContent(ch, record.Id, record.Identity); err != nil {
|
||||
log.Info("error while applying changes: %v; ignore", zap.Error(err))
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *AclState) applyChangeContent(ch *aclrecordproto2.AclContentValue) error {
|
||||
func (st *AclState) applyChangeContent(ch *aclrecordproto.AclContentValue, recordId string, authorIdentity crypto.PubKey) error {
|
||||
switch {
|
||||
case ch.GetUserPermissionChange() != nil:
|
||||
return st.applyUserPermissionChange(ch.GetUserPermissionChange())
|
||||
case ch.GetUserAdd() != nil:
|
||||
return st.applyUserAdd(ch.GetUserAdd())
|
||||
case ch.GetUserRemove() != nil:
|
||||
return st.applyUserRemove(ch.GetUserRemove())
|
||||
case ch.GetUserInvite() != nil:
|
||||
return st.applyUserInvite(ch.GetUserInvite())
|
||||
case ch.GetUserJoin() != nil:
|
||||
return st.applyUserJoin(ch.GetUserJoin())
|
||||
case ch.GetPermissionChange() != nil:
|
||||
return st.applyPermissionChange(ch.GetPermissionChange(), recordId, authorIdentity)
|
||||
case ch.GetInvite() != nil:
|
||||
return st.applyInvite(ch.GetInvite(), recordId, authorIdentity)
|
||||
case ch.GetInviteRevoke() != nil:
|
||||
return st.applyInviteRevoke(ch.GetInviteRevoke(), recordId, authorIdentity)
|
||||
case ch.GetRequestJoin() != nil:
|
||||
return st.applyRequestJoin(ch.GetRequestJoin(), recordId, authorIdentity)
|
||||
case ch.GetRequestAccept() != nil:
|
||||
return st.applyRequestAccept(ch.GetRequestAccept(), recordId, authorIdentity)
|
||||
case ch.GetRequestDecline() != nil:
|
||||
return st.applyRequestDecline(ch.GetRequestDecline(), recordId, authorIdentity)
|
||||
case ch.GetAccountRemove() != nil:
|
||||
return st.applyAccountRemove(ch.GetAccountRemove(), recordId, authorIdentity)
|
||||
case ch.GetReadKeyChange() != nil:
|
||||
return st.applyReadKeyChange(ch.GetReadKeyChange(), recordId, authorIdentity)
|
||||
case ch.GetAccountRequestRemove() != nil:
|
||||
return st.applyRequestRemove(ch.GetAccountRequestRemove(), recordId, authorIdentity)
|
||||
default:
|
||||
return fmt.Errorf("unexpected change type: %v", ch)
|
||||
return ErrUnexpectedContentType
|
||||
}
|
||||
}
|
||||
|
||||
func (st *AclState) applyUserPermissionChange(ch *aclrecordproto2.AclUserPermissionChange) error {
|
||||
chIdentity := string(ch.Identity)
|
||||
state, exists := st.userStates[chIdentity]
|
||||
if !exists {
|
||||
return ErrNoSuchUser
|
||||
func (st *AclState) applyPermissionChange(ch *aclrecordproto.AclAccountPermissionChange, recordId string, authorIdentity crypto.PubKey) error {
|
||||
chIdentity, err := st.keyStore.PubKeyFromProto(ch.Identity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state.Permissions = ch.Permissions
|
||||
err = st.contentValidator.ValidatePermissionChange(ch, authorIdentity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stringKey := mapKeyFromPubKey(chIdentity)
|
||||
state, _ := st.userStates[stringKey]
|
||||
state.Permissions = AclPermissions(ch.Permissions)
|
||||
st.userStates[stringKey] = state
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *AclState) applyUserInvite(ch *aclrecordproto2.AclUserInvite) error {
|
||||
st.userInvites[string(ch.AcceptPublicKey)] = ch
|
||||
func (st *AclState) applyInvite(ch *aclrecordproto.AclAccountInvite, recordId string, authorIdentity crypto.PubKey) error {
|
||||
inviteKey, err := st.keyStore.PubKeyFromProto(ch.InviteKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = st.contentValidator.ValidateInvite(ch, authorIdentity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
st.inviteKeys[recordId] = inviteKey
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *AclState) applyUserJoin(ch *aclrecordproto2.AclUserJoin) error {
|
||||
invite, exists := st.userInvites[string(ch.AcceptPubKey)]
|
||||
if !exists {
|
||||
return fmt.Errorf("no such invite with such public key %s", keys.EncodeBytesToString(ch.AcceptPubKey))
|
||||
}
|
||||
chIdentity := string(ch.Identity)
|
||||
|
||||
if _, exists = st.userStates[chIdentity]; exists {
|
||||
return ErrUserAlreadyExists
|
||||
}
|
||||
|
||||
// validating signature
|
||||
signature := ch.GetAcceptSignature()
|
||||
verificationKey, err := signingkey.NewSigningEd25519PubKeyFromBytes(invite.AcceptPublicKey)
|
||||
func (st *AclState) applyInviteRevoke(ch *aclrecordproto.AclAccountInviteRevoke, recordId string, authorIdentity crypto.PubKey) error {
|
||||
err := st.contentValidator.ValidateInviteRevoke(ch, authorIdentity)
|
||||
if err != nil {
|
||||
return fmt.Errorf("public key verifying invite accepts is given in incorrect format: %v", err)
|
||||
return err
|
||||
}
|
||||
delete(st.inviteKeys, ch.InviteRecordId)
|
||||
return nil
|
||||
}
|
||||
|
||||
res, err := verificationKey.Verify(ch.Identity, signature)
|
||||
func (st *AclState) applyRequestJoin(ch *aclrecordproto.AclAccountRequestJoin, recordId string, authorIdentity crypto.PubKey) error {
|
||||
err := st.contentValidator.ValidateRequestJoin(ch, authorIdentity)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verification returned error: %w", err)
|
||||
return err
|
||||
}
|
||||
if !res {
|
||||
return ErrInvalidSignature
|
||||
st.pendingRequests[mapKeyFromPubKey(authorIdentity)] = recordId
|
||||
st.requestRecords[recordId] = RequestRecord{
|
||||
RequestIdentity: authorIdentity,
|
||||
RequestMetadata: ch.Metadata,
|
||||
Type: RequestTypeJoin,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// if ourselves -> we need to decrypt the read keys
|
||||
if st.identity == chIdentity {
|
||||
func (st *AclState) applyRequestAccept(ch *aclrecordproto.AclAccountRequestAccept, recordId string, authorIdentity crypto.PubKey) error {
|
||||
err := st.contentValidator.ValidateRequestAccept(ch, authorIdentity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
acceptIdentity, err := st.keyStore.PubKeyFromProto(ch.Identity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
record, _ := st.requestRecords[ch.RequestRecordId]
|
||||
st.userStates[mapKeyFromPubKey(acceptIdentity)] = AclUserState{
|
||||
PubKey: acceptIdentity,
|
||||
Permissions: AclPermissions(ch.Permissions),
|
||||
RequestMetadata: record.RequestMetadata,
|
||||
}
|
||||
delete(st.pendingRequests, mapKeyFromPubKey(st.requestRecords[ch.RequestRecordId].RequestIdentity))
|
||||
if !st.pubKey.Equals(acceptIdentity) {
|
||||
return nil
|
||||
}
|
||||
for _, key := range ch.EncryptedReadKeys {
|
||||
key, hash, err := st.decryptReadKeyAndHash(key)
|
||||
decrypted, err := st.key.Decrypt(key.EncryptedReadKey)
|
||||
if err != nil {
|
||||
return ErrFailedToDecrypt
|
||||
return err
|
||||
}
|
||||
|
||||
st.userReadKeys[hash] = key
|
||||
}
|
||||
}
|
||||
|
||||
// adding user to the list
|
||||
userState := &aclrecordproto2.AclUserState{
|
||||
Identity: ch.Identity,
|
||||
EncryptionKey: ch.EncryptionKey,
|
||||
Permissions: invite.Permissions,
|
||||
}
|
||||
st.userStates[chIdentity] = userState
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *AclState) applyUserAdd(ch *aclrecordproto2.AclUserAdd) error {
|
||||
chIdentity := string(ch.Identity)
|
||||
if _, exists := st.userStates[chIdentity]; exists {
|
||||
return ErrUserAlreadyExists
|
||||
}
|
||||
|
||||
st.userStates[chIdentity] = &aclrecordproto2.AclUserState{
|
||||
Identity: ch.Identity,
|
||||
EncryptionKey: ch.EncryptionKey,
|
||||
Permissions: ch.Permissions,
|
||||
}
|
||||
|
||||
if chIdentity == st.identity {
|
||||
for _, key := range ch.EncryptedReadKeys {
|
||||
key, hash, err := st.decryptReadKeyAndHash(key)
|
||||
sym, err := crypto.UnmarshallAESKey(decrypted)
|
||||
if err != nil {
|
||||
return ErrFailedToDecrypt
|
||||
}
|
||||
|
||||
st.userReadKeys[hash] = key
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *AclState) applyUserRemove(ch *aclrecordproto2.AclUserRemove) error {
|
||||
chIdentity := string(ch.Identity)
|
||||
if chIdentity == st.identity {
|
||||
return ErrDocumentForbidden
|
||||
}
|
||||
|
||||
if _, exists := st.userStates[chIdentity]; !exists {
|
||||
return ErrNoSuchUser
|
||||
}
|
||||
|
||||
delete(st.userStates, chIdentity)
|
||||
|
||||
for _, replace := range ch.ReadKeyReplaces {
|
||||
repIdentity := string(replace.Identity)
|
||||
// if this is our identity then we have to decrypt the key
|
||||
if repIdentity == st.identity {
|
||||
key, hash, err := st.decryptReadKeyAndHash(replace.EncryptedReadKey)
|
||||
if err != nil {
|
||||
return ErrFailedToDecrypt
|
||||
}
|
||||
|
||||
st.userReadKeys[hash] = key
|
||||
break
|
||||
return err
|
||||
}
|
||||
st.userReadKeys[key.RecordId] = sym
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *AclState) decryptReadKeyAndHash(msg []byte) (*symmetric.Key, uint64, error) {
|
||||
decrypted, err := st.encryptionKey.Decrypt(msg)
|
||||
func (st *AclState) applyRequestDecline(ch *aclrecordproto.AclAccountRequestDecline, recordId string, authorIdentity crypto.PubKey) error {
|
||||
err := st.contentValidator.ValidateRequestDecline(ch, authorIdentity)
|
||||
if err != nil {
|
||||
return nil, 0, ErrFailedToDecrypt
|
||||
return err
|
||||
}
|
||||
|
||||
key, err := symmetric.FromBytes(decrypted)
|
||||
if err != nil {
|
||||
return nil, 0, ErrFailedToDecrypt
|
||||
}
|
||||
|
||||
hasher := fnv.New64()
|
||||
hasher.Write(decrypted)
|
||||
return key, hasher.Sum64(), nil
|
||||
delete(st.pendingRequests, mapKeyFromPubKey(st.requestRecords[ch.RequestRecordId].RequestIdentity))
|
||||
delete(st.requestRecords, ch.RequestRecordId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *AclState) HasPermission(identity []byte, permission aclrecordproto2.AclUserPermissions) bool {
|
||||
state, exists := st.userStates[string(identity)]
|
||||
func (st *AclState) applyRequestRemove(ch *aclrecordproto.AclAccountRequestRemove, recordId string, authorIdentity crypto.PubKey) error {
|
||||
err := st.contentValidator.ValidateRequestRemove(ch, authorIdentity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
st.requestRecords[recordId] = RequestRecord{
|
||||
RequestIdentity: authorIdentity,
|
||||
Type: RequestTypeRemove,
|
||||
}
|
||||
st.pendingRequests[mapKeyFromPubKey(authorIdentity)] = recordId
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *AclState) applyAccountRemove(ch *aclrecordproto.AclAccountRemove, recordId string, authorIdentity crypto.PubKey) error {
|
||||
err := st.contentValidator.ValidateAccountRemove(ch, authorIdentity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rawIdentity := range ch.Identities {
|
||||
identity, err := st.keyStore.PubKeyFromProto(rawIdentity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idKey := mapKeyFromPubKey(identity)
|
||||
delete(st.userStates, idKey)
|
||||
delete(st.pendingRequests, idKey)
|
||||
}
|
||||
return st.updateReadKey(ch.AccountKeys, recordId)
|
||||
}
|
||||
|
||||
func (st *AclState) applyReadKeyChange(ch *aclrecordproto.AclReadKeyChange, recordId string, authorIdentity crypto.PubKey) error {
|
||||
err := st.contentValidator.ValidateReadKeyChange(ch, authorIdentity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return st.updateReadKey(ch.AccountKeys, recordId)
|
||||
}
|
||||
|
||||
func (st *AclState) updateReadKey(keys []*aclrecordproto.AclEncryptedReadKey, recordId string) error {
|
||||
for _, accKey := range keys {
|
||||
identity, _ := st.keyStore.PubKeyFromProto(accKey.Identity)
|
||||
if st.pubKey.Equals(identity) {
|
||||
res, err := st.decryptReadKey(accKey.EncryptedReadKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
st.userReadKeys[recordId] = res
|
||||
}
|
||||
}
|
||||
st.currentReadKeyId = recordId
|
||||
return nil
|
||||
}
|
||||
|
||||
func (st *AclState) decryptReadKey(msg []byte) (crypto.SymKey, error) {
|
||||
decrypted, err := st.key.Decrypt(msg)
|
||||
if err != nil {
|
||||
return nil, ErrFailedToDecrypt
|
||||
}
|
||||
key, err := crypto.UnmarshallAESKey(decrypted)
|
||||
if err != nil {
|
||||
return nil, ErrFailedToDecrypt
|
||||
}
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func (st *AclState) Permissions(identity crypto.PubKey) AclPermissions {
|
||||
state, exists := st.userStates[mapKeyFromPubKey(identity)]
|
||||
if !exists {
|
||||
return false
|
||||
return AclPermissions(aclrecordproto.AclUserPermissions_None)
|
||||
}
|
||||
|
||||
return state.Permissions == permission
|
||||
return state.Permissions
|
||||
}
|
||||
|
||||
func (st *AclState) isUserJoin(data *aclrecordproto2.AclData) bool {
|
||||
// if we have a UserJoin, then it should always be the first one applied
|
||||
return data.GetAclContent() != nil && data.GetAclContent()[0].GetUserJoin() != nil
|
||||
}
|
||||
|
||||
func (st *AclState) isUserAdd(data *aclrecordproto2.AclData, identity []byte) bool {
|
||||
// if we have a UserAdd, then it should always be the first one applied
|
||||
userAdd := data.GetAclContent()[0].GetUserAdd()
|
||||
return data.GetAclContent() != nil && userAdd != nil && bytes.Compare(userAdd.GetIdentity(), identity) == 0
|
||||
}
|
||||
|
||||
func (st *AclState) UserStates() map[string]*aclrecordproto2.AclUserState {
|
||||
return st.userStates
|
||||
}
|
||||
|
||||
func (st *AclState) Invite(acceptPubKey []byte) (invite *aclrecordproto2.AclUserInvite, err error) {
|
||||
invite, exists := st.userInvites[string(acceptPubKey)]
|
||||
if !exists {
|
||||
err = ErrNoSuchInvite
|
||||
return
|
||||
func (st *AclState) JoinRecords() (records []RequestRecord) {
|
||||
for _, recId := range st.pendingRequests {
|
||||
rec := st.requestRecords[recId]
|
||||
if rec.Type == RequestTypeJoin {
|
||||
records = append(records, rec)
|
||||
}
|
||||
if len(invite.EncryptedReadKeys) != st.totalReadKeys {
|
||||
err = ErrOldInvite
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (st *AclState) UserKeys() (encKey encryptionkey.PrivKey, signKey signingkey.PrivKey) {
|
||||
return st.encryptionKey, st.signingKey
|
||||
}
|
||||
|
||||
func (st *AclState) Identity() []byte {
|
||||
return []byte(st.identity)
|
||||
func (st *AclState) RemoveRecords() (records []RequestRecord) {
|
||||
for _, recId := range st.pendingRequests {
|
||||
rec := st.requestRecords[recId]
|
||||
if rec.Type == RequestTypeRemove {
|
||||
records = append(records, rec)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (st *AclState) LastRecordId() string {
|
||||
return st.lastRecordId
|
||||
}
|
||||
|
||||
func (st *AclState) deriveKey() (crypto.SymKey, error) {
|
||||
keyBytes, err := st.key.Raw()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return crypto.DeriveSymmetricKey(keyBytes, crypto.AnysyncSpacePath)
|
||||
}
|
||||
|
||||
func mapKeyFromPubKey(pubKey crypto.PubKey) string {
|
||||
return string(pubKey.Storage())
|
||||
}
|
||||
|
||||
@ -1,21 +1,18 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/encryptionkey"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
"github.com/anyproto/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
)
|
||||
|
||||
type aclStateBuilder struct {
|
||||
signPrivKey signingkey.PrivKey
|
||||
encPrivKey encryptionkey.PrivKey
|
||||
privKey crypto.PrivKey
|
||||
id string
|
||||
}
|
||||
|
||||
func newAclStateBuilderWithIdentity(accountData *accountdata.AccountData) *aclStateBuilder {
|
||||
func newAclStateBuilderWithIdentity(keys *accountdata.AccountKeys) *aclStateBuilder {
|
||||
return &aclStateBuilder{
|
||||
signPrivKey: accountData.SignKey,
|
||||
encPrivKey: accountData.EncKey,
|
||||
privKey: keys.SignKey,
|
||||
}
|
||||
}
|
||||
|
||||
@ -28,8 +25,8 @@ func (sb *aclStateBuilder) Init(id string) {
|
||||
}
|
||||
|
||||
func (sb *aclStateBuilder) Build(records []*AclRecord) (state *AclState, err error) {
|
||||
if sb.encPrivKey != nil && sb.signPrivKey != nil {
|
||||
state, err = newAclStateWithKeys(sb.id, sb.signPrivKey, sb.encPrivKey)
|
||||
if sb.privKey != nil {
|
||||
state, err = newAclStateWithKeys(sb.id, sb.privKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
@ -1,20 +1,25 @@
|
||||
//go:generate mockgen -destination mock_list/mock_list.go github.com/anytypeio/any-sync/commonspace/object/acl/list AclList
|
||||
//go:generate mockgen -destination mock_list/mock_list.go github.com/anyproto/any-sync/commonspace/object/acl/list AclList
|
||||
package list
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/liststorage"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/keychain"
|
||||
"sync"
|
||||
|
||||
"github.com/anyproto/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
"github.com/anyproto/any-sync/util/cidutil"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
)
|
||||
|
||||
type IterFunc = func(record *AclRecord) (IsContinue bool)
|
||||
|
||||
var ErrIncorrectCID = errors.New("incorrect CID")
|
||||
var (
|
||||
ErrIncorrectCID = errors.New("incorrect CID")
|
||||
ErrRecordAlreadyExists = errors.New("record already exists")
|
||||
)
|
||||
|
||||
type RWLocker interface {
|
||||
sync.Locker
|
||||
@ -22,48 +27,97 @@ type RWLocker interface {
|
||||
RUnlock()
|
||||
}
|
||||
|
||||
type AcceptorVerifier interface {
|
||||
VerifyAcceptor(rec *consensusproto.RawRecord) (err error)
|
||||
}
|
||||
|
||||
type NoOpAcceptorVerifier struct {
|
||||
}
|
||||
|
||||
func (n NoOpAcceptorVerifier) VerifyAcceptor(rec *consensusproto.RawRecord) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
type AclList interface {
|
||||
RWLocker
|
||||
Id() string
|
||||
Root() *aclrecordproto.RawAclRecordWithId
|
||||
Root() *consensusproto.RawRecordWithId
|
||||
Records() []*AclRecord
|
||||
AclState() *AclState
|
||||
IsAfter(first string, second string) (bool, error)
|
||||
HasHead(head string) bool
|
||||
Head() *AclRecord
|
||||
|
||||
RecordsAfter(ctx context.Context, id string) (records []*consensusproto.RawRecordWithId, err error)
|
||||
Get(id string) (*AclRecord, error)
|
||||
GetIndex(idx int) (*AclRecord, error)
|
||||
Iterate(iterFunc IterFunc)
|
||||
IterateFrom(startId string, iterFunc IterFunc)
|
||||
|
||||
AddRawRecord(rawRec *aclrecordproto.RawAclRecordWithId) (added bool, err error)
|
||||
KeyStorage() crypto.KeyStorage
|
||||
RecordBuilder() AclRecordBuilder
|
||||
|
||||
Close() (err error)
|
||||
ValidateRawRecord(record *consensusproto.RawRecord) (err error)
|
||||
AddRawRecord(rawRec *consensusproto.RawRecordWithId) (err error)
|
||||
AddRawRecords(rawRecords []*consensusproto.RawRecordWithId) (err error)
|
||||
|
||||
Close(ctx context.Context) (err error)
|
||||
}
|
||||
|
||||
type aclList struct {
|
||||
root *aclrecordproto.RawAclRecordWithId
|
||||
root *consensusproto.RawRecordWithId
|
||||
records []*AclRecord
|
||||
indexes map[string]int
|
||||
id string
|
||||
|
||||
stateBuilder *aclStateBuilder
|
||||
recordBuilder AclRecordBuilder
|
||||
keyStorage crypto.KeyStorage
|
||||
aclState *AclState
|
||||
keychain *keychain.Keychain
|
||||
storage liststorage.ListStorage
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func BuildAclListWithIdentity(acc *accountdata.AccountData, storage liststorage.ListStorage) (AclList, error) {
|
||||
builder := newAclStateBuilderWithIdentity(acc)
|
||||
return build(storage.Id(), builder, newAclRecordBuilder(storage.Id(), keychain.NewKeychain()), storage)
|
||||
type internalDeps struct {
|
||||
storage liststorage.ListStorage
|
||||
keyStorage crypto.KeyStorage
|
||||
stateBuilder *aclStateBuilder
|
||||
recordBuilder AclRecordBuilder
|
||||
acceptorVerifier AcceptorVerifier
|
||||
}
|
||||
|
||||
func BuildAclList(storage liststorage.ListStorage) (AclList, error) {
|
||||
return build(storage.Id(), newAclStateBuilder(), newAclRecordBuilder(storage.Id(), keychain.NewKeychain()), storage)
|
||||
func BuildAclListWithIdentity(acc *accountdata.AccountKeys, storage liststorage.ListStorage, verifier AcceptorVerifier) (AclList, error) {
|
||||
keyStorage := crypto.NewKeyStorage()
|
||||
deps := internalDeps{
|
||||
storage: storage,
|
||||
keyStorage: keyStorage,
|
||||
stateBuilder: newAclStateBuilderWithIdentity(acc),
|
||||
recordBuilder: NewAclRecordBuilder(storage.Id(), keyStorage, acc, verifier),
|
||||
acceptorVerifier: verifier,
|
||||
}
|
||||
return build(deps)
|
||||
}
|
||||
|
||||
func build(id string, stateBuilder *aclStateBuilder, recBuilder AclRecordBuilder, storage liststorage.ListStorage) (list AclList, err error) {
|
||||
func BuildAclList(storage liststorage.ListStorage, verifier AcceptorVerifier) (AclList, error) {
|
||||
keyStorage := crypto.NewKeyStorage()
|
||||
deps := internalDeps{
|
||||
storage: storage,
|
||||
keyStorage: keyStorage,
|
||||
stateBuilder: newAclStateBuilder(),
|
||||
recordBuilder: NewAclRecordBuilder(storage.Id(), keyStorage, nil, verifier),
|
||||
acceptorVerifier: verifier,
|
||||
}
|
||||
return build(deps)
|
||||
}
|
||||
|
||||
func build(deps internalDeps) (list AclList, err error) {
|
||||
var (
|
||||
storage = deps.storage
|
||||
id = deps.storage.Id()
|
||||
recBuilder = deps.recordBuilder
|
||||
stateBuilder = deps.stateBuilder
|
||||
)
|
||||
head, err := storage.Head()
|
||||
if err != nil {
|
||||
return
|
||||
@ -74,7 +128,7 @@ func build(id string, stateBuilder *aclStateBuilder, recBuilder AclRecordBuilder
|
||||
return
|
||||
}
|
||||
|
||||
record, err := recBuilder.ConvertFromRaw(rawRecordWithId)
|
||||
record, err := recBuilder.UnmarshallWithId(rawRecordWithId)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -86,7 +140,7 @@ func build(id string, stateBuilder *aclStateBuilder, recBuilder AclRecordBuilder
|
||||
return
|
||||
}
|
||||
|
||||
record, err = recBuilder.ConvertFromRaw(rawRecordWithId)
|
||||
record, err = recBuilder.UnmarshallWithId(rawRecordWithId)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -116,6 +170,7 @@ func build(id string, stateBuilder *aclStateBuilder, recBuilder AclRecordBuilder
|
||||
return
|
||||
}
|
||||
|
||||
recBuilder.(*aclRecordBuilder).state = state
|
||||
list = &aclList{
|
||||
root: rootWithId,
|
||||
records: records,
|
||||
@ -129,15 +184,37 @@ func build(id string, stateBuilder *aclStateBuilder, recBuilder AclRecordBuilder
|
||||
return
|
||||
}
|
||||
|
||||
func (a *aclList) RecordBuilder() AclRecordBuilder {
|
||||
return a.recordBuilder
|
||||
}
|
||||
|
||||
func (a *aclList) Records() []*AclRecord {
|
||||
return a.records
|
||||
}
|
||||
|
||||
func (a *aclList) AddRawRecord(rawRec *aclrecordproto.RawAclRecordWithId) (added bool, err error) {
|
||||
if _, ok := a.indexes[rawRec.Id]; ok {
|
||||
func (a *aclList) ValidateRawRecord(rawRec *consensusproto.RawRecord) (err error) {
|
||||
record, err := a.recordBuilder.Unmarshall(rawRec)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
record, err := a.recordBuilder.ConvertFromRaw(rawRec)
|
||||
return a.aclState.Validator().ValidateAclRecordContents(record)
|
||||
}
|
||||
|
||||
func (a *aclList) AddRawRecords(rawRecords []*consensusproto.RawRecordWithId) (err error) {
|
||||
for _, rec := range rawRecords {
|
||||
err = a.AddRawRecord(rec)
|
||||
if err != nil && err != ErrRecordAlreadyExists {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *aclList) AddRawRecord(rawRec *consensusproto.RawRecordWithId) (err error) {
|
||||
if _, ok := a.indexes[rawRec.Id]; ok {
|
||||
return ErrRecordAlreadyExists
|
||||
}
|
||||
record, err := a.recordBuilder.UnmarshallWithId(rawRec)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -152,15 +229,6 @@ func (a *aclList) AddRawRecord(rawRec *aclrecordproto.RawAclRecordWithId) (added
|
||||
if err = a.storage.SetHead(rawRec.Id); err != nil {
|
||||
return
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (a *aclList) IsValidNext(rawRec *aclrecordproto.RawAclRecordWithId) (err error) {
|
||||
_, err = a.recordBuilder.ConvertFromRaw(rawRec)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
// TODO: change state and add "check" method for records
|
||||
return
|
||||
}
|
||||
|
||||
@ -168,7 +236,7 @@ func (a *aclList) Id() string {
|
||||
return a.id
|
||||
}
|
||||
|
||||
func (a *aclList) Root() *aclrecordproto.RawAclRecordWithId {
|
||||
func (a *aclList) Root() *consensusproto.RawRecordWithId {
|
||||
return a.root
|
||||
}
|
||||
|
||||
@ -176,6 +244,10 @@ func (a *aclList) AclState() *AclState {
|
||||
return a.aclState
|
||||
}
|
||||
|
||||
func (a *aclList) KeyStorage() crypto.KeyStorage {
|
||||
return a.keyStorage
|
||||
}
|
||||
|
||||
func (a *aclList) IsAfter(first string, second string) (bool, error) {
|
||||
firstRec, okFirst := a.indexes[first]
|
||||
secondRec, okSecond := a.indexes[second]
|
||||
@ -189,14 +261,27 @@ func (a *aclList) Head() *AclRecord {
|
||||
return a.records[len(a.records)-1]
|
||||
}
|
||||
|
||||
func (a *aclList) HasHead(head string) bool {
|
||||
_, exists := a.indexes[head]
|
||||
return exists
|
||||
}
|
||||
|
||||
func (a *aclList) Get(id string) (*AclRecord, error) {
|
||||
recIdx, ok := a.indexes[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("no such record")
|
||||
return nil, ErrNoSuchRecord
|
||||
}
|
||||
return a.records[recIdx], nil
|
||||
}
|
||||
|
||||
func (a *aclList) GetIndex(idx int) (*AclRecord, error) {
|
||||
// TODO: when we add snapshots we will have to monitor record num in snapshots
|
||||
if idx < 0 || idx >= len(a.records) {
|
||||
return nil, ErrNoSuchRecord
|
||||
}
|
||||
return a.records[idx], nil
|
||||
}
|
||||
|
||||
func (a *aclList) Iterate(iterFunc IterFunc) {
|
||||
for _, rec := range a.records {
|
||||
if !iterFunc(rec) {
|
||||
@ -205,6 +290,21 @@ func (a *aclList) Iterate(iterFunc IterFunc) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *aclList) RecordsAfter(ctx context.Context, id string) (records []*consensusproto.RawRecordWithId, err error) {
|
||||
recIdx, ok := a.indexes[id]
|
||||
if !ok {
|
||||
return nil, ErrNoSuchRecord
|
||||
}
|
||||
for i := recIdx + 1; i < len(a.records); i++ {
|
||||
rawRec, err := a.storage.GetRawRecord(ctx, a.records[i].Id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
records = append(records, rawRec)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *aclList) IterateFrom(startId string, iterFunc IterFunc) {
|
||||
recIdx, ok := a.indexes[startId]
|
||||
if !ok {
|
||||
@ -217,6 +317,21 @@ func (a *aclList) IterateFrom(startId string, iterFunc IterFunc) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *aclList) Close() (err error) {
|
||||
func (a *aclList) Close(ctx context.Context) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func WrapAclRecord(rawRec *consensusproto.RawRecord) *consensusproto.RawRecordWithId {
|
||||
payload, err := rawRec.Marshal()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
id, err := cidutil.NewCidFromBytes(payload)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &consensusproto.RawRecordWithId{
|
||||
Payload: payload,
|
||||
Id: id,
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,91 +1,293 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/testutils/acllistbuilder"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/anyproto/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAclList_AclState_UserInviteAndJoin(t *testing.T) {
|
||||
st, err := acllistbuilder.NewListStorageWithTestName("userjoinexample.yml")
|
||||
require.NoError(t, err, "building storage should not result in error")
|
||||
|
||||
keychain := st.(*acllistbuilder.AclListStorageBuilder).GetKeychain()
|
||||
|
||||
aclList, err := BuildAclList(st)
|
||||
require.NoError(t, err, "building acl list should be without error")
|
||||
|
||||
idA := keychain.GetIdentity("A")
|
||||
idB := keychain.GetIdentity("B")
|
||||
idC := keychain.GetIdentity("C")
|
||||
|
||||
// checking final state
|
||||
assert.Equal(t, aclrecordproto.AclUserPermissions_Admin, aclList.AclState().UserStates()[idA].Permissions)
|
||||
assert.Equal(t, aclrecordproto.AclUserPermissions_Writer, aclList.AclState().UserStates()[idB].Permissions)
|
||||
assert.Equal(t, aclrecordproto.AclUserPermissions_Reader, aclList.AclState().UserStates()[idC].Permissions)
|
||||
assert.Equal(t, aclList.Head().CurrentReadKeyHash, aclList.AclState().CurrentReadKeyHash())
|
||||
|
||||
var records []*AclRecord
|
||||
aclList.Iterate(func(record *AclRecord) (IsContinue bool) {
|
||||
records = append(records, record)
|
||||
return true
|
||||
})
|
||||
|
||||
// checking permissions at specific records
|
||||
assert.Equal(t, 3, len(records))
|
||||
|
||||
_, err = aclList.AclState().PermissionsAtRecord(records[1].Id, idB)
|
||||
assert.Error(t, err, "B should have no permissions at record 1")
|
||||
|
||||
perm, err := aclList.AclState().PermissionsAtRecord(records[2].Id, idB)
|
||||
assert.NoError(t, err, "should have no error with permissions of B in the record 2")
|
||||
assert.Equal(t, UserPermissionPair{
|
||||
Identity: idB,
|
||||
Permission: aclrecordproto.AclUserPermissions_Writer,
|
||||
}, perm)
|
||||
type aclFixture struct {
|
||||
ownerKeys *accountdata.AccountKeys
|
||||
accountKeys *accountdata.AccountKeys
|
||||
ownerAcl *aclList
|
||||
accountAcl *aclList
|
||||
spaceId string
|
||||
}
|
||||
|
||||
func TestAclList_AclState_UserJoinAndRemove(t *testing.T) {
|
||||
st, err := acllistbuilder.NewListStorageWithTestName("userremoveexample.yml")
|
||||
require.NoError(t, err, "building storage should not result in error")
|
||||
|
||||
keychain := st.(*acllistbuilder.AclListStorageBuilder).GetKeychain()
|
||||
|
||||
aclList, err := BuildAclList(st)
|
||||
require.NoError(t, err, "building acl list should be without error")
|
||||
|
||||
idA := keychain.GetIdentity("A")
|
||||
idB := keychain.GetIdentity("B")
|
||||
idC := keychain.GetIdentity("C")
|
||||
|
||||
// checking final state
|
||||
assert.Equal(t, aclrecordproto.AclUserPermissions_Admin, aclList.AclState().UserStates()[idA].Permissions)
|
||||
assert.Equal(t, aclrecordproto.AclUserPermissions_Reader, aclList.AclState().UserStates()[idC].Permissions)
|
||||
assert.Equal(t, aclList.Head().CurrentReadKeyHash, aclList.AclState().CurrentReadKeyHash())
|
||||
|
||||
_, exists := aclList.AclState().UserStates()[idB]
|
||||
assert.Equal(t, false, exists)
|
||||
|
||||
var records []*AclRecord
|
||||
aclList.Iterate(func(record *AclRecord) (IsContinue bool) {
|
||||
records = append(records, record)
|
||||
return true
|
||||
})
|
||||
|
||||
// checking permissions at specific records
|
||||
assert.Equal(t, 4, len(records))
|
||||
|
||||
assert.NotEqual(t, records[2].CurrentReadKeyHash, aclList.AclState().CurrentReadKeyHash())
|
||||
|
||||
perm, err := aclList.AclState().PermissionsAtRecord(records[2].Id, idB)
|
||||
assert.NoError(t, err, "should have no error with permissions of B in the record 2")
|
||||
assert.Equal(t, UserPermissionPair{
|
||||
Identity: idB,
|
||||
Permission: aclrecordproto.AclUserPermissions_Writer,
|
||||
}, perm)
|
||||
|
||||
_, err = aclList.AclState().PermissionsAtRecord(records[3].Id, idB)
|
||||
assert.Error(t, err, "B should have no permissions at record 3, because user should be removed")
|
||||
func newFixture(t *testing.T) *aclFixture {
|
||||
ownerKeys, err := accountdata.NewRandom()
|
||||
require.NoError(t, err)
|
||||
accountKeys, err := accountdata.NewRandom()
|
||||
require.NoError(t, err)
|
||||
spaceId := "spaceId"
|
||||
ownerAcl, err := NewTestDerivedAcl(spaceId, ownerKeys)
|
||||
require.NoError(t, err)
|
||||
accountAcl, err := NewTestAclWithRoot(accountKeys, ownerAcl.Root())
|
||||
require.NoError(t, err)
|
||||
return &aclFixture{
|
||||
ownerKeys: ownerKeys,
|
||||
accountKeys: accountKeys,
|
||||
ownerAcl: ownerAcl.(*aclList),
|
||||
accountAcl: accountAcl.(*aclList),
|
||||
spaceId: spaceId,
|
||||
}
|
||||
}
|
||||
|
||||
func (fx *aclFixture) addRec(t *testing.T, rec *consensusproto.RawRecordWithId) {
|
||||
err := fx.ownerAcl.AddRawRecord(rec)
|
||||
require.NoError(t, err)
|
||||
err = fx.accountAcl.AddRawRecord(rec)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func (fx *aclFixture) inviteAccount(t *testing.T, perms AclPermissions) {
|
||||
var (
|
||||
ownerAcl = fx.ownerAcl
|
||||
ownerState = fx.ownerAcl.aclState
|
||||
accountAcl = fx.accountAcl
|
||||
accountState = fx.accountAcl.aclState
|
||||
)
|
||||
// building invite
|
||||
inv, err := ownerAcl.RecordBuilder().BuildInvite()
|
||||
require.NoError(t, err)
|
||||
inviteRec := WrapAclRecord(inv.InviteRec)
|
||||
fx.addRec(t, inviteRec)
|
||||
|
||||
// building request join
|
||||
requestJoin, err := accountAcl.RecordBuilder().BuildRequestJoin(RequestJoinPayload{
|
||||
InviteRecordId: inviteRec.Id,
|
||||
InviteKey: inv.InviteKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requestJoinRec := WrapAclRecord(requestJoin)
|
||||
fx.addRec(t, requestJoinRec)
|
||||
|
||||
// building request accept
|
||||
requestAccept, err := ownerAcl.RecordBuilder().BuildRequestAccept(RequestAcceptPayload{
|
||||
RequestRecordId: requestJoinRec.Id,
|
||||
Permissions: perms,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// validate
|
||||
err = ownerAcl.ValidateRawRecord(requestAccept)
|
||||
require.NoError(t, err)
|
||||
requestAcceptRec := WrapAclRecord(requestAccept)
|
||||
fx.addRec(t, requestAcceptRec)
|
||||
|
||||
// checking acl state
|
||||
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
|
||||
require.True(t, ownerState.Permissions(accountState.pubKey).CanWrite())
|
||||
require.Equal(t, 0, len(ownerState.pendingRequests))
|
||||
require.Equal(t, 0, len(accountState.pendingRequests))
|
||||
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
|
||||
require.True(t, accountState.Permissions(accountState.pubKey).CanWrite())
|
||||
|
||||
_, err = ownerState.StateAtRecord(requestJoinRec.Id, accountState.pubKey)
|
||||
require.Equal(t, ErrNoSuchAccount, err)
|
||||
stateAtRec, err := ownerState.StateAtRecord(requestAcceptRec.Id, accountState.pubKey)
|
||||
require.NoError(t, err)
|
||||
require.True(t, stateAtRec.Permissions == perms)
|
||||
}
|
||||
|
||||
func TestAclList_BuildRoot(t *testing.T) {
|
||||
randomKeys, err := accountdata.NewRandom()
|
||||
require.NoError(t, err)
|
||||
randomAcl, err := NewTestDerivedAcl("spaceId", randomKeys)
|
||||
require.NoError(t, err)
|
||||
fmt.Println(randomAcl.Id())
|
||||
}
|
||||
|
||||
func TestAclList_InvitePipeline(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Writer))
|
||||
}
|
||||
|
||||
func TestAclList_InviteRevoke(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
var (
|
||||
ownerState = fx.ownerAcl.aclState
|
||||
accountState = fx.accountAcl.aclState
|
||||
)
|
||||
// building invite
|
||||
inv, err := fx.ownerAcl.RecordBuilder().BuildInvite()
|
||||
require.NoError(t, err)
|
||||
inviteRec := WrapAclRecord(inv.InviteRec)
|
||||
fx.addRec(t, inviteRec)
|
||||
|
||||
// building invite revoke
|
||||
inviteRevoke, err := fx.ownerAcl.RecordBuilder().BuildInviteRevoke(ownerState.lastRecordId)
|
||||
require.NoError(t, err)
|
||||
inviteRevokeRec := WrapAclRecord(inviteRevoke)
|
||||
fx.addRec(t, inviteRevokeRec)
|
||||
|
||||
// checking acl state
|
||||
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
|
||||
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
|
||||
require.Empty(t, ownerState.inviteKeys)
|
||||
require.Empty(t, accountState.inviteKeys)
|
||||
}
|
||||
|
||||
func TestAclList_RequestDecline(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
var (
|
||||
ownerAcl = fx.ownerAcl
|
||||
ownerState = fx.ownerAcl.aclState
|
||||
accountAcl = fx.accountAcl
|
||||
accountState = fx.accountAcl.aclState
|
||||
)
|
||||
// building invite
|
||||
inv, err := ownerAcl.RecordBuilder().BuildInvite()
|
||||
require.NoError(t, err)
|
||||
inviteRec := WrapAclRecord(inv.InviteRec)
|
||||
fx.addRec(t, inviteRec)
|
||||
|
||||
// building request join
|
||||
requestJoin, err := accountAcl.RecordBuilder().BuildRequestJoin(RequestJoinPayload{
|
||||
InviteRecordId: inviteRec.Id,
|
||||
InviteKey: inv.InviteKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requestJoinRec := WrapAclRecord(requestJoin)
|
||||
fx.addRec(t, requestJoinRec)
|
||||
|
||||
// building request decline
|
||||
requestDecline, err := ownerAcl.RecordBuilder().BuildRequestDecline(ownerState.lastRecordId)
|
||||
require.NoError(t, err)
|
||||
requestDeclineRec := WrapAclRecord(requestDecline)
|
||||
fx.addRec(t, requestDeclineRec)
|
||||
|
||||
// checking acl state
|
||||
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
|
||||
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
|
||||
require.Empty(t, ownerState.pendingRequests)
|
||||
require.Empty(t, accountState.pendingRequests)
|
||||
}
|
||||
|
||||
func TestAclList_Remove(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
var (
|
||||
ownerState = fx.ownerAcl.aclState
|
||||
accountState = fx.accountAcl.aclState
|
||||
)
|
||||
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Writer))
|
||||
|
||||
newReadKey := crypto.NewAES()
|
||||
remove, err := fx.ownerAcl.RecordBuilder().BuildAccountRemove(AccountRemovePayload{
|
||||
Identities: []crypto.PubKey{fx.accountKeys.SignKey.GetPublic()},
|
||||
ReadKey: newReadKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
removeRec := WrapAclRecord(remove)
|
||||
fx.addRec(t, removeRec)
|
||||
|
||||
// checking acl state
|
||||
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
|
||||
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
|
||||
require.True(t, ownerState.userReadKeys[removeRec.Id].Equals(newReadKey))
|
||||
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
|
||||
require.Equal(t, 0, len(ownerState.pendingRequests))
|
||||
require.Equal(t, 0, len(accountState.pendingRequests))
|
||||
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
|
||||
require.True(t, accountState.Permissions(accountState.pubKey).NoPermissions())
|
||||
require.Nil(t, accountState.userReadKeys[removeRec.Id])
|
||||
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
|
||||
}
|
||||
|
||||
func TestAclList_ReadKeyChange(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
var (
|
||||
ownerState = fx.ownerAcl.aclState
|
||||
accountState = fx.accountAcl.aclState
|
||||
)
|
||||
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Admin))
|
||||
|
||||
newReadKey := crypto.NewAES()
|
||||
readKeyChange, err := fx.ownerAcl.RecordBuilder().BuildReadKeyChange(newReadKey)
|
||||
require.NoError(t, err)
|
||||
readKeyRec := WrapAclRecord(readKeyChange)
|
||||
fx.addRec(t, readKeyRec)
|
||||
|
||||
// checking acl state
|
||||
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
|
||||
require.True(t, ownerState.Permissions(accountState.pubKey).CanManageAccounts())
|
||||
require.True(t, ownerState.userReadKeys[readKeyRec.Id].Equals(newReadKey))
|
||||
require.True(t, accountState.userReadKeys[readKeyRec.Id].Equals(newReadKey))
|
||||
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
|
||||
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
|
||||
readKey, err := ownerState.CurrentReadKey()
|
||||
require.NoError(t, err)
|
||||
require.True(t, newReadKey.Equals(readKey))
|
||||
require.Equal(t, 0, len(ownerState.pendingRequests))
|
||||
require.Equal(t, 0, len(accountState.pendingRequests))
|
||||
}
|
||||
|
||||
func TestAclList_PermissionChange(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
var (
|
||||
ownerState = fx.ownerAcl.aclState
|
||||
accountState = fx.accountAcl.aclState
|
||||
)
|
||||
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Admin))
|
||||
|
||||
permissionChange, err := fx.ownerAcl.RecordBuilder().BuildPermissionChange(PermissionChangePayload{
|
||||
Identity: fx.accountKeys.SignKey.GetPublic(),
|
||||
Permissions: AclPermissions(aclrecordproto.AclUserPermissions_Writer),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
permissionChangeRec := WrapAclRecord(permissionChange)
|
||||
fx.addRec(t, permissionChangeRec)
|
||||
|
||||
// checking acl state
|
||||
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
|
||||
require.True(t, ownerState.Permissions(accountState.pubKey) == AclPermissions(aclrecordproto.AclUserPermissions_Writer))
|
||||
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
|
||||
require.True(t, accountState.Permissions(accountState.pubKey) == AclPermissions(aclrecordproto.AclUserPermissions_Writer))
|
||||
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
|
||||
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
|
||||
require.Equal(t, 0, len(ownerState.pendingRequests))
|
||||
require.Equal(t, 0, len(accountState.pendingRequests))
|
||||
}
|
||||
|
||||
func TestAclList_RequestRemove(t *testing.T) {
|
||||
fx := newFixture(t)
|
||||
var (
|
||||
ownerState = fx.ownerAcl.aclState
|
||||
accountState = fx.accountAcl.aclState
|
||||
)
|
||||
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Writer))
|
||||
|
||||
removeRequest, err := fx.accountAcl.RecordBuilder().BuildRequestRemove()
|
||||
require.NoError(t, err)
|
||||
removeRequestRec := WrapAclRecord(removeRequest)
|
||||
fx.addRec(t, removeRequestRec)
|
||||
|
||||
recs := fx.accountAcl.AclState().RemoveRecords()
|
||||
require.Len(t, recs, 1)
|
||||
require.True(t, accountState.pubKey.Equals(recs[0].RequestIdentity))
|
||||
|
||||
newReadKey := crypto.NewAES()
|
||||
remove, err := fx.ownerAcl.RecordBuilder().BuildAccountRemove(AccountRemovePayload{
|
||||
Identities: []crypto.PubKey{recs[0].RequestIdentity},
|
||||
ReadKey: newReadKey,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
removeRec := WrapAclRecord(remove)
|
||||
fx.addRec(t, removeRec)
|
||||
|
||||
// checking acl state
|
||||
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
|
||||
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
|
||||
require.True(t, ownerState.userReadKeys[removeRec.Id].Equals(newReadKey))
|
||||
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
|
||||
require.Equal(t, 0, len(ownerState.pendingRequests))
|
||||
require.Equal(t, 0, len(accountState.pendingRequests))
|
||||
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
|
||||
require.True(t, accountState.Permissions(accountState.pubKey).NoPermissions())
|
||||
require.Nil(t, accountState.userReadKeys[removeRec.Id])
|
||||
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
|
||||
}
|
||||
|
||||
41
commonspace/object/acl/list/listutils.go
Normal file
41
commonspace/object/acl/list/listutils.go
Normal file
@ -0,0 +1,41 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"github.com/anyproto/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
)
|
||||
|
||||
func NewTestDerivedAcl(spaceId string, keys *accountdata.AccountKeys) (AclList, error) {
|
||||
builder := NewAclRecordBuilder("", crypto.NewKeyStorage(), keys, NoOpAcceptorVerifier{})
|
||||
masterKey, _, err := crypto.GenerateRandomEd25519KeyPair()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
root, err := builder.BuildRoot(RootContent{
|
||||
PrivKey: keys.SignKey,
|
||||
SpaceId: spaceId,
|
||||
MasterKey: masterKey,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
st, err := liststorage.NewInMemoryAclListStorage(root.Id, []*consensusproto.RawRecordWithId{
|
||||
root,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return BuildAclListWithIdentity(keys, st, NoOpAcceptorVerifier{})
|
||||
}
|
||||
|
||||
func NewTestAclWithRoot(keys *accountdata.AccountKeys, root *consensusproto.RawRecordWithId) (AclList, error) {
|
||||
st, err := liststorage.NewInMemoryAclListStorage(root.Id, []*consensusproto.RawRecordWithId{
|
||||
root,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return BuildAclListWithIdentity(keys, st, NoOpAcceptorVerifier{})
|
||||
}
|
||||
@ -1,15 +1,17 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anytypeio/any-sync/commonspace/object/acl/list (interfaces: AclList)
|
||||
// Source: github.com/anyproto/any-sync/commonspace/object/acl/list (interfaces: AclList)
|
||||
|
||||
// Package mock_list is a generated GoMock package.
|
||||
package mock_list
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
aclrecordproto "github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
list "github.com/anytypeio/any-sync/commonspace/object/acl/list"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
list "github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
consensusproto "github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
crypto "github.com/anyproto/any-sync/util/crypto"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockAclList is a mock of AclList interface.
|
||||
@ -50,12 +52,11 @@ func (mr *MockAclListMockRecorder) AclState() *gomock.Call {
|
||||
}
|
||||
|
||||
// AddRawRecord mocks base method.
|
||||
func (m *MockAclList) AddRawRecord(arg0 *aclrecordproto.RawAclRecordWithId) (bool, error) {
|
||||
func (m *MockAclList) AddRawRecord(arg0 *consensusproto.RawRecordWithId) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddRawRecord", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddRawRecord indicates an expected call of AddRawRecord.
|
||||
@ -64,18 +65,32 @@ func (mr *MockAclListMockRecorder) AddRawRecord(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockAclList)(nil).AddRawRecord), arg0)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockAclList) Close() error {
|
||||
// AddRawRecords mocks base method.
|
||||
func (m *MockAclList) AddRawRecords(arg0 []*consensusproto.RawRecordWithId) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close")
|
||||
ret := m.ctrl.Call(m, "AddRawRecords", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddRawRecords indicates an expected call of AddRawRecords.
|
||||
func (mr *MockAclListMockRecorder) AddRawRecords(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecords", reflect.TypeOf((*MockAclList)(nil).AddRawRecords), arg0)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockAclList) Close(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockAclListMockRecorder) Close() *gomock.Call {
|
||||
func (mr *MockAclListMockRecorder) Close(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAclList)(nil).Close))
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAclList)(nil).Close), arg0)
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
@ -93,6 +108,35 @@ func (mr *MockAclListMockRecorder) Get(arg0 interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAclList)(nil).Get), arg0)
|
||||
}
|
||||
|
||||
// GetIndex mocks base method.
|
||||
func (m *MockAclList) GetIndex(arg0 int) (*list.AclRecord, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetIndex", arg0)
|
||||
ret0, _ := ret[0].(*list.AclRecord)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetIndex indicates an expected call of GetIndex.
|
||||
func (mr *MockAclListMockRecorder) GetIndex(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIndex", reflect.TypeOf((*MockAclList)(nil).GetIndex), arg0)
|
||||
}
|
||||
|
||||
// HasHead mocks base method.
|
||||
func (m *MockAclList) HasHead(arg0 string) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "HasHead", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// HasHead indicates an expected call of HasHead.
|
||||
func (mr *MockAclListMockRecorder) HasHead(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasHead", reflect.TypeOf((*MockAclList)(nil).HasHead), arg0)
|
||||
}
|
||||
|
||||
// Head mocks base method.
|
||||
func (m *MockAclList) Head() *list.AclRecord {
|
||||
m.ctrl.T.Helper()
|
||||
@ -160,6 +204,20 @@ func (mr *MockAclListMockRecorder) IterateFrom(arg0, arg1 interface{}) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateFrom", reflect.TypeOf((*MockAclList)(nil).IterateFrom), arg0, arg1)
|
||||
}
|
||||
|
||||
// KeyStorage mocks base method.
|
||||
func (m *MockAclList) KeyStorage() crypto.KeyStorage {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeyStorage")
|
||||
ret0, _ := ret[0].(crypto.KeyStorage)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// KeyStorage indicates an expected call of KeyStorage.
|
||||
func (mr *MockAclListMockRecorder) KeyStorage() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyStorage", reflect.TypeOf((*MockAclList)(nil).KeyStorage))
|
||||
}
|
||||
|
||||
// Lock mocks base method.
|
||||
func (m *MockAclList) Lock() {
|
||||
m.ctrl.T.Helper()
|
||||
@ -196,6 +254,20 @@ func (mr *MockAclListMockRecorder) RUnlock() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RUnlock", reflect.TypeOf((*MockAclList)(nil).RUnlock))
|
||||
}
|
||||
|
||||
// RecordBuilder mocks base method.
|
||||
func (m *MockAclList) RecordBuilder() list.AclRecordBuilder {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RecordBuilder")
|
||||
ret0, _ := ret[0].(list.AclRecordBuilder)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RecordBuilder indicates an expected call of RecordBuilder.
|
||||
func (mr *MockAclListMockRecorder) RecordBuilder() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordBuilder", reflect.TypeOf((*MockAclList)(nil).RecordBuilder))
|
||||
}
|
||||
|
||||
// Records mocks base method.
|
||||
func (m *MockAclList) Records() []*list.AclRecord {
|
||||
m.ctrl.T.Helper()
|
||||
@ -210,11 +282,26 @@ func (mr *MockAclListMockRecorder) Records() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Records", reflect.TypeOf((*MockAclList)(nil).Records))
|
||||
}
|
||||
|
||||
// RecordsAfter mocks base method.
|
||||
func (m *MockAclList) RecordsAfter(arg0 context.Context, arg1 string) ([]*consensusproto.RawRecordWithId, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RecordsAfter", arg0, arg1)
|
||||
ret0, _ := ret[0].([]*consensusproto.RawRecordWithId)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// RecordsAfter indicates an expected call of RecordsAfter.
|
||||
func (mr *MockAclListMockRecorder) RecordsAfter(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordsAfter", reflect.TypeOf((*MockAclList)(nil).RecordsAfter), arg0, arg1)
|
||||
}
|
||||
|
||||
// Root mocks base method.
|
||||
func (m *MockAclList) Root() *aclrecordproto.RawAclRecordWithId {
|
||||
func (m *MockAclList) Root() *consensusproto.RawRecordWithId {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Root")
|
||||
ret0, _ := ret[0].(*aclrecordproto.RawAclRecordWithId)
|
||||
ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
|
||||
return ret0
|
||||
}
|
||||
|
||||
@ -235,3 +322,17 @@ func (mr *MockAclListMockRecorder) Unlock() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockAclList)(nil).Unlock))
|
||||
}
|
||||
|
||||
// ValidateRawRecord mocks base method.
|
||||
func (m *MockAclList) ValidateRawRecord(arg0 *consensusproto.RawRecord) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ValidateRawRecord", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ValidateRawRecord indicates an expected call of ValidateRawRecord.
|
||||
func (mr *MockAclListMockRecorder) ValidateRawRecord(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateRawRecord", reflect.TypeOf((*MockAclList)(nil).ValidateRawRecord), arg0)
|
||||
}
|
||||
|
||||
69
commonspace/object/acl/list/models.go
Normal file
69
commonspace/object/acl/list/models.go
Normal file
@ -0,0 +1,69 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
)
|
||||
|
||||
type AclRecord struct {
|
||||
Id string
|
||||
PrevId string
|
||||
Timestamp int64
|
||||
Data []byte
|
||||
Identity crypto.PubKey
|
||||
Model interface{}
|
||||
Signature []byte
|
||||
}
|
||||
|
||||
type RequestRecord struct {
|
||||
RequestIdentity crypto.PubKey
|
||||
RequestMetadata []byte
|
||||
Type RequestType
|
||||
}
|
||||
|
||||
type AclUserState struct {
|
||||
PubKey crypto.PubKey
|
||||
Permissions AclPermissions
|
||||
RequestMetadata []byte
|
||||
}
|
||||
|
||||
type RequestType int
|
||||
|
||||
const (
|
||||
RequestTypeRemove RequestType = iota
|
||||
RequestTypeJoin
|
||||
)
|
||||
|
||||
type AclPermissions aclrecordproto.AclUserPermissions
|
||||
|
||||
func (p AclPermissions) NoPermissions() bool {
|
||||
return aclrecordproto.AclUserPermissions(p) == aclrecordproto.AclUserPermissions_None
|
||||
}
|
||||
|
||||
func (p AclPermissions) IsOwner() bool {
|
||||
return aclrecordproto.AclUserPermissions(p) == aclrecordproto.AclUserPermissions_Owner
|
||||
}
|
||||
|
||||
func (p AclPermissions) CanWrite() bool {
|
||||
switch aclrecordproto.AclUserPermissions(p) {
|
||||
case aclrecordproto.AclUserPermissions_Admin:
|
||||
return true
|
||||
case aclrecordproto.AclUserPermissions_Writer:
|
||||
return true
|
||||
case aclrecordproto.AclUserPermissions_Owner:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (p AclPermissions) CanManageAccounts() bool {
|
||||
switch aclrecordproto.AclUserPermissions(p) {
|
||||
case aclrecordproto.AclUserPermissions_Admin:
|
||||
return true
|
||||
case aclrecordproto.AclUserPermissions_Owner:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@ -1,12 +0,0 @@
|
||||
package list
|
||||
|
||||
type AclRecord struct {
|
||||
Id string
|
||||
PrevId string
|
||||
CurrentReadKeyHash uint64
|
||||
Timestamp int64
|
||||
Data []byte
|
||||
Identity []byte
|
||||
Model interface{}
|
||||
Signature []byte
|
||||
}
|
||||
218
commonspace/object/acl/list/validator.go
Normal file
218
commonspace/object/acl/list/validator.go
Normal file
@ -0,0 +1,218 @@
|
||||
package list
|
||||
|
||||
import (
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
)
|
||||
|
||||
type ContentValidator interface {
|
||||
ValidateAclRecordContents(ch *AclRecord) (err error)
|
||||
ValidatePermissionChange(ch *aclrecordproto.AclAccountPermissionChange, authorIdentity crypto.PubKey) (err error)
|
||||
ValidateInvite(ch *aclrecordproto.AclAccountInvite, authorIdentity crypto.PubKey) (err error)
|
||||
ValidateInviteRevoke(ch *aclrecordproto.AclAccountInviteRevoke, authorIdentity crypto.PubKey) (err error)
|
||||
ValidateRequestJoin(ch *aclrecordproto.AclAccountRequestJoin, authorIdentity crypto.PubKey) (err error)
|
||||
ValidateRequestAccept(ch *aclrecordproto.AclAccountRequestAccept, authorIdentity crypto.PubKey) (err error)
|
||||
ValidateRequestDecline(ch *aclrecordproto.AclAccountRequestDecline, authorIdentity crypto.PubKey) (err error)
|
||||
ValidateAccountRemove(ch *aclrecordproto.AclAccountRemove, authorIdentity crypto.PubKey) (err error)
|
||||
ValidateRequestRemove(ch *aclrecordproto.AclAccountRequestRemove, authorIdentity crypto.PubKey) (err error)
|
||||
ValidateReadKeyChange(ch *aclrecordproto.AclReadKeyChange, authorIdentity crypto.PubKey) (err error)
|
||||
}
|
||||
|
||||
type contentValidator struct {
|
||||
keyStore crypto.KeyStorage
|
||||
aclState *AclState
|
||||
}
|
||||
|
||||
func (c *contentValidator) ValidateAclRecordContents(ch *AclRecord) (err error) {
|
||||
if ch.PrevId != c.aclState.lastRecordId {
|
||||
return ErrIncorrectRecordSequence
|
||||
}
|
||||
aclData := ch.Model.(*aclrecordproto.AclData)
|
||||
for _, content := range aclData.AclContent {
|
||||
err = c.validateAclRecordContent(content, ch.Identity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *contentValidator) validateAclRecordContent(ch *aclrecordproto.AclContentValue, authorIdentity crypto.PubKey) (err error) {
|
||||
switch {
|
||||
case ch.GetPermissionChange() != nil:
|
||||
return c.ValidatePermissionChange(ch.GetPermissionChange(), authorIdentity)
|
||||
case ch.GetInvite() != nil:
|
||||
return c.ValidateInvite(ch.GetInvite(), authorIdentity)
|
||||
case ch.GetInviteRevoke() != nil:
|
||||
return c.ValidateInviteRevoke(ch.GetInviteRevoke(), authorIdentity)
|
||||
case ch.GetRequestJoin() != nil:
|
||||
return c.ValidateRequestJoin(ch.GetRequestJoin(), authorIdentity)
|
||||
case ch.GetRequestAccept() != nil:
|
||||
return c.ValidateRequestAccept(ch.GetRequestAccept(), authorIdentity)
|
||||
case ch.GetRequestDecline() != nil:
|
||||
return c.ValidateRequestDecline(ch.GetRequestDecline(), authorIdentity)
|
||||
case ch.GetAccountRemove() != nil:
|
||||
return c.ValidateAccountRemove(ch.GetAccountRemove(), authorIdentity)
|
||||
case ch.GetAccountRequestRemove() != nil:
|
||||
return c.ValidateRequestRemove(ch.GetAccountRequestRemove(), authorIdentity)
|
||||
case ch.GetReadKeyChange() != nil:
|
||||
return c.ValidateReadKeyChange(ch.GetReadKeyChange(), authorIdentity)
|
||||
default:
|
||||
return ErrUnexpectedContentType
|
||||
}
|
||||
}
|
||||
|
||||
func (c *contentValidator) ValidatePermissionChange(ch *aclrecordproto.AclAccountPermissionChange, authorIdentity crypto.PubKey) (err error) {
|
||||
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
|
||||
return ErrInsufficientPermissions
|
||||
}
|
||||
chIdentity, err := c.keyStore.PubKeyFromProto(ch.Identity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, exists := c.aclState.userStates[mapKeyFromPubKey(chIdentity)]
|
||||
if !exists {
|
||||
return ErrNoSuchAccount
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *contentValidator) ValidateInvite(ch *aclrecordproto.AclAccountInvite, authorIdentity crypto.PubKey) (err error) {
|
||||
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
|
||||
return ErrInsufficientPermissions
|
||||
}
|
||||
_, err = c.keyStore.PubKeyFromProto(ch.InviteKey)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *contentValidator) ValidateInviteRevoke(ch *aclrecordproto.AclAccountInviteRevoke, authorIdentity crypto.PubKey) (err error) {
|
||||
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
|
||||
return ErrInsufficientPermissions
|
||||
}
|
||||
_, exists := c.aclState.inviteKeys[ch.InviteRecordId]
|
||||
if !exists {
|
||||
return ErrNoSuchInvite
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *contentValidator) ValidateRequestJoin(ch *aclrecordproto.AclAccountRequestJoin, authorIdentity crypto.PubKey) (err error) {
|
||||
inviteKey, exists := c.aclState.inviteKeys[ch.InviteRecordId]
|
||||
if !exists {
|
||||
return ErrNoSuchInvite
|
||||
}
|
||||
inviteIdentity, err := c.keyStore.PubKeyFromProto(ch.InviteIdentity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if _, exists := c.aclState.pendingRequests[mapKeyFromPubKey(inviteIdentity)]; exists {
|
||||
return ErrPendingRequest
|
||||
}
|
||||
if !authorIdentity.Equals(inviteIdentity) {
|
||||
return ErrIncorrectIdentity
|
||||
}
|
||||
rawInviteIdentity, err := inviteIdentity.Raw()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ok, err := inviteKey.Verify(rawInviteIdentity, ch.InviteIdentitySignature)
|
||||
if err != nil {
|
||||
return ErrInvalidSignature
|
||||
}
|
||||
if !ok {
|
||||
return ErrInvalidSignature
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *contentValidator) ValidateRequestAccept(ch *aclrecordproto.AclAccountRequestAccept, authorIdentity crypto.PubKey) (err error) {
|
||||
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
|
||||
return ErrInsufficientPermissions
|
||||
}
|
||||
record, exists := c.aclState.requestRecords[ch.RequestRecordId]
|
||||
if !exists {
|
||||
return ErrNoSuchRequest
|
||||
}
|
||||
acceptIdentity, err := c.keyStore.PubKeyFromProto(ch.Identity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !acceptIdentity.Equals(record.RequestIdentity) {
|
||||
return ErrIncorrectIdentity
|
||||
}
|
||||
if ch.Permissions == aclrecordproto.AclUserPermissions_Owner {
|
||||
return ErrInsufficientPermissions
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *contentValidator) ValidateRequestDecline(ch *aclrecordproto.AclAccountRequestDecline, authorIdentity crypto.PubKey) (err error) {
|
||||
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
|
||||
return ErrInsufficientPermissions
|
||||
}
|
||||
_, exists := c.aclState.requestRecords[ch.RequestRecordId]
|
||||
if !exists {
|
||||
return ErrNoSuchRequest
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *contentValidator) ValidateAccountRemove(ch *aclrecordproto.AclAccountRemove, authorIdentity crypto.PubKey) (err error) {
|
||||
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
|
||||
return ErrInsufficientPermissions
|
||||
}
|
||||
seenIdentities := map[string]struct{}{}
|
||||
for _, rawIdentity := range ch.Identities {
|
||||
identity, err := c.keyStore.PubKeyFromProto(rawIdentity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if identity.Equals(authorIdentity) {
|
||||
return ErrInsufficientPermissions
|
||||
}
|
||||
permissions := c.aclState.Permissions(identity)
|
||||
if permissions.NoPermissions() {
|
||||
return ErrNoSuchAccount
|
||||
}
|
||||
if permissions.IsOwner() {
|
||||
return ErrInsufficientPermissions
|
||||
}
|
||||
idKey := mapKeyFromPubKey(identity)
|
||||
if _, exists := seenIdentities[idKey]; exists {
|
||||
return ErrDuplicateAccounts
|
||||
}
|
||||
seenIdentities[mapKeyFromPubKey(identity)] = struct{}{}
|
||||
}
|
||||
return c.validateAccountReadKeys(ch.AccountKeys, len(c.aclState.userStates)-len(ch.Identities))
|
||||
}
|
||||
|
||||
func (c *contentValidator) ValidateRequestRemove(ch *aclrecordproto.AclAccountRequestRemove, authorIdentity crypto.PubKey) (err error) {
|
||||
if c.aclState.Permissions(authorIdentity).NoPermissions() {
|
||||
return ErrInsufficientPermissions
|
||||
}
|
||||
if _, exists := c.aclState.pendingRequests[mapKeyFromPubKey(authorIdentity)]; exists {
|
||||
return ErrPendingRequest
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (c *contentValidator) ValidateReadKeyChange(ch *aclrecordproto.AclReadKeyChange, authorIdentity crypto.PubKey) (err error) {
|
||||
return c.validateAccountReadKeys(ch.AccountKeys, len(c.aclState.userStates))
|
||||
}
|
||||
|
||||
func (c *contentValidator) validateAccountReadKeys(accountKeys []*aclrecordproto.AclEncryptedReadKey, usersNum int) (err error) {
|
||||
if len(accountKeys) != usersNum {
|
||||
return ErrIncorrectNumberOfAccounts
|
||||
}
|
||||
for _, encKeys := range accountKeys {
|
||||
identity, err := c.keyStore.PubKeyFromProto(encKeys.Identity)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, exists := c.aclState.userStates[mapKeyFromPubKey(identity)]
|
||||
if !exists {
|
||||
return ErrNoSuchAccount
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -3,24 +3,26 @@ package liststorage
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
|
||||
"sync"
|
||||
)
|
||||
|
||||
type inMemoryAclListStorage struct {
|
||||
id string
|
||||
root *aclrecordproto.RawAclRecordWithId
|
||||
root *consensusproto.RawRecordWithId
|
||||
head string
|
||||
records map[string]*aclrecordproto.RawAclRecordWithId
|
||||
records map[string]*consensusproto.RawRecordWithId
|
||||
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
func NewInMemoryAclListStorage(
|
||||
id string,
|
||||
records []*aclrecordproto.RawAclRecordWithId) (ListStorage, error) {
|
||||
records []*consensusproto.RawRecordWithId) (ListStorage, error) {
|
||||
|
||||
allRecords := make(map[string]*aclrecordproto.RawAclRecordWithId)
|
||||
allRecords := make(map[string]*consensusproto.RawRecordWithId)
|
||||
for _, ch := range records {
|
||||
allRecords[ch.Id] = ch
|
||||
}
|
||||
@ -41,7 +43,7 @@ func (t *inMemoryAclListStorage) Id() string {
|
||||
return t.id
|
||||
}
|
||||
|
||||
func (t *inMemoryAclListStorage) Root() (*aclrecordproto.RawAclRecordWithId, error) {
|
||||
func (t *inMemoryAclListStorage) Root() (*consensusproto.RawRecordWithId, error) {
|
||||
t.RLock()
|
||||
defer t.RUnlock()
|
||||
return t.root, nil
|
||||
@ -60,7 +62,7 @@ func (t *inMemoryAclListStorage) SetHead(head string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *inMemoryAclListStorage) AddRawRecord(ctx context.Context, record *aclrecordproto.RawAclRecordWithId) error {
|
||||
func (t *inMemoryAclListStorage) AddRawRecord(ctx context.Context, record *consensusproto.RawRecordWithId) error {
|
||||
t.Lock()
|
||||
defer t.Unlock()
|
||||
// TODO: better to do deep copy
|
||||
@ -68,7 +70,7 @@ func (t *inMemoryAclListStorage) AddRawRecord(ctx context.Context, record *aclre
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *inMemoryAclListStorage) GetRawRecord(ctx context.Context, recordId string) (*aclrecordproto.RawAclRecordWithId, error) {
|
||||
func (t *inMemoryAclListStorage) GetRawRecord(ctx context.Context, recordId string) (*consensusproto.RawRecordWithId, error) {
|
||||
t.RLock()
|
||||
defer t.RUnlock()
|
||||
if res, exists := t.records[recordId]; exists {
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
//go:generate mockgen -destination mock_liststorage/mock_liststorage.go github.com/anytypeio/any-sync/commonspace/object/acl/liststorage ListStorage
|
||||
//go:generate mockgen -destination mock_liststorage/mock_liststorage.go github.com/anyproto/any-sync/commonspace/object/acl/liststorage ListStorage
|
||||
package liststorage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -13,12 +14,16 @@ var (
|
||||
ErrUnknownRecord = errors.New("record doesn't exist")
|
||||
)
|
||||
|
||||
type Exporter interface {
|
||||
ListStorage(root *consensusproto.RawRecordWithId) (ListStorage, error)
|
||||
}
|
||||
|
||||
type ListStorage interface {
|
||||
Id() string
|
||||
Root() (*aclrecordproto.RawAclRecordWithId, error)
|
||||
Root() (*consensusproto.RawRecordWithId, error)
|
||||
Head() (string, error)
|
||||
SetHead(headId string) error
|
||||
|
||||
GetRawRecord(ctx context.Context, id string) (*aclrecordproto.RawAclRecordWithId, error)
|
||||
AddRawRecord(ctx context.Context, rec *aclrecordproto.RawAclRecordWithId) error
|
||||
GetRawRecord(ctx context.Context, id string) (*consensusproto.RawRecordWithId, error)
|
||||
AddRawRecord(ctx context.Context, rec *consensusproto.RawRecordWithId) error
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anytypeio/any-sync/commonspace/object/acl/liststorage (interfaces: ListStorage)
|
||||
// Source: github.com/anyproto/any-sync/commonspace/object/acl/liststorage (interfaces: ListStorage)
|
||||
|
||||
// Package mock_liststorage is a generated GoMock package.
|
||||
package mock_liststorage
|
||||
@ -8,8 +8,8 @@ import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
aclrecordproto "github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
consensusproto "github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockListStorage is a mock of ListStorage interface.
|
||||
@ -36,7 +36,7 @@ func (m *MockListStorage) EXPECT() *MockListStorageMockRecorder {
|
||||
}
|
||||
|
||||
// AddRawRecord mocks base method.
|
||||
func (m *MockListStorage) AddRawRecord(arg0 context.Context, arg1 *aclrecordproto.RawAclRecordWithId) error {
|
||||
func (m *MockListStorage) AddRawRecord(arg0 context.Context, arg1 *consensusproto.RawRecordWithId) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddRawRecord", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
@ -50,10 +50,10 @@ func (mr *MockListStorageMockRecorder) AddRawRecord(arg0, arg1 interface{}) *gom
|
||||
}
|
||||
|
||||
// GetRawRecord mocks base method.
|
||||
func (m *MockListStorage) GetRawRecord(arg0 context.Context, arg1 string) (*aclrecordproto.RawAclRecordWithId, error) {
|
||||
func (m *MockListStorage) GetRawRecord(arg0 context.Context, arg1 string) (*consensusproto.RawRecordWithId, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetRawRecord", arg0, arg1)
|
||||
ret0, _ := ret[0].(*aclrecordproto.RawAclRecordWithId)
|
||||
ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
@ -94,10 +94,10 @@ func (mr *MockListStorageMockRecorder) Id() *gomock.Call {
|
||||
}
|
||||
|
||||
// Root mocks base method.
|
||||
func (m *MockListStorage) Root() (*aclrecordproto.RawAclRecordWithId, error) {
|
||||
func (m *MockListStorage) Root() (*consensusproto.RawRecordWithId, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Root")
|
||||
ret0, _ := ret[0].(*aclrecordproto.RawAclRecordWithId)
|
||||
ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
120
commonspace/object/acl/syncacl/aclsyncprotocol.go
Normal file
120
commonspace/object/acl/syncacl/aclsyncprotocol.go
Normal file
@ -0,0 +1,120 @@
|
||||
//go:generate mockgen -destination mock_syncacl/mock_syncacl.go github.com/anyproto/any-sync/commonspace/object/acl/syncacl SyncAcl,SyncClient,RequestFactory,AclSyncProtocol
|
||||
package syncacl
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type AclSyncProtocol interface {
|
||||
HeadUpdate(ctx context.Context, senderId string, update *consensusproto.LogHeadUpdate) (request *consensusproto.LogSyncMessage, err error)
|
||||
FullSyncRequest(ctx context.Context, senderId string, request *consensusproto.LogFullSyncRequest) (response *consensusproto.LogSyncMessage, err error)
|
||||
FullSyncResponse(ctx context.Context, senderId string, response *consensusproto.LogFullSyncResponse) (err error)
|
||||
}
|
||||
|
||||
type aclSyncProtocol struct {
|
||||
log logger.CtxLogger
|
||||
spaceId string
|
||||
aclList list.AclList
|
||||
reqFactory RequestFactory
|
||||
}
|
||||
|
||||
func (a *aclSyncProtocol) HeadUpdate(ctx context.Context, senderId string, update *consensusproto.LogHeadUpdate) (request *consensusproto.LogSyncMessage, err error) {
|
||||
isEmptyUpdate := len(update.Records) == 0
|
||||
log := a.log.With(
|
||||
zap.String("senderId", senderId),
|
||||
zap.String("update head", update.Head),
|
||||
zap.Int("len(update records)", len(update.Records)))
|
||||
log.DebugCtx(ctx, "received acl head update message")
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.ErrorCtx(ctx, "acl head update finished with error", zap.Error(err))
|
||||
} else if request != nil {
|
||||
cnt := request.Content.GetFullSyncRequest()
|
||||
log.DebugCtx(ctx, "returning acl full sync request", zap.String("request head", cnt.Head))
|
||||
} else {
|
||||
if !isEmptyUpdate {
|
||||
log.DebugCtx(ctx, "acl head update finished correctly")
|
||||
}
|
||||
}
|
||||
}()
|
||||
if isEmptyUpdate {
|
||||
headEquals := a.aclList.Head().Id == update.Head
|
||||
log.DebugCtx(ctx, "is empty acl head update", zap.Bool("headEquals", headEquals))
|
||||
if headEquals {
|
||||
return
|
||||
}
|
||||
return a.reqFactory.CreateFullSyncRequest(a.aclList, update.Head)
|
||||
}
|
||||
if a.aclList.HasHead(update.Head) {
|
||||
return
|
||||
}
|
||||
err = a.aclList.AddRawRecords(update.Records)
|
||||
if err == list.ErrIncorrectRecordSequence {
|
||||
return a.reqFactory.CreateFullSyncRequest(a.aclList, update.Head)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *aclSyncProtocol) FullSyncRequest(ctx context.Context, senderId string, request *consensusproto.LogFullSyncRequest) (response *consensusproto.LogSyncMessage, err error) {
|
||||
log := a.log.With(
|
||||
zap.String("senderId", senderId),
|
||||
zap.String("request head", request.Head),
|
||||
zap.Int("len(request records)", len(request.Records)))
|
||||
log.DebugCtx(ctx, "received acl full sync request message")
|
||||
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.ErrorCtx(ctx, "acl full sync request finished with error", zap.Error(err))
|
||||
} else if response != nil {
|
||||
cnt := response.Content.GetFullSyncResponse()
|
||||
log.DebugCtx(ctx, "acl full sync response sent", zap.String("response head", cnt.Head), zap.Int("len(response records)", len(cnt.Records)))
|
||||
}
|
||||
}()
|
||||
if !a.aclList.HasHead(request.Head) {
|
||||
if len(request.Records) > 0 {
|
||||
// in this case we can try to add some records
|
||||
err = a.aclList.AddRawRecords(request.Records)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
// here it is impossible for us to do anything, we can't return records after head as defined in request, because we don't have it
|
||||
return nil, list.ErrIncorrectRecordSequence
|
||||
}
|
||||
}
|
||||
return a.reqFactory.CreateFullSyncResponse(a.aclList, request.Head)
|
||||
}
|
||||
|
||||
func (a *aclSyncProtocol) FullSyncResponse(ctx context.Context, senderId string, response *consensusproto.LogFullSyncResponse) (err error) {
|
||||
log := a.log.With(
|
||||
zap.String("senderId", senderId),
|
||||
zap.String("response head", response.Head),
|
||||
zap.Int("len(response records)", len(response.Records)))
|
||||
log.DebugCtx(ctx, "received acl full sync response message")
|
||||
defer func() {
|
||||
if err != nil {
|
||||
log.ErrorCtx(ctx, "acl full sync response failed", zap.Error(err))
|
||||
} else {
|
||||
log.DebugCtx(ctx, "acl full sync response succeeded")
|
||||
}
|
||||
}()
|
||||
if a.aclList.HasHead(response.Head) {
|
||||
return
|
||||
}
|
||||
return a.aclList.AddRawRecords(response.Records)
|
||||
}
|
||||
|
||||
func newAclSyncProtocol(spaceId string, aclList list.AclList, reqFactory RequestFactory) *aclSyncProtocol {
|
||||
return &aclSyncProtocol{
|
||||
log: log.With(zap.String("spaceId", spaceId), zap.String("aclId", aclList.Id())),
|
||||
spaceId: spaceId,
|
||||
aclList: aclList,
|
||||
reqFactory: reqFactory,
|
||||
}
|
||||
}
|
||||
213
commonspace/object/acl/syncacl/aclsyncprotocol_test.go
Normal file
213
commonspace/object/acl/syncacl/aclsyncprotocol_test.go
Normal file
@ -0,0 +1,213 @@
|
||||
package syncacl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list/mock_list"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl/mock_syncacl"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type aclSyncProtocolFixture struct {
|
||||
log logger.CtxLogger
|
||||
spaceId string
|
||||
senderId string
|
||||
aclId string
|
||||
aclMock *mock_list.MockAclList
|
||||
reqFactory *mock_syncacl.MockRequestFactory
|
||||
ctrl *gomock.Controller
|
||||
syncProtocol AclSyncProtocol
|
||||
}
|
||||
|
||||
func newSyncProtocolFixture(t *testing.T) *aclSyncProtocolFixture {
|
||||
ctrl := gomock.NewController(t)
|
||||
aclList := mock_list.NewMockAclList(ctrl)
|
||||
spaceId := "spaceId"
|
||||
reqFactory := mock_syncacl.NewMockRequestFactory(ctrl)
|
||||
aclList.EXPECT().Id().Return("aclId")
|
||||
syncProtocol := newAclSyncProtocol(spaceId, aclList, reqFactory)
|
||||
return &aclSyncProtocolFixture{
|
||||
log: log,
|
||||
spaceId: spaceId,
|
||||
senderId: "senderId",
|
||||
aclId: "aclId",
|
||||
aclMock: aclList,
|
||||
reqFactory: reqFactory,
|
||||
ctrl: ctrl,
|
||||
syncProtocol: syncProtocol,
|
||||
}
|
||||
}
|
||||
|
||||
func (fx *aclSyncProtocolFixture) stop() {
|
||||
fx.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestHeadUpdate(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
fullRequest := &consensusproto.LogSyncMessage{
|
||||
Content: &consensusproto.LogSyncContentValue{
|
||||
Value: &consensusproto.LogSyncContentValue_FullSyncRequest{
|
||||
FullSyncRequest: &consensusproto.LogFullSyncRequest{},
|
||||
},
|
||||
},
|
||||
}
|
||||
t.Run("head update non empty all heads added", func(t *testing.T) {
|
||||
fx := newSyncProtocolFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
headUpdate := &consensusproto.LogHeadUpdate{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.aclMock.EXPECT().HasHead("h1").Return(false)
|
||||
fx.aclMock.EXPECT().AddRawRecords(headUpdate.Records).Return(nil)
|
||||
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
|
||||
require.Nil(t, req)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("head update results in full request", func(t *testing.T) {
|
||||
fx := newSyncProtocolFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
headUpdate := &consensusproto.LogHeadUpdate{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.aclMock.EXPECT().HasHead("h1").Return(false)
|
||||
fx.aclMock.EXPECT().AddRawRecords(headUpdate.Records).Return(list.ErrIncorrectRecordSequence)
|
||||
fx.reqFactory.EXPECT().CreateFullSyncRequest(fx.aclMock, headUpdate.Head).Return(fullRequest, nil)
|
||||
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
|
||||
require.Equal(t, fullRequest, req)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("head update old heads", func(t *testing.T) {
|
||||
fx := newSyncProtocolFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
headUpdate := &consensusproto.LogHeadUpdate{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.aclMock.EXPECT().HasHead("h1").Return(true)
|
||||
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
|
||||
require.Nil(t, req)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("head update empty equals", func(t *testing.T) {
|
||||
fx := newSyncProtocolFixture(t)
|
||||
defer fx.stop()
|
||||
headUpdate := &consensusproto.LogHeadUpdate{
|
||||
Head: "h1",
|
||||
}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.aclMock.EXPECT().Head().Return(&list.AclRecord{Id: "h1"})
|
||||
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
|
||||
require.Nil(t, req)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("head update empty results in full request", func(t *testing.T) {
|
||||
fx := newSyncProtocolFixture(t)
|
||||
defer fx.stop()
|
||||
headUpdate := &consensusproto.LogHeadUpdate{
|
||||
Head: "h1",
|
||||
}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.aclMock.EXPECT().Head().Return(&list.AclRecord{Id: "h2"})
|
||||
fx.reqFactory.EXPECT().CreateFullSyncRequest(fx.aclMock, headUpdate.Head).Return(fullRequest, nil)
|
||||
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
|
||||
require.Equal(t, fullRequest, req)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFullSyncRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
fullResponse := &consensusproto.LogSyncMessage{
|
||||
Content: &consensusproto.LogSyncContentValue{
|
||||
Value: &consensusproto.LogSyncContentValue_FullSyncResponse{
|
||||
FullSyncResponse: &consensusproto.LogFullSyncResponse{},
|
||||
},
|
||||
},
|
||||
}
|
||||
t.Run("full sync request non empty all heads added", func(t *testing.T) {
|
||||
fx := newSyncProtocolFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
fullRequest := &consensusproto.LogFullSyncRequest{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.aclMock.EXPECT().HasHead("h1").Return(false)
|
||||
fx.aclMock.EXPECT().AddRawRecords(fullRequest.Records).Return(nil)
|
||||
fx.reqFactory.EXPECT().CreateFullSyncResponse(fx.aclMock, fullRequest.Head).Return(fullResponse, nil)
|
||||
resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest)
|
||||
require.Equal(t, fullResponse, resp)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("full sync request non empty head exists", func(t *testing.T) {
|
||||
fx := newSyncProtocolFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
fullRequest := &consensusproto.LogFullSyncRequest{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.aclMock.EXPECT().HasHead("h1").Return(true)
|
||||
fx.reqFactory.EXPECT().CreateFullSyncResponse(fx.aclMock, fullRequest.Head).Return(fullResponse, nil)
|
||||
resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest)
|
||||
require.Equal(t, fullResponse, resp)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("full sync request empty head not exists", func(t *testing.T) {
|
||||
fx := newSyncProtocolFixture(t)
|
||||
defer fx.stop()
|
||||
fullRequest := &consensusproto.LogFullSyncRequest{
|
||||
Head: "h1",
|
||||
}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.aclMock.EXPECT().HasHead("h1").Return(false)
|
||||
resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest)
|
||||
require.Nil(t, resp)
|
||||
require.Error(t, list.ErrIncorrectRecordSequence, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestFullSyncResponse(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("full sync response no heads", func(t *testing.T) {
|
||||
fx := newSyncProtocolFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
fullResponse := &consensusproto.LogFullSyncResponse{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.aclMock.EXPECT().HasHead("h1").Return(false)
|
||||
fx.aclMock.EXPECT().AddRawRecords(fullResponse.Records).Return(nil)
|
||||
err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullResponse)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("full sync response has heads", func(t *testing.T) {
|
||||
fx := newSyncProtocolFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
fullResponse := &consensusproto.LogFullSyncResponse{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.aclMock.EXPECT().HasHead("h1").Return(true)
|
||||
err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullResponse)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
@ -0,0 +1,5 @@
|
||||
package headupdater
|
||||
|
||||
type HeadUpdater interface {
|
||||
UpdateHeads(id string, heads []string)
|
||||
}
|
||||
694
commonspace/object/acl/syncacl/mock_syncacl/mock_syncacl.go
Normal file
694
commonspace/object/acl/syncacl/mock_syncacl/mock_syncacl.go
Normal file
@ -0,0 +1,694 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anyproto/any-sync/commonspace/object/acl/syncacl (interfaces: SyncAcl,SyncClient,RequestFactory,AclSyncProtocol)
|
||||
|
||||
// Package mock_syncacl is a generated GoMock package.
|
||||
package mock_syncacl
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
app "github.com/anyproto/any-sync/app"
|
||||
list "github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
headupdater "github.com/anyproto/any-sync/commonspace/object/acl/syncacl/headupdater"
|
||||
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
consensusproto "github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
crypto "github.com/anyproto/any-sync/util/crypto"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockSyncAcl is a mock of SyncAcl interface.
|
||||
type MockSyncAcl struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSyncAclMockRecorder
|
||||
}
|
||||
|
||||
// MockSyncAclMockRecorder is the mock recorder for MockSyncAcl.
|
||||
type MockSyncAclMockRecorder struct {
|
||||
mock *MockSyncAcl
|
||||
}
|
||||
|
||||
// NewMockSyncAcl creates a new mock instance.
|
||||
func NewMockSyncAcl(ctrl *gomock.Controller) *MockSyncAcl {
|
||||
mock := &MockSyncAcl{ctrl: ctrl}
|
||||
mock.recorder = &MockSyncAclMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockSyncAcl) EXPECT() *MockSyncAclMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AclState mocks base method.
|
||||
func (m *MockSyncAcl) AclState() *list.AclState {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AclState")
|
||||
ret0, _ := ret[0].(*list.AclState)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AclState indicates an expected call of AclState.
|
||||
func (mr *MockSyncAclMockRecorder) AclState() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AclState", reflect.TypeOf((*MockSyncAcl)(nil).AclState))
|
||||
}
|
||||
|
||||
// AddRawRecord mocks base method.
|
||||
func (m *MockSyncAcl) AddRawRecord(arg0 *consensusproto.RawRecordWithId) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddRawRecord", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddRawRecord indicates an expected call of AddRawRecord.
|
||||
func (mr *MockSyncAclMockRecorder) AddRawRecord(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockSyncAcl)(nil).AddRawRecord), arg0)
|
||||
}
|
||||
|
||||
// AddRawRecords mocks base method.
|
||||
func (m *MockSyncAcl) AddRawRecords(arg0 []*consensusproto.RawRecordWithId) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AddRawRecords", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AddRawRecords indicates an expected call of AddRawRecords.
|
||||
func (mr *MockSyncAclMockRecorder) AddRawRecords(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecords", reflect.TypeOf((*MockSyncAcl)(nil).AddRawRecords), arg0)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockSyncAcl) Close(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Close", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Close indicates an expected call of Close.
|
||||
func (mr *MockSyncAclMockRecorder) Close(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSyncAcl)(nil).Close), arg0)
|
||||
}
|
||||
|
||||
// Get mocks base method.
|
||||
func (m *MockSyncAcl) Get(arg0 string) (*list.AclRecord, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Get", arg0)
|
||||
ret0, _ := ret[0].(*list.AclRecord)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Get indicates an expected call of Get.
|
||||
func (mr *MockSyncAclMockRecorder) Get(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSyncAcl)(nil).Get), arg0)
|
||||
}
|
||||
|
||||
// GetIndex mocks base method.
|
||||
func (m *MockSyncAcl) GetIndex(arg0 int) (*list.AclRecord, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetIndex", arg0)
|
||||
ret0, _ := ret[0].(*list.AclRecord)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetIndex indicates an expected call of GetIndex.
|
||||
func (mr *MockSyncAclMockRecorder) GetIndex(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIndex", reflect.TypeOf((*MockSyncAcl)(nil).GetIndex), arg0)
|
||||
}
|
||||
|
||||
// HandleMessage mocks base method.
|
||||
func (m *MockSyncAcl) HandleMessage(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// HandleMessage indicates an expected call of HandleMessage.
|
||||
func (mr *MockSyncAclMockRecorder) HandleMessage(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockSyncAcl)(nil).HandleMessage), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// HandleRequest mocks base method.
|
||||
func (m *MockSyncAcl) HandleRequest(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "HandleRequest", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// HandleRequest indicates an expected call of HandleRequest.
|
||||
func (mr *MockSyncAclMockRecorder) HandleRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleRequest", reflect.TypeOf((*MockSyncAcl)(nil).HandleRequest), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// HasHead mocks base method.
|
||||
func (m *MockSyncAcl) HasHead(arg0 string) bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "HasHead", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// HasHead indicates an expected call of HasHead.
|
||||
func (mr *MockSyncAclMockRecorder) HasHead(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasHead", reflect.TypeOf((*MockSyncAcl)(nil).HasHead), arg0)
|
||||
}
|
||||
|
||||
// Head mocks base method.
|
||||
func (m *MockSyncAcl) Head() *list.AclRecord {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Head")
|
||||
ret0, _ := ret[0].(*list.AclRecord)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Head indicates an expected call of Head.
|
||||
func (mr *MockSyncAclMockRecorder) Head() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Head", reflect.TypeOf((*MockSyncAcl)(nil).Head))
|
||||
}
|
||||
|
||||
// Id mocks base method.
|
||||
func (m *MockSyncAcl) Id() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Id")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Id indicates an expected call of Id.
|
||||
func (mr *MockSyncAclMockRecorder) Id() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Id", reflect.TypeOf((*MockSyncAcl)(nil).Id))
|
||||
}
|
||||
|
||||
// Init mocks base method.
|
||||
func (m *MockSyncAcl) Init(arg0 *app.App) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Init", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Init indicates an expected call of Init.
|
||||
func (mr *MockSyncAclMockRecorder) Init(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockSyncAcl)(nil).Init), arg0)
|
||||
}
|
||||
|
||||
// IsAfter mocks base method.
|
||||
func (m *MockSyncAcl) IsAfter(arg0, arg1 string) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsAfter", arg0, arg1)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsAfter indicates an expected call of IsAfter.
|
||||
func (mr *MockSyncAclMockRecorder) IsAfter(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsAfter", reflect.TypeOf((*MockSyncAcl)(nil).IsAfter), arg0, arg1)
|
||||
}
|
||||
|
||||
// Iterate mocks base method.
|
||||
func (m *MockSyncAcl) Iterate(arg0 func(*list.AclRecord) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Iterate", arg0)
|
||||
}
|
||||
|
||||
// Iterate indicates an expected call of Iterate.
|
||||
func (mr *MockSyncAclMockRecorder) Iterate(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Iterate", reflect.TypeOf((*MockSyncAcl)(nil).Iterate), arg0)
|
||||
}
|
||||
|
||||
// IterateFrom mocks base method.
|
||||
func (m *MockSyncAcl) IterateFrom(arg0 string, arg1 func(*list.AclRecord) bool) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "IterateFrom", arg0, arg1)
|
||||
}
|
||||
|
||||
// IterateFrom indicates an expected call of IterateFrom.
|
||||
func (mr *MockSyncAclMockRecorder) IterateFrom(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateFrom", reflect.TypeOf((*MockSyncAcl)(nil).IterateFrom), arg0, arg1)
|
||||
}
|
||||
|
||||
// KeyStorage mocks base method.
|
||||
func (m *MockSyncAcl) KeyStorage() crypto.KeyStorage {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "KeyStorage")
|
||||
ret0, _ := ret[0].(crypto.KeyStorage)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// KeyStorage indicates an expected call of KeyStorage.
|
||||
func (mr *MockSyncAclMockRecorder) KeyStorage() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyStorage", reflect.TypeOf((*MockSyncAcl)(nil).KeyStorage))
|
||||
}
|
||||
|
||||
// Lock mocks base method.
|
||||
func (m *MockSyncAcl) Lock() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Lock")
|
||||
}
|
||||
|
||||
// Lock indicates an expected call of Lock.
|
||||
func (mr *MockSyncAclMockRecorder) Lock() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lock", reflect.TypeOf((*MockSyncAcl)(nil).Lock))
|
||||
}
|
||||
|
||||
// Name mocks base method.
|
||||
func (m *MockSyncAcl) Name() string {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Name")
|
||||
ret0, _ := ret[0].(string)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Name indicates an expected call of Name.
|
||||
func (mr *MockSyncAclMockRecorder) Name() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockSyncAcl)(nil).Name))
|
||||
}
|
||||
|
||||
// RLock mocks base method.
|
||||
func (m *MockSyncAcl) RLock() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "RLock")
|
||||
}
|
||||
|
||||
// RLock indicates an expected call of RLock.
|
||||
func (mr *MockSyncAclMockRecorder) RLock() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RLock", reflect.TypeOf((*MockSyncAcl)(nil).RLock))
|
||||
}
|
||||
|
||||
// RUnlock mocks base method.
|
||||
func (m *MockSyncAcl) RUnlock() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "RUnlock")
|
||||
}
|
||||
|
||||
// RUnlock indicates an expected call of RUnlock.
|
||||
func (mr *MockSyncAclMockRecorder) RUnlock() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RUnlock", reflect.TypeOf((*MockSyncAcl)(nil).RUnlock))
|
||||
}
|
||||
|
||||
// RecordBuilder mocks base method.
|
||||
func (m *MockSyncAcl) RecordBuilder() list.AclRecordBuilder {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RecordBuilder")
|
||||
ret0, _ := ret[0].(list.AclRecordBuilder)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// RecordBuilder indicates an expected call of RecordBuilder.
|
||||
func (mr *MockSyncAclMockRecorder) RecordBuilder() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordBuilder", reflect.TypeOf((*MockSyncAcl)(nil).RecordBuilder))
|
||||
}
|
||||
|
||||
// Records mocks base method.
|
||||
func (m *MockSyncAcl) Records() []*list.AclRecord {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Records")
|
||||
ret0, _ := ret[0].([]*list.AclRecord)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Records indicates an expected call of Records.
|
||||
func (mr *MockSyncAclMockRecorder) Records() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Records", reflect.TypeOf((*MockSyncAcl)(nil).Records))
|
||||
}
|
||||
|
||||
// RecordsAfter mocks base method.
|
||||
func (m *MockSyncAcl) RecordsAfter(arg0 context.Context, arg1 string) ([]*consensusproto.RawRecordWithId, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "RecordsAfter", arg0, arg1)
|
||||
ret0, _ := ret[0].([]*consensusproto.RawRecordWithId)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// RecordsAfter indicates an expected call of RecordsAfter.
|
||||
func (mr *MockSyncAclMockRecorder) RecordsAfter(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordsAfter", reflect.TypeOf((*MockSyncAcl)(nil).RecordsAfter), arg0, arg1)
|
||||
}
|
||||
|
||||
// Root mocks base method.
|
||||
func (m *MockSyncAcl) Root() *consensusproto.RawRecordWithId {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Root")
|
||||
ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Root indicates an expected call of Root.
|
||||
func (mr *MockSyncAclMockRecorder) Root() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Root", reflect.TypeOf((*MockSyncAcl)(nil).Root))
|
||||
}
|
||||
|
||||
// Run mocks base method.
|
||||
func (m *MockSyncAcl) Run(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Run", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Run indicates an expected call of Run.
|
||||
func (mr *MockSyncAclMockRecorder) Run(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSyncAcl)(nil).Run), arg0)
|
||||
}
|
||||
|
||||
// SetHeadUpdater mocks base method.
|
||||
func (m *MockSyncAcl) SetHeadUpdater(arg0 headupdater.HeadUpdater) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetHeadUpdater", arg0)
|
||||
}
|
||||
|
||||
// SetHeadUpdater indicates an expected call of SetHeadUpdater.
|
||||
func (mr *MockSyncAclMockRecorder) SetHeadUpdater(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeadUpdater", reflect.TypeOf((*MockSyncAcl)(nil).SetHeadUpdater), arg0)
|
||||
}
|
||||
|
||||
// SyncWithPeer mocks base method.
|
||||
func (m *MockSyncAcl) SyncWithPeer(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SyncWithPeer", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SyncWithPeer indicates an expected call of SyncWithPeer.
|
||||
func (mr *MockSyncAclMockRecorder) SyncWithPeer(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncWithPeer", reflect.TypeOf((*MockSyncAcl)(nil).SyncWithPeer), arg0, arg1)
|
||||
}
|
||||
|
||||
// Unlock mocks base method.
|
||||
func (m *MockSyncAcl) Unlock() {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Unlock")
|
||||
}
|
||||
|
||||
// Unlock indicates an expected call of Unlock.
|
||||
func (mr *MockSyncAclMockRecorder) Unlock() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockSyncAcl)(nil).Unlock))
|
||||
}
|
||||
|
||||
// ValidateRawRecord mocks base method.
|
||||
func (m *MockSyncAcl) ValidateRawRecord(arg0 *consensusproto.RawRecord) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ValidateRawRecord", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ValidateRawRecord indicates an expected call of ValidateRawRecord.
|
||||
func (mr *MockSyncAclMockRecorder) ValidateRawRecord(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateRawRecord", reflect.TypeOf((*MockSyncAcl)(nil).ValidateRawRecord), arg0)
|
||||
}
|
||||
|
||||
// MockSyncClient is a mock of SyncClient interface.
|
||||
type MockSyncClient struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockSyncClientMockRecorder
|
||||
}
|
||||
|
||||
// MockSyncClientMockRecorder is the mock recorder for MockSyncClient.
|
||||
type MockSyncClientMockRecorder struct {
|
||||
mock *MockSyncClient
|
||||
}
|
||||
|
||||
// NewMockSyncClient creates a new mock instance.
|
||||
func NewMockSyncClient(ctrl *gomock.Controller) *MockSyncClient {
|
||||
mock := &MockSyncClient{ctrl: ctrl}
|
||||
mock.recorder = &MockSyncClientMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockSyncClient) EXPECT() *MockSyncClientMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Broadcast mocks base method.
|
||||
func (m *MockSyncClient) Broadcast(arg0 *consensusproto.LogSyncMessage) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "Broadcast", arg0)
|
||||
}
|
||||
|
||||
// Broadcast indicates an expected call of Broadcast.
|
||||
func (mr *MockSyncClientMockRecorder) Broadcast(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Broadcast", reflect.TypeOf((*MockSyncClient)(nil).Broadcast), arg0)
|
||||
}
|
||||
|
||||
// CreateFullSyncRequest mocks base method.
|
||||
func (m *MockSyncClient) CreateFullSyncRequest(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1)
|
||||
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
|
||||
func (mr *MockSyncClientMockRecorder) CreateFullSyncRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncRequest), arg0, arg1)
|
||||
}
|
||||
|
||||
// CreateFullSyncResponse mocks base method.
|
||||
func (m *MockSyncClient) CreateFullSyncResponse(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1)
|
||||
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
|
||||
func (mr *MockSyncClientMockRecorder) CreateFullSyncResponse(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncResponse), arg0, arg1)
|
||||
}
|
||||
|
||||
// CreateHeadUpdate mocks base method.
|
||||
func (m *MockSyncClient) CreateHeadUpdate(arg0 list.AclList, arg1 []*consensusproto.RawRecordWithId) *consensusproto.LogSyncMessage {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
|
||||
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate.
|
||||
func (mr *MockSyncClientMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockSyncClient)(nil).CreateHeadUpdate), arg0, arg1)
|
||||
}
|
||||
|
||||
// QueueRequest mocks base method.
|
||||
func (m *MockSyncClient) QueueRequest(arg0 string, arg1 *consensusproto.LogSyncMessage) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "QueueRequest", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// QueueRequest indicates an expected call of QueueRequest.
|
||||
func (mr *MockSyncClientMockRecorder) QueueRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueRequest", reflect.TypeOf((*MockSyncClient)(nil).QueueRequest), arg0, arg1)
|
||||
}
|
||||
|
||||
// SendRequest mocks base method.
|
||||
func (m *MockSyncClient) SendRequest(arg0 context.Context, arg1 string, arg2 *consensusproto.LogSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SendRequest", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// SendRequest indicates an expected call of SendRequest.
|
||||
func (mr *MockSyncClientMockRecorder) SendRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRequest", reflect.TypeOf((*MockSyncClient)(nil).SendRequest), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// SendUpdate mocks base method.
|
||||
func (m *MockSyncClient) SendUpdate(arg0 string, arg1 *consensusproto.LogSyncMessage) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "SendUpdate", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// SendUpdate indicates an expected call of SendUpdate.
|
||||
func (mr *MockSyncClientMockRecorder) SendUpdate(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendUpdate", reflect.TypeOf((*MockSyncClient)(nil).SendUpdate), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockRequestFactory is a mock of RequestFactory interface.
|
||||
type MockRequestFactory struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRequestFactoryMockRecorder
|
||||
}
|
||||
|
||||
// MockRequestFactoryMockRecorder is the mock recorder for MockRequestFactory.
|
||||
type MockRequestFactoryMockRecorder struct {
|
||||
mock *MockRequestFactory
|
||||
}
|
||||
|
||||
// NewMockRequestFactory creates a new mock instance.
|
||||
func NewMockRequestFactory(ctrl *gomock.Controller) *MockRequestFactory {
|
||||
mock := &MockRequestFactory{ctrl: ctrl}
|
||||
mock.recorder = &MockRequestFactoryMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRequestFactory) EXPECT() *MockRequestFactoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// CreateFullSyncRequest mocks base method.
|
||||
func (m *MockRequestFactory) CreateFullSyncRequest(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1)
|
||||
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
|
||||
func (mr *MockRequestFactoryMockRecorder) CreateFullSyncRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncRequest), arg0, arg1)
|
||||
}
|
||||
|
||||
// CreateFullSyncResponse mocks base method.
|
||||
func (m *MockRequestFactory) CreateFullSyncResponse(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1)
|
||||
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
|
||||
func (mr *MockRequestFactoryMockRecorder) CreateFullSyncResponse(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncResponse), arg0, arg1)
|
||||
}
|
||||
|
||||
// CreateHeadUpdate mocks base method.
|
||||
func (m *MockRequestFactory) CreateHeadUpdate(arg0 list.AclList, arg1 []*consensusproto.RawRecordWithId) *consensusproto.LogSyncMessage {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
|
||||
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate.
|
||||
func (mr *MockRequestFactoryMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockRequestFactory)(nil).CreateHeadUpdate), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockAclSyncProtocol is a mock of AclSyncProtocol interface.
|
||||
type MockAclSyncProtocol struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockAclSyncProtocolMockRecorder
|
||||
}
|
||||
|
||||
// MockAclSyncProtocolMockRecorder is the mock recorder for MockAclSyncProtocol.
|
||||
type MockAclSyncProtocolMockRecorder struct {
|
||||
mock *MockAclSyncProtocol
|
||||
}
|
||||
|
||||
// NewMockAclSyncProtocol creates a new mock instance.
|
||||
func NewMockAclSyncProtocol(ctrl *gomock.Controller) *MockAclSyncProtocol {
|
||||
mock := &MockAclSyncProtocol{ctrl: ctrl}
|
||||
mock.recorder = &MockAclSyncProtocolMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockAclSyncProtocol) EXPECT() *MockAclSyncProtocolMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// FullSyncRequest mocks base method.
|
||||
func (m *MockAclSyncProtocol) FullSyncRequest(arg0 context.Context, arg1 string, arg2 *consensusproto.LogFullSyncRequest) (*consensusproto.LogSyncMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FullSyncRequest", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FullSyncRequest indicates an expected call of FullSyncRequest.
|
||||
func (mr *MockAclSyncProtocolMockRecorder) FullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncRequest", reflect.TypeOf((*MockAclSyncProtocol)(nil).FullSyncRequest), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// FullSyncResponse mocks base method.
|
||||
func (m *MockAclSyncProtocol) FullSyncResponse(arg0 context.Context, arg1 string, arg2 *consensusproto.LogFullSyncResponse) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FullSyncResponse", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FullSyncResponse indicates an expected call of FullSyncResponse.
|
||||
func (mr *MockAclSyncProtocolMockRecorder) FullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncResponse", reflect.TypeOf((*MockAclSyncProtocol)(nil).FullSyncResponse), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// HeadUpdate mocks base method.
|
||||
func (m *MockAclSyncProtocol) HeadUpdate(arg0 context.Context, arg1 string, arg2 *consensusproto.LogHeadUpdate) (*consensusproto.LogSyncMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "HeadUpdate", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// HeadUpdate indicates an expected call of HeadUpdate.
|
||||
func (mr *MockAclSyncProtocolMockRecorder) HeadUpdate(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadUpdate", reflect.TypeOf((*MockAclSyncProtocol)(nil).HeadUpdate), arg0, arg1, arg2)
|
||||
}
|
||||
54
commonspace/object/acl/syncacl/requestfactory.go
Normal file
54
commonspace/object/acl/syncacl/requestfactory.go
Normal file
@ -0,0 +1,54 @@
|
||||
package syncacl
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
)
|
||||
|
||||
type RequestFactory interface {
|
||||
CreateHeadUpdate(l list.AclList, added []*consensusproto.RawRecordWithId) (msg *consensusproto.LogSyncMessage)
|
||||
CreateFullSyncRequest(l list.AclList, theirHead string) (req *consensusproto.LogSyncMessage, err error)
|
||||
CreateFullSyncResponse(l list.AclList, theirHead string) (*consensusproto.LogSyncMessage, error)
|
||||
}
|
||||
|
||||
type requestFactory struct{}
|
||||
|
||||
func NewRequestFactory() RequestFactory {
|
||||
return &requestFactory{}
|
||||
}
|
||||
|
||||
func (r *requestFactory) CreateHeadUpdate(l list.AclList, added []*consensusproto.RawRecordWithId) (msg *consensusproto.LogSyncMessage) {
|
||||
return consensusproto.WrapHeadUpdate(&consensusproto.LogHeadUpdate{
|
||||
Head: l.Head().Id,
|
||||
Records: added,
|
||||
}, l.Root())
|
||||
}
|
||||
|
||||
func (r *requestFactory) CreateFullSyncRequest(l list.AclList, theirHead string) (req *consensusproto.LogSyncMessage, err error) {
|
||||
if !l.HasHead(theirHead) {
|
||||
return consensusproto.WrapFullRequest(&consensusproto.LogFullSyncRequest{
|
||||
Head: l.Head().Id,
|
||||
}, l.Root()), nil
|
||||
}
|
||||
records, err := l.RecordsAfter(context.Background(), theirHead)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return consensusproto.WrapFullRequest(&consensusproto.LogFullSyncRequest{
|
||||
Head: l.Head().Id,
|
||||
Records: records,
|
||||
}, l.Root()), nil
|
||||
}
|
||||
|
||||
func (r *requestFactory) CreateFullSyncResponse(l list.AclList, theirHead string) (resp *consensusproto.LogSyncMessage, err error) {
|
||||
records, err := l.RecordsAfter(context.Background(), theirHead)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return consensusproto.WrapFullResponse(&consensusproto.LogFullSyncResponse{
|
||||
Head: l.Head().Id,
|
||||
Records: records,
|
||||
}, l.Root()), nil
|
||||
}
|
||||
@ -1,21 +1,130 @@
|
||||
package syncacl
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anytypeio/any-sync/commonspace/objectsync"
|
||||
"github.com/anytypeio/any-sync/commonspace/objectsync/synchandler"
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl/headupdater"
|
||||
"github.com/anyproto/any-sync/commonspace/object/syncobjectgetter"
|
||||
|
||||
"github.com/anyproto/any-sync/accountservice"
|
||||
"github.com/anyproto/any-sync/app"
|
||||
"github.com/anyproto/any-sync/app/logger"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
|
||||
"github.com/anyproto/any-sync/commonspace/peermanager"
|
||||
"github.com/anyproto/any-sync/commonspace/requestmanager"
|
||||
"github.com/anyproto/any-sync/commonspace/spacestorage"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anyproto/any-sync/commonspace/syncstatus"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
)
|
||||
|
||||
type SyncAcl struct {
|
||||
const CName = "common.acl.syncacl"
|
||||
|
||||
var (
|
||||
log = logger.NewNamed(CName)
|
||||
|
||||
ErrSyncAclClosed = errors.New("sync acl is closed")
|
||||
)
|
||||
|
||||
type SyncAcl interface {
|
||||
app.ComponentRunnable
|
||||
list.AclList
|
||||
synchandler.SyncHandler
|
||||
streamPool objectsync.StreamPool
|
||||
syncobjectgetter.SyncObject
|
||||
SetHeadUpdater(updater headupdater.HeadUpdater)
|
||||
SyncWithPeer(ctx context.Context, peerId string) (err error)
|
||||
}
|
||||
|
||||
func NewSyncAcl(aclList list.AclList, streamPool objectsync.StreamPool) *SyncAcl {
|
||||
return &SyncAcl{
|
||||
AclList: aclList,
|
||||
SyncHandler: nil,
|
||||
streamPool: streamPool,
|
||||
}
|
||||
func New() SyncAcl {
|
||||
return &syncAcl{}
|
||||
}
|
||||
|
||||
type syncAcl struct {
|
||||
list.AclList
|
||||
syncClient SyncClient
|
||||
syncHandler synchandler.SyncHandler
|
||||
headUpdater headupdater.HeadUpdater
|
||||
isClosed bool
|
||||
}
|
||||
|
||||
func (s *syncAcl) Run(ctx context.Context) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *syncAcl) HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) {
|
||||
return s.syncHandler.HandleRequest(ctx, senderId, request)
|
||||
}
|
||||
|
||||
func (s *syncAcl) SetHeadUpdater(updater headupdater.HeadUpdater) {
|
||||
s.headUpdater = updater
|
||||
}
|
||||
|
||||
func (s *syncAcl) HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) {
|
||||
return s.syncHandler.HandleMessage(ctx, senderId, request)
|
||||
}
|
||||
|
||||
func (s *syncAcl) Init(a *app.App) (err error) {
|
||||
storage := a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
|
||||
aclStorage, err := storage.AclStorage()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
acc := a.MustComponent(accountservice.CName).(accountservice.Service)
|
||||
s.AclList, err = list.BuildAclListWithIdentity(acc.Account(), aclStorage, list.NoOpAcceptorVerifier{})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
spaceId := storage.Id()
|
||||
requestManager := a.MustComponent(requestmanager.CName).(requestmanager.RequestManager)
|
||||
peerManager := a.MustComponent(peermanager.CName).(peermanager.PeerManager)
|
||||
syncStatus := a.MustComponent(syncstatus.CName).(syncstatus.StatusService)
|
||||
s.syncClient = NewSyncClient(spaceId, requestManager, peerManager)
|
||||
s.syncHandler = newSyncAclHandler(storage.Id(), s, s.syncClient, syncStatus)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *syncAcl) AddRawRecord(rawRec *consensusproto.RawRecordWithId) (err error) {
|
||||
if s.isClosed {
|
||||
return ErrSyncAclClosed
|
||||
}
|
||||
err = s.AclList.AddRawRecord(rawRec)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
headUpdate := s.syncClient.CreateHeadUpdate(s, []*consensusproto.RawRecordWithId{rawRec})
|
||||
s.headUpdater.UpdateHeads(s.Id(), []string{rawRec.Id})
|
||||
s.syncClient.Broadcast(headUpdate)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *syncAcl) AddRawRecords(rawRecords []*consensusproto.RawRecordWithId) (err error) {
|
||||
if s.isClosed {
|
||||
return ErrSyncAclClosed
|
||||
}
|
||||
err = s.AclList.AddRawRecords(rawRecords)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
headUpdate := s.syncClient.CreateHeadUpdate(s, rawRecords)
|
||||
s.headUpdater.UpdateHeads(s.Id(), []string{rawRecords[len(rawRecords)-1].Id})
|
||||
s.syncClient.Broadcast(headUpdate)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *syncAcl) SyncWithPeer(ctx context.Context, peerId string) (err error) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
headUpdate := s.syncClient.CreateHeadUpdate(s, nil)
|
||||
return s.syncClient.SendUpdate(peerId, headUpdate)
|
||||
}
|
||||
|
||||
func (s *syncAcl) Close(ctx context.Context) (err error) {
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
s.isClosed = true
|
||||
return
|
||||
}
|
||||
|
||||
func (s *syncAcl) Name() (name string) {
|
||||
return CName
|
||||
}
|
||||
|
||||
@ -2,30 +2,81 @@ package syncacl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anytypeio/any-sync/commonspace/spacesyncproto"
|
||||
"errors"
|
||||
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anyproto/any-sync/commonspace/syncstatus"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMessageIsRequest = errors.New("message is request")
|
||||
ErrMessageIsNotRequest = errors.New("message is not request")
|
||||
)
|
||||
|
||||
type syncAclHandler struct {
|
||||
acl list.AclList
|
||||
aclList list.AclList
|
||||
syncClient SyncClient
|
||||
syncProtocol AclSyncProtocol
|
||||
syncStatus syncstatus.StatusUpdater
|
||||
spaceId string
|
||||
}
|
||||
|
||||
func (s *syncAclHandler) HandleMessage(ctx context.Context, senderId string, req *spacesyncproto.ObjectSyncMessage) (err error) {
|
||||
aclMsg := &aclrecordproto.AclSyncMessage{}
|
||||
if err = aclMsg.Unmarshal(req.Payload); err != nil {
|
||||
func newSyncAclHandler(spaceId string, aclList list.AclList, syncClient SyncClient, syncStatus syncstatus.StatusUpdater) synchandler.SyncHandler {
|
||||
return &syncAclHandler{
|
||||
aclList: aclList,
|
||||
syncClient: syncClient,
|
||||
syncProtocol: newAclSyncProtocol(spaceId, aclList, syncClient),
|
||||
syncStatus: syncStatus,
|
||||
spaceId: spaceId,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *syncAclHandler) HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) {
|
||||
unmarshalled := &consensusproto.LogSyncMessage{}
|
||||
err = proto.Unmarshal(message.Payload, unmarshalled)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
content := aclMsg.GetContent()
|
||||
content := unmarshalled.GetContent()
|
||||
head := consensusproto.GetHead(unmarshalled)
|
||||
s.syncStatus.HeadsReceive(senderId, s.aclList.Id(), []string{head})
|
||||
s.aclList.Lock()
|
||||
defer s.aclList.Unlock()
|
||||
switch {
|
||||
case content.GetAddRecords() != nil:
|
||||
return s.handleAddRecords(ctx, senderId, content.GetAddRecords())
|
||||
default:
|
||||
return fmt.Errorf("unexpected aclSync message: %T", content.Value)
|
||||
case content.GetHeadUpdate() != nil:
|
||||
var syncReq *consensusproto.LogSyncMessage
|
||||
syncReq, err = s.syncProtocol.HeadUpdate(ctx, senderId, content.GetHeadUpdate())
|
||||
if err != nil || syncReq == nil {
|
||||
return
|
||||
}
|
||||
return s.syncClient.QueueRequest(senderId, syncReq)
|
||||
case content.GetFullSyncRequest() != nil:
|
||||
return ErrMessageIsRequest
|
||||
case content.GetFullSyncResponse() != nil:
|
||||
return s.syncProtocol.FullSyncResponse(ctx, senderId, content.GetFullSyncResponse())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *syncAclHandler) handleAddRecords(ctx context.Context, senderId string, addRecord *aclrecordproto.AclAddRecords) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (s *syncAclHandler) HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) {
|
||||
unmarshalled := &consensusproto.LogSyncMessage{}
|
||||
err = proto.Unmarshal(request.Payload, unmarshalled)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
fullSyncRequest := unmarshalled.GetContent().GetFullSyncRequest()
|
||||
if fullSyncRequest == nil {
|
||||
return nil, ErrMessageIsNotRequest
|
||||
}
|
||||
s.aclList.Lock()
|
||||
defer s.aclList.Unlock()
|
||||
aclResp, err := s.syncProtocol.FullSyncRequest(ctx, senderId, fullSyncRequest)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return spacesyncproto.MarshallSyncMessage(aclResp, s.spaceId, s.aclList.Id())
|
||||
}
|
||||
|
||||
233
commonspace/object/acl/syncacl/syncaclhandler_test.go
Normal file
233
commonspace/object/acl/syncacl/syncaclhandler_test.go
Normal file
@ -0,0 +1,233 @@
|
||||
package syncacl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list/mock_list"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl/mock_syncacl"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anyproto/any-sync/commonspace/syncstatus"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testAclMock struct {
|
||||
*mock_list.MockAclList
|
||||
m sync.RWMutex
|
||||
}
|
||||
|
||||
func newTestAclMock(mockAcl *mock_list.MockAclList) *testAclMock {
|
||||
return &testAclMock{
|
||||
MockAclList: mockAcl,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *testAclMock) Lock() {
|
||||
t.m.Lock()
|
||||
}
|
||||
|
||||
func (t *testAclMock) RLock() {
|
||||
t.m.RLock()
|
||||
}
|
||||
|
||||
func (t *testAclMock) Unlock() {
|
||||
t.m.Unlock()
|
||||
}
|
||||
|
||||
func (t *testAclMock) RUnlock() {
|
||||
t.m.RUnlock()
|
||||
}
|
||||
|
||||
func (t *testAclMock) TryLock() bool {
|
||||
return t.m.TryLock()
|
||||
}
|
||||
|
||||
func (t *testAclMock) TryRLock() bool {
|
||||
return t.m.TryRLock()
|
||||
}
|
||||
|
||||
type syncHandlerFixture struct {
|
||||
ctrl *gomock.Controller
|
||||
syncClientMock *mock_syncacl.MockSyncClient
|
||||
aclMock *testAclMock
|
||||
syncProtocolMock *mock_syncacl.MockAclSyncProtocol
|
||||
spaceId string
|
||||
senderId string
|
||||
aclId string
|
||||
|
||||
syncHandler *syncAclHandler
|
||||
}
|
||||
|
||||
func newSyncHandlerFixture(t *testing.T) *syncHandlerFixture {
|
||||
ctrl := gomock.NewController(t)
|
||||
aclMock := newTestAclMock(mock_list.NewMockAclList(ctrl))
|
||||
syncClientMock := mock_syncacl.NewMockSyncClient(ctrl)
|
||||
syncProtocolMock := mock_syncacl.NewMockAclSyncProtocol(ctrl)
|
||||
spaceId := "spaceId"
|
||||
|
||||
syncHandler := &syncAclHandler{
|
||||
aclList: aclMock,
|
||||
syncClient: syncClientMock,
|
||||
syncProtocol: syncProtocolMock,
|
||||
syncStatus: syncstatus.NewNoOpSyncStatus(),
|
||||
spaceId: spaceId,
|
||||
}
|
||||
return &syncHandlerFixture{
|
||||
ctrl: ctrl,
|
||||
syncClientMock: syncClientMock,
|
||||
aclMock: aclMock,
|
||||
syncProtocolMock: syncProtocolMock,
|
||||
spaceId: spaceId,
|
||||
senderId: "senderId",
|
||||
aclId: "aclId",
|
||||
syncHandler: syncHandler,
|
||||
}
|
||||
}
|
||||
|
||||
func (fx *syncHandlerFixture) stop() {
|
||||
fx.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestSyncAclHandler_HandleMessage(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("handle head update, request returned", func(t *testing.T) {
|
||||
fx := newSyncHandlerFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
headUpdate := &consensusproto.LogHeadUpdate{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
|
||||
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
|
||||
|
||||
syncReq := &consensusproto.LogSyncMessage{}
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(syncReq, nil)
|
||||
fx.syncClientMock.EXPECT().QueueRequest(fx.senderId, syncReq).Return(nil)
|
||||
|
||||
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("handle head update, no request", func(t *testing.T) {
|
||||
fx := newSyncHandlerFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
headUpdate := &consensusproto.LogHeadUpdate{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
|
||||
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
|
||||
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(nil, nil)
|
||||
|
||||
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
t.Run("handle head update, returned error", func(t *testing.T) {
|
||||
fx := newSyncHandlerFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
headUpdate := &consensusproto.LogHeadUpdate{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
|
||||
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
|
||||
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
expectedErr := fmt.Errorf("some error")
|
||||
fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(nil, expectedErr)
|
||||
|
||||
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
|
||||
require.Error(t, expectedErr, err)
|
||||
})
|
||||
t.Run("handle full sync request is forbidden", func(t *testing.T) {
|
||||
fx := newSyncHandlerFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
fullRequest := &consensusproto.LogFullSyncRequest{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
logMessage := consensusproto.WrapFullRequest(fullRequest, chWithId)
|
||||
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
|
||||
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
|
||||
require.Error(t, ErrMessageIsRequest, err)
|
||||
})
|
||||
t.Run("handle full sync response, no error", func(t *testing.T) {
|
||||
fx := newSyncHandlerFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
fullResponse := &consensusproto.LogFullSyncResponse{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
logMessage := consensusproto.WrapFullResponse(fullResponse, chWithId)
|
||||
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
|
||||
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.syncProtocolMock.EXPECT().FullSyncResponse(ctx, fx.senderId, gomock.Any()).Return(nil)
|
||||
|
||||
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSyncAclHandler_HandleRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
t.Run("handle full sync request, no error", func(t *testing.T) {
|
||||
fx := newSyncHandlerFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
fullRequest := &consensusproto.LogFullSyncRequest{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
logMessage := consensusproto.WrapFullRequest(fullRequest, chWithId)
|
||||
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
|
||||
fullResp := &consensusproto.LogSyncMessage{
|
||||
Content: &consensusproto.LogSyncContentValue{
|
||||
Value: &consensusproto.LogSyncContentValue_FullSyncResponse{
|
||||
FullSyncResponse: &consensusproto.LogFullSyncResponse{
|
||||
Head: "returnedHead",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
fx.syncProtocolMock.EXPECT().FullSyncRequest(ctx, fx.senderId, gomock.Any()).Return(fullResp, nil)
|
||||
res, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg)
|
||||
require.NoError(t, err)
|
||||
unmarshalled := &consensusproto.LogSyncMessage{}
|
||||
err = proto.Unmarshal(res.Payload, unmarshalled)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
require.Equal(t, "returnedHead", consensusproto.GetHead(unmarshalled))
|
||||
})
|
||||
t.Run("handle other message returns error", func(t *testing.T) {
|
||||
fx := newSyncHandlerFixture(t)
|
||||
defer fx.stop()
|
||||
chWithId := &consensusproto.RawRecordWithId{}
|
||||
headUpdate := &consensusproto.LogHeadUpdate{
|
||||
Head: "h1",
|
||||
Records: []*consensusproto.RawRecordWithId{chWithId},
|
||||
}
|
||||
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
|
||||
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
|
||||
|
||||
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
|
||||
_, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg)
|
||||
require.Error(t, ErrMessageIsNotRequest, err)
|
||||
})
|
||||
}
|
||||
70
commonspace/object/acl/syncacl/syncclient.go
Normal file
70
commonspace/object/acl/syncacl/syncclient.go
Normal file
@ -0,0 +1,70 @@
|
||||
package syncacl
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/anyproto/any-sync/commonspace/peermanager"
|
||||
"github.com/anyproto/any-sync/commonspace/requestmanager"
|
||||
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
|
||||
"github.com/anyproto/any-sync/consensus/consensusproto"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type SyncClient interface {
|
||||
RequestFactory
|
||||
Broadcast(msg *consensusproto.LogSyncMessage)
|
||||
SendUpdate(peerId string, msg *consensusproto.LogSyncMessage) (err error)
|
||||
QueueRequest(peerId string, msg *consensusproto.LogSyncMessage) (err error)
|
||||
SendRequest(ctx context.Context, peerId string, msg *consensusproto.LogSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
|
||||
}
|
||||
|
||||
type syncClient struct {
|
||||
RequestFactory
|
||||
spaceId string
|
||||
requestManager requestmanager.RequestManager
|
||||
peerManager peermanager.PeerManager
|
||||
}
|
||||
|
||||
func NewSyncClient(spaceId string, requestManager requestmanager.RequestManager, peerManager peermanager.PeerManager) SyncClient {
|
||||
return &syncClient{
|
||||
RequestFactory: &requestFactory{},
|
||||
spaceId: spaceId,
|
||||
requestManager: requestManager,
|
||||
peerManager: peerManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *syncClient) Broadcast(msg *consensusproto.LogSyncMessage) {
|
||||
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = s.peerManager.Broadcast(context.Background(), objMsg)
|
||||
if err != nil {
|
||||
log.Debug("broadcast error", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
func (s *syncClient) SendUpdate(peerId string, msg *consensusproto.LogSyncMessage) (err error) {
|
||||
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.peerManager.SendPeer(context.Background(), peerId, objMsg)
|
||||
}
|
||||
|
||||
func (s *syncClient) SendRequest(ctx context.Context, peerId string, msg *consensusproto.LogSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
|
||||
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.requestManager.SendRequest(ctx, peerId, objMsg)
|
||||
}
|
||||
|
||||
func (s *syncClient) QueueRequest(peerId string, msg *consensusproto.LogSyncMessage) (err error) {
|
||||
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return s.requestManager.QueueRequest(peerId, objMsg)
|
||||
}
|
||||
@ -1,194 +0,0 @@
|
||||
package acllistbuilder
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anytypeio/any-sync/util/keys"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/encryptionkey"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
"github.com/anytypeio/any-sync/util/keys/symmetric"
|
||||
"hash/fnv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type SymKey struct {
|
||||
Hash uint64
|
||||
Key *symmetric.Key
|
||||
}
|
||||
|
||||
type YAMLKeychain struct {
|
||||
SigningKeysByYAMLName map[string]signingkey.PrivKey
|
||||
SigningKeysByRealIdentity map[string]signingkey.PrivKey
|
||||
EncryptionKeysByYAMLName map[string]encryptionkey.PrivKey
|
||||
ReadKeysByYAMLName map[string]*SymKey
|
||||
ReadKeysByHash map[uint64]*SymKey
|
||||
GeneratedIdentities map[string]string
|
||||
}
|
||||
|
||||
func NewKeychain() *YAMLKeychain {
|
||||
return &YAMLKeychain{
|
||||
SigningKeysByYAMLName: map[string]signingkey.PrivKey{},
|
||||
SigningKeysByRealIdentity: map[string]signingkey.PrivKey{},
|
||||
EncryptionKeysByYAMLName: map[string]encryptionkey.PrivKey{},
|
||||
GeneratedIdentities: map[string]string{},
|
||||
ReadKeysByYAMLName: map[string]*SymKey{},
|
||||
ReadKeysByHash: map[uint64]*SymKey{},
|
||||
}
|
||||
}
|
||||
|
||||
func (k *YAMLKeychain) ParseKeys(keys *Keys) {
|
||||
for _, encKey := range keys.Enc {
|
||||
k.AddEncryptionKey(encKey)
|
||||
}
|
||||
|
||||
for _, signKey := range keys.Sign {
|
||||
k.AddSigningKey(signKey)
|
||||
}
|
||||
|
||||
for _, readKey := range keys.Read {
|
||||
k.AddReadKey(readKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (k *YAMLKeychain) AddEncryptionKey(key *Key) {
|
||||
if _, exists := k.EncryptionKeysByYAMLName[key.Name]; exists {
|
||||
return
|
||||
}
|
||||
var (
|
||||
newPrivKey encryptionkey.PrivKey
|
||||
err error
|
||||
)
|
||||
if key.Value == "generated" {
|
||||
newPrivKey, _, err = encryptionkey.GenerateRandomRSAKeyPair(2048)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
newPrivKey, err = keys.DecodeKeyFromString(key.Value, encryptionkey.NewEncryptionRsaPrivKeyFromBytes, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
k.EncryptionKeysByYAMLName[key.Name] = newPrivKey
|
||||
}
|
||||
|
||||
func (k *YAMLKeychain) AddSigningKey(key *Key) {
|
||||
if _, exists := k.SigningKeysByYAMLName[key.Name]; exists {
|
||||
return
|
||||
}
|
||||
var (
|
||||
newPrivKey signingkey.PrivKey
|
||||
pubKey signingkey.PubKey
|
||||
err error
|
||||
)
|
||||
if key.Value == "generated" {
|
||||
newPrivKey, pubKey, err = signingkey.GenerateRandomEd25519KeyPair()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
newPrivKey, err = keys.DecodeKeyFromString(key.Value, signingkey.NewSigningEd25519PrivKeyFromBytes, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
pubKey = newPrivKey.GetPublic()
|
||||
}
|
||||
|
||||
k.SigningKeysByYAMLName[key.Name] = newPrivKey
|
||||
rawPubKey, err := pubKey.Raw()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
encoded := string(rawPubKey)
|
||||
|
||||
k.SigningKeysByRealIdentity[encoded] = newPrivKey
|
||||
k.GeneratedIdentities[key.Name] = encoded
|
||||
}
|
||||
|
||||
func (k *YAMLKeychain) AddReadKey(key *Key) {
|
||||
if _, exists := k.ReadKeysByYAMLName[key.Name]; exists {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
rkey *symmetric.Key
|
||||
err error
|
||||
)
|
||||
if key.Value == "generated" {
|
||||
rkey, err = symmetric.NewRandom()
|
||||
if err != nil {
|
||||
panic("should be able to generate symmetric key")
|
||||
}
|
||||
} else if key.Value == "derived" {
|
||||
signKey, _ := k.SigningKeysByYAMLName[key.Name].Raw()
|
||||
encKey, _ := k.EncryptionKeysByYAMLName[key.Name].Raw()
|
||||
rkey, err = aclrecordproto.AclReadKeyDerive(signKey, encKey)
|
||||
if err != nil {
|
||||
panic("should be able to derive symmetric key")
|
||||
}
|
||||
} else {
|
||||
rkey, err = symmetric.FromString(key.Value)
|
||||
if err != nil {
|
||||
panic("should be able to parse symmetric key")
|
||||
}
|
||||
}
|
||||
|
||||
hasher := fnv.New64()
|
||||
hasher.Write(rkey.Bytes())
|
||||
|
||||
k.ReadKeysByYAMLName[key.Name] = &SymKey{
|
||||
Hash: hasher.Sum64(),
|
||||
Key: rkey,
|
||||
}
|
||||
k.ReadKeysByHash[hasher.Sum64()] = &SymKey{
|
||||
Hash: hasher.Sum64(),
|
||||
Key: rkey,
|
||||
}
|
||||
}
|
||||
|
||||
func (k *YAMLKeychain) AddKey(key *Key) {
|
||||
parts := strings.Split(key.Name, ".")
|
||||
if len(parts) != 3 {
|
||||
panic("cannot parse a key")
|
||||
}
|
||||
|
||||
switch parts[1] {
|
||||
case "Signature":
|
||||
k.AddSigningKey(key)
|
||||
case "Enc":
|
||||
k.AddEncryptionKey(key)
|
||||
case "Read":
|
||||
k.AddReadKey(key)
|
||||
default:
|
||||
panic("incorrect format")
|
||||
}
|
||||
}
|
||||
|
||||
func (k *YAMLKeychain) GetKey(key string) interface{} {
|
||||
parts := strings.Split(key, ".")
|
||||
if len(parts) != 3 {
|
||||
panic("cannot parse a key")
|
||||
}
|
||||
name := parts[2]
|
||||
|
||||
switch parts[1] {
|
||||
case "Sign":
|
||||
if key, exists := k.SigningKeysByYAMLName[name]; exists {
|
||||
return key
|
||||
}
|
||||
case "Enc":
|
||||
if key, exists := k.EncryptionKeysByYAMLName[name]; exists {
|
||||
return key
|
||||
}
|
||||
case "Read":
|
||||
if key, exists := k.ReadKeysByYAMLName[name]; exists {
|
||||
return key
|
||||
}
|
||||
default:
|
||||
panic("incorrect format")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k *YAMLKeychain) GetIdentity(name string) string {
|
||||
return k.GeneratedIdentities[name]
|
||||
}
|
||||
@ -1,295 +0,0 @@
|
||||
package acllistbuilder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/liststorage"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/testutils/yamltests"
|
||||
"github.com/anytypeio/any-sync/util/cidutil"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/encryptionkey"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
"github.com/anytypeio/any-sync/util/keys/symmetric"
|
||||
"gopkg.in/yaml.v3"
|
||||
"io/ioutil"
|
||||
"path"
|
||||
"time"
|
||||
|
||||
"github.com/gogo/protobuf/proto"
|
||||
)
|
||||
|
||||
type AclListStorageBuilder struct {
|
||||
liststorage.ListStorage
|
||||
keychain *YAMLKeychain
|
||||
}
|
||||
|
||||
func NewAclListStorageBuilder(keychain *YAMLKeychain) *AclListStorageBuilder {
|
||||
return &AclListStorageBuilder{
|
||||
keychain: keychain,
|
||||
}
|
||||
}
|
||||
|
||||
func NewListStorageWithTestName(name string) (liststorage.ListStorage, error) {
|
||||
filePath := path.Join(yamltests.Path(), name)
|
||||
return NewAclListStorageBuilderFromFile(filePath)
|
||||
}
|
||||
|
||||
func NewAclListStorageBuilderFromFile(file string) (*AclListStorageBuilder, error) {
|
||||
content, err := ioutil.ReadFile(file)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ymlTree := YMLList{}
|
||||
err = yaml.Unmarshal(content, &ymlTree)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tb := NewAclListStorageBuilder(NewKeychain())
|
||||
tb.Parse(&ymlTree)
|
||||
|
||||
return tb, nil
|
||||
}
|
||||
|
||||
func (t *AclListStorageBuilder) createRaw(rec proto.Marshaler, identity []byte) *aclrecordproto.RawAclRecordWithId {
|
||||
protoMarshalled, err := rec.Marshal()
|
||||
if err != nil {
|
||||
panic("should be able to marshal final acl message!")
|
||||
}
|
||||
|
||||
signature, err := t.keychain.SigningKeysByRealIdentity[string(identity)].Sign(protoMarshalled)
|
||||
if err != nil {
|
||||
panic("should be able to sign final acl message!")
|
||||
}
|
||||
|
||||
rawRec := &aclrecordproto.RawAclRecord{
|
||||
Payload: protoMarshalled,
|
||||
Signature: signature,
|
||||
}
|
||||
|
||||
rawMarshalled, err := proto.Marshal(rawRec)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
id, _ := cidutil.NewCidFromBytes(rawMarshalled)
|
||||
|
||||
return &aclrecordproto.RawAclRecordWithId{
|
||||
Payload: rawMarshalled,
|
||||
Id: id,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *AclListStorageBuilder) GetKeychain() *YAMLKeychain {
|
||||
return t.keychain
|
||||
}
|
||||
|
||||
func (t *AclListStorageBuilder) Parse(l *YMLList) {
|
||||
// Just to clarify - we are generating new identities for the ones that
|
||||
// are specified in the yml file, because our identities should be Ed25519
|
||||
// the same thing is happening for the encryption keys
|
||||
t.keychain.ParseKeys(&l.Keys)
|
||||
rawRoot := t.parseRoot(l.Root)
|
||||
var err error
|
||||
t.ListStorage, err = liststorage.NewInMemoryAclListStorage(rawRoot.Id, []*aclrecordproto.RawAclRecordWithId{rawRoot})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
prevId := rawRoot.Id
|
||||
for _, rec := range l.Records {
|
||||
newRecord := t.parseRecord(rec, prevId)
|
||||
rawRecord := t.createRaw(newRecord, newRecord.Identity)
|
||||
err = t.AddRawRecord(context.Background(), rawRecord)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
prevId = rawRecord.Id
|
||||
}
|
||||
t.SetHead(prevId)
|
||||
}
|
||||
|
||||
func (t *AclListStorageBuilder) parseRecord(rec *Record, prevId string) *aclrecordproto.AclRecord {
|
||||
k := t.keychain.GetKey(rec.ReadKey).(*SymKey)
|
||||
var aclChangeContents []*aclrecordproto.AclContentValue
|
||||
for _, ch := range rec.AclChanges {
|
||||
aclChangeContent := t.parseAclChange(ch)
|
||||
aclChangeContents = append(aclChangeContents, aclChangeContent)
|
||||
}
|
||||
data := &aclrecordproto.AclData{
|
||||
AclContent: aclChangeContents,
|
||||
}
|
||||
bytes, _ := data.Marshal()
|
||||
|
||||
return &aclrecordproto.AclRecord{
|
||||
PrevId: prevId,
|
||||
Identity: []byte(t.keychain.GetIdentity(rec.Identity)),
|
||||
Data: bytes,
|
||||
CurrentReadKeyHash: k.Hash,
|
||||
Timestamp: time.Now().UnixNano(),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *AclListStorageBuilder) parseAclChange(ch *AclChange) (convCh *aclrecordproto.AclContentValue) {
|
||||
switch {
|
||||
case ch.UserAdd != nil:
|
||||
add := ch.UserAdd
|
||||
|
||||
encKey := t.keychain.GetKey(add.EncryptionKey).(encryptionkey.PrivKey)
|
||||
rawKey, _ := encKey.GetPublic().Raw()
|
||||
|
||||
convCh = &aclrecordproto.AclContentValue{
|
||||
Value: &aclrecordproto.AclContentValue_UserAdd{
|
||||
UserAdd: &aclrecordproto.AclUserAdd{
|
||||
Identity: []byte(t.keychain.GetIdentity(add.Identity)),
|
||||
EncryptionKey: rawKey,
|
||||
EncryptedReadKeys: t.encryptReadKeysWithPubKey(add.EncryptedReadKeys, encKey),
|
||||
Permissions: t.convertPermission(add.Permission),
|
||||
},
|
||||
},
|
||||
}
|
||||
case ch.UserJoin != nil:
|
||||
join := ch.UserJoin
|
||||
|
||||
encKey := t.keychain.GetKey(join.EncryptionKey).(encryptionkey.PrivKey)
|
||||
rawKey, _ := encKey.GetPublic().Raw()
|
||||
|
||||
idKey, _ := t.keychain.SigningKeysByYAMLName[join.Identity].GetPublic().Raw()
|
||||
signKey := t.keychain.GetKey(join.AcceptKey).(signingkey.PrivKey)
|
||||
signature, err := signKey.Sign(idKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
acceptPubKey, _ := signKey.GetPublic().Raw()
|
||||
|
||||
convCh = &aclrecordproto.AclContentValue{
|
||||
Value: &aclrecordproto.AclContentValue_UserJoin{
|
||||
UserJoin: &aclrecordproto.AclUserJoin{
|
||||
Identity: []byte(t.keychain.GetIdentity(join.Identity)),
|
||||
EncryptionKey: rawKey,
|
||||
AcceptSignature: signature,
|
||||
AcceptPubKey: acceptPubKey,
|
||||
EncryptedReadKeys: t.encryptReadKeysWithPubKey(join.EncryptedReadKeys, encKey),
|
||||
},
|
||||
},
|
||||
}
|
||||
case ch.UserInvite != nil:
|
||||
invite := ch.UserInvite
|
||||
rawAcceptKey, _ := t.keychain.GetKey(invite.AcceptKey).(signingkey.PrivKey).GetPublic().Raw()
|
||||
hash := t.keychain.GetKey(invite.EncryptionKey).(*SymKey).Hash
|
||||
encKey := t.keychain.ReadKeysByHash[hash]
|
||||
|
||||
convCh = &aclrecordproto.AclContentValue{
|
||||
Value: &aclrecordproto.AclContentValue_UserInvite{
|
||||
UserInvite: &aclrecordproto.AclUserInvite{
|
||||
AcceptPublicKey: rawAcceptKey,
|
||||
EncryptSymKeyHash: hash,
|
||||
EncryptedReadKeys: t.encryptReadKeysWithSymKey(invite.EncryptedReadKeys, encKey.Key),
|
||||
Permissions: t.convertPermission(invite.Permissions),
|
||||
},
|
||||
},
|
||||
}
|
||||
case ch.UserPermissionChange != nil:
|
||||
permissionChange := ch.UserPermissionChange
|
||||
|
||||
convCh = &aclrecordproto.AclContentValue{
|
||||
Value: &aclrecordproto.AclContentValue_UserPermissionChange{
|
||||
UserPermissionChange: &aclrecordproto.AclUserPermissionChange{
|
||||
Identity: []byte(t.keychain.GetIdentity(permissionChange.Identity)),
|
||||
Permissions: t.convertPermission(permissionChange.Permission),
|
||||
},
|
||||
},
|
||||
}
|
||||
case ch.UserRemove != nil:
|
||||
remove := ch.UserRemove
|
||||
|
||||
newReadKey := t.keychain.GetKey(remove.NewReadKey).(*SymKey)
|
||||
|
||||
var replaces []*aclrecordproto.AclReadKeyReplace
|
||||
for _, id := range remove.IdentitiesLeft {
|
||||
encKey := t.keychain.EncryptionKeysByYAMLName[id]
|
||||
rawEncKey, _ := encKey.GetPublic().Raw()
|
||||
encReadKey, err := encKey.GetPublic().Encrypt(newReadKey.Key.Bytes())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
replaces = append(replaces, &aclrecordproto.AclReadKeyReplace{
|
||||
Identity: []byte(t.keychain.GetIdentity(id)),
|
||||
EncryptionKey: rawEncKey,
|
||||
EncryptedReadKey: encReadKey,
|
||||
})
|
||||
}
|
||||
|
||||
convCh = &aclrecordproto.AclContentValue{
|
||||
Value: &aclrecordproto.AclContentValue_UserRemove{
|
||||
UserRemove: &aclrecordproto.AclUserRemove{
|
||||
Identity: []byte(t.keychain.GetIdentity(remove.RemovedIdentity)),
|
||||
ReadKeyReplaces: replaces,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
if convCh == nil {
|
||||
panic("cannot have empty acl change")
|
||||
}
|
||||
|
||||
return convCh
|
||||
}
|
||||
|
||||
func (t *AclListStorageBuilder) encryptReadKeysWithPubKey(keys []string, encKey encryptionkey.PrivKey) (enc [][]byte) {
|
||||
for _, k := range keys {
|
||||
realKey := t.keychain.GetKey(k).(*SymKey).Key.Bytes()
|
||||
res, err := encKey.GetPublic().Encrypt(realKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
enc = append(enc, res)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (t *AclListStorageBuilder) encryptReadKeysWithSymKey(keys []string, key *symmetric.Key) (enc [][]byte) {
|
||||
for _, k := range keys {
|
||||
realKey := t.keychain.GetKey(k).(*SymKey).Key.Bytes()
|
||||
res, err := key.Encrypt(realKey)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
enc = append(enc, res)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (t *AclListStorageBuilder) convertPermission(perm string) aclrecordproto.AclUserPermissions {
|
||||
switch perm {
|
||||
case "admin":
|
||||
return aclrecordproto.AclUserPermissions_Admin
|
||||
case "writer":
|
||||
return aclrecordproto.AclUserPermissions_Writer
|
||||
case "reader":
|
||||
return aclrecordproto.AclUserPermissions_Reader
|
||||
default:
|
||||
panic(fmt.Sprintf("incorrect permission: %s", perm))
|
||||
}
|
||||
}
|
||||
|
||||
func (t *AclListStorageBuilder) traverseFromHead(f func(rec *aclrecordproto.AclRecord, id string) error) (err error) {
|
||||
panic("this was removed, add if needed")
|
||||
}
|
||||
|
||||
func (t *AclListStorageBuilder) parseRoot(root *Root) (rawRoot *aclrecordproto.RawAclRecordWithId) {
|
||||
rawSignKey, _ := t.keychain.SigningKeysByYAMLName[root.Identity].GetPublic().Raw()
|
||||
rawEncKey, _ := t.keychain.EncryptionKeysByYAMLName[root.Identity].GetPublic().Raw()
|
||||
readKey := t.keychain.ReadKeysByYAMLName[root.Identity]
|
||||
aclRoot := &aclrecordproto.AclRoot{
|
||||
Identity: rawSignKey,
|
||||
EncryptionKey: rawEncKey,
|
||||
SpaceId: root.SpaceId,
|
||||
EncryptedReadKey: nil,
|
||||
DerivationScheme: "scheme",
|
||||
CurrentReadKeyHash: readKey.Hash,
|
||||
}
|
||||
return t.createRaw(aclRoot, rawSignKey)
|
||||
}
|
||||
@ -1,11 +0,0 @@
|
||||
//go:build ((!linux && !darwin) || android || ios || nographviz) && !amd64
|
||||
// +build !linux,!darwin android ios nographviz
|
||||
// +build !amd64
|
||||
|
||||
package acllistbuilder
|
||||
|
||||
import "fmt"
|
||||
|
||||
func (t *AclListStorageBuilder) Graph() (string, error) {
|
||||
return "", fmt.Errorf("building graphs is not supported")
|
||||
}
|
||||
@ -1,120 +0,0 @@
|
||||
//go:build (linux || darwin) && !android && !ios && !nographviz && (amd64 || arm64)
|
||||
// +build linux darwin
|
||||
// +build !android
|
||||
// +build !ios
|
||||
// +build !nographviz
|
||||
// +build amd64 arm64
|
||||
|
||||
package acllistbuilder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/awalterschulze/gographviz"
|
||||
)
|
||||
|
||||
// To quickly look at visualized string you can use https://dreampuf.github.io/GraphvizOnline
|
||||
|
||||
type EdgeParameters struct {
|
||||
style string
|
||||
color string
|
||||
label string
|
||||
}
|
||||
|
||||
func (t *AclListStorageBuilder) Graph() (string, error) {
|
||||
// TODO: check updates on https://github.com/goccy/go-graphviz/issues/52 or make a fix yourself to use better library here
|
||||
graph := gographviz.NewGraph()
|
||||
graph.SetName("G")
|
||||
graph.SetDir(true)
|
||||
var nodes = make(map[string]struct{})
|
||||
|
||||
var addNodes = func(r *aclrecordproto.AclRecord, id string) error {
|
||||
style := "solid"
|
||||
|
||||
var chSymbs []string
|
||||
aclData := &aclrecordproto.AclData{}
|
||||
err := proto.Unmarshal(r.GetData(), aclData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, chc := range aclData.AclContent {
|
||||
tp := fmt.Sprintf("%T", chc.Value)
|
||||
tp = strings.Replace(tp, "AclChangeAclContentValueValueOf", "", 1)
|
||||
res := ""
|
||||
for _, ts := range tp {
|
||||
if unicode.IsUpper(ts) {
|
||||
res += string(ts)
|
||||
}
|
||||
}
|
||||
chSymbs = append(chSymbs, res)
|
||||
}
|
||||
|
||||
shortId := id
|
||||
label := fmt.Sprintf("Id: %s\nChanges: %s\n",
|
||||
shortId,
|
||||
strings.Join(chSymbs, ","),
|
||||
)
|
||||
e := graph.AddNode("G", "\""+id+"\"", map[string]string{
|
||||
"label": "\"" + label + "\"",
|
||||
"style": "\"" + style + "\"",
|
||||
})
|
||||
if e != nil {
|
||||
return e
|
||||
}
|
||||
nodes[id] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
var createEdge = func(firstId, secondId string, params EdgeParameters) error {
|
||||
_, exists := nodes[firstId]
|
||||
if !exists {
|
||||
return fmt.Errorf("no such node")
|
||||
}
|
||||
_, exists = nodes[secondId]
|
||||
if !exists {
|
||||
return fmt.Errorf("no previous node")
|
||||
}
|
||||
|
||||
err := graph.AddEdge("\""+firstId+"\"", "\""+secondId+"\"", true, map[string]string{
|
||||
"color": params.color,
|
||||
"style": params.style,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var addLinks = func(r *aclrecordproto.AclRecord, id string) error {
|
||||
if r.PrevId == "" {
|
||||
return nil
|
||||
}
|
||||
err := createEdge(id, r.PrevId, EdgeParameters{
|
||||
style: "dashed",
|
||||
color: "red",
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
err := t.traverseFromHead(addNodes)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
err = t.traverseFromHead(addLinks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return graph.String(), nil
|
||||
}
|
||||
@ -1,70 +0,0 @@
|
||||
package acllistbuilder
|
||||
|
||||
type Key struct {
|
||||
Name string `yaml:"name"`
|
||||
Value string `yaml:"value"`
|
||||
}
|
||||
|
||||
type Keys struct {
|
||||
Derived string `yaml:"Derived"`
|
||||
Enc []*Key `yaml:"Enc"`
|
||||
Sign []*Key `yaml:"Sign"`
|
||||
Read []*Key `yaml:"Read"`
|
||||
}
|
||||
|
||||
type AclChange struct {
|
||||
UserAdd *struct {
|
||||
Identity string `yaml:"identity"`
|
||||
EncryptionKey string `yaml:"encryptionKey"`
|
||||
EncryptedReadKeys []string `yaml:"encryptedReadKeys"`
|
||||
Permission string `yaml:"permission"`
|
||||
} `yaml:"userAdd"`
|
||||
|
||||
UserJoin *struct {
|
||||
Identity string `yaml:"identity"`
|
||||
EncryptionKey string `yaml:"encryptionKey"`
|
||||
AcceptKey string `yaml:"acceptKey"`
|
||||
EncryptedReadKeys []string `yaml:"encryptedReadKeys"`
|
||||
} `yaml:"userJoin"`
|
||||
|
||||
UserInvite *struct {
|
||||
AcceptKey string `yaml:"acceptKey"`
|
||||
EncryptionKey string `yaml:"encryptionKey"`
|
||||
EncryptedReadKeys []string `yaml:"encryptedReadKeys"`
|
||||
Permissions string `yaml:"permissions"`
|
||||
} `yaml:"userInvite"`
|
||||
|
||||
UserRemove *struct {
|
||||
RemovedIdentity string `yaml:"removedIdentity"`
|
||||
NewReadKey string `yaml:"newReadKey"`
|
||||
IdentitiesLeft []string `yaml:"identitiesLeft"`
|
||||
} `yaml:"userRemove"`
|
||||
|
||||
UserPermissionChange *struct {
|
||||
Identity string `yaml:"identity"`
|
||||
Permission string `yaml:"permission"`
|
||||
}
|
||||
}
|
||||
|
||||
type Record struct {
|
||||
Identity string `yaml:"identity"`
|
||||
AclChanges []*AclChange `yaml:"aclChanges"`
|
||||
ReadKey string `yaml:"readKey"`
|
||||
}
|
||||
|
||||
type Header struct {
|
||||
FirstChangeId string `yaml:"firstChangeId"`
|
||||
IsWorkspace bool `yaml:"isWorkspace"`
|
||||
}
|
||||
|
||||
type Root struct {
|
||||
Identity string `yaml:"identity"`
|
||||
SpaceId string `yaml:"spaceId"`
|
||||
}
|
||||
|
||||
type YMLList struct {
|
||||
Root *Root
|
||||
Records []*Record `yaml:"records"`
|
||||
|
||||
Keys Keys `yaml:"keys"`
|
||||
}
|
||||
@ -1,15 +0,0 @@
|
||||
package yamltests
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
)
|
||||
|
||||
var (
|
||||
_, b, _, _ = runtime.Caller(0)
|
||||
basepath = filepath.Dir(b)
|
||||
)
|
||||
|
||||
func Path() string {
|
||||
return basepath
|
||||
}
|
||||
@ -1,53 +0,0 @@
|
||||
root:
|
||||
identity: A
|
||||
spaceId: space
|
||||
records:
|
||||
- identity: A
|
||||
aclChanges:
|
||||
- userInvite:
|
||||
acceptKey: key.Sign.Onetime1
|
||||
encryptionKey: key.Read.EncKey
|
||||
encryptedReadKeys: [key.Read.A]
|
||||
permissions: writer
|
||||
- userAdd:
|
||||
identity: C
|
||||
permission: reader
|
||||
encryptionKey: key.Enc.C
|
||||
encryptedReadKeys: [key.Read.A]
|
||||
readKey: key.Read.A
|
||||
- identity: B
|
||||
aclChanges:
|
||||
- userJoin:
|
||||
identity: B
|
||||
encryptionKey: key.Enc.B
|
||||
acceptKey: key.Sign.Onetime1
|
||||
encryptedReadKeys: [key.Read.A]
|
||||
readKey: key.Read.A
|
||||
keys:
|
||||
Enc:
|
||||
- name: A
|
||||
value: generated
|
||||
- name: B
|
||||
value: generated
|
||||
- name: C
|
||||
value: generated
|
||||
- name: D
|
||||
value: generated
|
||||
- name: Onetime1
|
||||
value: generated
|
||||
Sign:
|
||||
- name: A
|
||||
value: generated
|
||||
- name: B
|
||||
value: generated
|
||||
- name: C
|
||||
value: generated
|
||||
- name: D
|
||||
value: generated
|
||||
- name: Onetime1
|
||||
value: generated
|
||||
Read:
|
||||
- name: A
|
||||
value: derived
|
||||
- name: EncKey
|
||||
value: generated
|
||||
@ -1,58 +0,0 @@
|
||||
root:
|
||||
identity: A
|
||||
spaceId: space
|
||||
records:
|
||||
- identity: A
|
||||
aclChanges:
|
||||
- userInvite:
|
||||
acceptKey: key.Sign.Onetime1
|
||||
encryptionKey: key.Read.EncKey
|
||||
encryptedReadKeys: [key.Read.A]
|
||||
permissions: writer
|
||||
- userAdd:
|
||||
identity: C
|
||||
permission: reader
|
||||
encryptionKey: key.Enc.C
|
||||
encryptedReadKeys: [key.Read.A]
|
||||
readKey: key.Read.A
|
||||
- identity: B
|
||||
aclChanges:
|
||||
- userJoin:
|
||||
identity: B
|
||||
encryptionKey: key.Enc.B
|
||||
acceptKey: key.Sign.Onetime1
|
||||
encryptedReadKeys: [key.Read.A]
|
||||
readKey: key.Read.A
|
||||
- identity: A
|
||||
aclChanges:
|
||||
- userRemove:
|
||||
removedIdentity: B
|
||||
newReadKey: key.Read.2
|
||||
identitiesLeft: [A, C]
|
||||
readKey: key.Read.2
|
||||
keys:
|
||||
Enc:
|
||||
- name: A
|
||||
value: generated
|
||||
- name: B
|
||||
value: generated
|
||||
- name: C
|
||||
value: generated
|
||||
- name: Onetime1
|
||||
value: generated
|
||||
Sign:
|
||||
- name: A
|
||||
value: generated
|
||||
- name: B
|
||||
value: generated
|
||||
- name: C
|
||||
value: generated
|
||||
- name: Onetime1
|
||||
value: generated
|
||||
Read:
|
||||
- name: A
|
||||
value: derived
|
||||
- name: 2
|
||||
value: generated
|
||||
- name: EncKey
|
||||
value: generated
|
||||
@ -1,28 +0,0 @@
|
||||
package keychain
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
)
|
||||
|
||||
type Keychain struct {
|
||||
keys map[string]signingkey.PubKey
|
||||
}
|
||||
|
||||
func NewKeychain() *Keychain {
|
||||
return &Keychain{
|
||||
keys: make(map[string]signingkey.PubKey),
|
||||
}
|
||||
}
|
||||
|
||||
func (k *Keychain) GetOrAdd(identity string) (signingkey.PubKey, error) {
|
||||
if key, exists := k.keys[identity]; exists {
|
||||
return key, nil
|
||||
}
|
||||
res, err := signingkey.NewSigningEd25519PubKeyFromBytes([]byte(identity))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
k.keys[identity] = res.(signingkey.PubKey)
|
||||
return res.(signingkey.PubKey), nil
|
||||
}
|
||||
@ -2,10 +2,11 @@ package syncobjectgetter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/commonspace/objectsync/synchandler"
|
||||
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
|
||||
)
|
||||
|
||||
type SyncObject interface {
|
||||
Id() string
|
||||
synchandler.SyncHandler
|
||||
}
|
||||
|
||||
|
||||
80
commonspace/object/tree/exporter/treeexporter.go
Normal file
80
commonspace/object/tree/exporter/treeexporter.go
Normal file
@ -0,0 +1,80 @@
|
||||
package exporter
|
||||
|
||||
import (
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
)
|
||||
|
||||
type DataConverter interface {
|
||||
Unmarshall(decrypted []byte) (any, error)
|
||||
Marshall(model any) ([]byte, error)
|
||||
}
|
||||
|
||||
type TreeExporterParams struct {
|
||||
ListStorageExporter liststorage.Exporter
|
||||
TreeStorageExporter treestorage.Exporter
|
||||
DataConverter DataConverter
|
||||
}
|
||||
|
||||
type TreeExporter interface {
|
||||
ExportUnencrypted(tree objecttree.ReadableObjectTree) (err error)
|
||||
}
|
||||
|
||||
type treeExporter struct {
|
||||
listExporter liststorage.Exporter
|
||||
treeExporter treestorage.Exporter
|
||||
converter DataConverter
|
||||
}
|
||||
|
||||
func NewTreeExporter(params TreeExporterParams) TreeExporter {
|
||||
return &treeExporter{
|
||||
listExporter: params.ListStorageExporter,
|
||||
treeExporter: params.TreeStorageExporter,
|
||||
converter: params.DataConverter,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *treeExporter) ExportUnencrypted(tree objecttree.ReadableObjectTree) (err error) {
|
||||
lst := tree.AclList()
|
||||
// this exports root which should be enough before we implement acls
|
||||
_, err = t.listExporter.ListStorage(lst.Root())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
treeStorage, err := t.treeExporter.TreeStorage(tree.Header())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
changeBuilder := objecttree.NewChangeBuilder(crypto.NewKeyStorage(), tree.Header())
|
||||
putStorage := func(change *objecttree.Change) (err error) {
|
||||
var raw *treechangeproto.RawTreeChangeWithId
|
||||
raw, err = changeBuilder.Marshall(change)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return treeStorage.AddRawChange(raw)
|
||||
}
|
||||
err = tree.IterateRoot(t.converter.Unmarshall, func(change *objecttree.Change) bool {
|
||||
if change.Id == tree.Id() {
|
||||
err = putStorage(change)
|
||||
return err == nil
|
||||
}
|
||||
var data []byte
|
||||
data, err = t.converter.Marshall(change.Model)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
// that means that change is unencrypted
|
||||
change.ReadKeyId = ""
|
||||
change.Data = data
|
||||
err = putStorage(change)
|
||||
return err == nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return treeStorage.SetHeads(tree.Heads())
|
||||
}
|
||||
28
commonspace/object/tree/exporter/treeimport.go
Normal file
28
commonspace/object/tree/exporter/treeimport.go
Normal file
@ -0,0 +1,28 @@
|
||||
package exporter
|
||||
|
||||
import (
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
|
||||
)
|
||||
|
||||
type TreeImportParams struct {
|
||||
ListStorage liststorage.ListStorage
|
||||
TreeStorage treestorage.TreeStorage
|
||||
BeforeId string
|
||||
IncludeBeforeId bool
|
||||
}
|
||||
|
||||
func ImportHistoryTree(params TreeImportParams) (tree objecttree.ReadableObjectTree, err error) {
|
||||
aclList, err := list.BuildAclList(params.ListStorage, list.NoOpAcceptorVerifier{})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return objecttree.BuildNonVerifiableHistoryTree(objecttree.HistoryTreeParams{
|
||||
TreeStorage: params.TreeStorage,
|
||||
AclList: aclList,
|
||||
BeforeId: params.BeforeId,
|
||||
IncludeBeforeId: params.IncludeBeforeId,
|
||||
})
|
||||
}
|
||||
@ -2,7 +2,9 @@ package objecttree
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -17,44 +19,51 @@ type Change struct {
|
||||
AclHeadId string
|
||||
Id string
|
||||
SnapshotId string
|
||||
IsSnapshot bool
|
||||
Timestamp int64
|
||||
ReadKeyHash uint64
|
||||
Identity string
|
||||
ReadKeyId string
|
||||
Identity crypto.PubKey
|
||||
Data []byte
|
||||
Model interface{}
|
||||
Signature []byte
|
||||
|
||||
// iterator helpers
|
||||
visited bool
|
||||
branchesFinished bool
|
||||
|
||||
Signature []byte
|
||||
IsSnapshot bool
|
||||
}
|
||||
|
||||
func NewChange(id string, ch *treechangeproto.TreeChange, signature []byte) *Change {
|
||||
func NewChange(id string, identity crypto.PubKey, ch *treechangeproto.TreeChange, signature []byte) *Change {
|
||||
return &Change{
|
||||
Next: nil,
|
||||
PreviousIds: ch.TreeHeadIds,
|
||||
AclHeadId: ch.AclHeadId,
|
||||
Timestamp: ch.Timestamp,
|
||||
ReadKeyHash: ch.CurrentReadKeyHash,
|
||||
ReadKeyId: ch.ReadKeyId,
|
||||
Id: id,
|
||||
Data: ch.ChangesData,
|
||||
SnapshotId: ch.SnapshotBaseId,
|
||||
IsSnapshot: ch.IsSnapshot,
|
||||
Identity: string(ch.Identity),
|
||||
Identity: identity,
|
||||
Signature: signature,
|
||||
}
|
||||
}
|
||||
|
||||
func NewChangeFromRoot(id string, ch *treechangeproto.RootChange, signature []byte) *Change {
|
||||
func NewChangeFromRoot(id string, identity crypto.PubKey, ch *treechangeproto.RootChange, signature []byte) *Change {
|
||||
changeInfo := &treechangeproto.TreeChangeInfo{
|
||||
ChangeType: ch.ChangeType,
|
||||
ChangePayload: ch.ChangePayload,
|
||||
}
|
||||
data, _ := proto.Marshal(changeInfo)
|
||||
return &Change{
|
||||
Next: nil,
|
||||
AclHeadId: ch.AclHeadId,
|
||||
Id: id,
|
||||
IsSnapshot: true,
|
||||
Identity: string(ch.Identity),
|
||||
Timestamp: ch.Timestamp,
|
||||
Identity: identity,
|
||||
Signature: signature,
|
||||
Data: data,
|
||||
Model: changeInfo,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -2,13 +2,10 @@ package objecttree
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/keychain"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anytypeio/any-sync/util/cidutil"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
"github.com/anytypeio/any-sync/util/keys/symmetric"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anyproto/any-sync/util/cidutil"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrEmptyChange = errors.New("change payload should not be empty")
|
||||
@ -17,42 +14,64 @@ type BuilderContent struct {
|
||||
TreeHeadIds []string
|
||||
AclHeadId string
|
||||
SnapshotBaseId string
|
||||
CurrentReadKeyHash uint64
|
||||
Identity []byte
|
||||
ReadKeyId string
|
||||
IsSnapshot bool
|
||||
SigningKey signingkey.PrivKey
|
||||
ReadKey *symmetric.Key
|
||||
PrivKey crypto.PrivKey
|
||||
ReadKey crypto.SymKey
|
||||
Content []byte
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
type InitialContent struct {
|
||||
AclHeadId string
|
||||
Identity []byte
|
||||
SigningKey signingkey.PrivKey
|
||||
PrivKey crypto.PrivKey
|
||||
SpaceId string
|
||||
Seed []byte
|
||||
ChangeType string
|
||||
ChangePayload []byte
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
type ChangeBuilder interface {
|
||||
ConvertFromRaw(rawIdChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error)
|
||||
BuildContent(payload BuilderContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error)
|
||||
BuildInitialContent(payload InitialContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error)
|
||||
BuildRaw(ch *Change) (*treechangeproto.RawTreeChangeWithId, error)
|
||||
SetRootRawChange(rawIdChange *treechangeproto.RawTreeChangeWithId)
|
||||
type nonVerifiableChangeBuilder struct {
|
||||
ChangeBuilder
|
||||
}
|
||||
|
||||
func (c *nonVerifiableChangeBuilder) BuildRoot(payload InitialContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
return c.ChangeBuilder.BuildRoot(payload)
|
||||
}
|
||||
|
||||
func (c *nonVerifiableChangeBuilder) Unmarshall(rawChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error) {
|
||||
return c.ChangeBuilder.Unmarshall(rawChange, false)
|
||||
}
|
||||
|
||||
func (c *nonVerifiableChangeBuilder) Build(payload BuilderContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
return c.ChangeBuilder.Build(payload)
|
||||
}
|
||||
|
||||
func (c *nonVerifiableChangeBuilder) Marshall(ch *Change) (raw *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
return c.ChangeBuilder.Marshall(ch)
|
||||
}
|
||||
|
||||
type ChangeBuilder interface {
|
||||
Unmarshall(rawIdChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error)
|
||||
Build(payload BuilderContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error)
|
||||
BuildRoot(payload InitialContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error)
|
||||
Marshall(ch *Change) (*treechangeproto.RawTreeChangeWithId, error)
|
||||
}
|
||||
|
||||
type newChangeFunc = func(id string, identity crypto.PubKey, ch *treechangeproto.TreeChange, signature []byte) *Change
|
||||
|
||||
type changeBuilder struct {
|
||||
rootChange *treechangeproto.RawTreeChangeWithId
|
||||
keys *keychain.Keychain
|
||||
keys crypto.KeyStorage
|
||||
newChange newChangeFunc
|
||||
}
|
||||
|
||||
func NewChangeBuilder(keys *keychain.Keychain, rootChange *treechangeproto.RawTreeChangeWithId) ChangeBuilder {
|
||||
return &changeBuilder{keys: keys, rootChange: rootChange}
|
||||
func NewChangeBuilder(keys crypto.KeyStorage, rootChange *treechangeproto.RawTreeChangeWithId) ChangeBuilder {
|
||||
return &changeBuilder{keys: keys, rootChange: rootChange, newChange: NewChange}
|
||||
}
|
||||
|
||||
func (c *changeBuilder) ConvertFromRaw(rawIdChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error) {
|
||||
func (c *changeBuilder) Unmarshall(rawIdChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error) {
|
||||
if rawIdChange.GetRawChange() == nil {
|
||||
err = ErrEmptyChange
|
||||
return
|
||||
@ -77,15 +96,9 @@ func (c *changeBuilder) ConvertFromRaw(rawIdChange *treechangeproto.RawTreeChang
|
||||
}
|
||||
|
||||
if verify {
|
||||
var identityKey signingkey.PubKey
|
||||
identityKey, err = c.keys.GetOrAdd(ch.Identity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// verifying signature
|
||||
var res bool
|
||||
res, err = identityKey.Verify(raw.Payload, raw.Signature)
|
||||
res, err = ch.Identity.Verify(raw.Payload, raw.Signature)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -101,43 +114,41 @@ func (c *changeBuilder) SetRootRawChange(rawIdChange *treechangeproto.RawTreeCha
|
||||
c.rootChange = rawIdChange
|
||||
}
|
||||
|
||||
func (c *changeBuilder) BuildInitialContent(payload InitialContent) (ch *Change, rawIdChange *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
func (c *changeBuilder) BuildRoot(payload InitialContent) (ch *Change, rawIdChange *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
identity, err := payload.PrivKey.GetPublic().Marshall()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
change := &treechangeproto.RootChange{
|
||||
AclHeadId: payload.AclHeadId,
|
||||
Timestamp: payload.Timestamp,
|
||||
Identity: payload.Identity,
|
||||
Identity: identity,
|
||||
ChangeType: payload.ChangeType,
|
||||
ChangePayload: payload.ChangePayload,
|
||||
SpaceId: payload.SpaceId,
|
||||
Seed: payload.Seed,
|
||||
}
|
||||
|
||||
marshalledChange, err := proto.Marshal(change)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
signature, err := payload.SigningKey.Sign(marshalledChange)
|
||||
signature, err := payload.PrivKey.Sign(marshalledChange)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
raw := &treechangeproto.RawTreeChange{
|
||||
Payload: marshalledChange,
|
||||
Signature: signature,
|
||||
}
|
||||
|
||||
marshalledRawChange, err := proto.Marshal(raw)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
id, err := cidutil.NewCidFromBytes(marshalledRawChange)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ch = NewChangeFromRoot(id, change, signature)
|
||||
|
||||
ch = NewChangeFromRoot(id, payload.PrivKey.GetPublic(), change, signature)
|
||||
rawIdChange = &treechangeproto.RawTreeChangeWithId{
|
||||
RawChange: marshalledRawChange,
|
||||
Id: id,
|
||||
@ -145,14 +156,18 @@ func (c *changeBuilder) BuildInitialContent(payload InitialContent) (ch *Change,
|
||||
return
|
||||
}
|
||||
|
||||
func (c *changeBuilder) BuildContent(payload BuilderContent) (ch *Change, rawIdChange *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
func (c *changeBuilder) Build(payload BuilderContent) (ch *Change, rawIdChange *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
identity, err := payload.PrivKey.GetPublic().Marshall()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
change := &treechangeproto.TreeChange{
|
||||
TreeHeadIds: payload.TreeHeadIds,
|
||||
AclHeadId: payload.AclHeadId,
|
||||
SnapshotBaseId: payload.SnapshotBaseId,
|
||||
CurrentReadKeyHash: payload.CurrentReadKeyHash,
|
||||
Timestamp: int64(time.Now().Nanosecond()),
|
||||
Identity: payload.Identity,
|
||||
ReadKeyId: payload.ReadKeyId,
|
||||
Timestamp: payload.Timestamp,
|
||||
Identity: identity,
|
||||
IsSnapshot: payload.IsSnapshot,
|
||||
}
|
||||
if payload.ReadKey != nil {
|
||||
@ -165,34 +180,27 @@ func (c *changeBuilder) BuildContent(payload BuilderContent) (ch *Change, rawIdC
|
||||
} else {
|
||||
change.ChangesData = payload.Content
|
||||
}
|
||||
|
||||
marshalledChange, err := proto.Marshal(change)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
signature, err := payload.SigningKey.Sign(marshalledChange)
|
||||
signature, err := payload.PrivKey.Sign(marshalledChange)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
raw := &treechangeproto.RawTreeChange{
|
||||
Payload: marshalledChange,
|
||||
Signature: signature,
|
||||
}
|
||||
|
||||
marshalledRawChange, err := proto.Marshal(raw)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
id, err := cidutil.NewCidFromBytes(marshalledRawChange)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ch = NewChange(id, change, signature)
|
||||
|
||||
ch = c.newChange(id, payload.PrivKey.GetPublic(), change, signature)
|
||||
rawIdChange = &treechangeproto.RawTreeChangeWithId{
|
||||
RawChange: marshalledRawChange,
|
||||
Id: id,
|
||||
@ -200,18 +208,22 @@ func (c *changeBuilder) BuildContent(payload BuilderContent) (ch *Change, rawIdC
|
||||
return
|
||||
}
|
||||
|
||||
func (c *changeBuilder) BuildRaw(ch *Change) (raw *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
if ch.Id == c.rootChange.Id {
|
||||
func (c *changeBuilder) Marshall(ch *Change) (raw *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
if c.isRoot(ch.Id) {
|
||||
return c.rootChange, nil
|
||||
}
|
||||
identity, err := ch.Identity.Marshall()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
treeChange := &treechangeproto.TreeChange{
|
||||
TreeHeadIds: ch.PreviousIds,
|
||||
AclHeadId: ch.AclHeadId,
|
||||
SnapshotBaseId: ch.SnapshotId,
|
||||
ChangesData: ch.Data,
|
||||
CurrentReadKeyHash: ch.ReadKeyHash,
|
||||
ReadKeyId: ch.ReadKeyId,
|
||||
Timestamp: ch.Timestamp,
|
||||
Identity: []byte(ch.Identity),
|
||||
Identity: identity,
|
||||
IsSnapshot: ch.IsSnapshot,
|
||||
}
|
||||
var marshalled []byte
|
||||
@ -236,22 +248,36 @@ func (c *changeBuilder) BuildRaw(ch *Change) (raw *treechangeproto.RawTreeChange
|
||||
}
|
||||
|
||||
func (c *changeBuilder) unmarshallRawChange(raw *treechangeproto.RawTreeChange, id string) (ch *Change, err error) {
|
||||
if c.rootChange.Id == id {
|
||||
var key crypto.PubKey
|
||||
if c.isRoot(id) {
|
||||
unmarshalled := &treechangeproto.RootChange{}
|
||||
err = proto.Unmarshal(raw.Payload, unmarshalled)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ch = NewChangeFromRoot(id, unmarshalled, raw.Signature)
|
||||
key, err = c.keys.PubKeyFromProto(unmarshalled.Identity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ch = NewChangeFromRoot(id, key, unmarshalled, raw.Signature)
|
||||
return
|
||||
}
|
||||
|
||||
unmarshalled := &treechangeproto.TreeChange{}
|
||||
err = proto.Unmarshal(raw.Payload, unmarshalled)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
ch = NewChange(id, unmarshalled, raw.Signature)
|
||||
key, err = c.keys.PubKeyFromProto(unmarshalled.Identity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ch = c.newChange(id, key, unmarshalled, raw.Signature)
|
||||
return
|
||||
}
|
||||
|
||||
func (c *changeBuilder) isRoot(id string) bool {
|
||||
if c.rootChange != nil {
|
||||
return c.rootChange.Id == id
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
60
commonspace/object/tree/objecttree/historytree.go
Normal file
60
commonspace/object/tree/objecttree/historytree.go
Normal file
@ -0,0 +1,60 @@
|
||||
package objecttree
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
var ErrLoadBeforeRoot = errors.New("can't load before root")
|
||||
|
||||
type HistoryTree interface {
|
||||
ReadableObjectTree
|
||||
}
|
||||
|
||||
type historyTree struct {
|
||||
*objectTree
|
||||
}
|
||||
|
||||
func (h *historyTree) rebuildFromStorage(params HistoryTreeParams) (err error) {
|
||||
err = h.rebuild(params)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
h.aclList.RLock()
|
||||
defer h.aclList.RUnlock()
|
||||
state := h.aclList.AclState()
|
||||
|
||||
return h.readKeysFromAclState(state)
|
||||
}
|
||||
|
||||
func (h *historyTree) rebuild(params HistoryTreeParams) (err error) {
|
||||
var (
|
||||
beforeId = params.BeforeId
|
||||
include = params.IncludeBeforeId
|
||||
full = params.BuildFullTree
|
||||
)
|
||||
h.treeBuilder.Reset()
|
||||
if full {
|
||||
h.tree, err = h.treeBuilder.BuildFull()
|
||||
return
|
||||
}
|
||||
if beforeId == h.Id() && !include {
|
||||
return ErrLoadBeforeRoot
|
||||
}
|
||||
|
||||
heads := []string{beforeId}
|
||||
if beforeId == "" {
|
||||
heads, err = h.treeStorage.Heads()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
} else if !include {
|
||||
beforeChange, err := h.treeBuilder.loadChange(beforeId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
heads = beforeChange.PreviousIds
|
||||
}
|
||||
|
||||
h.tree, err = h.treeBuilder.build(heads, nil, nil)
|
||||
return
|
||||
}
|
||||
@ -1,5 +1,5 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/anytypeio/any-sync/commonspace/object/tree/objecttree (interfaces: ObjectTree)
|
||||
// Source: github.com/anyproto/any-sync/commonspace/object/tree/objecttree (interfaces: ObjectTree)
|
||||
|
||||
// Package mock_objecttree is a generated GoMock package.
|
||||
package mock_objecttree
|
||||
@ -7,11 +7,13 @@ package mock_objecttree
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
objecttree "github.com/anytypeio/any-sync/commonspace/object/tree/objecttree"
|
||||
treechangeproto "github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
|
||||
treestorage "github.com/anytypeio/any-sync/commonspace/object/tree/treestorage"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
list "github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
|
||||
treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
|
||||
treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockObjectTree is a mock of ObjectTree interface.
|
||||
@ -37,6 +39,20 @@ func (m *MockObjectTree) EXPECT() *MockObjectTreeMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// AclList mocks base method.
|
||||
func (m *MockObjectTree) AclList() list.AclList {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "AclList")
|
||||
ret0, _ := ret[0].(list.AclList)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// AclList indicates an expected call of AclList.
|
||||
func (mr *MockObjectTreeMockRecorder) AclList() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AclList", reflect.TypeOf((*MockObjectTree)(nil).AclList))
|
||||
}
|
||||
|
||||
// AddContent mocks base method.
|
||||
func (m *MockObjectTree) AddContent(arg0 context.Context, arg1 objecttree.SignableChangeContent) (objecttree.AddResult, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@ -67,6 +83,20 @@ func (mr *MockObjectTreeMockRecorder) AddRawChanges(arg0, arg1 interface{}) *gom
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawChanges", reflect.TypeOf((*MockObjectTree)(nil).AddRawChanges), arg0, arg1)
|
||||
}
|
||||
|
||||
// ChangeInfo mocks base method.
|
||||
func (m *MockObjectTree) ChangeInfo() *treechangeproto.TreeChangeInfo {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ChangeInfo")
|
||||
ret0, _ := ret[0].(*treechangeproto.TreeChangeInfo)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// ChangeInfo indicates an expected call of ChangeInfo.
|
||||
func (mr *MockObjectTreeMockRecorder) ChangeInfo() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ChangeInfo", reflect.TypeOf((*MockObjectTree)(nil).ChangeInfo))
|
||||
}
|
||||
|
||||
// ChangesAfterCommonSnapshot mocks base method.
|
||||
func (m *MockObjectTree) ChangesAfterCommonSnapshot(arg0, arg1 []string) ([]*treechangeproto.RawTreeChangeWithId, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@ -96,19 +126,19 @@ func (mr *MockObjectTreeMockRecorder) Close() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockObjectTree)(nil).Close))
|
||||
}
|
||||
|
||||
// DebugDump mocks base method.
|
||||
func (m *MockObjectTree) DebugDump(arg0 objecttree.DescriptionParser) (string, error) {
|
||||
// Debug mocks base method.
|
||||
func (m *MockObjectTree) Debug(arg0 objecttree.DescriptionParser) (objecttree.DebugInfo, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DebugDump", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret := m.ctrl.Call(m, "Debug", arg0)
|
||||
ret0, _ := ret[0].(objecttree.DebugInfo)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// DebugDump indicates an expected call of DebugDump.
|
||||
func (mr *MockObjectTreeMockRecorder) DebugDump(arg0 interface{}) *gomock.Call {
|
||||
// Debug indicates an expected call of Debug.
|
||||
func (mr *MockObjectTreeMockRecorder) Debug(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DebugDump", reflect.TypeOf((*MockObjectTree)(nil).DebugDump), arg0)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockObjectTree)(nil).Debug), arg0)
|
||||
}
|
||||
|
||||
// Delete mocks base method.
|
||||
@ -125,6 +155,21 @@ func (mr *MockObjectTreeMockRecorder) Delete() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockObjectTree)(nil).Delete))
|
||||
}
|
||||
|
||||
// GetChange mocks base method.
|
||||
func (m *MockObjectTree) GetChange(arg0 string) (*objecttree.Change, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChange", arg0)
|
||||
ret0, _ := ret[0].(*objecttree.Change)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChange indicates an expected call of GetChange.
|
||||
func (mr *MockObjectTreeMockRecorder) GetChange(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChange", reflect.TypeOf((*MockObjectTree)(nil).GetChange), arg0)
|
||||
}
|
||||
|
||||
// HasChanges mocks base method.
|
||||
func (m *MockObjectTree) HasChanges(arg0 ...string) bool {
|
||||
m.ctrl.T.Helper()
|
||||
@ -213,6 +258,20 @@ func (mr *MockObjectTreeMockRecorder) IterateRoot(arg0, arg1 interface{}) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateRoot", reflect.TypeOf((*MockObjectTree)(nil).IterateRoot), arg0, arg1)
|
||||
}
|
||||
|
||||
// Len mocks base method.
|
||||
func (m *MockObjectTree) Len() int {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Len")
|
||||
ret0, _ := ret[0].(int)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Len indicates an expected call of Len.
|
||||
func (mr *MockObjectTreeMockRecorder) Len() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Len", reflect.TypeOf((*MockObjectTree)(nil).Len))
|
||||
}
|
||||
|
||||
// Lock mocks base method.
|
||||
func (m *MockObjectTree) Lock() {
|
||||
m.ctrl.T.Helper()
|
||||
@ -225,6 +284,21 @@ func (mr *MockObjectTreeMockRecorder) Lock() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lock", reflect.TypeOf((*MockObjectTree)(nil).Lock))
|
||||
}
|
||||
|
||||
// PrepareChange mocks base method.
|
||||
func (m *MockObjectTree) PrepareChange(arg0 objecttree.SignableChangeContent) (*treechangeproto.RawTreeChangeWithId, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "PrepareChange", arg0)
|
||||
ret0, _ := ret[0].(*treechangeproto.RawTreeChangeWithId)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// PrepareChange indicates an expected call of PrepareChange.
|
||||
func (mr *MockObjectTreeMockRecorder) PrepareChange(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PrepareChange", reflect.TypeOf((*MockObjectTree)(nil).PrepareChange), arg0)
|
||||
}
|
||||
|
||||
// RLock mocks base method.
|
||||
func (m *MockObjectTree) RLock() {
|
||||
m.ctrl.T.Helper()
|
||||
@ -291,6 +365,49 @@ func (mr *MockObjectTreeMockRecorder) Storage() *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Storage", reflect.TypeOf((*MockObjectTree)(nil).Storage))
|
||||
}
|
||||
|
||||
// TryClose mocks base method.
|
||||
func (m *MockObjectTree) TryClose(arg0 time.Duration) (bool, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TryClose", arg0)
|
||||
ret0, _ := ret[0].(bool)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// TryClose indicates an expected call of TryClose.
|
||||
func (mr *MockObjectTreeMockRecorder) TryClose(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryClose", reflect.TypeOf((*MockObjectTree)(nil).TryClose), arg0)
|
||||
}
|
||||
|
||||
// TryLock mocks base method.
|
||||
func (m *MockObjectTree) TryLock() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TryLock")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// TryLock indicates an expected call of TryLock.
|
||||
func (mr *MockObjectTreeMockRecorder) TryLock() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryLock", reflect.TypeOf((*MockObjectTree)(nil).TryLock))
|
||||
}
|
||||
|
||||
// TryRLock mocks base method.
|
||||
func (m *MockObjectTree) TryRLock() bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "TryRLock")
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// TryRLock indicates an expected call of TryRLock.
|
||||
func (mr *MockObjectTreeMockRecorder) TryRLock() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryRLock", reflect.TypeOf((*MockObjectTree)(nil).TryRLock))
|
||||
}
|
||||
|
||||
// Unlock mocks base method.
|
||||
func (m *MockObjectTree) Unlock() {
|
||||
m.ctrl.T.Helper()
|
||||
@ -316,3 +433,18 @@ func (mr *MockObjectTreeMockRecorder) UnmarshalledHeader() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnmarshalledHeader", reflect.TypeOf((*MockObjectTree)(nil).UnmarshalledHeader))
|
||||
}
|
||||
|
||||
// UnpackChange mocks base method.
|
||||
func (m *MockObjectTree) UnpackChange(arg0 *treechangeproto.RawTreeChangeWithId) ([]byte, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UnpackChange", arg0)
|
||||
ret0, _ := ret[0].([]byte)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UnpackChange indicates an expected call of UnpackChange.
|
||||
func (mr *MockObjectTreeMockRecorder) UnpackChange(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UnpackChange", reflect.TypeOf((*MockObjectTree)(nil).UnpackChange), arg0)
|
||||
}
|
||||
|
||||
@ -1,28 +1,33 @@
|
||||
//go:generate mockgen -destination mock_objecttree/mock_objecttree.go github.com/anytypeio/any-sync/commonspace/object/tree/objecttree ObjectTree
|
||||
//go:generate mockgen -destination mock_objecttree/mock_objecttree.go github.com/anyproto/any-sync/commonspace/object/tree/objecttree ObjectTree
|
||||
package objecttree
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
list2 "github.com/anytypeio/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/keychain"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/anytypeio/any-sync/util/keys/symmetric"
|
||||
"github.com/anytypeio/any-sync/util/slice"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/anyproto/any-sync/util/slice"
|
||||
)
|
||||
|
||||
type RWLocker interface {
|
||||
sync.Locker
|
||||
RLock()
|
||||
RUnlock()
|
||||
TryRLock() bool
|
||||
TryLock() bool
|
||||
}
|
||||
|
||||
var (
|
||||
ErrHasInvalidChanges = errors.New("the change is invalid")
|
||||
ErrNoCommonSnapshot = errors.New("trees doesn't have a common snapshot")
|
||||
ErrNoChangeInTree = errors.New("no such change in tree")
|
||||
ErrMissingKey = errors.New("missing current read key")
|
||||
)
|
||||
|
||||
type AddResultSummary int
|
||||
@ -43,19 +48,29 @@ type RawChangesPayload struct {
|
||||
type ChangeIterateFunc = func(change *Change) bool
|
||||
type ChangeConvertFunc = func(decrypted []byte) (any, error)
|
||||
|
||||
type ObjectTree interface {
|
||||
type ReadableObjectTree interface {
|
||||
RWLocker
|
||||
|
||||
Id() string
|
||||
Header() *treechangeproto.RawTreeChangeWithId
|
||||
UnmarshalledHeader() *Change
|
||||
ChangeInfo() *treechangeproto.TreeChangeInfo
|
||||
Heads() []string
|
||||
Root() *Change
|
||||
HasChanges(...string) bool
|
||||
DebugDump(parser DescriptionParser) (string, error)
|
||||
Len() int
|
||||
|
||||
AclList() list.AclList
|
||||
|
||||
HasChanges(...string) bool
|
||||
GetChange(string) (*Change, error)
|
||||
|
||||
Debug(parser DescriptionParser) (DebugInfo, error)
|
||||
IterateRoot(convert ChangeConvertFunc, iterate ChangeIterateFunc) error
|
||||
IterateFrom(id string, convert ChangeConvertFunc, iterate ChangeIterateFunc) error
|
||||
}
|
||||
|
||||
type ObjectTree interface {
|
||||
ReadableObjectTree
|
||||
|
||||
SnapshotPath() []string
|
||||
ChangesAfterCommonSnapshot(snapshotPath, heads []string) ([]*treechangeproto.RawTreeChangeWithId, error)
|
||||
@ -65,8 +80,12 @@ type ObjectTree interface {
|
||||
AddContent(ctx context.Context, content SignableChangeContent) (AddResult, error)
|
||||
AddRawChanges(ctx context.Context, changes RawChangesPayload) (AddResult, error)
|
||||
|
||||
UnpackChange(raw *treechangeproto.RawTreeChangeWithId) (data []byte, err error)
|
||||
PrepareChange(content SignableChangeContent) (res *treechangeproto.RawTreeChangeWithId, err error)
|
||||
|
||||
Delete() error
|
||||
Close() error
|
||||
TryClose(objectTTL time.Duration) (bool, error)
|
||||
}
|
||||
|
||||
type objectTree struct {
|
||||
@ -75,14 +94,15 @@ type objectTree struct {
|
||||
validator ObjectTreeValidator
|
||||
rawChangeLoader *rawChangeLoader
|
||||
treeBuilder *treeBuilder
|
||||
aclList list2.AclList
|
||||
aclList list.AclList
|
||||
|
||||
id string
|
||||
rawRoot *treechangeproto.RawTreeChangeWithId
|
||||
root *Change
|
||||
tree *Tree
|
||||
|
||||
keys map[uint64]*symmetric.Key
|
||||
keys map[string]crypto.SymKey
|
||||
currentReadKey crypto.SymKey
|
||||
|
||||
// buffers
|
||||
difSnapshotBuf []*treechangeproto.RawTreeChangeWithId
|
||||
@ -95,41 +115,26 @@ type objectTree struct {
|
||||
sync.RWMutex
|
||||
}
|
||||
|
||||
type objectTreeDeps struct {
|
||||
changeBuilder ChangeBuilder
|
||||
treeBuilder *treeBuilder
|
||||
treeStorage treestorage.TreeStorage
|
||||
validator ObjectTreeValidator
|
||||
rawChangeLoader *rawChangeLoader
|
||||
aclList list2.AclList
|
||||
}
|
||||
|
||||
func defaultObjectTreeDeps(
|
||||
rootChange *treechangeproto.RawTreeChangeWithId,
|
||||
treeStorage treestorage.TreeStorage,
|
||||
aclList list2.AclList) objectTreeDeps {
|
||||
|
||||
keychain := keychain.NewKeychain()
|
||||
changeBuilder := NewChangeBuilder(keychain, rootChange)
|
||||
treeBuilder := newTreeBuilder(treeStorage, changeBuilder)
|
||||
return objectTreeDeps{
|
||||
changeBuilder: changeBuilder,
|
||||
treeBuilder: treeBuilder,
|
||||
treeStorage: treeStorage,
|
||||
validator: newTreeValidator(),
|
||||
rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder),
|
||||
aclList: aclList,
|
||||
}
|
||||
}
|
||||
|
||||
func (ot *objectTree) rebuildFromStorage(theirHeads []string, newChanges []*Change) (err error) {
|
||||
oldTree := ot.tree
|
||||
ot.treeBuilder.Reset()
|
||||
|
||||
ot.tree, err = ot.treeBuilder.Build(theirHeads, newChanges)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// in case there are new heads
|
||||
if theirHeads != nil && oldTree != nil {
|
||||
// checking that old root is still in tree
|
||||
rootCh, rootExists := ot.tree.attached[oldTree.RootId()]
|
||||
|
||||
// checking the case where theirHeads were actually below prevHeads
|
||||
// so if we did load some extra data in the tree, let's reduce it to old root
|
||||
if slice.UnsortedEquals(oldTree.headIds, ot.tree.headIds) && rootExists && ot.tree.RootId() != oldTree.RootId() {
|
||||
ot.tree.makeRootAndRemove(rootCh)
|
||||
}
|
||||
}
|
||||
|
||||
// during building the tree we may have marked some changes as possible roots,
|
||||
// but obviously they are not roots, because of the way how we construct the tree
|
||||
ot.tree.clearPossibleRoots()
|
||||
@ -143,6 +148,14 @@ func (ot *objectTree) Id() string {
|
||||
return ot.id
|
||||
}
|
||||
|
||||
func (ot *objectTree) Len() int {
|
||||
return ot.tree.Len()
|
||||
}
|
||||
|
||||
func (ot *objectTree) AclList() list.AclList {
|
||||
return ot.aclList
|
||||
}
|
||||
|
||||
func (ot *objectTree) Header() *treechangeproto.RawTreeChangeWithId {
|
||||
return ot.rawRoot
|
||||
}
|
||||
@ -151,10 +164,21 @@ func (ot *objectTree) UnmarshalledHeader() *Change {
|
||||
return ot.root
|
||||
}
|
||||
|
||||
func (ot *objectTree) ChangeInfo() *treechangeproto.TreeChangeInfo {
|
||||
return ot.root.Model.(*treechangeproto.TreeChangeInfo)
|
||||
}
|
||||
|
||||
func (ot *objectTree) Storage() treestorage.TreeStorage {
|
||||
return ot.treeStorage
|
||||
}
|
||||
|
||||
func (ot *objectTree) GetChange(id string) (*Change, error) {
|
||||
if ch, ok := ot.tree.attached[id]; ok {
|
||||
return ch, nil
|
||||
}
|
||||
return nil, ErrNoChangeInTree
|
||||
}
|
||||
|
||||
func (ot *objectTree) AddContent(ctx context.Context, content SignableChangeContent) (res AddResult, err error) {
|
||||
payload, err := ot.prepareBuilderContent(content)
|
||||
if err != nil {
|
||||
@ -165,7 +189,7 @@ func (ot *objectTree) AddContent(ctx context.Context, content SignableChangeCont
|
||||
oldHeads := make([]string, 0, len(ot.tree.Heads()))
|
||||
oldHeads = append(oldHeads, ot.tree.Heads()...)
|
||||
|
||||
objChange, rawChange, err := ot.changeBuilder.BuildContent(payload)
|
||||
objChange, rawChange, err := ot.changeBuilder.Build(payload)
|
||||
if content.IsSnapshot {
|
||||
// clearing tree, because we already saved everything in the last snapshot
|
||||
ot.tree = &Tree{}
|
||||
@ -175,7 +199,7 @@ func (ot *objectTree) AddContent(ctx context.Context, content SignableChangeCont
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = ot.treeStorage.TransactionAdd([]*treechangeproto.RawTreeChangeWithId{rawChange}, []string{objChange.Id})
|
||||
err = ot.treeStorage.AddRawChangesSetHeads([]*treechangeproto.RawTreeChangeWithId{rawChange}, []string{objChange.Id})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -196,39 +220,61 @@ func (ot *objectTree) AddContent(ctx context.Context, content SignableChangeCont
|
||||
return
|
||||
}
|
||||
|
||||
func (ot *objectTree) UnpackChange(raw *treechangeproto.RawTreeChangeWithId) (data []byte, err error) {
|
||||
unmarshalled, err := ot.changeBuilder.Unmarshall(raw, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
data = unmarshalled.Data
|
||||
return
|
||||
}
|
||||
|
||||
func (ot *objectTree) PrepareChange(content SignableChangeContent) (res *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
payload, err := ot.prepareBuilderContent(content)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, res, err = ot.changeBuilder.Build(payload)
|
||||
return
|
||||
}
|
||||
|
||||
func (ot *objectTree) prepareBuilderContent(content SignableChangeContent) (cnt BuilderContent, err error) {
|
||||
ot.aclList.RLock()
|
||||
defer ot.aclList.RUnlock()
|
||||
|
||||
var (
|
||||
state = ot.aclList.AclState() // special method for own keys
|
||||
readKey *symmetric.Key
|
||||
readKeyHash uint64
|
||||
readKey crypto.SymKey
|
||||
pubKey = content.Key.GetPublic()
|
||||
readKeyId string
|
||||
)
|
||||
canWrite := state.HasPermission(content.Identity, aclrecordproto.AclUserPermissions_Writer) ||
|
||||
state.HasPermission(content.Identity, aclrecordproto.AclUserPermissions_Admin)
|
||||
if !canWrite {
|
||||
err = list2.ErrInsufficientPermissions
|
||||
if !state.Permissions(pubKey).CanWrite() {
|
||||
err = list.ErrInsufficientPermissions
|
||||
return
|
||||
}
|
||||
|
||||
if content.IsEncrypted {
|
||||
readKeyHash = state.CurrentReadKeyHash()
|
||||
readKey, err = state.CurrentReadKey()
|
||||
if err != nil {
|
||||
readKeyId = state.CurrentReadKeyId()
|
||||
if ot.currentReadKey == nil {
|
||||
err = ErrMissingKey
|
||||
return
|
||||
}
|
||||
readKey = ot.currentReadKey
|
||||
}
|
||||
timestamp := content.Timestamp
|
||||
if timestamp <= 0 {
|
||||
timestamp = time.Now().Unix()
|
||||
}
|
||||
cnt = BuilderContent{
|
||||
TreeHeadIds: ot.tree.Heads(),
|
||||
AclHeadId: ot.aclList.Head().Id,
|
||||
SnapshotBaseId: ot.tree.RootId(),
|
||||
CurrentReadKeyHash: readKeyHash,
|
||||
Identity: content.Identity,
|
||||
ReadKeyId: readKeyId,
|
||||
IsSnapshot: content.IsSnapshot,
|
||||
SigningKey: content.Key,
|
||||
PrivKey: content.Key,
|
||||
ReadKey: readKey,
|
||||
Content: content.Data,
|
||||
Timestamp: timestamp,
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -248,7 +294,11 @@ func (ot *objectTree) AddRawChanges(ctx context.Context, changesPayload RawChang
|
||||
addResult.Mode = Rebuild
|
||||
}
|
||||
|
||||
err = ot.treeStorage.TransactionAdd(addResult.Added, addResult.Heads)
|
||||
err = ot.treeStorage.AddRawChangesSetHeads(addResult.Added, addResult.Heads)
|
||||
if err != nil {
|
||||
// rolling back all changes made to inmemory state
|
||||
ot.rebuildFromStorage(nil, nil)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -279,7 +329,7 @@ func (ot *objectTree) addRawChanges(ctx context.Context, changesPayload RawChang
|
||||
if unAttached, exists := ot.tree.unAttached[ch.Id]; exists {
|
||||
change = unAttached
|
||||
} else {
|
||||
change, err = ot.changeBuilder.ConvertFromRaw(ch, true)
|
||||
change, err = ot.changeBuilder.Unmarshall(ch, true)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -311,7 +361,7 @@ func (ot *objectTree) addRawChanges(ctx context.Context, changesPayload RawChang
|
||||
}
|
||||
|
||||
// checks if we need to go to database
|
||||
isOldSnapshot := func(ch *Change) bool {
|
||||
snapshotNotInTree := func(ch *Change) bool {
|
||||
if ch.SnapshotId == ot.tree.RootId() {
|
||||
return false
|
||||
}
|
||||
@ -326,26 +376,12 @@ func (ot *objectTree) addRawChanges(ctx context.Context, changesPayload RawChang
|
||||
|
||||
shouldRebuildFromStorage := false
|
||||
// checking if we have some changes with different snapshot and then rebuilding
|
||||
for idx, ch := range ot.newChangesBuf {
|
||||
if isOldSnapshot(ch) {
|
||||
var exists bool
|
||||
// checking if it exists in the storage, if yes, then at some point it was added to the tree
|
||||
// thus we don't need to look at this change
|
||||
exists, err = ot.treeStorage.HasChange(ctx, ch.Id)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if exists {
|
||||
// marking as nil to delete after
|
||||
ot.newChangesBuf[idx] = nil
|
||||
continue
|
||||
}
|
||||
// we haven't seen the change, and it refers to old snapshot, so we should rebuild
|
||||
for _, ch := range ot.newChangesBuf {
|
||||
if snapshotNotInTree(ch) {
|
||||
shouldRebuildFromStorage = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// discarding all previously seen changes
|
||||
ot.newChangesBuf = slice.DiscardFromSlice(ot.newChangesBuf, func(ch *Change) bool { return ch == nil })
|
||||
|
||||
if shouldRebuildFromStorage {
|
||||
err = ot.rebuildFromStorage(changesPayload.NewHeads, ot.newChangesBuf)
|
||||
@ -430,7 +466,7 @@ func (ot *objectTree) createAddResult(oldHeads []string, mode Mode, treeChangesA
|
||||
// if we got some changes that we need to convert to raw
|
||||
if _, exists := alreadyConverted[ch]; !exists {
|
||||
var raw *treechangeproto.RawTreeChangeWithId
|
||||
raw, err = ot.changeBuilder.BuildRaw(ch)
|
||||
raw, err = ot.changeBuilder.Marshall(ch)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -442,6 +478,11 @@ func (ot *objectTree) createAddResult(oldHeads []string, mode Mode, treeChangesA
|
||||
|
||||
var added []*treechangeproto.RawTreeChangeWithId
|
||||
added, err = getAddedChanges(treeChangesAdded)
|
||||
if !ot.treeBuilder.keepInMemoryData {
|
||||
for _, ch := range treeChangesAdded {
|
||||
ch.Data = nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -465,13 +506,13 @@ func (ot *objectTree) IterateFrom(id string, convert ChangeConvertFunc, iterate
|
||||
}
|
||||
decrypt := func(c *Change) (decrypted []byte, err error) {
|
||||
// the change is not encrypted
|
||||
if c.ReadKeyHash == 0 {
|
||||
if c.ReadKeyId == "" {
|
||||
decrypted = c.Data
|
||||
return
|
||||
}
|
||||
readKey, exists := ot.keys[c.ReadKeyHash]
|
||||
readKey, exists := ot.keys[c.ReadKeyId]
|
||||
if !exists {
|
||||
err = list2.ErrNoReadKey
|
||||
err = list.ErrNoReadKey
|
||||
return
|
||||
}
|
||||
|
||||
@ -508,22 +549,8 @@ func (ot *objectTree) IterateFrom(id string, convert ChangeConvertFunc, iterate
|
||||
}
|
||||
|
||||
func (ot *objectTree) HasChanges(chs ...string) bool {
|
||||
hasChange := func(s string) bool {
|
||||
_, attachedExists := ot.tree.attached[s]
|
||||
if attachedExists {
|
||||
return attachedExists
|
||||
}
|
||||
|
||||
has, err := ot.treeStorage.HasChange(context.Background(), s)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return has
|
||||
}
|
||||
|
||||
for _, ch := range chs {
|
||||
if !hasChange(ch) {
|
||||
if _, attachedExists := ot.tree.attached[ch]; !attachedExists {
|
||||
return false
|
||||
}
|
||||
}
|
||||
@ -539,6 +566,10 @@ func (ot *objectTree) Root() *Change {
|
||||
return ot.tree.Root()
|
||||
}
|
||||
|
||||
func (ot *objectTree) TryClose(objectTTL time.Duration) (bool, error) {
|
||||
return true, ot.Close()
|
||||
}
|
||||
|
||||
func (ot *objectTree) Close() error {
|
||||
return nil
|
||||
}
|
||||
@ -586,19 +617,7 @@ func (ot *objectTree) ChangesAfterCommonSnapshot(theirPath, theirHeads []string)
|
||||
}
|
||||
}
|
||||
|
||||
if commonSnapshot == ot.tree.RootId() {
|
||||
return ot.getChangesFromTree(theirHeads)
|
||||
} else {
|
||||
return ot.getChangesFromDB(commonSnapshot, theirHeads)
|
||||
}
|
||||
}
|
||||
|
||||
func (ot *objectTree) getChangesFromTree(theirHeads []string) (rawChanges []*treechangeproto.RawTreeChangeWithId, err error) {
|
||||
return ot.rawChangeLoader.LoadFromTree(ot.tree, theirHeads)
|
||||
}
|
||||
|
||||
func (ot *objectTree) getChangesFromDB(commonSnapshot string, theirHeads []string) (rawChanges []*treechangeproto.RawTreeChangeWithId, err error) {
|
||||
return ot.rawChangeLoader.LoadFromStorage(commonSnapshot, ot.tree.headIds, theirHeads)
|
||||
return ot.rawChangeLoader.Load(commonSnapshot, ot.tree, theirHeads)
|
||||
}
|
||||
|
||||
func (ot *objectTree) snapshotPathIsActual() bool {
|
||||
@ -610,11 +629,9 @@ func (ot *objectTree) validateTree(newChanges []*Change) error {
|
||||
defer ot.aclList.RUnlock()
|
||||
state := ot.aclList.AclState()
|
||||
|
||||
// just not to take lock many times, updating the key map from aclList
|
||||
if len(ot.keys) != len(state.UserReadKeys()) {
|
||||
for key, value := range state.UserReadKeys() {
|
||||
ot.keys[key] = value
|
||||
}
|
||||
err := ot.readKeysFromAclState(state)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(newChanges) == 0 {
|
||||
return ot.validator.ValidateFullTree(ot.tree, ot.aclList)
|
||||
@ -623,6 +640,26 @@ func (ot *objectTree) validateTree(newChanges []*Change) error {
|
||||
return ot.validator.ValidateNewChanges(ot.tree, ot.aclList, newChanges)
|
||||
}
|
||||
|
||||
func (ot *objectTree) DebugDump(parser DescriptionParser) (string, error) {
|
||||
return ot.tree.Graph(parser)
|
||||
func (ot *objectTree) readKeysFromAclState(state *list.AclState) (err error) {
|
||||
// just not to take lock many times, updating the key map from aclList
|
||||
if len(ot.keys) == len(state.UserReadKeys()) {
|
||||
return nil
|
||||
}
|
||||
for key, value := range state.UserReadKeys() {
|
||||
treeKey, err := deriveTreeKey(value, ot.id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ot.keys[key] = treeKey
|
||||
}
|
||||
curKey, err := state.CurrentReadKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ot.currentReadKey, err = deriveTreeKey(curKey, ot.id)
|
||||
return err
|
||||
}
|
||||
|
||||
func (ot *objectTree) Debug(parser DescriptionParser) (DebugInfo, error) {
|
||||
return objectTreeDebug{}.debugInfo(ot, parser)
|
||||
}
|
||||
|
||||
@ -2,133 +2,74 @@ package objecttree
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/testutils/acllistbuilder"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/anyproto/any-sync/commonspace/object/accountdata"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/gogo/protobuf/proto"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type mockChangeCreator struct{}
|
||||
|
||||
func (c *mockChangeCreator) createRoot(id, aclId string) *treechangeproto.RawTreeChangeWithId {
|
||||
aclChange := &treechangeproto.RootChange{
|
||||
AclHeadId: aclId,
|
||||
}
|
||||
res, _ := aclChange.Marshal()
|
||||
|
||||
raw := &treechangeproto.RawTreeChange{
|
||||
Payload: res,
|
||||
Signature: nil,
|
||||
}
|
||||
rawMarshalled, _ := raw.Marshal()
|
||||
|
||||
return &treechangeproto.RawTreeChangeWithId{
|
||||
RawChange: rawMarshalled,
|
||||
Id: id,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mockChangeCreator) createRaw(id, aclId, snapshotId string, isSnapshot bool, prevIds ...string) *treechangeproto.RawTreeChangeWithId {
|
||||
aclChange := &treechangeproto.TreeChange{
|
||||
TreeHeadIds: prevIds,
|
||||
AclHeadId: aclId,
|
||||
SnapshotBaseId: snapshotId,
|
||||
ChangesData: nil,
|
||||
IsSnapshot: isSnapshot,
|
||||
}
|
||||
res, _ := aclChange.Marshal()
|
||||
|
||||
raw := &treechangeproto.RawTreeChange{
|
||||
Payload: res,
|
||||
Signature: nil,
|
||||
}
|
||||
rawMarshalled, _ := raw.Marshal()
|
||||
|
||||
return &treechangeproto.RawTreeChangeWithId{
|
||||
RawChange: rawMarshalled,
|
||||
Id: id,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mockChangeCreator) createNewTreeStorage(treeId, aclHeadId string) treestorage.TreeStorage {
|
||||
root := c.createRoot(treeId, aclHeadId)
|
||||
treeStorage, _ := treestorage.NewInMemoryTreeStorage(root, []string{root.Id}, []*treechangeproto.RawTreeChangeWithId{root})
|
||||
return treeStorage
|
||||
}
|
||||
|
||||
type mockChangeBuilder struct {
|
||||
originalBuilder ChangeBuilder
|
||||
}
|
||||
|
||||
func (c *mockChangeBuilder) BuildInitialContent(payload InitialContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *mockChangeBuilder) SetRootRawChange(rawIdChange *treechangeproto.RawTreeChangeWithId) {
|
||||
c.originalBuilder.SetRootRawChange(rawIdChange)
|
||||
}
|
||||
|
||||
func (c *mockChangeBuilder) ConvertFromRaw(rawChange *treechangeproto.RawTreeChangeWithId, verify bool) (ch *Change, err error) {
|
||||
return c.originalBuilder.ConvertFromRaw(rawChange, false)
|
||||
}
|
||||
|
||||
func (c *mockChangeBuilder) BuildContent(payload BuilderContent) (ch *Change, raw *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
panic("implement me")
|
||||
}
|
||||
|
||||
func (c *mockChangeBuilder) BuildRaw(ch *Change) (raw *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
return c.originalBuilder.BuildRaw(ch)
|
||||
}
|
||||
|
||||
type mockChangeValidator struct{}
|
||||
|
||||
func (m *mockChangeValidator) ValidateNewChanges(tree *Tree, aclList list.AclList, newChanges []*Change) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockChangeValidator) ValidateFullTree(tree *Tree, aclList list.AclList) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type testTreeContext struct {
|
||||
aclList list.AclList
|
||||
treeStorage treestorage.TreeStorage
|
||||
changeBuilder *mockChangeBuilder
|
||||
changeCreator *mockChangeCreator
|
||||
changeCreator *MockChangeCreator
|
||||
objTree ObjectTree
|
||||
}
|
||||
|
||||
func prepareAclList(t *testing.T) list.AclList {
|
||||
st, err := acllistbuilder.NewListStorageWithTestName("userjoinexample.yml")
|
||||
require.NoError(t, err, "building storage should not result in error")
|
||||
|
||||
aclList, err := list.BuildAclList(st)
|
||||
func prepareAclList(t *testing.T) (list.AclList, *accountdata.AccountKeys) {
|
||||
randKeys, err := accountdata.NewRandom()
|
||||
require.NoError(t, err)
|
||||
aclList, err := list.NewTestDerivedAcl("spaceId", randKeys)
|
||||
require.NoError(t, err, "building acl list should be without error")
|
||||
|
||||
return aclList
|
||||
return aclList, randKeys
|
||||
}
|
||||
|
||||
func prepareTreeContext(t *testing.T, aclList list.AclList) testTreeContext {
|
||||
changeCreator := &mockChangeCreator{}
|
||||
treeStorage := changeCreator.createNewTreeStorage("0", aclList.Head().Id)
|
||||
func prepareHistoryTreeDeps(aclList list.AclList) (*MockChangeCreator, objectTreeDeps) {
|
||||
changeCreator := NewMockChangeCreator()
|
||||
treeStorage := changeCreator.CreateNewTreeStorage("0", aclList.Head().Id)
|
||||
root, _ := treeStorage.Root()
|
||||
changeBuilder := &mockChangeBuilder{
|
||||
originalBuilder: NewChangeBuilder(nil, root),
|
||||
changeBuilder := &nonVerifiableChangeBuilder{
|
||||
ChangeBuilder: NewChangeBuilder(newMockKeyStorage(), root),
|
||||
}
|
||||
deps := objectTreeDeps{
|
||||
changeBuilder: changeBuilder,
|
||||
treeBuilder: newTreeBuilder(treeStorage, changeBuilder),
|
||||
treeBuilder: newTreeBuilder(true, treeStorage, changeBuilder),
|
||||
treeStorage: treeStorage,
|
||||
rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder),
|
||||
validator: &mockChangeValidator{},
|
||||
validator: &noOpTreeValidator{},
|
||||
aclList: aclList,
|
||||
}
|
||||
return changeCreator, deps
|
||||
}
|
||||
|
||||
// check build
|
||||
objTree, err := buildObjectTree(deps)
|
||||
func prepareTreeContext(t *testing.T, aclList list.AclList) testTreeContext {
|
||||
return prepareContext(t, aclList, BuildTestableTree, nil)
|
||||
}
|
||||
|
||||
func prepareEmptyDataTreeContext(t *testing.T, aclList list.AclList, additionalChanges func(changeCreator *MockChangeCreator) RawChangesPayload) testTreeContext {
|
||||
return prepareContext(t, aclList, BuildEmptyDataTestableTree, additionalChanges)
|
||||
}
|
||||
|
||||
func prepareContext(
|
||||
t *testing.T,
|
||||
aclList list.AclList,
|
||||
objTreeBuilder BuildObjectTreeFunc,
|
||||
additionalChanges func(changeCreator *MockChangeCreator) RawChangesPayload) testTreeContext {
|
||||
changeCreator := NewMockChangeCreator()
|
||||
treeStorage := changeCreator.CreateNewTreeStorage("0", aclList.Head().Id)
|
||||
if additionalChanges != nil {
|
||||
payload := additionalChanges(changeCreator)
|
||||
err := treeStorage.AddRawChangesSetHeads(payload.RawChanges, payload.NewHeads)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
objTree, err := objTreeBuilder(treeStorage, aclList)
|
||||
require.NoError(t, err, "building tree should be without error")
|
||||
|
||||
// check tree iterate
|
||||
@ -138,18 +79,71 @@ func prepareTreeContext(t *testing.T, aclList list.AclList) testTreeContext {
|
||||
return true
|
||||
})
|
||||
require.NoError(t, err, "iterate should be without error")
|
||||
if additionalChanges == nil {
|
||||
assert.Equal(t, []string{"0"}, iterChangesId)
|
||||
}
|
||||
return testTreeContext{
|
||||
aclList: aclList,
|
||||
treeStorage: treeStorage,
|
||||
changeBuilder: changeBuilder,
|
||||
changeCreator: changeCreator,
|
||||
objTree: objTree,
|
||||
}
|
||||
}
|
||||
|
||||
func TestObjectTree(t *testing.T) {
|
||||
aclList := prepareAclList(t)
|
||||
aclList, keys := prepareAclList(t)
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("add content", func(t *testing.T) {
|
||||
root, err := CreateObjectTreeRoot(ObjectTreeCreatePayload{
|
||||
PrivKey: keys.SignKey,
|
||||
ChangeType: "changeType",
|
||||
ChangePayload: nil,
|
||||
SpaceId: "spaceId",
|
||||
IsEncrypted: true,
|
||||
}, aclList)
|
||||
require.NoError(t, err)
|
||||
store, _ := treestorage.NewInMemoryTreeStorage(root, []string{root.Id}, []*treechangeproto.RawTreeChangeWithId{root})
|
||||
oTree, err := BuildObjectTree(store, aclList)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("0 timestamp is changed to current", func(t *testing.T) {
|
||||
start := time.Now()
|
||||
res, err := oTree.AddContent(ctx, SignableChangeContent{
|
||||
Data: []byte("some"),
|
||||
Key: keys.SignKey,
|
||||
IsSnapshot: false,
|
||||
IsEncrypted: true,
|
||||
Timestamp: 0,
|
||||
})
|
||||
end := time.Now()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, oTree.Heads(), 1)
|
||||
require.Equal(t, res.Added[0].Id, oTree.Heads()[0])
|
||||
ch, err := oTree.(*objectTree).changeBuilder.Unmarshall(res.Added[0], true)
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, start.Unix(), ch.Timestamp)
|
||||
require.LessOrEqual(t, end.Unix(), ch.Timestamp)
|
||||
require.Equal(t, res.Added[0].Id, oTree.(*objectTree).tree.lastIteratedHeadId)
|
||||
})
|
||||
t.Run("timestamp is set correctly", func(t *testing.T) {
|
||||
someTs := time.Now().Add(time.Hour).Unix()
|
||||
res, err := oTree.AddContent(ctx, SignableChangeContent{
|
||||
Data: []byte("some"),
|
||||
Key: keys.SignKey,
|
||||
IsSnapshot: false,
|
||||
IsEncrypted: true,
|
||||
Timestamp: someTs,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, oTree.Heads(), 1)
|
||||
require.Equal(t, res.Added[0].Id, oTree.Heads()[0])
|
||||
ch, err := oTree.(*objectTree).changeBuilder.Unmarshall(res.Added[0], true)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, ch.Timestamp, someTs)
|
||||
require.Equal(t, res.Added[0].Id, oTree.(*objectTree).tree.lastIteratedHeadId)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("add simple", func(t *testing.T) {
|
||||
ctx := prepareTreeContext(t, aclList)
|
||||
@ -158,8 +152,8 @@ func TestObjectTree(t *testing.T) {
|
||||
objTree := ctx.objTree
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
}
|
||||
payload := RawChangesPayload{
|
||||
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
|
||||
@ -203,7 +197,7 @@ func TestObjectTree(t *testing.T) {
|
||||
objTree := ctx.objTree
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.createRaw("0", aclList.Head().Id, "", true, ""),
|
||||
changeCreator.CreateRaw("0", aclList.Head().Id, "", true, ""),
|
||||
}
|
||||
payload := RawChangesPayload{
|
||||
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
|
||||
@ -227,7 +221,33 @@ func TestObjectTree(t *testing.T) {
|
||||
objTree := ctx.objTree
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
}
|
||||
payload := RawChangesPayload{
|
||||
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
|
||||
RawChanges: rawChanges,
|
||||
}
|
||||
res, err := objTree.AddRawChanges(context.Background(), payload)
|
||||
require.NoError(t, err, "adding changes should be without error")
|
||||
|
||||
// check result
|
||||
assert.Equal(t, []string{"0"}, res.OldHeads)
|
||||
assert.Equal(t, []string{"0"}, res.Heads)
|
||||
assert.Equal(t, 0, len(res.Added))
|
||||
assert.Equal(t, Nothing, res.Mode)
|
||||
|
||||
// check tree heads
|
||||
assert.Equal(t, []string{"0"}, objTree.Heads())
|
||||
})
|
||||
|
||||
t.Run("add not connected changes", func(t *testing.T) {
|
||||
ctx := prepareTreeContext(t, aclList)
|
||||
changeCreator := ctx.changeCreator
|
||||
objTree := ctx.objTree
|
||||
|
||||
// this change could in theory replace current snapshot, we should prevent that
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "0", true, "1"),
|
||||
}
|
||||
payload := RawChangesPayload{
|
||||
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
|
||||
@ -253,10 +273,10 @@ func TestObjectTree(t *testing.T) {
|
||||
objTree := ctx.objTree
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
changeCreator.createRaw("4", aclList.Head().Id, "3", false, "3"),
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
changeCreator.CreateRaw("4", aclList.Head().Id, "3", false, "3"),
|
||||
}
|
||||
payload := RawChangesPayload{
|
||||
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
|
||||
@ -303,9 +323,9 @@ func TestObjectTree(t *testing.T) {
|
||||
objTree := ctx.objTree
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
}
|
||||
payload := RawChangesPayload{
|
||||
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
|
||||
@ -321,18 +341,195 @@ func TestObjectTree(t *testing.T) {
|
||||
assert.Equal(t, true, objTree.(*objectTree).snapshotPathIsActual())
|
||||
})
|
||||
|
||||
t.Run("test empty data tree", func(t *testing.T) {
|
||||
t.Run("empty tree add", func(t *testing.T) {
|
||||
ctx := prepareEmptyDataTreeContext(t, aclList, nil)
|
||||
changeCreator := ctx.changeCreator
|
||||
objTree := ctx.objTree
|
||||
|
||||
rawChangesFirst := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRawWithData("1", aclList.Head().Id, "0", false, []byte("1"), "0"),
|
||||
changeCreator.CreateRawWithData("2", aclList.Head().Id, "0", false, []byte("2"), "1"),
|
||||
changeCreator.CreateRawWithData("3", aclList.Head().Id, "0", false, []byte("3"), "2"),
|
||||
}
|
||||
rawChangesSecond := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRawWithData("4", aclList.Head().Id, "0", false, []byte("4"), "2"),
|
||||
changeCreator.CreateRawWithData("5", aclList.Head().Id, "0", false, []byte("5"), "1"),
|
||||
changeCreator.CreateRawWithData("6", aclList.Head().Id, "0", false, []byte("6"), "3", "4", "5"),
|
||||
}
|
||||
|
||||
// making them to be saved in unattached
|
||||
_, err := objTree.AddRawChanges(context.Background(), RawChangesPayload{
|
||||
NewHeads: []string{"6"},
|
||||
RawChanges: rawChangesSecond,
|
||||
})
|
||||
require.NoError(t, err, "adding changes should be without error")
|
||||
// attaching them
|
||||
res, err := objTree.AddRawChanges(context.Background(), RawChangesPayload{
|
||||
NewHeads: []string{"3"},
|
||||
RawChanges: rawChangesFirst,
|
||||
})
|
||||
|
||||
require.NoError(t, err, "adding changes should be without error")
|
||||
require.Equal(t, "0", objTree.Root().Id)
|
||||
require.Equal(t, []string{"6"}, objTree.Heads())
|
||||
require.Equal(t, 6, len(res.Added))
|
||||
|
||||
// checking that added changes still have data
|
||||
for _, ch := range res.Added {
|
||||
unmarshallRaw := &treechangeproto.RawTreeChange{}
|
||||
proto.Unmarshal(ch.RawChange, unmarshallRaw)
|
||||
treeCh := &treechangeproto.TreeChange{}
|
||||
proto.Unmarshal(unmarshallRaw.Payload, treeCh)
|
||||
require.Equal(t, ch.Id, string(treeCh.ChangesData))
|
||||
}
|
||||
|
||||
// checking that the tree doesn't have data in memory
|
||||
err = objTree.IterateRoot(nil, func(change *Change) bool {
|
||||
if change.Id == "0" {
|
||||
return true
|
||||
}
|
||||
require.Nil(t, change.Data)
|
||||
return true
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("empty tree load", func(t *testing.T) {
|
||||
ctx := prepareEmptyDataTreeContext(t, aclList, func(changeCreator *MockChangeCreator) RawChangesPayload {
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRawWithData("1", aclList.Head().Id, "0", false, []byte("1"), "0"),
|
||||
changeCreator.CreateRawWithData("2", aclList.Head().Id, "0", false, []byte("2"), "1"),
|
||||
changeCreator.CreateRawWithData("3", aclList.Head().Id, "0", false, []byte("3"), "2"),
|
||||
changeCreator.CreateRawWithData("4", aclList.Head().Id, "0", false, []byte("4"), "2"),
|
||||
changeCreator.CreateRawWithData("5", aclList.Head().Id, "0", false, []byte("5"), "1"),
|
||||
changeCreator.CreateRawWithData("6", aclList.Head().Id, "0", false, []byte("6"), "3", "4", "5"),
|
||||
}
|
||||
return RawChangesPayload{NewHeads: []string{"6"}, RawChanges: rawChanges}
|
||||
})
|
||||
ctx.objTree.IterateRoot(nil, func(change *Change) bool {
|
||||
if change.Id == "0" {
|
||||
return true
|
||||
}
|
||||
require.Nil(t, change.Data)
|
||||
return true
|
||||
})
|
||||
rawChanges, err := ctx.objTree.ChangesAfterCommonSnapshot([]string{"0"}, []string{"6"})
|
||||
require.NoError(t, err)
|
||||
for _, ch := range rawChanges {
|
||||
unmarshallRaw := &treechangeproto.RawTreeChange{}
|
||||
proto.Unmarshal(ch.RawChange, unmarshallRaw)
|
||||
treeCh := &treechangeproto.TreeChange{}
|
||||
proto.Unmarshal(unmarshallRaw.Payload, treeCh)
|
||||
require.Equal(t, ch.Id, string(treeCh.ChangesData))
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("rollback when add to storage returns error", func(t *testing.T) {
|
||||
ctx := prepareTreeContext(t, aclList)
|
||||
changeCreator := ctx.changeCreator
|
||||
objTree := ctx.objTree
|
||||
store := ctx.treeStorage.(*treestorage.InMemoryTreeStorage)
|
||||
addErr := fmt.Errorf("error saving")
|
||||
store.SetReturnErrorOnAdd(addErr)
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", true, "0"),
|
||||
}
|
||||
payload := RawChangesPayload{
|
||||
NewHeads: []string{"1"},
|
||||
RawChanges: rawChanges,
|
||||
}
|
||||
_, err := objTree.AddRawChanges(context.Background(), payload)
|
||||
require.Error(t, err, addErr)
|
||||
require.Equal(t, "0", objTree.Root().Id)
|
||||
})
|
||||
|
||||
t.Run("their heads before common snapshot", func(t *testing.T) {
|
||||
// checking that adding old changes did not affect the tree
|
||||
ctx := prepareTreeContext(t, aclList)
|
||||
changeCreator := ctx.changeCreator
|
||||
objTree := ctx.objTree
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", true, "0"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "1", false, "1"),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "1", true, "2"),
|
||||
changeCreator.CreateRaw("4", aclList.Head().Id, "1", false, "2"),
|
||||
changeCreator.CreateRaw("5", aclList.Head().Id, "1", false, "1"),
|
||||
changeCreator.CreateRaw("6", aclList.Head().Id, "1", true, "3", "4", "5"),
|
||||
}
|
||||
payload := RawChangesPayload{
|
||||
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
|
||||
RawChanges: rawChanges,
|
||||
}
|
||||
_, err := objTree.AddRawChanges(context.Background(), payload)
|
||||
require.NoError(t, err, "adding changes should be without error")
|
||||
require.Equal(t, "6", objTree.Root().Id)
|
||||
|
||||
rawChangesPrevious := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", true, "0"),
|
||||
}
|
||||
payload = RawChangesPayload{
|
||||
NewHeads: []string{"1"},
|
||||
RawChanges: rawChangesPrevious,
|
||||
}
|
||||
_, err = objTree.AddRawChanges(context.Background(), payload)
|
||||
require.NoError(t, err, "adding changes should be without error")
|
||||
require.Equal(t, "6", objTree.Root().Id)
|
||||
})
|
||||
|
||||
t.Run("stored changes will not break the pipeline if heads were not updated", func(t *testing.T) {
|
||||
ctx := prepareTreeContext(t, aclList)
|
||||
changeCreator := ctx.changeCreator
|
||||
objTree := ctx.objTree
|
||||
store := ctx.treeStorage.(*treestorage.InMemoryTreeStorage)
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", true, "0"),
|
||||
}
|
||||
payload := RawChangesPayload{
|
||||
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
|
||||
RawChanges: rawChanges,
|
||||
}
|
||||
_, err := objTree.AddRawChanges(context.Background(), payload)
|
||||
require.NoError(t, err, "adding changes should be without error")
|
||||
require.Equal(t, "1", objTree.Root().Id)
|
||||
|
||||
// creating changes to save in the storage
|
||||
// to imitate the condition where all changes are in the storage
|
||||
// but the head was not updated
|
||||
storageChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "1", false, "1"),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "1", true, "2"),
|
||||
changeCreator.CreateRaw("4", aclList.Head().Id, "1", false, "2"),
|
||||
changeCreator.CreateRaw("5", aclList.Head().Id, "1", false, "1"),
|
||||
changeCreator.CreateRaw("6", aclList.Head().Id, "1", true, "3", "4", "5"),
|
||||
}
|
||||
store.AddRawChangesSetHeads(storageChanges, []string{"1"})
|
||||
|
||||
// updating with subset of those changes to see that everything will still work
|
||||
payload = RawChangesPayload{
|
||||
NewHeads: []string{"6"},
|
||||
RawChanges: storageChanges,
|
||||
}
|
||||
_, err = objTree.AddRawChanges(context.Background(), payload)
|
||||
require.NoError(t, err, "adding changes should be without error")
|
||||
require.Equal(t, "6", objTree.Root().Id)
|
||||
})
|
||||
|
||||
t.Run("changes from tree after common snapshot complex", func(t *testing.T) {
|
||||
ctx := prepareTreeContext(t, aclList)
|
||||
changeCreator := ctx.changeCreator
|
||||
objTree := ctx.objTree
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
changeCreator.createRaw("4", aclList.Head().Id, "0", false, "2"),
|
||||
changeCreator.createRaw("5", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.createRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"),
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
changeCreator.CreateRaw("4", aclList.Head().Id, "0", false, "2"),
|
||||
changeCreator.CreateRaw("5", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"),
|
||||
}
|
||||
|
||||
payload := RawChangesPayload{
|
||||
@ -406,13 +603,13 @@ func TestObjectTree(t *testing.T) {
|
||||
objTree := ctx.objTree
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
changeCreator.createRaw("4", aclList.Head().Id, "0", false, "2"),
|
||||
changeCreator.createRaw("5", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
changeCreator.CreateRaw("4", aclList.Head().Id, "0", false, "2"),
|
||||
changeCreator.CreateRaw("5", aclList.Head().Id, "0", false, "1"),
|
||||
// main difference from tree example
|
||||
changeCreator.createRaw("6", aclList.Head().Id, "0", true, "3", "4", "5"),
|
||||
changeCreator.CreateRaw("6", aclList.Head().Id, "0", true, "3", "4", "5"),
|
||||
}
|
||||
|
||||
payload := RawChangesPayload{
|
||||
@ -487,9 +684,9 @@ func TestObjectTree(t *testing.T) {
|
||||
objTree := ctx.objTree
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.createRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.createRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.createRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
}
|
||||
payload := RawChangesPayload{
|
||||
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
|
||||
@ -501,9 +698,9 @@ func TestObjectTree(t *testing.T) {
|
||||
require.Equal(t, "3", objTree.Root().Id)
|
||||
|
||||
rawChanges = []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.createRaw("4", aclList.Head().Id, "0", false, "2"),
|
||||
changeCreator.createRaw("5", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.createRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"),
|
||||
changeCreator.CreateRaw("4", aclList.Head().Id, "0", false, "2"),
|
||||
changeCreator.CreateRaw("5", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"),
|
||||
}
|
||||
payload = RawChangesPayload{
|
||||
NewHeads: []string{rawChanges[len(rawChanges)-1].Id},
|
||||
@ -542,4 +739,154 @@ func TestObjectTree(t *testing.T) {
|
||||
assert.Equal(t, ch, raw, "the changes in the storage should be the same")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("test history tree not include", func(t *testing.T) {
|
||||
changeCreator, deps := prepareHistoryTreeDeps(aclList)
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
changeCreator.CreateRaw("4", aclList.Head().Id, "0", false, "2"),
|
||||
changeCreator.CreateRaw("5", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"),
|
||||
}
|
||||
deps.treeStorage.AddRawChangesSetHeads(rawChanges, []string{"6"})
|
||||
hTree, err := buildHistoryTree(deps, HistoryTreeParams{
|
||||
BeforeId: "6",
|
||||
IncludeBeforeId: false,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// check tree heads
|
||||
assert.Equal(t, []string{"3", "4", "5"}, hTree.Heads())
|
||||
|
||||
// check tree iterate
|
||||
var iterChangesId []string
|
||||
err = hTree.IterateFrom(hTree.Root().Id, nil, func(change *Change) bool {
|
||||
iterChangesId = append(iterChangesId, change.Id)
|
||||
return true
|
||||
})
|
||||
require.NoError(t, err, "iterate should be without error")
|
||||
assert.Equal(t, []string{"0", "1", "2", "3", "4", "5"}, iterChangesId)
|
||||
assert.Equal(t, "0", hTree.Root().Id)
|
||||
})
|
||||
|
||||
t.Run("test history tree build full", func(t *testing.T) {
|
||||
changeCreator, deps := prepareHistoryTreeDeps(aclList)
|
||||
|
||||
// sequence of snapshots: 5->1->0
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", true, "0"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "1", false, "1"),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "1", true, "2"),
|
||||
changeCreator.CreateRaw("4", aclList.Head().Id, "1", false, "2"),
|
||||
changeCreator.CreateRaw("5", aclList.Head().Id, "1", true, "3", "4"),
|
||||
changeCreator.CreateRaw("6", aclList.Head().Id, "5", false, "5"),
|
||||
}
|
||||
deps.treeStorage.AddRawChangesSetHeads(rawChanges, []string{"6"})
|
||||
hTree, err := buildHistoryTree(deps, HistoryTreeParams{
|
||||
BuildFullTree: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// check tree heads
|
||||
assert.Equal(t, []string{"6"}, hTree.Heads())
|
||||
|
||||
// check tree iterate
|
||||
var iterChangesId []string
|
||||
err = hTree.IterateFrom(hTree.Root().Id, nil, func(change *Change) bool {
|
||||
iterChangesId = append(iterChangesId, change.Id)
|
||||
return true
|
||||
})
|
||||
require.NoError(t, err, "iterate should be without error")
|
||||
assert.Equal(t, []string{"0", "1", "2", "3", "4", "5", "6"}, iterChangesId)
|
||||
assert.Equal(t, "0", hTree.Root().Id)
|
||||
})
|
||||
|
||||
t.Run("test history tree include", func(t *testing.T) {
|
||||
changeCreator, deps := prepareHistoryTreeDeps(aclList)
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
changeCreator.CreateRaw("4", aclList.Head().Id, "0", false, "2"),
|
||||
changeCreator.CreateRaw("5", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("6", aclList.Head().Id, "0", false, "3", "4", "5"),
|
||||
}
|
||||
deps.treeStorage.AddRawChangesSetHeads(rawChanges, []string{"6"})
|
||||
hTree, err := buildHistoryTree(deps, HistoryTreeParams{
|
||||
BeforeId: "6",
|
||||
IncludeBeforeId: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// check tree heads
|
||||
assert.Equal(t, []string{"6"}, hTree.Heads())
|
||||
|
||||
// check tree iterate
|
||||
var iterChangesId []string
|
||||
err = hTree.IterateFrom(hTree.Root().Id, nil, func(change *Change) bool {
|
||||
iterChangesId = append(iterChangesId, change.Id)
|
||||
return true
|
||||
})
|
||||
require.NoError(t, err, "iterate should be without error")
|
||||
assert.Equal(t, []string{"0", "1", "2", "3", "4", "5", "6"}, iterChangesId)
|
||||
assert.Equal(t, "0", hTree.Root().Id)
|
||||
})
|
||||
|
||||
t.Run("test history tree root", func(t *testing.T) {
|
||||
_, deps := prepareHistoryTreeDeps(aclList)
|
||||
hTree, err := buildHistoryTree(deps, HistoryTreeParams{
|
||||
BeforeId: "0",
|
||||
IncludeBeforeId: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// check tree heads
|
||||
assert.Equal(t, []string{"0"}, hTree.Heads())
|
||||
|
||||
// check tree iterate
|
||||
var iterChangesId []string
|
||||
err = hTree.IterateFrom(hTree.Root().Id, nil, func(change *Change) bool {
|
||||
iterChangesId = append(iterChangesId, change.Id)
|
||||
return true
|
||||
})
|
||||
require.NoError(t, err, "iterate should be without error")
|
||||
assert.Equal(t, []string{"0"}, iterChangesId)
|
||||
assert.Equal(t, "0", hTree.Root().Id)
|
||||
})
|
||||
|
||||
t.Run("validate correct tree", func(t *testing.T) {
|
||||
ctx := prepareTreeContext(t, aclList)
|
||||
changeCreator := ctx.changeCreator
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
ctx.objTree.Header(),
|
||||
changeCreator.CreateRaw("1", aclList.Head().Id, "0", false, "0"),
|
||||
changeCreator.CreateRaw("2", aclList.Head().Id, "0", false, "1"),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
}
|
||||
defaultObjectTreeDeps = nonVerifiableTreeDeps
|
||||
err := ValidateRawTree(treestorage.TreeStorageCreatePayload{
|
||||
RootRawChange: ctx.objTree.Header(),
|
||||
Heads: []string{"3"},
|
||||
Changes: rawChanges,
|
||||
}, ctx.aclList)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("fail to validate not connected tree", func(t *testing.T) {
|
||||
ctx := prepareTreeContext(t, aclList)
|
||||
changeCreator := ctx.changeCreator
|
||||
|
||||
rawChanges := []*treechangeproto.RawTreeChangeWithId{
|
||||
ctx.objTree.Header(),
|
||||
changeCreator.CreateRaw("3", aclList.Head().Id, "0", true, "2"),
|
||||
}
|
||||
defaultObjectTreeDeps = nonVerifiableTreeDeps
|
||||
err := ValidateRawTree(treestorage.TreeStorageCreatePayload{
|
||||
RootRawChange: ctx.objTree.Header(),
|
||||
Heads: []string{"3"},
|
||||
Changes: rawChanges,
|
||||
}, ctx.aclList)
|
||||
require.Equal(t, ErrHasInvalidChanges, err)
|
||||
})
|
||||
}
|
||||
|
||||
25
commonspace/object/tree/objecttree/objecttreedebug.go
Normal file
25
commonspace/object/tree/objecttree/objecttreedebug.go
Normal file
@ -0,0 +1,25 @@
|
||||
package objecttree
|
||||
|
||||
type objectTreeDebug struct {
|
||||
}
|
||||
|
||||
type DebugInfo struct {
|
||||
TreeLen int
|
||||
TreeString string
|
||||
Graphviz string
|
||||
Heads []string
|
||||
SnapshotPath []string
|
||||
}
|
||||
|
||||
func (o objectTreeDebug) debugInfo(ot *objectTree, parser DescriptionParser) (di DebugInfo, err error) {
|
||||
di = DebugInfo{}
|
||||
di.Graphviz, err = ot.tree.Graph(parser)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
di.TreeString = ot.tree.String()
|
||||
di.TreeLen = ot.tree.Len()
|
||||
di.Heads = ot.Heads()
|
||||
di.SnapshotPath = ot.SnapshotPath()
|
||||
return
|
||||
}
|
||||
@ -1,35 +1,132 @@
|
||||
package objecttree
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/keychain"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
"github.com/anytypeio/any-sync/util/keys/symmetric"
|
||||
"math/rand"
|
||||
"time"
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
)
|
||||
|
||||
type ObjectTreeCreatePayload struct {
|
||||
SignKey signingkey.PrivKey
|
||||
PrivKey crypto.PrivKey
|
||||
ChangeType string
|
||||
ChangePayload []byte
|
||||
SpaceId string
|
||||
Identity []byte
|
||||
IsEncrypted bool
|
||||
Seed []byte
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
func CreateObjectTreeRoot(payload ObjectTreeCreatePayload, aclList list.AclList) (root *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
bytes := make([]byte, 32)
|
||||
_, err = rand.Read(bytes)
|
||||
if err != nil {
|
||||
return
|
||||
type HistoryTreeParams struct {
|
||||
TreeStorage treestorage.TreeStorage
|
||||
AclList list.AclList
|
||||
BeforeId string
|
||||
IncludeBeforeId bool
|
||||
BuildFullTree bool
|
||||
}
|
||||
|
||||
type objectTreeDeps struct {
|
||||
changeBuilder ChangeBuilder
|
||||
treeBuilder *treeBuilder
|
||||
treeStorage treestorage.TreeStorage
|
||||
validator ObjectTreeValidator
|
||||
rawChangeLoader *rawChangeLoader
|
||||
aclList list.AclList
|
||||
}
|
||||
|
||||
type BuildObjectTreeFunc = func(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error)
|
||||
|
||||
var defaultObjectTreeDeps = verifiableTreeDeps
|
||||
|
||||
func verifiableTreeDeps(
|
||||
rootChange *treechangeproto.RawTreeChangeWithId,
|
||||
treeStorage treestorage.TreeStorage,
|
||||
aclList list.AclList) objectTreeDeps {
|
||||
changeBuilder := NewChangeBuilder(crypto.NewKeyStorage(), rootChange)
|
||||
treeBuilder := newTreeBuilder(true, treeStorage, changeBuilder)
|
||||
return objectTreeDeps{
|
||||
changeBuilder: changeBuilder,
|
||||
treeBuilder: treeBuilder,
|
||||
treeStorage: treeStorage,
|
||||
validator: newTreeValidator(),
|
||||
rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder),
|
||||
aclList: aclList,
|
||||
}
|
||||
return createObjectTreeRoot(payload, time.Now().UnixNano(), bytes, aclList)
|
||||
}
|
||||
|
||||
func DeriveObjectTreeRoot(payload ObjectTreeCreatePayload, aclList list.AclList) (root *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
return createObjectTreeRoot(payload, 0, nil, aclList)
|
||||
func emptyDataTreeDeps(
|
||||
rootChange *treechangeproto.RawTreeChangeWithId,
|
||||
treeStorage treestorage.TreeStorage,
|
||||
aclList list.AclList) objectTreeDeps {
|
||||
changeBuilder := NewChangeBuilder(crypto.NewKeyStorage(), rootChange)
|
||||
treeBuilder := newTreeBuilder(false, treeStorage, changeBuilder)
|
||||
return objectTreeDeps{
|
||||
changeBuilder: changeBuilder,
|
||||
treeBuilder: treeBuilder,
|
||||
treeStorage: treeStorage,
|
||||
validator: newTreeValidator(),
|
||||
rawChangeLoader: newStorageLoader(treeStorage, changeBuilder),
|
||||
aclList: aclList,
|
||||
}
|
||||
}
|
||||
|
||||
func nonVerifiableTreeDeps(
|
||||
rootChange *treechangeproto.RawTreeChangeWithId,
|
||||
treeStorage treestorage.TreeStorage,
|
||||
aclList list.AclList) objectTreeDeps {
|
||||
changeBuilder := &nonVerifiableChangeBuilder{NewChangeBuilder(newMockKeyStorage(), rootChange)}
|
||||
treeBuilder := newTreeBuilder(true, treeStorage, changeBuilder)
|
||||
return objectTreeDeps{
|
||||
changeBuilder: changeBuilder,
|
||||
treeBuilder: treeBuilder,
|
||||
treeStorage: treeStorage,
|
||||
validator: &noOpTreeValidator{},
|
||||
rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder),
|
||||
aclList: aclList,
|
||||
}
|
||||
}
|
||||
|
||||
func BuildEmptyDataObjectTree(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error) {
|
||||
rootChange, err := treeStorage.Root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deps := emptyDataTreeDeps(rootChange, treeStorage, aclList)
|
||||
return buildObjectTree(deps)
|
||||
}
|
||||
|
||||
func BuildTestableTree(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error) {
|
||||
root, _ := treeStorage.Root()
|
||||
changeBuilder := &nonVerifiableChangeBuilder{
|
||||
ChangeBuilder: NewChangeBuilder(newMockKeyStorage(), root),
|
||||
}
|
||||
deps := objectTreeDeps{
|
||||
changeBuilder: changeBuilder,
|
||||
treeBuilder: newTreeBuilder(true, treeStorage, changeBuilder),
|
||||
treeStorage: treeStorage,
|
||||
rawChangeLoader: newRawChangeLoader(treeStorage, changeBuilder),
|
||||
validator: &noOpTreeValidator{},
|
||||
aclList: aclList,
|
||||
}
|
||||
|
||||
return buildObjectTree(deps)
|
||||
}
|
||||
|
||||
func BuildEmptyDataTestableTree(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error) {
|
||||
root, _ := treeStorage.Root()
|
||||
changeBuilder := &nonVerifiableChangeBuilder{
|
||||
ChangeBuilder: NewChangeBuilder(newMockKeyStorage(), root),
|
||||
}
|
||||
deps := objectTreeDeps{
|
||||
changeBuilder: changeBuilder,
|
||||
treeBuilder: newTreeBuilder(false, treeStorage, changeBuilder),
|
||||
treeStorage: treeStorage,
|
||||
rawChangeLoader: newStorageLoader(treeStorage, changeBuilder),
|
||||
validator: &noOpTreeValidator{},
|
||||
aclList: aclList,
|
||||
}
|
||||
|
||||
return buildObjectTree(deps)
|
||||
}
|
||||
|
||||
func BuildObjectTree(treeStorage treestorage.TreeStorage, aclList list.AclList) (ObjectTree, error) {
|
||||
@ -41,54 +138,25 @@ func BuildObjectTree(treeStorage treestorage.TreeStorage, aclList list.AclList)
|
||||
return buildObjectTree(deps)
|
||||
}
|
||||
|
||||
func CreateDerivedObjectTree(
|
||||
payload ObjectTreeCreatePayload,
|
||||
aclList list.AclList,
|
||||
createStorage treestorage.TreeStorageCreatorFunc) (objTree ObjectTree, err error) {
|
||||
return createObjectTree(payload, 0, nil, aclList, createStorage)
|
||||
func BuildNonVerifiableHistoryTree(params HistoryTreeParams) (HistoryTree, error) {
|
||||
rootChange, err := params.TreeStorage.Root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deps := nonVerifiableTreeDeps(rootChange, params.TreeStorage, params.AclList)
|
||||
return buildHistoryTree(deps, params)
|
||||
}
|
||||
|
||||
func CreateObjectTree(
|
||||
payload ObjectTreeCreatePayload,
|
||||
aclList list.AclList,
|
||||
createStorage treestorage.TreeStorageCreatorFunc) (objTree ObjectTree, err error) {
|
||||
bytes := make([]byte, 32)
|
||||
_, err = rand.Read(bytes)
|
||||
func BuildHistoryTree(params HistoryTreeParams) (HistoryTree, error) {
|
||||
rootChange, err := params.TreeStorage.Root()
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
return createObjectTree(payload, time.Now().UnixNano(), bytes, aclList, createStorage)
|
||||
deps := defaultObjectTreeDeps(rootChange, params.TreeStorage, params.AclList)
|
||||
return buildHistoryTree(deps, params)
|
||||
}
|
||||
|
||||
func createObjectTree(
|
||||
payload ObjectTreeCreatePayload,
|
||||
timestamp int64,
|
||||
seed []byte,
|
||||
aclList list.AclList,
|
||||
createStorage treestorage.TreeStorageCreatorFunc) (objTree ObjectTree, err error) {
|
||||
raw, err := createObjectTreeRoot(payload, timestamp, seed, aclList)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// create storage
|
||||
st, err := createStorage(treestorage.TreeStorageCreatePayload{
|
||||
RootRawChange: raw,
|
||||
Changes: []*treechangeproto.RawTreeChangeWithId{raw},
|
||||
Heads: []string{raw.Id},
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return BuildObjectTree(st, aclList)
|
||||
}
|
||||
|
||||
func createObjectTreeRoot(
|
||||
payload ObjectTreeCreatePayload,
|
||||
timestamp int64,
|
||||
seed []byte,
|
||||
aclList list.AclList) (root *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
func CreateObjectTreeRoot(payload ObjectTreeCreatePayload, aclList list.AclList) (root *treechangeproto.RawTreeChangeWithId, err error) {
|
||||
aclList.RLock()
|
||||
aclHeadId := aclList.Head().Id
|
||||
aclList.RUnlock()
|
||||
@ -98,28 +166,28 @@ func createObjectTreeRoot(
|
||||
}
|
||||
cnt := InitialContent{
|
||||
AclHeadId: aclHeadId,
|
||||
Identity: payload.Identity,
|
||||
SigningKey: payload.SignKey,
|
||||
PrivKey: payload.PrivKey,
|
||||
SpaceId: payload.SpaceId,
|
||||
ChangeType: payload.ChangeType,
|
||||
Timestamp: timestamp,
|
||||
Seed: seed,
|
||||
ChangePayload: payload.ChangePayload,
|
||||
Timestamp: payload.Timestamp,
|
||||
Seed: payload.Seed,
|
||||
}
|
||||
|
||||
_, root, err = NewChangeBuilder(keychain.NewKeychain(), nil).BuildInitialContent(cnt)
|
||||
_, root, err = NewChangeBuilder(crypto.NewKeyStorage(), nil).BuildRoot(cnt)
|
||||
return
|
||||
}
|
||||
|
||||
func buildObjectTree(deps objectTreeDeps) (ObjectTree, error) {
|
||||
objTree := &objectTree{
|
||||
id: deps.treeStorage.Id(),
|
||||
treeStorage: deps.treeStorage,
|
||||
treeBuilder: deps.treeBuilder,
|
||||
validator: deps.validator,
|
||||
aclList: deps.aclList,
|
||||
changeBuilder: deps.changeBuilder,
|
||||
rawChangeLoader: deps.rawChangeLoader,
|
||||
tree: nil,
|
||||
keys: make(map[uint64]*symmetric.Key),
|
||||
keys: make(map[string]crypto.SymKey),
|
||||
newChangesBuf: make([]*Change, 0, 10),
|
||||
difSnapshotBuf: make([]*treechangeproto.RawTreeChangeWithId, 0, 10),
|
||||
notSeenIdxBuf: make([]int, 0, 10),
|
||||
@ -131,14 +199,13 @@ func buildObjectTree(deps objectTreeDeps) (ObjectTree, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
objTree.id = objTree.treeStorage.Id()
|
||||
objTree.rawRoot, err = objTree.treeStorage.Root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// verifying root
|
||||
header, err := objTree.changeBuilder.ConvertFromRaw(objTree.rawRoot, true)
|
||||
header, err := objTree.changeBuilder.Unmarshall(objTree.rawRoot, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -146,3 +213,38 @@ func buildObjectTree(deps objectTreeDeps) (ObjectTree, error) {
|
||||
|
||||
return objTree, nil
|
||||
}
|
||||
|
||||
func buildHistoryTree(deps objectTreeDeps, params HistoryTreeParams) (ht HistoryTree, err error) {
|
||||
objTree := &objectTree{
|
||||
id: deps.treeStorage.Id(),
|
||||
treeStorage: deps.treeStorage,
|
||||
treeBuilder: deps.treeBuilder,
|
||||
validator: deps.validator,
|
||||
aclList: deps.aclList,
|
||||
changeBuilder: deps.changeBuilder,
|
||||
rawChangeLoader: deps.rawChangeLoader,
|
||||
keys: make(map[string]crypto.SymKey),
|
||||
newChangesBuf: make([]*Change, 0, 10),
|
||||
difSnapshotBuf: make([]*treechangeproto.RawTreeChangeWithId, 0, 10),
|
||||
notSeenIdxBuf: make([]int, 0, 10),
|
||||
newSnapshotsBuf: make([]*Change, 0, 10),
|
||||
}
|
||||
|
||||
hTree := &historyTree{objectTree: objTree}
|
||||
err = hTree.rebuildFromStorage(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objTree.id = objTree.treeStorage.Id()
|
||||
objTree.rawRoot, err = objTree.treeStorage.Root()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header, err := objTree.changeBuilder.Unmarshall(objTree.rawRoot, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objTree.root = header
|
||||
return hTree, nil
|
||||
}
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
package objecttree
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/list"
|
||||
|
||||
"github.com/anyproto/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/anyproto/any-sync/util/slice"
|
||||
)
|
||||
|
||||
type ObjectTreeValidator interface {
|
||||
@ -13,6 +16,16 @@ type ObjectTreeValidator interface {
|
||||
ValidateNewChanges(tree *Tree, aclList list.AclList, newChanges []*Change) error
|
||||
}
|
||||
|
||||
type noOpTreeValidator struct{}
|
||||
|
||||
func (n *noOpTreeValidator) ValidateFullTree(tree *Tree, aclList list.AclList) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n *noOpTreeValidator) ValidateNewChanges(tree *Tree, aclList list.AclList, newChanges []*Change) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type objectTreeValidator struct{}
|
||||
|
||||
func newTreeValidator() ObjectTreeValidator {
|
||||
@ -39,20 +52,18 @@ func (v *objectTreeValidator) ValidateNewChanges(tree *Tree, aclList list.AclLis
|
||||
|
||||
func (v *objectTreeValidator) validateChange(tree *Tree, aclList list.AclList, c *Change) (err error) {
|
||||
var (
|
||||
perm list.UserPermissionPair
|
||||
userState list.AclUserState
|
||||
state = aclList.AclState()
|
||||
)
|
||||
// checking if the user could write
|
||||
perm, err = state.PermissionsAtRecord(c.AclHeadId, c.Identity)
|
||||
userState, err = state.StateAtRecord(c.AclHeadId, c.Identity)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if perm.Permission != aclrecordproto.AclUserPermissions_Writer && perm.Permission != aclrecordproto.AclUserPermissions_Admin {
|
||||
if !userState.Permissions.CanWrite() {
|
||||
err = list.ErrInsufficientPermissions
|
||||
return
|
||||
}
|
||||
|
||||
if c.Id == tree.RootId() {
|
||||
return
|
||||
}
|
||||
@ -75,3 +86,25 @@ func (v *objectTreeValidator) validateChange(tree *Tree, aclList list.AclList, c
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func ValidateRawTree(payload treestorage.TreeStorageCreatePayload, aclList list.AclList) (err error) {
|
||||
treeStorage, err := treestorage.NewInMemoryTreeStorage(payload.RootRawChange, []string{payload.RootRawChange.Id}, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tree, err := BuildObjectTree(treeStorage, aclList)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
res, err := tree.AddRawChanges(context.Background(), RawChangesPayload{
|
||||
NewHeads: payload.Heads,
|
||||
RawChanges: payload.Changes,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !slice.UnsortedEquals(res.Heads, payload.Heads) {
|
||||
return ErrHasInvalidChanges
|
||||
}
|
||||
return
|
||||
}
|
||||
@ -2,15 +2,17 @@ package objecttree
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/anytypeio/any-sync/util/slice"
|
||||
"time"
|
||||
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
|
||||
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
|
||||
"github.com/anyproto/any-sync/util/slice"
|
||||
)
|
||||
|
||||
type rawChangeLoader struct {
|
||||
treeStorage treestorage.TreeStorage
|
||||
changeBuilder ChangeBuilder
|
||||
alwaysFromStorage bool
|
||||
|
||||
// buffers
|
||||
idStack []string
|
||||
@ -21,6 +23,13 @@ type rawCacheEntry struct {
|
||||
change *Change
|
||||
rawChange *treechangeproto.RawTreeChangeWithId
|
||||
position int
|
||||
removed bool
|
||||
}
|
||||
|
||||
func newStorageLoader(treeStorage treestorage.TreeStorage, changeBuilder ChangeBuilder) *rawChangeLoader {
|
||||
loader := newRawChangeLoader(treeStorage, changeBuilder)
|
||||
loader.alwaysFromStorage = true
|
||||
return loader
|
||||
}
|
||||
|
||||
func newRawChangeLoader(treeStorage treestorage.TreeStorage, changeBuilder ChangeBuilder) *rawChangeLoader {
|
||||
@ -30,7 +39,15 @@ func newRawChangeLoader(treeStorage treestorage.TreeStorage, changeBuilder Chang
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rawChangeLoader) LoadFromTree(t *Tree, breakpoints []string) ([]*treechangeproto.RawTreeChangeWithId, error) {
|
||||
func (r *rawChangeLoader) Load(commonSnapshot string, t *Tree, breakpoints []string) ([]*treechangeproto.RawTreeChangeWithId, error) {
|
||||
if commonSnapshot == t.root.Id && !r.alwaysFromStorage {
|
||||
return r.loadFromTree(t, breakpoints)
|
||||
} else {
|
||||
return r.loadFromStorage(commonSnapshot, t.Heads(), breakpoints)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rawChangeLoader) loadFromTree(t *Tree, breakpoints []string) ([]*treechangeproto.RawTreeChangeWithId, error) {
|
||||
var stack []*Change
|
||||
for _, h := range t.headIds {
|
||||
stack = append(stack, t.attached[h])
|
||||
@ -39,7 +56,7 @@ func (r *rawChangeLoader) LoadFromTree(t *Tree, breakpoints []string) ([]*treech
|
||||
convert := func(chs []*Change) (rawChanges []*treechangeproto.RawTreeChangeWithId, err error) {
|
||||
for _, ch := range chs {
|
||||
var raw *treechangeproto.RawTreeChangeWithId
|
||||
raw, err = r.changeBuilder.BuildRaw(ch)
|
||||
raw, err = r.changeBuilder.Marshall(ch)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -98,7 +115,7 @@ func (r *rawChangeLoader) LoadFromTree(t *Tree, breakpoints []string) ([]*treech
|
||||
return convert(results)
|
||||
}
|
||||
|
||||
func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoints []string) ([]*treechangeproto.RawTreeChangeWithId, error) {
|
||||
func (r *rawChangeLoader) loadFromStorage(commonSnapshot string, heads, breakpoints []string) ([]*treechangeproto.RawTreeChangeWithId, error) {
|
||||
// resetting cache
|
||||
r.cache = make(map[string]rawCacheEntry)
|
||||
defer func() {
|
||||
@ -111,7 +128,6 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
entry.position = -1
|
||||
r.cache[b] = entry
|
||||
existingBreakpoints = append(existingBreakpoints, b)
|
||||
}
|
||||
@ -120,8 +136,7 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
|
||||
dfs := func(
|
||||
commonSnapshot string,
|
||||
heads []string,
|
||||
startCounter int,
|
||||
shouldVisit func(counter int, mapExists bool) bool,
|
||||
shouldVisit func(entry rawCacheEntry, mapExists bool) bool,
|
||||
visit func(entry rawCacheEntry) rawCacheEntry) bool {
|
||||
|
||||
// resetting stack
|
||||
@ -135,7 +150,7 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
|
||||
r.idStack = r.idStack[:len(r.idStack)-1]
|
||||
|
||||
entry, exists := r.cache[id]
|
||||
if !shouldVisit(entry.position, exists) {
|
||||
if !shouldVisit(entry, exists) {
|
||||
continue
|
||||
}
|
||||
if id == commonSnapshot {
|
||||
@ -144,7 +159,6 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
|
||||
}
|
||||
if !exists {
|
||||
entry, err = r.loadEntry(id)
|
||||
entry.position = -1
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
@ -159,7 +173,7 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
|
||||
break
|
||||
}
|
||||
prevEntry, exists := r.cache[prev]
|
||||
if !shouldVisit(prevEntry.position, exists) {
|
||||
if !shouldVisit(prevEntry, exists) {
|
||||
continue
|
||||
}
|
||||
r.idStack = append(r.idStack, prev)
|
||||
@ -172,8 +186,8 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
|
||||
r.idStack = append(r.idStack, heads...)
|
||||
var buffer []*treechangeproto.RawTreeChangeWithId
|
||||
|
||||
rootVisited := dfs(commonSnapshot, heads, 0,
|
||||
func(counter int, mapExists bool) bool {
|
||||
rootVisited := dfs(commonSnapshot, heads,
|
||||
func(_ rawCacheEntry, mapExists bool) bool {
|
||||
return !mapExists
|
||||
},
|
||||
func(entry rawCacheEntry) rawCacheEntry {
|
||||
@ -198,11 +212,13 @@ func (r *rawChangeLoader) LoadFromStorage(commonSnapshot string, heads, breakpoi
|
||||
}
|
||||
|
||||
// marking all visited as nil
|
||||
dfs(commonSnapshot, existingBreakpoints, len(buffer),
|
||||
func(counter int, mapExists bool) bool {
|
||||
return !mapExists || counter < len(buffer)
|
||||
dfs(commonSnapshot, existingBreakpoints,
|
||||
func(entry rawCacheEntry, mapExists bool) bool {
|
||||
// only going through already loaded changes
|
||||
return mapExists && !entry.removed
|
||||
},
|
||||
func(entry rawCacheEntry) rawCacheEntry {
|
||||
entry.removed = true
|
||||
if entry.position != -1 {
|
||||
buffer[entry.position] = nil
|
||||
}
|
||||
@ -226,13 +242,14 @@ func (r *rawChangeLoader) loadEntry(id string) (entry rawCacheEntry, err error)
|
||||
return
|
||||
}
|
||||
|
||||
change, err := r.changeBuilder.ConvertFromRaw(rawChange, false)
|
||||
change, err := r.changeBuilder.Unmarshall(rawChange, false)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
entry = rawCacheEntry{
|
||||
change: change,
|
||||
rawChange: rawChange,
|
||||
position: -1,
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -1,16 +0,0 @@
|
||||
package objecttree
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/commonspace/object/acl/list"
|
||||
"github.com/anytypeio/any-sync/commonspace/object/tree/treestorage"
|
||||
)
|
||||
|
||||
func ValidateRawTree(payload treestorage.TreeStorageCreatePayload, aclList list.AclList) (err error) {
|
||||
treeStorage, err := treestorage.NewInMemoryTreeStorage(payload.RootRawChange, payload.Heads, payload.Changes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = BuildObjectTree(treeStorage, aclList)
|
||||
return
|
||||
}
|
||||
@ -1,13 +1,19 @@
|
||||
package objecttree
|
||||
|
||||
import (
|
||||
"github.com/anytypeio/any-sync/util/keys/asymmetric/signingkey"
|
||||
"github.com/anyproto/any-sync/util/crypto"
|
||||
)
|
||||
|
||||
// SignableChangeContent is a payload to be passed when we are creating change
|
||||
type SignableChangeContent struct {
|
||||
// Data is a data provided by the client
|
||||
Data []byte
|
||||
Key signingkey.PrivKey
|
||||
Identity []byte
|
||||
// Key is the key which will be used to sign the change
|
||||
Key crypto.PrivKey
|
||||
// IsSnapshot tells if the change has snapshot of all previous data
|
||||
IsSnapshot bool
|
||||
// IsEncrypted tells if we encrypt the data with the relevant symmetric key
|
||||
IsEncrypted bool
|
||||
// Timestamp is a timestamp of change, if it is <= 0, then we use current timestamp
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
x
Reference in New Issue
Block a user