From 9836caf2cb1127a0e94cb54cd8b9ed155957708d Mon Sep 17 00:00:00 2001 From: mcrakhman Date: Thu, 9 Mar 2023 09:06:02 +0100 Subject: [PATCH] More tests and split entry --- app/ocache/entry.go | 120 ++++++++++++++++++++++++++++++++++++++ app/ocache/ocache.go | 118 ++----------------------------------- app/ocache/ocache_test.go | 51 +++++++++++++--- 3 files changed, 168 insertions(+), 121 deletions(-) create mode 100644 app/ocache/entry.go diff --git a/app/ocache/entry.go b/app/ocache/entry.go new file mode 100644 index 00000000..ad985fa0 --- /dev/null +++ b/app/ocache/entry.go @@ -0,0 +1,120 @@ +package ocache + +import ( + "context" + "go.uber.org/zap" + "sync" + "time" +) + +type entryState int + +const ( + entryStateLoading = iota + entryStateActive + entryStateClosing + entryStateClosed +) + +type entry struct { + id string + state entryState + lastUsage time.Time + load chan struct{} + loadErr error + value Object + close chan struct{} + mx sync.Mutex +} + +func newEntry(id string, value Object, state entryState) *entry { + return &entry{ + id: id, + load: make(chan struct{}), + lastUsage: time.Now(), + state: state, + value: value, + } +} + +func (e *entry) getState() entryState { + e.mx.Lock() + defer e.mx.Unlock() + return e.state +} + +func (e *entry) isClosing() bool { + e.mx.Lock() + defer e.mx.Unlock() + return e.state == entryStateClosed || e.state == entryStateClosing +} + +func (e *entry) waitLoad(ctx context.Context, id string) (value Object, err error) { + select { + case <-ctx.Done(): + log.DebugCtx(ctx, "ctx done while waiting on object load", zap.String("id", id)) + return nil, ctx.Err() + case <-e.load: + return e.value, e.loadErr + } +} + +func (e *entry) waitClose(ctx context.Context, id string) (res bool, err error) { + e.mx.Lock() + switch e.state { + case entryStateClosing: + waitCh := e.close + e.mx.Unlock() + select { + case <-ctx.Done(): + log.DebugCtx(ctx, "ctx done while waiting on object close", zap.String("id", id)) + return false, ctx.Err() + case <-waitCh: + return true, nil + } + case entryStateClosed: + e.mx.Unlock() + return true, nil + default: + e.mx.Unlock() + return false, nil + } +} + +func (e *entry) setClosing(wait bool) (prevState, curState entryState) { + e.mx.Lock() + prevState = e.state + curState = e.state + if e.state == entryStateClosing { + waitCh := e.close + e.mx.Unlock() + if !wait { + return + } + <-waitCh + e.mx.Lock() + } + if e.state != entryStateClosed { + e.state = entryStateClosing + e.close = make(chan struct{}) + } + curState = e.state + e.mx.Unlock() + return +} + +func (e *entry) setActive(chClose bool) { + e.mx.Lock() + defer e.mx.Unlock() + if chClose { + close(e.close) + } + e.state = entryStateActive +} + +func (e *entry) setClosed() { + e.mx.Lock() + defer e.mx.Unlock() + close(e.close) + e.state = entryStateClosed +} diff --git a/app/ocache/ocache.go b/app/ocache/ocache.go index eb67f120..4599f789 100644 --- a/app/ocache/ocache.go +++ b/app/ocache/ocache.go @@ -71,116 +71,6 @@ type Object interface { TryClose() (res bool, err error) } -type entryState int - -const ( - entryStateLoading = iota - entryStateActive - entryStateClosing - entryStateClosed -) - -type entry struct { - id string - state entryState - lastUsage time.Time - load chan struct{} - loadErr error - value Object - close chan struct{} - mx sync.Mutex -} - -func newEntry(id string, value Object, state entryState) *entry { - return &entry{ - id: id, - load: make(chan struct{}), - lastUsage: time.Now(), - state: state, - value: value, - } -} - -func (e *entry) getState() entryState { - e.mx.Lock() - defer e.mx.Unlock() - return e.state -} - -func (e *entry) isClosing() bool { - e.mx.Lock() - defer e.mx.Unlock() - return e.state == entryStateClosed || e.state == entryStateClosing -} - -func (e *entry) waitLoad(ctx context.Context, id string) (value Object, err error) { - select { - case <-ctx.Done(): - log.DebugCtx(ctx, "ctx done while waiting on object load", zap.String("id", id)) - return nil, ctx.Err() - case <-e.load: - return e.value, e.loadErr - } -} - -func (e *entry) waitClose(ctx context.Context, id string) (res bool, err error) { - e.mx.Lock() - switch e.state { - case entryStateClosing: - waitCh := e.close - e.mx.Unlock() - select { - case <-ctx.Done(): - log.DebugCtx(ctx, "ctx done while waiting on object close", zap.String("id", id)) - return false, ctx.Err() - case <-waitCh: - return true, nil - } - case entryStateClosed: - e.mx.Unlock() - return true, nil - default: - e.mx.Unlock() - return false, nil - } -} - -func (e *entry) setClosing(wait bool) (prevState entryState) { - e.mx.Lock() - prevState = e.state - if e.state == entryStateClosing { - waitCh := e.close - e.mx.Unlock() - if !wait { - return - } - <-waitCh - e.mx.Lock() - } - if e.state != entryStateClosed { - e.state = entryStateClosing - e.close = make(chan struct{}) - } - e.mx.Unlock() - return -} - -func (e *entry) setActive(chClose bool) { - e.mx.Lock() - defer e.mx.Unlock() - if chClose { - close(e.close) - } - e.state = entryStateActive -} - -func (e *entry) setClosed() { - e.mx.Lock() - defer e.mx.Unlock() - close(e.close) - e.state = entryStateClosed -} - type OCache interface { // DoLockedIfNotExists does an action if the object with id is not in cache // under a global lock, this will prevent a race which otherwise occurs @@ -312,8 +202,8 @@ func (c *oCache) remove(e *entry, remData bool) (ok bool, err error) { if e.value == nil { return false, ErrNotExists } - prevState := e.setClosing(true) - if prevState == entryStateActive { + _, curState := e.setClosing(true) + if curState == entryStateClosing { err = e.value.Close() e.setClosed() } @@ -321,7 +211,7 @@ func (c *oCache) remove(e *entry, remData bool) (ok bool, err error) { return } c.mu.Lock() - if prevState == entryStateActive { + if curState == entryStateClosing { delete(c.data, e.id) } c.mu.Unlock() @@ -406,7 +296,7 @@ func (c *oCache) GC() { c.mu.Unlock() for idx, e := range toClose { - prevState := e.setClosing(false) + prevState, _ := e.setClosing(false) if prevState == entryStateClosing || prevState == entryStateClosed { toClose[idx] = nil continue diff --git a/app/ocache/ocache_test.go b/app/ocache/ocache_test.go index 8d963214..43773184 100644 --- a/app/ocache/ocache_test.go +++ b/app/ocache/ocache_test.go @@ -3,6 +3,7 @@ package ocache import ( "context" "errors" + "fmt" "sync" "sync/atomic" "testing" @@ -178,7 +179,7 @@ func TestOCache_GC(t *testing.T) { <-getCh require.Equal(t, []string{"close", "get"}, events) }) - t.Run("test gc tryClose false, many get", func(t *testing.T) { + t.Run("test gc tryClose false, many parallel get", func(t *testing.T) { timesCalled := &atomic.Int32{} obj := NewTestObject("id", false, nil) c := New(func(ctx context.Context, id string) (value Object, err error) { @@ -190,9 +191,7 @@ func TestOCache_GC(t *testing.T) { require.NoError(t, err) require.NotNil(t, val) assert.Equal(t, 1, c.Len()) - // making ttl pass time.Sleep(time.Millisecond * 40) - // first gc will be run after 20 secs, so calling it manually begin := make(chan struct{}) wg := sync.WaitGroup{} once := sync.Once{} @@ -203,15 +202,14 @@ func TestOCache_GC(t *testing.T) { c.GC() wg.Done() }() - - for i := 0; i < 5; i++ { + for i := 0; i < 50; i++ { wg.Add(1) go func(i int) { once.Do(func() { close(begin) }) - if i > 0 { - time.Sleep(time.Duration(i) * time.Millisecond) + if i%2 != 0 { + time.Sleep(time.Millisecond) } _, err := c.Get(context.TODO(), "id") require.NoError(t, err) @@ -223,6 +221,45 @@ func TestOCache_GC(t *testing.T) { require.Equal(t, timesCalled.Load(), int32(1)) require.True(t, obj.tryCloseCalled) }) + t.Run("test gc tryClose different, many objects", func(t *testing.T) { + tryCloseIds := make(map[string]bool) + called := make(map[string]int) + max := 1000 + getId := func(i int) string { + return fmt.Sprintf("id%d", i) + } + for i := 0; i < max; i++ { + if i%2 == 1 { + tryCloseIds[getId(i)] = true + } else { + tryCloseIds[getId(i)] = false + } + } + c := New(func(ctx context.Context, id string) (value Object, err error) { + called[id] = called[id] + 1 + return NewTestObject(id, tryCloseIds[id], nil), nil + }, WithTTL(time.Millisecond*10)) + + for i := 0; i < max; i++ { + val, err := c.Get(context.TODO(), getId(i)) + require.NoError(t, err) + require.NotNil(t, val) + } + assert.Equal(t, max, c.Len()) + time.Sleep(time.Millisecond * 40) + c.GC() + for i := 0; i < max; i++ { + val, err := c.Get(context.TODO(), getId(i)) + require.NoError(t, err) + require.NotNil(t, val) + } + for i := 0; i < max; i++ { + val, err := c.Get(context.TODO(), getId(i)) + require.NoError(t, err) + require.NotNil(t, val) + require.Equal(t, called[getId(i)], i%2+1) + } + }) } func Test_OCache_Remove(t *testing.T) {