diff --git a/tests/unit/core/test_utils.py b/tests/unit/core/test_utils.py index b4db86e8f..fb783628c 100644 --- a/tests/unit/core/test_utils.py +++ b/tests/unit/core/test_utils.py @@ -1,7 +1,8 @@ # -*- coding: utf-8 -*- -from lbrynet import utils - import unittest +import asyncio +from lbrynet import utils +from torba.testcase import AsyncioTestCase class CompareVersionTest(unittest.TestCase): @@ -61,3 +62,76 @@ class SdHashTests(unittest.TestCase): } } self.assertIsNone(utils.get_sd_hash(claim)) + + +class CacheConcurrentDecoratorTests(AsyncioTestCase): + def setUp(self): + self.called = [] + self.finished = [] + self.counter = 0 + + @utils.cache_concurrent + async def foo(self, arg1, arg2=None, delay=1): + self.called.append((arg1, arg2, delay)) + await asyncio.sleep(delay, loop=self.loop) + self.counter += 1 + self.finished.append((arg1, arg2, delay)) + return object() + + async def test_gather_duplicates(self): + result = await asyncio.gather( + self.loop.create_task(self.foo(1)), self.loop.create_task(self.foo(1)), loop=self.loop + ) + self.assertEqual(1, len(self.called)) + self.assertEqual(1, len(self.finished)) + self.assertEqual(1, self.counter) + self.assertIs(result[0], result[1]) + self.assertEqual(2, len(result)) + + async def test_one_cancelled_all_cancel(self): + t1 = self.loop.create_task(self.foo(1)) + self.loop.call_later(0.1, t1.cancel) + + with self.assertRaises(asyncio.CancelledError): + await asyncio.gather( + t1, self.loop.create_task(self.foo(1)), loop=self.loop + ) + self.assertEqual(1, len(self.called)) + self.assertEqual(0, len(self.finished)) + self.assertEqual(0, self.counter) + + async def test_error_after_success(self): + def cause_type_error(): + self.counter = "" + + self.loop.call_later(0.1, cause_type_error) + + t1 = self.loop.create_task(self.foo(1)) + t2 = self.loop.create_task(self.foo(1)) + + with self.assertRaises(TypeError): + await t2 + self.assertEqual(1, len(self.called)) + self.assertEqual(0, len(self.finished)) + self.assertTrue(t1.done()) + self.assertEqual("", self.counter) + + # test that the task is run fresh, it should not error + self.counter = 0 + t3 = self.loop.create_task(self.foo(1)) + self.assertTrue((await t3)) + self.assertEqual(1, self.counter) + + # the previously failed call should still raise if awaited + with self.assertRaises(TypeError): + await t1 + + self.assertEqual(1, self.counter) + + async def test_break_it(self): + t1 = self.loop.create_task(self.foo(1)) + t2 = self.loop.create_task(self.foo(1)) + t3 = self.loop.create_task(self.foo(2, delay=0)) + t3.add_done_callback(lambda _: t2.cancel()) + with self.assertRaises(asyncio.CancelledError): + await asyncio.gather(t1, t2, t3)