refactor(services): PARTIAL migrate get_all_services

This commit is contained in:
Houkime 2024-07-24 11:41:32 +00:00
parent f6151ee451
commit 2ef674a037
4 changed files with 56 additions and 52 deletions

View file

@ -8,10 +8,8 @@ import os
from os import statvfs from os import statvfs
from typing import Callable, List, Optional from typing import Callable, List, Optional
from selfprivacy_api.services import ( from selfprivacy_api.services import ServiceManager
get_service_by_id,
get_all_services,
)
from selfprivacy_api.services.service import ( from selfprivacy_api.services.service import (
Service, Service,
ServiceStatus, ServiceStatus,
@ -376,7 +374,7 @@ class Backups:
@staticmethod @staticmethod
def prune_all_autosnaps() -> None: def prune_all_autosnaps() -> None:
for service in get_all_services(): for service in ServiceManager.get_all_services():
Backups._prune_auto_snaps(service) Backups._prune_auto_snaps(service)
# Restoring # Restoring
@ -431,7 +429,7 @@ class Backups:
snapshot: Snapshot, strategy=RestoreStrategy.DOWNLOAD_VERIFY_OVERWRITE snapshot: Snapshot, strategy=RestoreStrategy.DOWNLOAD_VERIFY_OVERWRITE
) -> None: ) -> None:
"""Restores a snapshot to its original service using the given strategy""" """Restores a snapshot to its original service using the given strategy"""
service = get_service_by_id(snapshot.service_name) service = ServiceManager.get_service_by_id(snapshot.service_name)
if service is None: if service is None:
raise ValueError( raise ValueError(
f"snapshot has a nonexistent service: {snapshot.service_name}" f"snapshot has a nonexistent service: {snapshot.service_name}"
@ -475,7 +473,7 @@ class Backups:
def _assert_restorable( def _assert_restorable(
snapshot: Snapshot, strategy=RestoreStrategy.DOWNLOAD_VERIFY_OVERWRITE snapshot: Snapshot, strategy=RestoreStrategy.DOWNLOAD_VERIFY_OVERWRITE
) -> None: ) -> None:
service = get_service_by_id(snapshot.service_name) service = ServiceManager.get_service_by_id(snapshot.service_name)
if service is None: if service is None:
raise ValueError( raise ValueError(
f"snapshot has a nonexistent service: {snapshot.service_name}" f"snapshot has a nonexistent service: {snapshot.service_name}"
@ -646,7 +644,7 @@ class Backups:
"""Returns a list of services that should be backed up at a given time""" """Returns a list of services that should be backed up at a given time"""
return [ return [
service service
for service in get_all_services() for service in ServiceManager.get_all_services()
if Backups.is_time_to_backup_service(service, time) if Backups.is_time_to_backup_service(service, time)
] ]

View file

@ -8,12 +8,12 @@ from selfprivacy_api.graphql.common_types.service import (
Service, Service,
service_to_graphql_service, service_to_graphql_service,
) )
from selfprivacy_api.services import get_all_services from selfprivacy_api.services import ServiceManager
@strawberry.type @strawberry.type
class Services: class Services:
@strawberry.field @strawberry.field
def all_services(self) -> typing.List[Service]: def all_services(self) -> typing.List[Service]:
services = get_all_services() services = ServiceManager.get_all_services()
return [service_to_graphql_service(service) for service in services] return [service_to_graphql_service(service) for service in services]

View file

@ -26,52 +26,54 @@ services: list[Service] = [
] ]
def get_all_services() -> list[Service]: class ServiceManager(Service):
return services @staticmethod
def get_all_services() -> list[Service]:
return services
@staticmethod
def get_service_by_id(service_id: str) -> typing.Optional[Service]:
for service in services:
if service.get_id() == service_id:
return service
return None
def get_service_by_id(service_id: str) -> typing.Optional[Service]: @staticmethod
for service in services: def get_enabled_services() -> list[Service]:
if service.get_id() == service_id: return [service for service in services if service.is_enabled()]
return service
return None
@staticmethod
def get_disabled_services() -> list[Service]:
return [service for service in services if not service.is_enabled()]
def get_enabled_services() -> list[Service]: @staticmethod
return [service for service in services if service.is_enabled()] def get_services_by_location(location: str) -> list[Service]:
return [service for service in services if service.get_drive() == location]
@staticmethod
def get_disabled_services() -> list[Service]: def get_all_required_dns_records() -> list[ServiceDnsRecord]:
return [service for service in services if not service.is_enabled()] ip4 = network_utils.get_ip4()
ip6 = network_utils.get_ip6()
dns_records: list[ServiceDnsRecord] = [
def get_services_by_location(location: str) -> list[Service]:
return [service for service in services if service.get_drive() == location]
def get_all_required_dns_records() -> list[ServiceDnsRecord]:
ip4 = network_utils.get_ip4()
ip6 = network_utils.get_ip6()
dns_records: list[ServiceDnsRecord] = [
ServiceDnsRecord(
type="A",
name="api",
content=ip4,
ttl=3600,
display_name="SelfPrivacy API",
),
]
if ip6 is not None:
dns_records.append(
ServiceDnsRecord( ServiceDnsRecord(
type="AAAA", type="A",
name="api", name="api",
content=ip6, content=ip4,
ttl=3600, ttl=3600,
display_name="SelfPrivacy API (IPv6)", display_name="SelfPrivacy API",
),
]
if ip6 is not None:
dns_records.append(
ServiceDnsRecord(
type="AAAA",
name="api",
content=ip6,
ttl=3600,
display_name="SelfPrivacy API (IPv6)",
)
) )
) for service in get_enabled_services():
for service in get_enabled_services(): dns_records += service.get_dns_records(ip4, ip6)
dns_records += service.get_dns_records(ip4, ip6) return dns_records
return dns_records

View file

@ -4,7 +4,7 @@ from copy import copy
from datetime import datetime, timezone, timedelta from datetime import datetime, timezone, timedelta
from selfprivacy_api.jobs import Jobs from selfprivacy_api.jobs import Jobs
from selfprivacy_api.services import Service, get_all_services from selfprivacy_api.services import Service, ServiceManager
from selfprivacy_api.graphql.common_types.backup import ( from selfprivacy_api.graphql.common_types.backup import (
BackupReason, BackupReason,
@ -23,7 +23,11 @@ from tests.test_graphql.test_services import only_dummy_service
def backuppable_services() -> list[Service]: def backuppable_services() -> list[Service]:
return [service for service in get_all_services() if service.can_be_backed_up()] return [
service
for service in ServiceManager.get_all_services()
if service.can_be_backed_up()
]
def dummy_snapshot(date: datetime): def dummy_snapshot(date: datetime):