diff --git a/app/app.go b/app/app.go index a290a82b..89e7ab46 100644 --- a/app/app.go +++ b/app/app.go @@ -55,6 +55,7 @@ type ComponentStatable interface { // App is the central part of the application // It contains and manages all components type App struct { + parent *App components []Component mu sync.RWMutex startStat Stat @@ -109,6 +110,16 @@ func VersionDescription() string { return fmt.Sprintf("build on %s from %s at #%s(%s)", BuildDate, GitBranch, GitCommit, GitState) } +// ChildApp creates a child container which has access to parent's components +// It doesn't call Start on any of the parent's components +func (app *App) ChildApp() *App { + return &App{ + parent: app, + deviceState: app.deviceState, + anySyncVersion: app.AnySyncVersion(), + } +} + // Register adds service to registry // All components will be started in the order they were registered func (app *App) Register(s Component) *App { @@ -128,10 +139,14 @@ func (app *App) Register(s Component) *App { func (app *App) Component(name string) Component { app.mu.RLock() defer app.mu.RUnlock() - for _, s := range app.components { - if s.Name() == name { - return s + current := app + for current != nil { + for _, s := range current.components { + if s.Name() == name { + return s + } } + current = current.parent } return nil } @@ -149,10 +164,14 @@ func (app *App) MustComponent(name string) Component { func MustComponent[i any](app *App) i { app.mu.RLock() defer app.mu.RUnlock() - for _, s := range app.components { - if v, ok := s.(i); ok { - return v + current := app + for current != nil { + for _, s := range current.components { + if v, ok := s.(i); ok { + return v + } } + current = current.parent } empty := new(i) panic(fmt.Errorf("component with interface %T is not found", empty)) @@ -162,9 +181,13 @@ func MustComponent[i any](app *App) i { func (app *App) ComponentNames() (names []string) { app.mu.RLock() defer app.mu.RUnlock() - names = make([]string, len(app.components)) - for i, c := range app.components { - names[i] = c.Name() + names = make([]string, 0, len(app.components)) + current := app + for current != nil { + for _, c := range current.components { + names = append(names, c.Name()) + } + current = current.parent } return } diff --git a/app/app_test.go b/app/app_test.go index cdc52445..0c122b24 100644 --- a/app/app_test.go +++ b/app/app_test.go @@ -34,6 +34,25 @@ func TestAppServiceRegistry(t *testing.T) { names := app.ComponentNames() assert.Equal(t, names, []string{"c1", "r1", "s1"}) }) + t.Run("Child MustComponent", func(t *testing.T) { + app := app.ChildApp() + app.Register(newTestService(testTypeComponent, "x1", nil, nil)) + for _, name := range []string{"c1", "r1", "s1", "x1"} { + assert.NotPanics(t, func() { app.MustComponent(name) }, name) + } + assert.Panics(t, func() { app.MustComponent("not-registered") }) + }) + t.Run("Child ComponentNames", func(t *testing.T) { + app := app.ChildApp() + app.Register(newTestService(testTypeComponent, "x1", nil, nil)) + names := app.ComponentNames() + assert.Equal(t, names, []string{"x1", "c1", "r1", "s1"}) + }) + t.Run("Child override", func(t *testing.T) { + app := app.ChildApp() + app.Register(newTestService(testTypeRunnable, "s1", nil, nil)) + _ = app.MustComponent("s1").(*testRunnable) + }) } func TestAppStart(t *testing.T) { diff --git a/app/ocache/metrics.go b/app/ocache/metrics.go index b520dff2..c2dc04c5 100644 --- a/app/ocache/metrics.go +++ b/app/ocache/metrics.go @@ -6,6 +6,9 @@ import ( ) func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Option { + if reg == nil { + return nil + } if subsystem == "" { subsystem = "cache" } @@ -13,9 +16,7 @@ func WithPrometheus(reg *prometheus.Registry, namespace, subsystem string) Optio subSplit := strings.Split(subsystem, ".") namespace = strings.Join(nameSplit, "_") subsystem = strings.Join(subSplit, "_") - if reg == nil { - return nil - } + return func(cache *oCache) { cache.metrics = &metrics{ hit: prometheus.NewCounter(prometheus.CounterOpts{ diff --git a/commonfile/fileproto/file_drpc.pb.go b/commonfile/fileproto/file_drpc.pb.go index 2f9ee69d..a03c22cd 100644 --- a/commonfile/fileproto/file_drpc.pb.go +++ b/commonfile/fileproto/file_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: commonfile/fileproto/protos/file.proto package fileproto diff --git a/commonspace/commongetter.go b/commonspace/commongetter.go deleted file mode 100644 index 80476a14..00000000 --- a/commonspace/commongetter.go +++ /dev/null @@ -1,62 +0,0 @@ -package commonspace - -import ( - "context" - "github.com/anyproto/any-sync/commonspace/object/syncobjectgetter" - "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" - "github.com/anyproto/any-sync/commonspace/object/treemanager" - "sync/atomic" -) - -type commonGetter struct { - treemanager.TreeManager - spaceId string - reservedObjects []syncobjectgetter.SyncObject - spaceIsClosed *atomic.Bool -} - -func newCommonGetter(spaceId string, getter treemanager.TreeManager, spaceIsClosed *atomic.Bool) *commonGetter { - return &commonGetter{ - TreeManager: getter, - spaceId: spaceId, - spaceIsClosed: spaceIsClosed, - } -} - -func (c *commonGetter) AddObject(object syncobjectgetter.SyncObject) { - c.reservedObjects = append(c.reservedObjects, object) -} - -func (c *commonGetter) GetTree(ctx context.Context, spaceId, treeId string) (objecttree.ObjectTree, error) { - if c.spaceIsClosed.Load() { - return nil, ErrSpaceClosed - } - if obj := c.getReservedObject(treeId); obj != nil { - return obj.(objecttree.ObjectTree), nil - } - return c.TreeManager.GetTree(ctx, spaceId, treeId) -} - -func (c *commonGetter) getReservedObject(id string) syncobjectgetter.SyncObject { - for _, obj := range c.reservedObjects { - if obj != nil && obj.Id() == id { - return obj - } - } - return nil -} - -func (c *commonGetter) GetObject(ctx context.Context, objectId string) (obj syncobjectgetter.SyncObject, err error) { - if c.spaceIsClosed.Load() { - return nil, ErrSpaceClosed - } - if obj := c.getReservedObject(objectId); obj != nil { - return obj, nil - } - t, err := c.TreeManager.GetTree(ctx, c.spaceId, objectId) - if err != nil { - return - } - obj = t.(syncobjectgetter.SyncObject) - return -} diff --git a/commonspace/config.go b/commonspace/config/config.go similarity index 91% rename from commonspace/config.go rename to commonspace/config/config.go index e5485068..cce4b548 100644 --- a/commonspace/config.go +++ b/commonspace/config/config.go @@ -1,4 +1,4 @@ -package commonspace +package config type ConfigGetter interface { GetSpace() Config diff --git a/commonspace/credentialprovider/credentialprovider.go b/commonspace/credentialprovider/credentialprovider.go index ce570904..5f12b065 100644 --- a/commonspace/credentialprovider/credentialprovider.go +++ b/commonspace/credentialprovider/credentialprovider.go @@ -3,6 +3,7 @@ package credentialprovider import ( "context" + "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/commonspace/spacesyncproto" ) @@ -13,12 +14,21 @@ func NewNoOp() CredentialProvider { } type CredentialProvider interface { + app.Component GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) } type noOpProvider struct { } +func (n noOpProvider) Init(a *app.App) (err error) { + return nil +} + +func (n noOpProvider) Name() (name string) { + return CName +} + func (n noOpProvider) GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) { return nil, nil } diff --git a/commonspace/credentialprovider/mock_credentialprovider/mock_credentialprovider.go b/commonspace/credentialprovider/mock_credentialprovider/mock_credentialprovider.go index a1f8dd97..77faf28e 100644 --- a/commonspace/credentialprovider/mock_credentialprovider/mock_credentialprovider.go +++ b/commonspace/credentialprovider/mock_credentialprovider/mock_credentialprovider.go @@ -8,6 +8,7 @@ import ( context "context" reflect "reflect" + app "github.com/anyproto/any-sync/app" spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto" gomock "github.com/golang/mock/gomock" ) @@ -49,3 +50,31 @@ func (mr *MockCredentialProviderMockRecorder) GetCredential(arg0, arg1 interface mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCredential", reflect.TypeOf((*MockCredentialProvider)(nil).GetCredential), arg0, arg1) } + +// Init mocks base method. +func (m *MockCredentialProvider) Init(arg0 *app.App) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Init", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Init indicates an expected call of Init. +func (mr *MockCredentialProviderMockRecorder) Init(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockCredentialProvider)(nil).Init), arg0) +} + +// Name mocks base method. +func (m *MockCredentialProvider) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name. +func (mr *MockCredentialProviderMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockCredentialProvider)(nil).Name)) +} diff --git a/commonspace/deletion_test.go b/commonspace/deletion_test.go index 64d6fa11..33745884 100644 --- a/commonspace/deletion_test.go +++ b/commonspace/deletion_test.go @@ -73,13 +73,14 @@ func TestSpaceDeleteIds(t *testing.T) { fx.treeManager.space = spc err = spc.Init(ctx) require.NoError(t, err) - + close(fx.treeManager.waitLoad) + var ids []string for i := 0; i < totalObjs; i++ { // creating a tree bytes := make([]byte, 32) rand.Read(bytes) - doc, err := spc.CreateTree(ctx, objecttree.ObjectTreeCreatePayload{ + doc, err := spc.TreeBuilder().CreateTree(ctx, objecttree.ObjectTreeCreatePayload{ PrivKey: acc.SignKey, ChangeType: "some", SpaceId: spc.Id(), @@ -88,7 +89,7 @@ func TestSpaceDeleteIds(t *testing.T) { Timestamp: time.Now().Unix(), }) require.NoError(t, err) - tr, err := spc.PutTree(ctx, doc, nil) + tr, err := spc.TreeBuilder().PutTree(ctx, doc, nil) require.NoError(t, err) ids = append(ids, tr.Id()) tr.Close() @@ -106,7 +107,7 @@ func TestSpaceDeleteIds(t *testing.T) { func createTree(t *testing.T, ctx context.Context, spc Space, acc *accountdata.AccountKeys) string { bytes := make([]byte, 32) rand.Read(bytes) - doc, err := spc.CreateTree(ctx, objecttree.ObjectTreeCreatePayload{ + doc, err := spc.TreeBuilder().CreateTree(ctx, objecttree.ObjectTreeCreatePayload{ PrivKey: acc.SignKey, ChangeType: "some", SpaceId: spc.Id(), @@ -115,7 +116,7 @@ func createTree(t *testing.T, ctx context.Context, spc Space, acc *accountdata.A Timestamp: time.Now().Unix(), }) require.NoError(t, err) - tr, err := spc.PutTree(ctx, doc, nil) + tr, err := spc.TreeBuilder().PutTree(ctx, doc, nil) require.NoError(t, err) tr.Close() return tr.Id() @@ -147,9 +148,10 @@ func TestSpaceDeleteIdsIncorrectSnapshot(t *testing.T) { // adding space to tree manager fx.treeManager.space = spc err = spc.Init(ctx) + close(fx.treeManager.waitLoad) require.NoError(t, err) - settingsObject := spc.(*space).settingsObject + settingsObject := spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject() var ids []string for i := 0; i < totalObjs; i++ { id := createTree(t, ctx, spc, acc) @@ -183,17 +185,19 @@ func TestSpaceDeleteIdsIncorrectSnapshot(t *testing.T) { spc, err = fx.spaceService.NewSpace(ctx, sp) require.NoError(t, err) require.NotNil(t, spc) + fx.treeManager.waitLoad = make(chan struct{}) fx.treeManager.space = spc fx.treeManager.deletedIds = nil err = spc.Init(ctx) require.NoError(t, err) + close(fx.treeManager.waitLoad) // waiting until everything is deleted time.Sleep(3 * time.Second) require.Equal(t, len(ids), len(fx.treeManager.deletedIds)) // checking that new snapshot will contain all the changes - settingsObject = spc.(*space).settingsObject + settingsObject = spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject() settings.DoSnapshot = func(treeLen int) bool { return true } @@ -230,8 +234,9 @@ func TestSpaceDeleteIdsMarkDeleted(t *testing.T) { fx.treeManager.space = spc err = spc.Init(ctx) require.NoError(t, err) + close(fx.treeManager.waitLoad) - settingsObject := spc.(*space).settingsObject + settingsObject := spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject() var ids []string for i := 0; i < totalObjs; i++ { id := createTree(t, ctx, spc, acc) @@ -259,10 +264,12 @@ func TestSpaceDeleteIdsMarkDeleted(t *testing.T) { require.NoError(t, err) require.NotNil(t, spc) fx.treeManager.space = spc + fx.treeManager.waitLoad = make(chan struct{}) fx.treeManager.deletedIds = nil fx.treeManager.markedIds = nil err = spc.Init(ctx) require.NoError(t, err) + close(fx.treeManager.waitLoad) // waiting until everything is deleted time.Sleep(3 * time.Second) diff --git a/commonspace/settings/settingsstate/deletionstate.go b/commonspace/deletionstate/deletionstate.go similarity index 83% rename from commonspace/settings/settingsstate/deletionstate.go rename to commonspace/deletionstate/deletionstate.go index f36f4fd0..f7d1d9a7 100644 --- a/commonspace/settings/settingsstate/deletionstate.go +++ b/commonspace/deletionstate/deletionstate.go @@ -1,16 +1,22 @@ -//go:generate mockgen -destination mock_settingsstate/mock_settingsstate.go github.com/anyproto/any-sync/commonspace/settings/settingsstate ObjectDeletionState,StateBuilder,ChangeFactory -package settingsstate +//go:generate mockgen -destination mock_deletionstate/mock_deletionstate.go github.com/anyproto/any-sync/commonspace/deletionstate ObjectDeletionState +package deletionstate import ( + "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/commonspace/spacestorage" "go.uber.org/zap" "sync" ) +var log = logger.NewNamed(CName) + +const CName = "common.commonspace.deletionstate" + type StateUpdateObserver func(ids []string) type ObjectDeletionState interface { + app.Component AddObserver(observer StateUpdateObserver) Add(ids map[string]struct{}) GetQueued() (ids []string) @@ -28,12 +34,20 @@ type objectDeletionState struct { storage spacestorage.SpaceStorage } -func NewObjectDeletionState(log logger.CtxLogger, storage spacestorage.SpaceStorage) ObjectDeletionState { +func (st *objectDeletionState) Init(a *app.App) (err error) { + st.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage) + return nil +} + +func (st *objectDeletionState) Name() (name string) { + return CName +} + +func New() ObjectDeletionState { return &objectDeletionState{ log: log, queued: map[string]struct{}{}, deleted: map[string]struct{}{}, - storage: storage, } } diff --git a/commonspace/settings/settingsstate/deletionstate_test.go b/commonspace/deletionstate/deletionstate_test.go similarity index 95% rename from commonspace/settings/settingsstate/deletionstate_test.go rename to commonspace/deletionstate/deletionstate_test.go index ca2ea679..e5489bd8 100644 --- a/commonspace/settings/settingsstate/deletionstate_test.go +++ b/commonspace/deletionstate/deletionstate_test.go @@ -1,7 +1,6 @@ -package settingsstate +package deletionstate import ( - "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage" "github.com/golang/mock/gomock" @@ -19,7 +18,8 @@ type fixture struct { func newFixture(t *testing.T) *fixture { ctrl := gomock.NewController(t) spaceStorage := mock_spacestorage.NewMockSpaceStorage(ctrl) - delState := NewObjectDeletionState(logger.NewNamed("test"), spaceStorage).(*objectDeletionState) + delState := New().(*objectDeletionState) + delState.storage = spaceStorage return &fixture{ ctrl: ctrl, delState: delState, diff --git a/commonspace/deletionstate/mock_deletionstate/mock_deletionstate.go b/commonspace/deletionstate/mock_deletionstate/mock_deletionstate.go new file mode 100644 index 00000000..c4e9fefb --- /dev/null +++ b/commonspace/deletionstate/mock_deletionstate/mock_deletionstate.go @@ -0,0 +1,144 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/anyproto/any-sync/commonspace/deletionstate (interfaces: ObjectDeletionState) + +// Package mock_deletionstate is a generated GoMock package. +package mock_deletionstate + +import ( + reflect "reflect" + + app "github.com/anyproto/any-sync/app" + deletionstate "github.com/anyproto/any-sync/commonspace/deletionstate" + gomock "github.com/golang/mock/gomock" +) + +// MockObjectDeletionState is a mock of ObjectDeletionState interface. +type MockObjectDeletionState struct { + ctrl *gomock.Controller + recorder *MockObjectDeletionStateMockRecorder +} + +// MockObjectDeletionStateMockRecorder is the mock recorder for MockObjectDeletionState. +type MockObjectDeletionStateMockRecorder struct { + mock *MockObjectDeletionState +} + +// NewMockObjectDeletionState creates a new mock instance. +func NewMockObjectDeletionState(ctrl *gomock.Controller) *MockObjectDeletionState { + mock := &MockObjectDeletionState{ctrl: ctrl} + mock.recorder = &MockObjectDeletionStateMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockObjectDeletionState) EXPECT() *MockObjectDeletionStateMockRecorder { + return m.recorder +} + +// Add mocks base method. +func (m *MockObjectDeletionState) Add(arg0 map[string]struct{}) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Add", arg0) +} + +// Add indicates an expected call of Add. +func (mr *MockObjectDeletionStateMockRecorder) Add(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockObjectDeletionState)(nil).Add), arg0) +} + +// AddObserver mocks base method. +func (m *MockObjectDeletionState) AddObserver(arg0 deletionstate.StateUpdateObserver) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddObserver", arg0) +} + +// AddObserver indicates an expected call of AddObserver. +func (mr *MockObjectDeletionStateMockRecorder) AddObserver(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddObserver", reflect.TypeOf((*MockObjectDeletionState)(nil).AddObserver), arg0) +} + +// Delete mocks base method. +func (m *MockObjectDeletionState) Delete(arg0 string) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Delete", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Delete indicates an expected call of Delete. +func (mr *MockObjectDeletionStateMockRecorder) Delete(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockObjectDeletionState)(nil).Delete), arg0) +} + +// Exists mocks base method. +func (m *MockObjectDeletionState) Exists(arg0 string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Exists", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Exists indicates an expected call of Exists. +func (mr *MockObjectDeletionStateMockRecorder) Exists(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockObjectDeletionState)(nil).Exists), arg0) +} + +// Filter mocks base method. +func (m *MockObjectDeletionState) Filter(arg0 []string) []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Filter", arg0) + ret0, _ := ret[0].([]string) + return ret0 +} + +// Filter indicates an expected call of Filter. +func (mr *MockObjectDeletionStateMockRecorder) Filter(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filter", reflect.TypeOf((*MockObjectDeletionState)(nil).Filter), arg0) +} + +// GetQueued mocks base method. +func (m *MockObjectDeletionState) GetQueued() []string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetQueued") + ret0, _ := ret[0].([]string) + return ret0 +} + +// GetQueued indicates an expected call of GetQueued. +func (mr *MockObjectDeletionStateMockRecorder) GetQueued() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQueued", reflect.TypeOf((*MockObjectDeletionState)(nil).GetQueued)) +} + +// Init mocks base method. +func (m *MockObjectDeletionState) Init(arg0 *app.App) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Init", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Init indicates an expected call of Init. +func (mr *MockObjectDeletionStateMockRecorder) Init(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockObjectDeletionState)(nil).Init), arg0) +} + +// Name mocks base method. +func (m *MockObjectDeletionState) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name. +func (mr *MockObjectDeletionStateMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockObjectDeletionState)(nil).Name)) +} diff --git a/commonspace/headsync/diffsyncer.go b/commonspace/headsync/diffsyncer.go index 8b59c743..a0539681 100644 --- a/commonspace/headsync/diffsyncer.go +++ b/commonspace/headsync/diffsyncer.go @@ -6,9 +6,9 @@ import ( "github.com/anyproto/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/commonspace/credentialprovider" + "github.com/anyproto/any-sync/commonspace/deletionstate" "github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/peermanager" - "github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/syncstatus" @@ -22,30 +22,22 @@ type DiffSyncer interface { Sync(ctx context.Context) error RemoveObjects(ids []string) UpdateHeads(id string, heads []string) - Init(deletionState settingsstate.ObjectDeletionState) + Init() Close() error } -func newDiffSyncer( - spaceId string, - diff ldiff.Diff, - peerManager peermanager.PeerManager, - cache treemanager.TreeManager, - storage spacestorage.SpaceStorage, - clientFactory spacesyncproto.ClientFactory, - syncStatus syncstatus.StatusUpdater, - credentialProvider credentialprovider.CredentialProvider, - log logger.CtxLogger) DiffSyncer { +func newDiffSyncer(hs *headSync) DiffSyncer { return &diffSyncer{ - diff: diff, - spaceId: spaceId, - treeManager: cache, - storage: storage, - peerManager: peerManager, - clientFactory: clientFactory, - credentialProvider: credentialProvider, + diff: hs.diff, + spaceId: hs.spaceId, + treeManager: hs.treeManager, + storage: hs.storage, + peerManager: hs.peerManager, + clientFactory: spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient), + credentialProvider: hs.credentialProvider, log: log, - syncStatus: syncStatus, + syncStatus: hs.syncStatus, + deletionState: hs.deletionState, } } @@ -57,14 +49,13 @@ type diffSyncer struct { storage spacestorage.SpaceStorage clientFactory spacesyncproto.ClientFactory log logger.CtxLogger - deletionState settingsstate.ObjectDeletionState + deletionState deletionstate.ObjectDeletionState credentialProvider credentialprovider.CredentialProvider syncStatus syncstatus.StatusUpdater treeSyncer treemanager.TreeSyncer } -func (d *diffSyncer) Init(deletionState settingsstate.ObjectDeletionState) { - d.deletionState = deletionState +func (d *diffSyncer) Init() { d.deletionState.AddObserver(d.RemoveObjects) d.treeSyncer = d.treeManager.NewTreeSyncer(d.spaceId, d.treeManager) } @@ -115,8 +106,14 @@ func (d *diffSyncer) Sync(ctx context.Context) error { func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error) { ctx = logger.CtxWithFields(ctx, zap.String("peerId", p.Id())) + conn, err := p.AcquireDrpcConn(ctx) + if err != nil { + return + } + defer p.ReleaseDrpcConn(conn) + var ( - cl = d.clientFactory.Client(p) + cl = d.clientFactory.Client(conn) rdiff = NewRemoteDiff(d.spaceId, cl) stateCounter = d.syncStatus.StateCounter() ) diff --git a/commonspace/headsync/diffsyncer_test.go b/commonspace/headsync/diffsyncer_test.go index fb1ad49c..7cdb870a 100644 --- a/commonspace/headsync/diffsyncer_test.go +++ b/commonspace/headsync/diffsyncer_test.go @@ -5,23 +5,13 @@ import ( "context" "fmt" "github.com/anyproto/any-sync/app/ldiff" - "github.com/anyproto/any-sync/app/ldiff/mock_ldiff" - "github.com/anyproto/any-sync/app/logger" - "github.com/anyproto/any-sync/commonspace/credentialprovider/mock_credentialprovider" "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/commonspace/object/acl/liststorage/mock_liststorage" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" - mock_treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage" - "github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager" - "github.com/anyproto/any-sync/commonspace/peermanager/mock_peermanager" - "github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate" - "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage" + "github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage" "github.com/anyproto/any-sync/commonspace/spacesyncproto" - "github.com/anyproto/any-sync/commonspace/spacesyncproto/mock_spacesyncproto" - "github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/net/peer" "github.com/golang/mock/gomock" - "github.com/libp2p/go-libp2p/core/sec" "github.com/stretchr/testify/require" "storj.io/drpc" "testing" @@ -36,60 +26,6 @@ type pushSpaceRequestMatcher struct { spaceHeader *spacesyncproto.RawSpaceHeaderWithId } -func (p pushSpaceRequestMatcher) Matches(x interface{}) bool { - res, ok := x.(*spacesyncproto.SpacePushRequest) - if !ok { - return false - } - - return res.Payload.AclPayloadId == p.aclRootId && res.Payload.SpaceHeader == p.spaceHeader && res.Payload.SpaceSettingsPayloadId == p.settingsId && bytes.Equal(p.credential, res.Credential) -} - -func (p pushSpaceRequestMatcher) String() string { - return "" -} - -type mockPeer struct{} - -func (m mockPeer) Addr() string { - return "" -} - -func (m mockPeer) TryClose(objectTTL time.Duration) (res bool, err error) { - return true, m.Close() -} - -func (m mockPeer) Id() string { - return "mockId" -} - -func (m mockPeer) LastUsage() time.Time { - return time.Time{} -} - -func (m mockPeer) Secure() sec.SecureConn { - return nil -} - -func (m mockPeer) UpdateLastUsage() { -} - -func (m mockPeer) Close() error { - return nil -} - -func (m mockPeer) Closed() <-chan struct{} { - return make(chan struct{}) -} - -func (m mockPeer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { - return nil -} - -func (m mockPeer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { - return nil, nil -} - func newPushSpaceRequestMatcher( spaceId string, aclRootId string, @@ -105,80 +41,134 @@ func newPushSpaceRequestMatcher( } } -func TestDiffSyncer_Sync(t *testing.T) { - // setup - ctx := context.Background() - ctrl := gomock.NewController(t) - defer ctrl.Finish() +func (p pushSpaceRequestMatcher) Matches(x interface{}) bool { + res, ok := x.(*spacesyncproto.SpacePushRequest) + if !ok { + return false + } - diffMock := mock_ldiff.NewMockDiff(ctrl) - peerManagerMock := mock_peermanager.NewMockPeerManager(ctrl) - cacheMock := mock_treemanager.NewMockTreeManager(ctrl) - stMock := mock_spacestorage.NewMockSpaceStorage(ctrl) - clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl) - factory := spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient { - return clientMock + return res.Payload.AclPayloadId == p.aclRootId && res.Payload.SpaceHeader == p.spaceHeader && res.Payload.SpaceSettingsPayloadId == p.settingsId && bytes.Equal(p.credential, res.Credential) +} + +func (p pushSpaceRequestMatcher) String() string { + return "" +} + +type mockPeer struct { +} + +func (m mockPeer) Id() string { + return "peerId" +} + +func (m mockPeer) Context() context.Context { + return context.Background() +} + +func (m mockPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { + return nil, nil +} + +func (m mockPeer) ReleaseDrpcConn(conn drpc.Conn) { + return +} + +func (m mockPeer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error { + return nil +} + +func (m mockPeer) IsClosed() bool { + return false +} + +func (m mockPeer) TryClose(objectTTL time.Duration) (res bool, err error) { + return false, err +} + +func (m mockPeer) Close() (err error) { + return nil +} + +func (fx *headSyncFixture) initDiffSyncer(t *testing.T) { + fx.init(t) + fx.diffSyncer = newDiffSyncer(fx.headSync).(*diffSyncer) + fx.diffSyncer.clientFactory = spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient { + return fx.clientMock }) - treeSyncerMock := mock_treemanager.NewMockTreeSyncer(ctrl) - credentialProvider := mock_credentialprovider.NewMockCredentialProvider(ctrl) - delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) - spaceId := "spaceId" - aclRootId := "aclRootId" - l := logger.NewNamed(spaceId) - diffSyncer := newDiffSyncer(spaceId, diffMock, peerManagerMock, cacheMock, stMock, factory, syncstatus.NewNoOpSyncStatus(), credentialProvider, l) - delState.EXPECT().AddObserver(gomock.Any()) - cacheMock.EXPECT().NewTreeSyncer(spaceId, gomock.Any()).Return(treeSyncerMock) - diffSyncer.Init(delState) + fx.deletionStateMock.EXPECT().AddObserver(gomock.Any()) + fx.treeManagerMock.EXPECT().NewTreeSyncer(fx.spaceState.SpaceId, fx.treeManagerMock).Return(fx.treeSyncerMock) + fx.diffSyncer.Init() +} + +func TestDiffSyncer(t *testing.T) { + ctx := context.Background() t.Run("diff syncer sync", func(t *testing.T) { + fx := newHeadSyncFixture(t) + fx.initDiffSyncer(t) + defer fx.stop() mPeer := mockPeer{} - peerManagerMock.EXPECT(). + fx.peerManagerMock.EXPECT(). GetResponsiblePeers(gomock.Any()). Return([]peer.Peer{mPeer}, nil) - diffMock.EXPECT(). - Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). + fx.diffMock.EXPECT(). + Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))). Return([]string{"new"}, []string{"changed"}, nil, nil) - delState.EXPECT().Filter([]string{"new"}).Return([]string{"new"}).Times(1) - delState.EXPECT().Filter([]string{"changed"}).Return([]string{"changed"}).Times(1) - delState.EXPECT().Filter(nil).Return(nil).Times(1) - treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"changed"}, []string{"new"}).Return(nil) - require.NoError(t, diffSyncer.Sync(ctx)) + fx.deletionStateMock.EXPECT().Filter([]string{"new"}).Return([]string{"new"}).Times(1) + fx.deletionStateMock.EXPECT().Filter([]string{"changed"}).Return([]string{"changed"}).Times(1) + fx.deletionStateMock.EXPECT().Filter(nil).Return(nil).Times(1) + fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"changed"}, []string{"new"}).Return(nil) + require.NoError(t, fx.diffSyncer.Sync(ctx)) }) t.Run("diff syncer sync conf error", func(t *testing.T) { - peerManagerMock.EXPECT(). + fx := newHeadSyncFixture(t) + fx.initDiffSyncer(t) + defer fx.stop() + ctx := context.Background() + fx.peerManagerMock.EXPECT(). GetResponsiblePeers(gomock.Any()). Return(nil, fmt.Errorf("some error")) - require.Error(t, diffSyncer.Sync(ctx)) + require.Error(t, fx.diffSyncer.Sync(ctx)) }) t.Run("deletion state remove objects", func(t *testing.T) { + fx := newHeadSyncFixture(t) + fx.initDiffSyncer(t) + defer fx.stop() deletedId := "id" - delState.EXPECT().Exists(deletedId).Return(true) + fx.deletionStateMock.EXPECT().Exists(deletedId).Return(true) // this should not result in any mock being called - diffSyncer.UpdateHeads(deletedId, []string{"someHead"}) + fx.diffSyncer.UpdateHeads(deletedId, []string{"someHead"}) }) t.Run("update heads updates diff", func(t *testing.T) { + fx := newHeadSyncFixture(t) + fx.initDiffSyncer(t) + defer fx.stop() newId := "newId" newHeads := []string{"h1", "h2"} hash := "hash" - diffMock.EXPECT().Set(ldiff.Element{ + fx.diffMock.EXPECT().Set(ldiff.Element{ Id: newId, Head: concatStrings(newHeads), }) - diffMock.EXPECT().Hash().Return(hash) - delState.EXPECT().Exists(newId).Return(false) - stMock.EXPECT().WriteSpaceHash(hash) - diffSyncer.UpdateHeads(newId, newHeads) + fx.diffMock.EXPECT().Hash().Return(hash) + fx.deletionStateMock.EXPECT().Exists(newId).Return(false) + fx.storageMock.EXPECT().WriteSpaceHash(hash) + fx.diffSyncer.UpdateHeads(newId, newHeads) }) t.Run("diff syncer sync space missing", func(t *testing.T) { - aclStorageMock := mock_liststorage.NewMockListStorage(ctrl) - settingsStorage := mock_treestorage.NewMockTreeStorage(ctrl) + fx := newHeadSyncFixture(t) + fx.initDiffSyncer(t) + defer fx.stop() + aclStorageMock := mock_liststorage.NewMockListStorage(fx.ctrl) + settingsStorage := mock_treestorage.NewMockTreeStorage(fx.ctrl) settingsId := "settingsId" + aclRootId := "aclRootId" aclRoot := &aclrecordproto.RawAclRecordWithId{ Id: aclRootId, } @@ -189,55 +179,61 @@ func TestDiffSyncer_Sync(t *testing.T) { spaceSettingsId := "spaceSettingsId" credential := []byte("credential") - peerManagerMock.EXPECT(). + fx.peerManagerMock.EXPECT(). GetResponsiblePeers(gomock.Any()). Return([]peer.Peer{mockPeer{}}, nil) - diffMock.EXPECT(). - Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). + fx.diffMock.EXPECT(). + Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))). Return(nil, nil, nil, spacesyncproto.ErrSpaceMissing) - stMock.EXPECT().AclStorage().Return(aclStorageMock, nil) - stMock.EXPECT().SpaceHeader().Return(spaceHeader, nil) - stMock.EXPECT().SpaceSettingsId().Return(spaceSettingsId) - stMock.EXPECT().TreeStorage(spaceSettingsId).Return(settingsStorage, nil) + fx.storageMock.EXPECT().AclStorage().Return(aclStorageMock, nil) + fx.storageMock.EXPECT().SpaceHeader().Return(spaceHeader, nil) + fx.storageMock.EXPECT().SpaceSettingsId().Return(spaceSettingsId) + fx.storageMock.EXPECT().TreeStorage(spaceSettingsId).Return(settingsStorage, nil) settingsStorage.EXPECT().Root().Return(settingsRoot, nil) aclStorageMock.EXPECT(). Root(). Return(aclRoot, nil) - credentialProvider.EXPECT(). + fx.credentialProviderMock.EXPECT(). GetCredential(gomock.Any(), spaceHeader). Return(credential, nil) - clientMock.EXPECT(). - SpacePush(gomock.Any(), newPushSpaceRequestMatcher(spaceId, aclRootId, settingsId, credential, spaceHeader)). + fx.clientMock.EXPECT(). + SpacePush(gomock.Any(), newPushSpaceRequestMatcher(fx.spaceState.SpaceId, aclRootId, settingsId, credential, spaceHeader)). Return(nil, nil) - peerManagerMock.EXPECT().SendPeer(gomock.Any(), "mockId", gomock.Any()) + fx.peerManagerMock.EXPECT().SendPeer(gomock.Any(), "peerId", gomock.Any()) - require.NoError(t, diffSyncer.Sync(ctx)) + require.NoError(t, fx.diffSyncer.Sync(ctx)) }) t.Run("diff syncer sync unexpected", func(t *testing.T) { - peerManagerMock.EXPECT(). + fx := newHeadSyncFixture(t) + fx.initDiffSyncer(t) + defer fx.stop() + fx.peerManagerMock.EXPECT(). GetResponsiblePeers(gomock.Any()). Return([]peer.Peer{mockPeer{}}, nil) - diffMock.EXPECT(). - Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). + fx.diffMock.EXPECT(). + Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))). Return(nil, nil, nil, spacesyncproto.ErrUnexpected) - require.NoError(t, diffSyncer.Sync(ctx)) + require.NoError(t, fx.diffSyncer.Sync(ctx)) }) t.Run("diff syncer sync space is deleted error", func(t *testing.T) { + fx := newHeadSyncFixture(t) + fx.initDiffSyncer(t) + defer fx.stop() mPeer := mockPeer{} - peerManagerMock.EXPECT(). + fx.peerManagerMock.EXPECT(). GetResponsiblePeers(gomock.Any()). Return([]peer.Peer{mPeer}, nil) - diffMock.EXPECT(). - Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). + fx.diffMock.EXPECT(). + Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))). Return(nil, nil, nil, spacesyncproto.ErrSpaceIsDeleted) - stMock.EXPECT().SpaceSettingsId().Return("settingsId") - treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"settingsId"}, nil).Return(nil) + fx.storageMock.EXPECT().SpaceSettingsId().Return("settingsId") + fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"settingsId"}, nil).Return(nil) - require.NoError(t, diffSyncer.Sync(ctx)) + require.NoError(t, fx.diffSyncer.Sync(ctx)) }) } diff --git a/commonspace/headsync/headsync.go b/commonspace/headsync/headsync.go index ddca3e50..18ed7357 100644 --- a/commonspace/headsync/headsync.go +++ b/commonspace/headsync/headsync.go @@ -3,123 +3,145 @@ package headsync import ( "context" + "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/logger" + config2 "github.com/anyproto/any-sync/commonspace/config" "github.com/anyproto/any-sync/commonspace/credentialprovider" + "github.com/anyproto/any-sync/commonspace/deletionstate" "github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/peermanager" - "github.com/anyproto/any-sync/commonspace/settings/settingsstate" + "github.com/anyproto/any-sync/commonspace/spacestate" "github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/util/periodicsync" + "github.com/anyproto/any-sync/util/slice" "go.uber.org/zap" "golang.org/x/exp/slices" - "strings" "sync/atomic" "time" ) +var log = logger.NewNamed(CName) + +const CName = "common.commonspace.headsync" + type TreeHeads struct { Id string Heads []string } type HeadSync interface { - Init(objectIds []string, deletionState settingsstate.ObjectDeletionState) - + app.ComponentRunnable + ExternalIds() []string + DebugAllHeads() (res []TreeHeads) + AllIds() []string UpdateHeads(id string, heads []string) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) RemoveObjects(ids []string) - AllIds() []string - DebugAllHeads() (res []TreeHeads) - - Close() (err error) } type headSync struct { spaceId string - periodicSync periodicsync.PeriodicSync - storage spacestorage.SpaceStorage - diff ldiff.Diff - log logger.CtxLogger - syncer DiffSyncer - configuration nodeconf.NodeConf spaceIsDeleted *atomic.Bool + syncPeriod int - syncPeriod int + periodicSync periodicsync.PeriodicSync + storage spacestorage.SpaceStorage + diff ldiff.Diff + log logger.CtxLogger + syncer DiffSyncer + configuration nodeconf.NodeConf + peerManager peermanager.PeerManager + treeManager treemanager.TreeManager + credentialProvider credentialprovider.CredentialProvider + syncStatus syncstatus.StatusService + deletionState deletionstate.ObjectDeletionState } -func NewHeadSync( - spaceId string, - spaceIsDeleted *atomic.Bool, - syncPeriod int, - configuration nodeconf.NodeConf, - storage spacestorage.SpaceStorage, - peerManager peermanager.PeerManager, - cache treemanager.TreeManager, - syncStatus syncstatus.StatusUpdater, - credentialProvider credentialprovider.CredentialProvider, - log logger.CtxLogger) HeadSync { +func New() HeadSync { + return &headSync{} +} - diff := ldiff.New(16, 16) - l := log.With(zap.String("spaceId", spaceId)) - factory := spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient) - syncer := newDiffSyncer(spaceId, diff, peerManager, cache, storage, factory, syncStatus, credentialProvider, l) +var createDiffSyncer = newDiffSyncer + +func (h *headSync) Init(a *app.App) (err error) { + shared := a.MustComponent(spacestate.CName).(*spacestate.SpaceState) + cfg := a.MustComponent("config").(config2.ConfigGetter) + h.spaceId = shared.SpaceId + h.spaceIsDeleted = shared.SpaceIsDeleted + h.syncPeriod = cfg.GetSpace().SyncPeriod + h.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) + h.log = log.With(zap.String("spaceId", h.spaceId)) + h.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage) + h.diff = ldiff.New(16, 16) + h.peerManager = a.MustComponent(peermanager.CName).(peermanager.PeerManager) + h.credentialProvider = a.MustComponent(credentialprovider.CName).(credentialprovider.CredentialProvider) + h.syncStatus = a.MustComponent(syncstatus.CName).(syncstatus.StatusService) + h.treeManager = a.MustComponent(treemanager.CName).(treemanager.TreeManager) + h.deletionState = a.MustComponent(deletionstate.CName).(deletionstate.ObjectDeletionState) + h.syncer = createDiffSyncer(h) sync := func(ctx context.Context) (err error) { // for clients cancelling the sync process - if spaceIsDeleted.Load() && !configuration.IsResponsible(spaceId) { + if h.spaceIsDeleted.Load() && !h.configuration.IsResponsible(h.spaceId) { return spacesyncproto.ErrSpaceIsDeleted } - return syncer.Sync(ctx) - } - periodicSync := periodicsync.NewPeriodicSync(syncPeriod, time.Minute, sync, l) - - return &headSync{ - spaceId: spaceId, - storage: storage, - syncer: syncer, - periodicSync: periodicSync, - diff: diff, - log: log, - syncPeriod: syncPeriod, - configuration: configuration, - spaceIsDeleted: spaceIsDeleted, + return h.syncer.Sync(ctx) } + h.periodicSync = periodicsync.NewPeriodicSync(h.syncPeriod, time.Minute, sync, h.log) + // TODO: move to run? + h.syncer.Init() + return nil } -func (d *headSync) Init(objectIds []string, deletionState settingsstate.ObjectDeletionState) { - d.fillDiff(objectIds) - d.syncer.Init(deletionState) - d.periodicSync.Run() +func (h *headSync) Name() (name string) { + return CName } -func (d *headSync) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) { - if d.spaceIsDeleted.Load() { +func (h *headSync) Run(ctx context.Context) (err error) { + initialIds, err := h.storage.StoredIds() + if err != nil { + return + } + h.fillDiff(initialIds) + h.periodicSync.Run() + return +} + +func (h *headSync) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) { + if h.spaceIsDeleted.Load() { peerId, err := peer.CtxPeerId(ctx) if err != nil { return nil, err } // stop receiving all request for sync from clients - if !slices.Contains(d.configuration.NodeIds(d.spaceId), peerId) { + if !slices.Contains(h.configuration.NodeIds(h.spaceId), peerId) { return nil, spacesyncproto.ErrSpaceIsDeleted } } - return HandleRangeRequest(ctx, d.diff, req) + return HandleRangeRequest(ctx, h.diff, req) } -func (d *headSync) UpdateHeads(id string, heads []string) { - d.syncer.UpdateHeads(id, heads) +func (h *headSync) UpdateHeads(id string, heads []string) { + h.syncer.UpdateHeads(id, heads) } -func (d *headSync) AllIds() []string { - return d.diff.Ids() +func (h *headSync) AllIds() []string { + return h.diff.Ids() } -func (d *headSync) DebugAllHeads() (res []TreeHeads) { - els := d.diff.Elements() +func (h *headSync) ExternalIds() []string { + settingsId := h.storage.SpaceSettingsId() + return slice.DiscardFromSlice(h.AllIds(), func(id string) bool { + return id == settingsId + }) +} + +func (h *headSync) DebugAllHeads() (res []TreeHeads) { + els := h.diff.Elements() for _, el := range els { idHead := TreeHeads{ Id: el.Id, @@ -130,19 +152,19 @@ func (d *headSync) DebugAllHeads() (res []TreeHeads) { return } -func (d *headSync) RemoveObjects(ids []string) { - d.syncer.RemoveObjects(ids) +func (h *headSync) RemoveObjects(ids []string) { + h.syncer.RemoveObjects(ids) } -func (d *headSync) Close() (err error) { - d.periodicSync.Close() - return d.syncer.Close() +func (h *headSync) Close(ctx context.Context) (err error) { + h.periodicSync.Close() + return h.syncer.Close() } -func (d *headSync) fillDiff(objectIds []string) { +func (h *headSync) fillDiff(objectIds []string) { var els = make([]ldiff.Element, 0, len(objectIds)) for _, id := range objectIds { - st, err := d.storage.TreeStorage(id) + st, err := h.storage.TreeStorage(id) if err != nil { continue } @@ -155,32 +177,8 @@ func (d *headSync) fillDiff(objectIds []string) { Head: concatStrings(heads), }) } - d.diff.Set(els...) - if err := d.storage.WriteSpaceHash(d.diff.Hash()); err != nil { - d.log.Error("can't write space hash", zap.Error(err)) + h.diff.Set(els...) + if err := h.storage.WriteSpaceHash(h.diff.Hash()); err != nil { + h.log.Error("can't write space hash", zap.Error(err)) } } - -func concatStrings(strs []string) string { - var ( - b strings.Builder - totalLen int - ) - for _, s := range strs { - totalLen += len(s) - } - - b.Grow(totalLen) - for _, s := range strs { - b.WriteString(s) - } - return b.String() -} - -func splitString(str string) (res []string) { - const cidLen = 59 - for i := 0; i < len(str); i += cidLen { - res = append(res, str[i:i+cidLen]) - } - return -} diff --git a/commonspace/headsync/headsync_test.go b/commonspace/headsync/headsync_test.go index ae2c419a..4f14d084 100644 --- a/commonspace/headsync/headsync_test.go +++ b/commonspace/headsync/headsync_test.go @@ -1,71 +1,179 @@ package headsync import ( + "context" + "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/ldiff/mock_ldiff" - "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/commonspace/config" + "github.com/anyproto/any-sync/commonspace/credentialprovider" + "github.com/anyproto/any-sync/commonspace/credentialprovider/mock_credentialprovider" + "github.com/anyproto/any-sync/commonspace/deletionstate" + "github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate" "github.com/anyproto/any-sync/commonspace/headsync/mock_headsync" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage" - "github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate" + "github.com/anyproto/any-sync/commonspace/object/treemanager" + "github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager" + "github.com/anyproto/any-sync/commonspace/peermanager" + "github.com/anyproto/any-sync/commonspace/peermanager/mock_peermanager" + "github.com/anyproto/any-sync/commonspace/spacestate" + "github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage" - "github.com/anyproto/any-sync/util/periodicsync/mock_periodicsync" + "github.com/anyproto/any-sync/commonspace/spacesyncproto/mock_spacesyncproto" + "github.com/anyproto/any-sync/commonspace/syncstatus" + "github.com/anyproto/any-sync/nodeconf" + "github.com/anyproto/any-sync/nodeconf/mock_nodeconf" "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "sync/atomic" "testing" ) -func TestDiffService(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() +type mockConfig struct { +} - spaceId := "spaceId" - l := logger.NewNamed("sync") - pSyncMock := mock_periodicsync.NewMockPeriodicSync(ctrl) - storageMock := mock_spacestorage.NewMockSpaceStorage(ctrl) - treeStorageMock := mock_treestorage.NewMockTreeStorage(ctrl) - diffMock := mock_ldiff.NewMockDiff(ctrl) - syncer := mock_headsync.NewMockDiffSyncer(ctrl) - delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) - syncPeriod := 1 - initId := "initId" +func (m mockConfig) Init(a *app.App) (err error) { + return nil +} - service := &headSync{ - spaceId: spaceId, - storage: storageMock, - periodicSync: pSyncMock, - syncer: syncer, - diff: diffMock, - log: l, - syncPeriod: syncPeriod, +func (m mockConfig) Name() (name string) { + return "config" +} + +func (m mockConfig) GetSpace() config.Config { + return config.Config{} +} + +type headSyncFixture struct { + spaceState *spacestate.SpaceState + ctrl *gomock.Controller + app *app.App + + configurationMock *mock_nodeconf.MockService + storageMock *mock_spacestorage.MockSpaceStorage + peerManagerMock *mock_peermanager.MockPeerManager + credentialProviderMock *mock_credentialprovider.MockCredentialProvider + syncStatus syncstatus.StatusService + treeManagerMock *mock_treemanager.MockTreeManager + deletionStateMock *mock_deletionstate.MockObjectDeletionState + diffSyncerMock *mock_headsync.MockDiffSyncer + treeSyncerMock *mock_treemanager.MockTreeSyncer + diffMock *mock_ldiff.MockDiff + clientMock *mock_spacesyncproto.MockDRPCSpaceSyncClient + headSync *headSync + diffSyncer *diffSyncer +} + +func newHeadSyncFixture(t *testing.T) *headSyncFixture { + spaceState := &spacestate.SpaceState{ + SpaceId: "spaceId", + SpaceIsDeleted: &atomic.Bool{}, } + ctrl := gomock.NewController(t) + configurationMock := mock_nodeconf.NewMockService(ctrl) + configurationMock.EXPECT().Name().AnyTimes().Return(nodeconf.CName) + storageMock := mock_spacestorage.NewMockSpaceStorage(ctrl) + storageMock.EXPECT().Name().AnyTimes().Return(spacestorage.CName) + peerManagerMock := mock_peermanager.NewMockPeerManager(ctrl) + peerManagerMock.EXPECT().Name().AnyTimes().Return(peermanager.CName) + credentialProviderMock := mock_credentialprovider.NewMockCredentialProvider(ctrl) + credentialProviderMock.EXPECT().Name().AnyTimes().Return(credentialprovider.CName) + syncStatus := syncstatus.NewNoOpSyncStatus() + treeManagerMock := mock_treemanager.NewMockTreeManager(ctrl) + treeManagerMock.EXPECT().Name().AnyTimes().Return(treemanager.CName) + deletionStateMock := mock_deletionstate.NewMockObjectDeletionState(ctrl) + deletionStateMock.EXPECT().Name().AnyTimes().Return(deletionstate.CName) + diffSyncerMock := mock_headsync.NewMockDiffSyncer(ctrl) + treeSyncerMock := mock_treemanager.NewMockTreeSyncer(ctrl) + diffMock := mock_ldiff.NewMockDiff(ctrl) + clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl) + hs := &headSync{} + a := &app.App{} + a.Register(spaceState). + Register(mockConfig{}). + Register(configurationMock). + Register(storageMock). + Register(peerManagerMock). + Register(credentialProviderMock). + Register(syncStatus). + Register(treeManagerMock). + Register(deletionStateMock). + Register(hs) + return &headSyncFixture{ + spaceState: spaceState, + ctrl: ctrl, + app: a, + configurationMock: configurationMock, + storageMock: storageMock, + peerManagerMock: peerManagerMock, + credentialProviderMock: credentialProviderMock, + syncStatus: syncStatus, + treeManagerMock: treeManagerMock, + deletionStateMock: deletionStateMock, + headSync: hs, + diffSyncerMock: diffSyncerMock, + treeSyncerMock: treeSyncerMock, + diffMock: diffMock, + clientMock: clientMock, + } +} - t.Run("init", func(t *testing.T) { - storageMock.EXPECT().TreeStorage(initId).Return(treeStorageMock, nil) - treeStorageMock.EXPECT().Heads().Return([]string{"h1", "h2"}, nil) - syncer.EXPECT().Init(delState) - diffMock.EXPECT().Set(ldiff.Element{ - Id: initId, +func (fx *headSyncFixture) init(t *testing.T) { + createDiffSyncer = func(hs *headSync) DiffSyncer { + return fx.diffSyncerMock + } + fx.diffSyncerMock.EXPECT().Init() + err := fx.headSync.Init(fx.app) + require.NoError(t, err) + fx.headSync.diff = fx.diffMock +} + +func (fx *headSyncFixture) stop() { + fx.ctrl.Finish() +} + +func TestHeadSync(t *testing.T) { + ctx := context.Background() + + t.Run("run close", func(t *testing.T) { + fx := newHeadSyncFixture(t) + fx.init(t) + defer fx.stop() + + ids := []string{"id1"} + treeMock := mock_treestorage.NewMockTreeStorage(fx.ctrl) + fx.storageMock.EXPECT().StoredIds().Return(ids, nil) + fx.storageMock.EXPECT().TreeStorage(ids[0]).Return(treeMock, nil) + treeMock.EXPECT().Heads().Return([]string{"h1", "h2"}, nil) + fx.diffMock.EXPECT().Set(ldiff.Element{ + Id: "id1", Head: "h1h2", }) - hash := "123" - diffMock.EXPECT().Hash().Return(hash) - storageMock.EXPECT().WriteSpaceHash(hash) - pSyncMock.EXPECT().Run() - service.Init([]string{initId}, delState) + fx.diffMock.EXPECT().Hash().Return("hash") + fx.storageMock.EXPECT().WriteSpaceHash("hash").Return(nil) + fx.diffSyncerMock.EXPECT().Sync(gomock.Any()).Return(nil) + fx.diffSyncerMock.EXPECT().Close().Return(nil) + err := fx.headSync.Run(ctx) + require.NoError(t, err) + err = fx.headSync.Close(ctx) + require.NoError(t, err) }) t.Run("update heads", func(t *testing.T) { - syncer.EXPECT().UpdateHeads(initId, []string{"h1", "h2"}) - service.UpdateHeads(initId, []string{"h1", "h2"}) + fx := newHeadSyncFixture(t) + fx.init(t) + defer fx.stop() + + fx.diffSyncerMock.EXPECT().UpdateHeads("id1", []string{"h1"}) + fx.headSync.UpdateHeads("id1", []string{"h1"}) }) t.Run("remove objects", func(t *testing.T) { - syncer.EXPECT().RemoveObjects([]string{"h1", "h2"}) - service.RemoveObjects([]string{"h1", "h2"}) - }) + fx := newHeadSyncFixture(t) + fx.init(t) + defer fx.stop() - t.Run("close", func(t *testing.T) { - pSyncMock.EXPECT().Close() - syncer.EXPECT().Close() - service.Close() + fx.diffSyncerMock.EXPECT().RemoveObjects([]string{"id1"}) + fx.headSync.RemoveObjects([]string{"id1"}) }) } diff --git a/commonspace/headsync/mock_headsync/mock_headsync.go b/commonspace/headsync/mock_headsync/mock_headsync.go index 7df2fe64..46b16aab 100644 --- a/commonspace/headsync/mock_headsync/mock_headsync.go +++ b/commonspace/headsync/mock_headsync/mock_headsync.go @@ -8,7 +8,6 @@ import ( context "context" reflect "reflect" - settingsstate "github.com/anyproto/any-sync/commonspace/settings/settingsstate" gomock "github.com/golang/mock/gomock" ) @@ -50,15 +49,15 @@ func (mr *MockDiffSyncerMockRecorder) Close() *gomock.Call { } // Init mocks base method. -func (m *MockDiffSyncer) Init(arg0 settingsstate.ObjectDeletionState) { +func (m *MockDiffSyncer) Init() { m.ctrl.T.Helper() - m.ctrl.Call(m, "Init", arg0) + m.ctrl.Call(m, "Init") } // Init indicates an expected call of Init. -func (mr *MockDiffSyncerMockRecorder) Init(arg0 interface{}) *gomock.Call { +func (mr *MockDiffSyncerMockRecorder) Init() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockDiffSyncer)(nil).Init), arg0) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockDiffSyncer)(nil).Init)) } // RemoveObjects mocks base method. diff --git a/commonspace/headsync/util.go b/commonspace/headsync/util.go new file mode 100644 index 00000000..549a77f2 --- /dev/null +++ b/commonspace/headsync/util.go @@ -0,0 +1,27 @@ +package headsync + +import "strings" + +func concatStrings(strs []string) string { + var ( + b strings.Builder + totalLen int + ) + for _, s := range strs { + totalLen += len(s) + } + + b.Grow(totalLen) + for _, s := range strs { + b.WriteString(s) + } + return b.String() +} + +func splitString(str string) (res []string) { + const cidLen = 59 + for i := 0; i < len(str); i += cidLen { + res = append(res, str[i:i+cidLen]) + } + return +} diff --git a/commonspace/object/acl/syncacl/syncacl.go b/commonspace/object/acl/syncacl/syncacl.go index dd1e647e..426b16cd 100644 --- a/commonspace/object/acl/syncacl/syncacl.go +++ b/commonspace/object/acl/syncacl/syncacl.go @@ -1,21 +1,43 @@ package syncacl import ( + "context" + "github.com/anyproto/any-sync/accountservice" + "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/commonspace/object/acl/list" - "github.com/anyproto/any-sync/commonspace/objectsync" - "github.com/anyproto/any-sync/commonspace/objectsync/synchandler" + "github.com/anyproto/any-sync/commonspace/spacestorage" + "github.com/anyproto/any-sync/commonspace/spacesyncproto" ) +const CName = "common.acl.syncacl" + +func New() *SyncAcl { + return &SyncAcl{} +} + type SyncAcl struct { list.AclList - synchandler.SyncHandler - messagePool objectsync.MessagePool } -func NewSyncAcl(aclList list.AclList, messagePool objectsync.MessagePool) *SyncAcl { - return &SyncAcl{ - AclList: aclList, - SyncHandler: nil, - messagePool: messagePool, - } +func (s *SyncAcl) HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) { + return nil, nil +} + +func (s *SyncAcl) HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) { + return nil +} + +func (s *SyncAcl) Init(a *app.App) (err error) { + storage := a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage) + aclStorage, err := storage.AclStorage() + if err != nil { + return err + } + acc := a.MustComponent(accountservice.CName).(accountservice.Service) + s.AclList, err = list.BuildAclListWithIdentity(acc.Account(), aclStorage) + return err +} + +func (s *SyncAcl) Name() (name string) { + return CName } diff --git a/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go b/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go index 0f9d40ba..792dbee7 100644 --- a/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go +++ b/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/anyproto/any-sync/commonspace/object/tree/synctree (interfaces: SyncTree,ReceiveQueue,HeadNotifiable) +// Source: github.com/anyproto/any-sync/commonspace/object/tree/synctree (interfaces: SyncTree,ReceiveQueue,HeadNotifiable,SyncClient,RequestFactory,TreeSyncProtocol) // Package mock_synctree is a generated GoMock package. package mock_synctree @@ -186,6 +186,21 @@ func (mr *MockSyncTreeMockRecorder) HandleMessage(arg0, arg1, arg2 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockSyncTree)(nil).HandleMessage), arg0, arg1, arg2) } +// HandleRequest mocks base method. +func (m *MockSyncTree) HandleRequest(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HandleRequest", arg0, arg1, arg2) + ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HandleRequest indicates an expected call of HandleRequest. +func (mr *MockSyncTreeMockRecorder) HandleRequest(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleRequest", reflect.TypeOf((*MockSyncTree)(nil).HandleRequest), arg0, arg1, arg2) +} + // HasChanges mocks base method. func (m *MockSyncTree) HasChanges(arg0 ...string) bool { m.ctrl.T.Helper() @@ -590,3 +605,287 @@ func (mr *MockHeadNotifiableMockRecorder) UpdateHeads(arg0, arg1 interface{}) *g mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHeads", reflect.TypeOf((*MockHeadNotifiable)(nil).UpdateHeads), arg0, arg1) } + +// MockSyncClient is a mock of SyncClient interface. +type MockSyncClient struct { + ctrl *gomock.Controller + recorder *MockSyncClientMockRecorder +} + +// MockSyncClientMockRecorder is the mock recorder for MockSyncClient. +type MockSyncClientMockRecorder struct { + mock *MockSyncClient +} + +// NewMockSyncClient creates a new mock instance. +func NewMockSyncClient(ctrl *gomock.Controller) *MockSyncClient { + mock := &MockSyncClient{ctrl: ctrl} + mock.recorder = &MockSyncClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockSyncClient) EXPECT() *MockSyncClientMockRecorder { + return m.recorder +} + +// Broadcast mocks base method. +func (m *MockSyncClient) Broadcast(arg0 *treechangeproto.TreeSyncMessage) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "Broadcast", arg0) +} + +// Broadcast indicates an expected call of Broadcast. +func (mr *MockSyncClientMockRecorder) Broadcast(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Broadcast", reflect.TypeOf((*MockSyncClient)(nil).Broadcast), arg0) +} + +// CreateFullSyncRequest mocks base method. +func (m *MockSyncClient) CreateFullSyncRequest(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1, arg2) + ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest. +func (mr *MockSyncClientMockRecorder) CreateFullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncRequest), arg0, arg1, arg2) +} + +// CreateFullSyncResponse mocks base method. +func (m *MockSyncClient) CreateFullSyncResponse(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1, arg2) + ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse. +func (mr *MockSyncClientMockRecorder) CreateFullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncResponse), arg0, arg1, arg2) +} + +// CreateHeadUpdate mocks base method. +func (m *MockSyncClient) CreateHeadUpdate(arg0 objecttree.ObjectTree, arg1 []*treechangeproto.RawTreeChangeWithId) *treechangeproto.TreeSyncMessage { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1) + ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + return ret0 +} + +// CreateHeadUpdate indicates an expected call of CreateHeadUpdate. +func (mr *MockSyncClientMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockSyncClient)(nil).CreateHeadUpdate), arg0, arg1) +} + +// CreateNewTreeRequest mocks base method. +func (m *MockSyncClient) CreateNewTreeRequest() *treechangeproto.TreeSyncMessage { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateNewTreeRequest") + ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + return ret0 +} + +// CreateNewTreeRequest indicates an expected call of CreateNewTreeRequest. +func (mr *MockSyncClientMockRecorder) CreateNewTreeRequest() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNewTreeRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateNewTreeRequest)) +} + +// QueueRequest mocks base method. +func (m *MockSyncClient) QueueRequest(arg0, arg1 string, arg2 *treechangeproto.TreeSyncMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "QueueRequest", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// QueueRequest indicates an expected call of QueueRequest. +func (mr *MockSyncClientMockRecorder) QueueRequest(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueRequest", reflect.TypeOf((*MockSyncClient)(nil).QueueRequest), arg0, arg1, arg2) +} + +// SendRequest mocks base method. +func (m *MockSyncClient) SendRequest(arg0 context.Context, arg1, arg2 string, arg3 *treechangeproto.TreeSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendRequest", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SendRequest indicates an expected call of SendRequest. +func (mr *MockSyncClientMockRecorder) SendRequest(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRequest", reflect.TypeOf((*MockSyncClient)(nil).SendRequest), arg0, arg1, arg2, arg3) +} + +// SendUpdate mocks base method. +func (m *MockSyncClient) SendUpdate(arg0, arg1 string, arg2 *treechangeproto.TreeSyncMessage) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SendUpdate", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// SendUpdate indicates an expected call of SendUpdate. +func (mr *MockSyncClientMockRecorder) SendUpdate(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendUpdate", reflect.TypeOf((*MockSyncClient)(nil).SendUpdate), arg0, arg1, arg2) +} + +// MockRequestFactory is a mock of RequestFactory interface. +type MockRequestFactory struct { + ctrl *gomock.Controller + recorder *MockRequestFactoryMockRecorder +} + +// MockRequestFactoryMockRecorder is the mock recorder for MockRequestFactory. +type MockRequestFactoryMockRecorder struct { + mock *MockRequestFactory +} + +// NewMockRequestFactory creates a new mock instance. +func NewMockRequestFactory(ctrl *gomock.Controller) *MockRequestFactory { + mock := &MockRequestFactory{ctrl: ctrl} + mock.recorder = &MockRequestFactoryMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRequestFactory) EXPECT() *MockRequestFactoryMockRecorder { + return m.recorder +} + +// CreateFullSyncRequest mocks base method. +func (m *MockRequestFactory) CreateFullSyncRequest(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1, arg2) + ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest. +func (mr *MockRequestFactoryMockRecorder) CreateFullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncRequest), arg0, arg1, arg2) +} + +// CreateFullSyncResponse mocks base method. +func (m *MockRequestFactory) CreateFullSyncResponse(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1, arg2) + ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse. +func (mr *MockRequestFactoryMockRecorder) CreateFullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncResponse), arg0, arg1, arg2) +} + +// CreateHeadUpdate mocks base method. +func (m *MockRequestFactory) CreateHeadUpdate(arg0 objecttree.ObjectTree, arg1 []*treechangeproto.RawTreeChangeWithId) *treechangeproto.TreeSyncMessage { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1) + ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + return ret0 +} + +// CreateHeadUpdate indicates an expected call of CreateHeadUpdate. +func (mr *MockRequestFactoryMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockRequestFactory)(nil).CreateHeadUpdate), arg0, arg1) +} + +// CreateNewTreeRequest mocks base method. +func (m *MockRequestFactory) CreateNewTreeRequest() *treechangeproto.TreeSyncMessage { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateNewTreeRequest") + ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + return ret0 +} + +// CreateNewTreeRequest indicates an expected call of CreateNewTreeRequest. +func (mr *MockRequestFactoryMockRecorder) CreateNewTreeRequest() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNewTreeRequest", reflect.TypeOf((*MockRequestFactory)(nil).CreateNewTreeRequest)) +} + +// MockTreeSyncProtocol is a mock of TreeSyncProtocol interface. +type MockTreeSyncProtocol struct { + ctrl *gomock.Controller + recorder *MockTreeSyncProtocolMockRecorder +} + +// MockTreeSyncProtocolMockRecorder is the mock recorder for MockTreeSyncProtocol. +type MockTreeSyncProtocolMockRecorder struct { + mock *MockTreeSyncProtocol +} + +// NewMockTreeSyncProtocol creates a new mock instance. +func NewMockTreeSyncProtocol(ctrl *gomock.Controller) *MockTreeSyncProtocol { + mock := &MockTreeSyncProtocol{ctrl: ctrl} + mock.recorder = &MockTreeSyncProtocolMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTreeSyncProtocol) EXPECT() *MockTreeSyncProtocolMockRecorder { + return m.recorder +} + +// FullSyncRequest mocks base method. +func (m *MockTreeSyncProtocol) FullSyncRequest(arg0 context.Context, arg1 string, arg2 *treechangeproto.TreeFullSyncRequest) (*treechangeproto.TreeSyncMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FullSyncRequest", arg0, arg1, arg2) + ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FullSyncRequest indicates an expected call of FullSyncRequest. +func (mr *MockTreeSyncProtocolMockRecorder) FullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncRequest", reflect.TypeOf((*MockTreeSyncProtocol)(nil).FullSyncRequest), arg0, arg1, arg2) +} + +// FullSyncResponse mocks base method. +func (m *MockTreeSyncProtocol) FullSyncResponse(arg0 context.Context, arg1 string, arg2 *treechangeproto.TreeFullSyncResponse) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FullSyncResponse", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// FullSyncResponse indicates an expected call of FullSyncResponse. +func (mr *MockTreeSyncProtocolMockRecorder) FullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncResponse", reflect.TypeOf((*MockTreeSyncProtocol)(nil).FullSyncResponse), arg0, arg1, arg2) +} + +// HeadUpdate mocks base method. +func (m *MockTreeSyncProtocol) HeadUpdate(arg0 context.Context, arg1 string, arg2 *treechangeproto.TreeHeadUpdate) (*treechangeproto.TreeSyncMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HeadUpdate", arg0, arg1, arg2) + ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// HeadUpdate indicates an expected call of HeadUpdate. +func (mr *MockTreeSyncProtocolMockRecorder) HeadUpdate(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadUpdate", reflect.TypeOf((*MockTreeSyncProtocol)(nil).HeadUpdate), arg0, arg1, arg2) +} diff --git a/commonspace/object/tree/synctree/syncprotocol_test.go b/commonspace/object/tree/synctree/protocolintegration_test.go similarity index 100% rename from commonspace/object/tree/synctree/syncprotocol_test.go rename to commonspace/object/tree/synctree/protocolintegration_test.go diff --git a/commonspace/objectsync/requestfactory.go b/commonspace/object/tree/synctree/requestfactory.go similarity index 99% rename from commonspace/objectsync/requestfactory.go rename to commonspace/object/tree/synctree/requestfactory.go index 1f4f3c7d..8d91add8 100644 --- a/commonspace/objectsync/requestfactory.go +++ b/commonspace/object/tree/synctree/requestfactory.go @@ -1,4 +1,4 @@ -package objectsync +package synctree import ( "fmt" diff --git a/commonspace/object/tree/synctree/syncclient.go b/commonspace/object/tree/synctree/syncclient.go new file mode 100644 index 00000000..13909b3b --- /dev/null +++ b/commonspace/object/tree/synctree/syncclient.go @@ -0,0 +1,82 @@ +package synctree + +import ( + "context" + "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" + "github.com/anyproto/any-sync/commonspace/peermanager" + "github.com/anyproto/any-sync/commonspace/requestmanager" + "github.com/anyproto/any-sync/commonspace/spacesyncproto" + "go.uber.org/zap" +) + +type SyncClient interface { + RequestFactory + Broadcast(msg *treechangeproto.TreeSyncMessage) + SendUpdate(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error) + QueueRequest(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error) + SendRequest(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) +} + +type syncClient struct { + RequestFactory + spaceId string + requestManager requestmanager.RequestManager + peerManager peermanager.PeerManager +} + +func NewSyncClient(spaceId string, requestManager requestmanager.RequestManager, peerManager peermanager.PeerManager) SyncClient { + return &syncClient{ + RequestFactory: &requestFactory{}, + spaceId: spaceId, + requestManager: requestManager, + peerManager: peerManager, + } +} +func (s *syncClient) Broadcast(msg *treechangeproto.TreeSyncMessage) { + objMsg, err := MarshallTreeMessage(msg, s.spaceId, msg.RootChange.Id, "") + if err != nil { + return + } + err = s.peerManager.Broadcast(context.Background(), objMsg) + if err != nil { + log.Debug("broadcast error", zap.Error(err)) + } +} + +func (s *syncClient) SendUpdate(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error) { + objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, "") + if err != nil { + return + } + return s.peerManager.SendPeer(context.Background(), peerId, objMsg) +} + +func (s *syncClient) SendRequest(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) { + objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, "") + if err != nil { + return + } + return s.requestManager.SendRequest(ctx, peerId, objMsg) +} + +func (s *syncClient) QueueRequest(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error) { + objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, "") + if err != nil { + return + } + return s.requestManager.QueueRequest(peerId, objMsg) +} + +func MarshallTreeMessage(message *treechangeproto.TreeSyncMessage, spaceId, objectId, replyId string) (objMsg *spacesyncproto.ObjectSyncMessage, err error) { + payload, err := message.Marshal() + if err != nil { + return + } + objMsg = &spacesyncproto.ObjectSyncMessage{ + ReplyId: replyId, + Payload: payload, + ObjectId: objectId, + SpaceId: spaceId, + } + return +} diff --git a/commonspace/object/tree/synctree/synctree.go b/commonspace/object/tree/synctree/synctree.go index 2a87030a..cd6ff7a2 100644 --- a/commonspace/object/tree/synctree/synctree.go +++ b/commonspace/object/tree/synctree/synctree.go @@ -1,4 +1,4 @@ -//go:generate mockgen -destination mock_synctree/mock_synctree.go github.com/anyproto/any-sync/commonspace/object/tree/synctree SyncTree,ReceiveQueue,HeadNotifiable +//go:generate mockgen -destination mock_synctree/mock_synctree.go github.com/anyproto/any-sync/commonspace/object/tree/synctree SyncTree,ReceiveQueue,HeadNotifiable,SyncClient,RequestFactory,TreeSyncProtocol package synctree import ( @@ -11,7 +11,6 @@ import ( "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage" - "github.com/anyproto/any-sync/commonspace/objectsync" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/syncstatus" @@ -44,7 +43,7 @@ type SyncTree interface { type syncTree struct { objecttree.ObjectTree synchandler.SyncHandler - syncClient objectsync.SyncClient + syncClient SyncClient syncStatus syncstatus.StatusUpdater notifiable HeadNotifiable listener updatelistener.UpdateListener @@ -61,7 +60,7 @@ type ResponsiblePeersGetter interface { type BuildDeps struct { SpaceId string - SyncClient objectsync.SyncClient + SyncClient SyncClient Configuration nodeconf.NodeConf HeadNotifiable HeadNotifiable Listener updatelistener.UpdateListener @@ -119,7 +118,7 @@ func buildSyncTree(ctx context.Context, sendUpdate bool, deps BuildDeps) (t Sync if sendUpdate { headUpdate := syncTree.syncClient.CreateHeadUpdate(t, nil) // send to everybody, because everybody should know that the node or client got new tree - syncTree.syncClient.Broadcast(ctx, headUpdate) + syncTree.syncClient.Broadcast(headUpdate) } return } @@ -156,7 +155,7 @@ func (s *syncTree) AddContent(ctx context.Context, content objecttree.SignableCh } s.syncStatus.HeadsChange(s.Id(), res.Heads) headUpdate := s.syncClient.CreateHeadUpdate(s, res.Added) - s.syncClient.Broadcast(ctx, headUpdate) + s.syncClient.Broadcast(headUpdate) return } @@ -183,7 +182,7 @@ func (s *syncTree) AddRawChanges(ctx context.Context, changesPayload objecttree. s.notifiable.UpdateHeads(s.Id(), res.Heads) } headUpdate := s.syncClient.CreateHeadUpdate(s, res.Added) - s.syncClient.Broadcast(ctx, headUpdate) + s.syncClient.Broadcast(headUpdate) } return } @@ -239,7 +238,7 @@ func (s *syncTree) SyncWithPeer(ctx context.Context, peerId string) (err error) s.Lock() defer s.Unlock() headUpdate := s.syncClient.CreateHeadUpdate(s, nil) - return s.syncClient.SendWithReply(ctx, peerId, headUpdate.RootChange.Id, headUpdate, "") + return s.syncClient.SendUpdate(peerId, headUpdate.RootChange.Id, headUpdate) } func (s *syncTree) afterBuild() { diff --git a/commonspace/object/tree/synctree/synctree_test.go b/commonspace/object/tree/synctree/synctree_test.go index 9900e49b..76791c59 100644 --- a/commonspace/object/tree/synctree/synctree_test.go +++ b/commonspace/object/tree/synctree/synctree_test.go @@ -4,11 +4,11 @@ import ( "context" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree" + "github.com/anyproto/any-sync/commonspace/object/tree/synctree/mock_synctree" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener/mock_updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/objectsync" - "github.com/anyproto/any-sync/commonspace/objectsync/mock_objectsync" "github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/nodeconf" "github.com/golang/mock/gomock" @@ -18,7 +18,7 @@ import ( type syncTreeMatcher struct { objTree objecttree.ObjectTree - client objectsync.SyncClient + client SyncClient listener updatelistener.UpdateListener } @@ -34,8 +34,8 @@ func (s syncTreeMatcher) String() string { return "" } -func syncClientFuncCreator(client objectsync.SyncClient) func(spaceId string, factory objectsync.RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) objectsync.SyncClient { - return func(spaceId string, factory objectsync.RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) objectsync.SyncClient { +func syncClientFuncCreator(client SyncClient) func(spaceId string, factory RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) SyncClient { + return func(spaceId string, factory RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) SyncClient { return client } } @@ -46,7 +46,7 @@ func Test_BuildSyncTree(t *testing.T) { defer ctrl.Finish() updateListenerMock := mock_updatelistener.NewMockUpdateListener(ctrl) - syncClientMock := mock_objectsync.NewMockSyncClient(ctrl) + syncClientMock := mock_synctree.NewMockSyncClient(ctrl) objTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl)) tr := &syncTree{ ObjectTree: objTreeMock, @@ -73,7 +73,7 @@ func Test_BuildSyncTree(t *testing.T) { updateListenerMock.EXPECT().Update(tr) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) - syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)) + syncClientMock.EXPECT().Broadcast(gomock.Eq(headUpdate)) res, err := tr.AddRawChanges(ctx, payload) require.NoError(t, err) require.Equal(t, expectedRes, res) @@ -95,7 +95,7 @@ func Test_BuildSyncTree(t *testing.T) { updateListenerMock.EXPECT().Rebuild(tr) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) - syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)) + syncClientMock.EXPECT().Broadcast(gomock.Eq(headUpdate)) res, err := tr.AddRawChanges(ctx, payload) require.NoError(t, err) require.Equal(t, expectedRes, res) @@ -133,7 +133,7 @@ func Test_BuildSyncTree(t *testing.T) { Return(expectedRes, nil) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) - syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)) + syncClientMock.EXPECT().Broadcast(gomock.Eq(headUpdate)) res, err := tr.AddContent(ctx, content) require.NoError(t, err) require.Equal(t, expectedRes, res) diff --git a/commonspace/object/tree/synctree/synctreehandler.go b/commonspace/object/tree/synctree/synctreehandler.go index cdb0aa2a..154330f3 100644 --- a/commonspace/object/tree/synctree/synctreehandler.go +++ b/commonspace/object/tree/synctree/synctreehandler.go @@ -2,39 +2,66 @@ package synctree import ( "context" + "errors" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" - "github.com/anyproto/any-sync/commonspace/objectsync" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/syncstatus" - "github.com/anyproto/any-sync/util/slice" "github.com/gogo/protobuf/proto" - "go.uber.org/zap" "sync" ) +var ( + ErrMessageIsRequest = errors.New("message is request") + ErrMessageIsNotRequest = errors.New("message is not request") +) + type syncTreeHandler struct { - objTree objecttree.ObjectTree - syncClient objectsync.SyncClient - syncStatus syncstatus.StatusUpdater - handlerLock sync.Mutex - spaceId string - queue ReceiveQueue + objTree objecttree.ObjectTree + syncClient SyncClient + syncProtocol TreeSyncProtocol + syncStatus syncstatus.StatusUpdater + handlerLock sync.Mutex + spaceId string + queue ReceiveQueue } const maxQueueSize = 5 -func newSyncTreeHandler(spaceId string, objTree objecttree.ObjectTree, syncClient objectsync.SyncClient, syncStatus syncstatus.StatusUpdater) synchandler.SyncHandler { +func newSyncTreeHandler(spaceId string, objTree objecttree.ObjectTree, syncClient SyncClient, syncStatus syncstatus.StatusUpdater) synchandler.SyncHandler { return &syncTreeHandler{ - objTree: objTree, - syncClient: syncClient, - syncStatus: syncStatus, - spaceId: spaceId, - queue: newReceiveQueue(maxQueueSize), + objTree: objTree, + syncProtocol: newTreeSyncProtocol(spaceId, objTree, syncClient), + syncClient: syncClient, + syncStatus: syncStatus, + spaceId: spaceId, + queue: newReceiveQueue(maxQueueSize), } } +func (s *syncTreeHandler) HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) { + unmarshalled := &treechangeproto.TreeSyncMessage{} + err = proto.Unmarshal(request.Payload, unmarshalled) + if err != nil { + return + } + fullSyncRequest := unmarshalled.GetContent().GetFullSyncRequest() + if fullSyncRequest == nil { + err = ErrMessageIsNotRequest + return + } + s.syncStatus.HeadsReceive(senderId, request.ObjectId, treechangeproto.GetHeads(unmarshalled)) + s.objTree.Lock() + defer s.objTree.Unlock() + treeResp, err := s.syncProtocol.FullSyncRequest(ctx, senderId, fullSyncRequest) + if err != nil { + return + } + response, err = MarshallTreeMessage(treeResp, s.spaceId, request.ObjectId, "") + return +} + func (s *syncTreeHandler) HandleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { unmarshalled := &treechangeproto.TreeSyncMessage{} err = proto.Unmarshal(msg.Payload, unmarshalled) @@ -54,181 +81,27 @@ func (s *syncTreeHandler) HandleMessage(ctx context.Context, senderId string, ms func (s *syncTreeHandler) handleMessage(ctx context.Context, senderId string) (err error) { s.objTree.Lock() defer s.objTree.Unlock() - msg, replyId, err := s.queue.GetMessage(senderId) + msg, _, err := s.queue.GetMessage(senderId) if err != nil { return } defer s.queue.ClearQueue(senderId) + treeId := s.objTree.Id() content := msg.GetContent() switch { case content.GetHeadUpdate() != nil: - return s.handleHeadUpdate(ctx, senderId, content.GetHeadUpdate(), replyId) + var syncReq *treechangeproto.TreeSyncMessage + syncReq, err = s.syncProtocol.HeadUpdate(ctx, senderId, content.GetHeadUpdate()) + if err != nil || syncReq == nil { + return + } + return s.syncClient.QueueRequest(senderId, treeId, syncReq) case content.GetFullSyncRequest() != nil: - return s.handleFullSyncRequest(ctx, senderId, content.GetFullSyncRequest(), replyId) + return ErrMessageIsRequest case content.GetFullSyncResponse() != nil: - return s.handleFullSyncResponse(ctx, senderId, content.GetFullSyncResponse()) + return s.syncProtocol.FullSyncResponse(ctx, senderId, content.GetFullSyncResponse()) } return } - -func (s *syncTreeHandler) handleHeadUpdate( - ctx context.Context, - senderId string, - update *treechangeproto.TreeHeadUpdate, - replyId string) (err error) { - var ( - fullRequest *treechangeproto.TreeSyncMessage - isEmptyUpdate = len(update.Changes) == 0 - objTree = s.objTree - treeId = objTree.Id() - ) - log := log.With( - zap.Strings("update heads", update.Heads), - zap.String("treeId", treeId), - zap.String("spaceId", s.spaceId), - zap.Int("len(update changes)", len(update.Changes))) - log.DebugCtx(ctx, "received head update message") - - defer func() { - if err != nil { - log.ErrorCtx(ctx, "head update finished with error", zap.Error(err)) - } else if fullRequest != nil { - cnt := fullRequest.Content.GetFullSyncRequest() - log = log.With(zap.Strings("request heads", cnt.Heads), zap.Int("len(request changes)", len(cnt.Changes))) - log.DebugCtx(ctx, "sending full sync request") - } else { - if !isEmptyUpdate { - log.DebugCtx(ctx, "head update finished correctly") - } - } - }() - - // isEmptyUpdate is sent when the tree is brought up from cache - if isEmptyUpdate { - headEquals := slice.UnsortedEquals(objTree.Heads(), update.Heads) - log.DebugCtx(ctx, "is empty update", zap.String("treeId", objTree.Id()), zap.Bool("headEquals", headEquals)) - if headEquals { - return - } - - // we need to sync in any case - fullRequest, err = s.syncClient.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath) - if err != nil { - return - } - - return s.syncClient.SendWithReply(ctx, senderId, treeId, fullRequest, replyId) - } - - if s.alreadyHasHeads(objTree, update.Heads) { - return - } - - _, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{ - NewHeads: update.Heads, - RawChanges: update.Changes, - }) - if err != nil { - return - } - - if s.alreadyHasHeads(objTree, update.Heads) { - return - } - - fullRequest, err = s.syncClient.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath) - if err != nil { - return - } - - return s.syncClient.SendWithReply(ctx, senderId, treeId, fullRequest, replyId) -} - -func (s *syncTreeHandler) handleFullSyncRequest( - ctx context.Context, - senderId string, - request *treechangeproto.TreeFullSyncRequest, - replyId string) (err error) { - var ( - fullResponse *treechangeproto.TreeSyncMessage - header = s.objTree.Header() - objTree = s.objTree - treeId = s.objTree.Id() - ) - - log := log.With(zap.String("senderId", senderId), - zap.Strings("request heads", request.Heads), - zap.String("treeId", treeId), - zap.String("replyId", replyId), - zap.String("spaceId", s.spaceId), - zap.Int("len(request changes)", len(request.Changes))) - log.DebugCtx(ctx, "received full sync request message") - - defer func() { - if err != nil { - log.ErrorCtx(ctx, "full sync request finished with error", zap.Error(err)) - s.syncClient.SendWithReply(ctx, senderId, treeId, treechangeproto.WrapError(treechangeproto.ErrFullSync, header), replyId) - return - } else if fullResponse != nil { - cnt := fullResponse.Content.GetFullSyncResponse() - log = log.With(zap.Strings("response heads", cnt.Heads), zap.Int("len(response changes)", len(cnt.Changes))) - log.DebugCtx(ctx, "full sync response sent") - } - }() - - if len(request.Changes) != 0 && !s.alreadyHasHeads(objTree, request.Heads) { - _, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{ - NewHeads: request.Heads, - RawChanges: request.Changes, - }) - if err != nil { - return - } - } - fullResponse, err = s.syncClient.CreateFullSyncResponse(objTree, request.Heads, request.SnapshotPath) - if err != nil { - return - } - - return s.syncClient.SendWithReply(ctx, senderId, treeId, fullResponse, replyId) -} - -func (s *syncTreeHandler) handleFullSyncResponse( - ctx context.Context, - senderId string, - response *treechangeproto.TreeFullSyncResponse) (err error) { - var ( - objTree = s.objTree - treeId = s.objTree.Id() - ) - log := log.With( - zap.Strings("heads", response.Heads), - zap.String("treeId", treeId), - zap.String("spaceId", s.spaceId), - zap.Int("len(changes)", len(response.Changes))) - log.DebugCtx(ctx, "received full sync response message") - - defer func() { - if err != nil { - log.ErrorCtx(ctx, "full sync response failed", zap.Error(err)) - } else { - log.DebugCtx(ctx, "full sync response succeeded") - } - }() - - if s.alreadyHasHeads(objTree, response.Heads) { - return - } - - _, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{ - NewHeads: response.Heads, - RawChanges: response.Changes, - }) - return -} - -func (s *syncTreeHandler) alreadyHasHeads(t objecttree.ObjectTree, heads []string) bool { - return slice.UnsortedEquals(t.Heads(), heads) || t.HasChanges(heads...) -} diff --git a/commonspace/object/tree/synctree/synctreehandler_test.go b/commonspace/object/tree/synctree/synctreehandler_test.go index c81ca5f4..f03f5ff1 100644 --- a/commonspace/object/tree/synctree/synctreehandler_test.go +++ b/commonspace/object/tree/synctree/synctreehandler_test.go @@ -2,20 +2,15 @@ package synctree import ( "context" - "fmt" - "github.com/anyproto/any-sync/commonspace/objectsync" - "github.com/anyproto/any-sync/commonspace/objectsync/mock_objectsync" + "github.com/anyproto/any-sync/commonspace/object/tree/synctree/mock_synctree" + "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" + "github.com/stretchr/testify/require" "sync" "testing" - "github.com/anyproto/any-sync/app/logger" - "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree" - "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/golang/mock/gomock" - "github.com/stretchr/testify/require" - "go.uber.org/zap" ) type testObjTreeMock struct { @@ -55,31 +50,43 @@ func (t *testObjTreeMock) TryRLock() bool { type syncHandlerFixture struct { ctrl *gomock.Controller - syncClientMock *mock_objectsync.MockSyncClient + syncClientMock *mock_synctree.MockSyncClient objectTreeMock *testObjTreeMock receiveQueueMock ReceiveQueue + syncProtocolMock *mock_synctree.MockTreeSyncProtocol + spaceId string + senderId string + treeId string syncHandler *syncTreeHandler } func newSyncHandlerFixture(t *testing.T) *syncHandlerFixture { ctrl := gomock.NewController(t) - syncClientMock := mock_objectsync.NewMockSyncClient(ctrl) objectTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl)) + syncClientMock := mock_synctree.NewMockSyncClient(ctrl) + syncProtocolMock := mock_synctree.NewMockTreeSyncProtocol(ctrl) + spaceId := "spaceId" receiveQueue := newReceiveQueue(5) syncHandler := &syncTreeHandler{ - objTree: objectTreeMock, - syncClient: syncClientMock, - queue: receiveQueue, - syncStatus: syncstatus.NewNoOpSyncStatus(), + objTree: objectTreeMock, + syncClient: syncClientMock, + syncProtocol: syncProtocolMock, + spaceId: spaceId, + queue: receiveQueue, + syncStatus: syncstatus.NewNoOpSyncStatus(), } return &syncHandlerFixture{ ctrl: ctrl, - syncClientMock: syncClientMock, objectTreeMock: objectTreeMock, receiveQueueMock: receiveQueue, + syncProtocolMock: syncProtocolMock, + syncClientMock: syncClientMock, syncHandler: syncHandler, + spaceId: spaceId, + senderId: "senderId", + treeId: "treeId", } } @@ -87,341 +94,128 @@ func (fx *syncHandlerFixture) stop() { fx.ctrl.Finish() } -func TestSyncHandler_HandleHeadUpdate(t *testing.T) { +func TestSyncTreeHandler_HandleMessage(t *testing.T) { ctx := context.Background() - log = logger.CtxLogger{Logger: zap.NewNop()} - fullRequest := &treechangeproto.TreeSyncMessage{ - Content: &treechangeproto.TreeSyncContentValue{ - Value: &treechangeproto.TreeSyncContentValue_FullSyncRequest{ - FullSyncRequest: &treechangeproto.TreeFullSyncRequest{}, - }, - }, - } - t.Run("head update non empty all heads added", func(t *testing.T) { + t.Run("handle head update message", func(t *testing.T) { fx := newSyncHandlerFixture(t) defer fx.stop() treeId := "treeId" - senderId := "senderId" chWithId := &treechangeproto.RawTreeChangeWithId{} - headUpdate := &treechangeproto.TreeHeadUpdate{ - Heads: []string{"h1"}, - Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, - SnapshotPath: []string{"h1"}, - } + headUpdate := &treechangeproto.TreeHeadUpdate{} treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) - objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") + objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "") - fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) - fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).Times(2) - fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false) - fx.objectTreeMock.EXPECT(). - AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{ - NewHeads: []string{"h1"}, - RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId}, - })). - Return(objecttree.AddResult{}, nil) - fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(true) + syncReq := &treechangeproto.TreeSyncMessage{} + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(syncReq, nil) + fx.syncClientMock.EXPECT().QueueRequest(fx.senderId, fx.treeId, syncReq).Return(nil) - err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) + err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg) require.NoError(t, err) }) - t.Run("head update non empty heads not added", func(t *testing.T) { + t.Run("handle head update message, empty sync request", func(t *testing.T) { fx := newSyncHandlerFixture(t) defer fx.stop() treeId := "treeId" - senderId := "senderId" chWithId := &treechangeproto.RawTreeChangeWithId{} - headUpdate := &treechangeproto.TreeHeadUpdate{ - Heads: []string{"h1"}, - Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, - SnapshotPath: []string{"h1"}, - } + headUpdate := &treechangeproto.TreeHeadUpdate{} treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) - objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") + objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "") - fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) - fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes() - fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false) - fx.objectTreeMock.EXPECT(). - AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{ - NewHeads: []string{"h1"}, - RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId}, - })). - Return(objecttree.AddResult{}, nil) - fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false) - fx.syncClientMock.EXPECT(). - CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). - Return(fullRequest, nil) - fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullRequest), gomock.Eq("")) + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(nil, nil) - err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) + err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg) require.NoError(t, err) }) - t.Run("head update non empty equal heads", func(t *testing.T) { + t.Run("handle full sync request returns error", func(t *testing.T) { fx := newSyncHandlerFixture(t) defer fx.stop() treeId := "treeId" - senderId := "senderId" chWithId := &treechangeproto.RawTreeChangeWithId{} - headUpdate := &treechangeproto.TreeHeadUpdate{ - Heads: []string{"h1"}, - Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, - SnapshotPath: []string{"h1"}, - } - treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) - objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") + fullRequest := &treechangeproto.TreeFullSyncRequest{} + treeMsg := treechangeproto.WrapFullRequest(fullRequest, chWithId) + objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "") - fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) - fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes() + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) - err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) - require.NoError(t, err) + err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg) + require.Equal(t, err, ErrMessageIsRequest) }) - t.Run("head update empty", func(t *testing.T) { + t.Run("handle full sync response", func(t *testing.T) { fx := newSyncHandlerFixture(t) defer fx.stop() treeId := "treeId" - senderId := "senderId" chWithId := &treechangeproto.RawTreeChangeWithId{} - headUpdate := &treechangeproto.TreeHeadUpdate{ - Heads: []string{"h1"}, - Changes: nil, - SnapshotPath: []string{"h1"}, - } - treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) - objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") - - fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) - fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes() - fx.syncClientMock.EXPECT(). - CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). - Return(fullRequest, nil) - fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullRequest), gomock.Eq("")) - - err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) - require.NoError(t, err) - }) - - t.Run("head update empty equal heads", func(t *testing.T) { - fx := newSyncHandlerFixture(t) - defer fx.stop() - treeId := "treeId" - senderId := "senderId" - chWithId := &treechangeproto.RawTreeChangeWithId{} - headUpdate := &treechangeproto.TreeHeadUpdate{ - Heads: []string{"h1"}, - Changes: nil, - SnapshotPath: []string{"h1"}, - } - treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) - objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") - - fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) - fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes() - - err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) - require.NoError(t, err) - }) -} - -func TestSyncHandler_HandleFullSyncRequest(t *testing.T) { - ctx := context.Background() - log = logger.CtxLogger{Logger: zap.NewNop()} - fullResponse := &treechangeproto.TreeSyncMessage{ - Content: &treechangeproto.TreeSyncContentValue{ - Value: &treechangeproto.TreeSyncContentValue_FullSyncResponse{ - FullSyncResponse: &treechangeproto.TreeFullSyncResponse{}, - }, - }, - } - - t.Run("full sync request with change", func(t *testing.T) { - fx := newSyncHandlerFixture(t) - defer fx.stop() - treeId := "treeId" - senderId := "senderId" - chWithId := &treechangeproto.RawTreeChangeWithId{} - fullSyncRequest := &treechangeproto.TreeFullSyncRequest{ - Heads: []string{"h1"}, - Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, - SnapshotPath: []string{"h1"}, - } - treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId) - objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") - - fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) - fx.objectTreeMock.EXPECT().Header().Return(nil) - fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes() - fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false) - fx.objectTreeMock.EXPECT(). - AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{ - NewHeads: []string{"h1"}, - RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId}, - })). - Return(objecttree.AddResult{}, nil) - fx.syncClientMock.EXPECT(). - CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). - Return(fullResponse, nil) - fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullResponse), gomock.Eq("")) - - err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) - require.NoError(t, err) - }) - - t.Run("full sync request with change same heads", func(t *testing.T) { - fx := newSyncHandlerFixture(t) - defer fx.stop() - treeId := "treeId" - senderId := "senderId" - chWithId := &treechangeproto.RawTreeChangeWithId{} - fullSyncRequest := &treechangeproto.TreeFullSyncRequest{ - Heads: []string{"h1"}, - Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, - SnapshotPath: []string{"h1"}, - } - treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId) - objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") - - fx.objectTreeMock.EXPECT(). - Id().AnyTimes().Return(treeId) - fx.objectTreeMock.EXPECT().Header().Return(nil) - fx.objectTreeMock.EXPECT(). - Heads(). - Return([]string{"h1"}).AnyTimes() - fx.syncClientMock.EXPECT(). - CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). - Return(fullResponse, nil) - fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullResponse), gomock.Eq("")) - - err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) - require.NoError(t, err) - }) - - t.Run("full sync request without change but with reply id", func(t *testing.T) { - fx := newSyncHandlerFixture(t) - defer fx.stop() - treeId := "treeId" - senderId := "senderId" - replyId := "replyId" - chWithId := &treechangeproto.RawTreeChangeWithId{} - fullSyncRequest := &treechangeproto.TreeFullSyncRequest{ - Heads: []string{"h1"}, - SnapshotPath: []string{"h1"}, - } - treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId) - objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") - objectMsg.RequestId = replyId - - fx.objectTreeMock.EXPECT(). - Id().AnyTimes().Return(treeId) - fx.objectTreeMock.EXPECT().Header().Return(nil) - fx.syncClientMock.EXPECT(). - CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). - Return(fullResponse, nil) - fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullResponse), gomock.Eq(replyId)) - - err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) - require.NoError(t, err) - }) - - t.Run("full sync request with add raw changes error", func(t *testing.T) { - fx := newSyncHandlerFixture(t) - defer fx.stop() - treeId := "treeId" - senderId := "senderId" - chWithId := &treechangeproto.RawTreeChangeWithId{} - fullSyncRequest := &treechangeproto.TreeFullSyncRequest{ - Heads: []string{"h1"}, - Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, - SnapshotPath: []string{"h1"}, - } - treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId) - objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") - - fx.objectTreeMock.EXPECT(). - Id().AnyTimes().Return(treeId) - fx.objectTreeMock.EXPECT().Header().Return(nil) - fx.objectTreeMock.EXPECT(). - Heads(). - Return([]string{"h2"}) - fx.objectTreeMock.EXPECT(). - HasChanges(gomock.Eq([]string{"h1"})). - Return(false) - fx.objectTreeMock.EXPECT(). - AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{ - NewHeads: []string{"h1"}, - RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId}, - })). - Return(objecttree.AddResult{}, fmt.Errorf("")) - fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Any(), gomock.Eq("")) - - err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) - require.Error(t, err) - }) -} - -func TestSyncHandler_HandleFullSyncResponse(t *testing.T) { - ctx := context.Background() - log = logger.CtxLogger{Logger: zap.NewNop()} - - t.Run("full sync response with change", func(t *testing.T) { - fx := newSyncHandlerFixture(t) - defer fx.stop() - treeId := "treeId" - senderId := "senderId" - replyId := "replyId" - chWithId := &treechangeproto.RawTreeChangeWithId{} - fullSyncResponse := &treechangeproto.TreeFullSyncResponse{ - Heads: []string{"h1"}, - Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, - SnapshotPath: []string{"h1"}, - } + fullSyncResponse := &treechangeproto.TreeFullSyncResponse{} treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId) - objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, replyId) + objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "") - fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) - fx.objectTreeMock.EXPECT(). - Heads(). - Return([]string{"h2"}).AnyTimes() - fx.objectTreeMock.EXPECT(). - HasChanges(gomock.Eq([]string{"h1"})). - Return(false) - fx.objectTreeMock.EXPECT(). - AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{ - NewHeads: []string{"h1"}, - RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId}, - })). - Return(objecttree.AddResult{}, nil) + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.syncProtocolMock.EXPECT().FullSyncResponse(ctx, fx.senderId, gomock.Any()).Return(nil) - err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) - require.NoError(t, err) - }) - - t.Run("full sync response with same heads", func(t *testing.T) { - fx := newSyncHandlerFixture(t) - defer fx.stop() - treeId := "treeId" - senderId := "senderId" - replyId := "replyId" - chWithId := &treechangeproto.RawTreeChangeWithId{} - fullSyncResponse := &treechangeproto.TreeFullSyncResponse{ - Heads: []string{"h1"}, - Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, - SnapshotPath: []string{"h1"}, - } - treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId) - objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, replyId) - - fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) - fx.objectTreeMock.EXPECT(). - Heads(). - Return([]string{"h1"}).AnyTimes() - - err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) + err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg) require.NoError(t, err) }) } + +func TestSyncTreeHandler_HandleRequest(t *testing.T) { + ctx := context.Background() + + t.Run("handle request", func(t *testing.T) { + fx := newSyncHandlerFixture(t) + defer fx.stop() + treeId := "treeId" + chWithId := &treechangeproto.RawTreeChangeWithId{} + fullRequest := &treechangeproto.TreeFullSyncRequest{} + treeMsg := treechangeproto.WrapFullRequest(fullRequest, chWithId) + objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "") + + syncResp := &treechangeproto.TreeSyncMessage{} + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.syncProtocolMock.EXPECT().FullSyncRequest(ctx, fx.senderId, gomock.Any()).Return(syncResp, nil) + + res, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg) + require.NoError(t, err) + require.NotNil(t, res) + }) + + t.Run("handle request", func(t *testing.T) { + fx := newSyncHandlerFixture(t) + defer fx.stop() + treeId := "treeId" + chWithId := &treechangeproto.RawTreeChangeWithId{} + fullRequest := &treechangeproto.TreeFullSyncRequest{} + treeMsg := treechangeproto.WrapFullRequest(fullRequest, chWithId) + objectMsg, _ := MarshallTreeMessage(treeMsg, "spaceId", treeId, "") + + syncResp := &treechangeproto.TreeSyncMessage{} + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.syncProtocolMock.EXPECT().FullSyncRequest(ctx, fx.senderId, gomock.Any()).Return(syncResp, nil) + + res, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg) + require.NoError(t, err) + require.NotNil(t, res) + }) + + t.Run("handle other message", func(t *testing.T) { + fx := newSyncHandlerFixture(t) + defer fx.stop() + treeId := "treeId" + chWithId := &treechangeproto.RawTreeChangeWithId{} + fullResponse := &treechangeproto.TreeFullSyncResponse{} + responseMsg := treechangeproto.WrapFullResponse(fullResponse, chWithId) + headUpdate := &treechangeproto.TreeHeadUpdate{} + headUpdateMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) + for _, msg := range []*treechangeproto.TreeSyncMessage{responseMsg, headUpdateMsg} { + objectMsg, _ := MarshallTreeMessage(msg, "spaceId", treeId, "") + + _, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg) + require.Equal(t, err, ErrMessageIsNotRequest) + } + }) +} diff --git a/commonspace/object/tree/synctree/treeremotegetter.go b/commonspace/object/tree/synctree/treeremotegetter.go index 3006ab5b..bee53231 100644 --- a/commonspace/object/tree/synctree/treeremotegetter.go +++ b/commonspace/object/tree/synctree/treeremotegetter.go @@ -47,7 +47,7 @@ func (t treeRemoteGetter) getPeers(ctx context.Context) (peerIds []string, err e func (t treeRemoteGetter) treeRequest(ctx context.Context, peerId string) (msg *treechangeproto.TreeSyncMessage, err error) { newTreeRequest := t.deps.SyncClient.CreateNewTreeRequest() - resp, err := t.deps.SyncClient.SendSync(ctx, peerId, t.treeId, newTreeRequest) + resp, err := t.deps.SyncClient.SendRequest(ctx, peerId, t.treeId, newTreeRequest) if err != nil { return } diff --git a/commonspace/object/tree/synctree/treesyncprotocol.go b/commonspace/object/tree/synctree/treesyncprotocol.go new file mode 100644 index 00000000..be759259 --- /dev/null +++ b/commonspace/object/tree/synctree/treesyncprotocol.go @@ -0,0 +1,152 @@ +package synctree + +import ( + "context" + "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" + "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" + "github.com/anyproto/any-sync/util/slice" + "go.uber.org/zap" +) + +type TreeSyncProtocol interface { + HeadUpdate(ctx context.Context, senderId string, update *treechangeproto.TreeHeadUpdate) (request *treechangeproto.TreeSyncMessage, err error) + FullSyncRequest(ctx context.Context, senderId string, request *treechangeproto.TreeFullSyncRequest) (response *treechangeproto.TreeSyncMessage, err error) + FullSyncResponse(ctx context.Context, senderId string, response *treechangeproto.TreeFullSyncResponse) (err error) +} + +type treeSyncProtocol struct { + log logger.CtxLogger + spaceId string + objTree objecttree.ObjectTree + reqFactory RequestFactory +} + +func newTreeSyncProtocol(spaceId string, objTree objecttree.ObjectTree, reqFactory RequestFactory) *treeSyncProtocol { + return &treeSyncProtocol{ + log: log.With(zap.String("spaceId", spaceId), zap.String("treeId", objTree.Id())), + spaceId: spaceId, + objTree: objTree, + reqFactory: reqFactory, + } +} + +func (t *treeSyncProtocol) HeadUpdate(ctx context.Context, senderId string, update *treechangeproto.TreeHeadUpdate) (fullRequest *treechangeproto.TreeSyncMessage, err error) { + var ( + isEmptyUpdate = len(update.Changes) == 0 + objTree = t.objTree + ) + log := t.log.With( + zap.String("senderId", senderId), + zap.Strings("update heads", update.Heads), + zap.Int("len(update changes)", len(update.Changes))) + log.DebugCtx(ctx, "received head update message") + + defer func() { + if err != nil { + log.ErrorCtx(ctx, "head update finished with error", zap.Error(err)) + } else if fullRequest != nil { + cnt := fullRequest.Content.GetFullSyncRequest() + log = log.With(zap.Strings("request heads", cnt.Heads), zap.Int("len(request changes)", len(cnt.Changes))) + log.DebugCtx(ctx, "returning full sync request") + } else { + if !isEmptyUpdate { + log.DebugCtx(ctx, "head update finished correctly") + } + } + }() + + // isEmptyUpdate is sent when the tree is brought up from cache + if isEmptyUpdate { + headEquals := slice.UnsortedEquals(objTree.Heads(), update.Heads) + log.DebugCtx(ctx, "is empty update", zap.String("treeId", objTree.Id()), zap.Bool("headEquals", headEquals)) + if headEquals { + return + } + + // we need to sync in any case + fullRequest, err = t.reqFactory.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath) + return + } + + if t.alreadyHasHeads(objTree, update.Heads) { + return + } + + _, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{ + NewHeads: update.Heads, + RawChanges: update.Changes, + }) + if err != nil { + return + } + + if t.alreadyHasHeads(objTree, update.Heads) { + return + } + + fullRequest, err = t.reqFactory.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath) + return +} + +func (t *treeSyncProtocol) FullSyncRequest(ctx context.Context, senderId string, request *treechangeproto.TreeFullSyncRequest) (fullResponse *treechangeproto.TreeSyncMessage, err error) { + var ( + objTree = t.objTree + ) + log := t.log.With(zap.String("senderId", senderId), + zap.Strings("request heads", request.Heads), + zap.Int("len(request changes)", len(request.Changes))) + log.DebugCtx(ctx, "received full sync request message") + + defer func() { + if err != nil { + log.ErrorCtx(ctx, "full sync request finished with error", zap.Error(err)) + } else if fullResponse != nil { + cnt := fullResponse.Content.GetFullSyncResponse() + log = log.With(zap.Strings("response heads", cnt.Heads), zap.Int("len(response changes)", len(cnt.Changes))) + log.DebugCtx(ctx, "full sync response sent") + } + }() + + if len(request.Changes) != 0 && !t.alreadyHasHeads(objTree, request.Heads) { + _, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{ + NewHeads: request.Heads, + RawChanges: request.Changes, + }) + if err != nil { + return + } + } + fullResponse, err = t.reqFactory.CreateFullSyncResponse(objTree, request.Heads, request.SnapshotPath) + return +} + +func (t *treeSyncProtocol) FullSyncResponse(ctx context.Context, senderId string, response *treechangeproto.TreeFullSyncResponse) (err error) { + var ( + objTree = t.objTree + ) + log := log.With( + zap.Strings("heads", response.Heads), + zap.Int("len(changes)", len(response.Changes))) + log.DebugCtx(ctx, "received full sync response message") + defer func() { + if err != nil { + log.ErrorCtx(ctx, "full sync response failed", zap.Error(err)) + } else { + log.DebugCtx(ctx, "full sync response succeeded") + } + }() + if t.alreadyHasHeads(objTree, response.Heads) { + return + } + + _, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{ + NewHeads: response.Heads, + RawChanges: response.Changes, + }) + return +} + +func (t *treeSyncProtocol) alreadyHasHeads(ot objecttree.ObjectTree, heads []string) bool { + return slice.UnsortedEquals(ot.Heads(), heads) || ot.HasChanges(heads...) +} diff --git a/commonspace/object/tree/synctree/treesyncprotocol_test.go b/commonspace/object/tree/synctree/treesyncprotocol_test.go new file mode 100644 index 00000000..c80dbe35 --- /dev/null +++ b/commonspace/object/tree/synctree/treesyncprotocol_test.go @@ -0,0 +1,293 @@ +package synctree + +import ( + "context" + "fmt" + "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" + "github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree" + "github.com/anyproto/any-sync/commonspace/object/tree/synctree/mock_synctree" + "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "testing" +) + +type treeSyncProtocolFixture struct { + log logger.CtxLogger + spaceId string + senderId string + treeId string + objectTreeMock *testObjTreeMock + reqFactory *mock_synctree.MockRequestFactory + ctrl *gomock.Controller + syncProtocol TreeSyncProtocol +} + +func newSyncProtocolFixture(t *testing.T) *treeSyncProtocolFixture { + ctrl := gomock.NewController(t) + objTree := &testObjTreeMock{ + MockObjectTree: mock_objecttree.NewMockObjectTree(ctrl), + } + spaceId := "spaceId" + reqFactory := mock_synctree.NewMockRequestFactory(ctrl) + objTree.EXPECT().Id().Return("treeId") + syncProtocol := newTreeSyncProtocol(spaceId, objTree, reqFactory) + return &treeSyncProtocolFixture{ + log: log, + spaceId: spaceId, + senderId: "senderId", + treeId: "treeId", + objectTreeMock: objTree, + reqFactory: reqFactory, + ctrl: ctrl, + syncProtocol: syncProtocol, + } +} + +func (fx *treeSyncProtocolFixture) stop() { + fx.ctrl.Finish() +} + +func TestTreeSyncProtocol_HeadUpdate(t *testing.T) { + ctx := context.Background() + fullRequest := &treechangeproto.TreeSyncMessage{ + Content: &treechangeproto.TreeSyncContentValue{ + Value: &treechangeproto.TreeSyncContentValue_FullSyncRequest{ + FullSyncRequest: &treechangeproto.TreeFullSyncRequest{}, + }, + }, + } + + t.Run("head update non empty all heads added", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &treechangeproto.RawTreeChangeWithId{} + headUpdate := &treechangeproto.TreeHeadUpdate{ + Heads: []string{"h1"}, + Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, + SnapshotPath: []string{"h1"}, + } + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).Times(2) + fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false) + fx.objectTreeMock.EXPECT(). + AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{ + NewHeads: []string{"h1"}, + RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId}, + })). + Return(objecttree.AddResult{}, nil) + fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(true) + + res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate) + require.NoError(t, err) + require.Nil(t, res) + }) + + t.Run("head update non empty equal heads", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &treechangeproto.RawTreeChangeWithId{} + headUpdate := &treechangeproto.TreeHeadUpdate{ + Heads: []string{"h1"}, + Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, + SnapshotPath: []string{"h1"}, + } + + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes() + + res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate) + require.NoError(t, err) + require.Nil(t, res) + }) + + t.Run("head update empty", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + headUpdate := &treechangeproto.TreeHeadUpdate{ + Heads: []string{"h1"}, + Changes: nil, + SnapshotPath: []string{"h1"}, + } + + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes() + fx.reqFactory.EXPECT(). + CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). + Return(fullRequest, nil) + + res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate) + require.NoError(t, err) + require.Equal(t, fullRequest, res) + }) + + t.Run("head update empty equal heads", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + headUpdate := &treechangeproto.TreeHeadUpdate{ + Heads: []string{"h1"}, + Changes: nil, + SnapshotPath: []string{"h1"}, + } + + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes() + + res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate) + require.NoError(t, err) + require.Nil(t, res) + }) +} + +func TestTreeSyncProtocol_FullSyncRequest(t *testing.T) { + ctx := context.Background() + fullResponse := &treechangeproto.TreeSyncMessage{ + Content: &treechangeproto.TreeSyncContentValue{ + Value: &treechangeproto.TreeSyncContentValue_FullSyncResponse{ + FullSyncResponse: &treechangeproto.TreeFullSyncResponse{}, + }, + }, + } + + t.Run("full sync request with change", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &treechangeproto.RawTreeChangeWithId{} + fullSyncRequest := &treechangeproto.TreeFullSyncRequest{ + Heads: []string{"h1"}, + Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, + SnapshotPath: []string{"h1"}, + } + + fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes() + fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false) + fx.objectTreeMock.EXPECT(). + AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{ + NewHeads: []string{"h1"}, + RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId}, + })). + Return(objecttree.AddResult{}, nil) + fx.reqFactory.EXPECT(). + CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). + Return(fullResponse, nil) + + res, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest) + require.NoError(t, err) + require.Equal(t, fullResponse, res) + }) + + t.Run("full sync request with change same heads", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &treechangeproto.RawTreeChangeWithId{} + fullSyncRequest := &treechangeproto.TreeFullSyncRequest{ + Heads: []string{"h1"}, + Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, + SnapshotPath: []string{"h1"}, + } + + fx.objectTreeMock.EXPECT(). + Heads(). + Return([]string{"h1"}).AnyTimes() + fx.reqFactory.EXPECT(). + CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). + Return(fullResponse, nil) + + res, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest) + require.NoError(t, err) + require.Equal(t, fullResponse, res) + }) + + t.Run("full sync request without changes", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + fullSyncRequest := &treechangeproto.TreeFullSyncRequest{ + Heads: []string{"h1"}, + SnapshotPath: []string{"h1"}, + } + + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.reqFactory.EXPECT(). + CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})). + Return(fullResponse, nil) + + res, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest) + require.NoError(t, err) + require.Equal(t, fullResponse, res) + }) + + t.Run("full sync request with change, raw changes error", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &treechangeproto.RawTreeChangeWithId{} + fullSyncRequest := &treechangeproto.TreeFullSyncRequest{ + Heads: []string{"h1"}, + Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, + SnapshotPath: []string{"h1"}, + } + + fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes() + fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false) + fx.objectTreeMock.EXPECT(). + AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{ + NewHeads: []string{"h1"}, + RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId}, + })). + Return(objecttree.AddResult{}, fmt.Errorf("addRawChanges error")) + + _, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest) + require.Error(t, err) + }) +} + +func TestTreeSyncProtocol_FullSyncResponse(t *testing.T) { + ctx := context.Background() + + t.Run("full sync response with change", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &treechangeproto.RawTreeChangeWithId{} + fullSyncResponse := &treechangeproto.TreeFullSyncResponse{ + Heads: []string{"h1"}, + Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, + SnapshotPath: []string{"h1"}, + } + + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.objectTreeMock.EXPECT(). + Heads(). + Return([]string{"h2"}).AnyTimes() + fx.objectTreeMock.EXPECT(). + HasChanges(gomock.Eq([]string{"h1"})). + Return(false) + fx.objectTreeMock.EXPECT(). + AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{ + NewHeads: []string{"h1"}, + RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId}, + })). + Return(objecttree.AddResult{}, nil) + + err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullSyncResponse) + require.NoError(t, err) + }) + + t.Run("full sync response with same heads", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &treechangeproto.RawTreeChangeWithId{} + fullSyncResponse := &treechangeproto.TreeFullSyncResponse{ + Heads: []string{"h1"}, + Changes: []*treechangeproto.RawTreeChangeWithId{chWithId}, + SnapshotPath: []string{"h1"}, + } + + fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId) + fx.objectTreeMock.EXPECT(). + Heads(). + Return([]string{"h1"}).AnyTimes() + + err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullSyncResponse) + require.NoError(t, err) + }) +} diff --git a/commonspace/object/tree/synctree/utils_test.go b/commonspace/object/tree/synctree/utils_test.go index 46936560..6aab234b 100644 --- a/commonspace/object/tree/synctree/utils_test.go +++ b/commonspace/object/tree/synctree/utils_test.go @@ -3,11 +3,11 @@ package synctree import ( "context" "fmt" + "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage" - "github.com/anyproto/any-sync/commonspace/objectsync" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/syncstatus" @@ -82,51 +82,124 @@ func (m *messageLog) addMessage(msg protocolMsg) { m.batcher.Add(context.Background(), msg) } +type requestPeerManager struct { + peerId string + handlers map[string]*testSyncHandler + log *messageLog +} + +func newRequestPeerManager(peerId string, log *messageLog) *requestPeerManager { + return &requestPeerManager{ + peerId: peerId, + handlers: map[string]*testSyncHandler{}, + log: log, + } +} + +func (r *requestPeerManager) addHandler(peerId string, handler *testSyncHandler) { + r.handlers[peerId] = handler +} + +func (r *requestPeerManager) Run(ctx context.Context) (err error) { + return nil +} + +func (r *requestPeerManager) Close(ctx context.Context) (err error) { + return nil +} + +func (r *requestPeerManager) SendRequest(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) { + panic("should not be called") +} + +func (r *requestPeerManager) QueueRequest(peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { + pMsg := protocolMsg{ + msg: msg, + senderId: r.peerId, + receiverId: peerId, + } + r.log.addMessage(pMsg) + return r.handlers[peerId].send(context.Background(), pMsg) +} + +func (r *requestPeerManager) Init(a *app.App) (err error) { + return +} + +func (r *requestPeerManager) Name() (name string) { + return +} + +func (r *requestPeerManager) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { + pMsg := protocolMsg{ + msg: msg, + senderId: r.peerId, + receiverId: peerId, + } + r.log.addMessage(pMsg) + return r.handlers[peerId].send(context.Background(), pMsg) +} + +func (r *requestPeerManager) Broadcast(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage) (err error) { + for _, handler := range r.handlers { + pMsg := protocolMsg{ + msg: msg, + senderId: r.peerId, + receiverId: handler.peerId, + } + r.log.addMessage(pMsg) + handler.send(context.Background(), pMsg) + } + return +} + +func (r *requestPeerManager) GetResponsiblePeers(ctx context.Context) (peers []peer.Peer, err error) { + return nil, nil +} + // testSyncHandler is the wrapper around individual tree to test sync protocol type testSyncHandler struct { synchandler.SyncHandler - batcher *mb.MB[protocolMsg] - peerId string - aclList list.AclList - log *messageLog - syncClient objectsync.SyncClient - builder objecttree.BuildObjectTreeFunc + batcher *mb.MB[protocolMsg] + peerId string + aclList list.AclList + log *messageLog + syncClient SyncClient + builder objecttree.BuildObjectTreeFunc + peerManager *requestPeerManager } // createSyncHandler creates a sync handler when a tree is already created func createSyncHandler(peerId, spaceId string, objTree objecttree.ObjectTree, log *messageLog) *testSyncHandler { - factory := objectsync.NewRequestFactory() - syncClient := objectsync.NewSyncClient(spaceId, newTestMessagePool(peerId, log), factory) + peerManager := newRequestPeerManager(peerId, log) + syncClient := NewSyncClient(spaceId, peerManager, peerManager) netTree := &broadcastTree{ ObjectTree: objTree, SyncClient: syncClient, } handler := newSyncTreeHandler(spaceId, netTree, syncClient, syncstatus.NewNoOpSyncStatus()) - return newTestSyncHandler(peerId, handler) + return &testSyncHandler{ + SyncHandler: handler, + batcher: mb.New[protocolMsg](0), + peerId: peerId, + peerManager: peerManager, + } } // createEmptySyncHandler creates a sync handler when the tree will be provided later (this emulates the situation when we have no tree) func createEmptySyncHandler(peerId, spaceId string, builder objecttree.BuildObjectTreeFunc, aclList list.AclList, log *messageLog) *testSyncHandler { - factory := objectsync.NewRequestFactory() - syncClient := objectsync.NewSyncClient(spaceId, newTestMessagePool(peerId, log), factory) + peerManager := newRequestPeerManager(peerId, log) + syncClient := NewSyncClient(spaceId, peerManager, peerManager) batcher := mb.New[protocolMsg](0) return &testSyncHandler{ - batcher: batcher, - peerId: peerId, - aclList: aclList, - log: log, - syncClient: syncClient, - builder: builder, - } -} - -func newTestSyncHandler(peerId string, syncHandler synchandler.SyncHandler) *testSyncHandler { - batcher := mb.New[protocolMsg](0) - return &testSyncHandler{ - SyncHandler: syncHandler, batcher: batcher, peerId: peerId, + aclList: aclList, + log: log, + syncClient: syncClient, + builder: builder, + peerManager: peerManager, } } @@ -140,13 +213,8 @@ func (h *testSyncHandler) HandleMessage(ctx context.Context, senderId string, re return } if unmarshalled.Content.GetFullSyncResponse() == nil { - newTreeRequest := objectsync.NewRequestFactory().CreateNewTreeRequest() - var objMsg *spacesyncproto.ObjectSyncMessage - objMsg, err = objectsync.MarshallTreeMessage(newTreeRequest, request.SpaceId, request.ObjectId, "") - if err != nil { - return - } - return h.manager().SendPeer(context.Background(), senderId, objMsg) + newTreeRequest := NewRequestFactory().CreateNewTreeRequest() + return h.syncClient.QueueRequest(senderId, request.ObjectId, newTreeRequest) } fullSyncResponse := unmarshalled.Content.GetFullSyncResponse() treeStorage, _ := treestorage.NewInMemoryTreeStorage(unmarshalled.RootChange, []string{unmarshalled.RootChange.Id}, nil) @@ -166,20 +234,13 @@ func (h *testSyncHandler) HandleMessage(ctx context.Context, senderId string, re return } h.SyncHandler = newSyncTreeHandler(request.SpaceId, netTree, h.syncClient, syncstatus.NewNoOpSyncStatus()) - var objMsg *spacesyncproto.ObjectSyncMessage - newTreeRequest := objectsync.NewRequestFactory().CreateHeadUpdate(netTree, res.Added) - objMsg, err = objectsync.MarshallTreeMessage(newTreeRequest, request.SpaceId, request.ObjectId, "") - if err != nil { - return - } - return h.manager().Broadcast(context.Background(), objMsg) + headUpdate := NewRequestFactory().CreateHeadUpdate(netTree, res.Added) + h.syncClient.Broadcast(headUpdate) + return nil } -func (h *testSyncHandler) manager() *testMessagePool { - if h.SyncHandler != nil { - return h.SyncHandler.(*syncTreeHandler).syncClient.MessagePool().(*testMessagePool) - } - return h.syncClient.MessagePool().(*testMessagePool) +func (h *testSyncHandler) manager() *requestPeerManager { + return h.peerManager } func (h *testSyncHandler) tree() *broadcastTree { @@ -211,74 +272,28 @@ func (h *testSyncHandler) run(ctx context.Context, t *testing.T, wg *sync.WaitGr h.tree().Unlock() continue } - err = h.HandleMessage(ctx, res.senderId, res.msg) - if err != nil { - fmt.Println("error handling message", err.Error()) - continue + if res.description().name == "FullSyncRequest" { + resp, err := h.HandleRequest(ctx, res.senderId, res.msg) + if err != nil { + fmt.Println("error handling request", err.Error()) + continue + } + h.peerManager.SendPeer(ctx, res.senderId, resp) + } else { + err = h.HandleMessage(ctx, res.senderId, res.msg) + if err != nil { + fmt.Println("error handling message", err.Error()) + } } } }() } -// testMessagePool captures all other handlers and sends messages to them -type testMessagePool struct { - peerId string - handlers map[string]*testSyncHandler - log *messageLog -} - -func newTestMessagePool(peerId string, log *messageLog) *testMessagePool { - return &testMessagePool{handlers: map[string]*testSyncHandler{}, peerId: peerId, log: log} -} - -func (m *testMessagePool) addHandler(peerId string, handler *testSyncHandler) { - m.handlers[peerId] = handler -} - -func (m *testMessagePool) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { - pMsg := protocolMsg{ - msg: msg, - senderId: m.peerId, - receiverId: peerId, - } - m.log.addMessage(pMsg) - return m.handlers[peerId].send(context.Background(), pMsg) -} - -func (m *testMessagePool) Broadcast(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage) (err error) { - for _, handler := range m.handlers { - pMsg := protocolMsg{ - msg: msg, - senderId: m.peerId, - receiverId: handler.peerId, - } - m.log.addMessage(pMsg) - handler.send(context.Background(), pMsg) - } - return -} - -func (m *testMessagePool) GetResponsiblePeers(ctx context.Context) (peers []peer.Peer, err error) { - panic("should not be called") -} - -func (m *testMessagePool) LastUsage() time.Time { - panic("should not be called") -} - -func (m *testMessagePool) HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) { - panic("should not be called") -} - -func (m *testMessagePool) SendSync(ctx context.Context, peerId string, message *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) { - panic("should not be called") -} - // broadcastTree is the tree that broadcasts changes to everyone when changes are added // it is a simplified version of SyncTree which is easier to use in the test environment type broadcastTree struct { objecttree.ObjectTree - objectsync.SyncClient + SyncClient } func (b *broadcastTree) AddRawChanges(ctx context.Context, changes objecttree.RawChangesPayload) (objecttree.AddResult, error) { @@ -287,7 +302,7 @@ func (b *broadcastTree) AddRawChanges(ctx context.Context, changes objecttree.Ra return objecttree.AddResult{}, err } upd := b.SyncClient.CreateHeadUpdate(b.ObjectTree, res.Added) - b.SyncClient.Broadcast(ctx, upd) + b.SyncClient.Broadcast(upd) return res, nil } diff --git a/commonspace/objectmanager/objectmanager.go b/commonspace/objectmanager/objectmanager.go new file mode 100644 index 00000000..49637818 --- /dev/null +++ b/commonspace/objectmanager/objectmanager.go @@ -0,0 +1,98 @@ +package objectmanager + +import ( + "context" + "errors" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/commonspace/object/acl/syncacl" + "github.com/anyproto/any-sync/commonspace/object/syncobjectgetter" + "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" + "github.com/anyproto/any-sync/commonspace/object/treemanager" + "github.com/anyproto/any-sync/commonspace/settings" + "github.com/anyproto/any-sync/commonspace/spacestate" + "sync/atomic" +) + +var ( + ErrSpaceClosed = errors.New("space is closed") +) + +type ObjectManager interface { + treemanager.TreeManager + AddObject(object syncobjectgetter.SyncObject) + GetObject(ctx context.Context, objectId string) (obj syncobjectgetter.SyncObject, err error) +} + +type objectManager struct { + treemanager.TreeManager + spaceId string + reservedObjects []syncobjectgetter.SyncObject + spaceIsClosed *atomic.Bool +} + +func New(manager treemanager.TreeManager) ObjectManager { + return &objectManager{ + TreeManager: manager, + } +} + +func (o *objectManager) Init(a *app.App) (err error) { + state := a.MustComponent(spacestate.CName).(*spacestate.SpaceState) + o.spaceId = state.SpaceId + o.spaceIsClosed = state.SpaceIsClosed + settingsObject := a.MustComponent(settings.CName).(settings.Settings).SettingsObject() + acl := a.MustComponent(syncacl.CName).(*syncacl.SyncAcl) + o.AddObject(settingsObject) + o.AddObject(acl) + return nil +} + +func (o *objectManager) Run(ctx context.Context) (err error) { + return nil +} + +func (o *objectManager) Close(ctx context.Context) (err error) { + return nil +} + +func (o *objectManager) AddObject(object syncobjectgetter.SyncObject) { + o.reservedObjects = append(o.reservedObjects, object) +} + +func (o *objectManager) Name() string { + return treemanager.CName +} + +func (o *objectManager) GetTree(ctx context.Context, spaceId, treeId string) (objecttree.ObjectTree, error) { + if o.spaceIsClosed.Load() { + return nil, ErrSpaceClosed + } + if obj := o.getReservedObject(treeId); obj != nil { + return obj.(objecttree.ObjectTree), nil + } + return o.TreeManager.GetTree(ctx, spaceId, treeId) +} + +func (o *objectManager) getReservedObject(id string) syncobjectgetter.SyncObject { + for _, obj := range o.reservedObjects { + if obj != nil && obj.Id() == id { + return obj + } + } + return nil +} + +func (o *objectManager) GetObject(ctx context.Context, objectId string) (obj syncobjectgetter.SyncObject, err error) { + if o.spaceIsClosed.Load() { + return nil, ErrSpaceClosed + } + if obj := o.getReservedObject(objectId); obj != nil { + return obj, nil + } + t, err := o.TreeManager.GetTree(ctx, o.spaceId, objectId) + if err != nil { + return + } + obj = t.(syncobjectgetter.SyncObject) + return +} diff --git a/commonspace/objectsync/mock_objectsync/mock_objectsync.go b/commonspace/objectsync/mock_objectsync/mock_objectsync.go index 5aee48e8..2858c6f9 100644 --- a/commonspace/objectsync/mock_objectsync/mock_objectsync.go +++ b/commonspace/objectsync/mock_objectsync/mock_objectsync.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/anyproto/any-sync/commonspace/objectsync (interfaces: SyncClient) +// Source: github.com/anyproto/any-sync/commonspace/objectsync (interfaces: ObjectSync) // Package mock_objectsync is a generated GoMock package. package mock_objectsync @@ -7,146 +7,146 @@ package mock_objectsync import ( context "context" reflect "reflect" + time "time" - objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" - treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" + app "github.com/anyproto/any-sync/app" objectsync "github.com/anyproto/any-sync/commonspace/objectsync" spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto" gomock "github.com/golang/mock/gomock" ) -// MockSyncClient is a mock of SyncClient interface. -type MockSyncClient struct { +// MockObjectSync is a mock of ObjectSync interface. +type MockObjectSync struct { ctrl *gomock.Controller - recorder *MockSyncClientMockRecorder + recorder *MockObjectSyncMockRecorder } -// MockSyncClientMockRecorder is the mock recorder for MockSyncClient. -type MockSyncClientMockRecorder struct { - mock *MockSyncClient +// MockObjectSyncMockRecorder is the mock recorder for MockObjectSync. +type MockObjectSyncMockRecorder struct { + mock *MockObjectSync } -// NewMockSyncClient creates a new mock instance. -func NewMockSyncClient(ctrl *gomock.Controller) *MockSyncClient { - mock := &MockSyncClient{ctrl: ctrl} - mock.recorder = &MockSyncClientMockRecorder{mock} +// NewMockObjectSync creates a new mock instance. +func NewMockObjectSync(ctrl *gomock.Controller) *MockObjectSync { + mock := &MockObjectSync{ctrl: ctrl} + mock.recorder = &MockObjectSyncMockRecorder{mock} return mock } // EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockSyncClient) EXPECT() *MockSyncClientMockRecorder { +func (m *MockObjectSync) EXPECT() *MockObjectSyncMockRecorder { return m.recorder } -// Broadcast mocks base method. -func (m *MockSyncClient) Broadcast(arg0 context.Context, arg1 *treechangeproto.TreeSyncMessage) { +// Close mocks base method. +func (m *MockObjectSync) Close(arg0 context.Context) error { m.ctrl.T.Helper() - m.ctrl.Call(m, "Broadcast", arg0, arg1) -} - -// Broadcast indicates an expected call of Broadcast. -func (mr *MockSyncClientMockRecorder) Broadcast(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Broadcast", reflect.TypeOf((*MockSyncClient)(nil).Broadcast), arg0, arg1) -} - -// CreateFullSyncRequest mocks base method. -func (m *MockSyncClient) CreateFullSyncRequest(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1, arg2) - ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest. -func (mr *MockSyncClientMockRecorder) CreateFullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncRequest), arg0, arg1, arg2) -} - -// CreateFullSyncResponse mocks base method. -func (m *MockSyncClient) CreateFullSyncResponse(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1, arg2) - ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse. -func (mr *MockSyncClientMockRecorder) CreateFullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncResponse), arg0, arg1, arg2) -} - -// CreateHeadUpdate mocks base method. -func (m *MockSyncClient) CreateHeadUpdate(arg0 objecttree.ObjectTree, arg1 []*treechangeproto.RawTreeChangeWithId) *treechangeproto.TreeSyncMessage { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1) - ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + ret := m.ctrl.Call(m, "Close", arg0) + ret0, _ := ret[0].(error) return ret0 } -// CreateHeadUpdate indicates an expected call of CreateHeadUpdate. -func (mr *MockSyncClientMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call { +// Close indicates an expected call of Close. +func (mr *MockObjectSyncMockRecorder) Close(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockSyncClient)(nil).CreateHeadUpdate), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockObjectSync)(nil).Close), arg0) } -// CreateNewTreeRequest mocks base method. -func (m *MockSyncClient) CreateNewTreeRequest() *treechangeproto.TreeSyncMessage { +// CloseThread mocks base method. +func (m *MockObjectSync) CloseThread(arg0 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "CreateNewTreeRequest") - ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) + ret := m.ctrl.Call(m, "CloseThread", arg0) + ret0, _ := ret[0].(error) return ret0 } -// CreateNewTreeRequest indicates an expected call of CreateNewTreeRequest. -func (mr *MockSyncClientMockRecorder) CreateNewTreeRequest() *gomock.Call { +// CloseThread indicates an expected call of CloseThread. +func (mr *MockObjectSyncMockRecorder) CloseThread(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNewTreeRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateNewTreeRequest)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseThread", reflect.TypeOf((*MockObjectSync)(nil).CloseThread), arg0) } -// MessagePool mocks base method. -func (m *MockSyncClient) MessagePool() objectsync.MessagePool { +// HandleMessage mocks base method. +func (m *MockObjectSync) HandleMessage(arg0 context.Context, arg1 objectsync.HandleMessage) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MessagePool") - ret0, _ := ret[0].(objectsync.MessagePool) + ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1) + ret0, _ := ret[0].(error) return ret0 } -// MessagePool indicates an expected call of MessagePool. -func (mr *MockSyncClientMockRecorder) MessagePool() *gomock.Call { +// HandleMessage indicates an expected call of HandleMessage. +func (mr *MockObjectSyncMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MessagePool", reflect.TypeOf((*MockSyncClient)(nil).MessagePool)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockObjectSync)(nil).HandleMessage), arg0, arg1) } -// SendSync mocks base method. -func (m *MockSyncClient) SendSync(arg0 context.Context, arg1, arg2 string, arg3 *treechangeproto.TreeSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) { +// HandleRequest mocks base method. +func (m *MockObjectSync) HandleRequest(arg0 context.Context, arg1 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendSync", arg0, arg1, arg2, arg3) + ret := m.ctrl.Call(m, "HandleRequest", arg0, arg1) ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage) ret1, _ := ret[1].(error) return ret0, ret1 } -// SendSync indicates an expected call of SendSync. -func (mr *MockSyncClientMockRecorder) SendSync(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { +// HandleRequest indicates an expected call of HandleRequest. +func (mr *MockObjectSyncMockRecorder) HandleRequest(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSync", reflect.TypeOf((*MockSyncClient)(nil).SendSync), arg0, arg1, arg2, arg3) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleRequest", reflect.TypeOf((*MockObjectSync)(nil).HandleRequest), arg0, arg1) } -// SendWithReply mocks base method. -func (m *MockSyncClient) SendWithReply(arg0 context.Context, arg1, arg2 string, arg3 *treechangeproto.TreeSyncMessage, arg4 string) error { +// Init mocks base method. +func (m *MockObjectSync) Init(arg0 *app.App) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "SendWithReply", arg0, arg1, arg2, arg3, arg4) + ret := m.ctrl.Call(m, "Init", arg0) ret0, _ := ret[0].(error) return ret0 } -// SendWithReply indicates an expected call of SendWithReply. -func (mr *MockSyncClientMockRecorder) SendWithReply(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { +// Init indicates an expected call of Init. +func (mr *MockObjectSyncMockRecorder) Init(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWithReply", reflect.TypeOf((*MockSyncClient)(nil).SendWithReply), arg0, arg1, arg2, arg3, arg4) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockObjectSync)(nil).Init), arg0) +} + +// LastUsage mocks base method. +func (m *MockObjectSync) LastUsage() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastUsage") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// LastUsage indicates an expected call of LastUsage. +func (mr *MockObjectSyncMockRecorder) LastUsage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastUsage", reflect.TypeOf((*MockObjectSync)(nil).LastUsage)) +} + +// Name mocks base method. +func (m *MockObjectSync) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name. +func (mr *MockObjectSyncMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockObjectSync)(nil).Name)) +} + +// Run mocks base method. +func (m *MockObjectSync) Run(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Run", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Run indicates an expected call of Run. +func (mr *MockObjectSyncMockRecorder) Run(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockObjectSync)(nil).Run), arg0) } diff --git a/commonspace/objectsync/msgpool.go b/commonspace/objectsync/msgpool.go deleted file mode 100644 index 19098ace..00000000 --- a/commonspace/objectsync/msgpool.go +++ /dev/null @@ -1,142 +0,0 @@ -package objectsync - -import ( - "context" - "fmt" - "github.com/anyproto/any-sync/commonspace/objectsync/synchandler" - "github.com/anyproto/any-sync/commonspace/peermanager" - "github.com/anyproto/any-sync/commonspace/spacesyncproto" - "go.uber.org/zap" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" -) - -type LastUsage interface { - LastUsage() time.Time -} - -// MessagePool can be made generic to work with different streams -type MessagePool interface { - LastUsage - synchandler.SyncHandler - peermanager.PeerManager - SendSync(ctx context.Context, peerId string, message *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) -} - -type MessageHandler func(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) - -type responseWaiter struct { - ch chan *spacesyncproto.ObjectSyncMessage -} - -type messagePool struct { - sync.Mutex - peermanager.PeerManager - messageHandler MessageHandler - waiters map[string]responseWaiter - waitersMx sync.Mutex - counter atomic.Uint64 - lastUsage atomic.Int64 -} - -func newMessagePool(peerManager peermanager.PeerManager, messageHandler MessageHandler) MessagePool { - s := &messagePool{ - PeerManager: peerManager, - messageHandler: messageHandler, - waiters: make(map[string]responseWaiter), - } - return s -} - -func (s *messagePool) SendSync(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) { - s.updateLastUsage() - if _, ok := ctx.Deadline(); !ok { - var cancel context.CancelFunc - ctx, cancel = context.WithTimeout(ctx, time.Minute) - defer cancel() - } - newCounter := s.counter.Add(1) - msg.RequestId = genReplyKey(peerId, msg.ObjectId, newCounter) - log.InfoCtx(ctx, "mpool sendSync", zap.String("requestId", msg.RequestId)) - s.waitersMx.Lock() - waiter := responseWaiter{ - ch: make(chan *spacesyncproto.ObjectSyncMessage, 1), - } - s.waiters[msg.RequestId] = waiter - s.waitersMx.Unlock() - - err = s.SendPeer(ctx, peerId, msg) - if err != nil { - return - } - select { - case <-ctx.Done(): - s.waitersMx.Lock() - delete(s.waiters, msg.RequestId) - s.waitersMx.Unlock() - - log.With(zap.String("requestId", msg.RequestId)).DebugCtx(ctx, "time elapsed when waiting") - err = fmt.Errorf("sendSync context error: %v", ctx.Err()) - case reply = <-waiter.ch: - // success - } - return -} - -func (s *messagePool) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { - s.updateLastUsage() - return s.PeerManager.SendPeer(ctx, peerId, msg) -} - -func (s *messagePool) Broadcast(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage) (err error) { - s.updateLastUsage() - return s.PeerManager.Broadcast(ctx, msg) -} - -func (s *messagePool) HandleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { - s.updateLastUsage() - if msg.ReplyId != "" { - log.InfoCtx(ctx, "mpool receive reply", zap.String("replyId", msg.ReplyId)) - // we got reply, send it to waiter - if s.stopWaiter(msg) { - return - } - log.WarnCtx(ctx, "reply id does not exist", zap.String("replyId", msg.ReplyId)) - return - } - return s.messageHandler(ctx, senderId, msg) -} - -func (s *messagePool) LastUsage() time.Time { - return time.Unix(s.lastUsage.Load(), 0) -} - -func (s *messagePool) updateLastUsage() { - s.lastUsage.Store(time.Now().Unix()) -} - -func (s *messagePool) stopWaiter(msg *spacesyncproto.ObjectSyncMessage) bool { - s.waitersMx.Lock() - waiter, exists := s.waiters[msg.ReplyId] - if exists { - delete(s.waiters, msg.ReplyId) - s.waitersMx.Unlock() - waiter.ch <- msg - return true - } - s.waitersMx.Unlock() - return false -} - -func genReplyKey(peerId, treeId string, counter uint64) string { - b := &strings.Builder{} - b.WriteString(peerId) - b.WriteString(".") - b.WriteString(treeId) - b.WriteString(".") - b.WriteString(strconv.FormatUint(counter, 36)) - return b.String() -} diff --git a/commonspace/objectsync/objectsync.go b/commonspace/objectsync/objectsync.go index e56ee6ed..0ec7f344 100644 --- a/commonspace/objectsync/objectsync.go +++ b/commonspace/objectsync/objectsync.go @@ -1,18 +1,23 @@ -//go:generate mockgen -destination mock_objectsync/mock_objectsync.go github.com/anyproto/any-sync/commonspace/objectsync SyncClient +//go:generate mockgen -destination mock_objectsync/mock_objectsync.go github.com/anyproto/any-sync/commonspace/objectsync ObjectSync package objectsync import ( "context" "fmt" + "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" + "github.com/anyproto/any-sync/commonspace/object/treemanager" + "github.com/anyproto/any-sync/commonspace/spacestate" + "github.com/anyproto/any-sync/metric" + "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/util/multiqueue" + "github.com/cheggaaa/mb/v3" "github.com/gogo/protobuf/proto" "sync/atomic" "time" "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/commonspace/object/syncobjectgetter" - "github.com/anyproto/any-sync/commonspace/objectsync/synchandler" - "github.com/anyproto/any-sync/commonspace/peermanager" "github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/nodeconf" @@ -20,138 +25,211 @@ import ( "golang.org/x/exp/slices" ) -var log = logger.NewNamed("common.commonspace.objectsync") +const CName = "common.commonspace.objectsync" + +var log = logger.NewNamed(CName) type ObjectSync interface { - LastUsage - synchandler.SyncHandler - SyncClient() SyncClient + LastUsage() time.Time + HandleMessage(ctx context.Context, hm HandleMessage) (err error) + HandleRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) + CloseThread(id string) (err error) + app.ComponentRunnable +} - Close() (err error) +type HandleMessage struct { + Id uint64 + ReceiveTime time.Time + StartHandlingTime time.Time + Deadline time.Time + SenderId string + Message *spacesyncproto.ObjectSyncMessage + PeerCtx context.Context +} + +func (m HandleMessage) LogFields(fields ...zap.Field) []zap.Field { + return append(fields, + metric.SpaceId(m.Message.SpaceId), + metric.ObjectId(m.Message.ObjectId), + metric.QueueDur(m.StartHandlingTime.Sub(m.ReceiveTime)), + metric.TotalDur(time.Since(m.ReceiveTime)), + ) } type objectSync struct { spaceId string - messagePool MessagePool - syncClient SyncClient objectGetter syncobjectgetter.SyncObjectGetter configuration nodeconf.NodeConf spaceStorage spacestorage.SpaceStorage + metric metric.Metric - syncCtx context.Context - cancelSync context.CancelFunc spaceIsDeleted *atomic.Bool + handleQueue multiqueue.MultiQueue[HandleMessage] } -func NewObjectSync( - spaceId string, - spaceIsDeleted *atomic.Bool, - configuration nodeconf.NodeConf, - peerManager peermanager.PeerManager, - objectGetter syncobjectgetter.SyncObjectGetter, - storage spacestorage.SpaceStorage) ObjectSync { - syncCtx, cancel := context.WithCancel(context.Background()) - os := &objectSync{ - objectGetter: objectGetter, - spaceStorage: storage, - spaceId: spaceId, - syncCtx: syncCtx, - cancelSync: cancel, - spaceIsDeleted: spaceIsDeleted, - configuration: configuration, +func (s *objectSync) Init(a *app.App) (err error) { + s.spaceStorage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage) + s.objectGetter = a.MustComponent(treemanager.CName).(syncobjectgetter.SyncObjectGetter) + s.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) + sharedData := a.MustComponent(spacestate.CName).(*spacestate.SpaceState) + mc := a.Component(metric.CName) + if mc != nil { + s.metric = mc.(metric.Metric) } - os.messagePool = newMessagePool(peerManager, os.handleMessage) - os.syncClient = NewSyncClient(spaceId, os.messagePool, NewRequestFactory()) - return os + s.spaceIsDeleted = sharedData.SpaceIsDeleted + s.spaceId = sharedData.SpaceId + s.handleQueue = multiqueue.New[HandleMessage](s.processHandleMessage, 100) + return nil } -func (s *objectSync) Close() (err error) { - s.cancelSync() - return +func (s *objectSync) Name() (name string) { + return CName +} + +func (s *objectSync) Run(ctx context.Context) (err error) { + return nil +} + +func (s *objectSync) Close(ctx context.Context) (err error) { + return s.handleQueue.Close() +} + +func New() ObjectSync { + return &objectSync{} } func (s *objectSync) LastUsage() time.Time { - return s.messagePool.LastUsage() + // TODO: add time + return time.Time{} } -func (s *objectSync) HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) { - return s.messagePool.HandleMessage(ctx, senderId, message) +func (s *objectSync) HandleRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) { + peerId, err := peer.CtxPeerId(ctx) + if err != nil { + return nil, err + } + return s.handleRequest(ctx, peerId, req) } -func (s *objectSync) handleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { - log := log.With( - zap.String("objectId", msg.ObjectId), - zap.String("requestId", msg.RequestId), - zap.String("replyId", msg.ReplyId)) +func (s *objectSync) HandleMessage(ctx context.Context, hm HandleMessage) (err error) { + threadId := hm.Message.ObjectId + hm.ReceiveTime = time.Now() + if hm.PeerCtx == nil { + hm.PeerCtx = ctx + } + err = s.handleQueue.Add(ctx, threadId, hm) + if err == mb.ErrOverflowed { + log.InfoCtx(ctx, "queue overflowed", zap.String("spaceId", s.spaceId), zap.String("objectId", threadId)) + // skip overflowed error + return nil + } + return +} + +func (s *objectSync) processHandleMessage(msg HandleMessage) { + var err error + msg.StartHandlingTime = time.Now() + ctx := peer.CtxWithPeerId(context.Background(), msg.SenderId) + ctx = logger.CtxWithFields(ctx, zap.Uint64("msgId", msg.Id), zap.String("senderId", msg.SenderId)) + defer func() { + if s.metric == nil { + return + } + s.metric.RequestLog(msg.PeerCtx, "space.streamOp", msg.LogFields( + zap.Error(err), + )...) + }() + + if !msg.Deadline.IsZero() { + now := time.Now() + if now.After(msg.Deadline) { + log.InfoCtx(ctx, "skip message: deadline exceed") + err = context.DeadlineExceeded + return + } + var cancel context.CancelFunc + ctx, cancel = context.WithDeadline(ctx, msg.Deadline) + defer cancel() + } + if err = s.handleMessage(ctx, msg.SenderId, msg.Message); err != nil { + if msg.Message.ObjectId != "" { + // cleanup thread on error + _ = s.handleQueue.CloseThread(msg.Message.ObjectId) + } + log.InfoCtx(ctx, "handleMessage error", zap.Error(err)) + } +} + +func (s *objectSync) handleRequest(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) { + log := log.With(zap.String("objectId", msg.ObjectId)) if s.spaceIsDeleted.Load() { log = log.With(zap.Bool("isDeleted", true)) // preventing sync with other clients if they are not just syncing the settings tree if !slices.Contains(s.configuration.NodeIds(s.spaceId), senderId) && msg.ObjectId != s.spaceStorage.SpaceSettingsId() { - s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId) - return fmt.Errorf("can't perform operation with object, space is deleted") + return nil, spacesyncproto.ErrSpaceIsDeleted } } - log.DebugCtx(ctx, "handling message") + err = s.checkEmptyFullSync(log, msg) + if err != nil { + return nil, err + } + obj, err := s.objectGetter.GetObject(ctx, msg.ObjectId) + if err != nil { + return nil, treechangeproto.ErrGetTree + } + return obj.HandleRequest(ctx, senderId, msg) +} + +func (s *objectSync) handleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { + log := log.With(zap.String("objectId", msg.ObjectId)) + if s.spaceIsDeleted.Load() { + log = log.With(zap.Bool("isDeleted", true)) + // preventing sync with other clients if they are not just syncing the settings tree + if !slices.Contains(s.configuration.NodeIds(s.spaceId), senderId) && msg.ObjectId != s.spaceStorage.SpaceSettingsId() { + return spacesyncproto.ErrSpaceIsDeleted + } + } + err = s.checkEmptyFullSync(log, msg) + if err != nil { + return err + } + obj, err := s.objectGetter.GetObject(ctx, msg.ObjectId) + if err != nil { + return fmt.Errorf("failed to get object from cache: %w", err) + } + err = obj.HandleMessage(ctx, senderId, msg) + if err != nil { + return fmt.Errorf("failed to handle message: %w", err) + } + return +} + +func (s *objectSync) CloseThread(id string) (err error) { + return s.handleQueue.CloseThread(id) +} + +func (s *objectSync) checkEmptyFullSync(log logger.CtxLogger, msg *spacesyncproto.ObjectSyncMessage) (err error) { hasTree, err := s.spaceStorage.HasTree(msg.ObjectId) if err != nil { - s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId) - return fmt.Errorf("falied to execute get operation on storage has tree: %w", err) + log.Warn("failed to execute get operation on storage has tree", zap.Error(err)) + return spacesyncproto.ErrUnexpected } // in this case we will try to get it from remote, unless the sender also sent us the same request :-) if !hasTree { treeMsg := &treechangeproto.TreeSyncMessage{} err = proto.Unmarshal(msg.Payload, treeMsg) if err != nil { - s.sendError(ctx, nil, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId, msg.RequestId) - return fmt.Errorf("failed to unmarshall tree sync message: %w", err) + return nil } // this means that we don't have the tree locally and therefore can't return it if s.isEmptyFullSyncRequest(treeMsg) { - err = treechangeproto.ErrGetTree - s.sendError(ctx, nil, treechangeproto.ErrGetTree, senderId, msg.ObjectId, msg.RequestId) - return fmt.Errorf("failed to get tree from storage on full sync: %w", err) + return treechangeproto.ErrGetTree } } - obj, err := s.objectGetter.GetObject(ctx, msg.ObjectId) - if err != nil { - // TODO: write tests for object sync https://linear.app/anytype/issue/GO-1299/write-tests-for-commonspaceobjectsync - s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId) - return fmt.Errorf("failed to get object from cache: %w", err) - } - // TODO: unmarshall earlier - err = obj.HandleMessage(ctx, senderId, msg) - if err != nil { - s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId) - return fmt.Errorf("failed to handle message: %w", err) - } return } -func (s *objectSync) SyncClient() SyncClient { - return s.syncClient -} - -func (s *objectSync) unmarshallSendError(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage, respErr error, senderId, objectId string) { - unmarshalled := &treechangeproto.TreeSyncMessage{} - err := proto.Unmarshal(msg.Payload, unmarshalled) - if err != nil { - return - } - s.sendError(ctx, unmarshalled.RootChange, respErr, senderId, objectId, msg.RequestId) -} - -func (s *objectSync) sendError(ctx context.Context, root *treechangeproto.RawTreeChangeWithId, respErr error, senderId, objectId, replyId string) { - // we don't send errors if have no reply id, this can lead to bugs and also nobody needs this error - if replyId == "" { - return - } - resp := treechangeproto.WrapError(respErr, root) - if err := s.syncClient.SendWithReply(ctx, senderId, objectId, resp, replyId); err != nil { - log.InfoCtx(ctx, "failed to send error to client") - } -} - func (s *objectSync) isEmptyFullSyncRequest(msg *treechangeproto.TreeSyncMessage) bool { return msg.GetContent().GetFullSyncRequest() != nil && len(msg.GetContent().GetFullSyncRequest().GetHeads()) == 0 } diff --git a/commonspace/objectsync/syncclient.go b/commonspace/objectsync/syncclient.go deleted file mode 100644 index aad712bd..00000000 --- a/commonspace/objectsync/syncclient.go +++ /dev/null @@ -1,78 +0,0 @@ -package objectsync - -import ( - "context" - "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" - "github.com/anyproto/any-sync/commonspace/spacesyncproto" - "go.uber.org/zap" -) - -type SyncClient interface { - RequestFactory - Broadcast(ctx context.Context, msg *treechangeproto.TreeSyncMessage) - SendWithReply(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage, replyId string) (err error) - SendSync(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) - MessagePool() MessagePool -} - -type syncClient struct { - RequestFactory - spaceId string - messagePool MessagePool -} - -func NewSyncClient( - spaceId string, - messagePool MessagePool, - factory RequestFactory) SyncClient { - return &syncClient{ - messagePool: messagePool, - RequestFactory: factory, - spaceId: spaceId, - } -} - -func (s *syncClient) Broadcast(ctx context.Context, msg *treechangeproto.TreeSyncMessage) { - objMsg, err := MarshallTreeMessage(msg, s.spaceId, msg.RootChange.Id, "") - if err != nil { - return - } - err = s.messagePool.Broadcast(ctx, objMsg) - if err != nil { - log.DebugCtx(ctx, "broadcast error", zap.Error(err)) - } -} - -func (s *syncClient) SendSync(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) { - objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, "") - if err != nil { - return - } - return s.messagePool.SendSync(ctx, peerId, objMsg) -} - -func (s *syncClient) SendWithReply(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage, replyId string) (err error) { - objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, replyId) - if err != nil { - return - } - return s.messagePool.SendPeer(ctx, peerId, objMsg) -} - -func (s *syncClient) MessagePool() MessagePool { - return s.messagePool -} - -func MarshallTreeMessage(message *treechangeproto.TreeSyncMessage, spaceId, objectId, replyId string) (objMsg *spacesyncproto.ObjectSyncMessage, err error) { - payload, err := message.Marshal() - if err != nil { - return - } - objMsg = &spacesyncproto.ObjectSyncMessage{ - ReplyId: replyId, - Payload: payload, - ObjectId: objectId, - SpaceId: spaceId, - } - return -} diff --git a/commonspace/objectsync/synchandler/synchhandler.go b/commonspace/objectsync/synchandler/synchhandler.go index 35aebb1e..090118cd 100644 --- a/commonspace/objectsync/synchandler/synchhandler.go +++ b/commonspace/objectsync/synchandler/synchhandler.go @@ -6,5 +6,6 @@ import ( ) type SyncHandler interface { - HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) + HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) + HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) } diff --git a/commonspace/objecttreebuilder/mock_objecttreebuilder/mock_objecttreebuilder.go b/commonspace/objecttreebuilder/mock_objecttreebuilder/mock_objecttreebuilder.go new file mode 100644 index 00000000..d7cca965 --- /dev/null +++ b/commonspace/objecttreebuilder/mock_objecttreebuilder/mock_objecttreebuilder.go @@ -0,0 +1,99 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/anyproto/any-sync/commonspace/objecttreebuilder (interfaces: TreeBuilder) + +// Package mock_objecttreebuilder is a generated GoMock package. +package mock_objecttreebuilder + +import ( + context "context" + reflect "reflect" + + objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" + updatelistener "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" + treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage" + objecttreebuilder "github.com/anyproto/any-sync/commonspace/objecttreebuilder" + gomock "github.com/golang/mock/gomock" +) + +// MockTreeBuilder is a mock of TreeBuilder interface. +type MockTreeBuilder struct { + ctrl *gomock.Controller + recorder *MockTreeBuilderMockRecorder +} + +// MockTreeBuilderMockRecorder is the mock recorder for MockTreeBuilder. +type MockTreeBuilderMockRecorder struct { + mock *MockTreeBuilder +} + +// NewMockTreeBuilder creates a new mock instance. +func NewMockTreeBuilder(ctrl *gomock.Controller) *MockTreeBuilder { + mock := &MockTreeBuilder{ctrl: ctrl} + mock.recorder = &MockTreeBuilderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTreeBuilder) EXPECT() *MockTreeBuilderMockRecorder { + return m.recorder +} + +// BuildHistoryTree mocks base method. +func (m *MockTreeBuilder) BuildHistoryTree(arg0 context.Context, arg1 string, arg2 objecttreebuilder.HistoryTreeOpts) (objecttree.HistoryTree, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BuildHistoryTree", arg0, arg1, arg2) + ret0, _ := ret[0].(objecttree.HistoryTree) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BuildHistoryTree indicates an expected call of BuildHistoryTree. +func (mr *MockTreeBuilderMockRecorder) BuildHistoryTree(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildHistoryTree", reflect.TypeOf((*MockTreeBuilder)(nil).BuildHistoryTree), arg0, arg1, arg2) +} + +// BuildTree mocks base method. +func (m *MockTreeBuilder) BuildTree(arg0 context.Context, arg1 string, arg2 objecttreebuilder.BuildTreeOpts) (objecttree.ObjectTree, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "BuildTree", arg0, arg1, arg2) + ret0, _ := ret[0].(objecttree.ObjectTree) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// BuildTree indicates an expected call of BuildTree. +func (mr *MockTreeBuilderMockRecorder) BuildTree(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildTree", reflect.TypeOf((*MockTreeBuilder)(nil).BuildTree), arg0, arg1, arg2) +} + +// CreateTree mocks base method. +func (m *MockTreeBuilder) CreateTree(arg0 context.Context, arg1 objecttree.ObjectTreeCreatePayload) (treestorage.TreeStorageCreatePayload, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CreateTree", arg0, arg1) + ret0, _ := ret[0].(treestorage.TreeStorageCreatePayload) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CreateTree indicates an expected call of CreateTree. +func (mr *MockTreeBuilderMockRecorder) CreateTree(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTree", reflect.TypeOf((*MockTreeBuilder)(nil).CreateTree), arg0, arg1) +} + +// PutTree mocks base method. +func (m *MockTreeBuilder) PutTree(arg0 context.Context, arg1 treestorage.TreeStorageCreatePayload, arg2 updatelistener.UpdateListener) (objecttree.ObjectTree, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "PutTree", arg0, arg1, arg2) + ret0, _ := ret[0].(objecttree.ObjectTree) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// PutTree indicates an expected call of PutTree. +func (mr *MockTreeBuilderMockRecorder) PutTree(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutTree", reflect.TypeOf((*MockTreeBuilder)(nil).PutTree), arg0, arg1, arg2) +} diff --git a/commonspace/objecttreebuilder/treebuilder.go b/commonspace/objecttreebuilder/treebuilder.go new file mode 100644 index 00000000..058efe16 --- /dev/null +++ b/commonspace/objecttreebuilder/treebuilder.go @@ -0,0 +1,205 @@ +//go:generate mockgen -destination mock_objecttreebuilder/mock_objecttreebuilder.go github.com/anyproto/any-sync/commonspace/objecttreebuilder TreeBuilder +package objecttreebuilder + +import ( + "context" + "errors" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/commonspace/headsync" + "github.com/anyproto/any-sync/commonspace/object/acl/list" + "github.com/anyproto/any-sync/commonspace/object/acl/syncacl" + "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" + "github.com/anyproto/any-sync/commonspace/object/tree/synctree" + "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" + "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" + "github.com/anyproto/any-sync/commonspace/object/tree/treestorage" + "github.com/anyproto/any-sync/commonspace/objectsync" + "github.com/anyproto/any-sync/commonspace/peermanager" + "github.com/anyproto/any-sync/commonspace/requestmanager" + "github.com/anyproto/any-sync/commonspace/spacestate" + "github.com/anyproto/any-sync/commonspace/spacestorage" + "github.com/anyproto/any-sync/commonspace/syncstatus" + "github.com/anyproto/any-sync/nodeconf" + "go.uber.org/zap" + "sync/atomic" +) + +type BuildTreeOpts struct { + Listener updatelistener.UpdateListener + WaitTreeRemoteSync bool + TreeBuilder objecttree.BuildObjectTreeFunc +} + +const CName = "common.commonspace.objecttreebuilder" + +var log = logger.NewNamed(CName) +var ErrSpaceClosed = errors.New("space is closed") + +type HistoryTreeOpts struct { + BeforeId string + Include bool + BuildFullTree bool +} + +type TreeBuilder interface { + BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (t objecttree.ObjectTree, err error) + BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (t objecttree.HistoryTree, err error) + CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error) + PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (t objecttree.ObjectTree, err error) +} + +type TreeBuilderComponent interface { + app.Component + TreeBuilder +} + +func New() TreeBuilderComponent { + return &treeBuilder{} +} + +type treeBuilder struct { + syncClient synctree.SyncClient + configuration nodeconf.NodeConf + headsNotifiable synctree.HeadNotifiable + peerManager peermanager.PeerManager + requestManager requestmanager.RequestManager + spaceStorage spacestorage.SpaceStorage + syncStatus syncstatus.StatusUpdater + objectSync objectsync.ObjectSync + + log logger.CtxLogger + builder objecttree.BuildObjectTreeFunc + spaceId string + aclList list.AclList + treesUsed *atomic.Int32 + isClosed *atomic.Bool +} + +func (t *treeBuilder) Init(a *app.App) (err error) { + state := a.MustComponent(spacestate.CName).(*spacestate.SpaceState) + t.spaceId = state.SpaceId + t.isClosed = state.SpaceIsClosed + t.treesUsed = state.TreesUsed + t.builder = state.TreeBuilderFunc + t.aclList = a.MustComponent(syncacl.CName).(*syncacl.SyncAcl) + t.spaceStorage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage) + t.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) + t.headsNotifiable = a.MustComponent(headsync.CName).(headsync.HeadSync) + t.syncStatus = a.MustComponent(syncstatus.CName).(syncstatus.StatusUpdater) + t.peerManager = a.MustComponent(peermanager.CName).(peermanager.PeerManager) + t.requestManager = a.MustComponent(requestmanager.CName).(requestmanager.RequestManager) + t.objectSync = a.MustComponent(objectsync.CName).(objectsync.ObjectSync) + t.log = log.With(zap.String("spaceId", t.spaceId)) + t.syncClient = synctree.NewSyncClient(t.spaceId, t.requestManager, t.peerManager) + return nil +} + +func (t *treeBuilder) Name() (name string) { + return CName +} + +func (t *treeBuilder) BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (ot objecttree.ObjectTree, err error) { + if t.isClosed.Load() { + // TODO: change to real error + err = ErrSpaceClosed + return + } + treeBuilder := opts.TreeBuilder + if treeBuilder == nil { + treeBuilder = t.builder + } + deps := synctree.BuildDeps{ + SpaceId: t.spaceId, + SyncClient: t.syncClient, + Configuration: t.configuration, + HeadNotifiable: t.headsNotifiable, + Listener: opts.Listener, + AclList: t.aclList, + SpaceStorage: t.spaceStorage, + OnClose: t.onClose, + SyncStatus: t.syncStatus, + WaitTreeRemoteSync: opts.WaitTreeRemoteSync, + PeerGetter: t.peerManager, + BuildObjectTree: treeBuilder, + } + t.treesUsed.Add(1) + t.log.Debug("incrementing counter", zap.String("id", id), zap.Int32("trees", t.treesUsed.Load())) + if ot, err = synctree.BuildSyncTreeOrGetRemote(ctx, id, deps); err != nil { + t.treesUsed.Add(-1) + t.log.Debug("decrementing counter, load failed", zap.String("id", id), zap.Int32("trees", t.treesUsed.Load()), zap.Error(err)) + return nil, err + } + return +} + +func (t *treeBuilder) BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (ot objecttree.HistoryTree, err error) { + if t.isClosed.Load() { + // TODO: change to real error + err = ErrSpaceClosed + return + } + + params := objecttree.HistoryTreeParams{ + AclList: t.aclList, + BeforeId: opts.BeforeId, + IncludeBeforeId: opts.Include, + BuildFullTree: opts.BuildFullTree, + } + params.TreeStorage, err = t.spaceStorage.TreeStorage(id) + if err != nil { + return + } + return objecttree.BuildHistoryTree(params) +} + +func (t *treeBuilder) CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error) { + if t.isClosed.Load() { + err = ErrSpaceClosed + return + } + root, err := objecttree.CreateObjectTreeRoot(payload, t.aclList) + if err != nil { + return + } + + res = treestorage.TreeStorageCreatePayload{ + RootRawChange: root, + Changes: []*treechangeproto.RawTreeChangeWithId{root}, + Heads: []string{root.Id}, + } + return +} + +func (t *treeBuilder) PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (ot objecttree.ObjectTree, err error) { + if t.isClosed.Load() { + err = ErrSpaceClosed + return + } + deps := synctree.BuildDeps{ + SpaceId: t.spaceId, + SyncClient: t.syncClient, + Configuration: t.configuration, + HeadNotifiable: t.headsNotifiable, + Listener: listener, + AclList: t.aclList, + SpaceStorage: t.spaceStorage, + OnClose: t.onClose, + SyncStatus: t.syncStatus, + PeerGetter: t.peerManager, + BuildObjectTree: t.builder, + } + ot, err = synctree.PutSyncTree(ctx, payload, deps) + if err != nil { + return + } + t.treesUsed.Add(1) + t.log.Debug("incrementing counter", zap.String("id", payload.RootRawChange.Id), zap.Int32("trees", t.treesUsed.Load())) + return +} + +func (t *treeBuilder) onClose(id string) { + t.treesUsed.Add(-1) + log.Debug("decrementing counter", zap.String("id", id), zap.Int32("trees", t.treesUsed.Load()), zap.String("spaceId", t.spaceId)) + _ = t.objectSync.CloseThread(id) +} diff --git a/commonspace/payloads.go b/commonspace/payloads.go index 5d2284d0..8d1c334c 100644 --- a/commonspace/payloads.go +++ b/commonspace/payloads.go @@ -214,14 +214,15 @@ func validateSpaceStorageCreatePayload(payload spacestorage.SpaceStorageCreatePa } func ValidateSpaceHeader(rawHeaderWithId *spacesyncproto.RawSpaceHeaderWithId, identity crypto.PubKey) (err error) { + if rawHeaderWithId == nil { + return spacestorage.ErrIncorrectSpaceHeader + } sepIdx := strings.Index(rawHeaderWithId.Id, ".") if sepIdx == -1 { - err = spacestorage.ErrIncorrectSpaceHeader - return + return spacestorage.ErrIncorrectSpaceHeader } if !cidutil.VerifyCid(rawHeaderWithId.RawHeader, rawHeaderWithId.Id[:sepIdx]) { - err = objecttree.ErrIncorrectCid - return + return objecttree.ErrIncorrectCid } var rawSpaceHeader spacesyncproto.RawSpaceHeader err = proto.Unmarshal(rawHeaderWithId.RawHeader, &rawSpaceHeader) @@ -239,19 +240,16 @@ func ValidateSpaceHeader(rawHeaderWithId *spacesyncproto.RawSpaceHeaderWithId, i } res, err := payloadIdentity.Verify(rawSpaceHeader.SpaceHeader, rawSpaceHeader.Signature) if err != nil || !res { - err = spacestorage.ErrIncorrectSpaceHeader - return + return spacestorage.ErrIncorrectSpaceHeader } if rawHeaderWithId.Id[sepIdx+1:] != strconv.FormatUint(header.ReplicationKey, 36) { - err = spacestorage.ErrIncorrectSpaceHeader - return + return spacestorage.ErrIncorrectSpaceHeader } if identity == nil { return } if !payloadIdentity.Equals(identity) { - err = ErrIncorrectIdentity - return + return ErrIncorrectIdentity } return } diff --git a/commonspace/peermanager/mock_peermanager/mock_peermanager.go b/commonspace/peermanager/mock_peermanager/mock_peermanager.go index a77e0ab1..210fb319 100644 --- a/commonspace/peermanager/mock_peermanager/mock_peermanager.go +++ b/commonspace/peermanager/mock_peermanager/mock_peermanager.go @@ -8,6 +8,7 @@ import ( context "context" reflect "reflect" + app "github.com/anyproto/any-sync/app" spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto" peer "github.com/anyproto/any-sync/net/peer" gomock "github.com/golang/mock/gomock" @@ -65,6 +66,34 @@ func (mr *MockPeerManagerMockRecorder) GetResponsiblePeers(arg0 interface{}) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponsiblePeers", reflect.TypeOf((*MockPeerManager)(nil).GetResponsiblePeers), arg0) } +// Init mocks base method. +func (m *MockPeerManager) Init(arg0 *app.App) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Init", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Init indicates an expected call of Init. +func (mr *MockPeerManagerMockRecorder) Init(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockPeerManager)(nil).Init), arg0) +} + +// Name mocks base method. +func (m *MockPeerManager) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name. +func (mr *MockPeerManagerMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockPeerManager)(nil).Name)) +} + // SendPeer mocks base method. func (m *MockPeerManager) SendPeer(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) error { m.ctrl.T.Helper() diff --git a/commonspace/peermanager/peermanager.go b/commonspace/peermanager/peermanager.go index a4075f32..0a5750bf 100644 --- a/commonspace/peermanager/peermanager.go +++ b/commonspace/peermanager/peermanager.go @@ -8,9 +8,12 @@ import ( "github.com/anyproto/any-sync/net/peer" ) -const CName = "common.commonspace.peermanager" +const ( + CName = "common.commonspace.peermanager" +) type PeerManager interface { + app.Component // SendPeer sends a message to a stream by peerId SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) // Broadcast sends a message to all subscribed peers diff --git a/commonspace/requestmanager/requestmanager.go b/commonspace/requestmanager/requestmanager.go new file mode 100644 index 00000000..d72ea394 --- /dev/null +++ b/commonspace/requestmanager/requestmanager.go @@ -0,0 +1,126 @@ +package requestmanager + +import ( + "context" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/commonspace/objectsync" + "github.com/anyproto/any-sync/commonspace/spacesyncproto" + "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/net/pool" + "github.com/anyproto/any-sync/net/streampool" + "go.uber.org/zap" + "storj.io/drpc" + "sync" +) + +const CName = "common.commonspace.requestmanager" + +var log = logger.NewNamed(CName) + +type RequestManager interface { + app.ComponentRunnable + SendRequest(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) + QueueRequest(peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) +} + +func New() RequestManager { + return &requestManager{ + workers: 10, + queueSize: 300, + pools: map[string]*streampool.ExecPool{}, + } +} + +type MessageHandler interface { + HandleMessage(ctx context.Context, hm objectsync.HandleMessage) (err error) +} + +type requestManager struct { + sync.Mutex + pools map[string]*streampool.ExecPool + peerPool pool.Pool + workers int + queueSize int + handler MessageHandler + ctx context.Context + cancel context.CancelFunc + clientFactory spacesyncproto.ClientFactory +} + +func (r *requestManager) Init(a *app.App) (err error) { + r.ctx, r.cancel = context.WithCancel(context.Background()) + r.handler = a.MustComponent(objectsync.CName).(MessageHandler) + r.peerPool = a.MustComponent(pool.CName).(pool.Pool) + r.clientFactory = spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient) + return +} + +func (r *requestManager) Name() (name string) { + return CName +} + +func (r *requestManager) Run(ctx context.Context) (err error) { + return nil +} + +func (r *requestManager) Close(ctx context.Context) (err error) { + r.Lock() + defer r.Unlock() + r.cancel() + for _, p := range r.pools { + _ = p.Close() + } + return nil +} + +func (r *requestManager) SendRequest(ctx context.Context, peerId string, req *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) { + // TODO: limit concurrent sends? + return r.doRequest(ctx, peerId, req) +} + +func (r *requestManager) QueueRequest(peerId string, req *spacesyncproto.ObjectSyncMessage) (err error) { + r.Lock() + defer r.Unlock() + pl, exists := r.pools[peerId] + if !exists { + pl = streampool.NewExecPool(r.workers, r.queueSize) + r.pools[peerId] = pl + pl.Run() + } + // TODO: for later think when many clients are there, + // we need to close pools for inactive clients + return pl.TryAdd(func() { + doRequestAndHandle(r, peerId, req) + }) +} + +var doRequestAndHandle = (*requestManager).requestAndHandle + +func (r *requestManager) requestAndHandle(peerId string, req *spacesyncproto.ObjectSyncMessage) { + ctx := r.ctx + resp, err := r.doRequest(ctx, peerId, req) + if err != nil { + log.Warn("failed to send request", zap.Error(err)) + return + } + ctx = peer.CtxWithPeerId(ctx, peerId) + _ = r.handler.HandleMessage(ctx, objectsync.HandleMessage{ + SenderId: peerId, + Message: resp, + PeerCtx: ctx, + }) +} + +func (r *requestManager) doRequest(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) { + pr, err := r.peerPool.Get(ctx, peerId) + if err != nil { + return + } + err = pr.DoDrpc(ctx, func(conn drpc.Conn) error { + cl := r.clientFactory.Client(conn) + resp, err = cl.ObjectSync(ctx, msg) + return err + }) + return +} diff --git a/commonspace/requestmanager/requestmanager_test.go b/commonspace/requestmanager/requestmanager_test.go new file mode 100644 index 00000000..35d497fc --- /dev/null +++ b/commonspace/requestmanager/requestmanager_test.go @@ -0,0 +1,189 @@ +package requestmanager + +import ( + "context" + "github.com/anyproto/any-sync/commonspace/objectsync" + "github.com/anyproto/any-sync/commonspace/objectsync/mock_objectsync" + "github.com/anyproto/any-sync/commonspace/spacesyncproto" + "github.com/anyproto/any-sync/commonspace/spacesyncproto/mock_spacesyncproto" + "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/net/peer/mock_peer" + "github.com/anyproto/any-sync/net/pool/mock_pool" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/require" + "storj.io/drpc" + "storj.io/drpc/drpcconn" + "sync" + "testing" + "time" +) + +type fixture struct { + requestManager *requestManager + messageHandlerMock *mock_objectsync.MockObjectSync + peerPoolMock *mock_pool.MockPool + clientMock *mock_spacesyncproto.MockDRPCSpaceSyncClient + ctrl *gomock.Controller +} + +func newFixture(t *testing.T) *fixture { + ctrl := gomock.NewController(t) + manager := New().(*requestManager) + peerPoolMock := mock_pool.NewMockPool(ctrl) + messageHandlerMock := mock_objectsync.NewMockObjectSync(ctrl) + clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl) + manager.peerPool = peerPoolMock + manager.handler = messageHandlerMock + manager.clientFactory = spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient { + return clientMock + }) + manager.ctx, manager.cancel = context.WithCancel(context.Background()) + return &fixture{ + requestManager: manager, + messageHandlerMock: messageHandlerMock, + peerPoolMock: peerPoolMock, + clientMock: clientMock, + ctrl: ctrl, + } +} + +func (fx *fixture) stop() { + fx.ctrl.Finish() +} + +func TestRequestManager_SyncRequest(t *testing.T) { + ctx := context.Background() + + t.Run("send request", func(t *testing.T) { + fx := newFixture(t) + defer fx.stop() + + peerId := "peerId" + peerMock := mock_peer.NewMockPeer(fx.ctrl) + conn := &drpcconn.Conn{} + msg := &spacesyncproto.ObjectSyncMessage{} + resp := &spacesyncproto.ObjectSyncMessage{} + fx.peerPoolMock.EXPECT().Get(ctx, peerId).Return(peerMock, nil) + fx.clientMock.EXPECT().ObjectSync(ctx, msg).Return(resp, nil) + peerMock.EXPECT().DoDrpc(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, drpcHandler func(conn drpc.Conn) error) { + drpcHandler(conn) + }).Return(nil) + res, err := fx.requestManager.SendRequest(ctx, peerId, msg) + require.NoError(t, err) + require.Equal(t, resp, res) + }) + + t.Run("request and handle", func(t *testing.T) { + fx := newFixture(t) + defer fx.stop() + ctx = fx.requestManager.ctx + + peerId := "peerId" + peerMock := mock_peer.NewMockPeer(fx.ctrl) + conn := &drpcconn.Conn{} + msg := &spacesyncproto.ObjectSyncMessage{} + resp := &spacesyncproto.ObjectSyncMessage{} + fx.peerPoolMock.EXPECT().Get(ctx, peerId).Return(peerMock, nil) + fx.clientMock.EXPECT().ObjectSync(ctx, msg).Return(resp, nil) + peerMock.EXPECT().DoDrpc(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, drpcHandler func(conn drpc.Conn) error) { + drpcHandler(conn) + }).Return(nil) + fx.messageHandlerMock.EXPECT().HandleMessage(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, msg objectsync.HandleMessage) { + require.Equal(t, peerId, msg.SenderId) + require.Equal(t, resp, msg.Message) + pId, _ := peer.CtxPeerId(msg.PeerCtx) + require.Equal(t, peerId, pId) + }).Return(nil) + fx.requestManager.requestAndHandle(peerId, msg) + }) +} + +func TestRequestManager_QueueRequest(t *testing.T) { + t.Run("max concurrent reqs for peer, independent reqs for other peer", func(t *testing.T) { + // testing 2 concurrent requests to one peer and simultaneous to another peer + fx := newFixture(t) + defer fx.stop() + fx.requestManager.workers = 2 + msgRelease := make(chan struct{}) + msgWait := make(chan struct{}) + msgs := sync.Map{} + doRequestAndHandle = func(manager *requestManager, peerId string, req *spacesyncproto.ObjectSyncMessage) { + msgs.Store(req.ObjectId, struct{}{}) + <-msgWait + <-msgRelease + } + otherPeer := "otherPeer" + msg1 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id1"} + msg2 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id2"} + msg3 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id3"} + otherMsg1 := &spacesyncproto.ObjectSyncMessage{ObjectId: "otherId1"} + + // sending requests to first peer + peerId := "peerId" + err := fx.requestManager.QueueRequest(peerId, msg1) + require.NoError(t, err) + err = fx.requestManager.QueueRequest(peerId, msg2) + require.NoError(t, err) + err = fx.requestManager.QueueRequest(peerId, msg3) + require.NoError(t, err) + + // waiting until all the messages are loaded + msgWait <- struct{}{} + msgWait <- struct{}{} + _, ok := msgs.Load("id1") + require.True(t, ok) + _, ok = msgs.Load("id2") + require.True(t, ok) + // third message should not be read + _, ok = msgs.Load("id3") + require.False(t, ok) + + // request for other peer should pass + err = fx.requestManager.QueueRequest(otherPeer, otherMsg1) + require.NoError(t, err) + msgWait <- struct{}{} + + _, ok = msgs.Load("otherId1") + require.True(t, ok) + close(msgRelease) + }) + + t.Run("no requests after close", func(t *testing.T) { + fx := newFixture(t) + defer fx.stop() + fx.requestManager.workers = 1 + msgRelease := make(chan struct{}) + msgWait := make(chan struct{}) + msgs := sync.Map{} + doRequestAndHandle = func(manager *requestManager, peerId string, req *spacesyncproto.ObjectSyncMessage) { + msgs.Store(req.ObjectId, struct{}{}) + <-msgWait + <-msgRelease + } + msg1 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id1"} + msg2 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id2"} + + // sending requests to first peer + peerId := "peerId" + err := fx.requestManager.QueueRequest(peerId, msg1) + require.NoError(t, err) + err = fx.requestManager.QueueRequest(peerId, msg2) + require.NoError(t, err) + + // waiting until all the message is loaded + msgWait <- struct{}{} + _, ok := msgs.Load("id1") + require.True(t, ok) + _, ok = msgs.Load("id2") + require.False(t, ok) + + fx.requestManager.Close(context.Background()) + close(msgRelease) + // waiting to know if the second one is not taken + // because the manager is now closed + time.Sleep(100 * time.Millisecond) + _, ok = msgs.Load("id2") + require.False(t, ok) + + }) +} diff --git a/commonspace/settings/deleter.go b/commonspace/settings/deleter.go index 322308de..5d3bd5a1 100644 --- a/commonspace/settings/deleter.go +++ b/commonspace/settings/deleter.go @@ -2,9 +2,9 @@ package settings import ( "context" + "github.com/anyproto/any-sync/commonspace/deletionstate" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/treemanager" - "github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/spacestorage" "go.uber.org/zap" ) @@ -15,11 +15,11 @@ type Deleter interface { type deleter struct { st spacestorage.SpaceStorage - state settingsstate.ObjectDeletionState + state deletionstate.ObjectDeletionState getter treemanager.TreeManager } -func newDeleter(st spacestorage.SpaceStorage, state settingsstate.ObjectDeletionState, getter treemanager.TreeManager) Deleter { +func newDeleter(st spacestorage.SpaceStorage, state deletionstate.ObjectDeletionState, getter treemanager.TreeManager) Deleter { return &deleter{st, state, getter} } diff --git a/commonspace/settings/deleter_test.go b/commonspace/settings/deleter_test.go index 54feed56..e4a32e84 100644 --- a/commonspace/settings/deleter_test.go +++ b/commonspace/settings/deleter_test.go @@ -2,9 +2,9 @@ package settings import ( "fmt" + "github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager" - "github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate" "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage" "github.com/golang/mock/gomock" "testing" @@ -14,7 +14,7 @@ func TestDeleter_Delete(t *testing.T) { ctrl := gomock.NewController(t) treeManager := mock_treemanager.NewMockTreeManager(ctrl) st := mock_spacestorage.NewMockSpaceStorage(ctrl) - delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) + delState := mock_deletionstate.NewMockObjectDeletionState(ctrl) deleter := newDeleter(st, delState, treeManager) diff --git a/commonspace/settings/deletionmanager.go b/commonspace/settings/deletionmanager.go index 2d3d47ff..2611d28c 100644 --- a/commonspace/settings/deletionmanager.go +++ b/commonspace/settings/deletionmanager.go @@ -2,6 +2,7 @@ package settings import ( "context" + "github.com/anyproto/any-sync/commonspace/deletionstate" "github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/settings/settingsstate" "go.uber.org/zap" @@ -20,7 +21,7 @@ func newDeletionManager( settingsId string, isResponsible bool, treeManager treemanager.TreeManager, - deletionState settingsstate.ObjectDeletionState, + deletionState deletionstate.ObjectDeletionState, provider SpaceIdsProvider, onSpaceDelete func()) DeletionManager { return &deletionManager{ @@ -35,7 +36,7 @@ func newDeletionManager( } type deletionManager struct { - deletionState settingsstate.ObjectDeletionState + deletionState deletionstate.ObjectDeletionState provider SpaceIdsProvider treeManager treemanager.TreeManager spaceId string diff --git a/commonspace/settings/deletionmanager_test.go b/commonspace/settings/deletionmanager_test.go index 9e6b4f05..69e8830d 100644 --- a/commonspace/settings/deletionmanager_test.go +++ b/commonspace/settings/deletionmanager_test.go @@ -2,10 +2,10 @@ package settings import ( "context" + "github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate" "github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager" "github.com/anyproto/any-sync/commonspace/settings/mock_settings" "github.com/anyproto/any-sync/commonspace/settings/settingsstate" - "github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" "testing" @@ -26,7 +26,7 @@ func TestDeletionManager_UpdateState_NotResponsible(t *testing.T) { onDeleted := func() { deleted = true } - delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) + delState := mock_deletionstate.NewMockObjectDeletionState(ctrl) treeManager := mock_treemanager.NewMockTreeManager(ctrl) delState.EXPECT().Add(state.DeletedIds) @@ -58,7 +58,7 @@ func TestDeletionManager_UpdateState_Responsible(t *testing.T) { onDeleted := func() { deleted = true } - delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) + delState := mock_deletionstate.NewMockObjectDeletionState(ctrl) treeManager := mock_treemanager.NewMockTreeManager(ctrl) provider := mock_settings.NewMockSpaceIdsProvider(ctrl) diff --git a/commonspace/settings/settings.go b/commonspace/settings/settings.go index 5c2b41cf..eab41709 100644 --- a/commonspace/settings/settings.go +++ b/commonspace/settings/settings.go @@ -1,328 +1,122 @@ -//go:generate mockgen -destination mock_settings/mock_settings.go github.com/anyproto/any-sync/commonspace/settings DeletionManager,Deleter,SpaceIdsProvider package settings import ( "context" - "errors" - "fmt" - "github.com/anyproto/any-sync/util/crypto" - "github.com/anyproto/any-sync/accountservice" - "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/commonspace/deletionstate" + "github.com/anyproto/any-sync/commonspace/headsync" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/synctree" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/treemanager" - "github.com/anyproto/any-sync/commonspace/settings/settingsstate" + "github.com/anyproto/any-sync/commonspace/objecttreebuilder" + "github.com/anyproto/any-sync/commonspace/spacestate" "github.com/anyproto/any-sync/commonspace/spacestorage" - "github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/nodeconf" - "github.com/gogo/protobuf/proto" "go.uber.org/zap" - "golang.org/x/exp/slices" + "sync/atomic" ) -var log = logger.NewNamed("common.commonspace.settings") +const CName = "common.commonspace.settings" -type SettingsObject interface { - synctree.SyncTree - Init(ctx context.Context) (err error) - DeleteObject(id string) (err error) - DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error) - SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error) +type Settings interface { + DeleteTree(ctx context.Context, id string) (err error) + SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) + DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) + SettingsObject() SettingsObject + app.ComponentRunnable } -var ( - ErrDeleteSelf = errors.New("cannot delete self") - ErrAlreadyDeleted = errors.New("the object is already deleted") - ErrObjDoesNotExist = errors.New("the object does not exist") - ErrCantDeleteSpace = errors.New("not able to delete space") -) - -var ( - DoSnapshot = objecttree.DoSnapshot - buildHistoryTree = func(objTree objecttree.ObjectTree) (objecttree.ReadableObjectTree, error) { - return objecttree.BuildHistoryTree(objecttree.HistoryTreeParams{ - TreeStorage: objTree.Storage(), - AclList: objTree.AclList(), - BuildFullTree: true, - }) - } -) - -type BuildTreeFunc func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error) - -type Deps struct { - BuildFunc BuildTreeFunc - Account accountservice.Service - TreeManager treemanager.TreeManager - Store spacestorage.SpaceStorage - Configuration nodeconf.NodeConf - DeletionState settingsstate.ObjectDeletionState - Provider SpaceIdsProvider - OnSpaceDelete func() - // testing dependencies - builder settingsstate.StateBuilder - del Deleter - delManager DeletionManager - changeFactory settingsstate.ChangeFactory +func New() Settings { + return &settings{} } -type settingsObject struct { - synctree.SyncTree - account accountservice.Service - spaceId string - treeManager treemanager.TreeManager - store spacestorage.SpaceStorage - builder settingsstate.StateBuilder - buildFunc BuildTreeFunc - loop *deleteLoop +type settings struct { + account accountservice.Service + treeManager treemanager.TreeManager + storage spacestorage.SpaceStorage + configuration nodeconf.NodeConf + deletionState deletionstate.ObjectDeletionState + headsync headsync.HeadSync + treeBuilder objecttreebuilder.TreeBuilderComponent + spaceIsDeleted *atomic.Bool - state *settingsstate.State - deletionState settingsstate.ObjectDeletionState - deletionManager DeletionManager - changeFactory settingsstate.ChangeFactory + settingsObject SettingsObject } -func NewSettingsObject(deps Deps, spaceId string) (obj SettingsObject) { - var ( - deleter Deleter - deletionManager DeletionManager - builder settingsstate.StateBuilder - changeFactory settingsstate.ChangeFactory - ) - if deps.del == nil { - deleter = newDeleter(deps.Store, deps.DeletionState, deps.TreeManager) - } else { - deleter = deps.del - } - if deps.delManager == nil { - deletionManager = newDeletionManager( - spaceId, - deps.Store.SpaceSettingsId(), - deps.Configuration.IsResponsible(spaceId), - deps.TreeManager, - deps.DeletionState, - deps.Provider, - deps.OnSpaceDelete) - } else { - deletionManager = deps.delManager - } - if deps.builder == nil { - builder = settingsstate.NewStateBuilder() - } else { - builder = deps.builder - } - if deps.changeFactory == nil { - changeFactory = settingsstate.NewChangeFactory() - } else { - changeFactory = deps.changeFactory - } +func (s *settings) Init(a *app.App) (err error) { + s.account = a.MustComponent(accountservice.CName).(accountservice.Service) + s.treeManager = app.MustComponent[treemanager.TreeManager](a) + s.headsync = a.MustComponent(headsync.CName).(headsync.HeadSync) + s.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) + s.deletionState = a.MustComponent(deletionstate.CName).(deletionstate.ObjectDeletionState) + s.treeBuilder = a.MustComponent(objecttreebuilder.CName).(objecttreebuilder.TreeBuilderComponent) - loop := newDeleteLoop(func() { - deleter.Delete() - }) - deps.DeletionState.AddObserver(func(ids []string) { - loop.notify() - }) + sharedState := a.MustComponent(spacestate.CName).(*spacestate.SpaceState) + s.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage) + s.spaceIsDeleted = sharedState.SpaceIsDeleted - s := &settingsObject{ - loop: loop, - spaceId: spaceId, - account: deps.Account, - deletionState: deps.DeletionState, - treeManager: deps.TreeManager, - store: deps.Store, - buildFunc: deps.BuildFunc, - builder: builder, - deletionManager: deletionManager, - changeFactory: changeFactory, - } - obj = s - return -} - -func (s *settingsObject) updateIds(tr objecttree.ObjectTree) { - var err error - s.state, err = s.builder.Build(tr, s.state) - if err != nil { - log.Error("failed to build state", zap.Error(err)) - return - } - log.Debug("updating object state", zap.String("deleted by", s.state.DeleterId)) - if err = s.deletionManager.UpdateState(context.Background(), s.state); err != nil { - log.Error("failed to update state", zap.Error(err)) - } -} - -// Update is called as part of UpdateListener interface -func (s *settingsObject) Update(tr objecttree.ObjectTree) { - s.updateIds(tr) -} - -// Rebuild is called as part of UpdateListener interface (including when the object is built for the first time, e.g. on Init call) -func (s *settingsObject) Rebuild(tr objecttree.ObjectTree) { - // at initial build "s" may not contain the object tree, so it is safer to provide it from the function parameter - s.state = nil - s.updateIds(tr) -} - -func (s *settingsObject) Init(ctx context.Context) (err error) { - settingsId := s.store.SpaceSettingsId() - log.Debug("space settings id", zap.String("id", settingsId)) - s.SyncTree, err = s.buildFunc(ctx, settingsId, s) - if err != nil { - return - } - // TODO: remove this check when everybody updates - if err = s.checkHistoryState(ctx); err != nil { - return - } - s.loop.Run() - return -} - -func (s *settingsObject) checkHistoryState(ctx context.Context) (err error) { - historyTree, err := buildHistoryTree(s.SyncTree) - if err != nil { - return - } - fullState, err := s.builder.Build(historyTree, nil) - if err != nil { - return - } - if len(fullState.DeletedIds) != len(s.state.DeletedIds) { - log.WarnCtx(ctx, "state does not have all deleted ids", - zap.Int("fullstate ids", len(fullState.DeletedIds)), - zap.Int("state ids", len(fullState.DeletedIds))) - s.state = fullState - err = s.deletionManager.UpdateState(context.Background(), s.state) - if err != nil { + deps := Deps{ + BuildFunc: func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error) { + res, err := s.treeBuilder.BuildTree(ctx, id, objecttreebuilder.BuildTreeOpts{ + Listener: listener, + WaitTreeRemoteSync: false, + // space settings document should not have empty data + TreeBuilder: objecttree.BuildObjectTree, + }) + log.Debug("building settings tree", zap.String("id", id), zap.String("spaceId", sharedState.SpaceId)) + if err != nil { + return + } + t = res.(synctree.SyncTree) return - } + }, + Account: s.account, + TreeManager: s.treeManager, + Store: s.storage, + Configuration: s.configuration, + DeletionState: s.deletionState, + Provider: s.headsync, + OnSpaceDelete: s.onSpaceDelete, } - return + s.settingsObject = NewSettingsObject(deps, sharedState.SpaceId) + return nil } -func (s *settingsObject) Close() error { - s.loop.Close() - return s.SyncTree.Close() +func (s *settings) Name() (name string) { + return CName } -func (s *settingsObject) DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error) { - s.Lock() - defer s.Unlock() - defer func() { - log.Debug("finished adding delete change", zap.Error(err)) - }() - err = s.verifyDeleteSpace(raw) - if err != nil { - return - } - res, err := s.AddRawChanges(ctx, objecttree.RawChangesPayload{ - NewHeads: []string{raw.Id}, - RawChanges: []*treechangeproto.RawTreeChangeWithId{raw}, - }) - if err != nil { - return - } - if !slices.Contains(res.Heads, raw.Id) { - err = ErrCantDeleteSpace - return - } - return +func (s *settings) Run(ctx context.Context) (err error) { + return s.settingsObject.Init(ctx) } -func (s *settingsObject) SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error) { - accountData := s.account.Account() - data, err := s.changeFactory.CreateSpaceDeleteChange(accountData.PeerId, s.state, false) - if err != nil { - return - } - return s.PrepareChange(objecttree.SignableChangeContent{ - Data: data, - Key: accountData.SignKey, - IsSnapshot: false, - IsEncrypted: false, - }) +func (s *settings) Close(ctx context.Context) (err error) { + return s.settingsObject.Close() } -func (s *settingsObject) DeleteObject(id string) (err error) { - s.Lock() - defer s.Unlock() - if s.Id() == id { - err = ErrDeleteSelf - return - } - if s.state.Exists(id) { - err = ErrAlreadyDeleted - return nil - } - _, err = s.store.TreeStorage(id) - if err != nil { - err = ErrObjDoesNotExist - return - } - isSnapshot := DoSnapshot(s.Len()) - res, err := s.changeFactory.CreateObjectDeleteChange(id, s.state, isSnapshot) - if err != nil { - return - } - - return s.addContent(res, isSnapshot) +func (s *settings) DeleteTree(ctx context.Context, id string) (err error) { + return s.settingsObject.DeleteObject(id) } -func (s *settingsObject) verifyDeleteSpace(raw *treechangeproto.RawTreeChangeWithId) (err error) { - data, err := s.UnpackChange(raw) - if err != nil { - return - } - return verifyDeleteContent(data, "") +func (s *settings) SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) { + return s.settingsObject.SpaceDeleteRawChange() } -func (s *settingsObject) addContent(data []byte, isSnapshot bool) (err error) { - accountData := s.account.Account() - res, err := s.AddContent(context.Background(), objecttree.SignableChangeContent{ - Data: data, - Key: accountData.SignKey, - IsSnapshot: isSnapshot, - IsEncrypted: false, - }) - if err != nil { - return - } - if res.Mode == objecttree.Rebuild { - s.Rebuild(s) - } else { - s.Update(s) - } - return +func (s *settings) DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) { + return s.settingsObject.DeleteSpace(ctx, deleteChange) } -func VerifyDeleteChange(raw *treechangeproto.RawTreeChangeWithId, identity crypto.PubKey, peerId string) (err error) { - changeBuilder := objecttree.NewChangeBuilder(crypto.NewKeyStorage(), nil) - res, err := changeBuilder.Unmarshall(raw, true) +func (s *settings) onSpaceDelete() { + err := s.storage.SetSpaceDeleted() if err != nil { - return + log.Warn("failed to set space deleted") } - if !res.Identity.Equals(identity) { - return fmt.Errorf("incorrect identity") - } - return verifyDeleteContent(res.Data, peerId) + s.spaceIsDeleted.Swap(true) } -func verifyDeleteContent(data []byte, peerId string) (err error) { - content := &spacesyncproto.SettingsData{} - err = proto.Unmarshal(data, content) - if err != nil { - return - } - if len(content.GetContent()) != 1 || - content.GetContent()[0].GetSpaceDelete() == nil || - (peerId == "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() == "") || - (peerId != "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() != peerId) { - return fmt.Errorf("incorrect delete change payload") - } - return +func (s *settings) SettingsObject() SettingsObject { + return s.settingsObject } diff --git a/commonspace/settings/settingsobject.go b/commonspace/settings/settingsobject.go new file mode 100644 index 00000000..e6fd0b39 --- /dev/null +++ b/commonspace/settings/settingsobject.go @@ -0,0 +1,329 @@ +//go:generate mockgen -destination mock_settings/mock_settings.go github.com/anyproto/any-sync/commonspace/settings DeletionManager,Deleter,SpaceIdsProvider +package settings + +import ( + "context" + "errors" + "fmt" + "github.com/anyproto/any-sync/commonspace/deletionstate" + "github.com/anyproto/any-sync/util/crypto" + + "github.com/anyproto/any-sync/accountservice" + "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" + "github.com/anyproto/any-sync/commonspace/object/tree/synctree" + "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" + "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" + "github.com/anyproto/any-sync/commonspace/object/treemanager" + "github.com/anyproto/any-sync/commonspace/settings/settingsstate" + "github.com/anyproto/any-sync/commonspace/spacestorage" + "github.com/anyproto/any-sync/commonspace/spacesyncproto" + "github.com/anyproto/any-sync/nodeconf" + "github.com/gogo/protobuf/proto" + "go.uber.org/zap" + "golang.org/x/exp/slices" +) + +var log = logger.NewNamed("common.commonspace.settings") + +type SettingsObject interface { + synctree.SyncTree + Init(ctx context.Context) (err error) + DeleteObject(id string) (err error) + DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error) + SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error) +} + +var ( + ErrDeleteSelf = errors.New("cannot delete self") + ErrAlreadyDeleted = errors.New("the object is already deleted") + ErrObjDoesNotExist = errors.New("the object does not exist") + ErrCantDeleteSpace = errors.New("not able to delete space") +) + +var ( + DoSnapshot = objecttree.DoSnapshot + buildHistoryTree = func(objTree objecttree.ObjectTree) (objecttree.ReadableObjectTree, error) { + return objecttree.BuildHistoryTree(objecttree.HistoryTreeParams{ + TreeStorage: objTree.Storage(), + AclList: objTree.AclList(), + BuildFullTree: true, + }) + } +) + +type BuildTreeFunc func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error) + +type Deps struct { + BuildFunc BuildTreeFunc + Account accountservice.Service + TreeManager treemanager.TreeManager + Store spacestorage.SpaceStorage + Configuration nodeconf.NodeConf + DeletionState deletionstate.ObjectDeletionState + Provider SpaceIdsProvider + OnSpaceDelete func() + // testing dependencies + builder settingsstate.StateBuilder + del Deleter + delManager DeletionManager + changeFactory settingsstate.ChangeFactory +} + +type settingsObject struct { + synctree.SyncTree + account accountservice.Service + spaceId string + treeManager treemanager.TreeManager + store spacestorage.SpaceStorage + builder settingsstate.StateBuilder + buildFunc BuildTreeFunc + loop *deleteLoop + + state *settingsstate.State + deletionState deletionstate.ObjectDeletionState + deletionManager DeletionManager + changeFactory settingsstate.ChangeFactory +} + +func NewSettingsObject(deps Deps, spaceId string) (obj SettingsObject) { + var ( + deleter Deleter + deletionManager DeletionManager + builder settingsstate.StateBuilder + changeFactory settingsstate.ChangeFactory + ) + if deps.del == nil { + deleter = newDeleter(deps.Store, deps.DeletionState, deps.TreeManager) + } else { + deleter = deps.del + } + if deps.delManager == nil { + deletionManager = newDeletionManager( + spaceId, + deps.Store.SpaceSettingsId(), + deps.Configuration.IsResponsible(spaceId), + deps.TreeManager, + deps.DeletionState, + deps.Provider, + deps.OnSpaceDelete) + } else { + deletionManager = deps.delManager + } + if deps.builder == nil { + builder = settingsstate.NewStateBuilder() + } else { + builder = deps.builder + } + if deps.changeFactory == nil { + changeFactory = settingsstate.NewChangeFactory() + } else { + changeFactory = deps.changeFactory + } + + loop := newDeleteLoop(func() { + deleter.Delete() + }) + deps.DeletionState.AddObserver(func(ids []string) { + loop.notify() + }) + + s := &settingsObject{ + loop: loop, + spaceId: spaceId, + account: deps.Account, + deletionState: deps.DeletionState, + treeManager: deps.TreeManager, + store: deps.Store, + buildFunc: deps.BuildFunc, + builder: builder, + deletionManager: deletionManager, + changeFactory: changeFactory, + } + obj = s + return +} + +func (s *settingsObject) updateIds(tr objecttree.ObjectTree) { + var err error + s.state, err = s.builder.Build(tr, s.state) + if err != nil { + log.Error("failed to build state", zap.Error(err)) + return + } + log.Debug("updating object state", zap.String("deleted by", s.state.DeleterId)) + if err = s.deletionManager.UpdateState(context.Background(), s.state); err != nil { + log.Error("failed to update state", zap.Error(err)) + } +} + +// Update is called as part of UpdateListener interface +func (s *settingsObject) Update(tr objecttree.ObjectTree) { + s.updateIds(tr) +} + +// Rebuild is called as part of UpdateListener interface (including when the object is built for the first time, e.g. on Init call) +func (s *settingsObject) Rebuild(tr objecttree.ObjectTree) { + // at initial build "s" may not contain the object tree, so it is safer to provide it from the function parameter + s.state = nil + s.updateIds(tr) +} + +func (s *settingsObject) Init(ctx context.Context) (err error) { + settingsId := s.store.SpaceSettingsId() + log.Debug("space settings id", zap.String("id", settingsId)) + s.SyncTree, err = s.buildFunc(ctx, settingsId, s) + if err != nil { + return + } + // TODO: remove this check when everybody updates + if err = s.checkHistoryState(ctx); err != nil { + return + } + s.loop.Run() + return +} + +func (s *settingsObject) checkHistoryState(ctx context.Context) (err error) { + historyTree, err := buildHistoryTree(s.SyncTree) + if err != nil { + return + } + fullState, err := s.builder.Build(historyTree, nil) + if err != nil { + return + } + if len(fullState.DeletedIds) != len(s.state.DeletedIds) { + log.WarnCtx(ctx, "state does not have all deleted ids", + zap.Int("fullstate ids", len(fullState.DeletedIds)), + zap.Int("state ids", len(fullState.DeletedIds))) + s.state = fullState + err = s.deletionManager.UpdateState(context.Background(), s.state) + if err != nil { + return + } + } + return +} + +func (s *settingsObject) Close() error { + s.loop.Close() + return s.SyncTree.Close() +} + +func (s *settingsObject) DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error) { + s.Lock() + defer s.Unlock() + defer func() { + log.Debug("finished adding delete change", zap.Error(err)) + }() + err = s.verifyDeleteSpace(raw) + if err != nil { + return + } + res, err := s.AddRawChanges(ctx, objecttree.RawChangesPayload{ + NewHeads: []string{raw.Id}, + RawChanges: []*treechangeproto.RawTreeChangeWithId{raw}, + }) + if err != nil { + return + } + if !slices.Contains(res.Heads, raw.Id) { + err = ErrCantDeleteSpace + return + } + return +} + +func (s *settingsObject) SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error) { + accountData := s.account.Account() + data, err := s.changeFactory.CreateSpaceDeleteChange(accountData.PeerId, s.state, false) + if err != nil { + return + } + return s.PrepareChange(objecttree.SignableChangeContent{ + Data: data, + Key: accountData.SignKey, + IsSnapshot: false, + IsEncrypted: false, + }) +} + +func (s *settingsObject) DeleteObject(id string) (err error) { + s.Lock() + defer s.Unlock() + if s.Id() == id { + err = ErrDeleteSelf + return + } + if s.state.Exists(id) { + err = ErrAlreadyDeleted + return nil + } + _, err = s.store.TreeStorage(id) + if err != nil { + err = ErrObjDoesNotExist + return + } + isSnapshot := DoSnapshot(s.Len()) + res, err := s.changeFactory.CreateObjectDeleteChange(id, s.state, isSnapshot) + if err != nil { + return + } + + return s.addContent(res, isSnapshot) +} + +func (s *settingsObject) verifyDeleteSpace(raw *treechangeproto.RawTreeChangeWithId) (err error) { + data, err := s.UnpackChange(raw) + if err != nil { + return + } + return verifyDeleteContent(data, "") +} + +func (s *settingsObject) addContent(data []byte, isSnapshot bool) (err error) { + accountData := s.account.Account() + res, err := s.AddContent(context.Background(), objecttree.SignableChangeContent{ + Data: data, + Key: accountData.SignKey, + IsSnapshot: isSnapshot, + IsEncrypted: false, + }) + if err != nil { + return + } + if res.Mode == objecttree.Rebuild { + s.Rebuild(s) + } else { + s.Update(s) + } + return +} + +func VerifyDeleteChange(raw *treechangeproto.RawTreeChangeWithId, identity crypto.PubKey, peerId string) (err error) { + changeBuilder := objecttree.NewChangeBuilder(crypto.NewKeyStorage(), nil) + res, err := changeBuilder.Unmarshall(raw, true) + if err != nil { + return + } + if !res.Identity.Equals(identity) { + return fmt.Errorf("incorrect identity") + } + return verifyDeleteContent(res.Data, peerId) +} + +func verifyDeleteContent(data []byte, peerId string) (err error) { + content := &spacesyncproto.SettingsData{} + err = proto.Unmarshal(data, content) + if err != nil { + return + } + if len(content.GetContent()) != 1 || + content.GetContent()[0].GetSpaceDelete() == nil || + (peerId == "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() == "") || + (peerId != "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() != peerId) { + return fmt.Errorf("incorrect delete change payload") + } + return +} diff --git a/commonspace/settings/settings_test.go b/commonspace/settings/settingsobject_test.go similarity index 97% rename from commonspace/settings/settings_test.go rename to commonspace/settings/settingsobject_test.go index 31956c81..9d83d9cd 100644 --- a/commonspace/settings/settings_test.go +++ b/commonspace/settings/settingsobject_test.go @@ -3,6 +3,7 @@ package settings import ( "context" "github.com/anyproto/any-sync/accountservice/mock_accountservice" + "github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate" "github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree" @@ -54,7 +55,7 @@ type settingsFixture struct { deleter *mock_settings.MockDeleter syncTree *mock_synctree.MockSyncTree historyTree *mock_objecttree.MockObjectTree - delState *mock_settingsstate.MockObjectDeletionState + delState *mock_deletionstate.MockObjectDeletionState account *mock_accountservice.MockService } @@ -66,7 +67,7 @@ func newSettingsFixture(t *testing.T) *settingsFixture { acc := mock_accountservice.NewMockService(ctrl) treeManager := mock_treemanager.NewMockTreeManager(ctrl) st := mock_spacestorage.NewMockSpaceStorage(ctrl) - delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) + delState := mock_deletionstate.NewMockObjectDeletionState(ctrl) delManager := mock_settings.NewMockDeletionManager(ctrl) stateBuilder := mock_settingsstate.NewMockStateBuilder(ctrl) changeFactory := mock_settingsstate.NewMockChangeFactory(ctrl) diff --git a/commonspace/settings/settingsstate/mock_settingsstate/mock_settingsstate.go b/commonspace/settings/settingsstate/mock_settingsstate/mock_settingsstate.go index 0bc9bf23..2bb898cf 100644 --- a/commonspace/settings/settingsstate/mock_settingsstate/mock_settingsstate.go +++ b/commonspace/settings/settingsstate/mock_settingsstate/mock_settingsstate.go @@ -1,5 +1,5 @@ // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/anyproto/any-sync/commonspace/settings/settingsstate (interfaces: ObjectDeletionState,StateBuilder,ChangeFactory) +// Source: github.com/anyproto/any-sync/commonspace/settings/settingsstate (interfaces: StateBuilder,ChangeFactory) // Package mock_settingsstate is a generated GoMock package. package mock_settingsstate @@ -12,109 +12,6 @@ import ( gomock "github.com/golang/mock/gomock" ) -// MockObjectDeletionState is a mock of ObjectDeletionState interface. -type MockObjectDeletionState struct { - ctrl *gomock.Controller - recorder *MockObjectDeletionStateMockRecorder -} - -// MockObjectDeletionStateMockRecorder is the mock recorder for MockObjectDeletionState. -type MockObjectDeletionStateMockRecorder struct { - mock *MockObjectDeletionState -} - -// NewMockObjectDeletionState creates a new mock instance. -func NewMockObjectDeletionState(ctrl *gomock.Controller) *MockObjectDeletionState { - mock := &MockObjectDeletionState{ctrl: ctrl} - mock.recorder = &MockObjectDeletionStateMockRecorder{mock} - return mock -} - -// EXPECT returns an object that allows the caller to indicate expected use. -func (m *MockObjectDeletionState) EXPECT() *MockObjectDeletionStateMockRecorder { - return m.recorder -} - -// Add mocks base method. -func (m *MockObjectDeletionState) Add(arg0 map[string]struct{}) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "Add", arg0) -} - -// Add indicates an expected call of Add. -func (mr *MockObjectDeletionStateMockRecorder) Add(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockObjectDeletionState)(nil).Add), arg0) -} - -// AddObserver mocks base method. -func (m *MockObjectDeletionState) AddObserver(arg0 settingsstate.StateUpdateObserver) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddObserver", arg0) -} - -// AddObserver indicates an expected call of AddObserver. -func (mr *MockObjectDeletionStateMockRecorder) AddObserver(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddObserver", reflect.TypeOf((*MockObjectDeletionState)(nil).AddObserver), arg0) -} - -// Delete mocks base method. -func (m *MockObjectDeletionState) Delete(arg0 string) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Delete", arg0) - ret0, _ := ret[0].(error) - return ret0 -} - -// Delete indicates an expected call of Delete. -func (mr *MockObjectDeletionStateMockRecorder) Delete(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockObjectDeletionState)(nil).Delete), arg0) -} - -// Exists mocks base method. -func (m *MockObjectDeletionState) Exists(arg0 string) bool { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Exists", arg0) - ret0, _ := ret[0].(bool) - return ret0 -} - -// Exists indicates an expected call of Exists. -func (mr *MockObjectDeletionStateMockRecorder) Exists(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockObjectDeletionState)(nil).Exists), arg0) -} - -// Filter mocks base method. -func (m *MockObjectDeletionState) Filter(arg0 []string) []string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Filter", arg0) - ret0, _ := ret[0].([]string) - return ret0 -} - -// Filter indicates an expected call of Filter. -func (mr *MockObjectDeletionStateMockRecorder) Filter(arg0 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filter", reflect.TypeOf((*MockObjectDeletionState)(nil).Filter), arg0) -} - -// GetQueued mocks base method. -func (m *MockObjectDeletionState) GetQueued() []string { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetQueued") - ret0, _ := ret[0].([]string) - return ret0 -} - -// GetQueued indicates an expected call of GetQueued. -func (mr *MockObjectDeletionStateMockRecorder) GetQueued() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQueued", reflect.TypeOf((*MockObjectDeletionState)(nil).GetQueued)) -} - // MockStateBuilder is a mock of StateBuilder interface. type MockStateBuilder struct { ctrl *gomock.Controller diff --git a/commonspace/settings/settingsstate/settingsstate.go b/commonspace/settings/settingsstate/settingsstate.go index 0b62c8d1..20619cb2 100644 --- a/commonspace/settings/settingsstate/settingsstate.go +++ b/commonspace/settings/settingsstate/settingsstate.go @@ -1,3 +1,4 @@ +//go:generate mockgen -destination mock_settingsstate/mock_settingsstate.go github.com/anyproto/any-sync/commonspace/settings/settingsstate StateBuilder,ChangeFactory package settingsstate import "github.com/anyproto/any-sync/commonspace/spacesyncproto" diff --git a/commonspace/space.go b/commonspace/space.go index 70bf6e0c..ef5119de 100644 --- a/commonspace/space.go +++ b/commonspace/space.go @@ -2,44 +2,26 @@ package commonspace import ( "context" - "errors" - "github.com/anyproto/any-sync/accountservice" - "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/commonspace/headsync" "github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/syncacl" - "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" - "github.com/anyproto/any-sync/commonspace/object/tree/synctree" - "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" - "github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/objectsync" - "github.com/anyproto/any-sync/commonspace/peermanager" + "github.com/anyproto/any-sync/commonspace/objecttreebuilder" "github.com/anyproto/any-sync/commonspace/settings" - "github.com/anyproto/any-sync/commonspace/settings/settingsstate" + "github.com/anyproto/any-sync/commonspace/spacestate" "github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/syncstatus" - "github.com/anyproto/any-sync/metric" - "github.com/anyproto/any-sync/net/peer" - "github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/util/crypto" - "github.com/anyproto/any-sync/util/multiqueue" - "github.com/anyproto/any-sync/util/slice" - "github.com/cheggaaa/mb/v3" - "github.com/zeebo/errs" "go.uber.org/zap" "strconv" "strings" "sync" - "sync/atomic" "time" ) -var ( - ErrSpaceClosed = errors.New("space is closed") -) - type SpaceCreatePayload struct { // SigningKey is the signing key of the owner SigningKey crypto.PrivKey @@ -55,25 +37,6 @@ type SpaceCreatePayload struct { MasterKey crypto.PrivKey } -type HandleMessage struct { - Id uint64 - ReceiveTime time.Time - StartHandlingTime time.Time - Deadline time.Time - SenderId string - Message *spacesyncproto.ObjectSyncMessage - PeerCtx context.Context -} - -func (m HandleMessage) LogFields(fields ...zap.Field) []zap.Field { - return append(fields, - metric.SpaceId(m.Message.SpaceId), - metric.ObjectId(m.Message.ObjectId), - metric.QueueDur(m.StartHandlingTime.Sub(m.ReceiveTime)), - metric.TotalDur(time.Since(m.ReceiveTime)), - ) -} - type SpaceDerivePayload struct { SigningKey crypto.PrivKey MasterKey crypto.PrivKey @@ -99,55 +62,38 @@ type Space interface { StoredIds() []string DebugAllHeads() []headsync.TreeHeads - Description() (SpaceDescription, error) + Description() (desc SpaceDescription, err error) - CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error) - PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (t objecttree.ObjectTree, err error) - BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (t objecttree.ObjectTree, err error) - DeleteTree(ctx context.Context, id string) (err error) - BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (t objecttree.HistoryTree, err error) - - SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) - DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) - - HeadSync() headsync.HeadSync - ObjectSync() objectsync.ObjectSync + TreeBuilder() objecttreebuilder.TreeBuilder SyncStatus() syncstatus.StatusUpdater Storage() spacestorage.SpaceStorage - HandleMessage(ctx context.Context, msg HandleMessage) (err error) + DeleteTree(ctx context.Context, id string) (err error) + SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) + DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) + + HandleMessage(ctx context.Context, msg objectsync.HandleMessage) (err error) + HandleSyncRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) + HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) TryClose(objectTTL time.Duration) (close bool, err error) Close() error } type space struct { - id string mu sync.RWMutex header *spacesyncproto.RawSpaceHeaderWithId - objectSync objectsync.ObjectSync - headSync headsync.HeadSync - syncStatus syncstatus.StatusUpdater - storage spacestorage.SpaceStorage - treeManager *commonGetter - account accountservice.Service - aclList *syncacl.SyncAcl - configuration nodeconf.NodeConf - settingsObject settings.SettingsObject - peerManager peermanager.PeerManager - treeBuilder objecttree.BuildObjectTreeFunc - metric metric.Metric + state *spacestate.SpaceState + app *app.App - handleQueue multiqueue.MultiQueue[HandleMessage] - - isClosed *atomic.Bool - isDeleted *atomic.Bool - treesUsed *atomic.Int32 -} - -func (s *space) Id() string { - return s.id + treeBuilder objecttreebuilder.TreeBuilderComponent + headSync headsync.HeadSync + objectSync objectsync.ObjectSync + syncStatus syncstatus.StatusService + settings settings.Settings + storage spacestorage.SpaceStorage + aclList list.AclList } func (s *space) Description() (desc SpaceDescription, err error) { @@ -171,72 +117,60 @@ func (s *space) Description() (desc SpaceDescription, err error) { return } +func (s *space) StoredIds() []string { + return s.headSync.ExternalIds() +} + +func (s *space) DebugAllHeads() []headsync.TreeHeads { + return s.headSync.DebugAllHeads() +} + +func (s *space) DeleteTree(ctx context.Context, id string) (err error) { + return s.settings.DeleteTree(ctx, id) +} + +func (s *space) SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) { + return s.settings.SpaceDeleteRawChange(ctx) +} + +func (s *space) DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) { + return s.settings.DeleteSpace(ctx, deleteChange) +} + +func (s *space) HandleMessage(ctx context.Context, msg objectsync.HandleMessage) (err error) { + return s.objectSync.HandleMessage(ctx, msg) +} + +func (s *space) HandleSyncRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) { + return s.objectSync.HandleRequest(ctx, req) +} + +func (s *space) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) { + return s.headSync.HandleRangeRequest(ctx, req) +} + +func (s *space) TreeBuilder() objecttreebuilder.TreeBuilder { + return s.treeBuilder +} + +func (s *space) Id() string { + return s.state.SpaceId +} + func (s *space) Init(ctx context.Context) (err error) { - log.With(zap.String("spaceId", s.id)).Debug("initializing space") - s.storage = newCommonStorage(s.storage) - - header, err := s.storage.SpaceHeader() + err = s.app.Start(ctx) if err != nil { return } - s.header = header - initialIds, err := s.storage.StoredIds() - if err != nil { - return - } - aclStorage, err := s.storage.AclStorage() - if err != nil { - return - } - aclList, err := list.BuildAclListWithIdentity(s.account.Account(), aclStorage) - if err != nil { - return - } - s.aclList = syncacl.NewSyncAcl(aclList, s.objectSync.SyncClient().MessagePool()) - s.treeManager.AddObject(s.aclList) - - deletionState := settingsstate.NewObjectDeletionState(log, s.storage) - deps := settings.Deps{ - BuildFunc: func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error) { - res, err := s.BuildTree(ctx, id, BuildTreeOpts{ - Listener: listener, - WaitTreeRemoteSync: false, - // space settings document should not have empty data - treeBuilder: objecttree.BuildObjectTree, - }) - log.Debug("building settings tree", zap.String("id", id), zap.String("spaceId", s.id)) - if err != nil { - return - } - t = res.(synctree.SyncTree) - return - }, - Account: s.account, - TreeManager: s.treeManager, - Store: s.storage, - DeletionState: deletionState, - Provider: s.headSync, - Configuration: s.configuration, - OnSpaceDelete: s.onSpaceDelete, - } - s.settingsObject = settings.NewSettingsObject(deps, s.id) - s.headSync.Init(initialIds, deletionState) - err = s.settingsObject.Init(ctx) - if err != nil { - return - } - s.treeManager.AddObject(s.settingsObject) - s.syncStatus.Run() - s.handleQueue = multiqueue.New[HandleMessage](s.handleMessage, 100) - return nil -} - -func (s *space) ObjectSync() objectsync.ObjectSync { - return s.objectSync -} - -func (s *space) HeadSync() headsync.HeadSync { - return s.headSync + s.treeBuilder = s.app.MustComponent(objecttreebuilder.CName).(objecttreebuilder.TreeBuilderComponent) + s.headSync = s.app.MustComponent(headsync.CName).(headsync.HeadSync) + s.syncStatus = s.app.MustComponent(syncstatus.CName).(syncstatus.StatusService) + s.settings = s.app.MustComponent(settings.CName).(settings.Settings) + s.objectSync = s.app.MustComponent(objectsync.CName).(objectsync.ObjectSync) + s.storage = s.app.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage) + s.aclList = s.app.MustComponent(syncacl.CName).(list.AclList) + s.header, err = s.storage.SpaceHeader() + return } func (s *space) SyncStatus() syncstatus.StatusUpdater { @@ -247,246 +181,25 @@ func (s *space) Storage() spacestorage.SpaceStorage { return s.storage } -func (s *space) StoredIds() []string { - return slice.DiscardFromSlice(s.headSync.AllIds(), func(id string) bool { - return id == s.settingsObject.Id() - }) -} - -func (s *space) DebugAllHeads() []headsync.TreeHeads { - return s.headSync.DebugAllHeads() -} - -func (s *space) CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error) { - if s.isClosed.Load() { - err = ErrSpaceClosed - return - } - root, err := objecttree.CreateObjectTreeRoot(payload, s.aclList) - if err != nil { - return - } - - res = treestorage.TreeStorageCreatePayload{ - RootRawChange: root, - Changes: []*treechangeproto.RawTreeChangeWithId{root}, - Heads: []string{root.Id}, - } - return -} - -func (s *space) PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (t objecttree.ObjectTree, err error) { - if s.isClosed.Load() { - err = ErrSpaceClosed - return - } - deps := synctree.BuildDeps{ - SpaceId: s.id, - SyncClient: s.objectSync.SyncClient(), - Configuration: s.configuration, - HeadNotifiable: s.headSync, - Listener: listener, - AclList: s.aclList, - SpaceStorage: s.storage, - OnClose: s.onObjectClose, - SyncStatus: s.syncStatus, - PeerGetter: s.peerManager, - BuildObjectTree: s.treeBuilder, - } - t, err = synctree.PutSyncTree(ctx, payload, deps) - if err != nil { - return - } - s.treesUsed.Add(1) - log.Debug("incrementing counter", zap.String("id", payload.RootRawChange.Id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id)) - return -} - -type BuildTreeOpts struct { - Listener updatelistener.UpdateListener - WaitTreeRemoteSync bool - treeBuilder objecttree.BuildObjectTreeFunc -} - -type HistoryTreeOpts struct { - BeforeId string - Include bool - BuildFullTree bool -} - -func (s *space) BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (t objecttree.ObjectTree, err error) { - if s.isClosed.Load() { - err = ErrSpaceClosed - return - } - treeBuilder := opts.treeBuilder - if treeBuilder == nil { - treeBuilder = s.treeBuilder - } - deps := synctree.BuildDeps{ - SpaceId: s.id, - SyncClient: s.objectSync.SyncClient(), - Configuration: s.configuration, - HeadNotifiable: s.headSync, - Listener: opts.Listener, - AclList: s.aclList, - SpaceStorage: s.storage, - OnClose: s.onObjectClose, - SyncStatus: s.syncStatus, - WaitTreeRemoteSync: opts.WaitTreeRemoteSync, - PeerGetter: s.peerManager, - BuildObjectTree: treeBuilder, - } - s.treesUsed.Add(1) - log.Debug("incrementing counter", zap.String("id", id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id)) - if t, err = synctree.BuildSyncTreeOrGetRemote(ctx, id, deps); err != nil { - s.treesUsed.Add(-1) - log.Debug("decrementing counter, load failed", zap.String("id", id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id), zap.Error(err)) - return nil, err - } - return -} - -func (s *space) BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (t objecttree.HistoryTree, err error) { - if s.isClosed.Load() { - err = ErrSpaceClosed - return - } - - params := objecttree.HistoryTreeParams{ - AclList: s.aclList, - BeforeId: opts.BeforeId, - IncludeBeforeId: opts.Include, - BuildFullTree: opts.BuildFullTree, - } - params.TreeStorage, err = s.storage.TreeStorage(id) - if err != nil { - return - } - return objecttree.BuildHistoryTree(params) -} - -func (s *space) DeleteTree(ctx context.Context, id string) (err error) { - return s.settingsObject.DeleteObject(id) -} - -func (s *space) SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) { - return s.settingsObject.SpaceDeleteRawChange() -} - -func (s *space) DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) { - return s.settingsObject.DeleteSpace(ctx, deleteChange) -} - -func (s *space) HandleMessage(ctx context.Context, hm HandleMessage) (err error) { - threadId := hm.Message.ObjectId - hm.ReceiveTime = time.Now() - if hm.Message.ReplyId != "" { - threadId += hm.Message.ReplyId - defer func() { - _ = s.handleQueue.CloseThread(threadId) - }() - } - if hm.PeerCtx == nil { - hm.PeerCtx = ctx - } - err = s.handleQueue.Add(ctx, threadId, hm) - if err == mb.ErrOverflowed { - log.InfoCtx(ctx, "queue overflowed", zap.String("spaceId", s.id), zap.String("objectId", threadId)) - // skip overflowed error - return nil - } - return -} - -func (s *space) handleMessage(msg HandleMessage) { - var err error - msg.StartHandlingTime = time.Now() - ctx := peer.CtxWithPeerId(context.Background(), msg.SenderId) - ctx = logger.CtxWithFields(ctx, zap.Uint64("msgId", msg.Id), zap.String("senderId", msg.SenderId)) - defer func() { - if s.metric == nil { - return - } - s.metric.RequestLog(msg.PeerCtx, "space.streamOp", msg.LogFields( - zap.Error(err), - )...) - }() - - if !msg.Deadline.IsZero() { - now := time.Now() - if now.After(msg.Deadline) { - log.InfoCtx(ctx, "skip message: deadline exceed") - err = context.DeadlineExceeded - return - } - var cancel context.CancelFunc - ctx, cancel = context.WithDeadline(ctx, msg.Deadline) - defer cancel() - } - - if err = s.objectSync.HandleMessage(ctx, msg.SenderId, msg.Message); err != nil { - if msg.Message.ObjectId != "" { - // cleanup thread on error - _ = s.handleQueue.CloseThread(msg.Message.ObjectId) - } - log.InfoCtx(ctx, "handleMessage error", zap.Error(err)) - } -} - -func (s *space) onObjectClose(id string) { - s.treesUsed.Add(-1) - log.Debug("decrementing counter", zap.String("id", id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id)) - _ = s.handleQueue.CloseThread(id) -} - -func (s *space) onSpaceDelete() { - err := s.storage.SetSpaceDeleted() - if err != nil { - log.Debug("failed to set space deleted") - } - s.isDeleted.Swap(true) -} - func (s *space) Close() error { - if s.isClosed.Swap(true) { - log.Warn("call space.Close on closed space", zap.String("id", s.id)) + if s.state.SpaceIsClosed.Swap(true) { + log.Warn("call space.Close on closed space", zap.String("id", s.state.SpaceId)) return nil } - log.With(zap.String("id", s.id)).Debug("space is closing") + log := log.With(zap.String("spaceId", s.state.SpaceId)) + log.Debug("space is closing") - var mError errs.Group - if err := s.handleQueue.Close(); err != nil { - mError.Add(err) - } - if err := s.headSync.Close(); err != nil { - mError.Add(err) - } - if err := s.objectSync.Close(); err != nil { - mError.Add(err) - } - if err := s.settingsObject.Close(); err != nil { - mError.Add(err) - } - if err := s.aclList.Close(); err != nil { - mError.Add(err) - } - if err := s.storage.Close(); err != nil { - mError.Add(err) - } - if err := s.syncStatus.Close(); err != nil { - mError.Add(err) - } - log.With(zap.String("id", s.id)).Debug("space closed") - return mError.Err() + err := s.app.Close(context.Background()) + log.Debug("space closed") + return err } func (s *space) TryClose(objectTTL time.Duration) (close bool, err error) { if time.Now().Sub(s.objectSync.LastUsage()) < objectTTL { return false, nil } - locked := s.treesUsed.Load() > 1 - log.With(zap.Int32("trees used", s.treesUsed.Load()), zap.Bool("locked", locked), zap.String("spaceId", s.id)).Debug("space lock status check") + locked := s.state.TreesUsed.Load() > 1 + log.With(zap.Int32("trees used", s.state.TreesUsed.Load()), zap.Bool("locked", locked), zap.String("spaceId", s.state.SpaceId)).Debug("space lock status check") if locked { return false, nil } diff --git a/commonspace/spaceservice.go b/commonspace/spaceservice.go index 0e770dc1..3e441aea 100644 --- a/commonspace/spaceservice.go +++ b/commonspace/spaceservice.go @@ -5,14 +5,22 @@ import ( "github.com/anyproto/any-sync/accountservice" "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/commonspace/config" "github.com/anyproto/any-sync/commonspace/credentialprovider" + "github.com/anyproto/any-sync/commonspace/deletionstate" "github.com/anyproto/any-sync/commonspace/headsync" "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" + "github.com/anyproto/any-sync/commonspace/object/acl/syncacl" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/treemanager" + "github.com/anyproto/any-sync/commonspace/objectmanager" "github.com/anyproto/any-sync/commonspace/objectsync" + "github.com/anyproto/any-sync/commonspace/objecttreebuilder" "github.com/anyproto/any-sync/commonspace/peermanager" + "github.com/anyproto/any-sync/commonspace/requestmanager" + "github.com/anyproto/any-sync/commonspace/settings" + "github.com/anyproto/any-sync/commonspace/spacestate" "github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/syncstatus" @@ -21,6 +29,7 @@ import ( "github.com/anyproto/any-sync/net/pool" "github.com/anyproto/any-sync/net/rpc/rpcerr" "github.com/anyproto/any-sync/nodeconf" + "storj.io/drpc" "sync/atomic" ) @@ -45,32 +54,30 @@ type SpaceService interface { } type spaceService struct { - config Config - account accountservice.Service - configurationService nodeconf.Service - storageProvider spacestorage.SpaceStorageProvider - peermanagerProvider peermanager.PeerManagerProvider - credentialProvider credentialprovider.CredentialProvider - treeManager treemanager.TreeManager - pool pool.Pool - metric metric.Metric + config config.Config + account accountservice.Service + configurationService nodeconf.Service + storageProvider spacestorage.SpaceStorageProvider + peerManagerProvider peermanager.PeerManagerProvider + credentialProvider credentialprovider.CredentialProvider + statusServiceProvider syncstatus.StatusServiceProvider + treeManager treemanager.TreeManager + pool pool.Pool + metric metric.Metric + app *app.App } func (s *spaceService) Init(a *app.App) (err error) { - s.config = a.MustComponent("config").(ConfigGetter).GetSpace() + s.config = a.MustComponent("config").(config.ConfigGetter).GetSpace() s.account = a.MustComponent(accountservice.CName).(accountservice.Service) s.storageProvider = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorageProvider) s.configurationService = a.MustComponent(nodeconf.CName).(nodeconf.Service) s.treeManager = a.MustComponent(treemanager.CName).(treemanager.TreeManager) - s.peermanagerProvider = a.MustComponent(peermanager.CName).(peermanager.PeerManagerProvider) - credProvider := a.Component(credentialprovider.CName) - if credProvider != nil { - s.credentialProvider = credProvider.(credentialprovider.CredentialProvider) - } else { - s.credentialProvider = credentialprovider.NewNoOp() - } + s.peerManagerProvider = a.MustComponent(peermanager.CName).(peermanager.PeerManagerProvider) + s.statusServiceProvider = a.MustComponent(syncstatus.CName).(syncstatus.StatusServiceProvider) s.pool = a.MustComponent(pool.CName).(pool.Pool) s.metric, _ = a.Component(metric.CName).(metric.Metric) + s.app = a return nil } @@ -138,8 +145,6 @@ func (s *spaceService) NewSpace(ctx context.Context, id string) (Space, error) { } } } - - lastConfiguration := s.configurationService var ( spaceIsClosed = &atomic.Bool{} spaceIsDeleted = &atomic.Bool{} @@ -149,42 +154,39 @@ func (s *spaceService) NewSpace(ctx context.Context, id string) (Space, error) { return nil, err } spaceIsDeleted.Swap(isDeleted) - getter := newCommonGetter(st.Id(), s.treeManager, spaceIsClosed) - syncStatus := syncstatus.NewNoOpSyncStatus() - // this will work only for clients, not the best solution, but... - if !lastConfiguration.IsResponsible(st.Id()) { - // TODO: move it to the client package and add possibility to inject StatusProvider from the client - syncStatus = syncstatus.NewSyncStatusProvider(st.Id(), syncstatus.DefaultDeps(lastConfiguration, st)) + state := &spacestate.SpaceState{ + SpaceId: st.Id(), + SpaceIsDeleted: spaceIsDeleted, + SpaceIsClosed: spaceIsClosed, + TreesUsed: &atomic.Int32{}, } - var builder objecttree.BuildObjectTreeFunc if s.config.KeepTreeDataInMemory { - builder = objecttree.BuildObjectTree + state.TreeBuilderFunc = objecttree.BuildObjectTree } else { - builder = objecttree.BuildEmptyDataObjectTree + state.TreeBuilderFunc = objecttree.BuildEmptyDataObjectTree } - - peerManager, err := s.peermanagerProvider.NewPeerManager(ctx, id) + peerManager, err := s.peerManagerProvider.NewPeerManager(ctx, id) if err != nil { return nil, err } + statusService := s.statusServiceProvider.NewStatusService() + spaceApp := s.app.ChildApp() + spaceApp.Register(state). + Register(peerManager). + Register(newCommonStorage(st)). + Register(statusService). + Register(syncacl.New()). + Register(requestmanager.New()). + Register(deletionstate.New()). + Register(settings.New()). + Register(objectmanager.New(s.treeManager)). + Register(objecttreebuilder.New()). + Register(objectsync.New()). + Register(headsync.New()) - headSync := headsync.NewHeadSync(id, spaceIsDeleted, s.config.SyncPeriod, lastConfiguration, st, peerManager, getter, syncStatus, s.credentialProvider, log) - objectSync := objectsync.NewObjectSync(id, spaceIsDeleted, lastConfiguration, peerManager, getter, st) sp := &space{ - id: id, - objectSync: objectSync, - headSync: headSync, - syncStatus: syncStatus, - treeManager: getter, - account: s.account, - configuration: lastConfiguration, - peerManager: peerManager, - storage: st, - treesUsed: &atomic.Int32{}, - treeBuilder: builder, - isClosed: spaceIsClosed, - isDeleted: spaceIsDeleted, - metric: s.metric, + state: state, + app: spaceApp, } return sp, nil } @@ -226,8 +228,12 @@ func (s *spaceService) getSpaceStorageFromRemote(ctx context.Context, id string) return } - cl := spacesyncproto.NewDRPCSpaceSyncClient(p) - res, err := cl.SpacePull(ctx, &spacesyncproto.SpacePullRequest{Id: id}) + var res *spacesyncproto.SpacePullResponse + err = p.DoDrpc(ctx, func(conn drpc.Conn) error { + cl := spacesyncproto.NewDRPCSpaceSyncClient(conn) + res, err = cl.SpacePull(ctx, &spacesyncproto.SpacePullRequest{Id: id}) + return err + }) if err != nil { err = rpcerr.Unwrap(err) return diff --git a/commonspace/spacestate/shareddata.go b/commonspace/spacestate/shareddata.go new file mode 100644 index 00000000..b2a53a50 --- /dev/null +++ b/commonspace/spacestate/shareddata.go @@ -0,0 +1,25 @@ +package spacestate + +import ( + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" + "sync/atomic" +) + +const CName = "common.commonspace.spacestate" + +type SpaceState struct { + SpaceId string + SpaceIsDeleted *atomic.Bool + SpaceIsClosed *atomic.Bool + TreesUsed *atomic.Int32 + TreeBuilderFunc objecttree.BuildObjectTreeFunc +} + +func (s *SpaceState) Init(a *app.App) (err error) { + return nil +} + +func (s *SpaceState) Name() (name string) { + return CName +} diff --git a/commonspace/spacestorage/inmemorystorage.go b/commonspace/spacestorage/inmemorystorage.go index b728a8a6..096f5a84 100644 --- a/commonspace/spacestorage/inmemorystorage.go +++ b/commonspace/spacestorage/inmemorystorage.go @@ -1,6 +1,8 @@ package spacestorage import ( + "context" + "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/commonspace/object/acl/liststorage" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" @@ -21,6 +23,22 @@ type InMemorySpaceStorage struct { sync.Mutex } +func (i *InMemorySpaceStorage) Run(ctx context.Context) (err error) { + return nil +} + +func (i *InMemorySpaceStorage) Close(ctx context.Context) (err error) { + return nil +} + +func (i *InMemorySpaceStorage) Init(a *app.App) (err error) { + return nil +} + +func (i *InMemorySpaceStorage) Name() (name string) { + return CName +} + func NewInMemorySpaceStorage(payload SpaceStorageCreatePayload) (SpaceStorage, error) { aclStorage, err := liststorage.NewInMemoryAclListStorage(payload.AclWithId.Id, []*aclrecordproto.RawAclRecordWithId{payload.AclWithId}) if err != nil { @@ -148,10 +166,6 @@ func (i *InMemorySpaceStorage) ReadSpaceHash() (hash string, err error) { return i.spaceHash, nil } -func (i *InMemorySpaceStorage) Close() error { - return nil -} - func (i *InMemorySpaceStorage) AllTrees() map[string]treestorage.TreeStorage { i.Lock() defer i.Unlock() diff --git a/commonspace/spacestorage/mock_spacestorage/mock_spacestorage.go b/commonspace/spacestorage/mock_spacestorage/mock_spacestorage.go index a8410488..bc7f448c 100644 --- a/commonspace/spacestorage/mock_spacestorage/mock_spacestorage.go +++ b/commonspace/spacestorage/mock_spacestorage/mock_spacestorage.go @@ -5,8 +5,10 @@ package mock_spacestorage import ( + context "context" reflect "reflect" + app "github.com/anyproto/any-sync/app" liststorage "github.com/anyproto/any-sync/commonspace/object/acl/liststorage" treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage" @@ -53,17 +55,17 @@ func (mr *MockSpaceStorageMockRecorder) AclStorage() *gomock.Call { } // Close mocks base method. -func (m *MockSpaceStorage) Close() error { +func (m *MockSpaceStorage) Close(arg0 context.Context) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") + ret := m.ctrl.Call(m, "Close", arg0) ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. -func (mr *MockSpaceStorageMockRecorder) Close() *gomock.Call { +func (mr *MockSpaceStorageMockRecorder) Close(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSpaceStorage)(nil).Close)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSpaceStorage)(nil).Close), arg0) } // CreateTreeStorage mocks base method. @@ -110,6 +112,20 @@ func (mr *MockSpaceStorageMockRecorder) Id() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Id", reflect.TypeOf((*MockSpaceStorage)(nil).Id)) } +// Init mocks base method. +func (m *MockSpaceStorage) Init(arg0 *app.App) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Init", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Init indicates an expected call of Init. +func (mr *MockSpaceStorageMockRecorder) Init(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockSpaceStorage)(nil).Init), arg0) +} + // IsSpaceDeleted mocks base method. func (m *MockSpaceStorage) IsSpaceDeleted() (bool, error) { m.ctrl.T.Helper() @@ -125,6 +141,20 @@ func (mr *MockSpaceStorageMockRecorder) IsSpaceDeleted() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsSpaceDeleted", reflect.TypeOf((*MockSpaceStorage)(nil).IsSpaceDeleted)) } +// Name mocks base method. +func (m *MockSpaceStorage) Name() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Name") + ret0, _ := ret[0].(string) + return ret0 +} + +// Name indicates an expected call of Name. +func (mr *MockSpaceStorageMockRecorder) Name() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockSpaceStorage)(nil).Name)) +} + // ReadSpaceHash mocks base method. func (m *MockSpaceStorage) ReadSpaceHash() (string, error) { m.ctrl.T.Helper() @@ -140,6 +170,20 @@ func (mr *MockSpaceStorageMockRecorder) ReadSpaceHash() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadSpaceHash", reflect.TypeOf((*MockSpaceStorage)(nil).ReadSpaceHash)) } +// Run mocks base method. +func (m *MockSpaceStorage) Run(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Run", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Run indicates an expected call of Run. +func (mr *MockSpaceStorageMockRecorder) Run(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSpaceStorage)(nil).Run), arg0) +} + // SetSpaceDeleted mocks base method. func (m *MockSpaceStorage) SetSpaceDeleted() error { m.ctrl.T.Helper() diff --git a/commonspace/spacestorage/spacestorage.go b/commonspace/spacestorage/spacestorage.go index cecd3736..e2807f5a 100644 --- a/commonspace/spacestorage/spacestorage.go +++ b/commonspace/spacestorage/spacestorage.go @@ -27,8 +27,8 @@ const ( TreeDeletedStatusDeleted = "deleted" ) -// TODO: consider moving to some file with all common interfaces etc type SpaceStorage interface { + app.ComponentRunnable Id() string SetSpaceDeleted() error IsSpaceDeleted() (bool, error) @@ -44,8 +44,6 @@ type SpaceStorage interface { CreateTreeStorage(payload treestorage.TreeStorageCreatePayload) (treestorage.TreeStorage, error) WriteSpaceHash(hash string) error ReadSpaceHash() (hash string, err error) - - Close() error } type SpaceStorageCreatePayload struct { diff --git a/commonspace/spacestorage/spacestorage_test.go b/commonspace/spacestorage/spacestorage_test.go deleted file mode 100644 index 7ca39ae0..00000000 --- a/commonspace/spacestorage/spacestorage_test.go +++ /dev/null @@ -1 +0,0 @@ -package spacestorage diff --git a/commonspace/spacesyncproto/mock_spacesyncproto/mock_spacesyncproto.go b/commonspace/spacesyncproto/mock_spacesyncproto/mock_spacesyncproto.go index db8313a2..8d8a1111 100644 --- a/commonspace/spacesyncproto/mock_spacesyncproto/mock_spacesyncproto.go +++ b/commonspace/spacesyncproto/mock_spacesyncproto/mock_spacesyncproto.go @@ -65,6 +65,21 @@ func (mr *MockDRPCSpaceSyncClientMockRecorder) HeadSync(arg0, arg1 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadSync", reflect.TypeOf((*MockDRPCSpaceSyncClient)(nil).HeadSync), arg0, arg1) } +// ObjectSync mocks base method. +func (m *MockDRPCSpaceSyncClient) ObjectSync(arg0 context.Context, arg1 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ObjectSync", arg0, arg1) + ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ObjectSync indicates an expected call of ObjectSync. +func (mr *MockDRPCSpaceSyncClientMockRecorder) ObjectSync(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ObjectSync", reflect.TypeOf((*MockDRPCSpaceSyncClient)(nil).ObjectSync), arg0, arg1) +} + // ObjectSyncStream mocks base method. func (m *MockDRPCSpaceSyncClient) ObjectSyncStream(arg0 context.Context) (spacesyncproto.DRPCSpaceSync_ObjectSyncStreamClient, error) { m.ctrl.T.Helper() diff --git a/commonspace/spacesyncproto/protos/spacesync.proto b/commonspace/spacesyncproto/protos/spacesync.proto index d5b461cf..f7f44ae4 100644 --- a/commonspace/spacesyncproto/protos/spacesync.proto +++ b/commonspace/spacesyncproto/protos/spacesync.proto @@ -23,6 +23,8 @@ service SpaceSync { rpc SpacePull(SpacePullRequest) returns (SpacePullResponse); // ObjectSyncStream opens object sync stream with node or client rpc ObjectSyncStream(stream ObjectSyncMessage) returns (stream ObjectSyncMessage); + // ObjectSync sends object sync message and synchronously gets response message + rpc ObjectSync(ObjectSyncMessage) returns (ObjectSyncMessage); } // HeadSyncRange presenting a request for one range diff --git a/commonspace/spacesyncproto/spacesync.pb.go b/commonspace/spacesyncproto/spacesync.pb.go index 33caed5b..80571bf6 100644 --- a/commonspace/spacesyncproto/spacesync.pb.go +++ b/commonspace/spacesyncproto/spacesync.pb.go @@ -1254,75 +1254,75 @@ func init() { } var fileDescriptor_80e49f1f4ac27799 = []byte{ - // 1077 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x94, 0x56, 0xcd, 0x6e, 0xdb, 0x46, - 0x10, 0x16, 0xe9, 0x5f, 0x8d, 0x65, 0x99, 0xd9, 0x28, 0x89, 0xaa, 0x18, 0x8a, 0xb0, 0x28, 0x0a, - 0x23, 0x07, 0x27, 0xb1, 0x8b, 0x02, 0x49, 0xdb, 0x43, 0x62, 0x3b, 0x0d, 0x51, 0x24, 0x36, 0x56, + // 1083 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xa4, 0x56, 0xcd, 0x6e, 0xdb, 0x46, + 0x10, 0x16, 0xe9, 0x5f, 0x8d, 0x65, 0x85, 0xd9, 0x28, 0x89, 0xaa, 0x18, 0x8a, 0xb0, 0x28, 0x0a, + 0x23, 0x07, 0x27, 0xb1, 0x8b, 0x02, 0x49, 0xdb, 0x43, 0x62, 0x3b, 0x35, 0x51, 0x24, 0x36, 0x56, 0x0d, 0x0a, 0x14, 0xc8, 0x61, 0x4d, 0x8e, 0x2d, 0xb6, 0x14, 0xc9, 0x72, 0x57, 0x89, 0x75, 0xec, - 0xa9, 0xd7, 0x9e, 0xdb, 0x07, 0xe8, 0x0b, 0xf4, 0x21, 0x7a, 0x4c, 0x6f, 0x3d, 0x16, 0xf6, 0x8b, - 0x14, 0xbb, 0x5c, 0xfe, 0xc8, 0xa2, 0x02, 0xe4, 0x22, 0xed, 0x7e, 0x33, 0xf3, 0xcd, 0xdf, 0xee, - 0x0e, 0xe1, 0x91, 0x17, 0x8f, 0xc7, 0x71, 0x24, 0x12, 0xee, 0xe1, 0x03, 0xfd, 0x2b, 0xa6, 0x91, - 0x97, 0xa4, 0xb1, 0x8c, 0x1f, 0xe8, 0x5f, 0x51, 0xa2, 0xbb, 0x1a, 0x20, 0xcd, 0x02, 0xa0, 0x2e, - 0x6c, 0xbe, 0x40, 0xee, 0x0f, 0xa7, 0x91, 0xc7, 0x78, 0x74, 0x8e, 0x84, 0xc0, 0xf2, 0x59, 0x1a, - 0x8f, 0xbb, 0xd6, 0xc0, 0xda, 0x59, 0x66, 0x7a, 0x4d, 0xda, 0x60, 0xcb, 0xb8, 0x6b, 0x6b, 0xc4, - 0x96, 0x31, 0xe9, 0xc0, 0x4a, 0x18, 0x8c, 0x03, 0xd9, 0x5d, 0x1a, 0x58, 0x3b, 0x9b, 0x2c, 0xdb, - 0xd0, 0x0b, 0x68, 0x17, 0x54, 0x28, 0x26, 0xa1, 0x54, 0x5c, 0x23, 0x2e, 0x46, 0x9a, 0xab, 0xc5, - 0xf4, 0x9a, 0x7c, 0x05, 0xeb, 0x18, 0xe2, 0x18, 0x23, 0x29, 0xba, 0xf6, 0x60, 0x69, 0x67, 0x63, - 0x6f, 0xb0, 0x5b, 0xc6, 0x37, 0x4b, 0x70, 0x94, 0x29, 0xb2, 0xc2, 0x42, 0x79, 0xf6, 0xe2, 0x49, - 0x54, 0x78, 0xd6, 0x1b, 0xfa, 0x25, 0xdc, 0xaa, 0x35, 0x54, 0x81, 0x07, 0xbe, 0x76, 0xdf, 0x64, - 0x76, 0xe0, 0xeb, 0x80, 0x90, 0xfb, 0x3a, 0x95, 0x26, 0xd3, 0x6b, 0xfa, 0x06, 0xb6, 0x4a, 0xe3, - 0x9f, 0x27, 0x28, 0x24, 0xe9, 0xc2, 0x9a, 0x0e, 0xc9, 0xcd, 0x6d, 0xf3, 0x2d, 0x79, 0x08, 0xab, - 0xa9, 0x2a, 0x53, 0x1e, 0x7b, 0xb7, 0x2e, 0x76, 0xa5, 0xc0, 0x8c, 0x1e, 0xfd, 0x06, 0x9c, 0x4a, - 0x6c, 0x49, 0x1c, 0x09, 0x24, 0xfb, 0xb0, 0x96, 0xea, 0x38, 0x45, 0xd7, 0xd2, 0x34, 0x9f, 0x2c, - 0x2c, 0x01, 0xcb, 0x35, 0xe9, 0x1f, 0x16, 0xdc, 0x38, 0x3e, 0xfd, 0x11, 0x3d, 0xa9, 0xa4, 0x2f, - 0x51, 0x08, 0x7e, 0x8e, 0x1f, 0x08, 0x75, 0x1b, 0x9a, 0x69, 0x96, 0x8f, 0x9b, 0x27, 0x5c, 0x02, - 0xca, 0x2e, 0xc5, 0x24, 0x9c, 0xba, 0xbe, 0x2e, 0x65, 0x93, 0xe5, 0x5b, 0x25, 0x49, 0xf8, 0x34, - 0x8c, 0xb9, 0xdf, 0x5d, 0xd6, 0x7d, 0xcb, 0xb7, 0xa4, 0x07, 0xeb, 0xb1, 0x0e, 0xc0, 0xf5, 0xbb, - 0x2b, 0xda, 0xa8, 0xd8, 0x53, 0x04, 0x67, 0xa8, 0x1c, 0x9f, 0x4c, 0xc4, 0x28, 0x2f, 0xe3, 0xa3, - 0x92, 0x49, 0xc5, 0xb6, 0xb1, 0x77, 0xa7, 0x92, 0x66, 0xa6, 0x9d, 0x89, 0x4b, 0x17, 0x7d, 0x80, - 0x83, 0x14, 0x7d, 0x8c, 0x64, 0xc0, 0x43, 0x1d, 0x75, 0x8b, 0x55, 0x10, 0x7a, 0x13, 0x6e, 0x54, - 0xdc, 0x64, 0xe5, 0xa4, 0xb4, 0xf0, 0x1d, 0x86, 0xb9, 0xef, 0x6b, 0x9d, 0xa7, 0xcf, 0x0b, 0x43, - 0xa5, 0x63, 0xfa, 0xf0, 0xf1, 0x01, 0xd2, 0x5f, 0x6c, 0x68, 0x55, 0x25, 0xe4, 0x29, 0x6c, 0x68, - 0x1b, 0xd5, 0x36, 0x4c, 0x0d, 0xcf, 0xbd, 0x0a, 0x0f, 0xe3, 0xef, 0x86, 0xa5, 0xc2, 0xf7, 0x81, - 0x1c, 0xb9, 0x3e, 0xab, 0xda, 0xa8, 0xa4, 0xb9, 0x17, 0x1a, 0xc2, 0x3c, 0xe9, 0x12, 0x21, 0x14, - 0x5a, 0xe5, 0xae, 0x68, 0xd8, 0x0c, 0x46, 0xf6, 0xa0, 0xa3, 0x29, 0x87, 0x28, 0x65, 0x10, 0x9d, - 0x8b, 0x93, 0x99, 0x16, 0xd6, 0xca, 0xc8, 0x17, 0x70, 0xbb, 0x0e, 0x2f, 0xba, 0xbb, 0x40, 0x4a, - 0xff, 0xb1, 0x60, 0xa3, 0x92, 0x92, 0x3a, 0x17, 0x81, 0x6e, 0x90, 0x9c, 0x9a, 0xab, 0x5e, 0xec, - 0xd5, 0x29, 0x94, 0xc1, 0x18, 0x85, 0xe4, 0xe3, 0x44, 0xa7, 0xb6, 0xc4, 0x4a, 0x40, 0x49, 0xb5, - 0x8f, 0xef, 0xa6, 0x09, 0x9a, 0xb4, 0x4a, 0x80, 0x7c, 0x06, 0x6d, 0x75, 0x28, 0x03, 0x8f, 0xcb, - 0x20, 0x8e, 0xbe, 0xc5, 0xa9, 0xce, 0x66, 0x99, 0x5d, 0x43, 0xd5, 0xad, 0x16, 0x88, 0x59, 0xd4, - 0x2d, 0xa6, 0xd7, 0x64, 0x17, 0x48, 0xa5, 0xc4, 0x79, 0x35, 0x56, 0xb5, 0x46, 0x8d, 0x84, 0x9e, - 0x40, 0x7b, 0xb6, 0x51, 0x64, 0x30, 0xdf, 0xd8, 0xd6, 0x6c, 0xdf, 0x54, 0xf4, 0xc1, 0x79, 0xc4, - 0xe5, 0x24, 0x45, 0xd3, 0xb6, 0x12, 0xa0, 0x87, 0xd0, 0xa9, 0x6b, 0xbd, 0xbe, 0x97, 0xfc, 0xdd, - 0x0c, 0x6b, 0x09, 0x98, 0x73, 0x6b, 0x17, 0xe7, 0xf6, 0x77, 0x0b, 0x3a, 0xc3, 0x6a, 0x1b, 0x0e, - 0xe2, 0x48, 0xaa, 0xa7, 0xed, 0x6b, 0x68, 0x65, 0x97, 0xef, 0x10, 0x43, 0x94, 0x58, 0x73, 0x80, - 0x8f, 0x2b, 0xe2, 0x17, 0x0d, 0x36, 0xa3, 0x4e, 0x9e, 0x98, 0xec, 0x8c, 0xb5, 0xad, 0xad, 0x6f, - 0x5f, 0x3f, 0xfe, 0x85, 0x71, 0x55, 0xf9, 0xd9, 0x1a, 0xac, 0xbc, 0xe5, 0xe1, 0x04, 0x69, 0x1f, - 0x5a, 0x55, 0x27, 0x73, 0x97, 0x6e, 0xdf, 0x9c, 0x13, 0x23, 0xfe, 0x14, 0x36, 0x7d, 0xbd, 0x4a, - 0x4f, 0x10, 0xd3, 0xe2, 0xc5, 0x9a, 0x05, 0xe9, 0x1b, 0xb8, 0x35, 0x93, 0xf0, 0x30, 0xe2, 0x89, - 0x18, 0xc5, 0x52, 0x5d, 0x93, 0x4c, 0xd3, 0x77, 0xfd, 0xec, 0xe1, 0x6c, 0xb2, 0x0a, 0x32, 0x4f, - 0x6f, 0xd7, 0xd1, 0xff, 0x6a, 0x41, 0x2b, 0xa7, 0x3e, 0xe4, 0x92, 0x93, 0xc7, 0xb0, 0xe6, 0x65, - 0x35, 0x35, 0x8f, 0xf1, 0xbd, 0xeb, 0x55, 0xb8, 0x56, 0x7a, 0x96, 0xeb, 0xab, 0x59, 0x26, 0x4c, - 0x74, 0xa6, 0x82, 0x83, 0x45, 0xb6, 0x79, 0x16, 0xac, 0xb0, 0xa0, 0x3f, 0x99, 0x27, 0x69, 0x38, - 0x39, 0x15, 0x5e, 0x1a, 0x24, 0xea, 0x38, 0xab, 0xbb, 0x64, 0x1e, 0xf0, 0x3c, 0xc5, 0x62, 0x4f, - 0x9e, 0xc0, 0x2a, 0xf7, 0x94, 0x96, 0x76, 0xd6, 0xde, 0xa3, 0x73, 0xce, 0x2a, 0x4c, 0x4f, 0xb5, - 0x26, 0x33, 0x16, 0xf7, 0xff, 0xb4, 0x60, 0xfd, 0x28, 0x4d, 0x0f, 0x62, 0x1f, 0x05, 0x69, 0x03, - 0xbc, 0x8e, 0xf0, 0x22, 0x41, 0x4f, 0xa2, 0xef, 0x34, 0x88, 0x63, 0xde, 0xb4, 0x97, 0x81, 0x10, - 0x41, 0x74, 0xee, 0x58, 0x64, 0xcb, 0x74, 0xee, 0xe8, 0x22, 0x10, 0x52, 0x38, 0x36, 0xb9, 0x09, - 0x5b, 0x1a, 0x78, 0x15, 0x4b, 0x37, 0x3a, 0xe0, 0xde, 0x08, 0x9d, 0x25, 0x42, 0xa0, 0xad, 0x41, - 0x57, 0x64, 0x1d, 0xf6, 0x9d, 0x65, 0xd2, 0x85, 0x8e, 0xae, 0xb4, 0x78, 0x15, 0x4b, 0xf3, 0xd0, - 0x06, 0xa7, 0x21, 0x3a, 0x2b, 0xa4, 0x03, 0x0e, 0x43, 0x0f, 0x83, 0x44, 0xba, 0xc2, 0x8d, 0xde, - 0xf2, 0x30, 0xf0, 0x9d, 0x55, 0xe5, 0xe9, 0x28, 0x4d, 0xe3, 0xf4, 0xf8, 0xec, 0x4c, 0xa0, 0x74, - 0xfc, 0xfb, 0x8f, 0xe1, 0xce, 0x82, 0x64, 0xc8, 0x26, 0x34, 0x0d, 0x7a, 0x8a, 0x4e, 0x43, 0x99, - 0xbe, 0x8e, 0x44, 0x01, 0x58, 0x7b, 0x7f, 0xd9, 0xd0, 0xcc, 0x6c, 0xa7, 0x91, 0x47, 0x0e, 0x60, - 0x3d, 0x9f, 0xa5, 0xa4, 0x57, 0x3b, 0x60, 0xf5, 0xa8, 0xe8, 0xdd, 0xad, 0x1f, 0xbe, 0xd9, 0x88, - 0x78, 0x6e, 0x18, 0xd5, 0xc0, 0x21, 0x77, 0xe7, 0xc6, 0x43, 0x39, 0xed, 0x7a, 0xdb, 0xf5, 0xc2, - 0x39, 0x9e, 0x30, 0xac, 0xe3, 0x29, 0x26, 0x57, 0x1d, 0x4f, 0x65, 0x64, 0x31, 0x70, 0xca, 0x8f, - 0x80, 0xa1, 0x4c, 0x91, 0x8f, 0xc9, 0xf6, 0xdc, 0xa5, 0xaf, 0x7c, 0x21, 0xf4, 0x3e, 0x28, 0xdd, - 0xb1, 0x1e, 0x5a, 0xcf, 0x3e, 0xff, 0xfb, 0xb2, 0x6f, 0xbd, 0xbf, 0xec, 0x5b, 0xff, 0x5d, 0xf6, - 0xad, 0xdf, 0xae, 0xfa, 0x8d, 0xf7, 0x57, 0xfd, 0xc6, 0xbf, 0x57, 0xfd, 0xc6, 0x0f, 0xbd, 0xc5, - 0xdf, 0x96, 0xa7, 0xab, 0xfa, 0x6f, 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0xd3, 0x01, 0xff, - 0xb5, 0x80, 0x0a, 0x00, 0x00, + 0xa9, 0xd7, 0x9e, 0xdb, 0x07, 0xe8, 0xab, 0xf4, 0x98, 0xde, 0x7a, 0x2c, 0xec, 0xf7, 0x28, 0x8a, + 0x5d, 0x2e, 0x7f, 0x64, 0x51, 0x01, 0x8a, 0x5e, 0xa4, 0xdd, 0x6f, 0x66, 0xbe, 0xf9, 0xdb, 0xdd, + 0x21, 0x3c, 0xf6, 0xe2, 0xf1, 0x38, 0x8e, 0x44, 0xc2, 0x3d, 0x7c, 0xa8, 0x7f, 0xc5, 0x34, 0xf2, + 0x92, 0x34, 0x96, 0xf1, 0x43, 0xfd, 0x2b, 0x4a, 0x74, 0x47, 0x03, 0xa4, 0x59, 0x00, 0xd4, 0x85, + 0xcd, 0x23, 0xe4, 0xfe, 0x70, 0x1a, 0x79, 0x8c, 0x47, 0xe7, 0x48, 0x08, 0x2c, 0x9f, 0xa5, 0xf1, + 0xb8, 0x6b, 0x0d, 0xac, 0xed, 0x65, 0xa6, 0xd7, 0xa4, 0x0d, 0xb6, 0x8c, 0xbb, 0xb6, 0x46, 0x6c, + 0x19, 0x93, 0x0e, 0xac, 0x84, 0xc1, 0x38, 0x90, 0xdd, 0xa5, 0x81, 0xb5, 0xbd, 0xc9, 0xb2, 0x0d, + 0xbd, 0x80, 0x76, 0x41, 0x85, 0x62, 0x12, 0x4a, 0xc5, 0x35, 0xe2, 0x62, 0xa4, 0xb9, 0x5a, 0x4c, + 0xaf, 0xc9, 0x17, 0xb0, 0x8e, 0x21, 0x8e, 0x31, 0x92, 0xa2, 0x6b, 0x0f, 0x96, 0xb6, 0x37, 0x76, + 0x07, 0x3b, 0x65, 0x7c, 0xb3, 0x04, 0x87, 0x99, 0x22, 0x2b, 0x2c, 0x94, 0x67, 0x2f, 0x9e, 0x44, + 0x85, 0x67, 0xbd, 0xa1, 0x9f, 0xc3, 0xed, 0x5a, 0x43, 0x15, 0x78, 0xe0, 0x6b, 0xf7, 0x4d, 0x66, + 0x07, 0xbe, 0x0e, 0x08, 0xb9, 0xaf, 0x53, 0x69, 0x32, 0xbd, 0xa6, 0x6f, 0xe0, 0x46, 0x69, 0xfc, + 0xe3, 0x04, 0x85, 0x24, 0x5d, 0x58, 0xd3, 0x21, 0xb9, 0xb9, 0x6d, 0xbe, 0x25, 0x8f, 0x60, 0x35, + 0x55, 0x65, 0xca, 0x63, 0xef, 0xd6, 0xc5, 0xae, 0x14, 0x98, 0xd1, 0xa3, 0x5f, 0x81, 0x53, 0x89, + 0x2d, 0x89, 0x23, 0x81, 0x64, 0x0f, 0xd6, 0x52, 0x1d, 0xa7, 0xe8, 0x5a, 0x9a, 0xe6, 0xa3, 0x85, + 0x25, 0x60, 0xb9, 0x26, 0xfd, 0xcd, 0x82, 0x9b, 0xc7, 0xa7, 0xdf, 0xa3, 0x27, 0x95, 0xf4, 0x25, + 0x0a, 0xc1, 0xcf, 0xf1, 0x03, 0xa1, 0x6e, 0x41, 0x33, 0xcd, 0xf2, 0x71, 0xf3, 0x84, 0x4b, 0x40, + 0xd9, 0xa5, 0x98, 0x84, 0x53, 0xd7, 0xd7, 0xa5, 0x6c, 0xb2, 0x7c, 0xab, 0x24, 0x09, 0x9f, 0x86, + 0x31, 0xf7, 0xbb, 0xcb, 0xba, 0x6f, 0xf9, 0x96, 0xf4, 0x60, 0x3d, 0xd6, 0x01, 0xb8, 0x7e, 0x77, + 0x45, 0x1b, 0x15, 0x7b, 0x8a, 0xe0, 0x0c, 0x95, 0xe3, 0x93, 0x89, 0x18, 0xe5, 0x65, 0x7c, 0x5c, + 0x32, 0xa9, 0xd8, 0x36, 0x76, 0xef, 0x56, 0xd2, 0xcc, 0xb4, 0x33, 0x71, 0xe9, 0xa2, 0x0f, 0xb0, + 0x9f, 0xa2, 0x8f, 0x91, 0x0c, 0x78, 0xa8, 0xa3, 0x6e, 0xb1, 0x0a, 0x42, 0x6f, 0xc1, 0xcd, 0x8a, + 0x9b, 0xac, 0x9c, 0x94, 0x16, 0xbe, 0xc3, 0x30, 0xf7, 0x7d, 0xad, 0xf3, 0xf4, 0x45, 0x61, 0xa8, + 0x74, 0x4c, 0x1f, 0xfe, 0x7b, 0x80, 0xf4, 0x27, 0x1b, 0x5a, 0x55, 0x09, 0x79, 0x06, 0x1b, 0xda, + 0x46, 0xb5, 0x0d, 0x53, 0xc3, 0x73, 0xbf, 0xc2, 0xc3, 0xf8, 0xbb, 0x61, 0xa9, 0xf0, 0x6d, 0x20, + 0x47, 0xae, 0xcf, 0xaa, 0x36, 0x2a, 0x69, 0xee, 0x85, 0x86, 0x30, 0x4f, 0xba, 0x44, 0x08, 0x85, + 0x56, 0xb9, 0x2b, 0x1a, 0x36, 0x83, 0x91, 0x5d, 0xe8, 0x68, 0xca, 0x21, 0x4a, 0x19, 0x44, 0xe7, + 0xe2, 0x64, 0xa6, 0x85, 0xb5, 0x32, 0xf2, 0x19, 0xdc, 0xa9, 0xc3, 0x8b, 0xee, 0x2e, 0x90, 0xd2, + 0x3f, 0x2d, 0xd8, 0xa8, 0xa4, 0xa4, 0xce, 0x45, 0xa0, 0x1b, 0x24, 0xa7, 0xe6, 0xaa, 0x17, 0x7b, + 0x75, 0x0a, 0x65, 0x30, 0x46, 0x21, 0xf9, 0x38, 0xd1, 0xa9, 0x2d, 0xb1, 0x12, 0x50, 0x52, 0xed, + 0xe3, 0x9b, 0x69, 0x82, 0x26, 0xad, 0x12, 0x20, 0x9f, 0x40, 0x5b, 0x1d, 0xca, 0xc0, 0xe3, 0x32, + 0x88, 0xa3, 0xaf, 0x71, 0xaa, 0xb3, 0x59, 0x66, 0xd7, 0x50, 0x75, 0xab, 0x05, 0x62, 0x16, 0x75, + 0x8b, 0xe9, 0x35, 0xd9, 0x01, 0x52, 0x29, 0x71, 0x5e, 0x8d, 0x55, 0xad, 0x51, 0x23, 0xa1, 0x27, + 0xd0, 0x9e, 0x6d, 0x14, 0x19, 0xcc, 0x37, 0xb6, 0x35, 0xdb, 0x37, 0x15, 0x7d, 0x70, 0x1e, 0x71, + 0x39, 0x49, 0xd1, 0xb4, 0xad, 0x04, 0xe8, 0x01, 0x74, 0xea, 0x5a, 0xaf, 0xef, 0x25, 0x7f, 0x37, + 0xc3, 0x5a, 0x02, 0xe6, 0xdc, 0xda, 0xc5, 0xb9, 0xfd, 0xd5, 0x82, 0xce, 0xb0, 0xda, 0x86, 0xfd, + 0x38, 0x92, 0xea, 0x69, 0xfb, 0x12, 0x5a, 0xd9, 0xe5, 0x3b, 0xc0, 0x10, 0x25, 0xd6, 0x1c, 0xe0, + 0xe3, 0x8a, 0xf8, 0xa8, 0xc1, 0x66, 0xd4, 0xc9, 0x53, 0x93, 0x9d, 0xb1, 0xb6, 0xb5, 0xf5, 0x9d, + 0xeb, 0xc7, 0xbf, 0x30, 0xae, 0x2a, 0x3f, 0x5f, 0x83, 0x95, 0xb7, 0x3c, 0x9c, 0x20, 0xed, 0x43, + 0xab, 0xea, 0x64, 0xee, 0xd2, 0xed, 0x99, 0x73, 0x62, 0xc4, 0x1f, 0xc3, 0xa6, 0xaf, 0x57, 0xe9, + 0x09, 0x62, 0x5a, 0xbc, 0x58, 0xb3, 0x20, 0x7d, 0x03, 0xb7, 0x67, 0x12, 0x1e, 0x46, 0x3c, 0x11, + 0xa3, 0x58, 0xaa, 0x6b, 0x92, 0x69, 0xfa, 0xae, 0x9f, 0x3d, 0x9c, 0x4d, 0x56, 0x41, 0xe6, 0xe9, + 0xed, 0x3a, 0xfa, 0x9f, 0x2d, 0x68, 0xe5, 0xd4, 0x07, 0x5c, 0x72, 0xf2, 0x04, 0xd6, 0xbc, 0xac, + 0xa6, 0xe6, 0x31, 0xbe, 0x7f, 0xbd, 0x0a, 0xd7, 0x4a, 0xcf, 0x72, 0x7d, 0x35, 0xcb, 0x84, 0x89, + 0xce, 0x54, 0x70, 0xb0, 0xc8, 0x36, 0xcf, 0x82, 0x15, 0x16, 0xf4, 0x07, 0xf3, 0x24, 0x0d, 0x27, + 0xa7, 0xc2, 0x4b, 0x83, 0x44, 0x1d, 0x67, 0x75, 0x97, 0xcc, 0x03, 0x9e, 0xa7, 0x58, 0xec, 0xc9, + 0x53, 0x58, 0xe5, 0x9e, 0xd2, 0xd2, 0xce, 0xda, 0xbb, 0x74, 0xce, 0x59, 0x85, 0xe9, 0x99, 0xd6, + 0x64, 0xc6, 0xe2, 0xc1, 0xef, 0x16, 0xac, 0x1f, 0xa6, 0xe9, 0x7e, 0xec, 0xa3, 0x20, 0x6d, 0x80, + 0xd7, 0x11, 0x5e, 0x24, 0xe8, 0x49, 0xf4, 0x9d, 0x06, 0x71, 0xcc, 0x9b, 0xf6, 0x32, 0x10, 0x22, + 0x88, 0xce, 0x1d, 0x8b, 0xdc, 0x30, 0x9d, 0x3b, 0xbc, 0x08, 0x84, 0x14, 0x8e, 0x4d, 0x6e, 0xc1, + 0x0d, 0x0d, 0xbc, 0x8a, 0xa5, 0x1b, 0xed, 0x73, 0x6f, 0x84, 0xce, 0x12, 0x21, 0xd0, 0xd6, 0xa0, + 0x2b, 0xb2, 0x0e, 0xfb, 0xce, 0x32, 0xe9, 0x42, 0x47, 0x57, 0x5a, 0xbc, 0x8a, 0xa5, 0x79, 0x68, + 0x83, 0xd3, 0x10, 0x9d, 0x15, 0xd2, 0x01, 0x87, 0xa1, 0x87, 0x41, 0x22, 0x5d, 0xe1, 0x46, 0x6f, + 0x79, 0x18, 0xf8, 0xce, 0xaa, 0xf2, 0x74, 0x98, 0xa6, 0x71, 0x7a, 0x7c, 0x76, 0x26, 0x50, 0x3a, + 0xfe, 0x83, 0x27, 0x70, 0x77, 0x41, 0x32, 0x64, 0x13, 0x9a, 0x06, 0x3d, 0x45, 0xa7, 0xa1, 0x4c, + 0x5f, 0x47, 0xa2, 0x00, 0xac, 0xdd, 0x7f, 0x6c, 0x68, 0x66, 0xb6, 0xd3, 0xc8, 0x23, 0xfb, 0xb0, + 0x9e, 0xcf, 0x52, 0xd2, 0xab, 0x1d, 0xb0, 0x7a, 0x54, 0xf4, 0xee, 0xd5, 0x0f, 0xdf, 0x6c, 0x44, + 0xbc, 0x30, 0x8c, 0x6a, 0xe0, 0x90, 0x7b, 0x73, 0xe3, 0xa1, 0x9c, 0x76, 0xbd, 0xad, 0x7a, 0xe1, + 0x1c, 0x4f, 0x18, 0xd6, 0xf1, 0x14, 0x93, 0xab, 0x8e, 0xa7, 0x32, 0xb2, 0x18, 0x38, 0xe5, 0x47, + 0xc0, 0x50, 0xa6, 0xc8, 0xc7, 0x64, 0x6b, 0xee, 0xd2, 0x57, 0xbe, 0x10, 0x7a, 0x1f, 0x94, 0x6e, + 0x5b, 0x8f, 0x2c, 0x72, 0x04, 0x50, 0x0a, 0xfe, 0x0f, 0xdb, 0xf3, 0x4f, 0xff, 0xb8, 0xec, 0x5b, + 0xef, 0x2f, 0xfb, 0xd6, 0xdf, 0x97, 0x7d, 0xeb, 0x97, 0xab, 0x7e, 0xe3, 0xfd, 0x55, 0xbf, 0xf1, + 0xd7, 0x55, 0xbf, 0xf1, 0x5d, 0x6f, 0xf1, 0x57, 0xea, 0xe9, 0xaa, 0xfe, 0xdb, 0xfb, 0x37, 0x00, + 0x00, 0xff, 0xff, 0xb6, 0xe1, 0x84, 0x46, 0xca, 0x0a, 0x00, 0x00, } func (m *HeadSyncRange) Marshal() (dAtA []byte, err error) { diff --git a/commonspace/spacesyncproto/spacesync_drpc.pb.go b/commonspace/spacesyncproto/spacesync_drpc.pb.go index 2c82a645..11e5d715 100644 --- a/commonspace/spacesyncproto/spacesync_drpc.pb.go +++ b/commonspace/spacesyncproto/spacesync_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: commonspace/spacesyncproto/protos/spacesync.proto package spacesyncproto @@ -44,6 +44,7 @@ type DRPCSpaceSyncClient interface { SpacePush(ctx context.Context, in *SpacePushRequest) (*SpacePushResponse, error) SpacePull(ctx context.Context, in *SpacePullRequest) (*SpacePullResponse, error) ObjectSyncStream(ctx context.Context) (DRPCSpaceSync_ObjectSyncStreamClient, error) + ObjectSync(ctx context.Context, in *ObjectSyncMessage) (*ObjectSyncMessage, error) } type drpcSpaceSyncClient struct { @@ -102,6 +103,10 @@ type drpcSpaceSync_ObjectSyncStreamClient struct { drpc.Stream } +func (x *drpcSpaceSync_ObjectSyncStreamClient) GetStream() drpc.Stream { + return x.Stream +} + func (x *drpcSpaceSync_ObjectSyncStreamClient) Send(m *ObjectSyncMessage) error { return x.MsgSend(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}) } @@ -118,11 +123,21 @@ func (x *drpcSpaceSync_ObjectSyncStreamClient) RecvMsg(m *ObjectSyncMessage) err return x.MsgRecv(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}) } +func (c *drpcSpaceSyncClient) ObjectSync(ctx context.Context, in *ObjectSyncMessage) (*ObjectSyncMessage, error) { + out := new(ObjectSyncMessage) + err := c.cc.Invoke(ctx, "/spacesync.SpaceSync/ObjectSync", drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}, in, out) + if err != nil { + return nil, err + } + return out, nil +} + type DRPCSpaceSyncServer interface { HeadSync(context.Context, *HeadSyncRequest) (*HeadSyncResponse, error) SpacePush(context.Context, *SpacePushRequest) (*SpacePushResponse, error) SpacePull(context.Context, *SpacePullRequest) (*SpacePullResponse, error) ObjectSyncStream(DRPCSpaceSync_ObjectSyncStreamStream) error + ObjectSync(context.Context, *ObjectSyncMessage) (*ObjectSyncMessage, error) } type DRPCSpaceSyncUnimplementedServer struct{} @@ -143,9 +158,13 @@ func (s *DRPCSpaceSyncUnimplementedServer) ObjectSyncStream(DRPCSpaceSync_Object return drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) } +func (s *DRPCSpaceSyncUnimplementedServer) ObjectSync(context.Context, *ObjectSyncMessage) (*ObjectSyncMessage, error) { + return nil, drpcerr.WithCode(errors.New("Unimplemented"), drpcerr.Unimplemented) +} + type DRPCSpaceSyncDescription struct{} -func (DRPCSpaceSyncDescription) NumMethods() int { return 4 } +func (DRPCSpaceSyncDescription) NumMethods() int { return 5 } func (DRPCSpaceSyncDescription) Method(n int) (string, drpc.Encoding, drpc.Receiver, interface{}, bool) { switch n { @@ -184,6 +203,15 @@ func (DRPCSpaceSyncDescription) Method(n int) (string, drpc.Encoding, drpc.Recei &drpcSpaceSync_ObjectSyncStreamStream{in1.(drpc.Stream)}, ) }, DRPCSpaceSyncServer.ObjectSyncStream, true + case 4: + return "/spacesync.SpaceSync/ObjectSync", drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}, + func(srv interface{}, ctx context.Context, in1, in2 interface{}) (drpc.Message, error) { + return srv.(DRPCSpaceSyncServer). + ObjectSync( + ctx, + in1.(*ObjectSyncMessage), + ) + }, DRPCSpaceSyncServer.ObjectSync, true default: return "", nil, nil, nil, false } @@ -266,3 +294,19 @@ func (x *drpcSpaceSync_ObjectSyncStreamStream) Recv() (*ObjectSyncMessage, error func (x *drpcSpaceSync_ObjectSyncStreamStream) RecvMsg(m *ObjectSyncMessage) error { return x.MsgRecv(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}) } + +type DRPCSpaceSync_ObjectSyncStream interface { + drpc.Stream + SendAndClose(*ObjectSyncMessage) error +} + +type drpcSpaceSync_ObjectSyncStream struct { + drpc.Stream +} + +func (x *drpcSpaceSync_ObjectSyncStream) SendAndClose(m *ObjectSyncMessage) error { + if err := x.MsgSend(m, drpcEncoding_File_commonspace_spacesyncproto_protos_spacesync_proto{}); err != nil { + return err + } + return x.CloseSend() +} diff --git a/commonspace/spaceutils_test.go b/commonspace/spaceutils_test.go index 000d571c..cc82cecd 100644 --- a/commonspace/spaceutils_test.go +++ b/commonspace/spaceutils_test.go @@ -6,12 +6,15 @@ import ( accountService "github.com/anyproto/any-sync/accountservice" "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app/ocache" + "github.com/anyproto/any-sync/commonspace/config" "github.com/anyproto/any-sync/commonspace/credentialprovider" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/treemanager" + "github.com/anyproto/any-sync/commonspace/objecttreebuilder" "github.com/anyproto/any-sync/commonspace/peermanager" "github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacesyncproto" + "github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/pool" "github.com/anyproto/any-sync/nodeconf" @@ -128,6 +131,14 @@ func (m *mockConf) NodeTypes(nodeId string) []nodeconf.NodeType { type mockPeerManager struct { } +func (p *mockPeerManager) Init(a *app.App) (err error) { + return nil +} + +func (p *mockPeerManager) Name() (name string) { + return peermanager.CName +} + func (p *mockPeerManager) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { return nil } @@ -159,6 +170,25 @@ func (m *mockPeerManagerProvider) NewPeerManager(ctx context.Context, spaceId st return &mockPeerManager{}, nil } +// +// Mock StatusServiceProvider +// + +type mockStatusServiceProvider struct { +} + +func (m *mockStatusServiceProvider) Init(a *app.App) (err error) { + return nil +} + +func (m *mockStatusServiceProvider) Name() (name string) { + return syncstatus.CName +} + +func (m *mockStatusServiceProvider) NewStatusService() syncstatus.StatusService { + return syncstatus.NewNoOpSyncStatus() +} + // // Mock Pool // @@ -166,6 +196,10 @@ func (m *mockPeerManagerProvider) NewPeerManager(ctx context.Context, spaceId st type mockPool struct { } +func (m *mockPool) AddPeer(ctx context.Context, p peer.Peer) (err error) { + return nil +} + func (m *mockPool) Init(a *app.App) (err error) { return nil } @@ -205,8 +239,8 @@ func (m *mockConfig) Name() (name string) { return "config" } -func (m *mockConfig) GetSpace() Config { - return Config{ +func (m *mockConfig) GetSpace() config.Config { + return config.Config{ GCTTL: 60, SyncPeriod: 20, KeepTreeDataInMemory: true, @@ -236,6 +270,7 @@ type mockTreeManager struct { cache ocache.OCache deletedIds []string markedIds []string + waitLoad chan struct{} } func (t *mockTreeManager) NewTreeSyncer(spaceId string, treeManager treemanager.TreeManager) treemanager.TreeSyncer { @@ -249,7 +284,8 @@ func (t *mockTreeManager) MarkTreeDeleted(ctx context.Context, spaceId, treeId s func (t *mockTreeManager) Init(a *app.App) (err error) { t.cache = ocache.New(func(ctx context.Context, id string) (value ocache.Object, err error) { - return t.space.BuildTree(ctx, id, BuildTreeOpts{}) + <-t.waitLoad + return t.space.TreeBuilder().BuildTree(ctx, id, objecttreebuilder.BuildTreeOpts{}) }, ocache.WithGCPeriod(time.Minute), ocache.WithTTL(time.Duration(60)*time.Second)) @@ -318,12 +354,14 @@ func newFixture(t *testing.T) *spaceFixture { configurationService: &mockConf{}, storageProvider: spacestorage.NewInMemorySpaceStorageProvider(), peermanagerProvider: &mockPeerManagerProvider{}, - treeManager: &mockTreeManager{}, + treeManager: &mockTreeManager{waitLoad: make(chan struct{})}, pool: &mockPool{}, spaceService: New(), } fx.app.Register(fx.account). Register(fx.config). + Register(credentialprovider.NewNoOp()). + Register(&mockStatusServiceProvider{}). Register(fx.configurationService). Register(fx.storageProvider). Register(fx.peermanagerProvider). diff --git a/commonspace/syncstatus/noop.go b/commonspace/syncstatus/noop.go index 586b90d2..79424d3f 100644 --- a/commonspace/syncstatus/noop.go +++ b/commonspace/syncstatus/noop.go @@ -1,9 +1,32 @@ package syncstatus +import ( + "context" + "github.com/anyproto/any-sync/app" +) + +func NewNoOpSyncStatus() StatusService { + return &noOpSyncStatus{} +} + type noOpSyncStatus struct{} -func NewNoOpSyncStatus() StatusUpdater { - return &noOpSyncStatus{} +func (n *noOpSyncStatus) Init(a *app.App) (err error) { + return nil +} + +func (n *noOpSyncStatus) Name() (name string) { + return CName +} + +func (n *noOpSyncStatus) Watch(treeId string) (err error) { + return nil +} + +func (n *noOpSyncStatus) Unwatch(treeId string) { +} + +func (n *noOpSyncStatus) SetUpdateReceiver(updater UpdateReceiver) { } func (n *noOpSyncStatus) HeadsChange(treeId string, heads []string) { @@ -22,9 +45,10 @@ func (n *noOpSyncStatus) StateCounter() uint64 { func (n *noOpSyncStatus) RemoveAllExcept(senderId string, differentRemoteIds []string, stateCounter uint64) { } -func (n *noOpSyncStatus) Run() { -} - -func (n *noOpSyncStatus) Close() error { +func (n *noOpSyncStatus) Run(ctx context.Context) error { + return nil +} + +func (n *noOpSyncStatus) Close(ctx context.Context) error { return nil } diff --git a/commonspace/syncstatus/syncstatus.go b/commonspace/syncstatus/syncstatus.go index 91191573..32c287fa 100644 --- a/commonspace/syncstatus/syncstatus.go +++ b/commonspace/syncstatus/syncstatus.go @@ -3,6 +3,8 @@ package syncstatus import ( "context" "fmt" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/commonspace/spacestate" "sync" "time" @@ -20,7 +22,9 @@ const ( syncTimeout = time.Second ) -var log = logger.NewNamed("common.commonspace.syncstatus") +var log = logger.NewNamed(CName) + +const CName = "common.commonspace.syncstatus" type UpdateReceiver interface { UpdateTree(ctx context.Context, treeId string, status SyncStatus) (err error) @@ -34,9 +38,6 @@ type StatusUpdater interface { SetNodesOnline(senderId string, online bool) StateCounter() uint64 RemoveAllExcept(senderId string, differentRemoteIds []string, stateCounter uint64) - - Run() - Close() error } type StatusWatcher interface { @@ -45,7 +46,13 @@ type StatusWatcher interface { SetUpdateReceiver(updater UpdateReceiver) } -type StatusProvider interface { +type StatusServiceProvider interface { + app.Component + NewStatusService() StatusService +} + +type StatusService interface { + app.ComponentRunnable StatusUpdater StatusWatcher } @@ -70,7 +77,7 @@ type treeStatus struct { heads []string } -type syncStatusProvider struct { +type syncStatusService struct { sync.Mutex configuration nodeconf.NodeConf periodicSync periodicsync.PeriodicSync @@ -89,52 +96,45 @@ type syncStatusProvider struct { updateTimeout time.Duration } -type SyncStatusDeps struct { - UpdateIntervalSecs int - UpdateTimeout time.Duration - Configuration nodeconf.NodeConf - Storage spacestorage.SpaceStorage -} - -func DefaultDeps(configuration nodeconf.NodeConf, store spacestorage.SpaceStorage) SyncStatusDeps { - return SyncStatusDeps{ - UpdateIntervalSecs: syncUpdateInterval, - UpdateTimeout: syncTimeout, - Configuration: configuration, - Storage: store, +func NewSyncStatusProvider() StatusService { + return &syncStatusService{ + treeHeads: map[string]treeHeadsEntry{}, + watchers: map[string]struct{}{}, } } -func NewSyncStatusProvider(spaceId string, deps SyncStatusDeps) StatusProvider { - return &syncStatusProvider{ - spaceId: spaceId, - treeHeads: map[string]treeHeadsEntry{}, - watchers: map[string]struct{}{}, - updateIntervalSecs: deps.UpdateIntervalSecs, - updateTimeout: deps.UpdateTimeout, - configuration: deps.Configuration, - storage: deps.Storage, - stateCounter: 0, - } +func (s *syncStatusService) Init(a *app.App) (err error) { + sharedState := a.MustComponent(spacestate.CName).(*spacestate.SpaceState) + s.updateIntervalSecs = syncUpdateInterval + s.updateTimeout = syncTimeout + s.spaceId = sharedState.SpaceId + s.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) + s.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage) + return } -func (s *syncStatusProvider) SetUpdateReceiver(updater UpdateReceiver) { +func (s *syncStatusService) Name() (name string) { + return CName +} + +func (s *syncStatusService) SetUpdateReceiver(updater UpdateReceiver) { s.Lock() defer s.Unlock() s.updateReceiver = updater } -func (s *syncStatusProvider) Run() { +func (s *syncStatusService) Run(ctx context.Context) error { s.periodicSync = periodicsync.NewPeriodicSync( s.updateIntervalSecs, s.updateTimeout, s.update, log) s.periodicSync.Run() + return nil } -func (s *syncStatusProvider) HeadsChange(treeId string, heads []string) { +func (s *syncStatusService) HeadsChange(treeId string, heads []string) { s.Lock() defer s.Unlock() @@ -149,7 +149,7 @@ func (s *syncStatusProvider) HeadsChange(treeId string, heads []string) { s.stateCounter++ } -func (s *syncStatusProvider) SetNodesOnline(senderId string, online bool) { +func (s *syncStatusService) SetNodesOnline(senderId string, online bool) { if !s.isSenderResponsible(senderId) { return } @@ -160,7 +160,7 @@ func (s *syncStatusProvider) SetNodesOnline(senderId string, online bool) { s.nodesOnline = online } -func (s *syncStatusProvider) update(ctx context.Context) (err error) { +func (s *syncStatusService) update(ctx context.Context) (err error) { s.treeStatusBuf = s.treeStatusBuf[:0] s.Lock() @@ -189,7 +189,7 @@ func (s *syncStatusProvider) update(ctx context.Context) (err error) { return } -func (s *syncStatusProvider) HeadsReceive(senderId, treeId string, heads []string) { +func (s *syncStatusService) HeadsReceive(senderId, treeId string, heads []string) { s.Lock() defer s.Unlock() @@ -218,7 +218,7 @@ func (s *syncStatusProvider) HeadsReceive(senderId, treeId string, heads []strin s.treeHeads[treeId] = curTreeHeads } -func (s *syncStatusProvider) Watch(treeId string) (err error) { +func (s *syncStatusService) Watch(treeId string) (err error) { s.Lock() defer s.Unlock() _, ok := s.treeHeads[treeId] @@ -248,7 +248,7 @@ func (s *syncStatusProvider) Watch(treeId string) (err error) { return } -func (s *syncStatusProvider) Unwatch(treeId string) { +func (s *syncStatusService) Unwatch(treeId string) { s.Lock() defer s.Unlock() @@ -257,19 +257,14 @@ func (s *syncStatusProvider) Unwatch(treeId string) { } } -func (s *syncStatusProvider) Close() (err error) { - s.periodicSync.Close() - return -} - -func (s *syncStatusProvider) StateCounter() uint64 { +func (s *syncStatusService) StateCounter() uint64 { s.Lock() defer s.Unlock() return s.stateCounter } -func (s *syncStatusProvider) RemoveAllExcept(senderId string, differentRemoteIds []string, stateCounter uint64) { +func (s *syncStatusService) RemoveAllExcept(senderId string, differentRemoteIds []string, stateCounter uint64) { // if sender is not a responsible node, then this should have no effect if !s.isSenderResponsible(senderId) { return @@ -292,6 +287,11 @@ func (s *syncStatusProvider) RemoveAllExcept(senderId string, differentRemoteIds } } -func (s *syncStatusProvider) isSenderResponsible(senderId string) bool { +func (s *syncStatusService) Close(ctx context.Context) error { + s.periodicSync.Close() + return nil +} + +func (s *syncStatusService) isSenderResponsible(senderId string) bool { return slices.Contains(s.configuration.NodeIds(s.spaceId), senderId) } diff --git a/coordinator/coordinatorclient/coordinatorclient.go b/coordinator/coordinatorclient/coordinatorclient.go index 4847ade7..b804d9a5 100644 --- a/coordinator/coordinatorclient/coordinatorclient.go +++ b/coordinator/coordinatorclient/coordinatorclient.go @@ -10,6 +10,7 @@ import ( "github.com/anyproto/any-sync/net/rpc/rpcerr" "github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/util/crypto" + "storj.io/drpc" ) const CName = "common.coordinator.coordinatorclient" @@ -39,42 +40,8 @@ type coordinatorClient struct { nodeConf nodeconf.Service } -func (c *coordinatorClient) ChangeStatus(ctx context.Context, spaceId string, deleteRaw *treechangeproto.RawTreeChangeWithId) (status *coordinatorproto.SpaceStatusPayload, err error) { - cl, err := c.client(ctx) - if err != nil { - return - } - resp, err := cl.SpaceStatusChange(ctx, &coordinatorproto.SpaceStatusChangeRequest{ - SpaceId: spaceId, - DeletionChangeId: deleteRaw.GetId(), - DeletionChangePayload: deleteRaw.GetRawChange(), - }) - if err != nil { - err = rpcerr.Unwrap(err) - return - } - status = resp.Payload - return -} - -func (c *coordinatorClient) StatusCheck(ctx context.Context, spaceId string) (status *coordinatorproto.SpaceStatusPayload, err error) { - cl, err := c.client(ctx) - if err != nil { - return - } - resp, err := cl.SpaceStatusCheck(ctx, &coordinatorproto.SpaceStatusCheckRequest{ - SpaceId: spaceId, - }) - if err != nil { - err = rpcerr.Unwrap(err) - return - } - status = resp.Payload - return -} - func (c *coordinatorClient) Init(a *app.App) (err error) { - c.pool = a.MustComponent(pool.CName).(pool.Service).NewPool(CName) + c.pool = a.MustComponent(pool.CName).(pool.Service) c.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.Service) return } @@ -83,8 +50,37 @@ func (c *coordinatorClient) Name() (name string) { return CName } +func (c *coordinatorClient) ChangeStatus(ctx context.Context, spaceId string, deleteRaw *treechangeproto.RawTreeChangeWithId) (status *coordinatorproto.SpaceStatusPayload, err error) { + err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error { + resp, err := cl.SpaceStatusChange(ctx, &coordinatorproto.SpaceStatusChangeRequest{ + SpaceId: spaceId, + DeletionChangeId: deleteRaw.GetId(), + DeletionChangePayload: deleteRaw.GetRawChange(), + }) + if err != nil { + return rpcerr.Unwrap(err) + } + status = resp.Payload + return nil + }) + return +} + +func (c *coordinatorClient) StatusCheck(ctx context.Context, spaceId string) (status *coordinatorproto.SpaceStatusPayload, err error) { + err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error { + resp, err := cl.SpaceStatusCheck(ctx, &coordinatorproto.SpaceStatusCheckRequest{ + SpaceId: spaceId, + }) + if err != nil { + return rpcerr.Unwrap(err) + } + status = resp.Payload + return nil + }) + return +} + func (c *coordinatorClient) SpaceSign(ctx context.Context, payload SpaceSignPayload) (receipt *coordinatorproto.SpaceReceiptWithSignature, err error) { - cl, err := c.client(ctx) if err != nil { return } @@ -100,54 +96,56 @@ func (c *coordinatorClient) SpaceSign(ctx context.Context, payload SpaceSignPayl if err != nil { return } - resp, err := cl.SpaceSign(ctx, &coordinatorproto.SpaceSignRequest{ - SpaceId: payload.SpaceId, - Header: payload.SpaceHeader, - OldIdentity: oldIdentity, - NewIdentitySignature: newSignature, + err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error { + resp, err := cl.SpaceSign(ctx, &coordinatorproto.SpaceSignRequest{ + SpaceId: payload.SpaceId, + Header: payload.SpaceHeader, + OldIdentity: oldIdentity, + NewIdentitySignature: newSignature, + }) + if err != nil { + return rpcerr.Unwrap(err) + } + receipt = resp.Receipt + return nil }) - if err != nil { - err = rpcerr.Unwrap(err) - return - } - return resp.Receipt, nil -} - -func (c *coordinatorClient) FileLimitCheck(ctx context.Context, spaceId string, identity []byte) (limit uint64, err error) { - cl, err := c.client(ctx) - if err != nil { - return - } - resp, err := cl.FileLimitCheck(ctx, &coordinatorproto.FileLimitCheckRequest{ - AccountIdentity: identity, - SpaceId: spaceId, - }) - if err != nil { - err = rpcerr.Unwrap(err) - return - } - return resp.Limit, nil -} - -func (c *coordinatorClient) NetworkConfiguration(ctx context.Context, currentId string) (resp *coordinatorproto.NetworkConfigurationResponse, err error) { - cl, err := c.client(ctx) - if err != nil { - return - } - resp, err = cl.NetworkConfiguration(ctx, &coordinatorproto.NetworkConfigurationRequest{ - CurrentId: currentId, - }) - if err != nil { - err = rpcerr.Unwrap(err) - return - } return } -func (c *coordinatorClient) client(ctx context.Context) (coordinatorproto.DRPCCoordinatorClient, error) { +func (c *coordinatorClient) FileLimitCheck(ctx context.Context, spaceId string, identity []byte) (limit uint64, err error) { + err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error { + resp, err := cl.FileLimitCheck(ctx, &coordinatorproto.FileLimitCheckRequest{ + AccountIdentity: identity, + SpaceId: spaceId, + }) + if err != nil { + return rpcerr.Unwrap(err) + } + limit = resp.Limit + return nil + }) + return +} + +func (c *coordinatorClient) NetworkConfiguration(ctx context.Context, currentId string) (resp *coordinatorproto.NetworkConfigurationResponse, err error) { + err = c.doClient(ctx, func(cl coordinatorproto.DRPCCoordinatorClient) error { + resp, err = cl.NetworkConfiguration(ctx, &coordinatorproto.NetworkConfigurationRequest{ + CurrentId: currentId, + }) + if err != nil { + return rpcerr.Unwrap(err) + } + return nil + }) + return +} + +func (c *coordinatorClient) doClient(ctx context.Context, f func(cl coordinatorproto.DRPCCoordinatorClient) error) error { p, err := c.pool.GetOneOf(ctx, c.nodeConf.CoordinatorPeers()) if err != nil { - return nil, err + return err } - return coordinatorproto.NewDRPCCoordinatorClient(p), nil + return p.DoDrpc(ctx, func(conn drpc.Conn) error { + return f(coordinatorproto.NewDRPCCoordinatorClient(conn)) + }) } diff --git a/coordinator/coordinatorproto/coordinator_drpc.pb.go b/coordinator/coordinatorproto/coordinator_drpc.pb.go index 75e73a7b..0ed69ea2 100644 --- a/coordinator/coordinatorproto/coordinator_drpc.pb.go +++ b/coordinator/coordinatorproto/coordinator_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: coordinator/coordinatorproto/protos/coordinator.proto package coordinatorproto diff --git a/go.mod b/go.mod index 2859cec7..0581372a 100644 --- a/go.mod +++ b/go.mod @@ -14,6 +14,7 @@ require ( github.com/gogo/protobuf v1.3.2 github.com/golang/mock v1.6.0 github.com/google/uuid v1.3.0 + github.com/hashicorp/yamux v0.1.1 github.com/huandu/skiplist v1.2.0 github.com/ipfs/go-block-format v0.1.2 github.com/ipfs/go-blockservice v0.5.2 @@ -32,7 +33,7 @@ require ( github.com/stretchr/testify v1.8.4 github.com/tyler-smith/go-bip39 v1.1.0 github.com/zeebo/blake3 v0.2.3 - github.com/zeebo/errs v1.3.0 + go.uber.org/atomic v1.11.0 go.uber.org/zap v1.24.0 golang.org/x/crypto v0.9.0 golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 @@ -55,7 +56,7 @@ require ( github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect github.com/golang/protobuf v1.5.3 // indirect - github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 // indirect + github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 // indirect github.com/hashicorp/golang-lru v0.5.4 // indirect github.com/huin/goupnp v1.2.0 // indirect github.com/ipfs/bbloom v0.0.4 // indirect @@ -88,7 +89,7 @@ require ( github.com/multiformats/go-multicodec v0.9.0 // indirect github.com/multiformats/go-multistream v0.4.1 // indirect github.com/multiformats/go-varint v0.0.7 // indirect - github.com/onsi/ginkgo/v2 v2.9.5 // indirect + github.com/onsi/ginkgo/v2 v2.9.7 // indirect github.com/opentracing/opentracing-go v1.2.0 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -96,19 +97,19 @@ require ( github.com/prometheus/client_model v0.4.0 // indirect github.com/prometheus/common v0.44.0 // indirect github.com/prometheus/procfs v0.10.0 // indirect - github.com/quic-go/quic-go v0.34.0 // indirect + github.com/quic-go/quic-go v0.35.1 // indirect github.com/quic-go/webtransport-go v0.5.3 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158 // indirect github.com/whyrusleeping/chunker v0.0.0-20181014151217-fe64bd25879f // indirect + github.com/zeebo/errs v1.3.0 // indirect go.opentelemetry.io/otel v1.7.0 // indirect go.opentelemetry.io/otel/trace v1.7.0 // indirect - go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect golang.org/x/image v0.6.0 // indirect golang.org/x/sync v0.2.0 // indirect golang.org/x/sys v0.8.0 // indirect - golang.org/x/tools v0.9.1 // indirect + golang.org/x/tools v0.9.3 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/protobuf v1.30.0 // indirect lukechampine.com/blake3 v1.2.1 // indirect diff --git a/go.sum b/go.sum index da266a43..04c9a7c6 100644 --- a/go.sum +++ b/go.sum @@ -67,8 +67,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.7/go.mod h1:n+brtR0CgQNWTVd5ZUFpTBC8YFBDLK/h/bpaJ8/DtOE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= -github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3 h1:2XF1Vzq06X+inNqgJ9tRnGuw+ZVCB3FazXODD6JE1R8= -github.com/google/pprof v0.0.0-20230510103437-eeec1cb781c3/go.mod h1:79YE0hCXdHag9sBkw2o+N/YnZtTkXi0UT9Nnixa5eYk= +github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9SN1TigNLn9ZnF3W4SYRKq2gAHs= +github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= @@ -79,6 +79,8 @@ github.com/gxed/hashland/keccakpg v0.0.1/go.mod h1:kRzw3HkwxFU1mpmPP8v1WyQzwdGfm github.com/gxed/hashland/murmur3 v0.0.1/go.mod h1:KjXop02n4/ckmZSnY2+HKcLud/tcmvhST0bie/0lS48= github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/yamux v0.1.1 h1:yrQxtgseBDrq9Y652vSRDvsKCJKOUD+GzTS4Y0Y8pvE= +github.com/hashicorp/yamux v0.1.1/go.mod h1:CtWFDAQgb7dxtzFs4tWbplKIe2jSi3+5vKbgIO0SLnQ= github.com/huandu/go-assert v1.1.5 h1:fjemmA7sSfYHJD7CUqs9qTwwfdNAx7/j2/ZlHXzNB3c= github.com/huandu/go-assert v1.1.5/go.mod h1:yOLvuqZwmcHIC5rIzrBhT7D3Q9c3GFnd0JrPVhn/06U= github.com/huandu/skiplist v1.2.0 h1:gox56QD77HzSC0w+Ws3MH3iie755GBJU1OER3h5VsYw= @@ -242,8 +244,8 @@ github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXS github.com/multiformats/go-varint v0.0.7 h1:sWSGR+f/eu5ABZA2ZpYKBILXTTs9JWpdEM/nEGOHFS8= github.com/multiformats/go-varint v0.0.7/go.mod h1:r8PUYw/fD/SjBCiKOoDlGF6QawOELpZAu9eioSos/OU= github.com/nfnt/resize v0.0.0-20160724205520-891127d8d1b5 h1:BvoENQQU+fZ9uukda/RzCAL/191HHwJA5b13R6diVlY= -github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= -github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= +github.com/onsi/ginkgo/v2 v2.9.7 h1:06xGQy5www2oN160RtEZoTvnP2sPhEfePYmCDc2szss= +github.com/onsi/ginkgo/v2 v2.9.7/go.mod h1:cxrmXWykAwTwhQsJOPfdIDiJ+l2RYq7U8hFU+M/1uw0= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -266,8 +268,8 @@ github.com/prometheus/procfs v0.10.0/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPH github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= github.com/quic-go/qtls-go1-19 v0.3.2 h1:tFxjCFcTQzK+oMxG6Zcvp4Dq8dx4yD3dDiIiyc86Z5U= github.com/quic-go/qtls-go1-20 v0.2.2 h1:WLOPx6OY/hxtTxKV1Zrq20FtXtDEkeY00CGQm8GEa3E= -github.com/quic-go/quic-go v0.34.0 h1:OvOJ9LFjTySgwOTYUZmNoq0FzVicP8YujpV0kB7m2lU= -github.com/quic-go/quic-go v0.34.0/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= +github.com/quic-go/quic-go v0.35.1 h1:b0kzj6b/cQAf05cT0CkQubHM31wiA+xH3IBkxP62poo= +github.com/quic-go/quic-go v0.35.1/go.mod h1:+4CVgVppm0FNjpG3UcX8Joi/frKOH7/ciD5yGcwOO1g= github.com/quic-go/webtransport-go v0.5.3 h1:5XMlzemqB4qmOlgIus5zB45AcZ2kCgCy2EptUrfOPWU= github.com/quic-go/webtransport-go v0.5.3/go.mod h1:OhmmgJIzTTqXK5xvtuX0oBpLV2GkLWNDA+UeTGJXErU= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -413,8 +415,8 @@ golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.9.1 h1:8WMNJAz3zrtPmnYC7ISf5dEn3MT0gY7jBJfw27yrrLo= -golang.org/x/tools v0.9.1/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= +golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM= +golang.org/x/tools v0.9.3/go.mod h1:owI94Op576fPu3cIGQeHs3joujW/2Oc6MtlxbF5dfNc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/net/config.go b/net/config.go index 333261dd..64863419 100644 --- a/net/config.go +++ b/net/config.go @@ -5,21 +5,3 @@ import "errors" var ( ErrUnableToConnect = errors.New("unable to connect") ) - -type ConfigGetter interface { - GetNet() Config -} - -type Config struct { - Server ServerConfig `yaml:"server"` - Stream StreamConfig `yaml:"stream"` -} - -type ServerConfig struct { - ListenAddrs []string `yaml:"listenAddrs"` -} - -type StreamConfig struct { - TimeoutMilliseconds int `yaml:"timeoutMilliseconds"` - MaxMsgSizeMb int `yaml:"maxMsgSizeMb"` -} diff --git a/net/timeoutconn/conn.go b/net/connutil/timeout.go similarity index 82% rename from net/timeoutconn/conn.go rename to net/connutil/timeout.go index 11e80709..057c7b8c 100644 --- a/net/timeoutconn/conn.go +++ b/net/connutil/timeout.go @@ -1,4 +1,4 @@ -package timeoutconn +package connutil import ( "errors" @@ -10,18 +10,18 @@ import ( "go.uber.org/zap" ) -var log = logger.NewNamed("common.net.timeoutconn") +var log = logger.NewNamed("common.net.connutil") -type Conn struct { +type TimeoutConn struct { net.Conn timeout time.Duration } -func NewConn(conn net.Conn, timeout time.Duration) *Conn { - return &Conn{conn, timeout} +func NewTimeout(conn net.Conn, timeout time.Duration) *TimeoutConn { + return &TimeoutConn{conn, timeout} } -func (c *Conn) Write(p []byte) (n int, err error) { +func (c *TimeoutConn) Write(p []byte) (n int, err error) { for { if c.timeout != 0 { if e := c.Conn.SetWriteDeadline(time.Now().Add(c.timeout)); e != nil { diff --git a/net/connutil/usage.go b/net/connutil/usage.go new file mode 100644 index 00000000..826d9c74 --- /dev/null +++ b/net/connutil/usage.go @@ -0,0 +1,30 @@ +package connutil + +import ( + "go.uber.org/atomic" + "net" + "time" +) + +func NewLastUsageConn(conn net.Conn) *LastUsageConn { + return &LastUsageConn{Conn: conn} +} + +type LastUsageConn struct { + net.Conn + lastUsage atomic.Time +} + +func (c *LastUsageConn) Write(p []byte) (n int, err error) { + c.lastUsage.Store(time.Now()) + return c.Conn.Write(p) +} + +func (c *LastUsageConn) Read(p []byte) (n int, err error) { + c.lastUsage.Store(time.Now()) + return c.Conn.Read(p) +} + +func (c *LastUsageConn) LastUsage() time.Time { + return c.lastUsage.Load() +} diff --git a/net/dialer/dialer.go b/net/dialer/dialer.go deleted file mode 100644 index aa65da75..00000000 --- a/net/dialer/dialer.go +++ /dev/null @@ -1,137 +0,0 @@ -package dialer - -import ( - "context" - "errors" - "fmt" - "github.com/anyproto/any-sync/app" - "github.com/anyproto/any-sync/app/logger" - net2 "github.com/anyproto/any-sync/net" - "github.com/anyproto/any-sync/net/peer" - "github.com/anyproto/any-sync/net/secureservice" - "github.com/anyproto/any-sync/net/secureservice/handshake" - "github.com/anyproto/any-sync/net/timeoutconn" - "github.com/anyproto/any-sync/nodeconf" - "github.com/libp2p/go-libp2p/core/sec" - "go.uber.org/zap" - "net" - "storj.io/drpc" - "storj.io/drpc/drpcconn" - "storj.io/drpc/drpcmanager" - "storj.io/drpc/drpcwire" - "sync" - "time" -) - -const CName = "common.net.dialer" - -var ( - ErrAddrsNotFound = errors.New("addrs for peer not found") - ErrPeerIdIsUnexpected = errors.New("expected to connect with other peer id") -) - -var log = logger.NewNamed(CName) - -func New() Dialer { - return &dialer{peerAddrs: map[string][]string{}} -} - -type Dialer interface { - Dial(ctx context.Context, peerId string) (peer peer.Peer, err error) - SetPeerAddrs(peerId string, addrs []string) - app.Component -} - -type dialer struct { - transport secureservice.SecureService - config net2.Config - nodeConf nodeconf.NodeConf - peerAddrs map[string][]string - - mu sync.RWMutex -} - -func (d *dialer) Init(a *app.App) (err error) { - d.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService) - d.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) - d.config = a.MustComponent("config").(net2.ConfigGetter).GetNet() - return -} - -func (d *dialer) Name() (name string) { - return CName -} - -func (d *dialer) SetPeerAddrs(peerId string, addrs []string) { - d.mu.Lock() - defer d.mu.Unlock() - d.peerAddrs[peerId] = addrs -} - -func (d *dialer) getPeerAddrs(peerId string) ([]string, error) { - if addrs, ok := d.nodeConf.PeerAddresses(peerId); ok { - return addrs, nil - } - addrs, ok := d.peerAddrs[peerId] - if !ok || len(addrs) == 0 { - return nil, ErrAddrsNotFound - } - return addrs, nil -} - -func (d *dialer) Dial(ctx context.Context, peerId string) (p peer.Peer, err error) { - var ctxCancel context.CancelFunc - ctx, ctxCancel = context.WithTimeout(ctx, time.Second*10) - defer ctxCancel() - d.mu.RLock() - defer d.mu.RUnlock() - - addrs, err := d.getPeerAddrs(peerId) - if err != nil { - return - } - - var ( - conn drpc.Conn - sc sec.SecureConn - ) - log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs)) - for _, addr := range addrs { - conn, sc, err = d.handshake(ctx, addr, peerId) - if err != nil { - log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err)) - } else { - break - } - } - if err != nil { - return - } - return peer.NewPeer(sc, conn), nil -} - -func (d *dialer) handshake(ctx context.Context, addr, peerId string) (conn drpc.Conn, sc sec.SecureConn, err error) { - st := time.Now() - // TODO: move dial timeout to config - tcpConn, err := net.DialTimeout("tcp", addr, time.Second*15) - if err != nil { - return nil, nil, fmt.Errorf("dialTimeout error: %v; since start: %v", err, time.Since(st)) - } - - timeoutConn := timeoutconn.NewConn(tcpConn, time.Millisecond*time.Duration(d.config.Stream.TimeoutMilliseconds)) - sc, err = d.transport.SecureOutbound(ctx, timeoutConn) - if err != nil { - if he, ok := err.(handshake.HandshakeError); ok { - return nil, nil, he - } - return nil, nil, fmt.Errorf("tls handshaeke error: %v; since start: %v", err, time.Since(st)) - } - if peerId != sc.RemotePeer().String() { - return nil, nil, ErrPeerIdIsUnexpected - } - log.Info("connected with remote host", zap.String("serverPeer", sc.RemotePeer().String()), zap.String("addr", addr)) - conn = drpcconn.NewWithOptions(sc, drpcconn.Options{Manager: drpcmanager.Options{ - Reader: drpcwire.ReaderOptions{MaximumBufferSize: d.config.Stream.MaxMsgSizeMb * (1 << 20)}, - }}) - return conn, sc, err -} diff --git a/net/peer/mock_peer/mock_peer.go b/net/peer/mock_peer/mock_peer.go new file mode 100644 index 00000000..dc0a5b6a --- /dev/null +++ b/net/peer/mock_peer/mock_peer.go @@ -0,0 +1,149 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/anyproto/any-sync/net/peer (interfaces: Peer) + +// Package mock_peer is a generated GoMock package. +package mock_peer + +import ( + context "context" + reflect "reflect" + time "time" + + gomock "github.com/golang/mock/gomock" + drpc "storj.io/drpc" +) + +// MockPeer is a mock of Peer interface. +type MockPeer struct { + ctrl *gomock.Controller + recorder *MockPeerMockRecorder +} + +// MockPeerMockRecorder is the mock recorder for MockPeer. +type MockPeerMockRecorder struct { + mock *MockPeer +} + +// NewMockPeer creates a new mock instance. +func NewMockPeer(ctrl *gomock.Controller) *MockPeer { + mock := &MockPeer{ctrl: ctrl} + mock.recorder = &MockPeerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPeer) EXPECT() *MockPeerMockRecorder { + return m.recorder +} + +// AcquireDrpcConn mocks base method. +func (m *MockPeer) AcquireDrpcConn(arg0 context.Context) (drpc.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AcquireDrpcConn", arg0) + ret0, _ := ret[0].(drpc.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// AcquireDrpcConn indicates an expected call of AcquireDrpcConn. +func (mr *MockPeerMockRecorder) AcquireDrpcConn(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AcquireDrpcConn", reflect.TypeOf((*MockPeer)(nil).AcquireDrpcConn), arg0) +} + +// Close mocks base method. +func (m *MockPeer) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockPeerMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPeer)(nil).Close)) +} + +// Context mocks base method. +func (m *MockPeer) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockPeerMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockPeer)(nil).Context)) +} + +// DoDrpc mocks base method. +func (m *MockPeer) DoDrpc(arg0 context.Context, arg1 func(drpc.Conn) error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DoDrpc", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// DoDrpc indicates an expected call of DoDrpc. +func (mr *MockPeerMockRecorder) DoDrpc(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DoDrpc", reflect.TypeOf((*MockPeer)(nil).DoDrpc), arg0, arg1) +} + +// Id mocks base method. +func (m *MockPeer) Id() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Id") + ret0, _ := ret[0].(string) + return ret0 +} + +// Id indicates an expected call of Id. +func (mr *MockPeerMockRecorder) Id() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Id", reflect.TypeOf((*MockPeer)(nil).Id)) +} + +// IsClosed mocks base method. +func (m *MockPeer) IsClosed() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClosed") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsClosed indicates an expected call of IsClosed. +func (mr *MockPeerMockRecorder) IsClosed() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockPeer)(nil).IsClosed)) +} + +// ReleaseDrpcConn mocks base method. +func (m *MockPeer) ReleaseDrpcConn(arg0 drpc.Conn) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReleaseDrpcConn", arg0) +} + +// ReleaseDrpcConn indicates an expected call of ReleaseDrpcConn. +func (mr *MockPeerMockRecorder) ReleaseDrpcConn(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReleaseDrpcConn", reflect.TypeOf((*MockPeer)(nil).ReleaseDrpcConn), arg0) +} + +// TryClose mocks base method. +func (m *MockPeer) TryClose(arg0 time.Duration) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TryClose", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TryClose indicates an expected call of TryClose. +func (mr *MockPeerMockRecorder) TryClose(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryClose", reflect.TypeOf((*MockPeer)(nil).TryClose), arg0) +} diff --git a/net/peer/peer.go b/net/peer/peer.go index 5bb8022a..bd0d9b34 100644 --- a/net/peer/peer.go +++ b/net/peer/peer.go @@ -1,100 +1,233 @@ +//go:generate mockgen -destination mock_peer/mock_peer.go github.com/anyproto/any-sync/net/peer Peer package peer import ( "context" - "sync/atomic" - "time" - "github.com/anyproto/any-sync/app/logger" - "github.com/libp2p/go-libp2p/core/sec" + "github.com/anyproto/any-sync/app/ocache" + "github.com/anyproto/any-sync/net/connutil" + "github.com/anyproto/any-sync/net/rpc" + "github.com/anyproto/any-sync/net/secureservice/handshake" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "github.com/anyproto/any-sync/net/transport" "go.uber.org/zap" + "io" + "net" "storj.io/drpc" + "storj.io/drpc/drpcconn" + "storj.io/drpc/drpcmanager" + "storj.io/drpc/drpcstream" + "storj.io/drpc/drpcwire" + "sync" + "time" ) var log = logger.NewNamed("common.net.peer") -func NewPeer(sc sec.SecureConn, conn drpc.Conn) Peer { - return &peer{ - id: sc.RemotePeer().String(), - lastUsage: time.Now().Unix(), - sc: sc, - Conn: conn, +type connCtrl interface { + ServeConn(ctx context.Context, conn net.Conn) (err error) + DrpcConfig() rpc.Config +} + +func NewPeer(mc transport.MultiConn, ctrl connCtrl) (p Peer, err error) { + ctx := mc.Context() + pr := &peer{ + active: map[*subConn]struct{}{}, + MultiConn: mc, + ctrl: ctrl, } + if pr.id, err = CtxPeerId(ctx); err != nil { + return + } + go pr.acceptLoop() + return pr, nil } type Peer interface { Id() string - LastUsage() time.Time - UpdateLastUsage() - Addr() string + Context() context.Context + + AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) + ReleaseDrpcConn(conn drpc.Conn) + DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error + + IsClosed() bool + TryClose(objectTTL time.Duration) (res bool, err error) + + ocache.Object +} + +type subConn struct { drpc.Conn + *connutil.LastUsageConn } type peer struct { - id string - ttl time.Duration - lastUsage int64 - sc sec.SecureConn - drpc.Conn + id string + + ctrl connCtrl + + // drpc conn pool + inactive []*subConn + active map[*subConn]struct{} + + mu sync.Mutex + + transport.MultiConn } func (p *peer) Id() string { return p.id } -func (p *peer) LastUsage() time.Time { - select { - case <-p.Closed(): - return time.Unix(0, 0) - default: +func (p *peer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { + p.mu.Lock() + if len(p.inactive) == 0 { + p.mu.Unlock() + dconn, err := p.openDrpcConn(ctx) + if err != nil { + return nil, err + } + p.mu.Lock() + p.inactive = append(p.inactive, dconn) } - return time.Unix(atomic.LoadInt64(&p.lastUsage), 0) + idx := len(p.inactive) - 1 + res := p.inactive[idx] + p.inactive = p.inactive[:idx] + p.active[res] = struct{}{} + p.mu.Unlock() + return res, nil } -func (p *peer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { - defer p.UpdateLastUsage() - return p.Conn.Invoke(ctx, rpc, enc, in, out) -} - -func (p *peer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { - defer p.UpdateLastUsage() - return p.Conn.NewStream(ctx, rpc, enc) -} - -func (p *peer) Read(b []byte) (n int, err error) { - if n, err = p.sc.Read(b); err == nil { - p.UpdateLastUsage() +func (p *peer) ReleaseDrpcConn(conn drpc.Conn) { + p.mu.Lock() + defer p.mu.Unlock() + sc, ok := conn.(*subConn) + if !ok { + return } + if _, ok = p.active[sc]; ok { + delete(p.active, sc) + } + p.inactive = append(p.inactive, sc) return } -func (p *peer) Write(b []byte) (n int, err error) { - if n, err = p.sc.Write(b); err == nil { - p.UpdateLastUsage() +func (p *peer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error { + conn, err := p.AcquireDrpcConn(ctx) + if err != nil { + return err } - return + defer p.ReleaseDrpcConn(conn) + return do(conn) } -func (p *peer) UpdateLastUsage() { - atomic.StoreInt64(&p.lastUsage, time.Now().Unix()) +func (p *peer) openDrpcConn(ctx context.Context) (dconn *subConn, err error) { + conn, err := p.Open(ctx) + if err != nil { + return nil, err + } + if err = handshake.OutgoingProtoHandshake(ctx, conn, handshakeproto.ProtoType_DRPC); err != nil { + return nil, err + } + tconn := connutil.NewLastUsageConn(conn) + bufSize := p.ctrl.DrpcConfig().Stream.MaxMsgSizeMb * (1 << 20) + return &subConn{ + Conn: drpcconn.NewWithOptions(conn, drpcconn.Options{ + Manager: drpcmanager.Options{ + Reader: drpcwire.ReaderOptions{MaximumBufferSize: bufSize}, + Stream: drpcstream.Options{MaximumBufferSize: bufSize}, + }, + }), + LastUsageConn: tconn, + }, nil +} + +func (p *peer) acceptLoop() { + var exitErr error + defer func() { + if exitErr != transport.ErrConnClosed { + log.Warn("accept error: close connection", zap.Error(exitErr)) + _ = p.MultiConn.Close() + } + }() + for { + conn, err := p.Accept() + if err != nil { + exitErr = err + return + } + go func() { + serveErr := p.serve(conn) + if serveErr != io.EOF && serveErr != transport.ErrConnClosed { + log.InfoCtx(p.Context(), "serve connection error", zap.Error(serveErr)) + } + }() + } +} + +var defaultProtoChecker = handshake.ProtoChecker{ + AllowedProtoTypes: []handshakeproto.ProtoType{ + handshakeproto.ProtoType_DRPC, + }, +} + +func (p *peer) serve(conn net.Conn) (err error) { + hsCtx, cancel := context.WithTimeout(p.Context(), time.Second*20) + if _, err = handshake.IncomingProtoHandshake(hsCtx, conn, defaultProtoChecker); err != nil { + cancel() + return + } + cancel() + return p.ctrl.ServeConn(p.Context(), conn) } func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) { + p.gc(objectTTL) if time.Now().Sub(p.LastUsage()) < objectTTL { return false, nil } return true, p.Close() } -func (p *peer) Addr() string { - if p.sc != nil { - return p.sc.RemoteAddr().String() +func (p *peer) gc(ttl time.Duration) { + p.mu.Lock() + defer p.mu.Unlock() + minLastUsage := time.Now().Add(-ttl) + var hasClosed bool + for i, in := range p.inactive { + select { + case <-in.Closed(): + p.inactive[i] = nil + hasClosed = true + default: + } + if in.LastUsage().Before(minLastUsage) { + _ = in.Close() + p.inactive[i] = nil + hasClosed = true + } + } + if hasClosed { + inactive := p.inactive + p.inactive = p.inactive[:0] + for _, in := range inactive { + if in != nil { + p.inactive = append(p.inactive, in) + } + } + } + for act := range p.active { + select { + case <-act.Closed(): + delete(p.active, act) + default: + } } - return "" } func (p *peer) Close() (err error) { log.Debug("peer close", zap.String("peerId", p.id)) - return p.Conn.Close() + return p.MultiConn.Close() } diff --git a/net/peer/peer_test.go b/net/peer/peer_test.go new file mode 100644 index 00000000..c0046923 --- /dev/null +++ b/net/peer/peer_test.go @@ -0,0 +1,192 @@ +package peer + +import ( + "context" + "github.com/anyproto/any-sync/net/rpc" + "github.com/anyproto/any-sync/net/secureservice/handshake" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "github.com/anyproto/any-sync/net/transport/mock_transport" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "io" + "net" + _ "net/http/pprof" + "testing" + "time" +) + +var ctx = context.Background() + +func TestPeer_AcquireDrpcConn(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + in, out := net.Pipe() + go func() { + handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker) + }() + defer out.Close() + fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil) + dc, err := fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + assert.NotEmpty(t, dc) + defer dc.Close() + + assert.Len(t, fx.active, 1) + assert.Len(t, fx.inactive, 0) + + fx.ReleaseDrpcConn(dc) + + assert.Len(t, fx.active, 0) + assert.Len(t, fx.inactive, 1) + + dc, err = fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + assert.NotEmpty(t, dc) + assert.Len(t, fx.active, 1) + assert.Len(t, fx.inactive, 0) +} + +func TestPeerAccept(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + in, out := net.Pipe() + defer out.Close() + + var outHandshakeCh = make(chan error) + go func() { + outHandshakeCh <- handshake.OutgoingProtoHandshake(ctx, out, handshakeproto.ProtoType_DRPC) + }() + fx.acceptCh <- acceptedConn{conn: in} + cn := <-fx.testCtrl.serveConn + assert.Equal(t, in, cn) + assert.NoError(t, <-outHandshakeCh) +} + +func TestPeer_TryClose(t *testing.T) { + t.Run("ttl", func(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + lu := time.Now() + fx.mc.EXPECT().LastUsage().Return(lu) + res, err := fx.TryClose(time.Second) + require.NoError(t, err) + assert.False(t, res) + }) + t.Run("close", func(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + lu := time.Now().Add(-time.Second * 2) + fx.mc.EXPECT().LastUsage().Return(lu) + res, err := fx.TryClose(time.Second) + require.NoError(t, err) + assert.True(t, res) + }) + t.Run("gc", func(t *testing.T) { + fx := newFixture(t, "p1") + defer fx.finish() + now := time.Now() + fx.mc.EXPECT().LastUsage().Return(now.Add(time.Millisecond * 100)) + + // make one inactive + in, out := net.Pipe() + go func() { + handshake.IncomingProtoHandshake(ctx, out, defaultProtoChecker) + }() + defer out.Close() + fx.mc.EXPECT().Open(gomock.Any()).Return(in, nil) + dc, err := fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + + // make one active but closed + in2, out2 := net.Pipe() + go func() { + handshake.IncomingProtoHandshake(ctx, out2, defaultProtoChecker) + }() + defer out2.Close() + fx.mc.EXPECT().Open(gomock.Any()).Return(in2, nil) + dc2, err := fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + _ = dc2.Close() + + // make one inactive and closed + in3, out3 := net.Pipe() + go func() { + handshake.IncomingProtoHandshake(ctx, out3, defaultProtoChecker) + }() + defer out3.Close() + fx.mc.EXPECT().Open(gomock.Any()).Return(in3, nil) + dc3, err := fx.AcquireDrpcConn(ctx) + require.NoError(t, err) + fx.ReleaseDrpcConn(dc3) + _ = dc3.Close() + fx.ReleaseDrpcConn(dc) + + time.Sleep(time.Millisecond * 100) + + res, err := fx.TryClose(time.Millisecond * 50) + require.NoError(t, err) + assert.False(t, res) + }) +} + +type acceptedConn struct { + conn net.Conn + err error +} + +func newFixture(t *testing.T, peerId string) *fixture { + fx := &fixture{ + ctrl: gomock.NewController(t), + acceptCh: make(chan acceptedConn), + testCtrl: newTesCtrl(), + } + fx.mc = mock_transport.NewMockMultiConn(fx.ctrl) + ctx := CtxWithPeerId(context.Background(), peerId) + fx.mc.EXPECT().Context().Return(ctx).AnyTimes() + fx.mc.EXPECT().Accept().DoAndReturn(func() (net.Conn, error) { + ac := <-fx.acceptCh + return ac.conn, ac.err + }).AnyTimes() + fx.mc.EXPECT().Close().AnyTimes() + p, err := NewPeer(fx.mc, fx.testCtrl) + require.NoError(t, err) + fx.peer = p.(*peer) + return fx +} + +type fixture struct { + *peer + ctrl *gomock.Controller + mc *mock_transport.MockMultiConn + acceptCh chan acceptedConn + testCtrl *testCtrl +} + +func (fx *fixture) finish() { + fx.testCtrl.close() + fx.ctrl.Finish() +} + +func newTesCtrl() *testCtrl { + return &testCtrl{closeCh: make(chan struct{}), serveConn: make(chan net.Conn, 10)} +} + +type testCtrl struct { + serveConn chan net.Conn + closeCh chan struct{} +} + +func (t *testCtrl) DrpcConfig() rpc.Config { + return rpc.Config{Stream: rpc.StreamConfig{MaxMsgSizeMb: 10}} +} + +func (t *testCtrl) ServeConn(ctx context.Context, conn net.Conn) (err error) { + t.serveConn <- conn + <-t.closeCh + return io.EOF +} + +func (t *testCtrl) close() { + close(t.closeCh) +} diff --git a/net/peerservice/peerservice.go b/net/peerservice/peerservice.go new file mode 100644 index 00000000..eda1d8c8 --- /dev/null +++ b/net/peerservice/peerservice.go @@ -0,0 +1,111 @@ +package peerservice + +import ( + "context" + "errors" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/net/pool" + "github.com/anyproto/any-sync/net/rpc/server" + "github.com/anyproto/any-sync/net/transport" + "github.com/anyproto/any-sync/net/transport/yamux" + "github.com/anyproto/any-sync/nodeconf" + "go.uber.org/zap" + "sync" +) + +const CName = "net.peerservice" + +var log = logger.NewNamed(CName) + +var ( + ErrAddrsNotFound = errors.New("addrs for peer not found") +) + +func New() PeerService { + return new(peerService) +} + +type PeerService interface { + Dial(ctx context.Context, peerId string) (pr peer.Peer, err error) + SetPeerAddrs(peerId string, addrs []string) + transport.Accepter + app.Component +} + +type peerService struct { + yamux transport.Transport + nodeConf nodeconf.NodeConf + peerAddrs map[string][]string + pool pool.Pool + server server.DRPCServer + mu sync.RWMutex +} + +func (p *peerService) Init(a *app.App) (err error) { + p.yamux = a.MustComponent(yamux.CName).(transport.Transport) + p.nodeConf = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf) + p.pool = a.MustComponent(pool.CName).(pool.Pool) + p.server = a.MustComponent(server.CName).(server.DRPCServer) + p.peerAddrs = map[string][]string{} + p.yamux.SetAccepter(p) + return nil +} + +func (p *peerService) Name() (name string) { + return CName +} + +func (p *peerService) Dial(ctx context.Context, peerId string) (pr peer.Peer, err error) { + p.mu.RLock() + defer p.mu.RUnlock() + + addrs, err := p.getPeerAddrs(peerId) + if err != nil { + return + } + + var mc transport.MultiConn + log.InfoCtx(ctx, "dial", zap.String("peerId", peerId), zap.Strings("addrs", addrs)) + for _, addr := range addrs { + mc, err = p.yamux.Dial(ctx, addr) + if err != nil { + log.InfoCtx(ctx, "can't connect to host", zap.String("addr", addr), zap.Error(err)) + } else { + break + } + } + if err != nil { + return + } + return peer.NewPeer(mc, p.server) +} + +func (p *peerService) Accept(mc transport.MultiConn) (err error) { + pr, err := peer.NewPeer(mc, p.server) + if err != nil { + return err + } + if err = p.pool.AddPeer(context.Background(), pr); err != nil { + _ = pr.Close() + } + return +} + +func (p *peerService) SetPeerAddrs(peerId string, addrs []string) { + p.mu.Lock() + defer p.mu.Unlock() + p.peerAddrs[peerId] = addrs +} + +func (p *peerService) getPeerAddrs(peerId string) ([]string, error) { + if addrs, ok := p.nodeConf.PeerAddresses(peerId); ok { + return addrs, nil + } + addrs, ok := p.peerAddrs[peerId] + if !ok || len(addrs) == 0 { + return nil, ErrAddrsNotFound + } + return addrs, nil +} diff --git a/net/pool/mock_pool/mock_pool.go b/net/pool/mock_pool/mock_pool.go new file mode 100644 index 00000000..be884903 --- /dev/null +++ b/net/pool/mock_pool/mock_pool.go @@ -0,0 +1,80 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/anyproto/any-sync/net/pool (interfaces: Pool) + +// Package mock_pool is a generated GoMock package. +package mock_pool + +import ( + context "context" + reflect "reflect" + + peer "github.com/anyproto/any-sync/net/peer" + gomock "github.com/golang/mock/gomock" +) + +// MockPool is a mock of Pool interface. +type MockPool struct { + ctrl *gomock.Controller + recorder *MockPoolMockRecorder +} + +// MockPoolMockRecorder is the mock recorder for MockPool. +type MockPoolMockRecorder struct { + mock *MockPool +} + +// NewMockPool creates a new mock instance. +func NewMockPool(ctrl *gomock.Controller) *MockPool { + mock := &MockPool{ctrl: ctrl} + mock.recorder = &MockPoolMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPool) EXPECT() *MockPoolMockRecorder { + return m.recorder +} + +// AddPeer mocks base method. +func (m *MockPool) AddPeer(arg0 context.Context, arg1 peer.Peer) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "AddPeer", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddPeer indicates an expected call of AddPeer. +func (mr *MockPoolMockRecorder) AddPeer(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddPeer", reflect.TypeOf((*MockPool)(nil).AddPeer), arg0, arg1) +} + +// Get mocks base method. +func (m *MockPool) Get(arg0 context.Context, arg1 string) (peer.Peer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Get", arg0, arg1) + ret0, _ := ret[0].(peer.Peer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Get indicates an expected call of Get. +func (mr *MockPoolMockRecorder) Get(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockPool)(nil).Get), arg0, arg1) +} + +// GetOneOf mocks base method. +func (m *MockPool) GetOneOf(arg0 context.Context, arg1 []string) (peer.Peer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOneOf", arg0, arg1) + ret0, _ := ret[0].(peer.Peer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetOneOf indicates an expected call of GetOneOf. +func (mr *MockPoolMockRecorder) GetOneOf(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOneOf", reflect.TypeOf((*MockPool)(nil).GetOneOf), arg0, arg1) +} diff --git a/net/pool/pool.go b/net/pool/pool.go index b6c0d7df..7e936e1b 100644 --- a/net/pool/pool.go +++ b/net/pool/pool.go @@ -1,10 +1,10 @@ +//go:generate mockgen -destination mock_pool/mock_pool.go github.com/anyproto/any-sync/net/pool Pool package pool import ( "context" "github.com/anyproto/any-sync/app/ocache" "github.com/anyproto/any-sync/net" - "github.com/anyproto/any-sync/net/dialer" "github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/secureservice/handshake" "go.uber.org/zap" @@ -13,59 +13,61 @@ import ( // Pool creates and caches outgoing connection type Pool interface { - // Get lookups to peer in existing connections or creates and cache new one + // Get lookups to peer in existing connections or creates and outgoing new one Get(ctx context.Context, id string) (peer.Peer, error) - // Dial creates new connection to peer and not use cache - Dial(ctx context.Context, id string) (peer.Peer, error) - // GetOneOf searches at least one existing connection in cache or creates a new one from a randomly selected id from given list + // GetOneOf searches at least one existing connection in outgoing or creates a new one from a randomly selected id from given list GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) - - DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) + // AddPeer adds incoming peer to the pool + AddPeer(ctx context.Context, p peer.Peer) (err error) } type pool struct { - cache ocache.OCache - dialer dialer.Dialer + outgoing ocache.OCache + incoming ocache.OCache } func (p *pool) Name() (name string) { return CName } -func (p *pool) Run(ctx context.Context) (err error) { - return nil +func (p *pool) Get(ctx context.Context, id string) (pr peer.Peer, err error) { + // if we have incoming connection - try to reuse it + if pr, err = p.get(ctx, p.incoming, id); err != nil { + // or try to get or create outgoing + return p.get(ctx, p.outgoing, id) + } + return } -func (p *pool) Get(ctx context.Context, id string) (peer.Peer, error) { - v, err := p.cache.Get(ctx, id) +func (p *pool) get(ctx context.Context, source ocache.OCache, id string) (peer.Peer, error) { + v, err := source.Get(ctx, id) if err != nil { return nil, err } pr := v.(peer.Peer) - select { - case <-pr.Closed(): - default: + if !pr.IsClosed() { return pr, nil } - _, _ = p.cache.Remove(ctx, id) + _, _ = source.Remove(ctx, id) return p.Get(ctx, id) } -func (p *pool) Dial(ctx context.Context, id string) (peer.Peer, error) { - return p.dialer.Dial(ctx, id) -} - func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { // finding existing connection for _, peerId := range peerIds { - if v, err := p.cache.Pick(ctx, peerId); err == nil { + if v, err := p.incoming.Pick(ctx, peerId); err == nil { pr := v.(peer.Peer) - select { - case <-pr.Closed(): - default: + if !pr.IsClosed() { return pr, nil } - _, _ = p.cache.Remove(ctx, peerId) + _, _ = p.incoming.Remove(ctx, peerId) + } + if v, err := p.outgoing.Pick(ctx, peerId); err == nil { + pr := v.(peer.Peer) + if !pr.IsClosed() { + return pr, nil + } + _, _ = p.outgoing.Remove(ctx, peerId) } } // shuffle ids for better consistency @@ -75,8 +77,8 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error // connecting var lastErr error for _, peerId := range peerIds { - if v, err := p.cache.Get(ctx, peerId); err == nil { - return v.(peer.Peer), nil + if v, err := p.Get(ctx, peerId); err == nil { + return v, nil } else { log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err)) lastErr = err @@ -88,27 +90,18 @@ func (p *pool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error return nil, lastErr } -func (p *pool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { - // shuffle ids for better consistency - rand.Shuffle(len(peerIds), func(i, j int) { - peerIds[i], peerIds[j] = peerIds[j], peerIds[i] - }) - // connecting - var lastErr error - for _, peerId := range peerIds { - if v, err := p.dialer.Dial(ctx, peerId); err == nil { - return v.(peer.Peer), nil +func (p *pool) AddPeer(ctx context.Context, pr peer.Peer) (err error) { + if err = p.incoming.Add(pr.Id(), pr); err != nil { + if err == ocache.ErrExists { + // in case when an incoming connection with a peer already exists, we close and remove an existing connection + if v, e := p.incoming.Pick(ctx, pr.Id()); e == nil { + _ = v.Close() + _, _ = p.incoming.Remove(ctx, pr.Id()) + return p.incoming.Add(pr.Id(), pr) + } } else { - log.Debug("unable to connect", zap.String("peerId", peerId), zap.Error(err)) - lastErr = err + return err } } - if _, ok := lastErr.(handshake.HandshakeError); !ok { - lastErr = net.ErrUnableToConnect - } - return nil, lastErr -} - -func (p *pool) Close(ctx context.Context) (err error) { - return p.cache.Close() + return } diff --git a/net/pool/pool_test.go b/net/pool/pool_test.go index c82533e8..c93c9aec 100644 --- a/net/pool/pool_test.go +++ b/net/pool/pool_test.go @@ -6,11 +6,11 @@ import ( "fmt" "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/net" - "github.com/anyproto/any-sync/net/dialer" "github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + net2 "net" "storj.io/drpc" "testing" "time" @@ -133,6 +133,27 @@ func TestPool_GetOneOf(t *testing.T) { }) } +func TestPool_AddPeer(t *testing.T) { + t.Run("success", func(t *testing.T) { + fx := newFixture(t) + defer fx.Finish() + require.NoError(t, fx.AddPeer(ctx, newTestPeer("p1"))) + }) + t.Run("two peers", func(t *testing.T) { + fx := newFixture(t) + defer fx.Finish() + p1, p2 := newTestPeer("p1"), newTestPeer("p1") + require.NoError(t, fx.AddPeer(ctx, p1)) + require.NoError(t, fx.AddPeer(ctx, p2)) + select { + case <-p1.closed: + default: + assert.Truef(t, false, "peer not closed") + } + }) + +} + func newFixture(t *testing.T) *fixture { fx := &fixture{ Service: New(), @@ -158,7 +179,7 @@ type fixture struct { t *testing.T } -var _ dialer.Dialer = (*dialerMock)(nil) +var _ dialer = (*dialerMock)(nil) type dialerMock struct { dial func(ctx context.Context, peerId string) (peer peer.Peer, err error) @@ -181,7 +202,7 @@ func (d *dialerMock) Init(a *app.App) (err error) { } func (d *dialerMock) Name() (name string) { - return dialer.CName + return "net.peerservice" } func newTestPeer(id string) *testPeer { @@ -196,6 +217,31 @@ type testPeer struct { closed chan struct{} } +func (t *testPeer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error { + return fmt.Errorf("not implemented") +} + +func (t *testPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) { + return nil, fmt.Errorf("not implemented") +} + +func (t *testPeer) ReleaseDrpcConn(conn drpc.Conn) {} + +func (t *testPeer) Context() context.Context { + //TODO implement me + panic("implement me") +} + +func (t *testPeer) Accept() (conn net2.Conn, err error) { + //TODO implement me + panic("implement me") +} + +func (t *testPeer) Open(ctx context.Context) (conn net2.Conn, err error) { + //TODO implement me + panic("implement me") +} + func (t *testPeer) Addr() string { return "" } @@ -204,12 +250,6 @@ func (t *testPeer) Id() string { return t.id } -func (t *testPeer) LastUsage() time.Time { - return time.Now() -} - -func (t *testPeer) UpdateLastUsage() {} - func (t *testPeer) TryClose(objectTTL time.Duration) (res bool, err error) { return true, t.Close() } @@ -224,14 +264,11 @@ func (t *testPeer) Close() error { return nil } -func (t *testPeer) Closed() <-chan struct{} { - return t.closed -} - -func (t *testPeer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error { - return fmt.Errorf("call Invoke on test peer") -} - -func (t *testPeer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) { - return nil, fmt.Errorf("call NewStream on test peer") +func (t *testPeer) IsClosed() bool { + select { + case <-t.closed: + return true + default: + return false + } } diff --git a/net/pool/poolservice.go b/net/pool/poolservice.go index 9b69ae24..2f84e5d0 100644 --- a/net/pool/poolservice.go +++ b/net/pool/poolservice.go @@ -6,8 +6,9 @@ import ( "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/ocache" "github.com/anyproto/any-sync/metric" - "github.com/anyproto/any-sync/net/dialer" + "github.com/anyproto/any-sync/net/peer" "github.com/prometheus/client_golang/prometheus" + "go.uber.org/zap" "time" ) @@ -23,46 +24,54 @@ func New() Service { type Service interface { Pool - NewPool(name string) Pool app.ComponentRunnable } +type dialer interface { + Dial(ctx context.Context, peerId string) (pr peer.Peer, err error) +} + type poolService struct { // default pool *pool - dialer dialer.Dialer + dialer dialer metricReg *prometheus.Registry } func (p *poolService) Init(a *app.App) (err error) { - p.dialer = a.MustComponent(dialer.CName).(dialer.Dialer) - p.pool = &pool{dialer: p.dialer} + p.dialer = a.MustComponent("net.peerservice").(dialer) + p.pool = &pool{} if m := a.Component(metric.CName); m != nil { p.metricReg = m.(metric.Metric).Registry() } - p.pool.cache = ocache.New( + p.pool.outgoing = ocache.New( func(ctx context.Context, id string) (value ocache.Object, err error) { return p.dialer.Dial(ctx, id) }, ocache.WithLogger(log.Sugar()), ocache.WithGCPeriod(time.Minute), ocache.WithTTL(time.Minute*5), - ocache.WithPrometheus(p.metricReg, "netpool", "default"), + ocache.WithPrometheus(p.metricReg, "netpool", "outgoing"), + ) + p.pool.incoming = ocache.New( + func(ctx context.Context, id string) (value ocache.Object, err error) { + return nil, ocache.ErrNotExists + }, + ocache.WithLogger(log.Sugar()), + ocache.WithGCPeriod(time.Minute), + ocache.WithTTL(time.Minute*5), + ocache.WithPrometheus(p.metricReg, "netpool", "incoming"), ) return nil } -func (p *poolService) NewPool(name string) Pool { - return &pool{ - dialer: p.dialer, - cache: ocache.New( - func(ctx context.Context, id string) (value ocache.Object, err error) { - return p.dialer.Dial(ctx, id) - }, - ocache.WithLogger(log.Sugar()), - ocache.WithGCPeriod(time.Minute), - ocache.WithTTL(time.Minute*5), - ocache.WithPrometheus(p.metricReg, "netpool", name), - ), - } +func (p *pool) Run(ctx context.Context) (err error) { + return nil +} + +func (p *pool) Close(ctx context.Context) (err error) { + if e := p.incoming.Close(); e != nil { + log.Warn("close incoming cache error", zap.Error(e)) + } + return p.outgoing.Close() } diff --git a/net/rpc/debugserver/config.go b/net/rpc/debugserver/config.go new file mode 100644 index 00000000..53342ece --- /dev/null +++ b/net/rpc/debugserver/config.go @@ -0,0 +1,9 @@ +package debugserver + +type configGetter interface { + GetDebugServer() Config +} + +type Config struct { + ListenAddr string `yaml:"listenAddr"` +} diff --git a/net/rpc/debugserver/debugserver.go b/net/rpc/debugserver/debugserver.go new file mode 100644 index 00000000..4a313ea8 --- /dev/null +++ b/net/rpc/debugserver/debugserver.go @@ -0,0 +1,70 @@ +package debugserver + +import ( + "context" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/net/rpc" + "net" + "storj.io/drpc" + "storj.io/drpc/drpcmanager" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + "storj.io/drpc/drpcstream" + "storj.io/drpc/drpcwire" +) + +const CName = "net.rpc.debugserver" + +func New() DebugServer { + return &debugServer{} +} + +type DebugServer interface { + app.ComponentRunnable + drpc.Mux +} + +type debugServer struct { + drpcServer *drpcserver.Server + *drpcmux.Mux + drpcConf rpc.Config + config Config + runCtx context.Context + runCtxCancel context.CancelFunc +} + +func (d *debugServer) Init(a *app.App) (err error) { + d.drpcConf = a.MustComponent("config").(rpc.ConfigGetter).GetDrpc() + d.config = a.MustComponent("config").(configGetter).GetDebugServer() + d.Mux = drpcmux.New() + bufSize := d.drpcConf.Stream.MaxMsgSizeMb * (1 << 20) + d.drpcServer = drpcserver.NewWithOptions(d, drpcserver.Options{Manager: drpcmanager.Options{ + Reader: drpcwire.ReaderOptions{MaximumBufferSize: bufSize}, + Stream: drpcstream.Options{MaximumBufferSize: bufSize}, + }}) + return nil +} + +func (d *debugServer) Name() (name string) { + return CName +} + +func (d *debugServer) Run(ctx context.Context) (err error) { + if d.config.ListenAddr == "" { + return + } + lis, err := net.Listen("tcp", d.config.ListenAddr) + if err != nil { + return + } + d.runCtx, d.runCtxCancel = context.WithCancel(context.Background()) + go d.drpcServer.Serve(d.runCtx, lis) + return +} + +func (d *debugServer) Close(ctx context.Context) (err error) { + if d.runCtx != nil { + d.runCtxCancel() + } + return nil +} diff --git a/net/rpc/drpcconfig.go b/net/rpc/drpcconfig.go new file mode 100644 index 00000000..551313f6 --- /dev/null +++ b/net/rpc/drpcconfig.go @@ -0,0 +1,13 @@ +package rpc + +type ConfigGetter interface { + GetDrpc() Config +} + +type Config struct { + Stream StreamConfig `yaml:"stream"` +} + +type StreamConfig struct { + MaxMsgSizeMb int `yaml:"maxMsgSizeMb"` +} diff --git a/net/rpc/rpctest/peer.go b/net/rpc/rpctest/peer.go new file mode 100644 index 00000000..a5fef8be --- /dev/null +++ b/net/rpc/rpctest/peer.go @@ -0,0 +1,30 @@ +package rpctest + +import ( + "context" + "github.com/anyproto/any-sync/net/connutil" + "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/net/transport" + yamux2 "github.com/anyproto/any-sync/net/transport/yamux" + "github.com/hashicorp/yamux" + "net" +) + +func MultiConnPair(peerIdServ, peerIdClient string) (serv, client transport.MultiConn) { + sc, cc := net.Pipe() + var servConn = make(chan transport.MultiConn, 1) + go func() { + sess, err := yamux.Server(sc, yamux.DefaultConfig()) + if err != nil { + panic(err) + } + servConn <- yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdServ), connutil.NewLastUsageConn(sc), "", sess) + }() + sess, err := yamux.Client(cc, yamux.DefaultConfig()) + if err != nil { + panic(err) + } + client = yamux2.NewMultiConn(peer.CtxWithPeerId(context.Background(), peerIdClient), connutil.NewLastUsageConn(cc), "", sess) + serv = <-servConn + return +} diff --git a/net/rpc/rpctest/pool.go b/net/rpc/rpctest/pool.go index a22a88cc..fc70d16b 100644 --- a/net/rpc/rpctest/pool.go +++ b/net/rpc/rpctest/pool.go @@ -2,84 +2,21 @@ package rpctest import ( "context" - "errors" "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/net" "github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/pool" - "math/rand" - "storj.io/drpc" "sync" - "time" ) -var ErrCantConnect = errors.New("can't connect to test server") - func NewTestPool() *TestPool { - return &TestPool{ - peers: map[string]peer.Peer{}, - } + return &TestPool{peers: map[string]peer.Peer{}} } type TestPool struct { - ts *TesServer peers map[string]peer.Peer mu sync.Mutex -} - -func (t *TestPool) WithServer(ts *TesServer) *TestPool { - t.mu.Lock() - defer t.mu.Unlock() - t.ts = ts - return t -} - -func (t *TestPool) Get(ctx context.Context, id string) (peer.Peer, error) { - t.mu.Lock() - defer t.mu.Unlock() - if p, ok := t.peers[id]; ok { - return p, nil - } - if t.ts == nil { - return nil, ErrCantConnect - } - return &testPeer{id: id, Conn: t.ts.Dial(ctx)}, nil -} - -func (t *TestPool) Dial(ctx context.Context, id string) (peer.Peer, error) { - return t.Get(ctx, id) -} - -func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { - t.mu.Lock() - defer t.mu.Unlock() - for _, peerId := range peerIds { - if p, ok := t.peers[peerId]; ok { - return p, nil - } - } - if t.ts == nil { - return nil, ErrCantConnect - } - return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil -} - -func (t *TestPool) DialOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { - t.mu.Lock() - defer t.mu.Unlock() - if t.ts == nil { - return nil, ErrCantConnect - } - return &testPeer{id: peerIds[rand.Intn(len(peerIds))], Conn: t.ts.Dial(ctx)}, nil -} - -func (t *TestPool) NewPool(name string) pool.Pool { - return t -} - -func (t *TestPool) AddPeer(p peer.Peer) { - t.mu.Lock() - defer t.mu.Unlock() - t.peers[p.Id()] = p + ts *TestServer } func (t *TestPool) Init(a *app.App) (err error) { @@ -90,6 +27,13 @@ func (t *TestPool) Name() (name string) { return pool.CName } +func (t *TestPool) WithServer(ts *TestServer) *TestPool { + t.mu.Lock() + defer t.mu.Unlock() + t.ts = ts + return t +} + func (t *TestPool) Run(ctx context.Context) (err error) { return nil } @@ -98,25 +42,35 @@ func (t *TestPool) Close(ctx context.Context) (err error) { return nil } -type testPeer struct { - id string - drpc.Conn +func (t *TestPool) Get(ctx context.Context, id string) (peer.Peer, error) { + t.mu.Lock() + defer t.mu.Unlock() + if p, ok := t.peers[id]; ok { + return p, nil + } + if t.ts == nil { + return nil, net.ErrUnableToConnect + } + return t.ts.Dial(id) } -func (t testPeer) Addr() string { - return "" +func (t *TestPool) GetOneOf(ctx context.Context, peerIds []string) (peer.Peer, error) { + t.mu.Lock() + defer t.mu.Unlock() + for _, peerId := range peerIds { + if p, ok := t.peers[peerId]; ok { + return p, nil + } + } + if t.ts == nil || len(peerIds) == 0 { + return nil, net.ErrUnableToConnect + } + return t.ts.Dial(peerIds[0]) } -func (t testPeer) TryClose(objectTTL time.Duration) (res bool, err error) { - return true, t.Close() +func (t *TestPool) AddPeer(ctx context.Context, p peer.Peer) (err error) { + t.mu.Lock() + defer t.mu.Unlock() + t.peers[p.Id()] = p + return nil } - -func (t testPeer) Id() string { - return t.id -} - -func (t testPeer) LastUsage() time.Time { - return time.Now() -} - -func (t testPeer) UpdateLastUsage() {} diff --git a/net/rpc/rpctest/server.go b/net/rpc/rpctest/server.go index 4731a8e2..de6bf68b 100644 --- a/net/rpc/rpctest/server.go +++ b/net/rpc/rpctest/server.go @@ -3,45 +3,69 @@ package rpctest import ( "context" "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/net/peer" + "github.com/anyproto/any-sync/net/rpc" "github.com/anyproto/any-sync/net/rpc/server" "net" - "storj.io/drpc" - "storj.io/drpc/drpcconn" "storj.io/drpc/drpcmux" "storj.io/drpc/drpcserver" ) -func NewTestServer() *TesServer { - ts := &TesServer{ +type mockCtrl struct { +} + +func (m mockCtrl) ServeConn(ctx context.Context, conn net.Conn) (err error) { + return nil +} + +func (m mockCtrl) DrpcConfig() rpc.Config { + return rpc.Config{} +} + +func NewTestServer() *TestServer { + ts := &TestServer{ Mux: drpcmux.New(), } ts.Server = drpcserver.New(ts.Mux) return ts } -type TesServer struct { +type TestServer struct { *drpcmux.Mux *drpcserver.Server } -func (ts *TesServer) Init(a *app.App) (err error) { +func (s *TestServer) Init(a *app.App) (err error) { return nil } -func (ts *TesServer) Name() (name string) { +func (s *TestServer) Name() (name string) { return server.CName } -func (ts *TesServer) Run(ctx context.Context) (err error) { +func (s *TestServer) Run(ctx context.Context) (err error) { return nil } -func (ts *TesServer) Close(ctx context.Context) (err error) { +func (s *TestServer) Close(ctx context.Context) (err error) { return nil } -func (ts *TesServer) Dial(ctx context.Context) drpc.Conn { - sc, cc := net.Pipe() - go ts.Server.ServeOne(ctx, sc) - return drpcconn.New(cc) +func (s *TestServer) ServeConn(ctx context.Context, conn net.Conn) (err error) { + return s.Server.ServeOne(ctx, conn) +} + +func (s *TestServer) DrpcConfig() rpc.Config { + return rpc.Config{Stream: rpc.StreamConfig{MaxMsgSizeMb: 10}} +} + +func (s *TestServer) Dial(peerId string) (clientPeer peer.Peer, err error) { + mcS, mcC := MultiConnPair(peerId+"server", peerId) + // NewPeer runs the accept loop + _, err = peer.NewPeer(mcS, s) + if err != nil { + return + } + // and we ourselves don't call server methods on accept + return peer.NewPeer(mcC, mockCtrl{}) } diff --git a/net/rpc/server/baseserver.go b/net/rpc/server/baseserver.go deleted file mode 100644 index cb3047ed..00000000 --- a/net/rpc/server/baseserver.go +++ /dev/null @@ -1,134 +0,0 @@ -package server - -import ( - "context" - "github.com/anyproto/any-sync/net/peer" - "github.com/anyproto/any-sync/net/secureservice" - "github.com/libp2p/go-libp2p/core/sec" - "github.com/zeebo/errs" - "go.uber.org/zap" - "io" - "net" - "storj.io/drpc" - "storj.io/drpc/drpcmanager" - "storj.io/drpc/drpcmux" - "storj.io/drpc/drpcserver" - "storj.io/drpc/drpcwire" - "time" -) - -type BaseDrpcServer struct { - drpcServer *drpcserver.Server - transport secureservice.SecureService - listeners []net.Listener - handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) - cancel func() - *drpcmux.Mux -} - -type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler - -type Params struct { - BufferSizeMb int - ListenAddrs []string - Wrapper DRPCHandlerWrapper - TimeoutMillis int - Handshake func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) -} - -func NewBaseDrpcServer() *BaseDrpcServer { - return &BaseDrpcServer{Mux: drpcmux.New()} -} - -func (s *BaseDrpcServer) Run(ctx context.Context, params Params) (err error) { - s.drpcServer = drpcserver.NewWithOptions(params.Wrapper(s.Mux), drpcserver.Options{Manager: drpcmanager.Options{ - Reader: drpcwire.ReaderOptions{MaximumBufferSize: params.BufferSizeMb * (1 << 20)}, - }}) - s.handshake = params.Handshake - ctx, s.cancel = context.WithCancel(ctx) - for _, addr := range params.ListenAddrs { - list, err := net.Listen("tcp", addr) - if err != nil { - return err - } - s.listeners = append(s.listeners, list) - go s.serve(ctx, list) - } - return -} - -func (s *BaseDrpcServer) serve(ctx context.Context, lis net.Listener) { - l := log.With(zap.String("localAddr", lis.Addr().String())) - l.Info("drpc listener started") - defer func() { - l.Debug("drpc listener stopped") - }() - for { - select { - case <-ctx.Done(): - return - default: - } - conn, err := lis.Accept() - if err != nil { - if isTemporary(err) { - l.Debug("listener temporary accept error", zap.Error(err)) - select { - case <-time.After(time.Second): - case <-ctx.Done(): - return - } - continue - } - l.Error("listener accept error", zap.Error(err)) - return - } - go s.serveConn(conn) - } -} - -func (s *BaseDrpcServer) serveConn(conn net.Conn) { - l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String())) - var ( - ctx = context.Background() - err error - ) - if s.handshake != nil { - ctx, conn, err = s.handshake(conn) - if err != nil { - l.Info("handshake error", zap.Error(err)) - return - } - if sc, ok := conn.(sec.SecureConn); ok { - ctx = peer.CtxWithPeerId(ctx, sc.RemotePeer().String()) - } - } - ctx = peer.CtxWithPeerAddr(ctx, conn.RemoteAddr().String()) - l.Debug("connection opened") - if err := s.drpcServer.ServeOne(ctx, conn); err != nil { - if errs.Is(err, context.Canceled) || errs.Is(err, io.EOF) { - l.Debug("connection closed") - } else { - l.Warn("serve connection error", zap.Error(err)) - } - } -} - -func (s *BaseDrpcServer) ListenAddrs() (addrs []net.Addr) { - for _, list := range s.listeners { - addrs = append(addrs, list.Addr()) - } - return -} - -func (s *BaseDrpcServer) Close(ctx context.Context) (err error) { - if s.cancel != nil { - s.cancel() - } - for _, l := range s.listeners { - if e := l.Close(); e != nil { - log.Warn("close listener error", zap.Error(e)) - } - } - return -} diff --git a/net/rpc/server/drpcserver.go b/net/rpc/server/drpcserver.go index 1874d16a..8fc112ca 100644 --- a/net/rpc/server/drpcserver.go +++ b/net/rpc/server/drpcserver.go @@ -5,12 +5,15 @@ import ( "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/metric" - anyNet "github.com/anyproto/any-sync/net" - "github.com/anyproto/any-sync/net/secureservice" - "github.com/libp2p/go-libp2p/core/sec" + "github.com/anyproto/any-sync/net/rpc" + "go.uber.org/zap" "net" "storj.io/drpc" - "time" + "storj.io/drpc/drpcmanager" + "storj.io/drpc/drpcmux" + "storj.io/drpc/drpcserver" + "storj.io/drpc/drpcstream" + "storj.io/drpc/drpcwire" ) const CName = "common.net.drpcserver" @@ -18,49 +21,53 @@ const CName = "common.net.drpcserver" var log = logger.NewNamed(CName) func New() DRPCServer { - return &drpcServer{BaseDrpcServer: NewBaseDrpcServer()} + return &drpcServer{} } type DRPCServer interface { - app.ComponentRunnable + ServeConn(ctx context.Context, conn net.Conn) (err error) + DrpcConfig() rpc.Config + app.Component drpc.Mux } type drpcServer struct { - config anyNet.Config - metric metric.Metric - transport secureservice.SecureService - *BaseDrpcServer + drpcServer *drpcserver.Server + *drpcmux.Mux + config rpc.Config + metric metric.Metric } -func (s *drpcServer) Init(a *app.App) (err error) { - s.config = a.MustComponent("config").(anyNet.ConfigGetter).GetNet() - s.metric = a.MustComponent(metric.CName).(metric.Metric) - s.transport = a.MustComponent(secureservice.CName).(secureservice.SecureService) - return nil -} +type DRPCHandlerWrapper func(handler drpc.Handler) drpc.Handler func (s *drpcServer) Name() (name string) { return CName } -func (s *drpcServer) Run(ctx context.Context) (err error) { - params := Params{ - BufferSizeMb: s.config.Stream.MaxMsgSizeMb, - TimeoutMillis: s.config.Stream.TimeoutMilliseconds, - ListenAddrs: s.config.Server.ListenAddrs, - Wrapper: func(handler drpc.Handler) drpc.Handler { - return s.metric.WrapDRPCHandler(handler) - }, - Handshake: func(conn net.Conn) (cCtx context.Context, sc sec.SecureConn, err error) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - return s.transport.SecureInbound(ctx, conn) - }, +func (s *drpcServer) Init(a *app.App) (err error) { + s.config = a.MustComponent("config").(rpc.ConfigGetter).GetDrpc() + s.metric, _ = a.Component(metric.CName).(metric.Metric) + s.Mux = drpcmux.New() + + var handler drpc.Handler + handler = s + if s.metric != nil { + handler = s.metric.WrapDRPCHandler(s) } - return s.BaseDrpcServer.Run(ctx, params) + bufSize := s.config.Stream.MaxMsgSizeMb * (1 << 20) + s.drpcServer = drpcserver.NewWithOptions(handler, drpcserver.Options{Manager: drpcmanager.Options{ + Reader: drpcwire.ReaderOptions{MaximumBufferSize: bufSize}, + Stream: drpcstream.Options{MaximumBufferSize: bufSize}, + }}) + return } -func (s *drpcServer) Close(ctx context.Context) (err error) { - return s.BaseDrpcServer.Close(ctx) +func (s *drpcServer) ServeConn(ctx context.Context, conn net.Conn) (err error) { + l := log.With(zap.String("remoteAddr", conn.RemoteAddr().String())).With(zap.String("localAddr", conn.LocalAddr().String())) + l.Debug("drpc serve peer") + return s.drpcServer.ServeOne(ctx, conn) +} + +func (s *drpcServer) DrpcConfig() rpc.Config { + return s.config } diff --git a/net/secureservice/credential.go b/net/secureservice/credential.go index 5e97e8fc..9d992430 100644 --- a/net/secureservice/credential.go +++ b/net/secureservice/credential.go @@ -5,7 +5,6 @@ import ( "github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" "github.com/anyproto/any-sync/util/crypto" - "github.com/libp2p/go-libp2p/core/sec" "go.uber.org/zap" ) @@ -19,11 +18,11 @@ type noVerifyChecker struct { cred *handshakeproto.Credentials } -func (n noVerifyChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { +func (n noVerifyChecker) MakeCredentials(remotePeerId string) *handshakeproto.Credentials { return n.cred } -func (n noVerifyChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { +func (n noVerifyChecker) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { if cred.Version != n.cred.Version { return nil, handshake.ErrIncompatibleVersion } @@ -42,8 +41,8 @@ type peerSignVerifier struct { account *accountdata.AccountKeys } -func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { - sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + sc.RemotePeer().String())) +func (p *peerSignVerifier) MakeCredentials(remotePeerId string) *handshakeproto.Credentials { + sign, err := p.account.SignKey.Sign([]byte(p.account.PeerId + remotePeerId)) if err != nil { log.Warn("can't sign identity credentials", zap.Error(err)) } @@ -61,7 +60,7 @@ func (p *peerSignVerifier) MakeCredentials(sc sec.SecureConn) *handshakeproto.Cr } } -func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { +func (p *peerSignVerifier) CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { if cred.Version != p.protoVersion { return nil, handshake.ErrIncompatibleVersion } @@ -76,7 +75,7 @@ func (p *peerSignVerifier) CheckCredential(sc sec.SecureConn, cred *handshakepro if err != nil { return nil, handshake.ErrInvalidCredentials } - ok, err := pubKey.Verify([]byte((sc.RemotePeer().String() + p.account.PeerId)), msg.Sign) + ok, err := pubKey.Verify([]byte((remotePeerId + p.account.PeerId)), msg.Sign) if err != nil { return nil, err } diff --git a/net/secureservice/credential_test.go b/net/secureservice/credential_test.go index e64e173a..50b24ece 100644 --- a/net/secureservice/credential_test.go +++ b/net/secureservice/credential_test.go @@ -4,13 +4,8 @@ import ( "github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/testutil/accounttest" - "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/network" - "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/sec" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "net" "testing" ) @@ -23,8 +18,8 @@ func TestPeerSignVerifier_CheckCredential(t *testing.T) { cc1 := newPeerSignVerifier(0, a1) cc2 := newPeerSignVerifier(0, a2) - c1 := newTestSC(a2.PeerId) - c2 := newTestSC(a1.PeerId) + c1 := a2.PeerId + c2 := a1.PeerId cr1 := cc1.MakeCredentials(c1) cr2 := cc2.MakeCredentials(c2) @@ -48,8 +43,8 @@ func TestIncompatibleVersion(t *testing.T) { cc1 := newPeerSignVerifier(0, a1) cc2 := newPeerSignVerifier(1, a2) - c1 := newTestSC(a2.PeerId) - c2 := newTestSC(a1.PeerId) + c1 := a2.PeerId + c2 := a1.PeerId cr1 := cc1.MakeCredentials(c1) cr2 := cc2.MakeCredentials(c2) @@ -68,35 +63,3 @@ func newTestAccData(t *testing.T) *accountdata.AccountKeys { require.NoError(t, as.Init(nil)) return as.Account() } - -func newTestSC(peerId string) sec.SecureConn { - pid, _ := peer.Decode(peerId) - return &testSc{ - ID: pid, - } -} - -type testSc struct { - net.Conn - peer.ID -} - -func (t *testSc) LocalPeer() peer.ID { - return "" -} - -func (t *testSc) LocalPrivateKey() crypto.PrivKey { - return nil -} - -func (t *testSc) RemotePeer() peer.ID { - return t.ID -} - -func (t *testSc) RemotePublicKey() crypto.PubKey { - return nil -} - -func (t *testSc) ConnState() network.ConnectionState { - return network.ConnectionState{} -} diff --git a/net/secureservice/handshake/credential.go b/net/secureservice/handshake/credential.go new file mode 100644 index 00000000..8a4760a0 --- /dev/null +++ b/net/secureservice/handshake/credential.go @@ -0,0 +1,133 @@ +package handshake + +import ( + "context" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "io" +) + +func OutgoingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + var ( + resIdentity []byte + resErr error + ) + go func() { + defer close(done) + resIdentity, resErr = outgoingHandshake(h, conn, peerId, cc) + }() + select { + case <-done: + return resIdentity, resErr + case <-ctx.Done(): + _ = conn.Close() + return nil, ctx.Err() + } +} + +func outgoingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { + defer h.release() + h.conn = conn + localCred := cc.MakeCredentials(peerId) + if err = h.writeCredentials(localCred); err != nil { + h.tryWriteErrAndClose(err) + return + } + msg, err := h.readMsg(msgTypeAck, msgTypeCred) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if msg.ack != nil { + if msg.ack.Error == handshakeproto.Error_InvalidCredentials { + return nil, ErrPeerDeclinedCredentials + } + return nil, HandshakeError{e: msg.ack.Error} + } + + if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil { + h.tryWriteErrAndClose(err) + return + } + + if err = h.writeAck(handshakeproto.Error_Null); err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + + msg, err = h.readMsg(msgTypeAck) + if err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + if msg.ack.Error == handshakeproto.Error_Null { + return identity, nil + } else { + _ = h.conn.Close() + return nil, HandshakeError{e: msg.ack.Error} + } +} + +func IncomingHandshake(ctx context.Context, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + var ( + resIdentity []byte + resError error + ) + go func() { + defer close(done) + resIdentity, resError = incomingHandshake(h, conn, peerId, cc) + }() + select { + case <-done: + return resIdentity, resError + case <-ctx.Done(): + _ = conn.Close() + return nil, ctx.Err() + } +} + +func incomingHandshake(h *handshake, conn io.ReadWriteCloser, peerId string, cc CredentialChecker) (identity []byte, err error) { + defer h.release() + h.conn = conn + + msg, err := h.readMsg(msgTypeCred) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if identity, err = cc.CheckCredential(peerId, msg.cred); err != nil { + h.tryWriteErrAndClose(err) + return + } + + if err = h.writeCredentials(cc.MakeCredentials(peerId)); err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + + msg, err = h.readMsg(msgTypeAck) + if err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + if msg.ack.Error != handshakeproto.Error_Null { + if msg.ack.Error == handshakeproto.Error_InvalidCredentials { + return nil, ErrPeerDeclinedCredentials + } + return nil, HandshakeError{e: msg.ack.Error} + } + if err = h.writeAck(handshakeproto.Error_Null); err != nil { + h.tryWriteErrAndClose(err) + return nil, err + } + return +} diff --git a/net/secureservice/handshake/handshake_test.go b/net/secureservice/handshake/credential_test.go similarity index 80% rename from net/secureservice/handshake/handshake_test.go rename to net/secureservice/handshake/credential_test.go index 0d8b16d7..49865848 100644 --- a/net/secureservice/handshake/handshake_test.go +++ b/net/secureservice/handshake/credential_test.go @@ -7,7 +7,6 @@ import ( "github.com/libp2p/go-libp2p/core/crypto" "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" - "github.com/libp2p/go-libp2p/core/sec" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "net" @@ -17,7 +16,7 @@ import ( var noVerifyChecker = &testCredChecker{ makeCred: &handshakeproto.Credentials{Type: handshakeproto.CredentialsType_SkipVerify}, - checkCred: func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { + checkCred: func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { return []byte("identity"), nil }, } @@ -32,21 +31,20 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // receive credential message - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred, msgTypeAck, msgTypeProto) require.NoError(t, err) - require.Nil(t, msg.ack) - _, err = noVerifyChecker.CheckCredential(c2, msg.cred) + _, err = noVerifyChecker.CheckCredential("p1", msg.cred) require.NoError(t, err) // send credential message - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // receive ack - msg, err = h.readMsg() + msg, err = h.readMsg(msgTypeAck) require.NoError(t, err) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) // send ack @@ -59,7 +57,7 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() _ = c2.Close() @@ -70,13 +68,13 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) _ = c2.Close() res := <-handshakeResCh @@ -86,13 +84,13 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.NoError(t, h.writeAck(ErrInvalidCredentials.e)) res := <-handshakeResCh @@ -102,16 +100,16 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) + identity, err := OutgoingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) - msg, err := h.readMsg() + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) + msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) assert.Equal(t, ErrInvalidCredentials.e, msg.ack.Error) res := <-handshakeResCh @@ -121,16 +119,16 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials and close conn - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) _ = c2.Close() res := <-handshakeResCh require.Error(t, res.err) @@ -139,18 +137,18 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // read ack and close conn - _, err = h.readMsg() + _, err = h.readMsg(msgTypeAck) require.NoError(t, err) _ = c2.Close() res := <-handshakeResCh @@ -160,24 +158,23 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // read ack - _, err = h.readMsg() + _, err = h.readMsg(msgTypeAck) require.NoError(t, err) // write cred instead ack - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) - msg, err := h.readMsg() - require.NoError(t, err) - assert.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) + _, err = h.readMsg(msgTypeAck) + require.Error(t, err) res := <-handshakeResCh require.Error(t, res.err) }) @@ -185,21 +182,21 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // receive credential message - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) - _, err = noVerifyChecker.CheckCredential(c2, msg.cred) + _, err = noVerifyChecker.CheckCredential("", msg.cred) require.NoError(t, err) // send credential message - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // receive ack - msg, err = h.readMsg() + msg, err = h.readMsg(msgTypeAck) require.NoError(t, err) require.Equal(t, handshakeproto.Error_Null, msg.ack.Error) // send ack @@ -213,13 +210,13 @@ func TestOutgoingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := OutgoingHandshake(ctx, c1, noVerifyChecker) + identity, err := OutgoingHandshake(ctx, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // receive credential message - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) ctxCancel() res := <-handshakeResCh @@ -236,22 +233,22 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) // write ack require.NoError(t, h.writeAck(handshakeproto.Error_Null)) // wait ack - msg, err = h.readMsg() + msg, err = h.readMsg(msgTypeAck) require.NoError(t, err) assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error) res := <-handshakeResCh @@ -262,7 +259,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() _ = c2.Close() @@ -273,13 +270,13 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials and close conn - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) _ = c2.Close() res := <-handshakeResCh require.Error(t, res.err) @@ -288,7 +285,7 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -302,15 +299,15 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) + identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrInvalidCredentials}) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // except ack with error - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) require.Nil(t, msg.cred) require.Equal(t, handshakeproto.Error_InvalidCredentials, msg.ack.Error) @@ -322,15 +319,15 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion}) + identity, err := IncomingHandshake(nil, c1, "", &testCredChecker{makeCred: noVerifyChecker.makeCred, checkErr: ErrIncompatibleVersion}) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // except ack with error - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeAck) require.NoError(t, err) require.Nil(t, msg.cred) require.Equal(t, handshakeproto.Error_IncompatibleVersion, msg.ack.Error) @@ -342,21 +339,21 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // read cred - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) // write cred instead ack - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) - // expect ack with error - msg, err := h.readMsg() - require.Equal(t, handshakeproto.Error_UnexpectedPayload, msg.ack.Error) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) + // expect EOF + _, err = h.readMsg(msgTypeAck) + require.Error(t, err) res := <-handshakeResCh require.Error(t, res.err) }) @@ -364,15 +361,15 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // read cred and close conn - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) _ = c2.Close() @@ -383,15 +380,15 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) @@ -405,15 +402,15 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) @@ -427,15 +424,15 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // wait credentials - msg, err := h.readMsg() + msg, err := h.readMsg(msgTypeCred) require.NoError(t, err) require.Nil(t, msg.ack) require.Equal(t, handshakeproto.CredentialsType_SkipVerify, msg.cred.Type) @@ -450,15 +447,15 @@ func TestIncomingHandshake(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(ctx, c1, noVerifyChecker) + identity, err := IncomingHandshake(ctx, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() h.conn = c2 // write credentials - require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(c2))) + require.NoError(t, h.writeCredentials(noVerifyChecker.MakeCredentials(""))) // wait credentials - _, err := h.readMsg() + _, err := h.readMsg(msgTypeCred) require.NoError(t, err) ctxCancel() res := <-handshakeResCh @@ -474,7 +471,7 @@ func TestNotAHandshakeMessage(t *testing.T) { c1, c2 := newConnPair(t) var handshakeResCh = make(chan handshakeRes, 1) go func() { - identity, err := IncomingHandshake(nil, c1, noVerifyChecker) + identity, err := IncomingHandshake(nil, c1, "", noVerifyChecker) handshakeResCh <- handshakeRes{identity: identity, err: err} }() h := newHandshake() @@ -482,7 +479,7 @@ func TestNotAHandshakeMessage(t *testing.T) { _, err := c2.Write([]byte("some unexpected bytes")) require.Error(t, err) res := <-handshakeResCh - assert.EqualError(t, res.err, ErrGotNotAHandshakeMessage.Error()) + assert.Error(t, res.err) } func TestEndToEnd(t *testing.T) { @@ -493,11 +490,11 @@ func TestEndToEnd(t *testing.T) { ) st := time.Now() go func() { - identity, err := OutgoingHandshake(nil, c1, noVerifyChecker) + identity, err := OutgoingHandshake(nil, c1, "", noVerifyChecker) outResCh <- handshakeRes{identity: identity, err: err} }() go func() { - identity, err := IncomingHandshake(nil, c2, noVerifyChecker) + identity, err := IncomingHandshake(nil, c2, "", noVerifyChecker) inResCh <- handshakeRes{identity: identity, err: err} }() @@ -521,7 +518,7 @@ func BenchmarkHandshake(b *testing.B) { defer close(done) go func() { for { - _, _ = OutgoingHandshake(nil, c1, noVerifyChecker) + _, _ = OutgoingHandshake(nil, c1, "", noVerifyChecker) select { case outRes <- struct{}{}: case <-done: @@ -531,7 +528,7 @@ func BenchmarkHandshake(b *testing.B) { }() go func() { for { - _, _ = IncomingHandshake(nil, c2, noVerifyChecker) + _, _ = IncomingHandshake(nil, c2, "", noVerifyChecker) select { case inRes <- struct{}{}: case <-done: @@ -551,20 +548,20 @@ func BenchmarkHandshake(b *testing.B) { type testCredChecker struct { makeCred *handshakeproto.Credentials - checkCred func(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) + checkCred func(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) checkErr error } -func (t *testCredChecker) MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials { +func (t *testCredChecker) MakeCredentials(peerId string) *handshakeproto.Credentials { return t.makeCred } -func (t *testCredChecker) CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) { +func (t *testCredChecker) CheckCredential(peerId string, cred *handshakeproto.Credentials) (identity []byte, err error) { if t.checkErr != nil { return nil, t.checkErr } if t.checkCred != nil { - return t.checkCred(sc, cred) + return t.checkCred(peerId, cred) } return nil, nil } diff --git a/net/secureservice/handshake/handshake.go b/net/secureservice/handshake/handshake.go index 04a9de72..42aa2a8c 100644 --- a/net/secureservice/handshake/handshake.go +++ b/net/secureservice/handshake/handshake.go @@ -1,11 +1,9 @@ package handshake import ( - "context" "encoding/binary" "errors" "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" - "github.com/libp2p/go-libp2p/core/sec" "golang.org/x/exp/slices" "io" "sync" @@ -14,8 +12,17 @@ import ( const headerSize = 5 // 1 byte for type + 4 byte for uint32 size const ( - msgTypeCred = byte(1) - msgTypeAck = byte(2) + msgTypeCred = byte(1) + msgTypeAck = byte(2) + msgTypeProto = byte(3) + + sizeLimit = 200 * 1024 // 200 Kb +) + +var ( + credMsgTypes = []byte{msgTypeCred, msgTypeAck} + protoMsgTypes = []byte{msgTypeProto, msgTypeAck} + protoMsgTypesAck = []byte{msgTypeAck} ) type HandshakeError struct { @@ -38,154 +45,26 @@ var ( ErrSkipVerifyNotAllowed = HandshakeError{e: handshakeproto.Error_SkipVerifyNotAllowed} ErrUnexpected = HandshakeError{e: handshakeproto.Error_Unexpected} - ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion} + ErrIncompatibleVersion = HandshakeError{e: handshakeproto.Error_IncompatibleVersion} + ErrIncompatibleProto = HandshakeError{e: handshakeproto.Error_IncompatibleProto} + ErrRemoteIncompatibleProto = HandshakeError{Err: errors.New("remote peer declined the proto")} - ErrGotNotAHandshakeMessage = errors.New("go not a handshake message") + ErrGotUnexpectedMessage = errors.New("go not a handshake message") ) var handshakePool = &sync.Pool{New: func() any { return &handshake{ - remoteCred: &handshakeproto.Credentials{}, - remoteAck: &handshakeproto.Ack{}, - localAck: &handshakeproto.Ack{}, - buf: make([]byte, 0, 1024), + remoteCred: &handshakeproto.Credentials{}, + remoteAck: &handshakeproto.Ack{}, + localAck: &handshakeproto.Ack{}, + remoteProto: &handshakeproto.Proto{}, + buf: make([]byte, 0, 1024), } }} type CredentialChecker interface { - MakeCredentials(sc sec.SecureConn) *handshakeproto.Credentials - CheckCredential(sc sec.SecureConn, cred *handshakeproto.Credentials) (identity []byte, err error) -} - -func OutgoingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - if ctx == nil { - ctx = context.Background() - } - h := newHandshake() - done := make(chan struct{}) - go func() { - defer close(done) - identity, err = outgoingHandshake(h, sc, cc) - }() - select { - case <-done: - return - case <-ctx.Done(): - _ = sc.Close() - return nil, ctx.Err() - } -} - -func outgoingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - defer h.release() - h.conn = sc - localCred := cc.MakeCredentials(sc) - if err = h.writeCredentials(localCred); err != nil { - h.tryWriteErrAndClose(err) - return - } - msg, err := h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return - } - if msg.ack != nil { - if msg.ack.Error == handshakeproto.Error_InvalidCredentials { - return nil, ErrPeerDeclinedCredentials - } - return nil, HandshakeError{e: msg.ack.Error} - } - - if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { - h.tryWriteErrAndClose(err) - return - } - - if err = h.writeAck(handshakeproto.Error_Null); err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - - msg, err = h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack == nil { - err = ErrUnexpectedPayload - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack.Error == handshakeproto.Error_Null { - return identity, nil - } else { - _ = h.conn.Close() - return nil, HandshakeError{e: msg.ack.Error} - } -} - -func IncomingHandshake(ctx context.Context, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - if ctx == nil { - ctx = context.Background() - } - h := newHandshake() - done := make(chan struct{}) - go func() { - defer close(done) - identity, err = incomingHandshake(h, sc, cc) - }() - select { - case <-done: - return - case <-ctx.Done(): - _ = sc.Close() - return nil, ctx.Err() - } -} - -func incomingHandshake(h *handshake, sc sec.SecureConn, cc CredentialChecker) (identity []byte, err error) { - defer h.release() - h.conn = sc - - msg, err := h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return - } - if msg.ack != nil { - return nil, ErrUnexpectedPayload - } - if identity, err = cc.CheckCredential(sc, msg.cred); err != nil { - h.tryWriteErrAndClose(err) - return - } - - if err = h.writeCredentials(cc.MakeCredentials(sc)); err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - - msg, err = h.readMsg() - if err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack == nil { - err = ErrUnexpectedPayload - h.tryWriteErrAndClose(err) - return nil, err - } - if msg.ack.Error != handshakeproto.Error_Null { - if msg.ack.Error == handshakeproto.Error_InvalidCredentials { - return nil, ErrPeerDeclinedCredentials - } - return nil, HandshakeError{e: msg.ack.Error} - } - if err = h.writeAck(handshakeproto.Error_Null); err != nil { - h.tryWriteErrAndClose(err) - return nil, err - } - return + MakeCredentials(remotePeerId string) *handshakeproto.Credentials + CheckCredential(remotePeerId string, cred *handshakeproto.Credentials) (identity []byte, err error) } func newHandshake() *handshake { @@ -193,11 +72,12 @@ func newHandshake() *handshake { } type handshake struct { - conn sec.SecureConn - remoteCred *handshakeproto.Credentials - remoteAck *handshakeproto.Ack - localAck *handshakeproto.Ack - buf []byte + conn io.ReadWriteCloser + remoteCred *handshakeproto.Credentials + remoteProto *handshakeproto.Proto + remoteAck *handshakeproto.Ack + localAck *handshakeproto.Ack + buf []byte } func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err error) { @@ -209,8 +89,17 @@ func (h *handshake) writeCredentials(cred *handshakeproto.Credentials) (err erro return h.writeData(msgTypeCred, n) } +func (h *handshake) writeProto(proto *handshakeproto.Proto) (err error) { + h.buf = slices.Grow(h.buf, proto.Size()+headerSize)[:proto.Size()+headerSize] + n, err := proto.MarshalToSizedBuffer(h.buf[headerSize:]) + if err != nil { + return err + } + return h.writeData(msgTypeProto, n) +} + func (h *handshake) tryWriteErrAndClose(err error) { - if err == ErrGotNotAHandshakeMessage { + if err == ErrUnexpectedPayload { // if we got unexpected message - just close the connection _ = h.conn.Close() return @@ -243,21 +132,26 @@ func (h *handshake) writeData(tp byte, size int) (err error) { } type message struct { - cred *handshakeproto.Credentials - ack *handshakeproto.Ack + cred *handshakeproto.Credentials + proto *handshakeproto.Proto + ack *handshakeproto.Ack } -func (h *handshake) readMsg() (msg message, err error) { +func (h *handshake) readMsg(allowedTypes ...byte) (msg message, err error) { h.buf = slices.Grow(h.buf, headerSize)[:headerSize] if _, err = io.ReadFull(h.conn, h.buf[:headerSize]); err != nil { return } tp := h.buf[0] - if tp != msgTypeCred && tp != msgTypeAck { - err = ErrGotNotAHandshakeMessage + if !slices.Contains(allowedTypes, tp) { + err = ErrUnexpectedPayload return } size := binary.LittleEndian.Uint32(h.buf[1:headerSize]) + if size > sizeLimit { + err = ErrGotUnexpectedMessage + return + } h.buf = slices.Grow(h.buf, int(size))[:size] if _, err = io.ReadFull(h.conn, h.buf[:size]); err != nil { return @@ -273,6 +167,11 @@ func (h *handshake) readMsg() (msg message, err error) { return } msg.ack = h.remoteAck + case msgTypeProto: + if err = h.remoteProto.Unmarshal(h.buf[:size]); err != nil { + return + } + msg.proto = h.remoteProto } return } @@ -284,5 +183,6 @@ func (h *handshake) release() { h.remoteAck.Error = 0 h.remoteCred.Type = 0 h.remoteCred.Payload = h.remoteCred.Payload[:0] + h.remoteProto.Proto = 0 handshakePool.Put(h) } diff --git a/net/secureservice/handshake/handshakeproto/handshake.pb.go b/net/secureservice/handshake/handshakeproto/handshake.pb.go index 3d868ef0..e9d6dfcb 100644 --- a/net/secureservice/handshake/handshakeproto/handshake.pb.go +++ b/net/secureservice/handshake/handshakeproto/handshake.pb.go @@ -59,6 +59,7 @@ const ( Error_SkipVerifyNotAllowed Error = 4 Error_DeadlineExceeded Error = 5 Error_IncompatibleVersion Error = 6 + Error_IncompatibleProto Error = 7 ) var Error_name = map[int32]string{ @@ -69,6 +70,7 @@ var Error_name = map[int32]string{ 4: "SkipVerifyNotAllowed", 5: "DeadlineExceeded", 6: "IncompatibleVersion", + 7: "IncompatibleProto", } var Error_value = map[string]int32{ @@ -79,6 +81,7 @@ var Error_value = map[string]int32{ "SkipVerifyNotAllowed": 4, "DeadlineExceeded": 5, "IncompatibleVersion": 6, + "IncompatibleProto": 7, } func (x Error) String() string { @@ -89,6 +92,28 @@ func (Error) EnumDescriptor() ([]byte, []int) { return fileDescriptor_60283fc75f020893, []int{1} } +type ProtoType int32 + +const ( + ProtoType_DRPC ProtoType = 0 +) + +var ProtoType_name = map[int32]string{ + 0: "DRPC", +} + +var ProtoType_value = map[string]int32{ + "DRPC": 0, +} + +func (x ProtoType) String() string { + return proto.EnumName(ProtoType_name, int32(x)) +} + +func (ProtoType) EnumDescriptor() ([]byte, []int) { + return fileDescriptor_60283fc75f020893, []int{2} +} + type Credentials struct { Type CredentialsType `protobuf:"varint,1,opt,name=type,proto3,enum=anyHandshake.CredentialsType" json:"type,omitempty"` Payload []byte `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` @@ -247,12 +272,58 @@ func (m *Ack) GetError() Error { return Error_Null } +type Proto struct { + Proto ProtoType `protobuf:"varint,1,opt,name=proto,proto3,enum=anyHandshake.ProtoType" json:"proto,omitempty"` +} + +func (m *Proto) Reset() { *m = Proto{} } +func (m *Proto) String() string { return proto.CompactTextString(m) } +func (*Proto) ProtoMessage() {} +func (*Proto) Descriptor() ([]byte, []int) { + return fileDescriptor_60283fc75f020893, []int{3} +} +func (m *Proto) XXX_Unmarshal(b []byte) error { + return m.Unmarshal(b) +} +func (m *Proto) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + if deterministic { + return xxx_messageInfo_Proto.Marshal(b, m, deterministic) + } else { + b = b[:cap(b)] + n, err := m.MarshalToSizedBuffer(b) + if err != nil { + return nil, err + } + return b[:n], nil + } +} +func (m *Proto) XXX_Merge(src proto.Message) { + xxx_messageInfo_Proto.Merge(m, src) +} +func (m *Proto) XXX_Size() int { + return m.Size() +} +func (m *Proto) XXX_DiscardUnknown() { + xxx_messageInfo_Proto.DiscardUnknown(m) +} + +var xxx_messageInfo_Proto proto.InternalMessageInfo + +func (m *Proto) GetProto() ProtoType { + if m != nil { + return m.Proto + } + return ProtoType_DRPC +} + func init() { proto.RegisterEnum("anyHandshake.CredentialsType", CredentialsType_name, CredentialsType_value) proto.RegisterEnum("anyHandshake.Error", Error_name, Error_value) + proto.RegisterEnum("anyHandshake.ProtoType", ProtoType_name, ProtoType_value) proto.RegisterType((*Credentials)(nil), "anyHandshake.Credentials") proto.RegisterType((*PayloadSignedPeerIds)(nil), "anyHandshake.PayloadSignedPeerIds") proto.RegisterType((*Ack)(nil), "anyHandshake.Ack") + proto.RegisterType((*Proto)(nil), "anyHandshake.Proto") } func init() { @@ -260,32 +331,35 @@ func init() { } var fileDescriptor_60283fc75f020893 = []byte{ - // 395 bytes of a gzipped FileDescriptorProto - 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcd, 0x6e, 0x13, 0x31, - 0x10, 0xc7, 0xd7, 0x4d, 0x52, 0xaa, 0x21, 0x2d, 0xee, 0x34, 0xc0, 0x0a, 0x89, 0x55, 0x94, 0x53, - 0xc8, 0x21, 0xe1, 0xeb, 0x05, 0x02, 0x2d, 0x22, 0x97, 0xaa, 0xda, 0x42, 0x0f, 0xdc, 0xdc, 0xf5, - 0xd0, 0x5a, 0x31, 0xf6, 0xca, 0x76, 0x43, 0xf7, 0x2d, 0xb8, 0xf2, 0x46, 0x1c, 0x7b, 0xe4, 0x88, - 0x92, 0x17, 0x41, 0x71, 0x12, 0x92, 0x70, 0xea, 0xc5, 0x9e, 0x8f, 0x9f, 0xfd, 0xff, 0x8f, 0x65, - 0x18, 0x1a, 0x0a, 0x03, 0x4f, 0xc5, 0x8d, 0x23, 0x4f, 0x6e, 0xa2, 0x0a, 0x1a, 0x5c, 0x0b, 0x23, - 0xfd, 0xb5, 0x18, 0x6f, 0x44, 0xa5, 0xb3, 0xc1, 0x0e, 0xe2, 0xea, 0xd7, 0xd5, 0x7e, 0x2c, 0x60, - 0x53, 0x98, 0xea, 0xe3, 0xaa, 0xd6, 0x09, 0xf0, 0xf0, 0xbd, 0x23, 0x49, 0x26, 0x28, 0xa1, 0x3d, - 0xbe, 0x82, 0x7a, 0xa8, 0x4a, 0x4a, 0x59, 0x9b, 0x75, 0x0f, 0x5e, 0x3f, 0xef, 0x6f, 0xb2, 0xfd, - 0x0d, 0xf0, 0x53, 0x55, 0x52, 0x1e, 0x51, 0x4c, 0xe1, 0x41, 0x29, 0x2a, 0x6d, 0x85, 0x4c, 0x77, - 0xda, 0xac, 0xdb, 0xcc, 0x57, 0xe9, 0xbc, 0x33, 0x21, 0xe7, 0x95, 0x35, 0x69, 0xad, 0xcd, 0xba, - 0xfb, 0xf9, 0x2a, 0xed, 0x7c, 0x80, 0xd6, 0xd9, 0x02, 0x3a, 0x57, 0x57, 0x86, 0xe4, 0x19, 0x91, - 0x1b, 0x49, 0x8f, 0xcf, 0x60, 0x4f, 0x45, 0x89, 0x50, 0x45, 0x0b, 0xcd, 0xfc, 0x5f, 0x8e, 0x08, - 0x75, 0xaf, 0xae, 0xcc, 0x52, 0x24, 0xc6, 0x9d, 0x97, 0x50, 0x1b, 0x16, 0x63, 0x7c, 0x01, 0x0d, - 0x72, 0xce, 0xba, 0xa5, 0xed, 0xa3, 0x6d, 0xdb, 0x27, 0xf3, 0x56, 0xbe, 0x20, 0x7a, 0x6f, 0xe1, - 0xd1, 0x7f, 0x63, 0xe0, 0x01, 0xc0, 0xf9, 0x58, 0x95, 0x17, 0xe4, 0xd4, 0xd7, 0x8a, 0x27, 0x78, - 0x08, 0xfb, 0x5b, 0xae, 0x38, 0xeb, 0xfd, 0x64, 0xd0, 0x88, 0xd7, 0xe0, 0x1e, 0xd4, 0x4f, 0x6f, - 0xb4, 0xe6, 0xc9, 0xfc, 0xd8, 0x67, 0x43, 0xb7, 0x25, 0x15, 0x81, 0x24, 0x67, 0xf8, 0x04, 0x70, - 0x64, 0x26, 0x42, 0x2b, 0xb9, 0x21, 0xc0, 0x77, 0xf0, 0x31, 0x1c, 0xae, 0xb9, 0xe5, 0xd4, 0xbc, - 0x86, 0x29, 0xb4, 0xd6, 0xaa, 0xa7, 0x36, 0x0c, 0xb5, 0xb6, 0xdf, 0x49, 0xf2, 0x3a, 0xb6, 0x80, - 0x1f, 0x93, 0x90, 0x5a, 0x19, 0x3a, 0xb9, 0x2d, 0x88, 0x24, 0x49, 0xde, 0xc0, 0xa7, 0x70, 0x34, - 0x32, 0x85, 0xfd, 0x56, 0x8a, 0xa0, 0x2e, 0x35, 0x5d, 0x2c, 0x5e, 0x92, 0xef, 0xbe, 0x3b, 0xfe, - 0x35, 0xcd, 0xd8, 0xdd, 0x34, 0x63, 0x7f, 0xa6, 0x19, 0xfb, 0x31, 0xcb, 0x92, 0xbb, 0x59, 0x96, - 0xfc, 0x9e, 0x65, 0xc9, 0x97, 0xde, 0xfd, 0x3f, 0xcb, 0xe5, 0x6e, 0xdc, 0xde, 0xfc, 0x0d, 0x00, - 0x00, 0xff, 0xff, 0xbf, 0x78, 0x2f, 0x36, 0x61, 0x02, 0x00, 0x00, + // 439 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x8c, 0x92, 0xcf, 0x6e, 0xd3, 0x40, + 0x10, 0xc6, 0xbd, 0x4d, 0xdc, 0x96, 0x21, 0x2d, 0xdb, 0x6d, 0x4a, 0x2d, 0x24, 0xac, 0x28, 0xa7, + 0x10, 0x89, 0x84, 0x7f, 0xe2, 0x1e, 0x9a, 0x22, 0x72, 0xa9, 0x22, 0x17, 0x7a, 0xe0, 0xb6, 0xf5, + 0x0e, 0xed, 0x2a, 0xcb, 0xda, 0xda, 0xdd, 0x86, 0xfa, 0x2d, 0x78, 0x14, 0x1e, 0x83, 0x63, 0x8f, + 0x1c, 0x51, 0xf2, 0x22, 0xc8, 0x9b, 0xa4, 0x71, 0x38, 0xf5, 0x62, 0xef, 0xcc, 0xfc, 0xc6, 0xdf, + 0x7c, 0xe3, 0x85, 0x81, 0x46, 0xd7, 0xb7, 0x98, 0xde, 0x18, 0xb4, 0x68, 0xa6, 0x32, 0xc5, 0xfe, + 0x35, 0xd7, 0xc2, 0x5e, 0xf3, 0x49, 0xe5, 0x94, 0x9b, 0xcc, 0x65, 0x7d, 0xff, 0xb4, 0xeb, 0x6c, + 0xcf, 0x27, 0x58, 0x83, 0xeb, 0xe2, 0xd3, 0x2a, 0xd7, 0x76, 0xf0, 0xf8, 0xc4, 0xa0, 0x40, 0xed, + 0x24, 0x57, 0x96, 0xbd, 0x86, 0xba, 0x2b, 0x72, 0x8c, 0x48, 0x8b, 0x74, 0xf6, 0xdf, 0x3c, 0xef, + 0x55, 0xd9, 0x5e, 0x05, 0xfc, 0x5c, 0xe4, 0x98, 0x78, 0x94, 0x45, 0xb0, 0x93, 0xf3, 0x42, 0x65, + 0x5c, 0x44, 0x5b, 0x2d, 0xd2, 0x69, 0x24, 0xab, 0xb0, 0xac, 0x4c, 0xd1, 0x58, 0x99, 0xe9, 0xa8, + 0xd6, 0x22, 0x9d, 0xbd, 0x64, 0x15, 0xb6, 0x3f, 0x42, 0x73, 0xbc, 0x80, 0xce, 0xe5, 0x95, 0x46, + 0x31, 0x46, 0x34, 0x23, 0x61, 0xd9, 0x33, 0xd8, 0x95, 0x5e, 0xc2, 0x15, 0x7e, 0x84, 0x46, 0x72, + 0x1f, 0x33, 0x06, 0x75, 0x2b, 0xaf, 0xf4, 0x52, 0xc4, 0x9f, 0xdb, 0xaf, 0xa0, 0x36, 0x48, 0x27, + 0xec, 0x05, 0x84, 0x68, 0x4c, 0x66, 0x96, 0x63, 0x1f, 0x6e, 0x8e, 0x7d, 0x5a, 0x96, 0x92, 0x05, + 0xd1, 0x7e, 0x0f, 0xe1, 0xd8, 0xaf, 0xe1, 0x25, 0x84, 0x7e, 0x1f, 0xcb, 0x9e, 0xe3, 0xcd, 0x1e, + 0xcf, 0x78, 0x93, 0x0b, 0xaa, 0xfb, 0x0e, 0x9e, 0xfc, 0x67, 0x9f, 0xed, 0x03, 0x9c, 0x4f, 0x64, + 0x7e, 0x81, 0x46, 0x7e, 0x2b, 0x68, 0xc0, 0x0e, 0x60, 0x6f, 0xc3, 0x0d, 0x25, 0xdd, 0x5f, 0x04, + 0x42, 0x2f, 0xcf, 0x76, 0xa1, 0x7e, 0x76, 0xa3, 0x14, 0x0d, 0xca, 0xb6, 0x2f, 0x1a, 0x6f, 0x73, + 0x4c, 0x1d, 0x0a, 0x4a, 0xd8, 0x53, 0x60, 0x23, 0x3d, 0xe5, 0x4a, 0x8a, 0x8a, 0x00, 0xdd, 0x62, + 0x47, 0x70, 0xb0, 0xe6, 0x96, 0xdb, 0xa2, 0x35, 0x16, 0x41, 0x73, 0xad, 0x7a, 0x96, 0xb9, 0x81, + 0x52, 0xd9, 0x0f, 0x14, 0xb4, 0xce, 0x9a, 0x40, 0x87, 0xc8, 0x85, 0x92, 0x1a, 0x4f, 0x6f, 0x53, + 0x44, 0x81, 0x82, 0x86, 0xec, 0x18, 0x0e, 0x47, 0x3a, 0xcd, 0xbe, 0xe7, 0xdc, 0xc9, 0x4b, 0x85, + 0x17, 0x8b, 0x3f, 0x40, 0xb7, 0xcb, 0xef, 0x57, 0x0b, 0xde, 0x31, 0xdd, 0xe9, 0x1e, 0xc1, 0xa3, + 0x7b, 0xf3, 0xe5, 0xd4, 0xc3, 0x64, 0x7c, 0x42, 0x83, 0x0f, 0xc3, 0xdf, 0xb3, 0x98, 0xdc, 0xcd, + 0x62, 0xf2, 0x77, 0x16, 0x93, 0x9f, 0xf3, 0x38, 0xb8, 0x9b, 0xc7, 0xc1, 0x9f, 0x79, 0x1c, 0x7c, + 0xed, 0x3e, 0xfc, 0x4a, 0x5e, 0x6e, 0xfb, 0xd7, 0xdb, 0x7f, 0x01, 0x00, 0x00, 0xff, 0xff, 0x53, + 0x32, 0xf7, 0x79, 0xc7, 0x02, 0x00, 0x00, } func (m *Credentials) Marshal() (dAtA []byte, err error) { @@ -393,6 +467,34 @@ func (m *Ack) MarshalToSizedBuffer(dAtA []byte) (int, error) { return len(dAtA) - i, nil } +func (m *Proto) Marshal() (dAtA []byte, err error) { + size := m.Size() + dAtA = make([]byte, size) + n, err := m.MarshalToSizedBuffer(dAtA[:size]) + if err != nil { + return nil, err + } + return dAtA[:n], nil +} + +func (m *Proto) MarshalTo(dAtA []byte) (int, error) { + size := m.Size() + return m.MarshalToSizedBuffer(dAtA[:size]) +} + +func (m *Proto) MarshalToSizedBuffer(dAtA []byte) (int, error) { + i := len(dAtA) + _ = i + var l int + _ = l + if m.Proto != 0 { + i = encodeVarintHandshake(dAtA, i, uint64(m.Proto)) + i-- + dAtA[i] = 0x8 + } + return len(dAtA) - i, nil +} + func encodeVarintHandshake(dAtA []byte, offset int, v uint64) int { offset -= sovHandshake(v) base := offset @@ -452,6 +554,18 @@ func (m *Ack) Size() (n int) { return n } +func (m *Proto) Size() (n int) { + if m == nil { + return 0 + } + var l int + _ = l + if m.Proto != 0 { + n += 1 + sovHandshake(uint64(m.Proto)) + } + return n +} + func sovHandshake(x uint64) (n int) { return (math_bits.Len64(x|1) + 6) / 7 } @@ -767,6 +881,75 @@ func (m *Ack) Unmarshal(dAtA []byte) error { } return nil } +func (m *Proto) Unmarshal(dAtA []byte) error { + l := len(dAtA) + iNdEx := 0 + for iNdEx < l { + preIndex := iNdEx + var wire uint64 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + wire |= uint64(b&0x7F) << shift + if b < 0x80 { + break + } + } + fieldNum := int32(wire >> 3) + wireType := int(wire & 0x7) + if wireType == 4 { + return fmt.Errorf("proto: Proto: wiretype end group for non-group") + } + if fieldNum <= 0 { + return fmt.Errorf("proto: Proto: illegal tag %d (wire type %d)", fieldNum, wire) + } + switch fieldNum { + case 1: + if wireType != 0 { + return fmt.Errorf("proto: wrong wireType = %d for field Proto", wireType) + } + m.Proto = 0 + for shift := uint(0); ; shift += 7 { + if shift >= 64 { + return ErrIntOverflowHandshake + } + if iNdEx >= l { + return io.ErrUnexpectedEOF + } + b := dAtA[iNdEx] + iNdEx++ + m.Proto |= ProtoType(b&0x7F) << shift + if b < 0x80 { + break + } + } + default: + iNdEx = preIndex + skippy, err := skipHandshake(dAtA[iNdEx:]) + if err != nil { + return err + } + if (skippy < 0) || (iNdEx+skippy) < 0 { + return ErrInvalidLengthHandshake + } + if (iNdEx + skippy) > l { + return io.ErrUnexpectedEOF + } + iNdEx += skippy + } + } + + if iNdEx > l { + return io.ErrUnexpectedEOF + } + return nil +} func skipHandshake(dAtA []byte) (n int, err error) { l := len(dAtA) iNdEx := 0 diff --git a/net/secureservice/handshake/handshakeproto/protos/handshake.proto b/net/secureservice/handshake/handshakeproto/protos/handshake.proto index cca5822e..1ea66b28 100644 --- a/net/secureservice/handshake/handshakeproto/protos/handshake.proto +++ b/net/secureservice/handshake/handshakeproto/protos/handshake.proto @@ -5,6 +5,8 @@ option go_package = "net/secureservice/handshake/handshakeproto"; /* +CREDENTIALS HANDSHAKE + Alice opens a new connection with Bob 1. TLS handshake done successfully; both sides know local and remote peer identifiers. @@ -68,4 +70,20 @@ enum Error { SkipVerifyNotAllowed = 4; DeadlineExceeded = 5; IncompatibleVersion = 6; + IncompatibleProto = 7; +} + + +/* + +PROTO HANDSHAKE + + */ + +message Proto { + ProtoType proto = 1; +} + +enum ProtoType { + DRPC = 0; } \ No newline at end of file diff --git a/net/secureservice/handshake/proto.go b/net/secureservice/handshake/proto.go new file mode 100644 index 00000000..45e95ab5 --- /dev/null +++ b/net/secureservice/handshake/proto.go @@ -0,0 +1,97 @@ +package handshake + +import ( + "context" + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "golang.org/x/exp/slices" + "net" +) + +type ProtoChecker struct { + AllowedProtoTypes []handshakeproto.ProtoType +} + +func OutgoingProtoHandshake(ctx context.Context, conn net.Conn, pt handshakeproto.ProtoType) (err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + err = outgoingProtoHandshake(h, conn, pt) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = conn.Close() + return ctx.Err() + } +} + +func outgoingProtoHandshake(h *handshake, conn net.Conn, pt handshakeproto.ProtoType) (err error) { + defer h.release() + h.conn = conn + localProto := &handshakeproto.Proto{ + Proto: pt, + } + if err = h.writeProto(localProto); err != nil { + h.tryWriteErrAndClose(err) + return + } + msg, err := h.readMsg(msgTypeAck) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if msg.ack.Error == handshakeproto.Error_IncompatibleProto { + return ErrRemoteIncompatibleProto + } + if msg.ack.Error == handshakeproto.Error_Null { + return nil + } + return HandshakeError{e: msg.ack.Error} +} + +func IncomingProtoHandshake(ctx context.Context, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { + if ctx == nil { + ctx = context.Background() + } + h := newHandshake() + done := make(chan struct{}) + go func() { + defer close(done) + protoType, err = incomingProtoHandshake(h, conn, pt) + }() + select { + case <-done: + return + case <-ctx.Done(): + _ = conn.Close() + return 0, ctx.Err() + } +} + +func incomingProtoHandshake(h *handshake, conn net.Conn, pt ProtoChecker) (protoType handshakeproto.ProtoType, err error) { + defer h.release() + h.conn = conn + + msg, err := h.readMsg(msgTypeProto) + if err != nil { + h.tryWriteErrAndClose(err) + return + } + if !slices.Contains(pt.AllowedProtoTypes, msg.proto.Proto) { + err = ErrIncompatibleProto + h.tryWriteErrAndClose(err) + return + } + + if err = h.writeAck(handshakeproto.Error_Null); err != nil { + h.tryWriteErrAndClose(err) + return 0, err + } else { + return msg.proto.Proto, nil + } +} diff --git a/net/secureservice/handshake/proto_test.go b/net/secureservice/handshake/proto_test.go new file mode 100644 index 00000000..f689e372 --- /dev/null +++ b/net/secureservice/handshake/proto_test.go @@ -0,0 +1,121 @@ +package handshake + +import ( + "github.com/anyproto/any-sync/net/secureservice/handshake/handshakeproto" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "testing" + "time" +) + +type protoRes struct { + protoType handshakeproto.ProtoType + err error +} + +func newProtoChecker(types ...handshakeproto.ProtoType) ProtoChecker { + return ProtoChecker{AllowedProtoTypes: types} +} +func TestIncomingProtoHandshake(t *testing.T) { + t.Run("success", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1)) + protoResCh <- protoRes{protoType: protoType, err: err} + }() + h := newHandshake() + h.conn = c2 + + // write desired proto + require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: handshakeproto.ProtoType(1)})) + msg, err := h.readMsg(msgTypeAck) + require.NoError(t, err) + assert.Equal(t, handshakeproto.Error_Null, msg.ack.Error) + res := <-protoResCh + require.NoError(t, res.err) + assert.Equal(t, handshakeproto.ProtoType(1), res.protoType) + }) + t.Run("incompatible", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + protoType, err := IncomingProtoHandshake(nil, c1, newProtoChecker(1)) + protoResCh <- protoRes{protoType: protoType, err: err} + }() + h := newHandshake() + h.conn = c2 + + // write desired proto + require.NoError(t, h.writeProto(&handshakeproto.Proto{Proto: 0})) + msg, err := h.readMsg(msgTypeAck) + require.NoError(t, err) + assert.Equal(t, handshakeproto.Error_IncompatibleProto, msg.ack.Error) + res := <-protoResCh + require.Error(t, res.err, ErrIncompatibleProto.Error()) + }) +} + +func TestOutgoingProtoHandshake(t *testing.T) { + t.Run("success", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + err := OutgoingProtoHandshake(nil, c1, 1) + protoResCh <- protoRes{err: err} + }() + h := newHandshake() + h.conn = c2 + + msg, err := h.readMsg(msgTypeProto) + require.NoError(t, err) + assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto) + require.NoError(t, h.writeAck(handshakeproto.Error_Null)) + + res := <-protoResCh + assert.NoError(t, res.err) + }) + t.Run("incompatible", func(t *testing.T) { + c1, c2 := newConnPair(t) + var protoResCh = make(chan protoRes, 1) + go func() { + err := OutgoingProtoHandshake(nil, c1, 1) + protoResCh <- protoRes{err: err} + }() + h := newHandshake() + h.conn = c2 + + msg, err := h.readMsg(msgTypeProto) + require.NoError(t, err) + assert.Equal(t, handshakeproto.ProtoType(1), msg.proto.Proto) + require.NoError(t, h.writeAck(handshakeproto.Error_IncompatibleProto)) + + res := <-protoResCh + assert.EqualError(t, res.err, ErrRemoteIncompatibleProto.Error()) + }) +} + +func TestEndToEndProto(t *testing.T) { + c1, c2 := newConnPair(t) + var ( + inResCh = make(chan protoRes, 1) + outResCh = make(chan protoRes, 1) + ) + st := time.Now() + go func() { + err := OutgoingProtoHandshake(nil, c1, 0) + outResCh <- protoRes{err: err} + }() + go func() { + protoType, err := IncomingProtoHandshake(nil, c2, newProtoChecker(0, 1)) + inResCh <- protoRes{protoType: protoType, err: err} + }() + + outRes := <-outResCh + assert.NoError(t, outRes.err) + + inRes := <-inResCh + assert.NoError(t, inRes.err) + assert.Equal(t, handshakeproto.ProtoType(0), inRes.protoType) + t.Log("dur", time.Since(st)) +} diff --git a/net/secureservice/secureservice.go b/net/secureservice/secureservice.go index d7d86040..d20d7109 100644 --- a/net/secureservice/secureservice.go +++ b/net/secureservice/secureservice.go @@ -2,6 +2,7 @@ package secureservice import ( "context" + "crypto/tls" commonaccount "github.com/anyproto/any-sync/accountservice" "github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app/logger" @@ -10,9 +11,9 @@ import ( "github.com/anyproto/any-sync/net/secureservice/handshake" "github.com/anyproto/any-sync/nodeconf" "github.com/libp2p/go-libp2p/core/crypto" - "github.com/libp2p/go-libp2p/core/sec" libp2ptls "github.com/libp2p/go-libp2p/p2p/security/tls" "go.uber.org/zap" + "io" "net" ) @@ -20,13 +21,21 @@ const CName = "common.net.secure" var log = logger.NewNamed(CName) +const ( + // ProtoVersion 0 - first any-sync version with raw tcp connections + // ProtoVersion 1 - version with yamux over tcp and quic + ProtoVersion = 1 +) + func New() SecureService { return &secureService{} } type SecureService interface { - SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) - SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) + SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) + SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) + HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, remotePeerId string) (cctx context.Context, err error) + ServerTlsConfig() (*tls.Config, error) app.Component } @@ -67,6 +76,8 @@ func (s *secureService) Init(a *app.App) (err error) { return } + s.protoVersion = ProtoVersion + log.Info("secure service init", zap.String("peerId", account.Account().PeerId)) return nil } @@ -75,25 +86,28 @@ func (s *secureService) Name() (name string) { return CName } -func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, sc sec.SecureConn, err error) { - sc, err = s.p2pTr.SecureInbound(ctx, conn, "") +func (s *secureService) SecureInbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) { + sc, err := s.p2pTr.SecureInbound(ctx, conn, "") if err != nil { - return nil, nil, handshake.HandshakeError{ + return nil, handshake.HandshakeError{ Err: err, } } + return s.HandshakeInbound(ctx, sc, sc.RemotePeer().String()) +} - identity, err := handshake.IncomingHandshake(ctx, sc, s.inboundChecker) +func (s *secureService) HandshakeInbound(ctx context.Context, conn io.ReadWriteCloser, peerId string) (cctx context.Context, err error) { + identity, err := handshake.IncomingHandshake(ctx, conn, peerId, s.inboundChecker) if err != nil { - return nil, nil, err + return nil, err } cctx = context.Background() - cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String()) + cctx = peer.CtxWithPeerId(cctx, peerId) cctx = peer.CtxWithIdentity(cctx, identity) return } -func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec.SecureConn, error) { +func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (cctx context.Context, err error) { sc, err := s.p2pTr.SecureOutbound(ctx, conn, "") if err != nil { return nil, handshake.HandshakeError{Err: err} @@ -106,10 +120,22 @@ func (s *secureService) SecureOutbound(ctx context.Context, conn net.Conn) (sec. } else { checker = s.noVerifyChecker } - // ignore identity for outgoing connection because we don't need it at this moment - _, err = handshake.OutgoingHandshake(ctx, sc, checker) + identity, err := handshake.OutgoingHandshake(ctx, sc, sc.RemotePeer().String(), checker) if err != nil { return nil, err } - return sc, nil + cctx = context.Background() + cctx = peer.CtxWithPeerId(cctx, sc.RemotePeer().String()) + cctx = peer.CtxWithIdentity(cctx, identity) + return cctx, nil +} + +func (s *secureService) ServerTlsConfig() (*tls.Config, error) { + p2pIdn, err := libp2ptls.NewIdentity(s.key) + if err != nil { + return nil, err + } + conf, _ := p2pIdn.ConfigForPeer("") + conf.NextProtos = []string{"anysync"} + return conf, nil } diff --git a/net/secureservice/secureservice_test.go b/net/secureservice/secureservice_test.go index 2f435a65..6aee985d 100644 --- a/net/secureservice/secureservice_test.go +++ b/net/secureservice/secureservice_test.go @@ -32,16 +32,18 @@ func TestHandshake(t *testing.T) { resCh := make(chan acceptRes) go func() { var ar acceptRes - ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc) + ar.ctx, ar.err = fxS.SecureInbound(ctx, sc) resCh <- ar }() fxC := newFixture(t, nc, nc.GetAccountService(1), 0) defer fxC.Finish(t) - secConn, err := fxC.SecureOutbound(ctx, cc) + cctx, err := fxC.SecureOutbound(ctx, cc) require.NoError(t, err) - assert.Equal(t, nc.GetAccountService(0).Account().PeerId, secConn.RemotePeer().String()) + ctxPeerId, err := peer.CtxPeerId(cctx) + require.NoError(t, err) + assert.Equal(t, nc.GetAccountService(0).Account().PeerId, ctxPeerId) res := <-resCh require.NoError(t, res.err) peerId, err := peer.CtxPeerId(res.ctx) @@ -67,7 +69,7 @@ func TestHandshakeIncompatibleVersion(t *testing.T) { resCh := make(chan acceptRes) go func() { var ar acceptRes - ar.ctx, ar.conn, ar.err = fxS.SecureInbound(ctx, sc) + ar.ctx, ar.err = fxS.SecureInbound(ctx, sc) resCh <- ar }() fxC := newFixture(t, nc, nc.GetAccountService(1), 1) diff --git a/net/streampool/sendpool.go b/net/streampool/sendpool.go index 0bff0765..642aeb9a 100644 --- a/net/streampool/sendpool.go +++ b/net/streampool/sendpool.go @@ -10,7 +10,10 @@ import ( // workers - how many processes will execute tasks // maxSize - limit for queue size func NewExecPool(workers, maxSize int) *ExecPool { + ctx, cancel := context.WithCancel(context.Background()) ss := &ExecPool{ + ctx: ctx, + cancel: cancel, workers: workers, batch: mb.New[func()](maxSize), } @@ -19,6 +22,8 @@ func NewExecPool(workers, maxSize int) *ExecPool { // ExecPool needed for parallel execution of the incoming send tasks type ExecPool struct { + ctx context.Context + cancel context.CancelFunc workers int batch *mb.MB[func()] } @@ -39,7 +44,7 @@ func (ss *ExecPool) Run() { func (ss *ExecPool) sendLoop() { for { - f, err := ss.batch.WaitOne(context.Background()) + f, err := ss.batch.WaitOne(ss.ctx) if err != nil { log.Debug("close send loop", zap.Error(err)) return @@ -49,5 +54,6 @@ func (ss *ExecPool) sendLoop() { } func (ss *ExecPool) Close() (err error) { + ss.cancel() return ss.batch.Close() } diff --git a/net/streampool/stream.go b/net/streampool/stream.go index 5dff0cb9..1d59af4b 100644 --- a/net/streampool/stream.go +++ b/net/streampool/stream.go @@ -3,7 +3,7 @@ package streampool import ( "context" "github.com/anyproto/any-sync/app/logger" - "github.com/anyproto/any-sync/util/multiqueue" + "github.com/cheggaaa/mb/v3" "go.uber.org/zap" "storj.io/drpc" "sync/atomic" @@ -17,17 +17,12 @@ type stream struct { streamId uint32 closed atomic.Bool l logger.CtxLogger - queue multiqueue.MultiQueue[drpc.Message] + queue *mb.MB[drpc.Message] tags []string } func (sr *stream) write(msg drpc.Message) (err error) { - var queueId string - if qId, ok := msg.(MessageQueueId); ok { - queueId = qId.MessageQueueId() - msg = qId.DrpcMessage() - } - return sr.queue.Add(sr.stream.Context(), queueId, msg) + return sr.queue.Add(sr.stream.Context(), msg) } func (sr *stream) readLoop() error { @@ -50,13 +45,21 @@ func (sr *stream) readLoop() error { } } -func (sr *stream) writeToStream(msg drpc.Message) { - if err := sr.stream.MsgSend(msg, EncodingProto); err != nil { - sr.l.Warn("msg send error", zap.Error(err)) - sr.streamClose() - return +func (sr *stream) writeLoop() { + for { + msg, err := sr.queue.WaitOne(sr.peerCtx) + if err != nil { + if err != mb.ErrClosed { + sr.streamClose() + } + return + } + if err := sr.stream.MsgSend(msg, EncodingProto); err != nil { + sr.l.Warn("msg send error", zap.Error(err)) + sr.streamClose() + return + } } - return } func (sr *stream) streamClose() { diff --git a/net/streampool/streampool.go b/net/streampool/streampool.go index 59ee6e4c..54c513b9 100644 --- a/net/streampool/streampool.go +++ b/net/streampool/streampool.go @@ -4,7 +4,7 @@ import ( "fmt" "github.com/anyproto/any-sync/net" "github.com/anyproto/any-sync/net/peer" - "github.com/anyproto/any-sync/util/multiqueue" + "github.com/cheggaaa/mb/v3" "go.uber.org/zap" "golang.org/x/exp/slices" "golang.org/x/net/context" @@ -74,6 +74,9 @@ func (s *streamPool) ReadStream(drpcStream drpc.Stream, tags ...string) error { if err != nil { return err } + go func() { + st.writeLoop() + }() return st.readLoop() } @@ -85,6 +88,9 @@ func (s *streamPool) AddStream(drpcStream drpc.Stream, tags ...string) error { go func() { _ = st.readLoop() }() + go func() { + st.writeLoop() + }() return nil } @@ -122,7 +128,7 @@ func (s *streamPool) addStream(drpcStream drpc.Stream, tags ...string) (*stream, l: log.With(zap.String("peerId", peerId), zap.Uint32("streamId", streamId)), tags: tags, } - st.queue = multiqueue.New[drpc.Message](st.writeToStream, s.writeQueueSize) + st.queue = mb.New[drpc.Message](s.writeQueueSize) s.streams[streamId] = st s.streamIdsByPeer[peerId] = append(s.streamIdsByPeer[peerId], streamId) for _, tag := range tags { @@ -244,6 +250,8 @@ func (s *streamPool) openStream(ctx context.Context, p peer.Peer) *openingProces close(op.ch) delete(s.opening, p.Id()) }() + // in case there was no peerId in context + ctx := peer.CtxWithPeerId(ctx, p.Id()) // open new stream and add to pool st, tags, err := s.handler.OpenStream(ctx, p) if err != nil { @@ -364,21 +372,3 @@ func removeStream(m map[string][]uint32, key string, streamId uint32) { m[key] = streamIds } } - -// WithQueueId wraps the message and adds queueId -func WithQueueId(msg drpc.Message, queueId string) drpc.Message { - return &messageWithQueueId{queueId: queueId, Message: msg} -} - -type messageWithQueueId struct { - drpc.Message - queueId string -} - -func (m messageWithQueueId) MessageQueueId() string { - return m.queueId -} - -func (m messageWithQueueId) DrpcMessage() drpc.Message { - return m.Message -} diff --git a/net/streampool/streampool_test.go b/net/streampool/streampool_test.go index 75d51059..d4a05de3 100644 --- a/net/streampool/streampool_test.go +++ b/net/streampool/streampool_test.go @@ -18,17 +18,25 @@ import ( var ctx = context.Background() +func makePeerPair(t *testing.T, fx *fixture, peerId string) (pS, pC peer.Peer) { + mcS, mcC := rpctest.MultiConnPair(peerId+"server", peerId) + pS, err := peer.NewPeer(mcS, fx.ts) + require.NoError(t, err) + pC, err = peer.NewPeer(mcC, fx.ts) + require.NoError(t, err) + return +} + func newClientStream(t *testing.T, fx *fixture, peerId string) (st testservice.DRPCTest_TestStreamClient, p peer.Peer) { - p, err := fx.tp.Dial(ctx, peerId) + _, pC := makePeerPair(t, fx, peerId) + drpcConn, err := pC.AcquireDrpcConn(ctx) require.NoError(t, err) - ctx = peer.CtxWithPeerId(ctx, peerId) - s, err := testservice.NewDRPCTestClient(p).TestStream(ctx) + st, err = testservice.NewDRPCTestClient(drpcConn).TestStream(pC.Context()) require.NoError(t, err) - return s, p + return st, pC } func TestStreamPool_AddStream(t *testing.T) { - t.Run("broadcast incoming", func(t *testing.T) { fx := newFixture(t) defer fx.Finish(t) @@ -39,7 +47,7 @@ func TestStreamPool_AddStream(t *testing.T) { require.NoError(t, fx.AddStream(s2, "space2", "common")) require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space1"}, "space1")) - require.NoError(t, fx.Broadcast(ctx, WithQueueId(&testservice.StreamMessage{ReqData: "space2"}, "q2"), "space2")) + require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "space2"}, "space2")) require.NoError(t, fx.Broadcast(ctx, &testservice.StreamMessage{ReqData: "common"}, "common")) var serverResults []string @@ -85,11 +93,10 @@ func TestStreamPool_Send(t *testing.T) { fx := newFixture(t) defer fx.Finish(t) - p, err := fx.tp.Dial(ctx, "p1") - require.NoError(t, err) + pS, _ := makePeerPair(t, fx, "p1") require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) { - return []peer.Peer{p}, nil + return []peer.Peer{pS}, nil })) var msg *testservice.StreamMessage @@ -100,12 +107,12 @@ func TestStreamPool_Send(t *testing.T) { } assert.Equal(t, "should open stream", msg.ReqData) }) + t.Run("parallel open stream", func(t *testing.T) { fx := newFixture(t) defer fx.Finish(t) - p, err := fx.tp.Dial(ctx, "p1") - require.NoError(t, err) + pS, _ := makePeerPair(t, fx, "p1") fx.th.streamOpenDelay = time.Second / 3 @@ -113,7 +120,7 @@ func TestStreamPool_Send(t *testing.T) { for i := 0; i < numMsgs; i++ { go require.NoError(t, fx.Send(ctx, &testservice.StreamMessage{ReqData: "should open stream"}, func(ctx context.Context) (peers []peer.Peer, err error) { - return []peer.Peer{p}, nil + return []peer.Peer{pS}, nil })) } @@ -134,9 +141,8 @@ func TestStreamPool_Send(t *testing.T) { fx := newFixture(t) defer fx.Finish(t) - p, err := fx.tp.Dial(ctx, "p1") - require.NoError(t, err) - _ = p.Close() + pS, _ := makePeerPair(t, fx, "p1") + _ = pS.Close() fx.th.streamOpenDelay = time.Second / 3 @@ -147,11 +153,12 @@ func TestStreamPool_Send(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, p, &testservice.StreamMessage{ReqData: "should open stream"})) + assert.Error(t, fx.StreamPool.(*streamPool).sendOne(ctx, pS, &testservice.StreamMessage{ReqData: "should open stream"})) }() } wg.Wait() }) + } func TestStreamPool_SendById(t *testing.T) { @@ -196,10 +203,9 @@ func TestStreamPool_Tags(t *testing.T) { func newFixture(t *testing.T) *fixture { fx := &fixture{} - ts := rpctest.NewTestServer() + fx.ts = rpctest.NewTestServer() fx.tsh = &testServerHandler{receiveCh: make(chan *testservice.StreamMessage, 100)} - require.NoError(t, testservice.DRPCRegisterTest(ts, fx.tsh)) - fx.tp = rpctest.NewTestPool().WithServer(ts) + require.NoError(t, testservice.DRPCRegisterTest(fx.ts, fx.tsh)) fx.th = &testHandler{} fx.StreamPool = New().NewStreamPool(fx.th, StreamConfig{ SendQueueSize: 10, @@ -211,14 +217,13 @@ func newFixture(t *testing.T) *fixture { type fixture struct { StreamPool - tp *rpctest.TestPool th *testHandler tsh *testServerHandler + ts *rpctest.TestServer } func (fx *fixture) Finish(t *testing.T) { require.NoError(t, fx.Close()) - require.NoError(t, fx.tp.Close(ctx)) } type testHandler struct { @@ -231,7 +236,11 @@ func (t *testHandler) OpenStream(ctx context.Context, p peer.Peer) (stream drpc. if t.streamOpenDelay > 0 { time.Sleep(t.streamOpenDelay) } - stream, err = testservice.NewDRPCTestClient(p).TestStream(ctx) + conn, err := p.AcquireDrpcConn(ctx) + if err != nil { + return + } + stream, err = testservice.NewDRPCTestClient(conn).TestStream(p.Context()) return } diff --git a/net/streampool/testservice/testservice_drpc.pb.go b/net/streampool/testservice/testservice_drpc.pb.go index f50fdbe7..cfe5bce9 100644 --- a/net/streampool/testservice/testservice_drpc.pb.go +++ b/net/streampool/testservice/testservice_drpc.pb.go @@ -1,5 +1,5 @@ // Code generated by protoc-gen-go-drpc. DO NOT EDIT. -// protoc-gen-go-drpc version: v0.0.32 +// protoc-gen-go-drpc version: v0.0.33 // source: net/streampool/testservice/protos/testservice.proto package testservice @@ -72,6 +72,10 @@ type drpcTest_TestStreamClient struct { drpc.Stream } +func (x *drpcTest_TestStreamClient) GetStream() drpc.Stream { + return x.Stream +} + func (x *drpcTest_TestStreamClient) Send(m *StreamMessage) error { return x.MsgSend(m, drpcEncoding_File_net_streampool_testservice_protos_testservice_proto{}) } diff --git a/net/transport/mock_transport/mock_transport.go b/net/transport/mock_transport/mock_transport.go new file mode 100644 index 00000000..43f5572c --- /dev/null +++ b/net/transport/mock_transport/mock_transport.go @@ -0,0 +1,188 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/anyproto/any-sync/net/transport (interfaces: Transport,MultiConn) + +// Package mock_transport is a generated GoMock package. +package mock_transport + +import ( + context "context" + net "net" + reflect "reflect" + time "time" + + transport "github.com/anyproto/any-sync/net/transport" + gomock "github.com/golang/mock/gomock" +) + +// MockTransport is a mock of Transport interface. +type MockTransport struct { + ctrl *gomock.Controller + recorder *MockTransportMockRecorder +} + +// MockTransportMockRecorder is the mock recorder for MockTransport. +type MockTransportMockRecorder struct { + mock *MockTransport +} + +// NewMockTransport creates a new mock instance. +func NewMockTransport(ctrl *gomock.Controller) *MockTransport { + mock := &MockTransport{ctrl: ctrl} + mock.recorder = &MockTransportMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTransport) EXPECT() *MockTransportMockRecorder { + return m.recorder +} + +// Dial mocks base method. +func (m *MockTransport) Dial(arg0 context.Context, arg1 string) (transport.MultiConn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Dial", arg0, arg1) + ret0, _ := ret[0].(transport.MultiConn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Dial indicates an expected call of Dial. +func (mr *MockTransportMockRecorder) Dial(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Dial", reflect.TypeOf((*MockTransport)(nil).Dial), arg0, arg1) +} + +// SetAccepter mocks base method. +func (m *MockTransport) SetAccepter(arg0 transport.Accepter) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetAccepter", arg0) +} + +// SetAccepter indicates an expected call of SetAccepter. +func (mr *MockTransportMockRecorder) SetAccepter(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccepter", reflect.TypeOf((*MockTransport)(nil).SetAccepter), arg0) +} + +// MockMultiConn is a mock of MultiConn interface. +type MockMultiConn struct { + ctrl *gomock.Controller + recorder *MockMultiConnMockRecorder +} + +// MockMultiConnMockRecorder is the mock recorder for MockMultiConn. +type MockMultiConnMockRecorder struct { + mock *MockMultiConn +} + +// NewMockMultiConn creates a new mock instance. +func NewMockMultiConn(ctrl *gomock.Controller) *MockMultiConn { + mock := &MockMultiConn{ctrl: ctrl} + mock.recorder = &MockMultiConnMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockMultiConn) EXPECT() *MockMultiConnMockRecorder { + return m.recorder +} + +// Accept mocks base method. +func (m *MockMultiConn) Accept() (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Accept") + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Accept indicates an expected call of Accept. +func (mr *MockMultiConnMockRecorder) Accept() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Accept", reflect.TypeOf((*MockMultiConn)(nil).Accept)) +} + +// Addr mocks base method. +func (m *MockMultiConn) Addr() string { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Addr") + ret0, _ := ret[0].(string) + return ret0 +} + +// Addr indicates an expected call of Addr. +func (mr *MockMultiConnMockRecorder) Addr() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Addr", reflect.TypeOf((*MockMultiConn)(nil).Addr)) +} + +// Close mocks base method. +func (m *MockMultiConn) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockMultiConnMockRecorder) Close() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockMultiConn)(nil).Close)) +} + +// Context mocks base method. +func (m *MockMultiConn) Context() context.Context { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Context") + ret0, _ := ret[0].(context.Context) + return ret0 +} + +// Context indicates an expected call of Context. +func (mr *MockMultiConnMockRecorder) Context() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Context", reflect.TypeOf((*MockMultiConn)(nil).Context)) +} + +// IsClosed mocks base method. +func (m *MockMultiConn) IsClosed() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "IsClosed") + ret0, _ := ret[0].(bool) + return ret0 +} + +// IsClosed indicates an expected call of IsClosed. +func (mr *MockMultiConnMockRecorder) IsClosed() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsClosed", reflect.TypeOf((*MockMultiConn)(nil).IsClosed)) +} + +// LastUsage mocks base method. +func (m *MockMultiConn) LastUsage() time.Time { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastUsage") + ret0, _ := ret[0].(time.Time) + return ret0 +} + +// LastUsage indicates an expected call of LastUsage. +func (mr *MockMultiConnMockRecorder) LastUsage() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastUsage", reflect.TypeOf((*MockMultiConn)(nil).LastUsage)) +} + +// Open mocks base method. +func (m *MockMultiConn) Open(arg0 context.Context) (net.Conn, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Open", arg0) + ret0, _ := ret[0].(net.Conn) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Open indicates an expected call of Open. +func (mr *MockMultiConnMockRecorder) Open(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Open", reflect.TypeOf((*MockMultiConn)(nil).Open), arg0) +} diff --git a/net/transport/transport.go b/net/transport/transport.go new file mode 100644 index 00000000..2dab5348 --- /dev/null +++ b/net/transport/transport.go @@ -0,0 +1,44 @@ +//go:generate mockgen -destination mock_transport/mock_transport.go github.com/anyproto/any-sync/net/transport Transport,MultiConn +package transport + +import ( + "context" + "errors" + "net" + "time" +) + +var ( + ErrConnClosed = errors.New("connection closed") +) + +// Transport is a common interface for a network transport +type Transport interface { + // SetAccepter sets accepter that will be called for new connections + // this method should be called before app start + SetAccepter(accepter Accepter) + // Dial creates a new connection by given address + Dial(ctx context.Context, addr string) (mc MultiConn, err error) +} + +// MultiConn is an object of multiplexing connection containing handshake info +type MultiConn interface { + // Context returns the connection context that contains handshake details + Context() context.Context + // Accept accepts new sub connections + Accept() (conn net.Conn, err error) + // Open opens new sub connection + Open(ctx context.Context) (conn net.Conn, err error) + // LastUsage returns the time of the last connection activity + LastUsage() time.Time + // Addr returns remote peer address + Addr() string + // IsClosed returns true when connection is closed + IsClosed() bool + // Close closes the connection and all sub connections + Close() error +} + +type Accepter interface { + Accept(mc MultiConn) (err error) +} diff --git a/net/transport/yamux/config.go b/net/transport/yamux/config.go new file mode 100644 index 00000000..2afce5fd --- /dev/null +++ b/net/transport/yamux/config.go @@ -0,0 +1,11 @@ +package yamux + +type configGetter interface { + GetYamux() Config +} + +type Config struct { + ListenAddrs []string `yaml:"listenAddrs"` + WriteTimeoutSec int `yaml:"writeTimeoutSec"` + DialTimeoutSec int `yaml:"dialTimeoutSec"` +} diff --git a/net/transport/yamux/conn.go b/net/transport/yamux/conn.go new file mode 100644 index 00000000..c752d22d --- /dev/null +++ b/net/transport/yamux/conn.go @@ -0,0 +1,55 @@ +package yamux + +import ( + "context" + "github.com/anyproto/any-sync/net/connutil" + "github.com/anyproto/any-sync/net/transport" + "github.com/hashicorp/yamux" + "net" + "time" +) + +func NewMultiConn(cctx context.Context, luConn *connutil.LastUsageConn, addr string, sess *yamux.Session) transport.MultiConn { + return &yamuxConn{ + ctx: cctx, + luConn: luConn, + addr: addr, + Session: sess, + } +} + +type yamuxConn struct { + ctx context.Context + luConn *connutil.LastUsageConn + addr string + *yamux.Session +} + +func (y *yamuxConn) Open(ctx context.Context) (conn net.Conn, err error) { + if conn, err = y.Session.Open(); err != nil { + return + } + return +} + +func (y *yamuxConn) LastUsage() time.Time { + return y.luConn.LastUsage() +} + +func (y *yamuxConn) Context() context.Context { + return y.ctx +} + +func (y *yamuxConn) Addr() string { + return y.addr +} + +func (y *yamuxConn) Accept() (conn net.Conn, err error) { + if conn, err = y.Session.Accept(); err != nil { + if err == yamux.ErrSessionShutdown { + err = transport.ErrConnClosed + } + return + } + return +} diff --git a/net/rpc/server/util.go b/net/transport/yamux/util.go similarity index 93% rename from net/rpc/server/util.go rename to net/transport/yamux/util.go index 5852288a..c1299d15 100644 --- a/net/rpc/server/util.go +++ b/net/transport/yamux/util.go @@ -1,6 +1,6 @@ //go:build !windows -package server +package yamux import ( "errors" diff --git a/net/rpc/server/util_windows.go b/net/transport/yamux/util_windows.go similarity index 97% rename from net/rpc/server/util_windows.go rename to net/transport/yamux/util_windows.go index efef2915..390524d5 100644 --- a/net/rpc/server/util_windows.go +++ b/net/transport/yamux/util_windows.go @@ -1,6 +1,6 @@ //go:build windows -package server +package yamux import ( "errors" diff --git a/net/transport/yamux/yamux.go b/net/transport/yamux/yamux.go new file mode 100644 index 00000000..9a0245a9 --- /dev/null +++ b/net/transport/yamux/yamux.go @@ -0,0 +1,175 @@ +package yamux + +import ( + "context" + "fmt" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/net/connutil" + "github.com/anyproto/any-sync/net/secureservice" + "github.com/anyproto/any-sync/net/transport" + "github.com/hashicorp/yamux" + "go.uber.org/zap" + "net" + "sync" + "time" +) + +const CName = "net.transport.yamux" + +var log = logger.NewNamed(CName) + +func New() Yamux { + return new(yamuxTransport) +} + +// Yamux implements transport.Transport with tcp+yamux +type Yamux interface { + transport.Transport + AddListener(lis net.Listener) + app.ComponentRunnable +} + +type yamuxTransport struct { + secure secureservice.SecureService + accepter transport.Accepter + conf Config + + listeners []net.Listener + listCtx context.Context + listCtxCancel context.CancelFunc + yamuxConf *yamux.Config + mu sync.Mutex +} + +func (y *yamuxTransport) Init(a *app.App) (err error) { + y.secure = a.MustComponent(secureservice.CName).(secureservice.SecureService) + y.conf = a.MustComponent("config").(configGetter).GetYamux() + if y.conf.DialTimeoutSec <= 0 { + y.conf.DialTimeoutSec = 10 + } + if y.conf.WriteTimeoutSec <= 0 { + y.conf.WriteTimeoutSec = 10 + } + y.yamuxConf = yamux.DefaultConfig() + y.yamuxConf.EnableKeepAlive = false + y.yamuxConf.StreamOpenTimeout = time.Duration(y.conf.DialTimeoutSec) * time.Second + y.yamuxConf.ConnectionWriteTimeout = time.Duration(y.conf.WriteTimeoutSec) * time.Second + y.listCtx, y.listCtxCancel = context.WithCancel(context.Background()) + return +} + +func (y *yamuxTransport) Name() string { + return CName +} + +func (y *yamuxTransport) Run(ctx context.Context) (err error) { + if y.accepter == nil { + return fmt.Errorf("can't run service without accepter") + } + y.mu.Lock() + defer y.mu.Unlock() + for _, listAddr := range y.conf.ListenAddrs { + list, err := net.Listen("tcp", listAddr) + if err != nil { + return err + } + y.listeners = append(y.listeners, list) + } + for _, list := range y.listeners { + go y.acceptLoop(y.listCtx, list) + } + return +} + +func (y *yamuxTransport) SetAccepter(accepter transport.Accepter) { + y.accepter = accepter +} + +func (y *yamuxTransport) AddListener(lis net.Listener) { + y.mu.Lock() + defer y.mu.Unlock() + y.listeners = append(y.listeners, lis) + go y.acceptLoop(y.listCtx, lis) +} + +func (y *yamuxTransport) Dial(ctx context.Context, addr string) (mc transport.MultiConn, err error) { + dialTimeout := time.Duration(y.conf.DialTimeoutSec) * time.Second + conn, err := net.DialTimeout("tcp", addr, dialTimeout) + if err != nil { + return nil, err + } + ctx, cancel := context.WithTimeout(ctx, dialTimeout) + defer cancel() + cctx, err := y.secure.SecureOutbound(ctx, conn) + if err != nil { + _ = conn.Close() + return nil, err + } + luc := connutil.NewLastUsageConn(conn) + sess, err := yamux.Client(luc, y.yamuxConf) + if err != nil { + return + } + mc = NewMultiConn(cctx, luc, addr, sess) + return +} + +func (y *yamuxTransport) acceptLoop(ctx context.Context, list net.Listener) { + l := log.With(zap.String("localAddr", list.Addr().String())) + l.Info("yamux listener started") + defer func() { + l.Debug("yamux listener stopped") + }() + for { + conn, err := list.Accept() + if err != nil { + if isTemporary(err) { + l.Debug("listener temporary accept error", zap.Error(err)) + select { + case <-time.After(time.Second): + case <-ctx.Done(): + return + } + continue + } + if err != net.ErrClosed { + l.Error("listener closed with error", zap.Error(err)) + } else { + l.Info("listener closed") + } + return + } + go y.accept(conn) + } +} + +func (y *yamuxTransport) accept(conn net.Conn) { + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(y.conf.DialTimeoutSec)*time.Second) + defer cancel() + cctx, err := y.secure.SecureInbound(ctx, conn) + if err != nil { + log.Warn("incoming connection handshake error", zap.Error(err)) + return + } + luc := connutil.NewLastUsageConn(conn) + sess, err := yamux.Server(luc, y.yamuxConf) + if err != nil { + log.Warn("incoming connection yamux session error", zap.Error(err)) + return + } + mc := NewMultiConn(cctx, luc, conn.RemoteAddr().String(), sess) + if err = y.accepter.Accept(mc); err != nil { + log.Warn("connection accept error", zap.Error(err)) + } +} + +func (y *yamuxTransport) Close(ctx context.Context) (err error) { + if y.listCtxCancel != nil { + y.listCtxCancel() + } + for _, l := range y.listeners { + _ = l.Close() + } + return +} diff --git a/net/transport/yamux/yamux_test.go b/net/transport/yamux/yamux_test.go new file mode 100644 index 00000000..02e1c322 --- /dev/null +++ b/net/transport/yamux/yamux_test.go @@ -0,0 +1,194 @@ +package yamux + +import ( + "bytes" + "context" + "github.com/anyproto/any-sync/app" + "github.com/anyproto/any-sync/net/secureservice" + "github.com/anyproto/any-sync/net/transport" + "github.com/anyproto/any-sync/nodeconf" + "github.com/anyproto/any-sync/nodeconf/mock_nodeconf" + "github.com/anyproto/any-sync/testutil/accounttest" + "github.com/anyproto/any-sync/testutil/testnodeconf" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "io" + "net" + "sync" + "testing" + "time" +) + +var ctx = context.Background() + +func TestYamuxTransport_Dial(t *testing.T) { + fxS := newFixture(t) + defer fxS.finish(t) + fxC := newFixture(t) + defer fxC.finish(t) + + mcC, err := fxC.Dial(ctx, fxS.addr) + require.NoError(t, err) + require.Len(t, fxS.accepter.mcs, 1) + mcS := <-fxS.accepter.mcs + + var ( + sData string + acceptErr error + copyErr error + done = make(chan struct{}) + ) + + go func() { + defer close(done) + conn, serr := mcS.Accept() + if serr != nil { + acceptErr = serr + return + } + buf := bytes.NewBuffer(nil) + _, copyErr = io.Copy(buf, conn) + sData = buf.String() + return + }() + + conn, err := mcC.Open(ctx) + require.NoError(t, err) + data := "some data" + _, err = conn.Write([]byte(data)) + require.NoError(t, err) + require.NoError(t, conn.Close()) + <-done + + assert.NoError(t, acceptErr) + assert.Equal(t, data, sData) + assert.NoError(t, copyErr) +} + +// no deadline - 69100 rps +// common write deadline - 66700 rps +// subconn write deadline - 67100 rps +func TestWriteBench(t *testing.T) { + t.Skip() + var ( + numSubConn = 10 + numWrites = 100000 + ) + + fxS := newFixture(t) + defer fxS.finish(t) + fxC := newFixture(t) + defer fxC.finish(t) + + mcC, err := fxC.Dial(ctx, fxS.addr) + require.NoError(t, err) + mcS := <-fxS.accepter.mcs + + go func() { + for i := 0; i < numSubConn; i++ { + conn, err := mcS.Accept() + require.NoError(t, err) + go func(sc net.Conn) { + var b = make([]byte, 1024) + for { + n, _ := sc.Read(b) + if n > 0 { + sc.Write(b[:n]) + } else { + break + } + } + }(conn) + } + }() + + var wg sync.WaitGroup + wg.Add(numSubConn) + st := time.Now() + for i := 0; i < numSubConn; i++ { + conn, err := mcC.Open(ctx) + require.NoError(t, err) + go func(sc net.Conn) { + defer sc.Close() + defer wg.Done() + for j := 0; j < numWrites; j++ { + var b = []byte("some data some data some data some data some data some data some data some data some data") + sc.Write(b) + sc.Read(b) + } + }(conn) + } + wg.Wait() + dur := time.Since(st) + t.Logf("%.2f req per sec", float64(numWrites*numSubConn)/dur.Seconds()) +} + +type fixture struct { + *yamuxTransport + a *app.App + ctrl *gomock.Controller + mockNodeConf *mock_nodeconf.MockService + acc *accounttest.AccountTestService + accepter *testAccepter + addr string +} + +func newFixture(t *testing.T) *fixture { + fx := &fixture{ + yamuxTransport: New().(*yamuxTransport), + ctrl: gomock.NewController(t), + acc: &accounttest.AccountTestService{}, + accepter: &testAccepter{mcs: make(chan transport.MultiConn, 100)}, + a: new(app.App), + } + + fx.mockNodeConf = mock_nodeconf.NewMockService(fx.ctrl) + fx.mockNodeConf.EXPECT().Init(gomock.Any()) + fx.mockNodeConf.EXPECT().Name().Return(nodeconf.CName).AnyTimes() + fx.mockNodeConf.EXPECT().Run(ctx) + fx.mockNodeConf.EXPECT().Close(ctx) + fx.mockNodeConf.EXPECT().NodeTypes(gomock.Any()).Return([]nodeconf.NodeType{nodeconf.NodeTypeTree}).AnyTimes() + fx.a.Register(fx.acc).Register(newTestConf()).Register(fx.mockNodeConf).Register(secureservice.New()).Register(fx.yamuxTransport).Register(fx.accepter) + require.NoError(t, fx.a.Start(ctx)) + fx.addr = fx.listeners[0].Addr().String() + return fx +} + +func (fx *fixture) finish(t *testing.T) { + require.NoError(t, fx.a.Close(ctx)) + fx.ctrl.Finish() +} + +func newTestConf() *testConf { + return &testConf{testnodeconf.GenNodeConfig(1)} +} + +type testConf struct { + *testnodeconf.Config +} + +func (c *testConf) GetYamux() Config { + return Config{ + ListenAddrs: []string{"127.0.0.1:0"}, + WriteTimeoutSec: 10, + DialTimeoutSec: 10, + } +} + +type testAccepter struct { + err error + mcs chan transport.MultiConn +} + +func (t *testAccepter) Accept(mc transport.MultiConn) (err error) { + t.mcs <- mc + return t.err +} + +func (t *testAccepter) Init(a *app.App) (err error) { + a.MustComponent(CName).(transport.Transport).SetAccepter(t) + return nil +} + +func (t *testAccepter) Name() (name string) { return "testAccepter" }