feature(websocket): add auth

This commit is contained in:
Houkime 2024-05-27 20:21:11 +00:00 committed by Inex Code
parent 0fda29cdd7
commit cb641e4f37
2 changed files with 131 additions and 56 deletions

View file

@ -4,6 +4,7 @@
import asyncio import asyncio
from typing import AsyncGenerator, List from typing import AsyncGenerator, List
import strawberry import strawberry
from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.mutations.deprecated_mutations import ( from selfprivacy_api.graphql.mutations.deprecated_mutations import (
DeprecatedApiMutations, DeprecatedApiMutations,
@ -134,12 +135,25 @@ class Mutation(
) )
# A cruft for Websockets
def authenticated(info) -> bool:
return IsAuthenticated().has_permission(source=None, info=info)
@strawberry.type @strawberry.type
class Subscription: class Subscription:
"""Root schema for subscriptions""" """Root schema for subscriptions.
Every field here should be an AsyncIterator or AsyncGenerator
It is not a part of the spec but graphql-core (dep of strawberryql)
demands it while the spec is vague in this area."""
@strawberry.subscription @strawberry.subscription
async def job_updates(self) -> AsyncGenerator[List[ApiJob], None]: async def job_updates(
self, info: strawberry.types.Info
) -> AsyncGenerator[List[ApiJob], None]:
if not authenticated(info):
raise Exception(IsAuthenticated().message)
# Send the complete list of jobs every time anything gets updated # Send the complete list of jobs every time anything gets updated
async for notification in job_notifications(): async for notification in job_notifications():
yield get_all_jobs() yield get_all_jobs()

View file

@ -1,13 +1,20 @@
from tests.common import generate_jobs_subscription
# from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions # from selfprivacy_api.graphql.subscriptions.jobs import JobSubscriptions
import pytest import pytest
import asyncio import asyncio
from typing import Generator
from selfprivacy_api.jobs import Jobs
from time import sleep from time import sleep
from tests.test_redis import empty_redis from starlette.testclient import WebSocketTestSession
from selfprivacy_api.jobs import Jobs
from selfprivacy_api.actions.api_tokens import TOKEN_REPO
from selfprivacy_api.graphql import IsAuthenticated
from tests.conftest import DEVICE_WE_AUTH_TESTS_WITH
from tests.test_jobs import jobs as empty_jobs
# We do not iterate through them yet
TESTED_SUBPROTOCOLS = ["graphql-transport-ws"]
JOBS_SUBSCRIPTION = """ JOBS_SUBSCRIPTION = """
jobUpdates { jobUpdates {
@ -27,6 +34,48 @@ jobUpdates {
""" """
def connect_ws_authenticated(authorized_client) -> WebSocketTestSession:
token = "Bearer " + str(DEVICE_WE_AUTH_TESTS_WITH["token"])
return authorized_client.websocket_connect(
"/graphql",
subprotocols=TESTED_SUBPROTOCOLS,
params={"token": token},
)
def connect_ws_not_authenticated(client) -> WebSocketTestSession:
return client.websocket_connect(
"/graphql",
subprotocols=TESTED_SUBPROTOCOLS,
params={"token": "I like vegan icecream but it is not a valid token"},
)
def init_graphql(websocket):
websocket.send_json({"type": "connection_init", "payload": {}})
ack = websocket.receive_json()
assert ack == {"type": "connection_ack"}
@pytest.fixture
def authenticated_websocket(
authorized_client,
) -> Generator[WebSocketTestSession, None, None]:
# We use authorized_client only tohave token in the repo, this client by itself is not enough to authorize websocket
ValueError(TOKEN_REPO.get_tokens())
with connect_ws_authenticated(authorized_client) as websocket:
yield websocket
sleep(1)
@pytest.fixture
def unauthenticated_websocket(client) -> Generator[WebSocketTestSession, None, None]:
with connect_ws_not_authenticated(client) as websocket:
yield websocket
sleep(1)
def test_websocket_connection_bare(authorized_client): def test_websocket_connection_bare(authorized_client):
client = authorized_client client = authorized_client
with client.websocket_connect( with client.websocket_connect(
@ -57,12 +106,6 @@ def test_websocket_graphql_ping(authorized_client):
assert pong == {"type": "pong"} assert pong == {"type": "pong"}
def init_graphql(websocket):
websocket.send_json({"type": "connection_init", "payload": {}})
ack = websocket.receive_json()
assert ack == {"type": "connection_ack"}
def test_websocket_subscription_minimal(authorized_client): def test_websocket_subscription_minimal(authorized_client):
client = authorized_client client = authorized_client
with client.websocket_connect( with client.websocket_connect(
@ -107,48 +150,66 @@ async def read_one_job(websocket):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_websocket_subscription(authorized_client, empty_redis, event_loop): async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs):
client = authorized_client websocket = authenticated_websocket
with client.websocket_connect( init_graphql(websocket)
"/graphql", subprotocols=["graphql-transport-ws"] websocket.send_json(
) as websocket: {
init_graphql(websocket) "id": "3aaa2445",
websocket.send_json( "type": "subscribe",
{ "payload": {
"id": "3aaa2445", "query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}",
"type": "subscribe", },
"payload": { }
"query": "subscription TestSubscription {" )
+ JOBS_SUBSCRIPTION future = asyncio.create_task(read_one_job(websocket))
+ "}", jobs = []
}, jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works"))
} sleep(0.5)
) jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works"))
future = asyncio.create_task(read_one_job(websocket))
jobs = []
jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works"))
sleep(0.5)
jobs.append(Jobs.add("bogus2", "bogus.bogus", "yyyaaaaayy it works"))
response = await future response = await future
data = response["payload"]["data"] data = response["payload"]["data"]
jobs_received = data["jobUpdates"] jobs_received = data["jobUpdates"]
received_names = [job["name"] for job in jobs_received] received_names = [job["name"] for job in jobs_received]
for job in jobs: for job in jobs:
assert job.name in received_names assert job.name in received_names
for job in jobs: assert len(jobs_received) == 2
for api_job in jobs_received:
if (job.name) == api_job["name"]: for job in jobs:
assert api_job["uid"] == str(job.uid) for api_job in jobs_received:
assert api_job["typeId"] == job.type_id if (job.name) == api_job["name"]:
assert api_job["name"] == job.name assert api_job["uid"] == str(job.uid)
assert api_job["description"] == job.description assert api_job["typeId"] == job.type_id
assert api_job["status"] == job.status assert api_job["name"] == job.name
assert api_job["statusText"] == job.status_text assert api_job["description"] == job.description
assert api_job["progress"] == job.progress assert api_job["status"] == job.status
assert api_job["createdAt"] == job.created_at.isoformat() assert api_job["statusText"] == job.status_text
assert api_job["updatedAt"] == job.updated_at.isoformat() assert api_job["progress"] == job.progress
assert api_job["finishedAt"] == None assert api_job["createdAt"] == job.created_at.isoformat()
assert api_job["error"] == None assert api_job["updatedAt"] == job.updated_at.isoformat()
assert api_job["result"] == None assert api_job["finishedAt"] == None
assert api_job["error"] == None
assert api_job["result"] == None
def test_websocket_subscription_unauthorized(unauthenticated_websocket):
websocket = unauthenticated_websocket
init_graphql(websocket)
websocket.send_json(
{
"id": "3aaa2445",
"type": "subscribe",
"payload": {
"query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}",
},
}
)
response = websocket.receive_json()
assert response == {
"id": "3aaa2445",
"payload": [{"message": IsAuthenticated.message}],
"type": "error",
}