feature(websocket): add auth to counter too

This commit is contained in:
Houkime 2024-05-27 20:38:51 +00:00
parent 8fd12a1775
commit 950093a3b1
2 changed files with 47 additions and 39 deletions

View file

@ -136,10 +136,15 @@ class Mutation(
# A cruft for Websockets
def authenticated(info) -> bool:
def authenticated(info: strawberry.types.Info) -> bool:
return IsAuthenticated().has_permission(source=None, info=info)
def reject_if_unauthenticated(info: strawberry.types.Info):
if not authenticated(info):
raise Exception(IsAuthenticated().message)
@strawberry.type
class Subscription:
"""Root schema for subscriptions.
@ -151,19 +156,15 @@ class Subscription:
async def job_updates(
self, info: strawberry.types.Info
) -> AsyncGenerator[List[ApiJob], None]:
if not authenticated(info):
raise Exception(IsAuthenticated().message)
reject_if_unauthenticated(info)
# Send the complete list of jobs every time anything gets updated
async for notification in job_notifications():
yield get_all_jobs()
# @strawberry.subscription
# async def job_updates(self) -> AsyncGenerator[List[ApiJob], None]:
# return job_updates()
@strawberry.subscription
async def count(self) -> AsyncGenerator[int, None]:
async def count(self, info: strawberry.types.Info) -> AsyncGenerator[int, None]:
reject_if_unauthenticated(info)
for i in range(10):
yield i
await asyncio.sleep(0.5)

View file

@ -106,41 +106,61 @@ def test_websocket_graphql_ping(authorized_client):
assert pong == {"type": "pong"}
def api_subscribe(websocket, id, subscription):
websocket.send_json(
{
"id": id,
"type": "subscribe",
"payload": {
"query": "subscription TestSubscription {" + subscription + "}",
},
}
)
def test_websocket_subscription_minimal(authorized_client):
# Test a small endpoint that exists specifically for tests
client = authorized_client
with client.websocket_connect(
"/graphql", subprotocols=["graphql-transport-ws"]
) as websocket:
init_graphql(websocket)
websocket.send_json(
{
"id": "3aaa2445",
"type": "subscribe",
"payload": {
"query": "subscription TestSubscription {count}",
},
}
)
arbitrary_id = "3aaa2445"
api_subscribe(websocket, arbitrary_id, "count")
response = websocket.receive_json()
assert response == {
"id": "3aaa2445",
"id": arbitrary_id,
"payload": {"data": {"count": 0}},
"type": "next",
}
response = websocket.receive_json()
assert response == {
"id": "3aaa2445",
"id": arbitrary_id,
"payload": {"data": {"count": 1}},
"type": "next",
}
response = websocket.receive_json()
assert response == {
"id": "3aaa2445",
"id": arbitrary_id,
"payload": {"data": {"count": 2}},
"type": "next",
}
def test_websocket_subscription_minimal_unauthorized(unauthenticated_websocket):
websocket = unauthenticated_websocket
init_graphql(websocket)
arbitrary_id = "3aaa2445"
api_subscribe(websocket, arbitrary_id, "count")
response = websocket.receive_json()
assert response == {
"id": arbitrary_id,
"payload": [{"message": IsAuthenticated.message}],
"type": "error",
}
async def read_one_job(websocket):
# bug? We only get them starting from the second job update
# that's why we receive two jobs in the list them
@ -153,15 +173,9 @@ async def read_one_job(websocket):
async def test_websocket_subscription(authenticated_websocket, event_loop, empty_jobs):
websocket = authenticated_websocket
init_graphql(websocket)
websocket.send_json(
{
"id": "3aaa2445",
"type": "subscribe",
"payload": {
"query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}",
},
}
)
arbitrary_id = "3aaa2445"
api_subscribe(websocket, arbitrary_id, JOBS_SUBSCRIPTION)
future = asyncio.create_task(read_one_job(websocket))
jobs = []
jobs.append(Jobs.add("bogus", "bogus.bogus", "yyyaaaaayy it works"))
@ -197,19 +211,12 @@ async def test_websocket_subscription(authenticated_websocket, event_loop, empty
def test_websocket_subscription_unauthorized(unauthenticated_websocket):
websocket = unauthenticated_websocket
init_graphql(websocket)
websocket.send_json(
{
"id": "3aaa2445",
"type": "subscribe",
"payload": {
"query": "subscription TestSubscription {" + JOBS_SUBSCRIPTION + "}",
},
}
)
id = "3aaa2445"
api_subscribe(websocket, id, JOBS_SUBSCRIPTION)
response = websocket.receive_json()
assert response == {
"id": "3aaa2445",
"id": id,
"payload": [{"message": IsAuthenticated.message}],
"type": "error",
}