feature(websocket): add auth to counter too

This commit is contained in:
Houkime 2024-05-27 20:38:51 +00:00
parent 8fd12a1775
commit 950093a3b1
2 changed files with 47 additions and 39 deletions

View file

@ -136,10 +136,15 @@ class Mutation(
# A cruft for Websockets # A cruft for Websockets
def authenticated(info) -> bool: def authenticated(info: strawberry.types.Info) -> bool:
return IsAuthenticated().has_permission(source=None, info=info) return IsAuthenticated().has_permission(source=None, info=info)
def reject_if_unauthenticated(info: strawberry.types.Info):
if not authenticated(info):
raise Exception(IsAuthenticated().message)
@strawberry.type @strawberry.type
class Subscription: class Subscription:
"""Root schema for subscriptions. """Root schema for subscriptions.
@ -151,19 +156,15 @@ class Subscription:
async def job_updates( async def job_updates(
self, info: strawberry.types.Info self, info: strawberry.types.Info
) -> AsyncGenerator[List[ApiJob], None]: ) -> AsyncGenerator[List[ApiJob], None]:
if not authenticated(info): reject_if_unauthenticated(info)
raise Exception(IsAuthenticated().message)
# Send the complete list of jobs every time anything gets updated # Send the complete list of jobs every time anything gets updated
async for notification in job_notifications(): async for notification in job_notifications():
yield get_all_jobs() yield get_all_jobs()
# @strawberry.subscription
# async def job_updates(self) -> AsyncGenerator[List[ApiJob], None]:
# return job_updates()
@strawberry.subscription @strawberry.subscription
async def count(self) -> AsyncGenerator[int, None]: async def count(self, info: strawberry.types.Info) -> AsyncGenerator[int, None]:
reject_if_unauthenticated(info)
for i in range(10): for i in range(10):
yield i yield i
await asyncio.sleep(0.5) await asyncio.sleep(0.5)

View file

@ -106,41 +106,61 @@ def test_websocket_graphql_ping(authorized_client):
assert pong == {"type": "pong"} assert pong == {"type": "pong"}
def api_subscribe(websocket, id, subscription):
websocket.send_json(
{
"id": id,
"type": "subscribe",
"payload": {
"query": "subscription TestSubscription {" + subscription + "}",
},
}
)
def test_websocket_subscription_minimal(authorized_client): def test_websocket_subscription_minimal(authorized_client):
# Test a small endpoint that exists specifically for tests
client = authorized_client client = authorized_client
with client.websocket_connect( with client.websocket_connect(
"/graphql", subprotocols=["graphql-transport-ws"] "/graphql", subprotocols=["graphql-transport-ws"]
) as websocket: ) as websocket:
init_graphql(websocket) init_graphql(websocket)
websocket.send_json( arbitrary_id = "3aaa2445"
{ api_subscribe(websocket, arbitrary_id, "count")
"id": "3aaa2445",
"type": "subscribe",
"payload": {
"query": "subscription TestSubscription {count}",
},
}
)
response = websocket.receive_json() response = websocket.receive_json()
assert response == { assert response == {
"id": "3aaa2445", "id": arbitrary_id,
"payload": {"data": {"count": 0}}, "payload": {"data": {"count": 0}},
"type": "next", "type": "next",
} }
response = websocket.receive_json() response = websocket.receive_json()
assert response == { assert response == {
"id": "3aaa2445", "id": arbitrary_id,
"payload": {"data": {"count": 1}}, "payload": {"data": {"count": 1}},
"type": "next", "type": "next",
} }
response = websocket.receive_json() response = websocket.receive_json()
assert response == { assert response == {
"id": "3aaa2445", "id": arbitrary_id,
"payload": {"data": {"count": 2}}, "payload": {"data": {"count": 2}},
"type": "next", "type": "next",
} }
def test_websocket_subscription_minimal_unauthorized(unauthenticated_websocket):
websocket = unauthenticated_websocket
init_graphql(websocket)
arbitrary_id = "3aaa2445"
api_subscribe(websocket, arbitrary_id, "count")
response = websocket.receive_json()
assert response == {
"id": arbitrary_id,
"payload": [{"message": IsAuthenticated.message}],
"type": "error",
}
async def read_one_job(websocket): async def read_one_job(websocket):
# bug? We only get them starting from the second job update # bug? We only get them starting from the second job update
# that's why we receive two jobs in the list them # that's why we receive two jobs in the list them
@ -153,15 +173,9 @@ async def read_one_job(websocket):
async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs): async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs):
websocket = authenticated_websocket websocket = authenticated_websocket
init_graphql(websocket) init_graphql(websocket)
websocket.send_json( arbitrary_id = "3aaa2445"
{ api_subscribe(websocket, arbitrary_id, JOBS_SUBSCRIPTION)
"id": "3aaa2445",
"type": "subscribe",
"payload": {
"query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}",
},
}
)
future = asyncio.create_task(read_one_job(websocket)) future = asyncio.create_task(read_one_job(websocket))
jobs = [] jobs = []
jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works")) jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works"))
@ -197,19 +211,12 @@ async def test_websocket_subscription(authenticated_websocket, event_loop, empty
def test_websocket_subscription_unauthorized(unauthenticated_websocket): def test_websocket_subscription_unauthorized(unauthenticated_websocket):
websocket = unauthenticated_websocket websocket = unauthenticated_websocket
init_graphql(websocket) init_graphql(websocket)
websocket.send_json( id = "3aaa2445"
{ api_subscribe(websocket, id, JOBS_SUBSCRIPTION)
"id": "3aaa2445",
"type": "subscribe",
"payload": {
"query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}",
},
}
)
response = websocket.receive_json() response = websocket.receive_json()
assert response == { assert response == {
"id": "3aaa2445", "id": id,
"payload": [{"message": IsAuthenticated.message}], "payload": [{"message": IsAuthenticated.message}],
"type": "error", "type": "error",
} }