diff --git a/nixos/module.nix b/nixos/module.nix index 6199597..e7c38d2 100644 --- a/nixos/module.nix +++ b/nixos/module.nix @@ -5,6 +5,19 @@ let config-id = "default"; nixos-rebuild = "${config.system.build.nixos-rebuild}/bin/nixos-rebuild"; nix = "${config.nix.package.out}/bin/nix"; + sp-fetch-remote-module = pkgs.writeShellApplication { + name = "sp-fetch-remote-module"; + runtimeInputs = [ config.nix.package.out ]; + text = '' + if [ "$#" -ne 1 ]; then + echo "Usage: $0 " + exit 1 + fi + + URL="$1" + nix eval --file /etc/nixos/sp-fetch-remote-module.nix --raw --apply "f: f { flakeURL = \"$URL\"; }" | jq . + ''; + }; in { options.services.selfprivacy-api = { @@ -46,6 +59,7 @@ in pkgs.util-linux pkgs.e2fsprogs pkgs.iproute2 + sp-fetch-remote-module ]; after = [ "network-online.target" ]; wants = [ "network-online.target" ]; @@ -81,6 +95,7 @@ in pkgs.util-linux pkgs.e2fsprogs pkgs.iproute2 + sp-fetch-remote-module ]; after = [ "network-online.target" ]; wants = [ "network-online.target" ]; diff --git a/selfprivacy_api/services/__init__.py b/selfprivacy_api/services/__init__.py index a070a0e..09e0ce7 100644 --- a/selfprivacy_api/services/__init__.py +++ b/selfprivacy_api/services/__init__.py @@ -3,6 +3,7 @@ import logging import base64 import typing +import subprocess from typing import List from os import path, remove from os import makedirs @@ -22,6 +23,7 @@ from selfprivacy_api.services.ocserv import Ocserv from selfprivacy_api.services.service import Service, ServiceDnsRecord from selfprivacy_api.services.service import ServiceStatus +from selfprivacy_api.utils.cached_call import redis_cached_call import selfprivacy_api.utils.network as network_utils from selfprivacy_api.services.api_icon import API_ICON @@ -37,33 +39,37 @@ CONFIG_STASH_DIR = "/etc/selfprivacy/dump" logger = logging.getLogger(__name__) +SP_SUGGESTED_MODULES_PATH = "/etc/suggested-sp-modules" + class ServiceManager(Service): folders: List[str] = [CONFIG_STASH_DIR] @staticmethod def get_all_services() -> list[Service]: - return services + return get_services() @staticmethod def get_service_by_id(service_id: str) -> typing.Optional[Service]: - for service in services: + for service in get_services(): if service.get_id() == service_id: return service return None @staticmethod def get_enabled_services() -> list[Service]: - return [service for service in services if service.is_enabled()] + return [service for service in get_services() if service.is_enabled()] # This one is not currently used by any code. @staticmethod def get_disabled_services() -> list[Service]: - return [service for service in services if not service.is_enabled()] + return [service for service in get_services() if not service.is_enabled()] @staticmethod def get_services_by_location(location: str) -> list[Service]: - return [service for service in services if service.get_drive() == location] + return [ + service for service in get_services() if service.get_drive() == location + ] @staticmethod def get_all_required_dns_records() -> list[ServiceDnsRecord]: @@ -155,7 +161,7 @@ class ServiceManager(Service): # For now we will just copy settings EXCEPT the locations of services # Stash locations as they are set by user right now locations = {} - for service in services: + for service in get_services(): if service.is_movable(): locations[service.get_id()] = service.get_drive() @@ -164,7 +170,7 @@ class ServiceManager(Service): cls.retrieve_stashed_path(p) # Pop locations - for service in services: + for service in get_services(): if service.is_movable(): device = BlockDevices().get_block_device(locations[service.get_id()]) if device is not None: @@ -250,6 +256,25 @@ class ServiceManager(Service): rmtree(cls.dump_dir(), ignore_errors=True) +@redis_cached_call(ttl=30) +def get_templated_service(service_id: str) -> TemplatedService: + return TemplatedService(service_id) + + +@redis_cached_call(ttl=3600) +def get_remote_service(url: str, id: str) -> TemplatedService: + # Get JSON from calling the sp-fetch-remote-module command with the URL + # Parse the JSON into a TemplatedService object + response = subprocess.run( + ["sp-fetch-remote-module", url], + capture_output=True, + text=True, + check=True, + ) + return TemplatedService(id, response.stdout) + + +@redis_cached_call(ttl=5) def get_services() -> List[Service]: hardcoded_services: list[Service] = [ Bitwarden(), @@ -263,20 +288,40 @@ def get_services() -> List[Service]: ServiceManager(), Prometheus(), ] - hardcoded_services_ids = [service.get_id() for service in hardcoded_services] + service_ids = [service.get_id() for service in hardcoded_services] # Load services from SP_MODULES_DEFENITIONS_PATH templated_services: List[Service] = [] if path.exists(SP_MODULES_DEFENITIONS_PATH): for module in listdir(SP_MODULES_DEFENITIONS_PATH): - if module in hardcoded_services_ids: + if module in service_ids: continue try: - templated_services.append(TemplatedService(module)) + templated_services.append(get_templated_service(module)) + service_ids.append(module) + except Exception as e: + logger.error(f"Failed to load service {module}: {e}") + + if path.exists(SP_SUGGESTED_MODULES_PATH): + # It is a file with a JSON array + with open(SP_SUGGESTED_MODULES_PATH) as f: + suggested_modules = f.read().splitlines() + for module in suggested_modules: + if module in service_ids: + continue + try: + # TODO: Replace the branch! + templated_services.append( + get_remote_service( + module, + f"git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=inex/experimental-templating&dir=sp-modules/{module}", + ) + ) + service_ids.append(module) except Exception as e: logger.error(f"Failed to load service {module}: {e}") return hardcoded_services + templated_services -services = get_services() +# services = get_services() diff --git a/selfprivacy_api/services/templated_service.py b/selfprivacy_api/services/templated_service.py index ad1fc75..eaf0c10 100644 --- a/selfprivacy_api/services/templated_service.py +++ b/selfprivacy_api/services/templated_service.py @@ -338,13 +338,16 @@ class ServiceMetaData(BaseSchema): class TemplatedService(Service): """Class representing a dynamically loaded service.""" - def __init__(self, service_id: str) -> None: - # Check if the service exists - if not path.exists(join(SP_MODULES_DEFENITIONS_PATH, service_id)): - raise FileNotFoundError(f"Service {service_id} not found") - # Load the service - with open(join(SP_MODULES_DEFENITIONS_PATH, service_id)) as file: - self.definition_data = json.load(file) + def __init__(self, service_id: str, source_data: Optional[str] = None) -> None: + if source_data: + self.definition_data = json.loads(source_data) + else: + # Check if the service exists + if not path.exists(join(SP_MODULES_DEFENITIONS_PATH, service_id)): + raise FileNotFoundError(f"Service {service_id} not found") + # Load the service + with open(join(SP_MODULES_DEFENITIONS_PATH, service_id)) as file: + self.definition_data = json.load(file) # Check if required fields are present if "meta" not in self.definition_data: raise ValueError("meta not found in service definition") diff --git a/selfprivacy_api/utils/cached_call.py b/selfprivacy_api/utils/cached_call.py new file mode 100644 index 0000000..6fbdd55 --- /dev/null +++ b/selfprivacy_api/utils/cached_call.py @@ -0,0 +1,48 @@ +import asyncio +import pickle +from functools import wraps +from typing import Any, Optional, Callable + +from selfprivacy_api.utils.redis_pool import RedisPool + +CACHE_PREFIX = "exec_cache:" + + +async def get_redis_object(key: str) -> Optional[Any]: + redis = RedisPool().get_connection_async() + binary_obj = await redis.get(key) + if binary_obj is None: + return None + return pickle.loads(binary_obj) + + +async def save_redis_object(key: str, obj: Any, expire: Optional[int] = 60) -> None: + redis = RedisPool().get_connection_async() + binary_obj = pickle.dumps(obj) + if expire: + await redis.setex(key, expire, binary_obj) + else: + await redis.set(key, binary_obj) + + +def redis_cached_call(ttl: Optional[int] = 60) -> Callable[..., Callable]: + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs) -> Any: + key = f"{CACHE_PREFIX}{func.__name__}:{args}:{kwargs}" + cached_value = await get_redis_object(key) + if cached_value is not None: + return cached_value + + if asyncio.iscoroutinefunction(func): + result = await func(*args, **kwargs) + else: + result = func(*args, **kwargs) + + await save_redis_object(key, result, ttl) + + return result + + return wrapper + + return decorator