Merge pull request #15 from anyproto/new-sync-protocol

This commit is contained in:
Mikhail Rakhmanov 2023-06-09 11:44:35 +02:00 committed by GitHub
commit 90c5ef3311
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
115 changed files with 6473 additions and 3412 deletions

View File

@ -55,6 +55,7 @@ type ComponentStatable interface {
// App is the central part of the application // App is the central part of the application
// It contains and manages all components // It contains and manages all components
type App struct { type App struct {
parent *App
components []Component components []Component
mu sync.RWMutex mu sync.RWMutex
startStat Stat startStat Stat
@ -109,6 +110,16 @@ func VersionDescription() string {
return fmt.Sprintf("build on %s from %s at #%s(%s)", BuildDate, GitBranch, GitCommit, GitState) return fmt.Sprintf("build on %s from %s at #%s(%s)", BuildDate, GitBranch, GitCommit, GitState)
} }
// ChildApp creates a child container which has access to parent's components
// It doesn't call Start on any of the parent's components
func (app *App) ChildApp() *App {
return &App{
parent: app,
deviceState: app.deviceState,
anySyncVersion: app.AnySyncVersion(),
}
}
// Register adds service to registry // Register adds service to registry
// All components will be started in the order they were registered // All components will be started in the order they were registered
func (app *App) Register(s Component) *App { func (app *App) Register(s Component) *App {
@ -128,10 +139,14 @@ func (app *App) Register(s Component) *App {
func (app *App) Component(name string) Component { func (app *App) Component(name string) Component {
app.mu.RLock() app.mu.RLock()
defer app.mu.RUnlock() defer app.mu.RUnlock()
for _, s := range app.components { current := app
if s.Name() == name { for current != nil {
return s for _, s := range current.components {
if s.Name() == name {
return s
}
} }
current = current.parent
} }
return nil return nil
} }
@ -149,10 +164,14 @@ func (app *App) MustComponent(name string) Component {
func MustComponent[i any](app *App) i { func MustComponent[i any](app *App) i {
app.mu.RLock() app.mu.RLock()
defer app.mu.RUnlock() defer app.mu.RUnlock()
for _, s := range app.components { current := app
if v, ok := s.(i); ok { for current != nil {
return v for _, s := range current.components {
if v, ok := s.(i); ok {
return v
}
} }
current = current.parent
} }
empty := new(i) empty := new(i)
panic(fmt.Errorf("component with interface %T is not found", empty)) panic(fmt.Errorf("component with interface %T is not found", empty))
@ -162,9 +181,13 @@ func MustComponent[i any](app *App) i {
func (app *App) ComponentNames() (names []string) { func (app *App) ComponentNames() (names []string) {
app.mu.RLock() app.mu.RLock()
defer app.mu.RUnlock() defer app.mu.RUnlock()
names = make([]string, len(app.components)) names = make([]string, 0, len(app.components))
for i, c := range app.components { current := app
names[i] = c.Name() for current != nil {
for _, c := range current.components {
names = append(names, c.Name())
}
current = current.parent
} }
return return
} }

View File

@ -34,6 +34,25 @@ func TestAppServiceRegistry(t *testing.T) {
names := app.ComponentNames() names := app.ComponentNames()
assert.Equal(t, names, []string{"c1", "r1", "s1"}) assert.Equal(t, names, []string{"c1", "r1", "s1"})
}) })
t.Run("Child MustComponent", func(t *testing.T) {
app := app.ChildApp()
app.Register(newTestService(testTypeComponent, "x1", nil, nil))
for _, name := range []string{"c1", "r1", "s1", "x1"} {
assert.NotPanics(t, func() { app.MustComponent(name) }, name)
}
assert.Panics(t, func() { app.MustComponent("not-registered") })
})
t.Run("Child ComponentNames", func(t *testing.T) {
app := app.ChildApp()
app.Register(newTestService(testTypeComponent, "x1", nil, nil))
names := app.ComponentNames()
assert.Equal(t, names, []string{"x1", "c1", "r1", "s1"})
})
t.Run("Child override", func(t *testing.T) {
app := app.ChildApp()
app.Register(newTestService(testTypeRunnable, "s1", nil, nil))
_ = app.MustComponent("s1").(*testRunnable)
})
} }
func TestAppStart(t *testing.T) { func TestAppStart(t *testing.T) {

View File

@ -6,6 +6,9 @@ import (
) )
func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Option { func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Option {
if reg == nil {
return nil
}
if subsystem == "" { if subsystem == "" {
subsystem = "cache" subsystem = "cache"
} }
@ -13,9 +16,7 @@ func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Optio
subSplit := strings.Split(subsystem, ".") subSplit := strings.Split(subsystem, ".")
namespace = strings.Join(nameSplit, "_") namespace = strings.Join(nameSplit, "_")
subsystem = strings.Join(subSplit, "_") subsystem = strings.Join(subSplit, "_")
if reg == nil {
return nil
}
return func(cache *oCache) { return func(cache *oCache) {
cache.metrics = &metrics{ cache.metrics = &metrics{
hit: prometheus.NewCounter(prometheus.CounterOpts{ hit: prometheus.NewCounter(prometheus.CounterOpts{

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT. // Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32 // protoc-gen-go-drpc version: v0.0.33
// source: commonfile/fileproto/protos/file.proto // source: commonfile/fileproto/protos/file.proto
package fileproto package fileproto

View File

@ -1,62 +0,0 @@
package commonspace
import (
"context"
"github.com/anyproto/any-sync/commonspace/object/syncobjectgetter"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/treemanager"
"sync/atomic"
)
type commonGetter struct {
treemanager.TreeManager
spaceId string
reservedObjects []syncobjectgetter.SyncObject
spaceIsClosed *atomic.Bool
}
func newCommonGetter(spaceId string, getter treemanager.TreeManager, spaceIsClosed *atomic.Bool) *commonGetter {
return &commonGetter{
TreeManager: getter,
spaceId: spaceId,
spaceIsClosed: spaceIsClosed,
}
}
func (c *commonGetter) AddObject(object syncobjectgetter.SyncObject) {
c.reservedObjects = append(c.reservedObjects, object)
}
func (c *commonGetter) GetTree(ctx context.Context, spaceId, treeId string) (objecttree.ObjectTree, error) {
if c.spaceIsClosed.Load() {
return nil, ErrSpaceClosed
}
if obj := c.getReservedObject(treeId); obj != nil {
return obj.(objecttree.ObjectTree), nil
}
return c.TreeManager.GetTree(ctx, spaceId, treeId)
}
func (c *commonGetter) getReservedObject(id string) syncobjectgetter.SyncObject {
for _, obj := range c.reservedObjects {
if obj != nil && obj.Id() == id {
return obj
}
}
return nil
}
func (c *commonGetter) GetObject(ctx context.Context, objectId string) (obj syncobjectgetter.SyncObject, err error) {
if c.spaceIsClosed.Load() {
return nil, ErrSpaceClosed
}
if obj := c.getReservedObject(objectId); obj != nil {
return obj, nil
}
t, err := c.TreeManager.GetTree(ctx, c.spaceId, objectId)
if err != nil {
return
}
obj = t.(syncobjectgetter.SyncObject)
return
}

View File

@ -1,4 +1,4 @@
package commonspace package config
type ConfigGetter interface { type ConfigGetter interface {
GetSpace() Config GetSpace() Config

View File

@ -3,6 +3,7 @@ package credentialprovider
import ( import (
"context" "context"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
) )
@ -13,12 +14,21 @@ func NewNoOp() CredentialProvider {
} }
type CredentialProvider interface { type CredentialProvider interface {
app.Component
GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error)
} }
type noOpProvider struct { type noOpProvider struct {
} }
func (n noOpProvider) Init(a *app.App) (err error) {
return nil
}
func (n noOpProvider) Name() (name string) {
return CName
}
func (n noOpProvider) GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) { func (n noOpProvider) GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) {
return nil, nil return nil, nil
} }

View File

@ -8,6 +8,7 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
app "github.com/anyproto/any-sync/app"
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto" spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
@ -49,3 +50,31 @@ func (mr *MockCredentialProviderMockRecorder) GetCredential(arg0, arg1 interface
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCredential", reflect.TypeOf((*MockCredentialProvider)(nil).GetCredential), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCredential", reflect.TypeOf((*MockCredentialProvider)(nil).GetCredential), arg0, arg1)
} }
// Init mocks base method.
func (m *MockCredentialProvider) Init(arg0 *app.App) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Init indicates an expected call of Init.
func (mr *MockCredentialProviderMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockCredentialProvider)(nil).Init), arg0)
}
// Name mocks base method.
func (m *MockCredentialProvider) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockCredentialProviderMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockCredentialProvider)(nil).Name))
}

View File

@ -73,13 +73,14 @@ func TestSpaceDeleteIds(t *testing.T) {
fx.treeManager.space = spc fx.treeManager.space = spc
err = spc.Init(ctx) err = spc.Init(ctx)
require.NoError(t, err) require.NoError(t, err)
close(fx.treeManager.waitLoad)
var ids []string var ids []string
for i := 0; i < totalObjs; i++ { for i := 0; i < totalObjs; i++ {
// creating a tree // creating a tree
bytes := make([]byte, 32) bytes := make([]byte, 32)
rand.Read(bytes) rand.Read(bytes)
doc, err := spc.CreateTree(ctx, objecttree.ObjectTreeCreatePayload{ doc, err := spc.TreeBuilder().CreateTree(ctx, objecttree.ObjectTreeCreatePayload{
PrivKey: acc.SignKey, PrivKey: acc.SignKey,
ChangeType: "some", ChangeType: "some",
SpaceId: spc.Id(), SpaceId: spc.Id(),
@ -88,7 +89,7 @@ func TestSpaceDeleteIds(t *testing.T) {
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
}) })
require.NoError(t, err) require.NoError(t, err)
tr, err := spc.PutTree(ctx, doc, nil) tr, err := spc.TreeBuilder().PutTree(ctx, doc, nil)
require.NoError(t, err) require.NoError(t, err)
ids = append(ids, tr.Id()) ids = append(ids, tr.Id())
tr.Close() tr.Close()
@ -106,7 +107,7 @@ func TestSpaceDeleteIds(t *testing.T) {
func createTree(t *testing.T, ctx context.Context, spc Space, acc *accountdata.AccountKeys) string { func createTree(t *testing.T, ctx context.Context, spc Space, acc *accountdata.AccountKeys) string {
bytes := make([]byte, 32) bytes := make([]byte, 32)
rand.Read(bytes) rand.Read(bytes)
doc, err := spc.CreateTree(ctx, objecttree.ObjectTreeCreatePayload{ doc, err := spc.TreeBuilder().CreateTree(ctx, objecttree.ObjectTreeCreatePayload{
PrivKey: acc.SignKey, PrivKey: acc.SignKey,
ChangeType: "some", ChangeType: "some",
SpaceId: spc.Id(), SpaceId: spc.Id(),
@ -115,7 +116,7 @@ func createTree(t *testing.T, ctx context.Context, spc Space, acc *accountdata.A
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
}) })
require.NoError(t, err) require.NoError(t, err)
tr, err := spc.PutTree(ctx, doc, nil) tr, err := spc.TreeBuilder().PutTree(ctx, doc, nil)
require.NoError(t, err) require.NoError(t, err)
tr.Close() tr.Close()
return tr.Id() return tr.Id()
@ -147,9 +148,10 @@ func TestSpaceDeleteIdsIncorrectSnapshot(t *testing.T) {
// adding space to tree manager // adding space to tree manager
fx.treeManager.space = spc fx.treeManager.space = spc
err = spc.Init(ctx) err = spc.Init(ctx)
close(fx.treeManager.waitLoad)
require.NoError(t, err) require.NoError(t, err)
settingsObject := spc.(*space).settingsObject settingsObject := spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
var ids []string var ids []string
for i := 0; i < totalObjs; i++ { for i := 0; i < totalObjs; i++ {
id := createTree(t, ctx, spc, acc) id := createTree(t, ctx, spc, acc)
@ -183,17 +185,19 @@ func TestSpaceDeleteIdsIncorrectSnapshot(t *testing.T) {
spc, err = fx.spaceService.NewSpace(ctx, sp) spc, err = fx.spaceService.NewSpace(ctx, sp)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, spc) require.NotNil(t, spc)
fx.treeManager.waitLoad = make(chan struct{})
fx.treeManager.space = spc fx.treeManager.space = spc
fx.treeManager.deletedIds = nil fx.treeManager.deletedIds = nil
err = spc.Init(ctx) err = spc.Init(ctx)
require.NoError(t, err) require.NoError(t, err)
close(fx.treeManager.waitLoad)
// waiting until everything is deleted // waiting until everything is deleted
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
require.Equal(t, len(ids), len(fx.treeManager.deletedIds)) require.Equal(t, len(ids), len(fx.treeManager.deletedIds))
// checking that new snapshot will contain all the changes // checking that new snapshot will contain all the changes
settingsObject = spc.(*space).settingsObject settingsObject = spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
settings.DoSnapshot = func(treeLen int) bool { settings.DoSnapshot = func(treeLen int) bool {
return true return true
} }
@ -230,8 +234,9 @@ func TestSpaceDeleteIdsMarkDeleted(t *testing.T) {
fx.treeManager.space = spc fx.treeManager.space = spc
err = spc.Init(ctx) err = spc.Init(ctx)
require.NoError(t, err) require.NoError(t, err)
close(fx.treeManager.waitLoad)
settingsObject := spc.(*space).settingsObject settingsObject := spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
var ids []string var ids []string
for i := 0; i < totalObjs; i++ { for i := 0; i < totalObjs; i++ {
id := createTree(t, ctx, spc, acc) id := createTree(t, ctx, spc, acc)
@ -259,10 +264,12 @@ func TestSpaceDeleteIdsMarkDeleted(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, spc) require.NotNil(t, spc)
fx.treeManager.space = spc fx.treeManager.space = spc
fx.treeManager.waitLoad = make(chan struct{})
fx.treeManager.deletedIds = nil fx.treeManager.deletedIds = nil
fx.treeManager.markedIds = nil fx.treeManager.markedIds = nil
err = spc.Init(ctx) err = spc.Init(ctx)
require.NoError(t, err) require.NoError(t, err)
close(fx.treeManager.waitLoad)
// waiting until everything is deleted // waiting until everything is deleted
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)

View File

@ -1,16 +1,22 @@
//go:generate mockgen -destination mock_settingsstate/mock_settingsstate.go github.com/anyproto/any-sync/commonspace/settings/settingsstate ObjectDeletionState,StateBuilder,ChangeFactory //go:generate mockgen -destination mock_deletionstate/mock_deletionstate.go github.com/anyproto/any-sync/commonspace/deletionstate ObjectDeletionState
package settingsstate package deletionstate
import ( import (
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"go.uber.org/zap" "go.uber.org/zap"
"sync" "sync"
) )
var log = logger.NewNamed(CName)
const CName = "common.commonspace.deletionstate"
type StateUpdateObserver func(ids []string) type StateUpdateObserver func(ids []string)
type ObjectDeletionState interface { type ObjectDeletionState interface {
app.Component
AddObserver(observer StateUpdateObserver) AddObserver(observer StateUpdateObserver)
Add(ids map[string]struct{}) Add(ids map[string]struct{})
GetQueued() (ids []string) GetQueued() (ids []string)
@ -28,12 +34,20 @@ type objectDeletionState struct {
storage spacestorage.SpaceStorage storage spacestorage.SpaceStorage
} }
func NewObjectDeletionState(log logger.CtxLogger, storage spacestorage.SpaceStorage) ObjectDeletionState { func (st *objectDeletionState) Init(a *app.App) (err error) {
st.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
return nil
}
func (st *objectDeletionState) Name() (name string) {
return CName
}
func New() ObjectDeletionState {
return &objectDeletionState{ return &objectDeletionState{
log: log, log: log,
queued: map[string]struct{}{}, queued: map[string]struct{}{},
deleted: map[string]struct{}{}, deleted: map[string]struct{}{},
storage: storage,
} }
} }

View File

@ -1,7 +1,6 @@
package settingsstate package deletionstate
import ( import (
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -19,7 +18,8 @@ type fixture struct {
func newFixture(t *testing.T) *fixture { func newFixture(t *testing.T) *fixture {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
spaceStorage := mock_spacestorage.NewMockSpaceStorage(ctrl) spaceStorage := mock_spacestorage.NewMockSpaceStorage(ctrl)
delState := NewObjectDeletionState(logger.NewNamed("test"), spaceStorage).(*objectDeletionState) delState := New().(*objectDeletionState)
delState.storage = spaceStorage
return &fixture{ return &fixture{
ctrl: ctrl, ctrl: ctrl,
delState: delState, delState: delState,

View File

@ -0,0 +1,144 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/deletionstate (interfaces: ObjectDeletionState)
// Package mock_deletionstate is a generated GoMock package.
package mock_deletionstate
import (
reflect "reflect"
app "github.com/anyproto/any-sync/app"
deletionstate "github.com/anyproto/any-sync/commonspace/deletionstate"
gomock "github.com/golang/mock/gomock"
)
// MockObjectDeletionState is a mock of ObjectDeletionState interface.
type MockObjectDeletionState struct {
ctrl *gomock.Controller
recorder *MockObjectDeletionStateMockRecorder
}
// MockObjectDeletionStateMockRecorder is the mock recorder for MockObjectDeletionState.
type MockObjectDeletionStateMockRecorder struct {
mock *MockObjectDeletionState
}
// NewMockObjectDeletionState creates a new mock instance.
func NewMockObjectDeletionState(ctrl *gomock.Controller) *MockObjectDeletionState {
mock := &MockObjectDeletionState{ctrl: ctrl}
mock.recorder = &MockObjectDeletionStateMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockObjectDeletionState) EXPECT() *MockObjectDeletionStateMockRecorder {
return m.recorder
}
// Add mocks base method.
func (m *MockObjectDeletionState) Add(arg0 map[string]struct{}) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0)
}
// Add indicates an expected call of Add.
func (mr *MockObjectDeletionStateMockRecorder) Add(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockObjectDeletionState)(nil).Add), arg0)
}
// AddObserver mocks base method.
func (m *MockObjectDeletionState) AddObserver(arg0 deletionstate.StateUpdateObserver) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddObserver", arg0)
}
// AddObserver indicates an expected call of AddObserver.
func (mr *MockObjectDeletionStateMockRecorder) AddObserver(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddObserver", reflect.TypeOf((*MockObjectDeletionState)(nil).AddObserver), arg0)
}
// Delete mocks base method.
func (m *MockObjectDeletionState) Delete(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockObjectDeletionStateMockRecorder) Delete(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockObjectDeletionState)(nil).Delete), arg0)
}
// Exists mocks base method.
func (m *MockObjectDeletionState) Exists(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Exists", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// Exists indicates an expected call of Exists.
func (mr *MockObjectDeletionStateMockRecorder) Exists(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockObjectDeletionState)(nil).Exists), arg0)
}
// Filter mocks base method.
func (m *MockObjectDeletionState) Filter(arg0 []string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Filter", arg0)
ret0, _ := ret[0].([]string)
return ret0
}
// Filter indicates an expected call of Filter.
func (mr *MockObjectDeletionStateMockRecorder) Filter(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filter", reflect.TypeOf((*MockObjectDeletionState)(nil).Filter), arg0)
}
// GetQueued mocks base method.
func (m *MockObjectDeletionState) GetQueued() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetQueued")
ret0, _ := ret[0].([]string)
return ret0
}
// GetQueued indicates an expected call of GetQueued.
func (mr *MockObjectDeletionStateMockRecorder) GetQueued() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQueued", reflect.TypeOf((*MockObjectDeletionState)(nil).GetQueued))
}
// Init mocks base method.
func (m *MockObjectDeletionState) Init(arg0 *app.App) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Init indicates an expected call of Init.
func (mr *MockObjectDeletionStateMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockObjectDeletionState)(nil).Init), arg0)
}
// Name mocks base method.
func (m *MockObjectDeletionState) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockObjectDeletionStateMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockObjectDeletionState)(nil).Name))
}

View File

@ -6,9 +6,9 @@ import (
"github.com/anyproto/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/credentialprovider" "github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/peermanager" "github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
@ -22,30 +22,22 @@ type DiffSyncer interface {
Sync(ctx context.Context) error Sync(ctx context.Context) error
RemoveObjects(ids []string) RemoveObjects(ids []string)
UpdateHeads(id string, heads []string) UpdateHeads(id string, heads []string)
Init(deletionState settingsstate.ObjectDeletionState) Init()
Close() error Close() error
} }
func newDiffSyncer( func newDiffSyncer(hs *headSync) DiffSyncer {
spaceId string,
diff ldiff.Diff,
peerManager peermanager.PeerManager,
cache treemanager.TreeManager,
storage spacestorage.SpaceStorage,
clientFactory spacesyncproto.ClientFactory,
syncStatus syncstatus.StatusUpdater,
credentialProvider credentialprovider.CredentialProvider,
log logger.CtxLogger) DiffSyncer {
return &diffSyncer{ return &diffSyncer{
diff: diff, diff: hs.diff,
spaceId: spaceId, spaceId: hs.spaceId,
treeManager: cache, treeManager: hs.treeManager,
storage: storage, storage: hs.storage,
peerManager: peerManager, peerManager: hs.peerManager,
clientFactory: clientFactory, clientFactory: spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient),
credentialProvider: credentialProvider, credentialProvider: hs.credentialProvider,
log: log, log: log,
syncStatus: syncStatus, syncStatus: hs.syncStatus,
deletionState: hs.deletionState,
} }
} }
@ -57,14 +49,13 @@ type diffSyncer struct {
storage spacestorage.SpaceStorage storage spacestorage.SpaceStorage
clientFactory spacesyncproto.ClientFactory clientFactory spacesyncproto.ClientFactory
log logger.CtxLogger log logger.CtxLogger
deletionState settingsstate.ObjectDeletionState deletionState deletionstate.ObjectDeletionState
credentialProvider credentialprovider.CredentialProvider credentialProvider credentialprovider.CredentialProvider
syncStatus syncstatus.StatusUpdater syncStatus syncstatus.StatusUpdater
treeSyncer treemanager.TreeSyncer treeSyncer treemanager.TreeSyncer
} }
func (d *diffSyncer) Init(deletionState settingsstate.ObjectDeletionState) { func (d *diffSyncer) Init() {
d.deletionState = deletionState
d.deletionState.AddObserver(d.RemoveObjects) d.deletionState.AddObserver(d.RemoveObjects)
d.treeSyncer = d.treeManager.NewTreeSyncer(d.spaceId, d.treeManager) d.treeSyncer = d.treeManager.NewTreeSyncer(d.spaceId, d.treeManager)
} }
@ -115,8 +106,14 @@ func (d *diffSyncer) Sync(ctx context.Context) error {
func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error) { func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error) {
ctx = logger.CtxWithFields(ctx, zap.String("peerId", p.Id())) ctx = logger.CtxWithFields(ctx, zap.String("peerId", p.Id()))
conn, err := p.AcquireDrpcConn(ctx)
if err != nil {
return
}
defer p.ReleaseDrpcConn(conn)
var ( var (
cl = d.clientFactory.Client(p) cl = d.clientFactory.Client(conn)
rdiff = NewRemoteDiff(d.spaceId, cl) rdiff = NewRemoteDiff(d.spaceId, cl)
stateCounter = d.syncStatus.StateCounter() stateCounter = d.syncStatus.StateCounter()
) )

View File

@ -5,23 +5,13 @@ import (
"context" "context"
"fmt" "fmt"
"github.com/anyproto/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/app/ldiff/mock_ldiff"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/credentialprovider/mock_credentialprovider"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage/mock_liststorage" "github.com/anyproto/any-sync/commonspace/object/acl/liststorage/mock_liststorage"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
mock_treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
"github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager"
"github.com/anyproto/any-sync/commonspace/peermanager/mock_peermanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/spacesyncproto/mock_spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"storj.io/drpc" "storj.io/drpc"
"testing" "testing"
@ -36,60 +26,6 @@ type pushSpaceRequestMatcher struct {
spaceHeader *spacesyncproto.RawSpaceHeaderWithId spaceHeader *spacesyncproto.RawSpaceHeaderWithId
} }
func (p pushSpaceRequestMatcher) Matches(x interface{}) bool {
res, ok := x.(*spacesyncproto.SpacePushRequest)
if !ok {
return false
}
return res.Payload.AclPayloadId == p.aclRootId && res.Payload.SpaceHeader == p.spaceHeader && res.Payload.SpaceSettingsPayloadId == p.settingsId && bytes.Equal(p.credential, res.Credential)
}
func (p pushSpaceRequestMatcher) String() string {
return ""
}
type mockPeer struct{}
func (m mockPeer) Addr() string {
return ""
}
func (m mockPeer) TryClose(objectTTL time.Duration) (res bool, err error) {
return true, m.Close()
}
func (m mockPeer) Id() string {
return "mockId"
}
func (m mockPeer) LastUsage() time.Time {
return time.Time{}
}
func (m mockPeer) Secure() sec.SecureConn {
return nil
}
func (m mockPeer) UpdateLastUsage() {
}
func (m mockPeer) Close() error {
return nil
}
func (m mockPeer) Closed() <-chan struct{} {
return make(chan struct{})
}
func (m mockPeer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error {
return nil
}
func (m mockPeer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
return nil, nil
}
func newPushSpaceRequestMatcher( func newPushSpaceRequestMatcher(
spaceId string, spaceId string,
aclRootId string, aclRootId string,
@ -105,80 +41,134 @@ func newPushSpaceRequestMatcher(
} }
} }
func TestDiffSyncer_Sync(t *testing.T) { func (p pushSpaceRequestMatcher) Matches(x interface{}) bool {
// setup res, ok := x.(*spacesyncproto.SpacePushRequest)
ctx := context.Background() if !ok {
ctrl := gomock.NewController(t) return false
defer ctrl.Finish() }
diffMock := mock_ldiff.NewMockDiff(ctrl) return res.Payload.AclPayloadId == p.aclRootId && res.Payload.SpaceHeader == p.spaceHeader && res.Payload.SpaceSettingsPayloadId == p.settingsId && bytes.Equal(p.credential, res.Credential)
peerManagerMock := mock_peermanager.NewMockPeerManager(ctrl) }
cacheMock := mock_treemanager.NewMockTreeManager(ctrl)
stMock := mock_spacestorage.NewMockSpaceStorage(ctrl) func (p pushSpaceRequestMatcher) String() string {
clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl) return ""
factory := spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient { }
return clientMock
type mockPeer struct {
}
func (m mockPeer) Id() string {
return "peerId"
}
func (m mockPeer) Context() context.Context {
return context.Background()
}
func (m mockPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
return nil, nil
}
func (m mockPeer) ReleaseDrpcConn(conn drpc.Conn) {
return
}
func (m mockPeer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error {
return nil
}
func (m mockPeer) IsClosed() bool {
return false
}
func (m mockPeer) TryClose(objectTTL time.Duration) (res bool, err error) {
return false, err
}
func (m mockPeer) Close() (err error) {
return nil
}
func (fx *headSyncFixture) initDiffSyncer(t *testing.T) {
fx.init(t)
fx.diffSyncer = newDiffSyncer(fx.headSync).(*diffSyncer)
fx.diffSyncer.clientFactory = spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient {
return fx.clientMock
}) })
treeSyncerMock := mock_treemanager.NewMockTreeSyncer(ctrl) fx.deletionStateMock.EXPECT().AddObserver(gomock.Any())
credentialProvider := mock_credentialprovider.NewMockCredentialProvider(ctrl) fx.treeManagerMock.EXPECT().NewTreeSyncer(fx.spaceState.SpaceId, fx.treeManagerMock).Return(fx.treeSyncerMock)
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) fx.diffSyncer.Init()
spaceId := "spaceId" }
aclRootId := "aclRootId"
l := logger.NewNamed(spaceId) func TestDiffSyncer(t *testing.T) {
diffSyncer := newDiffSyncer(spaceId, diffMock, peerManagerMock, cacheMock, stMock, factory, syncstatus.NewNoOpSyncStatus(), credentialProvider, l) ctx := context.Background()
delState.EXPECT().AddObserver(gomock.Any())
cacheMock.EXPECT().NewTreeSyncer(spaceId, gomock.Any()).Return(treeSyncerMock)
diffSyncer.Init(delState)
t.Run("diff syncer sync", func(t *testing.T) { t.Run("diff syncer sync", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
mPeer := mockPeer{} mPeer := mockPeer{}
peerManagerMock.EXPECT(). fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mPeer}, nil) Return([]peer.Peer{mPeer}, nil)
diffMock.EXPECT(). fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return([]string{"new"}, []string{"changed"}, nil, nil) Return([]string{"new"}, []string{"changed"}, nil, nil)
delState.EXPECT().Filter([]string{"new"}).Return([]string{"new"}).Times(1) fx.deletionStateMock.EXPECT().Filter([]string{"new"}).Return([]string{"new"}).Times(1)
delState.EXPECT().Filter([]string{"changed"}).Return([]string{"changed"}).Times(1) fx.deletionStateMock.EXPECT().Filter([]string{"changed"}).Return([]string{"changed"}).Times(1)
delState.EXPECT().Filter(nil).Return(nil).Times(1) fx.deletionStateMock.EXPECT().Filter(nil).Return(nil).Times(1)
treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"changed"}, []string{"new"}).Return(nil) fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"changed"}, []string{"new"}).Return(nil)
require.NoError(t, diffSyncer.Sync(ctx)) require.NoError(t, fx.diffSyncer.Sync(ctx))
}) })
t.Run("diff syncer sync conf error", func(t *testing.T) { t.Run("diff syncer sync conf error", func(t *testing.T) {
peerManagerMock.EXPECT(). fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
ctx := context.Background()
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return(nil, fmt.Errorf("some error")) Return(nil, fmt.Errorf("some error"))
require.Error(t, diffSyncer.Sync(ctx)) require.Error(t, fx.diffSyncer.Sync(ctx))
}) })
t.Run("deletion state remove objects", func(t *testing.T) { t.Run("deletion state remove objects", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
deletedId := "id" deletedId := "id"
delState.EXPECT().Exists(deletedId).Return(true) fx.deletionStateMock.EXPECT().Exists(deletedId).Return(true)
// this should not result in any mock being called // this should not result in any mock being called
diffSyncer.UpdateHeads(deletedId, []string{"someHead"}) fx.diffSyncer.UpdateHeads(deletedId, []string{"someHead"})
}) })
t.Run("update heads updates diff", func(t *testing.T) { t.Run("update heads updates diff", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
newId := "newId" newId := "newId"
newHeads := []string{"h1", "h2"} newHeads := []string{"h1", "h2"}
hash := "hash" hash := "hash"
diffMock.EXPECT().Set(ldiff.Element{ fx.diffMock.EXPECT().Set(ldiff.Element{
Id: newId, Id: newId,
Head: concatStrings(newHeads), Head: concatStrings(newHeads),
}) })
diffMock.EXPECT().Hash().Return(hash) fx.diffMock.EXPECT().Hash().Return(hash)
delState.EXPECT().Exists(newId).Return(false) fx.deletionStateMock.EXPECT().Exists(newId).Return(false)
stMock.EXPECT().WriteSpaceHash(hash) fx.storageMock.EXPECT().WriteSpaceHash(hash)
diffSyncer.UpdateHeads(newId, newHeads) fx.diffSyncer.UpdateHeads(newId, newHeads)
}) })
t.Run("diff syncer sync space missing", func(t *testing.T) { t.Run("diff syncer sync space missing", func(t *testing.T) {
aclStorageMock := mock_liststorage.NewMockListStorage(ctrl) fx := newHeadSyncFixture(t)
settingsStorage := mock_treestorage.NewMockTreeStorage(ctrl) fx.initDiffSyncer(t)
defer fx.stop()
aclStorageMock := mock_liststorage.NewMockListStorage(fx.ctrl)
settingsStorage := mock_treestorage.NewMockTreeStorage(fx.ctrl)
settingsId := "settingsId" settingsId := "settingsId"
aclRootId := "aclRootId"
aclRoot := &aclrecordproto.RawAclRecordWithId{ aclRoot := &aclrecordproto.RawAclRecordWithId{
Id: aclRootId, Id: aclRootId,
} }
@ -189,55 +179,61 @@ func TestDiffSyncer_Sync(t *testing.T) {
spaceSettingsId := "spaceSettingsId" spaceSettingsId := "spaceSettingsId"
credential := []byte("credential") credential := []byte("credential")
peerManagerMock.EXPECT(). fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mockPeer{}}, nil) Return([]peer.Peer{mockPeer{}}, nil)
diffMock.EXPECT(). fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return(nil, nil, nil, spacesyncproto.ErrSpaceMissing) Return(nil, nil, nil, spacesyncproto.ErrSpaceMissing)
stMock.EXPECT().AclStorage().Return(aclStorageMock, nil) fx.storageMock.EXPECT().AclStorage().Return(aclStorageMock, nil)
stMock.EXPECT().SpaceHeader().Return(spaceHeader, nil) fx.storageMock.EXPECT().SpaceHeader().Return(spaceHeader, nil)
stMock.EXPECT().SpaceSettingsId().Return(spaceSettingsId) fx.storageMock.EXPECT().SpaceSettingsId().Return(spaceSettingsId)
stMock.EXPECT().TreeStorage(spaceSettingsId).Return(settingsStorage, nil) fx.storageMock.EXPECT().TreeStorage(spaceSettingsId).Return(settingsStorage, nil)
settingsStorage.EXPECT().Root().Return(settingsRoot, nil) settingsStorage.EXPECT().Root().Return(settingsRoot, nil)
aclStorageMock.EXPECT(). aclStorageMock.EXPECT().
Root(). Root().
Return(aclRoot, nil) Return(aclRoot, nil)
credentialProvider.EXPECT(). fx.credentialProviderMock.EXPECT().
GetCredential(gomock.Any(), spaceHeader). GetCredential(gomock.Any(), spaceHeader).
Return(credential, nil) Return(credential, nil)
clientMock.EXPECT(). fx.clientMock.EXPECT().
SpacePush(gomock.Any(), newPushSpaceRequestMatcher(spaceId, aclRootId, settingsId, credential, spaceHeader)). SpacePush(gomock.Any(), newPushSpaceRequestMatcher(fx.spaceState.SpaceId, aclRootId, settingsId, credential, spaceHeader)).
Return(nil, nil) Return(nil, nil)
peerManagerMock.EXPECT().SendPeer(gomock.Any(), "mockId", gomock.Any()) fx.peerManagerMock.EXPECT().SendPeer(gomock.Any(), "peerId", gomock.Any())
require.NoError(t, diffSyncer.Sync(ctx)) require.NoError(t, fx.diffSyncer.Sync(ctx))
}) })
t.Run("diff syncer sync unexpected", func(t *testing.T) { t.Run("diff syncer sync unexpected", func(t *testing.T) {
peerManagerMock.EXPECT(). fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mockPeer{}}, nil) Return([]peer.Peer{mockPeer{}}, nil)
diffMock.EXPECT(). fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return(nil, nil, nil, spacesyncproto.ErrUnexpected) Return(nil, nil, nil, spacesyncproto.ErrUnexpected)
require.NoError(t, diffSyncer.Sync(ctx)) require.NoError(t, fx.diffSyncer.Sync(ctx))
}) })
t.Run("diff syncer sync space is deleted error", func(t *testing.T) { t.Run("diff syncer sync space is deleted error", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
mPeer := mockPeer{} mPeer := mockPeer{}
peerManagerMock.EXPECT(). fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mPeer}, nil) Return([]peer.Peer{mPeer}, nil)
diffMock.EXPECT(). fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return(nil, nil, nil, spacesyncproto.ErrSpaceIsDeleted) Return(nil, nil, nil, spacesyncproto.ErrSpaceIsDeleted)
stMock.EXPECT().SpaceSettingsId().Return("settingsId") fx.storageMock.EXPECT().SpaceSettingsId().Return("settingsId")
treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"settingsId"}, nil).Return(nil) fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"settingsId"}, nil).Return(nil)
require.NoError(t, diffSyncer.Sync(ctx)) require.NoError(t, fx.diffSyncer.Sync(ctx))
}) })
} }

View File

@ -3,123 +3,145 @@ package headsync
import ( import (
"context" "context"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
config2 "github.com/anyproto/any-sync/commonspace/config"
"github.com/anyproto/any-sync/commonspace/credentialprovider" "github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/peermanager" "github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/util/periodicsync" "github.com/anyproto/any-sync/util/periodicsync"
"github.com/anyproto/any-sync/util/slice"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"strings"
"sync/atomic" "sync/atomic"
"time" "time"
) )
var log = logger.NewNamed(CName)
const CName = "common.commonspace.headsync"
type TreeHeads struct { type TreeHeads struct {
Id string Id string
Heads []string Heads []string
} }
type HeadSync interface { type HeadSync interface {
Init(objectIds []string, deletionState settingsstate.ObjectDeletionState) app.ComponentRunnable
ExternalIds() []string
DebugAllHeads() (res []TreeHeads)
AllIds() []string
UpdateHeads(id string, heads []string) UpdateHeads(id string, heads []string)
HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error)
RemoveObjects(ids []string) RemoveObjects(ids []string)
AllIds() []string
DebugAllHeads() (res []TreeHeads)
Close() (err error)
} }
type headSync struct { type headSync struct {
spaceId string spaceId string
periodicSync periodicsync.PeriodicSync
storage spacestorage.SpaceStorage
diff ldiff.Diff
log logger.CtxLogger
syncer DiffSyncer
configuration nodeconf.NodeConf
spaceIsDeleted *atomic.Bool spaceIsDeleted *atomic.Bool
syncPeriod int
syncPeriod int periodicSync periodicsync.PeriodicSync
storage spacestorage.SpaceStorage
diff ldiff.Diff
log logger.CtxLogger
syncer DiffSyncer
configuration nodeconf.NodeConf
peerManager peermanager.PeerManager
treeManager treemanager.TreeManager
credentialProvider credentialprovider.CredentialProvider
syncStatus syncstatus.StatusService
deletionState deletionstate.ObjectDeletionState
} }
func NewHeadSync( func New() HeadSync {
spaceId string, return &headSync{}
spaceIsDeleted *atomic.Bool, }
syncPeriod int,
configuration nodeconf.NodeConf,
storage spacestorage.SpaceStorage,
peerManager peermanager.PeerManager,
cache treemanager.TreeManager,
syncStatus syncstatus.StatusUpdater,
credentialProvider credentialprovider.CredentialProvider,
log logger.CtxLogger) HeadSync {
diff := ldiff.New(16, 16) var createDiffSyncer = newDiffSyncer
l := log.With(zap.String("spaceId", spaceId))
factory := spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient) func (h *headSync) Init(a *app.App) (err error) {
syncer := newDiffSyncer(spaceId, diff, peerManager, cache, storage, factory, syncStatus, credentialProvider, l) shared := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
cfg := a.MustComponent("config").(config2.ConfigGetter)
h.spaceId = shared.SpaceId
h.spaceIsDeleted = shared.SpaceIsDeleted
h.syncPeriod = cfg.GetSpace().SyncPeriod
h.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
h.log = log.With(zap.String("spaceId", h.spaceId))
h.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
h.diff = ldiff.New(16, 16)
h.peerManager = a.MustComponent(peermanager.CName).(peermanager.PeerManager)
h.credentialProvider = a.MustComponent(credentialprovider.CName).(credentialprovider.CredentialProvider)
h.syncStatus = a.MustComponent(syncstatus.CName).(syncstatus.StatusService)
h.treeManager = a.MustComponent(treemanager.CName).(treemanager.TreeManager)
h.deletionState = a.MustComponent(deletionstate.CName).(deletionstate.ObjectDeletionState)
h.syncer = createDiffSyncer(h)
sync := func(ctx context.Context) (err error) { sync := func(ctx context.Context) (err error) {
// for clients cancelling the sync process // for clients cancelling the sync process
if spaceIsDeleted.Load() && !configuration.IsResponsible(spaceId) { if h.spaceIsDeleted.Load() && !h.configuration.IsResponsible(h.spaceId) {
return spacesyncproto.ErrSpaceIsDeleted return spacesyncproto.ErrSpaceIsDeleted
} }
return syncer.Sync(ctx) return h.syncer.Sync(ctx)
}
periodicSync := periodicsync.NewPeriodicSync(syncPeriod, time.Minute, sync, l)
return &headSync{
spaceId: spaceId,
storage: storage,
syncer: syncer,
periodicSync: periodicSync,
diff: diff,
log: log,
syncPeriod: syncPeriod,
configuration: configuration,
spaceIsDeleted: spaceIsDeleted,
} }
h.periodicSync = periodicsync.NewPeriodicSync(h.syncPeriod, time.Minute, sync, h.log)
// TODO: move to run?
h.syncer.Init()
return nil
} }
func (d *headSync) Init(objectIds []string, deletionState settingsstate.ObjectDeletionState) { func (h *headSync) Name() (name string) {
d.fillDiff(objectIds) return CName
d.syncer.Init(deletionState)
d.periodicSync.Run()
} }
func (d *headSync) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) { func (h *headSync) Run(ctx context.Context) (err error) {
if d.spaceIsDeleted.Load() { initialIds, err := h.storage.StoredIds()
if err != nil {
return
}
h.fillDiff(initialIds)
h.periodicSync.Run()
return
}
func (h *headSync) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) {
if h.spaceIsDeleted.Load() {
peerId, err := peer.CtxPeerId(ctx) peerId, err := peer.CtxPeerId(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// stop receiving all request for sync from clients // stop receiving all request for sync from clients
if !slices.Contains(d.configuration.NodeIds(d.spaceId), peerId) { if !slices.Contains(h.configuration.NodeIds(h.spaceId), peerId) {
return nil, spacesyncproto.ErrSpaceIsDeleted return nil, spacesyncproto.ErrSpaceIsDeleted
} }
} }
return HandleRangeRequest(ctx, d.diff, req) return HandleRangeRequest(ctx, h.diff, req)
} }
func (d *headSync) UpdateHeads(id string, heads []string) { func (h *headSync) UpdateHeads(id string, heads []string) {
d.syncer.UpdateHeads(id, heads) h.syncer.UpdateHeads(id, heads)
} }
func (d *headSync) AllIds() []string { func (h *headSync) AllIds() []string {
return d.diff.Ids() return h.diff.Ids()
} }
func (d *headSync) DebugAllHeads() (res []TreeHeads) { func (h *headSync) ExternalIds() []string {
els := d.diff.Elements() settingsId := h.storage.SpaceSettingsId()
return slice.DiscardFromSlice(h.AllIds(), func(id string) bool {
return id == settingsId
})
}
func (h *headSync) DebugAllHeads() (res []TreeHeads) {
els := h.diff.Elements()
for _, el := range els { for _, el := range els {
idHead := TreeHeads{ idHead := TreeHeads{
Id: el.Id, Id: el.Id,
@ -130,19 +152,19 @@ func (d *headSync) DebugAllHeads() (res []TreeHeads) {
return return
} }
func (d *headSync) RemoveObjects(ids []string) { func (h *headSync) RemoveObjects(ids []string) {
d.syncer.RemoveObjects(ids) h.syncer.RemoveObjects(ids)
} }
func (d *headSync) Close() (err error) { func (h *headSync) Close(ctx context.Context) (err error) {
d.periodicSync.Close() h.periodicSync.Close()
return d.syncer.Close() return h.syncer.Close()
} }
func (d *headSync) fillDiff(objectIds []string) { func (h *headSync) fillDiff(objectIds []string) {
var els = make([]ldiff.Element, 0, len(objectIds)) var els = make([]ldiff.Element, 0, len(objectIds))
for _, id := range objectIds { for _, id := range objectIds {
st, err := d.storage.TreeStorage(id) st, err := h.storage.TreeStorage(id)
if err != nil { if err != nil {
continue continue
} }
@ -155,32 +177,8 @@ func (d *headSync) fillDiff(objectIds []string) {
Head: concatStrings(heads), Head: concatStrings(heads),
}) })
} }
d.diff.Set(els...) h.diff.Set(els...)
if err := d.storage.WriteSpaceHash(d.diff.Hash()); err != nil { if err := h.storage.WriteSpaceHash(h.diff.Hash()); err != nil {
d.log.Error("can't write space hash", zap.Error(err)) h.log.Error("can't write space hash", zap.Error(err))
} }
} }
func concatStrings(strs []string) string {
var (
b strings.Builder
totalLen int
)
for _, s := range strs {
totalLen += len(s)
}
b.Grow(totalLen)
for _, s := range strs {
b.WriteString(s)
}
return b.String()
}
func splitString(str string) (res []string) {
const cidLen = 59
for i := 0; i < len(str); i += cidLen {
res = append(res, str[i:i+cidLen])
}
return
}

View File

@ -1,71 +1,179 @@
package headsync package headsync
import ( import (
"context"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/app/ldiff/mock_ldiff" "github.com/anyproto/any-sync/app/ldiff/mock_ldiff"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/commonspace/config"
"github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anyproto/any-sync/commonspace/credentialprovider/mock_credentialprovider"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate"
"github.com/anyproto/any-sync/commonspace/headsync/mock_headsync" "github.com/anyproto/any-sync/commonspace/headsync/mock_headsync"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/peermanager/mock_peermanager"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/anyproto/any-sync/util/periodicsync/mock_periodicsync" "github.com/anyproto/any-sync/commonspace/spacesyncproto/mock_spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/nodeconf/mock_nodeconf"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"sync/atomic"
"testing" "testing"
) )
func TestDiffService(t *testing.T) { type mockConfig struct {
ctrl := gomock.NewController(t) }
defer ctrl.Finish()
spaceId := "spaceId" func (m mockConfig) Init(a *app.App) (err error) {
l := logger.NewNamed("sync") return nil
pSyncMock := mock_periodicsync.NewMockPeriodicSync(ctrl) }
storageMock := mock_spacestorage.NewMockSpaceStorage(ctrl)
treeStorageMock := mock_treestorage.NewMockTreeStorage(ctrl)
diffMock := mock_ldiff.NewMockDiff(ctrl)
syncer := mock_headsync.NewMockDiffSyncer(ctrl)
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl)
syncPeriod := 1
initId := "initId"
service := &headSync{ func (m mockConfig) Name() (name string) {
spaceId: spaceId, return "config"
storage: storageMock, }
periodicSync: pSyncMock,
syncer: syncer, func (m mockConfig) GetSpace() config.Config {
diff: diffMock, return config.Config{}
log: l, }
syncPeriod: syncPeriod,
type headSyncFixture struct {
spaceState *spacestate.SpaceState
ctrl *gomock.Controller
app *app.App
configurationMock *mock_nodeconf.MockService
storageMock *mock_spacestorage.MockSpaceStorage
peerManagerMock *mock_peermanager.MockPeerManager
credentialProviderMock *mock_credentialprovider.MockCredentialProvider
syncStatus syncstatus.StatusService
treeManagerMock *mock_treemanager.MockTreeManager
deletionStateMock *mock_deletionstate.MockObjectDeletionState
diffSyncerMock *mock_headsync.MockDiffSyncer
treeSyncerMock *mock_treemanager.MockTreeSyncer
diffMock *mock_ldiff.MockDiff
clientMock *mock_spacesyncproto.MockDRPCSpaceSyncClient
headSync *headSync
diffSyncer *diffSyncer
}
func newHeadSyncFixture(t *testing.T) *headSyncFixture {
spaceState := &spacestate.SpaceState{
SpaceId: "spaceId",
SpaceIsDeleted: &atomic.Bool{},
} }
ctrl := gomock.NewController(t)
configurationMock := mock_nodeconf.NewMockService(ctrl)
configurationMock.EXPECT().Name().AnyTimes().Return(nodeconf.CName)
storageMock := mock_spacestorage.NewMockSpaceStorage(ctrl)
storageMock.EXPECT().Name().AnyTimes().Return(spacestorage.CName)
peerManagerMock := mock_peermanager.NewMockPeerManager(ctrl)
peerManagerMock.EXPECT().Name().AnyTimes().Return(peermanager.CName)
credentialProviderMock := mock_credentialprovider.NewMockCredentialProvider(ctrl)
credentialProviderMock.EXPECT().Name().AnyTimes().Return(credentialprovider.CName)
syncStatus := syncstatus.NewNoOpSyncStatus()
treeManagerMock := mock_treemanager.NewMockTreeManager(ctrl)
treeManagerMock.EXPECT().Name().AnyTimes().Return(treemanager.CName)
deletionStateMock := mock_deletionstate.NewMockObjectDeletionState(ctrl)
deletionStateMock.EXPECT().Name().AnyTimes().Return(deletionstate.CName)
diffSyncerMock := mock_headsync.NewMockDiffSyncer(ctrl)
treeSyncerMock := mock_treemanager.NewMockTreeSyncer(ctrl)
diffMock := mock_ldiff.NewMockDiff(ctrl)
clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl)
hs := &headSync{}
a := &app.App{}
a.Register(spaceState).
Register(mockConfig{}).
Register(configurationMock).
Register(storageMock).
Register(peerManagerMock).
Register(credentialProviderMock).
Register(syncStatus).
Register(treeManagerMock).
Register(deletionStateMock).
Register(hs)
return &headSyncFixture{
spaceState: spaceState,
ctrl: ctrl,
app: a,
configurationMock: configurationMock,
storageMock: storageMock,
peerManagerMock: peerManagerMock,
credentialProviderMock: credentialProviderMock,
syncStatus: syncStatus,
treeManagerMock: treeManagerMock,
deletionStateMock: deletionStateMock,
headSync: hs,
diffSyncerMock: diffSyncerMock,
treeSyncerMock: treeSyncerMock,
diffMock: diffMock,
clientMock: clientMock,
}
}
t.Run("init", func(t *testing.T) { func (fx *headSyncFixture) init(t *testing.T) {
storageMock.EXPECT().TreeStorage(initId).Return(treeStorageMock, nil) createDiffSyncer = func(hs *headSync) DiffSyncer {
treeStorageMock.EXPECT().Heads().Return([]string{"h1", "h2"}, nil) return fx.diffSyncerMock
syncer.EXPECT().Init(delState) }
diffMock.EXPECT().Set(ldiff.Element{ fx.diffSyncerMock.EXPECT().Init()
Id: initId, err := fx.headSync.Init(fx.app)
require.NoError(t, err)
fx.headSync.diff = fx.diffMock
}
func (fx *headSyncFixture) stop() {
fx.ctrl.Finish()
}
func TestHeadSync(t *testing.T) {
ctx := context.Background()
t.Run("run close", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.init(t)
defer fx.stop()
ids := []string{"id1"}
treeMock := mock_treestorage.NewMockTreeStorage(fx.ctrl)
fx.storageMock.EXPECT().StoredIds().Return(ids, nil)
fx.storageMock.EXPECT().TreeStorage(ids[0]).Return(treeMock, nil)
treeMock.EXPECT().Heads().Return([]string{"h1", "h2"}, nil)
fx.diffMock.EXPECT().Set(ldiff.Element{
Id: "id1",
Head: "h1h2", Head: "h1h2",
}) })
hash := "123" fx.diffMock.EXPECT().Hash().Return("hash")
diffMock.EXPECT().Hash().Return(hash) fx.storageMock.EXPECT().WriteSpaceHash("hash").Return(nil)
storageMock.EXPECT().WriteSpaceHash(hash) fx.diffSyncerMock.EXPECT().Sync(gomock.Any()).Return(nil)
pSyncMock.EXPECT().Run() fx.diffSyncerMock.EXPECT().Close().Return(nil)
service.Init([]string{initId}, delState) err := fx.headSync.Run(ctx)
require.NoError(t, err)
err = fx.headSync.Close(ctx)
require.NoError(t, err)
}) })
t.Run("update heads", func(t *testing.T) { t.Run("update heads", func(t *testing.T) {
syncer.EXPECT().UpdateHeads(initId, []string{"h1", "h2"}) fx := newHeadSyncFixture(t)
service.UpdateHeads(initId, []string{"h1", "h2"}) fx.init(t)
defer fx.stop()
fx.diffSyncerMock.EXPECT().UpdateHeads("id1", []string{"h1"})
fx.headSync.UpdateHeads("id1", []string{"h1"})
}) })
t.Run("remove objects", func(t *testing.T) { t.Run("remove objects", func(t *testing.T) {
syncer.EXPECT().RemoveObjects([]string{"h1", "h2"}) fx := newHeadSyncFixture(t)
service.RemoveObjects([]string{"h1", "h2"}) fx.init(t)
}) defer fx.stop()
t.Run("close", func(t *testing.T) { fx.diffSyncerMock.EXPECT().RemoveObjects([]string{"id1"})
pSyncMock.EXPECT().Close() fx.headSync.RemoveObjects([]string{"id1"})
syncer.EXPECT().Close()
service.Close()
}) })
} }

View File

@ -8,7 +8,6 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
settingsstate "github.com/anyproto/any-sync/commonspace/settings/settingsstate"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
@ -50,15 +49,15 @@ func (mr *MockDiffSyncerMockRecorder) Close() *gomock.Call {
} }
// Init mocks base method. // Init mocks base method.
func (m *MockDiffSyncer) Init(arg0 settingsstate.ObjectDeletionState) { func (m *MockDiffSyncer) Init() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Init", arg0) m.ctrl.Call(m, "Init")
} }
// Init indicates an expected call of Init. // Init indicates an expected call of Init.
func (mr *MockDiffSyncerMockRecorder) Init(arg0 interface{}) *gomock.Call { func (mr *MockDiffSyncerMockRecorder) Init() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockDiffSyncer)(nil).Init), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockDiffSyncer)(nil).Init))
} }
// RemoveObjects mocks base method. // RemoveObjects mocks base method.

View File

@ -0,0 +1,27 @@
package headsync
import "strings"
func concatStrings(strs []string) string {
var (
b strings.Builder
totalLen int
)
for _, s := range strs {
totalLen += len(s)
}
b.Grow(totalLen)
for _, s := range strs {
b.WriteString(s)
}
return b.String()
}
func splitString(str string) (res []string) {
const cidLen = 59
for i := 0; i < len(str); i += cidLen {
res = append(res, str[i:i+cidLen])
}
return
}

View File

@ -1,21 +1,43 @@
package syncacl package syncacl
import ( import (
"context"
"github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/objectsync" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
) )
const CName = "common.acl.syncacl"
func New() *SyncAcl {
return &SyncAcl{}
}
type SyncAcl struct { type SyncAcl struct {
list.AclList list.AclList
synchandler.SyncHandler
messagePool objectsync.MessagePool
} }
func NewSyncAcl(aclList list.AclList, messagePool objectsync.MessagePool) *SyncAcl { func (s *SyncAcl) HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) {
return &SyncAcl{ return nil, nil
AclList: aclList, }
SyncHandler: nil,
messagePool: messagePool, func (s *SyncAcl) HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) {
} return nil
}
func (s *SyncAcl) Init(a *app.App) (err error) {
storage := a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
aclStorage, err := storage.AclStorage()
if err != nil {
return err
}
acc := a.MustComponent(accountservice.CName).(accountservice.Service)
s.AclList, err = list.BuildAclListWithIdentity(acc.Account(), aclStorage)
return err
}
func (s *SyncAcl) Name() (name string) {
return CName
} }

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/object/tree/synctree (interfaces: SyncTree,ReceiveQueue,HeadNotifiable) // Source: github.com/anyproto/any-sync/commonspace/object/tree/synctree (interfaces: SyncTree,ReceiveQueue,HeadNotifiable,SyncClient,RequestFactory,TreeSyncProtocol)
// Package mock_synctree is a generated GoMock package. // Package mock_synctree is a generated GoMock package.
package mock_synctree package mock_synctree
@ -186,6 +186,21 @@ func (mr *MockSyncTreeMockRecorder) HandleMessage(arg0, arg1, arg2 interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockSyncTree)(nil).HandleMessage), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockSyncTree)(nil).HandleMessage), arg0, arg1, arg2)
} }
// HandleRequest mocks base method.
func (m *MockSyncTree) HandleRequest(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandleRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// HandleRequest indicates an expected call of HandleRequest.
func (mr *MockSyncTreeMockRecorder) HandleRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleRequest", reflect.TypeOf((*MockSyncTree)(nil).HandleRequest), arg0, arg1, arg2)
}
// HasChanges mocks base method. // HasChanges mocks base method.
func (m *MockSyncTree) HasChanges(arg0 ...string) bool { func (m *MockSyncTree) HasChanges(arg0 ...string) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -590,3 +605,287 @@ func (mr *MockHeadNotifiableMockRecorder) UpdateHeads(arg0, arg1 interface{}) *g
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHeads", reflect.TypeOf((*MockHeadNotifiable)(nil).UpdateHeads), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHeads", reflect.TypeOf((*MockHeadNotifiable)(nil).UpdateHeads), arg0, arg1)
} }
// MockSyncClient is a mock of SyncClient interface.
type MockSyncClient struct {
ctrl *gomock.Controller
recorder *MockSyncClientMockRecorder
}
// MockSyncClientMockRecorder is the mock recorder for MockSyncClient.
type MockSyncClientMockRecorder struct {
mock *MockSyncClient
}
// NewMockSyncClient creates a new mock instance.
func NewMockSyncClient(ctrl *gomock.Controller) *MockSyncClient {
mock := &MockSyncClient{ctrl: ctrl}
mock.recorder = &MockSyncClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSyncClient) EXPECT() *MockSyncClientMockRecorder {
return m.recorder
}
// Broadcast mocks base method.
func (m *MockSyncClient) Broadcast(arg0 *treechangeproto.TreeSyncMessage) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Broadcast", arg0)
}
// Broadcast indicates an expected call of Broadcast.
func (mr *MockSyncClientMockRecorder) Broadcast(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Broadcast", reflect.TypeOf((*MockSyncClient)(nil).Broadcast), arg0)
}
// CreateFullSyncRequest mocks base method.
func (m *MockSyncClient) CreateFullSyncRequest(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
func (mr *MockSyncClientMockRecorder) CreateFullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncRequest), arg0, arg1, arg2)
}
// CreateFullSyncResponse mocks base method.
func (m *MockSyncClient) CreateFullSyncResponse(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
func (mr *MockSyncClientMockRecorder) CreateFullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncResponse), arg0, arg1, arg2)
}
// CreateHeadUpdate mocks base method.
func (m *MockSyncClient) CreateHeadUpdate(arg0 objecttree.ObjectTree, arg1 []*treechangeproto.RawTreeChangeWithId) *treechangeproto.TreeSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
return ret0
}
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate.
func (mr *MockSyncClientMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockSyncClient)(nil).CreateHeadUpdate), arg0, arg1)
}
// CreateNewTreeRequest mocks base method.
func (m *MockSyncClient) CreateNewTreeRequest() *treechangeproto.TreeSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateNewTreeRequest")
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
return ret0
}
// CreateNewTreeRequest indicates an expected call of CreateNewTreeRequest.
func (mr *MockSyncClientMockRecorder) CreateNewTreeRequest() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNewTreeRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateNewTreeRequest))
}
// QueueRequest mocks base method.
func (m *MockSyncClient) QueueRequest(arg0, arg1 string, arg2 *treechangeproto.TreeSyncMessage) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "QueueRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// QueueRequest indicates an expected call of QueueRequest.
func (mr *MockSyncClientMockRecorder) QueueRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueRequest", reflect.TypeOf((*MockSyncClient)(nil).QueueRequest), arg0, arg1, arg2)
}
// SendRequest mocks base method.
func (m *MockSyncClient) SendRequest(arg0 context.Context, arg1, arg2 string, arg3 *treechangeproto.TreeSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendRequest", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SendRequest indicates an expected call of SendRequest.
func (mr *MockSyncClientMockRecorder) SendRequest(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRequest", reflect.TypeOf((*MockSyncClient)(nil).SendRequest), arg0, arg1, arg2, arg3)
}
// SendUpdate mocks base method.
func (m *MockSyncClient) SendUpdate(arg0, arg1 string, arg2 *treechangeproto.TreeSyncMessage) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendUpdate", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// SendUpdate indicates an expected call of SendUpdate.
func (mr *MockSyncClientMockRecorder) SendUpdate(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendUpdate", reflect.TypeOf((*MockSyncClient)(nil).SendUpdate), arg0, arg1, arg2)
}
// MockRequestFactory is a mock of RequestFactory interface.
type MockRequestFactory struct {
ctrl *gomock.Controller
recorder *MockRequestFactoryMockRecorder
}
// MockRequestFactoryMockRecorder is the mock recorder for MockRequestFactory.
type MockRequestFactoryMockRecorder struct {
mock *MockRequestFactory
}
// NewMockRequestFactory creates a new mock instance.
func NewMockRequestFactory(ctrl *gomock.Controller) *MockRequestFactory {
mock := &MockRequestFactory{ctrl: ctrl}
mock.recorder = &MockRequestFactoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRequestFactory) EXPECT() *MockRequestFactoryMockRecorder {
return m.recorder
}
// CreateFullSyncRequest mocks base method.
func (m *MockRequestFactory) CreateFullSyncRequest(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
func (mr *MockRequestFactoryMockRecorder) CreateFullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncRequest), arg0, arg1, arg2)
}
// CreateFullSyncResponse mocks base method.
func (m *MockRequestFactory) CreateFullSyncResponse(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
func (mr *MockRequestFactoryMockRecorder) CreateFullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncResponse), arg0, arg1, arg2)
}
// CreateHeadUpdate mocks base method.
func (m *MockRequestFactory) CreateHeadUpdate(arg0 objecttree.ObjectTree, arg1 []*treechangeproto.RawTreeChangeWithId) *treechangeproto.TreeSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
return ret0
}
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate.
func (mr *MockRequestFactoryMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockRequestFactory)(nil).CreateHeadUpdate), arg0, arg1)
}
// CreateNewTreeRequest mocks base method.
func (m *MockRequestFactory) CreateNewTreeRequest() *treechangeproto.TreeSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateNewTreeRequest")
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
return ret0
}
// CreateNewTreeRequest indicates an expected call of CreateNewTreeRequest.
func (mr *MockRequestFactoryMockRecorder) CreateNewTreeRequest() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNewTreeRequest", reflect.TypeOf((*MockRequestFactory)(nil).CreateNewTreeRequest))
}
// MockTreeSyncProtocol is a mock of TreeSyncProtocol interface.
type MockTreeSyncProtocol struct {
ctrl *gomock.Controller
recorder *MockTreeSyncProtocolMockRecorder
}
// MockTreeSyncProtocolMockRecorder is the mock recorder for MockTreeSyncProtocol.
type MockTreeSyncProtocolMockRecorder struct {
mock *MockTreeSyncProtocol
}
// NewMockTreeSyncProtocol creates a new mock instance.
func NewMockTreeSyncProtocol(ctrl *gomock.Controller) *MockTreeSyncProtocol {
mock := &MockTreeSyncProtocol{ctrl: ctrl}
mock.recorder = &MockTreeSyncProtocolMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTreeSyncProtocol) EXPECT() *MockTreeSyncProtocolMockRecorder {
return m.recorder
}
// FullSyncRequest mocks base method.
func (m *MockTreeSyncProtocol) FullSyncRequest(arg0 context.Context, arg1 string, arg2 *treechangeproto.TreeFullSyncRequest) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FullSyncRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FullSyncRequest indicates an expected call of FullSyncRequest.
func (mr *MockTreeSyncProtocolMockRecorder) FullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncRequest", reflect.TypeOf((*MockTreeSyncProtocol)(nil).FullSyncRequest), arg0, arg1, arg2)
}
// FullSyncResponse mocks base method.
func (m *MockTreeSyncProtocol) FullSyncResponse(arg0 context.Context, arg1 string, arg2 *treechangeproto.TreeFullSyncResponse) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FullSyncResponse", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// FullSyncResponse indicates an expected call of FullSyncResponse.
func (mr *MockTreeSyncProtocolMockRecorder) FullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncResponse", reflect.TypeOf((*MockTreeSyncProtocol)(nil).FullSyncResponse), arg0, arg1, arg2)
}
// HeadUpdate mocks base method.
func (m *MockTreeSyncProtocol) HeadUpdate(arg0 context.Context, arg1 string, arg2 *treechangeproto.TreeHeadUpdate) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HeadUpdate", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// HeadUpdate indicates an expected call of HeadUpdate.
func (mr *MockTreeSyncProtocolMockRecorder) HeadUpdate(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadUpdate", reflect.TypeOf((*MockTreeSyncProtocol)(nil).HeadUpdate), arg0, arg1, arg2)
}

View File

@ -1,4 +1,4 @@
package objectsync package synctree
import ( import (
"fmt" "fmt"

View File

@ -0,0 +1,82 @@
package synctree
import (
"context"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/requestmanager"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"go.uber.org/zap"
)
type SyncClient interface {
RequestFactory
Broadcast(msg *treechangeproto.TreeSyncMessage)
SendUpdate(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error)
QueueRequest(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error)
SendRequest(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
}
type syncClient struct {
RequestFactory
spaceId string
requestManager requestmanager.RequestManager
peerManager peermanager.PeerManager
}
func NewSyncClient(spaceId string, requestManager requestmanager.RequestManager, peerManager peermanager.PeerManager) SyncClient {
return &syncClient{
RequestFactory: &requestFactory{},
spaceId: spaceId,
requestManager: requestManager,
peerManager: peerManager,
}
}
func (s *syncClient) Broadcast(msg *treechangeproto.TreeSyncMessage) {
objMsg, err := MarshallTreeMessage(msg, s.spaceId, msg.RootChange.Id, "")
if err != nil {
return
}
err = s.peerManager.Broadcast(context.Background(), objMsg)
if err != nil {
log.Debug("broadcast error", zap.Error(err))
}
}
func (s *syncClient) SendUpdate(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error) {
objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, "")
if err != nil {
return
}
return s.peerManager.SendPeer(context.Background(), peerId, objMsg)
}
func (s *syncClient) SendRequest(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, "")
if err != nil {
return
}
return s.requestManager.SendRequest(ctx, peerId, objMsg)
}
func (s *syncClient) QueueRequest(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error) {
objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, "")
if err != nil {
return
}
return s.requestManager.QueueRequest(peerId, objMsg)
}
func MarshallTreeMessage(message *treechangeproto.TreeSyncMessage, spaceId, objectId, replyId string) (objMsg *spacesyncproto.ObjectSyncMessage, err error) {
payload, err := message.Marshal()
if err != nil {
return
}
objMsg = &spacesyncproto.ObjectSyncMessage{
ReplyId: replyId,
Payload: payload,
ObjectId: objectId,
SpaceId: spaceId,
}
return
}

View File

@ -1,4 +1,4 @@
//go:generate mockgen -destination mock_synctree/mock_synctree.go github.com/anyproto/any-sync/commonspace/object/tree/synctree SyncTree,ReceiveQueue,HeadNotifiable //go:generate mockgen -destination mock_synctree/mock_synctree.go github.com/anyproto/any-sync/commonspace/object/tree/synctree SyncTree,ReceiveQueue,HeadNotifiable,SyncClient,RequestFactory,TreeSyncProtocol
package synctree package synctree
import ( import (
@ -11,7 +11,6 @@ import (
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
@ -44,7 +43,7 @@ type SyncTree interface {
type syncTree struct { type syncTree struct {
objecttree.ObjectTree objecttree.ObjectTree
synchandler.SyncHandler synchandler.SyncHandler
syncClient objectsync.SyncClient syncClient SyncClient
syncStatus syncstatus.StatusUpdater syncStatus syncstatus.StatusUpdater
notifiable HeadNotifiable notifiable HeadNotifiable
listener updatelistener.UpdateListener listener updatelistener.UpdateListener
@ -61,7 +60,7 @@ type ResponsiblePeersGetter interface {
type BuildDeps struct { type BuildDeps struct {
SpaceId string SpaceId string
SyncClient objectsync.SyncClient SyncClient SyncClient
Configuration nodeconf.NodeConf Configuration nodeconf.NodeConf
HeadNotifiable HeadNotifiable HeadNotifiable HeadNotifiable
Listener updatelistener.UpdateListener Listener updatelistener.UpdateListener
@ -119,7 +118,7 @@ func buildSyncTree(ctx context.Context, sendUpdate bool, deps BuildDeps) (t Sync
if sendUpdate { if sendUpdate {
headUpdate := syncTree.syncClient.CreateHeadUpdate(t, nil) headUpdate := syncTree.syncClient.CreateHeadUpdate(t, nil)
// send to everybody, because everybody should know that the node or client got new tree // send to everybody, because everybody should know that the node or client got new tree
syncTree.syncClient.Broadcast(ctx, headUpdate) syncTree.syncClient.Broadcast(headUpdate)
} }
return return
} }
@ -156,7 +155,7 @@ func (s *syncTree) AddContent(ctx context.Context, content objecttree.SignableCh
} }
s.syncStatus.HeadsChange(s.Id(), res.Heads) s.syncStatus.HeadsChange(s.Id(), res.Heads)
headUpdate := s.syncClient.CreateHeadUpdate(s, res.Added) headUpdate := s.syncClient.CreateHeadUpdate(s, res.Added)
s.syncClient.Broadcast(ctx, headUpdate) s.syncClient.Broadcast(headUpdate)
return return
} }
@ -183,7 +182,7 @@ func (s *syncTree) AddRawChanges(ctx context.Context, changesPayload objecttree.
s.notifiable.UpdateHeads(s.Id(), res.Heads) s.notifiable.UpdateHeads(s.Id(), res.Heads)
} }
headUpdate := s.syncClient.CreateHeadUpdate(s, res.Added) headUpdate := s.syncClient.CreateHeadUpdate(s, res.Added)
s.syncClient.Broadcast(ctx, headUpdate) s.syncClient.Broadcast(headUpdate)
} }
return return
} }
@ -239,7 +238,7 @@ func (s *syncTree) SyncWithPeer(ctx context.Context, peerId string) (err error)
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
headUpdate := s.syncClient.CreateHeadUpdate(s, nil) headUpdate := s.syncClient.CreateHeadUpdate(s, nil)
return s.syncClient.SendWithReply(ctx, peerId, headUpdate.RootChange.Id, headUpdate, "") return s.syncClient.SendUpdate(peerId, headUpdate.RootChange.Id, headUpdate)
} }
func (s *syncTree) afterBuild() { func (s *syncTree) afterBuild() {

View File

@ -4,11 +4,11 @@ import (
"context" "context"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/mock_synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener/mock_updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener/mock_updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/objectsync" "github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/mock_objectsync"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
@ -18,7 +18,7 @@ import (
type syncTreeMatcher struct { type syncTreeMatcher struct {
objTree objecttree.ObjectTree objTree objecttree.ObjectTree
client objectsync.SyncClient client SyncClient
listener updatelistener.UpdateListener listener updatelistener.UpdateListener
} }
@ -34,8 +34,8 @@ func (s syncTreeMatcher) String() string {
return "" return ""
} }
func syncClientFuncCreator(client objectsync.SyncClient) func(spaceId string, factory objectsync.RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) objectsync.SyncClient { func syncClientFuncCreator(client SyncClient) func(spaceId string, factory RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) SyncClient {
return func(spaceId string, factory objectsync.RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) objectsync.SyncClient { return func(spaceId string, factory RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) SyncClient {
return client return client
} }
} }
@ -46,7 +46,7 @@ func Test_BuildSyncTree(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
updateListenerMock := mock_updatelistener.NewMockUpdateListener(ctrl) updateListenerMock := mock_updatelistener.NewMockUpdateListener(ctrl)
syncClientMock := mock_objectsync.NewMockSyncClient(ctrl) syncClientMock := mock_synctree.NewMockSyncClient(ctrl)
objTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl)) objTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl))
tr := &syncTree{ tr := &syncTree{
ObjectTree: objTreeMock, ObjectTree: objTreeMock,
@ -73,7 +73,7 @@ func Test_BuildSyncTree(t *testing.T) {
updateListenerMock.EXPECT().Update(tr) updateListenerMock.EXPECT().Update(tr)
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)) syncClientMock.EXPECT().Broadcast(gomock.Eq(headUpdate))
res, err := tr.AddRawChanges(ctx, payload) res, err := tr.AddRawChanges(ctx, payload)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expectedRes, res) require.Equal(t, expectedRes, res)
@ -95,7 +95,7 @@ func Test_BuildSyncTree(t *testing.T) {
updateListenerMock.EXPECT().Rebuild(tr) updateListenerMock.EXPECT().Rebuild(tr)
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)) syncClientMock.EXPECT().Broadcast(gomock.Eq(headUpdate))
res, err := tr.AddRawChanges(ctx, payload) res, err := tr.AddRawChanges(ctx, payload)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expectedRes, res) require.Equal(t, expectedRes, res)
@ -133,7 +133,7 @@ func Test_BuildSyncTree(t *testing.T) {
Return(expectedRes, nil) Return(expectedRes, nil)
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)) syncClientMock.EXPECT().Broadcast(gomock.Eq(headUpdate))
res, err := tr.AddContent(ctx, content) res, err := tr.AddContent(ctx, content)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expectedRes, res) require.Equal(t, expectedRes, res)

View File

@ -2,39 +2,66 @@ package synctree
import ( import (
"context" "context"
"errors"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/util/slice"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"go.uber.org/zap"
"sync" "sync"
) )
var (
ErrMessageIsRequest = errors.New("message is request")
ErrMessageIsNotRequest = errors.New("message is not request")
)
type syncTreeHandler struct { type syncTreeHandler struct {
objTree objecttree.ObjectTree objTree objecttree.ObjectTree
syncClient objectsync.SyncClient syncClient SyncClient
syncStatus syncstatus.StatusUpdater syncProtocol TreeSyncProtocol
handlerLock sync.Mutex syncStatus syncstatus.StatusUpdater
spaceId string handlerLock sync.Mutex
queue ReceiveQueue spaceId string
queue ReceiveQueue
} }
const maxQueueSize = 5 const maxQueueSize = 5
func newSyncTreeHandler(spaceId string, objTree objecttree.ObjectTree, syncClient objectsync.SyncClient, syncStatus syncstatus.StatusUpdater) synchandler.SyncHandler { func newSyncTreeHandler(spaceId string, objTree objecttree.ObjectTree, syncClient SyncClient, syncStatus syncstatus.StatusUpdater) synchandler.SyncHandler {
return &syncTreeHandler{ return &syncTreeHandler{
objTree: objTree, objTree: objTree,
syncClient: syncClient, syncProtocol: newTreeSyncProtocol(spaceId, objTree, syncClient),
syncStatus: syncStatus, syncClient: syncClient,
spaceId: spaceId, syncStatus: syncStatus,
queue: newReceiveQueue(maxQueueSize), spaceId: spaceId,
queue: newReceiveQueue(maxQueueSize),
} }
} }
func (s *syncTreeHandler) HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) {
unmarshalled := &treechangeproto.TreeSyncMessage{}
err = proto.Unmarshal(request.Payload, unmarshalled)
if err != nil {
return
}
fullSyncRequest := unmarshalled.GetContent().GetFullSyncRequest()
if fullSyncRequest == nil {
err = ErrMessageIsNotRequest
return
}
s.syncStatus.HeadsReceive(senderId, request.ObjectId, treechangeproto.GetHeads(unmarshalled))
s.objTree.Lock()
defer s.objTree.Unlock()
treeResp, err := s.syncProtocol.FullSyncRequest(ctx, senderId, fullSyncRequest)
if err != nil {
return
}
response, err = MarshallTreeMessage(treeResp, s.spaceId, request.ObjectId, "")
return
}
func (s *syncTreeHandler) HandleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { func (s *syncTreeHandler) HandleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
unmarshalled := &treechangeproto.TreeSyncMessage{} unmarshalled := &treechangeproto.TreeSyncMessage{}
err = proto.Unmarshal(msg.Payload, unmarshalled) err = proto.Unmarshal(msg.Payload, unmarshalled)
@ -54,181 +81,27 @@ func (s *syncTreeHandler) HandleMessage(ctx context.Context, senderId string, ms
func (s *syncTreeHandler) handleMessage(ctx context.Context, senderId string) (err error) { func (s *syncTreeHandler) handleMessage(ctx context.Context, senderId string) (err error) {
s.objTree.Lock() s.objTree.Lock()
defer s.objTree.Unlock() defer s.objTree.Unlock()
msg, replyId, err := s.queue.GetMessage(senderId) msg, _, err := s.queue.GetMessage(senderId)
if err != nil { if err != nil {
return return
} }
defer s.queue.ClearQueue(senderId) defer s.queue.ClearQueue(senderId)
treeId := s.objTree.Id()
content := msg.GetContent() content := msg.GetContent()
switch { switch {
case content.GetHeadUpdate() != nil: case content.GetHeadUpdate() != nil:
return s.handleHeadUpdate(ctx, senderId, content.GetHeadUpdate(), replyId) var syncReq *treechangeproto.TreeSyncMessage
syncReq, err = s.syncProtocol.HeadUpdate(ctx, senderId, content.GetHeadUpdate())
if err != nil || syncReq == nil {
return
}
return s.syncClient.QueueRequest(senderId, treeId, syncReq)
case content.GetFullSyncRequest() != nil: case content.GetFullSyncRequest() != nil:
return s.handleFullSyncRequest(ctx, senderId, content.GetFullSyncRequest(), replyId) return ErrMessageIsRequest
case content.GetFullSyncResponse() != nil: case content.GetFullSyncResponse() != nil:
return s.handleFullSyncResponse(ctx, senderId, content.GetFullSyncResponse()) return s.syncProtocol.FullSyncResponse(ctx, senderId, content.GetFullSyncResponse())
} }
return return
} }
func (s *syncTreeHandler) handleHeadUpdate(
ctx context.Context,
senderId string,
update *treechangeproto.TreeHeadUpdate,
replyId string) (err error) {
var (
fullRequest *treechangeproto.TreeSyncMessage
isEmptyUpdate = len(update.Changes) == 0
objTree = s.objTree
treeId = objTree.Id()
)
log := log.With(
zap.Strings("update heads", update.Heads),
zap.String("treeId", treeId),
zap.String("spaceId", s.spaceId),
zap.Int("len(update changes)", len(update.Changes)))
log.DebugCtx(ctx, "received head update message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "head update finished with error", zap.Error(err))
} else if fullRequest != nil {
cnt := fullRequest.Content.GetFullSyncRequest()
log = log.With(zap.Strings("request heads", cnt.Heads), zap.Int("len(request changes)", len(cnt.Changes)))
log.DebugCtx(ctx, "sending full sync request")
} else {
if !isEmptyUpdate {
log.DebugCtx(ctx, "head update finished correctly")
}
}
}()
// isEmptyUpdate is sent when the tree is brought up from cache
if isEmptyUpdate {
headEquals := slice.UnsortedEquals(objTree.Heads(), update.Heads)
log.DebugCtx(ctx, "is empty update", zap.String("treeId", objTree.Id()), zap.Bool("headEquals", headEquals))
if headEquals {
return
}
// we need to sync in any case
fullRequest, err = s.syncClient.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath)
if err != nil {
return
}
return s.syncClient.SendWithReply(ctx, senderId, treeId, fullRequest, replyId)
}
if s.alreadyHasHeads(objTree, update.Heads) {
return
}
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: update.Heads,
RawChanges: update.Changes,
})
if err != nil {
return
}
if s.alreadyHasHeads(objTree, update.Heads) {
return
}
fullRequest, err = s.syncClient.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath)
if err != nil {
return
}
return s.syncClient.SendWithReply(ctx, senderId, treeId, fullRequest, replyId)
}
func (s *syncTreeHandler) handleFullSyncRequest(
ctx context.Context,
senderId string,
request *treechangeproto.TreeFullSyncRequest,
replyId string) (err error) {
var (
fullResponse *treechangeproto.TreeSyncMessage
header = s.objTree.Header()
objTree = s.objTree
treeId = s.objTree.Id()
)
log := log.With(zap.String("senderId", senderId),
zap.Strings("request heads", request.Heads),
zap.String("treeId", treeId),
zap.String("replyId", replyId),
zap.String("spaceId", s.spaceId),
zap.Int("len(request changes)", len(request.Changes)))
log.DebugCtx(ctx, "received full sync request message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "full sync request finished with error", zap.Error(err))
s.syncClient.SendWithReply(ctx, senderId, treeId, treechangeproto.WrapError(treechangeproto.ErrFullSync, header), replyId)
return
} else if fullResponse != nil {
cnt := fullResponse.Content.GetFullSyncResponse()
log = log.With(zap.Strings("response heads", cnt.Heads), zap.Int("len(response changes)", len(cnt.Changes)))
log.DebugCtx(ctx, "full sync response sent")
}
}()
if len(request.Changes) != 0 && !s.alreadyHasHeads(objTree, request.Heads) {
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: request.Heads,
RawChanges: request.Changes,
})
if err != nil {
return
}
}
fullResponse, err = s.syncClient.CreateFullSyncResponse(objTree, request.Heads, request.SnapshotPath)
if err != nil {
return
}
return s.syncClient.SendWithReply(ctx, senderId, treeId, fullResponse, replyId)
}
func (s *syncTreeHandler) handleFullSyncResponse(
ctx context.Context,
senderId string,
response *treechangeproto.TreeFullSyncResponse) (err error) {
var (
objTree = s.objTree
treeId = s.objTree.Id()
)
log := log.With(
zap.Strings("heads", response.Heads),
zap.String("treeId", treeId),
zap.String("spaceId", s.spaceId),
zap.Int("len(changes)", len(response.Changes)))
log.DebugCtx(ctx, "received full sync response message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "full sync response failed", zap.Error(err))
} else {
log.DebugCtx(ctx, "full sync response succeeded")
}
}()
if s.alreadyHasHeads(objTree, response.Heads) {
return
}
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: response.Heads,
RawChanges: response.Changes,
})
return
}
func (s *syncTreeHandler) alreadyHasHeads(t objecttree.ObjectTree, heads []string) bool {
return slice.UnsortedEquals(t.Heads(), heads) || t.HasChanges(heads...)
}

View File

@ -2,20 +2,15 @@ package synctree
import ( import (
"context" "context"
"fmt" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/mock_synctree"
"github.com/anyproto/any-sync/commonspace/objectsync" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/objectsync/mock_objectsync" "github.com/stretchr/testify/require"
"sync" "sync"
"testing" "testing"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
) )
type testObjTreeMock struct { type testObjTreeMock struct {
@ -55,31 +50,43 @@ func (t *testObjTreeMock) TryRLock() bool {
type syncHandlerFixture struct { type syncHandlerFixture struct {
ctrl *gomock.Controller ctrl *gomock.Controller
syncClientMock *mock_objectsync.MockSyncClient syncClientMock *mock_synctree.MockSyncClient
objectTreeMock *testObjTreeMock objectTreeMock *testObjTreeMock
receiveQueueMock ReceiveQueue receiveQueueMock ReceiveQueue
syncProtocolMock *mock_synctree.MockTreeSyncProtocol
spaceId string
senderId string
treeId string
syncHandler *syncTreeHandler syncHandler *syncTreeHandler
} }
func newSyncHandlerFixture(t *testing.T) *syncHandlerFixture { func newSyncHandlerFixture(t *testing.T) *syncHandlerFixture {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
syncClientMock := mock_objectsync.NewMockSyncClient(ctrl)
objectTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl)) objectTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl))
syncClientMock := mock_synctree.NewMockSyncClient(ctrl)
syncProtocolMock := mock_synctree.NewMockTreeSyncProtocol(ctrl)
spaceId := "spaceId"
receiveQueue := newReceiveQueue(5) receiveQueue := newReceiveQueue(5)
syncHandler := &syncTreeHandler{ syncHandler := &syncTreeHandler{
objTree: objectTreeMock, objTree: objectTreeMock,
syncClient: syncClientMock, syncClient: syncClientMock,
queue: receiveQueue, syncProtocol: syncProtocolMock,
syncStatus: syncstatus.NewNoOpSyncStatus(), spaceId: spaceId,
queue: receiveQueue,
syncStatus: syncstatus.NewNoOpSyncStatus(),
} }
return &syncHandlerFixture{ return &syncHandlerFixture{
ctrl: ctrl, ctrl: ctrl,
syncClientMock: syncClientMock,
objectTreeMock: objectTreeMock, objectTreeMock: objectTreeMock,
receiveQueueMock: receiveQueue, receiveQueueMock: receiveQueue,
syncProtocolMock: syncProtocolMock,
syncClientMock: syncClientMock,
syncHandler: syncHandler, syncHandler: syncHandler,
spaceId: spaceId,
senderId: "senderId",
treeId: "treeId",
} }
} }
@ -87,341 +94,128 @@ func (fx *syncHandlerFixture) stop() {
fx.ctrl.Finish() fx.ctrl.Finish()
} }
func TestSyncHandler_HandleHeadUpdate(t *testing.T) { func TestSyncTreeHandler_HandleMessage(t *testing.T) {
ctx := context.Background() ctx := context.Background()
log = logger.CtxLogger{Logger: zap.NewNop()}
fullRequest := &treechangeproto.TreeSyncMessage{
Content: &treechangeproto.TreeSyncContentValue{
Value: &treechangeproto.TreeSyncContentValue_FullSyncRequest{
FullSyncRequest: &treechangeproto.TreeFullSyncRequest{},
},
},
}
t.Run("head update non empty all heads added", func(t *testing.T) { t.Run("handle head update message", func(t *testing.T) {
fx := newSyncHandlerFixture(t) fx := newSyncHandlerFixture(t)
defer fx.stop() defer fx.stop()
treeId := "treeId" treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{} chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{ headUpdate := &treechangeproto.TreeHeadUpdate{}
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) syncReq := &treechangeproto.TreeSyncMessage{}
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).Times(2) fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false) fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(syncReq, nil)
fx.objectTreeMock.EXPECT(). fx.syncClientMock.EXPECT().QueueRequest(fx.senderId, fx.treeId, syncReq).Return(nil)
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(true)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err) require.NoError(t, err)
}) })
t.Run("head update non empty heads not added", func(t *testing.T) { t.Run("handle head update message, empty sync request", func(t *testing.T) {
fx := newSyncHandlerFixture(t) fx := newSyncHandlerFixture(t)
defer fx.stop() defer fx.stop()
treeId := "treeId" treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{} chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{ headUpdate := &treechangeproto.TreeHeadUpdate{}
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes() fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(nil, nil)
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.syncClientMock.EXPECT().
CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullRequest, nil)
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullRequest), gomock.Eq(""))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err) require.NoError(t, err)
}) })
t.Run("head update non empty equal heads", func(t *testing.T) { t.Run("handle full sync request returns error", func(t *testing.T) {
fx := newSyncHandlerFixture(t) fx := newSyncHandlerFixture(t)
defer fx.stop() defer fx.stop()
treeId := "treeId" treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{} chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{ fullRequest := &treechangeproto.TreeFullSyncRequest{}
Heads: []string{"h1"}, treeMsg := treechangeproto.WrapFullRequest(fullRequest, chWithId)
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes()
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err) require.Equal(t, err, ErrMessageIsRequest)
}) })
t.Run("head update empty", func(t *testing.T) { t.Run("handle full sync response", func(t *testing.T) {
fx := newSyncHandlerFixture(t) fx := newSyncHandlerFixture(t)
defer fx.stop() defer fx.stop()
treeId := "treeId" treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{} chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{ fullSyncResponse := &treechangeproto.TreeFullSyncResponse{}
Heads: []string{"h1"},
Changes: nil,
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
fx.syncClientMock.EXPECT().
CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullRequest, nil)
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullRequest), gomock.Eq(""))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
t.Run("head update empty equal heads", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"},
Changes: nil,
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes()
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
}
func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
ctx := context.Background()
log = logger.CtxLogger{Logger: zap.NewNop()}
fullResponse := &treechangeproto.TreeSyncMessage{
Content: &treechangeproto.TreeSyncContentValue{
Value: &treechangeproto.TreeSyncContentValue_FullSyncResponse{
FullSyncResponse: &treechangeproto.TreeFullSyncResponse{},
},
},
}
t.Run("full sync request with change", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Header().Return(nil)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
fx.syncClientMock.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullResponse), gomock.Eq(""))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
t.Run("full sync request with change same heads", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().
Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Header().Return(nil)
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h1"}).AnyTimes()
fx.syncClientMock.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullResponse), gomock.Eq(""))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
t.Run("full sync request without change but with reply id", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
replyId := "replyId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
objectMsg.RequestId = replyId
fx.objectTreeMock.EXPECT().
Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Header().Return(nil)
fx.syncClientMock.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullResponse), gomock.Eq(replyId))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
t.Run("full sync request with add raw changes error", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().
Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Header().Return(nil)
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h2"})
fx.objectTreeMock.EXPECT().
HasChanges(gomock.Eq([]string{"h1"})).
Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, fmt.Errorf(""))
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Any(), gomock.Eq(""))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.Error(t, err)
})
}
func TestSyncHandler_HandleFullSyncResponse(t *testing.T) {
ctx := context.Background()
log = logger.CtxLogger{Logger: zap.NewNop()}
t.Run("full sync response with change", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
replyId := "replyId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncResponse := &treechangeproto.TreeFullSyncResponse{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId) treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, replyId) objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT(). fx.syncProtocolMock.EXPECT().FullSyncResponse(ctx, fx.senderId, gomock.Any()).Return(nil)
Heads().
Return([]string{"h2"}).AnyTimes()
fx.objectTreeMock.EXPECT().
HasChanges(gomock.Eq([]string{"h1"})).
Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
})
t.Run("full sync response with same heads", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
replyId := "replyId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncResponse := &treechangeproto.TreeFullSyncResponse{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, replyId)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h1"}).AnyTimes()
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err) require.NoError(t, err)
}) })
} }
func TestSyncTreeHandler_HandleRequest(t *testing.T) {
ctx := context.Background()
t.Run("handle request", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullRequest := &treechangeproto.TreeFullSyncRequest{}
treeMsg := treechangeproto.WrapFullRequest(fullRequest, chWithId)
objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
syncResp := &treechangeproto.TreeSyncMessage{}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.syncProtocolMock.EXPECT().FullSyncRequest(ctx, fx.senderId, gomock.Any()).Return(syncResp, nil)
res, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
require.NotNil(t, res)
})
t.Run("handle request", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullRequest := &treechangeproto.TreeFullSyncRequest{}
treeMsg := treechangeproto.WrapFullRequest(fullRequest, chWithId)
objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
syncResp := &treechangeproto.TreeSyncMessage{}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.syncProtocolMock.EXPECT().FullSyncRequest(ctx, fx.senderId, gomock.Any()).Return(syncResp, nil)
res, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
require.NotNil(t, res)
})
t.Run("handle other message", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullResponse := &treechangeproto.TreeFullSyncResponse{}
responseMsg := treechangeproto.WrapFullResponse(fullResponse, chWithId)
headUpdate := &treechangeproto.TreeHeadUpdate{}
headUpdateMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
for _, msg := range []*treechangeproto.TreeSyncMessage{responseMsg, headUpdateMsg} {
objectMsg, _ := MarshallTreeMessage(msg, "spaceId", treeId, "")
_, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg)
require.Equal(t, err, ErrMessageIsNotRequest)
}
})
}

View File

@ -47,7 +47,7 @@ func (t treeRemoteGetter) getPeers(ctx context.Context) (peerIds []string, err e
func (t treeRemoteGetter) treeRequest(ctx context.Context, peerId string) (msg *treechangeproto.TreeSyncMessage, err error) { func (t treeRemoteGetter) treeRequest(ctx context.Context, peerId string) (msg *treechangeproto.TreeSyncMessage, err error) {
newTreeRequest := t.deps.SyncClient.CreateNewTreeRequest() newTreeRequest := t.deps.SyncClient.CreateNewTreeRequest()
resp, err := t.deps.SyncClient.SendSync(ctx, peerId, t.treeId, newTreeRequest) resp, err := t.deps.SyncClient.SendRequest(ctx, peerId, t.treeId, newTreeRequest)
if err != nil { if err != nil {
return return
} }

View File

@ -0,0 +1,152 @@
package synctree
import (
"context"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/util/slice"
"go.uber.org/zap"
)
type TreeSyncProtocol interface {
HeadUpdate(ctx context.Context, senderId string, update *treechangeproto.TreeHeadUpdate) (request *treechangeproto.TreeSyncMessage, err error)
FullSyncRequest(ctx context.Context, senderId string, request *treechangeproto.TreeFullSyncRequest) (response *treechangeproto.TreeSyncMessage, err error)
FullSyncResponse(ctx context.Context, senderId string, response *treechangeproto.TreeFullSyncResponse) (err error)
}
type treeSyncProtocol struct {
log logger.CtxLogger
spaceId string
objTree objecttree.ObjectTree
reqFactory RequestFactory
}
func newTreeSyncProtocol(spaceId string, objTree objecttree.ObjectTree, reqFactory RequestFactory) *treeSyncProtocol {
return &treeSyncProtocol{
log: log.With(zap.String("spaceId", spaceId), zap.String("treeId", objTree.Id())),
spaceId: spaceId,
objTree: objTree,
reqFactory: reqFactory,
}
}
func (t *treeSyncProtocol) HeadUpdate(ctx context.Context, senderId string, update *treechangeproto.TreeHeadUpdate) (fullRequest *treechangeproto.TreeSyncMessage, err error) {
var (
isEmptyUpdate = len(update.Changes) == 0
objTree = t.objTree
)
log := t.log.With(
zap.String("senderId", senderId),
zap.Strings("update heads", update.Heads),
zap.Int("len(update changes)", len(update.Changes)))
log.DebugCtx(ctx, "received head update message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "head update finished with error", zap.Error(err))
} else if fullRequest != nil {
cnt := fullRequest.Content.GetFullSyncRequest()
log = log.With(zap.Strings("request heads", cnt.Heads), zap.Int("len(request changes)", len(cnt.Changes)))
log.DebugCtx(ctx, "returning full sync request")
} else {
if !isEmptyUpdate {
log.DebugCtx(ctx, "head update finished correctly")
}
}
}()
// isEmptyUpdate is sent when the tree is brought up from cache
if isEmptyUpdate {
headEquals := slice.UnsortedEquals(objTree.Heads(), update.Heads)
log.DebugCtx(ctx, "is empty update", zap.String("treeId", objTree.Id()), zap.Bool("headEquals", headEquals))
if headEquals {
return
}
// we need to sync in any case
fullRequest, err = t.reqFactory.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath)
return
}
if t.alreadyHasHeads(objTree, update.Heads) {
return
}
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: update.Heads,
RawChanges: update.Changes,
})
if err != nil {
return
}
if t.alreadyHasHeads(objTree, update.Heads) {
return
}
fullRequest, err = t.reqFactory.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath)
return
}
func (t *treeSyncProtocol) FullSyncRequest(ctx context.Context, senderId string, request *treechangeproto.TreeFullSyncRequest) (fullResponse *treechangeproto.TreeSyncMessage, err error) {
var (
objTree = t.objTree
)
log := t.log.With(zap.String("senderId", senderId),
zap.Strings("request heads", request.Heads),
zap.Int("len(request changes)", len(request.Changes)))
log.DebugCtx(ctx, "received full sync request message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "full sync request finished with error", zap.Error(err))
} else if fullResponse != nil {
cnt := fullResponse.Content.GetFullSyncResponse()
log = log.With(zap.Strings("response heads", cnt.Heads), zap.Int("len(response changes)", len(cnt.Changes)))
log.DebugCtx(ctx, "full sync response sent")
}
}()
if len(request.Changes) != 0 && !t.alreadyHasHeads(objTree, request.Heads) {
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: request.Heads,
RawChanges: request.Changes,
})
if err != nil {
return
}
}
fullResponse, err = t.reqFactory.CreateFullSyncResponse(objTree, request.Heads, request.SnapshotPath)
return
}
func (t *treeSyncProtocol) FullSyncResponse(ctx context.Context, senderId string, response *treechangeproto.TreeFullSyncResponse) (err error) {
var (
objTree = t.objTree
)
log := log.With(
zap.Strings("heads", response.Heads),
zap.Int("len(changes)", len(response.Changes)))
log.DebugCtx(ctx, "received full sync response message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "full sync response failed", zap.Error(err))
} else {
log.DebugCtx(ctx, "full sync response succeeded")
}
}()
if t.alreadyHasHeads(objTree, response.Heads) {
return
}
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: response.Heads,
RawChanges: response.Changes,
})
return
}
func (t *treeSyncProtocol) alreadyHasHeads(ot objecttree.ObjectTree, heads []string) bool {
return slice.UnsortedEquals(ot.Heads(), heads) || ot.HasChanges(heads...)
}

View File

@ -0,0 +1,293 @@
package synctree
import (
"context"
"fmt"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/mock_synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"testing"
)
type treeSyncProtocolFixture struct {
log logger.CtxLogger
spaceId string
senderId string
treeId string
objectTreeMock *testObjTreeMock
reqFactory *mock_synctree.MockRequestFactory
ctrl *gomock.Controller
syncProtocol TreeSyncProtocol
}
func newSyncProtocolFixture(t *testing.T) *treeSyncProtocolFixture {
ctrl := gomock.NewController(t)
objTree := &testObjTreeMock{
MockObjectTree: mock_objecttree.NewMockObjectTree(ctrl),
}
spaceId := "spaceId"
reqFactory := mock_synctree.NewMockRequestFactory(ctrl)
objTree.EXPECT().Id().Return("treeId")
syncProtocol := newTreeSyncProtocol(spaceId, objTree, reqFactory)
return &treeSyncProtocolFixture{
log: log,
spaceId: spaceId,
senderId: "senderId",
treeId: "treeId",
objectTreeMock: objTree,
reqFactory: reqFactory,
ctrl: ctrl,
syncProtocol: syncProtocol,
}
}
func (fx *treeSyncProtocolFixture) stop() {
fx.ctrl.Finish()
}
func TestTreeSyncProtocol_HeadUpdate(t *testing.T) {
ctx := context.Background()
fullRequest := &treechangeproto.TreeSyncMessage{
Content: &treechangeproto.TreeSyncContentValue{
Value: &treechangeproto.TreeSyncContentValue_FullSyncRequest{
FullSyncRequest: &treechangeproto.TreeFullSyncRequest{},
},
},
}
t.Run("head update non empty all heads added", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).Times(2)
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(true)
res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.NoError(t, err)
require.Nil(t, res)
})
t.Run("head update non empty equal heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes()
res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.NoError(t, err)
require.Nil(t, res)
})
t.Run("head update empty", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"},
Changes: nil,
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
fx.reqFactory.EXPECT().
CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullRequest, nil)
res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.NoError(t, err)
require.Equal(t, fullRequest, res)
})
t.Run("head update empty equal heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"},
Changes: nil,
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes()
res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.NoError(t, err)
require.Nil(t, res)
})
}
func TestTreeSyncProtocol_FullSyncRequest(t *testing.T) {
ctx := context.Background()
fullResponse := &treechangeproto.TreeSyncMessage{
Content: &treechangeproto.TreeSyncContentValue{
Value: &treechangeproto.TreeSyncContentValue_FullSyncResponse{
FullSyncResponse: &treechangeproto.TreeFullSyncResponse{},
},
},
}
t.Run("full sync request with change", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
fx.reqFactory.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
res, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest)
require.NoError(t, err)
require.Equal(t, fullResponse, res)
})
t.Run("full sync request with change same heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h1"}).AnyTimes()
fx.reqFactory.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
res, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest)
require.NoError(t, err)
require.Equal(t, fullResponse, res)
})
t.Run("full sync request without changes", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.reqFactory.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
res, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest)
require.NoError(t, err)
require.Equal(t, fullResponse, res)
})
t.Run("full sync request with change, raw changes error", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, fmt.Errorf("addRawChanges error"))
_, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest)
require.Error(t, err)
})
}
func TestTreeSyncProtocol_FullSyncResponse(t *testing.T) {
ctx := context.Background()
t.Run("full sync response with change", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncResponse := &treechangeproto.TreeFullSyncResponse{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h2"}).AnyTimes()
fx.objectTreeMock.EXPECT().
HasChanges(gomock.Eq([]string{"h1"})).
Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullSyncResponse)
require.NoError(t, err)
})
t.Run("full sync response with same heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncResponse := &treechangeproto.TreeFullSyncResponse{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h1"}).AnyTimes()
err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullSyncResponse)
require.NoError(t, err)
})
}

View File

@ -3,11 +3,11 @@ package synctree
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
@ -82,51 +82,124 @@ func (m *messageLog) addMessage(msg protocolMsg) {
m.batcher.Add(context.Background(), msg) m.batcher.Add(context.Background(), msg)
} }
type requestPeerManager struct {
peerId string
handlers map[string]*testSyncHandler
log *messageLog
}
func newRequestPeerManager(peerId string, log *messageLog) *requestPeerManager {
return &requestPeerManager{
peerId: peerId,
handlers: map[string]*testSyncHandler{},
log: log,
}
}
func (r *requestPeerManager) addHandler(peerId string, handler *testSyncHandler) {
r.handlers[peerId] = handler
}
func (r *requestPeerManager) Run(ctx context.Context) (err error) {
return nil
}
func (r *requestPeerManager) Close(ctx context.Context) (err error) {
return nil
}
func (r *requestPeerManager) SendRequest(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
panic("should not be called")
}
func (r *requestPeerManager) QueueRequest(peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
pMsg := protocolMsg{
msg: msg,
senderId: r.peerId,
receiverId: peerId,
}
r.log.addMessage(pMsg)
return r.handlers[peerId].send(context.Background(), pMsg)
}
func (r *requestPeerManager) Init(a *app.App) (err error) {
return
}
func (r *requestPeerManager) Name() (name string) {
return
}
func (r *requestPeerManager) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
pMsg := protocolMsg{
msg: msg,
senderId: r.peerId,
receiverId: peerId,
}
r.log.addMessage(pMsg)
return r.handlers[peerId].send(context.Background(), pMsg)
}
func (r *requestPeerManager) Broadcast(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage) (err error) {
for _, handler := range r.handlers {
pMsg := protocolMsg{
msg: msg,
senderId: r.peerId,
receiverId: handler.peerId,
}
r.log.addMessage(pMsg)
handler.send(context.Background(), pMsg)
}
return
}
func (r *requestPeerManager) GetResponsiblePeers(ctx context.Context) (peers []peer.Peer, err error) {
return nil, nil
}
// testSyncHandler is the wrapper around individual tree to test sync protocol // testSyncHandler is the wrapper around individual tree to test sync protocol
type testSyncHandler struct { type testSyncHandler struct {
synchandler.SyncHandler synchandler.SyncHandler
batcher *mb.MB[protocolMsg] batcher *mb.MB[protocolMsg]
peerId string peerId string
aclList list.AclList aclList list.AclList
log *messageLog log *messageLog
syncClient objectsync.SyncClient syncClient SyncClient
builder objecttree.BuildObjectTreeFunc builder objecttree.BuildObjectTreeFunc
peerManager *requestPeerManager
} }
// createSyncHandler creates a sync handler when a tree is already created // createSyncHandler creates a sync handler when a tree is already created
func createSyncHandler(peerId, spaceId string, objTree objecttree.ObjectTree, log *messageLog) *testSyncHandler { func createSyncHandler(peerId, spaceId string, objTree objecttree.ObjectTree, log *messageLog) *testSyncHandler {
factory := objectsync.NewRequestFactory() peerManager := newRequestPeerManager(peerId, log)
syncClient := objectsync.NewSyncClient(spaceId, newTestMessagePool(peerId, log), factory) syncClient := NewSyncClient(spaceId, peerManager, peerManager)
netTree := &broadcastTree{ netTree := &broadcastTree{
ObjectTree: objTree, ObjectTree: objTree,
SyncClient: syncClient, SyncClient: syncClient,
} }
handler := newSyncTreeHandler(spaceId, netTree, syncClient, syncstatus.NewNoOpSyncStatus()) handler := newSyncTreeHandler(spaceId, netTree, syncClient, syncstatus.NewNoOpSyncStatus())
return newTestSyncHandler(peerId, handler) return &testSyncHandler{
SyncHandler: handler,
batcher: mb.New[protocolMsg](0),
peerId: peerId,
peerManager: peerManager,
}
} }
// createEmptySyncHandler creates a sync handler when the tree will be provided later (this emulates the situation when we have no tree) // createEmptySyncHandler creates a sync handler when the tree will be provided later (this emulates the situation when we have no tree)
func createEmptySyncHandler(peerId, spaceId string, builder objecttree.BuildObjectTreeFunc, aclList list.AclList, log *messageLog) *testSyncHandler { func createEmptySyncHandler(peerId, spaceId string, builder objecttree.BuildObjectTreeFunc, aclList list.AclList, log *messageLog) *testSyncHandler {
factory := objectsync.NewRequestFactory() peerManager := newRequestPeerManager(peerId, log)
syncClient := objectsync.NewSyncClient(spaceId, newTestMessagePool(peerId, log), factory) syncClient := NewSyncClient(spaceId, peerManager, peerManager)
batcher := mb.New[protocolMsg](0) batcher := mb.New[protocolMsg](0)
return &testSyncHandler{ return &testSyncHandler{
batcher: batcher,
peerId: peerId,
aclList: aclList,
log: log,
syncClient: syncClient,
builder: builder,
}
}
func newTestSyncHandler(peerId string, syncHandler synchandler.SyncHandler) *testSyncHandler {
batcher := mb.New[protocolMsg](0)
return &testSyncHandler{
SyncHandler: syncHandler,
batcher: batcher, batcher: batcher,
peerId: peerId, peerId: peerId,
aclList: aclList,
log: log,
syncClient: syncClient,
builder: builder,
peerManager: peerManager,
} }
} }
@ -140,13 +213,8 @@ func (h *testSyncHandler) HandleMessage(ctx context.Context, senderId string, re
return return
} }
if unmarshalled.Content.GetFullSyncResponse() == nil { if unmarshalled.Content.GetFullSyncResponse() == nil {
newTreeRequest := objectsync.NewRequestFactory().CreateNewTreeRequest() newTreeRequest := NewRequestFactory().CreateNewTreeRequest()
var objMsg *spacesyncproto.ObjectSyncMessage return h.syncClient.QueueRequest(senderId, request.ObjectId, newTreeRequest)
objMsg, err = objectsync.MarshallTreeMessage(newTreeRequest, request.SpaceId, request.ObjectId, "")
if err != nil {
return
}
return h.manager().SendPeer(context.Background(), senderId, objMsg)
} }
fullSyncResponse := unmarshalled.Content.GetFullSyncResponse() fullSyncResponse := unmarshalled.Content.GetFullSyncResponse()
treeStorage, _ := treestorage.NewInMemoryTreeStorage(unmarshalled.RootChange, []string{unmarshalled.RootChange.Id}, nil) treeStorage, _ := treestorage.NewInMemoryTreeStorage(unmarshalled.RootChange, []string{unmarshalled.RootChange.Id}, nil)
@ -166,20 +234,13 @@ func (h *testSyncHandler) HandleMessage(ctx context.Context, senderId string, re
return return
} }
h.SyncHandler = newSyncTreeHandler(request.SpaceId, netTree, h.syncClient, syncstatus.NewNoOpSyncStatus()) h.SyncHandler = newSyncTreeHandler(request.SpaceId, netTree, h.syncClient, syncstatus.NewNoOpSyncStatus())
var objMsg *spacesyncproto.ObjectSyncMessage headUpdate := NewRequestFactory().CreateHeadUpdate(netTree, res.Added)
newTreeRequest := objectsync.NewRequestFactory().CreateHeadUpdate(netTree, res.Added) h.syncClient.Broadcast(headUpdate)
objMsg, err = objectsync.MarshallTreeMessage(newTreeRequest, request.SpaceId, request.ObjectId, "") return nil
if err != nil {
return
}
return h.manager().Broadcast(context.Background(), objMsg)
} }
func (h *testSyncHandler) manager() *testMessagePool { func (h *testSyncHandler) manager() *requestPeerManager {
if h.SyncHandler != nil { return h.peerManager
return h.SyncHandler.(*syncTreeHandler).syncClient.MessagePool().(*testMessagePool)
}
return h.syncClient.MessagePool().(*testMessagePool)
} }
func (h *testSyncHandler) tree() *broadcastTree { func (h *testSyncHandler) tree() *broadcastTree {
@ -211,74 +272,28 @@ func (h *testSyncHandler) run(ctx context.Context, t *testing.T, wg *sync.WaitGr
h.tree().Unlock() h.tree().Unlock()
continue continue
} }
err = h.HandleMessage(ctx, res.senderId, res.msg) if res.description().name == "FullSyncRequest" {
if err != nil { resp, err := h.HandleRequest(ctx, res.senderId, res.msg)
fmt.Println("error handling message", err.Error()) if err != nil {
continue fmt.Println("error handling request", err.Error())
continue
}
h.peerManager.SendPeer(ctx, res.senderId, resp)
} else {
err = h.HandleMessage(ctx, res.senderId, res.msg)
if err != nil {
fmt.Println("error handling message", err.Error())
}
} }
} }
}() }()
} }
// testMessagePool captures all other handlers and sends messages to them
type testMessagePool struct {
peerId string
handlers map[string]*testSyncHandler
log *messageLog
}
func newTestMessagePool(peerId string, log *messageLog) *testMessagePool {
return &testMessagePool{handlers: map[string]*testSyncHandler{}, peerId: peerId, log: log}
}
func (m *testMessagePool) addHandler(peerId string, handler *testSyncHandler) {
m.handlers[peerId] = handler
}
func (m *testMessagePool) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
pMsg := protocolMsg{
msg: msg,
senderId: m.peerId,
receiverId: peerId,
}
m.log.addMessage(pMsg)
return m.handlers[peerId].send(context.Background(), pMsg)
}
func (m *testMessagePool) Broadcast(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage) (err error) {
for _, handler := range m.handlers {
pMsg := protocolMsg{
msg: msg,
senderId: m.peerId,
receiverId: handler.peerId,
}
m.log.addMessage(pMsg)
handler.send(context.Background(), pMsg)
}
return
}
func (m *testMessagePool) GetResponsiblePeers(ctx context.Context) (peers []peer.Peer, err error) {
panic("should not be called")
}
func (m *testMessagePool) LastUsage() time.Time {
panic("should not be called")
}
func (m *testMessagePool) HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) {
panic("should not be called")
}
func (m *testMessagePool) SendSync(ctx context.Context, peerId string, message *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
panic("should not be called")
}
// broadcastTree is the tree that broadcasts changes to everyone when changes are added // broadcastTree is the tree that broadcasts changes to everyone when changes are added
// it is a simplified version of SyncTree which is easier to use in the test environment // it is a simplified version of SyncTree which is easier to use in the test environment
type broadcastTree struct { type broadcastTree struct {
objecttree.ObjectTree objecttree.ObjectTree
objectsync.SyncClient SyncClient
} }
func (b *broadcastTree) AddRawChanges(ctx context.Context, changes objecttree.RawChangesPayload) (objecttree.AddResult, error) { func (b *broadcastTree) AddRawChanges(ctx context.Context, changes objecttree.RawChangesPayload) (objecttree.AddResult, error) {
@ -287,7 +302,7 @@ func (b *broadcastTree) AddRawChanges(ctx context.Context, changes objecttree.Ra
return objecttree.AddResult{}, err return objecttree.AddResult{}, err
} }
upd := b.SyncClient.CreateHeadUpdate(b.ObjectTree, res.Added) upd := b.SyncClient.CreateHeadUpdate(b.ObjectTree, res.Added)
b.SyncClient.Broadcast(ctx, upd) b.SyncClient.Broadcast(upd)
return res, nil return res, nil
} }

View File

@ -0,0 +1,98 @@
package objectmanager
import (
"context"
"errors"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/syncobjectgetter"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/settings"
"github.com/anyproto/any-sync/commonspace/spacestate"
"sync/atomic"
)
var (
ErrSpaceClosed = errors.New("space is closed")
)
type ObjectManager interface {
treemanager.TreeManager
AddObject(object syncobjectgetter.SyncObject)
GetObject(ctx context.Context, objectId string) (obj syncobjectgetter.SyncObject, err error)
}
type objectManager struct {
treemanager.TreeManager
spaceId string
reservedObjects []syncobjectgetter.SyncObject
spaceIsClosed *atomic.Bool
}
func New(manager treemanager.TreeManager) ObjectManager {
return &objectManager{
TreeManager: manager,
}
}
func (o *objectManager) Init(a *app.App) (err error) {
state := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
o.spaceId = state.SpaceId
o.spaceIsClosed = state.SpaceIsClosed
settingsObject := a.MustComponent(settings.CName).(settings.Settings).SettingsObject()
acl := a.MustComponent(syncacl.CName).(*syncacl.SyncAcl)
o.AddObject(settingsObject)
o.AddObject(acl)
return nil
}
func (o *objectManager) Run(ctx context.Context) (err error) {
return nil
}
func (o *objectManager) Close(ctx context.Context) (err error) {
return nil
}
func (o *objectManager) AddObject(object syncobjectgetter.SyncObject) {
o.reservedObjects = append(o.reservedObjects, object)
}
func (o *objectManager) Name() string {
return treemanager.CName
}
func (o *objectManager) GetTree(ctx context.Context, spaceId, treeId string) (objecttree.ObjectTree, error) {
if o.spaceIsClosed.Load() {
return nil, ErrSpaceClosed
}
if obj := o.getReservedObject(treeId); obj != nil {
return obj.(objecttree.ObjectTree), nil
}
return o.TreeManager.GetTree(ctx, spaceId, treeId)
}
func (o *objectManager) getReservedObject(id string) syncobjectgetter.SyncObject {
for _, obj := range o.reservedObjects {
if obj != nil && obj.Id() == id {
return obj
}
}
return nil
}
func (o *objectManager) GetObject(ctx context.Context, objectId string) (obj syncobjectgetter.SyncObject, err error) {
if o.spaceIsClosed.Load() {
return nil, ErrSpaceClosed
}
if obj := o.getReservedObject(objectId); obj != nil {
return obj, nil
}
t, err := o.TreeManager.GetTree(ctx, o.spaceId, objectId)
if err != nil {
return
}
obj = t.(syncobjectgetter.SyncObject)
return
}

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/objectsync (interfaces: SyncClient) // Source: github.com/anyproto/any-sync/commonspace/objectsync (interfaces: ObjectSync)
// Package mock_objectsync is a generated GoMock package. // Package mock_objectsync is a generated GoMock package.
package mock_objectsync package mock_objectsync
@ -7,146 +7,146 @@ package mock_objectsync
import ( import (
context "context" context "context"
reflect "reflect" reflect "reflect"
time "time"
objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" app "github.com/anyproto/any-sync/app"
treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
objectsync "github.com/anyproto/any-sync/commonspace/objectsync" objectsync "github.com/anyproto/any-sync/commonspace/objectsync"
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto" spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
// MockSyncClient is a mock of SyncClient interface. // MockObjectSync is a mock of ObjectSync interface.
type MockSyncClient struct { type MockObjectSync struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockSyncClientMockRecorder recorder *MockObjectSyncMockRecorder
} }
// MockSyncClientMockRecorder is the mock recorder for MockSyncClient. // MockObjectSyncMockRecorder is the mock recorder for MockObjectSync.
type MockSyncClientMockRecorder struct { type MockObjectSyncMockRecorder struct {
mock *MockSyncClient mock *MockObjectSync
} }
// NewMockSyncClient creates a new mock instance. // NewMockObjectSync creates a new mock instance.
func NewMockSyncClient(ctrl *gomock.Controller) *MockSyncClient { func NewMockObjectSync(ctrl *gomock.Controller) *MockObjectSync {
mock := &MockSyncClient{ctrl: ctrl} mock := &MockObjectSync{ctrl: ctrl}
mock.recorder = &MockSyncClientMockRecorder{mock} mock.recorder = &MockObjectSyncMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSyncClient) EXPECT() *MockSyncClientMockRecorder { func (m *MockObjectSync) EXPECT() *MockObjectSyncMockRecorder {
return m.recorder return m.recorder
} }
// Broadcast mocks base method. // Close mocks base method.
func (m *MockSyncClient) Broadcast(arg0 context.Context, arg1 *treechangeproto.TreeSyncMessage) { func (m *MockObjectSync) Close(arg0 context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Broadcast", arg0, arg1) ret := m.ctrl.Call(m, "Close", arg0)
} ret0, _ := ret[0].(error)
// Broadcast indicates an expected call of Broadcast.
func (mr *MockSyncClientMockRecorder) Broadcast(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Broadcast", reflect.TypeOf((*MockSyncClient)(nil).Broadcast), arg0, arg1)
}
// CreateFullSyncRequest mocks base method.
func (m *MockSyncClient) CreateFullSyncRequest(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
func (mr *MockSyncClientMockRecorder) CreateFullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncRequest), arg0, arg1, arg2)
}
// CreateFullSyncResponse mocks base method.
func (m *MockSyncClient) CreateFullSyncResponse(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
func (mr *MockSyncClientMockRecorder) CreateFullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncResponse), arg0, arg1, arg2)
}
// CreateHeadUpdate mocks base method.
func (m *MockSyncClient) CreateHeadUpdate(arg0 objecttree.ObjectTree, arg1 []*treechangeproto.RawTreeChangeWithId) *treechangeproto.TreeSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
return ret0 return ret0
} }
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate. // Close indicates an expected call of Close.
func (mr *MockSyncClientMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call { func (mr *MockObjectSyncMockRecorder) Close(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockSyncClient)(nil).CreateHeadUpdate), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockObjectSync)(nil).Close), arg0)
} }
// CreateNewTreeRequest mocks base method. // CloseThread mocks base method.
func (m *MockSyncClient) CreateNewTreeRequest() *treechangeproto.TreeSyncMessage { func (m *MockObjectSync) CloseThread(arg0 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateNewTreeRequest") ret := m.ctrl.Call(m, "CloseThread", arg0)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// CreateNewTreeRequest indicates an expected call of CreateNewTreeRequest. // CloseThread indicates an expected call of CloseThread.
func (mr *MockSyncClientMockRecorder) CreateNewTreeRequest() *gomock.Call { func (mr *MockObjectSyncMockRecorder) CloseThread(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNewTreeRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateNewTreeRequest)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseThread", reflect.TypeOf((*MockObjectSync)(nil).CloseThread), arg0)
} }
// MessagePool mocks base method. // HandleMessage mocks base method.
func (m *MockSyncClient) MessagePool() objectsync.MessagePool { func (m *MockObjectSync) HandleMessage(arg0 context.Context, arg1 objectsync.HandleMessage) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MessagePool") ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
ret0, _ := ret[0].(objectsync.MessagePool) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// MessagePool indicates an expected call of MessagePool. // HandleMessage indicates an expected call of HandleMessage.
func (mr *MockSyncClientMockRecorder) MessagePool() *gomock.Call { func (mr *MockObjectSyncMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MessagePool", reflect.TypeOf((*MockSyncClient)(nil).MessagePool)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockObjectSync)(nil).HandleMessage), arg0, arg1)
} }
// SendSync mocks base method. // HandleRequest mocks base method.
func (m *MockSyncClient) SendSync(arg0 context.Context, arg1, arg2 string, arg3 *treechangeproto.TreeSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) { func (m *MockObjectSync) HandleRequest(arg0 context.Context, arg1 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendSync", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "HandleRequest", arg0, arg1)
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage) ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// SendSync indicates an expected call of SendSync. // HandleRequest indicates an expected call of HandleRequest.
func (mr *MockSyncClientMockRecorder) SendSync(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { func (mr *MockObjectSyncMockRecorder) HandleRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSync", reflect.TypeOf((*MockSyncClient)(nil).SendSync), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleRequest", reflect.TypeOf((*MockObjectSync)(nil).HandleRequest), arg0, arg1)
} }
// SendWithReply mocks base method. // Init mocks base method.
func (m *MockSyncClient) SendWithReply(arg0 context.Context, arg1, arg2 string, arg3 *treechangeproto.TreeSyncMessage, arg4 string) error { func (m *MockObjectSync) Init(arg0 *app.App) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendWithReply", arg0, arg1, arg2, arg3, arg4) ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// SendWithReply indicates an expected call of SendWithReply. // Init indicates an expected call of Init.
func (mr *MockSyncClientMockRecorder) SendWithReply(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { func (mr *MockObjectSyncMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWithReply", reflect.TypeOf((*MockSyncClient)(nil).SendWithReply), arg0, arg1, arg2, arg3, arg4) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockObjectSync)(nil).Init), arg0)
}
// LastUsage mocks base method.
func (m *MockObjectSync) LastUsage() time.Time {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LastUsage")
ret0, _ := ret[0].(time.Time)
return ret0
}
// LastUsage indicates an expected call of LastUsage.
func (mr *MockObjectSyncMockRecorder) LastUsage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastUsage", reflect.TypeOf((*MockObjectSync)(nil).LastUsage))
}
// Name mocks base method.
func (m *MockObjectSync) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockObjectSyncMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockObjectSync)(nil).Name))
}
// Run mocks base method.
func (m *MockObjectSync) Run(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Run indicates an expected call of Run.
func (mr *MockObjectSyncMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockObjectSync)(nil).Run), arg0)
} }

View File

@ -1,142 +0,0 @@
package objectsync
import (
"context"
"fmt"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"go.uber.org/zap"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
type LastUsage interface {
LastUsage() time.Time
}
// MessagePool can be made generic to work with different streams
type MessagePool interface {
LastUsage
synchandler.SyncHandler
peermanager.PeerManager
SendSync(ctx context.Context, peerId string, message *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
}
type MessageHandler func(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error)
type responseWaiter struct {
ch chan *spacesyncproto.ObjectSyncMessage
}
type messagePool struct {
sync.Mutex
peermanager.PeerManager
messageHandler MessageHandler
waiters map[string]responseWaiter
waitersMx sync.Mutex
counter atomic.Uint64
lastUsage atomic.Int64
}
func newMessagePool(peerManager peermanager.PeerManager, messageHandler MessageHandler) MessagePool {
s := &messagePool{
PeerManager: peerManager,
messageHandler: messageHandler,
waiters: make(map[string]responseWaiter),
}
return s
}
func (s *messagePool) SendSync(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
s.updateLastUsage()
if _, ok := ctx.Deadline(); !ok {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Minute)
defer cancel()
}
newCounter := s.counter.Add(1)
msg.RequestId = genReplyKey(peerId, msg.ObjectId, newCounter)
log.InfoCtx(ctx, "mpool sendSync", zap.String("requestId", msg.RequestId))
s.waitersMx.Lock()
waiter := responseWaiter{
ch: make(chan *spacesyncproto.ObjectSyncMessage, 1),
}
s.waiters[msg.RequestId] = waiter
s.waitersMx.Unlock()
err = s.SendPeer(ctx, peerId, msg)
if err != nil {
return
}
select {
case <-ctx.Done():
s.waitersMx.Lock()
delete(s.waiters, msg.RequestId)
s.waitersMx.Unlock()
log.With(zap.String("requestId", msg.RequestId)).DebugCtx(ctx, "time elapsed when waiting")
err = fmt.Errorf("sendSync context error: %v", ctx.Err())
case reply = <-waiter.ch:
// success
}
return
}
func (s *messagePool) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
s.updateLastUsage()
return s.PeerManager.SendPeer(ctx, peerId, msg)
}
func (s *messagePool) Broadcast(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage) (err error) {
s.updateLastUsage()
return s.PeerManager.Broadcast(ctx, msg)
}
func (s *messagePool) HandleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
s.updateLastUsage()
if msg.ReplyId != "" {
log.InfoCtx(ctx, "mpool receive reply", zap.String("replyId", msg.ReplyId))
// we got reply, send it to waiter
if s.stopWaiter(msg) {
return
}
log.WarnCtx(ctx, "reply id does not exist", zap.String("replyId", msg.ReplyId))
return
}
return s.messageHandler(ctx, senderId, msg)
}
func (s *messagePool) LastUsage() time.Time {
return time.Unix(s.lastUsage.Load(), 0)
}
func (s *messagePool) updateLastUsage() {
s.lastUsage.Store(time.Now().Unix())
}
func (s *messagePool) stopWaiter(msg *spacesyncproto.ObjectSyncMessage) bool {
s.waitersMx.Lock()
waiter, exists := s.waiters[msg.ReplyId]
if exists {
delete(s.waiters, msg.ReplyId)
s.waitersMx.Unlock()
waiter.ch <- msg
return true
}
s.waitersMx.Unlock()
return false
}
func genReplyKey(peerId, treeId string, counter uint64) string {
b := &strings.Builder{}
b.WriteString(peerId)
b.WriteString(".")
b.WriteString(treeId)
b.WriteString(".")
b.WriteString(strconv.FormatUint(counter, 36))
return b.String()
}

View File

@ -1,18 +1,23 @@
//go:generate mockgen -destination mock_objectsync/mock_objectsync.go github.com/anyproto/any-sync/commonspace/objectsync SyncClient //go:generate mockgen -destination mock_objectsync/mock_objectsync.go github.com/anyproto/any-sync/commonspace/objectsync ObjectSync
package objectsync package objectsync
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/metric"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/util/multiqueue"
"github.com/cheggaaa/mb/v3"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/syncobjectgetter" "github.com/anyproto/any-sync/commonspace/object/syncobjectgetter"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
@ -20,138 +25,211 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
var log = logger.NewNamed("common.commonspace.objectsync") const CName = "common.commonspace.objectsync"
var log = logger.NewNamed(CName)
type ObjectSync interface { type ObjectSync interface {
LastUsage LastUsage() time.Time
synchandler.SyncHandler HandleMessage(ctx context.Context, hm HandleMessage) (err error)
SyncClient() SyncClient HandleRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error)
CloseThread(id string) (err error)
app.ComponentRunnable
}
Close() (err error) type HandleMessage struct {
Id uint64
ReceiveTime time.Time
StartHandlingTime time.Time
Deadline time.Time
SenderId string
Message *spacesyncproto.ObjectSyncMessage
PeerCtx context.Context
}
func (m HandleMessage) LogFields(fields ...zap.Field) []zap.Field {
return append(fields,
metric.SpaceId(m.Message.SpaceId),
metric.ObjectId(m.Message.ObjectId),
metric.QueueDur(m.StartHandlingTime.Sub(m.ReceiveTime)),
metric.TotalDur(time.Since(m.ReceiveTime)),
)
} }
type objectSync struct { type objectSync struct {
spaceId string spaceId string
messagePool MessagePool
syncClient SyncClient
objectGetter syncobjectgetter.SyncObjectGetter objectGetter syncobjectgetter.SyncObjectGetter
configuration nodeconf.NodeConf configuration nodeconf.NodeConf
spaceStorage spacestorage.SpaceStorage spaceStorage spacestorage.SpaceStorage
metric metric.Metric
syncCtx context.Context
cancelSync context.CancelFunc
spaceIsDeleted *atomic.Bool spaceIsDeleted *atomic.Bool
handleQueue multiqueue.MultiQueue[HandleMessage]
} }
func NewObjectSync( func (s *objectSync) Init(a *app.App) (err error) {
spaceId string, s.spaceStorage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
spaceIsDeleted *atomic.Bool, s.objectGetter = a.MustComponent(treemanager.CName).(syncobjectgetter.SyncObjectGetter)
configuration nodeconf.NodeConf, s.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
peerManager peermanager.PeerManager, sharedData := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
objectGetter syncobjectgetter.SyncObjectGetter, mc := a.Component(metric.CName)
storage spacestorage.SpaceStorage) ObjectSync { if mc != nil {
syncCtx, cancel := context.WithCancel(context.Background()) s.metric = mc.(metric.Metric)
os := &objectSync{
objectGetter: objectGetter,
spaceStorage: storage,
spaceId: spaceId,
syncCtx: syncCtx,
cancelSync: cancel,
spaceIsDeleted: spaceIsDeleted,
configuration: configuration,
} }
os.messagePool = newMessagePool(peerManager, os.handleMessage) s.spaceIsDeleted = sharedData.SpaceIsDeleted
os.syncClient = NewSyncClient(spaceId, os.messagePool, NewRequestFactory()) s.spaceId = sharedData.SpaceId
return os s.handleQueue = multiqueue.New[HandleMessage](s.processHandleMessage, 100)
return nil
} }
func (s *objectSync) Close() (err error) { func (s *objectSync) Name() (name string) {
s.cancelSync() return CName
return }
func (s *objectSync) Run(ctx context.Context) (err error) {
return nil
}
func (s *objectSync) Close(ctx context.Context) (err error) {
return s.handleQueue.Close()
}
func New() ObjectSync {
return &objectSync{}
} }
func (s *objectSync) LastUsage() time.Time { func (s *objectSync) LastUsage() time.Time {
return s.messagePool.LastUsage() // TODO: add time
return time.Time{}
} }
func (s *objectSync) HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) { func (s *objectSync) HandleRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) {
return s.messagePool.HandleMessage(ctx, senderId, message) peerId, err := peer.CtxPeerId(ctx)
if err != nil {
return nil, err
}
return s.handleRequest(ctx, peerId, req)
} }
func (s *objectSync) handleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { func (s *objectSync) HandleMessage(ctx context.Context, hm HandleMessage) (err error) {
log := log.With( threadId := hm.Message.ObjectId
zap.String("objectId", msg.ObjectId), hm.ReceiveTime = time.Now()
zap.String("requestId", msg.RequestId), if hm.PeerCtx == nil {
zap.String("replyId", msg.ReplyId)) hm.PeerCtx = ctx
}
err = s.handleQueue.Add(ctx, threadId, hm)
if err == mb.ErrOverflowed {
log.InfoCtx(ctx, "queue overflowed", zap.String("spaceId", s.spaceId), zap.String("objectId", threadId))
// skip overflowed error
return nil
}
return
}
func (s *objectSync) processHandleMessage(msg HandleMessage) {
var err error
msg.StartHandlingTime = time.Now()
ctx := peer.CtxWithPeerId(context.Background(), msg.SenderId)
ctx = logger.CtxWithFields(ctx, zap.Uint64("msgId", msg.Id), zap.String("senderId", msg.SenderId))
defer func() {
if s.metric == nil {
return
}
s.metric.RequestLog(msg.PeerCtx, "space.streamOp", msg.LogFields(
zap.Error(err),
)...)
}()
if !msg.Deadline.IsZero() {
now := time.Now()
if now.After(msg.Deadline) {
log.InfoCtx(ctx, "skip message: deadline exceed")
err = context.DeadlineExceeded
return
}
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, msg.Deadline)
defer cancel()
}
if err = s.handleMessage(ctx, msg.SenderId, msg.Message); err != nil {
if msg.Message.ObjectId != "" {
// cleanup thread on error
_ = s.handleQueue.CloseThread(msg.Message.ObjectId)
}
log.InfoCtx(ctx, "handleMessage error", zap.Error(err))
}
}
func (s *objectSync) handleRequest(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) {
log := log.With(zap.String("objectId", msg.ObjectId))
if s.spaceIsDeleted.Load() { if s.spaceIsDeleted.Load() {
log = log.With(zap.Bool("isDeleted", true)) log = log.With(zap.Bool("isDeleted", true))
// preventing sync with other clients if they are not just syncing the settings tree // preventing sync with other clients if they are not just syncing the settings tree
if !slices.Contains(s.configuration.NodeIds(s.spaceId), senderId) && msg.ObjectId != s.spaceStorage.SpaceSettingsId() { if !slices.Contains(s.configuration.NodeIds(s.spaceId), senderId) && msg.ObjectId != s.spaceStorage.SpaceSettingsId() {
s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId) return nil, spacesyncproto.ErrSpaceIsDeleted
return fmt.Errorf("can't perform operation with object, space is deleted")
} }
} }
log.DebugCtx(ctx, "handling message") err = s.checkEmptyFullSync(log, msg)
if err != nil {
return nil, err
}
obj, err := s.objectGetter.GetObject(ctx, msg.ObjectId)
if err != nil {
return nil, treechangeproto.ErrGetTree
}
return obj.HandleRequest(ctx, senderId, msg)
}
func (s *objectSync) handleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
log := log.With(zap.String("objectId", msg.ObjectId))
if s.spaceIsDeleted.Load() {
log = log.With(zap.Bool("isDeleted", true))
// preventing sync with other clients if they are not just syncing the settings tree
if !slices.Contains(s.configuration.NodeIds(s.spaceId), senderId) && msg.ObjectId != s.spaceStorage.SpaceSettingsId() {
return spacesyncproto.ErrSpaceIsDeleted
}
}
err = s.checkEmptyFullSync(log, msg)
if err != nil {
return err
}
obj, err := s.objectGetter.GetObject(ctx, msg.ObjectId)
if err != nil {
return fmt.Errorf("failed to get object from cache: %w", err)
}
err = obj.HandleMessage(ctx, senderId, msg)
if err != nil {
return fmt.Errorf("failed to handle message: %w", err)
}
return
}
func (s *objectSync) CloseThread(id string) (err error) {
return s.handleQueue.CloseThread(id)
}
func (s *objectSync) checkEmptyFullSync(log logger.CtxLogger, msg *spacesyncproto.ObjectSyncMessage) (err error) {
hasTree, err := s.spaceStorage.HasTree(msg.ObjectId) hasTree, err := s.spaceStorage.HasTree(msg.ObjectId)
if err != nil { if err != nil {
s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId) log.Warn("failed to execute get operation on storage has tree", zap.Error(err))
return fmt.Errorf("falied to execute get operation on storage has tree: %w", err) return spacesyncproto.ErrUnexpected
} }
// in this case we will try to get it from remote, unless the sender also sent us the same request :-) // in this case we will try to get it from remote, unless the sender also sent us the same request :-)
if !hasTree { if !hasTree {
treeMsg := &treechangeproto.TreeSyncMessage{} treeMsg := &treechangeproto.TreeSyncMessage{}
err = proto.Unmarshal(msg.Payload, treeMsg) err = proto.Unmarshal(msg.Payload, treeMsg)
if err != nil { if err != nil {
s.sendError(ctx, nil, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId, msg.RequestId) return nil
return fmt.Errorf("failed to unmarshall tree sync message: %w", err)
} }
// this means that we don't have the tree locally and therefore can't return it // this means that we don't have the tree locally and therefore can't return it
if s.isEmptyFullSyncRequest(treeMsg) { if s.isEmptyFullSyncRequest(treeMsg) {
err = treechangeproto.ErrGetTree return treechangeproto.ErrGetTree
s.sendError(ctx, nil, treechangeproto.ErrGetTree, senderId, msg.ObjectId, msg.RequestId)
return fmt.Errorf("failed to get tree from storage on full sync: %w", err)
} }
} }
obj, err := s.objectGetter.GetObject(ctx, msg.ObjectId)
if err != nil {
// TODO: write tests for object sync https://linear.app/anytype/issue/GO-1299/write-tests-for-commonspaceobjectsync
s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId)
return fmt.Errorf("failed to get object from cache: %w", err)
}
// TODO: unmarshall earlier
err = obj.HandleMessage(ctx, senderId, msg)
if err != nil {
s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId)
return fmt.Errorf("failed to handle message: %w", err)
}
return return
} }
func (s *objectSync) SyncClient() SyncClient {
return s.syncClient
}
func (s *objectSync) unmarshallSendError(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage, respErr error, senderId, objectId string) {
unmarshalled := &treechangeproto.TreeSyncMessage{}
err := proto.Unmarshal(msg.Payload, unmarshalled)
if err != nil {
return
}
s.sendError(ctx, unmarshalled.RootChange, respErr, senderId, objectId, msg.RequestId)
}
func (s *objectSync) sendError(ctx context.Context, root *treechangeproto.RawTreeChangeWithId, respErr error, senderId, objectId, replyId string) {
// we don't send errors if have no reply id, this can lead to bugs and also nobody needs this error
if replyId == "" {
return
}
resp := treechangeproto.WrapError(respErr, root)
if err := s.syncClient.SendWithReply(ctx, senderId, objectId, resp, replyId); err != nil {
log.InfoCtx(ctx, "failed to send error to client")
}
}
func (s *objectSync) isEmptyFullSyncRequest(msg *treechangeproto.TreeSyncMessage) bool { func (s *objectSync) isEmptyFullSyncRequest(msg *treechangeproto.TreeSyncMessage) bool {
return msg.GetContent().GetFullSyncRequest() != nil && len(msg.GetContent().GetFullSyncRequest().GetHeads()) == 0 return msg.GetContent().GetFullSyncRequest() != nil && len(msg.GetContent().GetFullSyncRequest().GetHeads()) == 0
} }

View File

@ -1,78 +0,0 @@
package objectsync
import (
"context"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"go.uber.org/zap"
)
type SyncClient interface {
RequestFactory
Broadcast(ctx context.Context, msg *treechangeproto.TreeSyncMessage)
SendWithReply(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage, replyId string) (err error)
SendSync(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
MessagePool() MessagePool
}
type syncClient struct {
RequestFactory
spaceId string
messagePool MessagePool
}
func NewSyncClient(
spaceId string,
messagePool MessagePool,
factory RequestFactory) SyncClient {
return &syncClient{
messagePool: messagePool,
RequestFactory: factory,
spaceId: spaceId,
}
}
func (s *syncClient) Broadcast(ctx context.Context, msg *treechangeproto.TreeSyncMessage) {
objMsg, err := MarshallTreeMessage(msg, s.spaceId, msg.RootChange.Id, "")
if err != nil {
return
}
err = s.messagePool.Broadcast(ctx, objMsg)
if err != nil {
log.DebugCtx(ctx, "broadcast error", zap.Error(err))
}
}
func (s *syncClient) SendSync(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, "")
if err != nil {
return
}
return s.messagePool.SendSync(ctx, peerId, objMsg)
}
func (s *syncClient) SendWithReply(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage, replyId string) (err error) {
objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, replyId)
if err != nil {
return
}
return s.messagePool.SendPeer(ctx, peerId, objMsg)
}
func (s *syncClient) MessagePool() MessagePool {
return s.messagePool
}
func MarshallTreeMessage(message *treechangeproto.TreeSyncMessage, spaceId, objectId, replyId string) (objMsg *spacesyncproto.ObjectSyncMessage, err error) {
payload, err := message.Marshal()
if err != nil {
return
}
objMsg = &spacesyncproto.ObjectSyncMessage{
ReplyId: replyId,
Payload: payload,
ObjectId: objectId,
SpaceId: spaceId,
}
return
}

View File

@ -6,5 +6,6 @@ import (
) )
type SyncHandler interface { type SyncHandler interface {
HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error)
HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error)
} }

View File

@ -0,0 +1,99 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/objecttreebuilder (interfaces: TreeBuilder)
// Package mock_objecttreebuilder is a generated GoMock package.
package mock_objecttreebuilder
import (
context "context"
reflect "reflect"
objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
updatelistener "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
objecttreebuilder "github.com/anyproto/any-sync/commonspace/objecttreebuilder"
gomock "github.com/golang/mock/gomock"
)
// MockTreeBuilder is a mock of TreeBuilder interface.
type MockTreeBuilder struct {
ctrl *gomock.Controller
recorder *MockTreeBuilderMockRecorder
}
// MockTreeBuilderMockRecorder is the mock recorder for MockTreeBuilder.
type MockTreeBuilderMockRecorder struct {
mock *MockTreeBuilder
}
// NewMockTreeBuilder creates a new mock instance.
func NewMockTreeBuilder(ctrl *gomock.Controller) *MockTreeBuilder {
mock := &MockTreeBuilder{ctrl: ctrl}
mock.recorder = &MockTreeBuilderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTreeBuilder) EXPECT() *MockTreeBuilderMockRecorder {
return m.recorder
}
// BuildHistoryTree mocks base method.
func (m *MockTreeBuilder) BuildHistoryTree(arg0 context.Context, arg1 string, arg2 objecttreebuilder.HistoryTreeOpts) (objecttree.HistoryTree, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BuildHistoryTree", arg0, arg1, arg2)
ret0, _ := ret[0].(objecttree.HistoryTree)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BuildHistoryTree indicates an expected call of BuildHistoryTree.
func (mr *MockTreeBuilderMockRecorder) BuildHistoryTree(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildHistoryTree", reflect.TypeOf((*MockTreeBuilder)(nil).BuildHistoryTree), arg0, arg1, arg2)
}
// BuildTree mocks base method.
func (m *MockTreeBuilder) BuildTree(arg0 context.Context, arg1 string, arg2 objecttreebuilder.BuildTreeOpts) (objecttree.ObjectTree, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BuildTree", arg0, arg1, arg2)
ret0, _ := ret[0].(objecttree.ObjectTree)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BuildTree indicates an expected call of BuildTree.
func (mr *MockTreeBuilderMockRecorder) BuildTree(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildTree", reflect.TypeOf((*MockTreeBuilder)(nil).BuildTree), arg0, arg1, arg2)
}
// CreateTree mocks base method.
func (m *MockTreeBuilder) CreateTree(arg0 context.Context, arg1 objecttree.ObjectTreeCreatePayload) (treestorage.TreeStorageCreatePayload, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateTree", arg0, arg1)
ret0, _ := ret[0].(treestorage.TreeStorageCreatePayload)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateTree indicates an expected call of CreateTree.
func (mr *MockTreeBuilderMockRecorder) CreateTree(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTree", reflect.TypeOf((*MockTreeBuilder)(nil).CreateTree), arg0, arg1)
}
// PutTree mocks base method.
func (m *MockTreeBuilder) PutTree(arg0 context.Context, arg1 treestorage.TreeStorageCreatePayload, arg2 updatelistener.UpdateListener) (objecttree.ObjectTree, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PutTree", arg0, arg1, arg2)
ret0, _ := ret[0].(objecttree.ObjectTree)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PutTree indicates an expected call of PutTree.
func (mr *MockTreeBuilderMockRecorder) PutTree(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutTree", reflect.TypeOf((*MockTreeBuilder)(nil).PutTree), arg0, arg1, arg2)
}

View File

@ -0,0 +1,205 @@
//go:generate mockgen -destination mock_objecttreebuilder/mock_objecttreebuilder.go github.com/anyproto/any-sync/commonspace/objecttreebuilder TreeBuilder
package objecttreebuilder
import (
"context"
"errors"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/headsync"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/requestmanager"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/nodeconf"
"go.uber.org/zap"
"sync/atomic"
)
type BuildTreeOpts struct {
Listener updatelistener.UpdateListener
WaitTreeRemoteSync bool
TreeBuilder objecttree.BuildObjectTreeFunc
}
const CName = "common.commonspace.objecttreebuilder"
var log = logger.NewNamed(CName)
var ErrSpaceClosed = errors.New("space is closed")
type HistoryTreeOpts struct {
BeforeId string
Include bool
BuildFullTree bool
}
type TreeBuilder interface {
BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (t objecttree.ObjectTree, err error)
BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (t objecttree.HistoryTree, err error)
CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error)
PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (t objecttree.ObjectTree, err error)
}
type TreeBuilderComponent interface {
app.Component
TreeBuilder
}
func New() TreeBuilderComponent {
return &treeBuilder{}
}
type treeBuilder struct {
syncClient synctree.SyncClient
configuration nodeconf.NodeConf
headsNotifiable synctree.HeadNotifiable
peerManager peermanager.PeerManager
requestManager requestmanager.RequestManager
spaceStorage spacestorage.SpaceStorage
syncStatus syncstatus.StatusUpdater
objectSync objectsync.ObjectSync
log logger.CtxLogger
builder objecttree.BuildObjectTreeFunc
spaceId string
aclList list.AclList
treesUsed *atomic.Int32
isClosed *atomic.Bool
}
func (t *treeBuilder) Init(a *app.App) (err error) {
state := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
t.spaceId = state.SpaceId
t.isClosed = state.SpaceIsClosed
t.treesUsed = state.TreesUsed
t.builder = state.TreeBuilderFunc
t.aclList = a.MustComponent(syncacl.CName).(*syncacl.SyncAcl)
t.spaceStorage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
t.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
t.headsNotifiable = a.MustComponent(headsync.CName).(headsync.HeadSync)
t.syncStatus = a.MustComponent(syncstatus.CName).(syncstatus.StatusUpdater)
t.peerManager = a.MustComponent(peermanager.CName).(peermanager.PeerManager)
t.requestManager = a.MustComponent(requestmanager.CName).(requestmanager.RequestManager)
t.objectSync = a.MustComponent(objectsync.CName).(objectsync.ObjectSync)
t.log = log.With(zap.String("spaceId", t.spaceId))
t.syncClient = synctree.NewSyncClient(t.spaceId, t.requestManager, t.peerManager)
return nil
}
func (t *treeBuilder) Name() (name string) {
return CName
}
func (t *treeBuilder) BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (ot objecttree.ObjectTree, err error) {
if t.isClosed.Load() {
// TODO: change to real error
err = ErrSpaceClosed
return
}
treeBuilder := opts.TreeBuilder
if treeBuilder == nil {
treeBuilder = t.builder
}
deps := synctree.BuildDeps{
SpaceId: t.spaceId,
SyncClient: t.syncClient,
Configuration: t.configuration,
HeadNotifiable: t.headsNotifiable,
Listener: opts.Listener,
AclList: t.aclList,
SpaceStorage: t.spaceStorage,
OnClose: t.onClose,
SyncStatus: t.syncStatus,
WaitTreeRemoteSync: opts.WaitTreeRemoteSync,
PeerGetter: t.peerManager,
BuildObjectTree: treeBuilder,
}
t.treesUsed.Add(1)
t.log.Debug("incrementing counter", zap.String("id", id), zap.Int32("trees", t.treesUsed.Load()))
if ot, err = synctree.BuildSyncTreeOrGetRemote(ctx, id, deps); err != nil {
t.treesUsed.Add(-1)
t.log.Debug("decrementing counter, load failed", zap.String("id", id), zap.Int32("trees", t.treesUsed.Load()), zap.Error(err))
return nil, err
}
return
}
func (t *treeBuilder) BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (ot objecttree.HistoryTree, err error) {
if t.isClosed.Load() {
// TODO: change to real error
err = ErrSpaceClosed
return
}
params := objecttree.HistoryTreeParams{
AclList: t.aclList,
BeforeId: opts.BeforeId,
IncludeBeforeId: opts.Include,
BuildFullTree: opts.BuildFullTree,
}
params.TreeStorage, err = t.spaceStorage.TreeStorage(id)
if err != nil {
return
}
return objecttree.BuildHistoryTree(params)
}
func (t *treeBuilder) CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error) {
if t.isClosed.Load() {
err = ErrSpaceClosed
return
}
root, err := objecttree.CreateObjectTreeRoot(payload, t.aclList)
if err != nil {
return
}
res = treestorage.TreeStorageCreatePayload{
RootRawChange: root,
Changes: []*treechangeproto.RawTreeChangeWithId{root},
Heads: []string{root.Id},
}
return
}
func (t *treeBuilder) PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (ot objecttree.ObjectTree, err error) {
if t.isClosed.Load() {
err = ErrSpaceClosed
return
}
deps := synctree.BuildDeps{
SpaceId: t.spaceId,
SyncClient: t.syncClient,
Configuration: t.configuration,
HeadNotifiable: t.headsNotifiable,
Listener: listener,
AclList: t.aclList,
SpaceStorage: t.spaceStorage,
OnClose: t.onClose,
SyncStatus: t.syncStatus,
PeerGetter: t.peerManager,
BuildObjectTree: t.builder,
}
ot, err = synctree.PutSyncTree(ctx, payload, deps)
if err != nil {
return
}
t.treesUsed.Add(1)
t.log.Debug("incrementing counter", zap.String("id", payload.RootRawChange.Id), zap.Int32("trees", t.treesUsed.Load()))
return
}
func (t *treeBuilder) onClose(id string) {
t.treesUsed.Add(-1)
log.Debug("decrementing counter", zap.String("id", id), zap.Int32("trees", t.treesUsed.Load()), zap.String("spaceId", t.spaceId))
_ = t.objectSync.CloseThread(id)
}

View File

@ -214,14 +214,15 @@ func validateSpaceStorageCreatePayload(payload spacestorage.SpaceStorageCreatePa
} }
func ValidateSpaceHeader(rawHeaderWithId *spacesyncproto.RawSpaceHeaderWithId, identity crypto.PubKey) (err error) { func ValidateSpaceHeader(rawHeaderWithId *spacesyncproto.RawSpaceHeaderWithId, identity crypto.PubKey) (err error) {
if rawHeaderWithId == nil {
return spacestorage.ErrIncorrectSpaceHeader
}
sepIdx := strings.Index(rawHeaderWithId.Id, ".") sepIdx := strings.Index(rawHeaderWithId.Id, ".")
if sepIdx == -1 { if sepIdx == -1 {
err = spacestorage.ErrIncorrectSpaceHeader return spacestorage.ErrIncorrectSpaceHeader
return
} }
if !cidutil.VerifyCid(rawHeaderWithId.RawHeader, rawHeaderWithId.Id[:sepIdx]) { if !cidutil.VerifyCid(rawHeaderWithId.RawHeader, rawHeaderWithId.Id[:sepIdx]) {
err = objecttree.ErrIncorrectCid return objecttree.ErrIncorrectCid
return
} }
var rawSpaceHeader spacesyncproto.RawSpaceHeader var rawSpaceHeader spacesyncproto.RawSpaceHeader
err = proto.Unmarshal(rawHeaderWithId.RawHeader, &rawSpaceHeader) err = proto.Unmarshal(rawHeaderWithId.RawHeader, &rawSpaceHeader)
@ -239,19 +240,16 @@ func ValidateSpaceHeader(rawHeaderWithId *spacesyncproto.RawSpaceHeaderWithId, i
} }
res, err := payloadIdentity.Verify(rawSpaceHeader.SpaceHeader, rawSpaceHeader.Signature) res, err := payloadIdentity.Verify(rawSpaceHeader.SpaceHeader, rawSpaceHeader.Signature)
if err != nil || !res { if err != nil || !res {
err = spacestorage.ErrIncorrectSpaceHeader return spacestorage.ErrIncorrectSpaceHeader
return
} }
if rawHeaderWithId.Id[sepIdx+1:] != strconv.FormatUint(header.ReplicationKey, 36) { if rawHeaderWithId.Id[sepIdx+1:] != strconv.FormatUint(header.ReplicationKey, 36) {
err = spacestorage.ErrIncorrectSpaceHeader return spacestorage.ErrIncorrectSpaceHeader
return
} }
if identity == nil { if identity == nil {
return return
} }
if !payloadIdentity.Equals(identity) { if !payloadIdentity.Equals(identity) {
err = ErrIncorrectIdentity return ErrIncorrectIdentity
return
} }
return return
} }

View File

@ -8,6 +8,7 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
app "github.com/anyproto/any-sync/app"
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto" spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
peer "github.com/anyproto/any-sync/net/peer" peer "github.com/anyproto/any-sync/net/peer"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
@ -65,6 +66,34 @@ func (mr *MockPeerManagerMockRecorder) GetResponsiblePeers(arg0 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponsiblePeers", reflect.TypeOf((*MockPeerManager)(nil).GetResponsiblePeers), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponsiblePeers", reflect.TypeOf((*MockPeerManager)(nil).GetResponsiblePeers), arg0)
} }
// Init mocks base method.
func (m *MockPeerManager) Init(arg0 *app.App) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Init indicates an expected call of Init.
func (mr *MockPeerManagerMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockPeerManager)(nil).Init), arg0)
}
// Name mocks base method.
func (m *MockPeerManager) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockPeerManagerMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockPeerManager)(nil).Name))
}
// SendPeer mocks base method. // SendPeer mocks base method.
func (m *MockPeerManager) SendPeer(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) error { func (m *MockPeerManager) SendPeer(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -8,9 +8,12 @@ import (
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
) )
const CName = "common.commonspace.peermanager" const (
CName = "common.commonspace.peermanager"
)
type PeerManager interface { type PeerManager interface {
app.Component
// SendPeer sends a message to a stream by peerId // SendPeer sends a message to a stream by peerId
SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error)
// Broadcast sends a message to all subscribed peers // Broadcast sends a message to all subscribed peers

View File

@ -0,0 +1,126 @@
package requestmanager
import (
"context"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/pool"
"github.com/anyproto/any-sync/net/streampool"
"go.uber.org/zap"
"storj.io/drpc"
"sync"
)
const CName = "common.commonspace.requestmanager"
var log = logger.NewNamed(CName)
type RequestManager interface {
app.ComponentRunnable
SendRequest(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
QueueRequest(peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error)
}
func New() RequestManager {
return &requestManager{
workers: 10,
queueSize: 300,
pools: map[string]*streampool.ExecPool{},
}
}
type MessageHandler interface {
HandleMessage(ctx context.Context, hm objectsync.HandleMessage) (err error)
}
type requestManager struct {
sync.Mutex
pools map[string]*streampool.ExecPool
peerPool pool.Pool
workers int
queueSize int
handler MessageHandler
ctx context.Context
cancel context.CancelFunc
clientFactory spacesyncproto.ClientFactory
}
func (r *requestManager) Init(a *app.App) (err error) {
r.ctx, r.cancel = context.WithCancel(context.Background())
r.handler = a.MustComponent(objectsync.CName).(MessageHandler)
r.peerPool = a.MustComponent(pool.CName).(pool.Pool)
r.clientFactory = spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient)
return
}
func (r *requestManager) Name() (name string) {
return CName
}
func (r *requestManager) Run(ctx context.Context) (err error) {
return nil
}
func (r *requestManager) Close(ctx context.Context) (err error) {
r.Lock()
defer r.Unlock()
r.cancel()
for _, p := range r.pools {
_ = p.Close()
}
return nil
}
func (r *requestManager) SendRequest(ctx context.Context, peerId string, req *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
// TODO: limit concurrent sends?
return r.doRequest(ctx, peerId, req)
}
func (r *requestManager) QueueRequest(peerId string, req *spacesyncproto.ObjectSyncMessage) (err error) {
r.Lock()
defer r.Unlock()
pl, exists := r.pools[peerId]
if !exists {
pl = streampool.NewExecPool(r.workers, r.queueSize)
r.pools[peerId] = pl
pl.Run()
}
// TODO: for later think when many clients are there,
// we need to close pools for inactive clients
return pl.TryAdd(func() {
doRequestAndHandle(r, peerId, req)
})
}
var doRequestAndHandle = (*requestManager).requestAndHandle
func (r *requestManager) requestAndHandle(peerId string, req *spacesyncproto.ObjectSyncMessage) {
ctx := r.ctx
resp, err := r.doRequest(ctx, peerId, req)
if err != nil {
log.Warn("failed to send request", zap.Error(err))
return
}
ctx = peer.CtxWithPeerId(ctx, peerId)
_ = r.handler.HandleMessage(ctx, objectsync.HandleMessage{
SenderId: peerId,
Message: resp,
PeerCtx: ctx,
})
}
func (r *requestManager) doRequest(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) {
pr, err := r.peerPool.Get(ctx, peerId)
if err != nil {
return
}
err = pr.DoDrpc(ctx, func(conn drpc.Conn) error {
cl := r.clientFactory.Client(conn)
resp, err = cl.ObjectSync(ctx, msg)
return err
})
return
}

View File

@ -0,0 +1,189 @@
package requestmanager
import (
"context"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/mock_objectsync"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/spacesyncproto/mock_spacesyncproto"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/peer/mock_peer"
"github.com/anyproto/any-sync/net/pool/mock_pool"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
"sync"
"testing"
"time"
)
type fixture struct {
requestManager *requestManager
messageHandlerMock *mock_objectsync.MockObjectSync
peerPoolMock *mock_pool.MockPool
clientMock *mock_spacesyncproto.MockDRPCSpaceSyncClient
ctrl *gomock.Controller
}
func newFixture(t *testing.T) *fixture {
ctrl := gomock.NewController(t)
manager := New().(*requestManager)
peerPoolMock := mock_pool.NewMockPool(ctrl)
messageHandlerMock := mock_objectsync.NewMockObjectSync(ctrl)
clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl)
manager.peerPool = peerPoolMock
manager.handler = messageHandlerMock
manager.clientFactory = spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient {
return clientMock
})
manager.ctx, manager.cancel = context.WithCancel(context.Background())
return &fixture{
requestManager: manager,
messageHandlerMock: messageHandlerMock,
peerPoolMock: peerPoolMock,
clientMock: clientMock,
ctrl: ctrl,
}
}
func (fx *fixture) stop() {
fx.ctrl.Finish()
}
func TestRequestManager_SyncRequest(t *testing.T) {
ctx := context.Background()
t.Run("send request", func(t *testing.T) {
fx := newFixture(t)
defer fx.stop()
peerId := "peerId"
peerMock := mock_peer.NewMockPeer(fx.ctrl)
conn := &drpcconn.Conn{}
msg := &spacesyncproto.ObjectSyncMessage{}
resp := &spacesyncproto.ObjectSyncMessage{}
fx.peerPoolMock.EXPECT().Get(ctx, peerId).Return(peerMock, nil)
fx.clientMock.EXPECT().ObjectSync(ctx, msg).Return(resp, nil)
peerMock.EXPECT().DoDrpc(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, drpcHandler func(conn drpc.Conn) error) {
drpcHandler(conn)
}).Return(nil)
res, err := fx.requestManager.SendRequest(ctx, peerId, msg)
require.NoError(t, err)
require.Equal(t, resp, res)
})
t.Run("request and handle", func(t *testing.T) {
fx := newFixture(t)
defer fx.stop()
ctx = fx.requestManager.ctx
peerId := "peerId"
peerMock := mock_peer.NewMockPeer(fx.ctrl)
conn := &drpcconn.Conn{}
msg := &spacesyncproto.ObjectSyncMessage{}
resp := &spacesyncproto.ObjectSyncMessage{}
fx.peerPoolMock.EXPECT().Get(ctx, peerId).Return(peerMock, nil)
fx.clientMock.EXPECT().ObjectSync(ctx, msg).Return(resp, nil)
peerMock.EXPECT().DoDrpc(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, drpcHandler func(conn drpc.Conn) error) {
drpcHandler(conn)
}).Return(nil)
fx.messageHandlerMock.EXPECT().HandleMessage(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, msg objectsync.HandleMessage) {
require.Equal(t, peerId, msg.SenderId)
require.Equal(t, resp, msg.Message)
pId, _ := peer.CtxPeerId(msg.PeerCtx)
require.Equal(t, peerId, pId)
}).Return(nil)
fx.requestManager.requestAndHandle(peerId, msg)
})
}
func TestRequestManager_QueueRequest(t *testing.T) {
t.Run("max concurrent reqs for peer, independent reqs for other peer", func(t *testing.T) {
// testing 2 concurrent requests to one peer and simultaneous to another peer
fx := newFixture(t)
defer fx.stop()
fx.requestManager.workers = 2
msgRelease := make(chan struct{})
msgWait := make(chan struct{})
msgs := sync.Map{}
doRequestAndHandle = func(manager *requestManager, peerId string, req *spacesyncproto.ObjectSyncMessage) {
msgs.Store(req.ObjectId, struct{}{})
<-msgWait
<-msgRelease
}
otherPeer := "otherPeer"
msg1 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id1"}
msg2 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id2"}
msg3 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id3"}
otherMsg1 := &spacesyncproto.ObjectSyncMessage{ObjectId: "otherId1"}
// sending requests to first peer
peerId := "peerId"
err := fx.requestManager.QueueRequest(peerId, msg1)
require.NoError(t, err)
err = fx.requestManager.QueueRequest(peerId, msg2)
require.NoError(t, err)
err = fx.requestManager.QueueRequest(peerId, msg3)
require.NoError(t, err)
// waiting until all the messages are loaded
msgWait <- struct{}{}
msgWait <- struct{}{}
_, ok := msgs.Load("id1")
require.True(t, ok)
_, ok = msgs.Load("id2")
require.True(t, ok)
// third message should not be read
_, ok = msgs.Load("id3")
require.False(t, ok)
// request for other peer should pass
err = fx.requestManager.QueueRequest(otherPeer, otherMsg1)
require.NoError(t, err)
msgWait <- struct{}{}
_, ok = msgs.Load("otherId1")
require.True(t, ok)
close(msgRelease)
})
t.Run("no requests after close", func(t *testing.T) {
fx := newFixture(t)
defer fx.stop()
fx.requestManager.workers = 1
msgRelease := make(chan struct{})
msgWait := make(chan struct{})
msgs := sync.Map{}
doRequestAndHandle = func(manager *requestManager, peerId string, req *spacesyncproto.ObjectSyncMessage) {
msgs.Store(req.ObjectId, struct{}{})
<-msgWait
<-msgRelease
}
msg1 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id1"}
msg2 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id2"}
// sending requests to first peer
peerId := "peerId"
err := fx.requestManager.QueueRequest(peerId, msg1)
require.NoError(t, err)
err = fx.requestManager.QueueRequest(peerId, msg2)
require.NoError(t, err)
// waiting until all the message is loaded
msgWait <- struct{}{}
_, ok := msgs.Load("id1")
require.True(t, ok)
_, ok = msgs.Load("id2")
require.False(t, ok)
fx.requestManager.Close(context.Background())
close(msgRelease)
// waiting to know if the second one is not taken
// because the manager is now closed
time.Sleep(100 * time.Millisecond)
_, ok = msgs.Load("id2")
require.False(t, ok)
})
}

View File

@ -2,9 +2,9 @@ package settings
import ( import (
"context" "context"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -15,11 +15,11 @@ type Deleter interface {
type deleter struct { type deleter struct {
st spacestorage.SpaceStorage st spacestorage.SpaceStorage
state settingsstate.ObjectDeletionState state deletionstate.ObjectDeletionState
getter treemanager.TreeManager getter treemanager.TreeManager
} }
func newDeleter(st spacestorage.SpaceStorage, state settingsstate.ObjectDeletionState, getter treemanager.TreeManager) Deleter { func newDeleter(st spacestorage.SpaceStorage, state deletionstate.ObjectDeletionState, getter treemanager.TreeManager) Deleter {
return &deleter{st, state, getter} return &deleter{st, state, getter}
} }

View File

@ -2,9 +2,9 @@ package settings
import ( import (
"fmt" "fmt"
"github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"testing" "testing"
@ -14,7 +14,7 @@ func TestDeleter_Delete(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
treeManager := mock_treemanager.NewMockTreeManager(ctrl) treeManager := mock_treemanager.NewMockTreeManager(ctrl)
st := mock_spacestorage.NewMockSpaceStorage(ctrl) st := mock_spacestorage.NewMockSpaceStorage(ctrl)
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) delState := mock_deletionstate.NewMockObjectDeletionState(ctrl)
deleter := newDeleter(st, delState, treeManager) deleter := newDeleter(st, delState, treeManager)

View File

@ -2,6 +2,7 @@ package settings
import ( import (
"context" "context"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"go.uber.org/zap" "go.uber.org/zap"
@ -20,7 +21,7 @@ func newDeletionManager(
settingsId string, settingsId string,
isResponsible bool, isResponsible bool,
treeManager treemanager.TreeManager, treeManager treemanager.TreeManager,
deletionState settingsstate.ObjectDeletionState, deletionState deletionstate.ObjectDeletionState,
provider SpaceIdsProvider, provider SpaceIdsProvider,
onSpaceDelete func()) DeletionManager { onSpaceDelete func()) DeletionManager {
return &deletionManager{ return &deletionManager{
@ -35,7 +36,7 @@ func newDeletionManager(
} }
type deletionManager struct { type deletionManager struct {
deletionState settingsstate.ObjectDeletionState deletionState deletionstate.ObjectDeletionState
provider SpaceIdsProvider provider SpaceIdsProvider
treeManager treemanager.TreeManager treeManager treemanager.TreeManager
spaceId string spaceId string

View File

@ -2,10 +2,10 @@ package settings
import ( import (
"context" "context"
"github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate"
"github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager"
"github.com/anyproto/any-sync/commonspace/settings/mock_settings" "github.com/anyproto/any-sync/commonspace/settings/mock_settings"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate"
"github.com/golang/mock/gomock" "github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing" "testing"
@ -26,7 +26,7 @@ func TestDeletionManager_UpdateState_NotResponsible(t *testing.T) {
onDeleted := func() { onDeleted := func() {
deleted = true deleted = true
} }
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) delState := mock_deletionstate.NewMockObjectDeletionState(ctrl)
treeManager := mock_treemanager.NewMockTreeManager(ctrl) treeManager := mock_treemanager.NewMockTreeManager(ctrl)
delState.EXPECT().Add(state.DeletedIds) delState.EXPECT().Add(state.DeletedIds)
@ -58,7 +58,7 @@ func TestDeletionManager_UpdateState_Responsible(t *testing.T) {
onDeleted := func() { onDeleted := func() {
deleted = true deleted = true
} }
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) delState := mock_deletionstate.NewMockObjectDeletionState(ctrl)
treeManager := mock_treemanager.NewMockTreeManager(ctrl) treeManager := mock_treemanager.NewMockTreeManager(ctrl)
provider := mock_settings.NewMockSpaceIdsProvider(ctrl) provider := mock_settings.NewMockSpaceIdsProvider(ctrl)

View File

@ -1,328 +1,122 @@
//go:generate mockgen -destination mock_settings/mock_settings.go github.com/anyproto/any-sync/commonspace/settings DeletionManager,Deleter,SpaceIdsProvider
package settings package settings
import ( import (
"context" "context"
"errors"
"fmt"
"github.com/anyproto/any-sync/util/crypto"
"github.com/anyproto/any-sync/accountservice" "github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/headsync"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree" "github.com/anyproto/any-sync/commonspace/object/tree/synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/objecttreebuilder"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
"github.com/gogo/protobuf/proto"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/exp/slices" "sync/atomic"
) )
var log = logger.NewNamed("common.commonspace.settings") const CName = "common.commonspace.settings"
type SettingsObject interface { type Settings interface {
synctree.SyncTree DeleteTree(ctx context.Context, id string) (err error)
Init(ctx context.Context) (err error) SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error)
DeleteObject(id string) (err error) DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error)
DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error) SettingsObject() SettingsObject
SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error) app.ComponentRunnable
} }
var ( func New() Settings {
ErrDeleteSelf = errors.New("cannot delete self") return &settings{}
ErrAlreadyDeleted = errors.New("the object is already deleted")
ErrObjDoesNotExist = errors.New("the object does not exist")
ErrCantDeleteSpace = errors.New("not able to delete space")
)
var (
DoSnapshot = objecttree.DoSnapshot
buildHistoryTree = func(objTree objecttree.ObjectTree) (objecttree.ReadableObjectTree, error) {
return objecttree.BuildHistoryTree(objecttree.HistoryTreeParams{
TreeStorage: objTree.Storage(),
AclList: objTree.AclList(),
BuildFullTree: true,
})
}
)
type BuildTreeFunc func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error)
type Deps struct {
BuildFunc BuildTreeFunc
Account accountservice.Service
TreeManager treemanager.TreeManager
Store spacestorage.SpaceStorage
Configuration nodeconf.NodeConf
DeletionState settingsstate.ObjectDeletionState
Provider SpaceIdsProvider
OnSpaceDelete func()
// testing dependencies
builder settingsstate.StateBuilder
del Deleter
delManager DeletionManager
changeFactory settingsstate.ChangeFactory
} }
type settingsObject struct { type settings struct {
synctree.SyncTree account accountservice.Service
account accountservice.Service treeManager treemanager.TreeManager
spaceId string storage spacestorage.SpaceStorage
treeManager treemanager.TreeManager configuration nodeconf.NodeConf
store spacestorage.SpaceStorage deletionState deletionstate.ObjectDeletionState
builder settingsstate.StateBuilder headsync headsync.HeadSync
buildFunc BuildTreeFunc treeBuilder objecttreebuilder.TreeBuilderComponent
loop *deleteLoop spaceIsDeleted *atomic.Bool
state *settingsstate.State settingsObject SettingsObject
deletionState settingsstate.ObjectDeletionState
deletionManager DeletionManager
changeFactory settingsstate.ChangeFactory
} }
func NewSettingsObject(deps Deps, spaceId string) (obj SettingsObject) { func (s *settings) Init(a *app.App) (err error) {
var ( s.account = a.MustComponent(accountservice.CName).(accountservice.Service)
deleter Deleter s.treeManager = app.MustComponent[treemanager.TreeManager](a)
deletionManager DeletionManager s.headsync = a.MustComponent(headsync.CName).(headsync.HeadSync)
builder settingsstate.StateBuilder s.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
changeFactory settingsstate.ChangeFactory s.deletionState = a.MustComponent(deletionstate.CName).(deletionstate.ObjectDeletionState)
) s.treeBuilder = a.MustComponent(objecttreebuilder.CName).(objecttreebuilder.TreeBuilderComponent)
if deps.del == nil {
deleter = newDeleter(deps.Store, deps.DeletionState, deps.TreeManager)
} else {
deleter = deps.del
}
if deps.delManager == nil {
deletionManager = newDeletionManager(
spaceId,
deps.Store.SpaceSettingsId(),
deps.Configuration.IsResponsible(spaceId),
deps.TreeManager,
deps.DeletionState,
deps.Provider,
deps.OnSpaceDelete)
} else {
deletionManager = deps.delManager
}
if deps.builder == nil {
builder = settingsstate.NewStateBuilder()
} else {
builder = deps.builder
}
if deps.changeFactory == nil {
changeFactory = settingsstate.NewChangeFactory()
} else {
changeFactory = deps.changeFactory
}
loop := newDeleteLoop(func() { sharedState := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
deleter.Delete() s.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
}) s.spaceIsDeleted = sharedState.SpaceIsDeleted
deps.DeletionState.AddObserver(func(ids []string) {
loop.notify()
})
s := &settingsObject{ deps := Deps{
loop: loop, BuildFunc: func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error) {
spaceId: spaceId, res, err := s.treeBuilder.BuildTree(ctx, id, objecttreebuilder.BuildTreeOpts{
account: deps.Account, Listener: listener,
deletionState: deps.DeletionState, WaitTreeRemoteSync: false,
treeManager: deps.TreeManager, // space settings document should not have empty data
store: deps.Store, TreeBuilder: objecttree.BuildObjectTree,
buildFunc: deps.BuildFunc, })
builder: builder, log.Debug("building settings tree", zap.String("id", id), zap.String("spaceId", sharedState.SpaceId))
deletionManager: deletionManager, if err != nil {
changeFactory: changeFactory, return
} }
obj = s t = res.(synctree.SyncTree)
return
}
func (s *settingsObject) updateIds(tr objecttree.ObjectTree) {
var err error
s.state, err = s.builder.Build(tr, s.state)
if err != nil {
log.Error("failed to build state", zap.Error(err))
return
}
log.Debug("updating object state", zap.String("deleted by", s.state.DeleterId))
if err = s.deletionManager.UpdateState(context.Background(), s.state); err != nil {
log.Error("failed to update state", zap.Error(err))
}
}
// Update is called as part of UpdateListener interface
func (s *settingsObject) Update(tr objecttree.ObjectTree) {
s.updateIds(tr)
}
// Rebuild is called as part of UpdateListener interface (including when the object is built for the first time, e.g. on Init call)
func (s *settingsObject) Rebuild(tr objecttree.ObjectTree) {
// at initial build "s" may not contain the object tree, so it is safer to provide it from the function parameter
s.state = nil
s.updateIds(tr)
}
func (s *settingsObject) Init(ctx context.Context) (err error) {
settingsId := s.store.SpaceSettingsId()
log.Debug("space settings id", zap.String("id", settingsId))
s.SyncTree, err = s.buildFunc(ctx, settingsId, s)
if err != nil {
return
}
// TODO: remove this check when everybody updates
if err = s.checkHistoryState(ctx); err != nil {
return
}
s.loop.Run()
return
}
func (s *settingsObject) checkHistoryState(ctx context.Context) (err error) {
historyTree, err := buildHistoryTree(s.SyncTree)
if err != nil {
return
}
fullState, err := s.builder.Build(historyTree, nil)
if err != nil {
return
}
if len(fullState.DeletedIds) != len(s.state.DeletedIds) {
log.WarnCtx(ctx, "state does not have all deleted ids",
zap.Int("fullstate ids", len(fullState.DeletedIds)),
zap.Int("state ids", len(fullState.DeletedIds)))
s.state = fullState
err = s.deletionManager.UpdateState(context.Background(), s.state)
if err != nil {
return return
} },
Account: s.account,
TreeManager: s.treeManager,
Store: s.storage,
Configuration: s.configuration,
DeletionState: s.deletionState,
Provider: s.headsync,
OnSpaceDelete: s.onSpaceDelete,
} }
return s.settingsObject = NewSettingsObject(deps, sharedState.SpaceId)
return nil
} }
func (s *settingsObject) Close() error { func (s *settings) Name() (name string) {
s.loop.Close() return CName
return s.SyncTree.Close()
} }
func (s *settingsObject) DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error) { func (s *settings) Run(ctx context.Context) (err error) {
s.Lock() return s.settingsObject.Init(ctx)
defer s.Unlock()
defer func() {
log.Debug("finished adding delete change", zap.Error(err))
}()
err = s.verifyDeleteSpace(raw)
if err != nil {
return
}
res, err := s.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: []string{raw.Id},
RawChanges: []*treechangeproto.RawTreeChangeWithId{raw},
})
if err != nil {
return
}
if !slices.Contains(res.Heads, raw.Id) {
err = ErrCantDeleteSpace
return
}
return
} }
func (s *settingsObject) SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error) { func (s *settings) Close(ctx context.Context) (err error) {
accountData := s.account.Account() return s.settingsObject.Close()
data, err := s.changeFactory.CreateSpaceDeleteChange(accountData.PeerId, s.state, false)
if err != nil {
return
}
return s.PrepareChange(objecttree.SignableChangeContent{
Data: data,
Key: accountData.SignKey,
IsSnapshot: false,
IsEncrypted: false,
})
} }
func (s *settingsObject) DeleteObject(id string) (err error) { func (s *settings) DeleteTree(ctx context.Context, id string) (err error) {
s.Lock() return s.settingsObject.DeleteObject(id)
defer s.Unlock()
if s.Id() == id {
err = ErrDeleteSelf
return
}
if s.state.Exists(id) {
err = ErrAlreadyDeleted
return nil
}
_, err = s.store.TreeStorage(id)
if err != nil {
err = ErrObjDoesNotExist
return
}
isSnapshot := DoSnapshot(s.Len())
res, err := s.changeFactory.CreateObjectDeleteChange(id, s.state, isSnapshot)
if err != nil {
return
}
return s.addContent(res, isSnapshot)
} }
func (s *settingsObject) verifyDeleteSpace(raw *treechangeproto.RawTreeChangeWithId) (err error) { func (s *settings) SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) {
data, err := s.UnpackChange(raw) return s.settingsObject.SpaceDeleteRawChange()
if err != nil {
return
}
return verifyDeleteContent(data, "")
} }
func (s *settingsObject) addContent(data []byte, isSnapshot bool) (err error) { func (s *settings) DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) {
accountData := s.account.Account() return s.settingsObject.DeleteSpace(ctx, deleteChange)
res, err := s.AddContent(context.Background(), objecttree.SignableChangeContent{
Data: data,
Key: accountData.SignKey,
IsSnapshot: isSnapshot,
IsEncrypted: false,
})
if err != nil {
return
}
if res.Mode == objecttree.Rebuild {
s.Rebuild(s)
} else {
s.Update(s)
}
return
} }
func VerifyDeleteChange(raw *treechangeproto.RawTreeChangeWithId, identity crypto.PubKey, peerId string) (err error) { func (s *settings) onSpaceDelete() {
changeBuilder := objecttree.NewChangeBuilder(crypto.NewKeyStorage(), nil) err := s.storage.SetSpaceDeleted()
res, err := changeBuilder.Unmarshall(raw, true)
if err != nil { if err != nil {
return log.Warn("failed to set space deleted")
} }
if !res.Identity.Equals(identity) { s.spaceIsDeleted.Swap(true)
return fmt.Errorf("incorrect identity")
}
return verifyDeleteContent(res.Data, peerId)
} }
func verifyDeleteContent(data []byte, peerId string) (err error) { func (s *settings) SettingsObject() SettingsObject {
content := &spacesyncproto.SettingsData{} return s.settingsObject
err = proto.Unmarshal(data, content)
if err != nil {
return
}
if len(content.GetContent()) != 1 ||
content.GetContent()[0].GetSpaceDelete() == nil ||
(peerId == "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() == "") ||
(peerId != "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() != peerId) {
return fmt.Errorf("incorrect delete change payload")
}
return
} }

View File

@ -0,0 +1,329 @@
//go:generate mockgen -destination mock_settings/mock_settings.go github.com/anyproto/any-sync/commonspace/settings DeletionManager,Deleter,SpaceIdsProvider
package settings
import (
"context"
"errors"
"fmt"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/util/crypto"
"github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/nodeconf"
"github.com/gogo/protobuf/proto"
"go.uber.org/zap"
"golang.org/x/exp/slices"
)
var log = logger.NewNamed("common.commonspace.settings")
type SettingsObject interface {
synctree.SyncTree
Init(ctx context.Context) (err error)
DeleteObject(id string) (err error)
DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error)
SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error)
}
var (
ErrDeleteSelf = errors.New("cannot delete self")
ErrAlreadyDeleted = errors.New("the object is already deleted")
ErrObjDoesNotExist = errors.New("the object does not exist")
ErrCantDeleteSpace = errors.New("not able to delete space")
)
var (
DoSnapshot = objecttree.DoSnapshot
buildHistoryTree = func(objTree objecttree.ObjectTree) (objecttree.ReadableObjectTree, error) {
return objecttree.BuildHistoryTree(objecttree.HistoryTreeParams{
TreeStorage: objTree.Storage(),
AclList: objTree.AclList(),
BuildFullTree: true,
})
}
)
type BuildTreeFunc func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error)
type Deps struct {
BuildFunc BuildTreeFunc
Account accountservice.Service
TreeManager treemanager.TreeManager
Store spacestorage.SpaceStorage
Configuration nodeconf.NodeConf
DeletionState deletionstate.ObjectDeletionState
Provider SpaceIdsProvider
OnSpaceDelete func()
// testing dependencies
builder settingsstate.StateBuilder
del Deleter
delManager DeletionManager
changeFactory settingsstate.ChangeFactory
}
type settingsObject struct {
synctree.SyncTree
account accountservice.Service
spaceId string
treeManager treemanager.TreeManager
store spacestorage.SpaceStorage
builder settingsstate.StateBuilder
buildFunc BuildTreeFunc
loop *deleteLoop
state *settingsstate.State
deletionState deletionstate.ObjectDeletionState
deletionManager DeletionManager
changeFactory settingsstate.ChangeFactory
}
func NewSettingsObject(deps Deps, spaceId string) (obj SettingsObject) {
var (
deleter Deleter
deletionManager DeletionManager
builder settingsstate.StateBuilder
changeFactory settingsstate.ChangeFactory
)
if deps.del == nil {
deleter = newDeleter(deps.Store, deps.DeletionState, deps.TreeManager)
} else {
deleter = deps.del
}
if deps.delManager == nil {
deletionManager = newDeletionManager(
spaceId,
deps.Store.SpaceSettingsId(),
deps.Configuration.IsResponsible(spaceId),
deps.TreeManager,
deps.DeletionState,
deps.Provider,
deps.OnSpaceDelete)
} else {
deletionManager = deps.delManager
}
if deps.builder == nil {
builder = settingsstate.NewStateBuilder()
} else {
builder = deps.builder
}
if deps.changeFactory == nil {
changeFactory = settingsstate.NewChangeFactory()
} else {
changeFactory = deps.changeFactory
}
loop := newDeleteLoop(func() {
deleter.Delete()
})
deps.DeletionState.AddObserver(func(ids []string) {
loop.notify()
})
s := &settingsObject{
loop: loop,
spaceId: spaceId,
account: deps.Account,
deletionState: deps.DeletionState,
treeManager: deps.TreeManager,
store: deps.Store,
buildFunc: deps.BuildFunc,
builder: builder,
deletionManager: deletionManager,
changeFactory: changeFactory,
}
obj = s
return
}
func (s *settingsObject) updateIds(tr objecttree.ObjectTree) {
var err error
s.state, err = s.builder.Build(tr, s.state)
if err != nil {
log.Error("failed to build state", zap.Error(err))
return
}
log.Debug("updating object state", zap.String("deleted by", s.state.DeleterId))
if err = s.deletionManager.UpdateState(context.Background(), s.state); err != nil {
log.Error("failed to update state", zap.Error(err))
}
}
// Update is called as part of UpdateListener interface
func (s *settingsObject) Update(tr objecttree.ObjectTree) {
s.updateIds(tr)
}
// Rebuild is called as part of UpdateListener interface (including when the object is built for the first time, e.g. on Init call)
func (s *settingsObject) Rebuild(tr objecttree.ObjectTree) {
// at initial build "s" may not contain the object tree, so it is safer to provide it from the function parameter
s.state = nil
s.updateIds(tr)
}
func (s *settingsObject) Init(ctx context.Context) (err error) {
settingsId := s.store.SpaceSettingsId()
log.Debug("space settings id", zap.String("id", settingsId))
s.SyncTree, err = s.buildFunc(ctx, settingsId, s)
if err != nil {
return
}
// TODO: remove this check when everybody updates
if err = s.checkHistoryState(ctx); err != nil {
return
}
s.loop.Run()
return
}
func (s *settingsObject) checkHistoryState(ctx context.Context) (err error) {
historyTree, err := buildHistoryTree(s.SyncTree)
if err != nil {
return
}
fullState, err := s.builder.Build(historyTree, nil)
if err != nil {
return
}
if len(fullState.DeletedIds) != len(s.state.DeletedIds) {
log.WarnCtx(ctx, "state does not have all deleted ids",
zap.Int("fullstate ids", len(fullState.DeletedIds)),
zap.Int("state ids", len(fullState.DeletedIds)))
s.state = fullState
err = s.deletionManager.UpdateState(context.Background(), s.state)
if err != nil {
return
}
}
return
}
func (s *settingsObject) Close() error {
s.loop.Close()
return s.SyncTree.Close()
}
func (s *settingsObject) DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error) {
s.Lock()
defer s.Unlock()
defer func() {
log.Debug("finished adding delete change", zap.Error(err))
}()
err = s.verifyDeleteSpace(raw)
if err != nil {
return
}
res, err := s.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: []string{raw.Id},
RawChanges: []*treechangeproto.RawTreeChangeWithId{raw},
})
if err != nil {
return
}
if !slices.Contains(res.Heads, raw.Id) {
err = ErrCantDeleteSpace
return
}
return
}
func (s *settingsObject) SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error) {
accountData := s.account.Account()
data, err := s.changeFactory.CreateSpaceDeleteChange(accountData.PeerId, s.state, false)
if err != nil {
return
}
return s.PrepareChange(objecttree.SignableChangeContent{
Data: data,
Key: accountData.SignKey,
IsSnapshot: false,
IsEncrypted: false,
})
}
func (s *settingsObject) DeleteObject(id string) (err error) {
s.Lock()
defer s.Unlock()
if s.Id() == id {
err = ErrDeleteSelf
return
}
if s.state.Exists(id) {
err = ErrAlreadyDeleted
return nil
}
_, err = s.store.TreeStorage(id)
if err != nil {
err = ErrObjDoesNotExist
return
}
isSnapshot := DoSnapshot(s.Len())
res, err := s.changeFactory.CreateObjectDeleteChange(id, s.state, isSnapshot)
if err != nil {
return
}
return s.addContent(res, isSnapshot)
}
func (s *settingsObject) verifyDeleteSpace(raw *treechangeproto.RawTreeChangeWithId) (err error) {
data, err := s.UnpackChange(raw)
if err != nil {
return
}
return verifyDeleteContent(data, "")
}
func (s *settingsObject) addContent(data []byte, isSnapshot bool) (err error) {
accountData := s.account.Account()
res, err := s.AddContent(context.Background(), objecttree.SignableChangeContent{
Data: data,
Key: accountData.SignKey,
IsSnapshot: isSnapshot,
IsEncrypted: false,
})
if err != nil {
return
}
if res.Mode == objecttree.Rebuild {
s.Rebuild(s)
} else {
s.Update(s)
}
return
}
func VerifyDeleteChange(raw *treechangeproto.RawTreeChangeWithId, identity crypto.PubKey, peerId string) (err error) {
changeBuilder := objecttree.NewChangeBuilder(crypto.NewKeyStorage(), nil)
res, err := changeBuilder.Unmarshall(raw, true)
if err != nil {
return
}
if !res.Identity.Equals(identity) {
return fmt.Errorf("incorrect identity")
}
return verifyDeleteContent(res.Data, peerId)
}
func verifyDeleteContent(data []byte, peerId string) (err error) {
content := &spacesyncproto.SettingsData{}
err = proto.Unmarshal(data, content)
if err != nil {
return
}
if len(content.GetContent()) != 1 ||
content.GetContent()[0].GetSpaceDelete() == nil ||
(peerId == "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() == "") ||
(peerId != "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() != peerId) {
return fmt.Errorf("incorrect delete change payload")
}
return
}

View File

@ -3,6 +3,7 @@ package settings
import ( import (
"context" "context"
"github.com/anyproto/any-sync/accountservice/mock_accountservice" "github.com/anyproto/any-sync/accountservice/mock_accountservice"
"github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate"
"github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree"
@ -54,7 +55,7 @@ type settingsFixture struct {
deleter *mock_settings.MockDeleter deleter *mock_settings.MockDeleter
syncTree *mock_synctree.MockSyncTree syncTree *mock_synctree.MockSyncTree
historyTree *mock_objecttree.MockObjectTree historyTree *mock_objecttree.MockObjectTree
delState *mock_settingsstate.MockObjectDeletionState delState *mock_deletionstate.MockObjectDeletionState
account *mock_accountservice.MockService account *mock_accountservice.MockService
} }
@ -66,7 +67,7 @@ func newSettingsFixture(t *testing.T) *settingsFixture {
acc := mock_accountservice.NewMockService(ctrl) acc := mock_accountservice.NewMockService(ctrl)
treeManager := mock_treemanager.NewMockTreeManager(ctrl) treeManager := mock_treemanager.NewMockTreeManager(ctrl)
st := mock_spacestorage.NewMockSpaceStorage(ctrl) st := mock_spacestorage.NewMockSpaceStorage(ctrl)
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) delState := mock_deletionstate.NewMockObjectDeletionState(ctrl)
delManager := mock_settings.NewMockDeletionManager(ctrl) delManager := mock_settings.NewMockDeletionManager(ctrl)
stateBuilder := mock_settingsstate.NewMockStateBuilder(ctrl) stateBuilder := mock_settingsstate.NewMockStateBuilder(ctrl)
changeFactory := mock_settingsstate.NewMockChangeFactory(ctrl) changeFactory := mock_settingsstate.NewMockChangeFactory(ctrl)

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/settings/settingsstate (interfaces: ObjectDeletionState,StateBuilder,ChangeFactory) // Source: github.com/anyproto/any-sync/commonspace/settings/settingsstate (interfaces: StateBuilder,ChangeFactory)
// Package mock_settingsstate is a generated GoMock package. // Package mock_settingsstate is a generated GoMock package.
package mock_settingsstate package mock_settingsstate
@ -12,109 +12,6 @@ import (
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
) )
// MockObjectDeletionState is a mock of ObjectDeletionState interface.
type MockObjectDeletionState struct {
ctrl *gomock.Controller
recorder *MockObjectDeletionStateMockRecorder
}
// MockObjectDeletionStateMockRecorder is the mock recorder for MockObjectDeletionState.
type MockObjectDeletionStateMockRecorder struct {
mock *MockObjectDeletionState
}
// NewMockObjectDeletionState creates a new mock instance.
func NewMockObjectDeletionState(ctrl *gomock.Controller) *MockObjectDeletionState {
mock := &MockObjectDeletionState{ctrl: ctrl}
mock.recorder = &MockObjectDeletionStateMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockObjectDeletionState) EXPECT() *MockObjectDeletionStateMockRecorder {
return m.recorder
}
// Add mocks base method.
func (m *MockObjectDeletionState) Add(arg0 map[string]struct{}) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0)
}
// Add indicates an expected call of Add.
func (mr *MockObjectDeletionStateMockRecorder) Add(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockObjectDeletionState)(nil).Add), arg0)
}
// AddObserver mocks base method.
func (m *MockObjectDeletionState) AddObserver(arg0 settingsstate.StateUpdateObserver) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddObserver", arg0)
}
// AddObserver indicates an expected call of AddObserver.
func (mr *MockObjectDeletionStateMockRecorder) AddObserver(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddObserver", reflect.TypeOf((*MockObjectDeletionState)(nil).AddObserver), arg0)
}
// Delete mocks base method.
func (m *MockObjectDeletionState) Delete(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockObjectDeletionStateMockRecorder) Delete(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockObjectDeletionState)(nil).Delete), arg0)
}
// Exists mocks base method.
func (m *MockObjectDeletionState) Exists(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Exists", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// Exists indicates an expected call of Exists.
func (mr *MockObjectDeletionStateMockRecorder) Exists(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockObjectDeletionState)(nil).Exists), arg0)
}
// Filter mocks base method.
func (m *MockObjectDeletionState) Filter(arg0 []string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Filter", arg0)
ret0, _ := ret[0].([]string)
return ret0
}
// Filter indicates an expected call of Filter.
func (mr *MockObjectDeletionStateMockRecorder) Filter(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filter", reflect.TypeOf((*MockObjectDeletionState)(nil).Filter), arg0)
}
// GetQueued mocks base method.
func (m *MockObjectDeletionState) GetQueued() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetQueued")
ret0, _ := ret[0].([]string)
return ret0
}
// GetQueued indicates an expected call of GetQueued.
func (mr *MockObjectDeletionStateMockRecorder) GetQueued() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQueued", reflect.TypeOf((*MockObjectDeletionState)(nil).GetQueued))
}
// MockStateBuilder is a mock of StateBuilder interface. // MockStateBuilder is a mock of StateBuilder interface.
type MockStateBuilder struct { type MockStateBuilder struct {
ctrl *gomock.Controller ctrl *gomock.Controller

View File

@ -1,3 +1,4 @@
//go:generate mockgen -destination mock_settingsstate/mock_settingsstate.go github.com/anyproto/any-sync/commonspace/settings/settingsstate StateBuilder,ChangeFactory
package settingsstate package settingsstate
import "github.com/anyproto/any-sync/commonspace/spacesyncproto" import "github.com/anyproto/any-sync/commonspace/spacesyncproto"

View File

@ -2,44 +2,26 @@ package commonspace
import ( import (
"context" "context"
"errors" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/headsync" "github.com/anyproto/any-sync/commonspace/headsync"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl" "github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/objectsync" "github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/peermanager" "github.com/anyproto/any-sync/commonspace/objecttreebuilder"
"github.com/anyproto/any-sync/commonspace/settings" "github.com/anyproto/any-sync/commonspace/settings"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/metric"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/util/crypto" "github.com/anyproto/any-sync/util/crypto"
"github.com/anyproto/any-sync/util/multiqueue"
"github.com/anyproto/any-sync/util/slice"
"github.com/cheggaaa/mb/v3"
"github.com/zeebo/errs"
"go.uber.org/zap" "go.uber.org/zap"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
var (
ErrSpaceClosed = errors.New("space is closed")
)
type SpaceCreatePayload struct { type SpaceCreatePayload struct {
// SigningKey is the signing key of the owner // SigningKey is the signing key of the owner
SigningKey crypto.PrivKey SigningKey crypto.PrivKey
@ -55,25 +37,6 @@ type SpaceCreatePayload struct {
MasterKey crypto.PrivKey MasterKey crypto.PrivKey
} }
type HandleMessage struct {
Id uint64
ReceiveTime time.Time
StartHandlingTime time.Time
Deadline time.Time
SenderId string
Message *spacesyncproto.ObjectSyncMessage
PeerCtx context.Context
}
func (m HandleMessage) LogFields(fields ...zap.Field) []zap.Field {
return append(fields,
metric.SpaceId(m.Message.SpaceId),
metric.ObjectId(m.Message.ObjectId),
metric.QueueDur(m.StartHandlingTime.Sub(m.ReceiveTime)),
metric.TotalDur(time.Since(m.ReceiveTime)),
)
}
type SpaceDerivePayload struct { type SpaceDerivePayload struct {
SigningKey crypto.PrivKey SigningKey crypto.PrivKey
MasterKey crypto.PrivKey MasterKey crypto.PrivKey
@ -99,55 +62,38 @@ type Space interface {
StoredIds() []string StoredIds() []string
DebugAllHeads() []headsync.TreeHeads DebugAllHeads() []headsync.TreeHeads
Description() (SpaceDescription, error) Description() (desc SpaceDescription, err error)
CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error) TreeBuilder() objecttreebuilder.TreeBuilder
PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (t objecttree.ObjectTree, err error)
BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (t objecttree.ObjectTree, err error)
DeleteTree(ctx context.Context, id string) (err error)
BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (t objecttree.HistoryTree, err error)
SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error)
DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error)
HeadSync() headsync.HeadSync
ObjectSync() objectsync.ObjectSync
SyncStatus() syncstatus.StatusUpdater SyncStatus() syncstatus.StatusUpdater
Storage() spacestorage.SpaceStorage Storage() spacestorage.SpaceStorage
HandleMessage(ctx context.Context, msg HandleMessage) (err error) DeleteTree(ctx context.Context, id string) (err error)
SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error)
DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error)
HandleMessage(ctx context.Context, msg objectsync.HandleMessage) (err error)
HandleSyncRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error)
HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error)
TryClose(objectTTL time.Duration) (close bool, err error) TryClose(objectTTL time.Duration) (close bool, err error)
Close() error Close() error
} }
type space struct { type space struct {
id string
mu sync.RWMutex mu sync.RWMutex
header *spacesyncproto.RawSpaceHeaderWithId header *spacesyncproto.RawSpaceHeaderWithId
objectSync objectsync.ObjectSync state *spacestate.SpaceState
headSync headsync.HeadSync app *app.App
syncStatus syncstatus.StatusUpdater
storage spacestorage.SpaceStorage
treeManager *commonGetter
account accountservice.Service
aclList *syncacl.SyncAcl
configuration nodeconf.NodeConf
settingsObject settings.SettingsObject
peerManager peermanager.PeerManager
treeBuilder objecttree.BuildObjectTreeFunc
metric metric.Metric
handleQueue multiqueue.MultiQueue[HandleMessage] treeBuilder objecttreebuilder.TreeBuilderComponent
headSync headsync.HeadSync
isClosed *atomic.Bool objectSync objectsync.ObjectSync
isDeleted *atomic.Bool syncStatus syncstatus.StatusService
treesUsed *atomic.Int32 settings settings.Settings
} storage spacestorage.SpaceStorage
aclList list.AclList
func (s *space) Id() string {
return s.id
} }
func (s *space) Description() (desc SpaceDescription, err error) { func (s *space) Description() (desc SpaceDescription, err error) {
@ -171,72 +117,60 @@ func (s *space) Description() (desc SpaceDescription, err error) {
return return
} }
func (s *space) StoredIds() []string {
return s.headSync.ExternalIds()
}
func (s *space) DebugAllHeads() []headsync.TreeHeads {
return s.headSync.DebugAllHeads()
}
func (s *space) DeleteTree(ctx context.Context, id string) (err error) {
return s.settings.DeleteTree(ctx, id)
}
func (s *space) SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) {
return s.settings.SpaceDeleteRawChange(ctx)
}
func (s *space) DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) {
return s.settings.DeleteSpace(ctx, deleteChange)
}
func (s *space) HandleMessage(ctx context.Context, msg objectsync.HandleMessage) (err error) {
return s.objectSync.HandleMessage(ctx, msg)
}
func (s *space) HandleSyncRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) {
return s.objectSync.HandleRequest(ctx, req)
}
func (s *space) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) {
return s.headSync.HandleRangeRequest(ctx, req)
}
func (s *space) TreeBuilder() objecttreebuilder.TreeBuilder {
return s.treeBuilder
}
func (s *space) Id() string {
return s.state.SpaceId
}
func (s *space) Init(ctx context.Context) (err error) { func (s *space) Init(ctx context.Context) (err error) {
log.With(zap.String("spaceId", s.id)).Debug("initializing space") err = s.app.Start(ctx)
s.storage = newCommonStorage(s.storage)
header, err := s.storage.SpaceHeader()
if err != nil { if err != nil {
return return
} }
s.header = header s.treeBuilder = s.app.MustComponent(objecttreebuilder.CName).(objecttreebuilder.TreeBuilderComponent)
initialIds, err := s.storage.StoredIds() s.headSync = s.app.MustComponent(headsync.CName).(headsync.HeadSync)
if err != nil { s.syncStatus = s.app.MustComponent(syncstatus.CName).(syncstatus.StatusService)
return s.settings = s.app.MustComponent(settings.CName).(settings.Settings)
} s.objectSync = s.app.MustComponent(objectsync.CName).(objectsync.ObjectSync)
aclStorage, err := s.storage.AclStorage() s.storage = s.app.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
if err != nil { s.aclList = s.app.MustComponent(syncacl.CName).(list.AclList)
return s.header, err = s.storage.SpaceHeader()
} return
aclList, err := list.BuildAclListWithIdentity(s.account.Account(), aclStorage)
if err != nil {
return
}
s.aclList = syncacl.NewSyncAcl(aclList, s.objectSync.SyncClient().MessagePool())
s.treeManager.AddObject(s.aclList)
deletionState := settingsstate.NewObjectDeletionState(log, s.storage)
deps := settings.Deps{
BuildFunc: func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error) {
res, err := s.BuildTree(ctx, id, BuildTreeOpts{
Listener: listener,
WaitTreeRemoteSync: false,
// space settings document should not have empty data
treeBuilder: objecttree.BuildObjectTree,
})
log.Debug("building settings tree", zap.String("id", id), zap.String("spaceId", s.id))
if err != nil {
return
}
t = res.(synctree.SyncTree)
return
},
Account: s.account,
TreeManager: s.treeManager,
Store: s.storage,
DeletionState: deletionState,
Provider: s.headSync,
Configuration: s.configuration,
OnSpaceDelete: s.onSpaceDelete,
}
s.settingsObject = settings.NewSettingsObject(deps, s.id)
s.headSync.Init(initialIds, deletionState)
err = s.settingsObject.Init(ctx)
if err != nil {
return
}
s.treeManager.AddObject(s.settingsObject)
s.syncStatus.Run()
s.handleQueue = multiqueue.New[HandleMessage](s.handleMessage, 100)
return nil
}
func (s *space) ObjectSync() objectsync.ObjectSync {
return s.objectSync
}
func (s *space) HeadSync() headsync.HeadSync {
return s.headSync
} }
func (s *space) SyncStatus() syncstatus.StatusUpdater { func (s *space) SyncStatus() syncstatus.StatusUpdater {
@ -247,246 +181,25 @@ func (s *space) Storage() spacestorage.SpaceStorage {
return s.storage return s.storage
} }
func (s *space) StoredIds() []string {
return slice.DiscardFromSlice(s.headSync.AllIds(), func(id string) bool {
return id == s.settingsObject.Id()
})
}
func (s *space) DebugAllHeads() []headsync.TreeHeads {
return s.headSync.DebugAllHeads()
}
func (s *space) CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error) {
if s.isClosed.Load() {
err = ErrSpaceClosed
return
}
root, err := objecttree.CreateObjectTreeRoot(payload, s.aclList)
if err != nil {
return
}
res = treestorage.TreeStorageCreatePayload{
RootRawChange: root,
Changes: []*treechangeproto.RawTreeChangeWithId{root},
Heads: []string{root.Id},
}
return
}
func (s *space) PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (t objecttree.ObjectTree, err error) {
if s.isClosed.Load() {
err = ErrSpaceClosed
return
}
deps := synctree.BuildDeps{
SpaceId: s.id,
SyncClient: s.objectSync.SyncClient(),
Configuration: s.configuration,
HeadNotifiable: s.headSync,
Listener: listener,
AclList: s.aclList,
SpaceStorage: s.storage,
OnClose: s.onObjectClose,
SyncStatus: s.syncStatus,
PeerGetter: s.peerManager,
BuildObjectTree: s.treeBuilder,
}
t, err = synctree.PutSyncTree(ctx, payload, deps)
if err != nil {
return
}
s.treesUsed.Add(1)
log.Debug("incrementing counter", zap.String("id", payload.RootRawChange.Id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id))
return
}
type BuildTreeOpts struct {
Listener updatelistener.UpdateListener
WaitTreeRemoteSync bool
treeBuilder objecttree.BuildObjectTreeFunc
}
type HistoryTreeOpts struct {
BeforeId string
Include bool
BuildFullTree bool
}
func (s *space) BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (t objecttree.ObjectTree, err error) {
if s.isClosed.Load() {
err = ErrSpaceClosed
return
}
treeBuilder := opts.treeBuilder
if treeBuilder == nil {
treeBuilder = s.treeBuilder
}
deps := synctree.BuildDeps{
SpaceId: s.id,
SyncClient: s.objectSync.SyncClient(),
Configuration: s.configuration,
HeadNotifiable: s.headSync,
Listener: opts.Listener,
AclList: s.aclList,
SpaceStorage: s.storage,
OnClose: s.onObjectClose,
SyncStatus: s.syncStatus,
WaitTreeRemoteSync: opts.WaitTreeRemoteSync,
PeerGetter: s.peerManager,
BuildObjectTree: treeBuilder,
}
s.treesUsed.Add(1)
log.Debug("incrementing counter", zap.String("id", id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id))
if t, err = synctree.BuildSyncTreeOrGetRemote(ctx, id, deps); err != nil {
s.treesUsed.Add(-1)
log.Debug("decrementing counter, load failed", zap.String("id", id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id), zap.Error(err))
return nil, err
}
return
}
func (s *space) BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (t objecttree.HistoryTree, err error) {
if s.isClosed.Load() {
err = ErrSpaceClosed
return
}
params := objecttree.HistoryTreeParams{
AclList: s.aclList,
BeforeId: opts.BeforeId,
IncludeBeforeId: opts.Include,
BuildFullTree: opts.BuildFullTree,
}
params.TreeStorage, err = s.storage.TreeStorage(id)
if err != nil {
return
}
return objecttree.BuildHistoryTree(params)
}
func (s *space) DeleteTree(ctx context.Context, id string) (err error) {
return s.settingsObject.DeleteObject(id)
}
func (s *space) SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) {
return s.settingsObject.SpaceDeleteRawChange()
}
func (s *space) DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) {
return s.settingsObject.DeleteSpace(ctx, deleteChange)
}
func (s *space) HandleMessage(ctx context.Context, hm HandleMessage) (err error) {
threadId := hm.Message.ObjectId
hm.ReceiveTime = time.Now()
if hm.Message.ReplyId != "" {
threadId += hm.Message.ReplyId
defer func() {
_ = s.handleQueue.CloseThread(threadId)
}()
}
if hm.PeerCtx == nil {
hm.PeerCtx = ctx
}
err = s.handleQueue.Add(ctx, threadId, hm)
if err == mb.ErrOverflowed {
log.InfoCtx(ctx, "queue overflowed", zap.String("spaceId", s.id), zap.String("objectId", threadId))
// skip overflowed error
return nil
}
return
}
func (s *space) handleMessage(msg HandleMessage) {
var err error
msg.StartHandlingTime = time.Now()
ctx := peer.CtxWithPeerId(context.Background(), msg.SenderId)
ctx = logger.CtxWithFields(ctx, zap.Uint64("msgId", msg.Id), zap.String("senderId", msg.SenderId))
defer func() {
if s.metric == nil {
return
}
s.metric.RequestLog(msg.PeerCtx, "space.streamOp", msg.LogFields(
zap.Error(err),
)...)
}()
if !msg.Deadline.IsZero() {
now := time.Now()
if now.After(msg.Deadline) {
log.InfoCtx(ctx, "skip message: deadline exceed")
err = context.DeadlineExceeded
return
}
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, msg.Deadline)
defer cancel()
}
if err = s.objectSync.HandleMessage(ctx, msg.SenderId, msg.Message); err != nil {
if msg.Message.ObjectId != "" {
// cleanup thread on error
_ = s.handleQueue.CloseThread(msg.Message.ObjectId)
}
log.InfoCtx(ctx, "handleMessage error", zap.Error(err))
}
}
func (s *space) onObjectClose(id string) {
s.treesUsed.Add(-1)
log.Debug("decrementing counter", zap.String("id", id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id))
_ = s.handleQueue.CloseThread(id)
}
func (s *space) onSpaceDelete() {
err := s.storage.SetSpaceDeleted()
if err != nil {
log.Debug("failed to set space deleted")
}
s.isDeleted.Swap(true)
}
func (s *space) Close() error { func (s *space) Close() error {
if s.isClosed.Swap(true) { if s.state.SpaceIsClosed.Swap(true) {
log.Warn("call space.Close on closed space", zap.String("id", s.id)) log.Warn("call space.Close on closed space", zap.String("id", s.state.SpaceId))
return nil return nil
} }
log.With(zap.String("id", s.id)).Debug("space is closing") log := log.With(zap.String("spaceId", s.state.SpaceId))
log.Debug("space is closing")
var mError errs.Group err := s.app.Close(context.Background())
if err := s.handleQueue.Close(); err != nil { log.Debug("space closed")
mError.Add(err) return err
}
if err := s.headSync.Close(); err != nil {
mError.Add(err)
}
if err := s.objectSync.Close(); err != nil {
mError.Add(err)
}
if err := s.settingsObject.Close(); err != nil {
mError.Add(err)
}
if err := s.aclList.Close(); err != nil {
mError.Add(err)
}
if err := s.storage.Close(); err != nil {
mError.Add(err)
}
if err := s.syncStatus.Close(); err != nil {
mError.Add(err)
}
log.With(zap.String("id", s.id)).Debug("space closed")
return mError.Err()
} }
func (s *space) TryClose(objectTTL time.Duration) (close bool, err error) { func (s *space) TryClose(objectTTL time.Duration) (close bool, err error) {
if time.Now().Sub(s.objectSync.LastUsage()) < objectTTL { if time.Now().Sub(s.objectSync.LastUsage()) < objectTTL {
return false, nil return false, nil
} }
locked := s.treesUsed.Load() > 1 locked := s.state.TreesUsed.Load() > 1
log.With(zap.Int32("trees used", s.treesUsed.Load()), zap.Bool("locked", locked), zap.String("spaceId", s.id)).Debug("space lock status check") log.With(zap.Int32("trees used", s.state.TreesUsed.Load()), zap.Bool("locked", locked), zap.String("spaceId", s.state.SpaceId)).Debug("space lock status check")
if locked { if locked {
return false, nil return false, nil
} }

View File

@ -5,14 +5,22 @@ import (
"github.com/anyproto/any-sync/accountservice" "github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/config"
"github.com/anyproto/any-sync/commonspace/credentialprovider" "github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/headsync" "github.com/anyproto/any-sync/commonspace/headsync"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/objectmanager"
"github.com/anyproto/any-sync/commonspace/objectsync" "github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objecttreebuilder"
"github.com/anyproto/any-sync/commonspace/peermanager" "github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/requestmanager"
"github.com/anyproto/any-sync/commonspace/settings"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
@ -21,6 +29,7 @@ import (
"github.com/anyproto/any-sync/net/pool" "github.com/anyproto/any-sync/net/pool"
"github.com/anyproto/any-sync/net/rpc/rpcerr" "github.com/anyproto/any-sync/net/rpc/rpcerr"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
"storj.io/drpc"
"sync/atomic" "sync/atomic"
) )
@ -45,32 +54,30 @@ type SpaceService interface {
} }
type spaceService struct { type spaceService struct {
config Config config config.Config
account accountservice.Service account accountservice.Service
configurationService nodeconf.Service configurationService nodeconf.Service
storageProvider spacestorage.SpaceStorageProvider storageProvider spacestorage.SpaceStorageProvider
peermanagerProvider peermanager.PeerManagerProvider peerManagerProvider peermanager.PeerManagerProvider
credentialProvider credentialprovider.CredentialProvider credentialProvider credentialprovider.CredentialProvider
treeManager treemanager.TreeManager statusServiceProvider syncstatus.StatusServiceProvider
pool pool.Pool treeManager treemanager.TreeManager
metric metric.Metric pool pool.Pool
metric metric.Metric
app *app.App
} }
func (s *spaceService) Init(a *app.App) (err error) { func (s *spaceService) Init(a *app.App) (err error) {
s.config = a.MustComponent("config").(ConfigGetter).GetSpace() s.config = a.MustComponent("config").(config.ConfigGetter).GetSpace()
s.account = a.MustComponent(accountservice.CName).(accountservice.Service) s.account = a.MustComponent(accountservice.CName).(accountservice.Service)
s.storageProvider = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorageProvider) s.storageProvider = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorageProvider)
s.configurationService = a.MustComponent(nodeconf.CName).(nodeconf.Service) s.configurationService = a.MustComponent(nodeconf.CName).(nodeconf.Service)
s.treeManager = a.MustComponent(treemanager.CName).(treemanager.TreeManager) s.treeManager = a.MustComponent(treemanager.CName).(treemanager.TreeManager)
s.peermanagerProvider = a.MustComponent(peermanager.CName).(peermanager.PeerManagerProvider) s.peerManagerProvider = a.MustComponent(peermanager.CName).(peermanager.PeerManagerProvider)
credProvider := a.Component(credentialprovider.CName) s.statusServiceProvider = a.MustComponent(syncstatus.CName).(syncstatus.StatusServiceProvider)
if credProvider != nil {
s.credentialProvider = credProvider.(credentialprovider.CredentialProvider)
} else {
s.credentialProvider = credentialprovider.NewNoOp()
}
s.pool = a.MustComponent(pool.CName).(pool.Pool) s.pool = a.MustComponent(pool.CName).(pool.Pool)
s.metric, _ = a.Component(metric.CName).(metric.Metric) s.metric, _ = a.Component(metric.CName).(metric.Metric)
s.app = a
return nil return nil
} }
@ -138,8 +145,6 @@ func (s *spaceService) NewSpace(ctx context.Context, id string) (Space, error) {
} }
} }
} }
lastConfiguration := s.configurationService
var ( var (
spaceIsClosed = &atomic.Bool{} spaceIsClosed = &atomic.Bool{}
spaceIsDeleted = &atomic.Bool{} spaceIsDeleted = &atomic.Bool{}
@ -149,42 +154,39 @@ func (s *spaceService) NewSpace(ctx context.Context, id string) (Space, error) {
return nil, err return nil, err
} }
spaceIsDeleted.Swap(isDeleted) spaceIsDeleted.Swap(isDeleted)
getter := newCommonGetter(st.Id(), s.treeManager, spaceIsClosed) state := &spacestate.SpaceState{
syncStatus := syncstatus.NewNoOpSyncStatus() SpaceId: st.Id(),
// this will work only for clients, not the best solution, but... SpaceIsDeleted: spaceIsDeleted,
if !lastConfiguration.IsResponsible(st.Id()) { SpaceIsClosed: spaceIsClosed,
// TODO: move it to the client package and add possibility to inject StatusProvider from the client TreesUsed: &atomic.Int32{},
syncStatus = syncstatus.NewSyncStatusProvider(st.Id(), syncstatus.DefaultDeps(lastConfiguration, st))
} }
var builder objecttree.BuildObjectTreeFunc
if s.config.KeepTreeDataInMemory { if s.config.KeepTreeDataInMemory {
builder = objecttree.BuildObjectTree state.TreeBuilderFunc = objecttree.BuildObjectTree
} else { } else {
builder = objecttree.BuildEmptyDataObjectTree state.TreeBuilderFunc = objecttree.BuildEmptyDataObjectTree
} }
peerManager, err := s.peerManagerProvider.NewPeerManager(ctx, id)
peerManager, err := s.peermanagerProvider.NewPeerManager(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
statusService := s.statusServiceProvider.NewStatusService()
spaceApp := s.app.ChildApp()
spaceApp.Register(state).
Register(peerManager).
Register(newCommonStorage(st)).
Register(statusService).
Register(syncacl.New()).
Register(requestmanager.New()).
Register(deletionstate.New()).
Register(settings.New()).
Register(objectmanager.New(s.treeManager)).
Register(objecttreebuilder.New()).
Register(objectsync.New()).
Register(headsync.New())
headSync := headsync.NewHeadSync(id, spaceIsDeleted, s.config.SyncPeriod, lastConfiguration, st, peerManager, getter, syncStatus, s.credentialProvider, log)
objectSync := objectsync.NewObjectSync(id, spaceIsDeleted, lastConfiguration, peerManager, getter, st)
sp := &space{ sp := &space{
id: id, state: state,
objectSync: objectSync, app: spaceApp,
headSync: headSync,
syncStatus: syncStatus,
treeManager: getter,
account: s.account,
configuration: lastConfiguration,
peerManager: peerManager,
storage: st,
treesUsed: &atomic.Int32{},
treeBuilder: builder,
isClosed: spaceIsClosed,
isDeleted: spaceIsDeleted,
metric: s.metric,
} }
return sp, nil return sp, nil
} }
@ -226,8 +228,12 @@ func (s *spaceService) getSpaceStorageFromRemote(ctx context.Context, id string)
return return
} }
cl := spacesyncproto.NewDRPCSpaceSyncClient(p) var res *spacesyncproto.SpacePullResponse
res, err := cl.SpacePull(ctx, &spacesyncproto.SpacePullRequest{Id: id}) err = p.DoDrpc(ctx, func(conn drpc.Conn) error {
cl := spacesyncproto.NewDRPCSpaceSyncClient(conn)
res, err = cl.SpacePull(ctx, &spacesyncproto.SpacePullRequest{Id: id})
return err
})
if err != nil { if err != nil {
err = rpcerr.Unwrap(err) err = rpcerr.Unwrap(err)
return return

View File

@ -0,0 +1,25 @@
package spacestate
import (
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"sync/atomic"
)
const CName = "common.commonspace.spacestate"
type SpaceState struct {
SpaceId string
SpaceIsDeleted *atomic.Bool
SpaceIsClosed *atomic.Bool
TreesUsed *atomic.Int32
TreeBuilderFunc objecttree.BuildObjectTreeFunc
}
func (s *SpaceState) Init(a *app.App) (err error) {
return nil
}
func (s *SpaceState) Name() (name string) {
return CName
}

View File

@ -1,6 +1,8 @@
package spacestorage package spacestorage
import ( import (
"context"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage" "github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
@ -21,6 +23,22 @@ type InMemorySpaceStorage struct {
sync.Mutex sync.Mutex
} }
func (i *InMemorySpaceStorage) Run(ctx context.Context) (err error) {
return nil
}
func (i *InMemorySpaceStorage) Close(ctx context.Context) (err error) {
return nil
}
func (i *InMemorySpaceStorage) Init(a *app.App) (err error) {
return nil
}
func (i *InMemorySpaceStorage) Name() (name string) {
return CName
}
func NewInMemorySpaceStorage(payload SpaceStorageCreatePayload) (SpaceStorage, error) { func NewInMemorySpaceStorage(payload SpaceStorageCreatePayload) (SpaceStorage, error) {
aclStorage, err := liststorage.NewInMemoryAclListStorage(payload.AclWithId.Id, []*aclrecordproto.RawAclRecordWithId{payload.AclWithId}) aclStorage, err := liststorage.NewInMemoryAclListStorage(payload.AclWithId.Id, []*aclrecordproto.RawAclRecordWithId{payload.AclWithId})
if err != nil { if err != nil {
@ -148,10 +166,6 @@ func (i *InMemorySpaceStorage) ReadSpaceHash() (hash string, err error) {
return i.spaceHash, nil return i.spaceHash, nil
} }
func (i *InMemorySpaceStorage) Close() error {
return nil
}
func (i *InMemorySpaceStorage) AllTrees() map[string]treestorage.TreeStorage { func (i *InMemorySpaceStorage) AllTrees() map[string]treestorage.TreeStorage {
i.Lock() i.Lock()
defer i.Unlock() defer i.Unlock()

View File

@ -5,8 +5,10 @@
package mock_spacestorage package mock_spacestorage
import ( import (
context "context"
reflect "reflect" reflect "reflect"
app "github.com/anyproto/any-sync/app"
liststorage "github.com/anyproto/any-sync/commonspace/object/acl/liststorage" liststorage "github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage" treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
@ -53,17 +55,17 @@ func (mr *MockSpaceStorageMockRecorder) AclStorage() *gomock.Call {
} }
// Close mocks base method. // Close mocks base method.
func (m *MockSpaceStorage) Close() error { func (m *MockSpaceStorage) Close(arg0 context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close") ret := m.ctrl.Call(m, "Close", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Close indicates an expected call of Close. // Close indicates an expected call of Close.
func (mr *MockSpaceStorageMockRecorder) Close() *gomock.Call { func (mr *MockSpaceStorageMockRecorder) Close(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSpaceStorage)(nil).Close)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSpaceStorage)(nil).Close), arg0)
} }
// CreateTreeStorage mocks base method. // CreateTreeStorage mocks base method.
@ -110,6 +112,20 @@ func (mr *MockSpaceStorageMockRecorder) Id() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Id", reflect.TypeOf((*MockSpaceStorage)(nil).Id)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Id", reflect.TypeOf((*MockSpaceStorage)(nil).Id))
} }
// Init mocks base method.
func (m *MockSpaceStorage) Init(arg0 *app.App) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Init indicates an expected call of Init.
func (mr *MockSpaceStorageMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockSpaceStorage)(nil).Init), arg0)
}
// IsSpaceDeleted mocks base method. // IsSpaceDeleted mocks base method.
func (m *MockSpaceStorage) IsSpaceDeleted() (bool, error) { func (m *MockSpaceStorage) IsSpaceDeleted() (bool, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -125,6 +141,20 @@ func (mr *MockSpaceStorageMockRecorder) IsSpaceDeleted() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsSpaceDeleted", reflect.TypeOf((*MockSpaceStorage)(nil).IsSpaceDeleted)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsSpaceDeleted", reflect.TypeOf((*MockSpaceStorage)(nil).IsSpaceDeleted))
} }
// Name mocks base method.
func (m *MockSpaceStorage) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockSpaceStorageMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockSpaceStorage)(nil).Name))
}
// ReadSpaceHash mocks base method. // ReadSpaceHash mocks base method.
func (m *MockSpaceStorage) ReadSpaceHash() (string, error) { func (m *MockSpaceStorage) ReadSpaceHash() (string, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -140,6 +170,20 @@ func (mr *MockSpaceStorageMockRecorder) ReadSpaceHash() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadSpaceHash", reflect.TypeOf((*MockSpaceStorage)(nil).ReadSpaceHash)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadSpaceHash", reflect.TypeOf((*MockSpaceStorage)(nil).ReadSpaceHash))
} }
// Run mocks base method.
func (m *MockSpaceStorage) Run(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Run indicates an expected call of Run.
func (mr *MockSpaceStorageMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSpaceStorage)(nil).Run), arg0)
}
// SetSpaceDeleted mocks base method. // SetSpaceDeleted mocks base method.
func (m *MockSpaceStorage) SetSpaceDeleted() error { func (m *MockSpaceStorage) SetSpaceDeleted() error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -27,8 +27,8 @@ const (
TreeDeletedStatusDeleted = "deleted" TreeDeletedStatusDeleted = "deleted"
) )
// TODO: consider moving to some file with all common interfaces etc
type SpaceStorage interface { type SpaceStorage interface {
app.ComponentRunnable
Id() string Id() string
SetSpaceDeleted() error SetSpaceDeleted() error
IsSpaceDeleted() (bool, error) IsSpaceDeleted() (bool, error)
@ -44,8 +44,6 @@ type SpaceStorage interface {
CreateTreeStorage(payload treestorage.TreeStorageCreatePayload) (treestorage.TreeStorage, error) CreateTreeStorage(payload treestorage.TreeStorageCreatePayload) (treestorage.TreeStorage, error)
WriteSpaceHash(hash string) error WriteSpaceHash(hash string) error
ReadSpaceHash() (hash string, err error) ReadSpaceHash() (hash string, err error)
Close() error
} }
type SpaceStorageCreatePayload struct { type SpaceStorageCreatePayload struct {

View File

@ -1 +0,0 @@
package spacestorage

View File

@ -65,6 +65,21 @@ func (mr *MockDRPCSpaceSyncClientMockRecorder) HeadSync(arg0, arg1 interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadSync", reflect.TypeOf((*MockDRPCSpaceSyncClient)(nil).HeadSync), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadSync", reflect.TypeOf((*MockDRPCSpaceSyncClient)(nil).HeadSync), arg0, arg1)
} }
// ObjectSync mocks base method.
func (m *MockDRPCSpaceSyncClient) ObjectSync(arg0 context.Context, arg1 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ObjectSync", arg0, arg1)
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ObjectSync indicates an expected call of ObjectSync.
func (mr *MockDRPCSpaceSyncClientMockRecorder) ObjectSync(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ObjectSync", reflect.TypeOf((*MockDRPCSpaceSyncClient)(nil).ObjectSync), arg0, arg1)
}
// ObjectSyncStream mocks base method. // ObjectSyncStream mocks base method.
func (m *MockDRPCSpaceSyncClient) ObjectSyncStream(arg0 context.Context) (spacesyncproto.DRPCSpaceSync_ObjectSyncStreamClient, error) { func (m *MockDRPCSpaceSyncClient) ObjectSyncStream(arg0 context.Context) (spacesyncproto.DRPCSpaceSync_ObjectSyncStreamClient, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -23,6 +23,8 @@ service SpaceSync {
rpc SpacePull(SpacePullRequest) returns (SpacePullResponse); rpc SpacePull(SpacePullRequest) returns (SpacePullResponse);
// ObjectSyncStream opens object sync stream with node or client // ObjectSyncStream opens object sync stream with node or client
rpc ObjectSyncStream(stream ObjectSyncMessage) returns (stream ObjectSyncMessage); rpc ObjectSyncStream(stream ObjectSyncMessage) returns (stream ObjectSyncMessage);
// ObjectSync sends object sync message and synchronously gets response message
rpc ObjectSync(ObjectSyncMessage) returns (ObjectSyncMessage);
} }
// HeadSyncRange presenting a request for one range // HeadSyncRange presenting a request for one range

View File

@ -1254,75 +1254,75 @@ func init() {
} }
var fileDescriptor_80e49f1f4ac27799 = []byte{ var fileDescriptor_80e49f1f4ac27799 = []byte{
// 1077 bytes of a gzipped FileDescriptorProto // 1083 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x56, 0xcd, 0x6e, 0xdb, 0x46, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x56, 0xcd, 0x6e, 0xdb, 0x46,
0x10, 0x16, 0xe9, 0x5f, 0x8d, 0x65, 0x99, 0xd9, 0x28, 0x89, 0xaa, 0x18, 0x8a, 0xb0, 0x28, 0x0a, 0x10, 0x16, 0xe9, 0x5f, 0x8d, 0x65, 0x85, 0xd9, 0x28, 0x89, 0xaa, 0x18, 0x8a, 0xb0, 0x28, 0x0a,
0x23, 0x07, 0x27, 0xb1, 0x8b, 0x02, 0x49, 0xdb, 0x43, 0x62, 0x3b, 0x0d, 0x51, 0x24, 0x36, 0x56, 0x23, 0x07, 0x27, 0xb1, 0x8b, 0x02, 0x49, 0xdb, 0x43, 0x62, 0x3b, 0x35, 0x51, 0x24, 0x36, 0x56,
0x0d, 0x0a, 0x14, 0xc8, 0x61, 0x4d, 0x8e, 0x2d, 0xb6, 0x14, 0xc9, 0x72, 0x57, 0x89, 0x75, 0xec, 0x0d, 0x0a, 0x14, 0xc8, 0x61, 0x4d, 0x8e, 0x2d, 0xb6, 0x14, 0xc9, 0x72, 0x57, 0x89, 0x75, 0xec,
0xa9, 0xd7, 0x9e, 0xdb, 0x07, 0xe8, 0x0b, 0xf4, 0x21, 0x7a, 0x4c, 0x6f, 0x3d, 0x16, 0xf6, 0x8b, 0xa9, 0xd7, 0x9e, 0xdb, 0x07, 0xe8, 0xab, 0xf4, 0x98, 0xde, 0x7a, 0x2c, 0xec, 0xf7, 0x28, 0x8a,
0x14, 0xbb, 0x5c, 0xfe, 0xc8, 0xa2, 0x02, 0xe4, 0x22, 0xed, 0x7e, 0x33, 0xf3, 0xcd, 0xdf, 0xee, 0x5d, 0x2e, 0x7f, 0x64, 0x51, 0x01, 0x8a, 0x5e, 0xa4, 0xdd, 0x6f, 0x66, 0xbe, 0xf9, 0xdb, 0xdd,
0x0e, 0xe1, 0x91, 0x17, 0x8f, 0xc7, 0x71, 0x24, 0x12, 0xee, 0xe1, 0x03, 0xfd, 0x2b, 0xa6, 0x91, 0x21, 0x3c, 0xf6, 0xe2, 0xf1, 0x38, 0x8e, 0x44, 0xc2, 0x3d, 0x7c, 0xa8, 0x7f, 0xc5, 0x34, 0xf2,
0x97, 0xa4, 0xb1, 0x8c, 0x1f, 0xe8, 0x5f, 0x51, 0xa2, 0xbb, 0x1a, 0x20, 0xcd, 0x02, 0xa0, 0x2e, 0x92, 0x34, 0x96, 0xf1, 0x43, 0xfd, 0x2b, 0x4a, 0x74, 0x47, 0x03, 0xa4, 0x59, 0x00, 0xd4, 0x85,
0x6c, 0xbe, 0x40, 0xee, 0x0f, 0xa7, 0x91, 0xc7, 0x78, 0x74, 0x8e, 0x84, 0xc0, 0xf2, 0x59, 0x1a, 0xcd, 0x23, 0xe4, 0xfe, 0x70, 0x1a, 0x79, 0x8c, 0x47, 0xe7, 0x48, 0x08, 0x2c, 0x9f, 0xa5, 0xf1,
0x8f, 0xbb, 0xd6, 0xc0, 0xda, 0x59, 0x66, 0x7a, 0x4d, 0xda, 0x60, 0xcb, 0xb8, 0x6b, 0x6b, 0xc4, 0xb8, 0x6b, 0x0d, 0xac, 0xed, 0x65, 0xa6, 0xd7, 0xa4, 0x0d, 0xb6, 0x8c, 0xbb, 0xb6, 0x46, 0x6c,
0x96, 0x31, 0xe9, 0xc0, 0x4a, 0x18, 0x8c, 0x03, 0xd9, 0x5d, 0x1a, 0x58, 0x3b, 0x9b, 0x2c, 0xdb, 0x19, 0x93, 0x0e, 0xac, 0x84, 0xc1, 0x38, 0x90, 0xdd, 0xa5, 0x81, 0xb5, 0xbd, 0xc9, 0xb2, 0x0d,
0xd0, 0x0b, 0x68, 0x17, 0x54, 0x28, 0x26, 0xa1, 0x54, 0x5c, 0x23, 0x2e, 0x46, 0x9a, 0xab, 0xc5, 0xbd, 0x80, 0x76, 0x41, 0x85, 0x62, 0x12, 0x4a, 0xc5, 0x35, 0xe2, 0x62, 0xa4, 0xb9, 0x5a, 0x4c,
0xf4, 0x9a, 0x7c, 0x05, 0xeb, 0x18, 0xe2, 0x18, 0x23, 0x29, 0xba, 0xf6, 0x60, 0x69, 0x67, 0x63, 0xaf, 0xc9, 0x17, 0xb0, 0x8e, 0x21, 0x8e, 0x31, 0x92, 0xa2, 0x6b, 0x0f, 0x96, 0xb6, 0x37, 0x76,
0x6f, 0xb0, 0x5b, 0xc6, 0x37, 0x4b, 0x70, 0x94, 0x29, 0xb2, 0xc2, 0x42, 0x79, 0xf6, 0xe2, 0x49, 0x07, 0x3b, 0x65, 0x7c, 0xb3, 0x04, 0x87, 0x99, 0x22, 0x2b, 0x2c, 0x94, 0x67, 0x2f, 0x9e, 0x44,
0x54, 0x78, 0xd6, 0x1b, 0xfa, 0x25, 0xdc, 0xaa, 0x35, 0x54, 0x81, 0x07, 0xbe, 0x76, 0xdf, 0x64, 0x85, 0x67, 0xbd, 0xa1, 0x9f, 0xc3, 0xed, 0x5a, 0x43, 0x15, 0x78, 0xe0, 0x6b, 0xf7, 0x4d, 0x66,
0x76, 0xe0, 0xeb, 0x80, 0x90, 0xfb, 0x3a, 0x95, 0x26, 0xd3, 0x6b, 0xfa, 0x06, 0xb6, 0x4a, 0xe3, 0x07, 0xbe, 0x0e, 0x08, 0xb9, 0xaf, 0x53, 0x69, 0x32, 0xbd, 0xa6, 0x6f, 0xe0, 0x46, 0x69, 0xfc,
0x9f, 0x27, 0x28, 0x24, 0xe9, 0xc2, 0x9a, 0x0e, 0xc9, 0xcd, 0x6d, 0xf3, 0x2d, 0x79, 0x08, 0xab, 0xe3, 0x04, 0x85, 0x24, 0x5d, 0x58, 0xd3, 0x21, 0xb9, 0xb9, 0x6d, 0xbe, 0x25, 0x8f, 0x60, 0x35,
0xa9, 0x2a, 0x53, 0x1e, 0x7b, 0xb7, 0x2e, 0x76, 0xa5, 0xc0, 0x8c, 0x1e, 0xfd, 0x06, 0x9c, 0x4a, 0x55, 0x65, 0xca, 0x63, 0xef, 0xd6, 0xc5, 0xae, 0x14, 0x98, 0xd1, 0xa3, 0x5f, 0x81, 0x53, 0x89,
0x6c, 0x49, 0x1c, 0x09, 0x24, 0xfb, 0xb0, 0x96, 0xea, 0x38, 0x45, 0xd7, 0xd2, 0x34, 0x9f, 0x2c, 0x2d, 0x89, 0x23, 0x81, 0x64, 0x0f, 0xd6, 0x52, 0x1d, 0xa7, 0xe8, 0x5a, 0x9a, 0xe6, 0xa3, 0x85,
0x2c, 0x01, 0xcb, 0x35, 0xe9, 0x1f, 0x16, 0xdc, 0x38, 0x3e, 0xfd, 0x11, 0x3d, 0xa9, 0xa4, 0x2f, 0x25, 0x60, 0xb9, 0x26, 0xfd, 0xcd, 0x82, 0x9b, 0xc7, 0xa7, 0xdf, 0xa3, 0x27, 0x95, 0xf4, 0x25,
0x51, 0x08, 0x7e, 0x8e, 0x1f, 0x08, 0x75, 0x1b, 0x9a, 0x69, 0x96, 0x8f, 0x9b, 0x27, 0x5c, 0x02, 0x0a, 0xc1, 0xcf, 0xf1, 0x03, 0xa1, 0x6e, 0x41, 0x33, 0xcd, 0xf2, 0x71, 0xf3, 0x84, 0x4b, 0x40,
0xca, 0x2e, 0xc5, 0x24, 0x9c, 0xba, 0xbe, 0x2e, 0x65, 0x93, 0xe5, 0x5b, 0x25, 0x49, 0xf8, 0x34, 0xd9, 0xa5, 0x98, 0x84, 0x53, 0xd7, 0xd7, 0xa5, 0x6c, 0xb2, 0x7c, 0xab, 0x24, 0x09, 0x9f, 0x86,
0x8c, 0xb9, 0xdf, 0x5d, 0xd6, 0x7d, 0xcb, 0xb7, 0xa4, 0x07, 0xeb, 0xb1, 0x0e, 0xc0, 0xf5, 0xbb, 0x31, 0xf7, 0xbb, 0xcb, 0xba, 0x6f, 0xf9, 0x96, 0xf4, 0x60, 0x3d, 0xd6, 0x01, 0xb8, 0x7e, 0x77,
0x2b, 0xda, 0xa8, 0xd8, 0x53, 0x04, 0x67, 0xa8, 0x1c, 0x9f, 0x4c, 0xc4, 0x28, 0x2f, 0xe3, 0xa3, 0x45, 0x1b, 0x15, 0x7b, 0x8a, 0xe0, 0x0c, 0x95, 0xe3, 0x93, 0x89, 0x18, 0xe5, 0x65, 0x7c, 0x5c,
0x92, 0x49, 0xc5, 0xb6, 0xb1, 0x77, 0xa7, 0x92, 0x66, 0xa6, 0x9d, 0x89, 0x4b, 0x17, 0x7d, 0x80, 0x32, 0xa9, 0xd8, 0x36, 0x76, 0xef, 0x56, 0xd2, 0xcc, 0xb4, 0x33, 0x71, 0xe9, 0xa2, 0x0f, 0xb0,
0x83, 0x14, 0x7d, 0x8c, 0x64, 0xc0, 0x43, 0x1d, 0x75, 0x8b, 0x55, 0x10, 0x7a, 0x13, 0x6e, 0x54, 0x9f, 0xa2, 0x8f, 0x91, 0x0c, 0x78, 0xa8, 0xa3, 0x6e, 0xb1, 0x0a, 0x42, 0x6f, 0xc1, 0xcd, 0x8a,
0xdc, 0x64, 0xe5, 0xa4, 0xb4, 0xf0, 0x1d, 0x86, 0xb9, 0xef, 0x6b, 0x9d, 0xa7, 0xcf, 0x0b, 0x43, 0x9b, 0xac, 0x9c, 0x94, 0x16, 0xbe, 0xc3, 0x30, 0xf7, 0x7d, 0xad, 0xf3, 0xf4, 0x45, 0x61, 0xa8,
0xa5, 0x63, 0xfa, 0xf0, 0xf1, 0x01, 0xd2, 0x5f, 0x6c, 0x68, 0x55, 0x25, 0xe4, 0x29, 0x6c, 0x68, 0x74, 0x4c, 0x1f, 0xfe, 0x7b, 0x80, 0xf4, 0x27, 0x1b, 0x5a, 0x55, 0x09, 0x79, 0x06, 0x1b, 0xda,
0x1b, 0xd5, 0x36, 0x4c, 0x0d, 0xcf, 0xbd, 0x0a, 0x0f, 0xe3, 0xef, 0x86, 0xa5, 0xc2, 0xf7, 0x81, 0x46, 0xb5, 0x0d, 0x53, 0xc3, 0x73, 0xbf, 0xc2, 0xc3, 0xf8, 0xbb, 0x61, 0xa9, 0xf0, 0x6d, 0x20,
0x1c, 0xb9, 0x3e, 0xab, 0xda, 0xa8, 0xa4, 0xb9, 0x17, 0x1a, 0xc2, 0x3c, 0xe9, 0x12, 0x21, 0x14, 0x47, 0xae, 0xcf, 0xaa, 0x36, 0x2a, 0x69, 0xee, 0x85, 0x86, 0x30, 0x4f, 0xba, 0x44, 0x08, 0x85,
0x5a, 0xe5, 0xae, 0x68, 0xd8, 0x0c, 0x46, 0xf6, 0xa0, 0xa3, 0x29, 0x87, 0x28, 0x65, 0x10, 0x9d, 0x56, 0xb9, 0x2b, 0x1a, 0x36, 0x83, 0x91, 0x5d, 0xe8, 0x68, 0xca, 0x21, 0x4a, 0x19, 0x44, 0xe7,
0x8b, 0x93, 0x99, 0x16, 0xd6, 0xca, 0xc8, 0x17, 0x70, 0xbb, 0x0e, 0x2f, 0xba, 0xbb, 0x40, 0x4a, 0xe2, 0x64, 0xa6, 0x85, 0xb5, 0x32, 0xf2, 0x19, 0xdc, 0xa9, 0xc3, 0x8b, 0xee, 0x2e, 0x90, 0xd2,
0xff, 0xb1, 0x60, 0xa3, 0x92, 0x92, 0x3a, 0x17, 0x81, 0x6e, 0x90, 0x9c, 0x9a, 0xab, 0x5e, 0xec, 0x3f, 0x2d, 0xd8, 0xa8, 0xa4, 0xa4, 0xce, 0x45, 0xa0, 0x1b, 0x24, 0xa7, 0xe6, 0xaa, 0x17, 0x7b,
0xd5, 0x29, 0x94, 0xc1, 0x18, 0x85, 0xe4, 0xe3, 0x44, 0xa7, 0xb6, 0xc4, 0x4a, 0x40, 0x49, 0xb5, 0x75, 0x0a, 0x65, 0x30, 0x46, 0x21, 0xf9, 0x38, 0xd1, 0xa9, 0x2d, 0xb1, 0x12, 0x50, 0x52, 0xed,
0x8f, 0xef, 0xa6, 0x09, 0x9a, 0xb4, 0x4a, 0x80, 0x7c, 0x06, 0x6d, 0x75, 0x28, 0x03, 0x8f, 0xcb, 0xe3, 0x9b, 0x69, 0x82, 0x26, 0xad, 0x12, 0x20, 0x9f, 0x40, 0x5b, 0x1d, 0xca, 0xc0, 0xe3, 0x32,
0x20, 0x8e, 0xbe, 0xc5, 0xa9, 0xce, 0x66, 0x99, 0x5d, 0x43, 0xd5, 0xad, 0x16, 0x88, 0x59, 0xd4, 0x88, 0xa3, 0xaf, 0x71, 0xaa, 0xb3, 0x59, 0x66, 0xd7, 0x50, 0x75, 0xab, 0x05, 0x62, 0x16, 0x75,
0x2d, 0xa6, 0xd7, 0x64, 0x17, 0x48, 0xa5, 0xc4, 0x79, 0x35, 0x56, 0xb5, 0x46, 0x8d, 0x84, 0x9e, 0x8b, 0xe9, 0x35, 0xd9, 0x01, 0x52, 0x29, 0x71, 0x5e, 0x8d, 0x55, 0xad, 0x51, 0x23, 0xa1, 0x27,
0x40, 0x7b, 0xb6, 0x51, 0x64, 0x30, 0xdf, 0xd8, 0xd6, 0x6c, 0xdf, 0x54, 0xf4, 0xc1, 0x79, 0xc4, 0xd0, 0x9e, 0x6d, 0x14, 0x19, 0xcc, 0x37, 0xb6, 0x35, 0xdb, 0x37, 0x15, 0x7d, 0x70, 0x1e, 0x71,
0xe5, 0x24, 0x45, 0xd3, 0xb6, 0x12, 0xa0, 0x87, 0xd0, 0xa9, 0x6b, 0xbd, 0xbe, 0x97, 0xfc, 0xdd, 0x39, 0x49, 0xd1, 0xb4, 0xad, 0x04, 0xe8, 0x01, 0x74, 0xea, 0x5a, 0xaf, 0xef, 0x25, 0x7f, 0x37,
0x0c, 0x6b, 0x09, 0x98, 0x73, 0x6b, 0x17, 0xe7, 0xf6, 0x77, 0x0b, 0x3a, 0xc3, 0x6a, 0x1b, 0x0e, 0xc3, 0x5a, 0x02, 0xe6, 0xdc, 0xda, 0xc5, 0xb9, 0xfd, 0xd5, 0x82, 0xce, 0xb0, 0xda, 0x86, 0xfd,
0xe2, 0x48, 0xaa, 0xa7, 0xed, 0x6b, 0x68, 0x65, 0x97, 0xef, 0x10, 0x43, 0x94, 0x58, 0x73, 0x80, 0x38, 0x92, 0xea, 0x69, 0xfb, 0x12, 0x5a, 0xd9, 0xe5, 0x3b, 0xc0, 0x10, 0x25, 0xd6, 0x1c, 0xe0,
0x8f, 0x2b, 0xe2, 0x17, 0x0d, 0x36, 0xa3, 0x4e, 0x9e, 0x98, 0xec, 0x8c, 0xb5, 0xad, 0xad, 0x6f, 0xe3, 0x8a, 0xf8, 0xa8, 0xc1, 0x66, 0xd4, 0xc9, 0x53, 0x93, 0x9d, 0xb1, 0xb6, 0xb5, 0xf5, 0x9d,
0x5f, 0x3f, 0xfe, 0x85, 0x71, 0x55, 0xf9, 0xd9, 0x1a, 0xac, 0xbc, 0xe5, 0xe1, 0x04, 0x69, 0x1f, 0xeb, 0xc7, 0xbf, 0x30, 0xae, 0x2a, 0x3f, 0x5f, 0x83, 0x95, 0xb7, 0x3c, 0x9c, 0x20, 0xed, 0x43,
0x5a, 0x55, 0x27, 0x73, 0x97, 0x6e, 0xdf, 0x9c, 0x13, 0x23, 0xfe, 0x14, 0x36, 0x7d, 0xbd, 0x4a, 0xab, 0xea, 0x64, 0xee, 0xd2, 0xed, 0x99, 0x73, 0x62, 0xc4, 0x1f, 0xc3, 0xa6, 0xaf, 0x57, 0xe9,
0x4f, 0x10, 0xd3, 0xe2, 0xc5, 0x9a, 0x05, 0xe9, 0x1b, 0xb8, 0x35, 0x93, 0xf0, 0x30, 0xe2, 0x89, 0x09, 0x62, 0x5a, 0xbc, 0x58, 0xb3, 0x20, 0x7d, 0x03, 0xb7, 0x67, 0x12, 0x1e, 0x46, 0x3c, 0x11,
0x18, 0xc5, 0x52, 0x5d, 0x93, 0x4c, 0xd3, 0x77, 0xfd, 0xec, 0xe1, 0x6c, 0xb2, 0x0a, 0x32, 0x4f, 0xa3, 0x58, 0xaa, 0x6b, 0x92, 0x69, 0xfa, 0xae, 0x9f, 0x3d, 0x9c, 0x4d, 0x56, 0x41, 0xe6, 0xe9,
0x6f, 0xd7, 0xd1, 0xff, 0x6a, 0x41, 0x2b, 0xa7, 0x3e, 0xe4, 0x92, 0x93, 0xc7, 0xb0, 0xe6, 0x65, 0xed, 0x3a, 0xfa, 0x9f, 0x2d, 0x68, 0xe5, 0xd4, 0x07, 0x5c, 0x72, 0xf2, 0x04, 0xd6, 0xbc, 0xac,
0x35, 0x35, 0x8f, 0xf1, 0xbd, 0xeb, 0x55, 0xb8, 0x56, 0x7a, 0x96, 0xeb, 0xab, 0x59, 0x26, 0x4c, 0xa6, 0xe6, 0x31, 0xbe, 0x7f, 0xbd, 0x0a, 0xd7, 0x4a, 0xcf, 0x72, 0x7d, 0x35, 0xcb, 0x84, 0x89,
0x74, 0xa6, 0x82, 0x83, 0x45, 0xb6, 0x79, 0x16, 0xac, 0xb0, 0xa0, 0x3f, 0x99, 0x27, 0x69, 0x38, 0xce, 0x54, 0x70, 0xb0, 0xc8, 0x36, 0xcf, 0x82, 0x15, 0x16, 0xf4, 0x07, 0xf3, 0x24, 0x0d, 0x27,
0x39, 0x15, 0x5e, 0x1a, 0x24, 0xea, 0x38, 0xab, 0xbb, 0x64, 0x1e, 0xf0, 0x3c, 0xc5, 0x62, 0x4f, 0xa7, 0xc2, 0x4b, 0x83, 0x44, 0x1d, 0x67, 0x75, 0x97, 0xcc, 0x03, 0x9e, 0xa7, 0x58, 0xec, 0xc9,
0x9e, 0xc0, 0x2a, 0xf7, 0x94, 0x96, 0x76, 0xd6, 0xde, 0xa3, 0x73, 0xce, 0x2a, 0x4c, 0x4f, 0xb5, 0x53, 0x58, 0xe5, 0x9e, 0xd2, 0xd2, 0xce, 0xda, 0xbb, 0x74, 0xce, 0x59, 0x85, 0xe9, 0x99, 0xd6,
0x26, 0x33, 0x16, 0xf7, 0xff, 0xb4, 0x60, 0xfd, 0x28, 0x4d, 0x0f, 0x62, 0x1f, 0x05, 0x69, 0x03, 0x64, 0xc6, 0xe2, 0xc1, 0xef, 0x16, 0xac, 0x1f, 0xa6, 0xe9, 0x7e, 0xec, 0xa3, 0x20, 0x6d, 0x80,
0xbc, 0x8e, 0xf0, 0x22, 0x41, 0x4f, 0xa2, 0xef, 0x34, 0x88, 0x63, 0xde, 0xb4, 0x97, 0x81, 0x10, 0xd7, 0x11, 0x5e, 0x24, 0xe8, 0x49, 0xf4, 0x9d, 0x06, 0x71, 0xcc, 0x9b, 0xf6, 0x32, 0x10, 0x22,
0x41, 0x74, 0xee, 0x58, 0x64, 0xcb, 0x74, 0xee, 0xe8, 0x22, 0x10, 0x52, 0x38, 0x36, 0xb9, 0x09, 0x88, 0xce, 0x1d, 0x8b, 0xdc, 0x30, 0x9d, 0x3b, 0xbc, 0x08, 0x84, 0x14, 0x8e, 0x4d, 0x6e, 0xc1,
0x5b, 0x1a, 0x78, 0x15, 0x4b, 0x37, 0x3a, 0xe0, 0xde, 0x08, 0x9d, 0x25, 0x42, 0xa0, 0xad, 0x41, 0x0d, 0x0d, 0xbc, 0x8a, 0xa5, 0x1b, 0xed, 0x73, 0x6f, 0x84, 0xce, 0x12, 0x21, 0xd0, 0xd6, 0xa0,
0x57, 0x64, 0x1d, 0xf6, 0x9d, 0x65, 0xd2, 0x85, 0x8e, 0xae, 0xb4, 0x78, 0x15, 0x4b, 0xf3, 0xd0, 0x2b, 0xb2, 0x0e, 0xfb, 0xce, 0x32, 0xe9, 0x42, 0x47, 0x57, 0x5a, 0xbc, 0x8a, 0xa5, 0x79, 0x68,
0x06, 0xa7, 0x21, 0x3a, 0x2b, 0xa4, 0x03, 0x0e, 0x43, 0x0f, 0x83, 0x44, 0xba, 0xc2, 0x8d, 0xde, 0x83, 0xd3, 0x10, 0x9d, 0x15, 0xd2, 0x01, 0x87, 0xa1, 0x87, 0x41, 0x22, 0x5d, 0xe1, 0x46, 0x6f,
0xf2, 0x30, 0xf0, 0x9d, 0x55, 0xe5, 0xe9, 0x28, 0x4d, 0xe3, 0xf4, 0xf8, 0xec, 0x4c, 0xa0, 0x74, 0x79, 0x18, 0xf8, 0xce, 0xaa, 0xf2, 0x74, 0x98, 0xa6, 0x71, 0x7a, 0x7c, 0x76, 0x26, 0x50, 0x3a,
0xfc, 0xfb, 0x8f, 0xe1, 0xce, 0x82, 0x64, 0xc8, 0x26, 0x34, 0x0d, 0x7a, 0x8a, 0x4e, 0x43, 0x99, 0xfe, 0x83, 0x27, 0x70, 0x77, 0x41, 0x32, 0x64, 0x13, 0x9a, 0x06, 0x3d, 0x45, 0xa7, 0xa1, 0x4c,
0xbe, 0x8e, 0x44, 0x01, 0x58, 0x7b, 0x7f, 0xd9, 0xd0, 0xcc, 0x6c, 0xa7, 0x91, 0x47, 0x0e, 0x60, 0x5f, 0x47, 0xa2, 0x00, 0xac, 0xdd, 0x7f, 0x6c, 0x68, 0x66, 0xb6, 0xd3, 0xc8, 0x23, 0xfb, 0xb0,
0x3d, 0x9f, 0xa5, 0xa4, 0x57, 0x3b, 0x60, 0xf5, 0xa8, 0xe8, 0xdd, 0xad, 0x1f, 0xbe, 0xd9, 0x88, 0x9e, 0xcf, 0x52, 0xd2, 0xab, 0x1d, 0xb0, 0x7a, 0x54, 0xf4, 0xee, 0xd5, 0x0f, 0xdf, 0x6c, 0x44,
0x78, 0x6e, 0x18, 0xd5, 0xc0, 0x21, 0x77, 0xe7, 0xc6, 0x43, 0x39, 0xed, 0x7a, 0xdb, 0xf5, 0xc2, 0xbc, 0x30, 0x8c, 0x6a, 0xe0, 0x90, 0x7b, 0x73, 0xe3, 0xa1, 0x9c, 0x76, 0xbd, 0xad, 0x7a, 0xe1,
0x39, 0x9e, 0x30, 0xac, 0xe3, 0x29, 0x26, 0x57, 0x1d, 0x4f, 0x65, 0x64, 0x31, 0x70, 0xca, 0x8f, 0x1c, 0x4f, 0x18, 0xd6, 0xf1, 0x14, 0x93, 0xab, 0x8e, 0xa7, 0x32, 0xb2, 0x18, 0x38, 0xe5, 0x47,
0x80, 0xa1, 0x4c, 0x91, 0x8f, 0xc9, 0xf6, 0xdc, 0xa5, 0xaf, 0x7c, 0x21, 0xf4, 0x3e, 0x28, 0xdd, 0xc0, 0x50, 0xa6, 0xc8, 0xc7, 0x64, 0x6b, 0xee, 0xd2, 0x57, 0xbe, 0x10, 0x7a, 0x1f, 0x94, 0x6e,
0xb1, 0x1e, 0x5a, 0xcf, 0x3e, 0xff, 0xfb, 0xb2, 0x6f, 0xbd, 0xbf, 0xec, 0x5b, 0xff, 0x5d, 0xf6, 0x5b, 0x8f, 0x2c, 0x72, 0x04, 0x50, 0x0a, 0xfe, 0x0f, 0xdb, 0xf3, 0x4f, 0xff, 0xb8, 0xec, 0x5b,
0xad, 0xdf, 0xae, 0xfa, 0x8d, 0xf7, 0x57, 0xfd, 0xc6, 0xbf, 0x57, 0xfd, 0xc6, 0x0f, 0xbd, 0xc5, 0xef, 0x2f, 0xfb, 0xd6, 0xdf, 0x97, 0x7d, 0xeb, 0x97, 0xab, 0x7e, 0xe3, 0xfd, 0x55, 0xbf, 0xf1,
0xdf, 0x96, 0xa7, 0xab, 0xfa, 0x6f, 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0xd3, 0x01, 0xff, 0xd7, 0x55, 0xbf, 0xf1, 0x5d, 0x6f, 0xf1, 0x57, 0xea, 0xe9, 0xaa, 0xfe, 0xdb, 0xfb, 0x37, 0x00,
0xb5, 0x80, 0x0a, 0x00, 0x00, 0x00, 0xff, 0xff, 0xb6, 0xe1, 0x84, 0x46, 0xca, 0x0a, 0x00, 0x00,
} }
func (m *HeadSyncRange) Marshal() (dAtA []byte, err error) { func (m *HeadSyncRange) Marshal() (dAtA []byte, err error) {

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT. // Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32 // protoc-gen-go-drpc version: v0.0.33
// source: commonspace/spacesyncproto/protos/spacesync.proto // source: commonspace/spacesyncproto/protos/spacesync.proto
package spacesyncproto package spacesyncproto
@ -44,6 +44,7 @@ type DRPCSpaceSyncClient interface {
SpacePush(ctx context.Context, in *SpacePushRequest) (*SpacePushResponse, error) SpacePush(ctx context.Context, in *SpacePushRequest) (*SpacePushResponse, error)
SpacePull(ctx context.Context, in *SpacePullRequest) (*SpacePullResponse, error) SpacePull(ctx context.Context, in *SpacePullRequest) (*SpacePullResponse, error)
ObjectSyncStream(ctx context.Context) (DRPCSpaceSync_ObjectSyncStreamClient, error) ObjectSyncStream(ctx context.Context) (DRPCSpaceSync_ObjectSyncStreamClient, error)
ObjectSync(ctx context.Context, in *ObjectSyncMessage) (*ObjectSyncMessage, error)
} }
type drpcSpaceSyncClient struct { type drpcSpaceSyncClient struct {
@ -102,6 +103,10 @@ type drpcSpaceSync_ObjectSyncStreamClient struct {
drpc.Stream drpc.Stream
} }
func (x *drpcSpaceSync_ObjectSyncStreamClient) GetStream() drpc.Stream {
return x.Stream
}
func (x *drpcSpaceSync_ObjectSyncStreamClient) Send(m *ObjectSyncMessage) error { func (x *drpcSpaceSync_ObjectSyncStreamClient) Send(m *ObjectSyncMessage) error {
return x.MsgSend(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}) return x.MsgSend(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{})
} }
@ -118,11 +123,21 @@ func (x *drpcSpaceSync_ObjectSyncStreamClient) RecvMsg(m *ObjectSyncMessage) err
return x.MsgRecv(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}) return x.MsgRecv(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{})
} }
func (c *drpcSpaceSyncClient) ObjectSync(ctx context.Context, in *ObjectSyncMessage) (*ObjectSyncMessage, error) {
out := new(ObjectSyncMessage)
err := c.cc.Invoke(ctx, "/spacesync.SpaceSync/ObjectSync", drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}, in, out)
if err != nil {
return nil, err
}
return out, nil
}
type DRPCSpaceSyncServer interface { type DRPCSpaceSyncServer interface {
HeadSync(context.Context, *HeadSyncRequest) (*HeadSyncResponse, error) HeadSync(context.Context, *HeadSyncRequest) (*HeadSyncResponse, error)
SpacePush(context.Context, *SpacePushRequest) (*SpacePushResponse, error) SpacePush(context.Context, *SpacePushRequest) (*SpacePushResponse, error)
SpacePull(context.Context, *SpacePullRequest) (*SpacePullResponse, error) SpacePull(context.Context, *SpacePullRequest) (*SpacePullResponse, error)
ObjectSyncStream(DRPCSpaceSync_ObjectSyncStreamStream) error ObjectSyncStream(DRPCSpaceSync_ObjectSyncStreamStream) error
ObjectSync(context.Context, *ObjectSyncMessage) (*ObjectSyncMessage, error)
} }
type DRPCSpaceSyncUnimplementedServer struct{} type DRPCSpaceSyncUnimplementedServer struct{}
@ -143,9 +158,13 @@ func (s *DRPCSpaceSyncUnimplementedServer) ObjectSyncStream(DRPCSpaceSync_Object
return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
} }
func (s *DRPCSpaceSyncUnimplementedServer) ObjectSync(context.Context, *ObjectSyncMessage) (*ObjectSyncMessage, error) {
return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented)
}
type DRPCSpaceSyncDescription struct{} type DRPCSpaceSyncDescription struct{}
func (DRPCSpaceSyncDescription) NumMethods() int { return 4 } func (DRPCSpaceSyncDescription) NumMethods() int { return 5 }
func (DRPCSpaceSyncDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) { func (DRPCSpaceSyncDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) {
switch n { switch n {
@ -184,6 +203,15 @@ func (DRPCSpaceSyncDescription) Method(n int) (string, drpc.Encoding, drpc.Recei
&drpcSpaceSync_ObjectSyncStreamStream{in1.(drpc.Stream)}, &drpcSpaceSync_ObjectSyncStreamStream{in1.(drpc.Stream)},
) )
}, DRPCSpaceSyncServer.ObjectSyncStream, true }, DRPCSpaceSyncServer.ObjectSyncStream, true
case 4:
return "/spacesync.SpaceSync/ObjectSync", drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{},
func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) {
return srv.(DRPCSpaceSyncServer).
ObjectSync(
ctx,
in1.(*ObjectSyncMessage),
)
}, DRPCSpaceSyncServer.ObjectSync, true
default: default:
return "", nil, nil, nil, false return "", nil, nil, nil, false
} }
@ -266,3 +294,19 @@ func (x *drpcSpaceSync_ObjectSyncStreamStream) Recv() (*ObjectSyncMessage, error
func (x *drpcSpaceSync_ObjectSyncStreamStream) RecvMsg(m *ObjectSyncMessage) error { func (x *drpcSpaceSync_ObjectSyncStreamStream) RecvMsg(m *ObjectSyncMessage) error {
return x.MsgRecv(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}) return x.MsgRecv(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{})
} }
type DRPCSpaceSync_ObjectSyncStream interface {
drpc.Stream
SendAndClose(*ObjectSyncMessage) error
}
type drpcSpaceSync_ObjectSyncStream struct {
drpc.Stream
}
func (x *drpcSpaceSync_ObjectSyncStream) SendAndClose(m *ObjectSyncMessage) error {
if err := x.MsgSend(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}); err != nil {
return err
}
return x.CloseSend()
}

View File

@ -6,12 +6,15 @@ import (
accountService "github.com/anyproto/any-sync/accountservice" accountService "github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/ocache" "github.com/anyproto/any-sync/app/ocache"
"github.com/anyproto/any-sync/commonspace/config"
"github.com/anyproto/any-sync/commonspace/credentialprovider" "github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/objecttreebuilder"
"github.com/anyproto/any-sync/commonspace/peermanager" "github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/pool" "github.com/anyproto/any-sync/net/pool"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
@ -128,6 +131,14 @@ func (m *mockConf) NodeTypes(nodeId string) []nodeconf.NodeType {
type mockPeerManager struct { type mockPeerManager struct {
} }
func (p *mockPeerManager) Init(a *app.App) (err error) {
return nil
}
func (p *mockPeerManager) Name() (name string) {
return peermanager.CName
}
func (p *mockPeerManager) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { func (p *mockPeerManager) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
return nil return nil
} }
@ -159,6 +170,25 @@ func (m *mockPeerManagerProvider) NewPeerManager(ctx context.Context, spaceId st
return &mockPeerManager{}, nil return &mockPeerManager{}, nil
} }
//
// Mock StatusServiceProvider
//
type mockStatusServiceProvider struct {
}
func (m *mockStatusServiceProvider) Init(a *app.App) (err error) {
return nil
}
func (m *mockStatusServiceProvider) Name() (name string) {
return syncstatus.CName
}
func (m *mockStatusServiceProvider) NewStatusService() syncstatus.StatusService {
return syncstatus.NewNoOpSyncStatus()
}
// //
// Mock Pool // Mock Pool
// //
@ -166,6 +196,10 @@ func (m *mockPeerManagerProvider) NewPeerManager(ctx context.Context, spaceId st
type mockPool struct { type mockPool struct {
} }
func (m *mockPool) AddPeer(ctx context.Context, p peer.Peer) (err error) {
return nil
}
func (m *mockPool) Init(a *app.App) (err error) { func (m *mockPool) Init(a *app.App) (err error) {
return nil return nil
} }
@ -205,8 +239,8 @@ func (m *mockConfig) Name() (name string) {
return "config" return "config"
} }
func (m *mockConfig) GetSpace() Config { func (m *mockConfig) GetSpace() config.Config {
return Config{ return config.Config{
GCTTL: 60, GCTTL: 60,
SyncPeriod: 20, SyncPeriod: 20,
KeepTreeDataInMemory: true, KeepTreeDataInMemory: true,
@ -236,6 +270,7 @@ type mockTreeManager struct {
cache ocache.OCache cache ocache.OCache
deletedIds []string deletedIds []string
markedIds []string markedIds []string
waitLoad chan struct{}
} }
func (t *mockTreeManager) NewTreeSyncer(spaceId string, treeManager treemanager.TreeManager) treemanager.TreeSyncer { func (t *mockTreeManager) NewTreeSyncer(spaceId string, treeManager treemanager.TreeManager) treemanager.TreeSyncer {
@ -249,7 +284,8 @@ func (t *mockTreeManager) MarkTreeDeleted(ctx context.Context, spaceId, treeId s
func (t *mockTreeManager) Init(a *app.App) (err error) { func (t *mockTreeManager) Init(a *app.App) (err error) {
t.cache = ocache.New(func(ctx context.Context, id string) (value ocache.Object, err error) { t.cache = ocache.New(func(ctx context.Context, id string) (value ocache.Object, err error) {
return t.space.BuildTree(ctx, id, BuildTreeOpts{}) <-t.waitLoad
return t.space.TreeBuilder().BuildTree(ctx, id, objecttreebuilder.BuildTreeOpts{})
}, },
ocache.WithGCPeriod(time.Minute), ocache.WithGCPeriod(time.Minute),
ocache.WithTTL(time.Duration(60)*time.Second)) ocache.WithTTL(time.Duration(60)*time.Second))
@ -318,12 +354,14 @@ func newFixture(t *testing.T) *spaceFixture {
configurationService: &mockConf{}, configurationService: &mockConf{},
storageProvider: spacestorage.NewInMemorySpaceStorageProvider(), storageProvider: spacestorage.NewInMemorySpaceStorageProvider(),
peermanagerProvider: &mockPeerManagerProvider{}, peermanagerProvider: &mockPeerManagerProvider{},
treeManager: &mockTreeManager{}, treeManager: &mockTreeManager{waitLoad: make(chan struct{})},
pool: &mockPool{}, pool: &mockPool{},
spaceService: New(), spaceService: New(),
} }
fx.app.Register(fx.account). fx.app.Register(fx.account).
Register(fx.config). Register(fx.config).
Register(credentialprovider.NewNoOp()).
Register(&mockStatusServiceProvider{}).
Register(fx.configurationService). Register(fx.configurationService).
Register(fx.storageProvider). Register(fx.storageProvider).
Register(fx.peermanagerProvider). Register(fx.peermanagerProvider).

View File

@ -1,9 +1,32 @@
package syncstatus package syncstatus
import (
"context"
"github.com/anyproto/any-sync/app"
)
func NewNoOpSyncStatus() StatusService {
return &noOpSyncStatus{}
}
type noOpSyncStatus struct{} type noOpSyncStatus struct{}
func NewNoOpSyncStatus() StatusUpdater { func (n *noOpSyncStatus) Init(a *app.App) (err error) {
return &noOpSyncStatus{} return nil
}
func (n *noOpSyncStatus) Name() (name string) {
return CName
}
func (n *noOpSyncStatus) Watch(treeId string) (err error) {
return nil
}
func (n *noOpSyncStatus) Unwatch(treeId string) {
}
func (n *noOpSyncStatus) SetUpdateReceiver(updater UpdateReceiver) {
} }
func (n *noOpSyncStatus) HeadsChange(treeId string, heads []string) { func (n *noOpSyncStatus) HeadsChange(treeId string, heads []string) {
@ -22,9 +45,10 @@ func (n *noOpSyncStatus) StateCounter() uint64 {
func (n *noOpSyncStatus) RemoveAllExcept(senderId string, differentRemoteIds []string, stateCounter uint64) { func (n *noOpSyncStatus) RemoveAllExcept(senderId string, differentRemoteIds []string, stateCounter uint64) {
} }
func (n *noOpSyncStatus) Run() { func (n *noOpSyncStatus) Run(ctx context.Context) error {
} return nil
}
func (n *noOpSyncStatus) Close() error {
func (n *noOpSyncStatus) Close(ctx context.Context) error {
return nil return nil
} }

View File

@ -3,6 +3,8 @@ package syncstatus
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/spacestate"
"sync" "sync"
"time" "time"
@ -20,7 +22,9 @@ const (
syncTimeout = time.Second syncTimeout = time.Second
) )
var log = logger.NewNamed("common.commonspace.syncstatus") var log = logger.NewNamed(CName)
const CName = "common.commonspace.syncstatus"
type UpdateReceiver interface { type UpdateReceiver interface {
UpdateTree(ctx context.Context, treeId string, status SyncStatus) (err error) UpdateTree(ctx context.Context, treeId string, status SyncStatus) (err error)
@ -34,9 +38,6 @@ type StatusUpdater interface {
SetNodesOnline(senderId string, online bool) SetNodesOnline(senderId string, online bool)
StateCounter() uint64 StateCounter() uint64
RemoveAllExcept(senderId string, differentRemoteIds []string, stateCounter uint64) RemoveAllExcept(senderId string, differentRemoteIds []string, stateCounter uint64)
Run()
Close() error
} }
type StatusWatcher interface { type StatusWatcher interface {
@ -45,7 +46,13 @@ type StatusWatcher interface {
SetUpdateReceiver(updater UpdateReceiver) SetUpdateReceiver(updater UpdateReceiver)
} }
type StatusProvider interface { type StatusServiceProvider interface {
app.Component
NewStatusService() StatusService
}
type StatusService interface {
app.ComponentRunnable
StatusUpdater StatusUpdater
StatusWatcher StatusWatcher
} }
@ -70,7 +77,7 @@ type treeStatus struct {
heads []string heads []string
} }
type syncStatusProvider struct { type syncStatusService struct {
sync.Mutex sync.Mutex
configuration nodeconf.NodeConf configuration nodeconf.NodeConf
periodicSync periodicsync.PeriodicSync periodicSync periodicsync.PeriodicSync
@ -89,52 +96,45 @@ type syncStatusProvider struct {
updateTimeout time.Duration updateTimeout time.Duration
} }
type SyncStatusDeps struct { func NewSyncStatusProvider() StatusService {
UpdateIntervalSecs int return &syncStatusService{
UpdateTimeout time.Duration treeHeads: map[string]treeHeadsEntry{},
Configuration nodeconf.NodeConf watchers: map[string]struct{}{},
Storage spacestorage.SpaceStorage
}
func DefaultDeps(configuration nodeconf.NodeConf, store spacestorage.SpaceStorage) SyncStatusDeps {
return SyncStatusDeps{
UpdateIntervalSecs: syncUpdateInterval,
UpdateTimeout: syncTimeout,
Configuration: configuration,
Storage: store,
} }
} }
func NewSyncStatusProvider(spaceId string, deps SyncStatusDeps) StatusProvider { func (s *syncStatusService) Init(a *app.App) (err error) {
return &syncStatusProvider{ sharedState := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
spaceId: spaceId, s.updateIntervalSecs = syncUpdateInterval
treeHeads: map[string]treeHeadsEntry{}, s.updateTimeout = syncTimeout
watchers: map[string]struct{}{}, s.spaceId = sharedState.SpaceId
updateIntervalSecs: deps.UpdateIntervalSecs, s.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
updateTimeout: deps.UpdateTimeout, s.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
configuration: deps.Configuration, return
storage: deps.Storage,
stateCounter: 0,
}
} }
func (s *syncStatusProvider) SetUpdateReceiver(updater UpdateReceiver) { func (s *syncStatusService) Name() (name string) {
return CName
}
func (s *syncStatusService) SetUpdateReceiver(updater UpdateReceiver) {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
s.updateReceiver = updater s.updateReceiver = updater
} }
func (s *syncStatusProvider) Run() { func (s *syncStatusService) Run(ctx context.Context) error {
s.periodicSync = periodicsync.NewPeriodicSync( s.periodicSync = periodicsync.NewPeriodicSync(
s.updateIntervalSecs, s.updateIntervalSecs,
s.updateTimeout, s.updateTimeout,
s.update, s.update,
log) log)
s.periodicSync.Run() s.periodicSync.Run()
return nil
} }
func (s *syncStatusProvider) HeadsChange(treeId string, heads []string) { func (s *syncStatusService) HeadsChange(treeId string, heads []string) {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
@ -149,7 +149,7 @@ func (s *syncStatusProvider) HeadsChange(treeId string, heads []string) {
s.stateCounter++ s.stateCounter++
} }
func (s *syncStatusProvider) SetNodesOnline(senderId string, online bool) { func (s *syncStatusService) SetNodesOnline(senderId string, online bool) {
if !s.isSenderResponsible(senderId) { if !s.isSenderResponsible(senderId) {
return return
} }
@ -160,7 +160,7 @@ func (s *syncStatusProvider) SetNodesOnline(senderId string, online bool) {
s.nodesOnline = online s.nodesOnline = online
} }
func (s *syncStatusProvider) update(ctx context.Context) (err error) { func (s *syncStatusService) update(ctx context.Context) (err error) {
s.treeStatusBuf = s.treeStatusBuf[:0] s.treeStatusBuf = s.treeStatusBuf[:0]
s.Lock() s.Lock()
@ -189,7 +189,7 @@ func (s *syncStatusProvider) update(ctx context.Context) (err error) {
return return
} }
func (s *syncStatusProvider) HeadsReceive(senderId, treeId string, heads []string) { func (s *syncStatusService) HeadsReceive(senderId, treeId string, heads []string) {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
@ -218,7 +218,7 @@ func (s *syncStatusProvider) HeadsReceive(senderId, treeId string, heads []strin
s.treeHeads[treeId] = curTreeHeads s.treeHeads[treeId] = curTreeHeads
} }
func (s *syncStatusProvider) Watch(treeId string) (err error) { func (s *syncStatusService) Watch(treeId string) (err error) {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
_, ok := s.treeHeads[treeId] _, ok := s.treeHeads[treeId]
@ -248,7 +248,7 @@ func (s *syncStatusProvider) Watch(treeId string) (err error) {
return return
} }
func (s *syncStatusProvider) Unwatch(treeId string) { func (s *syncStatusService) Unwatch(treeId string) {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
@ -257,19 +257,14 @@ func (s *syncStatusProvider) Unwatch(treeId string) {
} }
} }
func (s *syncStatusProvider) Close() (err error) { func (s *syncStatusService) StateCounter() uint64 {
s.periodicSync.Close()
return
}
func (s *syncStatusProvider) StateCounter() uint64 {
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
return s.stateCounter return s.stateCounter
} }
func (s *syncStatusProvider) RemoveAllExcept(senderId string, differentRemoteIds []string, stateCounter uint64) { func (s *syncStatusService) RemoveAllExcept(senderId string, differentRemoteIds []string, stateCounter uint64) {
// if sender is not a responsible node, then this should have no effect // if sender is not a responsible node, then this should have no effect
if !s.isSenderResponsible(senderId) { if !s.isSenderResponsible(senderId) {
return return
@ -292,6 +287,11 @@ func (s *syncStatusProvider) RemoveAllExcept(senderId string, differentRemoteIds
} }
} }
func (s *syncStatusProvider) isSenderResponsible(senderId string) bool { func (s *syncStatusService) Close(ctx context.Context) error {
s.periodicSync.Close()
return nil
}
func (s *syncStatusService) isSenderResponsible(senderId string) bool {
return slices.Contains(s.configuration.NodeIds(s.spaceId), senderId) return slices.Contains(s.configuration.NodeIds(s.spaceId), senderId)
} }

View File

@ -10,6 +10,7 @@ import (
"github.com/anyproto/any-sync/net/rpc/rpcerr" "github.com/anyproto/any-sync/net/rpc/rpcerr"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/util/crypto" "github.com/anyproto/any-sync/util/crypto"
"storj.io/drpc"
) )
const CName = "common.coordinator.coordinatorclient" const CName = "common.coordinator.coordinatorclient"
@ -39,42 +40,8 @@ type coordinatorClient struct {
nodeConf nodeconf.Service nodeConf nodeconf.Service
} }
func (c *coordinatorClient) ChangeStatus(ctx context.Context, spaceId string, deleteRaw *treechangeproto.RawTreeChangeWithId) (status *coordinatorproto.SpaceStatusPayload, err error) {
cl, err := c.client(ctx)
if err != nil {
return
}
resp, err := cl.SpaceStatusChange(ctx, &coordinatorproto.SpaceStatusChangeRequest{
SpaceId: spaceId,
DeletionChangeId: deleteRaw.GetId(),
DeletionChangePayload: deleteRaw.GetRawChange(),
})
if err != nil {
err = rpcerr.Unwrap(err)
return
}
status = resp.Payload
return
}
func (c *coordinatorClient) StatusCheck(ctx context.Context, spaceId string) (status *coordinatorproto.SpaceStatusPayload, err error) {
cl, err := c.client(ctx)
if err != nil {
return
}
resp, err := cl.SpaceStatusCheck(ctx, &coordinatorproto.SpaceStatusCheckRequest{
SpaceId: spaceId,
})
if err != nil {
err = rpcerr.Unwrap(err)
return
}
status = resp.Payload
return
}
func (c *coordinatorClient) Init(a *app.App) (err error) { func (c *coordinatorClient) Init(a *app.App) (err error) {
c.pool = a.MustComponent(pool.CName).(pool.Service).NewPool(CName) c.pool = a.MustComponent(pool.CName).(pool.Service)
c.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.Service) c.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.Service)
return return
} }
@ -83,8 +50,37 @@ func (c *coordinatorClient) Name() (name string) {
return CName return CName
} }
func (c *coordinatorClient) ChangeStatus(ctx context.Context, spaceId string, deleteRaw *treechangeproto.RawTreeChangeWithId) (status *coordinatorproto.SpaceStatusPayload, err error) {
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
resp, err := cl.SpaceStatusChange(ctx, &coordinatorproto.SpaceStatusChangeRequest{
SpaceId: spaceId,
DeletionChangeId: deleteRaw.GetId(),
DeletionChangePayload: deleteRaw.GetRawChange(),
})
if err != nil {
return rpcerr.Unwrap(err)
}
status = resp.Payload
return nil
})
return
}
func (c *coordinatorClient) StatusCheck(ctx context.Context, spaceId string) (status *coordinatorproto.SpaceStatusPayload, err error) {
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
resp, err := cl.SpaceStatusCheck(ctx, &coordinatorproto.SpaceStatusCheckRequest{
SpaceId: spaceId,
})
if err != nil {
return rpcerr.Unwrap(err)
}
status = resp.Payload
return nil
})
return
}
func (c *coordinatorClient) SpaceSign(ctx context.Context, payload SpaceSignPayload) (receipt *coordinatorproto.SpaceReceiptWithSignature, err error) { func (c *coordinatorClient) SpaceSign(ctx context.Context, payload SpaceSignPayload) (receipt *coordinatorproto.SpaceReceiptWithSignature, err error) {
cl, err := c.client(ctx)
if err != nil { if err != nil {
return return
} }
@ -100,54 +96,56 @@ func (c *coordinatorClient) SpaceSign(ctx context.Context, payload SpaceSignPayl
if err != nil { if err != nil {
return return
} }
resp, err := cl.SpaceSign(ctx, &coordinatorproto.SpaceSignRequest{ err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
SpaceId: payload.SpaceId, resp, err := cl.SpaceSign(ctx, &coordinatorproto.SpaceSignRequest{
Header: payload.SpaceHeader, SpaceId: payload.SpaceId,
OldIdentity: oldIdentity, Header: payload.SpaceHeader,
NewIdentitySignature: newSignature, OldIdentity: oldIdentity,
NewIdentitySignature: newSignature,
})
if err != nil {
return rpcerr.Unwrap(err)
}
receipt = resp.Receipt
return nil
}) })
if err != nil {
err = rpcerr.Unwrap(err)
return
}
return resp.Receipt, nil
}
func (c *coordinatorClient) FileLimitCheck(ctx context.Context, spaceId string, identity []byte) (limit uint64, err error) {
cl, err := c.client(ctx)
if err != nil {
return
}
resp, err := cl.FileLimitCheck(ctx, &coordinatorproto.FileLimitCheckRequest{
AccountIdentity: identity,
SpaceId: spaceId,
})
if err != nil {
err = rpcerr.Unwrap(err)
return
}
return resp.Limit, nil
}
func (c *coordinatorClient) NetworkConfiguration(ctx context.Context, currentId string) (resp *coordinatorproto.NetworkConfigurationResponse, err error) {
cl, err := c.client(ctx)
if err != nil {
return
}
resp, err = cl.NetworkConfiguration(ctx, &coordinatorproto.NetworkConfigurationRequest{
CurrentId: currentId,
})
if err != nil {
err = rpcerr.Unwrap(err)
return
}
return return
} }
func (c *coordinatorClient) client(ctx context.Context) (coordinatorproto.DRPCCoordinatorClient, error) { func (c *coordinatorClient) FileLimitCheck(ctx context.Context, spaceId string, identity []byte) (limit uint64, err error) {
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
resp, err := cl.FileLimitCheck(ctx, &coordinatorproto.FileLimitCheckRequest{
AccountIdentity: identity,
SpaceId: spaceId,
})
if err != nil {
return rpcerr.Unwrap(err)
}
limit = resp.Limit
return nil
})
return
}
func (c *coordinatorClient) NetworkConfiguration(ctx context.Context, currentId string) (resp *coordinatorproto.NetworkConfigurationResponse, err error) {
err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error {
resp, err = cl.NetworkConfiguration(ctx, &coordinatorproto.NetworkConfigurationRequest{
CurrentId: currentId,
})
if err != nil {
return rpcerr.Unwrap(err)
}
return nil
})
return
}
func (c *coordinatorClient) doClient(ctx context.Context, f func(cl coordinatorproto.DRPCCoordinatorClient) error) error {
p, err := c.pool.GetOneOf(ctx, c.nodeConf.CoordinatorPeers()) p, err := c.pool.GetOneOf(ctx, c.nodeConf.CoordinatorPeers())
if err != nil { if err != nil {
return nil, err return err
} }
return coordinatorproto.NewDRPCCoordinatorClient(p), nil return p.DoDrpc(ctx, func(conn drpc.Conn) error {
return f(coordinatorproto.NewDRPCCoordinatorClient(conn))
})
} }

View File

@ -1,5 +1,5 @@
// Code generated by protoc-gen-go-drpc. DO NOT EDIT. // Code generated by protoc-gen-go-drpc. DO NOT EDIT.
// protoc-gen-go-drpc version: v0.0.32 // protoc-gen-go-drpc version: v0.0.33
// source: coordinator/coordinatorproto/protos/coordinator.proto // source: coordinator/coordinatorproto/protos/coordinator.proto
package coordinatorproto package coordinatorproto

13
go.mod
View File

@ -14,6 +14,7 @@ require (
github.com/gogo/protobuf v1.3.2 github.com/gogo/protobuf v1.3.2
github.com/golang/mock v1.6.0 github.com/golang/mock v1.6.0
github.com/google/uuid v1.3.0 github.com/google/uuid v1.3.0
github.com/hashicorp/yamux v0.1.1
github.com/huandu/skiplist v1.2.0 github.com/huandu/skiplist v1.2.0
github.com/ipfs/go-block-format v0.1.2 github.com/ipfs/go-block-format v0.1.2
github.com/ipfs/go-blockservice v0.5.2 github.com/ipfs/go-blockservice v0.5.2
@ -32,7 +33,7 @@ require (
github.com/stretchr/testify v1.8.4 github.com/stretchr/testify v1.8.4
github.com/tyler-smith/go-bip39 v1.1.0 github.com/tyler-smith/go-bip39 v1.1.0
github.com/zeebo/blake3 v0.2.3 github.com/zeebo/blake3 v0.2.3
github.com/zeebo/errs v1.3.0 go.uber.org/atomic v1.11.0
go.uber.org/zap v1.24.0 go.uber.org/zap v1.24.0
golang.org/x/crypto v0.9.0 golang.org/x/crypto v0.9.0
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
@ -55,7 +56,7 @@ require (
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
github.com/golang/protobuf v1.5.3 // indirect github.com/golang/protobuf v1.5.3 // indirect
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 // indirect github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect
github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect
github.com/huin/goupnp v1.2.0 // indirect github.com/huin/goupnp v1.2.0 // indirect
github.com/ipfs/bbloom v0.0.4 // indirect github.com/ipfs/bbloom v0.0.4 // indirect
@ -88,7 +89,7 @@ require (
github.com/multiformats/go-multicodec v0.9.0 // indirect github.com/multiformats/go-multicodec v0.9.0 // indirect
github.com/multiformats/go-multistream v0.4.1 // indirect github.com/multiformats/go-multistream v0.4.1 // indirect
github.com/multiformats/go-varint v0.0.7 // indirect github.com/multiformats/go-varint v0.0.7 // indirect
github.com/onsi/ginkgo/v2 v2.9.5 // indirect github.com/onsi/ginkgo/v2 v2.9.7 // indirect
github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
@ -96,19 +97,19 @@ require (
github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/client_model v0.4.0 // indirect
github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/common v0.44.0 // indirect
github.com/prometheus/procfs v0.10.0 // indirect github.com/prometheus/procfs v0.10.0 // indirect
github.com/quic-go/quic-go v0.34.0 // indirect github.com/quic-go/quic-go v0.35.1 // indirect
github.com/quic-go/webtransport-go v0.5.3 // indirect github.com/quic-go/webtransport-go v0.5.3 // indirect
github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect
github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158 // indirect github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158 // indirect
github.com/whyrusleeping/chunker v0.0.0-20181014151217-fe64bd25879f // indirect github.com/whyrusleeping/chunker v0.0.0-20181014151217-fe64bd25879f // indirect
github.com/zeebo/errs v1.3.0 // indirect
go.opentelemetry.io/otel v1.7.0 // indirect go.opentelemetry.io/otel v1.7.0 // indirect
go.opentelemetry.io/otel/trace v1.7.0 // indirect go.opentelemetry.io/otel/trace v1.7.0 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.6.0 // indirect golang.org/x/image v0.6.0 // indirect
golang.org/x/sync v0.2.0 // indirect golang.org/x/sync v0.2.0 // indirect
golang.org/x/sys v0.8.0 // indirect golang.org/x/sys v0.8.0 // indirect
golang.org/x/tools v0.9.1 // indirect golang.org/x/tools v0.9.3 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect
lukechampine.com/blake3 v1.2.1 // indirect lukechampine.com/blake3 v1.2.1 // indirect

18
go.sum
View File

@ -67,8 +67,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8=
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9tRnGuw+ZVCB3FazXODD6JE1R8= github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs=
github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
@ -79,6 +79,8 @@ github.com/gxed/hashland/keccakpg v0.0.1/go.mod h1:kRzw3HkwxFU1mpmPP8v1WyQzwdGfm
github.com/gxed/hashland/murmur3 v0.0.1/go.mod h1:KjXop02n4/ckmZSnY2+HKcLud/tcmvhST0bie/0lS48= github.com/gxed/hashland/murmur3 v0.0.1/go.mod h1:KjXop02n4/ckmZSnY2+HKcLud/tcmvhST0bie/0lS48=
github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc=
github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4=
github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE=
github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ=
github.com/huandu/go-assert v1.1.5 h1:fjemmA7sSfYHJD7CUqs9qTwwfdNAx7/j2/ZlHXzNB3c= github.com/huandu/go-assert v1.1.5 h1:fjemmA7sSfYHJD7CUqs9qTwwfdNAx7/j2/ZlHXzNB3c=
github.com/huandu/go-assert v1.1.5/go.mod h1:yOLvuqZwmcHIC5rIzrBhT7D3Q9c3GFnd0JrPVhn/06U= github.com/huandu/go-assert v1.1.5/go.mod h1:yOLvuqZwmcHIC5rIzrBhT7D3Q9c3GFnd0JrPVhn/06U=
github.com/huandu/skiplist v1.2.0 h1:gox56QD77HzSC0w+Ws3MH3iie755GBJU1OER3h5VsYw= github.com/huandu/skiplist v1.2.0 h1:gox56QD77HzSC0w+Ws3MH3iie755GBJU1OER3h5VsYw=
@ -242,8 +244,8 @@ github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXS
github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8= github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8=
github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU= github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU=
github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY= github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY=
github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss=
github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0=
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -266,8 +268,8 @@ github.com/prometheus/procfs v0.10.0/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPH
github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo=
github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U=
github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E=
github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU= github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo=
github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g=
github.com/quic-go/webtransport-go v0.5.3 h1:5XMlzemqB4qmOlgIus5zB45AcZ2kCgCy2EptUrfOPWU= github.com/quic-go/webtransport-go v0.5.3 h1:5XMlzemqB4qmOlgIus5zB45AcZ2kCgCy2EptUrfOPWU=
github.com/quic-go/webtransport-go v0.5.3/go.mod h1:OhmmgJIzTTqXK5xvtuX0oBpLV2GkLWNDA+UeTGJXErU= github.com/quic-go/webtransport-go v0.5.3/go.mod h1:OhmmgJIzTTqXK5xvtuX0oBpLV2GkLWNDA+UeTGJXErU=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
@ -413,8 +415,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@ -5,21 +5,3 @@ import "errors"
var ( var (
ErrUnableToConnect = errors.New("unable to connect") ErrUnableToConnect = errors.New("unable to connect")
) )
type ConfigGetter interface {
GetNet() Config
}
type Config struct {
Server ServerConfig `yaml:"server"`
Stream StreamConfig `yaml:"stream"`
}
type ServerConfig struct {
ListenAddrs []string `yaml:"listenAddrs"`
}
type StreamConfig struct {
TimeoutMilliseconds int `yaml:"timeoutMilliseconds"`
MaxMsgSizeMb int `yaml:"maxMsgSizeMb"`
}

View File

@ -1,4 +1,4 @@
package timeoutconn package connutil
import ( import (
"errors" "errors"
@ -10,18 +10,18 @@ import (
"go.uber.org/zap" "go.uber.org/zap"
) )
var log = logger.NewNamed("common.net.timeoutconn") var log = logger.NewNamed("common.net.connutil")
type Conn struct { type TimeoutConn struct {
net.Conn net.Conn
timeout time.Duration timeout time.Duration
} }
func NewConn(conn net.Conn, timeout time.Duration) *Conn { func NewTimeout(conn net.Conn, timeout time.Duration) *TimeoutConn {
return &Conn{conn, timeout} return &TimeoutConn{conn, timeout}
} }
func (c *Conn) Write(p []byte) (n int, err error) { func (c *TimeoutConn) Write(p []byte) (n int, err error) {
for { for {
if c.timeout != 0 { if c.timeout != 0 {
if e := c.Conn.SetWriteDeadline(time.Now().Add(c.timeout)); e != nil { if e := c.Conn.SetWriteDeadline(time.Now().Add(c.timeout)); e != nil {

30
net/connutil/usage.go Normal file
View File

@ -0,0 +1,30 @@
package connutil
import (
"go.uber.org/atomic"
"net"
"time"
)
func NewLastUsageConn(conn net.Conn) *LastUsageConn {
return &LastUsageConn{Conn: conn}
}
type LastUsageConn struct {
net.Conn
lastUsage atomic.Time
}
func (c *LastUsageConn) Write(p []byte) (n int, err error) {
c.lastUsage.Store(time.Now())
return c.Conn.Write(p)
}
func (c *LastUsageConn) Read(p []byte) (n int, err error) {
c.lastUsage.Store(time.Now())
return c.Conn.Read(p)
}
func (c *LastUsageConn) LastUsage() time.Time {
return c.lastUsage.Load()
}

View File

@ -1,137 +0,0 @@
package dialer
import (
"context"
"errors"
"fmt"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
net2 "github.com/anyproto/any-sync/net"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/secureservice"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/timeoutconn"
"github.com/anyproto/any-sync/nodeconf"
"github.com/libp2p/go-libp2p/core/sec"
"go.uber.org/zap"
"net"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
"storj.io/drpc/drpcmanager"
"storj.io/drpc/drpcwire"
"sync"
"time"
)
const CName = "common.net.dialer"
var (
ErrAddrsNotFound = errors.New("addrs for peer not found")
ErrPeerIdIsUnexpected = errors.New("expected to connect with other peer id")
)
var log = logger.NewNamed(CName)
func New() Dialer {
return &dialer{peerAddrs: map[string][]string{}}
}
type Dialer interface {
Dial(ctx context.Context, peerId string) (peer peer.Peer, err error)
SetPeerAddrs(peerId string, addrs []string)
app.Component
}
type dialer struct {
transport secureservice.SecureService
config net2.Config
nodeConf nodeconf.NodeConf
peerAddrs map[string][]string
mu sync.RWMutex
}
func (d *dialer) Init(a *app.App) (err error) {
d.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService)
d.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
d.config = a.MustComponent("config").(net2.ConfigGetter).GetNet()
return
}
func (d *dialer) Name() (name string) {
return CName
}
func (d *dialer) SetPeerAddrs(peerId string, addrs []string) {
d.mu.Lock()
defer d.mu.Unlock()
d.peerAddrs[peerId] = addrs
}
func (d *dialer) getPeerAddrs(peerId string) ([]string, error) {
if addrs, ok := d.nodeConf.PeerAddresses(peerId); ok {
return addrs, nil
}
addrs, ok := d.peerAddrs[peerId]
if !ok || len(addrs) == 0 {
return nil, ErrAddrsNotFound
}
return addrs, nil
}
func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err error) {
var ctxCancel context.CancelFunc
ctx, ctxCancel = context.WithTimeout(ctx, time.Second*10)
defer ctxCancel()
d.mu.RLock()
defer d.mu.RUnlock()
addrs, err := d.getPeerAddrs(peerId)
if err != nil {
return
}
var (
conn drpc.Conn
sc sec.SecureConn
)
log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs))
for _, addr := range addrs {
conn, sc, err = d.handshake(ctx, addr, peerId)
if err != nil {
log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err))
} else {
break
}
}
if err != nil {
return
}
return peer.NewPeer(sc, conn), nil
}
func (d *dialer) handshake(ctx context.Context, addr, peerId string) (conn drpc.Conn, sc sec.SecureConn, err error) {
st := time.Now()
// TODO: move dial timeout to config
tcpConn, err := net.DialTimeout("tcp", addr, time.Second*15)
if err != nil {
return nil, nil, fmt.Errorf("dialTimeout error: %v; since start: %v", err, time.Since(st))
}
timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds))
sc, err = d.transport.SecureOutbound(ctx, timeoutConn)
if err != nil {
if he, ok := err.(handshake.HandshakeError); ok {
return nil, nil, he
}
return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st))
}
if peerId != sc.RemotePeer().String() {
return nil, nil, ErrPeerIdIsUnexpected
}
log.Info("connected with remote host", zap.String("serverPeer", sc.RemotePeer().String()), zap.String("addr", addr))
conn = drpcconn.NewWithOptions(sc, drpcconn.Options{Manager: drpcmanager.Options{
Reader: drpcwire.ReaderOptions{MaximumBufferSize: d.config.Stream.MaxMsgSizeMb * (1 << 20)},
}})
return conn, sc, err
}

View File

@ -0,0 +1,149 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/net/peer (interfaces: Peer)
// Package mock_peer is a generated GoMock package.
package mock_peer
import (
context "context"
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
drpc "storj.io/drpc"
)
// MockPeer is a mock of Peer interface.
type MockPeer struct {
ctrl *gomock.Controller
recorder *MockPeerMockRecorder
}
// MockPeerMockRecorder is the mock recorder for MockPeer.
type MockPeerMockRecorder struct {
mock *MockPeer
}
// NewMockPeer creates a new mock instance.
func NewMockPeer(ctrl *gomock.Controller) *MockPeer {
mock := &MockPeer{ctrl: ctrl}
mock.recorder = &MockPeerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPeer) EXPECT() *MockPeerMockRecorder {
return m.recorder
}
// AcquireDrpcConn mocks base method.
func (m *MockPeer) AcquireDrpcConn(arg0 context.Context) (drpc.Conn, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AcquireDrpcConn", arg0)
ret0, _ := ret[0].(drpc.Conn)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// AcquireDrpcConn indicates an expected call of AcquireDrpcConn.
func (mr *MockPeerMockRecorder) AcquireDrpcConn(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireDrpcConn", reflect.TypeOf((*MockPeer)(nil).AcquireDrpcConn), arg0)
}
// Close mocks base method.
func (m *MockPeer) Close() error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close")
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockPeerMockRecorder) Close() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPeer)(nil).Close))
}
// Context mocks base method.
func (m *MockPeer) Context() context.Context {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Context")
ret0, _ := ret[0].(context.Context)
return ret0
}
// Context indicates an expected call of Context.
func (mr *MockPeerMockRecorder) Context() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockPeer)(nil).Context))
}
// DoDrpc mocks base method.
func (m *MockPeer) DoDrpc(arg0 context.Context, arg1 func(drpc.Conn) error) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DoDrpc", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DoDrpc indicates an expected call of DoDrpc.
func (mr *MockPeerMockRecorder) DoDrpc(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoDrpc", reflect.TypeOf((*MockPeer)(nil).DoDrpc), arg0, arg1)
}
// Id mocks base method.
func (m *MockPeer) Id() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Id")
ret0, _ := ret[0].(string)
return ret0
}
// Id indicates an expected call of Id.
func (mr *MockPeerMockRecorder) Id() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Id", reflect.TypeOf((*MockPeer)(nil).Id))
}
// IsClosed mocks base method.
func (m *MockPeer) IsClosed() bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsClosed")
ret0, _ := ret[0].(bool)
return ret0
}
// IsClosed indicates an expected call of IsClosed.
func (mr *MockPeerMockRecorder) IsClosed() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockPeer)(nil).IsClosed))
}
// ReleaseDrpcConn mocks base method.
func (m *MockPeer) ReleaseDrpcConn(arg0 drpc.Conn) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReleaseDrpcConn", arg0)
}
// ReleaseDrpcConn indicates an expected call of ReleaseDrpcConn.
func (mr *MockPeerMockRecorder) ReleaseDrpcConn(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseDrpcConn", reflect.TypeOf((*MockPeer)(nil).ReleaseDrpcConn), arg0)
}
// TryClose mocks base method.
func (m *MockPeer) TryClose(arg0 time.Duration) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryClose", arg0)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// TryClose indicates an expected call of TryClose.
func (mr *MockPeerMockRecorder) TryClose(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryClose", reflect.TypeOf((*MockPeer)(nil).TryClose), arg0)
}

View File

@ -1,100 +1,233 @@
//go:generate mockgen -destination mock_peer/mock_peer.go github.com/anyproto/any-sync/net/peer Peer
package peer package peer
import ( import (
"context" "context"
"sync/atomic"
"time"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/libp2p/go-libp2p/core/sec" "github.com/anyproto/any-sync/app/ocache"
"github.com/anyproto/any-sync/net/connutil"
"github.com/anyproto/any-sync/net/rpc"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anyproto/any-sync/net/transport"
"go.uber.org/zap" "go.uber.org/zap"
"io"
"net"
"storj.io/drpc" "storj.io/drpc"
"storj.io/drpc/drpcconn"
"storj.io/drpc/drpcmanager"
"storj.io/drpc/drpcstream"
"storj.io/drpc/drpcwire"
"sync"
"time"
) )
var log = logger.NewNamed("common.net.peer") var log = logger.NewNamed("common.net.peer")
func NewPeer(sc sec.SecureConn, conn drpc.Conn) Peer { type connCtrl interface {
return &peer{ ServeConn(ctx context.Context, conn net.Conn) (err error)
id: sc.RemotePeer().String(), DrpcConfig() rpc.Config
lastUsage: time.Now().Unix(), }
sc: sc,
Conn: conn, func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) {
ctx := mc.Context()
pr := &peer{
active: map[*subConn]struct{}{},
MultiConn: mc,
ctrl: ctrl,
} }
if pr.id, err = CtxPeerId(ctx); err != nil {
return
}
go pr.acceptLoop()
return pr, nil
} }
type Peer interface { type Peer interface {
Id() string Id() string
LastUsage() time.Time Context() context.Context
UpdateLastUsage()
Addr() string AcquireDrpcConn(ctx context.Context) (drpc.Conn, error)
ReleaseDrpcConn(conn drpc.Conn)
DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error
IsClosed() bool
TryClose(objectTTL time.Duration) (res bool, err error) TryClose(objectTTL time.Duration) (res bool, err error)
ocache.Object
}
type subConn struct {
drpc.Conn drpc.Conn
*connutil.LastUsageConn
} }
type peer struct { type peer struct {
id string id string
ttl time.Duration
lastUsage int64 ctrl connCtrl
sc sec.SecureConn
drpc.Conn // drpc conn pool
inactive []*subConn
active map[*subConn]struct{}
mu sync.Mutex
transport.MultiConn
} }
func (p *peer) Id() string { func (p *peer) Id() string {
return p.id return p.id
} }
func (p *peer) LastUsage() time.Time { func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
select { p.mu.Lock()
case <-p.Closed(): if len(p.inactive) == 0 {
return time.Unix(0, 0) p.mu.Unlock()
default: dconn, err := p.openDrpcConn(ctx)
if err != nil {
return nil, err
}
p.mu.Lock()
p.inactive = append(p.inactive, dconn)
} }
return time.Unix(atomic.LoadInt64(&p.lastUsage), 0) idx := len(p.inactive) - 1
res := p.inactive[idx]
p.inactive = p.inactive[:idx]
p.active[res] = struct{}{}
p.mu.Unlock()
return res, nil
} }
func (p *peer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { func (p *peer) ReleaseDrpcConn(conn drpc.Conn) {
defer p.UpdateLastUsage() p.mu.Lock()
return p.Conn.Invoke(ctx, rpc, enc, in, out) defer p.mu.Unlock()
} sc, ok := conn.(*subConn)
if !ok {
func (p *peer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { return
defer p.UpdateLastUsage()
return p.Conn.NewStream(ctx, rpc, enc)
}
func (p *peer) Read(b []byte) (n int, err error) {
if n, err = p.sc.Read(b); err == nil {
p.UpdateLastUsage()
} }
if _, ok = p.active[sc]; ok {
delete(p.active, sc)
}
p.inactive = append(p.inactive, sc)
return return
} }
func (p *peer) Write(b []byte) (n int, err error) { func (p *peer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error {
if n, err = p.sc.Write(b); err == nil { conn, err := p.AcquireDrpcConn(ctx)
p.UpdateLastUsage() if err != nil {
return err
} }
return defer p.ReleaseDrpcConn(conn)
return do(conn)
} }
func (p *peer) UpdateLastUsage() { func (p *peer) openDrpcConn(ctx context.Context) (dconn *subConn, err error) {
atomic.StoreInt64(&p.lastUsage, time.Now().Unix()) conn, err := p.Open(ctx)
if err != nil {
return nil, err
}
if err = handshake.OutgoingProtoHandshake(ctx, conn, handshakeproto.ProtoType_DRPC); err != nil {
return nil, err
}
tconn := connutil.NewLastUsageConn(conn)
bufSize := p.ctrl.DrpcConfig().Stream.MaxMsgSizeMb * (1 << 20)
return &subConn{
Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{
Manager: drpcmanager.Options{
Reader: drpcwire.ReaderOptions{MaximumBufferSize: bufSize},
Stream: drpcstream.Options{MaximumBufferSize: bufSize},
},
}),
LastUsageConn: tconn,
}, nil
}
func (p *peer) acceptLoop() {
var exitErr error
defer func() {
if exitErr != transport.ErrConnClosed {
log.Warn("accept error: close connection", zap.Error(exitErr))
_ = p.MultiConn.Close()
}
}()
for {
conn, err := p.Accept()
if err != nil {
exitErr = err
return
}
go func() {
serveErr := p.serve(conn)
if serveErr != io.EOF && serveErr != transport.ErrConnClosed {
log.InfoCtx(p.Context(), "serve connection error", zap.Error(serveErr))
}
}()
}
}
var defaultProtoChecker = handshake.ProtoChecker{
AllowedProtoTypes: []handshakeproto.ProtoType{
handshakeproto.ProtoType_DRPC,
},
}
func (p *peer) serve(conn net.Conn) (err error) {
hsCtx, cancel := context.WithTimeout(p.Context(), time.Second*20)
if _, err = handshake.IncomingProtoHandshake(hsCtx, conn, defaultProtoChecker); err != nil {
cancel()
return
}
cancel()
return p.ctrl.ServeConn(p.Context(), conn)
} }
func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) { func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) {
p.gc(objectTTL)
if time.Now().Sub(p.LastUsage()) < objectTTL { if time.Now().Sub(p.LastUsage()) < objectTTL {
return false, nil return false, nil
} }
return true, p.Close() return true, p.Close()
} }
func (p *peer) Addr() string { func (p *peer) gc(ttl time.Duration) {
if p.sc != nil { p.mu.Lock()
return p.sc.RemoteAddr().String() defer p.mu.Unlock()
minLastUsage := time.Now().Add(-ttl)
var hasClosed bool
for i, in := range p.inactive {
select {
case <-in.Closed():
p.inactive[i] = nil
hasClosed = true
default:
}
if in.LastUsage().Before(minLastUsage) {
_ = in.Close()
p.inactive[i] = nil
hasClosed = true
}
}
if hasClosed {
inactive := p.inactive
p.inactive = p.inactive[:0]
for _, in := range inactive {
if in != nil {
p.inactive = append(p.inactive, in)
}
}
}
for act := range p.active {
select {
case <-act.Closed():
delete(p.active, act)
default:
}
} }
return ""
} }
func (p *peer) Close() (err error) { func (p *peer) Close() (err error) {
log.Debug("peer close", zap.String("peerId", p.id)) log.Debug("peer close", zap.String("peerId", p.id))
return p.Conn.Close() return p.MultiConn.Close()
} }

192
net/peer/peer_test.go Normal file
View File

@ -0,0 +1,192 @@
package peer
import (
"context"
"github.com/anyproto/any-sync/net/rpc"
"github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anyproto/any-sync/net/transport/mock_transport"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"io"
"net"
_ "net/http/pprof"
"testing"
"time"
)
var ctx = context.Background()
func TestPeer_AcquireDrpcConn(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
in, out := net.Pipe()
go func() {
handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker)
}()
defer out.Close()
fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil)
dc, err := fx.AcquireDrpcConn(ctx)
require.NoError(t, err)
assert.NotEmpty(t, dc)
defer dc.Close()
assert.Len(t, fx.active, 1)
assert.Len(t, fx.inactive, 0)
fx.ReleaseDrpcConn(dc)
assert.Len(t, fx.active, 0)
assert.Len(t, fx.inactive, 1)
dc, err = fx.AcquireDrpcConn(ctx)
require.NoError(t, err)
assert.NotEmpty(t, dc)
assert.Len(t, fx.active, 1)
assert.Len(t, fx.inactive, 0)
}
func TestPeerAccept(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
in, out := net.Pipe()
defer out.Close()
var outHandshakeCh = make(chan error)
go func() {
outHandshakeCh <- handshake.OutgoingProtoHandshake(ctx, out, handshakeproto.ProtoType_DRPC)
}()
fx.acceptCh <- acceptedConn{conn: in}
cn := <-fx.testCtrl.serveConn
assert.Equal(t, in, cn)
assert.NoError(t, <-outHandshakeCh)
}
func TestPeer_TryClose(t *testing.T) {
t.Run("ttl", func(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
lu := time.Now()
fx.mc.EXPECT().LastUsage().Return(lu)
res, err := fx.TryClose(time.Second)
require.NoError(t, err)
assert.False(t, res)
})
t.Run("close", func(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
lu := time.Now().Add(-time.Second * 2)
fx.mc.EXPECT().LastUsage().Return(lu)
res, err := fx.TryClose(time.Second)
require.NoError(t, err)
assert.True(t, res)
})
t.Run("gc", func(t *testing.T) {
fx := newFixture(t, "p1")
defer fx.finish()
now := time.Now()
fx.mc.EXPECT().LastUsage().Return(now.Add(time.Millisecond * 100))
// make one inactive
in, out := net.Pipe()
go func() {
handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker)
}()
defer out.Close()
fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil)
dc, err := fx.AcquireDrpcConn(ctx)
require.NoError(t, err)
// make one active but closed
in2, out2 := net.Pipe()
go func() {
handshake.IncomingProtoHandshake(ctx, out2, defaultProtoChecker)
}()
defer out2.Close()
fx.mc.EXPECT().Open(gomock.Any()).Return(in2, nil)
dc2, err := fx.AcquireDrpcConn(ctx)
require.NoError(t, err)
_ = dc2.Close()
// make one inactive and closed
in3, out3 := net.Pipe()
go func() {
handshake.IncomingProtoHandshake(ctx, out3, defaultProtoChecker)
}()
defer out3.Close()
fx.mc.EXPECT().Open(gomock.Any()).Return(in3, nil)
dc3, err := fx.AcquireDrpcConn(ctx)
require.NoError(t, err)
fx.ReleaseDrpcConn(dc3)
_ = dc3.Close()
fx.ReleaseDrpcConn(dc)
time.Sleep(time.Millisecond * 100)
res, err := fx.TryClose(time.Millisecond * 50)
require.NoError(t, err)
assert.False(t, res)
})
}
type acceptedConn struct {
conn net.Conn
err error
}
func newFixture(t *testing.T, peerId string) *fixture {
fx := &fixture{
ctrl: gomock.NewController(t),
acceptCh: make(chan acceptedConn),
testCtrl: newTesCtrl(),
}
fx.mc = mock_transport.NewMockMultiConn(fx.ctrl)
ctx := CtxWithPeerId(context.Background(), peerId)
fx.mc.EXPECT().Context().Return(ctx).AnyTimes()
fx.mc.EXPECT().Accept().DoAndReturn(func() (net.Conn, error) {
ac := <-fx.acceptCh
return ac.conn, ac.err
}).AnyTimes()
fx.mc.EXPECT().Close().AnyTimes()
p, err := NewPeer(fx.mc, fx.testCtrl)
require.NoError(t, err)
fx.peer = p.(*peer)
return fx
}
type fixture struct {
*peer
ctrl *gomock.Controller
mc *mock_transport.MockMultiConn
acceptCh chan acceptedConn
testCtrl *testCtrl
}
func (fx *fixture) finish() {
fx.testCtrl.close()
fx.ctrl.Finish()
}
func newTesCtrl() *testCtrl {
return &testCtrl{closeCh: make(chan struct{}), serveConn: make(chan net.Conn, 10)}
}
type testCtrl struct {
serveConn chan net.Conn
closeCh chan struct{}
}
func (t *testCtrl) DrpcConfig() rpc.Config {
return rpc.Config{Stream: rpc.StreamConfig{MaxMsgSizeMb: 10}}
}
func (t *testCtrl) ServeConn(ctx context.Context, conn net.Conn) (err error) {
t.serveConn <- conn
<-t.closeCh
return io.EOF
}
func (t *testCtrl) close() {
close(t.closeCh)
}

View File

@ -0,0 +1,111 @@
package peerservice
import (
"context"
"errors"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/pool"
"github.com/anyproto/any-sync/net/rpc/server"
"github.com/anyproto/any-sync/net/transport"
"github.com/anyproto/any-sync/net/transport/yamux"
"github.com/anyproto/any-sync/nodeconf"
"go.uber.org/zap"
"sync"
)
const CName = "net.peerservice"
var log = logger.NewNamed(CName)
var (
ErrAddrsNotFound = errors.New("addrs for peer not found")
)
func New() PeerService {
return new(peerService)
}
type PeerService interface {
Dial(ctx context.Context, peerId string) (pr peer.Peer, err error)
SetPeerAddrs(peerId string, addrs []string)
transport.Accepter
app.Component
}
type peerService struct {
yamux transport.Transport
nodeConf nodeconf.NodeConf
peerAddrs map[string][]string
pool pool.Pool
server server.DRPCServer
mu sync.RWMutex
}
func (p *peerService) Init(a *app.App) (err error) {
p.yamux = a.MustComponent(yamux.CName).(transport.Transport)
p.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
p.pool = a.MustComponent(pool.CName).(pool.Pool)
p.server = a.MustComponent(server.CName).(server.DRPCServer)
p.peerAddrs = map[string][]string{}
p.yamux.SetAccepter(p)
return nil
}
func (p *peerService) Name() (name string) {
return CName
}
func (p *peerService) Dial(ctx context.Context, peerId string) (pr peer.Peer, err error) {
p.mu.RLock()
defer p.mu.RUnlock()
addrs, err := p.getPeerAddrs(peerId)
if err != nil {
return
}
var mc transport.MultiConn
log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs))
for _, addr := range addrs {
mc, err = p.yamux.Dial(ctx, addr)
if err != nil {
log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err))
} else {
break
}
}
if err != nil {
return
}
return peer.NewPeer(mc, p.server)
}
func (p *peerService) Accept(mc transport.MultiConn) (err error) {
pr, err := peer.NewPeer(mc, p.server)
if err != nil {
return err
}
if err = p.pool.AddPeer(context.Background(), pr); err != nil {
_ = pr.Close()
}
return
}
func (p *peerService) SetPeerAddrs(peerId string, addrs []string) {
p.mu.Lock()
defer p.mu.Unlock()
p.peerAddrs[peerId] = addrs
}
func (p *peerService) getPeerAddrs(peerId string) ([]string, error) {
if addrs, ok := p.nodeConf.PeerAddresses(peerId); ok {
return addrs, nil
}
addrs, ok := p.peerAddrs[peerId]
if !ok || len(addrs) == 0 {
return nil, ErrAddrsNotFound
}
return addrs, nil
}

View File

@ -0,0 +1,80 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/net/pool (interfaces: Pool)
// Package mock_pool is a generated GoMock package.
package mock_pool
import (
context "context"
reflect "reflect"
peer "github.com/anyproto/any-sync/net/peer"
gomock "github.com/golang/mock/gomock"
)
// MockPool is a mock of Pool interface.
type MockPool struct {
ctrl *gomock.Controller
recorder *MockPoolMockRecorder
}
// MockPoolMockRecorder is the mock recorder for MockPool.
type MockPoolMockRecorder struct {
mock *MockPool
}
// NewMockPool creates a new mock instance.
func NewMockPool(ctrl *gomock.Controller) *MockPool {
mock := &MockPool{ctrl: ctrl}
mock.recorder = &MockPoolMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPool) EXPECT() *MockPoolMockRecorder {
return m.recorder
}
// AddPeer mocks base method.
func (m *MockPool) AddPeer(arg0 context.Context, arg1 peer.Peer) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddPeer", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// AddPeer indicates an expected call of AddPeer.
func (mr *MockPoolMockRecorder) AddPeer(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPeer", reflect.TypeOf((*MockPool)(nil).AddPeer), arg0, arg1)
}
// Get mocks base method.
func (m *MockPool) Get(arg0 context.Context, arg1 string) (peer.Peer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0, arg1)
ret0, _ := ret[0].(peer.Peer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockPoolMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPool)(nil).Get), arg0, arg1)
}
// GetOneOf mocks base method.
func (m *MockPool) GetOneOf(arg0 context.Context, arg1 []string) (peer.Peer, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOneOf", arg0, arg1)
ret0, _ := ret[0].(peer.Peer)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOneOf indicates an expected call of GetOneOf.
func (mr *MockPoolMockRecorder) GetOneOf(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOneOf", reflect.TypeOf((*MockPool)(nil).GetOneOf), arg0, arg1)
}

View File

@ -1,10 +1,10 @@
//go:generate mockgen -destination mock_pool/mock_pool.go github.com/anyproto/any-sync/net/pool Pool
package pool package pool
import ( import (
"context" "context"
"github.com/anyproto/any-sync/app/ocache" "github.com/anyproto/any-sync/app/ocache"
"github.com/anyproto/any-sync/net" "github.com/anyproto/any-sync/net"
"github.com/anyproto/any-sync/net/dialer"
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/net/secureservice/handshake"
"go.uber.org/zap" "go.uber.org/zap"
@ -13,59 +13,61 @@ import (
// Pool creates and caches outgoing connection // Pool creates and caches outgoing connection
type Pool interface { type Pool interface {
// Get lookups to peer in existing connections or creates and cache new one // Get lookups to peer in existing connections or creates and outgoing new one
Get(ctx context.Context, id string) (peer.Peer, error) Get(ctx context.Context, id string) (peer.Peer, error)
// Dial creates new connection to peer and not use cache // GetOneOf searches at least one existing connection in outgoing or creates a new one from a randomly selected id from given list
Dial(ctx context.Context, id string) (peer.Peer, error)
// GetOneOf searches at least one existing connection in cache or creates a new one from a randomly selected id from given list
GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error)
// AddPeer adds incoming peer to the pool
DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) AddPeer(ctx context.Context, p peer.Peer) (err error)
} }
type pool struct { type pool struct {
cache ocache.OCache outgoing ocache.OCache
dialer dialer.Dialer incoming ocache.OCache
} }
func (p *pool) Name() (name string) { func (p *pool) Name() (name string) {
return CName return CName
} }
func (p *pool) Run(ctx context.Context) (err error) { func (p *pool) Get(ctx context.Context, id string) (pr peer.Peer, err error) {
return nil // if we have incoming connection - try to reuse it
if pr, err = p.get(ctx, p.incoming, id); err != nil {
// or try to get or create outgoing
return p.get(ctx, p.outgoing, id)
}
return
} }
func (p *pool) Get(ctx context.Context, id string) (peer.Peer, error) { func (p *pool) get(ctx context.Context, source ocache.OCache, id string) (peer.Peer, error) {
v, err := p.cache.Get(ctx, id) v, err := source.Get(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
pr := v.(peer.Peer) pr := v.(peer.Peer)
select { if !pr.IsClosed() {
case <-pr.Closed():
default:
return pr, nil return pr, nil
} }
_, _ = p.cache.Remove(ctx, id) _, _ = source.Remove(ctx, id)
return p.Get(ctx, id) return p.Get(ctx, id)
} }
func (p *pool) Dial(ctx context.Context, id string) (peer.Peer, error) {
return p.dialer.Dial(ctx, id)
}
func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
// finding existing connection // finding existing connection
for _, peerId := range peerIds { for _, peerId := range peerIds {
if v, err := p.cache.Pick(ctx, peerId); err == nil { if v, err := p.incoming.Pick(ctx, peerId); err == nil {
pr := v.(peer.Peer) pr := v.(peer.Peer)
select { if !pr.IsClosed() {
case <-pr.Closed():
default:
return pr, nil return pr, nil
} }
_, _ = p.cache.Remove(ctx, peerId) _, _ = p.incoming.Remove(ctx, peerId)
}
if v, err := p.outgoing.Pick(ctx, peerId); err == nil {
pr := v.(peer.Peer)
if !pr.IsClosed() {
return pr, nil
}
_, _ = p.outgoing.Remove(ctx, peerId)
} }
} }
// shuffle ids for better consistency // shuffle ids for better consistency
@ -75,8 +77,8 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error
// connecting // connecting
var lastErr error var lastErr error
for _, peerId := range peerIds { for _, peerId := range peerIds {
if v, err := p.cache.Get(ctx, peerId); err == nil { if v, err := p.Get(ctx, peerId); err == nil {
return v.(peer.Peer), nil return v, nil
} else { } else {
log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err)) log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err))
lastErr = err lastErr = err
@ -88,27 +90,18 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error
return nil, lastErr return nil, lastErr
} }
func (p *pool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { func (p *pool) AddPeer(ctx context.Context, pr peer.Peer) (err error) {
// shuffle ids for better consistency if err = p.incoming.Add(pr.Id(), pr); err != nil {
rand.Shuffle(len(peerIds), func(i, j int) { if err == ocache.ErrExists {
peerIds[i], peerIds[j] = peerIds[j], peerIds[i] // in case when an incoming connection with a peer already exists, we close and remove an existing connection
}) if v, e := p.incoming.Pick(ctx, pr.Id()); e == nil {
// connecting _ = v.Close()
var lastErr error _, _ = p.incoming.Remove(ctx, pr.Id())
for _, peerId := range peerIds { return p.incoming.Add(pr.Id(), pr)
if v, err := p.dialer.Dial(ctx, peerId); err == nil { }
return v.(peer.Peer), nil
} else { } else {
log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err)) return err
lastErr = err
} }
} }
if _, ok := lastErr.(handshake.HandshakeError); !ok { return
lastErr = net.ErrUnableToConnect
}
return nil, lastErr
}
func (p *pool) Close(ctx context.Context) (err error) {
return p.cache.Close()
} }

View File

@ -6,11 +6,11 @@ import (
"fmt" "fmt"
"github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/net" "github.com/anyproto/any-sync/net"
"github.com/anyproto/any-sync/net/dialer"
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
net2 "net"
"storj.io/drpc" "storj.io/drpc"
"testing" "testing"
"time" "time"
@ -133,6 +133,27 @@ func TestPool_GetOneOf(t *testing.T) {
}) })
} }
func TestPool_AddPeer(t *testing.T) {
t.Run("success", func(t *testing.T) {
fx := newFixture(t)
defer fx.Finish()
require.NoError(t, fx.AddPeer(ctx, newTestPeer("p1")))
})
t.Run("two peers", func(t *testing.T) {
fx := newFixture(t)
defer fx.Finish()
p1, p2 := newTestPeer("p1"), newTestPeer("p1")
require.NoError(t, fx.AddPeer(ctx, p1))
require.NoError(t, fx.AddPeer(ctx, p2))
select {
case <-p1.closed:
default:
assert.Truef(t, false, "peer not closed")
}
})
}
func newFixture(t *testing.T) *fixture { func newFixture(t *testing.T) *fixture {
fx := &fixture{ fx := &fixture{
Service: New(), Service: New(),
@ -158,7 +179,7 @@ type fixture struct {
t *testing.T t *testing.T
} }
var _ dialer.Dialer = (*dialerMock)(nil) var _ dialer = (*dialerMock)(nil)
type dialerMock struct { type dialerMock struct {
dial func(ctx context.Context, peerId string) (peer peer.Peer, err error) dial func(ctx context.Context, peerId string) (peer peer.Peer, err error)
@ -181,7 +202,7 @@ func (d *dialerMock) Init(a *app.App) (err error) {
} }
func (d *dialerMock) Name() (name string) { func (d *dialerMock) Name() (name string) {
return dialer.CName return "net.peerservice"
} }
func newTestPeer(id string) *testPeer { func newTestPeer(id string) *testPeer {
@ -196,6 +217,31 @@ type testPeer struct {
closed chan struct{} closed chan struct{}
} }
func (t *testPeer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error {
return fmt.Errorf("not implemented")
}
func (t *testPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
return nil, fmt.Errorf("not implemented")
}
func (t *testPeer) ReleaseDrpcConn(conn drpc.Conn) {}
func (t *testPeer) Context() context.Context {
//TODO implement me
panic("implement me")
}
func (t *testPeer) Accept() (conn net2.Conn, err error) {
//TODO implement me
panic("implement me")
}
func (t *testPeer) Open(ctx context.Context) (conn net2.Conn, err error) {
//TODO implement me
panic("implement me")
}
func (t *testPeer) Addr() string { func (t *testPeer) Addr() string {
return "" return ""
} }
@ -204,12 +250,6 @@ func (t *testPeer) Id() string {
return t.id return t.id
} }
func (t *testPeer) LastUsage() time.Time {
return time.Now()
}
func (t *testPeer) UpdateLastUsage() {}
func (t *testPeer) TryClose(objectTTL time.Duration) (res bool, err error) { func (t *testPeer) TryClose(objectTTL time.Duration) (res bool, err error) {
return true, t.Close() return true, t.Close()
} }
@ -224,14 +264,11 @@ func (t *testPeer) Close() error {
return nil return nil
} }
func (t *testPeer) Closed() <-chan struct{} { func (t *testPeer) IsClosed() bool {
return t.closed select {
} case <-t.closed:
return true
func (t *testPeer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { default:
return fmt.Errorf("call Invoke on test peer") return false
} }
func (t *testPeer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
return nil, fmt.Errorf("call NewStream on test peer")
} }

View File

@ -6,8 +6,9 @@ import (
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/app/ocache" "github.com/anyproto/any-sync/app/ocache"
"github.com/anyproto/any-sync/metric" "github.com/anyproto/any-sync/metric"
"github.com/anyproto/any-sync/net/dialer" "github.com/anyproto/any-sync/net/peer"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"go.uber.org/zap"
"time" "time"
) )
@ -23,46 +24,54 @@ func New() Service {
type Service interface { type Service interface {
Pool Pool
NewPool(name string) Pool
app.ComponentRunnable app.ComponentRunnable
} }
type dialer interface {
Dial(ctx context.Context, peerId string) (pr peer.Peer, err error)
}
type poolService struct { type poolService struct {
// default pool // default pool
*pool *pool
dialer dialer.Dialer dialer dialer
metricReg *prometheus.Registry metricReg *prometheus.Registry
} }
func (p *poolService) Init(a *app.App) (err error) { func (p *poolService) Init(a *app.App) (err error) {
p.dialer = a.MustComponent(dialer.CName).(dialer.Dialer) p.dialer = a.MustComponent("net.peerservice").(dialer)
p.pool = &pool{dialer: p.dialer} p.pool = &pool{}
if m := a.Component(metric.CName); m != nil { if m := a.Component(metric.CName); m != nil {
p.metricReg = m.(metric.Metric).Registry() p.metricReg = m.(metric.Metric).Registry()
} }
p.pool.cache = ocache.New( p.pool.outgoing = ocache.New(
func(ctx context.Context, id string) (value ocache.Object, err error) { func(ctx context.Context, id string) (value ocache.Object, err error) {
return p.dialer.Dial(ctx, id) return p.dialer.Dial(ctx, id)
}, },
ocache.WithLogger(log.Sugar()), ocache.WithLogger(log.Sugar()),
ocache.WithGCPeriod(time.Minute), ocache.WithGCPeriod(time.Minute),
ocache.WithTTL(time.Minute*5), ocache.WithTTL(time.Minute*5),
ocache.WithPrometheus(p.metricReg, "netpool", "default"), ocache.WithPrometheus(p.metricReg, "netpool", "outgoing"),
)
p.pool.incoming = ocache.New(
func(ctx context.Context, id string) (value ocache.Object, err error) {
return nil, ocache.ErrNotExists
},
ocache.WithLogger(log.Sugar()),
ocache.WithGCPeriod(time.Minute),
ocache.WithTTL(time.Minute*5),
ocache.WithPrometheus(p.metricReg, "netpool", "incoming"),
) )
return nil return nil
} }
func (p *poolService) NewPool(name string) Pool { func (p *pool) Run(ctx context.Context) (err error) {
return &pool{ return nil
dialer: p.dialer, }
cache: ocache.New(
func(ctx context.Context, id string) (value ocache.Object, err error) { func (p *pool) Close(ctx context.Context) (err error) {
return p.dialer.Dial(ctx, id) if e := p.incoming.Close(); e != nil {
}, log.Warn("close incoming cache error", zap.Error(e))
ocache.WithLogger(log.Sugar()), }
ocache.WithGCPeriod(time.Minute), return p.outgoing.Close()
ocache.WithTTL(time.Minute*5),
ocache.WithPrometheus(p.metricReg, "netpool", name),
),
}
} }

View File

@ -0,0 +1,9 @@
package debugserver
type configGetter interface {
GetDebugServer() Config
}
type Config struct {
ListenAddr string `yaml:"listenAddr"`
}

View File

@ -0,0 +1,70 @@
package debugserver
import (
"context"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/net/rpc"
"net"
"storj.io/drpc"
"storj.io/drpc/drpcmanager"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"storj.io/drpc/drpcstream"
"storj.io/drpc/drpcwire"
)
const CName = "net.rpc.debugserver"
func New() DebugServer {
return &debugServer{}
}
type DebugServer interface {
app.ComponentRunnable
drpc.Mux
}
type debugServer struct {
drpcServer *drpcserver.Server
*drpcmux.Mux
drpcConf rpc.Config
config Config
runCtx context.Context
runCtxCancel context.CancelFunc
}
func (d *debugServer) Init(a *app.App) (err error) {
d.drpcConf = a.MustComponent("config").(rpc.ConfigGetter).GetDrpc()
d.config = a.MustComponent("config").(configGetter).GetDebugServer()
d.Mux = drpcmux.New()
bufSize := d.drpcConf.Stream.MaxMsgSizeMb * (1 << 20)
d.drpcServer = drpcserver.NewWithOptions(d, drpcserver.Options{Manager: drpcmanager.Options{
Reader: drpcwire.ReaderOptions{MaximumBufferSize: bufSize},
Stream: drpcstream.Options{MaximumBufferSize: bufSize},
}})
return nil
}
func (d *debugServer) Name() (name string) {
return CName
}
func (d *debugServer) Run(ctx context.Context) (err error) {
if d.config.ListenAddr == "" {
return
}
lis, err := net.Listen("tcp", d.config.ListenAddr)
if err != nil {
return
}
d.runCtx, d.runCtxCancel = context.WithCancel(context.Background())
go d.drpcServer.Serve(d.runCtx, lis)
return
}
func (d *debugServer) Close(ctx context.Context) (err error) {
if d.runCtx != nil {
d.runCtxCancel()
}
return nil
}

13
net/rpc/drpcconfig.go Normal file
View File

@ -0,0 +1,13 @@
package rpc
type ConfigGetter interface {
GetDrpc() Config
}
type Config struct {
Stream StreamConfig `yaml:"stream"`
}
type StreamConfig struct {
MaxMsgSizeMb int `yaml:"maxMsgSizeMb"`
}

30
net/rpc/rpctest/peer.go Normal file
View File

@ -0,0 +1,30 @@
package rpctest
import (
"context"
"github.com/anyproto/any-sync/net/connutil"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/transport"
yamux2 "github.com/anyproto/any-sync/net/transport/yamux"
"github.com/hashicorp/yamux"
"net"
)
func MultiConnPair(peerIdServ, peerIdClient string) (serv, client transport.MultiConn) {
sc, cc := net.Pipe()
var servConn = make(chan transport.MultiConn, 1)
go func() {
sess, err := yamux.Server(sc, yamux.DefaultConfig())
if err != nil {
panic(err)
}
servConn <- yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdServ), connutil.NewLastUsageConn(sc), "", sess)
}()
sess, err := yamux.Client(cc, yamux.DefaultConfig())
if err != nil {
panic(err)
}
client = yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdClient), connutil.NewLastUsageConn(cc), "", sess)
serv = <-servConn
return
}

View File

@ -2,84 +2,21 @@ package rpctest
import ( import (
"context" "context"
"errors"
"github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/net"
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/pool" "github.com/anyproto/any-sync/net/pool"
"math/rand"
"storj.io/drpc"
"sync" "sync"
"time"
) )
var ErrCantConnect = errors.New("can't connect to test server")
func NewTestPool() *TestPool { func NewTestPool() *TestPool {
return &TestPool{ return &TestPool{peers: map[string]peer.Peer{}}
peers: map[string]peer.Peer{},
}
} }
type TestPool struct { type TestPool struct {
ts *TesServer
peers map[string]peer.Peer peers map[string]peer.Peer
mu sync.Mutex mu sync.Mutex
} ts *TestServer
func (t *TestPool) WithServer(ts *TesServer) *TestPool {
t.mu.Lock()
defer t.mu.Unlock()
t.ts = ts
return t
}
func (t *TestPool) Get(ctx context.Context, id string) (peer.Peer, error) {
t.mu.Lock()
defer t.mu.Unlock()
if p, ok := t.peers[id]; ok {
return p, nil
}
if t.ts == nil {
return nil, ErrCantConnect
}
return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil
}
func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) {
return t.Get(ctx, id)
}
func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
t.mu.Lock()
defer t.mu.Unlock()
for _, peerId := range peerIds {
if p, ok := t.peers[peerId]; ok {
return p, nil
}
}
if t.ts == nil {
return nil, ErrCantConnect
}
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
}
func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.ts == nil {
return nil, ErrCantConnect
}
return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil
}
func (t *TestPool) NewPool(name string) pool.Pool {
return t
}
func (t *TestPool) AddPeer(p peer.Peer) {
t.mu.Lock()
defer t.mu.Unlock()
t.peers[p.Id()] = p
} }
func (t *TestPool) Init(a *app.App) (err error) { func (t *TestPool) Init(a *app.App) (err error) {
@ -90,6 +27,13 @@ func (t *TestPool) Name() (name string) {
return pool.CName return pool.CName
} }
func (t *TestPool) WithServer(ts *TestServer) *TestPool {
t.mu.Lock()
defer t.mu.Unlock()
t.ts = ts
return t
}
func (t *TestPool) Run(ctx context.Context) (err error) { func (t *TestPool) Run(ctx context.Context) (err error) {
return nil return nil
} }
@ -98,25 +42,35 @@ func (t *TestPool) Close(ctx context.Context) (err error) {
return nil return nil
} }
type testPeer struct { func (t *TestPool) Get(ctx context.Context, id string) (peer.Peer, error) {
id string t.mu.Lock()
drpc.Conn defer t.mu.Unlock()
if p, ok := t.peers[id]; ok {
return p, nil
}
if t.ts == nil {
return nil, net.ErrUnableToConnect
}
return t.ts.Dial(id)
} }
func (t testPeer) Addr() string { func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) {
return "" t.mu.Lock()
defer t.mu.Unlock()
for _, peerId := range peerIds {
if p, ok := t.peers[peerId]; ok {
return p, nil
}
}
if t.ts == nil || len(peerIds) == 0 {
return nil, net.ErrUnableToConnect
}
return t.ts.Dial(peerIds[0])
} }
func (t testPeer) TryClose(objectTTL time.Duration) (res bool, err error) { func (t *TestPool) AddPeer(ctx context.Context, p peer.Peer) (err error) {
return true, t.Close() t.mu.Lock()
defer t.mu.Unlock()
t.peers[p.Id()] = p
return nil
} }
func (t testPeer) Id() string {
return t.id
}
func (t testPeer) LastUsage() time.Time {
return time.Now()
}
func (t testPeer) UpdateLastUsage() {}

View File

@ -3,45 +3,69 @@ package rpctest
import ( import (
"context" "context"
"github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/rpc"
"github.com/anyproto/any-sync/net/rpc/server" "github.com/anyproto/any-sync/net/rpc/server"
"net" "net"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
"storj.io/drpc/drpcmux" "storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver" "storj.io/drpc/drpcserver"
) )
func NewTestServer() *TesServer { type mockCtrl struct {
ts := &TesServer{ }
func (m mockCtrl) ServeConn(ctx context.Context, conn net.Conn) (err error) {
return nil
}
func (m mockCtrl) DrpcConfig() rpc.Config {
return rpc.Config{}
}
func NewTestServer() *TestServer {
ts := &TestServer{
Mux: drpcmux.New(), Mux: drpcmux.New(),
} }
ts.Server = drpcserver.New(ts.Mux) ts.Server = drpcserver.New(ts.Mux)
return ts return ts
} }
type TesServer struct { type TestServer struct {
*drpcmux.Mux *drpcmux.Mux
*drpcserver.Server *drpcserver.Server
} }
func (ts *TesServer) Init(a *app.App) (err error) { func (s *TestServer) Init(a *app.App) (err error) {
return nil return nil
} }
func (ts *TesServer) Name() (name string) { func (s *TestServer) Name() (name string) {
return server.CName return server.CName
} }
func (ts *TesServer) Run(ctx context.Context) (err error) { func (s *TestServer) Run(ctx context.Context) (err error) {
return nil return nil
} }
func (ts *TesServer) Close(ctx context.Context) (err error) { func (s *TestServer) Close(ctx context.Context) (err error) {
return nil return nil
} }
func (ts *TesServer) Dial(ctx context.Context) drpc.Conn { func (s *TestServer) ServeConn(ctx context.Context, conn net.Conn) (err error) {
sc, cc := net.Pipe() return s.Server.ServeOne(ctx, conn)
go ts.Server.ServeOne(ctx, sc) }
return drpcconn.New(cc)
func (s *TestServer) DrpcConfig() rpc.Config {
return rpc.Config{Stream: rpc.StreamConfig{MaxMsgSizeMb: 10}}
}
func (s *TestServer) Dial(peerId string) (clientPeer peer.Peer, err error) {
mcS, mcC := MultiConnPair(peerId+"server", peerId)
// NewPeer runs the accept loop
_, err = peer.NewPeer(mcS, s)
if err != nil {
return
}
// and we ourselves don't call server methods on accept
return peer.NewPeer(mcC, mockCtrl{})
} }

View File

@ -1,134 +0,0 @@
package server
import (
"context"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/secureservice"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/zeebo/errs"
"go.uber.org/zap"
"io"
"net"
"storj.io/drpc"
"storj.io/drpc/drpcmanager"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"storj.io/drpc/drpcwire"
"time"
)
type BaseDrpcServer struct {
drpcServer *drpcserver.Server
transport secureservice.SecureService
listeners []net.Listener
handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
cancel func()
*drpcmux.Mux
}
type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
type Params struct {
BufferSizeMb int
ListenAddrs []string
Wrapper DRPCHandlerWrapper
TimeoutMillis int
Handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error)
}
func NewBaseDrpcServer() *BaseDrpcServer {
return &BaseDrpcServer{Mux: drpcmux.New()}
}
func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) {
s.drpcServer = drpcserver.NewWithOptions(params.Wrapper(s.Mux), drpcserver.Options{Manager: drpcmanager.Options{
Reader: drpcwire.ReaderOptions{MaximumBufferSize: params.BufferSizeMb * (1 << 20)},
}})
s.handshake = params.Handshake
ctx, s.cancel = context.WithCancel(ctx)
for _, addr := range params.ListenAddrs {
list, err := net.Listen("tcp", addr)
if err != nil {
return err
}
s.listeners = append(s.listeners, list)
go s.serve(ctx, list)
}
return
}
func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) {
l := log.With(zap.String("localAddr", lis.Addr().String()))
l.Info("drpc listener started")
defer func() {
l.Debug("drpc listener stopped")
}()
for {
select {
case <-ctx.Done():
return
default:
}
conn, err := lis.Accept()
if err != nil {
if isTemporary(err) {
l.Debug("listener temporary accept error", zap.Error(err))
select {
case <-time.After(time.Second):
case <-ctx.Done():
return
}
continue
}
l.Error("listener accept error", zap.Error(err))
return
}
go s.serveConn(conn)
}
}
func (s *BaseDrpcServer) serveConn(conn net.Conn) {
l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
var (
ctx = context.Background()
err error
)
if s.handshake != nil {
ctx, conn, err = s.handshake(conn)
if err != nil {
l.Info("handshake error", zap.Error(err))
return
}
if sc, ok := conn.(sec.SecureConn); ok {
ctx = peer.CtxWithPeerId(ctx, sc.RemotePeer().String())
}
}
ctx = peer.CtxWithPeerAddr(ctx, conn.RemoteAddr().String())
l.Debug("connection opened")
if err := s.drpcServer.ServeOne(ctx, conn); err != nil {
if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) {
l.Debug("connection closed")
} else {
l.Warn("serve connection error", zap.Error(err))
}
}
}
func (s *BaseDrpcServer) ListenAddrs() (addrs []net.Addr) {
for _, list := range s.listeners {
addrs = append(addrs, list.Addr())
}
return
}
func (s *BaseDrpcServer) Close(ctx context.Context) (err error) {
if s.cancel != nil {
s.cancel()
}
for _, l := range s.listeners {
if e := l.Close(); e != nil {
log.Warn("close listener error", zap.Error(e))
}
}
return
}

View File

@ -5,12 +5,15 @@ import (
"github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/metric" "github.com/anyproto/any-sync/metric"
anyNet "github.com/anyproto/any-sync/net" "github.com/anyproto/any-sync/net/rpc"
"github.com/anyproto/any-sync/net/secureservice" "go.uber.org/zap"
"github.com/libp2p/go-libp2p/core/sec"
"net" "net"
"storj.io/drpc" "storj.io/drpc"
"time" "storj.io/drpc/drpcmanager"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"storj.io/drpc/drpcstream"
"storj.io/drpc/drpcwire"
) )
const CName = "common.net.drpcserver" const CName = "common.net.drpcserver"
@ -18,49 +21,53 @@ const CName = "common.net.drpcserver"
var log = logger.NewNamed(CName) var log = logger.NewNamed(CName)
func New() DRPCServer { func New() DRPCServer {
return &drpcServer{BaseDrpcServer: NewBaseDrpcServer()} return &drpcServer{}
} }
type DRPCServer interface { type DRPCServer interface {
app.ComponentRunnable ServeConn(ctx context.Context, conn net.Conn) (err error)
DrpcConfig() rpc.Config
app.Component
drpc.Mux drpc.Mux
} }
type drpcServer struct { type drpcServer struct {
config anyNet.Config drpcServer *drpcserver.Server
metric metric.Metric *drpcmux.Mux
transport secureservice.SecureService config rpc.Config
*BaseDrpcServer metric metric.Metric
} }
func (s *drpcServer) Init(a *app.App) (err error) { type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler
s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet()
s.metric = a.MustComponent(metric.CName).(metric.Metric)
s.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService)
return nil
}
func (s *drpcServer) Name() (name string) { func (s *drpcServer) Name() (name string) {
return CName return CName
} }
func (s *drpcServer) Run(ctx context.Context) (err error) { func (s *drpcServer) Init(a *app.App) (err error) {
params := Params{ s.config = a.MustComponent("config").(rpc.ConfigGetter).GetDrpc()
BufferSizeMb: s.config.Stream.MaxMsgSizeMb, s.metric, _ = a.Component(metric.CName).(metric.Metric)
TimeoutMillis: s.config.Stream.TimeoutMilliseconds, s.Mux = drpcmux.New()
ListenAddrs: s.config.Server.ListenAddrs,
Wrapper: func(handler drpc.Handler) drpc.Handler { var handler drpc.Handler
return s.metric.WrapDRPCHandler(handler) handler = s
}, if s.metric != nil {
Handshake: func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) { handler = s.metric.WrapDRPCHandler(s)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
return s.transport.SecureInbound(ctx, conn)
},
} }
return s.BaseDrpcServer.Run(ctx, params) bufSize := s.config.Stream.MaxMsgSizeMb * (1 << 20)
s.drpcServer = drpcserver.NewWithOptions(handler, drpcserver.Options{Manager: drpcmanager.Options{
Reader: drpcwire.ReaderOptions{MaximumBufferSize: bufSize},
Stream: drpcstream.Options{MaximumBufferSize: bufSize},
}})
return
} }
func (s *drpcServer) Close(ctx context.Context) (err error) { func (s *drpcServer) ServeConn(ctx context.Context, conn net.Conn) (err error) {
return s.BaseDrpcServer.Close(ctx) l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String()))
l.Debug("drpc serve peer")
return s.drpcServer.ServeOne(ctx, conn)
}
func (s *drpcServer) DrpcConfig() rpc.Config {
return s.config
} }

View File

@ -5,7 +5,6 @@ import (
"github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/anyproto/any-sync/util/crypto" "github.com/anyproto/any-sync/util/crypto"
"github.com/libp2p/go-libp2p/core/sec"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -19,11 +18,11 @@ type noVerifyChecker struct {
cred *handshakeproto.Credentials cred *handshakeproto.Credentials
} }
func (n noVerifyChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { func (n noVerifyChecker) MakeCredentials(remotePeerId string) *handshakeproto.Credentials {
return n.cred return n.cred
} }
func (n noVerifyChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { func (n noVerifyChecker) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if cred.Version != n.cred.Version { if cred.Version != n.cred.Version {
return nil, handshake.ErrIncompatibleVersion return nil, handshake.ErrIncompatibleVersion
} }
@ -42,8 +41,8 @@ type peerSignVerifier struct {
account *accountdata.AccountKeys account *accountdata.AccountKeys
} }
func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { func (p *peerSignVerifier) MakeCredentials(remotePeerId string) *handshakeproto.Credentials {
sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + sc.RemotePeer().String())) sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + remotePeerId))
if err != nil { if err != nil {
log.Warn("can't sign identity credentials", zap.Error(err)) log.Warn("can't sign identity credentials", zap.Error(err))
} }
@ -61,7 +60,7 @@ func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cr
} }
} }
func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { func (p *peerSignVerifier) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if cred.Version != p.protoVersion { if cred.Version != p.protoVersion {
return nil, handshake.ErrIncompatibleVersion return nil, handshake.ErrIncompatibleVersion
} }
@ -76,7 +75,7 @@ func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakepro
if err != nil { if err != nil {
return nil, handshake.ErrInvalidCredentials return nil, handshake.ErrInvalidCredentials
} }
ok, err := pubKey.Verify([]byte((sc.RemotePeer().String() + p.account.PeerId)), msg.Sign) ok, err := pubKey.Verify([]byte((remotePeerId + p.account.PeerId)), msg.Sign)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -4,13 +4,8 @@ import (
"github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/net/secureservice/handshake"
"github.com/anyproto/any-sync/testutil/accounttest" "github.com/anyproto/any-sync/testutil/accounttest"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net"
"testing" "testing"
) )
@ -23,8 +18,8 @@ func TestPeerSignVerifier_CheckCredential(t *testing.T) {
cc1 := newPeerSignVerifier(0, a1) cc1 := newPeerSignVerifier(0, a1)
cc2 := newPeerSignVerifier(0, a2) cc2 := newPeerSignVerifier(0, a2)
c1 := newTestSC(a2.PeerId) c1 := a2.PeerId
c2 := newTestSC(a1.PeerId) c2 := a1.PeerId
cr1 := cc1.MakeCredentials(c1) cr1 := cc1.MakeCredentials(c1)
cr2 := cc2.MakeCredentials(c2) cr2 := cc2.MakeCredentials(c2)
@ -48,8 +43,8 @@ func TestIncompatibleVersion(t *testing.T) {
cc1 := newPeerSignVerifier(0, a1) cc1 := newPeerSignVerifier(0, a1)
cc2 := newPeerSignVerifier(1, a2) cc2 := newPeerSignVerifier(1, a2)
c1 := newTestSC(a2.PeerId) c1 := a2.PeerId
c2 := newTestSC(a1.PeerId) c2 := a1.PeerId
cr1 := cc1.MakeCredentials(c1) cr1 := cc1.MakeCredentials(c1)
cr2 := cc2.MakeCredentials(c2) cr2 := cc2.MakeCredentials(c2)
@ -68,35 +63,3 @@ func newTestAccData(t *testing.T) *accountdata.AccountKeys {
require.NoError(t, as.Init(nil)) require.NoError(t, as.Init(nil))
return as.Account() return as.Account()
} }
func newTestSC(peerId string) sec.SecureConn {
pid, _ := peer.Decode(peerId)
return &testSc{
ID: pid,
}
}
type testSc struct {
net.Conn
peer.ID
}
func (t *testSc) LocalPeer() peer.ID {
return ""
}
func (t *testSc) LocalPrivateKey() crypto.PrivKey {
return nil
}
func (t *testSc) RemotePeer() peer.ID {
return t.ID
}
func (t *testSc) RemotePublicKey() crypto.PubKey {
return nil
}
func (t *testSc) ConnState() network.ConnectionState {
return network.ConnectionState{}
}

View File

@ -0,0 +1,133 @@
package handshake
import (
"context"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"io"
)
func OutgoingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
var (
resIdentity []byte
resErr error
)
go func() {
defer close(done)
resIdentity, resErr = outgoingHandshake(h, conn, peerId, cc)
}()
select {
case <-done:
return resIdentity, resErr
case <-ctx.Done():
_ = conn.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = conn
localCred := cc.MakeCredentials(peerId)
if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err)
return
}
msg, err := h.readMsg(msgTypeAck, msgTypeCred)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg(msgTypeAck)
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error == handshakeproto.Error_Null {
return identity, nil
} else {
_ = h.conn.Close()
return nil, HandshakeError{e: msg.ack.Error}
}
}
func IncomingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
var (
resIdentity []byte
resError error
)
go func() {
defer close(done)
resIdentity, resError = incomingHandshake(h, conn, peerId, cc)
}()
select {
case <-done:
return resIdentity, resError
case <-ctx.Done():
_ = conn.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = conn
msg, err := h.readMsg(msgTypeCred)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeCredentials(cc.MakeCredentials(peerId)); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg(msgTypeAck)
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error != handshakeproto.Error_Null {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
return
}

View File

@ -7,7 +7,6 @@ import (
"github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"net" "net"
@ -17,7 +16,7 @@ import (
var noVerifyChecker = &testCredChecker{ var noVerifyChecker = &testCredChecker{
makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify}, makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify},
checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { checkCred: func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
return []byte("identity"), nil return []byte("identity"), nil
}, },
} }
@ -32,21 +31,20 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) _, err = noVerifyChecker.CheckCredential("p1", msg.cred)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred)
require.NoError(t, err) require.NoError(t, err)
// send credential message // send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// receive ack // receive ack
msg, err = h.readMsg() msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack // send ack
@ -59,7 +57,7 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
_ = c2.Close() _ = c2.Close()
@ -70,13 +68,13 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
_ = c2.Close() _ = c2.Close()
res := <-handshakeResCh res := <-handshakeResCh
@ -86,13 +84,13 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, h.writeAck(ErrInvalidCredentials.e)) require.NoError(t, h.writeAck(ErrInvalidCredentials.e))
res := <-handshakeResCh res := <-handshakeResCh
@ -102,16 +100,16 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) identity, err := OutgoingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error) assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error)
res := <-handshakeResCh res := <-handshakeResCh
@ -121,16 +119,16 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write credentials and close conn // write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
_ = c2.Close() _ = c2.Close()
res := <-handshakeResCh res := <-handshakeResCh
require.Error(t, res.err) require.Error(t, res.err)
@ -139,18 +137,18 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read ack and close conn // read ack and close conn
_, err = h.readMsg() _, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
_ = c2.Close() _ = c2.Close()
res := <-handshakeResCh res := <-handshakeResCh
@ -160,24 +158,23 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read ack // read ack
_, err = h.readMsg() _, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
// write cred instead ack // write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
msg, err := h.readMsg() _, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.Error(t, err)
assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error)
res := <-handshakeResCh res := <-handshakeResCh
require.Error(t, res.err) require.Error(t, res.err)
}) })
@ -185,21 +182,21 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
_, err = noVerifyChecker.CheckCredential(c2, msg.cred) _, err = noVerifyChecker.CheckCredential("", msg.cred)
require.NoError(t, err) require.NoError(t, err)
// send credential message // send credential message
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// receive ack // receive ack
msg, err = h.readMsg() msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
// send ack // send ack
@ -213,13 +210,13 @@ func TestOutgoingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker) identity, err := OutgoingHandshake(ctx, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// receive credential message // receive credential message
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
ctxCancel() ctxCancel()
res := <-handshakeResCh res := <-handshakeResCh
@ -236,22 +233,22 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials // wait credentials
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
// write ack // write ack
require.NoError(t, h.writeAck(handshakeproto.Error_Null)) require.NoError(t, h.writeAck(handshakeproto.Error_Null))
// wait ack // wait ack
msg, err = h.readMsg() msg, err = h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error) assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
res := <-handshakeResCh res := <-handshakeResCh
@ -262,7 +259,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
_ = c2.Close() _ = c2.Close()
@ -273,13 +270,13 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials and close conn // write credentials and close conn
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
_ = c2.Close() _ = c2.Close()
res := <-handshakeResCh res := <-handshakeResCh
require.Error(t, res.err) require.Error(t, res.err)
@ -288,7 +285,7 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -302,15 +299,15 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials})
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// except ack with error // except ack with error
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.cred) require.Nil(t, msg.cred)
require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error) require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error)
@ -322,15 +319,15 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion}) identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion})
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// except ack with error // except ack with error
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.cred) require.Nil(t, msg.cred)
require.Equal(t, handshakeproto.Error_IncompatibleVersion, msg.ack.Error) require.Equal(t, handshakeproto.Error_IncompatibleVersion, msg.ack.Error)
@ -342,21 +339,21 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read cred // read cred
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
// write cred instead ack // write cred instead ack
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// expect ack with error // expect EOF
msg, err := h.readMsg() _, err = h.readMsg(msgTypeAck)
require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) require.Error(t, err)
res := <-handshakeResCh res := <-handshakeResCh
require.Error(t, res.err) require.Error(t, res.err)
}) })
@ -364,15 +361,15 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// read cred and close conn // read cred and close conn
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
_ = c2.Close() _ = c2.Close()
@ -383,15 +380,15 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials // wait credentials
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
@ -405,15 +402,15 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials // wait credentials
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
@ -427,15 +424,15 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials // wait credentials
msg, err := h.readMsg() msg, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
require.Nil(t, msg.ack) require.Nil(t, msg.ack)
require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type)
@ -450,15 +447,15 @@ func TestIncomingHandshake(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(ctx, c1, noVerifyChecker) identity, err := IncomingHandshake(ctx, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
h.conn = c2 h.conn = c2
// write credentials // write credentials
require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials("")))
// wait credentials // wait credentials
_, err := h.readMsg() _, err := h.readMsg(msgTypeCred)
require.NoError(t, err) require.NoError(t, err)
ctxCancel() ctxCancel()
res := <-handshakeResCh res := <-handshakeResCh
@ -474,7 +471,7 @@ func TestNotAHandshakeMessage(t *testing.T) {
c1, c2 := newConnPair(t) c1, c2 := newConnPair(t)
var handshakeResCh = make(chan handshakeRes, 1) var handshakeResCh = make(chan handshakeRes, 1)
go func() { go func() {
identity, err := IncomingHandshake(nil, c1, noVerifyChecker) identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker)
handshakeResCh <- handshakeRes{identity: identity, err: err} handshakeResCh <- handshakeRes{identity: identity, err: err}
}() }()
h := newHandshake() h := newHandshake()
@ -482,7 +479,7 @@ func TestNotAHandshakeMessage(t *testing.T) {
_, err := c2.Write([]byte("some unexpected bytes")) _, err := c2.Write([]byte("some unexpected bytes"))
require.Error(t, err) require.Error(t, err)
res := <-handshakeResCh res := <-handshakeResCh
assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error()) assert.Error(t, res.err)
} }
func TestEndToEnd(t *testing.T) { func TestEndToEnd(t *testing.T) {
@ -493,11 +490,11 @@ func TestEndToEnd(t *testing.T) {
) )
st := time.Now() st := time.Now()
go func() { go func() {
identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker)
outResCh <- handshakeRes{identity: identity, err: err} outResCh <- handshakeRes{identity: identity, err: err}
}() }()
go func() { go func() {
identity, err := IncomingHandshake(nil, c2, noVerifyChecker) identity, err := IncomingHandshake(nil, c2, "", noVerifyChecker)
inResCh <- handshakeRes{identity: identity, err: err} inResCh <- handshakeRes{identity: identity, err: err}
}() }()
@ -521,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) {
defer close(done) defer close(done)
go func() { go func() {
for { for {
_, _ = OutgoingHandshake(nil, c1, noVerifyChecker) _, _ = OutgoingHandshake(nil, c1, "", noVerifyChecker)
select { select {
case outRes <- struct{}{}: case outRes <- struct{}{}:
case <-done: case <-done:
@ -531,7 +528,7 @@ func BenchmarkHandshake(b *testing.B) {
}() }()
go func() { go func() {
for { for {
_, _ = IncomingHandshake(nil, c2, noVerifyChecker) _, _ = IncomingHandshake(nil, c2, "", noVerifyChecker)
select { select {
case inRes <- struct{}{}: case inRes <- struct{}{}:
case <-done: case <-done:
@ -551,20 +548,20 @@ func BenchmarkHandshake(b *testing.B) {
type testCredChecker struct { type testCredChecker struct {
makeCred *handshakeproto.Credentials makeCred *handshakeproto.Credentials
checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) checkCred func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error)
checkErr error checkErr error
} }
func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { func (t *testCredChecker) MakeCredentials(peerId string) *handshakeproto.Credentials {
return t.makeCred return t.makeCred
} }
func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { func (t *testCredChecker) CheckCredential(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) {
if t.checkErr != nil { if t.checkErr != nil {
return nil, t.checkErr return nil, t.checkErr
} }
if t.checkCred != nil { if t.checkCred != nil {
return t.checkCred(sc, cred) return t.checkCred(peerId, cred)
} }
return nil, nil return nil, nil
} }

View File

@ -1,11 +1,9 @@
package handshake package handshake
import ( import (
"context"
"encoding/binary" "encoding/binary"
"errors" "errors"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/libp2p/go-libp2p/core/sec"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"io" "io"
"sync" "sync"
@ -14,8 +12,17 @@ import (
const headerSize = 5 // 1 byte for type + 4 byte for uint32 size const headerSize = 5 // 1 byte for type + 4 byte for uint32 size
const ( const (
msgTypeCred = byte(1) msgTypeCred = byte(1)
msgTypeAck = byte(2) msgTypeAck = byte(2)
msgTypeProto = byte(3)
sizeLimit = 200 * 1024 // 200 Kb
)
var (
credMsgTypes = []byte{msgTypeCred, msgTypeAck}
protoMsgTypes = []byte{msgTypeProto, msgTypeAck}
protoMsgTypesAck = []byte{msgTypeAck}
) )
type HandshakeError struct { type HandshakeError struct {
@ -38,154 +45,26 @@ var (
ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed} ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed}
ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected} ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected}
ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion} ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion}
ErrIncompatibleProto = HandshakeError{e: handshakeproto.Error_IncompatibleProto}
ErrRemoteIncompatibleProto = HandshakeError{Err: errors.New("remote peer declined the proto")}
ErrGotNotAHandshakeMessage = errors.New("go not a handshake message") ErrGotUnexpectedMessage = errors.New("go not a handshake message")
) )
var handshakePool = &sync.Pool{New: func() any { var handshakePool = &sync.Pool{New: func() any {
return &handshake{ return &handshake{
remoteCred: &handshakeproto.Credentials{}, remoteCred: &handshakeproto.Credentials{},
remoteAck: &handshakeproto.Ack{}, remoteAck: &handshakeproto.Ack{},
localAck: &handshakeproto.Ack{}, localAck: &handshakeproto.Ack{},
buf: make([]byte, 0, 1024), remoteProto: &handshakeproto.Proto{},
buf: make([]byte, 0, 1024),
} }
}} }}
type CredentialChecker interface { type CredentialChecker interface {
MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials MakeCredentials(remotePeerId string) *handshakeproto.Credentials
CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error)
}
func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = outgoingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
localCred := cc.MakeCredentials(sc)
if err = h.writeCredentials(localCred); err != nil {
h.tryWriteErrAndClose(err)
return
}
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error == handshakeproto.Error_Null {
return identity, nil
} else {
_ = h.conn.Close()
return nil, HandshakeError{e: msg.ack.Error}
}
}
func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
identity, err = incomingHandshake(h, sc, cc)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = sc.Close()
return nil, ctx.Err()
}
}
func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) {
defer h.release()
h.conn = sc
msg, err := h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack != nil {
return nil, ErrUnexpectedPayload
}
if identity, err = cc.CheckCredential(sc, msg.cred); err != nil {
h.tryWriteErrAndClose(err)
return
}
if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
msg, err = h.readMsg()
if err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack == nil {
err = ErrUnexpectedPayload
h.tryWriteErrAndClose(err)
return nil, err
}
if msg.ack.Error != handshakeproto.Error_Null {
if msg.ack.Error == handshakeproto.Error_InvalidCredentials {
return nil, ErrPeerDeclinedCredentials
}
return nil, HandshakeError{e: msg.ack.Error}
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return nil, err
}
return
} }
func newHandshake() *handshake { func newHandshake() *handshake {
@ -193,11 +72,12 @@ func newHandshake() *handshake {
} }
type handshake struct { type handshake struct {
conn sec.SecureConn conn io.ReadWriteCloser
remoteCred *handshakeproto.Credentials remoteCred *handshakeproto.Credentials
remoteAck *handshakeproto.Ack remoteProto *handshakeproto.Proto
localAck *handshakeproto.Ack remoteAck *handshakeproto.Ack
buf []byte localAck *handshakeproto.Ack
buf []byte
} }
func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err error) { func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err error) {
@ -209,8 +89,17 @@ func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err erro
return h.writeData(msgTypeCred, n) return h.writeData(msgTypeCred, n)
} }
func (h *handshake) writeProto(proto *handshakeproto.Proto) (err error) {
h.buf = slices.Grow(h.buf, proto.Size()+headerSize)[:proto.Size()+headerSize]
n, err := proto.MarshalToSizedBuffer(h.buf[headerSize:])
if err != nil {
return err
}
return h.writeData(msgTypeProto, n)
}
func (h *handshake) tryWriteErrAndClose(err error) { func (h *handshake) tryWriteErrAndClose(err error) {
if err == ErrGotNotAHandshakeMessage { if err == ErrUnexpectedPayload {
// if we got unexpected message - just close the connection // if we got unexpected message - just close the connection
_ = h.conn.Close() _ = h.conn.Close()
return return
@ -243,21 +132,26 @@ func (h *handshake) writeData(tp byte, size int) (err error) {
} }
type message struct { type message struct {
cred *handshakeproto.Credentials cred *handshakeproto.Credentials
ack *handshakeproto.Ack proto *handshakeproto.Proto
ack *handshakeproto.Ack
} }
func (h *handshake) readMsg() (msg message, err error) { func (h *handshake) readMsg(allowedTypes ...byte) (msg message, err error) {
h.buf = slices.Grow(h.buf, headerSize)[:headerSize] h.buf = slices.Grow(h.buf, headerSize)[:headerSize]
if _, err = io.ReadFull(h.conn, h.buf[:headerSize]); err != nil { if _, err = io.ReadFull(h.conn, h.buf[:headerSize]); err != nil {
return return
} }
tp := h.buf[0] tp := h.buf[0]
if tp != msgTypeCred && tp != msgTypeAck { if !slices.Contains(allowedTypes, tp) {
err = ErrGotNotAHandshakeMessage err = ErrUnexpectedPayload
return return
} }
size := binary.LittleEndian.Uint32(h.buf[1:headerSize]) size := binary.LittleEndian.Uint32(h.buf[1:headerSize])
if size > sizeLimit {
err = ErrGotUnexpectedMessage
return
}
h.buf = slices.Grow(h.buf, int(size))[:size] h.buf = slices.Grow(h.buf, int(size))[:size]
if _, err = io.ReadFull(h.conn, h.buf[:size]); err != nil { if _, err = io.ReadFull(h.conn, h.buf[:size]); err != nil {
return return
@ -273,6 +167,11 @@ func (h *handshake) readMsg() (msg message, err error) {
return return
} }
msg.ack = h.remoteAck msg.ack = h.remoteAck
case msgTypeProto:
if err = h.remoteProto.Unmarshal(h.buf[:size]); err != nil {
return
}
msg.proto = h.remoteProto
} }
return return
} }
@ -284,5 +183,6 @@ func (h *handshake) release() {
h.remoteAck.Error = 0 h.remoteAck.Error = 0
h.remoteCred.Type = 0 h.remoteCred.Type = 0
h.remoteCred.Payload = h.remoteCred.Payload[:0] h.remoteCred.Payload = h.remoteCred.Payload[:0]
h.remoteProto.Proto = 0
handshakePool.Put(h) handshakePool.Put(h)
} }

View File

@ -59,6 +59,7 @@ const (
Error_SkipVerifyNotAllowed Error = 4 Error_SkipVerifyNotAllowed Error = 4
Error_DeadlineExceeded Error = 5 Error_DeadlineExceeded Error = 5
Error_IncompatibleVersion Error = 6 Error_IncompatibleVersion Error = 6
Error_IncompatibleProto Error = 7
) )
var Error_name = map[int32]string{ var Error_name = map[int32]string{
@ -69,6 +70,7 @@ var Error_name = map[int32]string{
4: "SkipVerifyNotAllowed", 4: "SkipVerifyNotAllowed",
5: "DeadlineExceeded", 5: "DeadlineExceeded",
6: "IncompatibleVersion", 6: "IncompatibleVersion",
7: "IncompatibleProto",
} }
var Error_value = map[string]int32{ var Error_value = map[string]int32{
@ -79,6 +81,7 @@ var Error_value = map[string]int32{
"SkipVerifyNotAllowed": 4, "SkipVerifyNotAllowed": 4,
"DeadlineExceeded": 5, "DeadlineExceeded": 5,
"IncompatibleVersion": 6, "IncompatibleVersion": 6,
"IncompatibleProto": 7,
} }
func (x Error) String() string { func (x Error) String() string {
@ -89,6 +92,28 @@ func (Error) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{1} return fileDescriptor_60283fc75f020893, []int{1}
} }
type ProtoType int32
const (
ProtoType_DRPC ProtoType = 0
)
var ProtoType_name = map[int32]string{
0: "DRPC",
}
var ProtoType_value = map[string]int32{
"DRPC": 0,
}
func (x ProtoType) String() string {
return proto.EnumName(ProtoType_name, int32(x))
}
func (ProtoType) EnumDescriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{2}
}
type Credentials struct { type Credentials struct {
Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"` Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"`
Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"`
@ -247,12 +272,58 @@ func (m *Ack) GetError() Error {
return Error_Null return Error_Null
} }
type Proto struct {
Proto ProtoType `protobuf:"varint,1,opt,name=proto,proto3,enum=anyHandshake.ProtoType" json:"proto,omitempty"`
}
func (m *Proto) Reset() { *m = Proto{} }
func (m *Proto) String() string { return proto.CompactTextString(m) }
func (*Proto) ProtoMessage() {}
func (*Proto) Descriptor() ([]byte, []int) {
return fileDescriptor_60283fc75f020893, []int{3}
}
func (m *Proto) XXX_Unmarshal(b []byte) error {
return m.Unmarshal(b)
}
func (m *Proto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
if deterministic {
return xxx_messageInfo_Proto.Marshal(b, m, deterministic)
} else {
b = b[:cap(b)]
n, err := m.MarshalToSizedBuffer(b)
if err != nil {
return nil, err
}
return b[:n], nil
}
}
func (m *Proto) XXX_Merge(src proto.Message) {
xxx_messageInfo_Proto.Merge(m, src)
}
func (m *Proto) XXX_Size() int {
return m.Size()
}
func (m *Proto) XXX_DiscardUnknown() {
xxx_messageInfo_Proto.DiscardUnknown(m)
}
var xxx_messageInfo_Proto proto.InternalMessageInfo
func (m *Proto) GetProto() ProtoType {
if m != nil {
return m.Proto
}
return ProtoType_DRPC
}
func init() { func init() {
proto.RegisterEnum("anyHandshake.CredentialsType", CredentialsType_name, CredentialsType_value) proto.RegisterEnum("anyHandshake.CredentialsType", CredentialsType_name, CredentialsType_value)
proto.RegisterEnum("anyHandshake.Error", Error_name, Error_value) proto.RegisterEnum("anyHandshake.Error", Error_name, Error_value)
proto.RegisterEnum("anyHandshake.ProtoType", ProtoType_name, ProtoType_value)
proto.RegisterType((*Credentials)(nil), "anyHandshake.Credentials") proto.RegisterType((*Credentials)(nil), "anyHandshake.Credentials")
proto.RegisterType((*PayloadSignedPeerIds)(nil), "anyHandshake.PayloadSignedPeerIds") proto.RegisterType((*PayloadSignedPeerIds)(nil), "anyHandshake.PayloadSignedPeerIds")
proto.RegisterType((*Ack)(nil), "anyHandshake.Ack") proto.RegisterType((*Ack)(nil), "anyHandshake.Ack")
proto.RegisterType((*Proto)(nil), "anyHandshake.Proto")
} }
func init() { func init() {
@ -260,32 +331,35 @@ func init() {
} }
var fileDescriptor_60283fc75f020893 = []byte{ var fileDescriptor_60283fc75f020893 = []byte{
// 395 bytes of a gzipped FileDescriptorProto // 439 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcd, 0x6e, 0x13, 0x31, 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcf, 0x6e, 0xd3, 0x40,
0x10, 0xc7, 0xd7, 0x4d, 0x52, 0xaa, 0x21, 0x2d, 0xee, 0x34, 0xc0, 0x0a, 0x89, 0x55, 0x94, 0x53, 0x10, 0xc6, 0xbd, 0x4d, 0xdc, 0x96, 0x21, 0x2d, 0xdb, 0x6d, 0x4a, 0x2d, 0x24, 0xac, 0x28, 0xa7,
0xc8, 0x21, 0xe1, 0xeb, 0x05, 0x02, 0x2d, 0x22, 0x97, 0xaa, 0xda, 0x42, 0x0f, 0xdc, 0xdc, 0xf5, 0x10, 0x89, 0x84, 0x7f, 0xe2, 0x1e, 0x9a, 0x22, 0x72, 0xa9, 0x22, 0x17, 0x7a, 0xe0, 0xb6, 0xf5,
0xd0, 0x5a, 0x31, 0xf6, 0xca, 0x76, 0x43, 0xf7, 0x2d, 0xb8, 0xf2, 0x46, 0x1c, 0x7b, 0xe4, 0x88, 0x0e, 0xed, 0x2a, 0xcb, 0xda, 0xda, 0xdd, 0x86, 0xfa, 0x2d, 0x78, 0x14, 0x1e, 0x83, 0x63, 0x8f,
0x92, 0x17, 0x41, 0x71, 0x12, 0x92, 0x70, 0xea, 0xc5, 0x9e, 0x8f, 0x9f, 0xfd, 0xff, 0x8f, 0x65, 0x1c, 0x51, 0xf2, 0x22, 0xc8, 0x9b, 0xa4, 0x71, 0x38, 0xf5, 0x62, 0xef, 0xcc, 0xfc, 0xc6, 0xdf,
0x18, 0x1a, 0x0a, 0x03, 0x4f, 0xc5, 0x8d, 0x23, 0x4f, 0x6e, 0xa2, 0x0a, 0x1a, 0x5c, 0x0b, 0x23, 0x7c, 0xe3, 0x85, 0x81, 0x46, 0xd7, 0xb7, 0x98, 0xde, 0x18, 0xb4, 0x68, 0xa6, 0x32, 0xc5, 0xfe,
0xfd, 0xb5, 0x18, 0x6f, 0x44, 0xa5, 0xb3, 0xc1, 0x0e, 0xe2, 0xea, 0xd7, 0xd5, 0x7e, 0x2c, 0x60, 0x35, 0xd7, 0xc2, 0x5e, 0xf3, 0x49, 0xe5, 0x94, 0x9b, 0xcc, 0x65, 0x7d, 0xff, 0xb4, 0xeb, 0x6c,
0x53, 0x98, 0xea, 0xe3, 0xaa, 0xd6, 0x09, 0xf0, 0xf0, 0xbd, 0x23, 0x49, 0x26, 0x28, 0xa1, 0x3d, 0xcf, 0x27, 0x58, 0x83, 0xeb, 0xe2, 0xd3, 0x2a, 0xd7, 0x76, 0xf0, 0xf8, 0xc4, 0xa0, 0x40, 0xed,
0xbe, 0x82, 0x7a, 0xa8, 0x4a, 0x4a, 0x59, 0x9b, 0x75, 0x0f, 0x5e, 0x3f, 0xef, 0x6f, 0xb2, 0xfd, 0x24, 0x57, 0x96, 0xbd, 0x86, 0xba, 0x2b, 0x72, 0x8c, 0x48, 0x8b, 0x74, 0xf6, 0xdf, 0x3c, 0xef,
0x0d, 0xf0, 0x53, 0x55, 0x52, 0x1e, 0x51, 0x4c, 0xe1, 0x41, 0x29, 0x2a, 0x6d, 0x85, 0x4c, 0x77, 0x55, 0xd9, 0x5e, 0x05, 0xfc, 0x5c, 0xe4, 0x98, 0x78, 0x94, 0x45, 0xb0, 0x93, 0xf3, 0x42, 0x65,
0xda, 0xac, 0xdb, 0xcc, 0x57, 0xe9, 0xbc, 0x33, 0x21, 0xe7, 0x95, 0x35, 0x69, 0xad, 0xcd, 0xba, 0x5c, 0x44, 0x5b, 0x2d, 0xd2, 0x69, 0x24, 0xab, 0xb0, 0xac, 0x4c, 0xd1, 0x58, 0x99, 0xe9, 0xa8,
0xfb, 0xf9, 0x2a, 0xed, 0x7c, 0x80, 0xd6, 0xd9, 0x02, 0x3a, 0x57, 0x57, 0x86, 0xe4, 0x19, 0x91, 0xd6, 0x22, 0x9d, 0xbd, 0x64, 0x15, 0xb6, 0x3f, 0x42, 0x73, 0xbc, 0x80, 0xce, 0xe5, 0x95, 0x46,
0x1b, 0x49, 0x8f, 0xcf, 0x60, 0x4f, 0x45, 0x89, 0x50, 0x45, 0x0b, 0xcd, 0xfc, 0x5f, 0x8e, 0x08, 0x31, 0x46, 0x34, 0x23, 0x61, 0xd9, 0x33, 0xd8, 0x95, 0x5e, 0xc2, 0x15, 0x7e, 0x84, 0x46, 0x72,
0x75, 0xaf, 0xae, 0xcc, 0x52, 0x24, 0xc6, 0x9d, 0x97, 0x50, 0x1b, 0x16, 0x63, 0x7c, 0x01, 0x0d, 0x1f, 0x33, 0x06, 0x75, 0x2b, 0xaf, 0xf4, 0x52, 0xc4, 0x9f, 0xdb, 0xaf, 0xa0, 0x36, 0x48, 0x27,
0x72, 0xce, 0xba, 0xa5, 0xed, 0xa3, 0x6d, 0xdb, 0x27, 0xf3, 0x56, 0xbe, 0x20, 0x7a, 0x6f, 0xe1, 0xec, 0x05, 0x84, 0x68, 0x4c, 0x66, 0x96, 0x63, 0x1f, 0x6e, 0x8e, 0x7d, 0x5a, 0x96, 0x92, 0x05,
0xd1, 0x7f, 0x63, 0xe0, 0x01, 0xc0, 0xf9, 0x58, 0x95, 0x17, 0xe4, 0xd4, 0xd7, 0x8a, 0x27, 0x78, 0xd1, 0x7e, 0x0f, 0xe1, 0xd8, 0xaf, 0xe1, 0x25, 0x84, 0x7e, 0x1f, 0xcb, 0x9e, 0xe3, 0xcd, 0x1e,
0x08, 0xfb, 0x5b, 0xae, 0x38, 0xeb, 0xfd, 0x64, 0xd0, 0x88, 0xd7, 0xe0, 0x1e, 0xd4, 0x4f, 0x6f, 0xcf, 0x78, 0x93, 0x0b, 0xaa, 0xfb, 0x0e, 0x9e, 0xfc, 0x67, 0x9f, 0xed, 0x03, 0x9c, 0x4f, 0x64,
0xb4, 0xe6, 0xc9, 0xfc, 0xd8, 0x67, 0x43, 0xb7, 0x25, 0x15, 0x81, 0x24, 0x67, 0xf8, 0x04, 0x70, 0x7e, 0x81, 0x46, 0x7e, 0x2b, 0x68, 0xc0, 0x0e, 0x60, 0x6f, 0xc3, 0x0d, 0x25, 0xdd, 0x5f, 0x04,
0x64, 0x26, 0x42, 0x2b, 0xb9, 0x21, 0xc0, 0x77, 0xf0, 0x31, 0x1c, 0xae, 0xb9, 0xe5, 0xd4, 0xbc, 0x42, 0x2f, 0xcf, 0x76, 0xa1, 0x7e, 0x76, 0xa3, 0x14, 0x0d, 0xca, 0xb6, 0x2f, 0x1a, 0x6f, 0x73,
0x86, 0x29, 0xb4, 0xd6, 0xaa, 0xa7, 0x36, 0x0c, 0xb5, 0xb6, 0xdf, 0x49, 0xf2, 0x3a, 0xb6, 0x80, 0x4c, 0x1d, 0x0a, 0x4a, 0xd8, 0x53, 0x60, 0x23, 0x3d, 0xe5, 0x4a, 0x8a, 0x8a, 0x00, 0xdd, 0x62,
0x1f, 0x93, 0x90, 0x5a, 0x19, 0x3a, 0xb9, 0x2d, 0x88, 0x24, 0x49, 0xde, 0xc0, 0xa7, 0x70, 0x34, 0x47, 0x70, 0xb0, 0xe6, 0x96, 0xdb, 0xa2, 0x35, 0x16, 0x41, 0x73, 0xad, 0x7a, 0x96, 0xb9, 0x81,
0x32, 0x85, 0xfd, 0x56, 0x8a, 0xa0, 0x2e, 0x35, 0x5d, 0x2c, 0x5e, 0x92, 0xef, 0xbe, 0x3b, 0xfe, 0x52, 0xd9, 0x0f, 0x14, 0xb4, 0xce, 0x9a, 0x40, 0x87, 0xc8, 0x85, 0x92, 0x1a, 0x4f, 0x6f, 0x53,
0x35, 0xcd, 0xd8, 0xdd, 0x34, 0x63, 0x7f, 0xa6, 0x19, 0xfb, 0x31, 0xcb, 0x92, 0xbb, 0x59, 0x96, 0x44, 0x81, 0x82, 0x86, 0xec, 0x18, 0x0e, 0x47, 0x3a, 0xcd, 0xbe, 0xe7, 0xdc, 0xc9, 0x4b, 0x85,
0xfc, 0x9e, 0x65, 0xc9, 0x97, 0xde, 0xfd, 0x3f, 0xcb, 0xe5, 0x6e, 0xdc, 0xde, 0xfc, 0x0d, 0x00, 0x17, 0x8b, 0x3f, 0x40, 0xb7, 0xcb, 0xef, 0x57, 0x0b, 0xde, 0x31, 0xdd, 0xe9, 0x1e, 0xc1, 0xa3,
0x00, 0xff, 0xff, 0xbf, 0x78, 0x2f, 0x36, 0x61, 0x02, 0x00, 0x00, 0x7b, 0xf3, 0xe5, 0xd4, 0xc3, 0x64, 0x7c, 0x42, 0x83, 0x0f, 0xc3, 0xdf, 0xb3, 0x98, 0xdc, 0xcd,
0x62, 0xf2, 0x77, 0x16, 0x93, 0x9f, 0xf3, 0x38, 0xb8, 0x9b, 0xc7, 0xc1, 0x9f, 0x79, 0x1c, 0x7c,
0xed, 0x3e, 0xfc, 0x4a, 0x5e, 0x6e, 0xfb, 0xd7, 0xdb, 0x7f, 0x01, 0x00, 0x00, 0xff, 0xff, 0x53,
0x32, 0xf7, 0x79, 0xc7, 0x02, 0x00, 0x00,
} }
func (m *Credentials) Marshal() (dAtA []byte, err error) { func (m *Credentials) Marshal() (dAtA []byte, err error) {
@ -393,6 +467,34 @@ func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) {
return len(dAtA) - i, nil return len(dAtA) - i, nil
} }
func (m *Proto) Marshal() (dAtA []byte, err error) {
size := m.Size()
dAtA = make([]byte, size)
n, err := m.MarshalToSizedBuffer(dAtA[:size])
if err != nil {
return nil, err
}
return dAtA[:n], nil
}
func (m *Proto) MarshalTo(dAtA []byte) (int, error) {
size := m.Size()
return m.MarshalToSizedBuffer(dAtA[:size])
}
func (m *Proto) MarshalToSizedBuffer(dAtA []byte) (int, error) {
i := len(dAtA)
_ = i
var l int
_ = l
if m.Proto != 0 {
i = encodeVarintHandshake(dAtA, i, uint64(m.Proto))
i--
dAtA[i] = 0x8
}
return len(dAtA) - i, nil
}
func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int { func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int {
offset -= sovHandshake(v) offset -= sovHandshake(v)
base := offset base := offset
@ -452,6 +554,18 @@ func (m *Ack) Size() (n int) {
return n return n
} }
func (m *Proto) Size() (n int) {
if m == nil {
return 0
}
var l int
_ = l
if m.Proto != 0 {
n += 1 + sovHandshake(uint64(m.Proto))
}
return n
}
func sovHandshake(x uint64) (n int) { func sovHandshake(x uint64) (n int) {
return (math_bits.Len64(x|1) + 6) / 7 return (math_bits.Len64(x|1) + 6) / 7
} }
@ -767,6 +881,75 @@ func (m *Ack) Unmarshal(dAtA []byte) error {
} }
return nil return nil
} }
func (m *Proto) Unmarshal(dAtA []byte) error {
l := len(dAtA)
iNdEx := 0
for iNdEx < l {
preIndex := iNdEx
var wire uint64
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
wire |= uint64(b&0x7F) << shift
if b < 0x80 {
break
}
}
fieldNum := int32(wire >> 3)
wireType := int(wire & 0x7)
if wireType == 4 {
return fmt.Errorf("proto: Proto: wiretype end group for non-group")
}
if fieldNum <= 0 {
return fmt.Errorf("proto: Proto: illegal tag %d (wire type %d)", fieldNum, wire)
}
switch fieldNum {
case 1:
if wireType != 0 {
return fmt.Errorf("proto: wrong wireType = %d for field Proto", wireType)
}
m.Proto = 0
for shift := uint(0); ; shift += 7 {
if shift >= 64 {
return ErrIntOverflowHandshake
}
if iNdEx >= l {
return io.ErrUnexpectedEOF
}
b := dAtA[iNdEx]
iNdEx++
m.Proto |= ProtoType(b&0x7F) << shift
if b < 0x80 {
break
}
}
default:
iNdEx = preIndex
skippy, err := skipHandshake(dAtA[iNdEx:])
if err != nil {
return err
}
if (skippy < 0) || (iNdEx+skippy) < 0 {
return ErrInvalidLengthHandshake
}
if (iNdEx + skippy) > l {
return io.ErrUnexpectedEOF
}
iNdEx += skippy
}
}
if iNdEx > l {
return io.ErrUnexpectedEOF
}
return nil
}
func skipHandshake(dAtA []byte) (n int, err error) { func skipHandshake(dAtA []byte) (n int, err error) {
l := len(dAtA) l := len(dAtA)
iNdEx := 0 iNdEx := 0

View File

@ -5,6 +5,8 @@ option go_package = "net/secureservice/handshake/handshakeproto";
/* /*
CREDENTIALS HANDSHAKE
Alice opens a new connection with Bob Alice opens a new connection with Bob
1. TLS handshake done successfully; both sides know local and remote peer identifiers. 1. TLS handshake done successfully; both sides know local and remote peer identifiers.
@ -68,4 +70,20 @@ enum Error {
SkipVerifyNotAllowed = 4; SkipVerifyNotAllowed = 4;
DeadlineExceeded = 5; DeadlineExceeded = 5;
IncompatibleVersion = 6; IncompatibleVersion = 6;
IncompatibleProto = 7;
}
/*
PROTO HANDSHAKE
*/
message Proto {
ProtoType proto = 1;
}
enum ProtoType {
DRPC = 0;
} }

View File

@ -0,0 +1,97 @@
package handshake
import (
"context"
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"golang.org/x/exp/slices"
"net"
)
type ProtoChecker struct {
AllowedProtoTypes []handshakeproto.ProtoType
}
func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) (err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
err = outgoingProtoHandshake(h, conn, pt)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = conn.Close()
return ctx.Err()
}
}
func outgoingProtoHandshake(h *handshake, conn net.Conn, pt handshakeproto.ProtoType) (err error) {
defer h.release()
h.conn = conn
localProto := &handshakeproto.Proto{
Proto: pt,
}
if err = h.writeProto(localProto); err != nil {
h.tryWriteErrAndClose(err)
return
}
msg, err := h.readMsg(msgTypeAck)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if msg.ack.Error == handshakeproto.Error_IncompatibleProto {
return ErrRemoteIncompatibleProto
}
if msg.ack.Error == handshakeproto.Error_Null {
return nil
}
return HandshakeError{e: msg.ack.Error}
}
func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
if ctx == nil {
ctx = context.Background()
}
h := newHandshake()
done := make(chan struct{})
go func() {
defer close(done)
protoType, err = incomingProtoHandshake(h, conn, pt)
}()
select {
case <-done:
return
case <-ctx.Done():
_ = conn.Close()
return 0, ctx.Err()
}
}
func incomingProtoHandshake(h *handshake, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) {
defer h.release()
h.conn = conn
msg, err := h.readMsg(msgTypeProto)
if err != nil {
h.tryWriteErrAndClose(err)
return
}
if !slices.Contains(pt.AllowedProtoTypes, msg.proto.Proto) {
err = ErrIncompatibleProto
h.tryWriteErrAndClose(err)
return
}
if err = h.writeAck(handshakeproto.Error_Null); err != nil {
h.tryWriteErrAndClose(err)
return 0, err
} else {
return msg.proto.Proto, nil
}
}

View File

@ -0,0 +1,121 @@
package handshake
import (
"github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing"
"time"
)
type protoRes struct {
protoType handshakeproto.ProtoType
err error
}
func newProtoChecker(types ...handshakeproto.ProtoType) ProtoChecker {
return ProtoChecker{AllowedProtoTypes: types}
}
func TestIncomingProtoHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1))
protoResCh <- protoRes{protoType: protoType, err: err}
}()
h := newHandshake()
h.conn = c2
// write desired proto
require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: handshakeproto.ProtoType(1)}))
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error)
res := <-protoResCh
require.NoError(t, res.err)
assert.Equal(t, handshakeproto.ProtoType(1), res.protoType)
})
t.Run("incompatible", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1))
protoResCh <- protoRes{protoType: protoType, err: err}
}()
h := newHandshake()
h.conn = c2
// write desired proto
require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: 0}))
msg, err := h.readMsg(msgTypeAck)
require.NoError(t, err)
assert.Equal(t, handshakeproto.Error_IncompatibleProto, msg.ack.Error)
res := <-protoResCh
require.Error(t, res.err, ErrIncompatibleProto.Error())
})
}
func TestOutgoingProtoHandshake(t *testing.T) {
t.Run("success", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
err := OutgoingProtoHandshake(nil, c1, 1)
protoResCh <- protoRes{err: err}
}()
h := newHandshake()
h.conn = c2
msg, err := h.readMsg(msgTypeProto)
require.NoError(t, err)
assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto)
require.NoError(t, h.writeAck(handshakeproto.Error_Null))
res := <-protoResCh
assert.NoError(t, res.err)
})
t.Run("incompatible", func(t *testing.T) {
c1, c2 := newConnPair(t)
var protoResCh = make(chan protoRes, 1)
go func() {
err := OutgoingProtoHandshake(nil, c1, 1)
protoResCh <- protoRes{err: err}
}()
h := newHandshake()
h.conn = c2
msg, err := h.readMsg(msgTypeProto)
require.NoError(t, err)
assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto)
require.NoError(t, h.writeAck(handshakeproto.Error_IncompatibleProto))
res := <-protoResCh
assert.EqualError(t, res.err, ErrRemoteIncompatibleProto.Error())
})
}
func TestEndToEndProto(t *testing.T) {
c1, c2 := newConnPair(t)
var (
inResCh = make(chan protoRes, 1)
outResCh = make(chan protoRes, 1)
)
st := time.Now()
go func() {
err := OutgoingProtoHandshake(nil, c1, 0)
outResCh <- protoRes{err: err}
}()
go func() {
protoType, err := IncomingProtoHandshake(nil, c2, newProtoChecker(0, 1))
inResCh <- protoRes{protoType: protoType, err: err}
}()
outRes := <-outResCh
assert.NoError(t, outRes.err)
inRes := <-inResCh
assert.NoError(t, inRes.err)
assert.Equal(t, handshakeproto.ProtoType(0), inRes.protoType)
t.Log("dur", time.Since(st))
}

Some files were not shown because too many files have changed in this diff Show More