diff --git a/pkg/ocache/ocache.go b/pkg/ocache/ocache.go index 998faf6d..c19c398c 100644 --- a/pkg/ocache/ocache.go +++ b/pkg/ocache/ocache.go @@ -86,10 +86,10 @@ type entry struct { id string lastUsage time.Time refCount uint32 + isClosing bool load chan struct{} loadErr error value Object - isClosing bool close chan struct{} } diff --git a/pkg/ocache/ocache_test.go b/pkg/ocache/ocache_test.go index 55bf2e95..cebed09a 100644 --- a/pkg/ocache/ocache_test.go +++ b/pkg/ocache/ocache_test.go @@ -14,16 +14,23 @@ import ( type testObject struct { name string closeErr error + closeCh chan struct{} +} + +func NewTestObject(name string, closeCh chan struct{}) *testObject { + return &testObject{ + name: name, + closeCh: closeCh, + } } func (t *testObject) Close() (err error) { + if t.closeCh != nil { + <-t.closeCh + } return t.closeErr } -func (t *testObject) ShouldClose() bool { - return true -} - func TestOCache_Get(t *testing.T) { t.Run("successful", func(t *testing.T) { c := New(func(ctx context.Context, id string) (value Object, err error) { @@ -109,20 +116,91 @@ func TestOCache_Get(t *testing.T) { } func TestOCache_GC(t *testing.T) { + t.Run("test without close wait", func(t *testing.T) { + c := New(func(ctx context.Context, id string) (value Object, err error) { + return &testObject{name: id}, nil + }, WithTTL(time.Millisecond*10)) + val, err := c.Get(context.TODO(), "id") + require.NoError(t, err) + require.NotNil(t, val) + assert.Equal(t, 1, c.Len()) + c.GC() + assert.Equal(t, 1, c.Len()) + time.Sleep(time.Millisecond * 30) + c.GC() + assert.Equal(t, 1, c.Len()) + assert.True(t, c.Release("id")) + c.GC() + assert.Equal(t, 0, c.Len()) + assert.False(t, c.Release("id")) + }) + t.Run("test with close wait", func(t *testing.T) { + closeCh := make(chan struct{}) + getCh := make(chan struct{}) + + c := New(func(ctx context.Context, id string) (value Object, err error) { + return NewTestObject(id, closeCh), nil + }, WithTTL(time.Millisecond*10)) + val, err := c.Get(context.TODO(), "id") + require.NoError(t, err) + require.NotNil(t, val) + assert.Equal(t, 1, c.Len()) + assert.True(t, c.Release("id")) + // making ttl pass + time.Sleep(time.Millisecond * 40) + // first gc will be run after 20 secs, so calling it manually + go c.GC() + // waiting until all objects are marked as closing + time.Sleep(time.Millisecond * 40) + var events []string + go func() { + _, err := c.Get(context.TODO(), "id") + require.NoError(t, err) + require.NotNil(t, val) + events = append(events, "get") + close(getCh) + }() + events = append(events, "close") + // sleeping to make sure that Get is called + time.Sleep(time.Millisecond * 40) + close(closeCh) + + <-getCh + require.Equal(t, []string{"close", "get"}, events) + }) +} + +func Test_OCache_Remove(t *testing.T) { + closeCh := make(chan struct{}) + getCh := make(chan struct{}) + c := New(func(ctx context.Context, id string) (value Object, err error) { - return &testObject{name: id}, nil + return NewTestObject(id, closeCh), nil }, WithTTL(time.Millisecond*10)) val, err := c.Get(context.TODO(), "id") require.NoError(t, err) require.NotNil(t, val) assert.Equal(t, 1, c.Len()) - c.GC() - assert.Equal(t, 1, c.Len()) - time.Sleep(time.Millisecond * 30) - c.GC() - assert.Equal(t, 1, c.Len()) - assert.True(t, c.Release("id")) - c.GC() - assert.Equal(t, 0, c.Len()) - assert.False(t, c.Release("id")) + // removing the object, so we will wait on closing + go func() { + _, err := c.Remove("id") + require.NoError(t, err) + }() + time.Sleep(time.Millisecond * 40) + + var events []string + go func() { + _, err := c.Get(context.TODO(), "id") + require.NoError(t, err) + require.NotNil(t, val) + events = append(events, "get") + close(getCh) + }() + events = append(events, "close") + // sleeping to make sure that Get is called + time.Sleep(time.Millisecond * 40) + close(closeCh) + + <-getCh + require.Equal(t, []string{"close", "get"}, events) }