diff --git a/app/ocache/entry.go b/app/ocache/entry.go new file mode 100644 index 00000000..fbe60e26 --- /dev/null +++ b/app/ocache/entry.go @@ -0,0 +1,120 @@ +package ocache + +import ( + "context" + "go.uber.org/zap" + "sync" + "time" +) + +type entryState int + +const ( + entryStateLoading = iota + entryStateActive + entryStateClosing + entryStateClosed +) + +type entry struct { + id string + state entryState + lastUsage time.Time + load chan struct{} + loadErr error + value Object + close chan struct{} + mx sync.Mutex +} + +func newEntry(id string, value Object, state entryState) *entry { + return &entry{ + id: id, + load: make(chan struct{}), + lastUsage: time.Now(), + state: state, + value: value, + } +} + +func (e *entry) isActive() bool { + e.mx.Lock() + defer e.mx.Unlock() + return e.state == entryStateActive +} + +func (e *entry) isClosing() bool { + e.mx.Lock() + defer e.mx.Unlock() + return e.state == entryStateClosed || e.state == entryStateClosing +} + +func (e *entry) waitLoad(ctx context.Context, id string) (value Object, err error) { + select { + case <-ctx.Done(): + log.DebugCtx(ctx, "ctx done while waiting on object load", zap.String("id", id)) + return nil, ctx.Err() + case <-e.load: + return e.value, e.loadErr + } +} + +func (e *entry) waitClose(ctx context.Context, id string) (res bool, err error) { + e.mx.Lock() + switch e.state { + case entryStateClosing: + waitCh := e.close + e.mx.Unlock() + select { + case <-ctx.Done(): + log.DebugCtx(ctx, "ctx done while waiting on object close", zap.String("id", id)) + return false, ctx.Err() + case <-waitCh: + return true, nil + } + case entryStateClosed: + e.mx.Unlock() + return true, nil + default: + e.mx.Unlock() + return false, nil + } +} + +func (e *entry) setClosing(wait bool) (prevState, curState entryState) { + e.mx.Lock() + prevState = e.state + curState = e.state + if e.state == entryStateClosing { + waitCh := e.close + e.mx.Unlock() + if !wait { + return + } + <-waitCh + e.mx.Lock() + } + if e.state != entryStateClosed { + e.state = entryStateClosing + e.close = make(chan struct{}) + } + curState = e.state + e.mx.Unlock() + return +} + +func (e *entry) setActive(chClose bool) { + e.mx.Lock() + defer e.mx.Unlock() + if chClose { + close(e.close) + } + e.state = entryStateActive +} + +func (e *entry) setClosed() { + e.mx.Lock() + defer e.mx.Unlock() + close(e.close) + e.state = entryStateClosed +} diff --git a/app/ocache/ocache.go b/app/ocache/ocache.go index df19b220..7375c643 100644 --- a/app/ocache/ocache.go +++ b/app/ocache/ocache.go @@ -44,12 +44,6 @@ var WithGCPeriod = func(gcPeriod time.Duration) Option { } } -var WithRefCounter = func(enable bool) Option { - return func(cache *oCache) { - cache.refCounter = enable - } -} - func New(loadFunc LoadFunc, opts ...Option) OCache { c := &oCache{ data: make(map[string]*entry), @@ -73,33 +67,7 @@ func New(loadFunc LoadFunc, opts ...Option) OCache { type Object interface { Close() (err error) -} - -type ObjectLocker interface { - Object - Locked() bool -} - -type ObjectLastUsage interface { - LastUsage() time.Time -} - -type entry struct { - id string - lastUsage time.Time - refCount uint32 - isClosing bool - load chan struct{} - loadErr error - value Object - close chan struct{} -} - -func (e *entry) locked() bool { - if locker, ok := e.value.(ObjectLocker); ok { - return locker.Locked() - } - return false + TryClose(objectTTL time.Duration) (res bool, err error) } type OCache interface { @@ -116,10 +84,6 @@ type OCache interface { // Add adds new object to cache // Returns error when object exists Add(id string, value Object) (err error) - // Release decreases the refs counter - Release(id string) bool - // Reset sets refs counter to 0 - Reset(id string) bool // Remove closes and removes object Remove(id string) (ok bool, err error) // ForEach iterates over all loaded objects, breaks when callback returns false @@ -134,17 +98,16 @@ type OCache interface { } type oCache struct { - mu sync.Mutex - data map[string]*entry - loadFunc LoadFunc - timeNow func() time.Time - ttl time.Duration - gc time.Duration - closed bool - closeCh chan struct{} - log *zap.SugaredLogger - metrics *metrics - refCounter bool + mu sync.Mutex + data map[string]*entry + loadFunc LoadFunc + timeNow func() time.Time + ttl time.Duration + gc time.Duration + closed bool + closeCh chan struct{} + log *zap.SugaredLogger + metrics *metrics } func (c *oCache) Get(ctx context.Context, id string) (value Object, err error) { @@ -160,69 +123,36 @@ Load: return nil, ErrClosed } if e, ok = c.data[id]; !ok { + e = newEntry(id, nil, entryStateLoading) load = true - e = &entry{ - id: id, - load: make(chan struct{}), - } c.data[id] = e } - closing := e.isClosing - if !e.isClosing { - e.lastUsage = c.timeNow() - if c.refCounter { - e.refCount++ - } - } + e.lastUsage = time.Now() c.mu.Unlock() - if closing { - select { - case <-ctx.Done(): - log.DebugCtx(ctx, "ctx done while waiting on object close", zap.String("id", id)) - return nil, ctx.Err() - case <-e.close: - goto Load - } + reload, err := e.waitClose(ctx, id) + if err != nil { + return nil, err + } + if reload { + goto Load } - if load { go c.load(ctx, id, e) } - if c.metrics != nil { - if load { - c.metrics.miss.Inc() - } else { - c.metrics.hit.Inc() - } - } - select { - case <-ctx.Done(): - log.DebugCtx(ctx, "ctx done while waiting on object load", zap.String("id", id)) - return nil, ctx.Err() - case <-e.load: - } - return e.value, e.loadErr + c.metricsGet(!load) + return e.waitLoad(ctx, id) } func (c *oCache) Pick(ctx context.Context, id string) (value Object, err error) { c.mu.Lock() val, ok := c.data[id] - if !ok || val.isClosing { + if !ok || val.isClosing() { c.mu.Unlock() return nil, ErrNotExists } c.mu.Unlock() - - if c.metrics != nil { - c.metrics.hit.Inc() - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-val.load: - return val.value, val.loadErr - } + c.metricsGet(true) + return val.waitLoad(ctx, id) } func (c *oCache) load(ctx context.Context, id string, e *entry) { @@ -236,37 +166,10 @@ func (c *oCache) load(ctx context.Context, id string, e *entry) { delete(c.data, id) } else { e.value = value + e.setActive(false) } } -func (c *oCache) Release(id string) bool { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return false - } - if e, ok := c.data[id]; ok { - if c.refCounter && e.refCount > 0 { - e.refCount-- - return true - } - } - return false -} - -func (c *oCache) Reset(id string) bool { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return false - } - if e, ok := c.data[id]; ok { - e.refCount = 0 - return true - } - return false -} - func (c *oCache) Remove(id string) (ok bool, err error) { c.mu.Lock() if c.closed { @@ -274,25 +177,29 @@ func (c *oCache) Remove(id string) (ok bool, err error) { err = ErrClosed return } - var e *entry - e, ok = c.data[id] - if !ok || e.isClosing { + e, ok := c.data[id] + if !ok { c.mu.Unlock() - return + return false, ErrNotExists } - e.isClosing = true - e.close = make(chan struct{}) c.mu.Unlock() + return c.remove(e) +} +func (c *oCache) remove(e *entry) (ok bool, err error) { <-e.load - if e.value != nil { - err = e.value.Close() + if e.value == nil { + return false, ErrNotExists + } + _, curState := e.setClosing(true) + if curState == entryStateClosing { + ok = true + err = e.value.Close() + c.mu.Lock() + e.setClosed() + delete(c.data, e.id) + c.mu.Unlock() } - c.mu.Lock() - close(e.close) - delete(c.data, e.id) - c.mu.Unlock() - return } @@ -314,13 +221,7 @@ func (c *oCache) Add(id string, value Object) (err error) { if _, ok := c.data[id]; ok { return ErrExists } - e := &entry{ - id: id, - lastUsage: time.Now(), - refCount: 0, - load: make(chan struct{}), - value: value, - } + e := newEntry(id, value, entryStateActive) close(e.load) c.data[id] = e return @@ -332,7 +233,7 @@ func (c *oCache) ForEach(f func(obj Object) (isContinue bool)) { for _, v := range c.data { select { case <-v.load: - if v.value != nil && !v.isClosing { + if v.value != nil && !v.isClosing() { objects = append(objects, v.value) } default: @@ -368,40 +269,35 @@ func (c *oCache) GC() { deadline := c.timeNow().Add(-c.ttl) var toClose []*entry for _, e := range c.data { - if e.isClosing { - continue - } - lu := e.lastUsage - if lug, ok := e.value.(ObjectLastUsage); ok { - lu = lug.LastUsage() - } - if !e.locked() && e.refCount <= 0 && lu.Before(deadline) { - e.isClosing = true + if e.isActive() && e.lastUsage.Before(deadline) { e.close = make(chan struct{}) toClose = append(toClose, e) } } size := len(c.data) c.mu.Unlock() - + closedNum := 0 for _, e := range toClose { - <-e.load - if e.value != nil { - if err := e.value.Close(); err != nil { - c.log.With("object_id", e.id).Warnf("GC: object close error: %v", err) - } + prevState, _ := e.setClosing(false) + if prevState == entryStateClosing || prevState == entryStateClosed { + continue + } + closed, err := e.value.TryClose(c.ttl) + if err != nil { + c.log.With("object_id", e.id).Warnf("GC: object close error: %v", err) + } + if !closed { + e.setActive(true) + continue + } else { + closedNum++ + c.mu.Lock() + e.setClosed() + delete(c.data, e.id) + c.mu.Unlock() } } - c.log.Infof("GC: removed %d; cache size: %d", len(toClose), size) - if len(toClose) > 0 && c.metrics != nil { - c.metrics.gc.Add(float64(len(toClose))) - } - c.mu.Lock() - for _, e := range toClose { - close(e.close) - delete(c.data, e.id) - } - c.mu.Unlock() + c.metricsClosed(closedNum, size) } func (c *oCache) Len() int { @@ -418,25 +314,34 @@ func (c *oCache) Close() (err error) { } c.closed = true close(c.closeCh) - var toClose, alreadyClosing []*entry + var toClose []*entry for _, e := range c.data { - if e.isClosing { - alreadyClosing = append(alreadyClosing, e) - } else { - toClose = append(toClose, e) - } + toClose = append(toClose, e) } c.mu.Unlock() for _, e := range toClose { - <-e.load - if e.value != nil { - if clErr := e.value.Close(); clErr != nil { - c.log.With("object_id", e.id).Warnf("cache close: object close error: %v", clErr) - } + if _, err := c.remove(e); err != nil && err != ErrNotExists { + c.log.With("object_id", e.id).Warnf("cache close: object close error: %v", err) } } - for _, e := range alreadyClosing { - <-e.close - } return nil } + +func (c *oCache) metricsGet(hit bool) { + if c.metrics == nil { + return + } + if hit { + c.metrics.hit.Inc() + } else { + c.metrics.miss.Inc() + } +} + +func (c *oCache) metricsClosed(closedLen, size int) { + c.log.Infof("GC: removed %d; cache size: %d", closedLen, size) + if c.metrics == nil || closedLen == 0 { + return + } + c.metrics.gc.Add(float64(closedLen)) +} diff --git a/app/ocache/ocache_test.go b/app/ocache/ocache_test.go index 54034141..14a59b1a 100644 --- a/app/ocache/ocache_test.go +++ b/app/ocache/ocache_test.go @@ -3,6 +3,8 @@ package ocache import ( "context" "errors" + "fmt" + "sync" "sync/atomic" "testing" "time" @@ -12,25 +14,45 @@ import ( ) type testObject struct { - name string - closeErr error - closeCh chan struct{} + name string + closeErr error + closeCh chan struct{} + tryReturn bool + closeCalled bool + tryCloseCalled bool } -func NewTestObject(name string, closeCh chan struct{}) *testObject { +func NewTestObject(name string, tryReturn bool, closeCh chan struct{}) *testObject { return &testObject{ - name: name, - closeCh: closeCh, + name: name, + closeCh: closeCh, + tryReturn: tryReturn, } } func (t *testObject) Close() (err error) { + if t.closeCalled || (t.tryCloseCalled && t.tryReturn) { + panic("close called twice") + } + t.closeCalled = true if t.closeCh != nil { <-t.closeCh } return t.closeErr } +func (t *testObject) TryClose(objectTTL time.Duration) (res bool, err error) { + if t.closeCalled || (t.tryCloseCalled && t.tryReturn) { + panic("close called twice") + } + t.tryCloseCalled = true + if t.closeCh != nil { + <-t.closeCh + return t.tryReturn, t.closeErr + } + return t.tryReturn, nil +} + func TestOCache_Get(t *testing.T) { t.Run("successful", func(t *testing.T) { c := New(func(ctx context.Context, id string) (value Object, err error) { @@ -116,42 +138,37 @@ func TestOCache_Get(t *testing.T) { } func TestOCache_GC(t *testing.T) { - t.Run("test without close wait", func(t *testing.T) { + t.Run("test gc expired object", func(t *testing.T) { c := New(func(ctx context.Context, id string) (value Object, err error) { - return &testObject{name: id}, nil - }, WithTTL(time.Millisecond*10), WithRefCounter(true)) + return NewTestObject(id, true, nil), nil + }, WithTTL(time.Millisecond*10)) val, err := c.Get(context.TODO(), "id") require.NoError(t, err) require.NotNil(t, val) assert.Equal(t, 1, c.Len()) c.GC() assert.Equal(t, 1, c.Len()) - time.Sleep(time.Millisecond * 30) - c.GC() - assert.Equal(t, 1, c.Len()) - assert.True(t, c.Release("id")) + time.Sleep(time.Millisecond * 20) c.GC() assert.Equal(t, 0, c.Len()) - assert.False(t, c.Release("id")) }) - t.Run("test with close wait", func(t *testing.T) { + t.Run("test gc tryClose true, close before get", func(t *testing.T) { closeCh := make(chan struct{}) getCh := make(chan struct{}) c := New(func(ctx context.Context, id string) (value Object, err error) { - return NewTestObject(id, closeCh), nil - }, WithTTL(time.Millisecond*10), WithRefCounter(true)) + return NewTestObject(id, true, closeCh), nil + }, WithTTL(time.Millisecond*10)) val, err := c.Get(context.TODO(), "id") require.NoError(t, err) require.NotNil(t, val) assert.Equal(t, 1, c.Len()) - assert.True(t, c.Release("id")) // making ttl pass - time.Sleep(time.Millisecond * 40) + time.Sleep(time.Millisecond * 20) // first gc will be run after 20 secs, so calling it manually go c.GC() // waiting until all objects are marked as closing - time.Sleep(time.Millisecond * 40) + time.Sleep(time.Millisecond * 20) var events []string go func() { _, err := c.Get(context.TODO(), "id") @@ -160,47 +177,314 @@ func TestOCache_GC(t *testing.T) { events = append(events, "get") close(getCh) }() - events = append(events, "close") // sleeping to make sure that Get is called - time.Sleep(time.Millisecond * 40) + time.Sleep(time.Millisecond * 20) + events = append(events, "close") close(closeCh) <-getCh require.Equal(t, []string{"close", "get"}, events) }) + t.Run("test gc tryClose false, many parallel get", func(t *testing.T) { + timesCalled := &atomic.Int32{} + obj := NewTestObject("id", false, nil) + c := New(func(ctx context.Context, id string) (value Object, err error) { + timesCalled.Add(1) + return obj, nil + }, WithTTL(time.Millisecond*10)) + + val, err := c.Get(context.TODO(), "id") + require.NoError(t, err) + require.NotNil(t, val) + assert.Equal(t, 1, c.Len()) + time.Sleep(time.Millisecond * 20) + begin := make(chan struct{}) + wg := sync.WaitGroup{} + once := sync.Once{} + + wg.Add(1) + go func() { + <-begin + c.GC() + wg.Done() + }() + for i := 0; i < 50; i++ { + wg.Add(1) + go func(i int) { + once.Do(func() { + close(begin) + }) + if i%2 != 0 { + time.Sleep(time.Millisecond) + } + _, err := c.Get(context.TODO(), "id") + require.NoError(t, err) + wg.Done() + }(i) + } + require.NoError(t, err) + wg.Wait() + require.Equal(t, timesCalled.Load(), int32(1)) + require.True(t, obj.tryCloseCalled) + }) + t.Run("test gc tryClose different, many objects", func(t *testing.T) { + tryCloseIds := make(map[string]bool) + called := make(map[string]int) + max := 1000 + getId := func(i int) string { + return fmt.Sprintf("id%d", i) + } + for i := 0; i < max; i++ { + if i%2 == 1 { + tryCloseIds[getId(i)] = true + } else { + tryCloseIds[getId(i)] = false + } + } + c := New(func(ctx context.Context, id string) (value Object, err error) { + called[id] = called[id] + 1 + return NewTestObject(id, tryCloseIds[id], nil), nil + }, WithTTL(time.Millisecond*10)) + + for i := 0; i < max; i++ { + val, err := c.Get(context.TODO(), getId(i)) + require.NoError(t, err) + require.NotNil(t, val) + } + assert.Equal(t, max, c.Len()) + time.Sleep(time.Millisecond * 20) + c.GC() + for i := 0; i < max; i++ { + val, err := c.Get(context.TODO(), getId(i)) + require.NoError(t, err) + require.NotNil(t, val) + } + for i := 0; i < max; i++ { + val, err := c.Get(context.TODO(), getId(i)) + require.NoError(t, err) + require.NotNil(t, val) + require.Equal(t, called[getId(i)], i%2+1) + } + }) } func Test_OCache_Remove(t *testing.T) { - closeCh := make(chan struct{}) - getCh := make(chan struct{}) + t.Run("remove simple", func(t *testing.T) { + closeCh := make(chan struct{}) + getCh := make(chan struct{}) + c := New(func(ctx context.Context, id string) (value Object, err error) { + return NewTestObject(id, false, closeCh), nil + }, WithTTL(time.Millisecond*10)) - c := New(func(ctx context.Context, id string) (value Object, err error) { - return NewTestObject(id, closeCh), nil - }, WithTTL(time.Millisecond*10)) - val, err := c.Get(context.TODO(), "id") - require.NoError(t, err) - require.NotNil(t, val) - assert.Equal(t, 1, c.Len()) - // removing the object, so we will wait on closing - go func() { - _, err := c.Remove("id") - require.NoError(t, err) - }() - time.Sleep(time.Millisecond * 40) - - var events []string - go func() { - _, err := c.Get(context.TODO(), "id") + val, err := c.Get(context.TODO(), "id") require.NoError(t, err) require.NotNil(t, val) - events = append(events, "get") - close(getCh) - }() - events = append(events, "close") - // sleeping to make sure that Get is called - time.Sleep(time.Millisecond * 40) - close(closeCh) + assert.Equal(t, 1, c.Len()) + // removing the object, so we will wait on closing + go func() { + _, err := c.Remove("id") + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 20) - <-getCh - require.Equal(t, []string{"close", "get"}, events) + var events []string + go func() { + _, err := c.Get(context.TODO(), "id") + require.NoError(t, err) + require.NotNil(t, val) + events = append(events, "get") + close(getCh) + }() + // sleeping to make sure that Get is called + time.Sleep(time.Millisecond * 20) + events = append(events, "close") + close(closeCh) + + <-getCh + require.Equal(t, []string{"close", "get"}, events) + }) + t.Run("test remove while gc, tryClose false", func(t *testing.T) { + closeCh := make(chan struct{}) + removeCh := make(chan struct{}) + + c := New(func(ctx context.Context, id string) (value Object, err error) { + return NewTestObject(id, false, closeCh), nil + }, WithTTL(time.Millisecond*10)) + val, err := c.Get(context.TODO(), "id") + require.NoError(t, err) + require.NotNil(t, val) + assert.Equal(t, 1, c.Len()) + time.Sleep(time.Millisecond * 20) + go c.GC() + time.Sleep(time.Millisecond * 20) + var events []string + go func() { + ok, err := c.Remove("id") + require.NoError(t, err) + require.True(t, ok) + events = append(events, "remove") + close(removeCh) + }() + time.Sleep(time.Millisecond * 20) + events = append(events, "close") + close(closeCh) + + <-removeCh + require.Equal(t, []string{"close", "remove"}, events) + }) + t.Run("test remove while gc, tryClose true", func(t *testing.T) { + closeCh := make(chan struct{}) + removeCh := make(chan struct{}) + + c := New(func(ctx context.Context, id string) (value Object, err error) { + return NewTestObject(id, true, closeCh), nil + }, WithTTL(time.Millisecond*10)) + val, err := c.Get(context.TODO(), "id") + require.NoError(t, err) + require.NotNil(t, val) + assert.Equal(t, 1, c.Len()) + time.Sleep(time.Millisecond * 20) + go c.GC() + time.Sleep(time.Millisecond * 20) + var events []string + go func() { + ok, err := c.Remove("id") + require.NoError(t, err) + require.False(t, ok) + events = append(events, "remove") + close(removeCh) + }() + time.Sleep(time.Millisecond * 20) + events = append(events, "close") + close(closeCh) + + <-removeCh + require.Equal(t, []string{"close", "remove"}, events) + }) + t.Run("test gc while remove, tryClose true", func(t *testing.T) { + closeCh := make(chan struct{}) + removeCh := make(chan struct{}) + + c := New(func(ctx context.Context, id string) (value Object, err error) { + return NewTestObject(id, true, closeCh), nil + }, WithTTL(time.Millisecond*10)) + val, err := c.Get(context.TODO(), "id") + require.NoError(t, err) + require.NotNil(t, val) + assert.Equal(t, 1, c.Len()) + go func() { + ok, err := c.Remove("id") + require.NoError(t, err) + require.True(t, ok) + close(removeCh) + }() + time.Sleep(20 * time.Millisecond) + c.GC() + close(closeCh) + <-removeCh + }) +} + +func TestOCacheFuzzy(t *testing.T) { + t.Run("test many objects gc, get and remove simultaneously, close after", func(t *testing.T) { + tryCloseIds := make(map[string]bool) + max := 2000 + getId := func(i int) string { + return fmt.Sprintf("id%d", i) + } + for i := 0; i < max; i++ { + if i%2 == 1 { + tryCloseIds[getId(i)] = true + } else { + tryCloseIds[getId(i)] = false + } + } + c := New(func(ctx context.Context, id string) (value Object, err error) { + return NewTestObject(id, tryCloseIds[id], nil), nil + }, WithTTL(time.Nanosecond)) + + stopGC := make(chan struct{}) + wg := sync.WaitGroup{} + go func() { + for { + select { + case <-stopGC: + return + default: + c.GC() + } + } + }() + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + for i := 0; i < max; i++ { + val, err := c.Get(context.TODO(), getId(i)) + require.NoError(t, err) + require.NotNil(t, val) + } + } + }() + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 10; j++ { + for i := 0; i < max; i++ { + c.Remove(getId(i)) + } + } + }() + wg.Wait() + close(stopGC) + err := c.Close() + require.NoError(t, err) + require.Equal(t, 0, c.Len()) + }) + t.Run("test many objects gc, get, remove and close simultaneously", func(t *testing.T) { + tryCloseIds := make(map[string]bool) + max := 2000 + getId := func(i int) string { + return fmt.Sprintf("id%d", i) + } + for i := 0; i < max; i++ { + if i%2 == 1 { + tryCloseIds[getId(i)] = true + } else { + tryCloseIds[getId(i)] = false + } + } + c := New(func(ctx context.Context, id string) (value Object, err error) { + return NewTestObject(id, tryCloseIds[id], nil), nil + }, WithTTL(time.Nanosecond)) + + go func() { + for { + c.GC() + } + }() + go func() { + for j := 0; j < 10; j++ { + for i := 0; i < max; i++ { + val, err := c.Get(context.TODO(), getId(i)) + if err == ErrClosed { + return + } + require.NoError(t, err) + require.NotNil(t, val) + } + } + }() + go func() { + for j := 0; j < 10; j++ { + for i := 0; i < max; i++ { + c.Remove(getId(i)) + } + } + }() + time.Sleep(time.Millisecond) + err := c.Close() + require.NoError(t, err) + require.Equal(t, 0, c.Len()) + }) } diff --git a/commonspace/headsync/diffsyncer_test.go b/commonspace/headsync/diffsyncer_test.go index f9622ca6..ae40e5a7 100644 --- a/commonspace/headsync/diffsyncer_test.go +++ b/commonspace/headsync/diffsyncer_test.go @@ -51,6 +51,10 @@ func (p pushSpaceRequestMatcher) String() string { type mockPeer struct{} +func (m mockPeer) TryClose(objectTTL time.Duration) (res bool, err error) { + return true, m.Close() +} + func (m mockPeer) Id() string { return "mockId" } diff --git a/commonspace/object/tree/objecttree/mock_objecttree/mock_objecttree.go b/commonspace/object/tree/objecttree/mock_objecttree/mock_objecttree.go index 12dd4a95..fb5ea6b1 100644 --- a/commonspace/object/tree/objecttree/mock_objecttree/mock_objecttree.go +++ b/commonspace/object/tree/objecttree/mock_objecttree/mock_objecttree.go @@ -7,6 +7,7 @@ package mock_objecttree import ( context "context" reflect "reflect" + time "time" list "github.com/anytypeio/any-sync/commonspace/object/acl/list" objecttree "github.com/anytypeio/any-sync/commonspace/object/tree/objecttree" @@ -350,6 +351,21 @@ func (mr *MockObjectTreeMockRecorder) Storage() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Storage", reflect.TypeOf((*MockObjectTree)(nil).Storage)) } +// TryClose mocks base method. +func (m *MockObjectTree) TryClose(arg0 time.Duration) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TryClose", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TryClose indicates an expected call of TryClose. +func (mr *MockObjectTreeMockRecorder) TryClose(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryClose", reflect.TypeOf((*MockObjectTree)(nil).TryClose), arg0) +} + // TryLock mocks base method. func (m *MockObjectTree) TryLock() bool { m.ctrl.T.Helper() diff --git a/commonspace/object/tree/objecttree/objecttree.go b/commonspace/object/tree/objecttree/objecttree.go index 170f183e..af169ad2 100644 --- a/commonspace/object/tree/objecttree/objecttree.go +++ b/commonspace/object/tree/objecttree/objecttree.go @@ -5,6 +5,7 @@ import ( "context" "errors" "sync" + "time" "github.com/anytypeio/any-sync/commonspace/object/acl/aclrecordproto" "github.com/anytypeio/any-sync/commonspace/object/acl/list" @@ -82,6 +83,7 @@ type ObjectTree interface { Delete() error Close() error + TryClose(objectTTL time.Duration) (bool, error) } type objectTree struct { @@ -560,6 +562,10 @@ func (ot *objectTree) Root() *Change { return ot.tree.Root() } +func (ot *objectTree) TryClose(objectTTL time.Duration) (bool, error) { + return true, ot.Close() +} + func (ot *objectTree) Close() error { return nil } diff --git a/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go b/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go index e5ee2d8a..1dee0ce3 100644 --- a/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go +++ b/commonspace/object/tree/synctree/mock_synctree/mock_synctree.go @@ -7,6 +7,7 @@ package mock_synctree import ( context "context" reflect "reflect" + time "time" list "github.com/anytypeio/any-sync/commonspace/object/acl/list" objecttree "github.com/anytypeio/any-sync/commonspace/object/tree/objecttree" @@ -501,6 +502,21 @@ func (mr *MockSyncTreeMockRecorder) SyncWithPeer(arg0, arg1 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncWithPeer", reflect.TypeOf((*MockSyncTree)(nil).SyncWithPeer), arg0, arg1) } +// TryClose mocks base method. +func (m *MockSyncTree) TryClose(arg0 time.Duration) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TryClose", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// TryClose indicates an expected call of TryClose. +func (mr *MockSyncTreeMockRecorder) TryClose(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryClose", reflect.TypeOf((*MockSyncTree)(nil).TryClose), arg0) +} + // TryLock mocks base method. func (m *MockSyncTree) TryLock() bool { m.ctrl.T.Helper() diff --git a/commonspace/object/tree/synctree/synctree.go b/commonspace/object/tree/synctree/synctree.go index f25e14d0..56d07c1c 100644 --- a/commonspace/object/tree/synctree/synctree.go +++ b/commonspace/object/tree/synctree/synctree.go @@ -3,6 +3,7 @@ package synctree import ( "context" "errors" + "time" "github.com/anytypeio/any-sync/app/logger" "github.com/anytypeio/any-sync/commonspace/object/acl/list" @@ -209,6 +210,10 @@ func (s *syncTree) Delete() (err error) { return } +func (s *syncTree) TryClose(objectTTL time.Duration) (bool, error) { + return true, s.Close() +} + func (s *syncTree) Close() (err error) { log.Debug("closing sync tree", zap.String("id", s.Id())) defer func() { diff --git a/commonspace/objectsync/msgpool.go b/commonspace/objectsync/msgpool.go index 3dd14a2e..533efc7e 100644 --- a/commonspace/objectsync/msgpool.go +++ b/commonspace/objectsync/msgpool.go @@ -3,7 +3,6 @@ package objectsync import ( "context" "fmt" - "github.com/anytypeio/any-sync/app/ocache" "github.com/anytypeio/any-sync/commonspace/objectsync/synchandler" "github.com/anytypeio/any-sync/commonspace/peermanager" "github.com/anytypeio/any-sync/commonspace/spacesyncproto" @@ -15,9 +14,13 @@ import ( "time" ) +type LastUsage interface { + LastUsage() time.Time +} + // MessagePool can be made generic to work with different streams type MessagePool interface { - ocache.ObjectLastUsage + LastUsage synchandler.SyncHandler peermanager.PeerManager SendSync(ctx context.Context, peerId string, message *spacesyncproto.ObjectSyncMessage) (reply *spacesyncproto.ObjectSyncMessage, err error) diff --git a/commonspace/objectsync/objectsync.go b/commonspace/objectsync/objectsync.go index 85c587d9..74b3f7fa 100644 --- a/commonspace/objectsync/objectsync.go +++ b/commonspace/objectsync/objectsync.go @@ -6,7 +6,6 @@ import ( "time" "github.com/anytypeio/any-sync/app/logger" - "github.com/anytypeio/any-sync/app/ocache" "github.com/anytypeio/any-sync/commonspace/object/syncobjectgetter" "github.com/anytypeio/any-sync/commonspace/objectsync/synchandler" "github.com/anytypeio/any-sync/commonspace/peermanager" @@ -20,7 +19,7 @@ import ( var log = logger.NewNamed("common.commonspace.objectsync") type ObjectSync interface { - ocache.ObjectLastUsage + LastUsage synchandler.SyncHandler MessagePool() MessagePool diff --git a/commonspace/space.go b/commonspace/space.go index ba31af7e..ece7437d 100644 --- a/commonspace/space.go +++ b/commonspace/space.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/anytypeio/any-sync/accountservice" "github.com/anytypeio/any-sync/app/logger" - "github.com/anytypeio/any-sync/app/ocache" "github.com/anytypeio/any-sync/commonspace/headsync" "github.com/anytypeio/any-sync/commonspace/object/acl/list" "github.com/anytypeio/any-sync/commonspace/object/acl/syncacl" @@ -83,9 +82,6 @@ func NewSpaceId(id string, repKey uint64) string { } type Space interface { - ocache.ObjectLocker - ocache.ObjectLastUsage - Id() string Init(ctx context.Context) error @@ -110,6 +106,7 @@ type Space interface { HandleMessage(ctx context.Context, msg HandleMessage) (err error) + TryClose(objectTTL time.Duration) (close bool, err error) Close() error } @@ -136,16 +133,6 @@ type space struct { treesUsed *atomic.Int32 } -func (s *space) LastUsage() time.Time { - return s.objectSync.LastUsage() -} - -func (s *space) Locked() bool { - locked := s.treesUsed.Load() > 1 - log.With(zap.Int32("trees used", s.treesUsed.Load()), zap.Bool("locked", locked), zap.String("spaceId", s.id)).Debug("space lock status check") - return locked -} - func (s *space) Id() string { return s.id } @@ -464,3 +451,15 @@ func (s *space) Close() error { log.With(zap.String("id", s.id)).Debug("space closed") return mError.Err() } + +func (s *space) TryClose(objectTTL time.Duration) (close bool, err error) { + if time.Now().Sub(s.objectSync.LastUsage()) < objectTTL { + return false, nil + } + locked := s.treesUsed.Load() > 1 + log.With(zap.Int32("trees used", s.treesUsed.Load()), zap.Bool("locked", locked), zap.String("spaceId", s.id)).Debug("space lock status check") + if locked { + return false, nil + } + return true, s.Close() +} diff --git a/net/peer/peer.go b/net/peer/peer.go index 9c9f547d..c96cd4ba 100644 --- a/net/peer/peer.go +++ b/net/peer/peer.go @@ -25,11 +25,13 @@ type Peer interface { Id() string LastUsage() time.Time UpdateLastUsage() + TryClose(objectTTL time.Duration) (res bool, err error) drpc.Conn } type peer struct { id string + ttl time.Duration lastUsage int64 sc sec.SecureConn drpc.Conn @@ -76,6 +78,13 @@ func (p *peer) UpdateLastUsage() { atomic.StoreInt64(&p.lastUsage, time.Now().Unix()) } +func (p *peer) TryClose(objectTTL time.Duration) (res bool, err error) { + if time.Now().Sub(p.LastUsage()) < objectTTL { + return false, nil + } + return true, p.Close() +} + func (p *peer) Close() (err error) { log.Debug("peer close", zap.String("peerId", p.id)) return p.Conn.Close() diff --git a/net/pool/pool_test.go b/net/pool/pool_test.go index f913333c..ce3876e0 100644 --- a/net/pool/pool_test.go +++ b/net/pool/pool_test.go @@ -194,6 +194,10 @@ func (t *testPeer) LastUsage() time.Time { func (t *testPeer) UpdateLastUsage() {} +func (t *testPeer) TryClose(objectTTL time.Duration) (res bool, err error) { + return true, t.Close() +} + func (t *testPeer) Close() error { select { case <-t.closed: diff --git a/net/rpc/rpctest/pool.go b/net/rpc/rpctest/pool.go index 7fdbdda4..630cbb6a 100644 --- a/net/rpc/rpctest/pool.go +++ b/net/rpc/rpctest/pool.go @@ -103,6 +103,10 @@ type testPeer struct { drpc.Conn } +func (t testPeer) TryClose(objectTTL time.Duration) (res bool, err error) { + return true, t.Close() +} + func (t testPeer) Id() string { return t.id }