test(root_daemon): threading

This commit is contained in:
Houkime 2024-12-11 13:36:32 +00:00
parent 644f0783ee
commit 16300cd899
4 changed files with 89 additions and 26 deletions

View file

@ -8,6 +8,8 @@ import socket as socket_module
import subprocess
from tests.test_common import get_test_mode
SOCKET_PATH = "/tmp/socket_test.s"
BUFFER_SIZE = 1024
@ -76,21 +78,23 @@ def get_available_commands() -> List[str]:
return commands
def init() -> socket_module.socket:
if os.path.exists(SOCKET_PATH):
os.remove(SOCKET_PATH)
def init(socket_path=SOCKET_PATH) -> socket_module.socket:
if os.path.exists(socket_path):
os.remove(socket_path)
sock = socket_module.socket(socket_module.AF_UNIX, socket_module.SOCK_STREAM)
sock.bind(SOCKET_PATH)
sock.bind(socket_path)
assert os.path.exists(socket_path)
# raise ValueError(socket_path)
return sock
def _spawn_shell(command_string):
def _spawn_shell(command_string: str):
# We use sh to refrain from parsing and simplify logic
# Our commands are hardcoded so sh does not present
# an extra attack surface here
# TODO: continuous forwarding of command output
subprocess.check_output("sh", "-c", command_string)
subprocess.check_output(["sh", "-c", command_string])
def _process_request(request: str, allowed_commands: str) -> str:
@ -98,27 +102,40 @@ def _process_request(request: str, allowed_commands: str) -> str:
if request == command:
# explicitly only calling a _hardcoded_ command
# ever
_spawn_shell(command)
# test mode made like this does not make it more dangerous too
raise ValueError("Oh no")
if get_test_mode():
_spawn_shell(f'echo "{command}"')
else:
_spawn_shell(command)
else:
return "-1"
def _root_loop(socket: socket_module.socket, allowed_commands):
while True:
socket.listen(1)
conn, addr = socket.accept()
datagram = conn.recv(BUFFER_SIZE)
if datagram:
request = datagram.strip().decode("utf-8")
answer = _process_request(request, allowed_commands)
conn.send(answer.encode("utf-8"))
conn.close()
socket.listen(1)
def main():
# in seconds
socket.settimeout(1.0) # we do it so that we can throw exceptions into the loop
while True:
try:
conn, addr = socket.accept()
except TimeoutError:
continue
pipe = conn.makefile("rw")
# We accept a single line per connection for simplicity and safety
line = pipe.readline()
request = line.strip()
answer = _process_request(request, allowed_commands)
conn.send(answer.encode("utf-8"))
conn.close()
def main(socket_path=SOCKET_PATH):
allowed_commands = get_available_commands()
print("\n".join(allowed_commands))
sock = init()
sock = init(socket_path)
_root_loop(sock, allowed_commands)

View file

@ -0,0 +1,24 @@
from typing import List
# from subprocess import check_output
from selfprivacy_api.root_daemon import SOCKET_PATH, socket_module
from tests.test_common import get_test_mode
def call_root_function(cmd: List[str]) -> str:
if get_test_mode():
return "done"
else:
return _call_root_daemon(cmd)
def _call_root_daemon(cmd: List[str]) -> str:
return _write_to_daemon_socket(cmd)
def _write_to_daemon_socket(cmd: List[str]) -> str:
sock = socket_module.socket(socket_module.AF_UNIX, socket_module.SOCK_STREAM)
sock.connect(SOCKET_PATH)
sock.send(" ".join(cmd).encode("utf-8")+b"\n")
pipe = sock.makefile("rw")
line = pipe.readline()
return line

View file

@ -2,6 +2,7 @@
# pylint: disable=unused-argument
import os
import pytest
from typing import Optional
from selfprivacy_api.utils import WriteUserData, ReadUserData
@ -30,9 +31,15 @@ def test_write_invalid_user_data():
pass
def get_test_mode() -> Optional[str]:
return os.environ.get("TEST_MODE")
# TODO: Does it make any sense to have such a fixture though?
# If it can only be called from tests then it is always test
@pytest.fixture
def test_mode():
return os.environ.get("TEST_MODE")
return get_test_mode()
def test_the_test_mode(test_mode):

View file

@ -1,4 +1,8 @@
import pytest
import os
import asyncio
import threading
import subprocess
from selfprivacy_api.root_daemon import (
get_available_commands,
@ -7,7 +11,9 @@ from selfprivacy_api.root_daemon import (
service_commands,
services,
)
import selfprivacy_api.root_daemon
import selfprivacy_api
import selfprivacy_api.root_daemon as root_daemon
from selfprivacy_api.utils.root_interface import call_root_function
from os.path import join, exists
from typing import List
@ -35,12 +41,21 @@ def test_available_commands():
assert is_in_strings(commands, service)
def test_init(test_socket):
def test_init():
sock = init()
assert exists(test_socket)
assert exists(root_daemon.SOCKET_PATH)
assert sock is not None
def test_main():
# main()
pass
def test_send_command():
root_daemon_file = selfprivacy_api.root_daemon.__file__
# this is a prototype of how we need to run it`
proc = subprocess.Popen(args=["python", root_daemon_file], shell=False)
# thread = threading.Thread(target=start_root_daemon,args=[])
# thread.start()
answer = call_root_function("blabla")
assert answer == "done"
proc.kill()
# thread.join(timeout=1.0)