import asyncio import pytest import pytest_asyncio from asyncio import streams import redis from typing import List from selfprivacy_api.utils.redis_pool import RedisPool TEST_KEY = "test:test" STOPWORD = "STOP" @pytest.fixture() def empty_redis(event_loop): r = RedisPool().get_connection() r.flushdb() yield r r.flushdb() async def write_to_test_key(): r = RedisPool().get_connection_async() async with r.pipeline(transaction=True) as pipe: ok1, ok2 = await pipe.set(TEST_KEY, "value1").set(TEST_KEY, "value2").execute() assert ok1 assert ok2 assert await r.get(TEST_KEY) == "value2" await r.close() def test_async_connection(empty_redis): r = RedisPool().get_connection() assert not r.exists(TEST_KEY) # It _will_ report an error if it arises asyncio.run(write_to_test_key()) # Confirming that we can read result from sync connection too assert r.get(TEST_KEY) == "value2" async def channel_reader(channel: redis.client.PubSub) -> List[dict]: result: List[dict] = [] while True: # Mypy cannot correctly detect that it is a coroutine # But it is message: dict = await channel.get_message(ignore_subscribe_messages=True, timeout=None) # type: ignore if message is not None: result.append(message) if message["data"] == STOPWORD: break return result @pytest.mark.asyncio async def test_pubsub(empty_redis, event_loop): # Adapted from : # https://redis.readthedocs.io/en/stable/examples/asyncio_examples.html # Sanity checking because of previous event loop bugs assert event_loop == asyncio.get_event_loop() assert event_loop == asyncio.events.get_event_loop() assert event_loop == asyncio.events._get_event_loop() assert event_loop == asyncio.events.get_running_loop() reader = streams.StreamReader(34) assert event_loop == reader._loop f = reader._loop.create_future() f.set_result(3) await f r = RedisPool().get_connection_async() async with r.pubsub() as pubsub: await pubsub.subscribe("channel:1") future = asyncio.create_task(channel_reader(pubsub)) await r.publish("channel:1", "Hello") # message: dict = await pubsub.get_message(ignore_subscribe_messages=True, timeout=5.0) # type: ignore # raise ValueError(message) await r.publish("channel:1", "World") await r.publish("channel:1", STOPWORD) messages = await future assert len(messages) == 3 message = messages[0] assert "data" in message.keys() assert message["data"] == "Hello" message = messages[1] assert "data" in message.keys() assert message["data"] == "World" message = messages[2] assert "data" in message.keys() assert message["data"] == STOPWORD await r.close()