diff --git a/selfprivacy_api/graphql/schema.py b/selfprivacy_api/graphql/schema.py index c6cf46b..3280396 100644 --- a/selfprivacy_api/graphql/schema.py +++ b/selfprivacy_api/graphql/schema.py @@ -136,10 +136,15 @@ class Mutation( # A cruft for Websockets -def authenticated(info) -> bool: +def authenticated(info: strawberry.types.Info) -> bool: return IsAuthenticated().has_permission(source=None, info=info) +def reject_if_unauthenticated(info: strawberry.types.Info): + if not authenticated(info): + raise Exception(IsAuthenticated().message) + + @strawberry.type class Subscription: """Root schema for subscriptions. @@ -151,19 +156,15 @@ class Subscription: async def job_updates( self, info: strawberry.types.Info ) -> AsyncGenerator[List[ApiJob], None]: - if not authenticated(info): - raise Exception(IsAuthenticated().message) + reject_if_unauthenticated(info) # Send the complete list of jobs every time anything gets updated async for notification in job_notifications(): yield get_all_jobs() - # @strawberry.subscription - # async def job_updates(self) -> AsyncGenerator[List[ApiJob], None]: - # return job_updates() - @strawberry.subscription - async def count(self) -> AsyncGenerator[int, None]: + async def count(self, info: strawberry.types.Info) -> AsyncGenerator[int, None]: + reject_if_unauthenticated(info) for i in range(10): yield i await asyncio.sleep(0.5) diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index 5a92416..49cc944 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -106,41 +106,61 @@ def test_websocket_graphql_ping(authorized_client): assert pong == {"type": "pong"} +def api_subscribe(websocket, id, subscription): + websocket.send_json( + { + "id": id, + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription {" + subscription + "}", + }, + } + ) + + def test_websocket_subscription_minimal(authorized_client): + # Test a small endpoint that exists specifically for tests client = authorized_client with client.websocket_connect( "/graphql", subprotocols=["graphql-transport-ws"] ) as websocket: init_graphql(websocket) - websocket.send_json( - { - "id": "3aaa2445", - "type": "subscribe", - "payload": { - "query": "subscription TestSubscription {count}", - }, - } - ) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, "count") response = websocket.receive_json() assert response == { - "id": "3aaa2445", + "id": arbitrary_id, "payload": {"data": {"count": 0}}, "type": "next", } response = websocket.receive_json() assert response == { - "id": "3aaa2445", + "id": arbitrary_id, "payload": {"data": {"count": 1}}, "type": "next", } response = websocket.receive_json() assert response == { - "id": "3aaa2445", + "id": arbitrary_id, "payload": {"data": {"count": 2}}, "type": "next", } +def test_websocket_subscription_minimal_unauthorized(unauthenticated_websocket): + websocket = unauthenticated_websocket + init_graphql(websocket) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, "count") + + response = websocket.receive_json() + assert response == { + "id": arbitrary_id, + "payload": [{"message": IsAuthenticated.message}], + "type": "error", + } + + async def read_one_job(websocket): # bug? We only get them starting from the second job update # that's why we receive two jobs in the list them @@ -153,15 +173,9 @@ async def read_one_job(websocket): async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs): websocket = authenticated_websocket init_graphql(websocket) - websocket.send_json( - { - "id": "3aaa2445", - "type": "subscribe", - "payload": { - "query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}", - }, - } - ) + arbitrary_id = "3aaa2445" + api_subscribe(websocket, arbitrary_id, JOBS_SUBSCRIPTION) + future = asyncio.create_task(read_one_job(websocket)) jobs = [] jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works")) @@ -197,19 +211,12 @@ async def test_websocket_subscription(authenticated_websocket, event_loop, empty def test_websocket_subscription_unauthorized(unauthenticated_websocket): websocket = unauthenticated_websocket init_graphql(websocket) - websocket.send_json( - { - "id": "3aaa2445", - "type": "subscribe", - "payload": { - "query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}", - }, - } - ) + id = "3aaa2445" + api_subscribe(websocket, id, JOBS_SUBSCRIPTION) response = websocket.receive_json() assert response == { - "id": "3aaa2445", + "id": id, "payload": [{"message": IsAuthenticated.message}], "type": "error", }