Compare commits

...

170 Commits
yamux ... main

Author SHA1 Message Date
ff27016534 Merge pull request 'Update README.md' (#1) from force_build into main
Reviewed-on: #1
2023-08-24 20:15:49 -04:00
e0d234213e Merge branch 'main' into force_build
All checks were successful
/ test (pull_request) Successful in 2m0s
2023-07-26 19:04:23 -04:00
3242503554 Update .github/workflows/coverage.yml 2023-07-26 19:04:07 -04:00
8dce67f3d2 Update README.md
Some checks failed
/ test (pull_request) Failing after 18s
2023-07-26 19:01:27 -04:00
a1475ca95b Update README.md 2023-07-26 18:59:59 -04:00
Sergey Cherepanov
9b7d7e11a7
Merge pull request #49 from anyproto/open-22-prepare-any-sync-for-publishing
Add README.md
2023-07-17 19:57:29 +02:00
Sergey Cherepanov
66921158c1
Merge pull request #50 from anyproto/dependabot/go_modules/github.com/libp2p/go-libp2p-0.29.0
Bump github.com/libp2p/go-libp2p from 0.28.1 to 0.29.0
2023-07-17 19:31:42 +02:00
dependabot[bot]
735536068d
Bump github.com/libp2p/go-libp2p from 0.28.1 to 0.29.0
Bumps [github.com/libp2p/go-libp2p](https://github.com/libp2p/go-libp2p) from 0.28.1 to 0.29.0.
- [Release notes](https://github.com/libp2p/go-libp2p/releases)
- [Changelog](https://github.com/libp2p/go-libp2p/blob/master/CHANGELOG.md)
- [Commits](https://github.com/libp2p/go-libp2p/compare/v0.28.1...v0.29.0)

---
updated-dependencies:
- dependency-name: github.com/libp2p/go-libp2p
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-07-17 15:55:07 +00:00
Sergey Fuksman
63e533efb5
Add README.md 2023-07-17 10:09:29 +03:00
Mikhail Rakhmanov
b1c198df1d
Merge pull request #47 from anyproto/acl-sync-protocol
Acl sync protocol
2023-07-13 11:23:49 +02:00
mcrakhman
e08b3ba659
Add syncaclhandler tests 2023-07-12 15:58:41 +02:00
mcrakhman
22ec754ca7
Add sync protocol tests 2023-07-12 14:12:00 +02:00
mcrakhman
febfb72cec
Add diffsyncer tests 2023-07-12 12:09:55 +02:00
mcrakhman
098120da84
Update headsync tests 2023-07-11 13:58:59 +02:00
mcrakhman
bf7e256065
Merge remote-tracking branch 'origin/consensus-client' into acl-sync-protocol
# Conflicts:
#	consensus/consensusproto/consensus.pb.go
2023-07-11 13:58:45 +02:00
Sergey Cherepanov
ebf4034ec7
consensus: fix race 2023-07-11 12:02:15 +02:00
mcrakhman
b4cc8d0a61
Change head sync update behaviour 2023-07-10 23:47:29 +02:00
mcrakhman
94aea5bafb
Expose Acl in space 2023-07-10 23:17:05 +02:00
Sergey Cherepanov
4bccbf1faf
Merge pull request #45 from anyproto/dependabot/go_modules/golang.org/x/net-0.12.0
Bump golang.org/x/net from 0.11.0 to 0.12.0
2023-07-10 19:13:21 +02:00
Sergey Cherepanov
26b18fba87
Merge pull request #44 from anyproto/dependabot/go_modules/golang.org/x/crypto-0.11.0
Bump golang.org/x/crypto from 0.10.0 to 0.11.0
2023-07-10 19:13:11 +02:00
Sergey Cherepanov
1a23081336
merge 2023-07-10 19:12:58 +02:00
dependabot[bot]
fc0c3d54f1
Bump golang.org/x/net from 0.11.0 to 0.12.0
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.11.0 to 0.12.0.
- [Commits](https://github.com/golang/net/compare/v0.11.0...v0.12.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-07-10 15:26:07 +00:00
dependabot[bot]
9dde29a280
Bump golang.org/x/crypto from 0.10.0 to 0.11.0
Bumps [golang.org/x/crypto](https://github.com/golang/crypto) from 0.10.0 to 0.11.0.
- [Commits](https://github.com/golang/crypto/compare/v0.10.0...v0.11.0)

---
updated-dependencies:
- dependency-name: golang.org/x/crypto
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-07-10 15:25:50 +00:00
mcrakhman
3c5e3bed96
Merge branch 'main' into acl-sync-protocol 2023-07-10 15:56:45 +02:00
Sergey Cherepanov
ef128dd33f
switch to uber/gomock 2023-07-10 15:41:22 +02:00
Sergey Cherepanov
648aa15b55
consensus: err invalid payload 2023-07-07 11:22:12 +02:00
Sergey Cherepanov
cd3c6c736a
consensus: remove log payload 2023-07-05 18:50:14 +02:00
Mikhail Rakhmanov
79cc89bec2
Merge pull request #43 from anyproto/remove-object-load-deadline 2023-07-04 16:55:12 +02:00
mcrakhman
fe31afc337
Remove cancel with deadline 2023-07-04 16:51:14 +02:00
Mikhail Rakhmanov
ac8c8e4a31
Merge pull request #42 from anyproto/fix-peer-connection 2023-07-04 11:12:38 +02:00
mcrakhman
ab34ff4bc9
Fix not sending correct connection and incoming count 2023-07-04 08:07:55 +02:00
mcrakhman
bef93d46ad
Implement sync protocol 2023-07-03 18:19:23 +02:00
Sergey Cherepanov
0c1d752acf
consensus: err forbidden 2023-07-03 17:36:31 +02:00
mcrakhman
8aa41da1ff
Merge remote-tracking branch 'origin/consensus-client' into acl-sync-protocol
# Conflicts:
#	consensus/consensusproto/consensus.pb.go
2023-07-03 17:16:43 +02:00
Sergey Cherepanov
b12a056dd9
consensus: use strings for ids 2023-07-03 16:19:24 +02:00
mcrakhman
0d16c5d7e4
WIP sync logic 2023-07-03 15:48:48 +02:00
mcrakhman
145332b0f7
Add headsync acl logic 2023-07-03 13:43:54 +02:00
mcrakhman
51ac955f1c
Add sync protocol interfaces 2023-07-02 15:55:58 +02:00
Mikhail Rakhmanov
b10d72a092
Merge pull request #40 from anyproto/acl-change 2023-07-02 15:54:07 +02:00
mcrakhman
822e7f374d
Change to consensus proto 2023-07-01 13:17:18 +02:00
mcrakhman
e094743fbc
Merge remote-tracking branch 'origin/consensus-client' into acl-change 2023-07-01 12:51:05 +02:00
Sergey Cherepanov
92cbfb1cb3
change consensus proto and client 2023-06-30 19:42:07 +02:00
Sergey Cherepanov
59cf8b46fd
consensus: change err offset 2023-06-29 14:52:47 +02:00
Sergey Cherepanov
50f94e7518
consensus: change err offset 2023-06-29 14:51:42 +02:00
Sergey Cherepanov
cbdbe0c34b
Merge branch 'main' of github.com:anyproto/any-sync into consensus-client 2023-06-29 14:40:07 +02:00
Sergey Cherepanov
02dd4783bc
fix mock 2023-06-29 14:37:52 +02:00
mcrakhman
5ffc175f4f
Remove time from test 2023-06-29 01:05:43 +02:00
mcrakhman
68cda47ede
Update list mock 2023-06-29 01:00:52 +02:00
mcrakhman
f4cbbfa374
Update tests 2023-06-29 00:57:24 +02:00
mcrakhman
53e9c4ab02
Merge branch 'main' into acl-change
# Conflicts:
#	net/peer/peer.go
2023-06-28 23:12:04 +02:00
Mikhail Rakhmanov
8dc0ead8f3
Merge pull request #39 from anyproto/fix-last-iterated-id
Add lastIteratedId when setting merged heads
2023-06-28 22:33:21 +02:00
mcrakhman
02b326cc90
Add lastIteratedId when setting merged heads 2023-06-28 21:34:50 +02:00
mcrakhman
e5b4f62e48
fix nodes online 2023-06-28 17:35:45 +02:00
mcrakhman
3f08fcb555
Add account remove test 2023-06-28 15:43:35 +02:00
mcrakhman
39f41c52d1
Add invite test 2023-06-28 14:55:17 +02:00
mcrakhman
ffd613a5fc
Fix requestmanager test 2023-06-28 11:59:13 +02:00
Sergey Cherepanov
ea6ca799e7
Merge pull request #38 from anyproto/fix-handshake-race
fix race in proto handshake
2023-06-28 11:00:39 +02:00
Sergey Cherepanov
6057fc2c7c
Merge pull request #37 from anyproto/dependabot/go_modules/github.com/libp2p/go-libp2p-0.28.1
Bump github.com/libp2p/go-libp2p from 0.28.0 to 0.28.1
2023-06-28 10:27:34 +02:00
Sergey Cherepanov
8770da4abf
fix race in proto handshake 2023-06-28 10:11:21 +02:00
Sergey Cherepanov
a092f7b4a1
consensus client 2023-06-27 22:09:52 +02:00
mcrakhman
0ffbb6fa5a
Rework ACL structures 2023-06-27 19:44:44 +02:00
mcrakhman
061522eec2
Update protocol 2023-06-26 19:38:54 +02:00
dependabot[bot]
b768dedd56
Bump github.com/libp2p/go-libp2p from 0.28.0 to 0.28.1
Bumps [github.com/libp2p/go-libp2p](https://github.com/libp2p/go-libp2p) from 0.28.0 to 0.28.1.
- [Release notes](https://github.com/libp2p/go-libp2p/releases)
- [Changelog](https://github.com/libp2p/go-libp2p/blob/master/CHANGELOG.md)
- [Commits](https://github.com/libp2p/go-libp2p/compare/v0.28.0...v0.28.1)

---
updated-dependencies:
- dependency-name: github.com/libp2p/go-libp2p
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-26 16:03:40 +00:00
Sergey Cherepanov
5a02d1c338
Merge pull request #36 from anyproto/subconn-limit
peer sub connections limit/throttling
2023-06-26 13:48:50 +02:00
mcrakhman
62f23b7229
Update record builder to build new payloads 2023-06-26 11:43:17 +02:00
mcrakhman
81aadfde7e
Add validate method in list 2023-06-26 10:10:14 +02:00
Sergey Cherepanov
f943991bc0
yamux bench test: open sub conn 2023-06-23 18:32:13 +02:00
Sergey Cherepanov
f0ffc9b7bf
descrease net.pool ttl 2023-06-23 18:27:59 +02:00
Sergey Cherepanov
291b8daf5f
fix blinking test 2023-06-23 17:47:29 +02:00
Sergey Cherepanov
e929d5431d
correct throttle counting 2023-06-23 17:39:13 +02:00
Sergey Cherepanov
49c3178f65
cleanup 2023-06-23 16:57:51 +02:00
Sergey Cherepanov
894f4db1ff
peer sub connections throtling + fixes 2023-06-23 16:54:55 +02:00
mcrakhman
7577c14d5f
Add state apply changes 2023-06-23 16:16:26 +02:00
mcrakhman
f9bab4d51d
Add content validator 2023-06-23 14:50:09 +02:00
mcrakhman
1fada6f336
Merge branch 'main' into acl-change 2023-06-23 13:44:38 +02:00
Mikhail Rakhmanov
0095a34167
Merge pull request #35 from anyproto/fix-raw-loader 2023-06-22 19:31:02 +02:00
mcrakhman
2c573138e6
Correctly removing changes which we don't need to send 2023-06-22 19:22:56 +02:00
Sergey Cherepanov
78a3bc6aeb
Merge pull request #30 from anyproto/dependabot/go_modules/github.com/prometheus/client_golang-1.16.0
Bump github.com/prometheus/client_golang from 1.15.1 to 1.16.0
2023-06-22 15:30:14 +02:00
Sergey Cherepanov
89d32044e3
Merge pull request #29 from anyproto/dependabot/go_modules/golang.org/x/net-0.11.0
Bump golang.org/x/net from 0.10.0 to 0.11.0
2023-06-22 15:30:02 +02:00
Sergey Cherepanov
d1be3c8a43
Merge pull request #28 from anyproto/dependabot/go_modules/github.com/libp2p/go-libp2p-0.28.0
Bump github.com/libp2p/go-libp2p from 0.27.5 to 0.28.0
2023-06-22 15:29:51 +02:00
mcrakhman
718a5b04dc
Update proto 2023-06-22 13:42:38 +02:00
Sergey
9d2691ddfd
Merge pull request #33 from anyproto/add-iterate-components-method
Add app.IterateComponents method
2023-06-21 12:56:56 +00:00
Sergey
4eb2245669
IterateComponents: Add test and comment 2023-06-21 17:54:38 +05:00
Sergey
fa178d7c26
Add app.IterateComponents method. This method helps to create debugging HTTP handlers in Heart 2023-06-21 13:50:50 +05:00
Mikhail Rakhmanov
1b47a54f87
Merge pull request #32 from anyproto/sync-improvements 2023-06-20 13:24:47 +02:00
mcrakhman
8e7df9eae5
Sync updates 2023-06-20 12:02:05 +02:00
dependabot[bot]
d3636604d7
Bump github.com/prometheus/client_golang from 1.15.1 to 1.16.0
Bumps [github.com/prometheus/client_golang](https://github.com/prometheus/client_golang) from 1.15.1 to 1.16.0.
- [Release notes](https://github.com/prometheus/client_golang/releases)
- [Changelog](https://github.com/prometheus/client_golang/blob/main/CHANGELOG.md)
- [Commits](https://github.com/prometheus/client_golang/compare/v1.15.1...v1.16.0)

---
updated-dependencies:
- dependency-name: github.com/prometheus/client_golang
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-19 16:05:27 +00:00
dependabot[bot]
1d4447f126
Bump golang.org/x/net from 0.10.0 to 0.11.0
Bumps [golang.org/x/net](https://github.com/golang/net) from 0.10.0 to 0.11.0.
- [Commits](https://github.com/golang/net/compare/v0.10.0...v0.11.0)

---
updated-dependencies:
- dependency-name: golang.org/x/net
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-19 16:05:09 +00:00
dependabot[bot]
e9668c73a8
Bump github.com/libp2p/go-libp2p from 0.27.5 to 0.28.0
Bumps [github.com/libp2p/go-libp2p](https://github.com/libp2p/go-libp2p) from 0.27.5 to 0.28.0.
- [Release notes](https://github.com/libp2p/go-libp2p/releases)
- [Changelog](https://github.com/libp2p/go-libp2p/blob/master/CHANGELOG.md)
- [Commits](https://github.com/libp2p/go-libp2p/compare/v0.27.5...v0.28.0)

---
updated-dependencies:
- dependency-name: github.com/libp2p/go-libp2p
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-19 16:04:53 +00:00
Mikhail Rakhmanov
fe7b97bee9
Merge pull request #27 from anyproto/change-add-to-tryadd-streampool 2023-06-19 11:14:52 +02:00
mcrakhman
a9889a6245
TryAdd correctly 2023-06-19 11:10:56 +02:00
mcrakhman
16be33fc96
Change Add to TryAdd 2023-06-19 10:41:16 +02:00
Mikhail Rakhmanov
2ea446184d
Merge pull request #26 from anyproto/cache-cancel-load-close
Cancel on cache close
2023-06-15 17:38:14 +02:00
mcrakhman
18716eebb4
Cancel on cache close 2023-06-15 13:08:03 +02:00
Sergey Cherepanov
53cbfca3ca
Merge pull request #25 from anyproto/fix-metric-panic
validate rpc method name
2023-06-14 15:27:46 +02:00
Sergey Cherepanov
646f7fbedc
remove mockgen from deps 2023-06-14 12:01:52 +02:00
Sergey Cherepanov
d35ac55ee1
validate rpc method name 2023-06-14 11:55:59 +02:00
Sergey Cherepanov
ff3fc68451
Merge pull request #24 from anyproto/fix-metrics
Peer addr in ctx + fix any-sync version
2023-06-13 19:09:58 +02:00
Mikhail Rakhmanov
8f984443b9
Merge pull request #23 from anyproto/no-retry-on-not-found 2023-06-13 19:09:31 +02:00
Sergey Cherepanov
31f0014783
add peer addr to cctx 2023-06-13 19:06:10 +02:00
Sergey Cherepanov
69f2cb8b1d
app.VersionName fix any-sync version 2023-06-13 19:02:50 +02:00
mcrakhman
cc3da7e66b
Simplify tree remote getter 2023-06-13 19:00:01 +02:00
mcrakhman
060c6d1231
Retry fail on treechangeproto.ErrGetTree 2023-06-13 18:14:13 +02:00
Mikhail Rakhmanov
759c48c6b7
Merge pull request #21 from anyproto/add-treegetter-get-timeout 2023-06-13 15:51:32 +02:00
Sergey Cherepanov
d551201021
Merge pull request #22 from anyproto/GO-1466-client-version
Go 1466 client version
2023-06-13 15:49:00 +02:00
Sergey Cherepanov
be58956bec
provide client version in secureservice 2023-06-13 15:39:22 +02:00
mcrakhman
0ca2fe5e7e
Change naming 2023-06-13 15:36:23 +02:00
Sergey Cherepanov
38090ee68f
provide client version to rpc log 2023-06-13 15:32:46 +02:00
Sergey Cherepanov
c753da8def
client version to handshake 2023-06-13 15:30:14 +02:00
mcrakhman
6eda884686
Remove goto 2023-06-13 15:26:42 +02:00
mcrakhman
5c884afcf1
Make error typed 2023-06-13 15:22:46 +02:00
mcrakhman
2aaa8f4a0c
Change retry logic and add tests 2023-06-13 15:21:11 +02:00
Sergey Cherepanov
ba7cffb51a
Merge branch 'main' of github.com:anyproto/any-sync into GO-1466-client-version 2023-06-13 14:49:14 +02:00
Sergey Cherepanov
3cfa70c291
Merge pull request #20 from anyproto/net-fixes
Net fixes
2023-06-13 14:17:10 +02:00
Sergey Cherepanov
531f61d9d6
Merge pull request #19 from anyproto/dependabot/go_modules/github.com/multiformats/go-multihash-0.2.3
Bump github.com/multiformats/go-multihash from 0.2.2 to 0.2.3
2023-06-13 14:03:01 +02:00
dependabot[bot]
c3ebc8981c
Bump github.com/multiformats/go-multihash from 0.2.2 to 0.2.3
Bumps [github.com/multiformats/go-multihash](https://github.com/multiformats/go-multihash) from 0.2.2 to 0.2.3.
- [Release notes](https://github.com/multiformats/go-multihash/releases)
- [Commits](https://github.com/multiformats/go-multihash/compare/v0.2.2...v0.2.3)

---
updated-dependencies:
- dependency-name: github.com/multiformats/go-multihash
  dependency-type: direct:production
  update-type: version-update:semver-patch
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-13 11:49:19 +00:00
Sergey Cherepanov
1bfb7ff64f
Merge pull request #18 from anyproto/dependabot/go_modules/github.com/ipfs/go-merkledag-0.11.0
Bump github.com/ipfs/go-merkledag from 0.10.0 to 0.11.0
2023-06-13 13:47:43 +02:00
Sergey Cherepanov
733027d798
remove multiconn.LastUsage + fix peer gc 2023-06-13 13:40:13 +02:00
Sergey Cherepanov
05b479e5fa
yamux keep-alive config + remove write deadline for sub conns 2023-06-13 13:28:19 +02:00
Sergey Cherepanov
767f868a36
remove test test 2023-06-12 18:55:40 +02:00
Sergey Cherepanov
ab6ecaa462
write deadline + check conn for close 2023-06-12 18:42:30 +02:00
dependabot[bot]
d6df4b7001
Bump github.com/ipfs/go-merkledag from 0.10.0 to 0.11.0
Bumps [github.com/ipfs/go-merkledag](https://github.com/ipfs/go-merkledag) from 0.10.0 to 0.11.0.
- [Release notes](https://github.com/ipfs/go-merkledag/releases)
- [Commits](https://github.com/ipfs/go-merkledag/compare/v0.10.0...v0.11.0)

---
updated-dependencies:
- dependency-name: github.com/ipfs/go-merkledag
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2023-06-12 16:05:59 +00:00
Sergey Cherepanov
cb0396c40f
comments 2023-06-12 16:01:17 +02:00
Sergey Cherepanov
40cd112a2a
app.VersionName 2023-06-12 15:58:23 +02:00
Sergey Cherepanov
fb1df54941
Merge pull request #16 from anyproto/peer-active-subconn-gc
peer: gc active sub conn
2023-06-09 19:38:36 +02:00
Sergey Cherepanov
89afc03218
fix HandshakeIncompatibleVersion test 2023-06-09 19:28:09 +02:00
Sergey Cherepanov
e97e6b68c6
peer: gc active sub conn 2023-06-09 19:14:23 +02:00
mcrakhman
b0c0e4b26e
Fix proto version 2023-06-09 17:10:25 +02:00
Mikhail Rakhmanov
90c5ef3311
Merge pull request #15 from anyproto/new-sync-protocol 2023-06-09 11:44:35 +02:00
mcrakhman
35c29a4842
Add WithServer for TestPool 2023-06-08 21:07:03 +02:00
Sergey Cherepanov
15326da736
yamux: move listCtx to Init 2023-06-08 14:37:54 +02:00
Sergey Cherepanov
24ce490524
yamux: AddListener method 2023-06-08 14:36:14 +02:00
Sergey Cherepanov
85cf6b8332
debug server fixes 2023-06-08 13:38:00 +02:00
Sergey Cherepanov
f9cb0c2dbb
debug server 2023-06-08 13:27:01 +02:00
Sergey Cherepanov
33cbdd06a6
drpc conn config 2023-06-08 13:10:30 +02:00
Sergey Cherepanov
c7828d0671
move server config to drpc config 2023-06-08 12:35:29 +02:00
mcrakhman
dad76def2f
Merge branch 'new-sync-protocol-tests' into new-sync-protocol 2023-06-07 22:05:43 +02:00
mcrakhman
318a49c526
Change objectsync injection 2023-06-07 22:05:05 +02:00
mcrakhman
c8c0839a57
Add request manager tests 2023-06-07 21:48:13 +02:00
Sergey Cherepanov
5a6661eab1
cleanup net config 2023-06-07 20:52:09 +02:00
Sergey Cherepanov
065ff11983
bump proto version as 1 2023-06-07 20:50:26 +02:00
Sergey Cherepanov
6c9d1b0e84
rpctest pool 2023-06-07 20:45:32 +02:00
mcrakhman
4d1494a17a
Add mocks and some requestmanager tests 2023-06-07 19:31:15 +02:00
mcrakhman
51eb5b1a42
Fix settings and deletion tests 2023-06-07 18:05:13 +02:00
Sergey Cherepanov
485c9dd768
yamux default timeouts 2023-06-07 14:49:44 +02:00
mcrakhman
564c636391
Fix diffsyncer tests 2023-06-07 14:09:29 +02:00
Sergey Cherepanov
5a8c69e557
Merge branch 'yamux' into new-sync-protocol 2023-06-07 13:34:31 +02:00
mcrakhman
4ef617b1f2
More headsync tests 2023-06-07 13:06:37 +02:00
mcrakhman
2f5e0dd6c8
WIP headsync tests revive 2023-06-07 11:30:27 +02:00
mcrakhman
8310cb3c05
Fix sync protocol integration tests 2023-06-06 22:08:06 +02:00
mcrakhman
100e7e04c3
SyncTreeHandler tests 2023-06-06 21:49:23 +02:00
mcrakhman
b18bb02176
TreeSyncProtocol tests 2023-06-06 20:50:53 +02:00
mcrakhman
3a2f9fe6f5
WIP synctree tests rewrite 2023-06-06 20:10:44 +02:00
mcrakhman
67d535362f
Different fixes 2023-06-06 17:18:59 +02:00
mcrakhman
4c45ad3e67
Merge remote-tracking branch 'origin/yamux' into new-sync-protocol 2023-06-06 10:43:52 +02:00
mcrakhman
e96524f702
Yamux setaccepter 2023-06-06 10:43:43 +02:00
mcrakhman
66775873c7
Add syncstatusprovider 2023-06-05 20:44:27 +02:00
mcrakhman
69e607eddb
Expose more methods 2023-06-05 15:16:38 +02:00
mcrakhman
85a093dd4a
Change space methods (handle requests) 2023-06-05 15:09:17 +02:00
mcrakhman
aff2061bd1
WIP request manager 2023-06-04 19:01:33 +02:00
mcrakhman
b85f545fa3
Update connections on space level 2023-06-04 11:17:56 +02:00
mcrakhman
248205cddd
Merge remote-tracking branch 'origin/yamux' into new-sync-protocol 2023-06-04 10:43:11 +02:00
mcrakhman
ce63951ae6
Update proto files 2023-06-03 22:48:16 +02:00
mcrakhman
990cbc58b6
Add sync requests handling 2023-06-03 22:41:03 +02:00
mcrakhman
748681d765
WIP rearrange components 2023-06-03 15:57:55 +02:00
mcrakhman
a89a325d6c
Fix sync client 2023-06-02 01:08:56 +02:00
mcrakhman
815bc7927d
Wire up the stuff 2023-06-02 00:59:33 +02:00
mcrakhman
796b66478b
Further components rearrange 2023-06-01 22:55:37 +02:00
mcrakhman
eeb87dd144
WIP further space refactoring 2023-06-01 14:24:58 +02:00
mcrakhman
b0fa43fb14
WIP work on components 2023-06-01 10:28:32 +02:00
169 changed files with 16008 additions and 6676 deletions

View File

@ -17,20 +17,20 @@ jobs:
- name: git config - name: git config
run: git config --global url.https://${{ secrets.ANYTYPE_PAT }}@github.com/.insteadOf https://github.com/ run: git config --global url.https://${{ secrets.ANYTYPE_PAT }}@github.com/.insteadOf https://github.com/
# cache {{ # # cache {{
- id: go-cache-paths # - id: go-cache-paths
run: | # run: |
echo "GOCACHE=$(go env GOCACHE)" >> $GITHUB_OUTPUT # echo "GOCACHE=$(go env GOCACHE)" >> $GITHUB_OUTPUT
echo "GOMODCACHE=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT # echo "GOMODCACHE=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
- uses: actions/cache@v3 # - uses: actions/cache@v3
with: # with:
path: | # path: |
${{ steps.go-cache-paths.outputs.GOCACHE }} # ${{ steps.go-cache-paths.outputs.GOCACHE }}
${{ steps.go-cache-paths.outputs.GOMODCACHE }} # ${{ steps.go-cache-paths.outputs.GOMODCACHE }}
key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('**/go.sum') }} # key: ${{ runner.os }}-go-${{ matrix.go-version }}-${{ hashFiles('**/go.sum') }}
restore-keys: | # restore-keys: |
${{ runner.os }}-go-${{ matrix.go-version }}- # ${{ runner.os }}-go-${{ matrix.go-version }}-
# }} # # }}
- name: deps - name: deps
run: make deps run: make deps

View File

@ -20,6 +20,7 @@ proto:
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. net/streampool/testservice/protos/*.proto protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. net/streampool/testservice/protos/*.proto
protoc --gogofaster_out=:. net/secureservice/handshake/handshakeproto/protos/*.proto protoc --gogofaster_out=:. net/secureservice/handshake/handshakeproto/protos/*.proto
protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. coordinator/coordinatorproto/protos/*.proto protoc --gogofaster_out=$(PKGMAP):. --go-drpc_out=protolib=github.com/gogo/protobuf:. coordinator/coordinatorproto/protos/*.proto
protoc --gogofaster_out=:. --go-drpc_out=protolib=github.com/gogo/protobuf:. consensus/consensusproto/protos/*.proto
deps: deps:
go mod download go mod download

43
README.md Normal file
View File

@ -0,0 +1,43 @@
# Any-Sync
Any-Sync is an open-source protocol designed to create high-performance, local-first, peer-to-peer, end-to-end encrypted applications that facilitate seamless collaboration among multiple users and devices.
By utilizing this protocol, users can rest assured that they retain complete control over their data and digital experience. They are empowered to freely transition between various service providers, or even opt to self-host the applications.
This ensures utmost flexibility and autonomy for users in managing their personal information and digital interactions.
## Introduction
Most existing information management tools are implemented on centralized client-server architecture or designed for an offline-first single-user usage. Either way there are trade-offs for users: they can face restricted freedoms and privacy violations or compromise on the functionality of tools to avoid this.
We believe this goes against fundamental digital freedoms and that a new generation of software is needed that will respect these freedoms, while providing best in-class user experience.
Our goal with `any-sync` is to develop a protocol that will enable the deployment of this software.
Features:
- Conflict-free data replication across multiple devices and agents
- Built-in end-to-end encryption
- Cryptographically verifiable history of changes
- Adoption to frequent operations (high performance)
- Reliable and scalable infrastructure
- Simultaneous support of p2p and remote communication
## Protocol explanation
Plese read the [overview](https://tech.anytype.io/any-sync/overview) of protocol entities and design.
## Implementation
You can find the various parts of the protocol implemented in Go in the following repositories:
- [`any-sync-node`](https://github.com/anyproto/any-sync-node) — implementation of a sync node responsible for storing spaces and objects.
- [`any-sync-filenode`](https://github.com/anyproto/any-sync-filenode) — implementation of a file node responsible for storing files.
- [`any-sync-coordinator`](https://github.com/anyproto/any-sync-coordinator) — implementation of a coordinator node responsible for network configuration management.
## Contribution
Thank you for your desire to develop Anytype together.
Currently, we're not ready to accept PRs, but we will in the nearest future.
Follow us on [Github](https://github.com/anyproto) and join the [Contributors Community](https://github.com/orgs/anyproto/discussions).
---
Made by Any — a Swiss association 🇨🇭
Licensed under [MIT License](./LICENSE).

View File

@ -3,7 +3,7 @@ package mock_accountservice
import ( import (
"github.com/anyproto/any-sync/accountservice" "github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/golang/mock/gomock" "go.uber.org/mock/gomock"
) )
func NewAccountServiceWithAccount(ctrl *gomock.Controller, acc *accountdata.AccountKeys) *MockService { func NewAccountServiceWithAccount(ctrl *gomock.Controller, acc *accountdata.AccountKeys) *MockService {

View File

@ -9,7 +9,7 @@ import (
app "github.com/anyproto/any-sync/app" app "github.com/anyproto/any-sync/app"
accountdata "github.com/anyproto/any-sync/commonspace/object/accountdata" accountdata "github.com/anyproto/any-sync/commonspace/object/accountdata"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockService is a mock of Service interface. // MockService is a mock of Service interface.

View File

@ -55,11 +55,13 @@ type ComponentStatable interface {
// App is the central part of the application // App is the central part of the application
// It contains and manages all components // It contains and manages all components
type App struct { type App struct {
parent *App
components []Component components []Component
mu sync.RWMutex mu sync.RWMutex
startStat Stat startStat Stat
stopStat Stat stopStat Stat
deviceState int deviceState int
versionName string
anySyncVersion string anySyncVersion string
} }
@ -77,6 +79,19 @@ func (app *App) Version() string {
return GitSummary return GitSummary
} }
// SetVersionName sets the custom application version
func (app *App) SetVersionName(v string) {
app.versionName = v
}
// VersionName returns a string with the settled app version or auto-generated version if it didn't set
func (app *App) VersionName() string {
if app.versionName != "" {
return app.versionName
}
return AppName + ":" + GitSummary + "/any-sync:" + app.AnySyncVersion()
}
type Stat struct { type Stat struct {
SpentMsPerComp map[string]int64 SpentMsPerComp map[string]int64
SpentMsTotal int64 SpentMsTotal int64
@ -109,6 +124,16 @@ func VersionDescription() string {
return fmt.Sprintf("build on %s from %s at #%s(%s)", BuildDate, GitBranch, GitCommit, GitState) return fmt.Sprintf("build on %s from %s at #%s(%s)", BuildDate, GitBranch, GitCommit, GitState)
} }
// ChildApp creates a child container which has access to parent's components
// It doesn't call Start on any of the parent's components
func (app *App) ChildApp() *App {
return &App{
parent: app,
deviceState: app.deviceState,
anySyncVersion: app.AnySyncVersion(),
}
}
// Register adds service to registry // Register adds service to registry
// All components will be started in the order they were registered // All components will be started in the order they were registered
func (app *App) Register(s Component) *App { func (app *App) Register(s Component) *App {
@ -128,10 +153,14 @@ func (app *App) Register(s Component) *App {
func (app *App) Component(name string) Component { func (app *App) Component(name string) Component {
app.mu.RLock() app.mu.RLock()
defer app.mu.RUnlock() defer app.mu.RUnlock()
for _, s := range app.components { current := app
if s.Name() == name { for current != nil {
return s for _, s := range current.components {
if s.Name() == name {
return s
}
} }
current = current.parent
} }
return nil return nil
} }
@ -149,10 +178,14 @@ func (app *App) MustComponent(name string) Component {
func MustComponent[i any](app *App) i { func MustComponent[i any](app *App) i {
app.mu.RLock() app.mu.RLock()
defer app.mu.RUnlock() defer app.mu.RUnlock()
for _, s := range app.components { current := app
if v, ok := s.(i); ok { for current != nil {
return v for _, s := range current.components {
if v, ok := s.(i); ok {
return v
}
} }
current = current.parent
} }
empty := new(i) empty := new(i)
panic(fmt.Errorf("component with interface %T is not found", empty)) panic(fmt.Errorf("component with interface %T is not found", empty))
@ -162,9 +195,13 @@ func MustComponent[i any](app *App) i {
func (app *App) ComponentNames() (names []string) { func (app *App) ComponentNames() (names []string) {
app.mu.RLock() app.mu.RLock()
defer app.mu.RUnlock() defer app.mu.RUnlock()
names = make([]string, len(app.components)) names = make([]string, 0, len(app.components))
for i, c := range app.components { current := app
names[i] = c.Name() for current != nil {
for _, c := range current.components {
names = append(names, c.Name())
}
current = current.parent
} }
return return
} }
@ -225,6 +262,15 @@ func (app *App) Start(ctx context.Context) (err error) {
return return
} }
// IterateComponents iterates over all registered components. It's safe for concurrent use.
func (app *App) IterateComponents(fn func(Component)) {
app.mu.RLock()
defer app.mu.RUnlock()
for _, s := range app.components {
fn(s)
}
}
func stackAllGoroutines() []byte { func stackAllGoroutines() []byte {
buf := make([]byte, 1024) buf := make([]byte, 1024)
for { for {

View File

@ -34,6 +34,40 @@ func TestAppServiceRegistry(t *testing.T) {
names := app.ComponentNames() names := app.ComponentNames()
assert.Equal(t, names, []string{"c1", "r1", "s1"}) assert.Equal(t, names, []string{"c1", "r1", "s1"})
}) })
t.Run("Child MustComponent", func(t *testing.T) {
app := app.ChildApp()
app.Register(newTestService(testTypeComponent, "x1", nil, nil))
for _, name := range []string{"c1", "r1", "s1", "x1"} {
assert.NotPanics(t, func() { app.MustComponent(name) }, name)
}
assert.Panics(t, func() { app.MustComponent("not-registered") })
})
t.Run("Child ComponentNames", func(t *testing.T) {
app := app.ChildApp()
app.Register(newTestService(testTypeComponent, "x1", nil, nil))
names := app.ComponentNames()
assert.Equal(t, names, []string{"x1", "c1", "r1", "s1"})
})
t.Run("Child override", func(t *testing.T) {
app := app.ChildApp()
app.Register(newTestService(testTypeRunnable, "s1", nil, nil))
_ = app.MustComponent("s1").(*testRunnable)
})
}
func TestApp_IterateComponents(t *testing.T) {
app := new(App)
app.Register(newTestService(testTypeRunnable, "c1", nil, nil))
app.Register(newTestService(testTypeRunnable, "r1", nil, nil))
app.Register(newTestService(testTypeComponent, "s1", nil, nil))
var got []string
app.IterateComponents(func(s Component) {
got = append(got, s.Name())
})
assert.ElementsMatch(t, []string{"c1", "r1", "s1"}, got)
} }
func TestAppStart(t *testing.T) { func TestAppStart(t *testing.T) {

View File

@ -9,7 +9,7 @@ import (
reflect "reflect" reflect "reflect"
ldiff "github.com/anyproto/any-sync/app/ldiff" ldiff "github.com/anyproto/any-sync/app/ldiff"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockDiff is a mock of Diff interface. // MockDiff is a mock of Diff interface.

View File

@ -2,9 +2,10 @@ package ocache
import ( import (
"context" "context"
"go.uber.org/zap"
"sync" "sync"
"time" "time"
"go.uber.org/zap"
) )
type entryState int type entryState int
@ -25,6 +26,7 @@ type entry struct {
value Object value Object
close chan struct{} close chan struct{}
mx sync.Mutex mx sync.Mutex
cancel context.CancelFunc
} }
func newEntry(id string, value Object, state entryState) *entry { func newEntry(id string, value Object, state entryState) *entry {
@ -49,6 +51,20 @@ func (e *entry) isClosing() bool {
return e.state == entryStateClosed || e.state == entryStateClosing return e.state == entryStateClosed || e.state == entryStateClosing
} }
func (e *entry) setCancel(cancel context.CancelFunc) {
e.mx.Lock()
defer e.mx.Unlock()
e.cancel = cancel
}
func (e *entry) cancelLoad() {
e.mx.Lock()
defer e.mx.Unlock()
if e.cancel != nil {
e.cancel()
}
}
func (e *entry) waitLoad(ctx context.Context, id string) (value Object, err error) { func (e *entry) waitLoad(ctx context.Context, id string) (value Object, err error) {
select { select {
case <-ctx.Done(): case <-ctx.Done():

View File

@ -3,10 +3,11 @@ package ocache
import ( import (
"context" "context"
"errors" "errors"
"github.com/anyproto/any-sync/app/logger"
"go.uber.org/zap"
"sync" "sync"
"time" "time"
"github.com/anyproto/any-sync/app/logger"
"go.uber.org/zap"
) )
var ( var (
@ -157,7 +158,10 @@ func (c *oCache) Pick(ctx context.Context, id string) (value Object, err error)
func (c *oCache) load(ctx context.Context, id string, e *entry) { func (c *oCache) load(ctx context.Context, id string, e *entry) {
defer close(e.load) defer close(e.load)
ctx, cancel := context.WithCancel(ctx)
e.setCancel(cancel)
value, err := c.loadFunc(ctx, id) value, err := c.loadFunc(ctx, id)
cancel()
c.mu.Lock() c.mu.Lock()
defer c.mu.Unlock() defer c.mu.Unlock()
@ -315,6 +319,7 @@ func (c *oCache) Close() (err error) {
close(c.closeCh) close(c.closeCh)
var toClose []*entry var toClose []*entry
for _, e := range c.data { for _, e := range c.data {
e.cancelLoad()
toClose = append(toClose, e) toClose = append(toClose, e)
} }
c.mu.Unlock() c.mu.Unlock()

View File

@ -386,6 +386,25 @@ func Test_OCache_Remove(t *testing.T) {
}) })
} }
func TestOCacheCancelWhenRemove(t *testing.T) {
c := New(func(ctx context.Context, id string) (value Object, err error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
}
}, WithTTL(time.Millisecond*10))
stopLoad := make(chan struct{})
var err error
go func() {
_, err = c.Get(context.TODO(), "id")
stopLoad <- struct{}{}
}()
time.Sleep(time.Millisecond * 10)
c.Close()
<-stopLoad
require.Equal(t, context.Canceled, err)
}
func TestOCacheFuzzy(t *testing.T) { func TestOCacheFuzzy(t *testing.T) {
t.Run("test many objects gc, get and remove simultaneously, close after", func(t *testing.T) { t.Run("test many objects gc, get and remove simultaneously, close after", func(t *testing.T) {
tryCloseIds := make(map[string]bool) tryCloseIds := make(map[string]bool)

View File

@ -1,62 +0,0 @@
package commonspace
import (
"context"
"github.com/anyproto/any-sync/commonspace/object/syncobjectgetter"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/treemanager"
"sync/atomic"
)
type commonGetter struct {
treemanager.TreeManager
spaceId string
reservedObjects []syncobjectgetter.SyncObject
spaceIsClosed *atomic.Bool
}
func newCommonGetter(spaceId string, getter treemanager.TreeManager, spaceIsClosed *atomic.Bool) *commonGetter {
return &commonGetter{
TreeManager: getter,
spaceId: spaceId,
spaceIsClosed: spaceIsClosed,
}
}
func (c *commonGetter) AddObject(object syncobjectgetter.SyncObject) {
c.reservedObjects = append(c.reservedObjects, object)
}
func (c *commonGetter) GetTree(ctx context.Context, spaceId, treeId string) (objecttree.ObjectTree, error) {
if c.spaceIsClosed.Load() {
return nil, ErrSpaceClosed
}
if obj := c.getReservedObject(treeId); obj != nil {
return obj.(objecttree.ObjectTree), nil
}
return c.TreeManager.GetTree(ctx, spaceId, treeId)
}
func (c *commonGetter) getReservedObject(id string) syncobjectgetter.SyncObject {
for _, obj := range c.reservedObjects {
if obj != nil && obj.Id() == id {
return obj
}
}
return nil
}
func (c *commonGetter) GetObject(ctx context.Context, objectId string) (obj syncobjectgetter.SyncObject, err error) {
if c.spaceIsClosed.Load() {
return nil, ErrSpaceClosed
}
if obj := c.getReservedObject(objectId); obj != nil {
return obj, nil
}
t, err := c.TreeManager.GetTree(ctx, c.spaceId, objectId)
if err != nil {
return
}
obj = t.(syncobjectgetter.SyncObject)
return
}

View File

@ -1,4 +1,4 @@
package commonspace package config
type ConfigGetter interface { type ConfigGetter interface {
GetSpace() Config GetSpace() Config

View File

@ -3,6 +3,7 @@ package credentialprovider
import ( import (
"context" "context"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
) )
@ -13,12 +14,21 @@ func NewNoOp() CredentialProvider {
} }
type CredentialProvider interface { type CredentialProvider interface {
app.Component
GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error)
} }
type noOpProvider struct { type noOpProvider struct {
} }
func (n noOpProvider) Init(a *app.App) (err error) {
return nil
}
func (n noOpProvider) Name() (name string) {
return CName
}
func (n noOpProvider) GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) { func (n noOpProvider) GetCredential(ctx context.Context, spaceHeader *spacesyncproto.RawSpaceHeaderWithId) ([]byte, error) {
return nil, nil return nil, nil
} }

View File

@ -8,8 +8,9 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
app "github.com/anyproto/any-sync/app"
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto" spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockCredentialProvider is a mock of CredentialProvider interface. // MockCredentialProvider is a mock of CredentialProvider interface.
@ -49,3 +50,31 @@ func (mr *MockCredentialProviderMockRecorder) GetCredential(arg0, arg1 interface
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCredential", reflect.TypeOf((*MockCredentialProvider)(nil).GetCredential), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCredential", reflect.TypeOf((*MockCredentialProvider)(nil).GetCredential), arg0, arg1)
} }
// Init mocks base method.
func (m *MockCredentialProvider) Init(arg0 *app.App) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Init indicates an expected call of Init.
func (mr *MockCredentialProviderMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockCredentialProvider)(nil).Init), arg0)
}
// Name mocks base method.
func (m *MockCredentialProvider) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockCredentialProviderMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockCredentialProvider)(nil).Name))
}

View File

@ -73,13 +73,14 @@ func TestSpaceDeleteIds(t *testing.T) {
fx.treeManager.space = spc fx.treeManager.space = spc
err = spc.Init(ctx) err = spc.Init(ctx)
require.NoError(t, err) require.NoError(t, err)
close(fx.treeManager.waitLoad)
var ids []string var ids []string
for i := 0; i < totalObjs; i++ { for i := 0; i < totalObjs; i++ {
// creating a tree // creating a tree
bytes := make([]byte, 32) bytes := make([]byte, 32)
rand.Read(bytes) rand.Read(bytes)
doc, err := spc.CreateTree(ctx, objecttree.ObjectTreeCreatePayload{ doc, err := spc.TreeBuilder().CreateTree(ctx, objecttree.ObjectTreeCreatePayload{
PrivKey: acc.SignKey, PrivKey: acc.SignKey,
ChangeType: "some", ChangeType: "some",
SpaceId: spc.Id(), SpaceId: spc.Id(),
@ -88,7 +89,7 @@ func TestSpaceDeleteIds(t *testing.T) {
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
}) })
require.NoError(t, err) require.NoError(t, err)
tr, err := spc.PutTree(ctx, doc, nil) tr, err := spc.TreeBuilder().PutTree(ctx, doc, nil)
require.NoError(t, err) require.NoError(t, err)
ids = append(ids, tr.Id()) ids = append(ids, tr.Id())
tr.Close() tr.Close()
@ -106,7 +107,7 @@ func TestSpaceDeleteIds(t *testing.T) {
func createTree(t *testing.T, ctx context.Context, spc Space, acc *accountdata.AccountKeys) string { func createTree(t *testing.T, ctx context.Context, spc Space, acc *accountdata.AccountKeys) string {
bytes := make([]byte, 32) bytes := make([]byte, 32)
rand.Read(bytes) rand.Read(bytes)
doc, err := spc.CreateTree(ctx, objecttree.ObjectTreeCreatePayload{ doc, err := spc.TreeBuilder().CreateTree(ctx, objecttree.ObjectTreeCreatePayload{
PrivKey: acc.SignKey, PrivKey: acc.SignKey,
ChangeType: "some", ChangeType: "some",
SpaceId: spc.Id(), SpaceId: spc.Id(),
@ -115,7 +116,7 @@ func createTree(t *testing.T, ctx context.Context, spc Space, acc *accountdata.A
Timestamp: time.Now().Unix(), Timestamp: time.Now().Unix(),
}) })
require.NoError(t, err) require.NoError(t, err)
tr, err := spc.PutTree(ctx, doc, nil) tr, err := spc.TreeBuilder().PutTree(ctx, doc, nil)
require.NoError(t, err) require.NoError(t, err)
tr.Close() tr.Close()
return tr.Id() return tr.Id()
@ -147,9 +148,10 @@ func TestSpaceDeleteIdsIncorrectSnapshot(t *testing.T) {
// adding space to tree manager // adding space to tree manager
fx.treeManager.space = spc fx.treeManager.space = spc
err = spc.Init(ctx) err = spc.Init(ctx)
close(fx.treeManager.waitLoad)
require.NoError(t, err) require.NoError(t, err)
settingsObject := spc.(*space).settingsObject settingsObject := spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
var ids []string var ids []string
for i := 0; i < totalObjs; i++ { for i := 0; i < totalObjs; i++ {
id := createTree(t, ctx, spc, acc) id := createTree(t, ctx, spc, acc)
@ -183,17 +185,19 @@ func TestSpaceDeleteIdsIncorrectSnapshot(t *testing.T) {
spc, err = fx.spaceService.NewSpace(ctx, sp) spc, err = fx.spaceService.NewSpace(ctx, sp)
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, spc) require.NotNil(t, spc)
fx.treeManager.waitLoad = make(chan struct{})
fx.treeManager.space = spc fx.treeManager.space = spc
fx.treeManager.deletedIds = nil fx.treeManager.deletedIds = nil
err = spc.Init(ctx) err = spc.Init(ctx)
require.NoError(t, err) require.NoError(t, err)
close(fx.treeManager.waitLoad)
// waiting until everything is deleted // waiting until everything is deleted
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)
require.Equal(t, len(ids), len(fx.treeManager.deletedIds)) require.Equal(t, len(ids), len(fx.treeManager.deletedIds))
// checking that new snapshot will contain all the changes // checking that new snapshot will contain all the changes
settingsObject = spc.(*space).settingsObject settingsObject = spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
settings.DoSnapshot = func(treeLen int) bool { settings.DoSnapshot = func(treeLen int) bool {
return true return true
} }
@ -230,8 +234,9 @@ func TestSpaceDeleteIdsMarkDeleted(t *testing.T) {
fx.treeManager.space = spc fx.treeManager.space = spc
err = spc.Init(ctx) err = spc.Init(ctx)
require.NoError(t, err) require.NoError(t, err)
close(fx.treeManager.waitLoad)
settingsObject := spc.(*space).settingsObject settingsObject := spc.(*space).app.MustComponent(settings.CName).(settings.Settings).SettingsObject()
var ids []string var ids []string
for i := 0; i < totalObjs; i++ { for i := 0; i < totalObjs; i++ {
id := createTree(t, ctx, spc, acc) id := createTree(t, ctx, spc, acc)
@ -259,10 +264,12 @@ func TestSpaceDeleteIdsMarkDeleted(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, spc) require.NotNil(t, spc)
fx.treeManager.space = spc fx.treeManager.space = spc
fx.treeManager.waitLoad = make(chan struct{})
fx.treeManager.deletedIds = nil fx.treeManager.deletedIds = nil
fx.treeManager.markedIds = nil fx.treeManager.markedIds = nil
err = spc.Init(ctx) err = spc.Init(ctx)
require.NoError(t, err) require.NoError(t, err)
close(fx.treeManager.waitLoad)
// waiting until everything is deleted // waiting until everything is deleted
time.Sleep(3 * time.Second) time.Sleep(3 * time.Second)

View File

@ -1,16 +1,22 @@
//go:generate mockgen -destination mock_settingsstate/mock_settingsstate.go github.com/anyproto/any-sync/commonspace/settings/settingsstate ObjectDeletionState,StateBuilder,ChangeFactory //go:generate mockgen -destination mock_deletionstate/mock_deletionstate.go github.com/anyproto/any-sync/commonspace/deletionstate ObjectDeletionState
package settingsstate package deletionstate
import ( import (
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"go.uber.org/zap" "go.uber.org/zap"
"sync" "sync"
) )
var log = logger.NewNamed(CName)
const CName = "common.commonspace.deletionstate"
type StateUpdateObserver func(ids []string) type StateUpdateObserver func(ids []string)
type ObjectDeletionState interface { type ObjectDeletionState interface {
app.Component
AddObserver(observer StateUpdateObserver) AddObserver(observer StateUpdateObserver)
Add(ids map[string]struct{}) Add(ids map[string]struct{})
GetQueued() (ids []string) GetQueued() (ids []string)
@ -28,12 +34,20 @@ type objectDeletionState struct {
storage spacestorage.SpaceStorage storage spacestorage.SpaceStorage
} }
func NewObjectDeletionState(log logger.CtxLogger, storage spacestorage.SpaceStorage) ObjectDeletionState { func (st *objectDeletionState) Init(a *app.App) (err error) {
st.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
return nil
}
func (st *objectDeletionState) Name() (name string) {
return CName
}
func New() ObjectDeletionState {
return &objectDeletionState{ return &objectDeletionState{
log: log, log: log,
queued: map[string]struct{}{}, queued: map[string]struct{}{},
deleted: map[string]struct{}{}, deleted: map[string]struct{}{},
storage: storage,
} }
} }

View File

@ -1,11 +1,10 @@
package settingsstate package deletionstate
import ( import (
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"sort" "sort"
"testing" "testing"
) )
@ -19,7 +18,8 @@ type fixture struct {
func newFixture(t *testing.T) *fixture { func newFixture(t *testing.T) *fixture {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
spaceStorage := mock_spacestorage.NewMockSpaceStorage(ctrl) spaceStorage := mock_spacestorage.NewMockSpaceStorage(ctrl)
delState := NewObjectDeletionState(logger.NewNamed("test"), spaceStorage).(*objectDeletionState) delState := New().(*objectDeletionState)
delState.storage = spaceStorage
return &fixture{ return &fixture{
ctrl: ctrl, ctrl: ctrl,
delState: delState, delState: delState,

View File

@ -0,0 +1,144 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/deletionstate (interfaces: ObjectDeletionState)
// Package mock_deletionstate is a generated GoMock package.
package mock_deletionstate
import (
reflect "reflect"
app "github.com/anyproto/any-sync/app"
deletionstate "github.com/anyproto/any-sync/commonspace/deletionstate"
gomock "go.uber.org/mock/gomock"
)
// MockObjectDeletionState is a mock of ObjectDeletionState interface.
type MockObjectDeletionState struct {
ctrl *gomock.Controller
recorder *MockObjectDeletionStateMockRecorder
}
// MockObjectDeletionStateMockRecorder is the mock recorder for MockObjectDeletionState.
type MockObjectDeletionStateMockRecorder struct {
mock *MockObjectDeletionState
}
// NewMockObjectDeletionState creates a new mock instance.
func NewMockObjectDeletionState(ctrl *gomock.Controller) *MockObjectDeletionState {
mock := &MockObjectDeletionState{ctrl: ctrl}
mock.recorder = &MockObjectDeletionStateMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockObjectDeletionState) EXPECT() *MockObjectDeletionStateMockRecorder {
return m.recorder
}
// Add mocks base method.
func (m *MockObjectDeletionState) Add(arg0 map[string]struct{}) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0)
}
// Add indicates an expected call of Add.
func (mr *MockObjectDeletionStateMockRecorder) Add(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockObjectDeletionState)(nil).Add), arg0)
}
// AddObserver mocks base method.
func (m *MockObjectDeletionState) AddObserver(arg0 deletionstate.StateUpdateObserver) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddObserver", arg0)
}
// AddObserver indicates an expected call of AddObserver.
func (mr *MockObjectDeletionStateMockRecorder) AddObserver(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddObserver", reflect.TypeOf((*MockObjectDeletionState)(nil).AddObserver), arg0)
}
// Delete mocks base method.
func (m *MockObjectDeletionState) Delete(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockObjectDeletionStateMockRecorder) Delete(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockObjectDeletionState)(nil).Delete), arg0)
}
// Exists mocks base method.
func (m *MockObjectDeletionState) Exists(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Exists", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// Exists indicates an expected call of Exists.
func (mr *MockObjectDeletionStateMockRecorder) Exists(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockObjectDeletionState)(nil).Exists), arg0)
}
// Filter mocks base method.
func (m *MockObjectDeletionState) Filter(arg0 []string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Filter", arg0)
ret0, _ := ret[0].([]string)
return ret0
}
// Filter indicates an expected call of Filter.
func (mr *MockObjectDeletionStateMockRecorder) Filter(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filter", reflect.TypeOf((*MockObjectDeletionState)(nil).Filter), arg0)
}
// GetQueued mocks base method.
func (m *MockObjectDeletionState) GetQueued() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetQueued")
ret0, _ := ret[0].([]string)
return ret0
}
// GetQueued indicates an expected call of GetQueued.
func (mr *MockObjectDeletionStateMockRecorder) GetQueued() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQueued", reflect.TypeOf((*MockObjectDeletionState)(nil).GetQueued))
}
// Init mocks base method.
func (m *MockObjectDeletionState) Init(arg0 *app.App) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Init indicates an expected call of Init.
func (mr *MockObjectDeletionStateMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockObjectDeletionState)(nil).Init), arg0)
}
// Name mocks base method.
func (m *MockObjectDeletionState) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockObjectDeletionStateMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockObjectDeletionState)(nil).Name))
}

View File

@ -3,49 +3,45 @@ package headsync
import ( import (
"context" "context"
"fmt" "fmt"
"time"
"github.com/anyproto/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/credentialprovider" "github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/peermanager" "github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/rpc/rpcerr" "github.com/anyproto/any-sync/net/rpc/rpcerr"
"github.com/anyproto/any-sync/util/slice"
"go.uber.org/zap" "go.uber.org/zap"
"time"
) )
type DiffSyncer interface { type DiffSyncer interface {
Sync(ctx context.Context) error Sync(ctx context.Context) error
RemoveObjects(ids []string) RemoveObjects(ids []string)
UpdateHeads(id string, heads []string) UpdateHeads(id string, heads []string)
Init(deletionState settingsstate.ObjectDeletionState) Init()
Close() error Close() error
} }
func newDiffSyncer( func newDiffSyncer(hs *headSync) DiffSyncer {
spaceId string,
diff ldiff.Diff,
peerManager peermanager.PeerManager,
cache treemanager.TreeManager,
storage spacestorage.SpaceStorage,
clientFactory spacesyncproto.ClientFactory,
syncStatus syncstatus.StatusUpdater,
credentialProvider credentialprovider.CredentialProvider,
log logger.CtxLogger) DiffSyncer {
return &diffSyncer{ return &diffSyncer{
diff: diff, diff: hs.diff,
spaceId: spaceId, spaceId: hs.spaceId,
treeManager: cache, treeManager: hs.treeManager,
storage: storage, storage: hs.storage,
peerManager: peerManager, peerManager: hs.peerManager,
clientFactory: clientFactory, clientFactory: spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient),
credentialProvider: credentialProvider, credentialProvider: hs.credentialProvider,
log: log, log: log,
syncStatus: syncStatus, syncStatus: hs.syncStatus,
deletionState: hs.deletionState,
syncAcl: hs.syncAcl,
} }
} }
@ -57,14 +53,14 @@ type diffSyncer struct {
storage spacestorage.SpaceStorage storage spacestorage.SpaceStorage
clientFactory spacesyncproto.ClientFactory clientFactory spacesyncproto.ClientFactory
log logger.CtxLogger log logger.CtxLogger
deletionState settingsstate.ObjectDeletionState deletionState deletionstate.ObjectDeletionState
credentialProvider credentialprovider.CredentialProvider credentialProvider credentialprovider.CredentialProvider
syncStatus syncstatus.StatusUpdater syncStatus syncstatus.StatusUpdater
treeSyncer treemanager.TreeSyncer treeSyncer treemanager.TreeSyncer
syncAcl syncacl.SyncAcl
} }
func (d *diffSyncer) Init(deletionState settingsstate.ObjectDeletionState) { func (d *diffSyncer) Init() {
d.deletionState = deletionState
d.deletionState.AddObserver(d.RemoveObjects) d.deletionState.AddObserver(d.RemoveObjects)
d.treeSyncer = d.treeManager.NewTreeSyncer(d.spaceId, d.treeManager) d.treeSyncer = d.treeManager.NewTreeSyncer(d.spaceId, d.treeManager)
} }
@ -115,10 +111,17 @@ func (d *diffSyncer) Sync(ctx context.Context) error {
func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error) { func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error) {
ctx = logger.CtxWithFields(ctx, zap.String("peerId", p.Id())) ctx = logger.CtxWithFields(ctx, zap.String("peerId", p.Id()))
conn, err := p.AcquireDrpcConn(ctx)
if err != nil {
return
}
defer p.ReleaseDrpcConn(conn)
var ( var (
cl = d.clientFactory.Client(p) cl = d.clientFactory.Client(conn)
rdiff = NewRemoteDiff(d.spaceId, cl) rdiff = NewRemoteDiff(d.spaceId, cl)
stateCounter = d.syncStatus.StateCounter() stateCounter = d.syncStatus.StateCounter()
syncAclId = d.syncAcl.Id()
) )
newIds, changedIds, removedIds, err := d.diff.Diff(ctx, rdiff) newIds, changedIds, removedIds, err := d.diff.Diff(ctx, rdiff)
@ -141,17 +144,29 @@ func (d *diffSyncer) syncWithPeer(ctx context.Context, p peer.Peer) (err error)
// not syncing ids which were removed through settings document // not syncing ids which were removed through settings document
missingIds := d.deletionState.Filter(newIds) missingIds := d.deletionState.Filter(newIds)
existingIds := append(d.deletionState.Filter(removedIds), d.deletionState.Filter(changedIds)...) existingIds := append(d.deletionState.Filter(removedIds), d.deletionState.Filter(changedIds)...)
d.syncStatus.RemoveAllExcept(p.Id(), existingIds, stateCounter) d.syncStatus.RemoveAllExcept(p.Id(), existingIds, stateCounter)
prevExistingLen := len(existingIds)
existingIds = slice.DiscardFromSlice(existingIds, func(s string) bool {
return s == syncAclId
})
// if we removed acl head from the list
if len(existingIds) < prevExistingLen {
if syncErr := d.syncAcl.SyncWithPeer(ctx, p.Id()); syncErr != nil {
log.Warn("failed to send acl sync message to peer", zap.String("aclId", syncAclId))
}
}
// treeSyncer should not get acl id, that's why we filter existing ids before
err = d.treeSyncer.SyncAll(ctx, p.Id(), existingIds, missingIds) err = d.treeSyncer.SyncAll(ctx, p.Id(), existingIds, missingIds)
if err != nil { if err != nil {
return err return err
} }
d.log.Info("sync done:", zap.Int("newIds", len(newIds)), d.log.Info("sync done:",
zap.Int("newIds", len(newIds)),
zap.Int("changedIds", len(changedIds)), zap.Int("changedIds", len(changedIds)),
zap.Int("removedIds", len(removedIds)), zap.Int("removedIds", len(removedIds)),
zap.Int("already deleted ids", totalLen-len(existingIds)-len(missingIds)), zap.Int("already deleted ids", totalLen-prevExistingLen-len(missingIds)),
zap.String("peerId", p.Id()), zap.String("peerId", p.Id()),
) )
return return

View File

@ -4,28 +4,19 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/app/ldiff/mock_ldiff"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/credentialprovider/mock_credentialprovider"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage/mock_liststorage"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
mock_treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
"github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager"
"github.com/anyproto/any-sync/commonspace/peermanager/mock_peermanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/spacesyncproto/mock_spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/net/peer"
"github.com/golang/mock/gomock"
"github.com/libp2p/go-libp2p/core/sec"
"github.com/stretchr/testify/require"
"storj.io/drpc"
"testing" "testing"
"time" "time"
"github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage/mock_liststorage"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/net/peer"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"storj.io/drpc"
) )
type pushSpaceRequestMatcher struct { type pushSpaceRequestMatcher struct {
@ -36,60 +27,6 @@ type pushSpaceRequestMatcher struct {
spaceHeader *spacesyncproto.RawSpaceHeaderWithId spaceHeader *spacesyncproto.RawSpaceHeaderWithId
} }
func (p pushSpaceRequestMatcher) Matches(x interface{}) bool {
res, ok := x.(*spacesyncproto.SpacePushRequest)
if !ok {
return false
}
return res.Payload.AclPayloadId == p.aclRootId && res.Payload.SpaceHeader == p.spaceHeader && res.Payload.SpaceSettingsPayloadId == p.settingsId && bytes.Equal(p.credential, res.Credential)
}
func (p pushSpaceRequestMatcher) String() string {
return ""
}
type mockPeer struct{}
func (m mockPeer) Addr() string {
return ""
}
func (m mockPeer) TryClose(objectTTL time.Duration) (res bool, err error) {
return true, m.Close()
}
func (m mockPeer) Id() string {
return "mockId"
}
func (m mockPeer) LastUsage() time.Time {
return time.Time{}
}
func (m mockPeer) Secure() sec.SecureConn {
return nil
}
func (m mockPeer) UpdateLastUsage() {
}
func (m mockPeer) Close() error {
return nil
}
func (m mockPeer) Closed() <-chan struct{} {
return make(chan struct{})
}
func (m mockPeer) Invoke(ctx context.Context, rpc string, enc drpc.Encoding, in, out drpc.Message) error {
return nil
}
func (m mockPeer) NewStream(ctx context.Context, rpc string, enc drpc.Encoding) (drpc.Stream, error) {
return nil, nil
}
func newPushSpaceRequestMatcher( func newPushSpaceRequestMatcher(
spaceId string, spaceId string,
aclRootId string, aclRootId string,
@ -105,81 +42,159 @@ func newPushSpaceRequestMatcher(
} }
} }
func TestDiffSyncer_Sync(t *testing.T) { func (p pushSpaceRequestMatcher) Matches(x interface{}) bool {
// setup res, ok := x.(*spacesyncproto.SpacePushRequest)
ctx := context.Background() if !ok {
ctrl := gomock.NewController(t) return false
defer ctrl.Finish() }
diffMock := mock_ldiff.NewMockDiff(ctrl) return res.Payload.AclPayloadId == p.aclRootId && res.Payload.SpaceHeader == p.spaceHeader && res.Payload.SpaceSettingsPayloadId == p.settingsId && bytes.Equal(p.credential, res.Credential)
peerManagerMock := mock_peermanager.NewMockPeerManager(ctrl) }
cacheMock := mock_treemanager.NewMockTreeManager(ctrl)
stMock := mock_spacestorage.NewMockSpaceStorage(ctrl) func (p pushSpaceRequestMatcher) String() string {
clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl) return ""
factory := spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient { }
return clientMock
type mockPeer struct {
}
func (m mockPeer) Id() string {
return "peerId"
}
func (m mockPeer) Context() context.Context {
return context.Background()
}
func (m mockPeer) AcquireDrpcConn(ctx context.Context) (drpc.Conn, error) {
return nil, nil
}
func (m mockPeer) ReleaseDrpcConn(conn drpc.Conn) {
return
}
func (m mockPeer) DoDrpc(ctx context.Context, do func(conn drpc.Conn) error) error {
return nil
}
func (m mockPeer) IsClosed() bool {
return false
}
func (m mockPeer) TryClose(objectTTL time.Duration) (res bool, err error) {
return false, err
}
func (m mockPeer) Close() (err error) {
return nil
}
func (fx *headSyncFixture) initDiffSyncer(t *testing.T) {
fx.init(t)
fx.diffSyncer = newDiffSyncer(fx.headSync).(*diffSyncer)
fx.diffSyncer.clientFactory = spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient {
return fx.clientMock
}) })
treeSyncerMock := mock_treemanager.NewMockTreeSyncer(ctrl) fx.deletionStateMock.EXPECT().AddObserver(gomock.Any())
credentialProvider := mock_credentialprovider.NewMockCredentialProvider(ctrl) fx.treeManagerMock.EXPECT().NewTreeSyncer(fx.spaceState.SpaceId, fx.treeManagerMock).Return(fx.treeSyncerMock)
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) fx.diffSyncer.Init()
spaceId := "spaceId" }
aclRootId := "aclRootId"
l := logger.NewNamed(spaceId) func TestDiffSyncer(t *testing.T) {
diffSyncer := newDiffSyncer(spaceId, diffMock, peerManagerMock, cacheMock, stMock, factory, syncstatus.NewNoOpSyncStatus(), credentialProvider, l) ctx := context.Background()
delState.EXPECT().AddObserver(gomock.Any())
cacheMock.EXPECT().NewTreeSyncer(spaceId, gomock.Any()).Return(treeSyncerMock)
diffSyncer.Init(delState)
t.Run("diff syncer sync", func(t *testing.T) { t.Run("diff syncer sync", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
mPeer := mockPeer{} mPeer := mockPeer{}
peerManagerMock.EXPECT(). fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mPeer}, nil) Return([]peer.Peer{mPeer}, nil)
diffMock.EXPECT(). fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return([]string{"new"}, []string{"changed"}, nil, nil) Return([]string{"new"}, []string{"changed"}, nil, nil)
delState.EXPECT().Filter([]string{"new"}).Return([]string{"new"}).Times(1) fx.deletionStateMock.EXPECT().Filter([]string{"new"}).Return([]string{"new"}).Times(1)
delState.EXPECT().Filter([]string{"changed"}).Return([]string{"changed"}).Times(1) fx.deletionStateMock.EXPECT().Filter([]string{"changed"}).Return([]string{"changed"}).Times(1)
delState.EXPECT().Filter(nil).Return(nil).Times(1) fx.deletionStateMock.EXPECT().Filter(nil).Return(nil).Times(1)
treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"changed"}, []string{"new"}).Return(nil) fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"changed"}, []string{"new"}).Return(nil)
require.NoError(t, diffSyncer.Sync(ctx)) require.NoError(t, fx.diffSyncer.Sync(ctx))
})
t.Run("diff syncer sync, acl changed", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
mPeer := mockPeer{}
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mPeer}, nil)
fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return([]string{"new"}, []string{"changed"}, nil, nil)
fx.deletionStateMock.EXPECT().Filter([]string{"new"}).Return([]string{"new"}).Times(1)
fx.deletionStateMock.EXPECT().Filter([]string{"changed"}).Return([]string{"changed", "aclId"}).Times(1)
fx.deletionStateMock.EXPECT().Filter(nil).Return(nil).Times(1)
fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"changed"}, []string{"new"}).Return(nil)
fx.aclMock.EXPECT().SyncWithPeer(gomock.Any(), mPeer.Id()).Return(nil)
require.NoError(t, fx.diffSyncer.Sync(ctx))
}) })
t.Run("diff syncer sync conf error", func(t *testing.T) { t.Run("diff syncer sync conf error", func(t *testing.T) {
peerManagerMock.EXPECT(). fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
ctx := context.Background()
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return(nil, fmt.Errorf("some error")) Return(nil, fmt.Errorf("some error"))
require.Error(t, diffSyncer.Sync(ctx)) require.Error(t, fx.diffSyncer.Sync(ctx))
}) })
t.Run("deletion state remove objects", func(t *testing.T) { t.Run("deletion state remove objects", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
deletedId := "id" deletedId := "id"
delState.EXPECT().Exists(deletedId).Return(true) fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.deletionStateMock.EXPECT().Exists(deletedId).Return(true)
// this should not result in any mock being called // this should not result in any mock being called
diffSyncer.UpdateHeads(deletedId, []string{"someHead"}) fx.diffSyncer.UpdateHeads(deletedId, []string{"someHead"})
}) })
t.Run("update heads updates diff", func(t *testing.T) { t.Run("update heads updates diff", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
newId := "newId" newId := "newId"
newHeads := []string{"h1", "h2"} newHeads := []string{"h1", "h2"}
hash := "hash" hash := "hash"
diffMock.EXPECT().Set(ldiff.Element{ fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.diffMock.EXPECT().Set(ldiff.Element{
Id: newId, Id: newId,
Head: concatStrings(newHeads), Head: concatStrings(newHeads),
}) })
diffMock.EXPECT().Hash().Return(hash) fx.diffMock.EXPECT().Hash().Return(hash)
delState.EXPECT().Exists(newId).Return(false) fx.deletionStateMock.EXPECT().Exists(newId).Return(false)
stMock.EXPECT().WriteSpaceHash(hash) fx.storageMock.EXPECT().WriteSpaceHash(hash)
diffSyncer.UpdateHeads(newId, newHeads) fx.diffSyncer.UpdateHeads(newId, newHeads)
}) })
t.Run("diff syncer sync space missing", func(t *testing.T) { t.Run("diff syncer sync space missing", func(t *testing.T) {
aclStorageMock := mock_liststorage.NewMockListStorage(ctrl) fx := newHeadSyncFixture(t)
settingsStorage := mock_treestorage.NewMockTreeStorage(ctrl) fx.initDiffSyncer(t)
defer fx.stop()
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
aclStorageMock := mock_liststorage.NewMockListStorage(fx.ctrl)
settingsStorage := mock_treestorage.NewMockTreeStorage(fx.ctrl)
settingsId := "settingsId" settingsId := "settingsId"
aclRoot := &aclrecordproto.RawAclRecordWithId{ aclRootId := "aclRootId"
aclRoot := &consensusproto.RawRecordWithId{
Id: aclRootId, Id: aclRootId,
} }
settingsRoot := &treechangeproto.RawTreeChangeWithId{ settingsRoot := &treechangeproto.RawTreeChangeWithId{
@ -189,55 +204,63 @@ func TestDiffSyncer_Sync(t *testing.T) {
spaceSettingsId := "spaceSettingsId" spaceSettingsId := "spaceSettingsId"
credential := []byte("credential") credential := []byte("credential")
peerManagerMock.EXPECT(). fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mockPeer{}}, nil) Return([]peer.Peer{mockPeer{}}, nil)
diffMock.EXPECT(). fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return(nil, nil, nil, spacesyncproto.ErrSpaceMissing) Return(nil, nil, nil, spacesyncproto.ErrSpaceMissing)
stMock.EXPECT().AclStorage().Return(aclStorageMock, nil) fx.storageMock.EXPECT().AclStorage().Return(aclStorageMock, nil)
stMock.EXPECT().SpaceHeader().Return(spaceHeader, nil) fx.storageMock.EXPECT().SpaceHeader().Return(spaceHeader, nil)
stMock.EXPECT().SpaceSettingsId().Return(spaceSettingsId) fx.storageMock.EXPECT().SpaceSettingsId().Return(spaceSettingsId)
stMock.EXPECT().TreeStorage(spaceSettingsId).Return(settingsStorage, nil) fx.storageMock.EXPECT().TreeStorage(spaceSettingsId).Return(settingsStorage, nil)
settingsStorage.EXPECT().Root().Return(settingsRoot, nil) settingsStorage.EXPECT().Root().Return(settingsRoot, nil)
aclStorageMock.EXPECT(). aclStorageMock.EXPECT().
Root(). Root().
Return(aclRoot, nil) Return(aclRoot, nil)
credentialProvider.EXPECT(). fx.credentialProviderMock.EXPECT().
GetCredential(gomock.Any(), spaceHeader). GetCredential(gomock.Any(), spaceHeader).
Return(credential, nil) Return(credential, nil)
clientMock.EXPECT(). fx.clientMock.EXPECT().
SpacePush(gomock.Any(), newPushSpaceRequestMatcher(spaceId, aclRootId, settingsId, credential, spaceHeader)). SpacePush(gomock.Any(), newPushSpaceRequestMatcher(fx.spaceState.SpaceId, aclRootId, settingsId, credential, spaceHeader)).
Return(nil, nil) Return(nil, nil)
peerManagerMock.EXPECT().SendPeer(gomock.Any(), "mockId", gomock.Any()) fx.peerManagerMock.EXPECT().SendPeer(gomock.Any(), "peerId", gomock.Any())
require.NoError(t, diffSyncer.Sync(ctx)) require.NoError(t, fx.diffSyncer.Sync(ctx))
}) })
t.Run("diff syncer sync unexpected", func(t *testing.T) { t.Run("diff syncer sync unexpected", func(t *testing.T) {
peerManagerMock.EXPECT(). fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mockPeer{}}, nil) Return([]peer.Peer{mockPeer{}}, nil)
diffMock.EXPECT(). fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return(nil, nil, nil, spacesyncproto.ErrUnexpected) Return(nil, nil, nil, spacesyncproto.ErrUnexpected)
require.NoError(t, diffSyncer.Sync(ctx)) require.NoError(t, fx.diffSyncer.Sync(ctx))
}) })
t.Run("diff syncer sync space is deleted error", func(t *testing.T) { t.Run("diff syncer sync space is deleted error", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.initDiffSyncer(t)
defer fx.stop()
mPeer := mockPeer{} mPeer := mockPeer{}
peerManagerMock.EXPECT(). fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.peerManagerMock.EXPECT().
GetResponsiblePeers(gomock.Any()). GetResponsiblePeers(gomock.Any()).
Return([]peer.Peer{mPeer}, nil) Return([]peer.Peer{mPeer}, nil)
diffMock.EXPECT(). fx.diffMock.EXPECT().
Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(spaceId, clientMock))). Diff(gomock.Any(), gomock.Eq(NewRemoteDiff(fx.spaceState.SpaceId, fx.clientMock))).
Return(nil, nil, nil, spacesyncproto.ErrSpaceIsDeleted) Return(nil, nil, nil, spacesyncproto.ErrSpaceIsDeleted)
stMock.EXPECT().SpaceSettingsId().Return("settingsId") fx.storageMock.EXPECT().SpaceSettingsId().Return("settingsId")
treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"settingsId"}, nil).Return(nil) fx.treeSyncerMock.EXPECT().SyncAll(gomock.Any(), mPeer.Id(), []string{"settingsId"}, nil).Return(nil)
require.NoError(t, diffSyncer.Sync(ctx)) require.NoError(t, fx.diffSyncer.Sync(ctx))
}) })
} }

View File

@ -3,123 +3,150 @@ package headsync
import ( import (
"context" "context"
"sync/atomic"
"time"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
config2 "github.com/anyproto/any-sync/commonspace/config"
"github.com/anyproto/any-sync/commonspace/credentialprovider" "github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/peermanager" "github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/util/periodicsync" "github.com/anyproto/any-sync/util/periodicsync"
"github.com/anyproto/any-sync/util/slice"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"strings"
"sync/atomic"
"time"
) )
var log = logger.NewNamed(CName)
const CName = "common.commonspace.headsync"
type TreeHeads struct { type TreeHeads struct {
Id string Id string
Heads []string Heads []string
} }
type HeadSync interface { type HeadSync interface {
Init(objectIds []string, deletionState settingsstate.ObjectDeletionState) app.ComponentRunnable
ExternalIds() []string
DebugAllHeads() (res []TreeHeads)
AllIds() []string
UpdateHeads(id string, heads []string) UpdateHeads(id string, heads []string)
HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error)
RemoveObjects(ids []string) RemoveObjects(ids []string)
AllIds() []string
DebugAllHeads() (res []TreeHeads)
Close() (err error)
} }
type headSync struct { type headSync struct {
spaceId string spaceId string
periodicSync periodicsync.PeriodicSync
storage spacestorage.SpaceStorage
diff ldiff.Diff
log logger.CtxLogger
syncer DiffSyncer
configuration nodeconf.NodeConf
spaceIsDeleted *atomic.Bool spaceIsDeleted *atomic.Bool
syncPeriod int
syncPeriod int periodicSync periodicsync.PeriodicSync
storage spacestorage.SpaceStorage
diff ldiff.Diff
log logger.CtxLogger
syncer DiffSyncer
configuration nodeconf.NodeConf
peerManager peermanager.PeerManager
treeManager treemanager.TreeManager
credentialProvider credentialprovider.CredentialProvider
syncStatus syncstatus.StatusService
deletionState deletionstate.ObjectDeletionState
syncAcl syncacl.SyncAcl
} }
func NewHeadSync( func New() HeadSync {
spaceId string, return &headSync{}
spaceIsDeleted *atomic.Bool, }
syncPeriod int,
configuration nodeconf.NodeConf,
storage spacestorage.SpaceStorage,
peerManager peermanager.PeerManager,
cache treemanager.TreeManager,
syncStatus syncstatus.StatusUpdater,
credentialProvider credentialprovider.CredentialProvider,
log logger.CtxLogger) HeadSync {
diff := ldiff.New(16, 16) var createDiffSyncer = newDiffSyncer
l := log.With(zap.String("spaceId", spaceId))
factory := spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient) func (h *headSync) Init(a *app.App) (err error) {
syncer := newDiffSyncer(spaceId, diff, peerManager, cache, storage, factory, syncStatus, credentialProvider, l) shared := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
cfg := a.MustComponent("config").(config2.ConfigGetter)
h.syncAcl = a.MustComponent(syncacl.CName).(syncacl.SyncAcl)
h.spaceId = shared.SpaceId
h.spaceIsDeleted = shared.SpaceIsDeleted
h.syncPeriod = cfg.GetSpace().SyncPeriod
h.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
h.log = log.With(zap.String("spaceId", h.spaceId))
h.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
h.diff = ldiff.New(16, 16)
h.peerManager = a.MustComponent(peermanager.CName).(peermanager.PeerManager)
h.credentialProvider = a.MustComponent(credentialprovider.CName).(credentialprovider.CredentialProvider)
h.syncStatus = a.MustComponent(syncstatus.CName).(syncstatus.StatusService)
h.treeManager = a.MustComponent(treemanager.CName).(treemanager.TreeManager)
h.deletionState = a.MustComponent(deletionstate.CName).(deletionstate.ObjectDeletionState)
h.syncer = createDiffSyncer(h)
sync := func(ctx context.Context) (err error) { sync := func(ctx context.Context) (err error) {
// for clients cancelling the sync process // for clients cancelling the sync process
if spaceIsDeleted.Load() && !configuration.IsResponsible(spaceId) { if h.spaceIsDeleted.Load() && !h.configuration.IsResponsible(h.spaceId) {
return spacesyncproto.ErrSpaceIsDeleted return spacesyncproto.ErrSpaceIsDeleted
} }
return syncer.Sync(ctx) return h.syncer.Sync(ctx)
}
periodicSync := periodicsync.NewPeriodicSync(syncPeriod, time.Minute, sync, l)
return &headSync{
spaceId: spaceId,
storage: storage,
syncer: syncer,
periodicSync: periodicSync,
diff: diff,
log: log,
syncPeriod: syncPeriod,
configuration: configuration,
spaceIsDeleted: spaceIsDeleted,
} }
h.periodicSync = periodicsync.NewPeriodicSync(h.syncPeriod, time.Minute, sync, h.log)
h.syncAcl.SetHeadUpdater(h)
// TODO: move to run?
h.syncer.Init()
return nil
} }
func (d *headSync) Init(objectIds []string, deletionState settingsstate.ObjectDeletionState) { func (h *headSync) Name() (name string) {
d.fillDiff(objectIds) return CName
d.syncer.Init(deletionState)
d.periodicSync.Run()
} }
func (d *headSync) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) { func (h *headSync) Run(ctx context.Context) (err error) {
if d.spaceIsDeleted.Load() { initialIds, err := h.storage.StoredIds()
if err != nil {
return
}
h.fillDiff(initialIds)
h.periodicSync.Run()
return
}
func (h *headSync) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) {
if h.spaceIsDeleted.Load() {
peerId, err := peer.CtxPeerId(ctx) peerId, err := peer.CtxPeerId(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// stop receiving all request for sync from clients // stop receiving all request for sync from clients
if !slices.Contains(d.configuration.NodeIds(d.spaceId), peerId) { if !slices.Contains(h.configuration.NodeIds(h.spaceId), peerId) {
return nil, spacesyncproto.ErrSpaceIsDeleted return nil, spacesyncproto.ErrSpaceIsDeleted
} }
} }
return HandleRangeRequest(ctx, d.diff, req) return HandleRangeRequest(ctx, h.diff, req)
} }
func (d *headSync) UpdateHeads(id string, heads []string) { func (h *headSync) UpdateHeads(id string, heads []string) {
d.syncer.UpdateHeads(id, heads) h.syncer.UpdateHeads(id, heads)
} }
func (d *headSync) AllIds() []string { func (h *headSync) AllIds() []string {
return d.diff.Ids() return h.diff.Ids()
} }
func (d *headSync) DebugAllHeads() (res []TreeHeads) { func (h *headSync) ExternalIds() []string {
els := d.diff.Elements() settingsId := h.storage.SpaceSettingsId()
return slice.DiscardFromSlice(h.AllIds(), func(id string) bool {
return id == settingsId
})
}
func (h *headSync) DebugAllHeads() (res []TreeHeads) {
els := h.diff.Elements()
for _, el := range els { for _, el := range els {
idHead := TreeHeads{ idHead := TreeHeads{
Id: el.Id, Id: el.Id,
@ -130,19 +157,19 @@ func (d *headSync) DebugAllHeads() (res []TreeHeads) {
return return
} }
func (d *headSync) RemoveObjects(ids []string) { func (h *headSync) RemoveObjects(ids []string) {
d.syncer.RemoveObjects(ids) h.syncer.RemoveObjects(ids)
} }
func (d *headSync) Close() (err error) { func (h *headSync) Close(ctx context.Context) (err error) {
d.periodicSync.Close() h.periodicSync.Close()
return d.syncer.Close() return h.syncer.Close()
} }
func (d *headSync) fillDiff(objectIds []string) { func (h *headSync) fillDiff(objectIds []string) {
var els = make([]ldiff.Element, 0, len(objectIds)) var els = make([]ldiff.Element, 0, len(objectIds))
for _, id := range objectIds { for _, id := range objectIds {
st, err := d.storage.TreeStorage(id) st, err := h.storage.TreeStorage(id)
if err != nil { if err != nil {
continue continue
} }
@ -155,32 +182,12 @@ func (d *headSync) fillDiff(objectIds []string) {
Head: concatStrings(heads), Head: concatStrings(heads),
}) })
} }
d.diff.Set(els...) els = append(els, ldiff.Element{
if err := d.storage.WriteSpaceHash(d.diff.Hash()); err != nil { Id: h.syncAcl.Id(),
d.log.Error("can't write space hash", zap.Error(err)) Head: h.syncAcl.Head().Id,
})
h.diff.Set(els...)
if err := h.storage.WriteSpaceHash(h.diff.Hash()); err != nil {
h.log.Error("can't write space hash", zap.Error(err))
} }
} }
func concatStrings(strs []string) string {
var (
b strings.Builder
totalLen int
)
for _, s := range strs {
totalLen += len(s)
}
b.Grow(totalLen)
for _, s := range strs {
b.WriteString(s)
}
return b.String()
}
func splitString(str string) (res []string) {
const cidLen = 59
for i := 0; i < len(str); i += cidLen {
res = append(res, str[i:i+cidLen])
}
return
}

View File

@ -1,71 +1,190 @@
package headsync package headsync
import ( import (
"context"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/ldiff" "github.com/anyproto/any-sync/app/ldiff"
"github.com/anyproto/any-sync/app/ldiff/mock_ldiff" "github.com/anyproto/any-sync/app/ldiff/mock_ldiff"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/commonspace/config"
"github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anyproto/any-sync/commonspace/credentialprovider/mock_credentialprovider"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate"
"github.com/anyproto/any-sync/commonspace/headsync/mock_headsync" "github.com/anyproto/any-sync/commonspace/headsync/mock_headsync"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl/mock_syncacl"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage/mock_treestorage"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/peermanager/mock_peermanager"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/anyproto/any-sync/util/periodicsync/mock_periodicsync" "github.com/anyproto/any-sync/commonspace/spacesyncproto/mock_spacesyncproto"
"github.com/golang/mock/gomock" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/nodeconf/mock_nodeconf"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"sync/atomic"
"testing" "testing"
) )
func TestDiffService(t *testing.T) { type mockConfig struct {
ctrl := gomock.NewController(t) }
defer ctrl.Finish()
spaceId := "spaceId" func (m mockConfig) Init(a *app.App) (err error) {
l := logger.NewNamed("sync") return nil
pSyncMock := mock_periodicsync.NewMockPeriodicSync(ctrl) }
storageMock := mock_spacestorage.NewMockSpaceStorage(ctrl)
treeStorageMock := mock_treestorage.NewMockTreeStorage(ctrl)
diffMock := mock_ldiff.NewMockDiff(ctrl)
syncer := mock_headsync.NewMockDiffSyncer(ctrl)
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl)
syncPeriod := 1
initId := "initId"
service := &headSync{ func (m mockConfig) Name() (name string) {
spaceId: spaceId, return "config"
storage: storageMock, }
periodicSync: pSyncMock,
syncer: syncer, func (m mockConfig) GetSpace() config.Config {
diff: diffMock, return config.Config{}
log: l, }
syncPeriod: syncPeriod,
type headSyncFixture struct {
spaceState *spacestate.SpaceState
ctrl *gomock.Controller
app *app.App
configurationMock *mock_nodeconf.MockService
storageMock *mock_spacestorage.MockSpaceStorage
peerManagerMock *mock_peermanager.MockPeerManager
credentialProviderMock *mock_credentialprovider.MockCredentialProvider
syncStatus syncstatus.StatusService
treeManagerMock *mock_treemanager.MockTreeManager
deletionStateMock *mock_deletionstate.MockObjectDeletionState
diffSyncerMock *mock_headsync.MockDiffSyncer
treeSyncerMock *mock_treemanager.MockTreeSyncer
diffMock *mock_ldiff.MockDiff
clientMock *mock_spacesyncproto.MockDRPCSpaceSyncClient
aclMock *mock_syncacl.MockSyncAcl
headSync *headSync
diffSyncer *diffSyncer
}
func newHeadSyncFixture(t *testing.T) *headSyncFixture {
spaceState := &spacestate.SpaceState{
SpaceId: "spaceId",
SpaceIsDeleted: &atomic.Bool{},
} }
ctrl := gomock.NewController(t)
configurationMock := mock_nodeconf.NewMockService(ctrl)
configurationMock.EXPECT().Name().AnyTimes().Return(nodeconf.CName)
storageMock := mock_spacestorage.NewMockSpaceStorage(ctrl)
storageMock.EXPECT().Name().AnyTimes().Return(spacestorage.CName)
peerManagerMock := mock_peermanager.NewMockPeerManager(ctrl)
peerManagerMock.EXPECT().Name().AnyTimes().Return(peermanager.CName)
credentialProviderMock := mock_credentialprovider.NewMockCredentialProvider(ctrl)
credentialProviderMock.EXPECT().Name().AnyTimes().Return(credentialprovider.CName)
syncStatus := syncstatus.NewNoOpSyncStatus()
treeManagerMock := mock_treemanager.NewMockTreeManager(ctrl)
treeManagerMock.EXPECT().Name().AnyTimes().Return(treemanager.CName)
deletionStateMock := mock_deletionstate.NewMockObjectDeletionState(ctrl)
deletionStateMock.EXPECT().Name().AnyTimes().Return(deletionstate.CName)
diffSyncerMock := mock_headsync.NewMockDiffSyncer(ctrl)
treeSyncerMock := mock_treemanager.NewMockTreeSyncer(ctrl)
diffMock := mock_ldiff.NewMockDiff(ctrl)
clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl)
aclMock := mock_syncacl.NewMockSyncAcl(ctrl)
aclMock.EXPECT().Name().AnyTimes().Return(syncacl.CName)
aclMock.EXPECT().SetHeadUpdater(gomock.Any()).AnyTimes()
hs := &headSync{}
a := &app.App{}
a.Register(spaceState).
Register(aclMock).
Register(mockConfig{}).
Register(configurationMock).
Register(storageMock).
Register(peerManagerMock).
Register(credentialProviderMock).
Register(syncStatus).
Register(treeManagerMock).
Register(deletionStateMock).
Register(hs)
return &headSyncFixture{
spaceState: spaceState,
ctrl: ctrl,
app: a,
configurationMock: configurationMock,
storageMock: storageMock,
peerManagerMock: peerManagerMock,
credentialProviderMock: credentialProviderMock,
syncStatus: syncStatus,
treeManagerMock: treeManagerMock,
deletionStateMock: deletionStateMock,
headSync: hs,
diffSyncerMock: diffSyncerMock,
treeSyncerMock: treeSyncerMock,
diffMock: diffMock,
clientMock: clientMock,
aclMock: aclMock,
}
}
t.Run("init", func(t *testing.T) { func (fx *headSyncFixture) init(t *testing.T) {
storageMock.EXPECT().TreeStorage(initId).Return(treeStorageMock, nil) createDiffSyncer = func(hs *headSync) DiffSyncer {
treeStorageMock.EXPECT().Heads().Return([]string{"h1", "h2"}, nil) return fx.diffSyncerMock
syncer.EXPECT().Init(delState) }
diffMock.EXPECT().Set(ldiff.Element{ fx.diffSyncerMock.EXPECT().Init()
Id: initId, err := fx.headSync.Init(fx.app)
require.NoError(t, err)
fx.headSync.diff = fx.diffMock
}
func (fx *headSyncFixture) stop() {
fx.ctrl.Finish()
}
func TestHeadSync(t *testing.T) {
ctx := context.Background()
t.Run("run close", func(t *testing.T) {
fx := newHeadSyncFixture(t)
fx.init(t)
defer fx.stop()
ids := []string{"id1"}
treeMock := mock_treestorage.NewMockTreeStorage(fx.ctrl)
fx.storageMock.EXPECT().StoredIds().Return(ids, nil)
fx.storageMock.EXPECT().TreeStorage(ids[0]).Return(treeMock, nil)
fx.aclMock.EXPECT().Id().AnyTimes().Return("aclId")
fx.aclMock.EXPECT().Head().AnyTimes().Return(&list.AclRecord{Id: "headId"})
treeMock.EXPECT().Heads().Return([]string{"h1", "h2"}, nil)
fx.diffMock.EXPECT().Set(ldiff.Element{
Id: "id1",
Head: "h1h2", Head: "h1h2",
}) })
hash := "123" fx.diffMock.EXPECT().Hash().Return("hash")
diffMock.EXPECT().Hash().Return(hash) fx.storageMock.EXPECT().WriteSpaceHash("hash").Return(nil)
storageMock.EXPECT().WriteSpaceHash(hash) fx.diffSyncerMock.EXPECT().Sync(gomock.Any()).Return(nil)
pSyncMock.EXPECT().Run() fx.diffSyncerMock.EXPECT().Close().Return(nil)
service.Init([]string{initId}, delState) err := fx.headSync.Run(ctx)
require.NoError(t, err)
err = fx.headSync.Close(ctx)
require.NoError(t, err)
}) })
t.Run("update heads", func(t *testing.T) { t.Run("update heads", func(t *testing.T) {
syncer.EXPECT().UpdateHeads(initId, []string{"h1", "h2"}) fx := newHeadSyncFixture(t)
service.UpdateHeads(initId, []string{"h1", "h2"}) fx.init(t)
defer fx.stop()
fx.diffSyncerMock.EXPECT().UpdateHeads("id1", []string{"h1"})
fx.headSync.UpdateHeads("id1", []string{"h1"})
}) })
t.Run("remove objects", func(t *testing.T) { t.Run("remove objects", func(t *testing.T) {
syncer.EXPECT().RemoveObjects([]string{"h1", "h2"}) fx := newHeadSyncFixture(t)
service.RemoveObjects([]string{"h1", "h2"}) fx.init(t)
}) defer fx.stop()
t.Run("close", func(t *testing.T) { fx.diffSyncerMock.EXPECT().RemoveObjects([]string{"id1"})
pSyncMock.EXPECT().Close() fx.headSync.RemoveObjects([]string{"id1"})
syncer.EXPECT().Close()
service.Close()
}) })
} }

View File

@ -8,8 +8,7 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
settingsstate "github.com/anyproto/any-sync/commonspace/settings/settingsstate" gomock "go.uber.org/mock/gomock"
gomock "github.com/golang/mock/gomock"
) )
// MockDiffSyncer is a mock of DiffSyncer interface. // MockDiffSyncer is a mock of DiffSyncer interface.
@ -50,15 +49,15 @@ func (mr *MockDiffSyncerMockRecorder) Close() *gomock.Call {
} }
// Init mocks base method. // Init mocks base method.
func (m *MockDiffSyncer) Init(arg0 settingsstate.ObjectDeletionState) { func (m *MockDiffSyncer) Init() {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Init", arg0) m.ctrl.Call(m, "Init")
} }
// Init indicates an expected call of Init. // Init indicates an expected call of Init.
func (mr *MockDiffSyncerMockRecorder) Init(arg0 interface{}) *gomock.Call { func (mr *MockDiffSyncerMockRecorder) Init() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockDiffSyncer)(nil).Init), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockDiffSyncer)(nil).Init))
} }
// RemoveObjects mocks base method. // RemoveObjects mocks base method.

View File

@ -0,0 +1,27 @@
package headsync
import "strings"
func concatStrings(strs []string) string {
var (
b strings.Builder
totalLen int
)
for _, s := range strs {
totalLen += len(s)
}
b.Grow(totalLen)
for _, s := range strs {
b.WriteString(s)
}
return b.String()
}
func splitString(str string) (res []string) {
const cidLen = 59
for i := 0; i < len(str); i += cidLen {
res = append(res, str[i:i+cidLen])
}
return
}

File diff suppressed because it is too large Load Diff

View File

@ -2,26 +2,7 @@ syntax = "proto3";
package aclrecord; package aclrecord;
option go_package = "commonspace/object/acl/aclrecordproto"; option go_package = "commonspace/object/acl/aclrecordproto";
message RawAclRecord { // AclRoot is a root of access control list
bytes payload = 1;
bytes signature = 2;
bytes acceptorIdentity = 3;
bytes acceptorSignature = 4;
}
message RawAclRecordWithId {
bytes payload = 1;
string id = 2;
}
message AclRecord {
string prevId = 1;
bytes identity = 2;
bytes data = 3;
string readKeyId = 4;
int64 timestamp = 5;
}
message AclRoot { message AclRoot {
bytes identity = 1; bytes identity = 1;
bytes masterKey = 2; bytes masterKey = 2;
@ -31,82 +12,95 @@ message AclRoot {
bytes identitySignature = 6; bytes identitySignature = 6;
} }
message AclContentValue { // AclAccountInvite contains the public invite key, the private part of which is sent to the user directly
oneof value { message AclAccountInvite {
AclUserAdd userAdd = 1; bytes inviteKey = 1;
AclUserRemove userRemove = 2;
AclUserPermissionChange userPermissionChange = 3;
AclUserInvite userInvite = 4;
AclUserJoin userJoin = 5;
}
} }
message AclData { // AclAccountRequestJoin contains the reference to the invite record and the data of the person who wants to join, confirmed by the private invite key
repeated AclContentValue aclContent = 1; message AclAccountRequestJoin {
bytes inviteIdentity = 1;
string inviteRecordId = 2;
bytes inviteIdentitySignature = 3;
bytes metadata = 4;
} }
message AclState { // AclAccountRequestAccept contains the reference to join record and all read keys, encrypted with the identity of the requestor
repeated string readKeyIds = 1; message AclAccountRequestAccept {
repeated AclUserState userStates = 2;
map<string, AclUserInvite> invites = 3;
}
message AclUserState {
bytes identity = 1; bytes identity = 1;
AclUserPermissions permissions = 2; string requestRecordId = 2;
repeated AclReadKeyWithRecord encryptedReadKeys = 3;
AclUserPermissions permissions = 4;
} }
message AclUserAdd { // AclAccountRequestDecline contains the reference to join record
bytes identity = 1; message AclAccountRequestDecline {
repeated bytes encryptedReadKeys = 2; string requestRecordId = 1;
AclUserPermissions permissions = 3;
} }
message AclUserInvite { // AclAccountInviteRevoke revokes the invite record
bytes acceptPublicKey = 1; message AclAccountInviteRevoke {
repeated bytes encryptedReadKeys = 2; string inviteRecordId = 1;
AclUserPermissions permissions = 3;
} }
message AclUserJoin { // AclReadKeys are a read key with record id
bytes identity = 1; message AclReadKeyWithRecord {
bytes acceptSignature = 2; string recordId = 1;
bytes acceptPubKey = 3; bytes encryptedReadKey = 2;
repeated bytes encryptedReadKeys = 4;
} }
message AclUserRemove { // AclEncryptedReadKeys are new key for specific identity
bytes identity = 1; message AclEncryptedReadKey {
repeated AclReadKeyReplace readKeyReplaces = 2;
}
message AclReadKeyReplace {
bytes identity = 1; bytes identity = 1;
bytes encryptedReadKey = 2; bytes encryptedReadKey = 2;
} }
message AclUserPermissionChange { // AclAccountPermissionChange changes permissions of specific account
message AclAccountPermissionChange {
bytes identity = 1; bytes identity = 1;
AclUserPermissions permissions = 2; AclUserPermissions permissions = 2;
} }
enum AclUserPermissions { // AclReadKeyChange changes the key for a space
Admin = 0; message AclReadKeyChange {
Writer = 1; repeated AclEncryptedReadKey accountKeys = 1;
Reader = 2;
} }
message AclSyncMessage { // AclAccountRemove removes an account and changes read key for space
AclSyncContentValue content = 1; message AclAccountRemove {
repeated bytes identities = 1;
repeated AclEncryptedReadKey accountKeys = 2;
} }
// AclSyncContentValue provides different types for acl sync // AclAccountRequestRemove adds a request to remove an account
message AclSyncContentValue { message AclAccountRequestRemove {
}
// AclContentValue contains possible values for Acl
message AclContentValue {
oneof value { oneof value {
AclAddRecords addRecords = 1; AclAccountInvite invite = 1;
AclAccountInviteRevoke inviteRevoke = 2;
AclAccountRequestJoin requestJoin = 3;
AclAccountRequestAccept requestAccept = 4;
AclAccountPermissionChange permissionChange = 5;
AclAccountRemove accountRemove = 6;
AclReadKeyChange readKeyChange = 7;
AclAccountRequestDecline requestDecline = 8;
AclAccountRequestRemove accountRequestRemove = 9;
} }
} }
message AclAddRecords { // AclData contains different acl content
repeated RawAclRecordWithId records = 1; message AclData {
} repeated AclContentValue aclContent = 1;
}
// AclUserPermissions contains different possible user roles
enum AclUserPermissions {
None = 0;
Owner = 1;
Admin = 2;
Writer = 3;
Reader = 4;
}

View File

@ -1,11 +1,14 @@
package list package list
import ( import (
"time"
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/util/cidutil" "github.com/anyproto/any-sync/util/cidutil"
"github.com/anyproto/any-sync/util/crypto" "github.com/anyproto/any-sync/util/crypto"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"time"
) )
type RootContent struct { type RootContent struct {
@ -15,26 +18,387 @@ type RootContent struct {
EncryptedReadKey []byte EncryptedReadKey []byte
} }
type RequestJoinPayload struct {
InviteRecordId string
InviteKey crypto.PrivKey
Metadata []byte
}
type RequestAcceptPayload struct {
RequestRecordId string
Permissions AclPermissions
}
type PermissionChangePayload struct {
Identity crypto.PubKey
Permissions AclPermissions
}
type AccountRemovePayload struct {
Identities []crypto.PubKey
ReadKey crypto.SymKey
}
type InviteResult struct {
InviteRec *consensusproto.RawRecord
InviteKey crypto.PrivKey
}
type AclRecordBuilder interface { type AclRecordBuilder interface {
Unmarshall(rawIdRecord *aclrecordproto.RawAclRecordWithId) (rec *AclRecord, err error) UnmarshallWithId(rawIdRecord *consensusproto.RawRecordWithId) (rec *AclRecord, err error)
BuildRoot(content RootContent) (rec *aclrecordproto.RawAclRecordWithId, err error) Unmarshall(rawRecord *consensusproto.RawRecord) (rec *AclRecord, err error)
BuildRoot(content RootContent) (rec *consensusproto.RawRecordWithId, err error)
BuildInvite() (res InviteResult, err error)
BuildInviteRevoke(inviteRecordId string) (rawRecord *consensusproto.RawRecord, err error)
BuildRequestJoin(payload RequestJoinPayload) (rawRecord *consensusproto.RawRecord, err error)
BuildRequestAccept(payload RequestAcceptPayload) (rawRecord *consensusproto.RawRecord, err error)
BuildRequestDecline(requestRecordId string) (rawRecord *consensusproto.RawRecord, err error)
BuildRequestRemove() (rawRecord *consensusproto.RawRecord, err error)
BuildPermissionChange(payload PermissionChangePayload) (rawRecord *consensusproto.RawRecord, err error)
BuildReadKeyChange(newKey crypto.SymKey) (rawRecord *consensusproto.RawRecord, err error)
BuildAccountRemove(payload AccountRemovePayload) (rawRecord *consensusproto.RawRecord, err error)
} }
type aclRecordBuilder struct { type aclRecordBuilder struct {
id string id string
keyStorage crypto.KeyStorage keyStorage crypto.KeyStorage
accountKeys *accountdata.AccountKeys
verifier AcceptorVerifier
state *AclState
} }
func NewAclRecordBuilder(id string, keyStorage crypto.KeyStorage) AclRecordBuilder { func NewAclRecordBuilder(id string, keyStorage crypto.KeyStorage, keys *accountdata.AccountKeys, verifier AcceptorVerifier) AclRecordBuilder {
return &aclRecordBuilder{ return &aclRecordBuilder{
id: id, id: id,
keyStorage: keyStorage, keyStorage: keyStorage,
accountKeys: keys,
verifier: verifier,
} }
} }
func (a *aclRecordBuilder) Unmarshall(rawIdRecord *aclrecordproto.RawAclRecordWithId) (rec *AclRecord, err error) { func (a *aclRecordBuilder) buildRecord(aclContent *aclrecordproto.AclContentValue) (rawRec *consensusproto.RawRecord, err error) {
aclData := &aclrecordproto.AclData{AclContent: []*aclrecordproto.AclContentValue{
aclContent,
}}
marshalledData, err := aclData.Marshal()
if err != nil {
return
}
protoKey, err := a.accountKeys.SignKey.GetPublic().Marshall()
if err != nil {
return
}
rec := &consensusproto.Record{
PrevId: a.state.lastRecordId,
Identity: protoKey,
Data: marshalledData,
Timestamp: time.Now().Unix(),
}
marshalledRec, err := rec.Marshal()
if err != nil {
return
}
signature, err := a.accountKeys.SignKey.Sign(marshalledRec)
if err != nil {
return
}
rawRec = &consensusproto.RawRecord{
Payload: marshalledRec,
Signature: signature,
}
return
}
func (a *aclRecordBuilder) BuildInvite() (res InviteResult, err error) {
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
privKey, pubKey, err := crypto.GenerateRandomEd25519KeyPair()
if err != nil {
return
}
invitePubKey, err := pubKey.Marshall()
if err != nil {
return
}
inviteRec := &aclrecordproto.AclAccountInvite{InviteKey: invitePubKey}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_Invite{Invite: inviteRec}}
rawRec, err := a.buildRecord(content)
if err != nil {
return
}
res.InviteKey = privKey
res.InviteRec = rawRec
return
}
func (a *aclRecordBuilder) BuildInviteRevoke(inviteRecordId string) (rawRecord *consensusproto.RawRecord, err error) {
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
_, exists := a.state.inviteKeys[inviteRecordId]
if !exists {
err = ErrNoSuchInvite
return
}
revokeRec := &aclrecordproto.AclAccountInviteRevoke{InviteRecordId: inviteRecordId}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_InviteRevoke{InviteRevoke: revokeRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildRequestJoin(payload RequestJoinPayload) (rawRecord *consensusproto.RawRecord, err error) {
key, exists := a.state.inviteKeys[payload.InviteRecordId]
if !exists {
err = ErrNoSuchInvite
return
}
if !payload.InviteKey.GetPublic().Equals(key) {
err = ErrIncorrectInviteKey
}
rawIdentity, err := a.accountKeys.SignKey.GetPublic().Raw()
if err != nil {
return
}
signature, err := payload.InviteKey.Sign(rawIdentity)
if err != nil {
return
}
protoIdentity, err := a.accountKeys.SignKey.GetPublic().Marshall()
if err != nil {
return
}
joinRec := &aclrecordproto.AclAccountRequestJoin{
InviteIdentity: protoIdentity,
InviteRecordId: payload.InviteRecordId,
InviteIdentitySignature: signature,
Metadata: payload.Metadata,
}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_RequestJoin{RequestJoin: joinRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildRequestAccept(payload RequestAcceptPayload) (rawRecord *consensusproto.RawRecord, err error) {
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
request, exists := a.state.requestRecords[payload.RequestRecordId]
if !exists {
err = ErrNoSuchRequest
return
}
var encryptedReadKeys []*aclrecordproto.AclReadKeyWithRecord
for keyId, key := range a.state.userReadKeys {
rawKey, err := key.Raw()
if err != nil {
return nil, err
}
enc, err := request.RequestIdentity.Encrypt(rawKey)
if err != nil {
return nil, err
}
encryptedReadKeys = append(encryptedReadKeys, &aclrecordproto.AclReadKeyWithRecord{
RecordId: keyId,
EncryptedReadKey: enc,
})
}
if err != nil {
return
}
requestIdentityProto, err := request.RequestIdentity.Marshall()
if err != nil {
return
}
acceptRec := &aclrecordproto.AclAccountRequestAccept{
Identity: requestIdentityProto,
RequestRecordId: payload.RequestRecordId,
EncryptedReadKeys: encryptedReadKeys,
Permissions: aclrecordproto.AclUserPermissions(payload.Permissions),
}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_RequestAccept{RequestAccept: acceptRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildRequestDecline(requestRecordId string) (rawRecord *consensusproto.RawRecord, err error) {
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
_, exists := a.state.requestRecords[requestRecordId]
if !exists {
err = ErrNoSuchRequest
return
}
declineRec := &aclrecordproto.AclAccountRequestDecline{RequestRecordId: requestRecordId}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_RequestDecline{RequestDecline: declineRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildPermissionChange(payload PermissionChangePayload) (rawRecord *consensusproto.RawRecord, err error) {
permissions := a.state.Permissions(a.state.pubKey)
if !permissions.CanManageAccounts() || payload.Identity.Equals(a.state.pubKey) {
err = ErrInsufficientPermissions
return
}
if payload.Permissions.IsOwner() {
err = ErrIsOwner
return
}
protoIdentity, err := payload.Identity.Marshall()
if err != nil {
return
}
permissionRec := &aclrecordproto.AclAccountPermissionChange{
Identity: protoIdentity,
Permissions: aclrecordproto.AclUserPermissions(payload.Permissions),
}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_PermissionChange{PermissionChange: permissionRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildReadKeyChange(newKey crypto.SymKey) (rawRecord *consensusproto.RawRecord, err error) {
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
rawKey, err := newKey.Raw()
if err != nil {
return
}
if len(rawKey) != crypto.KeyBytes {
err = ErrIncorrectReadKey
return
}
var aclReadKeys []*aclrecordproto.AclEncryptedReadKey
for _, st := range a.state.userStates {
protoIdentity, err := st.PubKey.Marshall()
if err != nil {
return nil, err
}
enc, err := st.PubKey.Encrypt(rawKey)
if err != nil {
return nil, err
}
aclReadKeys = append(aclReadKeys, &aclrecordproto.AclEncryptedReadKey{
Identity: protoIdentity,
EncryptedReadKey: enc,
})
}
readRec := &aclrecordproto.AclReadKeyChange{AccountKeys: aclReadKeys}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_ReadKeyChange{ReadKeyChange: readRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildAccountRemove(payload AccountRemovePayload) (rawRecord *consensusproto.RawRecord, err error) {
deletedMap := map[string]struct{}{}
for _, key := range payload.Identities {
permissions := a.state.Permissions(key)
if permissions.IsOwner() {
return nil, ErrInsufficientPermissions
}
if permissions.NoPermissions() {
return nil, ErrNoSuchAccount
}
deletedMap[mapKeyFromPubKey(key)] = struct{}{}
}
if !a.state.Permissions(a.state.pubKey).CanManageAccounts() {
err = ErrInsufficientPermissions
return
}
rawKey, err := payload.ReadKey.Raw()
if err != nil {
return
}
if len(rawKey) != crypto.KeyBytes {
err = ErrIncorrectReadKey
return
}
var aclReadKeys []*aclrecordproto.AclEncryptedReadKey
for _, st := range a.state.userStates {
if _, exists := deletedMap[mapKeyFromPubKey(st.PubKey)]; exists {
continue
}
protoIdentity, err := st.PubKey.Marshall()
if err != nil {
return nil, err
}
enc, err := st.PubKey.Encrypt(rawKey)
if err != nil {
return nil, err
}
aclReadKeys = append(aclReadKeys, &aclrecordproto.AclEncryptedReadKey{
Identity: protoIdentity,
EncryptedReadKey: enc,
})
}
var marshalledIdentities [][]byte
for _, key := range payload.Identities {
protoIdentity, err := key.Marshall()
if err != nil {
return nil, err
}
marshalledIdentities = append(marshalledIdentities, protoIdentity)
}
removeRec := &aclrecordproto.AclAccountRemove{AccountKeys: aclReadKeys, Identities: marshalledIdentities}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_AccountRemove{AccountRemove: removeRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) BuildRequestRemove() (rawRecord *consensusproto.RawRecord, err error) {
permissions := a.state.Permissions(a.state.pubKey)
if permissions.NoPermissions() {
err = ErrNoSuchAccount
return
}
if permissions.IsOwner() {
err = ErrIsOwner
return
}
removeRec := &aclrecordproto.AclAccountRequestRemove{}
content := &aclrecordproto.AclContentValue{Value: &aclrecordproto.AclContentValue_AccountRequestRemove{AccountRequestRemove: removeRec}}
return a.buildRecord(content)
}
func (a *aclRecordBuilder) Unmarshall(rawRecord *consensusproto.RawRecord) (rec *AclRecord, err error) {
aclRecord := &consensusproto.Record{}
err = proto.Unmarshal(rawRecord.Payload, aclRecord)
if err != nil {
return
}
pubKey, err := a.keyStorage.PubKeyFromProto(aclRecord.Identity)
if err != nil {
return
}
aclData := &aclrecordproto.AclData{}
err = proto.Unmarshal(aclRecord.Data, aclData)
if err != nil {
return
}
rec = &AclRecord{
PrevId: aclRecord.PrevId,
Timestamp: aclRecord.Timestamp,
Data: aclRecord.Data,
Signature: rawRecord.Signature,
Identity: pubKey,
Model: aclData,
}
res, err := pubKey.Verify(rawRecord.Payload, rawRecord.Signature)
if err != nil {
return
}
if !res {
err = ErrInvalidSignature
return
}
return
}
func (a *aclRecordBuilder) UnmarshallWithId(rawIdRecord *consensusproto.RawRecordWithId) (rec *AclRecord, err error) {
var ( var (
rawRec = &aclrecordproto.RawAclRecord{} rawRec = &consensusproto.RawRecord{}
pubKey crypto.PubKey pubKey crypto.PubKey
) )
err = proto.Unmarshal(rawIdRecord.Payload, rawRec) err = proto.Unmarshal(rawIdRecord.Payload, rawRec)
@ -53,14 +417,17 @@ func (a *aclRecordBuilder) Unmarshall(rawIdRecord *aclrecordproto.RawAclRecordWi
} }
rec = &AclRecord{ rec = &AclRecord{
Id: rawIdRecord.Id, Id: rawIdRecord.Id,
ReadKeyId: rawIdRecord.Id,
Timestamp: aclRoot.Timestamp, Timestamp: aclRoot.Timestamp,
Signature: rawRec.Signature, Signature: rawRec.Signature,
Identity: pubKey, Identity: pubKey,
Model: aclRoot, Model: aclRoot,
} }
} else { } else {
aclRecord := &aclrecordproto.AclRecord{} err = a.verifier.VerifyAcceptor(rawRec)
if err != nil {
return
}
aclRecord := &consensusproto.Record{}
err = proto.Unmarshal(rawRec.Payload, aclRecord) err = proto.Unmarshal(rawRec.Payload, aclRecord)
if err != nil { if err != nil {
return return
@ -69,14 +436,19 @@ func (a *aclRecordBuilder) Unmarshall(rawIdRecord *aclrecordproto.RawAclRecordWi
if err != nil { if err != nil {
return return
} }
aclData := &aclrecordproto.AclData{}
err = proto.Unmarshal(aclRecord.Data, aclData)
if err != nil {
return
}
rec = &AclRecord{ rec = &AclRecord{
Id: rawIdRecord.Id, Id: rawIdRecord.Id,
PrevId: aclRecord.PrevId, PrevId: aclRecord.PrevId,
ReadKeyId: aclRecord.ReadKeyId,
Timestamp: aclRecord.Timestamp, Timestamp: aclRecord.Timestamp,
Data: aclRecord.Data, Data: aclRecord.Data,
Signature: rawRec.Signature, Signature: rawRec.Signature,
Identity: pubKey, Identity: pubKey,
Model: aclData,
} }
} }
@ -84,7 +456,7 @@ func (a *aclRecordBuilder) Unmarshall(rawIdRecord *aclrecordproto.RawAclRecordWi
return return
} }
func (a *aclRecordBuilder) BuildRoot(content RootContent) (rec *aclrecordproto.RawAclRecordWithId, err error) { func (a *aclRecordBuilder) BuildRoot(content RootContent) (rec *consensusproto.RawRecordWithId, err error) {
rawIdentity, err := content.PrivKey.GetPublic().Raw() rawIdentity, err := content.PrivKey.GetPublic().Raw()
if err != nil { if err != nil {
return return
@ -118,8 +490,8 @@ func (a *aclRecordBuilder) BuildRoot(content RootContent) (rec *aclrecordproto.R
func verifyRaw( func verifyRaw(
pubKey crypto.PubKey, pubKey crypto.PubKey,
rawRec *aclrecordproto.RawAclRecord, rawRec *consensusproto.RawRecord,
recWithId *aclrecordproto.RawAclRecordWithId) (err error) { recWithId *consensusproto.RawRecordWithId) (err error) {
// verifying signature // verifying signature
res, err := pubKey.Verify(rawRec.Payload, rawRec.Signature) res, err := pubKey.Verify(rawRec.Payload, rawRec.Signature)
if err != nil { if err != nil {
@ -137,7 +509,7 @@ func verifyRaw(
return return
} }
func marshalAclRoot(aclRoot *aclrecordproto.AclRoot, key crypto.PrivKey) (rawWithId *aclrecordproto.RawAclRecordWithId, err error) { func marshalAclRoot(aclRoot *aclrecordproto.AclRoot, key crypto.PrivKey) (rawWithId *consensusproto.RawRecordWithId, err error) {
marshalledRoot, err := aclRoot.Marshal() marshalledRoot, err := aclRoot.Marshal()
if err != nil { if err != nil {
return return
@ -146,7 +518,7 @@ func marshalAclRoot(aclRoot *aclrecordproto.AclRoot, key crypto.PrivKey) (rawWit
if err != nil { if err != nil {
return return
} }
raw := &aclrecordproto.RawAclRecord{ raw := &consensusproto.RawRecord{
Payload: marshalledRoot, Payload: marshalledRoot,
Signature: signature, Signature: signature,
} }
@ -158,7 +530,7 @@ func marshalAclRoot(aclRoot *aclrecordproto.AclRoot, key crypto.PrivKey) (rawWit
if err != nil { if err != nil {
return return
} }
rawWithId = &aclrecordproto.RawAclRecordWithId{ rawWithId = &consensusproto.RawRecordWithId{
Payload: marshalledRaw, Payload: marshalledRaw,
Id: aclHeadId, Id: aclHeadId,
} }

View File

@ -1,9 +0,0 @@
package list
import (
"testing"
)
func TestAclRecordBuilder_BuildUserJoin(t *testing.T) {
return
}

View File

@ -2,7 +2,7 @@ package list
import ( import (
"errors" "errors"
"fmt"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/util/crypto" "github.com/anyproto/any-sync/util/crypto"
@ -13,19 +13,24 @@ import (
var log = logger.NewNamedSugared("common.commonspace.acllist") var log = logger.NewNamedSugared("common.commonspace.acllist")
var ( var (
ErrNoSuchUser = errors.New("no such user") ErrNoSuchAccount = errors.New("no such account")
ErrFailedToDecrypt = errors.New("failed to decrypt key") ErrPendingRequest = errors.New("already exists pending request")
ErrUserRemoved = errors.New("user was removed from the document") ErrUnexpectedContentType = errors.New("unexpected content type")
ErrDocumentForbidden = errors.New("your user was forbidden access to the document") ErrIncorrectIdentity = errors.New("incorrect identity")
ErrUserAlreadyExists = errors.New("user already exists") ErrIncorrectInviteKey = errors.New("incorrect invite key")
ErrNoSuchRecord = errors.New("no such record") ErrFailedToDecrypt = errors.New("failed to decrypt key")
ErrNoSuchInvite = errors.New("no such invite") ErrNoSuchRecord = errors.New("no such record")
ErrOldInvite = errors.New("invite is too old") ErrNoSuchRequest = errors.New("no such request")
ErrInsufficientPermissions = errors.New("insufficient permissions") ErrNoSuchInvite = errors.New("no such invite")
ErrNoReadKey = errors.New("acl state doesn't have a read key") ErrInsufficientPermissions = errors.New("insufficient permissions")
ErrInvalidSignature = errors.New("signature is invalid") ErrIsOwner = errors.New("can't be made by owner")
ErrIncorrectRoot = errors.New("incorrect root") ErrIncorrectNumberOfAccounts = errors.New("incorrect number of accounts")
ErrIncorrectRecordSequence = errors.New("incorrect prev id of a record") ErrDuplicateAccounts = errors.New("duplicate accounts")
ErrNoReadKey = errors.New("acl state doesn't have a read key")
ErrIncorrectReadKey = errors.New("incorrect read key")
ErrInvalidSignature = errors.New("signature is invalid")
ErrIncorrectRoot = errors.New("incorrect root")
ErrIncorrectRecordSequence = errors.New("incorrect prev id of a record")
) )
type UserPermissionPair struct { type UserPermissionPair struct {
@ -36,37 +41,71 @@ type UserPermissionPair struct {
type AclState struct { type AclState struct {
id string id string
currentReadKeyId string currentReadKeyId string
userReadKeys map[string]crypto.SymKey // userReadKeys is a map recordId -> read key which tells us about every read key
userStates map[string]AclUserState userReadKeys map[string]crypto.SymKey
statesAtRecord map[string][]AclUserState // userStates is a map pubKey -> state which defines current user state
key crypto.PrivKey userStates map[string]AclUserState
pubKey crypto.PubKey // statesAtRecord is a map recordId -> state which define user state at particular record
keyStore crypto.KeyStorage // probably this can grow rather large at some point, so we can maybe optimise later to have:
totalReadKeys int // - map pubKey -> []recordIds (where recordIds is an array where such identity permissions were changed)
statesAtRecord map[string][]AclUserState
// inviteKeys is a map recordId -> invite
inviteKeys map[string]crypto.PubKey
// requestRecords is a map recordId -> RequestRecord
requestRecords map[string]RequestRecord
// pendingRequests is a map pubKey -> recordId
pendingRequests map[string]string
key crypto.PrivKey
pubKey crypto.PubKey
keyStore crypto.KeyStorage
totalReadKeys int
lastRecordId string lastRecordId string
contentValidator ContentValidator
} }
func newAclStateWithKeys( func newAclStateWithKeys(
id string, id string,
key crypto.PrivKey) (*AclState, error) { key crypto.PrivKey) (*AclState, error) {
return &AclState{ st := &AclState{
id: id, id: id,
key: key, key: key,
pubKey: key.GetPublic(), pubKey: key.GetPublic(),
userReadKeys: make(map[string]crypto.SymKey), userReadKeys: make(map[string]crypto.SymKey),
userStates: make(map[string]AclUserState), userStates: make(map[string]AclUserState),
statesAtRecord: make(map[string][]AclUserState), statesAtRecord: make(map[string][]AclUserState),
}, nil inviteKeys: make(map[string]crypto.PubKey),
requestRecords: make(map[string]RequestRecord),
pendingRequests: make(map[string]string),
keyStore: crypto.NewKeyStorage(),
}
st.contentValidator = &contentValidator{
keyStore: st.keyStore,
aclState: st,
}
return st, nil
} }
func newAclState(id string) *AclState { func newAclState(id string) *AclState {
return &AclState{ st := &AclState{
id: id, id: id,
userReadKeys: make(map[string]crypto.SymKey), userReadKeys: make(map[string]crypto.SymKey),
userStates: make(map[string]AclUserState), userStates: make(map[string]AclUserState),
statesAtRecord: make(map[string][]AclUserState), statesAtRecord: make(map[string][]AclUserState),
inviteKeys: make(map[string]crypto.PubKey),
requestRecords: make(map[string]RequestRecord),
pendingRequests: make(map[string]string),
keyStore: crypto.NewKeyStorage(),
} }
st.contentValidator = &contentValidator{
keyStore: st.keyStore,
aclState: st,
}
return st
}
func (st *AclState) Validator() ContentValidator {
return st.contentValidator
} }
func (st *AclState) CurrentReadKeyId() string { func (st *AclState) CurrentReadKeyId() string {
@ -74,7 +113,7 @@ func (st *AclState) CurrentReadKeyId() string {
} }
func (st *AclState) CurrentReadKey() (crypto.SymKey, error) { func (st *AclState) CurrentReadKey() (crypto.SymKey, error) {
key, exists := st.userReadKeys[st.currentReadKeyId] key, exists := st.userReadKeys[st.CurrentReadKeyId()]
if !exists { if !exists {
return nil, ErrNoReadKey return nil, ErrNoReadKey
} }
@ -97,7 +136,7 @@ func (st *AclState) StateAtRecord(id string, pubKey crypto.PubKey) (AclUserState
return perm, nil return perm, nil
} }
} }
return AclUserState{}, ErrNoSuchUser return AclUserState{}, ErrNoSuchAccount
} }
func (st *AclState) applyRecord(record *AclRecord) (err error) { func (st *AclState) applyRecord(record *AclRecord) (err error) {
@ -110,17 +149,18 @@ func (st *AclState) applyRecord(record *AclRecord) (err error) {
err = ErrIncorrectRecordSequence err = ErrIncorrectRecordSequence
return return
} }
// if the record is root record
if record.Id == st.id { if record.Id == st.id {
err = st.applyRoot(record) err = st.applyRoot(record)
if err != nil { if err != nil {
return return
} }
st.statesAtRecord[record.Id] = []AclUserState{ st.statesAtRecord[record.Id] = []AclUserState{
{PubKey: record.Identity, Permissions: aclrecordproto.AclUserPermissions_Admin}, st.userStates[mapKeyFromPubKey(record.Identity)],
} }
return return
} }
// if the model is not cached
if record.Model == nil { if record.Model == nil {
aclData := &aclrecordproto.AclData{} aclData := &aclrecordproto.AclData{}
err = proto.Unmarshal(record.Data, aclData) err = proto.Unmarshal(record.Data, aclData)
@ -129,18 +169,16 @@ func (st *AclState) applyRecord(record *AclRecord) (err error) {
} }
record.Model = aclData record.Model = aclData
} }
// applying records contents
err = st.applyChangeData(record) err = st.applyChangeData(record)
if err != nil { if err != nil {
return return
} }
// getting all states for users at record and saving them
// getting all states for users at record
var states []AclUserState var states []AclUserState
for _, state := range st.userStates { for _, state := range st.userStates {
states = append(states, state) states = append(states, state)
} }
st.statesAtRecord[record.Id] = states st.statesAtRecord[record.Id] = states
return return
} }
@ -156,9 +194,9 @@ func (st *AclState) applyRoot(record *AclRecord) (err error) {
// adding user to the list // adding user to the list
userState := AclUserState{ userState := AclUserState{
PubKey: record.Identity, PubKey: record.Identity,
Permissions: aclrecordproto.AclUserPermissions_Admin, Permissions: AclPermissions(aclrecordproto.AclUserPermissions_Owner),
} }
st.currentReadKeyId = record.ReadKeyId st.currentReadKeyId = record.Id
st.userStates[mapKeyFromPubKey(record.Identity)] = userState st.userStates[mapKeyFromPubKey(record.Identity)] = userState
st.totalReadKeys++ st.totalReadKeys++
return return
@ -181,92 +219,191 @@ func (st *AclState) saveReadKeyFromRoot(record *AclRecord) (err error) {
return return
} }
} }
st.userReadKeys[record.Id] = readKey st.userReadKeys[record.Id] = readKey
return return
} }
func (st *AclState) applyChangeData(record *AclRecord) (err error) { func (st *AclState) applyChangeData(record *AclRecord) (err error) {
defer func() {
if err != nil {
return
}
if record.ReadKeyId != st.currentReadKeyId {
st.totalReadKeys++
st.currentReadKeyId = record.ReadKeyId
}
}()
model := record.Model.(*aclrecordproto.AclData) model := record.Model.(*aclrecordproto.AclData)
if !st.isUserJoin(model) {
// we check signature when we add this to the List, so no need to do it here
if _, exists := st.userStates[mapKeyFromPubKey(record.Identity)]; !exists {
err = ErrNoSuchUser
return
}
// only Admins can do non-user join changes
if !st.HasPermission(record.Identity, aclrecordproto.AclUserPermissions_Admin) {
// TODO: add string encoding
err = fmt.Errorf("user %s must have admin permissions", record.Identity.Account())
return
}
}
for _, ch := range model.GetAclContent() { for _, ch := range model.GetAclContent() {
if err = st.applyChangeContent(ch, record.Id); err != nil { if err = st.applyChangeContent(ch, record.Id, record.Identity); err != nil {
log.Info("error while applying changes: %v; ignore", zap.Error(err)) log.Info("error while applying changes: %v; ignore", zap.Error(err))
return err return err
} }
} }
return nil return nil
} }
func (st *AclState) applyChangeContent(ch *aclrecordproto.AclContentValue, recordId string) error { func (st *AclState) applyChangeContent(ch *aclrecordproto.AclContentValue, recordId string, authorIdentity crypto.PubKey) error {
switch { switch {
case ch.GetUserPermissionChange() != nil: case ch.GetPermissionChange() != nil:
return st.applyUserPermissionChange(ch.GetUserPermissionChange(), recordId) return st.applyPermissionChange(ch.GetPermissionChange(), recordId, authorIdentity)
case ch.GetUserAdd() != nil: case ch.GetInvite() != nil:
return st.applyUserAdd(ch.GetUserAdd(), recordId) return st.applyInvite(ch.GetInvite(), recordId, authorIdentity)
case ch.GetUserRemove() != nil: case ch.GetInviteRevoke() != nil:
return st.applyUserRemove(ch.GetUserRemove(), recordId) return st.applyInviteRevoke(ch.GetInviteRevoke(), recordId, authorIdentity)
case ch.GetUserInvite() != nil: case ch.GetRequestJoin() != nil:
return st.applyUserInvite(ch.GetUserInvite(), recordId) return st.applyRequestJoin(ch.GetRequestJoin(), recordId, authorIdentity)
case ch.GetUserJoin() != nil: case ch.GetRequestAccept() != nil:
return st.applyUserJoin(ch.GetUserJoin(), recordId) return st.applyRequestAccept(ch.GetRequestAccept(), recordId, authorIdentity)
case ch.GetRequestDecline() != nil:
return st.applyRequestDecline(ch.GetRequestDecline(), recordId, authorIdentity)
case ch.GetAccountRemove() != nil:
return st.applyAccountRemove(ch.GetAccountRemove(), recordId, authorIdentity)
case ch.GetReadKeyChange() != nil:
return st.applyReadKeyChange(ch.GetReadKeyChange(), recordId, authorIdentity)
case ch.GetAccountRequestRemove() != nil:
return st.applyRequestRemove(ch.GetAccountRequestRemove(), recordId, authorIdentity)
default: default:
return fmt.Errorf("unexpected change type: %v", ch) return ErrUnexpectedContentType
} }
} }
func (st *AclState) applyUserPermissionChange(ch *aclrecordproto.AclUserPermissionChange, recordId string) error { func (st *AclState) applyPermissionChange(ch *aclrecordproto.AclAccountPermissionChange, recordId string, authorIdentity crypto.PubKey) error {
chIdentity, err := st.keyStore.PubKeyFromProto(ch.Identity) chIdentity, err := st.keyStore.PubKeyFromProto(ch.Identity)
if err != nil { if err != nil {
return err return err
} }
state, exists := st.userStates[mapKeyFromPubKey(chIdentity)] err = st.contentValidator.ValidatePermissionChange(ch, authorIdentity)
if !exists { if err != nil {
return ErrNoSuchUser return err
} }
stringKey := mapKeyFromPubKey(chIdentity)
state.Permissions = ch.Permissions state, _ := st.userStates[stringKey]
state.Permissions = AclPermissions(ch.Permissions)
st.userStates[stringKey] = state
return nil return nil
} }
func (st *AclState) applyUserInvite(ch *aclrecordproto.AclUserInvite, recordId string) error { func (st *AclState) applyInvite(ch *aclrecordproto.AclAccountInvite, recordId string, authorIdentity crypto.PubKey) error {
// TODO: check old code and bring it back :-) inviteKey, err := st.keyStore.PubKeyFromProto(ch.InviteKey)
if err != nil {
return err
}
err = st.contentValidator.ValidateInvite(ch, authorIdentity)
if err != nil {
return err
}
st.inviteKeys[recordId] = inviteKey
return nil return nil
} }
func (st *AclState) applyUserJoin(ch *aclrecordproto.AclUserJoin, recordId string) error { func (st *AclState) applyInviteRevoke(ch *aclrecordproto.AclAccountInviteRevoke, recordId string, authorIdentity crypto.PubKey) error {
err := st.contentValidator.ValidateInviteRevoke(ch, authorIdentity)
if err != nil {
return err
}
delete(st.inviteKeys, ch.InviteRecordId)
return nil return nil
} }
func (st *AclState) applyUserAdd(ch *aclrecordproto.AclUserAdd, recordId string) error { func (st *AclState) applyRequestJoin(ch *aclrecordproto.AclAccountRequestJoin, recordId string, authorIdentity crypto.PubKey) error {
err := st.contentValidator.ValidateRequestJoin(ch, authorIdentity)
if err != nil {
return err
}
st.pendingRequests[mapKeyFromPubKey(authorIdentity)] = recordId
st.requestRecords[recordId] = RequestRecord{
RequestIdentity: authorIdentity,
RequestMetadata: ch.Metadata,
Type: RequestTypeJoin,
}
return nil return nil
} }
func (st *AclState) applyUserRemove(ch *aclrecordproto.AclUserRemove, recordId string) error { func (st *AclState) applyRequestAccept(ch *aclrecordproto.AclAccountRequestAccept, recordId string, authorIdentity crypto.PubKey) error {
err := st.contentValidator.ValidateRequestAccept(ch, authorIdentity)
if err != nil {
return err
}
acceptIdentity, err := st.keyStore.PubKeyFromProto(ch.Identity)
if err != nil {
return err
}
record, _ := st.requestRecords[ch.RequestRecordId]
st.userStates[mapKeyFromPubKey(acceptIdentity)] = AclUserState{
PubKey: acceptIdentity,
Permissions: AclPermissions(ch.Permissions),
RequestMetadata: record.RequestMetadata,
}
delete(st.pendingRequests, mapKeyFromPubKey(st.requestRecords[ch.RequestRecordId].RequestIdentity))
if !st.pubKey.Equals(acceptIdentity) {
return nil
}
for _, key := range ch.EncryptedReadKeys {
decrypted, err := st.key.Decrypt(key.EncryptedReadKey)
if err != nil {
return err
}
sym, err := crypto.UnmarshallAESKey(decrypted)
if err != nil {
return err
}
st.userReadKeys[key.RecordId] = sym
}
return nil
}
func (st *AclState) applyRequestDecline(ch *aclrecordproto.AclAccountRequestDecline, recordId string, authorIdentity crypto.PubKey) error {
err := st.contentValidator.ValidateRequestDecline(ch, authorIdentity)
if err != nil {
return err
}
delete(st.pendingRequests, mapKeyFromPubKey(st.requestRecords[ch.RequestRecordId].RequestIdentity))
delete(st.requestRecords, ch.RequestRecordId)
return nil
}
func (st *AclState) applyRequestRemove(ch *aclrecordproto.AclAccountRequestRemove, recordId string, authorIdentity crypto.PubKey) error {
err := st.contentValidator.ValidateRequestRemove(ch, authorIdentity)
if err != nil {
return err
}
st.requestRecords[recordId] = RequestRecord{
RequestIdentity: authorIdentity,
Type: RequestTypeRemove,
}
st.pendingRequests[mapKeyFromPubKey(authorIdentity)] = recordId
return nil
}
func (st *AclState) applyAccountRemove(ch *aclrecordproto.AclAccountRemove, recordId string, authorIdentity crypto.PubKey) error {
err := st.contentValidator.ValidateAccountRemove(ch, authorIdentity)
if err != nil {
return err
}
for _, rawIdentity := range ch.Identities {
identity, err := st.keyStore.PubKeyFromProto(rawIdentity)
if err != nil {
return err
}
idKey := mapKeyFromPubKey(identity)
delete(st.userStates, idKey)
delete(st.pendingRequests, idKey)
}
return st.updateReadKey(ch.AccountKeys, recordId)
}
func (st *AclState) applyReadKeyChange(ch *aclrecordproto.AclReadKeyChange, recordId string, authorIdentity crypto.PubKey) error {
err := st.contentValidator.ValidateReadKeyChange(ch, authorIdentity)
if err != nil {
return err
}
return st.updateReadKey(ch.AccountKeys, recordId)
}
func (st *AclState) updateReadKey(keys []*aclrecordproto.AclEncryptedReadKey, recordId string) error {
for _, accKey := range keys {
identity, _ := st.keyStore.PubKeyFromProto(accKey.Identity)
if st.pubKey.Equals(identity) {
res, err := st.decryptReadKey(accKey.EncryptedReadKey)
if err != nil {
return err
}
st.userReadKeys[recordId] = res
}
}
st.currentReadKeyId = recordId
return nil return nil
} }
@ -275,7 +412,6 @@ func (st *AclState) decryptReadKey(msg []byte) (crypto.SymKey, error) {
if err != nil { if err != nil {
return nil, ErrFailedToDecrypt return nil, ErrFailedToDecrypt
} }
key, err := crypto.UnmarshallAESKey(decrypted) key, err := crypto.UnmarshallAESKey(decrypted)
if err != nil { if err != nil {
return nil, ErrFailedToDecrypt return nil, ErrFailedToDecrypt
@ -283,29 +419,31 @@ func (st *AclState) decryptReadKey(msg []byte) (crypto.SymKey, error) {
return key, nil return key, nil
} }
func (st *AclState) HasPermission(identity crypto.PubKey, permission aclrecordproto.AclUserPermissions) bool { func (st *AclState) Permissions(identity crypto.PubKey) AclPermissions {
state, exists := st.userStates[mapKeyFromPubKey(identity)] state, exists := st.userStates[mapKeyFromPubKey(identity)]
if !exists { if !exists {
return false return AclPermissions(aclrecordproto.AclUserPermissions_None)
} }
return state.Permissions
return state.Permissions == permission
} }
func (st *AclState) isUserJoin(data *aclrecordproto.AclData) bool { func (st *AclState) JoinRecords() (records []RequestRecord) {
// if we have a UserJoin, then it should always be the first one applied for _, recId := range st.pendingRequests {
return data.GetAclContent() != nil && data.GetAclContent()[0].GetUserJoin() != nil rec := st.requestRecords[recId]
if rec.Type == RequestTypeJoin {
records = append(records, rec)
}
}
return
} }
func (st *AclState) isUserAdd(data *aclrecordproto.AclData, identity []byte) bool { func (st *AclState) RemoveRecords() (records []RequestRecord) {
return false for _, recId := range st.pendingRequests {
} rec := st.requestRecords[recId]
if rec.Type == RequestTypeRemove {
func (st *AclState) UserStates() map[string]AclUserState { records = append(records, rec)
return st.userStates }
} }
func (st *AclState) Invite(acceptPubKey []byte) (invite *aclrecordproto.AclUserInvite, err error) {
return return
} }

View File

@ -5,16 +5,21 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
"github.com/anyproto/any-sync/util/crypto"
"sync" "sync"
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/util/cidutil"
"github.com/anyproto/any-sync/util/crypto"
) )
type IterFunc = func(record *AclRecord) (IsContinue bool) type IterFunc = func(record *AclRecord) (IsContinue bool)
var ErrIncorrectCID = errors.New("incorrect CID") var (
ErrIncorrectCID = errors.New("incorrect CID")
ErrRecordAlreadyExists = errors.New("record already exists")
)
type RWLocker interface { type RWLocker interface {
sync.Locker sync.Locker
@ -22,26 +27,45 @@ type RWLocker interface {
RUnlock() RUnlock()
} }
type AcceptorVerifier interface {
VerifyAcceptor(rec *consensusproto.RawRecord) (err error)
}
type NoOpAcceptorVerifier struct {
}
func (n NoOpAcceptorVerifier) VerifyAcceptor(rec *consensusproto.RawRecord) (err error) {
return nil
}
type AclList interface { type AclList interface {
RWLocker RWLocker
Id() string Id() string
Root() *aclrecordproto.RawAclRecordWithId Root() *consensusproto.RawRecordWithId
Records() []*AclRecord Records() []*AclRecord
AclState() *AclState AclState() *AclState
IsAfter(first string, second string) (bool, error) IsAfter(first string, second string) (bool, error)
HasHead(head string) bool
Head() *AclRecord Head() *AclRecord
RecordsAfter(ctx context.Context, id string) (records []*consensusproto.RawRecordWithId, err error)
Get(id string) (*AclRecord, error) Get(id string) (*AclRecord, error)
GetIndex(idx int) (*AclRecord, error)
Iterate(iterFunc IterFunc) Iterate(iterFunc IterFunc)
IterateFrom(startId string, iterFunc IterFunc) IterateFrom(startId string, iterFunc IterFunc)
KeyStorage() crypto.KeyStorage KeyStorage() crypto.KeyStorage
RecordBuilder() AclRecordBuilder
AddRawRecord(rawRec *aclrecordproto.RawAclRecordWithId) (added bool, err error) ValidateRawRecord(record *consensusproto.RawRecord) (err error)
AddRawRecord(rawRec *consensusproto.RawRecordWithId) (err error)
AddRawRecords(rawRecords []*consensusproto.RawRecordWithId) (err error)
Close() (err error) Close(ctx context.Context) (err error)
} }
type aclList struct { type aclList struct {
root *aclrecordproto.RawAclRecordWithId root *consensusproto.RawRecordWithId
records []*AclRecord records []*AclRecord
indexes map[string]int indexes map[string]int
id string id string
@ -55,18 +79,45 @@ type aclList struct {
sync.RWMutex sync.RWMutex
} }
func BuildAclListWithIdentity(acc *accountdata.AccountKeys, storage liststorage.ListStorage) (AclList, error) { type internalDeps struct {
builder := newAclStateBuilderWithIdentity(acc) storage liststorage.ListStorage
keyStorage := crypto.NewKeyStorage() keyStorage crypto.KeyStorage
return build(storage.Id(), keyStorage, builder, NewAclRecordBuilder(storage.Id(), keyStorage), storage) stateBuilder *aclStateBuilder
recordBuilder AclRecordBuilder
acceptorVerifier AcceptorVerifier
} }
func BuildAclList(storage liststorage.ListStorage) (AclList, error) { func BuildAclListWithIdentity(acc *accountdata.AccountKeys, storage liststorage.ListStorage, verifier AcceptorVerifier) (AclList, error) {
keyStorage := crypto.NewKeyStorage() keyStorage := crypto.NewKeyStorage()
return build(storage.Id(), keyStorage, newAclStateBuilder(), NewAclRecordBuilder(storage.Id(), crypto.NewKeyStorage()), storage) deps := internalDeps{
storage: storage,
keyStorage: keyStorage,
stateBuilder: newAclStateBuilderWithIdentity(acc),
recordBuilder: NewAclRecordBuilder(storage.Id(), keyStorage, acc, verifier),
acceptorVerifier: verifier,
}
return build(deps)
} }
func build(id string, keyStorage crypto.KeyStorage, stateBuilder *aclStateBuilder, recBuilder AclRecordBuilder, storage liststorage.ListStorage) (list AclList, err error) { func BuildAclList(storage liststorage.ListStorage, verifier AcceptorVerifier) (AclList, error) {
keyStorage := crypto.NewKeyStorage()
deps := internalDeps{
storage: storage,
keyStorage: keyStorage,
stateBuilder: newAclStateBuilder(),
recordBuilder: NewAclRecordBuilder(storage.Id(), keyStorage, nil, verifier),
acceptorVerifier: verifier,
}
return build(deps)
}
func build(deps internalDeps) (list AclList, err error) {
var (
storage = deps.storage
id = deps.storage.Id()
recBuilder = deps.recordBuilder
stateBuilder = deps.stateBuilder
)
head, err := storage.Head() head, err := storage.Head()
if err != nil { if err != nil {
return return
@ -77,7 +128,7 @@ func build(id string, keyStorage crypto.KeyStorage, stateBuilder *aclStateBuilde
return return
} }
record, err := recBuilder.Unmarshall(rawRecordWithId) record, err := recBuilder.UnmarshallWithId(rawRecordWithId)
if err != nil { if err != nil {
return return
} }
@ -89,7 +140,7 @@ func build(id string, keyStorage crypto.KeyStorage, stateBuilder *aclStateBuilde
return return
} }
record, err = recBuilder.Unmarshall(rawRecordWithId) record, err = recBuilder.UnmarshallWithId(rawRecordWithId)
if err != nil { if err != nil {
return return
} }
@ -119,6 +170,7 @@ func build(id string, keyStorage crypto.KeyStorage, stateBuilder *aclStateBuilde
return return
} }
recBuilder.(*aclRecordBuilder).state = state
list = &aclList{ list = &aclList{
root: rootWithId, root: rootWithId,
records: records, records: records,
@ -132,15 +184,37 @@ func build(id string, keyStorage crypto.KeyStorage, stateBuilder *aclStateBuilde
return return
} }
func (a *aclList) RecordBuilder() AclRecordBuilder {
return a.recordBuilder
}
func (a *aclList) Records() []*AclRecord { func (a *aclList) Records() []*AclRecord {
return a.records return a.records
} }
func (a *aclList) AddRawRecord(rawRec *aclrecordproto.RawAclRecordWithId) (added bool, err error) { func (a *aclList) ValidateRawRecord(rawRec *consensusproto.RawRecord) (err error) {
if _, ok := a.indexes[rawRec.Id]; ok { record, err := a.recordBuilder.Unmarshall(rawRec)
if err != nil {
return return
} }
record, err := a.recordBuilder.Unmarshall(rawRec) return a.aclState.Validator().ValidateAclRecordContents(record)
}
func (a *aclList) AddRawRecords(rawRecords []*consensusproto.RawRecordWithId) (err error) {
for _, rec := range rawRecords {
err = a.AddRawRecord(rec)
if err != nil && err != ErrRecordAlreadyExists {
return
}
}
return
}
func (a *aclList) AddRawRecord(rawRec *consensusproto.RawRecordWithId) (err error) {
if _, ok := a.indexes[rawRec.Id]; ok {
return ErrRecordAlreadyExists
}
record, err := a.recordBuilder.UnmarshallWithId(rawRec)
if err != nil { if err != nil {
return return
} }
@ -155,15 +229,6 @@ func (a *aclList) AddRawRecord(rawRec *aclrecordproto.RawAclRecordWithId) (added
if err = a.storage.SetHead(rawRec.Id); err != nil { if err = a.storage.SetHead(rawRec.Id); err != nil {
return return
} }
return true, nil
}
func (a *aclList) IsValidNext(rawRec *aclrecordproto.RawAclRecordWithId) (err error) {
_, err = a.recordBuilder.Unmarshall(rawRec)
if err != nil {
return
}
// TODO: change state and add "check" method for records
return return
} }
@ -171,7 +236,7 @@ func (a *aclList) Id() string {
return a.id return a.id
} }
func (a *aclList) Root() *aclrecordproto.RawAclRecordWithId { func (a *aclList) Root() *consensusproto.RawRecordWithId {
return a.root return a.root
} }
@ -196,14 +261,27 @@ func (a *aclList) Head() *AclRecord {
return a.records[len(a.records)-1] return a.records[len(a.records)-1]
} }
func (a *aclList) HasHead(head string) bool {
_, exists := a.indexes[head]
return exists
}
func (a *aclList) Get(id string) (*AclRecord, error) { func (a *aclList) Get(id string) (*AclRecord, error) {
recIdx, ok := a.indexes[id] recIdx, ok := a.indexes[id]
if !ok { if !ok {
return nil, fmt.Errorf("no such record") return nil, ErrNoSuchRecord
} }
return a.records[recIdx], nil return a.records[recIdx], nil
} }
func (a *aclList) GetIndex(idx int) (*AclRecord, error) {
// TODO: when we add snapshots we will have to monitor record num in snapshots
if idx < 0 || idx >= len(a.records) {
return nil, ErrNoSuchRecord
}
return a.records[idx], nil
}
func (a *aclList) Iterate(iterFunc IterFunc) { func (a *aclList) Iterate(iterFunc IterFunc) {
for _, rec := range a.records { for _, rec := range a.records {
if !iterFunc(rec) { if !iterFunc(rec) {
@ -212,6 +290,21 @@ func (a *aclList) Iterate(iterFunc IterFunc) {
} }
} }
func (a *aclList) RecordsAfter(ctx context.Context, id string) (records []*consensusproto.RawRecordWithId, err error) {
recIdx, ok := a.indexes[id]
if !ok {
return nil, ErrNoSuchRecord
}
for i := recIdx + 1; i < len(a.records); i++ {
rawRec, err := a.storage.GetRawRecord(ctx, a.records[i].Id)
if err != nil {
return nil, err
}
records = append(records, rawRec)
}
return
}
func (a *aclList) IterateFrom(startId string, iterFunc IterFunc) { func (a *aclList) IterateFrom(startId string, iterFunc IterFunc) {
recIdx, ok := a.indexes[startId] recIdx, ok := a.indexes[startId]
if !ok { if !ok {
@ -224,6 +317,21 @@ func (a *aclList) IterateFrom(startId string, iterFunc IterFunc) {
} }
} }
func (a *aclList) Close() (err error) { func (a *aclList) Close(ctx context.Context) (err error) {
return nil return nil
} }
func WrapAclRecord(rawRec *consensusproto.RawRecord) *consensusproto.RawRecordWithId {
payload, err := rawRec.Marshal()
if err != nil {
panic(err)
}
id, err := cidutil.NewCidFromBytes(payload)
if err != nil {
panic(err)
}
return &consensusproto.RawRecordWithId{
Payload: payload,
Id: id,
}
}

View File

@ -2,11 +2,98 @@ package list
import ( import (
"fmt" "fmt"
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/stretchr/testify/require"
"testing" "testing"
"github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/util/crypto"
"github.com/stretchr/testify/require"
) )
type aclFixture struct {
ownerKeys *accountdata.AccountKeys
accountKeys *accountdata.AccountKeys
ownerAcl *aclList
accountAcl *aclList
spaceId string
}
func newFixture(t *testing.T) *aclFixture {
ownerKeys, err := accountdata.NewRandom()
require.NoError(t, err)
accountKeys, err := accountdata.NewRandom()
require.NoError(t, err)
spaceId := "spaceId"
ownerAcl, err := NewTestDerivedAcl(spaceId, ownerKeys)
require.NoError(t, err)
accountAcl, err := NewTestAclWithRoot(accountKeys, ownerAcl.Root())
require.NoError(t, err)
return &aclFixture{
ownerKeys: ownerKeys,
accountKeys: accountKeys,
ownerAcl: ownerAcl.(*aclList),
accountAcl: accountAcl.(*aclList),
spaceId: spaceId,
}
}
func (fx *aclFixture) addRec(t *testing.T, rec *consensusproto.RawRecordWithId) {
err := fx.ownerAcl.AddRawRecord(rec)
require.NoError(t, err)
err = fx.accountAcl.AddRawRecord(rec)
require.NoError(t, err)
}
func (fx *aclFixture) inviteAccount(t *testing.T, perms AclPermissions) {
var (
ownerAcl = fx.ownerAcl
ownerState = fx.ownerAcl.aclState
accountAcl = fx.accountAcl
accountState = fx.accountAcl.aclState
)
// building invite
inv, err := ownerAcl.RecordBuilder().BuildInvite()
require.NoError(t, err)
inviteRec := WrapAclRecord(inv.InviteRec)
fx.addRec(t, inviteRec)
// building request join
requestJoin, err := accountAcl.RecordBuilder().BuildRequestJoin(RequestJoinPayload{
InviteRecordId: inviteRec.Id,
InviteKey: inv.InviteKey,
})
require.NoError(t, err)
requestJoinRec := WrapAclRecord(requestJoin)
fx.addRec(t, requestJoinRec)
// building request accept
requestAccept, err := ownerAcl.RecordBuilder().BuildRequestAccept(RequestAcceptPayload{
RequestRecordId: requestJoinRec.Id,
Permissions: perms,
})
require.NoError(t, err)
// validate
err = ownerAcl.ValidateRawRecord(requestAccept)
require.NoError(t, err)
requestAcceptRec := WrapAclRecord(requestAccept)
fx.addRec(t, requestAcceptRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).CanWrite())
require.Equal(t, 0, len(ownerState.pendingRequests))
require.Equal(t, 0, len(accountState.pendingRequests))
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, accountState.Permissions(accountState.pubKey).CanWrite())
_, err = ownerState.StateAtRecord(requestJoinRec.Id, accountState.pubKey)
require.Equal(t, ErrNoSuchAccount, err)
stateAtRec, err := ownerState.StateAtRecord(requestAcceptRec.Id, accountState.pubKey)
require.NoError(t, err)
require.True(t, stateAtRec.Permissions == perms)
}
func TestAclList_BuildRoot(t *testing.T) { func TestAclList_BuildRoot(t *testing.T) {
randomKeys, err := accountdata.NewRandom() randomKeys, err := accountdata.NewRandom()
require.NoError(t, err) require.NoError(t, err)
@ -14,3 +101,193 @@ func TestAclList_BuildRoot(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
fmt.Println(randomAcl.Id()) fmt.Println(randomAcl.Id())
} }
func TestAclList_InvitePipeline(t *testing.T) {
fx := newFixture(t)
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Writer))
}
func TestAclList_InviteRevoke(t *testing.T) {
fx := newFixture(t)
var (
ownerState = fx.ownerAcl.aclState
accountState = fx.accountAcl.aclState
)
// building invite
inv, err := fx.ownerAcl.RecordBuilder().BuildInvite()
require.NoError(t, err)
inviteRec := WrapAclRecord(inv.InviteRec)
fx.addRec(t, inviteRec)
// building invite revoke
inviteRevoke, err := fx.ownerAcl.RecordBuilder().BuildInviteRevoke(ownerState.lastRecordId)
require.NoError(t, err)
inviteRevokeRec := WrapAclRecord(inviteRevoke)
fx.addRec(t, inviteRevokeRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
require.Empty(t, ownerState.inviteKeys)
require.Empty(t, accountState.inviteKeys)
}
func TestAclList_RequestDecline(t *testing.T) {
fx := newFixture(t)
var (
ownerAcl = fx.ownerAcl
ownerState = fx.ownerAcl.aclState
accountAcl = fx.accountAcl
accountState = fx.accountAcl.aclState
)
// building invite
inv, err := ownerAcl.RecordBuilder().BuildInvite()
require.NoError(t, err)
inviteRec := WrapAclRecord(inv.InviteRec)
fx.addRec(t, inviteRec)
// building request join
requestJoin, err := accountAcl.RecordBuilder().BuildRequestJoin(RequestJoinPayload{
InviteRecordId: inviteRec.Id,
InviteKey: inv.InviteKey,
})
require.NoError(t, err)
requestJoinRec := WrapAclRecord(requestJoin)
fx.addRec(t, requestJoinRec)
// building request decline
requestDecline, err := ownerAcl.RecordBuilder().BuildRequestDecline(ownerState.lastRecordId)
require.NoError(t, err)
requestDeclineRec := WrapAclRecord(requestDecline)
fx.addRec(t, requestDeclineRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
require.Empty(t, ownerState.pendingRequests)
require.Empty(t, accountState.pendingRequests)
}
func TestAclList_Remove(t *testing.T) {
fx := newFixture(t)
var (
ownerState = fx.ownerAcl.aclState
accountState = fx.accountAcl.aclState
)
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Writer))
newReadKey := crypto.NewAES()
remove, err := fx.ownerAcl.RecordBuilder().BuildAccountRemove(AccountRemovePayload{
Identities: []crypto.PubKey{fx.accountKeys.SignKey.GetPublic()},
ReadKey: newReadKey,
})
require.NoError(t, err)
removeRec := WrapAclRecord(remove)
fx.addRec(t, removeRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
require.True(t, ownerState.userReadKeys[removeRec.Id].Equals(newReadKey))
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
require.Equal(t, 0, len(ownerState.pendingRequests))
require.Equal(t, 0, len(accountState.pendingRequests))
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, accountState.Permissions(accountState.pubKey).NoPermissions())
require.Nil(t, accountState.userReadKeys[removeRec.Id])
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
}
func TestAclList_ReadKeyChange(t *testing.T) {
fx := newFixture(t)
var (
ownerState = fx.ownerAcl.aclState
accountState = fx.accountAcl.aclState
)
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Admin))
newReadKey := crypto.NewAES()
readKeyChange, err := fx.ownerAcl.RecordBuilder().BuildReadKeyChange(newReadKey)
require.NoError(t, err)
readKeyRec := WrapAclRecord(readKeyChange)
fx.addRec(t, readKeyRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).CanManageAccounts())
require.True(t, ownerState.userReadKeys[readKeyRec.Id].Equals(newReadKey))
require.True(t, accountState.userReadKeys[readKeyRec.Id].Equals(newReadKey))
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
readKey, err := ownerState.CurrentReadKey()
require.NoError(t, err)
require.True(t, newReadKey.Equals(readKey))
require.Equal(t, 0, len(ownerState.pendingRequests))
require.Equal(t, 0, len(accountState.pendingRequests))
}
func TestAclList_PermissionChange(t *testing.T) {
fx := newFixture(t)
var (
ownerState = fx.ownerAcl.aclState
accountState = fx.accountAcl.aclState
)
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Admin))
permissionChange, err := fx.ownerAcl.RecordBuilder().BuildPermissionChange(PermissionChangePayload{
Identity: fx.accountKeys.SignKey.GetPublic(),
Permissions: AclPermissions(aclrecordproto.AclUserPermissions_Writer),
})
require.NoError(t, err)
permissionChangeRec := WrapAclRecord(permissionChange)
fx.addRec(t, permissionChangeRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey) == AclPermissions(aclrecordproto.AclUserPermissions_Writer))
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, accountState.Permissions(accountState.pubKey) == AclPermissions(aclrecordproto.AclUserPermissions_Writer))
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
require.Equal(t, 0, len(ownerState.pendingRequests))
require.Equal(t, 0, len(accountState.pendingRequests))
}
func TestAclList_RequestRemove(t *testing.T) {
fx := newFixture(t)
var (
ownerState = fx.ownerAcl.aclState
accountState = fx.accountAcl.aclState
)
fx.inviteAccount(t, AclPermissions(aclrecordproto.AclUserPermissions_Writer))
removeRequest, err := fx.accountAcl.RecordBuilder().BuildRequestRemove()
require.NoError(t, err)
removeRequestRec := WrapAclRecord(removeRequest)
fx.addRec(t, removeRequestRec)
recs := fx.accountAcl.AclState().RemoveRecords()
require.Len(t, recs, 1)
require.True(t, accountState.pubKey.Equals(recs[0].RequestIdentity))
newReadKey := crypto.NewAES()
remove, err := fx.ownerAcl.RecordBuilder().BuildAccountRemove(AccountRemovePayload{
Identities: []crypto.PubKey{recs[0].RequestIdentity},
ReadKey: newReadKey,
})
require.NoError(t, err)
removeRec := WrapAclRecord(remove)
fx.addRec(t, removeRec)
// checking acl state
require.True(t, ownerState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, ownerState.Permissions(accountState.pubKey).NoPermissions())
require.True(t, ownerState.userReadKeys[removeRec.Id].Equals(newReadKey))
require.NotNil(t, ownerState.userReadKeys[fx.ownerAcl.Id()])
require.Equal(t, 0, len(ownerState.pendingRequests))
require.Equal(t, 0, len(accountState.pendingRequests))
require.True(t, accountState.Permissions(ownerState.pubKey).IsOwner())
require.True(t, accountState.Permissions(accountState.pubKey).NoPermissions())
require.Nil(t, accountState.userReadKeys[removeRec.Id])
require.NotNil(t, accountState.userReadKeys[fx.ownerAcl.Id()])
}

View File

@ -2,13 +2,13 @@ package list
import ( import (
"github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/commonspace/object/acl/liststorage" "github.com/anyproto/any-sync/commonspace/object/acl/liststorage"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/util/crypto" "github.com/anyproto/any-sync/util/crypto"
) )
func NewTestDerivedAcl(spaceId string, keys *accountdata.AccountKeys) (AclList, error) { func NewTestDerivedAcl(spaceId string, keys *accountdata.AccountKeys) (AclList, error) {
builder := NewAclRecordBuilder("", crypto.NewKeyStorage()) builder := NewAclRecordBuilder("", crypto.NewKeyStorage(), keys, NoOpAcceptorVerifier{})
masterKey, _, err := crypto.GenerateRandomEd25519KeyPair() masterKey, _, err := crypto.GenerateRandomEd25519KeyPair()
if err != nil { if err != nil {
return nil, err return nil, err
@ -21,11 +21,21 @@ func NewTestDerivedAcl(spaceId string, keys *accountdata.AccountKeys) (AclList,
if err != nil { if err != nil {
return nil, err return nil, err
} }
st, err := liststorage.NewInMemoryAclListStorage(root.Id, []*aclrecordproto.RawAclRecordWithId{ st, err := liststorage.NewInMemoryAclListStorage(root.Id, []*consensusproto.RawRecordWithId{
root, root,
}) })
if err != nil { if err != nil {
return nil, err return nil, err
} }
return BuildAclListWithIdentity(keys, st) return BuildAclListWithIdentity(keys, st, NoOpAcceptorVerifier{})
}
func NewTestAclWithRoot(keys *accountdata.AccountKeys, root *consensusproto.RawRecordWithId) (AclList, error) {
st, err := liststorage.NewInMemoryAclListStorage(root.Id, []*consensusproto.RawRecordWithId{
root,
})
if err != nil {
return nil, err
}
return BuildAclListWithIdentity(keys, st, NoOpAcceptorVerifier{})
} }

View File

@ -5,12 +5,13 @@
package mock_list package mock_list
import ( import (
context "context"
reflect "reflect" reflect "reflect"
aclrecordproto "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
list "github.com/anyproto/any-sync/commonspace/object/acl/list" list "github.com/anyproto/any-sync/commonspace/object/acl/list"
consensusproto "github.com/anyproto/any-sync/consensus/consensusproto"
crypto "github.com/anyproto/any-sync/util/crypto" crypto "github.com/anyproto/any-sync/util/crypto"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockAclList is a mock of AclList interface. // MockAclList is a mock of AclList interface.
@ -51,12 +52,11 @@ func (mr *MockAclListMockRecorder) AclState() *gomock.Call {
} }
// AddRawRecord mocks base method. // AddRawRecord mocks base method.
func (m *MockAclList) AddRawRecord(arg0 *aclrecordproto.RawAclRecordWithId) (bool, error) { func (m *MockAclList) AddRawRecord(arg0 *consensusproto.RawRecordWithId) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddRawRecord", arg0) ret := m.ctrl.Call(m, "AddRawRecord", arg0)
ret0, _ := ret[0].(bool) ret0, _ := ret[0].(error)
ret1, _ := ret[1].(error) return ret0
return ret0, ret1
} }
// AddRawRecord indicates an expected call of AddRawRecord. // AddRawRecord indicates an expected call of AddRawRecord.
@ -65,18 +65,32 @@ func (mr *MockAclListMockRecorder) AddRawRecord(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockAclList)(nil).AddRawRecord), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockAclList)(nil).AddRawRecord), arg0)
} }
// Close mocks base method. // AddRawRecords mocks base method.
func (m *MockAclList) Close() error { func (m *MockAclList) AddRawRecords(arg0 []*consensusproto.RawRecordWithId) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close") ret := m.ctrl.Call(m, "AddRawRecords", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// AddRawRecords indicates an expected call of AddRawRecords.
func (mr *MockAclListMockRecorder) AddRawRecords(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecords", reflect.TypeOf((*MockAclList)(nil).AddRawRecords), arg0)
}
// Close mocks base method.
func (m *MockAclList) Close(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Close indicates an expected call of Close. // Close indicates an expected call of Close.
func (mr *MockAclListMockRecorder) Close() *gomock.Call { func (mr *MockAclListMockRecorder) Close(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAclList)(nil).Close)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockAclList)(nil).Close), arg0)
} }
// Get mocks base method. // Get mocks base method.
@ -94,6 +108,35 @@ func (mr *MockAclListMockRecorder) Get(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAclList)(nil).Get), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockAclList)(nil).Get), arg0)
} }
// GetIndex mocks base method.
func (m *MockAclList) GetIndex(arg0 int) (*list.AclRecord, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetIndex", arg0)
ret0, _ := ret[0].(*list.AclRecord)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetIndex indicates an expected call of GetIndex.
func (mr *MockAclListMockRecorder) GetIndex(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIndex", reflect.TypeOf((*MockAclList)(nil).GetIndex), arg0)
}
// HasHead mocks base method.
func (m *MockAclList) HasHead(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HasHead", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// HasHead indicates an expected call of HasHead.
func (mr *MockAclListMockRecorder) HasHead(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasHead", reflect.TypeOf((*MockAclList)(nil).HasHead), arg0)
}
// Head mocks base method. // Head mocks base method.
func (m *MockAclList) Head() *list.AclRecord { func (m *MockAclList) Head() *list.AclRecord {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -211,6 +254,20 @@ func (mr *MockAclListMockRecorder) RUnlock() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RUnlock", reflect.TypeOf((*MockAclList)(nil).RUnlock)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RUnlock", reflect.TypeOf((*MockAclList)(nil).RUnlock))
} }
// RecordBuilder mocks base method.
func (m *MockAclList) RecordBuilder() list.AclRecordBuilder {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RecordBuilder")
ret0, _ := ret[0].(list.AclRecordBuilder)
return ret0
}
// RecordBuilder indicates an expected call of RecordBuilder.
func (mr *MockAclListMockRecorder) RecordBuilder() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordBuilder", reflect.TypeOf((*MockAclList)(nil).RecordBuilder))
}
// Records mocks base method. // Records mocks base method.
func (m *MockAclList) Records() []*list.AclRecord { func (m *MockAclList) Records() []*list.AclRecord {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -225,11 +282,26 @@ func (mr *MockAclListMockRecorder) Records() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Records", reflect.TypeOf((*MockAclList)(nil).Records)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Records", reflect.TypeOf((*MockAclList)(nil).Records))
} }
// RecordsAfter mocks base method.
func (m *MockAclList) RecordsAfter(arg0 context.Context, arg1 string) ([]*consensusproto.RawRecordWithId, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RecordsAfter", arg0, arg1)
ret0, _ := ret[0].([]*consensusproto.RawRecordWithId)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RecordsAfter indicates an expected call of RecordsAfter.
func (mr *MockAclListMockRecorder) RecordsAfter(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordsAfter", reflect.TypeOf((*MockAclList)(nil).RecordsAfter), arg0, arg1)
}
// Root mocks base method. // Root mocks base method.
func (m *MockAclList) Root() *aclrecordproto.RawAclRecordWithId { func (m *MockAclList) Root() *consensusproto.RawRecordWithId {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Root") ret := m.ctrl.Call(m, "Root")
ret0, _ := ret[0].(*aclrecordproto.RawAclRecordWithId) ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
return ret0 return ret0
} }
@ -250,3 +322,17 @@ func (mr *MockAclListMockRecorder) Unlock() *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockAclList)(nil).Unlock)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockAclList)(nil).Unlock))
} }
// ValidateRawRecord mocks base method.
func (m *MockAclList) ValidateRawRecord(arg0 *consensusproto.RawRecord) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ValidateRawRecord", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// ValidateRawRecord indicates an expected call of ValidateRawRecord.
func (mr *MockAclListMockRecorder) ValidateRawRecord(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateRawRecord", reflect.TypeOf((*MockAclList)(nil).ValidateRawRecord), arg0)
}

View File

@ -8,7 +8,6 @@ import (
type AclRecord struct { type AclRecord struct {
Id string Id string
PrevId string PrevId string
ReadKeyId string
Timestamp int64 Timestamp int64
Data []byte Data []byte
Identity crypto.PubKey Identity crypto.PubKey
@ -16,7 +15,55 @@ type AclRecord struct {
Signature []byte Signature []byte
} }
type AclUserState struct { type RequestRecord struct {
PubKey crypto.PubKey RequestIdentity crypto.PubKey
Permissions aclrecordproto.AclUserPermissions RequestMetadata []byte
Type RequestType
}
type AclUserState struct {
PubKey crypto.PubKey
Permissions AclPermissions
RequestMetadata []byte
}
type RequestType int
const (
RequestTypeRemove RequestType = iota
RequestTypeJoin
)
type AclPermissions aclrecordproto.AclUserPermissions
func (p AclPermissions) NoPermissions() bool {
return aclrecordproto.AclUserPermissions(p) == aclrecordproto.AclUserPermissions_None
}
func (p AclPermissions) IsOwner() bool {
return aclrecordproto.AclUserPermissions(p) == aclrecordproto.AclUserPermissions_Owner
}
func (p AclPermissions) CanWrite() bool {
switch aclrecordproto.AclUserPermissions(p) {
case aclrecordproto.AclUserPermissions_Admin:
return true
case aclrecordproto.AclUserPermissions_Writer:
return true
case aclrecordproto.AclUserPermissions_Owner:
return true
default:
return false
}
}
func (p AclPermissions) CanManageAccounts() bool {
switch aclrecordproto.AclUserPermissions(p) {
case aclrecordproto.AclUserPermissions_Admin:
return true
case aclrecordproto.AclUserPermissions_Owner:
return true
default:
return false
}
} }

View File

@ -0,0 +1,218 @@
package list
import (
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/util/crypto"
)
type ContentValidator interface {
ValidateAclRecordContents(ch *AclRecord) (err error)
ValidatePermissionChange(ch *aclrecordproto.AclAccountPermissionChange, authorIdentity crypto.PubKey) (err error)
ValidateInvite(ch *aclrecordproto.AclAccountInvite, authorIdentity crypto.PubKey) (err error)
ValidateInviteRevoke(ch *aclrecordproto.AclAccountInviteRevoke, authorIdentity crypto.PubKey) (err error)
ValidateRequestJoin(ch *aclrecordproto.AclAccountRequestJoin, authorIdentity crypto.PubKey) (err error)
ValidateRequestAccept(ch *aclrecordproto.AclAccountRequestAccept, authorIdentity crypto.PubKey) (err error)
ValidateRequestDecline(ch *aclrecordproto.AclAccountRequestDecline, authorIdentity crypto.PubKey) (err error)
ValidateAccountRemove(ch *aclrecordproto.AclAccountRemove, authorIdentity crypto.PubKey) (err error)
ValidateRequestRemove(ch *aclrecordproto.AclAccountRequestRemove, authorIdentity crypto.PubKey) (err error)
ValidateReadKeyChange(ch *aclrecordproto.AclReadKeyChange, authorIdentity crypto.PubKey) (err error)
}
type contentValidator struct {
keyStore crypto.KeyStorage
aclState *AclState
}
func (c *contentValidator) ValidateAclRecordContents(ch *AclRecord) (err error) {
if ch.PrevId != c.aclState.lastRecordId {
return ErrIncorrectRecordSequence
}
aclData := ch.Model.(*aclrecordproto.AclData)
for _, content := range aclData.AclContent {
err = c.validateAclRecordContent(content, ch.Identity)
if err != nil {
return
}
}
return
}
func (c *contentValidator) validateAclRecordContent(ch *aclrecordproto.AclContentValue, authorIdentity crypto.PubKey) (err error) {
switch {
case ch.GetPermissionChange() != nil:
return c.ValidatePermissionChange(ch.GetPermissionChange(), authorIdentity)
case ch.GetInvite() != nil:
return c.ValidateInvite(ch.GetInvite(), authorIdentity)
case ch.GetInviteRevoke() != nil:
return c.ValidateInviteRevoke(ch.GetInviteRevoke(), authorIdentity)
case ch.GetRequestJoin() != nil:
return c.ValidateRequestJoin(ch.GetRequestJoin(), authorIdentity)
case ch.GetRequestAccept() != nil:
return c.ValidateRequestAccept(ch.GetRequestAccept(), authorIdentity)
case ch.GetRequestDecline() != nil:
return c.ValidateRequestDecline(ch.GetRequestDecline(), authorIdentity)
case ch.GetAccountRemove() != nil:
return c.ValidateAccountRemove(ch.GetAccountRemove(), authorIdentity)
case ch.GetAccountRequestRemove() != nil:
return c.ValidateRequestRemove(ch.GetAccountRequestRemove(), authorIdentity)
case ch.GetReadKeyChange() != nil:
return c.ValidateReadKeyChange(ch.GetReadKeyChange(), authorIdentity)
default:
return ErrUnexpectedContentType
}
}
func (c *contentValidator) ValidatePermissionChange(ch *aclrecordproto.AclAccountPermissionChange, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
chIdentity, err := c.keyStore.PubKeyFromProto(ch.Identity)
if err != nil {
return err
}
_, exists := c.aclState.userStates[mapKeyFromPubKey(chIdentity)]
if !exists {
return ErrNoSuchAccount
}
return
}
func (c *contentValidator) ValidateInvite(ch *aclrecordproto.AclAccountInvite, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
_, err = c.keyStore.PubKeyFromProto(ch.InviteKey)
return
}
func (c *contentValidator) ValidateInviteRevoke(ch *aclrecordproto.AclAccountInviteRevoke, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
_, exists := c.aclState.inviteKeys[ch.InviteRecordId]
if !exists {
return ErrNoSuchInvite
}
return
}
func (c *contentValidator) ValidateRequestJoin(ch *aclrecordproto.AclAccountRequestJoin, authorIdentity crypto.PubKey) (err error) {
inviteKey, exists := c.aclState.inviteKeys[ch.InviteRecordId]
if !exists {
return ErrNoSuchInvite
}
inviteIdentity, err := c.keyStore.PubKeyFromProto(ch.InviteIdentity)
if err != nil {
return
}
if _, exists := c.aclState.pendingRequests[mapKeyFromPubKey(inviteIdentity)]; exists {
return ErrPendingRequest
}
if !authorIdentity.Equals(inviteIdentity) {
return ErrIncorrectIdentity
}
rawInviteIdentity, err := inviteIdentity.Raw()
if err != nil {
return err
}
ok, err := inviteKey.Verify(rawInviteIdentity, ch.InviteIdentitySignature)
if err != nil {
return ErrInvalidSignature
}
if !ok {
return ErrInvalidSignature
}
return
}
func (c *contentValidator) ValidateRequestAccept(ch *aclrecordproto.AclAccountRequestAccept, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
record, exists := c.aclState.requestRecords[ch.RequestRecordId]
if !exists {
return ErrNoSuchRequest
}
acceptIdentity, err := c.keyStore.PubKeyFromProto(ch.Identity)
if err != nil {
return
}
if !acceptIdentity.Equals(record.RequestIdentity) {
return ErrIncorrectIdentity
}
if ch.Permissions == aclrecordproto.AclUserPermissions_Owner {
return ErrInsufficientPermissions
}
return
}
func (c *contentValidator) ValidateRequestDecline(ch *aclrecordproto.AclAccountRequestDecline, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
_, exists := c.aclState.requestRecords[ch.RequestRecordId]
if !exists {
return ErrNoSuchRequest
}
return
}
func (c *contentValidator) ValidateAccountRemove(ch *aclrecordproto.AclAccountRemove, authorIdentity crypto.PubKey) (err error) {
if !c.aclState.Permissions(authorIdentity).CanManageAccounts() {
return ErrInsufficientPermissions
}
seenIdentities := map[string]struct{}{}
for _, rawIdentity := range ch.Identities {
identity, err := c.keyStore.PubKeyFromProto(rawIdentity)
if err != nil {
return err
}
if identity.Equals(authorIdentity) {
return ErrInsufficientPermissions
}
permissions := c.aclState.Permissions(identity)
if permissions.NoPermissions() {
return ErrNoSuchAccount
}
if permissions.IsOwner() {
return ErrInsufficientPermissions
}
idKey := mapKeyFromPubKey(identity)
if _, exists := seenIdentities[idKey]; exists {
return ErrDuplicateAccounts
}
seenIdentities[mapKeyFromPubKey(identity)] = struct{}{}
}
return c.validateAccountReadKeys(ch.AccountKeys, len(c.aclState.userStates)-len(ch.Identities))
}
func (c *contentValidator) ValidateRequestRemove(ch *aclrecordproto.AclAccountRequestRemove, authorIdentity crypto.PubKey) (err error) {
if c.aclState.Permissions(authorIdentity).NoPermissions() {
return ErrInsufficientPermissions
}
if _, exists := c.aclState.pendingRequests[mapKeyFromPubKey(authorIdentity)]; exists {
return ErrPendingRequest
}
return
}
func (c *contentValidator) ValidateReadKeyChange(ch *aclrecordproto.AclReadKeyChange, authorIdentity crypto.PubKey) (err error) {
return c.validateAccountReadKeys(ch.AccountKeys, len(c.aclState.userStates))
}
func (c *contentValidator) validateAccountReadKeys(accountKeys []*aclrecordproto.AclEncryptedReadKey, usersNum int) (err error) {
if len(accountKeys) != usersNum {
return ErrIncorrectNumberOfAccounts
}
for _, encKeys := range accountKeys {
identity, err := c.keyStore.PubKeyFromProto(encKeys.Identity)
if err != nil {
return err
}
_, exists := c.aclState.userStates[mapKeyFromPubKey(identity)]
if !exists {
return ErrNoSuchAccount
}
}
return
}

View File

@ -3,24 +3,26 @@ package liststorage
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"sync" "sync"
) )
type inMemoryAclListStorage struct { type inMemoryAclListStorage struct {
id string id string
root *aclrecordproto.RawAclRecordWithId root *consensusproto.RawRecordWithId
head string head string
records map[string]*aclrecordproto.RawAclRecordWithId records map[string]*consensusproto.RawRecordWithId
sync.RWMutex sync.RWMutex
} }
func NewInMemoryAclListStorage( func NewInMemoryAclListStorage(
id string, id string,
records []*aclrecordproto.RawAclRecordWithId) (ListStorage, error) { records []*consensusproto.RawRecordWithId) (ListStorage, error) {
allRecords := make(map[string]*aclrecordproto.RawAclRecordWithId) allRecords := make(map[string]*consensusproto.RawRecordWithId)
for _, ch := range records { for _, ch := range records {
allRecords[ch.Id] = ch allRecords[ch.Id] = ch
} }
@ -41,7 +43,7 @@ func (t *inMemoryAclListStorage) Id() string {
return t.id return t.id
} }
func (t *inMemoryAclListStorage) Root() (*aclrecordproto.RawAclRecordWithId, error) { func (t *inMemoryAclListStorage) Root() (*consensusproto.RawRecordWithId, error) {
t.RLock() t.RLock()
defer t.RUnlock() defer t.RUnlock()
return t.root, nil return t.root, nil
@ -60,7 +62,7 @@ func (t *inMemoryAclListStorage) SetHead(head string) error {
return nil return nil
} }
func (t *inMemoryAclListStorage) AddRawRecord(ctx context.Context, record *aclrecordproto.RawAclRecordWithId) error { func (t *inMemoryAclListStorage) AddRawRecord(ctx context.Context, record *consensusproto.RawRecordWithId) error {
t.Lock() t.Lock()
defer t.Unlock() defer t.Unlock()
// TODO: better to do deep copy // TODO: better to do deep copy
@ -68,7 +70,7 @@ func (t *inMemoryAclListStorage) AddRawRecord(ctx context.Context, record *aclre
return nil return nil
} }
func (t *inMemoryAclListStorage) GetRawRecord(ctx context.Context, recordId string) (*aclrecordproto.RawAclRecordWithId, error) { func (t *inMemoryAclListStorage) GetRawRecord(ctx context.Context, recordId string) (*consensusproto.RawRecordWithId, error) {
t.RLock() t.RLock()
defer t.RUnlock() defer t.RUnlock()
if res, exists := t.records[recordId]; exists { if res, exists := t.records[recordId]; exists {

View File

@ -4,7 +4,8 @@ package liststorage
import ( import (
"context" "context"
"errors" "errors"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
) )
var ( var (
@ -14,15 +15,15 @@ var (
) )
type Exporter interface { type Exporter interface {
ListStorage(root *aclrecordproto.RawAclRecordWithId) (ListStorage, error) ListStorage(root *consensusproto.RawRecordWithId) (ListStorage, error)
} }
type ListStorage interface { type ListStorage interface {
Id() string Id() string
Root() (*aclrecordproto.RawAclRecordWithId, error) Root() (*consensusproto.RawRecordWithId, error)
Head() (string, error) Head() (string, error)
SetHead(headId string) error SetHead(headId string) error
GetRawRecord(ctx context.Context, id string) (*aclrecordproto.RawAclRecordWithId, error) GetRawRecord(ctx context.Context, id string) (*consensusproto.RawRecordWithId, error)
AddRawRecord(ctx context.Context, rec *aclrecordproto.RawAclRecordWithId) error AddRawRecord(ctx context.Context, rec *consensusproto.RawRecordWithId) error
} }

View File

@ -8,8 +8,8 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
aclrecordproto "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" consensusproto "github.com/anyproto/any-sync/consensus/consensusproto"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockListStorage is a mock of ListStorage interface. // MockListStorage is a mock of ListStorage interface.
@ -36,7 +36,7 @@ func (m *MockListStorage) EXPECT() *MockListStorageMockRecorder {
} }
// AddRawRecord mocks base method. // AddRawRecord mocks base method.
func (m *MockListStorage) AddRawRecord(arg0 context.Context, arg1 *aclrecordproto.RawAclRecordWithId) error { func (m *MockListStorage) AddRawRecord(arg0 context.Context, arg1 *consensusproto.RawRecordWithId) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddRawRecord", arg0, arg1) ret := m.ctrl.Call(m, "AddRawRecord", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
@ -50,10 +50,10 @@ func (mr *MockListStorageMockRecorder) AddRawRecord(arg0, arg1 interface{}) *gom
} }
// GetRawRecord mocks base method. // GetRawRecord mocks base method.
func (m *MockListStorage) GetRawRecord(arg0 context.Context, arg1 string) (*aclrecordproto.RawAclRecordWithId, error) { func (m *MockListStorage) GetRawRecord(arg0 context.Context, arg1 string) (*consensusproto.RawRecordWithId, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRawRecord", arg0, arg1) ret := m.ctrl.Call(m, "GetRawRecord", arg0, arg1)
ret0, _ := ret[0].(*aclrecordproto.RawAclRecordWithId) ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
@ -94,10 +94,10 @@ func (mr *MockListStorageMockRecorder) Id() *gomock.Call {
} }
// Root mocks base method. // Root mocks base method.
func (m *MockListStorage) Root() (*aclrecordproto.RawAclRecordWithId, error) { func (m *MockListStorage) Root() (*consensusproto.RawRecordWithId, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Root") ret := m.ctrl.Call(m, "Root")
ret0, _ := ret[0].(*aclrecordproto.RawAclRecordWithId) ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }

View File

@ -0,0 +1,120 @@
//go:generate mockgen -destination mock_syncacl/mock_syncacl.go github.com/anyproto/any-sync/commonspace/object/acl/syncacl SyncAcl,SyncClient,RequestFactory,AclSyncProtocol
package syncacl
import (
"context"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/consensus/consensusproto"
"go.uber.org/zap"
)
type AclSyncProtocol interface {
HeadUpdate(ctx context.Context, senderId string, update *consensusproto.LogHeadUpdate) (request *consensusproto.LogSyncMessage, err error)
FullSyncRequest(ctx context.Context, senderId string, request *consensusproto.LogFullSyncRequest) (response *consensusproto.LogSyncMessage, err error)
FullSyncResponse(ctx context.Context, senderId string, response *consensusproto.LogFullSyncResponse) (err error)
}
type aclSyncProtocol struct {
log logger.CtxLogger
spaceId string
aclList list.AclList
reqFactory RequestFactory
}
func (a *aclSyncProtocol) HeadUpdate(ctx context.Context, senderId string, update *consensusproto.LogHeadUpdate) (request *consensusproto.LogSyncMessage, err error) {
isEmptyUpdate := len(update.Records) == 0
log := a.log.With(
zap.String("senderId", senderId),
zap.String("update head", update.Head),
zap.Int("len(update records)", len(update.Records)))
log.DebugCtx(ctx, "received acl head update message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "acl head update finished with error", zap.Error(err))
} else if request != nil {
cnt := request.Content.GetFullSyncRequest()
log.DebugCtx(ctx, "returning acl full sync request", zap.String("request head", cnt.Head))
} else {
if !isEmptyUpdate {
log.DebugCtx(ctx, "acl head update finished correctly")
}
}
}()
if isEmptyUpdate {
headEquals := a.aclList.Head().Id == update.Head
log.DebugCtx(ctx, "is empty acl head update", zap.Bool("headEquals", headEquals))
if headEquals {
return
}
return a.reqFactory.CreateFullSyncRequest(a.aclList, update.Head)
}
if a.aclList.HasHead(update.Head) {
return
}
err = a.aclList.AddRawRecords(update.Records)
if err == list.ErrIncorrectRecordSequence {
return a.reqFactory.CreateFullSyncRequest(a.aclList, update.Head)
}
return
}
func (a *aclSyncProtocol) FullSyncRequest(ctx context.Context, senderId string, request *consensusproto.LogFullSyncRequest) (response *consensusproto.LogSyncMessage, err error) {
log := a.log.With(
zap.String("senderId", senderId),
zap.String("request head", request.Head),
zap.Int("len(request records)", len(request.Records)))
log.DebugCtx(ctx, "received acl full sync request message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "acl full sync request finished with error", zap.Error(err))
} else if response != nil {
cnt := response.Content.GetFullSyncResponse()
log.DebugCtx(ctx, "acl full sync response sent", zap.String("response head", cnt.Head), zap.Int("len(response records)", len(cnt.Records)))
}
}()
if !a.aclList.HasHead(request.Head) {
if len(request.Records) > 0 {
// in this case we can try to add some records
err = a.aclList.AddRawRecords(request.Records)
if err != nil {
return
}
} else {
// here it is impossible for us to do anything, we can't return records after head as defined in request, because we don't have it
return nil, list.ErrIncorrectRecordSequence
}
}
return a.reqFactory.CreateFullSyncResponse(a.aclList, request.Head)
}
func (a *aclSyncProtocol) FullSyncResponse(ctx context.Context, senderId string, response *consensusproto.LogFullSyncResponse) (err error) {
log := a.log.With(
zap.String("senderId", senderId),
zap.String("response head", response.Head),
zap.Int("len(response records)", len(response.Records)))
log.DebugCtx(ctx, "received acl full sync response message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "acl full sync response failed", zap.Error(err))
} else {
log.DebugCtx(ctx, "acl full sync response succeeded")
}
}()
if a.aclList.HasHead(response.Head) {
return
}
return a.aclList.AddRawRecords(response.Records)
}
func newAclSyncProtocol(spaceId string, aclList list.AclList, reqFactory RequestFactory) *aclSyncProtocol {
return &aclSyncProtocol{
log: log.With(zap.String("spaceId", spaceId), zap.String("aclId", aclList.Id())),
spaceId: spaceId,
aclList: aclList,
reqFactory: reqFactory,
}
}

View File

@ -0,0 +1,213 @@
package syncacl
import (
"context"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/acl/list/mock_list"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl/mock_syncacl"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"testing"
)
type aclSyncProtocolFixture struct {
log logger.CtxLogger
spaceId string
senderId string
aclId string
aclMock *mock_list.MockAclList
reqFactory *mock_syncacl.MockRequestFactory
ctrl *gomock.Controller
syncProtocol AclSyncProtocol
}
func newSyncProtocolFixture(t *testing.T) *aclSyncProtocolFixture {
ctrl := gomock.NewController(t)
aclList := mock_list.NewMockAclList(ctrl)
spaceId := "spaceId"
reqFactory := mock_syncacl.NewMockRequestFactory(ctrl)
aclList.EXPECT().Id().Return("aclId")
syncProtocol := newAclSyncProtocol(spaceId, aclList, reqFactory)
return &aclSyncProtocolFixture{
log: log,
spaceId: spaceId,
senderId: "senderId",
aclId: "aclId",
aclMock: aclList,
reqFactory: reqFactory,
ctrl: ctrl,
syncProtocol: syncProtocol,
}
}
func (fx *aclSyncProtocolFixture) stop() {
fx.ctrl.Finish()
}
func TestHeadUpdate(t *testing.T) {
ctx := context.Background()
fullRequest := &consensusproto.LogSyncMessage{
Content: &consensusproto.LogSyncContentValue{
Value: &consensusproto.LogSyncContentValue_FullSyncRequest{
FullSyncRequest: &consensusproto.LogFullSyncRequest{},
},
},
}
t.Run("head update non empty all heads added", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(false)
fx.aclMock.EXPECT().AddRawRecords(headUpdate.Records).Return(nil)
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.Nil(t, req)
require.NoError(t, err)
})
t.Run("head update results in full request", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(false)
fx.aclMock.EXPECT().AddRawRecords(headUpdate.Records).Return(list.ErrIncorrectRecordSequence)
fx.reqFactory.EXPECT().CreateFullSyncRequest(fx.aclMock, headUpdate.Head).Return(fullRequest, nil)
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.Equal(t, fullRequest, req)
require.NoError(t, err)
})
t.Run("head update old heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(true)
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.Nil(t, req)
require.NoError(t, err)
})
t.Run("head update empty equals", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().Head().Return(&list.AclRecord{Id: "h1"})
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.Nil(t, req)
require.NoError(t, err)
})
t.Run("head update empty results in full request", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().Head().Return(&list.AclRecord{Id: "h2"})
fx.reqFactory.EXPECT().CreateFullSyncRequest(fx.aclMock, headUpdate.Head).Return(fullRequest, nil)
req, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.Equal(t, fullRequest, req)
require.NoError(t, err)
})
}
func TestFullSyncRequest(t *testing.T) {
ctx := context.Background()
fullResponse := &consensusproto.LogSyncMessage{
Content: &consensusproto.LogSyncContentValue{
Value: &consensusproto.LogSyncContentValue_FullSyncResponse{
FullSyncResponse: &consensusproto.LogFullSyncResponse{},
},
},
}
t.Run("full sync request non empty all heads added", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullRequest := &consensusproto.LogFullSyncRequest{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(false)
fx.aclMock.EXPECT().AddRawRecords(fullRequest.Records).Return(nil)
fx.reqFactory.EXPECT().CreateFullSyncResponse(fx.aclMock, fullRequest.Head).Return(fullResponse, nil)
resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest)
require.Equal(t, fullResponse, resp)
require.NoError(t, err)
})
t.Run("full sync request non empty head exists", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullRequest := &consensusproto.LogFullSyncRequest{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(true)
fx.reqFactory.EXPECT().CreateFullSyncResponse(fx.aclMock, fullRequest.Head).Return(fullResponse, nil)
resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest)
require.Equal(t, fullResponse, resp)
require.NoError(t, err)
})
t.Run("full sync request empty head not exists", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
fullRequest := &consensusproto.LogFullSyncRequest{
Head: "h1",
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(false)
resp, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullRequest)
require.Nil(t, resp)
require.Error(t, list.ErrIncorrectRecordSequence, err)
})
}
func TestFullSyncResponse(t *testing.T) {
ctx := context.Background()
t.Run("full sync response no heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullResponse := &consensusproto.LogFullSyncResponse{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(false)
fx.aclMock.EXPECT().AddRawRecords(fullResponse.Records).Return(nil)
err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullResponse)
require.NoError(t, err)
})
t.Run("full sync response has heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullResponse := &consensusproto.LogFullSyncResponse{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.aclMock.EXPECT().HasHead("h1").Return(true)
err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullResponse)
require.NoError(t, err)
})
}

View File

@ -0,0 +1,5 @@
package headupdater
type HeadUpdater interface {
UpdateHeads(id string, heads []string)
}

View File

@ -0,0 +1,694 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/object/acl/syncacl (interfaces: SyncAcl,SyncClient,RequestFactory,AclSyncProtocol)
// Package mock_syncacl is a generated GoMock package.
package mock_syncacl
import (
context "context"
reflect "reflect"
app "github.com/anyproto/any-sync/app"
list "github.com/anyproto/any-sync/commonspace/object/acl/list"
headupdater "github.com/anyproto/any-sync/commonspace/object/acl/syncacl/headupdater"
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
consensusproto "github.com/anyproto/any-sync/consensus/consensusproto"
crypto "github.com/anyproto/any-sync/util/crypto"
gomock "go.uber.org/mock/gomock"
)
// MockSyncAcl is a mock of SyncAcl interface.
type MockSyncAcl struct {
ctrl *gomock.Controller
recorder *MockSyncAclMockRecorder
}
// MockSyncAclMockRecorder is the mock recorder for MockSyncAcl.
type MockSyncAclMockRecorder struct {
mock *MockSyncAcl
}
// NewMockSyncAcl creates a new mock instance.
func NewMockSyncAcl(ctrl *gomock.Controller) *MockSyncAcl {
mock := &MockSyncAcl{ctrl: ctrl}
mock.recorder = &MockSyncAclMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSyncAcl) EXPECT() *MockSyncAclMockRecorder {
return m.recorder
}
// AclState mocks base method.
func (m *MockSyncAcl) AclState() *list.AclState {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AclState")
ret0, _ := ret[0].(*list.AclState)
return ret0
}
// AclState indicates an expected call of AclState.
func (mr *MockSyncAclMockRecorder) AclState() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AclState", reflect.TypeOf((*MockSyncAcl)(nil).AclState))
}
// AddRawRecord mocks base method.
func (m *MockSyncAcl) AddRawRecord(arg0 *consensusproto.RawRecordWithId) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddRawRecord", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// AddRawRecord indicates an expected call of AddRawRecord.
func (mr *MockSyncAclMockRecorder) AddRawRecord(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecord", reflect.TypeOf((*MockSyncAcl)(nil).AddRawRecord), arg0)
}
// AddRawRecords mocks base method.
func (m *MockSyncAcl) AddRawRecords(arg0 []*consensusproto.RawRecordWithId) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddRawRecords", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// AddRawRecords indicates an expected call of AddRawRecords.
func (mr *MockSyncAclMockRecorder) AddRawRecords(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddRawRecords", reflect.TypeOf((*MockSyncAcl)(nil).AddRawRecords), arg0)
}
// Close mocks base method.
func (m *MockSyncAcl) Close(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Close", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Close indicates an expected call of Close.
func (mr *MockSyncAclMockRecorder) Close(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockSyncAcl)(nil).Close), arg0)
}
// Get mocks base method.
func (m *MockSyncAcl) Get(arg0 string) (*list.AclRecord, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Get", arg0)
ret0, _ := ret[0].(*list.AclRecord)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Get indicates an expected call of Get.
func (mr *MockSyncAclMockRecorder) Get(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Get", reflect.TypeOf((*MockSyncAcl)(nil).Get), arg0)
}
// GetIndex mocks base method.
func (m *MockSyncAcl) GetIndex(arg0 int) (*list.AclRecord, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetIndex", arg0)
ret0, _ := ret[0].(*list.AclRecord)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetIndex indicates an expected call of GetIndex.
func (mr *MockSyncAclMockRecorder) GetIndex(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetIndex", reflect.TypeOf((*MockSyncAcl)(nil).GetIndex), arg0)
}
// HandleMessage mocks base method.
func (m *MockSyncAcl) HandleMessage(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// HandleMessage indicates an expected call of HandleMessage.
func (mr *MockSyncAclMockRecorder) HandleMessage(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockSyncAcl)(nil).HandleMessage), arg0, arg1, arg2)
}
// HandleRequest mocks base method.
func (m *MockSyncAcl) HandleRequest(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandleRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// HandleRequest indicates an expected call of HandleRequest.
func (mr *MockSyncAclMockRecorder) HandleRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleRequest", reflect.TypeOf((*MockSyncAcl)(nil).HandleRequest), arg0, arg1, arg2)
}
// HasHead mocks base method.
func (m *MockSyncAcl) HasHead(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HasHead", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// HasHead indicates an expected call of HasHead.
func (mr *MockSyncAclMockRecorder) HasHead(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HasHead", reflect.TypeOf((*MockSyncAcl)(nil).HasHead), arg0)
}
// Head mocks base method.
func (m *MockSyncAcl) Head() *list.AclRecord {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Head")
ret0, _ := ret[0].(*list.AclRecord)
return ret0
}
// Head indicates an expected call of Head.
func (mr *MockSyncAclMockRecorder) Head() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Head", reflect.TypeOf((*MockSyncAcl)(nil).Head))
}
// Id mocks base method.
func (m *MockSyncAcl) Id() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Id")
ret0, _ := ret[0].(string)
return ret0
}
// Id indicates an expected call of Id.
func (mr *MockSyncAclMockRecorder) Id() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Id", reflect.TypeOf((*MockSyncAcl)(nil).Id))
}
// Init mocks base method.
func (m *MockSyncAcl) Init(arg0 *app.App) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Init indicates an expected call of Init.
func (mr *MockSyncAclMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockSyncAcl)(nil).Init), arg0)
}
// IsAfter mocks base method.
func (m *MockSyncAcl) IsAfter(arg0, arg1 string) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsAfter", arg0, arg1)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsAfter indicates an expected call of IsAfter.
func (mr *MockSyncAclMockRecorder) IsAfter(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsAfter", reflect.TypeOf((*MockSyncAcl)(nil).IsAfter), arg0, arg1)
}
// Iterate mocks base method.
func (m *MockSyncAcl) Iterate(arg0 func(*list.AclRecord) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Iterate", arg0)
}
// Iterate indicates an expected call of Iterate.
func (mr *MockSyncAclMockRecorder) Iterate(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Iterate", reflect.TypeOf((*MockSyncAcl)(nil).Iterate), arg0)
}
// IterateFrom mocks base method.
func (m *MockSyncAcl) IterateFrom(arg0 string, arg1 func(*list.AclRecord) bool) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "IterateFrom", arg0, arg1)
}
// IterateFrom indicates an expected call of IterateFrom.
func (mr *MockSyncAclMockRecorder) IterateFrom(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IterateFrom", reflect.TypeOf((*MockSyncAcl)(nil).IterateFrom), arg0, arg1)
}
// KeyStorage mocks base method.
func (m *MockSyncAcl) KeyStorage() crypto.KeyStorage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "KeyStorage")
ret0, _ := ret[0].(crypto.KeyStorage)
return ret0
}
// KeyStorage indicates an expected call of KeyStorage.
func (mr *MockSyncAclMockRecorder) KeyStorage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "KeyStorage", reflect.TypeOf((*MockSyncAcl)(nil).KeyStorage))
}
// Lock mocks base method.
func (m *MockSyncAcl) Lock() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Lock")
}
// Lock indicates an expected call of Lock.
func (mr *MockSyncAclMockRecorder) Lock() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Lock", reflect.TypeOf((*MockSyncAcl)(nil).Lock))
}
// Name mocks base method.
func (m *MockSyncAcl) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockSyncAclMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockSyncAcl)(nil).Name))
}
// RLock mocks base method.
func (m *MockSyncAcl) RLock() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RLock")
}
// RLock indicates an expected call of RLock.
func (mr *MockSyncAclMockRecorder) RLock() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RLock", reflect.TypeOf((*MockSyncAcl)(nil).RLock))
}
// RUnlock mocks base method.
func (m *MockSyncAcl) RUnlock() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "RUnlock")
}
// RUnlock indicates an expected call of RUnlock.
func (mr *MockSyncAclMockRecorder) RUnlock() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RUnlock", reflect.TypeOf((*MockSyncAcl)(nil).RUnlock))
}
// RecordBuilder mocks base method.
func (m *MockSyncAcl) RecordBuilder() list.AclRecordBuilder {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RecordBuilder")
ret0, _ := ret[0].(list.AclRecordBuilder)
return ret0
}
// RecordBuilder indicates an expected call of RecordBuilder.
func (mr *MockSyncAclMockRecorder) RecordBuilder() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordBuilder", reflect.TypeOf((*MockSyncAcl)(nil).RecordBuilder))
}
// Records mocks base method.
func (m *MockSyncAcl) Records() []*list.AclRecord {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Records")
ret0, _ := ret[0].([]*list.AclRecord)
return ret0
}
// Records indicates an expected call of Records.
func (mr *MockSyncAclMockRecorder) Records() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Records", reflect.TypeOf((*MockSyncAcl)(nil).Records))
}
// RecordsAfter mocks base method.
func (m *MockSyncAcl) RecordsAfter(arg0 context.Context, arg1 string) ([]*consensusproto.RawRecordWithId, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RecordsAfter", arg0, arg1)
ret0, _ := ret[0].([]*consensusproto.RawRecordWithId)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RecordsAfter indicates an expected call of RecordsAfter.
func (mr *MockSyncAclMockRecorder) RecordsAfter(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RecordsAfter", reflect.TypeOf((*MockSyncAcl)(nil).RecordsAfter), arg0, arg1)
}
// Root mocks base method.
func (m *MockSyncAcl) Root() *consensusproto.RawRecordWithId {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Root")
ret0, _ := ret[0].(*consensusproto.RawRecordWithId)
return ret0
}
// Root indicates an expected call of Root.
func (mr *MockSyncAclMockRecorder) Root() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Root", reflect.TypeOf((*MockSyncAcl)(nil).Root))
}
// Run mocks base method.
func (m *MockSyncAcl) Run(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Run indicates an expected call of Run.
func (mr *MockSyncAclMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockSyncAcl)(nil).Run), arg0)
}
// SetHeadUpdater mocks base method.
func (m *MockSyncAcl) SetHeadUpdater(arg0 headupdater.HeadUpdater) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "SetHeadUpdater", arg0)
}
// SetHeadUpdater indicates an expected call of SetHeadUpdater.
func (mr *MockSyncAclMockRecorder) SetHeadUpdater(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetHeadUpdater", reflect.TypeOf((*MockSyncAcl)(nil).SetHeadUpdater), arg0)
}
// SyncWithPeer mocks base method.
func (m *MockSyncAcl) SyncWithPeer(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SyncWithPeer", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SyncWithPeer indicates an expected call of SyncWithPeer.
func (mr *MockSyncAclMockRecorder) SyncWithPeer(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncWithPeer", reflect.TypeOf((*MockSyncAcl)(nil).SyncWithPeer), arg0, arg1)
}
// Unlock mocks base method.
func (m *MockSyncAcl) Unlock() {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Unlock")
}
// Unlock indicates an expected call of Unlock.
func (mr *MockSyncAclMockRecorder) Unlock() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Unlock", reflect.TypeOf((*MockSyncAcl)(nil).Unlock))
}
// ValidateRawRecord mocks base method.
func (m *MockSyncAcl) ValidateRawRecord(arg0 *consensusproto.RawRecord) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ValidateRawRecord", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// ValidateRawRecord indicates an expected call of ValidateRawRecord.
func (mr *MockSyncAclMockRecorder) ValidateRawRecord(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateRawRecord", reflect.TypeOf((*MockSyncAcl)(nil).ValidateRawRecord), arg0)
}
// MockSyncClient is a mock of SyncClient interface.
type MockSyncClient struct {
ctrl *gomock.Controller
recorder *MockSyncClientMockRecorder
}
// MockSyncClientMockRecorder is the mock recorder for MockSyncClient.
type MockSyncClientMockRecorder struct {
mock *MockSyncClient
}
// NewMockSyncClient creates a new mock instance.
func NewMockSyncClient(ctrl *gomock.Controller) *MockSyncClient {
mock := &MockSyncClient{ctrl: ctrl}
mock.recorder = &MockSyncClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSyncClient) EXPECT() *MockSyncClientMockRecorder {
return m.recorder
}
// Broadcast mocks base method.
func (m *MockSyncClient) Broadcast(arg0 *consensusproto.LogSyncMessage) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Broadcast", arg0)
}
// Broadcast indicates an expected call of Broadcast.
func (mr *MockSyncClientMockRecorder) Broadcast(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Broadcast", reflect.TypeOf((*MockSyncClient)(nil).Broadcast), arg0)
}
// CreateFullSyncRequest mocks base method.
func (m *MockSyncClient) CreateFullSyncRequest(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
func (mr *MockSyncClientMockRecorder) CreateFullSyncRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncRequest), arg0, arg1)
}
// CreateFullSyncResponse mocks base method.
func (m *MockSyncClient) CreateFullSyncResponse(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
func (mr *MockSyncClientMockRecorder) CreateFullSyncResponse(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncResponse), arg0, arg1)
}
// CreateHeadUpdate mocks base method.
func (m *MockSyncClient) CreateHeadUpdate(arg0 list.AclList, arg1 []*consensusproto.RawRecordWithId) *consensusproto.LogSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
return ret0
}
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate.
func (mr *MockSyncClientMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockSyncClient)(nil).CreateHeadUpdate), arg0, arg1)
}
// QueueRequest mocks base method.
func (m *MockSyncClient) QueueRequest(arg0 string, arg1 *consensusproto.LogSyncMessage) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "QueueRequest", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// QueueRequest indicates an expected call of QueueRequest.
func (mr *MockSyncClientMockRecorder) QueueRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueRequest", reflect.TypeOf((*MockSyncClient)(nil).QueueRequest), arg0, arg1)
}
// SendRequest mocks base method.
func (m *MockSyncClient) SendRequest(arg0 context.Context, arg1 string, arg2 *consensusproto.LogSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SendRequest indicates an expected call of SendRequest.
func (mr *MockSyncClientMockRecorder) SendRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRequest", reflect.TypeOf((*MockSyncClient)(nil).SendRequest), arg0, arg1, arg2)
}
// SendUpdate mocks base method.
func (m *MockSyncClient) SendUpdate(arg0 string, arg1 *consensusproto.LogSyncMessage) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendUpdate", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// SendUpdate indicates an expected call of SendUpdate.
func (mr *MockSyncClientMockRecorder) SendUpdate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendUpdate", reflect.TypeOf((*MockSyncClient)(nil).SendUpdate), arg0, arg1)
}
// MockRequestFactory is a mock of RequestFactory interface.
type MockRequestFactory struct {
ctrl *gomock.Controller
recorder *MockRequestFactoryMockRecorder
}
// MockRequestFactoryMockRecorder is the mock recorder for MockRequestFactory.
type MockRequestFactoryMockRecorder struct {
mock *MockRequestFactory
}
// NewMockRequestFactory creates a new mock instance.
func NewMockRequestFactory(ctrl *gomock.Controller) *MockRequestFactory {
mock := &MockRequestFactory{ctrl: ctrl}
mock.recorder = &MockRequestFactoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRequestFactory) EXPECT() *MockRequestFactoryMockRecorder {
return m.recorder
}
// CreateFullSyncRequest mocks base method.
func (m *MockRequestFactory) CreateFullSyncRequest(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
func (mr *MockRequestFactoryMockRecorder) CreateFullSyncRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncRequest), arg0, arg1)
}
// CreateFullSyncResponse mocks base method.
func (m *MockRequestFactory) CreateFullSyncResponse(arg0 list.AclList, arg1 string) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
func (mr *MockRequestFactoryMockRecorder) CreateFullSyncResponse(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncResponse), arg0, arg1)
}
// CreateHeadUpdate mocks base method.
func (m *MockRequestFactory) CreateHeadUpdate(arg0 list.AclList, arg1 []*consensusproto.RawRecordWithId) *consensusproto.LogSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
return ret0
}
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate.
func (mr *MockRequestFactoryMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockRequestFactory)(nil).CreateHeadUpdate), arg0, arg1)
}
// MockAclSyncProtocol is a mock of AclSyncProtocol interface.
type MockAclSyncProtocol struct {
ctrl *gomock.Controller
recorder *MockAclSyncProtocolMockRecorder
}
// MockAclSyncProtocolMockRecorder is the mock recorder for MockAclSyncProtocol.
type MockAclSyncProtocolMockRecorder struct {
mock *MockAclSyncProtocol
}
// NewMockAclSyncProtocol creates a new mock instance.
func NewMockAclSyncProtocol(ctrl *gomock.Controller) *MockAclSyncProtocol {
mock := &MockAclSyncProtocol{ctrl: ctrl}
mock.recorder = &MockAclSyncProtocolMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockAclSyncProtocol) EXPECT() *MockAclSyncProtocolMockRecorder {
return m.recorder
}
// FullSyncRequest mocks base method.
func (m *MockAclSyncProtocol) FullSyncRequest(arg0 context.Context, arg1 string, arg2 *consensusproto.LogFullSyncRequest) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FullSyncRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FullSyncRequest indicates an expected call of FullSyncRequest.
func (mr *MockAclSyncProtocolMockRecorder) FullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncRequest", reflect.TypeOf((*MockAclSyncProtocol)(nil).FullSyncRequest), arg0, arg1, arg2)
}
// FullSyncResponse mocks base method.
func (m *MockAclSyncProtocol) FullSyncResponse(arg0 context.Context, arg1 string, arg2 *consensusproto.LogFullSyncResponse) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FullSyncResponse", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// FullSyncResponse indicates an expected call of FullSyncResponse.
func (mr *MockAclSyncProtocolMockRecorder) FullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncResponse", reflect.TypeOf((*MockAclSyncProtocol)(nil).FullSyncResponse), arg0, arg1, arg2)
}
// HeadUpdate mocks base method.
func (m *MockAclSyncProtocol) HeadUpdate(arg0 context.Context, arg1 string, arg2 *consensusproto.LogHeadUpdate) (*consensusproto.LogSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HeadUpdate", arg0, arg1, arg2)
ret0, _ := ret[0].(*consensusproto.LogSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// HeadUpdate indicates an expected call of HeadUpdate.
func (mr *MockAclSyncProtocolMockRecorder) HeadUpdate(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadUpdate", reflect.TypeOf((*MockAclSyncProtocol)(nil).HeadUpdate), arg0, arg1, arg2)
}

View File

@ -0,0 +1,54 @@
package syncacl
import (
"context"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/consensus/consensusproto"
)
type RequestFactory interface {
CreateHeadUpdate(l list.AclList, added []*consensusproto.RawRecordWithId) (msg *consensusproto.LogSyncMessage)
CreateFullSyncRequest(l list.AclList, theirHead string) (req *consensusproto.LogSyncMessage, err error)
CreateFullSyncResponse(l list.AclList, theirHead string) (*consensusproto.LogSyncMessage, error)
}
type requestFactory struct{}
func NewRequestFactory() RequestFactory {
return &requestFactory{}
}
func (r *requestFactory) CreateHeadUpdate(l list.AclList, added []*consensusproto.RawRecordWithId) (msg *consensusproto.LogSyncMessage) {
return consensusproto.WrapHeadUpdate(&consensusproto.LogHeadUpdate{
Head: l.Head().Id,
Records: added,
}, l.Root())
}
func (r *requestFactory) CreateFullSyncRequest(l list.AclList, theirHead string) (req *consensusproto.LogSyncMessage, err error) {
if !l.HasHead(theirHead) {
return consensusproto.WrapFullRequest(&consensusproto.LogFullSyncRequest{
Head: l.Head().Id,
}, l.Root()), nil
}
records, err := l.RecordsAfter(context.Background(), theirHead)
if err != nil {
return
}
return consensusproto.WrapFullRequest(&consensusproto.LogFullSyncRequest{
Head: l.Head().Id,
Records: records,
}, l.Root()), nil
}
func (r *requestFactory) CreateFullSyncResponse(l list.AclList, theirHead string) (resp *consensusproto.LogSyncMessage, err error) {
records, err := l.RecordsAfter(context.Background(), theirHead)
if err != nil {
return
}
return consensusproto.WrapFullResponse(&consensusproto.LogFullSyncResponse{
Head: l.Head().Id,
Records: records,
}, l.Root()), nil
}

View File

@ -1,21 +1,130 @@
package syncacl package syncacl
import ( import (
"context"
"errors"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl/headupdater"
"github.com/anyproto/any-sync/commonspace/object/syncobjectgetter"
"github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/requestmanager"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/consensus/consensusproto"
) )
type SyncAcl struct { const CName = "common.acl.syncacl"
var (
log = logger.NewNamed(CName)
ErrSyncAclClosed = errors.New("sync acl is closed")
)
type SyncAcl interface {
app.ComponentRunnable
list.AclList list.AclList
synchandler.SyncHandler syncobjectgetter.SyncObject
messagePool objectsync.MessagePool SetHeadUpdater(updater headupdater.HeadUpdater)
SyncWithPeer(ctx context.Context, peerId string) (err error)
} }
func NewSyncAcl(aclList list.AclList, messagePool objectsync.MessagePool) *SyncAcl { func New() SyncAcl {
return &SyncAcl{ return &syncAcl{}
AclList: aclList, }
SyncHandler: nil,
messagePool: messagePool, type syncAcl struct {
} list.AclList
syncClient SyncClient
syncHandler synchandler.SyncHandler
headUpdater headupdater.HeadUpdater
isClosed bool
}
func (s *syncAcl) Run(ctx context.Context) (err error) {
return
}
func (s *syncAcl) HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) {
return s.syncHandler.HandleRequest(ctx, senderId, request)
}
func (s *syncAcl) SetHeadUpdater(updater headupdater.HeadUpdater) {
s.headUpdater = updater
}
func (s *syncAcl) HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) {
return s.syncHandler.HandleMessage(ctx, senderId, request)
}
func (s *syncAcl) Init(a *app.App) (err error) {
storage := a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
aclStorage, err := storage.AclStorage()
if err != nil {
return err
}
acc := a.MustComponent(accountservice.CName).(accountservice.Service)
s.AclList, err = list.BuildAclListWithIdentity(acc.Account(), aclStorage, list.NoOpAcceptorVerifier{})
if err != nil {
return
}
spaceId := storage.Id()
requestManager := a.MustComponent(requestmanager.CName).(requestmanager.RequestManager)
peerManager := a.MustComponent(peermanager.CName).(peermanager.PeerManager)
syncStatus := a.MustComponent(syncstatus.CName).(syncstatus.StatusService)
s.syncClient = NewSyncClient(spaceId, requestManager, peerManager)
s.syncHandler = newSyncAclHandler(storage.Id(), s, s.syncClient, syncStatus)
return err
}
func (s *syncAcl) AddRawRecord(rawRec *consensusproto.RawRecordWithId) (err error) {
if s.isClosed {
return ErrSyncAclClosed
}
err = s.AclList.AddRawRecord(rawRec)
if err != nil {
return
}
headUpdate := s.syncClient.CreateHeadUpdate(s, []*consensusproto.RawRecordWithId{rawRec})
s.headUpdater.UpdateHeads(s.Id(), []string{rawRec.Id})
s.syncClient.Broadcast(headUpdate)
return
}
func (s *syncAcl) AddRawRecords(rawRecords []*consensusproto.RawRecordWithId) (err error) {
if s.isClosed {
return ErrSyncAclClosed
}
err = s.AclList.AddRawRecords(rawRecords)
if err != nil {
return
}
headUpdate := s.syncClient.CreateHeadUpdate(s, rawRecords)
s.headUpdater.UpdateHeads(s.Id(), []string{rawRecords[len(rawRecords)-1].Id})
s.syncClient.Broadcast(headUpdate)
return
}
func (s *syncAcl) SyncWithPeer(ctx context.Context, peerId string) (err error) {
s.Lock()
defer s.Unlock()
headUpdate := s.syncClient.CreateHeadUpdate(s, nil)
return s.syncClient.SendUpdate(peerId, headUpdate)
}
func (s *syncAcl) Close(ctx context.Context) (err error) {
s.Lock()
defer s.Unlock()
s.isClosed = true
return
}
func (s *syncAcl) Name() (name string) {
return CName
} }

View File

@ -2,30 +2,81 @@ package syncacl
import ( import (
"context" "context"
"fmt" "errors"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/gogo/protobuf/proto"
)
var (
ErrMessageIsRequest = errors.New("message is request")
ErrMessageIsNotRequest = errors.New("message is not request")
) )
type syncAclHandler struct { type syncAclHandler struct {
acl list.AclList aclList list.AclList
syncClient SyncClient
syncProtocol AclSyncProtocol
syncStatus syncstatus.StatusUpdater
spaceId string
} }
func (s *syncAclHandler) HandleMessage(ctx context.Context, senderId string, req *spacesyncproto.ObjectSyncMessage) (err error) { func newSyncAclHandler(spaceId string, aclList list.AclList, syncClient SyncClient, syncStatus syncstatus.StatusUpdater) synchandler.SyncHandler {
aclMsg := &aclrecordproto.AclSyncMessage{} return &syncAclHandler{
if err = aclMsg.Unmarshal(req.Payload); err != nil { aclList: aclList,
syncClient: syncClient,
syncProtocol: newAclSyncProtocol(spaceId, aclList, syncClient),
syncStatus: syncStatus,
spaceId: spaceId,
}
}
func (s *syncAclHandler) HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) {
unmarshalled := &consensusproto.LogSyncMessage{}
err = proto.Unmarshal(message.Payload, unmarshalled)
if err != nil {
return return
} }
content := aclMsg.GetContent() content := unmarshalled.GetContent()
head := consensusproto.GetHead(unmarshalled)
s.syncStatus.HeadsReceive(senderId, s.aclList.Id(), []string{head})
s.aclList.Lock()
defer s.aclList.Unlock()
switch { switch {
case content.GetAddRecords() != nil: case content.GetHeadUpdate() != nil:
return s.handleAddRecords(ctx, senderId, content.GetAddRecords()) var syncReq *consensusproto.LogSyncMessage
default: syncReq, err = s.syncProtocol.HeadUpdate(ctx, senderId, content.GetHeadUpdate())
return fmt.Errorf("unexpected aclSync message: %T", content.Value) if err != nil || syncReq == nil {
return
}
return s.syncClient.QueueRequest(senderId, syncReq)
case content.GetFullSyncRequest() != nil:
return ErrMessageIsRequest
case content.GetFullSyncResponse() != nil:
return s.syncProtocol.FullSyncResponse(ctx, senderId, content.GetFullSyncResponse())
} }
}
func (s *syncAclHandler) handleAddRecords(ctx context.Context, senderId string, addRecord *aclrecordproto.AclAddRecords) (err error) {
return return
} }
func (s *syncAclHandler) HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) {
unmarshalled := &consensusproto.LogSyncMessage{}
err = proto.Unmarshal(request.Payload, unmarshalled)
if err != nil {
return
}
fullSyncRequest := unmarshalled.GetContent().GetFullSyncRequest()
if fullSyncRequest == nil {
return nil, ErrMessageIsNotRequest
}
s.aclList.Lock()
defer s.aclList.Unlock()
aclResp, err := s.syncProtocol.FullSyncRequest(ctx, senderId, fullSyncRequest)
if err != nil {
return
}
return spacesyncproto.MarshallSyncMessage(aclResp, s.spaceId, s.aclList.Id())
}

View File

@ -0,0 +1,233 @@
package syncacl
import (
"context"
"fmt"
"github.com/anyproto/any-sync/commonspace/object/acl/list/mock_list"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl/mock_syncacl"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"sync"
"testing"
)
type testAclMock struct {
*mock_list.MockAclList
m sync.RWMutex
}
func newTestAclMock(mockAcl *mock_list.MockAclList) *testAclMock {
return &testAclMock{
MockAclList: mockAcl,
}
}
func (t *testAclMock) Lock() {
t.m.Lock()
}
func (t *testAclMock) RLock() {
t.m.RLock()
}
func (t *testAclMock) Unlock() {
t.m.Unlock()
}
func (t *testAclMock) RUnlock() {
t.m.RUnlock()
}
func (t *testAclMock) TryLock() bool {
return t.m.TryLock()
}
func (t *testAclMock) TryRLock() bool {
return t.m.TryRLock()
}
type syncHandlerFixture struct {
ctrl *gomock.Controller
syncClientMock *mock_syncacl.MockSyncClient
aclMock *testAclMock
syncProtocolMock *mock_syncacl.MockAclSyncProtocol
spaceId string
senderId string
aclId string
syncHandler *syncAclHandler
}
func newSyncHandlerFixture(t *testing.T) *syncHandlerFixture {
ctrl := gomock.NewController(t)
aclMock := newTestAclMock(mock_list.NewMockAclList(ctrl))
syncClientMock := mock_syncacl.NewMockSyncClient(ctrl)
syncProtocolMock := mock_syncacl.NewMockAclSyncProtocol(ctrl)
spaceId := "spaceId"
syncHandler := &syncAclHandler{
aclList: aclMock,
syncClient: syncClientMock,
syncProtocol: syncProtocolMock,
syncStatus: syncstatus.NewNoOpSyncStatus(),
spaceId: spaceId,
}
return &syncHandlerFixture{
ctrl: ctrl,
syncClientMock: syncClientMock,
aclMock: aclMock,
syncProtocolMock: syncProtocolMock,
spaceId: spaceId,
senderId: "senderId",
aclId: "aclId",
syncHandler: syncHandler,
}
}
func (fx *syncHandlerFixture) stop() {
fx.ctrl.Finish()
}
func TestSyncAclHandler_HandleMessage(t *testing.T) {
ctx := context.Background()
t.Run("handle head update, request returned", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
syncReq := &consensusproto.LogSyncMessage{}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(syncReq, nil)
fx.syncClientMock.EXPECT().QueueRequest(fx.senderId, syncReq).Return(nil)
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
})
t.Run("handle head update, no request", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(nil, nil)
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
})
t.Run("handle head update, returned error", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
expectedErr := fmt.Errorf("some error")
fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(nil, expectedErr)
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.Error(t, expectedErr, err)
})
t.Run("handle full sync request is forbidden", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullRequest := &consensusproto.LogFullSyncRequest{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapFullRequest(fullRequest, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.Error(t, ErrMessageIsRequest, err)
})
t.Run("handle full sync response, no error", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullResponse := &consensusproto.LogFullSyncResponse{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapFullResponse(fullResponse, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.syncProtocolMock.EXPECT().FullSyncResponse(ctx, fx.senderId, gomock.Any()).Return(nil)
err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
})
}
func TestSyncAclHandler_HandleRequest(t *testing.T) {
ctx := context.Background()
t.Run("handle full sync request, no error", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
fullRequest := &consensusproto.LogFullSyncRequest{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapFullRequest(fullRequest, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fullResp := &consensusproto.LogSyncMessage{
Content: &consensusproto.LogSyncContentValue{
Value: &consensusproto.LogSyncContentValue_FullSyncResponse{
FullSyncResponse: &consensusproto.LogFullSyncResponse{
Head: "returnedHead",
},
},
},
}
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
fx.syncProtocolMock.EXPECT().FullSyncRequest(ctx, fx.senderId, gomock.Any()).Return(fullResp, nil)
res, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
unmarshalled := &consensusproto.LogSyncMessage{}
err = proto.Unmarshal(res.Payload, unmarshalled)
if err != nil {
return
}
require.Equal(t, "returnedHead", consensusproto.GetHead(unmarshalled))
})
t.Run("handle other message returns error", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
chWithId := &consensusproto.RawRecordWithId{}
headUpdate := &consensusproto.LogHeadUpdate{
Head: "h1",
Records: []*consensusproto.RawRecordWithId{chWithId},
}
logMessage := consensusproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(logMessage, fx.spaceId, fx.aclId)
fx.aclMock.EXPECT().Id().AnyTimes().Return(fx.aclId)
_, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg)
require.Error(t, ErrMessageIsNotRequest, err)
})
}

View File

@ -0,0 +1,70 @@
package syncacl
import (
"context"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/requestmanager"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"go.uber.org/zap"
)
type SyncClient interface {
RequestFactory
Broadcast(msg *consensusproto.LogSyncMessage)
SendUpdate(peerId string, msg *consensusproto.LogSyncMessage) (err error)
QueueRequest(peerId string, msg *consensusproto.LogSyncMessage) (err error)
SendRequest(ctx context.Context, peerId string, msg *consensusproto.LogSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
}
type syncClient struct {
RequestFactory
spaceId string
requestManager requestmanager.RequestManager
peerManager peermanager.PeerManager
}
func NewSyncClient(spaceId string, requestManager requestmanager.RequestManager, peerManager peermanager.PeerManager) SyncClient {
return &syncClient{
RequestFactory: &requestFactory{},
spaceId: spaceId,
requestManager: requestManager,
peerManager: peerManager,
}
}
func (s *syncClient) Broadcast(msg *consensusproto.LogSyncMessage) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
if err != nil {
return
}
err = s.peerManager.Broadcast(context.Background(), objMsg)
if err != nil {
log.Debug("broadcast error", zap.Error(err))
}
}
func (s *syncClient) SendUpdate(peerId string, msg *consensusproto.LogSyncMessage) (err error) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
if err != nil {
return
}
return s.peerManager.SendPeer(context.Background(), peerId, objMsg)
}
func (s *syncClient) SendRequest(ctx context.Context, peerId string, msg *consensusproto.LogSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
if err != nil {
return
}
return s.requestManager.SendRequest(ctx, peerId, objMsg)
}
func (s *syncClient) QueueRequest(peerId string, msg *consensusproto.LogSyncMessage) (err error) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.Id)
if err != nil {
return
}
return s.requestManager.QueueRequest(peerId, objMsg)
}

View File

@ -15,7 +15,7 @@ type TreeImportParams struct {
} }
func ImportHistoryTree(params TreeImportParams) (tree objecttree.ReadableObjectTree, err error) { func ImportHistoryTree(params TreeImportParams) (tree objecttree.ReadableObjectTree, err error) {
aclList, err := list.BuildAclList(params.ListStorage) aclList, err := list.BuildAclList(params.ListStorage, list.NoOpAcceptorVerifier{})
if err != nil { if err != nil {
return return
} }

View File

@ -13,7 +13,7 @@ import (
objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage" treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockObjectTree is a mock of ObjectTree interface. // MockObjectTree is a mock of ObjectTree interface.

View File

@ -4,11 +4,11 @@ package objecttree
import ( import (
"context" "context"
"errors" "errors"
"github.com/anyproto/any-sync/util/crypto"
"sync" "sync"
"time" "time"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/util/crypto"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
@ -248,9 +248,7 @@ func (ot *objectTree) prepareBuilderContent(content SignableChangeContent) (cnt
pubKey = content.Key.GetPublic() pubKey = content.Key.GetPublic()
readKeyId string readKeyId string
) )
canWrite := state.HasPermission(pubKey, aclrecordproto.AclUserPermissions_Writer) || if !state.Permissions(pubKey).CanWrite() {
state.HasPermission(pubKey, aclrecordproto.AclUserPermissions_Admin)
if !canWrite {
err = list.ErrInsufficientPermissions err = list.ErrInsufficientPermissions
return return
} }

View File

@ -3,6 +3,9 @@ package objecttree
import ( import (
"context" "context"
"fmt" "fmt"
"testing"
"time"
"github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
@ -10,8 +13,6 @@ import (
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing"
"time"
) )
type testTreeContext struct { type testTreeContext struct {
@ -123,6 +124,7 @@ func TestObjectTree(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.GreaterOrEqual(t, start.Unix(), ch.Timestamp) require.GreaterOrEqual(t, start.Unix(), ch.Timestamp)
require.LessOrEqual(t, end.Unix(), ch.Timestamp) require.LessOrEqual(t, end.Unix(), ch.Timestamp)
require.Equal(t, res.Added[0].Id, oTree.(*objectTree).tree.lastIteratedHeadId)
}) })
t.Run("timestamp is set correctly", func(t *testing.T) { t.Run("timestamp is set correctly", func(t *testing.T) {
someTs := time.Now().Add(time.Hour).Unix() someTs := time.Now().Add(time.Hour).Unix()
@ -139,6 +141,7 @@ func TestObjectTree(t *testing.T) {
ch, err := oTree.(*objectTree).changeBuilder.Unmarshall(res.Added[0], true) ch, err := oTree.(*objectTree).changeBuilder.Unmarshall(res.Added[0], true)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, ch.Timestamp, someTs) require.Equal(t, ch.Timestamp, someTs)
require.Equal(t, res.Added[0].Id, oTree.(*objectTree).tree.lastIteratedHeadId)
}) })
}) })

View File

@ -3,7 +3,7 @@ package objecttree
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/util/slice" "github.com/anyproto/any-sync/util/slice"
@ -52,20 +52,18 @@ func (v *objectTreeValidator) ValidateNewChanges(tree *Tree, aclList list.AclLis
func (v *objectTreeValidator) validateChange(tree *Tree, aclList list.AclList, c *Change) (err error) { func (v *objectTreeValidator) validateChange(tree *Tree, aclList list.AclList, c *Change) (err error) {
var ( var (
perm list.AclUserState userState list.AclUserState
state = aclList.AclState() state = aclList.AclState()
) )
// checking if the user could write // checking if the user could write
perm, err = state.StateAtRecord(c.AclHeadId, c.Identity) userState, err = state.StateAtRecord(c.AclHeadId, c.Identity)
if err != nil { if err != nil {
return return
} }
if !userState.Permissions.CanWrite() {
if perm.Permissions != aclrecordproto.AclUserPermissions_Writer && perm.Permissions != aclrecordproto.AclUserPermissions_Admin {
err = list.ErrInsufficientPermissions err = list.ErrInsufficientPermissions
return return
} }
if c.Id == tree.RootId() { if c.Id == tree.RootId() {
return return
} }

View File

@ -2,10 +2,11 @@ package objecttree
import ( import (
"context" "context"
"time"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/util/slice" "github.com/anyproto/any-sync/util/slice"
"time"
) )
type rawChangeLoader struct { type rawChangeLoader struct {
@ -22,6 +23,7 @@ type rawCacheEntry struct {
change *Change change *Change
rawChange *treechangeproto.RawTreeChangeWithId rawChange *treechangeproto.RawTreeChangeWithId
position int position int
removed bool
} }
func newStorageLoader(treeStorage treestorage.TreeStorage, changeBuilder ChangeBuilder) *rawChangeLoader { func newStorageLoader(treeStorage treestorage.TreeStorage, changeBuilder ChangeBuilder) *rawChangeLoader {
@ -126,7 +128,6 @@ func (r *rawChangeLoader) loadFromStorage(commonSnapshot string, heads, breakpoi
if err != nil { if err != nil {
continue continue
} }
entry.position = -1
r.cache[b] = entry r.cache[b] = entry
existingBreakpoints = append(existingBreakpoints, b) existingBreakpoints = append(existingBreakpoints, b)
} }
@ -135,8 +136,7 @@ func (r *rawChangeLoader) loadFromStorage(commonSnapshot string, heads, breakpoi
dfs := func( dfs := func(
commonSnapshot string, commonSnapshot string,
heads []string, heads []string,
startCounter int, shouldVisit func(entry rawCacheEntry, mapExists bool) bool,
shouldVisit func(counter int, mapExists bool) bool,
visit func(entry rawCacheEntry) rawCacheEntry) bool { visit func(entry rawCacheEntry) rawCacheEntry) bool {
// resetting stack // resetting stack
@ -150,7 +150,7 @@ func (r *rawChangeLoader) loadFromStorage(commonSnapshot string, heads, breakpoi
r.idStack = r.idStack[:len(r.idStack)-1] r.idStack = r.idStack[:len(r.idStack)-1]
entry, exists := r.cache[id] entry, exists := r.cache[id]
if !shouldVisit(entry.position, exists) { if !shouldVisit(entry, exists) {
continue continue
} }
if id == commonSnapshot { if id == commonSnapshot {
@ -159,7 +159,6 @@ func (r *rawChangeLoader) loadFromStorage(commonSnapshot string, heads, breakpoi
} }
if !exists { if !exists {
entry, err = r.loadEntry(id) entry, err = r.loadEntry(id)
entry.position = -1
if err != nil { if err != nil {
continue continue
} }
@ -174,7 +173,7 @@ func (r *rawChangeLoader) loadFromStorage(commonSnapshot string, heads, breakpoi
break break
} }
prevEntry, exists := r.cache[prev] prevEntry, exists := r.cache[prev]
if !shouldVisit(prevEntry.position, exists) { if !shouldVisit(prevEntry, exists) {
continue continue
} }
r.idStack = append(r.idStack, prev) r.idStack = append(r.idStack, prev)
@ -187,8 +186,8 @@ func (r *rawChangeLoader) loadFromStorage(commonSnapshot string, heads, breakpoi
r.idStack = append(r.idStack, heads...) r.idStack = append(r.idStack, heads...)
var buffer []*treechangeproto.RawTreeChangeWithId var buffer []*treechangeproto.RawTreeChangeWithId
rootVisited := dfs(commonSnapshot, heads, 0, rootVisited := dfs(commonSnapshot, heads,
func(counter int, mapExists bool) bool { func(_ rawCacheEntry, mapExists bool) bool {
return !mapExists return !mapExists
}, },
func(entry rawCacheEntry) rawCacheEntry { func(entry rawCacheEntry) rawCacheEntry {
@ -213,11 +212,13 @@ func (r *rawChangeLoader) loadFromStorage(commonSnapshot string, heads, breakpoi
} }
// marking all visited as nil // marking all visited as nil
dfs(commonSnapshot, existingBreakpoints, len(buffer), dfs(commonSnapshot, existingBreakpoints,
func(counter int, mapExists bool) bool { func(entry rawCacheEntry, mapExists bool) bool {
return !mapExists || counter < len(buffer) // only going through already loaded changes
return mapExists && !entry.removed
}, },
func(entry rawCacheEntry) rawCacheEntry { func(entry rawCacheEntry) rawCacheEntry {
entry.removed = true
if entry.position != -1 { if entry.position != -1 {
buffer[entry.position] = nil buffer[entry.position] = nil
} }
@ -248,6 +249,7 @@ func (r *rawChangeLoader) loadEntry(id string) (entry rawCacheEntry, err error)
entry = rawCacheEntry{ entry = rawCacheEntry{
change: change, change: change,
rawChange: rawChange, rawChange: rawChange,
position: -1,
} }
return return
} }

View File

@ -82,6 +82,7 @@ func (t *Tree) AddMergedHead(c *Change) error {
} }
} }
t.headIds = []string{c.Id} t.headIds = []string{c.Id}
t.lastIteratedHeadId = c.Id
return nil return nil
} }

View File

@ -2,10 +2,12 @@ package objecttree
import ( import (
"fmt" "fmt"
"github.com/stretchr/testify/assert"
"math/rand" "math/rand"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func newChange(id string, snapshotId string, prevIds ...string) *Change { func newChange(id string, snapshotId string, prevIds ...string) *Change {
@ -26,6 +28,17 @@ func newSnapshot(id, snapshotId string, prevIds ...string) *Change {
} }
} }
func TestTree_AddMergedHead(t *testing.T) {
tr := new(Tree)
_, _ = tr.Add(
newSnapshot("root", ""),
newChange("one", "root", "root"),
)
require.Equal(t, tr.lastIteratedHeadId, "one")
tr.AddMergedHead(newChange("two", "root", "one"))
require.Equal(t, tr.lastIteratedHeadId, "two")
}
func TestTree_Add(t *testing.T) { func TestTree_Add(t *testing.T) {
t.Run("add first el", func(t *testing.T) { t.Run("add first el", func(t *testing.T) {
tr := new(Tree) tr := new(Tree)

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/object/tree/synctree (interfaces: SyncTree,ReceiveQueue,HeadNotifiable) // Source: github.com/anyproto/any-sync/commonspace/object/tree/synctree (interfaces: SyncTree,ReceiveQueue,HeadNotifiable,SyncClient,RequestFactory,TreeSyncProtocol)
// Package mock_synctree is a generated GoMock package. // Package mock_synctree is a generated GoMock package.
package mock_synctree package mock_synctree
@ -15,7 +15,7 @@ import (
treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage" treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto" spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockSyncTree is a mock of SyncTree interface. // MockSyncTree is a mock of SyncTree interface.
@ -186,6 +186,21 @@ func (mr *MockSyncTreeMockRecorder) HandleMessage(arg0, arg1, arg2 interface{})
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockSyncTree)(nil).HandleMessage), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockSyncTree)(nil).HandleMessage), arg0, arg1, arg2)
} }
// HandleRequest mocks base method.
func (m *MockSyncTree) HandleRequest(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HandleRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// HandleRequest indicates an expected call of HandleRequest.
func (mr *MockSyncTreeMockRecorder) HandleRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleRequest", reflect.TypeOf((*MockSyncTree)(nil).HandleRequest), arg0, arg1, arg2)
}
// HasChanges mocks base method. // HasChanges mocks base method.
func (m *MockSyncTree) HasChanges(arg0 ...string) bool { func (m *MockSyncTree) HasChanges(arg0 ...string) bool {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -590,3 +605,287 @@ func (mr *MockHeadNotifiableMockRecorder) UpdateHeads(arg0, arg1 interface{}) *g
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHeads", reflect.TypeOf((*MockHeadNotifiable)(nil).UpdateHeads), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateHeads", reflect.TypeOf((*MockHeadNotifiable)(nil).UpdateHeads), arg0, arg1)
} }
// MockSyncClient is a mock of SyncClient interface.
type MockSyncClient struct {
ctrl *gomock.Controller
recorder *MockSyncClientMockRecorder
}
// MockSyncClientMockRecorder is the mock recorder for MockSyncClient.
type MockSyncClientMockRecorder struct {
mock *MockSyncClient
}
// NewMockSyncClient creates a new mock instance.
func NewMockSyncClient(ctrl *gomock.Controller) *MockSyncClient {
mock := &MockSyncClient{ctrl: ctrl}
mock.recorder = &MockSyncClientMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSyncClient) EXPECT() *MockSyncClientMockRecorder {
return m.recorder
}
// Broadcast mocks base method.
func (m *MockSyncClient) Broadcast(arg0 *treechangeproto.TreeSyncMessage) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Broadcast", arg0)
}
// Broadcast indicates an expected call of Broadcast.
func (mr *MockSyncClientMockRecorder) Broadcast(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Broadcast", reflect.TypeOf((*MockSyncClient)(nil).Broadcast), arg0)
}
// CreateFullSyncRequest mocks base method.
func (m *MockSyncClient) CreateFullSyncRequest(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
func (mr *MockSyncClientMockRecorder) CreateFullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncRequest), arg0, arg1, arg2)
}
// CreateFullSyncResponse mocks base method.
func (m *MockSyncClient) CreateFullSyncResponse(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
func (mr *MockSyncClientMockRecorder) CreateFullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncResponse), arg0, arg1, arg2)
}
// CreateHeadUpdate mocks base method.
func (m *MockSyncClient) CreateHeadUpdate(arg0 objecttree.ObjectTree, arg1 []*treechangeproto.RawTreeChangeWithId) *treechangeproto.TreeSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
return ret0
}
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate.
func (mr *MockSyncClientMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockSyncClient)(nil).CreateHeadUpdate), arg0, arg1)
}
// CreateNewTreeRequest mocks base method.
func (m *MockSyncClient) CreateNewTreeRequest() *treechangeproto.TreeSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateNewTreeRequest")
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
return ret0
}
// CreateNewTreeRequest indicates an expected call of CreateNewTreeRequest.
func (mr *MockSyncClientMockRecorder) CreateNewTreeRequest() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNewTreeRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateNewTreeRequest))
}
// QueueRequest mocks base method.
func (m *MockSyncClient) QueueRequest(arg0, arg1 string, arg2 *treechangeproto.TreeSyncMessage) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "QueueRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// QueueRequest indicates an expected call of QueueRequest.
func (mr *MockSyncClientMockRecorder) QueueRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueueRequest", reflect.TypeOf((*MockSyncClient)(nil).QueueRequest), arg0, arg1, arg2)
}
// SendRequest mocks base method.
func (m *MockSyncClient) SendRequest(arg0 context.Context, arg1, arg2 string, arg3 *treechangeproto.TreeSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendRequest", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SendRequest indicates an expected call of SendRequest.
func (mr *MockSyncClientMockRecorder) SendRequest(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendRequest", reflect.TypeOf((*MockSyncClient)(nil).SendRequest), arg0, arg1, arg2, arg3)
}
// SendUpdate mocks base method.
func (m *MockSyncClient) SendUpdate(arg0, arg1 string, arg2 *treechangeproto.TreeSyncMessage) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendUpdate", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// SendUpdate indicates an expected call of SendUpdate.
func (mr *MockSyncClientMockRecorder) SendUpdate(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendUpdate", reflect.TypeOf((*MockSyncClient)(nil).SendUpdate), arg0, arg1, arg2)
}
// MockRequestFactory is a mock of RequestFactory interface.
type MockRequestFactory struct {
ctrl *gomock.Controller
recorder *MockRequestFactoryMockRecorder
}
// MockRequestFactoryMockRecorder is the mock recorder for MockRequestFactory.
type MockRequestFactoryMockRecorder struct {
mock *MockRequestFactory
}
// NewMockRequestFactory creates a new mock instance.
func NewMockRequestFactory(ctrl *gomock.Controller) *MockRequestFactory {
mock := &MockRequestFactory{ctrl: ctrl}
mock.recorder = &MockRequestFactoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockRequestFactory) EXPECT() *MockRequestFactoryMockRecorder {
return m.recorder
}
// CreateFullSyncRequest mocks base method.
func (m *MockRequestFactory) CreateFullSyncRequest(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
func (mr *MockRequestFactoryMockRecorder) CreateFullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncRequest), arg0, arg1, arg2)
}
// CreateFullSyncResponse mocks base method.
func (m *MockRequestFactory) CreateFullSyncResponse(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
func (mr *MockRequestFactoryMockRecorder) CreateFullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockRequestFactory)(nil).CreateFullSyncResponse), arg0, arg1, arg2)
}
// CreateHeadUpdate mocks base method.
func (m *MockRequestFactory) CreateHeadUpdate(arg0 objecttree.ObjectTree, arg1 []*treechangeproto.RawTreeChangeWithId) *treechangeproto.TreeSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
return ret0
}
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate.
func (mr *MockRequestFactoryMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockRequestFactory)(nil).CreateHeadUpdate), arg0, arg1)
}
// CreateNewTreeRequest mocks base method.
func (m *MockRequestFactory) CreateNewTreeRequest() *treechangeproto.TreeSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateNewTreeRequest")
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
return ret0
}
// CreateNewTreeRequest indicates an expected call of CreateNewTreeRequest.
func (mr *MockRequestFactoryMockRecorder) CreateNewTreeRequest() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNewTreeRequest", reflect.TypeOf((*MockRequestFactory)(nil).CreateNewTreeRequest))
}
// MockTreeSyncProtocol is a mock of TreeSyncProtocol interface.
type MockTreeSyncProtocol struct {
ctrl *gomock.Controller
recorder *MockTreeSyncProtocolMockRecorder
}
// MockTreeSyncProtocolMockRecorder is the mock recorder for MockTreeSyncProtocol.
type MockTreeSyncProtocolMockRecorder struct {
mock *MockTreeSyncProtocol
}
// NewMockTreeSyncProtocol creates a new mock instance.
func NewMockTreeSyncProtocol(ctrl *gomock.Controller) *MockTreeSyncProtocol {
mock := &MockTreeSyncProtocol{ctrl: ctrl}
mock.recorder = &MockTreeSyncProtocolMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTreeSyncProtocol) EXPECT() *MockTreeSyncProtocolMockRecorder {
return m.recorder
}
// FullSyncRequest mocks base method.
func (m *MockTreeSyncProtocol) FullSyncRequest(arg0 context.Context, arg1 string, arg2 *treechangeproto.TreeFullSyncRequest) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FullSyncRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FullSyncRequest indicates an expected call of FullSyncRequest.
func (mr *MockTreeSyncProtocolMockRecorder) FullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncRequest", reflect.TypeOf((*MockTreeSyncProtocol)(nil).FullSyncRequest), arg0, arg1, arg2)
}
// FullSyncResponse mocks base method.
func (m *MockTreeSyncProtocol) FullSyncResponse(arg0 context.Context, arg1 string, arg2 *treechangeproto.TreeFullSyncResponse) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FullSyncResponse", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// FullSyncResponse indicates an expected call of FullSyncResponse.
func (mr *MockTreeSyncProtocolMockRecorder) FullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FullSyncResponse", reflect.TypeOf((*MockTreeSyncProtocol)(nil).FullSyncResponse), arg0, arg1, arg2)
}
// HeadUpdate mocks base method.
func (m *MockTreeSyncProtocol) HeadUpdate(arg0 context.Context, arg1 string, arg2 *treechangeproto.TreeHeadUpdate) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HeadUpdate", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// HeadUpdate indicates an expected call of HeadUpdate.
func (mr *MockTreeSyncProtocolMockRecorder) HeadUpdate(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HeadUpdate", reflect.TypeOf((*MockTreeSyncProtocol)(nil).HeadUpdate), arg0, arg1, arg2)
}

View File

@ -2,6 +2,10 @@ package synctree
import ( import (
"context" "context"
"math/rand"
"testing"
"time"
"github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
@ -10,9 +14,6 @@ import (
"github.com/anyproto/any-sync/util/slice" "github.com/anyproto/any-sync/util/slice"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"math/rand"
"testing"
"time"
) )
func TestEmptyClientGetsFullHistory(t *testing.T) { func TestEmptyClientGetsFullHistory(t *testing.T) {

View File

@ -1,4 +1,4 @@
package objectsync package synctree
import ( import (
"fmt" "fmt"

View File

@ -0,0 +1,70 @@
package synctree
import (
"context"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/requestmanager"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"go.uber.org/zap"
)
type SyncClient interface {
RequestFactory
Broadcast(msg *treechangeproto.TreeSyncMessage)
SendUpdate(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error)
QueueRequest(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error)
SendRequest(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
}
type syncClient struct {
RequestFactory
spaceId string
requestManager requestmanager.RequestManager
peerManager peermanager.PeerManager
}
func NewSyncClient(spaceId string, requestManager requestmanager.RequestManager, peerManager peermanager.PeerManager) SyncClient {
return &syncClient{
RequestFactory: &requestFactory{},
spaceId: spaceId,
requestManager: requestManager,
peerManager: peerManager,
}
}
func (s *syncClient) Broadcast(msg *treechangeproto.TreeSyncMessage) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, msg.RootChange.Id)
if err != nil {
return
}
err = s.peerManager.Broadcast(context.Background(), objMsg)
if err != nil {
log.Debug("broadcast error", zap.Error(err))
}
}
func (s *syncClient) SendUpdate(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, objectId)
if err != nil {
return
}
return s.peerManager.SendPeer(context.Background(), peerId, objMsg)
}
func (s *syncClient) SendRequest(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, objectId)
if err != nil {
return
}
return s.requestManager.SendRequest(ctx, peerId, objMsg)
}
func (s *syncClient) QueueRequest(peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (err error) {
objMsg, err := spacesyncproto.MarshallSyncMessage(msg, s.spaceId, objectId)
if err != nil {
return
}
return s.requestManager.QueueRequest(peerId, objMsg)
}

View File

@ -1,4 +1,4 @@
//go:generate mockgen -destination mock_synctree/mock_synctree.go github.com/anyproto/any-sync/commonspace/object/tree/synctree SyncTree,ReceiveQueue,HeadNotifiable //go:generate mockgen -destination mock_synctree/mock_synctree.go github.com/anyproto/any-sync/commonspace/object/tree/synctree SyncTree,ReceiveQueue,HeadNotifiable,SyncClient,RequestFactory,TreeSyncProtocol
package synctree package synctree
import ( import (
@ -11,7 +11,6 @@ import (
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
@ -44,7 +43,7 @@ type SyncTree interface {
type syncTree struct { type syncTree struct {
objecttree.ObjectTree objecttree.ObjectTree
synchandler.SyncHandler synchandler.SyncHandler
syncClient objectsync.SyncClient syncClient SyncClient
syncStatus syncstatus.StatusUpdater syncStatus syncstatus.StatusUpdater
notifiable HeadNotifiable notifiable HeadNotifiable
listener updatelistener.UpdateListener listener updatelistener.UpdateListener
@ -60,19 +59,18 @@ type ResponsiblePeersGetter interface {
} }
type BuildDeps struct { type BuildDeps struct {
SpaceId string SpaceId string
SyncClient objectsync.SyncClient SyncClient SyncClient
Configuration nodeconf.NodeConf Configuration nodeconf.NodeConf
HeadNotifiable HeadNotifiable HeadNotifiable HeadNotifiable
Listener updatelistener.UpdateListener Listener updatelistener.UpdateListener
AclList list.AclList AclList list.AclList
SpaceStorage spacestorage.SpaceStorage SpaceStorage spacestorage.SpaceStorage
TreeStorage treestorage.TreeStorage TreeStorage treestorage.TreeStorage
OnClose func(id string) OnClose func(id string)
SyncStatus syncstatus.StatusUpdater SyncStatus syncstatus.StatusUpdater
PeerGetter ResponsiblePeersGetter PeerGetter ResponsiblePeersGetter
BuildObjectTree objecttree.BuildObjectTreeFunc BuildObjectTree objecttree.BuildObjectTreeFunc
WaitTreeRemoteSync bool
} }
func BuildSyncTreeOrGetRemote(ctx context.Context, id string, deps BuildDeps) (t SyncTree, err error) { func BuildSyncTreeOrGetRemote(ctx context.Context, id string, deps BuildDeps) (t SyncTree, err error) {
@ -119,7 +117,7 @@ func buildSyncTree(ctx context.Context, sendUpdate bool, deps BuildDeps) (t Sync
if sendUpdate { if sendUpdate {
headUpdate := syncTree.syncClient.CreateHeadUpdate(t, nil) headUpdate := syncTree.syncClient.CreateHeadUpdate(t, nil)
// send to everybody, because everybody should know that the node or client got new tree // send to everybody, because everybody should know that the node or client got new tree
syncTree.syncClient.Broadcast(ctx, headUpdate) syncTree.syncClient.Broadcast(headUpdate)
} }
return return
} }
@ -156,7 +154,7 @@ func (s *syncTree) AddContent(ctx context.Context, content objecttree.SignableCh
} }
s.syncStatus.HeadsChange(s.Id(), res.Heads) s.syncStatus.HeadsChange(s.Id(), res.Heads)
headUpdate := s.syncClient.CreateHeadUpdate(s, res.Added) headUpdate := s.syncClient.CreateHeadUpdate(s, res.Added)
s.syncClient.Broadcast(ctx, headUpdate) s.syncClient.Broadcast(headUpdate)
return return
} }
@ -183,7 +181,7 @@ func (s *syncTree) AddRawChanges(ctx context.Context, changesPayload objecttree.
s.notifiable.UpdateHeads(s.Id(), res.Heads) s.notifiable.UpdateHeads(s.Id(), res.Heads)
} }
headUpdate := s.syncClient.CreateHeadUpdate(s, res.Added) headUpdate := s.syncClient.CreateHeadUpdate(s, res.Added)
s.syncClient.Broadcast(ctx, headUpdate) s.syncClient.Broadcast(headUpdate)
} }
return return
} }
@ -207,18 +205,27 @@ func (s *syncTree) Delete() (err error) {
} }
func (s *syncTree) TryClose(objectTTL time.Duration) (bool, error) { func (s *syncTree) TryClose(objectTTL time.Duration) (bool, error) {
return true, s.Close() if !s.TryLock() {
return false, nil
}
log.Debug("closing sync tree", zap.String("id", s.Id()))
return true, s.close()
} }
func (s *syncTree) Close() (err error) { func (s *syncTree) Close() (err error) {
log.Debug("closing sync tree", zap.String("id", s.Id())) log.Debug("closing sync tree", zap.String("id", s.Id()))
s.Lock()
return s.close()
}
func (s *syncTree) close() (err error) {
defer s.Unlock()
defer func() { defer func() {
log.Debug("closed sync tree", zap.Error(err), zap.String("id", s.Id())) log.Debug("closed sync tree", zap.Error(err), zap.String("id", s.Id()))
}() }()
s.Lock()
defer s.Unlock()
if s.isClosed { if s.isClosed {
return ErrSyncTreeClosed err = ErrSyncTreeClosed
return
} }
s.onClose(s.Id()) s.onClose(s.Id())
s.isClosed = true s.isClosed = true
@ -239,7 +246,7 @@ func (s *syncTree) SyncWithPeer(ctx context.Context, peerId string) (err error)
s.Lock() s.Lock()
defer s.Unlock() defer s.Unlock()
headUpdate := s.syncClient.CreateHeadUpdate(s, nil) headUpdate := s.syncClient.CreateHeadUpdate(s, nil)
return s.syncClient.SendWithReply(ctx, peerId, headUpdate.RootChange.Id, headUpdate, "") return s.syncClient.SendUpdate(peerId, headUpdate.RootChange.Id, headUpdate)
} }
func (s *syncTree) afterBuild() { func (s *syncTree) afterBuild() {

View File

@ -4,21 +4,21 @@ import (
"context" "context"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/mock_synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener/mock_updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener/mock_updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/objectsync" "github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/mock_objectsync"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"testing" "testing"
) )
type syncTreeMatcher struct { type syncTreeMatcher struct {
objTree objecttree.ObjectTree objTree objecttree.ObjectTree
client objectsync.SyncClient client SyncClient
listener updatelistener.UpdateListener listener updatelistener.UpdateListener
} }
@ -34,8 +34,8 @@ func (s syncTreeMatcher) String() string {
return "" return ""
} }
func syncClientFuncCreator(client objectsync.SyncClient) func(spaceId string, factory objectsync.RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) objectsync.SyncClient { func syncClientFuncCreator(client SyncClient) func(spaceId string, factory RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) SyncClient {
return func(spaceId string, factory objectsync.RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) objectsync.SyncClient { return func(spaceId string, factory RequestFactory, objectSync objectsync.ObjectSync, configuration nodeconf.NodeConf) SyncClient {
return client return client
} }
} }
@ -46,7 +46,7 @@ func Test_BuildSyncTree(t *testing.T) {
defer ctrl.Finish() defer ctrl.Finish()
updateListenerMock := mock_updatelistener.NewMockUpdateListener(ctrl) updateListenerMock := mock_updatelistener.NewMockUpdateListener(ctrl)
syncClientMock := mock_objectsync.NewMockSyncClient(ctrl) syncClientMock := mock_synctree.NewMockSyncClient(ctrl)
objTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl)) objTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl))
tr := &syncTree{ tr := &syncTree{
ObjectTree: objTreeMock, ObjectTree: objTreeMock,
@ -73,7 +73,7 @@ func Test_BuildSyncTree(t *testing.T) {
updateListenerMock.EXPECT().Update(tr) updateListenerMock.EXPECT().Update(tr)
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)) syncClientMock.EXPECT().Broadcast(gomock.Eq(headUpdate))
res, err := tr.AddRawChanges(ctx, payload) res, err := tr.AddRawChanges(ctx, payload)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expectedRes, res) require.Equal(t, expectedRes, res)
@ -95,7 +95,7 @@ func Test_BuildSyncTree(t *testing.T) {
updateListenerMock.EXPECT().Rebuild(tr) updateListenerMock.EXPECT().Rebuild(tr)
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)) syncClientMock.EXPECT().Broadcast(gomock.Eq(headUpdate))
res, err := tr.AddRawChanges(ctx, payload) res, err := tr.AddRawChanges(ctx, payload)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expectedRes, res) require.Equal(t, expectedRes, res)
@ -133,7 +133,7 @@ func Test_BuildSyncTree(t *testing.T) {
Return(expectedRes, nil) Return(expectedRes, nil)
syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate) syncClientMock.EXPECT().CreateHeadUpdate(gomock.Eq(tr), gomock.Eq(changes)).Return(headUpdate)
syncClientMock.EXPECT().Broadcast(gomock.Any(), gomock.Eq(headUpdate)) syncClientMock.EXPECT().Broadcast(gomock.Eq(headUpdate))
res, err := tr.AddContent(ctx, content) res, err := tr.AddContent(ctx, content)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, expectedRes, res) require.Equal(t, expectedRes, res)

View File

@ -2,233 +2,142 @@ package synctree
import ( import (
"context" "context"
"errors"
"sync"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/util/slice" "github.com/anyproto/any-sync/util/slice"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"go.uber.org/zap" )
"sync"
var (
ErrMessageIsRequest = errors.New("message is request")
ErrMessageIsNotRequest = errors.New("message is not request")
ErrMoreThanOneRequest = errors.New("more than one request for same peer")
) )
type syncTreeHandler struct { type syncTreeHandler struct {
objTree objecttree.ObjectTree objTree objecttree.ObjectTree
syncClient objectsync.SyncClient syncClient SyncClient
syncStatus syncstatus.StatusUpdater syncProtocol TreeSyncProtocol
handlerLock sync.Mutex syncStatus syncstatus.StatusUpdater
spaceId string spaceId string
queue ReceiveQueue
handlerLock sync.Mutex
pendingRequests map[string]struct{}
heads []string
} }
const maxQueueSize = 5 const maxQueueSize = 5
func newSyncTreeHandler(spaceId string, objTree objecttree.ObjectTree, syncClient objectsync.SyncClient, syncStatus syncstatus.StatusUpdater) synchandler.SyncHandler { func newSyncTreeHandler(spaceId string, objTree objecttree.ObjectTree, syncClient SyncClient, syncStatus syncstatus.StatusUpdater) synchandler.SyncHandler {
return &syncTreeHandler{ return &syncTreeHandler{
objTree: objTree, objTree: objTree,
syncClient: syncClient, syncProtocol: newTreeSyncProtocol(spaceId, objTree, syncClient),
syncStatus: syncStatus, syncClient: syncClient,
spaceId: spaceId, syncStatus: syncStatus,
queue: newReceiveQueue(maxQueueSize), spaceId: spaceId,
pendingRequests: make(map[string]struct{}),
} }
} }
func (s *syncTreeHandler) HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) {
unmarshalled := &treechangeproto.TreeSyncMessage{}
err = proto.Unmarshal(request.Payload, unmarshalled)
if err != nil {
return
}
fullSyncRequest := unmarshalled.GetContent().GetFullSyncRequest()
if fullSyncRequest == nil {
return nil, ErrMessageIsNotRequest
}
// setting pending requests
s.handlerLock.Lock()
_, exists := s.pendingRequests[senderId]
if exists {
s.handlerLock.Unlock()
return nil, ErrMoreThanOneRequest
}
s.pendingRequests[senderId] = struct{}{}
s.handlerLock.Unlock()
response, err = s.handleRequest(ctx, senderId, fullSyncRequest)
// removing pending requests
s.handlerLock.Lock()
delete(s.pendingRequests, senderId)
s.handlerLock.Unlock()
return
}
func (s *syncTreeHandler) handleRequest(ctx context.Context, senderId string, fullSyncRequest *treechangeproto.TreeFullSyncRequest) (response *spacesyncproto.ObjectSyncMessage, err error) {
s.objTree.Lock()
defer s.objTree.Unlock()
treeResp, err := s.syncProtocol.FullSyncRequest(ctx, senderId, fullSyncRequest)
if err != nil {
return
}
response, err = spacesyncproto.MarshallSyncMessage(treeResp, s.spaceId, s.objTree.Id())
return
}
func (s *syncTreeHandler) HandleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { func (s *syncTreeHandler) HandleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
unmarshalled := &treechangeproto.TreeSyncMessage{} unmarshalled := &treechangeproto.TreeSyncMessage{}
err = proto.Unmarshal(msg.Payload, unmarshalled) err = proto.Unmarshal(msg.Payload, unmarshalled)
if err != nil { if err != nil {
return return
} }
s.syncStatus.HeadsReceive(senderId, msg.ObjectId, treechangeproto.GetHeads(unmarshalled)) heads := treechangeproto.GetHeads(unmarshalled)
s.syncStatus.HeadsReceive(senderId, msg.ObjectId, heads)
queueFull := s.queue.AddMessage(senderId, unmarshalled, msg.RequestId) s.handlerLock.Lock()
if queueFull { // if the update has same heads then returning not to hang on a lock
if unmarshalled.GetContent().GetHeadUpdate() != nil && slice.UnsortedEquals(heads, s.heads) {
s.handlerLock.Unlock()
return return
} }
s.handlerLock.Unlock()
return s.handleMessage(ctx, senderId) return s.handleMessage(ctx, unmarshalled, senderId)
} }
func (s *syncTreeHandler) handleMessage(ctx context.Context, senderId string) (err error) { func (s *syncTreeHandler) handleMessage(ctx context.Context, msg *treechangeproto.TreeSyncMessage, senderId string) (err error) {
s.objTree.Lock() s.objTree.Lock()
defer s.objTree.Unlock() defer s.objTree.Unlock()
msg, replyId, err := s.queue.GetMessage(senderId)
if err != nil {
return
}
defer s.queue.ClearQueue(senderId)
content := msg.GetContent()
switch {
case content.GetHeadUpdate() != nil:
return s.handleHeadUpdate(ctx, senderId, content.GetHeadUpdate(), replyId)
case content.GetFullSyncRequest() != nil:
return s.handleFullSyncRequest(ctx, senderId, content.GetFullSyncRequest(), replyId)
case content.GetFullSyncResponse() != nil:
return s.handleFullSyncResponse(ctx, senderId, content.GetFullSyncResponse())
}
return
}
func (s *syncTreeHandler) handleHeadUpdate(
ctx context.Context,
senderId string,
update *treechangeproto.TreeHeadUpdate,
replyId string) (err error) {
var ( var (
fullRequest *treechangeproto.TreeSyncMessage copyHeads = make([]string, 0, len(s.objTree.Heads()))
isEmptyUpdate = len(update.Changes) == 0 treeId = s.objTree.Id()
objTree = s.objTree content = msg.GetContent()
treeId = objTree.Id()
) )
log := log.With(
zap.Strings("update heads", update.Heads),
zap.String("treeId", treeId),
zap.String("spaceId", s.spaceId),
zap.Int("len(update changes)", len(update.Changes)))
log.DebugCtx(ctx, "received head update message")
// getting old heads
copyHeads = append(copyHeads, s.objTree.Heads()...)
defer func() { defer func() {
if err != nil { // checking if something changed
log.ErrorCtx(ctx, "head update finished with error", zap.Error(err)) if !slice.UnsortedEquals(copyHeads, s.objTree.Heads()) {
} else if fullRequest != nil { s.handlerLock.Lock()
cnt := fullRequest.Content.GetFullSyncRequest() defer s.handlerLock.Unlock()
log = log.With(zap.Strings("request heads", cnt.Heads), zap.Int("len(request changes)", len(cnt.Changes))) s.heads = s.heads[:0]
log.DebugCtx(ctx, "sending full sync request") for _, h := range s.objTree.Heads() {
} else { s.heads = append(s.heads, h)
if !isEmptyUpdate {
log.DebugCtx(ctx, "head update finished correctly")
} }
} }
}() }()
// isEmptyUpdate is sent when the tree is brought up from cache switch {
if isEmptyUpdate { case content.GetHeadUpdate() != nil:
headEquals := slice.UnsortedEquals(objTree.Heads(), update.Heads) var syncReq *treechangeproto.TreeSyncMessage
log.DebugCtx(ctx, "is empty update", zap.String("treeId", objTree.Id()), zap.Bool("headEquals", headEquals)) syncReq, err = s.syncProtocol.HeadUpdate(ctx, senderId, content.GetHeadUpdate())
if headEquals { if err != nil || syncReq == nil {
return return
} }
return s.syncClient.QueueRequest(senderId, treeId, syncReq)
// we need to sync in any case case content.GetFullSyncRequest() != nil:
fullRequest, err = s.syncClient.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath) return ErrMessageIsRequest
if err != nil { case content.GetFullSyncResponse() != nil:
return return s.syncProtocol.FullSyncResponse(ctx, senderId, content.GetFullSyncResponse())
}
return s.syncClient.SendWithReply(ctx, senderId, treeId, fullRequest, replyId)
} }
if s.alreadyHasHeads(objTree, update.Heads) {
return
}
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: update.Heads,
RawChanges: update.Changes,
})
if err != nil {
return
}
if s.alreadyHasHeads(objTree, update.Heads) {
return
}
fullRequest, err = s.syncClient.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath)
if err != nil {
return
}
return s.syncClient.SendWithReply(ctx, senderId, treeId, fullRequest, replyId)
}
func (s *syncTreeHandler) handleFullSyncRequest(
ctx context.Context,
senderId string,
request *treechangeproto.TreeFullSyncRequest,
replyId string) (err error) {
var (
fullResponse *treechangeproto.TreeSyncMessage
header = s.objTree.Header()
objTree = s.objTree
treeId = s.objTree.Id()
)
log := log.With(zap.String("senderId", senderId),
zap.Strings("request heads", request.Heads),
zap.String("treeId", treeId),
zap.String("replyId", replyId),
zap.String("spaceId", s.spaceId),
zap.Int("len(request changes)", len(request.Changes)))
log.DebugCtx(ctx, "received full sync request message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "full sync request finished with error", zap.Error(err))
s.syncClient.SendWithReply(ctx, senderId, treeId, treechangeproto.WrapError(treechangeproto.ErrFullSync, header), replyId)
return
} else if fullResponse != nil {
cnt := fullResponse.Content.GetFullSyncResponse()
log = log.With(zap.Strings("response heads", cnt.Heads), zap.Int("len(response changes)", len(cnt.Changes)))
log.DebugCtx(ctx, "full sync response sent")
}
}()
if len(request.Changes) != 0 && !s.alreadyHasHeads(objTree, request.Heads) {
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: request.Heads,
RawChanges: request.Changes,
})
if err != nil {
return
}
}
fullResponse, err = s.syncClient.CreateFullSyncResponse(objTree, request.Heads, request.SnapshotPath)
if err != nil {
return
}
return s.syncClient.SendWithReply(ctx, senderId, treeId, fullResponse, replyId)
}
func (s *syncTreeHandler) handleFullSyncResponse(
ctx context.Context,
senderId string,
response *treechangeproto.TreeFullSyncResponse) (err error) {
var (
objTree = s.objTree
treeId = s.objTree.Id()
)
log := log.With(
zap.Strings("heads", response.Heads),
zap.String("treeId", treeId),
zap.String("spaceId", s.spaceId),
zap.Int("len(changes)", len(response.Changes)))
log.DebugCtx(ctx, "received full sync response message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "full sync response failed", zap.Error(err))
} else {
log.DebugCtx(ctx, "full sync response succeeded")
}
}()
if s.alreadyHasHeads(objTree, response.Heads) {
return
}
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: response.Heads,
RawChanges: response.Changes,
})
return return
} }
func (s *syncTreeHandler) alreadyHasHeads(t objecttree.ObjectTree, heads []string) bool {
return slice.UnsortedEquals(t.Heads(), heads) || t.HasChanges(heads...)
}

View File

@ -2,20 +2,16 @@ package synctree
import ( import (
"context" "context"
"fmt"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/mock_objectsync"
"sync" "sync"
"testing" "testing"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/mock_synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/zap" "go.uber.org/mock/gomock"
) )
type testObjTreeMock struct { type testObjTreeMock struct {
@ -55,31 +51,40 @@ func (t *testObjTreeMock) TryRLock() bool {
type syncHandlerFixture struct { type syncHandlerFixture struct {
ctrl *gomock.Controller ctrl *gomock.Controller
syncClientMock *mock_objectsync.MockSyncClient syncClientMock *mock_synctree.MockSyncClient
objectTreeMock *testObjTreeMock objectTreeMock *testObjTreeMock
receiveQueueMock ReceiveQueue syncProtocolMock *mock_synctree.MockTreeSyncProtocol
spaceId string
senderId string
treeId string
syncHandler *syncTreeHandler syncHandler *syncTreeHandler
} }
func newSyncHandlerFixture(t *testing.T) *syncHandlerFixture { func newSyncHandlerFixture(t *testing.T) *syncHandlerFixture {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
syncClientMock := mock_objectsync.NewMockSyncClient(ctrl)
objectTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl)) objectTreeMock := newTestObjMock(mock_objecttree.NewMockObjectTree(ctrl))
receiveQueue := newReceiveQueue(5) syncClientMock := mock_synctree.NewMockSyncClient(ctrl)
syncProtocolMock := mock_synctree.NewMockTreeSyncProtocol(ctrl)
spaceId := "spaceId"
syncHandler := &syncTreeHandler{ syncHandler := &syncTreeHandler{
objTree: objectTreeMock, objTree: objectTreeMock,
syncClient: syncClientMock, syncClient: syncClientMock,
queue: receiveQueue, syncProtocol: syncProtocolMock,
syncStatus: syncstatus.NewNoOpSyncStatus(), spaceId: spaceId,
syncStatus: syncstatus.NewNoOpSyncStatus(),
pendingRequests: map[string]struct{}{},
} }
return &syncHandlerFixture{ return &syncHandlerFixture{
ctrl: ctrl, ctrl: ctrl,
syncClientMock: syncClientMock,
objectTreeMock: objectTreeMock, objectTreeMock: objectTreeMock,
receiveQueueMock: receiveQueue, syncProtocolMock: syncProtocolMock,
syncClientMock: syncClientMock,
syncHandler: syncHandler, syncHandler: syncHandler,
spaceId: spaceId,
senderId: "senderId",
treeId: "treeId",
} }
} }
@ -87,341 +92,149 @@ func (fx *syncHandlerFixture) stop() {
fx.ctrl.Finish() fx.ctrl.Finish()
} }
func TestSyncHandler_HandleHeadUpdate(t *testing.T) { func TestSyncTreeHandler_HandleMessage(t *testing.T) {
ctx := context.Background() ctx := context.Background()
log = logger.CtxLogger{Logger: zap.NewNop()}
fullRequest := &treechangeproto.TreeSyncMessage{
Content: &treechangeproto.TreeSyncContentValue{
Value: &treechangeproto.TreeSyncContentValue_FullSyncRequest{
FullSyncRequest: &treechangeproto.TreeFullSyncRequest{},
},
},
}
t.Run("head update non empty all heads added", func(t *testing.T) { t.Run("handle head update message, heads not equal, request returned", func(t *testing.T) {
fx := newSyncHandlerFixture(t) fx := newSyncHandlerFixture(t)
defer fx.stop() defer fx.stop()
treeId := "treeId" treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{} chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{ headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"}, Heads: []string{"h3"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
} }
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") objectMsg, _ := spacesyncproto.MarshallSyncMessage(treeMsg, "spaceId", treeId)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) syncReq := &treechangeproto.TreeSyncMessage{}
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).Times(2) fx.syncHandler.heads = []string{"h2"}
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false) fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT(). fx.objectTreeMock.EXPECT().Heads().Times(2).Return([]string{"h2"})
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{ fx.objectTreeMock.EXPECT().Heads().Times(2).Return([]string{"h3"})
NewHeads: []string{"h1"}, fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(syncReq, nil)
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId}, fx.syncClientMock.EXPECT().QueueRequest(fx.senderId, fx.treeId, syncReq).Return(nil)
})).
Return(objecttree.AddResult{}, nil)
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(true)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []string{"h3"}, fx.syncHandler.heads)
}) })
t.Run("head update non empty heads not added", func(t *testing.T) { t.Run("handle head update message, heads equal", func(t *testing.T) {
fx := newSyncHandlerFixture(t) fx := newSyncHandlerFixture(t)
defer fx.stop() defer fx.stop()
treeId := "treeId" treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{} chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{ headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"}, Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
} }
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") objectMsg, _ := spacesyncproto.MarshallSyncMessage(treeMsg, "spaceId", treeId)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.syncHandler.heads = []string{"h1"}
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes() fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.syncClientMock.EXPECT().
CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullRequest, nil)
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullRequest), gomock.Eq(""))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err) require.NoError(t, err)
}) })
t.Run("head update non empty equal heads", func(t *testing.T) { t.Run("handle head update message, no sync request returned", func(t *testing.T) {
fx := newSyncHandlerFixture(t) fx := newSyncHandlerFixture(t)
defer fx.stop() defer fx.stop()
treeId := "treeId" treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{} chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{ headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"}, Heads: []string{"h3"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
} }
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") objectMsg, _ := spacesyncproto.MarshallSyncMessage(treeMsg, "spaceId", treeId)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.syncHandler.heads = []string{"h2"}
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes() fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().Heads().Times(2).Return([]string{"h2"})
fx.objectTreeMock.EXPECT().Heads().Times(2).Return([]string{"h3"})
fx.syncProtocolMock.EXPECT().HeadUpdate(ctx, fx.senderId, gomock.Any()).Return(nil, nil)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, []string{"h3"}, fx.syncHandler.heads)
}) })
t.Run("head update empty", func(t *testing.T) { t.Run("handle full sync request returns error", func(t *testing.T) {
fx := newSyncHandlerFixture(t) fx := newSyncHandlerFixture(t)
defer fx.stop() defer fx.stop()
treeId := "treeId" treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{} chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{ fullRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"}, Heads: []string{"h3"},
Changes: nil,
SnapshotPath: []string{"h1"},
} }
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId) treeMsg := treechangeproto.WrapFullRequest(fullRequest, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "") objectMsg, _ := spacesyncproto.MarshallSyncMessage(treeMsg, "spaceId", treeId)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.syncHandler.heads = []string{"h2"}
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes() fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.syncClientMock.EXPECT(). fx.objectTreeMock.EXPECT().Heads().Times(3).Return([]string{"h2"})
CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullRequest, nil)
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullRequest), gomock.Eq(""))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err) require.Equal(t, err, ErrMessageIsRequest)
}) })
t.Run("head update empty equal heads", func(t *testing.T) { t.Run("handle full sync response", func(t *testing.T) {
fx := newSyncHandlerFixture(t) fx := newSyncHandlerFixture(t)
defer fx.stop() defer fx.stop()
treeId := "treeId" treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"},
Changes: nil,
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes()
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
}
func TestSyncHandler_HandleFullSyncRequest(t *testing.T) {
ctx := context.Background()
log = logger.CtxLogger{Logger: zap.NewNop()}
fullResponse := &treechangeproto.TreeSyncMessage{
Content: &treechangeproto.TreeSyncContentValue{
Value: &treechangeproto.TreeSyncContentValue_FullSyncResponse{
FullSyncResponse: &treechangeproto.TreeFullSyncResponse{},
},
},
}
t.Run("full sync request with change", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Header().Return(nil)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
fx.syncClientMock.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullResponse), gomock.Eq(""))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
t.Run("full sync request with change same heads", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().
Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Header().Return(nil)
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h1"}).AnyTimes()
fx.syncClientMock.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullResponse), gomock.Eq(""))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
t.Run("full sync request without change but with reply id", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
replyId := "replyId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
objectMsg.RequestId = replyId
fx.objectTreeMock.EXPECT().
Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Header().Return(nil)
fx.syncClientMock.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Eq(fullResponse), gomock.Eq(replyId))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err)
})
t.Run("full sync request with add raw changes error", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullRequest(fullSyncRequest, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, "")
fx.objectTreeMock.EXPECT().
Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().Header().Return(nil)
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h2"})
fx.objectTreeMock.EXPECT().
HasChanges(gomock.Eq([]string{"h1"})).
Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, fmt.Errorf(""))
fx.syncClientMock.EXPECT().SendWithReply(gomock.Any(), gomock.Eq(senderId), gomock.Eq(treeId), gomock.Any(), gomock.Eq(""))
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.Error(t, err)
})
}
func TestSyncHandler_HandleFullSyncResponse(t *testing.T) {
ctx := context.Background()
log = logger.CtxLogger{Logger: zap.NewNop()}
t.Run("full sync response with change", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
replyId := "replyId"
chWithId := &treechangeproto.RawTreeChangeWithId{} chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncResponse := &treechangeproto.TreeFullSyncResponse{ fullSyncResponse := &treechangeproto.TreeFullSyncResponse{
Heads: []string{"h1"}, Heads: []string{"h3"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
} }
treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId) treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, replyId) objectMsg, _ := spacesyncproto.MarshallSyncMessage(treeMsg, "spaceId", treeId)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId) fx.syncHandler.heads = []string{"h2"}
fx.objectTreeMock.EXPECT(). fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
Heads(). fx.objectTreeMock.EXPECT().Heads().Times(2).Return([]string{"h2"})
Return([]string{"h2"}).AnyTimes() fx.objectTreeMock.EXPECT().Heads().Times(2).Return([]string{"h3"})
fx.objectTreeMock.EXPECT(). fx.syncProtocolMock.EXPECT().FullSyncResponse(ctx, fx.senderId, gomock.Any()).Return(nil)
HasChanges(gomock.Eq([]string{"h1"})).
Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg) err := fx.syncHandler.HandleMessage(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
})
t.Run("full sync response with same heads", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
senderId := "senderId"
replyId := "replyId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncResponse := &treechangeproto.TreeFullSyncResponse{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
treeMsg := treechangeproto.WrapFullResponse(fullSyncResponse, chWithId)
objectMsg, _ := objectsync.MarshallTreeMessage(treeMsg, "spaceId", treeId, replyId)
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(treeId)
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h1"}).AnyTimes()
err := fx.syncHandler.HandleMessage(ctx, senderId, objectMsg)
require.NoError(t, err) require.NoError(t, err)
}) })
} }
func TestSyncTreeHandler_HandleRequest(t *testing.T) {
ctx := context.Background()
t.Run("handle request", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullRequest := &treechangeproto.TreeFullSyncRequest{}
treeMsg := treechangeproto.WrapFullRequest(fullRequest, chWithId)
objectMsg, _ := spacesyncproto.MarshallSyncMessage(treeMsg, "spaceId", treeId)
syncResp := &treechangeproto.TreeSyncMessage{}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.syncProtocolMock.EXPECT().FullSyncRequest(ctx, fx.senderId, gomock.Any()).Return(syncResp, nil)
res, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg)
require.NoError(t, err)
require.NotNil(t, res)
})
t.Run("handle other message", func(t *testing.T) {
fx := newSyncHandlerFixture(t)
defer fx.stop()
treeId := "treeId"
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullResponse := &treechangeproto.TreeFullSyncResponse{}
responseMsg := treechangeproto.WrapFullResponse(fullResponse, chWithId)
headUpdate := &treechangeproto.TreeHeadUpdate{}
headUpdateMsg := treechangeproto.WrapHeadUpdate(headUpdate, chWithId)
for _, msg := range []*treechangeproto.TreeSyncMessage{responseMsg, headUpdateMsg} {
objectMsg, _ := spacesyncproto.MarshallSyncMessage(msg, "spaceId", treeId)
_, err := fx.syncHandler.HandleRequest(ctx, fx.senderId, objectMsg)
require.Equal(t, err, ErrMessageIsNotRequest)
}
})
}

View File

@ -2,16 +2,19 @@ package synctree
import ( import (
"context" "context"
"fmt" "errors"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/rpc/rpcerr"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"go.uber.org/zap" "go.uber.org/zap"
"time" )
var (
ErrNoResponsiblePeers = errors.New("no responsible peers")
) )
type treeRemoteGetter struct { type treeRemoteGetter struct {
@ -36,7 +39,7 @@ func (t treeRemoteGetter) getPeers(ctx context.Context) (peerIds []string, err e
return return
} }
if len(respPeers) == 0 { if len(respPeers) == 0 {
err = fmt.Errorf("no responsible peers") err = ErrNoResponsiblePeers
return return
} }
for _, p := range respPeers { for _, p := range respPeers {
@ -47,7 +50,7 @@ func (t treeRemoteGetter) getPeers(ctx context.Context) (peerIds []string, err e
func (t treeRemoteGetter) treeRequest(ctx context.Context, peerId string) (msg *treechangeproto.TreeSyncMessage, err error) { func (t treeRemoteGetter) treeRequest(ctx context.Context, peerId string) (msg *treechangeproto.TreeSyncMessage, err error) {
newTreeRequest := t.deps.SyncClient.CreateNewTreeRequest() newTreeRequest := t.deps.SyncClient.CreateNewTreeRequest()
resp, err := t.deps.SyncClient.SendSync(ctx, peerId, t.treeId, newTreeRequest) resp, err := t.deps.SyncClient.SendRequest(ctx, peerId, t.treeId, newTreeRequest)
if err != nil { if err != nil {
return return
} }
@ -57,37 +60,13 @@ func (t treeRemoteGetter) treeRequest(ctx context.Context, peerId string) (msg *
return return
} }
func (t treeRemoteGetter) treeRequestLoop(ctx context.Context, wait bool) (msg *treechangeproto.TreeSyncMessage, err error) { func (t treeRemoteGetter) treeRequestLoop(ctx context.Context) (msg *treechangeproto.TreeSyncMessage, err error) {
peerIdx := 0 availablePeers, err := t.getPeers(ctx)
Loop: if err != nil {
for { return
select {
case <-ctx.Done():
return nil, fmt.Errorf("waiting for object %s interrupted, context closed", t.treeId)
default:
break
}
availablePeers, err := t.getPeers(ctx)
if err != nil {
if !wait {
return nil, err
}
select {
// wait for peers to connect
case <-time.After(1 * time.Second):
continue Loop
case <-ctx.Done():
return nil, fmt.Errorf("waiting for object %s interrupted, context closed", t.treeId)
}
}
peerIdx = peerIdx % len(availablePeers)
msg, err = t.treeRequest(ctx, availablePeers[peerIdx])
if err == nil || !wait {
return msg, err
}
peerIdx++
} }
// in future we will try to load from different peers
return t.treeRequest(ctx, availablePeers[0])
} }
func (t treeRemoteGetter) getTree(ctx context.Context) (treeStorage treestorage.TreeStorage, isRemote bool, err error) { func (t treeRemoteGetter) getTree(ctx context.Context) (treeStorage treestorage.TreeStorage, isRemote bool, err error) {
@ -109,22 +88,15 @@ func (t treeRemoteGetter) getTree(ctx context.Context) (treeStorage treestorage.
} }
isRemote = true isRemote = true
resp, err := t.treeRequestLoop(ctx, t.deps.WaitTreeRemoteSync) resp, err := t.treeRequestLoop(ctx)
if err != nil { if err != nil {
return return
} }
switch { fullSyncResp := resp.GetContent().GetFullSyncResponse()
case resp.GetContent().GetErrorResponse() != nil: if fullSyncResp == nil {
errResp := resp.GetContent().GetErrorResponse()
err = rpcerr.Err(errResp.ErrCode)
return
case resp.GetContent().GetFullSyncResponse() == nil:
err = treechangeproto.ErrUnexpected err = treechangeproto.ErrUnexpected
return return
default:
break
} }
fullSyncResp := resp.GetContent().GetFullSyncResponse()
payload := treestorage.TreeStorageCreatePayload{ payload := treestorage.TreeStorageCreatePayload{
RootRawChange: resp.RootChange, RootRawChange: resp.RootChange,

View File

@ -0,0 +1,87 @@
package synctree
import (
"context"
"fmt"
"testing"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/mock_synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/peermanager/mock_peermanager"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/peer/mock_peer"
"github.com/gogo/protobuf/proto"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
)
type treeRemoteGetterFixture struct {
ctrl *gomock.Controller
treeGetter treeRemoteGetter
syncClientMock *mock_synctree.MockSyncClient
peerGetterMock *mock_peermanager.MockPeerManager
}
func newTreeRemoteGetterFixture(t *testing.T) *treeRemoteGetterFixture {
ctrl := gomock.NewController(t)
syncClientMock := mock_synctree.NewMockSyncClient(ctrl)
peerGetterMock := mock_peermanager.NewMockPeerManager(ctrl)
treeGetter := treeRemoteGetter{
deps: BuildDeps{
SyncClient: syncClientMock,
PeerGetter: peerGetterMock,
},
treeId: "treeId",
}
return &treeRemoteGetterFixture{
ctrl: ctrl,
treeGetter: treeGetter,
syncClientMock: syncClientMock,
peerGetterMock: peerGetterMock,
}
}
func (fx *treeRemoteGetterFixture) stop() {
fx.ctrl.Finish()
}
func TestTreeRemoteGetter(t *testing.T) {
ctx := context.Background()
peerId := "peerId"
treeRequest := &treechangeproto.TreeSyncMessage{}
treeResponse := &treechangeproto.TreeSyncMessage{
RootChange: &treechangeproto.RawTreeChangeWithId{Id: "id"},
}
marshalled, _ := proto.Marshal(treeResponse)
objectResponse := &spacesyncproto.ObjectSyncMessage{
Payload: marshalled,
}
t.Run("request works", func(t *testing.T) {
fx := newTreeRemoteGetterFixture(t)
defer fx.stop()
mockPeer := mock_peer.NewMockPeer(fx.ctrl)
mockPeer.EXPECT().Id().AnyTimes().Return(peerId)
fx.peerGetterMock.EXPECT().GetResponsiblePeers(ctx).Return([]peer.Peer{mockPeer}, nil)
fx.syncClientMock.EXPECT().CreateNewTreeRequest().Return(treeRequest)
fx.syncClientMock.EXPECT().SendRequest(ctx, peerId, fx.treeGetter.treeId, treeRequest).Return(objectResponse, nil)
resp, err := fx.treeGetter.treeRequestLoop(ctx)
require.NoError(t, err)
require.Equal(t, "id", resp.RootChange.Id)
})
t.Run("request fails", func(t *testing.T) {
fx := newTreeRemoteGetterFixture(t)
defer fx.stop()
treeRequest := &treechangeproto.TreeSyncMessage{}
mockPeer := mock_peer.NewMockPeer(fx.ctrl)
mockPeer.EXPECT().Id().AnyTimes().Return(peerId)
fx.peerGetterMock.EXPECT().GetResponsiblePeers(ctx).Return([]peer.Peer{mockPeer}, nil)
fx.syncClientMock.EXPECT().CreateNewTreeRequest().Return(treeRequest)
fx.syncClientMock.EXPECT().SendRequest(ctx, peerId, fx.treeGetter.treeId, treeRequest).AnyTimes().Return(nil, fmt.Errorf("some"))
_, err := fx.treeGetter.treeRequestLoop(ctx)
require.Error(t, err)
})
}

View File

@ -0,0 +1,153 @@
package synctree
import (
"context"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/util/slice"
"go.uber.org/zap"
)
type TreeSyncProtocol interface {
HeadUpdate(ctx context.Context, senderId string, update *treechangeproto.TreeHeadUpdate) (request *treechangeproto.TreeSyncMessage, err error)
FullSyncRequest(ctx context.Context, senderId string, request *treechangeproto.TreeFullSyncRequest) (response *treechangeproto.TreeSyncMessage, err error)
FullSyncResponse(ctx context.Context, senderId string, response *treechangeproto.TreeFullSyncResponse) (err error)
}
type treeSyncProtocol struct {
log logger.CtxLogger
spaceId string
objTree objecttree.ObjectTree
reqFactory RequestFactory
}
func newTreeSyncProtocol(spaceId string, objTree objecttree.ObjectTree, reqFactory RequestFactory) *treeSyncProtocol {
return &treeSyncProtocol{
log: log.With(zap.String("spaceId", spaceId), zap.String("treeId", objTree.Id())),
spaceId: spaceId,
objTree: objTree,
reqFactory: reqFactory,
}
}
func (t *treeSyncProtocol) HeadUpdate(ctx context.Context, senderId string, update *treechangeproto.TreeHeadUpdate) (fullRequest *treechangeproto.TreeSyncMessage, err error) {
var (
isEmptyUpdate = len(update.Changes) == 0
objTree = t.objTree
)
log := t.log.With(
zap.String("senderId", senderId),
zap.Strings("update heads", update.Heads),
zap.Int("len(update changes)", len(update.Changes)))
log.DebugCtx(ctx, "received head update message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "head update finished with error", zap.Error(err))
} else if fullRequest != nil {
cnt := fullRequest.Content.GetFullSyncRequest()
log = log.With(zap.Strings("request heads", cnt.Heads), zap.Int("len(request changes)", len(cnt.Changes)))
log.DebugCtx(ctx, "returning full sync request")
} else {
if !isEmptyUpdate {
log.DebugCtx(ctx, "head update finished correctly")
}
}
}()
// isEmptyUpdate is sent when the tree is brought up from cache
if isEmptyUpdate {
headEquals := slice.UnsortedEquals(objTree.Heads(), update.Heads)
log.DebugCtx(ctx, "is empty update", zap.Bool("headEquals", headEquals))
if headEquals {
return
}
// we need to sync in any case
fullRequest, err = t.reqFactory.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath)
return
}
if t.hasHeads(objTree, update.Heads) {
return
}
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: update.Heads,
RawChanges: update.Changes,
})
if err != nil {
return
}
if t.hasHeads(objTree, update.Heads) {
return
}
fullRequest, err = t.reqFactory.CreateFullSyncRequest(objTree, update.Heads, update.SnapshotPath)
return
}
func (t *treeSyncProtocol) FullSyncRequest(ctx context.Context, senderId string, request *treechangeproto.TreeFullSyncRequest) (fullResponse *treechangeproto.TreeSyncMessage, err error) {
var (
objTree = t.objTree
)
log := t.log.With(zap.String("senderId", senderId),
zap.Strings("request heads", request.Heads),
zap.Int("len(request changes)", len(request.Changes)))
log.DebugCtx(ctx, "received full sync request message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "full sync request finished with error", zap.Error(err))
} else if fullResponse != nil {
cnt := fullResponse.Content.GetFullSyncResponse()
log = log.With(zap.Strings("response heads", cnt.Heads), zap.Int("len(response changes)", len(cnt.Changes)))
log.DebugCtx(ctx, "full sync response sent")
}
}()
if len(request.Changes) != 0 && !t.hasHeads(objTree, request.Heads) {
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: request.Heads,
RawChanges: request.Changes,
})
if err != nil {
return
}
}
fullResponse, err = t.reqFactory.CreateFullSyncResponse(objTree, request.Heads, request.SnapshotPath)
return
}
func (t *treeSyncProtocol) FullSyncResponse(ctx context.Context, senderId string, response *treechangeproto.TreeFullSyncResponse) (err error) {
var (
objTree = t.objTree
)
log := log.With(
zap.Strings("heads", response.Heads),
zap.Int("len(changes)", len(response.Changes)))
log.DebugCtx(ctx, "received full sync response message")
defer func() {
if err != nil {
log.ErrorCtx(ctx, "full sync response failed", zap.Error(err))
} else {
log.DebugCtx(ctx, "full sync response succeeded")
}
}()
if t.hasHeads(objTree, response.Heads) {
return
}
_, err = objTree.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: response.Heads,
RawChanges: response.Changes,
})
return
}
func (t *treeSyncProtocol) hasHeads(ot objecttree.ObjectTree, heads []string) bool {
return slice.UnsortedEquals(ot.Heads(), heads) || ot.HasChanges(heads...)
}

View File

@ -0,0 +1,293 @@
package synctree
import (
"context"
"fmt"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/mock_synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"testing"
)
type treeSyncProtocolFixture struct {
log logger.CtxLogger
spaceId string
senderId string
treeId string
objectTreeMock *testObjTreeMock
reqFactory *mock_synctree.MockRequestFactory
ctrl *gomock.Controller
syncProtocol TreeSyncProtocol
}
func newSyncProtocolFixture(t *testing.T) *treeSyncProtocolFixture {
ctrl := gomock.NewController(t)
objTree := &testObjTreeMock{
MockObjectTree: mock_objecttree.NewMockObjectTree(ctrl),
}
spaceId := "spaceId"
reqFactory := mock_synctree.NewMockRequestFactory(ctrl)
objTree.EXPECT().Id().Return("treeId")
syncProtocol := newTreeSyncProtocol(spaceId, objTree, reqFactory)
return &treeSyncProtocolFixture{
log: log,
spaceId: spaceId,
senderId: "senderId",
treeId: "treeId",
objectTreeMock: objTree,
reqFactory: reqFactory,
ctrl: ctrl,
syncProtocol: syncProtocol,
}
}
func (fx *treeSyncProtocolFixture) stop() {
fx.ctrl.Finish()
}
func TestTreeSyncProtocol_HeadUpdate(t *testing.T) {
ctx := context.Background()
fullRequest := &treechangeproto.TreeSyncMessage{
Content: &treechangeproto.TreeSyncContentValue{
Value: &treechangeproto.TreeSyncContentValue_FullSyncRequest{
FullSyncRequest: &treechangeproto.TreeFullSyncRequest{},
},
},
}
t.Run("head update non empty all heads added", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).Times(2)
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(true)
res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.NoError(t, err)
require.Nil(t, res)
})
t.Run("head update non empty equal heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes()
res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.NoError(t, err)
require.Nil(t, res)
})
t.Run("head update empty", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"},
Changes: nil,
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
fx.reqFactory.EXPECT().
CreateFullSyncRequest(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullRequest, nil)
res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.NoError(t, err)
require.Equal(t, fullRequest, res)
})
t.Run("head update empty equal heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
headUpdate := &treechangeproto.TreeHeadUpdate{
Heads: []string{"h1"},
Changes: nil,
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h1"}).AnyTimes()
res, err := fx.syncProtocol.HeadUpdate(ctx, fx.senderId, headUpdate)
require.NoError(t, err)
require.Nil(t, res)
})
}
func TestTreeSyncProtocol_FullSyncRequest(t *testing.T) {
ctx := context.Background()
fullResponse := &treechangeproto.TreeSyncMessage{
Content: &treechangeproto.TreeSyncContentValue{
Value: &treechangeproto.TreeSyncContentValue_FullSyncResponse{
FullSyncResponse: &treechangeproto.TreeFullSyncResponse{},
},
},
}
t.Run("full sync request with change", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
fx.reqFactory.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
res, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest)
require.NoError(t, err)
require.Equal(t, fullResponse, res)
})
t.Run("full sync request with change same heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h1"}).AnyTimes()
fx.reqFactory.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
res, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest)
require.NoError(t, err)
require.Equal(t, fullResponse, res)
})
t.Run("full sync request without changes", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.reqFactory.EXPECT().
CreateFullSyncResponse(gomock.Eq(fx.objectTreeMock), gomock.Eq([]string{"h1"}), gomock.Eq([]string{"h1"})).
Return(fullResponse, nil)
res, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest)
require.NoError(t, err)
require.Equal(t, fullResponse, res)
})
t.Run("full sync request with change, raw changes error", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncRequest := &treechangeproto.TreeFullSyncRequest{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Heads().Return([]string{"h2"}).AnyTimes()
fx.objectTreeMock.EXPECT().HasChanges(gomock.Eq([]string{"h1"})).Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, fmt.Errorf("addRawChanges error"))
_, err := fx.syncProtocol.FullSyncRequest(ctx, fx.senderId, fullSyncRequest)
require.Error(t, err)
})
}
func TestTreeSyncProtocol_FullSyncResponse(t *testing.T) {
ctx := context.Background()
t.Run("full sync response with change", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncResponse := &treechangeproto.TreeFullSyncResponse{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h2"}).AnyTimes()
fx.objectTreeMock.EXPECT().
HasChanges(gomock.Eq([]string{"h1"})).
Return(false)
fx.objectTreeMock.EXPECT().
AddRawChanges(gomock.Any(), gomock.Eq(objecttree.RawChangesPayload{
NewHeads: []string{"h1"},
RawChanges: []*treechangeproto.RawTreeChangeWithId{chWithId},
})).
Return(objecttree.AddResult{}, nil)
err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullSyncResponse)
require.NoError(t, err)
})
t.Run("full sync response with same heads", func(t *testing.T) {
fx := newSyncProtocolFixture(t)
defer fx.stop()
chWithId := &treechangeproto.RawTreeChangeWithId{}
fullSyncResponse := &treechangeproto.TreeFullSyncResponse{
Heads: []string{"h1"},
Changes: []*treechangeproto.RawTreeChangeWithId{chWithId},
SnapshotPath: []string{"h1"},
}
fx.objectTreeMock.EXPECT().Id().AnyTimes().Return(fx.treeId)
fx.objectTreeMock.EXPECT().
Heads().
Return([]string{"h1"}).AnyTimes()
err := fx.syncProtocol.FullSyncResponse(ctx, fx.senderId, fullSyncResponse)
require.NoError(t, err)
})
}

View File

@ -8,7 +8,7 @@ import (
reflect "reflect" reflect "reflect"
objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockUpdateListener is a mock of UpdateListener interface. // MockUpdateListener is a mock of UpdateListener interface.

View File

@ -3,11 +3,11 @@ package synctree
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler" "github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
@ -82,51 +82,124 @@ func (m *messageLog) addMessage(msg protocolMsg) {
m.batcher.Add(context.Background(), msg) m.batcher.Add(context.Background(), msg)
} }
type requestPeerManager struct {
peerId string
handlers map[string]*testSyncHandler
log *messageLog
}
func newRequestPeerManager(peerId string, log *messageLog) *requestPeerManager {
return &requestPeerManager{
peerId: peerId,
handlers: map[string]*testSyncHandler{},
log: log,
}
}
func (r *requestPeerManager) addHandler(peerId string, handler *testSyncHandler) {
r.handlers[peerId] = handler
}
func (r *requestPeerManager) Run(ctx context.Context) (err error) {
return nil
}
func (r *requestPeerManager) Close(ctx context.Context) (err error) {
return nil
}
func (r *requestPeerManager) SendRequest(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
panic("should not be called")
}
func (r *requestPeerManager) QueueRequest(peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
pMsg := protocolMsg{
msg: msg,
senderId: r.peerId,
receiverId: peerId,
}
r.log.addMessage(pMsg)
return r.handlers[peerId].send(context.Background(), pMsg)
}
func (r *requestPeerManager) Init(a *app.App) (err error) {
return
}
func (r *requestPeerManager) Name() (name string) {
return
}
func (r *requestPeerManager) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
pMsg := protocolMsg{
msg: msg,
senderId: r.peerId,
receiverId: peerId,
}
r.log.addMessage(pMsg)
return r.handlers[peerId].send(context.Background(), pMsg)
}
func (r *requestPeerManager) Broadcast(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage) (err error) {
for _, handler := range r.handlers {
pMsg := protocolMsg{
msg: msg,
senderId: r.peerId,
receiverId: handler.peerId,
}
r.log.addMessage(pMsg)
handler.send(context.Background(), pMsg)
}
return
}
func (r *requestPeerManager) GetResponsiblePeers(ctx context.Context) (peers []peer.Peer, err error) {
return nil, nil
}
// testSyncHandler is the wrapper around individual tree to test sync protocol // testSyncHandler is the wrapper around individual tree to test sync protocol
type testSyncHandler struct { type testSyncHandler struct {
synchandler.SyncHandler synchandler.SyncHandler
batcher *mb.MB[protocolMsg] batcher *mb.MB[protocolMsg]
peerId string peerId string
aclList list.AclList aclList list.AclList
log *messageLog log *messageLog
syncClient objectsync.SyncClient syncClient SyncClient
builder objecttree.BuildObjectTreeFunc builder objecttree.BuildObjectTreeFunc
peerManager *requestPeerManager
} }
// createSyncHandler creates a sync handler when a tree is already created // createSyncHandler creates a sync handler when a tree is already created
func createSyncHandler(peerId, spaceId string, objTree objecttree.ObjectTree, log *messageLog) *testSyncHandler { func createSyncHandler(peerId, spaceId string, objTree objecttree.ObjectTree, log *messageLog) *testSyncHandler {
factory := objectsync.NewRequestFactory() peerManager := newRequestPeerManager(peerId, log)
syncClient := objectsync.NewSyncClient(spaceId, newTestMessagePool(peerId, log), factory) syncClient := NewSyncClient(spaceId, peerManager, peerManager)
netTree := &broadcastTree{ netTree := &broadcastTree{
ObjectTree: objTree, ObjectTree: objTree,
SyncClient: syncClient, SyncClient: syncClient,
} }
handler := newSyncTreeHandler(spaceId, netTree, syncClient, syncstatus.NewNoOpSyncStatus()) handler := newSyncTreeHandler(spaceId, netTree, syncClient, syncstatus.NewNoOpSyncStatus())
return newTestSyncHandler(peerId, handler) return &testSyncHandler{
SyncHandler: handler,
batcher: mb.New[protocolMsg](0),
peerId: peerId,
peerManager: peerManager,
}
} }
// createEmptySyncHandler creates a sync handler when the tree will be provided later (this emulates the situation when we have no tree) // createEmptySyncHandler creates a sync handler when the tree will be provided later (this emulates the situation when we have no tree)
func createEmptySyncHandler(peerId, spaceId string, builder objecttree.BuildObjectTreeFunc, aclList list.AclList, log *messageLog) *testSyncHandler { func createEmptySyncHandler(peerId, spaceId string, builder objecttree.BuildObjectTreeFunc, aclList list.AclList, log *messageLog) *testSyncHandler {
factory := objectsync.NewRequestFactory() peerManager := newRequestPeerManager(peerId, log)
syncClient := objectsync.NewSyncClient(spaceId, newTestMessagePool(peerId, log), factory) syncClient := NewSyncClient(spaceId, peerManager, peerManager)
batcher := mb.New[protocolMsg](0) batcher := mb.New[protocolMsg](0)
return &testSyncHandler{ return &testSyncHandler{
batcher: batcher,
peerId: peerId,
aclList: aclList,
log: log,
syncClient: syncClient,
builder: builder,
}
}
func newTestSyncHandler(peerId string, syncHandler synchandler.SyncHandler) *testSyncHandler {
batcher := mb.New[protocolMsg](0)
return &testSyncHandler{
SyncHandler: syncHandler,
batcher: batcher, batcher: batcher,
peerId: peerId, peerId: peerId,
aclList: aclList,
log: log,
syncClient: syncClient,
builder: builder,
peerManager: peerManager,
} }
} }
@ -140,13 +213,8 @@ func (h *testSyncHandler) HandleMessage(ctx context.Context, senderId string, re
return return
} }
if unmarshalled.Content.GetFullSyncResponse() == nil { if unmarshalled.Content.GetFullSyncResponse() == nil {
newTreeRequest := objectsync.NewRequestFactory().CreateNewTreeRequest() newTreeRequest := NewRequestFactory().CreateNewTreeRequest()
var objMsg *spacesyncproto.ObjectSyncMessage return h.syncClient.QueueRequest(senderId, request.ObjectId, newTreeRequest)
objMsg, err = objectsync.MarshallTreeMessage(newTreeRequest, request.SpaceId, request.ObjectId, "")
if err != nil {
return
}
return h.manager().SendPeer(context.Background(), senderId, objMsg)
} }
fullSyncResponse := unmarshalled.Content.GetFullSyncResponse() fullSyncResponse := unmarshalled.Content.GetFullSyncResponse()
treeStorage, _ := treestorage.NewInMemoryTreeStorage(unmarshalled.RootChange, []string{unmarshalled.RootChange.Id}, nil) treeStorage, _ := treestorage.NewInMemoryTreeStorage(unmarshalled.RootChange, []string{unmarshalled.RootChange.Id}, nil)
@ -166,20 +234,13 @@ func (h *testSyncHandler) HandleMessage(ctx context.Context, senderId string, re
return return
} }
h.SyncHandler = newSyncTreeHandler(request.SpaceId, netTree, h.syncClient, syncstatus.NewNoOpSyncStatus()) h.SyncHandler = newSyncTreeHandler(request.SpaceId, netTree, h.syncClient, syncstatus.NewNoOpSyncStatus())
var objMsg *spacesyncproto.ObjectSyncMessage headUpdate := NewRequestFactory().CreateHeadUpdate(netTree, res.Added)
newTreeRequest := objectsync.NewRequestFactory().CreateHeadUpdate(netTree, res.Added) h.syncClient.Broadcast(headUpdate)
objMsg, err = objectsync.MarshallTreeMessage(newTreeRequest, request.SpaceId, request.ObjectId, "") return nil
if err != nil {
return
}
return h.manager().Broadcast(context.Background(), objMsg)
} }
func (h *testSyncHandler) manager() *testMessagePool { func (h *testSyncHandler) manager() *requestPeerManager {
if h.SyncHandler != nil { return h.peerManager
return h.SyncHandler.(*syncTreeHandler).syncClient.MessagePool().(*testMessagePool)
}
return h.syncClient.MessagePool().(*testMessagePool)
} }
func (h *testSyncHandler) tree() *broadcastTree { func (h *testSyncHandler) tree() *broadcastTree {
@ -211,74 +272,28 @@ func (h *testSyncHandler) run(ctx context.Context, t *testing.T, wg *sync.WaitGr
h.tree().Unlock() h.tree().Unlock()
continue continue
} }
err = h.HandleMessage(ctx, res.senderId, res.msg) if res.description().name == "FullSyncRequest" {
if err != nil { resp, err := h.HandleRequest(ctx, res.senderId, res.msg)
fmt.Println("error handling message", err.Error()) if err != nil {
continue fmt.Println("error handling request", err.Error())
continue
}
h.peerManager.SendPeer(ctx, res.senderId, resp)
} else {
err = h.HandleMessage(ctx, res.senderId, res.msg)
if err != nil {
fmt.Println("error handling message", err.Error())
}
} }
} }
}() }()
} }
// testMessagePool captures all other handlers and sends messages to them
type testMessagePool struct {
peerId string
handlers map[string]*testSyncHandler
log *messageLog
}
func newTestMessagePool(peerId string, log *messageLog) *testMessagePool {
return &testMessagePool{handlers: map[string]*testSyncHandler{}, peerId: peerId, log: log}
}
func (m *testMessagePool) addHandler(peerId string, handler *testSyncHandler) {
m.handlers[peerId] = handler
}
func (m *testMessagePool) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
pMsg := protocolMsg{
msg: msg,
senderId: m.peerId,
receiverId: peerId,
}
m.log.addMessage(pMsg)
return m.handlers[peerId].send(context.Background(), pMsg)
}
func (m *testMessagePool) Broadcast(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage) (err error) {
for _, handler := range m.handlers {
pMsg := protocolMsg{
msg: msg,
senderId: m.peerId,
receiverId: handler.peerId,
}
m.log.addMessage(pMsg)
handler.send(context.Background(), pMsg)
}
return
}
func (m *testMessagePool) GetResponsiblePeers(ctx context.Context) (peers []peer.Peer, err error) {
panic("should not be called")
}
func (m *testMessagePool) LastUsage() time.Time {
panic("should not be called")
}
func (m *testMessagePool) HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) {
panic("should not be called")
}
func (m *testMessagePool) SendSync(ctx context.Context, peerId string, message *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
panic("should not be called")
}
// broadcastTree is the tree that broadcasts changes to everyone when changes are added // broadcastTree is the tree that broadcasts changes to everyone when changes are added
// it is a simplified version of SyncTree which is easier to use in the test environment // it is a simplified version of SyncTree which is easier to use in the test environment
type broadcastTree struct { type broadcastTree struct {
objecttree.ObjectTree objecttree.ObjectTree
objectsync.SyncClient SyncClient
} }
func (b *broadcastTree) AddRawChanges(ctx context.Context, changes objecttree.RawChangesPayload) (objecttree.AddResult, error) { func (b *broadcastTree) AddRawChanges(ctx context.Context, changes objecttree.RawChangesPayload) (objecttree.AddResult, error) {
@ -287,7 +302,7 @@ func (b *broadcastTree) AddRawChanges(ctx context.Context, changes objecttree.Ra
return objecttree.AddResult{}, err return objecttree.AddResult{}, err
} }
upd := b.SyncClient.CreateHeadUpdate(b.ObjectTree, res.Added) upd := b.SyncClient.CreateHeadUpdate(b.ObjectTree, res.Added)
b.SyncClient.Broadcast(ctx, upd) b.SyncClient.Broadcast(upd)
return res, nil return res, nil
} }

View File

@ -9,7 +9,7 @@ import (
reflect "reflect" reflect "reflect"
treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockTreeStorage is a mock of TreeStorage interface. // MockTreeStorage is a mock of TreeStorage interface.

View File

@ -11,7 +11,7 @@ import (
app "github.com/anyproto/any-sync/app" app "github.com/anyproto/any-sync/app"
objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
treemanager "github.com/anyproto/any-sync/commonspace/object/treemanager" treemanager "github.com/anyproto/any-sync/commonspace/object/treemanager"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockTreeManager is a mock of TreeManager interface. // MockTreeManager is a mock of TreeManager interface.

View File

@ -0,0 +1,98 @@
package objectmanager
import (
"context"
"errors"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/syncobjectgetter"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/settings"
"github.com/anyproto/any-sync/commonspace/spacestate"
"sync/atomic"
)
var (
ErrSpaceClosed = errors.New("space is closed")
)
type ObjectManager interface {
treemanager.TreeManager
AddObject(object syncobjectgetter.SyncObject)
GetObject(ctx context.Context, objectId string) (obj syncobjectgetter.SyncObject, err error)
}
type objectManager struct {
treemanager.TreeManager
spaceId string
reservedObjects []syncobjectgetter.SyncObject
spaceIsClosed *atomic.Bool
}
func New(manager treemanager.TreeManager) ObjectManager {
return &objectManager{
TreeManager: manager,
}
}
func (o *objectManager) Init(a *app.App) (err error) {
state := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
o.spaceId = state.SpaceId
o.spaceIsClosed = state.SpaceIsClosed
settingsObject := a.MustComponent(settings.CName).(settings.Settings).SettingsObject()
acl := a.MustComponent(syncacl.CName).(syncacl.SyncAcl)
o.AddObject(settingsObject)
o.AddObject(acl)
return nil
}
func (o *objectManager) Run(ctx context.Context) (err error) {
return nil
}
func (o *objectManager) Close(ctx context.Context) (err error) {
return nil
}
func (o *objectManager) AddObject(object syncobjectgetter.SyncObject) {
o.reservedObjects = append(o.reservedObjects, object)
}
func (o *objectManager) Name() string {
return treemanager.CName
}
func (o *objectManager) GetTree(ctx context.Context, spaceId, treeId string) (objecttree.ObjectTree, error) {
if o.spaceIsClosed.Load() {
return nil, ErrSpaceClosed
}
if obj := o.getReservedObject(treeId); obj != nil {
return obj.(objecttree.ObjectTree), nil
}
return o.TreeManager.GetTree(ctx, spaceId, treeId)
}
func (o *objectManager) getReservedObject(id string) syncobjectgetter.SyncObject {
for _, obj := range o.reservedObjects {
if obj != nil && obj.Id() == id {
return obj
}
}
return nil
}
func (o *objectManager) GetObject(ctx context.Context, objectId string) (obj syncobjectgetter.SyncObject, err error) {
if o.spaceIsClosed.Load() {
return nil, ErrSpaceClosed
}
if obj := o.getReservedObject(objectId); obj != nil {
return obj, nil
}
t, err := o.TreeManager.GetTree(ctx, o.spaceId, objectId)
if err != nil {
return
}
obj = t.(syncobjectgetter.SyncObject)
return
}

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/objectsync (interfaces: SyncClient) // Source: github.com/anyproto/any-sync/commonspace/objectsync (interfaces: ObjectSync)
// Package mock_objectsync is a generated GoMock package. // Package mock_objectsync is a generated GoMock package.
package mock_objectsync package mock_objectsync
@ -7,146 +7,146 @@ package mock_objectsync
import ( import (
context "context" context "context"
reflect "reflect" reflect "reflect"
time "time"
objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" app "github.com/anyproto/any-sync/app"
treechangeproto "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
objectsync "github.com/anyproto/any-sync/commonspace/objectsync" objectsync "github.com/anyproto/any-sync/commonspace/objectsync"
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto" spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockSyncClient is a mock of SyncClient interface. // MockObjectSync is a mock of ObjectSync interface.
type MockSyncClient struct { type MockObjectSync struct {
ctrl *gomock.Controller ctrl *gomock.Controller
recorder *MockSyncClientMockRecorder recorder *MockObjectSyncMockRecorder
} }
// MockSyncClientMockRecorder is the mock recorder for MockSyncClient. // MockObjectSyncMockRecorder is the mock recorder for MockObjectSync.
type MockSyncClientMockRecorder struct { type MockObjectSyncMockRecorder struct {
mock *MockSyncClient mock *MockObjectSync
} }
// NewMockSyncClient creates a new mock instance. // NewMockObjectSync creates a new mock instance.
func NewMockSyncClient(ctrl *gomock.Controller) *MockSyncClient { func NewMockObjectSync(ctrl *gomock.Controller) *MockObjectSync {
mock := &MockSyncClient{ctrl: ctrl} mock := &MockObjectSync{ctrl: ctrl}
mock.recorder = &MockSyncClientMockRecorder{mock} mock.recorder = &MockObjectSyncMockRecorder{mock}
return mock return mock
} }
// EXPECT returns an object that allows the caller to indicate expected use. // EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockSyncClient) EXPECT() *MockSyncClientMockRecorder { func (m *MockObjectSync) EXPECT() *MockObjectSyncMockRecorder {
return m.recorder return m.recorder
} }
// Broadcast mocks base method. // Close mocks base method.
func (m *MockSyncClient) Broadcast(arg0 context.Context, arg1 *treechangeproto.TreeSyncMessage) { func (m *MockObjectSync) Close(arg0 context.Context) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
m.ctrl.Call(m, "Broadcast", arg0, arg1) ret := m.ctrl.Call(m, "Close", arg0)
} ret0, _ := ret[0].(error)
// Broadcast indicates an expected call of Broadcast.
func (mr *MockSyncClientMockRecorder) Broadcast(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Broadcast", reflect.TypeOf((*MockSyncClient)(nil).Broadcast), arg0, arg1)
}
// CreateFullSyncRequest mocks base method.
func (m *MockSyncClient) CreateFullSyncRequest(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncRequest", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncRequest indicates an expected call of CreateFullSyncRequest.
func (mr *MockSyncClientMockRecorder) CreateFullSyncRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncRequest), arg0, arg1, arg2)
}
// CreateFullSyncResponse mocks base method.
func (m *MockSyncClient) CreateFullSyncResponse(arg0 objecttree.ObjectTree, arg1, arg2 []string) (*treechangeproto.TreeSyncMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateFullSyncResponse", arg0, arg1, arg2)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateFullSyncResponse indicates an expected call of CreateFullSyncResponse.
func (mr *MockSyncClientMockRecorder) CreateFullSyncResponse(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateFullSyncResponse", reflect.TypeOf((*MockSyncClient)(nil).CreateFullSyncResponse), arg0, arg1, arg2)
}
// CreateHeadUpdate mocks base method.
func (m *MockSyncClient) CreateHeadUpdate(arg0 objecttree.ObjectTree, arg1 []*treechangeproto.RawTreeChangeWithId) *treechangeproto.TreeSyncMessage {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateHeadUpdate", arg0, arg1)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage)
return ret0 return ret0
} }
// CreateHeadUpdate indicates an expected call of CreateHeadUpdate. // Close indicates an expected call of Close.
func (mr *MockSyncClientMockRecorder) CreateHeadUpdate(arg0, arg1 interface{}) *gomock.Call { func (mr *MockObjectSyncMockRecorder) Close(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateHeadUpdate", reflect.TypeOf((*MockSyncClient)(nil).CreateHeadUpdate), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockObjectSync)(nil).Close), arg0)
} }
// CreateNewTreeRequest mocks base method. // CloseThread mocks base method.
func (m *MockSyncClient) CreateNewTreeRequest() *treechangeproto.TreeSyncMessage { func (m *MockObjectSync) CloseThread(arg0 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateNewTreeRequest") ret := m.ctrl.Call(m, "CloseThread", arg0)
ret0, _ := ret[0].(*treechangeproto.TreeSyncMessage) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// CreateNewTreeRequest indicates an expected call of CreateNewTreeRequest. // CloseThread indicates an expected call of CloseThread.
func (mr *MockSyncClientMockRecorder) CreateNewTreeRequest() *gomock.Call { func (mr *MockObjectSyncMockRecorder) CloseThread(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateNewTreeRequest", reflect.TypeOf((*MockSyncClient)(nil).CreateNewTreeRequest)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CloseThread", reflect.TypeOf((*MockObjectSync)(nil).CloseThread), arg0)
} }
// MessagePool mocks base method. // HandleMessage mocks base method.
func (m *MockSyncClient) MessagePool() objectsync.MessagePool { func (m *MockObjectSync) HandleMessage(arg0 context.Context, arg1 objectsync.HandleMessage) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MessagePool") ret := m.ctrl.Call(m, "HandleMessage", arg0, arg1)
ret0, _ := ret[0].(objectsync.MessagePool) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// MessagePool indicates an expected call of MessagePool. // HandleMessage indicates an expected call of HandleMessage.
func (mr *MockSyncClientMockRecorder) MessagePool() *gomock.Call { func (mr *MockObjectSyncMockRecorder) HandleMessage(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MessagePool", reflect.TypeOf((*MockSyncClient)(nil).MessagePool)) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleMessage", reflect.TypeOf((*MockObjectSync)(nil).HandleMessage), arg0, arg1)
} }
// SendSync mocks base method. // HandleRequest mocks base method.
func (m *MockSyncClient) SendSync(arg0 context.Context, arg1, arg2 string, arg3 *treechangeproto.TreeSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) { func (m *MockObjectSync) HandleRequest(arg0 context.Context, arg1 *spacesyncproto.ObjectSyncMessage) (*spacesyncproto.ObjectSyncMessage, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendSync", arg0, arg1, arg2, arg3) ret := m.ctrl.Call(m, "HandleRequest", arg0, arg1)
ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage) ret0, _ := ret[0].(*spacesyncproto.ObjectSyncMessage)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// SendSync indicates an expected call of SendSync. // HandleRequest indicates an expected call of HandleRequest.
func (mr *MockSyncClientMockRecorder) SendSync(arg0, arg1, arg2, arg3 interface{}) *gomock.Call { func (mr *MockObjectSyncMockRecorder) HandleRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendSync", reflect.TypeOf((*MockSyncClient)(nil).SendSync), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HandleRequest", reflect.TypeOf((*MockObjectSync)(nil).HandleRequest), arg0, arg1)
} }
// SendWithReply mocks base method. // Init mocks base method.
func (m *MockSyncClient) SendWithReply(arg0 context.Context, arg1, arg2 string, arg3 *treechangeproto.TreeSyncMessage, arg4 string) error { func (m *MockObjectSync) Init(arg0 *app.App) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SendWithReply", arg0, arg1, arg2, arg3, arg4) ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// SendWithReply indicates an expected call of SendWithReply. // Init indicates an expected call of Init.
func (mr *MockSyncClientMockRecorder) SendWithReply(arg0, arg1, arg2, arg3, arg4 interface{}) *gomock.Call { func (mr *MockObjectSyncMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SendWithReply", reflect.TypeOf((*MockSyncClient)(nil).SendWithReply), arg0, arg1, arg2, arg3, arg4) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockObjectSync)(nil).Init), arg0)
}
// LastUsage mocks base method.
func (m *MockObjectSync) LastUsage() time.Time {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LastUsage")
ret0, _ := ret[0].(time.Time)
return ret0
}
// LastUsage indicates an expected call of LastUsage.
func (mr *MockObjectSyncMockRecorder) LastUsage() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastUsage", reflect.TypeOf((*MockObjectSync)(nil).LastUsage))
}
// Name mocks base method.
func (m *MockObjectSync) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockObjectSyncMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockObjectSync)(nil).Name))
}
// Run mocks base method.
func (m *MockObjectSync) Run(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Run", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Run indicates an expected call of Run.
func (mr *MockObjectSyncMockRecorder) Run(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockObjectSync)(nil).Run), arg0)
} }

View File

@ -1,142 +0,0 @@
package objectsync
import (
"context"
"fmt"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"go.uber.org/zap"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
)
type LastUsage interface {
LastUsage() time.Time
}
// MessagePool can be made generic to work with different streams
type MessagePool interface {
LastUsage
synchandler.SyncHandler
peermanager.PeerManager
SendSync(ctx context.Context, peerId string, message *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
}
type MessageHandler func(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error)
type responseWaiter struct {
ch chan *spacesyncproto.ObjectSyncMessage
}
type messagePool struct {
sync.Mutex
peermanager.PeerManager
messageHandler MessageHandler
waiters map[string]responseWaiter
waitersMx sync.Mutex
counter atomic.Uint64
lastUsage atomic.Int64
}
func newMessagePool(peerManager peermanager.PeerManager, messageHandler MessageHandler) MessagePool {
s := &messagePool{
PeerManager: peerManager,
messageHandler: messageHandler,
waiters: make(map[string]responseWaiter),
}
return s
}
func (s *messagePool) SendSync(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
s.updateLastUsage()
if _, ok := ctx.Deadline(); !ok {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Minute)
defer cancel()
}
newCounter := s.counter.Add(1)
msg.RequestId = genReplyKey(peerId, msg.ObjectId, newCounter)
log.InfoCtx(ctx, "mpool sendSync", zap.String("requestId", msg.RequestId))
s.waitersMx.Lock()
waiter := responseWaiter{
ch: make(chan *spacesyncproto.ObjectSyncMessage, 1),
}
s.waiters[msg.RequestId] = waiter
s.waitersMx.Unlock()
err = s.SendPeer(ctx, peerId, msg)
if err != nil {
return
}
select {
case <-ctx.Done():
s.waitersMx.Lock()
delete(s.waiters, msg.RequestId)
s.waitersMx.Unlock()
log.With(zap.String("requestId", msg.RequestId)).DebugCtx(ctx, "time elapsed when waiting")
err = fmt.Errorf("sendSync context error: %v", ctx.Err())
case reply = <-waiter.ch:
// success
}
return
}
func (s *messagePool) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
s.updateLastUsage()
return s.PeerManager.SendPeer(ctx, peerId, msg)
}
func (s *messagePool) Broadcast(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage) (err error) {
s.updateLastUsage()
return s.PeerManager.Broadcast(ctx, msg)
}
func (s *messagePool) HandleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
s.updateLastUsage()
if msg.ReplyId != "" {
log.InfoCtx(ctx, "mpool receive reply", zap.String("replyId", msg.ReplyId))
// we got reply, send it to waiter
if s.stopWaiter(msg) {
return
}
log.WarnCtx(ctx, "reply id does not exist", zap.String("replyId", msg.ReplyId))
return
}
return s.messageHandler(ctx, senderId, msg)
}
func (s *messagePool) LastUsage() time.Time {
return time.Unix(s.lastUsage.Load(), 0)
}
func (s *messagePool) updateLastUsage() {
s.lastUsage.Store(time.Now().Unix())
}
func (s *messagePool) stopWaiter(msg *spacesyncproto.ObjectSyncMessage) bool {
s.waitersMx.Lock()
waiter, exists := s.waiters[msg.ReplyId]
if exists {
delete(s.waiters, msg.ReplyId)
s.waitersMx.Unlock()
waiter.ch <- msg
return true
}
s.waitersMx.Unlock()
return false
}
func genReplyKey(peerId, treeId string, counter uint64) string {
b := &strings.Builder{}
b.WriteString(peerId)
b.WriteString(".")
b.WriteString(treeId)
b.WriteString(".")
b.WriteString(strconv.FormatUint(counter, 36))
return b.String()
}

View File

@ -1,18 +1,24 @@
//go:generate mockgen -destination mock_objectsync/mock_objectsync.go github.com/anyproto/any-sync/commonspace/objectsync SyncClient //go:generate mockgen -destination mock_objectsync/mock_objectsync.go github.com/anyproto/any-sync/commonspace/objectsync ObjectSync
package objectsync package objectsync
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/gogo/protobuf/proto"
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/metric"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/util/multiqueue"
"github.com/cheggaaa/mb/v3"
"github.com/gogo/protobuf/proto"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/syncobjectgetter" "github.com/anyproto/any-sync/commonspace/object/syncobjectgetter"
"github.com/anyproto/any-sync/commonspace/objectsync/synchandler"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
@ -20,138 +26,208 @@ import (
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
) )
var log = logger.NewNamed("common.commonspace.objectsync") const CName = "common.commonspace.objectsync"
var log = logger.NewNamed(CName)
type ObjectSync interface { type ObjectSync interface {
LastUsage LastUsage() time.Time
synchandler.SyncHandler HandleMessage(ctx context.Context, hm HandleMessage) (err error)
SyncClient() SyncClient HandleRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error)
CloseThread(id string) (err error)
app.ComponentRunnable
}
Close() (err error) type HandleMessage struct {
Id uint64
ReceiveTime time.Time
StartHandlingTime time.Time
Deadline time.Time
SenderId string
Message *spacesyncproto.ObjectSyncMessage
PeerCtx context.Context
}
func (m HandleMessage) LogFields(fields ...zap.Field) []zap.Field {
return append(fields,
metric.SpaceId(m.Message.SpaceId),
metric.ObjectId(m.Message.ObjectId),
metric.QueueDur(m.StartHandlingTime.Sub(m.ReceiveTime)),
metric.TotalDur(time.Since(m.ReceiveTime)),
)
} }
type objectSync struct { type objectSync struct {
spaceId string spaceId string
messagePool MessagePool
syncClient SyncClient
objectGetter syncobjectgetter.SyncObjectGetter objectGetter syncobjectgetter.SyncObjectGetter
configuration nodeconf.NodeConf configuration nodeconf.NodeConf
spaceStorage spacestorage.SpaceStorage spaceStorage spacestorage.SpaceStorage
metric metric.Metric
syncCtx context.Context
cancelSync context.CancelFunc
spaceIsDeleted *atomic.Bool spaceIsDeleted *atomic.Bool
handleQueue multiqueue.MultiQueue[HandleMessage]
} }
func NewObjectSync( func (s *objectSync) Init(a *app.App) (err error) {
spaceId string, s.spaceStorage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
spaceIsDeleted *atomic.Bool, s.objectGetter = a.MustComponent(treemanager.CName).(syncobjectgetter.SyncObjectGetter)
configuration nodeconf.NodeConf, s.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
peerManager peermanager.PeerManager, sharedData := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
objectGetter syncobjectgetter.SyncObjectGetter, mc := a.Component(metric.CName)
storage spacestorage.SpaceStorage) ObjectSync { if mc != nil {
syncCtx, cancel := context.WithCancel(context.Background()) s.metric = mc.(metric.Metric)
os := &objectSync{
objectGetter: objectGetter,
spaceStorage: storage,
spaceId: spaceId,
syncCtx: syncCtx,
cancelSync: cancel,
spaceIsDeleted: spaceIsDeleted,
configuration: configuration,
} }
os.messagePool = newMessagePool(peerManager, os.handleMessage) s.spaceIsDeleted = sharedData.SpaceIsDeleted
os.syncClient = NewSyncClient(spaceId, os.messagePool, NewRequestFactory()) s.spaceId = sharedData.SpaceId
return os s.handleQueue = multiqueue.New[HandleMessage](s.processHandleMessage, 30)
return nil
} }
func (s *objectSync) Close() (err error) { func (s *objectSync) Name() (name string) {
s.cancelSync() return CName
return }
func (s *objectSync) Run(ctx context.Context) (err error) {
return nil
}
func (s *objectSync) Close(ctx context.Context) (err error) {
return s.handleQueue.Close()
}
func New() ObjectSync {
return &objectSync{}
} }
func (s *objectSync) LastUsage() time.Time { func (s *objectSync) LastUsage() time.Time {
return s.messagePool.LastUsage() // TODO: add time
return time.Time{}
} }
func (s *objectSync) HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error) { func (s *objectSync) HandleRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) {
return s.messagePool.HandleMessage(ctx, senderId, message) peerId, err := peer.CtxPeerId(ctx)
if err != nil {
return nil, err
}
return s.handleRequest(ctx, peerId, req)
} }
func (s *objectSync) handleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) { func (s *objectSync) HandleMessage(ctx context.Context, hm HandleMessage) (err error) {
log := log.With( threadId := hm.Message.ObjectId
zap.String("objectId", msg.ObjectId), hm.ReceiveTime = time.Now()
zap.String("requestId", msg.RequestId), if hm.PeerCtx == nil {
zap.String("replyId", msg.ReplyId)) hm.PeerCtx = ctx
}
err = s.handleQueue.Add(ctx, threadId, hm)
if err == mb.ErrOverflowed {
log.InfoCtx(ctx, "queue overflowed", zap.String("spaceId", s.spaceId), zap.String("objectId", threadId))
// skip overflowed error
return nil
}
return
}
func (s *objectSync) processHandleMessage(msg HandleMessage) {
var err error
msg.StartHandlingTime = time.Now()
ctx := peer.CtxWithPeerId(context.Background(), msg.SenderId)
ctx = logger.CtxWithFields(ctx, zap.Uint64("msgId", msg.Id), zap.String("senderId", msg.SenderId))
defer func() {
if s.metric == nil {
return
}
s.metric.RequestLog(msg.PeerCtx, "space.streamOp", msg.LogFields(
zap.Error(err),
)...)
}()
if !msg.Deadline.IsZero() {
now := time.Now()
if now.After(msg.Deadline) {
log.InfoCtx(ctx, "skip message: deadline exceed")
err = context.DeadlineExceeded
return
}
}
if err = s.handleMessage(ctx, msg.SenderId, msg.Message); err != nil {
if msg.Message.ObjectId != "" {
// cleanup thread on error
_ = s.handleQueue.CloseThread(msg.Message.ObjectId)
}
log.InfoCtx(ctx, "handleMessage error", zap.Error(err))
}
}
func (s *objectSync) handleRequest(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error) {
log := log.With(zap.String("objectId", msg.ObjectId))
if s.spaceIsDeleted.Load() { if s.spaceIsDeleted.Load() {
log = log.With(zap.Bool("isDeleted", true)) log = log.With(zap.Bool("isDeleted", true))
// preventing sync with other clients if they are not just syncing the settings tree // preventing sync with other clients if they are not just syncing the settings tree
if !slices.Contains(s.configuration.NodeIds(s.spaceId), senderId) && msg.ObjectId != s.spaceStorage.SpaceSettingsId() { if !slices.Contains(s.configuration.NodeIds(s.spaceId), senderId) && msg.ObjectId != s.spaceStorage.SpaceSettingsId() {
s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId) return nil, spacesyncproto.ErrSpaceIsDeleted
return fmt.Errorf("can't perform operation with object, space is deleted")
} }
} }
log.DebugCtx(ctx, "handling message") err = s.checkEmptyFullSync(log, msg)
if err != nil {
return nil, err
}
obj, err := s.objectGetter.GetObject(ctx, msg.ObjectId)
if err != nil {
return nil, treechangeproto.ErrGetTree
}
return obj.HandleRequest(ctx, senderId, msg)
}
func (s *objectSync) handleMessage(ctx context.Context, senderId string, msg *spacesyncproto.ObjectSyncMessage) (err error) {
log := log.With(zap.String("objectId", msg.ObjectId))
if s.spaceIsDeleted.Load() {
log = log.With(zap.Bool("isDeleted", true))
// preventing sync with other clients if they are not just syncing the settings tree
if !slices.Contains(s.configuration.NodeIds(s.spaceId), senderId) && msg.ObjectId != s.spaceStorage.SpaceSettingsId() {
return spacesyncproto.ErrSpaceIsDeleted
}
}
err = s.checkEmptyFullSync(log, msg)
if err != nil {
return err
}
obj, err := s.objectGetter.GetObject(ctx, msg.ObjectId)
if err != nil {
return fmt.Errorf("failed to get object from cache: %w", err)
}
err = obj.HandleMessage(ctx, senderId, msg)
if err != nil {
return fmt.Errorf("failed to handle message: %w", err)
}
return
}
func (s *objectSync) CloseThread(id string) (err error) {
return s.handleQueue.CloseThread(id)
}
func (s *objectSync) checkEmptyFullSync(log logger.CtxLogger, msg *spacesyncproto.ObjectSyncMessage) (err error) {
hasTree, err := s.spaceStorage.HasTree(msg.ObjectId) hasTree, err := s.spaceStorage.HasTree(msg.ObjectId)
if err != nil { if err != nil {
s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId) log.Warn("failed to execute get operation on storage has tree", zap.Error(err))
return fmt.Errorf("falied to execute get operation on storage has tree: %w", err) return spacesyncproto.ErrUnexpected
} }
// in this case we will try to get it from remote, unless the sender also sent us the same request :-) // in this case we will try to get it from remote, unless the sender also sent us the same request :-)
if !hasTree { if !hasTree {
treeMsg := &treechangeproto.TreeSyncMessage{} treeMsg := &treechangeproto.TreeSyncMessage{}
err = proto.Unmarshal(msg.Payload, treeMsg) err = proto.Unmarshal(msg.Payload, treeMsg)
if err != nil { if err != nil {
s.sendError(ctx, nil, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId, msg.RequestId) return nil
return fmt.Errorf("failed to unmarshall tree sync message: %w", err)
} }
// this means that we don't have the tree locally and therefore can't return it // this means that we don't have the tree locally and therefore can't return it
if s.isEmptyFullSyncRequest(treeMsg) { if s.isEmptyFullSyncRequest(treeMsg) {
err = treechangeproto.ErrGetTree return treechangeproto.ErrGetTree
s.sendError(ctx, nil, treechangeproto.ErrGetTree, senderId, msg.ObjectId, msg.RequestId)
return fmt.Errorf("failed to get tree from storage on full sync: %w", err)
} }
} }
obj, err := s.objectGetter.GetObject(ctx, msg.ObjectId)
if err != nil {
// TODO: write tests for object sync https://linear.app/anytype/issue/GO-1299/write-tests-for-commonspaceobjectsync
s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId)
return fmt.Errorf("failed to get object from cache: %w", err)
}
// TODO: unmarshall earlier
err = obj.HandleMessage(ctx, senderId, msg)
if err != nil {
s.unmarshallSendError(ctx, msg, spacesyncproto.ErrUnexpected, senderId, msg.ObjectId)
return fmt.Errorf("failed to handle message: %w", err)
}
return return
} }
func (s *objectSync) SyncClient() SyncClient {
return s.syncClient
}
func (s *objectSync) unmarshallSendError(ctx context.Context, msg *spacesyncproto.ObjectSyncMessage, respErr error, senderId, objectId string) {
unmarshalled := &treechangeproto.TreeSyncMessage{}
err := proto.Unmarshal(msg.Payload, unmarshalled)
if err != nil {
return
}
s.sendError(ctx, unmarshalled.RootChange, respErr, senderId, objectId, msg.RequestId)
}
func (s *objectSync) sendError(ctx context.Context, root *treechangeproto.RawTreeChangeWithId, respErr error, senderId, objectId, replyId string) {
// we don't send errors if have no reply id, this can lead to bugs and also nobody needs this error
if replyId == "" {
return
}
resp := treechangeproto.WrapError(respErr, root)
if err := s.syncClient.SendWithReply(ctx, senderId, objectId, resp, replyId); err != nil {
log.InfoCtx(ctx, "failed to send error to client")
}
}
func (s *objectSync) isEmptyFullSyncRequest(msg *treechangeproto.TreeSyncMessage) bool { func (s *objectSync) isEmptyFullSyncRequest(msg *treechangeproto.TreeSyncMessage) bool {
return msg.GetContent().GetFullSyncRequest() != nil && len(msg.GetContent().GetFullSyncRequest().GetHeads()) == 0 return msg.GetContent().GetFullSyncRequest() != nil && len(msg.GetContent().GetFullSyncRequest().GetHeads()) == 0
} }

View File

@ -1,78 +0,0 @@
package objectsync
import (
"context"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"go.uber.org/zap"
)
type SyncClient interface {
RequestFactory
Broadcast(ctx context.Context, msg *treechangeproto.TreeSyncMessage)
SendWithReply(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage, replyId string) (err error)
SendSync(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
MessagePool() MessagePool
}
type syncClient struct {
RequestFactory
spaceId string
messagePool MessagePool
}
func NewSyncClient(
spaceId string,
messagePool MessagePool,
factory RequestFactory) SyncClient {
return &syncClient{
messagePool: messagePool,
RequestFactory: factory,
spaceId: spaceId,
}
}
func (s *syncClient) Broadcast(ctx context.Context, msg *treechangeproto.TreeSyncMessage) {
objMsg, err := MarshallTreeMessage(msg, s.spaceId, msg.RootChange.Id, "")
if err != nil {
return
}
err = s.messagePool.Broadcast(ctx, objMsg)
if err != nil {
log.DebugCtx(ctx, "broadcast error", zap.Error(err))
}
}
func (s *syncClient) SendSync(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, "")
if err != nil {
return
}
return s.messagePool.SendSync(ctx, peerId, objMsg)
}
func (s *syncClient) SendWithReply(ctx context.Context, peerId, objectId string, msg *treechangeproto.TreeSyncMessage, replyId string) (err error) {
objMsg, err := MarshallTreeMessage(msg, s.spaceId, objectId, replyId)
if err != nil {
return
}
return s.messagePool.SendPeer(ctx, peerId, objMsg)
}
func (s *syncClient) MessagePool() MessagePool {
return s.messagePool
}
func MarshallTreeMessage(message *treechangeproto.TreeSyncMessage, spaceId, objectId, replyId string) (objMsg *spacesyncproto.ObjectSyncMessage, err error) {
payload, err := message.Marshal()
if err != nil {
return
}
objMsg = &spacesyncproto.ObjectSyncMessage{
ReplyId: replyId,
Payload: payload,
ObjectId: objectId,
SpaceId: spaceId,
}
return
}

View File

@ -6,5 +6,6 @@ import (
) )
type SyncHandler interface { type SyncHandler interface {
HandleMessage(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (err error) HandleMessage(ctx context.Context, senderId string, message *spacesyncproto.ObjectSyncMessage) (err error)
HandleRequest(ctx context.Context, senderId string, request *spacesyncproto.ObjectSyncMessage) (response *spacesyncproto.ObjectSyncMessage, err error)
} }

View File

@ -0,0 +1,99 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/objecttreebuilder (interfaces: TreeBuilder)
// Package mock_objecttreebuilder is a generated GoMock package.
package mock_objecttreebuilder
import (
context "context"
reflect "reflect"
objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
updatelistener "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
treestorage "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
objecttreebuilder "github.com/anyproto/any-sync/commonspace/objecttreebuilder"
gomock "go.uber.org/mock/gomock"
)
// MockTreeBuilder is a mock of TreeBuilder interface.
type MockTreeBuilder struct {
ctrl *gomock.Controller
recorder *MockTreeBuilderMockRecorder
}
// MockTreeBuilderMockRecorder is the mock recorder for MockTreeBuilder.
type MockTreeBuilderMockRecorder struct {
mock *MockTreeBuilder
}
// NewMockTreeBuilder creates a new mock instance.
func NewMockTreeBuilder(ctrl *gomock.Controller) *MockTreeBuilder {
mock := &MockTreeBuilder{ctrl: ctrl}
mock.recorder = &MockTreeBuilderMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockTreeBuilder) EXPECT() *MockTreeBuilderMockRecorder {
return m.recorder
}
// BuildHistoryTree mocks base method.
func (m *MockTreeBuilder) BuildHistoryTree(arg0 context.Context, arg1 string, arg2 objecttreebuilder.HistoryTreeOpts) (objecttree.HistoryTree, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BuildHistoryTree", arg0, arg1, arg2)
ret0, _ := ret[0].(objecttree.HistoryTree)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BuildHistoryTree indicates an expected call of BuildHistoryTree.
func (mr *MockTreeBuilderMockRecorder) BuildHistoryTree(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildHistoryTree", reflect.TypeOf((*MockTreeBuilder)(nil).BuildHistoryTree), arg0, arg1, arg2)
}
// BuildTree mocks base method.
func (m *MockTreeBuilder) BuildTree(arg0 context.Context, arg1 string, arg2 objecttreebuilder.BuildTreeOpts) (objecttree.ObjectTree, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BuildTree", arg0, arg1, arg2)
ret0, _ := ret[0].(objecttree.ObjectTree)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// BuildTree indicates an expected call of BuildTree.
func (mr *MockTreeBuilderMockRecorder) BuildTree(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BuildTree", reflect.TypeOf((*MockTreeBuilder)(nil).BuildTree), arg0, arg1, arg2)
}
// CreateTree mocks base method.
func (m *MockTreeBuilder) CreateTree(arg0 context.Context, arg1 objecttree.ObjectTreeCreatePayload) (treestorage.TreeStorageCreatePayload, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CreateTree", arg0, arg1)
ret0, _ := ret[0].(treestorage.TreeStorageCreatePayload)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CreateTree indicates an expected call of CreateTree.
func (mr *MockTreeBuilderMockRecorder) CreateTree(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateTree", reflect.TypeOf((*MockTreeBuilder)(nil).CreateTree), arg0, arg1)
}
// PutTree mocks base method.
func (m *MockTreeBuilder) PutTree(arg0 context.Context, arg1 treestorage.TreeStorageCreatePayload, arg2 updatelistener.UpdateListener) (objecttree.ObjectTree, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PutTree", arg0, arg1, arg2)
ret0, _ := ret[0].(objecttree.ObjectTree)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// PutTree indicates an expected call of PutTree.
func (mr *MockTreeBuilderMockRecorder) PutTree(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PutTree", reflect.TypeOf((*MockTreeBuilder)(nil).PutTree), arg0, arg1, arg2)
}

View File

@ -0,0 +1,204 @@
//go:generate mockgen -destination mock_objecttreebuilder/mock_objecttreebuilder.go github.com/anyproto/any-sync/commonspace/objecttreebuilder TreeBuilder
package objecttreebuilder
import (
"context"
"errors"
"sync/atomic"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/headsync"
"github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/requestmanager"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/nodeconf"
"go.uber.org/zap"
)
type BuildTreeOpts struct {
Listener updatelistener.UpdateListener
TreeBuilder objecttree.BuildObjectTreeFunc
}
const CName = "common.commonspace.objecttreebuilder"
var log = logger.NewNamed(CName)
var ErrSpaceClosed = errors.New("space is closed")
type HistoryTreeOpts struct {
BeforeId string
Include bool
BuildFullTree bool
}
type TreeBuilder interface {
BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (t objecttree.ObjectTree, err error)
BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (t objecttree.HistoryTree, err error)
CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error)
PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (t objecttree.ObjectTree, err error)
}
type TreeBuilderComponent interface {
app.Component
TreeBuilder
}
func New() TreeBuilderComponent {
return &treeBuilder{}
}
type treeBuilder struct {
syncClient synctree.SyncClient
configuration nodeconf.NodeConf
headsNotifiable synctree.HeadNotifiable
peerManager peermanager.PeerManager
requestManager requestmanager.RequestManager
spaceStorage spacestorage.SpaceStorage
syncStatus syncstatus.StatusUpdater
objectSync objectsync.ObjectSync
log logger.CtxLogger
builder objecttree.BuildObjectTreeFunc
spaceId string
aclList list.AclList
treesUsed *atomic.Int32
isClosed *atomic.Bool
}
func (t *treeBuilder) Init(a *app.App) (err error) {
state := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
t.spaceId = state.SpaceId
t.isClosed = state.SpaceIsClosed
t.treesUsed = state.TreesUsed
t.builder = state.TreeBuilderFunc
t.aclList = a.MustComponent(syncacl.CName).(syncacl.SyncAcl)
t.spaceStorage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
t.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
t.headsNotifiable = a.MustComponent(headsync.CName).(headsync.HeadSync)
t.syncStatus = a.MustComponent(syncstatus.CName).(syncstatus.StatusUpdater)
t.peerManager = a.MustComponent(peermanager.CName).(peermanager.PeerManager)
t.requestManager = a.MustComponent(requestmanager.CName).(requestmanager.RequestManager)
t.objectSync = a.MustComponent(objectsync.CName).(objectsync.ObjectSync)
t.log = log.With(zap.String("spaceId", t.spaceId))
t.syncClient = synctree.NewSyncClient(t.spaceId, t.requestManager, t.peerManager)
return nil
}
func (t *treeBuilder) Name() (name string) {
return CName
}
func (t *treeBuilder) BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (ot objecttree.ObjectTree, err error) {
if t.isClosed.Load() {
// TODO: change to real error
err = ErrSpaceClosed
return
}
treeBuilder := opts.TreeBuilder
if treeBuilder == nil {
treeBuilder = t.builder
}
deps := synctree.BuildDeps{
SpaceId: t.spaceId,
SyncClient: t.syncClient,
Configuration: t.configuration,
HeadNotifiable: t.headsNotifiable,
Listener: opts.Listener,
AclList: t.aclList,
SpaceStorage: t.spaceStorage,
OnClose: t.onClose,
SyncStatus: t.syncStatus,
PeerGetter: t.peerManager,
BuildObjectTree: treeBuilder,
}
t.treesUsed.Add(1)
t.log.Debug("incrementing counter", zap.String("id", id), zap.Int32("trees", t.treesUsed.Load()))
if ot, err = synctree.BuildSyncTreeOrGetRemote(ctx, id, deps); err != nil {
t.treesUsed.Add(-1)
t.log.Debug("decrementing counter, load failed", zap.String("id", id), zap.Int32("trees", t.treesUsed.Load()), zap.Error(err))
return nil, err
}
return
}
func (t *treeBuilder) BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (ot objecttree.HistoryTree, err error) {
if t.isClosed.Load() {
// TODO: change to real error
err = ErrSpaceClosed
return
}
params := objecttree.HistoryTreeParams{
AclList: t.aclList,
BeforeId: opts.BeforeId,
IncludeBeforeId: opts.Include,
BuildFullTree: opts.BuildFullTree,
}
params.TreeStorage, err = t.spaceStorage.TreeStorage(id)
if err != nil {
return
}
return objecttree.BuildHistoryTree(params)
}
func (t *treeBuilder) CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error) {
if t.isClosed.Load() {
err = ErrSpaceClosed
return
}
root, err := objecttree.CreateObjectTreeRoot(payload, t.aclList)
if err != nil {
return
}
res = treestorage.TreeStorageCreatePayload{
RootRawChange: root,
Changes: []*treechangeproto.RawTreeChangeWithId{root},
Heads: []string{root.Id},
}
return
}
func (t *treeBuilder) PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (ot objecttree.ObjectTree, err error) {
if t.isClosed.Load() {
err = ErrSpaceClosed
return
}
deps := synctree.BuildDeps{
SpaceId: t.spaceId,
SyncClient: t.syncClient,
Configuration: t.configuration,
HeadNotifiable: t.headsNotifiable,
Listener: listener,
AclList: t.aclList,
SpaceStorage: t.spaceStorage,
OnClose: t.onClose,
SyncStatus: t.syncStatus,
PeerGetter: t.peerManager,
BuildObjectTree: t.builder,
}
ot, err = synctree.PutSyncTree(ctx, payload, deps)
if err != nil {
return
}
t.treesUsed.Add(1)
t.log.Debug("incrementing counter", zap.String("id", payload.RootRawChange.Id), zap.Int32("trees", t.treesUsed.Load()))
return
}
func (t *treeBuilder) onClose(id string) {
t.treesUsed.Add(-1)
log.Debug("decrementing counter", zap.String("id", id), zap.Int32("trees", t.treesUsed.Load()), zap.String("spaceId", t.spaceId))
_ = t.objectSync.CloseThread(id)
}

View File

@ -2,20 +2,22 @@ package commonspace
import ( import (
"errors" "errors"
"hash/fnv"
"math/rand"
"strconv"
"strings"
"time"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/util/cidutil" "github.com/anyproto/any-sync/util/cidutil"
"github.com/anyproto/any-sync/util/crypto" "github.com/anyproto/any-sync/util/crypto"
"github.com/gogo/protobuf/proto" "github.com/gogo/protobuf/proto"
"hash/fnv"
"math/rand"
"strconv"
"strings"
"time"
) )
const ( const (
@ -71,7 +73,7 @@ func storagePayloadForSpaceCreate(payload SpaceCreatePayload) (storagePayload sp
// building acl root // building acl root
keyStorage := crypto.NewKeyStorage() keyStorage := crypto.NewKeyStorage()
aclBuilder := list.NewAclRecordBuilder("", keyStorage) aclBuilder := list.NewAclRecordBuilder("", keyStorage, nil, list.NoOpAcceptorVerifier{})
aclRoot, err := aclBuilder.BuildRoot(list.RootContent{ aclRoot, err := aclBuilder.BuildRoot(list.RootContent{
PrivKey: payload.SigningKey, PrivKey: payload.SigningKey,
MasterKey: payload.MasterKey, MasterKey: payload.MasterKey,
@ -158,7 +160,7 @@ func storagePayloadForSpaceDerive(payload SpaceDerivePayload) (storagePayload sp
// building acl root // building acl root
keyStorage := crypto.NewKeyStorage() keyStorage := crypto.NewKeyStorage()
aclBuilder := list.NewAclRecordBuilder("", keyStorage) aclBuilder := list.NewAclRecordBuilder("", keyStorage, nil, list.NoOpAcceptorVerifier{})
aclRoot, err := aclBuilder.BuildRoot(list.RootContent{ aclRoot, err := aclBuilder.BuildRoot(list.RootContent{
PrivKey: payload.SigningKey, PrivKey: payload.SigningKey,
MasterKey: payload.MasterKey, MasterKey: payload.MasterKey,
@ -214,14 +216,15 @@ func validateSpaceStorageCreatePayload(payload spacestorage.SpaceStorageCreatePa
} }
func ValidateSpaceHeader(rawHeaderWithId *spacesyncproto.RawSpaceHeaderWithId, identity crypto.PubKey) (err error) { func ValidateSpaceHeader(rawHeaderWithId *spacesyncproto.RawSpaceHeaderWithId, identity crypto.PubKey) (err error) {
if rawHeaderWithId == nil {
return spacestorage.ErrIncorrectSpaceHeader
}
sepIdx := strings.Index(rawHeaderWithId.Id, ".") sepIdx := strings.Index(rawHeaderWithId.Id, ".")
if sepIdx == -1 { if sepIdx == -1 {
err = spacestorage.ErrIncorrectSpaceHeader return spacestorage.ErrIncorrectSpaceHeader
return
} }
if !cidutil.VerifyCid(rawHeaderWithId.RawHeader, rawHeaderWithId.Id[:sepIdx]) { if !cidutil.VerifyCid(rawHeaderWithId.RawHeader, rawHeaderWithId.Id[:sepIdx]) {
err = objecttree.ErrIncorrectCid return objecttree.ErrIncorrectCid
return
} }
var rawSpaceHeader spacesyncproto.RawSpaceHeader var rawSpaceHeader spacesyncproto.RawSpaceHeader
err = proto.Unmarshal(rawHeaderWithId.RawHeader, &rawSpaceHeader) err = proto.Unmarshal(rawHeaderWithId.RawHeader, &rawSpaceHeader)
@ -239,29 +242,26 @@ func ValidateSpaceHeader(rawHeaderWithId *spacesyncproto.RawSpaceHeaderWithId, i
} }
res, err := payloadIdentity.Verify(rawSpaceHeader.SpaceHeader, rawSpaceHeader.Signature) res, err := payloadIdentity.Verify(rawSpaceHeader.SpaceHeader, rawSpaceHeader.Signature)
if err != nil || !res { if err != nil || !res {
err = spacestorage.ErrIncorrectSpaceHeader return spacestorage.ErrIncorrectSpaceHeader
return
} }
if rawHeaderWithId.Id[sepIdx+1:] != strconv.FormatUint(header.ReplicationKey, 36) { if rawHeaderWithId.Id[sepIdx+1:] != strconv.FormatUint(header.ReplicationKey, 36) {
err = spacestorage.ErrIncorrectSpaceHeader return spacestorage.ErrIncorrectSpaceHeader
return
} }
if identity == nil { if identity == nil {
return return
} }
if !payloadIdentity.Equals(identity) { if !payloadIdentity.Equals(identity) {
err = ErrIncorrectIdentity return ErrIncorrectIdentity
return
} }
return return
} }
func validateCreateSpaceAclPayload(rawWithId *aclrecordproto.RawAclRecordWithId) (spaceId string, err error) { func validateCreateSpaceAclPayload(rawWithId *consensusproto.RawRecordWithId) (spaceId string, err error) {
if !cidutil.VerifyCid(rawWithId.Payload, rawWithId.Id) { if !cidutil.VerifyCid(rawWithId.Payload, rawWithId.Id) {
err = objecttree.ErrIncorrectCid err = objecttree.ErrIncorrectCid
return return
} }
var rawAcl aclrecordproto.RawAclRecord var rawAcl consensusproto.RawRecord
err = proto.Unmarshal(rawWithId.Payload, &rawAcl) err = proto.Unmarshal(rawWithId.Payload, &rawAcl)
if err != nil { if err != nil {
return return

View File

@ -2,20 +2,22 @@ package commonspace
import ( import (
"fmt" "fmt"
"math/rand"
"strconv"
"testing"
"time"
"github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/util/cidutil" "github.com/anyproto/any-sync/util/cidutil"
"github.com/anyproto/any-sync/util/crypto" "github.com/anyproto/any-sync/util/crypto"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"math/rand"
"strconv"
"testing"
"time"
) )
func TestSuccessHeaderPayloadForSpaceCreate(t *testing.T) { func TestSuccessHeaderPayloadForSpaceCreate(t *testing.T) {
@ -188,14 +190,14 @@ func TestFailAclPayloadSpace_IncorrectCid(t *testing.T) {
marshalled, err := aclRoot.Marshal() marshalled, err := aclRoot.Marshal()
require.NoError(t, err) require.NoError(t, err)
signature, err := accountKeys.SignKey.Sign(marshalled) signature, err := accountKeys.SignKey.Sign(marshalled)
rawAclRecord := &aclrecordproto.RawAclRecord{ rawAclRecord := &consensusproto.RawRecord{
Payload: marshalled, Payload: marshalled,
Signature: signature, Signature: signature,
} }
marshalledRaw, err := rawAclRecord.Marshal() marshalledRaw, err := rawAclRecord.Marshal()
require.NoError(t, err) require.NoError(t, err)
aclHeadId := "rand" aclHeadId := "rand"
rawWithId := &aclrecordproto.RawAclRecordWithId{ rawWithId := &consensusproto.RawRecordWithId{
Payload: marshalledRaw, Payload: marshalledRaw,
Id: aclHeadId, Id: aclHeadId,
} }
@ -230,7 +232,7 @@ func TestFailedAclPayloadSpace_IncorrectSignature(t *testing.T) {
} }
marshalled, err := aclRoot.Marshal() marshalled, err := aclRoot.Marshal()
require.NoError(t, err) require.NoError(t, err)
rawAclRecord := &aclrecordproto.RawAclRecord{ rawAclRecord := &consensusproto.RawRecord{
Payload: marshalled, Payload: marshalled,
Signature: marshalled, Signature: marshalled,
} }
@ -238,7 +240,7 @@ func TestFailedAclPayloadSpace_IncorrectSignature(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
aclHeadId, err := cidutil.NewCidFromBytes(marshalledRaw) aclHeadId, err := cidutil.NewCidFromBytes(marshalledRaw)
require.NoError(t, err) require.NoError(t, err)
rawWithId := &aclrecordproto.RawAclRecordWithId{ rawWithId := &consensusproto.RawRecordWithId{
Payload: marshalledRaw, Payload: marshalledRaw,
Id: aclHeadId, Id: aclHeadId,
} }
@ -286,7 +288,7 @@ func TestFailedAclPayloadSpace_IncorrectIdentitySignature(t *testing.T) {
return return
} }
signature, err := accountKeys.SignKey.Sign(marshalled) signature, err := accountKeys.SignKey.Sign(marshalled)
rawAclRecord := &aclrecordproto.RawAclRecord{ rawAclRecord := &consensusproto.RawRecord{
Payload: marshalled, Payload: marshalled,
Signature: signature, Signature: signature,
} }
@ -298,7 +300,7 @@ func TestFailedAclPayloadSpace_IncorrectIdentitySignature(t *testing.T) {
if err != nil { if err != nil {
return return
} }
rawWithId := &aclrecordproto.RawAclRecordWithId{ rawWithId := &consensusproto.RawRecordWithId{
Payload: marshalledRaw, Payload: marshalledRaw,
Id: aclHeadId, Id: aclHeadId,
} }
@ -540,7 +542,7 @@ func rawSettingsPayload(accountKeys *accountdata.AccountKeys, spaceId, aclHeadId
return return
} }
func rawAclWithId(accountKeys *accountdata.AccountKeys, spaceId string) (aclHeadId string, rawWithId *aclrecordproto.RawAclRecordWithId, err error) { func rawAclWithId(accountKeys *accountdata.AccountKeys, spaceId string) (aclHeadId string, rawWithId *consensusproto.RawRecordWithId, err error) {
// TODO: use same storage creation methods as we use in spaces // TODO: use same storage creation methods as we use in spaces
readKeyBytes := make([]byte, 32) readKeyBytes := make([]byte, 32)
_, err = rand.Read(readKeyBytes) _, err = rand.Read(readKeyBytes)
@ -582,7 +584,7 @@ func rawAclWithId(accountKeys *accountdata.AccountKeys, spaceId string) (aclHead
return return
} }
signature, err := accountKeys.SignKey.Sign(marshalled) signature, err := accountKeys.SignKey.Sign(marshalled)
rawAclRecord := &aclrecordproto.RawAclRecord{ rawAclRecord := &consensusproto.RawRecord{
Payload: marshalled, Payload: marshalled,
Signature: signature, Signature: signature,
} }
@ -594,7 +596,7 @@ func rawAclWithId(accountKeys *accountdata.AccountKeys, spaceId string) (aclHead
if err != nil { if err != nil {
return return
} }
rawWithId = &aclrecordproto.RawAclRecordWithId{ rawWithId = &consensusproto.RawRecordWithId{
Payload: marshalledRaw, Payload: marshalledRaw,
Id: aclHeadId, Id: aclHeadId,
} }

View File

@ -8,9 +8,10 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
app "github.com/anyproto/any-sync/app"
spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto" spacesyncproto "github.com/anyproto/any-sync/commonspace/spacesyncproto"
peer "github.com/anyproto/any-sync/net/peer" peer "github.com/anyproto/any-sync/net/peer"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockPeerManager is a mock of PeerManager interface. // MockPeerManager is a mock of PeerManager interface.
@ -65,6 +66,34 @@ func (mr *MockPeerManagerMockRecorder) GetResponsiblePeers(arg0 interface{}) *go
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponsiblePeers", reflect.TypeOf((*MockPeerManager)(nil).GetResponsiblePeers), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponsiblePeers", reflect.TypeOf((*MockPeerManager)(nil).GetResponsiblePeers), arg0)
} }
// Init mocks base method.
func (m *MockPeerManager) Init(arg0 *app.App) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Init", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Init indicates an expected call of Init.
func (mr *MockPeerManagerMockRecorder) Init(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Init", reflect.TypeOf((*MockPeerManager)(nil).Init), arg0)
}
// Name mocks base method.
func (m *MockPeerManager) Name() string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Name")
ret0, _ := ret[0].(string)
return ret0
}
// Name indicates an expected call of Name.
func (mr *MockPeerManagerMockRecorder) Name() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Name", reflect.TypeOf((*MockPeerManager)(nil).Name))
}
// SendPeer mocks base method. // SendPeer mocks base method.
func (m *MockPeerManager) SendPeer(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) error { func (m *MockPeerManager) SendPeer(arg0 context.Context, arg1 string, arg2 *spacesyncproto.ObjectSyncMessage) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -8,9 +8,12 @@ import (
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
) )
const CName = "common.commonspace.peermanager" const (
CName = "common.commonspace.peermanager"
)
type PeerManager interface { type PeerManager interface {
app.Component
// SendPeer sends a message to a stream by peerId // SendPeer sends a message to a stream by peerId
SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error) SendPeer(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error)
// Broadcast sends a message to all subscribed peers // Broadcast sends a message to all subscribed peers

View File

@ -0,0 +1,129 @@
package requestmanager
import (
"context"
"sync"
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/pool"
"github.com/anyproto/any-sync/net/rpc/rpcerr"
"github.com/anyproto/any-sync/net/streampool"
"go.uber.org/zap"
"storj.io/drpc"
)
const CName = "common.commonspace.requestmanager"
var log = logger.NewNamed(CName)
type RequestManager interface {
app.ComponentRunnable
SendRequest(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error)
QueueRequest(peerId string, msg *spacesyncproto.ObjectSyncMessage) (err error)
}
func New() RequestManager {
return &requestManager{
workers: 10,
queueSize: 300,
pools: map[string]*streampool.ExecPool{},
}
}
type MessageHandler interface {
HandleMessage(ctx context.Context, hm objectsync.HandleMessage) (err error)
}
type requestManager struct {
sync.Mutex
pools map[string]*streampool.ExecPool
peerPool pool.Pool
workers int
queueSize int
handler MessageHandler
ctx context.Context
cancel context.CancelFunc
clientFactory spacesyncproto.ClientFactory
}
func (r *requestManager) Init(a *app.App) (err error) {
r.ctx, r.cancel = context.WithCancel(context.Background())
r.handler = a.MustComponent(objectsync.CName).(MessageHandler)
r.peerPool = a.MustComponent(pool.CName).(pool.Pool)
r.clientFactory = spacesyncproto.ClientFactoryFunc(spacesyncproto.NewDRPCSpaceSyncClient)
return
}
func (r *requestManager) Name() (name string) {
return CName
}
func (r *requestManager) Run(ctx context.Context) (err error) {
return nil
}
func (r *requestManager) Close(ctx context.Context) (err error) {
r.Lock()
defer r.Unlock()
r.cancel()
for _, p := range r.pools {
_ = p.Close()
}
return nil
}
func (r *requestManager) SendRequest(ctx context.Context, peerId string, req *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) {
// TODO: limit concurrent sends?
return r.doRequest(ctx, peerId, req)
}
func (r *requestManager) QueueRequest(peerId string, req *spacesyncproto.ObjectSyncMessage) (err error) {
r.Lock()
defer r.Unlock()
pl, exists := r.pools[peerId]
if !exists {
pl = streampool.NewExecPool(r.workers, r.queueSize)
r.pools[peerId] = pl
pl.Run()
}
// TODO: for later think when many clients are there,
// we need to close pools for inactive clients
return pl.TryAdd(func() {
doRequestAndHandle(r, peerId, req)
})
}
var doRequestAndHandle = (*requestManager).requestAndHandle
func (r *requestManager) requestAndHandle(peerId string, req *spacesyncproto.ObjectSyncMessage) {
ctx := r.ctx
resp, err := r.doRequest(ctx, peerId, req)
if err != nil {
log.Warn("failed to send request", zap.Error(err))
return
}
ctx = peer.CtxWithPeerId(ctx, peerId)
_ = r.handler.HandleMessage(ctx, objectsync.HandleMessage{
SenderId: peerId,
Message: resp,
PeerCtx: ctx,
})
}
func (r *requestManager) doRequest(ctx context.Context, peerId string, msg *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) {
pr, err := r.peerPool.Get(ctx, peerId)
if err != nil {
return
}
err = pr.DoDrpc(ctx, func(conn drpc.Conn) error {
cl := r.clientFactory.Client(conn)
resp, err = cl.ObjectSync(ctx, msg)
return err
})
err = rpcerr.Unwrap(err)
return
}

View File

@ -0,0 +1,186 @@
package requestmanager
import (
"context"
"sync"
"testing"
"github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objectsync/mock_objectsync"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/spacesyncproto/mock_spacesyncproto"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/peer/mock_peer"
"github.com/anyproto/any-sync/net/pool/mock_pool"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"storj.io/drpc"
"storj.io/drpc/drpcconn"
)
type fixture struct {
requestManager *requestManager
messageHandlerMock *mock_objectsync.MockObjectSync
peerPoolMock *mock_pool.MockPool
clientMock *mock_spacesyncproto.MockDRPCSpaceSyncClient
ctrl *gomock.Controller
}
func newFixture(t *testing.T) *fixture {
ctrl := gomock.NewController(t)
manager := New().(*requestManager)
peerPoolMock := mock_pool.NewMockPool(ctrl)
messageHandlerMock := mock_objectsync.NewMockObjectSync(ctrl)
clientMock := mock_spacesyncproto.NewMockDRPCSpaceSyncClient(ctrl)
manager.peerPool = peerPoolMock
manager.handler = messageHandlerMock
manager.clientFactory = spacesyncproto.ClientFactoryFunc(func(cc drpc.Conn) spacesyncproto.DRPCSpaceSyncClient {
return clientMock
})
manager.ctx, manager.cancel = context.WithCancel(context.Background())
return &fixture{
requestManager: manager,
messageHandlerMock: messageHandlerMock,
peerPoolMock: peerPoolMock,
clientMock: clientMock,
ctrl: ctrl,
}
}
func (fx *fixture) stop() {
fx.ctrl.Finish()
}
func TestRequestManager_SyncRequest(t *testing.T) {
ctx := context.Background()
t.Run("send request", func(t *testing.T) {
fx := newFixture(t)
defer fx.stop()
peerId := "peerId"
peerMock := mock_peer.NewMockPeer(fx.ctrl)
conn := &drpcconn.Conn{}
msg := &spacesyncproto.ObjectSyncMessage{}
resp := &spacesyncproto.ObjectSyncMessage{}
fx.peerPoolMock.EXPECT().Get(ctx, peerId).Return(peerMock, nil)
fx.clientMock.EXPECT().ObjectSync(ctx, msg).Return(resp, nil)
peerMock.EXPECT().DoDrpc(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, drpcHandler func(conn drpc.Conn) error) {
drpcHandler(conn)
}).Return(nil)
res, err := fx.requestManager.SendRequest(ctx, peerId, msg)
require.NoError(t, err)
require.Equal(t, resp, res)
})
t.Run("request and handle", func(t *testing.T) {
fx := newFixture(t)
defer fx.stop()
ctx = fx.requestManager.ctx
peerId := "peerId"
peerMock := mock_peer.NewMockPeer(fx.ctrl)
conn := &drpcconn.Conn{}
msg := &spacesyncproto.ObjectSyncMessage{}
resp := &spacesyncproto.ObjectSyncMessage{}
fx.peerPoolMock.EXPECT().Get(ctx, peerId).Return(peerMock, nil)
fx.clientMock.EXPECT().ObjectSync(ctx, msg).Return(resp, nil)
peerMock.EXPECT().DoDrpc(ctx, gomock.Any()).DoAndReturn(func(ctx context.Context, drpcHandler func(conn drpc.Conn) error) {
drpcHandler(conn)
}).Return(nil)
fx.messageHandlerMock.EXPECT().HandleMessage(gomock.Any(), gomock.Any()).DoAndReturn(func(ctx context.Context, msg objectsync.HandleMessage) {
require.Equal(t, peerId, msg.SenderId)
require.Equal(t, resp, msg.Message)
pId, _ := peer.CtxPeerId(msg.PeerCtx)
require.Equal(t, peerId, pId)
}).Return(nil)
fx.requestManager.requestAndHandle(peerId, msg)
})
}
func TestRequestManager_QueueRequest(t *testing.T) {
t.Run("max concurrent reqs for peer, independent reqs for other peer", func(t *testing.T) {
// testing 2 concurrent requests to one peer and simultaneous to another peer
fx := newFixture(t)
defer fx.stop()
fx.requestManager.workers = 2
msgRelease := make(chan struct{})
msgWait := make(chan struct{})
msgs := sync.Map{}
doRequestAndHandle = func(manager *requestManager, peerId string, req *spacesyncproto.ObjectSyncMessage) {
msgs.Store(req.ObjectId, struct{}{})
<-msgWait
<-msgRelease
}
otherPeer := "otherPeer"
msg1 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id1"}
msg2 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id2"}
msg3 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id3"}
otherMsg1 := &spacesyncproto.ObjectSyncMessage{ObjectId: "otherId1"}
// sending requests to first peer
peerId := "peerId"
err := fx.requestManager.QueueRequest(peerId, msg1)
require.NoError(t, err)
err = fx.requestManager.QueueRequest(peerId, msg2)
require.NoError(t, err)
err = fx.requestManager.QueueRequest(peerId, msg3)
require.NoError(t, err)
// waiting until all the messages are loaded
msgWait <- struct{}{}
msgWait <- struct{}{}
_, ok := msgs.Load("id1")
require.True(t, ok)
_, ok = msgs.Load("id2")
require.True(t, ok)
// third message should not be read
_, ok = msgs.Load("id3")
require.False(t, ok)
// request for other peer should pass
err = fx.requestManager.QueueRequest(otherPeer, otherMsg1)
require.NoError(t, err)
msgWait <- struct{}{}
_, ok = msgs.Load("otherId1")
require.True(t, ok)
close(msgRelease)
fx.requestManager.Close(context.Background())
})
t.Run("no requests after close", func(t *testing.T) {
fx := newFixture(t)
defer fx.stop()
fx.requestManager.workers = 1
msgRelease := make(chan struct{})
msgWait := make(chan struct{})
msgs := sync.Map{}
doRequestAndHandle = func(manager *requestManager, peerId string, req *spacesyncproto.ObjectSyncMessage) {
msgs.Store(req.ObjectId, struct{}{})
<-msgWait
<-msgRelease
}
msg1 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id1"}
msg2 := &spacesyncproto.ObjectSyncMessage{ObjectId: "id2"}
// sending requests to first peer
peerId := "peerId"
err := fx.requestManager.QueueRequest(peerId, msg1)
require.NoError(t, err)
err = fx.requestManager.QueueRequest(peerId, msg2)
require.NoError(t, err)
// waiting until all the message is loaded
msgWait <- struct{}{}
_, ok := msgs.Load("id1")
require.True(t, ok)
_, ok = msgs.Load("id2")
require.False(t, ok)
fx.requestManager.Close(context.Background())
close(msgRelease)
_, ok = msgs.Load("id2")
require.False(t, ok)
})
}

View File

@ -2,9 +2,9 @@ package settings
import ( import (
"context" "context"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"go.uber.org/zap" "go.uber.org/zap"
) )
@ -15,11 +15,11 @@ type Deleter interface {
type deleter struct { type deleter struct {
st spacestorage.SpaceStorage st spacestorage.SpaceStorage
state settingsstate.ObjectDeletionState state deletionstate.ObjectDeletionState
getter treemanager.TreeManager getter treemanager.TreeManager
} }
func newDeleter(st spacestorage.SpaceStorage, state settingsstate.ObjectDeletionState, getter treemanager.TreeManager) Deleter { func newDeleter(st spacestorage.SpaceStorage, state deletionstate.ObjectDeletionState, getter treemanager.TreeManager) Deleter {
return &deleter{st, state, getter} return &deleter{st, state, getter}
} }

View File

@ -2,11 +2,11 @@ package settings
import ( import (
"fmt" "fmt"
"github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage" "github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/golang/mock/gomock" "go.uber.org/mock/gomock"
"testing" "testing"
) )
@ -14,7 +14,7 @@ func TestDeleter_Delete(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
treeManager := mock_treemanager.NewMockTreeManager(ctrl) treeManager := mock_treemanager.NewMockTreeManager(ctrl)
st := mock_spacestorage.NewMockSpaceStorage(ctrl) st := mock_spacestorage.NewMockSpaceStorage(ctrl)
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) delState := mock_deletionstate.NewMockObjectDeletionState(ctrl)
deleter := newDeleter(st, delState, treeManager) deleter := newDeleter(st, delState, treeManager)

View File

@ -2,6 +2,7 @@ package settings
import ( import (
"context" "context"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"go.uber.org/zap" "go.uber.org/zap"
@ -20,7 +21,7 @@ func newDeletionManager(
settingsId string, settingsId string,
isResponsible bool, isResponsible bool,
treeManager treemanager.TreeManager, treeManager treemanager.TreeManager,
deletionState settingsstate.ObjectDeletionState, deletionState deletionstate.ObjectDeletionState,
provider SpaceIdsProvider, provider SpaceIdsProvider,
onSpaceDelete func()) DeletionManager { onSpaceDelete func()) DeletionManager {
return &deletionManager{ return &deletionManager{
@ -35,7 +36,7 @@ func newDeletionManager(
} }
type deletionManager struct { type deletionManager struct {
deletionState settingsstate.ObjectDeletionState deletionState deletionstate.ObjectDeletionState
provider SpaceIdsProvider provider SpaceIdsProvider
treeManager treemanager.TreeManager treeManager treemanager.TreeManager
spaceId string spaceId string

View File

@ -2,12 +2,12 @@ package settings
import ( import (
"context" "context"
"github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate"
"github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager/mock_treemanager"
"github.com/anyproto/any-sync/commonspace/settings/mock_settings" "github.com/anyproto/any-sync/commonspace/settings/mock_settings"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"testing" "testing"
) )
@ -26,7 +26,7 @@ func TestDeletionManager_UpdateState_NotResponsible(t *testing.T) {
onDeleted := func() { onDeleted := func() {
deleted = true deleted = true
} }
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) delState := mock_deletionstate.NewMockObjectDeletionState(ctrl)
treeManager := mock_treemanager.NewMockTreeManager(ctrl) treeManager := mock_treemanager.NewMockTreeManager(ctrl)
delState.EXPECT().Add(state.DeletedIds) delState.EXPECT().Add(state.DeletedIds)
@ -58,7 +58,7 @@ func TestDeletionManager_UpdateState_Responsible(t *testing.T) {
onDeleted := func() { onDeleted := func() {
deleted = true deleted = true
} }
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) delState := mock_deletionstate.NewMockObjectDeletionState(ctrl)
treeManager := mock_treemanager.NewMockTreeManager(ctrl) treeManager := mock_treemanager.NewMockTreeManager(ctrl)
provider := mock_settings.NewMockSpaceIdsProvider(ctrl) provider := mock_settings.NewMockSpaceIdsProvider(ctrl)

View File

@ -9,7 +9,7 @@ import (
reflect "reflect" reflect "reflect"
settingsstate "github.com/anyproto/any-sync/commonspace/settings/settingsstate" settingsstate "github.com/anyproto/any-sync/commonspace/settings/settingsstate"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockDeletionManager is a mock of DeletionManager interface. // MockDeletionManager is a mock of DeletionManager interface.

View File

@ -1,328 +1,122 @@
//go:generate mockgen -destination mock_settings/mock_settings.go github.com/anyproto/any-sync/commonspace/settings DeletionManager,Deleter,SpaceIdsProvider
package settings package settings
import ( import (
"context" "context"
"errors" "sync/atomic"
"fmt"
"github.com/anyproto/any-sync/util/crypto"
"github.com/anyproto/any-sync/accountservice" "github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/headsync"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree" "github.com/anyproto/any-sync/commonspace/object/tree/synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener" "github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/objecttreebuilder"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
"github.com/gogo/protobuf/proto"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/exp/slices"
) )
var log = logger.NewNamed("common.commonspace.settings") const CName = "common.commonspace.settings"
type SettingsObject interface { type Settings interface {
synctree.SyncTree DeleteTree(ctx context.Context, id string) (err error)
Init(ctx context.Context) (err error) SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error)
DeleteObject(id string) (err error) DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error)
DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error) SettingsObject() SettingsObject
SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error) app.ComponentRunnable
} }
var ( func New() Settings {
ErrDeleteSelf = errors.New("cannot delete self") return &settings{}
ErrAlreadyDeleted = errors.New("the object is already deleted")
ErrObjDoesNotExist = errors.New("the object does not exist")
ErrCantDeleteSpace = errors.New("not able to delete space")
)
var (
DoSnapshot = objecttree.DoSnapshot
buildHistoryTree = func(objTree objecttree.ObjectTree) (objecttree.ReadableObjectTree, error) {
return objecttree.BuildHistoryTree(objecttree.HistoryTreeParams{
TreeStorage: objTree.Storage(),
AclList: objTree.AclList(),
BuildFullTree: true,
})
}
)
type BuildTreeFunc func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error)
type Deps struct {
BuildFunc BuildTreeFunc
Account accountservice.Service
TreeManager treemanager.TreeManager
Store spacestorage.SpaceStorage
Configuration nodeconf.NodeConf
DeletionState settingsstate.ObjectDeletionState
Provider SpaceIdsProvider
OnSpaceDelete func()
// testing dependencies
builder settingsstate.StateBuilder
del Deleter
delManager DeletionManager
changeFactory settingsstate.ChangeFactory
} }
type settingsObject struct { type settings struct {
synctree.SyncTree account accountservice.Service
account accountservice.Service treeManager treemanager.TreeManager
spaceId string storage spacestorage.SpaceStorage
treeManager treemanager.TreeManager configuration nodeconf.NodeConf
store spacestorage.SpaceStorage deletionState deletionstate.ObjectDeletionState
builder settingsstate.StateBuilder headsync headsync.HeadSync
buildFunc BuildTreeFunc treeBuilder objecttreebuilder.TreeBuilderComponent
loop *deleteLoop spaceIsDeleted *atomic.Bool
state *settingsstate.State settingsObject SettingsObject
deletionState settingsstate.ObjectDeletionState
deletionManager DeletionManager
changeFactory settingsstate.ChangeFactory
} }
func NewSettingsObject(deps Deps, spaceId string) (obj SettingsObject) { func (s *settings) Init(a *app.App) (err error) {
var ( s.account = a.MustComponent(accountservice.CName).(accountservice.Service)
deleter Deleter s.treeManager = app.MustComponent[treemanager.TreeManager](a)
deletionManager DeletionManager s.headsync = a.MustComponent(headsync.CName).(headsync.HeadSync)
builder settingsstate.StateBuilder s.configuration = a.MustComponent(nodeconf.CName).(nodeconf.NodeConf)
changeFactory settingsstate.ChangeFactory s.deletionState = a.MustComponent(deletionstate.CName).(deletionstate.ObjectDeletionState)
) s.treeBuilder = a.MustComponent(objecttreebuilder.CName).(objecttreebuilder.TreeBuilderComponent)
if deps.del == nil {
deleter = newDeleter(deps.Store, deps.DeletionState, deps.TreeManager)
} else {
deleter = deps.del
}
if deps.delManager == nil {
deletionManager = newDeletionManager(
spaceId,
deps.Store.SpaceSettingsId(),
deps.Configuration.IsResponsible(spaceId),
deps.TreeManager,
deps.DeletionState,
deps.Provider,
deps.OnSpaceDelete)
} else {
deletionManager = deps.delManager
}
if deps.builder == nil {
builder = settingsstate.NewStateBuilder()
} else {
builder = deps.builder
}
if deps.changeFactory == nil {
changeFactory = settingsstate.NewChangeFactory()
} else {
changeFactory = deps.changeFactory
}
loop := newDeleteLoop(func() { sharedState := a.MustComponent(spacestate.CName).(*spacestate.SpaceState)
deleter.Delete() s.storage = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
}) s.spaceIsDeleted = sharedState.SpaceIsDeleted
deps.DeletionState.AddObserver(func(ids []string) {
loop.notify()
})
s := &settingsObject{ deps := Deps{
loop: loop, BuildFunc: func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error) {
spaceId: spaceId, res, err := s.treeBuilder.BuildTree(ctx, id, objecttreebuilder.BuildTreeOpts{
account: deps.Account, Listener: listener,
deletionState: deps.DeletionState, // space settings document should not have empty data
treeManager: deps.TreeManager, TreeBuilder: objecttree.BuildObjectTree,
store: deps.Store, })
buildFunc: deps.BuildFunc, log.Debug("building settings tree", zap.String("id", id), zap.String("spaceId", sharedState.SpaceId))
builder: builder, if err != nil {
deletionManager: deletionManager, return
changeFactory: changeFactory, }
} t = res.(synctree.SyncTree)
obj = s
return
}
func (s *settingsObject) updateIds(tr objecttree.ObjectTree) {
var err error
s.state, err = s.builder.Build(tr, s.state)
if err != nil {
log.Error("failed to build state", zap.Error(err))
return
}
log.Debug("updating object state", zap.String("deleted by", s.state.DeleterId))
if err = s.deletionManager.UpdateState(context.Background(), s.state); err != nil {
log.Error("failed to update state", zap.Error(err))
}
}
// Update is called as part of UpdateListener interface
func (s *settingsObject) Update(tr objecttree.ObjectTree) {
s.updateIds(tr)
}
// Rebuild is called as part of UpdateListener interface (including when the object is built for the first time, e.g. on Init call)
func (s *settingsObject) Rebuild(tr objecttree.ObjectTree) {
// at initial build "s" may not contain the object tree, so it is safer to provide it from the function parameter
s.state = nil
s.updateIds(tr)
}
func (s *settingsObject) Init(ctx context.Context) (err error) {
settingsId := s.store.SpaceSettingsId()
log.Debug("space settings id", zap.String("id", settingsId))
s.SyncTree, err = s.buildFunc(ctx, settingsId, s)
if err != nil {
return
}
// TODO: remove this check when everybody updates
if err = s.checkHistoryState(ctx); err != nil {
return
}
s.loop.Run()
return
}
func (s *settingsObject) checkHistoryState(ctx context.Context) (err error) {
historyTree, err := buildHistoryTree(s.SyncTree)
if err != nil {
return
}
fullState, err := s.builder.Build(historyTree, nil)
if err != nil {
return
}
if len(fullState.DeletedIds) != len(s.state.DeletedIds) {
log.WarnCtx(ctx, "state does not have all deleted ids",
zap.Int("fullstate ids", len(fullState.DeletedIds)),
zap.Int("state ids", len(fullState.DeletedIds)))
s.state = fullState
err = s.deletionManager.UpdateState(context.Background(), s.state)
if err != nil {
return return
} },
Account: s.account,
TreeManager: s.treeManager,
Store: s.storage,
Configuration: s.configuration,
DeletionState: s.deletionState,
Provider: s.headsync,
OnSpaceDelete: s.onSpaceDelete,
} }
return s.settingsObject = NewSettingsObject(deps, sharedState.SpaceId)
return nil
} }
func (s *settingsObject) Close() error { func (s *settings) Name() (name string) {
s.loop.Close() return CName
return s.SyncTree.Close()
} }
func (s *settingsObject) DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error) { func (s *settings) Run(ctx context.Context) (err error) {
s.Lock() return s.settingsObject.Init(ctx)
defer s.Unlock()
defer func() {
log.Debug("finished adding delete change", zap.Error(err))
}()
err = s.verifyDeleteSpace(raw)
if err != nil {
return
}
res, err := s.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: []string{raw.Id},
RawChanges: []*treechangeproto.RawTreeChangeWithId{raw},
})
if err != nil {
return
}
if !slices.Contains(res.Heads, raw.Id) {
err = ErrCantDeleteSpace
return
}
return
} }
func (s *settingsObject) SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error) { func (s *settings) Close(ctx context.Context) (err error) {
accountData := s.account.Account() return s.settingsObject.Close()
data, err := s.changeFactory.CreateSpaceDeleteChange(accountData.PeerId, s.state, false)
if err != nil {
return
}
return s.PrepareChange(objecttree.SignableChangeContent{
Data: data,
Key: accountData.SignKey,
IsSnapshot: false,
IsEncrypted: false,
})
} }
func (s *settingsObject) DeleteObject(id string) (err error) { func (s *settings) DeleteTree(ctx context.Context, id string) (err error) {
s.Lock() return s.settingsObject.DeleteObject(id)
defer s.Unlock()
if s.Id() == id {
err = ErrDeleteSelf
return
}
if s.state.Exists(id) {
err = ErrAlreadyDeleted
return nil
}
_, err = s.store.TreeStorage(id)
if err != nil {
err = ErrObjDoesNotExist
return
}
isSnapshot := DoSnapshot(s.Len())
res, err := s.changeFactory.CreateObjectDeleteChange(id, s.state, isSnapshot)
if err != nil {
return
}
return s.addContent(res, isSnapshot)
} }
func (s *settingsObject) verifyDeleteSpace(raw *treechangeproto.RawTreeChangeWithId) (err error) { func (s *settings) SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) {
data, err := s.UnpackChange(raw) return s.settingsObject.SpaceDeleteRawChange()
if err != nil {
return
}
return verifyDeleteContent(data, "")
} }
func (s *settingsObject) addContent(data []byte, isSnapshot bool) (err error) { func (s *settings) DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) {
accountData := s.account.Account() return s.settingsObject.DeleteSpace(ctx, deleteChange)
res, err := s.AddContent(context.Background(), objecttree.SignableChangeContent{
Data: data,
Key: accountData.SignKey,
IsSnapshot: isSnapshot,
IsEncrypted: false,
})
if err != nil {
return
}
if res.Mode == objecttree.Rebuild {
s.Rebuild(s)
} else {
s.Update(s)
}
return
} }
func VerifyDeleteChange(raw *treechangeproto.RawTreeChangeWithId, identity crypto.PubKey, peerId string) (err error) { func (s *settings) onSpaceDelete() {
changeBuilder := objecttree.NewChangeBuilder(crypto.NewKeyStorage(), nil) err := s.storage.SetSpaceDeleted()
res, err := changeBuilder.Unmarshall(raw, true)
if err != nil { if err != nil {
return log.Warn("failed to set space deleted")
} }
if !res.Identity.Equals(identity) { s.spaceIsDeleted.Swap(true)
return fmt.Errorf("incorrect identity")
}
return verifyDeleteContent(res.Data, peerId)
} }
func verifyDeleteContent(data []byte, peerId string) (err error) { func (s *settings) SettingsObject() SettingsObject {
content := &spacesyncproto.SettingsData{} return s.settingsObject
err = proto.Unmarshal(data, content)
if err != nil {
return
}
if len(content.GetContent()) != 1 ||
content.GetContent()[0].GetSpaceDelete() == nil ||
(peerId == "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() == "") ||
(peerId != "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() != peerId) {
return fmt.Errorf("incorrect delete change payload")
}
return
} }

View File

@ -0,0 +1,329 @@
//go:generate mockgen -destination mock_settings/mock_settings.go github.com/anyproto/any-sync/commonspace/settings DeletionManager,Deleter,SpaceIdsProvider
package settings
import (
"context"
"errors"
"fmt"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/util/crypto"
"github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/nodeconf"
"github.com/gogo/protobuf/proto"
"go.uber.org/zap"
"golang.org/x/exp/slices"
)
var log = logger.NewNamed("common.commonspace.settings")
type SettingsObject interface {
synctree.SyncTree
Init(ctx context.Context) (err error)
DeleteObject(id string) (err error)
DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error)
SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error)
}
var (
ErrDeleteSelf = errors.New("cannot delete self")
ErrAlreadyDeleted = errors.New("the object is already deleted")
ErrObjDoesNotExist = errors.New("the object does not exist")
ErrCantDeleteSpace = errors.New("not able to delete space")
)
var (
DoSnapshot = objecttree.DoSnapshot
buildHistoryTree = func(objTree objecttree.ObjectTree) (objecttree.ReadableObjectTree, error) {
return objecttree.BuildHistoryTree(objecttree.HistoryTreeParams{
TreeStorage: objTree.Storage(),
AclList: objTree.AclList(),
BuildFullTree: true,
})
}
)
type BuildTreeFunc func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error)
type Deps struct {
BuildFunc BuildTreeFunc
Account accountservice.Service
TreeManager treemanager.TreeManager
Store spacestorage.SpaceStorage
Configuration nodeconf.NodeConf
DeletionState deletionstate.ObjectDeletionState
Provider SpaceIdsProvider
OnSpaceDelete func()
// testing dependencies
builder settingsstate.StateBuilder
del Deleter
delManager DeletionManager
changeFactory settingsstate.ChangeFactory
}
type settingsObject struct {
synctree.SyncTree
account accountservice.Service
spaceId string
treeManager treemanager.TreeManager
store spacestorage.SpaceStorage
builder settingsstate.StateBuilder
buildFunc BuildTreeFunc
loop *deleteLoop
state *settingsstate.State
deletionState deletionstate.ObjectDeletionState
deletionManager DeletionManager
changeFactory settingsstate.ChangeFactory
}
func NewSettingsObject(deps Deps, spaceId string) (obj SettingsObject) {
var (
deleter Deleter
deletionManager DeletionManager
builder settingsstate.StateBuilder
changeFactory settingsstate.ChangeFactory
)
if deps.del == nil {
deleter = newDeleter(deps.Store, deps.DeletionState, deps.TreeManager)
} else {
deleter = deps.del
}
if deps.delManager == nil {
deletionManager = newDeletionManager(
spaceId,
deps.Store.SpaceSettingsId(),
deps.Configuration.IsResponsible(spaceId),
deps.TreeManager,
deps.DeletionState,
deps.Provider,
deps.OnSpaceDelete)
} else {
deletionManager = deps.delManager
}
if deps.builder == nil {
builder = settingsstate.NewStateBuilder()
} else {
builder = deps.builder
}
if deps.changeFactory == nil {
changeFactory = settingsstate.NewChangeFactory()
} else {
changeFactory = deps.changeFactory
}
loop := newDeleteLoop(func() {
deleter.Delete()
})
deps.DeletionState.AddObserver(func(ids []string) {
loop.notify()
})
s := &settingsObject{
loop: loop,
spaceId: spaceId,
account: deps.Account,
deletionState: deps.DeletionState,
treeManager: deps.TreeManager,
store: deps.Store,
buildFunc: deps.BuildFunc,
builder: builder,
deletionManager: deletionManager,
changeFactory: changeFactory,
}
obj = s
return
}
func (s *settingsObject) updateIds(tr objecttree.ObjectTree) {
var err error
s.state, err = s.builder.Build(tr, s.state)
if err != nil {
log.Error("failed to build state", zap.Error(err))
return
}
log.Debug("updating object state", zap.String("deleted by", s.state.DeleterId))
if err = s.deletionManager.UpdateState(context.Background(), s.state); err != nil {
log.Error("failed to update state", zap.Error(err))
}
}
// Update is called as part of UpdateListener interface
func (s *settingsObject) Update(tr objecttree.ObjectTree) {
s.updateIds(tr)
}
// Rebuild is called as part of UpdateListener interface (including when the object is built for the first time, e.g. on Init call)
func (s *settingsObject) Rebuild(tr objecttree.ObjectTree) {
// at initial build "s" may not contain the object tree, so it is safer to provide it from the function parameter
s.state = nil
s.updateIds(tr)
}
func (s *settingsObject) Init(ctx context.Context) (err error) {
settingsId := s.store.SpaceSettingsId()
log.Debug("space settings id", zap.String("id", settingsId))
s.SyncTree, err = s.buildFunc(ctx, settingsId, s)
if err != nil {
return
}
// TODO: remove this check when everybody updates
if err = s.checkHistoryState(ctx); err != nil {
return
}
s.loop.Run()
return
}
func (s *settingsObject) checkHistoryState(ctx context.Context) (err error) {
historyTree, err := buildHistoryTree(s.SyncTree)
if err != nil {
return
}
fullState, err := s.builder.Build(historyTree, nil)
if err != nil {
return
}
if len(fullState.DeletedIds) != len(s.state.DeletedIds) {
log.WarnCtx(ctx, "state does not have all deleted ids",
zap.Int("fullstate ids", len(fullState.DeletedIds)),
zap.Int("state ids", len(fullState.DeletedIds)))
s.state = fullState
err = s.deletionManager.UpdateState(context.Background(), s.state)
if err != nil {
return
}
}
return
}
func (s *settingsObject) Close() error {
s.loop.Close()
return s.SyncTree.Close()
}
func (s *settingsObject) DeleteSpace(ctx context.Context, raw *treechangeproto.RawTreeChangeWithId) (err error) {
s.Lock()
defer s.Unlock()
defer func() {
log.Debug("finished adding delete change", zap.Error(err))
}()
err = s.verifyDeleteSpace(raw)
if err != nil {
return
}
res, err := s.AddRawChanges(ctx, objecttree.RawChangesPayload{
NewHeads: []string{raw.Id},
RawChanges: []*treechangeproto.RawTreeChangeWithId{raw},
})
if err != nil {
return
}
if !slices.Contains(res.Heads, raw.Id) {
err = ErrCantDeleteSpace
return
}
return
}
func (s *settingsObject) SpaceDeleteRawChange() (raw *treechangeproto.RawTreeChangeWithId, err error) {
accountData := s.account.Account()
data, err := s.changeFactory.CreateSpaceDeleteChange(accountData.PeerId, s.state, false)
if err != nil {
return
}
return s.PrepareChange(objecttree.SignableChangeContent{
Data: data,
Key: accountData.SignKey,
IsSnapshot: false,
IsEncrypted: false,
})
}
func (s *settingsObject) DeleteObject(id string) (err error) {
s.Lock()
defer s.Unlock()
if s.Id() == id {
err = ErrDeleteSelf
return
}
if s.state.Exists(id) {
err = ErrAlreadyDeleted
return nil
}
_, err = s.store.TreeStorage(id)
if err != nil {
err = ErrObjDoesNotExist
return
}
isSnapshot := DoSnapshot(s.Len())
res, err := s.changeFactory.CreateObjectDeleteChange(id, s.state, isSnapshot)
if err != nil {
return
}
return s.addContent(res, isSnapshot)
}
func (s *settingsObject) verifyDeleteSpace(raw *treechangeproto.RawTreeChangeWithId) (err error) {
data, err := s.UnpackChange(raw)
if err != nil {
return
}
return verifyDeleteContent(data, "")
}
func (s *settingsObject) addContent(data []byte, isSnapshot bool) (err error) {
accountData := s.account.Account()
res, err := s.AddContent(context.Background(), objecttree.SignableChangeContent{
Data: data,
Key: accountData.SignKey,
IsSnapshot: isSnapshot,
IsEncrypted: false,
})
if err != nil {
return
}
if res.Mode == objecttree.Rebuild {
s.Rebuild(s)
} else {
s.Update(s)
}
return
}
func VerifyDeleteChange(raw *treechangeproto.RawTreeChangeWithId, identity crypto.PubKey, peerId string) (err error) {
changeBuilder := objecttree.NewChangeBuilder(crypto.NewKeyStorage(), nil)
res, err := changeBuilder.Unmarshall(raw, true)
if err != nil {
return
}
if !res.Identity.Equals(identity) {
return fmt.Errorf("incorrect identity")
}
return verifyDeleteContent(res.Data, peerId)
}
func verifyDeleteContent(data []byte, peerId string) (err error) {
content := &spacesyncproto.SettingsData{}
err = proto.Unmarshal(data, content)
if err != nil {
return
}
if len(content.GetContent()) != 1 ||
content.GetContent()[0].GetSpaceDelete() == nil ||
(peerId == "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() == "") ||
(peerId != "" && content.GetContent()[0].GetSpaceDelete().GetDeleterPeerId() != peerId) {
return fmt.Errorf("incorrect delete change payload")
}
return
}

View File

@ -3,6 +3,7 @@ package settings
import ( import (
"context" "context"
"github.com/anyproto/any-sync/accountservice/mock_accountservice" "github.com/anyproto/any-sync/accountservice/mock_accountservice"
"github.com/anyproto/any-sync/commonspace/deletionstate/mock_deletionstate"
"github.com/anyproto/any-sync/commonspace/object/accountdata" "github.com/anyproto/any-sync/commonspace/object/accountdata"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree"
@ -15,8 +16,8 @@ import (
"github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/settings/settingsstate"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate" "github.com/anyproto/any-sync/commonspace/settings/settingsstate/mock_settingsstate"
"github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage/mock_spacestorage"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -54,7 +55,7 @@ type settingsFixture struct {
deleter *mock_settings.MockDeleter deleter *mock_settings.MockDeleter
syncTree *mock_synctree.MockSyncTree syncTree *mock_synctree.MockSyncTree
historyTree *mock_objecttree.MockObjectTree historyTree *mock_objecttree.MockObjectTree
delState *mock_settingsstate.MockObjectDeletionState delState *mock_deletionstate.MockObjectDeletionState
account *mock_accountservice.MockService account *mock_accountservice.MockService
} }
@ -66,7 +67,7 @@ func newSettingsFixture(t *testing.T) *settingsFixture {
acc := mock_accountservice.NewMockService(ctrl) acc := mock_accountservice.NewMockService(ctrl)
treeManager := mock_treemanager.NewMockTreeManager(ctrl) treeManager := mock_treemanager.NewMockTreeManager(ctrl)
st := mock_spacestorage.NewMockSpaceStorage(ctrl) st := mock_spacestorage.NewMockSpaceStorage(ctrl)
delState := mock_settingsstate.NewMockObjectDeletionState(ctrl) delState := mock_deletionstate.NewMockObjectDeletionState(ctrl)
delManager := mock_settings.NewMockDeletionManager(ctrl) delManager := mock_settings.NewMockDeletionManager(ctrl)
stateBuilder := mock_settingsstate.NewMockStateBuilder(ctrl) stateBuilder := mock_settingsstate.NewMockStateBuilder(ctrl)
changeFactory := mock_settingsstate.NewMockChangeFactory(ctrl) changeFactory := mock_settingsstate.NewMockChangeFactory(ctrl)

View File

@ -1,5 +1,5 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/anyproto/any-sync/commonspace/settings/settingsstate (interfaces: ObjectDeletionState,StateBuilder,ChangeFactory) // Source: github.com/anyproto/any-sync/commonspace/settings/settingsstate (interfaces: StateBuilder,ChangeFactory)
// Package mock_settingsstate is a generated GoMock package. // Package mock_settingsstate is a generated GoMock package.
package mock_settingsstate package mock_settingsstate
@ -9,112 +9,9 @@ import (
objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree" objecttree "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
settingsstate "github.com/anyproto/any-sync/commonspace/settings/settingsstate" settingsstate "github.com/anyproto/any-sync/commonspace/settings/settingsstate"
gomock "github.com/golang/mock/gomock" gomock "go.uber.org/mock/gomock"
) )
// MockObjectDeletionState is a mock of ObjectDeletionState interface.
type MockObjectDeletionState struct {
ctrl *gomock.Controller
recorder *MockObjectDeletionStateMockRecorder
}
// MockObjectDeletionStateMockRecorder is the mock recorder for MockObjectDeletionState.
type MockObjectDeletionStateMockRecorder struct {
mock *MockObjectDeletionState
}
// NewMockObjectDeletionState creates a new mock instance.
func NewMockObjectDeletionState(ctrl *gomock.Controller) *MockObjectDeletionState {
mock := &MockObjectDeletionState{ctrl: ctrl}
mock.recorder = &MockObjectDeletionStateMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockObjectDeletionState) EXPECT() *MockObjectDeletionStateMockRecorder {
return m.recorder
}
// Add mocks base method.
func (m *MockObjectDeletionState) Add(arg0 map[string]struct{}) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "Add", arg0)
}
// Add indicates an expected call of Add.
func (mr *MockObjectDeletionStateMockRecorder) Add(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Add", reflect.TypeOf((*MockObjectDeletionState)(nil).Add), arg0)
}
// AddObserver mocks base method.
func (m *MockObjectDeletionState) AddObserver(arg0 settingsstate.StateUpdateObserver) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "AddObserver", arg0)
}
// AddObserver indicates an expected call of AddObserver.
func (mr *MockObjectDeletionStateMockRecorder) AddObserver(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddObserver", reflect.TypeOf((*MockObjectDeletionState)(nil).AddObserver), arg0)
}
// Delete mocks base method.
func (m *MockObjectDeletionState) Delete(arg0 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Delete", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Delete indicates an expected call of Delete.
func (mr *MockObjectDeletionStateMockRecorder) Delete(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Delete", reflect.TypeOf((*MockObjectDeletionState)(nil).Delete), arg0)
}
// Exists mocks base method.
func (m *MockObjectDeletionState) Exists(arg0 string) bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Exists", arg0)
ret0, _ := ret[0].(bool)
return ret0
}
// Exists indicates an expected call of Exists.
func (mr *MockObjectDeletionStateMockRecorder) Exists(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exists", reflect.TypeOf((*MockObjectDeletionState)(nil).Exists), arg0)
}
// Filter mocks base method.
func (m *MockObjectDeletionState) Filter(arg0 []string) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Filter", arg0)
ret0, _ := ret[0].([]string)
return ret0
}
// Filter indicates an expected call of Filter.
func (mr *MockObjectDeletionStateMockRecorder) Filter(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filter", reflect.TypeOf((*MockObjectDeletionState)(nil).Filter), arg0)
}
// GetQueued mocks base method.
func (m *MockObjectDeletionState) GetQueued() []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetQueued")
ret0, _ := ret[0].([]string)
return ret0
}
// GetQueued indicates an expected call of GetQueued.
func (mr *MockObjectDeletionStateMockRecorder) GetQueued() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQueued", reflect.TypeOf((*MockObjectDeletionState)(nil).GetQueued))
}
// MockStateBuilder is a mock of StateBuilder interface. // MockStateBuilder is a mock of StateBuilder interface.
type MockStateBuilder struct { type MockStateBuilder struct {
ctrl *gomock.Controller ctrl *gomock.Controller

View File

@ -1,3 +1,4 @@
//go:generate mockgen -destination mock_settingsstate/mock_settingsstate.go github.com/anyproto/any-sync/commonspace/settings/settingsstate StateBuilder,ChangeFactory
package settingsstate package settingsstate
import "github.com/anyproto/any-sync/commonspace/spacesyncproto" import "github.com/anyproto/any-sync/commonspace/spacesyncproto"

View File

@ -4,8 +4,8 @@ import (
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree/mock_objecttree"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"testing" "testing"
) )

View File

@ -2,44 +2,26 @@ package commonspace
import ( import (
"context" "context"
"errors" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/headsync" "github.com/anyproto/any-sync/commonspace/headsync"
"github.com/anyproto/any-sync/commonspace/object/acl/list" "github.com/anyproto/any-sync/commonspace/object/acl/list"
"github.com/anyproto/any-sync/commonspace/object/acl/syncacl" "github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree"
"github.com/anyproto/any-sync/commonspace/object/tree/synctree/updatelistener"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/tree/treestorage"
"github.com/anyproto/any-sync/commonspace/objectsync" "github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/peermanager" "github.com/anyproto/any-sync/commonspace/objecttreebuilder"
"github.com/anyproto/any-sync/commonspace/settings" "github.com/anyproto/any-sync/commonspace/settings"
"github.com/anyproto/any-sync/commonspace/settings/settingsstate" "github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/metric"
"github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/nodeconf"
"github.com/anyproto/any-sync/util/crypto" "github.com/anyproto/any-sync/util/crypto"
"github.com/anyproto/any-sync/util/multiqueue"
"github.com/anyproto/any-sync/util/slice"
"github.com/cheggaaa/mb/v3"
"github.com/zeebo/errs"
"go.uber.org/zap" "go.uber.org/zap"
"strconv" "strconv"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
var (
ErrSpaceClosed = errors.New("space is closed")
)
type SpaceCreatePayload struct { type SpaceCreatePayload struct {
// SigningKey is the signing key of the owner // SigningKey is the signing key of the owner
SigningKey crypto.PrivKey SigningKey crypto.PrivKey
@ -55,25 +37,6 @@ type SpaceCreatePayload struct {
MasterKey crypto.PrivKey MasterKey crypto.PrivKey
} }
type HandleMessage struct {
Id uint64
ReceiveTime time.Time
StartHandlingTime time.Time
Deadline time.Time
SenderId string
Message *spacesyncproto.ObjectSyncMessage
PeerCtx context.Context
}
func (m HandleMessage) LogFields(fields ...zap.Field) []zap.Field {
return append(fields,
metric.SpaceId(m.Message.SpaceId),
metric.ObjectId(m.Message.ObjectId),
metric.QueueDur(m.StartHandlingTime.Sub(m.ReceiveTime)),
metric.TotalDur(time.Since(m.ReceiveTime)),
)
}
type SpaceDerivePayload struct { type SpaceDerivePayload struct {
SigningKey crypto.PrivKey SigningKey crypto.PrivKey
MasterKey crypto.PrivKey MasterKey crypto.PrivKey
@ -96,58 +59,42 @@ func NewSpaceId(id string, repKey uint64) string {
type Space interface { type Space interface {
Id() string Id() string
Init(ctx context.Context) error Init(ctx context.Context) error
Acl() list.AclList
StoredIds() []string StoredIds() []string
DebugAllHeads() []headsync.TreeHeads DebugAllHeads() []headsync.TreeHeads
Description() (SpaceDescription, error) Description() (desc SpaceDescription, err error)
CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error) TreeBuilder() objecttreebuilder.TreeBuilder
PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (t objecttree.ObjectTree, err error)
BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (t objecttree.ObjectTree, err error)
DeleteTree(ctx context.Context, id string) (err error)
BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (t objecttree.HistoryTree, err error)
SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error)
DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error)
HeadSync() headsync.HeadSync
ObjectSync() objectsync.ObjectSync
SyncStatus() syncstatus.StatusUpdater SyncStatus() syncstatus.StatusUpdater
Storage() spacestorage.SpaceStorage Storage() spacestorage.SpaceStorage
HandleMessage(ctx context.Context, msg HandleMessage) (err error) DeleteTree(ctx context.Context, id string) (err error)
SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error)
DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error)
HandleMessage(ctx context.Context, msg objectsync.HandleMessage) (err error)
HandleSyncRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error)
HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error)
TryClose(objectTTL time.Duration) (close bool, err error) TryClose(objectTTL time.Duration) (close bool, err error)
Close() error Close() error
} }
type space struct { type space struct {
id string
mu sync.RWMutex mu sync.RWMutex
header *spacesyncproto.RawSpaceHeaderWithId header *spacesyncproto.RawSpaceHeaderWithId
objectSync objectsync.ObjectSync state *spacestate.SpaceState
headSync headsync.HeadSync app *app.App
syncStatus syncstatus.StatusUpdater
storage spacestorage.SpaceStorage
treeManager *commonGetter
account accountservice.Service
aclList *syncacl.SyncAcl
configuration nodeconf.NodeConf
settingsObject settings.SettingsObject
peerManager peermanager.PeerManager
treeBuilder objecttree.BuildObjectTreeFunc
metric metric.Metric
handleQueue multiqueue.MultiQueue[HandleMessage] treeBuilder objecttreebuilder.TreeBuilderComponent
headSync headsync.HeadSync
isClosed *atomic.Bool objectSync objectsync.ObjectSync
isDeleted *atomic.Bool syncStatus syncstatus.StatusService
treesUsed *atomic.Int32 settings settings.Settings
} storage spacestorage.SpaceStorage
aclList list.AclList
func (s *space) Id() string {
return s.id
} }
func (s *space) Description() (desc SpaceDescription, err error) { func (s *space) Description() (desc SpaceDescription, err error) {
@ -171,72 +118,64 @@ func (s *space) Description() (desc SpaceDescription, err error) {
return return
} }
func (s *space) StoredIds() []string {
return s.headSync.ExternalIds()
}
func (s *space) DebugAllHeads() []headsync.TreeHeads {
return s.headSync.DebugAllHeads()
}
func (s *space) DeleteTree(ctx context.Context, id string) (err error) {
return s.settings.DeleteTree(ctx, id)
}
func (s *space) SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) {
return s.settings.SpaceDeleteRawChange(ctx)
}
func (s *space) DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) {
return s.settings.DeleteSpace(ctx, deleteChange)
}
func (s *space) HandleMessage(ctx context.Context, msg objectsync.HandleMessage) (err error) {
return s.objectSync.HandleMessage(ctx, msg)
}
func (s *space) HandleSyncRequest(ctx context.Context, req *spacesyncproto.ObjectSyncMessage) (resp *spacesyncproto.ObjectSyncMessage, err error) {
return s.objectSync.HandleRequest(ctx, req)
}
func (s *space) HandleRangeRequest(ctx context.Context, req *spacesyncproto.HeadSyncRequest) (resp *spacesyncproto.HeadSyncResponse, err error) {
return s.headSync.HandleRangeRequest(ctx, req)
}
func (s *space) TreeBuilder() objecttreebuilder.TreeBuilder {
return s.treeBuilder
}
func (s *space) Acl() list.AclList {
return s.aclList
}
func (s *space) Id() string {
return s.state.SpaceId
}
func (s *space) Init(ctx context.Context) (err error) { func (s *space) Init(ctx context.Context) (err error) {
log.With(zap.String("spaceId", s.id)).Debug("initializing space") err = s.app.Start(ctx)
s.storage = newCommonStorage(s.storage)
header, err := s.storage.SpaceHeader()
if err != nil { if err != nil {
return return
} }
s.header = header s.treeBuilder = s.app.MustComponent(objecttreebuilder.CName).(objecttreebuilder.TreeBuilderComponent)
initialIds, err := s.storage.StoredIds() s.headSync = s.app.MustComponent(headsync.CName).(headsync.HeadSync)
if err != nil { s.syncStatus = s.app.MustComponent(syncstatus.CName).(syncstatus.StatusService)
return s.settings = s.app.MustComponent(settings.CName).(settings.Settings)
} s.objectSync = s.app.MustComponent(objectsync.CName).(objectsync.ObjectSync)
aclStorage, err := s.storage.AclStorage() s.storage = s.app.MustComponent(spacestorage.CName).(spacestorage.SpaceStorage)
if err != nil { s.aclList = s.app.MustComponent(syncacl.CName).(list.AclList)
return s.header, err = s.storage.SpaceHeader()
} return
aclList, err := list.BuildAclListWithIdentity(s.account.Account(), aclStorage)
if err != nil {
return
}
s.aclList = syncacl.NewSyncAcl(aclList, s.objectSync.SyncClient().MessagePool())
s.treeManager.AddObject(s.aclList)
deletionState := settingsstate.NewObjectDeletionState(log, s.storage)
deps := settings.Deps{
BuildFunc: func(ctx context.Context, id string, listener updatelistener.UpdateListener) (t synctree.SyncTree, err error) {
res, err := s.BuildTree(ctx, id, BuildTreeOpts{
Listener: listener,
WaitTreeRemoteSync: false,
// space settings document should not have empty data
treeBuilder: objecttree.BuildObjectTree,
})
log.Debug("building settings tree", zap.String("id", id), zap.String("spaceId", s.id))
if err != nil {
return
}
t = res.(synctree.SyncTree)
return
},
Account: s.account,
TreeManager: s.treeManager,
Store: s.storage,
DeletionState: deletionState,
Provider: s.headSync,
Configuration: s.configuration,
OnSpaceDelete: s.onSpaceDelete,
}
s.settingsObject = settings.NewSettingsObject(deps, s.id)
s.headSync.Init(initialIds, deletionState)
err = s.settingsObject.Init(ctx)
if err != nil {
return
}
s.treeManager.AddObject(s.settingsObject)
s.syncStatus.Run()
s.handleQueue = multiqueue.New[HandleMessage](s.handleMessage, 100)
return nil
}
func (s *space) ObjectSync() objectsync.ObjectSync {
return s.objectSync
}
func (s *space) HeadSync() headsync.HeadSync {
return s.headSync
} }
func (s *space) SyncStatus() syncstatus.StatusUpdater { func (s *space) SyncStatus() syncstatus.StatusUpdater {
@ -247,246 +186,25 @@ func (s *space) Storage() spacestorage.SpaceStorage {
return s.storage return s.storage
} }
func (s *space) StoredIds() []string {
return slice.DiscardFromSlice(s.headSync.AllIds(), func(id string) bool {
return id == s.settingsObject.Id()
})
}
func (s *space) DebugAllHeads() []headsync.TreeHeads {
return s.headSync.DebugAllHeads()
}
func (s *space) CreateTree(ctx context.Context, payload objecttree.ObjectTreeCreatePayload) (res treestorage.TreeStorageCreatePayload, err error) {
if s.isClosed.Load() {
err = ErrSpaceClosed
return
}
root, err := objecttree.CreateObjectTreeRoot(payload, s.aclList)
if err != nil {
return
}
res = treestorage.TreeStorageCreatePayload{
RootRawChange: root,
Changes: []*treechangeproto.RawTreeChangeWithId{root},
Heads: []string{root.Id},
}
return
}
func (s *space) PutTree(ctx context.Context, payload treestorage.TreeStorageCreatePayload, listener updatelistener.UpdateListener) (t objecttree.ObjectTree, err error) {
if s.isClosed.Load() {
err = ErrSpaceClosed
return
}
deps := synctree.BuildDeps{
SpaceId: s.id,
SyncClient: s.objectSync.SyncClient(),
Configuration: s.configuration,
HeadNotifiable: s.headSync,
Listener: listener,
AclList: s.aclList,
SpaceStorage: s.storage,
OnClose: s.onObjectClose,
SyncStatus: s.syncStatus,
PeerGetter: s.peerManager,
BuildObjectTree: s.treeBuilder,
}
t, err = synctree.PutSyncTree(ctx, payload, deps)
if err != nil {
return
}
s.treesUsed.Add(1)
log.Debug("incrementing counter", zap.String("id", payload.RootRawChange.Id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id))
return
}
type BuildTreeOpts struct {
Listener updatelistener.UpdateListener
WaitTreeRemoteSync bool
treeBuilder objecttree.BuildObjectTreeFunc
}
type HistoryTreeOpts struct {
BeforeId string
Include bool
BuildFullTree bool
}
func (s *space) BuildTree(ctx context.Context, id string, opts BuildTreeOpts) (t objecttree.ObjectTree, err error) {
if s.isClosed.Load() {
err = ErrSpaceClosed
return
}
treeBuilder := opts.treeBuilder
if treeBuilder == nil {
treeBuilder = s.treeBuilder
}
deps := synctree.BuildDeps{
SpaceId: s.id,
SyncClient: s.objectSync.SyncClient(),
Configuration: s.configuration,
HeadNotifiable: s.headSync,
Listener: opts.Listener,
AclList: s.aclList,
SpaceStorage: s.storage,
OnClose: s.onObjectClose,
SyncStatus: s.syncStatus,
WaitTreeRemoteSync: opts.WaitTreeRemoteSync,
PeerGetter: s.peerManager,
BuildObjectTree: treeBuilder,
}
s.treesUsed.Add(1)
log.Debug("incrementing counter", zap.String("id", id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id))
if t, err = synctree.BuildSyncTreeOrGetRemote(ctx, id, deps); err != nil {
s.treesUsed.Add(-1)
log.Debug("decrementing counter, load failed", zap.String("id", id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id), zap.Error(err))
return nil, err
}
return
}
func (s *space) BuildHistoryTree(ctx context.Context, id string, opts HistoryTreeOpts) (t objecttree.HistoryTree, err error) {
if s.isClosed.Load() {
err = ErrSpaceClosed
return
}
params := objecttree.HistoryTreeParams{
AclList: s.aclList,
BeforeId: opts.BeforeId,
IncludeBeforeId: opts.Include,
BuildFullTree: opts.BuildFullTree,
}
params.TreeStorage, err = s.storage.TreeStorage(id)
if err != nil {
return
}
return objecttree.BuildHistoryTree(params)
}
func (s *space) DeleteTree(ctx context.Context, id string) (err error) {
return s.settingsObject.DeleteObject(id)
}
func (s *space) SpaceDeleteRawChange(ctx context.Context) (raw *treechangeproto.RawTreeChangeWithId, err error) {
return s.settingsObject.SpaceDeleteRawChange()
}
func (s *space) DeleteSpace(ctx context.Context, deleteChange *treechangeproto.RawTreeChangeWithId) (err error) {
return s.settingsObject.DeleteSpace(ctx, deleteChange)
}
func (s *space) HandleMessage(ctx context.Context, hm HandleMessage) (err error) {
threadId := hm.Message.ObjectId
hm.ReceiveTime = time.Now()
if hm.Message.ReplyId != "" {
threadId += hm.Message.ReplyId
defer func() {
_ = s.handleQueue.CloseThread(threadId)
}()
}
if hm.PeerCtx == nil {
hm.PeerCtx = ctx
}
err = s.handleQueue.Add(ctx, threadId, hm)
if err == mb.ErrOverflowed {
log.InfoCtx(ctx, "queue overflowed", zap.String("spaceId", s.id), zap.String("objectId", threadId))
// skip overflowed error
return nil
}
return
}
func (s *space) handleMessage(msg HandleMessage) {
var err error
msg.StartHandlingTime = time.Now()
ctx := peer.CtxWithPeerId(context.Background(), msg.SenderId)
ctx = logger.CtxWithFields(ctx, zap.Uint64("msgId", msg.Id), zap.String("senderId", msg.SenderId))
defer func() {
if s.metric == nil {
return
}
s.metric.RequestLog(msg.PeerCtx, "space.streamOp", msg.LogFields(
zap.Error(err),
)...)
}()
if !msg.Deadline.IsZero() {
now := time.Now()
if now.After(msg.Deadline) {
log.InfoCtx(ctx, "skip message: deadline exceed")
err = context.DeadlineExceeded
return
}
var cancel context.CancelFunc
ctx, cancel = context.WithDeadline(ctx, msg.Deadline)
defer cancel()
}
if err = s.objectSync.HandleMessage(ctx, msg.SenderId, msg.Message); err != nil {
if msg.Message.ObjectId != "" {
// cleanup thread on error
_ = s.handleQueue.CloseThread(msg.Message.ObjectId)
}
log.InfoCtx(ctx, "handleMessage error", zap.Error(err))
}
}
func (s *space) onObjectClose(id string) {
s.treesUsed.Add(-1)
log.Debug("decrementing counter", zap.String("id", id), zap.Int32("trees", s.treesUsed.Load()), zap.String("spaceId", s.id))
_ = s.handleQueue.CloseThread(id)
}
func (s *space) onSpaceDelete() {
err := s.storage.SetSpaceDeleted()
if err != nil {
log.Debug("failed to set space deleted")
}
s.isDeleted.Swap(true)
}
func (s *space) Close() error { func (s *space) Close() error {
if s.isClosed.Swap(true) { if s.state.SpaceIsClosed.Swap(true) {
log.Warn("call space.Close on closed space", zap.String("id", s.id)) log.Warn("call space.Close on closed space", zap.String("id", s.state.SpaceId))
return nil return nil
} }
log.With(zap.String("id", s.id)).Debug("space is closing") log := log.With(zap.String("spaceId", s.state.SpaceId))
log.Debug("space is closing")
var mError errs.Group err := s.app.Close(context.Background())
if err := s.handleQueue.Close(); err != nil { log.Debug("space closed")
mError.Add(err) return err
}
if err := s.headSync.Close(); err != nil {
mError.Add(err)
}
if err := s.objectSync.Close(); err != nil {
mError.Add(err)
}
if err := s.settingsObject.Close(); err != nil {
mError.Add(err)
}
if err := s.aclList.Close(); err != nil {
mError.Add(err)
}
if err := s.storage.Close(); err != nil {
mError.Add(err)
}
if err := s.syncStatus.Close(); err != nil {
mError.Add(err)
}
log.With(zap.String("id", s.id)).Debug("space closed")
return mError.Err()
} }
func (s *space) TryClose(objectTTL time.Duration) (close bool, err error) { func (s *space) TryClose(objectTTL time.Duration) (close bool, err error) {
if time.Now().Sub(s.objectSync.LastUsage()) < objectTTL { if time.Now().Sub(s.objectSync.LastUsage()) < objectTTL {
return false, nil return false, nil
} }
locked := s.treesUsed.Load() > 1 locked := s.state.TreesUsed.Load() > 1
log.With(zap.Int32("trees used", s.treesUsed.Load()), zap.Bool("locked", locked), zap.String("spaceId", s.id)).Debug("space lock status check") log.With(zap.Int32("trees used", s.state.TreesUsed.Load()), zap.Bool("locked", locked), zap.String("spaceId", s.state.SpaceId)).Debug("space lock status check")
if locked { if locked {
return false, nil return false, nil
} }

View File

@ -2,26 +2,36 @@ package commonspace
import ( import (
"context" "context"
"sync/atomic"
"github.com/anyproto/any-sync/accountservice" "github.com/anyproto/any-sync/accountservice"
"github.com/anyproto/any-sync/app" "github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/app/logger" "github.com/anyproto/any-sync/app/logger"
"github.com/anyproto/any-sync/commonspace/config"
"github.com/anyproto/any-sync/commonspace/credentialprovider" "github.com/anyproto/any-sync/commonspace/credentialprovider"
"github.com/anyproto/any-sync/commonspace/deletionstate"
"github.com/anyproto/any-sync/commonspace/headsync" "github.com/anyproto/any-sync/commonspace/headsync"
"github.com/anyproto/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anyproto/any-sync/commonspace/object/acl/syncacl"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree" "github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto" "github.com/anyproto/any-sync/commonspace/object/tree/treechangeproto"
"github.com/anyproto/any-sync/commonspace/object/treemanager" "github.com/anyproto/any-sync/commonspace/object/treemanager"
"github.com/anyproto/any-sync/commonspace/objectmanager"
"github.com/anyproto/any-sync/commonspace/objectsync" "github.com/anyproto/any-sync/commonspace/objectsync"
"github.com/anyproto/any-sync/commonspace/objecttreebuilder"
"github.com/anyproto/any-sync/commonspace/peermanager" "github.com/anyproto/any-sync/commonspace/peermanager"
"github.com/anyproto/any-sync/commonspace/requestmanager"
"github.com/anyproto/any-sync/commonspace/settings"
"github.com/anyproto/any-sync/commonspace/spacestate"
"github.com/anyproto/any-sync/commonspace/spacestorage" "github.com/anyproto/any-sync/commonspace/spacestorage"
"github.com/anyproto/any-sync/commonspace/spacesyncproto" "github.com/anyproto/any-sync/commonspace/spacesyncproto"
"github.com/anyproto/any-sync/commonspace/syncstatus" "github.com/anyproto/any-sync/commonspace/syncstatus"
"github.com/anyproto/any-sync/consensus/consensusproto"
"github.com/anyproto/any-sync/metric" "github.com/anyproto/any-sync/metric"
"github.com/anyproto/any-sync/net/peer" "github.com/anyproto/any-sync/net/peer"
"github.com/anyproto/any-sync/net/pool" "github.com/anyproto/any-sync/net/pool"
"github.com/anyproto/any-sync/net/rpc/rpcerr" "github.com/anyproto/any-sync/net/rpc/rpcerr"
"github.com/anyproto/any-sync/nodeconf" "github.com/anyproto/any-sync/nodeconf"
"sync/atomic" "storj.io/drpc"
) )
const CName = "common.commonspace" const CName = "common.commonspace"
@ -45,32 +55,30 @@ type SpaceService interface {
} }
type spaceService struct { type spaceService struct {
config Config config config.Config
account accountservice.Service account accountservice.Service
configurationService nodeconf.Service configurationService nodeconf.Service
storageProvider spacestorage.SpaceStorageProvider storageProvider spacestorage.SpaceStorageProvider
peermanagerProvider peermanager.PeerManagerProvider peerManagerProvider peermanager.PeerManagerProvider
credentialProvider credentialprovider.CredentialProvider credentialProvider credentialprovider.CredentialProvider
treeManager treemanager.TreeManager statusServiceProvider syncstatus.StatusServiceProvider
pool pool.Pool treeManager treemanager.TreeManager
metric metric.Metric pool pool.Pool
metric metric.Metric
app *app.App
} }
func (s *spaceService) Init(a *app.App) (err error) { func (s *spaceService) Init(a *app.App) (err error) {
s.config = a.MustComponent("config").(ConfigGetter).GetSpace() s.config = a.MustComponent("config").(config.ConfigGetter).GetSpace()
s.account = a.MustComponent(accountservice.CName).(accountservice.Service) s.account = a.MustComponent(accountservice.CName).(accountservice.Service)
s.storageProvider = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorageProvider) s.storageProvider = a.MustComponent(spacestorage.CName).(spacestorage.SpaceStorageProvider)
s.configurationService = a.MustComponent(nodeconf.CName).(nodeconf.Service) s.configurationService = a.MustComponent(nodeconf.CName).(nodeconf.Service)
s.treeManager = a.MustComponent(treemanager.CName).(treemanager.TreeManager) s.treeManager = a.MustComponent(treemanager.CName).(treemanager.TreeManager)
s.peermanagerProvider = a.MustComponent(peermanager.CName).(peermanager.PeerManagerProvider) s.peerManagerProvider = a.MustComponent(peermanager.CName).(peermanager.PeerManagerProvider)
credProvider := a.Component(credentialprovider.CName) s.statusServiceProvider = a.MustComponent(syncstatus.CName).(syncstatus.StatusServiceProvider)
if credProvider != nil {
s.credentialProvider = credProvider.(credentialprovider.CredentialProvider)
} else {
s.credentialProvider = credentialprovider.NewNoOp()
}
s.pool = a.MustComponent(pool.CName).(pool.Pool) s.pool = a.MustComponent(pool.CName).(pool.Pool)
s.metric, _ = a.Component(metric.CName).(metric.Metric) s.metric, _ = a.Component(metric.CName).(metric.Metric)
s.app = a
return nil return nil
} }
@ -138,8 +146,6 @@ func (s *spaceService) NewSpace(ctx context.Context, id string) (Space, error) {
} }
} }
} }
lastConfiguration := s.configurationService
var ( var (
spaceIsClosed = &atomic.Bool{} spaceIsClosed = &atomic.Bool{}
spaceIsDeleted = &atomic.Bool{} spaceIsDeleted = &atomic.Bool{}
@ -149,49 +155,46 @@ func (s *spaceService) NewSpace(ctx context.Context, id string) (Space, error) {
return nil, err return nil, err
} }
spaceIsDeleted.Swap(isDeleted) spaceIsDeleted.Swap(isDeleted)
getter := newCommonGetter(st.Id(), s.treeManager, spaceIsClosed) state := &spacestate.SpaceState{
syncStatus := syncstatus.NewNoOpSyncStatus() SpaceId: st.Id(),
// this will work only for clients, not the best solution, but... SpaceIsDeleted: spaceIsDeleted,
if !lastConfiguration.IsResponsible(st.Id()) { SpaceIsClosed: spaceIsClosed,
// TODO: move it to the client package and add possibility to inject StatusProvider from the client TreesUsed: &atomic.Int32{},
syncStatus = syncstatus.NewSyncStatusProvider(st.Id(), syncstatus.DefaultDeps(lastConfiguration, st))
} }
var builder objecttree.BuildObjectTreeFunc
if s.config.KeepTreeDataInMemory { if s.config.KeepTreeDataInMemory {
builder = objecttree.BuildObjectTree state.TreeBuilderFunc = objecttree.BuildObjectTree
} else { } else {
builder = objecttree.BuildEmptyDataObjectTree state.TreeBuilderFunc = objecttree.BuildEmptyDataObjectTree
} }
peerManager, err := s.peerManagerProvider.NewPeerManager(ctx, id)
peerManager, err := s.peermanagerProvider.NewPeerManager(ctx, id)
if err != nil { if err != nil {
return nil, err return nil, err
} }
statusService := s.statusServiceProvider.NewStatusService()
spaceApp := s.app.ChildApp()
spaceApp.Register(state).
Register(peerManager).
Register(newCommonStorage(st)).
Register(statusService).
Register(syncacl.New()).
Register(requestmanager.New()).
Register(deletionstate.New()).
Register(settings.New()).
Register(objectmanager.New(s.treeManager)).
Register(objecttreebuilder.New()).
Register(objectsync.New()).
Register(headsync.New())
headSync := headsync.NewHeadSync(id, spaceIsDeleted, s.config.SyncPeriod, lastConfiguration, st, peerManager, getter, syncStatus, s.credentialProvider, log)
objectSync := objectsync.NewObjectSync(id, spaceIsDeleted, lastConfiguration, peerManager, getter, st)
sp := &space{ sp := &space{
id: id, state: state,
objectSync: objectSync, app: spaceApp,
headSync: headSync,
syncStatus: syncStatus,
treeManager: getter,
account: s.account,
configuration: lastConfiguration,
peerManager: peerManager,
storage: st,
treesUsed: &atomic.Int32{},
treeBuilder: builder,
isClosed: spaceIsClosed,
isDeleted: spaceIsDeleted,
metric: s.metric,
} }
return sp, nil return sp, nil
} }
func (s *spaceService) addSpaceStorage(ctx context.Context, spaceDescription SpaceDescription) (st spacestorage.SpaceStorage, err error) { func (s *spaceService) addSpaceStorage(ctx context.Context, spaceDescription SpaceDescription) (st spacestorage.SpaceStorage, err error) {
payload := spacestorage.SpaceStorageCreatePayload{ payload := spacestorage.SpaceStorageCreatePayload{
AclWithId: &aclrecordproto.RawAclRecordWithId{ AclWithId: &consensusproto.RawRecordWithId{
Payload: spaceDescription.AclPayload, Payload: spaceDescription.AclPayload,
Id: spaceDescription.AclId, Id: spaceDescription.AclId,
}, },
@ -226,15 +229,19 @@ func (s *spaceService) getSpaceStorageFromRemote(ctx context.Context, id string)
return return
} }
cl := spacesyncproto.NewDRPCSpaceSyncClient(p) var res *spacesyncproto.SpacePullResponse
res, err := cl.SpacePull(ctx, &spacesyncproto.SpacePullRequest{Id: id}) err = p.DoDrpc(ctx, func(conn drpc.Conn) error {
cl := spacesyncproto.NewDRPCSpaceSyncClient(conn)
res, err = cl.SpacePull(ctx, &spacesyncproto.SpacePullRequest{Id: id})
return err
})
if err != nil { if err != nil {
err = rpcerr.Unwrap(err) err = rpcerr.Unwrap(err)
return return
} }
st, err = s.createSpaceStorage(spacestorage.SpaceStorageCreatePayload{ st, err = s.createSpaceStorage(spacestorage.SpaceStorageCreatePayload{
AclWithId: &aclrecordproto.RawAclRecordWithId{ AclWithId: &consensusproto.RawRecordWithId{
Payload: res.Payload.AclPayload, Payload: res.Payload.AclPayload,
Id: res.Payload.AclPayloadId, Id: res.Payload.AclPayloadId,
}, },

View File

@ -0,0 +1,25 @@
package spacestate
import (
"github.com/anyproto/any-sync/app"
"github.com/anyproto/any-sync/commonspace/object/tree/objecttree"
"sync/atomic"
)
const CName = "common.commonspace.spacestate"
type SpaceState struct {
SpaceId string
SpaceIsDeleted *atomic.Bool
SpaceIsClosed *atomic.Bool
TreesUsed *atomic.Int32
TreeBuilderFunc objecttree.BuildObjectTreeFunc
}
func (s *SpaceState) Init(a *app.App) (err error) {
return nil
}
func (s *SpaceState) Name() (name string) {
return CName
}

Some files were not shown because too many files have changed in this diff Show More