diff --git a/app/ocache/ocache.go b/app/ocache/ocache.go index df19b220..eb67f120 100644 --- a/app/ocache/ocache.go +++ b/app/ocache/ocache.go @@ -4,6 +4,7 @@ import ( "context" "errors" "github.com/anytypeio/any-sync/app/logger" + "github.com/anytypeio/any-sync/util/slice" "go.uber.org/zap" "sync" "time" @@ -44,12 +45,6 @@ var WithGCPeriod = func(gcPeriod time.Duration) Option { } } -var WithRefCounter = func(enable bool) Option { - return func(cache *oCache) { - cache.refCounter = enable - } -} - func New(loadFunc LoadFunc, opts ...Option) OCache { c := &oCache{ data: make(map[string]*entry), @@ -73,33 +68,117 @@ func New(loadFunc LoadFunc, opts ...Option) OCache { type Object interface { Close() (err error) + TryClose() (res bool, err error) } -type ObjectLocker interface { - Object - Locked() bool -} +type entryState int -type ObjectLastUsage interface { - LastUsage() time.Time -} +const ( + entryStateLoading = iota + entryStateActive + entryStateClosing + entryStateClosed +) type entry struct { id string + state entryState lastUsage time.Time - refCount uint32 - isClosing bool load chan struct{} loadErr error value Object close chan struct{} + mx sync.Mutex } -func (e *entry) locked() bool { - if locker, ok := e.value.(ObjectLocker); ok { - return locker.Locked() +func newEntry(id string, value Object, state entryState) *entry { + return &entry{ + id: id, + load: make(chan struct{}), + lastUsage: time.Now(), + state: state, + value: value, } - return false +} + +func (e *entry) getState() entryState { + e.mx.Lock() + defer e.mx.Unlock() + return e.state +} + +func (e *entry) isClosing() bool { + e.mx.Lock() + defer e.mx.Unlock() + return e.state == entryStateClosed || e.state == entryStateClosing +} + +func (e *entry) waitLoad(ctx context.Context, id string) (value Object, err error) { + select { + case <-ctx.Done(): + log.DebugCtx(ctx, "ctx done while waiting on object load", zap.String("id", id)) + return nil, ctx.Err() + case <-e.load: + return e.value, e.loadErr + } +} + +func (e *entry) waitClose(ctx context.Context, id string) (res bool, err error) { + e.mx.Lock() + switch e.state { + case entryStateClosing: + waitCh := e.close + e.mx.Unlock() + select { + case <-ctx.Done(): + log.DebugCtx(ctx, "ctx done while waiting on object close", zap.String("id", id)) + return false, ctx.Err() + case <-waitCh: + return true, nil + } + case entryStateClosed: + e.mx.Unlock() + return true, nil + default: + e.mx.Unlock() + return false, nil + } +} + +func (e *entry) setClosing(wait bool) (prevState entryState) { + e.mx.Lock() + prevState = e.state + if e.state == entryStateClosing { + waitCh := e.close + e.mx.Unlock() + if !wait { + return + } + <-waitCh + e.mx.Lock() + } + if e.state != entryStateClosed { + e.state = entryStateClosing + e.close = make(chan struct{}) + } + e.mx.Unlock() + return +} + +func (e *entry) setActive(chClose bool) { + e.mx.Lock() + defer e.mx.Unlock() + if chClose { + close(e.close) + } + e.state = entryStateActive +} + +func (e *entry) setClosed() { + e.mx.Lock() + defer e.mx.Unlock() + close(e.close) + e.state = entryStateClosed } type OCache interface { @@ -116,10 +195,6 @@ type OCache interface { // Add adds new object to cache // Returns error when object exists Add(id string, value Object) (err error) - // Release decreases the refs counter - Release(id string) bool - // Reset sets refs counter to 0 - Reset(id string) bool // Remove closes and removes object Remove(id string) (ok bool, err error) // ForEach iterates over all loaded objects, breaks when callback returns false @@ -134,17 +209,16 @@ type OCache interface { } type oCache struct { - mu sync.Mutex - data map[string]*entry - loadFunc LoadFunc - timeNow func() time.Time - ttl time.Duration - gc time.Duration - closed bool - closeCh chan struct{} - log *zap.SugaredLogger - metrics *metrics - refCounter bool + mu sync.Mutex + data map[string]*entry + loadFunc LoadFunc + timeNow func() time.Time + ttl time.Duration + gc time.Duration + closed bool + closeCh chan struct{} + log *zap.SugaredLogger + metrics *metrics } func (c *oCache) Get(ctx context.Context, id string) (value Object, err error) { @@ -160,69 +234,46 @@ Load: return nil, ErrClosed } if e, ok = c.data[id]; !ok { + e = newEntry(id, nil, entryStateLoading) load = true - e = &entry{ - id: id, - load: make(chan struct{}), - } c.data[id] = e } - closing := e.isClosing - if !e.isClosing { - e.lastUsage = c.timeNow() - if c.refCounter { - e.refCount++ - } - } c.mu.Unlock() - if closing { - select { - case <-ctx.Done(): - log.DebugCtx(ctx, "ctx done while waiting on object close", zap.String("id", id)) - return nil, ctx.Err() - case <-e.close: - goto Load - } + reload, err := e.waitClose(ctx, id) + if err != nil { + return nil, err + } + if reload { + goto Load } - if load { go c.load(ctx, id, e) } - if c.metrics != nil { - if load { - c.metrics.miss.Inc() - } else { - c.metrics.hit.Inc() - } + c.metricsGet(!load) + return e.waitLoad(ctx, id) +} + +func (c *oCache) metricsGet(hit bool) { + if c.metrics == nil { + return } - select { - case <-ctx.Done(): - log.DebugCtx(ctx, "ctx done while waiting on object load", zap.String("id", id)) - return nil, ctx.Err() - case <-e.load: + if hit { + c.metrics.hit.Inc() + } else { + c.metrics.miss.Inc() } - return e.value, e.loadErr } func (c *oCache) Pick(ctx context.Context, id string) (value Object, err error) { c.mu.Lock() val, ok := c.data[id] - if !ok || val.isClosing { + if !ok || val.isClosing() { c.mu.Unlock() return nil, ErrNotExists } c.mu.Unlock() - - if c.metrics != nil { - c.metrics.hit.Inc() - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-val.load: - return val.value, val.loadErr - } + c.metricsGet(true) + return val.waitLoad(ctx, id) } func (c *oCache) load(ctx context.Context, id string, e *entry) { @@ -236,37 +287,10 @@ func (c *oCache) load(ctx context.Context, id string, e *entry) { delete(c.data, id) } else { e.value = value + e.setActive(false) } } -func (c *oCache) Release(id string) bool { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return false - } - if e, ok := c.data[id]; ok { - if c.refCounter && e.refCount > 0 { - e.refCount-- - return true - } - } - return false -} - -func (c *oCache) Reset(id string) bool { - c.mu.Lock() - defer c.mu.Unlock() - if c.closed { - return false - } - if e, ok := c.data[id]; ok { - e.refCount = 0 - return true - } - return false -} - func (c *oCache) Remove(id string) (ok bool, err error) { c.mu.Lock() if c.closed { @@ -274,25 +298,33 @@ func (c *oCache) Remove(id string) (ok bool, err error) { err = ErrClosed return } - var e *entry - e, ok = c.data[id] - if !ok || e.isClosing { + e, ok := c.data[id] + if !ok { c.mu.Unlock() return } - e.isClosing = true - e.close = make(chan struct{}) c.mu.Unlock() + return c.remove(e, true) +} +func (c *oCache) remove(e *entry, remData bool) (ok bool, err error) { <-e.load - if e.value != nil { + if e.value == nil { + return false, ErrNotExists + } + prevState := e.setClosing(true) + if prevState == entryStateActive { err = e.value.Close() + e.setClosed() + } + if !remData { + return } c.mu.Lock() - close(e.close) - delete(c.data, e.id) + if prevState == entryStateActive { + delete(c.data, e.id) + } c.mu.Unlock() - return } @@ -314,13 +346,7 @@ func (c *oCache) Add(id string, value Object) (err error) { if _, ok := c.data[id]; ok { return ErrExists } - e := &entry{ - id: id, - lastUsage: time.Now(), - refCount: 0, - load: make(chan struct{}), - value: value, - } + e := newEntry(id, value, entryStateActive) close(e.load) c.data[id] = e return @@ -332,7 +358,7 @@ func (c *oCache) ForEach(f func(obj Object) (isContinue bool)) { for _, v := range c.data { select { case <-v.load: - if v.value != nil && !v.isClosing { + if v.value != nil && !v.isClosing() { objects = append(objects, v.value) } default: @@ -368,15 +394,10 @@ func (c *oCache) GC() { deadline := c.timeNow().Add(-c.ttl) var toClose []*entry for _, e := range c.data { - if e.isClosing { + if e.getState() != entryStateActive { continue } - lu := e.lastUsage - if lug, ok := e.value.(ObjectLastUsage); ok { - lu = lug.LastUsage() - } - if !e.locked() && e.refCount <= 0 && lu.Before(deadline) { - e.isClosing = true + if e.lastUsage.Before(deadline) { e.close = make(chan struct{}) toClose = append(toClose, e) } @@ -384,21 +405,33 @@ func (c *oCache) GC() { size := len(c.data) c.mu.Unlock() - for _, e := range toClose { - <-e.load - if e.value != nil { - if err := e.value.Close(); err != nil { - c.log.With("object_id", e.id).Warnf("GC: object close error: %v", err) - } + for idx, e := range toClose { + prevState := e.setClosing(false) + if prevState == entryStateClosing || prevState == entryStateClosed { + toClose[idx] = nil + continue + } + ok, err := e.value.TryClose() + if !ok { + e.setActive(true) + toClose[idx] = nil + continue + } else { + e.setClosed() + } + if err != nil { + c.log.With("object_id", e.id).Warnf("GC: object close error: %v", err) } } + toClose = slice.DiscardFromSlice(toClose, func(e *entry) bool { + return e == nil + }) c.log.Infof("GC: removed %d; cache size: %d", len(toClose), size) if len(toClose) > 0 && c.metrics != nil { c.metrics.gc.Add(float64(len(toClose))) } c.mu.Lock() for _, e := range toClose { - close(e.close) delete(c.data, e.id) } c.mu.Unlock() @@ -418,25 +451,15 @@ func (c *oCache) Close() (err error) { } c.closed = true close(c.closeCh) - var toClose, alreadyClosing []*entry + var toClose []*entry for _, e := range c.data { - if e.isClosing { - alreadyClosing = append(alreadyClosing, e) - } else { - toClose = append(toClose, e) - } + toClose = append(toClose, e) } c.mu.Unlock() for _, e := range toClose { - <-e.load - if e.value != nil { - if clErr := e.value.Close(); clErr != nil { - c.log.With("object_id", e.id).Warnf("cache close: object close error: %v", clErr) - } + if _, err := c.remove(e, false); err != ErrNotExists { + c.log.With("object_id", e.id).Warnf("cache close: object close error: %v", err) } } - for _, e := range alreadyClosing { - <-e.close - } return nil } diff --git a/app/ocache/ocache_test.go b/app/ocache/ocache_test.go index 54034141..d32b01f7 100644 --- a/app/ocache/ocache_test.go +++ b/app/ocache/ocache_test.go @@ -12,15 +12,17 @@ import ( ) type testObject struct { - name string - closeErr error - closeCh chan struct{} + name string + closeErr error + closeCh chan struct{} + tryReturn bool } -func NewTestObject(name string, closeCh chan struct{}) *testObject { +func NewTestObject(name string, tryReturn bool, closeCh chan struct{}) *testObject { return &testObject{ - name: name, - closeCh: closeCh, + name: name, + closeCh: closeCh, + tryReturn: tryReturn, } } @@ -31,6 +33,14 @@ func (t *testObject) Close() (err error) { return t.closeErr } +func (t *testObject) TryClose() (res bool, err error) { + if t.closeCh != nil { + <-t.closeCh + return true, t.closeErr + } + return t.tryReturn, nil +} + func TestOCache_Get(t *testing.T) { t.Run("successful", func(t *testing.T) { c := New(func(ctx context.Context, id string) (value Object, err error) { @@ -118,8 +128,8 @@ func TestOCache_Get(t *testing.T) { func TestOCache_GC(t *testing.T) { t.Run("test without close wait", func(t *testing.T) { c := New(func(ctx context.Context, id string) (value Object, err error) { - return &testObject{name: id}, nil - }, WithTTL(time.Millisecond*10), WithRefCounter(true)) + return NewTestObject(id, true, nil), nil + }, WithTTL(time.Millisecond*10)) val, err := c.Get(context.TODO(), "id") require.NoError(t, err) require.NotNil(t, val) @@ -128,24 +138,19 @@ func TestOCache_GC(t *testing.T) { assert.Equal(t, 1, c.Len()) time.Sleep(time.Millisecond * 30) c.GC() - assert.Equal(t, 1, c.Len()) - assert.True(t, c.Release("id")) - c.GC() assert.Equal(t, 0, c.Len()) - assert.False(t, c.Release("id")) }) t.Run("test with close wait", func(t *testing.T) { closeCh := make(chan struct{}) getCh := make(chan struct{}) c := New(func(ctx context.Context, id string) (value Object, err error) { - return NewTestObject(id, closeCh), nil - }, WithTTL(time.Millisecond*10), WithRefCounter(true)) + return NewTestObject(id, true, closeCh), nil + }, WithTTL(time.Millisecond*10)) val, err := c.Get(context.TODO(), "id") require.NoError(t, err) require.NotNil(t, val) assert.Equal(t, 1, c.Len()) - assert.True(t, c.Release("id")) // making ttl pass time.Sleep(time.Millisecond * 40) // first gc will be run after 20 secs, so calling it manually @@ -160,9 +165,9 @@ func TestOCache_GC(t *testing.T) { events = append(events, "get") close(getCh) }() - events = append(events, "close") // sleeping to make sure that Get is called time.Sleep(time.Millisecond * 40) + events = append(events, "close") close(closeCh) <-getCh @@ -175,7 +180,7 @@ func Test_OCache_Remove(t *testing.T) { getCh := make(chan struct{}) c := New(func(ctx context.Context, id string) (value Object, err error) { - return NewTestObject(id, closeCh), nil + return NewTestObject(id, false, closeCh), nil }, WithTTL(time.Millisecond*10)) val, err := c.Get(context.TODO(), "id") require.NoError(t, err) @@ -196,9 +201,9 @@ func Test_OCache_Remove(t *testing.T) { events = append(events, "get") close(getCh) }() - events = append(events, "close") // sleeping to make sure that Get is called time.Sleep(time.Millisecond * 40) + events = append(events, "close") close(closeCh) <-getCh