feat: Call caching and remote module metadata loading

This commit is contained in:
Inex Code 2024-12-21 21:01:45 +03:00
parent 4cdc9f3e08
commit b609bbc39d
No known key found for this signature in database
4 changed files with 129 additions and 18 deletions

View file

@ -5,6 +5,19 @@ let
config-id = "default"; config-id = "default";
nixos-rebuild = "${config.system.build.nixos-rebuild}/bin/nixos-rebuild"; nixos-rebuild = "${config.system.build.nixos-rebuild}/bin/nixos-rebuild";
nix = "${config.nix.package.out}/bin/nix"; nix = "${config.nix.package.out}/bin/nix";
sp-fetch-remote-module = pkgs.writeShellApplication {
name = "sp-fetch-remote-module";
runtimeInputs = [ config.nix.package.out ];
text = ''
if [ "$#" -ne 1 ]; then
echo "Usage: $0 <URL>"
exit 1
fi
URL="$1"
nix eval --file /etc/nixos/sp-fetch-remote-module.nix --raw --apply "f: f { flakeURL = \"$URL\"; }" | jq .
'';
};
in in
{ {
options.services.selfprivacy-api = { options.services.selfprivacy-api = {
@ -46,6 +59,7 @@ in
pkgs.util-linux pkgs.util-linux
pkgs.e2fsprogs pkgs.e2fsprogs
pkgs.iproute2 pkgs.iproute2
sp-fetch-remote-module
]; ];
after = [ "network-online.target" ]; after = [ "network-online.target" ];
wants = [ "network-online.target" ]; wants = [ "network-online.target" ];
@ -81,6 +95,7 @@ in
pkgs.util-linux pkgs.util-linux
pkgs.e2fsprogs pkgs.e2fsprogs
pkgs.iproute2 pkgs.iproute2
sp-fetch-remote-module
]; ];
after = [ "network-online.target" ]; after = [ "network-online.target" ];
wants = [ "network-online.target" ]; wants = [ "network-online.target" ];

View file

@ -3,6 +3,7 @@
import logging import logging
import base64 import base64
import typing import typing
import subprocess
from typing import List from typing import List
from os import path, remove from os import path, remove
from os import makedirs from os import makedirs
@ -22,6 +23,7 @@ from selfprivacy_api.services.ocserv import Ocserv
from selfprivacy_api.services.service import Service, ServiceDnsRecord from selfprivacy_api.services.service import Service, ServiceDnsRecord
from selfprivacy_api.services.service import ServiceStatus from selfprivacy_api.services.service import ServiceStatus
from selfprivacy_api.utils.cached_call import redis_cached_call
import selfprivacy_api.utils.network as network_utils import selfprivacy_api.utils.network as network_utils
from selfprivacy_api.services.api_icon import API_ICON from selfprivacy_api.services.api_icon import API_ICON
@ -37,33 +39,37 @@ CONFIG_STASH_DIR = "/etc/selfprivacy/dump"
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
SP_SUGGESTED_MODULES_PATH = "/etc/suggested-sp-modules"
class ServiceManager(Service): class ServiceManager(Service):
folders: List[str] = [CONFIG_STASH_DIR] folders: List[str] = [CONFIG_STASH_DIR]
@staticmethod @staticmethod
def get_all_services() -> list[Service]: def get_all_services() -> list[Service]:
return services return get_services()
@staticmethod @staticmethod
def get_service_by_id(service_id: str) -> typing.Optional[Service]: def get_service_by_id(service_id: str) -> typing.Optional[Service]:
for service in services: for service in get_services():
if service.get_id() == service_id: if service.get_id() == service_id:
return service return service
return None return None
@staticmethod @staticmethod
def get_enabled_services() -> list[Service]: def get_enabled_services() -> list[Service]:
return [service for service in services if service.is_enabled()] return [service for service in get_services() if service.is_enabled()]
# This one is not currently used by any code. # This one is not currently used by any code.
@staticmethod @staticmethod
def get_disabled_services() -> list[Service]: def get_disabled_services() -> list[Service]:
return [service for service in services if not service.is_enabled()] return [service for service in get_services() if not service.is_enabled()]
@staticmethod @staticmethod
def get_services_by_location(location: str) -> list[Service]: def get_services_by_location(location: str) -> list[Service]:
return [service for service in services if service.get_drive() == location] return [
service for service in get_services() if service.get_drive() == location
]
@staticmethod @staticmethod
def get_all_required_dns_records() -> list[ServiceDnsRecord]: def get_all_required_dns_records() -> list[ServiceDnsRecord]:
@ -155,7 +161,7 @@ class ServiceManager(Service):
# For now we will just copy settings EXCEPT the locations of services # For now we will just copy settings EXCEPT the locations of services
# Stash locations as they are set by user right now # Stash locations as they are set by user right now
locations = {} locations = {}
for service in services: for service in get_services():
if service.is_movable(): if service.is_movable():
locations[service.get_id()] = service.get_drive() locations[service.get_id()] = service.get_drive()
@ -164,7 +170,7 @@ class ServiceManager(Service):
cls.retrieve_stashed_path(p) cls.retrieve_stashed_path(p)
# Pop locations # Pop locations
for service in services: for service in get_services():
if service.is_movable(): if service.is_movable():
device = BlockDevices().get_block_device(locations[service.get_id()]) device = BlockDevices().get_block_device(locations[service.get_id()])
if device is not None: if device is not None:
@ -250,6 +256,25 @@ class ServiceManager(Service):
rmtree(cls.dump_dir(), ignore_errors=True) rmtree(cls.dump_dir(), ignore_errors=True)
@redis_cached_call(ttl=30)
def get_templated_service(service_id: str) -> TemplatedService:
return TemplatedService(service_id)
@redis_cached_call(ttl=3600)
def get_remote_service(url: str, id: str) -> TemplatedService:
# Get JSON from calling the sp-fetch-remote-module command with the URL
# Parse the JSON into a TemplatedService object
response = subprocess.run(
["sp-fetch-remote-module", url],
capture_output=True,
text=True,
check=True,
)
return TemplatedService(id, response.stdout)
@redis_cached_call(ttl=5)
def get_services() -> List[Service]: def get_services() -> List[Service]:
hardcoded_services: list[Service] = [ hardcoded_services: list[Service] = [
Bitwarden(), Bitwarden(),
@ -263,20 +288,40 @@ def get_services() -> List[Service]:
ServiceManager(), ServiceManager(),
Prometheus(), Prometheus(),
] ]
hardcoded_services_ids = [service.get_id() for service in hardcoded_services] service_ids = [service.get_id() for service in hardcoded_services]
# Load services from SP_MODULES_DEFENITIONS_PATH # Load services from SP_MODULES_DEFENITIONS_PATH
templated_services: List[Service] = [] templated_services: List[Service] = []
if path.exists(SP_MODULES_DEFENITIONS_PATH): if path.exists(SP_MODULES_DEFENITIONS_PATH):
for module in listdir(SP_MODULES_DEFENITIONS_PATH): for module in listdir(SP_MODULES_DEFENITIONS_PATH):
if module in hardcoded_services_ids: if module in service_ids:
continue continue
try: try:
templated_services.append(TemplatedService(module)) templated_services.append(get_templated_service(module))
service_ids.append(module)
except Exception as e:
logger.error(f"Failed to load service {module}: {e}")
if path.exists(SP_SUGGESTED_MODULES_PATH):
# It is a file with a JSON array
with open(SP_SUGGESTED_MODULES_PATH) as f:
suggested_modules = f.read().splitlines()
for module in suggested_modules:
if module in service_ids:
continue
try:
# TODO: Replace the branch!
templated_services.append(
get_remote_service(
module,
f"git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=inex/experimental-templating&dir=sp-modules/{module}",
)
)
service_ids.append(module)
except Exception as e: except Exception as e:
logger.error(f"Failed to load service {module}: {e}") logger.error(f"Failed to load service {module}: {e}")
return hardcoded_services + templated_services return hardcoded_services + templated_services
services = get_services() # services = get_services()

View file

@ -338,7 +338,10 @@ class ServiceMetaData(BaseSchema):
class TemplatedService(Service): class TemplatedService(Service):
"""Class representing a dynamically loaded service.""" """Class representing a dynamically loaded service."""
def __init__(self, service_id: str) -> None: def __init__(self, service_id: str, source_data: Optional[str] = None) -> None:
if source_data:
self.definition_data = json.loads(source_data)
else:
# Check if the service exists # Check if the service exists
if not path.exists(join(SP_MODULES_DEFENITIONS_PATH, service_id)): if not path.exists(join(SP_MODULES_DEFENITIONS_PATH, service_id)):
raise FileNotFoundError(f"Service {service_id} not found") raise FileNotFoundError(f"Service {service_id} not found")

View file

@ -0,0 +1,48 @@
import asyncio
import pickle
from functools import wraps
from typing import Any, Optional, Callable
from selfprivacy_api.utils.redis_pool import RedisPool
CACHE_PREFIX = "exec_cache:"
async def get_redis_object(key: str) -> Optional[Any]:
redis = RedisPool().get_connection_async()
binary_obj = await redis.get(key)
if binary_obj is None:
return None
return pickle.loads(binary_obj)
async def save_redis_object(key: str, obj: Any, expire: Optional[int] = 60) -> None:
redis = RedisPool().get_connection_async()
binary_obj = pickle.dumps(obj)
if expire:
await redis.setex(key, expire, binary_obj)
else:
await redis.set(key, binary_obj)
def redis_cached_call(ttl: Optional[int] = 60) -> Callable[..., Callable]:
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs) -> Any:
key = f"{CACHE_PREFIX}{func.__name__}:{args}:{kwargs}"
cached_value = await get_redis_object(key)
if cached_value is not None:
return cached_value
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
await save_redis_object(key, result, ttl)
return result
return wrapper
return decorator