fix(security): Do not use Pickle for caching

This commit is contained in:
Inex Code 2024-12-24 16:13:55 +03:00
parent 1d7879acc1
commit 1f343815eb
No known key found for this signature in database
3 changed files with 45 additions and 58 deletions

View file

@ -16,5 +16,8 @@ from selfprivacy_api.services import ServiceManager
class Services: class Services:
@strawberry.field @strawberry.field
def all_services(self) -> typing.List[Service]: def all_services(self) -> typing.List[Service]:
services = [service_to_graphql_service(service) for service in ServiceManager.get_all_services()] services = [
service_to_graphql_service(service)
for service in ServiceManager.get_all_services()
]
return sorted(services, key=lambda service: service.display_name) return sorted(services, key=lambda service: service.display_name)

View file

@ -10,6 +10,8 @@ from os import path
from os import makedirs from os import makedirs
from os import listdir from os import listdir
from os.path import join from os.path import join
from functools import lru_cache
from shutil import copyfile, copytree, rmtree from shutil import copyfile, copytree, rmtree
from selfprivacy_api.services.prometheus import Prometheus from selfprivacy_api.services.prometheus import Prometheus
@ -17,7 +19,7 @@ from selfprivacy_api.services.mailserver import MailServer
from selfprivacy_api.services.service import Service, ServiceDnsRecord from selfprivacy_api.services.service import Service, ServiceDnsRecord
from selfprivacy_api.services.service import ServiceStatus from selfprivacy_api.services.service import ServiceStatus
from selfprivacy_api.utils.cached_call import redis_cached_call from selfprivacy_api.utils.cached_call import get_ttl_hash
import selfprivacy_api.utils.network as network_utils import selfprivacy_api.utils.network as network_utils
from selfprivacy_api.services.api_icon import API_ICON from selfprivacy_api.services.api_icon import API_ICON
@ -40,28 +42,38 @@ class ServiceManager(Service):
@staticmethod @staticmethod
def get_all_services() -> list[Service]: def get_all_services() -> list[Service]:
return get_services() return get_services(ttl_hash=get_ttl_hash(5))
@staticmethod @staticmethod
def get_service_by_id(service_id: str) -> typing.Optional[Service]: def get_service_by_id(service_id: str) -> typing.Optional[Service]:
for service in get_services(): for service in get_services(ttl_hash=get_ttl_hash(5)):
if service.get_id() == service_id: if service.get_id() == service_id:
return service return service
return None return None
@staticmethod @staticmethod
def get_enabled_services() -> list[Service]: def get_enabled_services() -> list[Service]:
return [service for service in get_services() if service.is_enabled()] return [
service
for service in get_services(ttl_hash=get_ttl_hash(5))
if service.is_enabled()
]
# This one is not currently used by any code. # This one is not currently used by any code.
@staticmethod @staticmethod
def get_disabled_services() -> list[Service]: def get_disabled_services() -> list[Service]:
return [service for service in get_services() if not service.is_enabled()] return [
service
for service in get_services(ttl_hash=get_ttl_hash(5))
if not service.is_enabled()
]
@staticmethod @staticmethod
def get_services_by_location(location: str) -> list[Service]: def get_services_by_location(location: str) -> list[Service]:
return [ return [
service for service in get_services() if service.get_drive() == location service
for service in get_services(ttl_hash=get_ttl_hash(5))
if service.get_drive() == location
] ]
@staticmethod @staticmethod
@ -158,7 +170,7 @@ class ServiceManager(Service):
# For now we will just copy settings EXCEPT the locations of services # For now we will just copy settings EXCEPT the locations of services
# Stash locations as they are set by user right now # Stash locations as they are set by user right now
locations = {} locations = {}
for service in get_services(): for service in get_services(ttl_hash=get_ttl_hash(5)):
if service.is_movable(): if service.is_movable():
locations[service.get_id()] = service.get_drive() locations[service.get_id()] = service.get_drive()
@ -167,7 +179,7 @@ class ServiceManager(Service):
cls.retrieve_stashed_path(p) cls.retrieve_stashed_path(p)
# Pop locations # Pop locations
for service in get_services(): for service in get_services(ttl_hash=get_ttl_hash(5)):
if service.is_movable(): if service.is_movable():
device = BlockDevices().get_block_device(locations[service.get_id()]) device = BlockDevices().get_block_device(locations[service.get_id()])
if device is not None: if device is not None:
@ -253,13 +265,17 @@ class ServiceManager(Service):
rmtree(cls.dump_dir(), ignore_errors=True) rmtree(cls.dump_dir(), ignore_errors=True)
@redis_cached_call(ttl=30) # @redis_cached_call(ttl=30)
def get_templated_service(service_id: str) -> TemplatedService: @lru_cache()
def get_templated_service(service_id: str, ttl_hash=None) -> TemplatedService:
del ttl_hash
return TemplatedService(service_id) return TemplatedService(service_id)
@redis_cached_call(ttl=3600) # @redis_cached_call(ttl=3600)
def get_remote_service(id: str, url: str) -> TemplatedService: @lru_cache()
def get_remote_service(id: str, url: str, ttl_hash=None) -> TemplatedService:
del ttl_hash
response = subprocess.run( response = subprocess.run(
["sp-fetch-remote-module", url], ["sp-fetch-remote-module", url],
capture_output=True, capture_output=True,
@ -273,8 +289,11 @@ DUMMY_SERVICES = []
TEST_FLAGS: list[str] = [] TEST_FLAGS: list[str] = []
@redis_cached_call(ttl=5) # @redis_cached_call(ttl=5)
def get_services() -> List[Service]: @lru_cache(maxsize=1)
def get_services(ttl_hash=None) -> List[Service]:
del ttl_hash
if "ONLY_DUMMY_SERVICE" in TEST_FLAGS: if "ONLY_DUMMY_SERVICE" in TEST_FLAGS:
return DUMMY_SERVICES return DUMMY_SERVICES
if "DUMMY_SERVICE_AND_API" in TEST_FLAGS: if "DUMMY_SERVICE_AND_API" in TEST_FLAGS:
@ -295,7 +314,9 @@ def get_services() -> List[Service]:
if module in service_ids: if module in service_ids:
continue continue
try: try:
templated_services.append(get_templated_service(module)) templated_services.append(
get_templated_service(module, ttl_hash=get_ttl_hash(30))
)
service_ids.append(module) service_ids.append(module)
except Exception as e: except Exception as e:
logger.error(f"Failed to load service {module}: {e}") logger.error(f"Failed to load service {module}: {e}")
@ -313,6 +334,7 @@ def get_services() -> List[Service]:
get_remote_service( get_remote_service(
module, module,
f"git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=inex/experimental-templating&dir=sp-modules/{module}", f"git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=inex/experimental-templating&dir=sp-modules/{module}",
ttl_hash=get_ttl_hash(3600),
) )
) )
service_ids.append(module) service_ids.append(module)

View file

@ -1,44 +1,6 @@
import pickle import time
from functools import wraps
from typing import Any, Optional, Callable
from selfprivacy_api.utils.redis_pool import RedisPool
CACHE_PREFIX = "exec_cache:"
def get_redis_object(key: str) -> Optional[Any]: def get_ttl_hash(seconds=3600):
redis = RedisPool().get_raw_connection() """Return the same value withing `seconds` time period"""
binary_obj = redis.get(key) return round(time.time() / seconds)
if binary_obj is None:
return None
return pickle.loads(binary_obj)
def save_redis_object(key: str, obj: Any, expire: Optional[int] = 60) -> None:
redis = RedisPool().get_raw_connection()
binary_obj = pickle.dumps(obj)
if expire:
redis.setex(key, expire, binary_obj)
else:
redis.set(key, binary_obj)
def redis_cached_call(ttl: Optional[int] = 60) -> Callable[..., Callable]:
def decorator(func: Callable) -> Callable:
@wraps(func)
def wrapper(*args, **kwargs) -> Any:
key = f"{CACHE_PREFIX}{func.__name__}:{args}:{kwargs}"
cached_value = get_redis_object(key)
if cached_value is not None:
return cached_value
result = func(*args, **kwargs)
save_redis_object(key, result, ttl)
return result
return wrapper
return decorator