diff --git a/selfprivacy_api/utils/redis_pool.py b/selfprivacy_api/utils/redis_pool.py index 04ccb51..ea827d1 100644 --- a/selfprivacy_api/utils/redis_pool.py +++ b/selfprivacy_api/utils/redis_pool.py @@ -9,13 +9,15 @@ from selfprivacy_api.utils.singleton_metaclass import SingletonMetaclass REDIS_SOCKET = "/run/redis-sp-api/redis.sock" -class RedisPool(metaclass=SingletonMetaclass): +# class RedisPool(metaclass=SingletonMetaclass): +class RedisPool: """ Redis connection pool singleton. """ def __init__(self): - url = RedisPool.connection_url(dbnumber=0) + self._dbnumber = 0 + url = RedisPool.connection_url(dbnumber=self._dbnumber) # We need a normal sync pool because otherwise # our whole API will need to be async self._pool = redis.ConnectionPool.from_url( @@ -48,3 +50,9 @@ class RedisPool(metaclass=SingletonMetaclass): Async connections allow pubsub. """ return redis_async.Redis(connection_pool=self._async_pool) + + async def subscribe_to_keys(self, pattern: str) -> redis_async.client.PubSub: + async_redis = self.get_connection_async() + pubsub = async_redis.pubsub() + await pubsub.psubscribe(f"__keyspace@{self._dbnumber}__:" + pattern) + return pubsub diff --git a/tests/test_redis.py b/tests/test_redis.py index 48ec56e..2def280 100644 --- a/tests/test_redis.py +++ b/tests/test_redis.py @@ -1,13 +1,18 @@ import asyncio import pytest +import pytest_asyncio +from asyncio import streams +import redis +from typing import List from selfprivacy_api.utils.redis_pool import RedisPool TEST_KEY = "test:test" +STOPWORD = "STOP" @pytest.fixture() -def empty_redis(): +def empty_redis(event_loop): r = RedisPool().get_connection() r.flushdb() yield r @@ -31,3 +36,60 @@ def test_async_connection(empty_redis): asyncio.run(write_to_test_key()) # Confirming that we can read result from sync connection too assert r.get(TEST_KEY) == "value2" + + +async def channel_reader(channel: redis.client.PubSub) -> List[dict]: + result: List[dict] = [] + while True: + # Mypy cannot correctly detect that it is a coroutine + # But it is + message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore + if message is not None: + result.append(message) + if message["data"] == STOPWORD: + break + return result + + +@pytest.mark.asyncio +async def test_pubsub(empty_redis, event_loop): + # Adapted from : + # https://redis.readthedocs.io/en/stable/examples/asyncio_examples.html + # Sanity checking because of previous event loop bugs + assert event_loop == asyncio.get_event_loop() + assert event_loop == asyncio.events.get_event_loop() + assert event_loop == asyncio.events._get_event_loop() + assert event_loop == asyncio.events.get_running_loop() + + reader = streams.StreamReader(34) + assert event_loop == reader._loop + f = reader._loop.create_future() + f.set_result(3) + await f + + r = RedisPool().get_connection_async() + async with r.pubsub() as pubsub: + await pubsub.subscribe("channel:1") + future = asyncio.create_task(channel_reader(pubsub)) + + await r.publish("channel:1", "Hello") + # message: dict = await pubsub.get_message(ignore_subscribe_messages=True, timeout=5.0) # type: ignore + # raise ValueError(message) + await r.publish("channel:1", "World") + await r.publish("channel:1", STOPWORD) + + messages = await future + + assert len(messages) == 3 + + message = messages[0] + assert "data" in message.keys() + assert message["data"] == "Hello" + message = messages[1] + assert "data" in message.keys() + assert message["data"] == "World" + message = messages[2] + assert "data" in message.keys() + assert message["data"] == STOPWORD + + await r.close()