diff --git a/selfprivacy_api/root_daemon.py b/selfprivacy_api/root_daemon.py index c18c3d4..d32b374 100644 --- a/selfprivacy_api/root_daemon.py +++ b/selfprivacy_api/root_daemon.py @@ -8,6 +8,8 @@ import socket as socket_module import subprocess +from tests.test_common import get_test_mode + SOCKET_PATH = "/tmp/socket_test.s" BUFFER_SIZE = 1024 @@ -76,21 +78,23 @@ def get_available_commands() -> List[str]: return commands -def init() -> socket_module.socket: - if os.path.exists(SOCKET_PATH): - os.remove(SOCKET_PATH) +def init(socket_path=SOCKET_PATH) -> socket_module.socket: + if os.path.exists(socket_path): + os.remove(socket_path) sock = socket_module.socket(socket_module.AF_UNIX, socket_module.SOCK_STREAM) - sock.bind(SOCKET_PATH) + sock.bind(socket_path) + assert os.path.exists(socket_path) + # raise ValueError(socket_path) return sock -def _spawn_shell(command_string): +def _spawn_shell(command_string: str): # We use sh to refrain from parsing and simplify logic # Our commands are hardcoded so sh does not present # an extra attack surface here # TODO: continuous forwarding of command output - subprocess.check_output("sh", "-c", command_string) + subprocess.check_output(["sh", "-c", command_string]) def _process_request(request: str, allowed_commands: str) -> str: @@ -98,29 +102,42 @@ def _process_request(request: str, allowed_commands: str) -> str: if request == command: # explicitly only calling a _hardcoded_ command # ever - _spawn_shell(command) + # test mode made like this does not make it more dangerous too + raise ValueError("Oh no") + if get_test_mode(): + _spawn_shell(f'echo "{command}"') + else: + _spawn_shell(command) else: return "-1" def _root_loop(socket: socket_module.socket, allowed_commands): - while True: - socket.listen(1) - conn, addr = socket.accept() - datagram = conn.recv(BUFFER_SIZE) - if datagram: - request = datagram.strip().decode("utf-8") - answer = _process_request(request, allowed_commands) - conn.send(answer.encode("utf-8")) - conn.close() + socket.listen(1) -def main(): + # in seconds + socket.settimeout(1.0) # we do it so that we can throw exceptions into the loop + while True: + try: + conn, addr = socket.accept() + except TimeoutError: + continue + pipe = conn.makefile("rw") + # We accept a single line per connection for simplicity and safety + line = pipe.readline() + request = line.strip() + answer = _process_request(request, allowed_commands) + conn.send(answer.encode("utf-8")) + conn.close() + + +def main(socket_path=SOCKET_PATH): allowed_commands = get_available_commands() print("\n".join(allowed_commands)) - sock = init() + sock = init(socket_path) _root_loop(sock, allowed_commands) - + if __name__ == "__main__": main() diff --git a/selfprivacy_api/utils/root_interface.py b/selfprivacy_api/utils/root_interface.py new file mode 100644 index 0000000..2296da8 --- /dev/null +++ b/selfprivacy_api/utils/root_interface.py @@ -0,0 +1,24 @@ +from typing import List +# from subprocess import check_output +from selfprivacy_api.root_daemon import SOCKET_PATH, socket_module +from tests.test_common import get_test_mode + + +def call_root_function(cmd: List[str]) -> str: + if get_test_mode(): + return "done" + else: + return _call_root_daemon(cmd) + +def _call_root_daemon(cmd: List[str]) -> str: + return _write_to_daemon_socket(cmd) + + +def _write_to_daemon_socket(cmd: List[str]) -> str: + sock = socket_module.socket(socket_module.AF_UNIX, socket_module.SOCK_STREAM) + sock.connect(SOCKET_PATH) + sock.send(" ".join(cmd).encode("utf-8")+b"\n") + pipe = sock.makefile("rw") + line = pipe.readline() + return line + diff --git a/tests/test_common.py b/tests/test_common.py index 7dd3652..5e77d80 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -2,6 +2,7 @@ # pylint: disable=unused-argument import os import pytest +from typing import Optional from selfprivacy_api.utils import WriteUserData, ReadUserData @@ -30,9 +31,15 @@ def test_write_invalid_user_data(): pass +def get_test_mode() -> Optional[str]: + return os.environ.get("TEST_MODE") + + +# TODO: Does it make any sense to have such a fixture though? +# If it can only be called from tests then it is always test @pytest.fixture def test_mode(): - return os.environ.get("TEST_MODE") + return get_test_mode() def test_the_test_mode(test_mode): diff --git a/tests/test_root_daemon.py b/tests/test_root_daemon.py index 62ff872..846b1f1 100644 --- a/tests/test_root_daemon.py +++ b/tests/test_root_daemon.py @@ -1,4 +1,8 @@ import pytest +import os +import asyncio +import threading +import subprocess from selfprivacy_api.root_daemon import ( get_available_commands, @@ -7,7 +11,9 @@ from selfprivacy_api.root_daemon import ( service_commands, services, ) -import selfprivacy_api.root_daemon +import selfprivacy_api +import selfprivacy_api.root_daemon as root_daemon +from selfprivacy_api.utils.root_interface import call_root_function from os.path import join, exists from typing import List @@ -35,12 +41,21 @@ def test_available_commands(): assert is_in_strings(commands, service) -def test_init(test_socket): +def test_init(): sock = init() - assert exists(test_socket) + assert exists(root_daemon.SOCKET_PATH) assert sock is not None -def test_main(): - # main() - pass +def test_send_command(): + root_daemon_file = selfprivacy_api.root_daemon.__file__ + # this is a prototype of how we need to run it` + proc = subprocess.Popen(args=["python", root_daemon_file], shell=False) + + # thread = threading.Thread(target=start_root_daemon,args=[]) + # thread.start() + answer = call_root_function("blabla") + assert answer == "done" + + proc.kill() + # thread.join(timeout=1.0)