From cb641e4f37d1ec73595b058d49b93721adc2555d Mon Sep 17 00:00:00 2001 From: Houkime <> Date: Mon, 27 May 2024 20:21:11 +0000 Subject: [PATCH] feature(websocket): add auth --- selfprivacy_api/graphql/schema.py | 18 ++- tests/test_graphql/test_websocket.py | 169 ++++++++++++++++++--------- 2 files changed, 131 insertions(+), 56 deletions(-) diff --git a/selfprivacy_api/graphql/schema.py b/selfprivacy_api/graphql/schema.py index b8ed4e2..c6cf46b 100644 --- a/selfprivacy_api/graphql/schema.py +++ b/selfprivacy_api/graphql/schema.py @@ -4,6 +4,7 @@ import asyncio from typing import AsyncGenerator, List import strawberry + from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql.mutations.deprecated_mutations import ( DeprecatedApiMutations, @@ -134,12 +135,25 @@ class Mutation( ) +# A cruft for Websockets +def authenticated(info) -> bool: + return IsAuthenticated().has_permission(source=None, info=info) + + @strawberry.type class Subscription: - """Root schema for subscriptions""" + """Root schema for subscriptions. + Every field here should be an AsyncIterator or AsyncGenerator + It is not a part of the spec but graphql-core (dep of strawberryql) + demands it while the spec is vague in this area.""" @strawberry.subscription - async def job_updates(self) -> AsyncGenerator[List[ApiJob], None]: + async def job_updates( + self, info: strawberry.types.Info + ) -> AsyncGenerator[List[ApiJob], None]: + if not authenticated(info): + raise Exception(IsAuthenticated().message) + # Send the complete list of jobs every time anything gets updated async for notification in job_notifications(): yield get_all_jobs() diff --git a/tests/test_graphql/test_websocket.py b/tests/test_graphql/test_websocket.py index ee33262..5a92416 100644 --- a/tests/test_graphql/test_websocket.py +++ b/tests/test_graphql/test_websocket.py @@ -1,13 +1,20 @@ -from tests.common import generate_jobs_subscription - # from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions import pytest import asyncio - -from selfprivacy_api.jobs import Jobs +from typing import Generator from time import sleep -from tests.test_redis import empty_redis +from starlette.testclient import WebSocketTestSession + +from selfprivacy_api.jobs import Jobs +from selfprivacy_api.actions.api_tokens import TOKEN_REPO +from selfprivacy_api.graphql import IsAuthenticated + +from tests.conftest import DEVICE_WE_AUTH_TESTS_WITH +from tests.test_jobs import jobs as empty_jobs + +# We do not iterate through them yet +TESTED_SUBPROTOCOLS = ["graphql-transport-ws"] JOBS_SUBSCRIPTION = """ jobUpdates { @@ -27,6 +34,48 @@ jobUpdates { """ +def connect_ws_authenticated(authorized_client) -> WebSocketTestSession: + token = "Bearer " + str(DEVICE_WE_AUTH_TESTS_WITH["token"]) + return authorized_client.websocket_connect( + "/graphql", + subprotocols=TESTED_SUBPROTOCOLS, + params={"token": token}, + ) + + +def connect_ws_not_authenticated(client) -> WebSocketTestSession: + return client.websocket_connect( + "/graphql", + subprotocols=TESTED_SUBPROTOCOLS, + params={"token": "I like vegan icecream but it is not a valid token"}, + ) + + +def init_graphql(websocket): + websocket.send_json({"type": "connection_init", "payload": {}}) + ack = websocket.receive_json() + assert ack == {"type": "connection_ack"} + + +@pytest.fixture +def authenticated_websocket( + authorized_client, +) -> Generator[WebSocketTestSession, None, None]: + # We use authorized_client only tohave token in the repo, this client by itself is not enough to authorize websocket + + ValueError(TOKEN_REPO.get_tokens()) + with connect_ws_authenticated(authorized_client) as websocket: + yield websocket + sleep(1) + + +@pytest.fixture +def unauthenticated_websocket(client) -> Generator[WebSocketTestSession, None, None]: + with connect_ws_not_authenticated(client) as websocket: + yield websocket + sleep(1) + + def test_websocket_connection_bare(authorized_client): client = authorized_client with client.websocket_connect( @@ -57,12 +106,6 @@ def test_websocket_graphql_ping(authorized_client): assert pong == {"type": "pong"} -def init_graphql(websocket): - websocket.send_json({"type": "connection_init", "payload": {}}) - ack = websocket.receive_json() - assert ack == {"type": "connection_ack"} - - def test_websocket_subscription_minimal(authorized_client): client = authorized_client with client.websocket_connect( @@ -107,48 +150,66 @@ async def read_one_job(websocket): @pytest.mark.asyncio -async def test_websocket_subscription(authorized_client, empty_redis, event_loop): - client = authorized_client - with client.websocket_connect( - "/graphql", subprotocols=["graphql-transport-ws"] - ) as websocket: - init_graphql(websocket) - websocket.send_json( - { - "id": "3aaa2445", - "type": "subscribe", - "payload": { - "query": "subscription TestSubscription {" - + JOBS_SUBSCRIPTION - + "}", - }, - } - ) - future = asyncio.create_task(read_one_job(websocket)) - jobs = [] - jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works")) - sleep(0.5) - jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works")) +async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs): + websocket = authenticated_websocket + init_graphql(websocket) + websocket.send_json( + { + "id": "3aaa2445", + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}", + }, + } + ) + future = asyncio.create_task(read_one_job(websocket)) + jobs = [] + jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works")) + sleep(0.5) + jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works")) - response = await future - data = response["payload"]["data"] - jobs_received = data["jobUpdates"] - received_names = [job["name"] for job in jobs_received] - for job in jobs: - assert job.name in received_names + response = await future + data = response["payload"]["data"] + jobs_received = data["jobUpdates"] + received_names = [job["name"] for job in jobs_received] + for job in jobs: + assert job.name in received_names - for job in jobs: - for api_job in jobs_received: - if (job.name) == api_job["name"]: - assert api_job["uid"] == str(job.uid) - assert api_job["typeId"] == job.type_id - assert api_job["name"] == job.name - assert api_job["description"] == job.description - assert api_job["status"] == job.status - assert api_job["statusText"] == job.status_text - assert api_job["progress"] == job.progress - assert api_job["createdAt"] == job.created_at.isoformat() - assert api_job["updatedAt"] == job.updated_at.isoformat() - assert api_job["finishedAt"] == None - assert api_job["error"] == None - assert api_job["result"] == None + assert len(jobs_received) == 2 + + for job in jobs: + for api_job in jobs_received: + if (job.name) == api_job["name"]: + assert api_job["uid"] == str(job.uid) + assert api_job["typeId"] == job.type_id + assert api_job["name"] == job.name + assert api_job["description"] == job.description + assert api_job["status"] == job.status + assert api_job["statusText"] == job.status_text + assert api_job["progress"] == job.progress + assert api_job["createdAt"] == job.created_at.isoformat() + assert api_job["updatedAt"] == job.updated_at.isoformat() + assert api_job["finishedAt"] == None + assert api_job["error"] == None + assert api_job["result"] == None + + +def test_websocket_subscription_unauthorized(unauthenticated_websocket): + websocket = unauthenticated_websocket + init_graphql(websocket) + websocket.send_json( + { + "id": "3aaa2445", + "type": "subscribe", + "payload": { + "query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}", + }, + } + ) + + response = websocket.receive_json() + assert response == { + "id": "3aaa2445", + "payload": [{"message": IsAuthenticated.message}], + "type": "error", + }