diff --git a/commonspace/object/acl/list/mock_list/mock_list.go b/commonspace/object/acl/list/mock_list/mock_list.go index 3e5b50be..6cdccac9 100644 --- a/commonspace/object/acl/list/mock_list/mock_list.go +++ b/commonspace/object/acl/list/mock_list/mock_list.go @@ -5,6 +5,7 @@ package mock_list import ( + context "context" reflect "reflect" list "github.com/anyproto/any-sync/commonspace/object/acl/list" @@ -64,18 +65,32 @@ func (mr *MockAclListMockRecorder) AddRawRecord(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockAclList)(nil).AddRawRecord), arg0) } -// Close mocks base method. -func (m *MockAclList) Close() error { +// AddRawRecords mocks base method. +func (m *MockAclList) AddRawRecords(arg0 []*consensusproto.RawRecordWithId) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "Close") + ret := m.ctrl.Call(m, "AddRawRecords", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// AddRawRecords indicates an expected call of AddRawRecords. +func (mr *MockAclListMockRecorder) AddRawRecords(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecords", reflect.TypeOf((*MockAclList)(nil).AddRawRecords), arg0) +} + +// Close mocks base method. +func (m *MockAclList) Close(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close", arg0) ret0, _ := ret[0].(error) return ret0 } // Close indicates an expected call of Close. -func (mr *MockAclListMockRecorder) Close() *gomock.Call { +func (mr *MockAclListMockRecorder) Close(arg0 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAclList)(nil).Close)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAclList)(nil).Close), arg0) } // Get mocks base method. @@ -108,6 +123,20 @@ func (mr *MockAclListMockRecorder) GetIndex(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIndex", reflect.TypeOf((*MockAclList)(nil).GetIndex), arg0) } +// HasHead mocks base method. +func (m *MockAclList) HasHead(arg0 string) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "HasHead", arg0) + ret0, _ := ret[0].(bool) + return ret0 +} + +// HasHead indicates an expected call of HasHead. +func (mr *MockAclListMockRecorder) HasHead(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasHead", reflect.TypeOf((*MockAclList)(nil).HasHead), arg0) +} + // Head mocks base method. func (m *MockAclList) Head() *list.AclRecord { m.ctrl.T.Helper() @@ -253,6 +282,21 @@ func (mr *MockAclListMockRecorder) Records() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Records", reflect.TypeOf((*MockAclList)(nil).Records)) } +// RecordsAfter mocks base method. +func (m *MockAclList) RecordsAfter(arg0 context.Context, arg1 string) ([]*consensusproto.RawRecordWithId, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "RecordsAfter", arg0, arg1) + ret0, _ := ret[0].([]*consensusproto.RawRecordWithId) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// RecordsAfter indicates an expected call of RecordsAfter. +func (mr *MockAclListMockRecorder) RecordsAfter(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordsAfter", reflect.TypeOf((*MockAclList)(nil).RecordsAfter), arg0, arg1) +} + // Root mocks base method. func (m *MockAclList) Root() *consensusproto.RawRecordWithId { m.ctrl.T.Helper() diff --git a/commonspace/object/acl/syncacl/aclsyncprotocol.go b/commonspace/object/acl/syncacl/aclsyncprotocol.go index 0b856b49..22fd3a0c 100644 --- a/commonspace/object/acl/syncacl/aclsyncprotocol.go +++ b/commonspace/object/acl/syncacl/aclsyncprotocol.go @@ -76,10 +76,16 @@ func (a *aclSyncProtocol) FullSyncRequest(ctx context.Context, senderId string, log.DebugCtx(ctx, "acl full sync response sent", zap.String("response head", cnt.Head), zap.Int("len(response records)", len(cnt.Records))) } }() - if len(request.Records) > 0 && !a.aclList.HasHead(request.Head) { - err = a.aclList.AddRawRecords(request.Records) - if err != nil { - return + if !a.aclList.HasHead(request.Head) { + if len(request.Records) > 0 { + // in this case we can try to add some records + err = a.aclList.AddRawRecords(request.Records) + if err != nil { + return + } + } else { + // here it is impossible for us to do anything, we can't return records after head as defined in request, because we don't have it + return nil, list.ErrIncorrectRecordSequence } } return a.reqFactory.CreateFullSyncResponse(a.aclList, request.Head) diff --git a/commonspace/object/acl/syncacl/aclsyncprotocol_test.go b/commonspace/object/acl/syncacl/aclsyncprotocol_test.go new file mode 100644 index 00000000..1f861335 --- /dev/null +++ b/commonspace/object/acl/syncacl/aclsyncprotocol_test.go @@ -0,0 +1,213 @@ +package syncacl + +import ( + "context" + "github.com/anyproto/any-sync/app/logger" + "github.com/anyproto/any-sync/commonspace/object/acl/list" + "github.com/anyproto/any-sync/commonspace/object/acl/list/mock_list" + "github.com/anyproto/any-sync/commonspace/object/acl/syncacl/mock_syncacl" + "github.com/anyproto/any-sync/consensus/consensusproto" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "testing" +) + +type aclSyncProtocolFixture struct { + log logger.CtxLogger + spaceId string + senderId string + aclId string + aclMock *mock_list.MockAclList + reqFactory *mock_syncacl.MockRequestFactory + ctrl *gomock.Controller + syncProtocol AclSyncProtocol +} + +func newSyncProtocolFixture(t *testing.T) *aclSyncProtocolFixture { + ctrl := gomock.NewController(t) + aclList := mock_list.NewMockAclList(ctrl) + spaceId := "spaceId" + reqFactory := mock_syncacl.NewMockRequestFactory(ctrl) + aclList.EXPECT().Id().Return("aclId") + syncProtocol := newAclSyncProtocol(spaceId, aclList, reqFactory) + return &aclSyncProtocolFixture{ + log: log, + spaceId: spaceId, + senderId: "senderId", + aclId: "aclId", + aclMock: aclList, + reqFactory: reqFactory, + ctrl: ctrl, + syncProtocol: syncProtocol, + } +} + +func (fx *aclSyncProtocolFixture) stop() { + fx.ctrl.Finish() +} + +func TestHeadUpdate(t *testing.T) { + ctx := context.Background() + fullRequest := &consensusproto.LogSyncMessage{ + Content: &consensusproto.LogSyncContentValue{ + Value: &consensusproto.LogSyncContentValue_FullSyncRequest{ + FullSyncRequest: &consensusproto.LogFullSyncRequest{}, + }, + }, + } + t.Run("head update non empty all heads added", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &consensusproto.RawRecordWithId{} + headUpdate := &consensusproto.LogHeadUpdate{ + Head: "h1", + Records: []*consensusproto.RawRecordWithId{chWithId}, + } + fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId) + fx.aclMock.EXPECT().HasHead("h1").Return(false) + fx.aclMock.EXPECT().AddRawRecords(headUpdate.Records).Return(nil) + req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate) + require.Nil(t, req) + require.NoError(t, err) + }) + t.Run("head update results in full request", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &consensusproto.RawRecordWithId{} + headUpdate := &consensusproto.LogHeadUpdate{ + Head: "h1", + Records: []*consensusproto.RawRecordWithId{chWithId}, + } + fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId) + fx.aclMock.EXPECT().HasHead("h1").Return(false) + fx.aclMock.EXPECT().AddRawRecords(headUpdate.Records).Return(list.ErrIncorrectRecordSequence) + fx.reqFactory.EXPECT().CreateFullSyncRequest(fx.aclMock, headUpdate.Head).Return(fullRequest, nil) + req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate) + require.Equal(t, fullRequest, req) + require.NoError(t, err) + }) + t.Run("head update old heads", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &consensusproto.RawRecordWithId{} + headUpdate := &consensusproto.LogHeadUpdate{ + Head: "h1", + Records: []*consensusproto.RawRecordWithId{chWithId}, + } + fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId) + fx.aclMock.EXPECT().HasHead("h1").Return(true) + req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate) + require.Nil(t, req) + require.NoError(t, err) + }) + t.Run("head update empty equals", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + headUpdate := &consensusproto.LogHeadUpdate{ + Head: "h1", + } + fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId) + fx.aclMock.EXPECT().Head().Return(&list.AclRecord{Id: "h1"}) + req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate) + require.Nil(t, req) + require.NoError(t, err) + }) + t.Run("head update empty results in full request", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + headUpdate := &consensusproto.LogHeadUpdate{ + Head: "h1", + } + fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId) + fx.aclMock.EXPECT().Head().Return(&list.AclRecord{Id: "h2"}) + fx.reqFactory.EXPECT().CreateFullSyncRequest(fx.aclMock, headUpdate.Head).Return(fullRequest, nil) + req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate) + require.Equal(t, fullRequest, req) + require.NoError(t, err) + }) +} + +func TestFullSyncRequest(t *testing.T) { + ctx := context.Background() + fullResponse := &consensusproto.LogSyncMessage{ + Content: &consensusproto.LogSyncContentValue{ + Value: &consensusproto.LogSyncContentValue_FullSyncResponse{ + FullSyncResponse: &consensusproto.LogFullSyncResponse{}, + }, + }, + } + t.Run("full sync request non empty all heads added", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &consensusproto.RawRecordWithId{} + fullRequest := &consensusproto.LogFullSyncRequest{ + Head: "h1", + Records: []*consensusproto.RawRecordWithId{chWithId}, + } + fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId) + fx.aclMock.EXPECT().HasHead("h1").Return(false) + fx.aclMock.EXPECT().AddRawRecords(fullRequest.Records).Return(nil) + fx.reqFactory.EXPECT().CreateFullSyncResponse(fx.aclMock, fullRequest.Head).Return(fullResponse, nil) + resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest) + require.Equal(t, fullResponse, resp) + require.NoError(t, err) + }) + t.Run("full sync request non empty head exists", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &consensusproto.RawRecordWithId{} + fullRequest := &consensusproto.LogFullSyncRequest{ + Head: "h1", + Records: []*consensusproto.RawRecordWithId{chWithId}, + } + fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId) + fx.aclMock.EXPECT().HasHead("h1").Return(true) + fx.reqFactory.EXPECT().CreateFullSyncResponse(fx.aclMock, fullRequest.Head).Return(fullResponse, nil) + resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest) + require.Equal(t, fullResponse, resp) + require.NoError(t, err) + }) + t.Run("full sync request empty head not exists", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + fullRequest := &consensusproto.LogFullSyncRequest{ + Head: "h1", + } + fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId) + fx.aclMock.EXPECT().HasHead("h1").Return(false) + resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest) + require.Nil(t, resp) + require.Error(t, list.ErrIncorrectRecordSequence, err) + }) +} + +func TestFullSyncResponse(t *testing.T) { + ctx := context.Background() + t.Run("full sync response no heads", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &consensusproto.RawRecordWithId{} + fullResponse := &consensusproto.LogFullSyncResponse{ + Head: "h1", + Records: []*consensusproto.RawRecordWithId{chWithId}, + } + fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId) + fx.aclMock.EXPECT().HasHead("h1").Return(false) + fx.aclMock.EXPECT().AddRawRecords(fullResponse.Records).Return(nil) + err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullResponse) + require.NoError(t, err) + }) + t.Run("full sync response has heads", func(t *testing.T) { + fx := newSyncProtocolFixture(t) + defer fx.stop() + chWithId := &consensusproto.RawRecordWithId{} + fullResponse := &consensusproto.LogFullSyncResponse{ + Head: "h1", + Records: []*consensusproto.RawRecordWithId{chWithId}, + } + fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId) + fx.aclMock.EXPECT().HasHead("h1").Return(true) + err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullResponse) + require.NoError(t, err) + }) +} diff --git a/commonspace/object/acl/syncacl/headupdater/headupdater.go b/commonspace/object/acl/syncacl/headupdater/headupdater.go new file mode 100644 index 00000000..64c41f41 --- /dev/null +++ b/commonspace/object/acl/syncacl/headupdater/headupdater.go @@ -0,0 +1,5 @@ +package headupdater + +type HeadUpdater interface { + UpdateHeads(id string, heads []string) +} diff --git a/commonspace/object/acl/syncacl/mock_syncacl/mock_syncacl.go b/commonspace/object/acl/syncacl/mock_syncacl/mock_syncacl.go index a7ca877c..66d16b9d 100644 --- a/commonspace/object/acl/syncacl/mock_syncacl/mock_syncacl.go +++ b/commonspace/object/acl/syncacl/mock_syncacl/mock_syncacl.go @@ -10,7 +10,7 @@ import ( app "github.com/anyproto/any-sync/app" list "github.com/anyproto/any-sync/commonspace/object/acl/list" - syncacl "github.com/anyproto/any-sync/commonspace/object/acl/syncacl" + headupdater "github.com/anyproto/any-sync/commonspace/object/acl/syncacl/headupdater" spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto" consensusproto "github.com/anyproto/any-sync/consensus/consensusproto" crypto "github.com/anyproto/any-sync/util/crypto" @@ -386,7 +386,7 @@ func (mr *MockSyncAclMockRecorder) Run(arg0 interface{}) *gomock.Call { } // SetHeadUpdater mocks base method. -func (m *MockSyncAcl) SetHeadUpdater(arg0 syncacl.HeadUpdater) { +func (m *MockSyncAcl) SetHeadUpdater(arg0 headupdater.HeadUpdater) { m.ctrl.T.Helper() m.ctrl.Call(m, "SetHeadUpdater", arg0) } diff --git a/commonspace/object/acl/syncacl/syncacl.go b/commonspace/object/acl/syncacl/syncacl.go index a612e5c5..23592490 100644 --- a/commonspace/object/acl/syncacl/syncacl.go +++ b/commonspace/object/acl/syncacl/syncacl.go @@ -3,6 +3,7 @@ package syncacl import ( "context" "errors" + "github.com/anyproto/any-sync/commonspace/object/acl/syncacl/headupdater" "github.com/anyproto/any-sync/commonspace/object/syncobjectgetter" "github.com/anyproto/any-sync/accountservice" @@ -30,7 +31,7 @@ type SyncAcl interface { app.ComponentRunnable list.AclList syncobjectgetter.SyncObject - SetHeadUpdater(updater HeadUpdater) + SetHeadUpdater(updater headupdater.HeadUpdater) SyncWithPeer(ctx context.Context, peerId string) (err error) } @@ -38,15 +39,11 @@ func New() SyncAcl { return &syncAcl{} } -type HeadUpdater interface { - UpdateHeads(id string, heads []string) -} - type syncAcl struct { list.AclList syncClient SyncClient syncHandler synchandler.SyncHandler - headUpdater HeadUpdater + headUpdater headupdater.HeadUpdater isClosed bool } @@ -58,7 +55,7 @@ func (s *syncAcl) HandleRequest(ctx context.Context, senderId string, request *s return s.syncHandler.HandleRequest(ctx, senderId, request) } -func (s *syncAcl) SetHeadUpdater(updater HeadUpdater) { +func (s *syncAcl) SetHeadUpdater(updater headupdater.HeadUpdater) { s.headUpdater = updater }