mirror of
https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api.git
synced 2025-02-18 07:14:40 +00:00
Fix ws auth
This commit is contained in:
parent
ab9e8d81e5
commit
28c6d983b9
|
@ -1,13 +1,12 @@
|
|||
#!/usr/bin/env python3
|
||||
"""SelfPrivacy server management API"""
|
||||
import os
|
||||
from fastapi import FastAPI, Depends, Request, WebSocket, BackgroundTasks
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from strawberry.fastapi import BaseContext, GraphQLRouter
|
||||
from strawberry.fastapi import GraphQLRouter
|
||||
|
||||
import uvicorn
|
||||
|
||||
from selfprivacy_api.dependencies import get_api_version, get_graphql_context
|
||||
from selfprivacy_api.dependencies import get_api_version
|
||||
from selfprivacy_api.graphql.schema import schema
|
||||
from selfprivacy_api.migrations import run_migrations
|
||||
from selfprivacy_api.restic_controller.tasks import init_restic
|
||||
|
@ -20,9 +19,9 @@ from selfprivacy_api.rest import (
|
|||
)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
graphql_app = GraphQLRouter(
|
||||
schema,
|
||||
context_getter=get_graphql_context,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from typing import Optional
|
||||
from strawberry.fastapi import BaseContext
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import APIKeyHeader
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
@ -27,29 +25,6 @@ async def get_token_header(
|
|||
return TokenHeader(token=token)
|
||||
|
||||
|
||||
class GraphQlContext(BaseContext):
|
||||
def __init__(self, auth_token: Optional[str] = None):
|
||||
self.auth_token = auth_token
|
||||
self.is_authenticated = auth_token is not None
|
||||
|
||||
|
||||
async def get_graphql_context(
|
||||
token: str = Depends(
|
||||
APIKeyHeader(
|
||||
name="Authorization",
|
||||
auto_error=False,
|
||||
)
|
||||
)
|
||||
) -> GraphQlContext:
|
||||
if token is None:
|
||||
return GraphQlContext()
|
||||
else:
|
||||
token = token.replace("Bearer ", "")
|
||||
if not is_token_valid(token):
|
||||
return GraphQlContext()
|
||||
return GraphQlContext(auth_token=token)
|
||||
|
||||
|
||||
def get_api_version() -> str:
|
||||
"""Get API version"""
|
||||
return "2.0.0"
|
||||
|
|
|
@ -13,4 +13,8 @@ class IsAuthenticated(BasePermission):
|
|||
message = "You must be authenticated to access this resource."
|
||||
|
||||
def has_permission(self, source: typing.Any, info: Info, **kwargs) -> bool:
|
||||
return info.context.is_authenticated
|
||||
return is_token_valid(
|
||||
info.context["request"]
|
||||
.headers.get("Authorization", "")
|
||||
.replace("Bearer ", "")
|
||||
)
|
||||
|
|
|
@ -116,7 +116,11 @@ class ApiMutations:
|
|||
@strawberry.mutation(permission_classes=[IsAuthenticated])
|
||||
def refresh_device_api_token(self, info: Info) -> DeviceApiTokenMutationReturn:
|
||||
"""Refresh device api token"""
|
||||
token = info.context.auth_token
|
||||
token = (
|
||||
info.context["request"]
|
||||
.headers.get("Authorization", "")
|
||||
.replace("Bearer ", "")
|
||||
)
|
||||
if token is None:
|
||||
return DeviceApiTokenMutationReturn(
|
||||
success=False,
|
||||
|
@ -142,7 +146,11 @@ class ApiMutations:
|
|||
@strawberry.mutation(permission_classes=[IsAuthenticated])
|
||||
def delete_device_api_token(self, device: str, info: Info) -> GenericMutationReturn:
|
||||
"""Delete device api token"""
|
||||
self_token = info.context.auth_token
|
||||
self_token = (
|
||||
info.context["request"]
|
||||
.headers.get("Authorization", "")
|
||||
.replace("Bearer ", "")
|
||||
)
|
||||
try:
|
||||
delete_api_token(self_token, device)
|
||||
except NotFoundException:
|
||||
|
|
|
@ -85,7 +85,11 @@ class Api:
|
|||
creation_date=device.date,
|
||||
is_caller=device.is_caller,
|
||||
)
|
||||
for device in get_api_tokens_with_caller_flag(info.context.auth_token)
|
||||
for device in get_api_tokens_with_caller_flag(
|
||||
info.context["request"]
|
||||
.headers.get("Authorization", "")
|
||||
.replace("Bearer ", "")
|
||||
)
|
||||
]
|
||||
|
||||
recovery_key: ApiRecoveryKeyStatus = strawberry.field(
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""GraphQL API for SelfPrivacy."""
|
||||
# pylint: disable=too-few-public-methods
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator
|
||||
import strawberry
|
||||
from selfprivacy_api.graphql import IsAuthenticated
|
||||
from selfprivacy_api.graphql.mutations.api_mutations import ApiMutations
|
||||
|
@ -69,7 +71,7 @@ class Mutation(
|
|||
):
|
||||
"""Root schema for mutations"""
|
||||
|
||||
@strawberry.mutation
|
||||
@strawberry.mutation(permission_classes=[IsAuthenticated])
|
||||
def test_mutation(self) -> GenericMutationReturn:
|
||||
"""Test mutation"""
|
||||
test_job()
|
||||
|
@ -82,4 +84,15 @@ class Mutation(
|
|||
pass
|
||||
|
||||
|
||||
schema = strawberry.Schema(query=Query, mutation=Mutation)
|
||||
@strawberry.type
|
||||
class Subscription:
|
||||
"""Root schema for subscriptions"""
|
||||
|
||||
@strawberry.subscription(permission_classes=[IsAuthenticated])
|
||||
async def count(self, target: int = 100) -> AsyncGenerator[int, None]:
|
||||
for i in range(target):
|
||||
yield i
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
schema = strawberry.Schema(query=Query, mutation=Mutation, subscription=Subscription)
|
||||
|
|
Loading…
Reference in a new issue