diff --git a/common/commonspace/rpchandler.go b/common/commonspace/rpchandler.go index b0d0be69..be83b43c 100644 --- a/common/commonspace/rpchandler.go +++ b/common/commonspace/rpchandler.go @@ -20,11 +20,5 @@ func (r *rpcHandler) HeadSync(ctx context.Context, req *spacesyncproto.HeadSyncR } func (r *rpcHandler) Stream(stream spacesyncproto.DRPCSpace_StreamStream) (err error) { - err = r.s.SyncService().StreamPool().AddAndReadStream(stream) - if err != nil { - return - } - - <-stream.Context().Done() - return + return r.s.SyncService().StreamPool().AddAndReadStream(stream) } diff --git a/common/commonspace/space.go b/common/commonspace/space.go index 757b1f32..69cac657 100644 --- a/common/commonspace/space.go +++ b/common/commonspace/space.go @@ -27,8 +27,8 @@ type Space interface { SpaceSyncRpc() RpcHandler SyncService() syncservice.SyncService - CreateTree(payload tree.ObjectTreeCreatePayload, listener tree.ObjectTreeUpdateListener) (tree.ObjectTree, error) - BuildTree(id string, listener tree.ObjectTreeUpdateListener) (tree.ObjectTree, error) + CreateTree(ctx context.Context, payload tree.ObjectTreeCreatePayload, listener tree.ObjectTreeUpdateListener) (tree.ObjectTree, error) + BuildTree(ctx context.Context, id string, listener tree.ObjectTreeUpdateListener) (tree.ObjectTree, error) Close() error } @@ -47,16 +47,51 @@ type space struct { cache cache.TreeCache } -func (s *space) CreateTree(payload tree.ObjectTreeCreatePayload, listener tree.ObjectTreeUpdateListener) (tree.ObjectTree, error) { +func (s *space) CreateTree(ctx context.Context, payload tree.ObjectTreeCreatePayload, listener tree.ObjectTreeUpdateListener) (tree.ObjectTree, error) { return synctree.CreateSyncTree(payload, s.syncService, listener, nil, s.storage.CreateTreeStorage) } -func (s *space) BuildTree(id string, listener tree.ObjectTreeUpdateListener) (t tree.ObjectTree, err error) { +func (s *space) BuildTree(ctx context.Context, id string, listener tree.ObjectTreeUpdateListener) (t tree.ObjectTree, err error) { + getTreeRemote := func() (*spacesyncproto.ObjectSyncMessage, error) { + // TODO: add empty context handling (when this is not happening due to head update) + peerId, err := syncservice.GetPeerIdFromStreamContext(ctx) + if err != nil { + return nil, err + } + + return s.syncService.StreamPool().SendSync( + peerId, + spacesyncproto.WrapFullRequest(&spacesyncproto.ObjectFullSyncRequest{}, nil, id), + func(syncMessage *spacesyncproto.ObjectSyncMessage) bool { + return syncMessage.GetContent().GetFullSyncResponse() != nil + }, + ) + } + store, err := s.storage.Storage(id) - if err != nil { + if err != nil && err != treestorage.ErrUnknownTreeId { return } + if err == treestorage.ErrUnknownTreeId { + var resp *spacesyncproto.ObjectSyncMessage + resp, err = getTreeRemote() + if err != nil { + return + } + fullSyncResp := resp.GetContent().GetFullSyncResponse() + + payload := treestorage.TreeStorageCreatePayload{ + TreeId: resp.TreeId, + Header: resp.TreeHeader, + Changes: fullSyncResp.Changes, + Heads: fullSyncResp.Heads, + } + store, err = s.storage.CreateTreeStorage(payload) + if err != nil { + return + } + } return synctree.BuildSyncTree(s.syncService, store.(treestorage.TreeStorage), listener, nil) } diff --git a/common/commonspace/syncservice/streampool.go b/common/commonspace/syncservice/streampool.go index 6b402294..0ccd2d6b 100644 --- a/common/commonspace/syncservice/streampool.go +++ b/common/commonspace/syncservice/streampool.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/anytypeio/go-anytype-infrastructure-experiments/common/commonspace/spacesyncproto" "github.com/libp2p/go-libp2p-core/sec" - "storj.io/drpc" "storj.io/drpc/drpcctx" "sync" ) @@ -25,17 +24,27 @@ type StreamPool interface { } type SyncClient interface { + SendSync(peerId string, + message *spacesyncproto.ObjectSyncMessage, + msgCheck func(syncMessage *spacesyncproto.ObjectSyncMessage) bool) (reply *spacesyncproto.ObjectSyncMessage, err error) SendAsync(peerId string, message *spacesyncproto.ObjectSyncMessage) (err error) BroadcastAsync(message *spacesyncproto.ObjectSyncMessage) (err error) } type MessageHandler func(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) +type responseWaiter struct { + ch chan *spacesyncproto.ObjectSyncMessage + msgCheck func(message *spacesyncproto.ObjectSyncMessage) bool +} + type streamPool struct { sync.Mutex peerStreams map[string]spacesyncproto.SpaceStream messageHandler MessageHandler wg *sync.WaitGroup + waiters map[string]responseWaiter + waitersMx sync.Mutex } func newStreamPool(messageHandler MessageHandler) StreamPool { @@ -51,6 +60,41 @@ func (s *streamPool) HasStream(peerId string) (res bool) { return err == nil } +func (s *streamPool) SendSync( + peerId string, + message *spacesyncproto.ObjectSyncMessage, + msgCheck func(syncMessage *spacesyncproto.ObjectSyncMessage) bool) (reply *spacesyncproto.ObjectSyncMessage, err error) { + + sendAndWait := func(waiter responseWaiter) (err error) { + err = s.SendAsync(peerId, message) + if err != nil { + return + } + + reply = <-waiter.ch + return + } + + key := fmt.Sprintf("%s.%s", peerId, message.TreeId) + s.waitersMx.Lock() + waiter, exists := s.waiters[key] + if exists { + s.waitersMx.Unlock() + + err = sendAndWait(waiter) + return + } + + waiter = responseWaiter{ + ch: make(chan *spacesyncproto.ObjectSyncMessage), + msgCheck: msgCheck, + } + s.waiters[key] = waiter + s.waitersMx.Unlock() + err = sendAndWait(waiter) + return +} + func (s *streamPool) SendAsync(peerId string, message *spacesyncproto.ObjectSyncMessage) (err error) { stream, err := s.getStream(peerId) if err != nil { @@ -109,7 +153,7 @@ func (s *streamPool) BroadcastAsync(message *spacesyncproto.ObjectSyncMessage) ( func (s *streamPool) AddAndReadStream(stream spacesyncproto.SpaceStream) (err error) { s.Lock() - peerId, err := getPeerIdFromStream(stream) + peerId, err := GetPeerIdFromStreamContext(stream.Context()) if err != nil { s.Unlock() return @@ -119,8 +163,7 @@ func (s *streamPool) AddAndReadStream(stream spacesyncproto.SpaceStream) (err er s.wg.Add(1) s.Unlock() - go s.readPeerLoop(peerId, stream) - return + return s.readPeerLoop(peerId, stream) } func (s *streamPool) Close() (err error) { @@ -140,6 +183,22 @@ func (s *streamPool) readPeerLoop(peerId string, stream spacesyncproto.SpaceStre limiter <- struct{}{} } + process := func(msg *spacesyncproto.ObjectSyncMessage) { + key := fmt.Sprintf("%s.%s", peerId, msg.TreeId) + s.waitersMx.Lock() + waiter, exists := s.waiters[key] + + if !exists || !waiter.msgCheck(msg) { + s.waitersMx.Unlock() + s.messageHandler(stream.Context(), peerId, msg) + return + } + + delete(s.waiters, key) + s.waitersMx.Unlock() + waiter.ch <- msg + } + Loop: for { msg, err := stream.Recv() @@ -155,8 +214,7 @@ Loop: defer func() { limiter <- struct{}{} }() - - s.messageHandler(context.Background(), peerId, msg) + process(msg) }() } return s.removePeer(peerId) @@ -173,8 +231,7 @@ func (s *streamPool) removePeer(peerId string) (err error) { return } -func getPeerIdFromStream(stream drpc.Stream) (string, error) { - ctx := stream.Context() +func GetPeerIdFromStreamContext(ctx context.Context) (string, error) { conn, ok := ctx.Value(drpcctx.TransportKey{}).(sec.SecureConn) if !ok { return "", fmt.Errorf("incorrect connection type in stream") diff --git a/common/commonspace/synctree/synctree.go b/common/commonspace/synctree/synctree.go index 3b84ee77..d9dcb4e4 100644 --- a/common/commonspace/synctree/synctree.go +++ b/common/commonspace/synctree/synctree.go @@ -11,7 +11,7 @@ import ( ) type SyncTree struct { - objTree tree.ObjectTree + tree.ObjectTree syncService syncservice.SyncService } @@ -60,68 +60,8 @@ func buildSyncTree( return } -func (s *SyncTree) Lock() { - s.objTree.Lock() -} - -func (s *SyncTree) Unlock() { - s.objTree.Unlock() -} - -func (s *SyncTree) RLock() { - s.objTree.RLock() -} - -func (s *SyncTree) RUnlock() { - s.objTree.RUnlock() -} - -func (s *SyncTree) ID() string { - return s.objTree.ID() -} - -func (s *SyncTree) Header() *aclpb.TreeHeader { - return s.objTree.Header() -} - -func (s *SyncTree) Heads() []string { - return s.objTree.Heads() -} - -func (s *SyncTree) Root() *tree.Change { - return s.objTree.Root() -} - -func (s *SyncTree) HasChange(id string) bool { - return s.objTree.HasChange(id) -} - -func (s *SyncTree) Iterate(convert tree.ChangeConvertFunc, iterate tree.ChangeIterateFunc) error { - return s.objTree.Iterate(convert, iterate) -} - -func (s *SyncTree) IterateFrom(id string, convert tree.ChangeConvertFunc, iterate tree.ChangeIterateFunc) error { - return s.objTree.IterateFrom(id, convert, iterate) -} - -func (s *SyncTree) SnapshotPath() []string { - return s.objTree.SnapshotPath() -} - -func (s *SyncTree) ChangesAfterCommonSnapshot(snapshotPath, heads []string) ([]*aclpb.RawTreeChangeWithId, error) { - return s.objTree.ChangesAfterCommonSnapshot(snapshotPath, heads) -} - -func (s *SyncTree) Storage() storage.TreeStorage { - return s.objTree.Storage() -} - -func (s *SyncTree) DebugDump() (string, error) { - return s.objTree.DebugDump() -} - func (s *SyncTree) AddContent(ctx context.Context, content tree.SignableChangeContent) (res tree.AddResult, err error) { - res, err = s.objTree.AddContent(ctx, content) + res, err = s.AddContent(ctx, content) if err != nil { return } @@ -134,7 +74,7 @@ func (s *SyncTree) AddContent(ctx context.Context, content tree.SignableChangeCo } func (s *SyncTree) AddRawChanges(ctx context.Context, changes ...*aclpb.RawTreeChangeWithId) (res tree.AddResult, err error) { - res, err = s.objTree.AddRawChanges(ctx, changes...) + res, err = s.AddRawChanges(ctx, changes...) if err != nil || res.Mode == tree.Nothing { return } @@ -145,7 +85,3 @@ func (s *SyncTree) AddRawChanges(ctx context.Context, changes ...*aclpb.RawTreeC }) return } - -func (s *SyncTree) Close() error { - return s.objTree.Close() -}