diff --git a/selfprivacy_api/graphql/queries/services.py b/selfprivacy_api/graphql/queries/services.py index 2012303..420ade0 100644 --- a/selfprivacy_api/graphql/queries/services.py +++ b/selfprivacy_api/graphql/queries/services.py @@ -16,5 +16,8 @@ from selfprivacy_api.services import ServiceManager class Services: @strawberry.field def all_services(self) -> typing.List[Service]: - services = [service_to_graphql_service(service) for service in ServiceManager.get_all_services()] + services = [ + service_to_graphql_service(service) + for service in ServiceManager.get_all_services() + ] return sorted(services, key=lambda service: service.display_name) diff --git a/selfprivacy_api/services/__init__.py b/selfprivacy_api/services/__init__.py index a37ec7c..31db233 100644 --- a/selfprivacy_api/services/__init__.py +++ b/selfprivacy_api/services/__init__.py @@ -10,6 +10,8 @@ from os import path from os import makedirs from os import listdir from os.path import join +from functools import lru_cache + from shutil import copyfile, copytree, rmtree from selfprivacy_api.services.prometheus import Prometheus @@ -17,7 +19,7 @@ from selfprivacy_api.services.mailserver import MailServer from selfprivacy_api.services.service import Service, ServiceDnsRecord from selfprivacy_api.services.service import ServiceStatus -from selfprivacy_api.utils.cached_call import redis_cached_call +from selfprivacy_api.utils.cached_call import get_ttl_hash import selfprivacy_api.utils.network as network_utils from selfprivacy_api.services.api_icon import API_ICON @@ -40,28 +42,38 @@ class ServiceManager(Service): @staticmethod def get_all_services() -> list[Service]: - return get_services() + return get_services(ttl_hash=get_ttl_hash(5)) @staticmethod def get_service_by_id(service_id: str) -> typing.Optional[Service]: - for service in get_services(): + for service in get_services(ttl_hash=get_ttl_hash(5)): if service.get_id() == service_id: return service return None @staticmethod def get_enabled_services() -> list[Service]: - return [service for service in get_services() if service.is_enabled()] + return [ + service + for service in get_services(ttl_hash=get_ttl_hash(5)) + if service.is_enabled() + ] # This one is not currently used by any code. @staticmethod def get_disabled_services() -> list[Service]: - return [service for service in get_services() if not service.is_enabled()] + return [ + service + for service in get_services(ttl_hash=get_ttl_hash(5)) + if not service.is_enabled() + ] @staticmethod def get_services_by_location(location: str) -> list[Service]: return [ - service for service in get_services() if service.get_drive() == location + service + for service in get_services(ttl_hash=get_ttl_hash(5)) + if service.get_drive() == location ] @staticmethod @@ -158,7 +170,7 @@ class ServiceManager(Service): # For now we will just copy settings EXCEPT the locations of services # Stash locations as they are set by user right now locations = {} - for service in get_services(): + for service in get_services(ttl_hash=get_ttl_hash(5)): if service.is_movable(): locations[service.get_id()] = service.get_drive() @@ -167,7 +179,7 @@ class ServiceManager(Service): cls.retrieve_stashed_path(p) # Pop locations - for service in get_services(): + for service in get_services(ttl_hash=get_ttl_hash(5)): if service.is_movable(): device = BlockDevices().get_block_device(locations[service.get_id()]) if device is not None: @@ -253,13 +265,17 @@ class ServiceManager(Service): rmtree(cls.dump_dir(), ignore_errors=True) -@redis_cached_call(ttl=30) -def get_templated_service(service_id: str) -> TemplatedService: +# @redis_cached_call(ttl=30) +@lru_cache() +def get_templated_service(service_id: str, ttl_hash=None) -> TemplatedService: + del ttl_hash return TemplatedService(service_id) -@redis_cached_call(ttl=3600) -def get_remote_service(id: str, url: str) -> TemplatedService: +# @redis_cached_call(ttl=3600) +@lru_cache() +def get_remote_service(id: str, url: str, ttl_hash=None) -> TemplatedService: + del ttl_hash response = subprocess.run( ["sp-fetch-remote-module", url], capture_output=True, @@ -273,8 +289,11 @@ DUMMY_SERVICES = [] TEST_FLAGS: list[str] = [] -@redis_cached_call(ttl=5) -def get_services() -> List[Service]: +# @redis_cached_call(ttl=5) +@lru_cache(maxsize=1) +def get_services(ttl_hash=None) -> List[Service]: + del ttl_hash + if "ONLY_DUMMY_SERVICE" in TEST_FLAGS: return DUMMY_SERVICES if "DUMMY_SERVICE_AND_API" in TEST_FLAGS: @@ -295,7 +314,9 @@ def get_services() -> List[Service]: if module in service_ids: continue try: - templated_services.append(get_templated_service(module)) + templated_services.append( + get_templated_service(module, ttl_hash=get_ttl_hash(30)) + ) service_ids.append(module) except Exception as e: logger.error(f"Failed to load service {module}: {e}") @@ -313,6 +334,7 @@ def get_services() -> List[Service]: get_remote_service( module, f"git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=inex/experimental-templating&dir=sp-modules/{module}", + ttl_hash=get_ttl_hash(3600), ) ) service_ids.append(module) diff --git a/selfprivacy_api/utils/cached_call.py b/selfprivacy_api/utils/cached_call.py index c35d593..14195fb 100644 --- a/selfprivacy_api/utils/cached_call.py +++ b/selfprivacy_api/utils/cached_call.py @@ -1,44 +1,6 @@ -import pickle -from functools import wraps -from typing import Any, Optional, Callable - -from selfprivacy_api.utils.redis_pool import RedisPool - -CACHE_PREFIX = "exec_cache:" +import time -def get_redis_object(key: str) -> Optional[Any]: - redis = RedisPool().get_raw_connection() - binary_obj = redis.get(key) - if binary_obj is None: - return None - return pickle.loads(binary_obj) - - -def save_redis_object(key: str, obj: Any, expire: Optional[int] = 60) -> None: - redis = RedisPool().get_raw_connection() - binary_obj = pickle.dumps(obj) - if expire: - redis.setex(key, expire, binary_obj) - else: - redis.set(key, binary_obj) - - -def redis_cached_call(ttl: Optional[int] = 60) -> Callable[..., Callable]: - def decorator(func: Callable) -> Callable: - @wraps(func) - def wrapper(*args, **kwargs) -> Any: - key = f"{CACHE_PREFIX}{func.__name__}:{args}:{kwargs}" - cached_value = get_redis_object(key) - if cached_value is not None: - return cached_value - - result = func(*args, **kwargs) - - save_redis_object(key, result, ttl) - - return result - - return wrapper - - return decorator +def get_ttl_hash(seconds=3600): + """Return the same value withing `seconds` time period""" + return round(time.time() / seconds)