diff --git a/selfprivacy_api/utils/redis_pool.py b/selfprivacy_api/utils/redis_pool.py index 3d35f01..04ccb51 100644 --- a/selfprivacy_api/utils/redis_pool.py +++ b/selfprivacy_api/utils/redis_pool.py @@ -2,6 +2,7 @@ Redis pool module for selfprivacy_api """ import redis +import redis.asyncio as redis_async from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass @@ -14,11 +15,18 @@ class RedisPool(metaclass=SingletonMetaclass): """ def __init__(self): + url = RedisPool.connection_url(dbnumber=0) + # We need a normal sync pool because otherwise + # our whole API will need to be async self._pool = redis.ConnectionPool.from_url( - RedisPool.connection_url(dbnumber=0), + url, + decode_responses=True, + ) + # We need an async pool for pubsub + self._async_pool = redis_async.ConnectionPool.from_url( + url, decode_responses=True, ) - self._pubsub_connection = self.get_connection() @staticmethod def connection_url(dbnumber: int) -> str: @@ -34,8 +42,9 @@ class RedisPool(metaclass=SingletonMetaclass): """ return redis.Redis(connection_pool=self._pool) - def get_pubsub(self): + def get_connection_async(self) -> redis_async.Redis: """ - Get a pubsub connection from the pool. + Get an async connection from the pool. + Async connections allow pubsub. """ - return self._pubsub_connection.pubsub() + return redis_async.Redis(connection_pool=self._async_pool) diff --git a/tests/test_redis.py b/tests/test_redis.py new file mode 100644 index 0000000..48ec56e --- /dev/null +++ b/tests/test_redis.py @@ -0,0 +1,33 @@ +import asyncio +import pytest + +from selfprivacy_api.utils.redis_pool import RedisPool + +TEST_KEY = "test:test" + + +@pytest.fixture() +def empty_redis(): + r = RedisPool().get_connection() + r.flushdb() + yield r + r.flushdb() + + +async def write_to_test_key(): + r = RedisPool().get_connection_async() + async with r.pipeline(transaction=True) as pipe: + ok1, ok2 = await pipe.set(TEST_KEY, "value1").set(TEST_KEY, "value2").execute() + assert ok1 + assert ok2 + assert await r.get(TEST_KEY) == "value2" + await r.close() + + +def test_async_connection(empty_redis): + r = RedisPool().get_connection() + assert not r.exists(TEST_KEY) + # It _will_ report an error if it arises + asyncio.run(write_to_test_key()) + # Confirming that we can read result from sync connection too + assert r.get(TEST_KEY) == "value2"