Merge branch 'master' into def/nix-collect-garbage-endpoint

This commit is contained in:
dettlaff 2024-01-23 21:01:08 +04:00
commit d923b04aef
179 changed files with 5404 additions and 9182 deletions

2
.gitignore vendored
View file

@ -148,3 +148,5 @@ cython_debug/
*.db *.db
*.rdb *.rdb
/result

73
README.md Normal file
View file

@ -0,0 +1,73 @@
# SelfPrivacy GraphQL API which allows app to control your server
## build
```console
$ nix build
```
As a result, you should get the `./result` symlink to a folder (in `/nix/store`) with build contents.
## develop & test
```console
$ nix develop
$ [SP devshell] pytest .
=================================== test session starts =====================================
platform linux -- Python 3.10.11, pytest-7.1.3, pluggy-1.0.0
rootdir: /data/selfprivacy/selfprivacy-rest-api
plugins: anyio-3.5.0, datadir-1.4.1, mock-3.8.2
collected 692 items
tests/test_block_device_utils.py ................. [ 2%]
tests/test_common.py ..... [ 3%]
tests/test_jobs.py ........ [ 4%]
tests/test_model_storage.py .. [ 4%]
tests/test_models.py .. [ 4%]
tests/test_network_utils.py ...... [ 5%]
tests/test_services.py ...... [ 6%]
tests/test_graphql/test_api.py . [ 6%]
tests/test_graphql/test_api_backup.py ............... [ 8%]
tests/test_graphql/test_api_devices.py ................. [ 11%]
tests/test_graphql/test_api_recovery.py ......... [ 12%]
tests/test_graphql/test_api_version.py .. [ 13%]
tests/test_graphql/test_backup.py ............................... [ 21%]
tests/test_graphql/test_localsecret.py ... [ 22%]
tests/test_graphql/test_ssh.py ............ [ 23%]
tests/test_graphql/test_system.py ............................. [ 28%]
tests/test_graphql/test_system_nixos_tasks.py ........ [ 29%]
tests/test_graphql/test_users.py .................................. [ 42%]
tests/test_graphql/test_repository/test_json_tokens_repository.py [ 44%]
tests/test_graphql/test_repository/test_tokens_repository.py .... [ 53%]
tests/test_rest_endpoints/test_auth.py .......................... [ 58%]
tests/test_rest_endpoints/test_system.py ........................ [ 63%]
tests/test_rest_endpoints/test_users.py ................................ [ 76%]
tests/test_rest_endpoints/services/test_bitwarden.py ............ [ 78%]
tests/test_rest_endpoints/services/test_gitea.py .............. [ 80%]
tests/test_rest_endpoints/services/test_mailserver.py ..... [ 81%]
tests/test_rest_endpoints/services/test_nextcloud.py ............ [ 83%]
tests/test_rest_endpoints/services/test_ocserv.py .............. [ 85%]
tests/test_rest_endpoints/services/test_pleroma.py .............. [ 87%]
tests/test_rest_endpoints/services/test_services.py .... [ 88%]
tests/test_rest_endpoints/services/test_ssh.py ..................... [100%]
============================== 692 passed in 352.76s (0:05:52) ===============================
```
If you don't have experimental flakes enabled, you can use the following command:
```console
nix --extra-experimental-features nix-command --extra-experimental-features flakes develop
```
## dependencies and dependant modules
Current flake inherits nixpkgs from NixOS configuration flake. So there is no need to refer to extra nixpkgs dependency if you want to be aligned with exact NixOS configuration.
![diagram](http://www.plantuml.com/plantuml/proxy?src=https://git.selfprivacy.org/SelfPrivacy/selfprivacy-rest-api/raw/branch/master/nix-dependencies-diagram.puml)
Nix code for NixOS service module for API is located in NixOS configuration repository.
## current issues
- It's not clear how to store in this repository information about several compatible NixOS configuration commits, where API application tests pass. Currently, here is only a single `flake.lock`.

33
default.nix Normal file
View file

@ -0,0 +1,33 @@
{ pythonPackages, rev ? "local" }:
pythonPackages.buildPythonPackage rec {
pname = "selfprivacy-graphql-api";
version = rev;
src = builtins.filterSource (p: t: p != ".git" && t != "symlink") ./.;
nativeCheckInputs = [ pythonPackages.pytestCheckHook ];
propagatedBuildInputs = with pythonPackages; [
fastapi
gevent
huey
mnemonic
portalocker
psutil
pydantic
pytest
pytest-datadir
pytest-mock
pytz
redis
setuptools
strawberry-graphql
typing-extensions
uvicorn
];
pythonImportsCheck = [ "selfprivacy_api" ];
doCheck = false;
meta = {
description = ''
SelfPrivacy Server Management API
'';
};
}

26
flake.lock Normal file
View file

@ -0,0 +1,26 @@
{
"nodes": {
"nixpkgs": {
"locked": {
"lastModified": 1702780907,
"narHash": "sha256-blbrBBXjjZt6OKTcYX1jpe9SRof2P9ZYWPzq22tzXAA=",
"owner": "nixos",
"repo": "nixpkgs",
"rev": "1e2e384c5b7c50dbf8e9c441a9e58d85f408b01f",
"type": "github"
},
"original": {
"owner": "nixos",
"repo": "nixpkgs",
"type": "github"
}
},
"root": {
"inputs": {
"nixpkgs": "nixpkgs"
}
}
},
"root": "root",
"version": 7
}

50
flake.nix Normal file
View file

@ -0,0 +1,50 @@
{
description = "SelfPrivacy API flake";
inputs.nixpkgs.url = "github:nixos/nixpkgs";
outputs = { self, nixpkgs, ... }:
let
system = "x86_64-linux";
pkgs = nixpkgs.legacyPackages.${system};
selfprivacy-graphql-api = pkgs.callPackage ./default.nix {
pythonPackages = pkgs.python310Packages;
rev = self.shortRev or self.dirtyShortRev or "dirty";
};
in
{
packages.${system}.default = selfprivacy-graphql-api;
nixosModules.default =
import ./nixos/module.nix self.packages.${system}.default;
devShells.${system}.default = pkgs.mkShell {
packages =
let
# TODO is there a better way to get environment for VS Code?
python3 =
nixpkgs.lib.findFirst (p: p.pname == "python3") (abort "wtf")
self.packages.${system}.default.propagatedBuildInputs;
python-env =
python3.withPackages
(_: self.packages.${system}.default.propagatedBuildInputs);
in
with pkgs; [
python-env
black
rclone
redis
restic
];
shellHook = ''
# envs set with export and as attributes are treated differently.
# for example. printenv <Name> will not fetch the value of an attribute.
export USE_REDIS_PORT=6379
export TEST_MODE=true
pkill redis-server
sleep 2
setsid redis-server --bind 127.0.0.1 --port $USE_REDIS_PORT >/dev/null 2>/dev/null &
# maybe set more env-vars
'';
};
};
nixConfig.bash-prompt = ''\n\[\e[1;32m\][\[\e[0m\]\[\e[1;34m\]SP devshell\[\e[0m\]\[\e[1;32m\]:\w]\$\[\[\e[0m\] '';
}

View file

@ -0,0 +1,22 @@
@startuml
left to right direction
title repositories and flake inputs relations diagram
cloud nixpkgs as nixpkgs_transit
control "<font:monospaced><size:15>nixos-rebuild" as nixos_rebuild
component "SelfPrivacy\nAPI app" as selfprivacy_app
component "SelfPrivacy\nNixOS configuration" as nixos_configuration
note top of nixos_configuration : SelfPrivacy\nAPI service module
nixos_configuration ).. nixpkgs_transit
nixpkgs_transit ..> selfprivacy_app
selfprivacy_app --> nixos_configuration
[nixpkgs] --> nixos_configuration
nixos_configuration -> nixos_rebuild
footer %date("yyyy-MM-dd'T'HH:mmZ")
@enduml

166
nixos/module.nix Normal file
View file

@ -0,0 +1,166 @@
selfprivacy-graphql-api: { config, lib, pkgs, ... }:
let
cfg = config.services.selfprivacy-api;
config-id = "default";
nixos-rebuild = "${config.system.build.nixos-rebuild}/bin/nixos-rebuild";
nix = "${config.nix.package.out}/bin/nix";
in
{
options.services.selfprivacy-api = {
enable = lib.mkOption {
default = true;
type = lib.types.bool;
description = ''
Enable SelfPrivacy API service
'';
};
};
config = lib.mkIf cfg.enable {
users.users."selfprivacy-api" = {
isNormalUser = false;
isSystemUser = true;
extraGroups = [ "opendkim" ];
group = "selfprivacy-api";
};
users.groups."selfprivacy-api".members = [ "selfprivacy-api" ];
systemd.services.selfprivacy-api = {
description = "API Server used to control system from the mobile application";
environment = config.nix.envVars // {
HOME = "/root";
PYTHONUNBUFFERED = "1";
} // config.networking.proxy.envVars;
path = [
"/var/"
"/var/dkim/"
pkgs.coreutils
pkgs.gnutar
pkgs.xz.bin
pkgs.gzip
pkgs.gitMinimal
config.nix.package.out
pkgs.restic
pkgs.mkpasswd
pkgs.util-linux
pkgs.e2fsprogs
pkgs.iproute2
];
after = [ "network-online.target" ];
wantedBy = [ "network-online.target" ];
serviceConfig = {
User = "root";
ExecStart = "${selfprivacy-graphql-api}/bin/app.py";
Restart = "always";
RestartSec = "5";
};
};
systemd.services.selfprivacy-api-worker = {
description = "Task worker for SelfPrivacy API";
environment = config.nix.envVars // {
HOME = "/root";
PYTHONUNBUFFERED = "1";
PYTHONPATH =
pkgs.python310Packages.makePythonPath [ selfprivacy-graphql-api ];
} // config.networking.proxy.envVars;
path = [
"/var/"
"/var/dkim/"
pkgs.coreutils
pkgs.gnutar
pkgs.xz.bin
pkgs.gzip
pkgs.gitMinimal
config.nix.package.out
pkgs.restic
pkgs.mkpasswd
pkgs.util-linux
pkgs.e2fsprogs
pkgs.iproute2
];
after = [ "network-online.target" ];
wantedBy = [ "network-online.target" ];
serviceConfig = {
User = "root";
ExecStart = "${pkgs.python310Packages.huey}/bin/huey_consumer.py selfprivacy_api.task_registry.huey";
Restart = "always";
RestartSec = "5";
};
};
# One shot systemd service to rebuild NixOS using nixos-rebuild
systemd.services.sp-nixos-rebuild = {
description = "nixos-rebuild switch";
environment = config.nix.envVars // {
HOME = "/root";
} // config.networking.proxy.envVars;
# TODO figure out how to get dependencies list reliably
path = [ pkgs.coreutils pkgs.gnutar pkgs.xz.bin pkgs.gzip pkgs.gitMinimal config.nix.package.out ];
# TODO set proper timeout for reboot instead of service restart
serviceConfig = {
User = "root";
WorkingDirectory = "/etc/nixos";
# sync top-level flake with sp-modules sub-flake
# (https://github.com/NixOS/nix/issues/9339)
ExecStartPre = ''
${nix} flake lock --override-input sp-modules path:./sp-modules
'';
ExecStart = ''
${nixos-rebuild} switch --flake .#${config-id}
'';
KillMode = "none";
SendSIGKILL = "no";
};
restartIfChanged = false;
unitConfig.X-StopOnRemoval = false;
};
# One shot systemd service to upgrade NixOS using nixos-rebuild
systemd.services.sp-nixos-upgrade = {
# protection against simultaneous runs
after = [ "sp-nixos-rebuild.service" ];
description = "Upgrade NixOS and SP modules to latest versions";
environment = config.nix.envVars // {
HOME = "/root";
} // config.networking.proxy.envVars;
# TODO figure out how to get dependencies list reliably
path = [ pkgs.coreutils pkgs.gnutar pkgs.xz.bin pkgs.gzip pkgs.gitMinimal config.nix.package.out ];
serviceConfig = {
User = "root";
WorkingDirectory = "/etc/nixos";
# TODO get URL from systemd template parameter?
ExecStartPre = ''
${nix} flake update \
--override-input selfprivacy-nixos-config git+https://git.selfprivacy.org/SelfPrivacy/selfprivacy-nixos-config.git?ref=flakes
'';
ExecStart = ''
${nixos-rebuild} switch --flake .#${config-id}
'';
KillMode = "none";
SendSIGKILL = "no";
};
restartIfChanged = false;
unitConfig.X-StopOnRemoval = false;
};
# One shot systemd service to rollback NixOS using nixos-rebuild
systemd.services.sp-nixos-rollback = {
# protection against simultaneous runs
after = [ "sp-nixos-rebuild.service" "sp-nixos-upgrade.service" ];
description = "Rollback NixOS using nixos-rebuild";
environment = config.nix.envVars // {
HOME = "/root";
} // config.networking.proxy.envVars;
# TODO figure out how to get dependencies list reliably
path = [ pkgs.coreutils pkgs.gnutar pkgs.xz.bin pkgs.gzip pkgs.gitMinimal config.nix.package.out ];
serviceConfig = {
User = "root";
WorkingDirectory = "/etc/nixos";
ExecStart = ''
${nixos-rebuild} switch --rollback --flake .#${config-id}
'';
KillMode = "none";
SendSIGKILL = "no";
};
restartIfChanged = false;
unitConfig.X-StopOnRemoval = false;
};
};
}

View file

@ -1,11 +1,15 @@
"""App tokens actions""" """
from datetime import datetime App tokens actions.
The only actions on tokens that are accessible from APIs
"""
from datetime import datetime, timezone
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from mnemonic import Mnemonic from mnemonic import Mnemonic
from selfprivacy_api.repositories.tokens.json_tokens_repository import ( from selfprivacy_api.utils.timeutils import ensure_tz_aware, ensure_tz_aware_strict
JsonTokensRepository, from selfprivacy_api.repositories.tokens.redis_tokens_repository import (
RedisTokensRepository,
) )
from selfprivacy_api.repositories.tokens.exceptions import ( from selfprivacy_api.repositories.tokens.exceptions import (
TokenNotFound, TokenNotFound,
@ -14,7 +18,7 @@ from selfprivacy_api.repositories.tokens.exceptions import (
NewDeviceKeyNotFound, NewDeviceKeyNotFound,
) )
TOKEN_REPO = JsonTokensRepository() TOKEN_REPO = RedisTokensRepository()
class TokenInfoWithIsCaller(BaseModel): class TokenInfoWithIsCaller(BaseModel):
@ -25,6 +29,14 @@ class TokenInfoWithIsCaller(BaseModel):
is_caller: bool is_caller: bool
def _naive(date_time: datetime) -> datetime:
if date_time is None:
return None
if date_time.tzinfo is not None:
date_time.astimezone(timezone.utc)
return date_time.replace(tzinfo=None)
def get_api_tokens_with_caller_flag(caller_token: str) -> list[TokenInfoWithIsCaller]: def get_api_tokens_with_caller_flag(caller_token: str) -> list[TokenInfoWithIsCaller]:
"""Get the tokens info""" """Get the tokens info"""
caller_name = TOKEN_REPO.get_token_by_token_string(caller_token).device_name caller_name = TOKEN_REPO.get_token_by_token_string(caller_token).device_name
@ -83,16 +95,22 @@ class RecoveryTokenStatus(BaseModel):
def get_api_recovery_token_status() -> RecoveryTokenStatus: def get_api_recovery_token_status() -> RecoveryTokenStatus:
"""Get the recovery token status""" """Get the recovery token status, timezone-aware"""
token = TOKEN_REPO.get_recovery_key() token = TOKEN_REPO.get_recovery_key()
if token is None: if token is None:
return RecoveryTokenStatus(exists=False, valid=False) return RecoveryTokenStatus(exists=False, valid=False)
is_valid = TOKEN_REPO.is_recovery_key_valid() is_valid = TOKEN_REPO.is_recovery_key_valid()
# New tokens are tz-aware, but older ones might not be
expiry_date = token.expires_at
if expiry_date is not None:
expiry_date = ensure_tz_aware_strict(expiry_date)
return RecoveryTokenStatus( return RecoveryTokenStatus(
exists=True, exists=True,
valid=is_valid, valid=is_valid,
date=token.created_at, date=ensure_tz_aware_strict(token.created_at),
expiration=token.expires_at, expiration=expiry_date,
uses_left=token.uses_left, uses_left=token.uses_left,
) )
@ -110,8 +128,9 @@ def get_new_api_recovery_key(
) -> str: ) -> str:
"""Get new recovery key""" """Get new recovery key"""
if expiration_date is not None: if expiration_date is not None:
current_time = datetime.now().timestamp() expiration_date = ensure_tz_aware(expiration_date)
if expiration_date.timestamp() < current_time: current_time = datetime.now(timezone.utc)
if expiration_date < current_time:
raise InvalidExpirationDate("Expiration date is in the past") raise InvalidExpirationDate("Expiration date is in the past")
if uses_left is not None: if uses_left is not None:
if uses_left <= 0: if uses_left <= 0:

View file

@ -31,7 +31,7 @@ def get_ssh_settings() -> UserdataSshSettings:
if "enable" not in data["ssh"]: if "enable" not in data["ssh"]:
data["ssh"]["enable"] = True data["ssh"]["enable"] = True
if "passwordAuthentication" not in data["ssh"]: if "passwordAuthentication" not in data["ssh"]:
data["ssh"]["passwordAuthentication"] = True data["ssh"]["passwordAuthentication"] = False
if "rootKeys" not in data["ssh"]: if "rootKeys" not in data["ssh"]:
data["ssh"]["rootKeys"] = [] data["ssh"]["rootKeys"] = []
return UserdataSshSettings(**data["ssh"]) return UserdataSshSettings(**data["ssh"])
@ -49,19 +49,6 @@ def set_ssh_settings(
data["ssh"]["passwordAuthentication"] = password_authentication data["ssh"]["passwordAuthentication"] = password_authentication
def add_root_ssh_key(public_key: str):
with WriteUserData() as data:
if "ssh" not in data:
data["ssh"] = {}
if "rootKeys" not in data["ssh"]:
data["ssh"]["rootKeys"] = []
# Return 409 if key already in array
for key in data["ssh"]["rootKeys"]:
if key == public_key:
raise KeyAlreadyExists()
data["ssh"]["rootKeys"].append(public_key)
class KeyAlreadyExists(Exception): class KeyAlreadyExists(Exception):
"""Key already exists""" """Key already exists"""

View file

@ -2,7 +2,7 @@
import os import os
import subprocess import subprocess
import pytz import pytz
from typing import Optional from typing import Optional, List
from pydantic import BaseModel from pydantic import BaseModel
from selfprivacy_api.utils import WriteUserData, ReadUserData from selfprivacy_api.utils import WriteUserData, ReadUserData
@ -13,7 +13,7 @@ def get_timezone() -> str:
with ReadUserData() as user_data: with ReadUserData() as user_data:
if "timezone" in user_data: if "timezone" in user_data:
return user_data["timezone"] return user_data["timezone"]
return "Europe/Uzhgorod" return "Etc/UTC"
class InvalidTimezone(Exception): class InvalidTimezone(Exception):
@ -58,36 +58,56 @@ def set_auto_upgrade_settings(
user_data["autoUpgrade"]["allowReboot"] = allowReboot user_data["autoUpgrade"]["allowReboot"] = allowReboot
class ShellException(Exception):
"""Something went wrong when calling another process"""
pass
def run_blocking(cmd: List[str], new_session: bool = False) -> str:
"""Run a process, block until done, return output, complain if failed"""
process_handle = subprocess.Popen(
cmd,
shell=False,
start_new_session=new_session,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
stdout_raw, stderr_raw = process_handle.communicate()
stdout = stdout_raw.decode("utf-8")
if stderr_raw is not None:
stderr = stderr_raw.decode("utf-8")
else:
stderr = ""
output = stdout + "\n" + stderr
if process_handle.returncode != 0:
raise ShellException(
f"Shell command failed, command array: {cmd}, output: {output}"
)
return stdout
def rebuild_system() -> int: def rebuild_system() -> int:
"""Rebuild the system""" """Rebuild the system"""
rebuild_result = subprocess.Popen( run_blocking(["systemctl", "start", "sp-nixos-rebuild.service"], new_session=True)
["systemctl", "start", "sp-nixos-rebuild.service"], start_new_session=True return 0
)
rebuild_result.communicate()[0]
return rebuild_result.returncode
def rollback_system() -> int: def rollback_system() -> int:
"""Rollback the system""" """Rollback the system"""
rollback_result = subprocess.Popen( run_blocking(["systemctl", "start", "sp-nixos-rollback.service"], new_session=True)
["systemctl", "start", "sp-nixos-rollback.service"], start_new_session=True return 0
)
rollback_result.communicate()[0]
return rollback_result.returncode
def upgrade_system() -> int: def upgrade_system() -> int:
"""Upgrade the system""" """Upgrade the system"""
upgrade_result = subprocess.Popen( run_blocking(["systemctl", "start", "sp-nixos-upgrade.service"], new_session=True)
["systemctl", "start", "sp-nixos-upgrade.service"], start_new_session=True return 0
)
upgrade_result.communicate()[0]
return upgrade_result.returncode
def reboot_system() -> None: def reboot_system() -> None:
"""Reboot the system""" """Reboot the system"""
subprocess.Popen(["reboot"], start_new_session=True) run_blocking(["reboot"], new_session=True)
def get_system_version() -> str: def get_system_version() -> str:

View file

@ -58,7 +58,7 @@ def get_users(
) )
for user in user_data["users"] for user in user_data["users"]
] ]
if not exclude_primary: if not exclude_primary and "username" in user_data.keys():
users.append( users.append(
UserDataUser( UserDataUser(
username=user_data["username"], username=user_data["username"],
@ -107,6 +107,12 @@ class PasswordIsEmpty(Exception):
pass pass
class InvalidConfiguration(Exception):
"""The userdata is broken"""
pass
def create_user(username: str, password: str): def create_user(username: str, password: str):
if password == "": if password == "":
raise PasswordIsEmpty("Password is empty") raise PasswordIsEmpty("Password is empty")
@ -124,6 +130,10 @@ def create_user(username: str, password: str):
with ReadUserData() as user_data: with ReadUserData() as user_data:
ensure_ssh_and_users_fields_exist(user_data) ensure_ssh_and_users_fields_exist(user_data)
if "username" not in user_data.keys():
raise InvalidConfiguration(
"Broken config: Admin name is not defined. Consider recovery or add it manually"
)
if username == user_data["username"]: if username == user_data["username"]:
raise UserAlreadyExists("User already exists") raise UserAlreadyExists("User already exists")
if username in [user["username"] for user in user_data["users"]]: if username in [user["username"] for user in user_data["users"]]:

View file

@ -10,12 +10,6 @@ from selfprivacy_api.dependencies import get_api_version
from selfprivacy_api.graphql.schema import schema from selfprivacy_api.graphql.schema import schema
from selfprivacy_api.migrations import run_migrations from selfprivacy_api.migrations import run_migrations
from selfprivacy_api.rest import (
system,
users,
api_auth,
services,
)
app = FastAPI() app = FastAPI()
@ -32,10 +26,6 @@ app.add_middleware(
) )
app.include_router(system.router)
app.include_router(users.router)
app.include_router(api_auth.router)
app.include_router(services.router)
app.include_router(graphql_app, prefix="/graphql") app.include_router(graphql_app, prefix="/graphql")

View file

@ -1,10 +1,11 @@
""" """
This module contains the controller class for backups. This module contains the controller class for backups.
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
import time
import os import os
from os import statvfs from os import statvfs
from typing import List, Optional from typing import Callable, List, Optional
from selfprivacy_api.utils import ReadUserData, WriteUserData from selfprivacy_api.utils import ReadUserData, WriteUserData
@ -23,7 +24,12 @@ from selfprivacy_api.jobs import Jobs, JobStatus, Job
from selfprivacy_api.graphql.queries.providers import ( from selfprivacy_api.graphql.queries.providers import (
BackupProvider as BackupProviderEnum, BackupProvider as BackupProviderEnum,
) )
from selfprivacy_api.graphql.common_types.backup import RestoreStrategy from selfprivacy_api.graphql.common_types.backup import (
RestoreStrategy,
BackupReason,
AutobackupQuotas,
)
from selfprivacy_api.models.backup.snapshot import Snapshot from selfprivacy_api.models.backup.snapshot import Snapshot
@ -32,6 +38,7 @@ from selfprivacy_api.backup.providers import get_provider
from selfprivacy_api.backup.storage import Storage from selfprivacy_api.backup.storage import Storage
from selfprivacy_api.backup.jobs import ( from selfprivacy_api.backup.jobs import (
get_backup_job, get_backup_job,
get_backup_fail,
add_backup_job, add_backup_job,
get_restore_job, get_restore_job,
add_restore_job, add_restore_job,
@ -51,6 +58,8 @@ BACKUP_PROVIDER_ENVS = {
"location": "BACKUP_LOCATION", "location": "BACKUP_LOCATION",
} }
AUTOBACKUP_JOB_EXPIRATION_SECONDS = 60 * 60 # one hour
class NotDeadError(AssertionError): class NotDeadError(AssertionError):
""" """
@ -70,6 +79,24 @@ class NotDeadError(AssertionError):
""" """
class RotationBucket:
"""
Bucket object used for rotation.
Has the following mutable fields:
- the counter, int
- the lambda function which takes datetime and the int and returns the int
- the last, int
"""
def __init__(self, counter: int, last: int, rotation_lambda):
self.counter: int = counter
self.last: int = last
self.rotation_lambda: Callable[[datetime, int], int] = rotation_lambda
def __str__(self) -> str:
return f"Bucket(counter={self.counter}, last={self.last})"
class Backups: class Backups:
"""A stateless controller class for backups""" """A stateless controller class for backups"""
@ -264,10 +291,12 @@ class Backups:
# Backup # Backup
@staticmethod @staticmethod
def back_up(service: Service) -> Snapshot: def back_up(
"""The top-level function to back up a service""" service: Service, reason: BackupReason = BackupReason.EXPLICIT
folders = service.get_folders() ) -> Snapshot:
tag = service.get_id() """The top-level function to back up a service
If it fails for any reason at all, it should both mark job as
errored and re-raise an error"""
job = get_backup_job(service) job = get_backup_job(service)
if job is None: if job is None:
@ -275,20 +304,132 @@ class Backups:
Jobs.update(job, status=JobStatus.RUNNING) Jobs.update(job, status=JobStatus.RUNNING)
try: try:
if service.can_be_backed_up() is False:
raise ValueError("cannot backup a non-backuppable service")
folders = service.get_folders()
service_name = service.get_id()
service.pre_backup() service.pre_backup()
snapshot = Backups.provider().backupper.start_backup( snapshot = Backups.provider().backupper.start_backup(
folders, folders,
tag, service_name,
reason=reason,
) )
Backups._store_last_snapshot(tag, snapshot)
Backups._store_last_snapshot(service_name, snapshot)
if reason == BackupReason.AUTO:
Backups._prune_auto_snaps(service)
service.post_restore() service.post_restore()
except Exception as error: except Exception as error:
Jobs.update(job, status=JobStatus.ERROR, status_text=str(error)) Jobs.update(job, status=JobStatus.ERROR, status_text=str(error))
raise error raise error
Jobs.update(job, status=JobStatus.FINISHED) Jobs.update(job, status=JobStatus.FINISHED)
if reason in [BackupReason.AUTO, BackupReason.PRE_RESTORE]:
Jobs.set_expiration(job, AUTOBACKUP_JOB_EXPIRATION_SECONDS)
return snapshot return snapshot
@staticmethod
def _auto_snaps(service):
return [
snap
for snap in Backups.get_snapshots(service)
if snap.reason == BackupReason.AUTO
]
@staticmethod
def _prune_snaps_with_quotas(snapshots: List[Snapshot]) -> List[Snapshot]:
# Function broken out for testability
# Sorting newest first
sorted_snaps = sorted(snapshots, key=lambda s: s.created_at, reverse=True)
quotas: AutobackupQuotas = Backups.autobackup_quotas()
buckets: list[RotationBucket] = [
RotationBucket(
quotas.last, # type: ignore
-1,
lambda _, index: index,
),
RotationBucket(
quotas.daily, # type: ignore
-1,
lambda date, _: date.year * 10000 + date.month * 100 + date.day,
),
RotationBucket(
quotas.weekly, # type: ignore
-1,
lambda date, _: date.year * 100 + date.isocalendar()[1],
),
RotationBucket(
quotas.monthly, # type: ignore
-1,
lambda date, _: date.year * 100 + date.month,
),
RotationBucket(
quotas.yearly, # type: ignore
-1,
lambda date, _: date.year,
),
]
new_snaplist: List[Snapshot] = []
for i, snap in enumerate(sorted_snaps):
keep_snap = False
for bucket in buckets:
if (bucket.counter > 0) or (bucket.counter == -1):
val = bucket.rotation_lambda(snap.created_at, i)
if (val != bucket.last) or (i == len(sorted_snaps) - 1):
bucket.last = val
if bucket.counter > 0:
bucket.counter -= 1
if not keep_snap:
new_snaplist.append(snap)
keep_snap = True
return new_snaplist
@staticmethod
def _prune_auto_snaps(service) -> None:
# Not very testable by itself, so most testing is going on Backups._prune_snaps_with_quotas
# We can still test total limits and, say, daily limits
auto_snaps = Backups._auto_snaps(service)
new_snaplist = Backups._prune_snaps_with_quotas(auto_snaps)
deletable_snaps = [snap for snap in auto_snaps if snap not in new_snaplist]
Backups.forget_snapshots(deletable_snaps)
@staticmethod
def _standardize_quotas(i: int) -> int:
if i <= -1:
i = -1
return i
@staticmethod
def autobackup_quotas() -> AutobackupQuotas:
"""0 means do not keep, -1 means unlimited"""
return Storage.autobackup_quotas()
@staticmethod
def set_autobackup_quotas(quotas: AutobackupQuotas) -> None:
"""0 means do not keep, -1 means unlimited"""
Storage.set_autobackup_quotas(
AutobackupQuotas(
last=Backups._standardize_quotas(quotas.last), # type: ignore
daily=Backups._standardize_quotas(quotas.daily), # type: ignore
weekly=Backups._standardize_quotas(quotas.weekly), # type: ignore
monthly=Backups._standardize_quotas(quotas.monthly), # type: ignore
yearly=Backups._standardize_quotas(quotas.yearly), # type: ignore
)
)
# do not prune all autosnaps right away, this will be done by an async task
@staticmethod
def prune_all_autosnaps() -> None:
for service in get_all_services():
Backups._prune_auto_snaps(service)
# Restoring # Restoring
@staticmethod @staticmethod
@ -307,9 +448,9 @@ class Backups:
job: Job, job: Job,
) -> None: ) -> None:
Jobs.update( Jobs.update(
job, status=JobStatus.CREATED, status_text=f"Waiting for pre-restore backup" job, status=JobStatus.CREATED, status_text="Waiting for pre-restore backup"
) )
failsafe_snapshot = Backups.back_up(service) failsafe_snapshot = Backups.back_up(service, BackupReason.PRE_RESTORE)
Jobs.update( Jobs.update(
job, status=JobStatus.RUNNING, status_text=f"Restoring from {snapshot.id}" job, status=JobStatus.RUNNING, status_text=f"Restoring from {snapshot.id}"
@ -465,6 +606,19 @@ class Backups:
return snap return snap
@staticmethod
def forget_snapshots(snapshots: List[Snapshot]) -> None:
"""
Deletes a batch of snapshots from the repo and from cache
Optimized
"""
ids = [snapshot.id for snapshot in snapshots]
Backups.provider().backupper.forget_snapshots(ids)
# less critical
for snapshot in snapshots:
Storage.delete_cached_snapshot(snapshot)
@staticmethod @staticmethod
def forget_snapshot(snapshot: Snapshot) -> None: def forget_snapshot(snapshot: Snapshot) -> None:
"""Deletes a snapshot from the repo and from cache""" """Deletes a snapshot from the repo and from cache"""
@ -473,11 +627,11 @@ class Backups:
@staticmethod @staticmethod
def forget_all_snapshots(): def forget_all_snapshots():
"""deliberately erase all snapshots we made""" """
# there is no dedicated optimized command for this, Mark all snapshots we have made for deletion and make them inaccessible
# but maybe we can have a multi-erase (this is done by cloud, we only issue a command)
for snapshot in Backups.get_all_snapshots(): """
Backups.forget_snapshot(snapshot) Backups.forget_snapshots(Backups.get_all_snapshots())
@staticmethod @staticmethod
def force_snapshot_cache_reload() -> None: def force_snapshot_cache_reload() -> None:
@ -557,23 +711,49 @@ class Backups:
"""Get a timezone-aware time of the last backup of a service""" """Get a timezone-aware time of the last backup of a service"""
return Storage.get_last_backup_time(service.get_id()) return Storage.get_last_backup_time(service.get_id())
@staticmethod
def get_last_backup_error_time(service: Service) -> Optional[datetime]:
"""Get a timezone-aware time of the last backup of a service"""
job = get_backup_fail(service)
if job is not None:
datetime_created = job.created_at
if datetime_created.tzinfo is None:
# assume it is in localtime
offset = timedelta(seconds=time.localtime().tm_gmtoff)
datetime_created = datetime_created - offset
return datetime.combine(
datetime_created.date(), datetime_created.time(), timezone.utc
)
return datetime_created
return None
@staticmethod @staticmethod
def is_time_to_backup_service(service: Service, time: datetime): def is_time_to_backup_service(service: Service, time: datetime):
"""Returns True if it is time to back up a service""" """Returns True if it is time to back up a service"""
period = Backups.autobackup_period_minutes() period = Backups.autobackup_period_minutes()
service_id = service.get_id()
if not service.can_be_backed_up():
return False
if period is None: if period is None:
return False return False
last_backup = Storage.get_last_backup_time(service_id) if not service.is_enabled():
return False
if not service.can_be_backed_up():
return False
last_error = Backups.get_last_backup_error_time(service)
if last_error is not None:
if time < last_error + timedelta(seconds=AUTOBACKUP_JOB_EXPIRATION_SECONDS):
return False
last_backup = Backups.get_last_backed_up(service)
# Queue a backup immediately if there are no previous backups
if last_backup is None: if last_backup is None:
# queue a backup immediately if there are no previous backups
return True return True
if time > last_backup + timedelta(minutes=period): if time > last_backup + timedelta(minutes=period):
return True return True
return False return False
# Helpers # Helpers

View file

@ -2,6 +2,7 @@ from abc import ABC, abstractmethod
from typing import List from typing import List
from selfprivacy_api.models.backup.snapshot import Snapshot from selfprivacy_api.models.backup.snapshot import Snapshot
from selfprivacy_api.graphql.common_types.backup import BackupReason
class AbstractBackupper(ABC): class AbstractBackupper(ABC):
@ -22,7 +23,12 @@ class AbstractBackupper(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def start_backup(self, folders: List[str], tag: str) -> Snapshot: def start_backup(
self,
folders: List[str],
service_name: str,
reason: BackupReason = BackupReason.EXPLICIT,
) -> Snapshot:
"""Start a backup of the given folders""" """Start a backup of the given folders"""
raise NotImplementedError raise NotImplementedError
@ -60,3 +66,8 @@ class AbstractBackupper(ABC):
def forget_snapshot(self, snapshot_id) -> None: def forget_snapshot(self, snapshot_id) -> None:
"""Forget a snapshot""" """Forget a snapshot"""
raise NotImplementedError raise NotImplementedError
@abstractmethod
def forget_snapshots(self, snapshot_ids: List[str]) -> None:
"""Maybe optimized deletion of a batch of snapshots, just cycling if unsupported"""
raise NotImplementedError

View file

@ -2,6 +2,7 @@ from typing import List
from selfprivacy_api.models.backup.snapshot import Snapshot from selfprivacy_api.models.backup.snapshot import Snapshot
from selfprivacy_api.backup.backuppers import AbstractBackupper from selfprivacy_api.backup.backuppers import AbstractBackupper
from selfprivacy_api.graphql.common_types.backup import BackupReason
class NoneBackupper(AbstractBackupper): class NoneBackupper(AbstractBackupper):
@ -13,7 +14,9 @@ class NoneBackupper(AbstractBackupper):
def set_creds(self, account: str, key: str, repo: str): def set_creds(self, account: str, key: str, repo: str):
pass pass
def start_backup(self, folders: List[str], tag: str): def start_backup(
self, folders: List[str], tag: str, reason: BackupReason = BackupReason.EXPLICIT
):
raise NotImplementedError raise NotImplementedError
def get_snapshots(self) -> List[Snapshot]: def get_snapshots(self) -> List[Snapshot]:
@ -36,4 +39,7 @@ class NoneBackupper(AbstractBackupper):
raise NotImplementedError raise NotImplementedError
def forget_snapshot(self, snapshot_id): def forget_snapshot(self, snapshot_id):
raise NotImplementedError raise NotImplementedError("forget_snapshot")
def forget_snapshots(self, snapshots):
raise NotImplementedError("forget_snapshots")

View file

@ -5,13 +5,14 @@ import json
import datetime import datetime
import tempfile import tempfile
from typing import List, TypeVar, Callable from typing import List, Optional, TypeVar, Callable
from collections.abc import Iterable from collections.abc import Iterable
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from os.path import exists, join from os.path import exists, join
from os import mkdir from os import mkdir
from shutil import rmtree from shutil import rmtree
from selfprivacy_api.graphql.common_types.backup import BackupReason
from selfprivacy_api.backup.util import output_yielder, sync from selfprivacy_api.backup.util import output_yielder, sync
from selfprivacy_api.backup.backuppers import AbstractBackupper from selfprivacy_api.backup.backuppers import AbstractBackupper
from selfprivacy_api.models.backup.snapshot import Snapshot from selfprivacy_api.models.backup.snapshot import Snapshot
@ -84,7 +85,14 @@ class ResticBackupper(AbstractBackupper):
def _password_command(self): def _password_command(self):
return f"echo {LocalBackupSecret.get()}" return f"echo {LocalBackupSecret.get()}"
def restic_command(self, *args, tag: str = "") -> List[str]: def restic_command(self, *args, tags: Optional[List[str]] = None) -> List[str]:
"""
Construct a restic command against the currently configured repo
Can support [nested] arrays as arguments, will flatten them into the final commmand
"""
if tags is None:
tags = []
command = [ command = [
"restic", "restic",
"-o", "-o",
@ -94,13 +102,14 @@ class ResticBackupper(AbstractBackupper):
"--password-command", "--password-command",
self._password_command(), self._password_command(),
] ]
if tag != "": if tags != []:
command.extend( for tag in tags:
[ command.extend(
"--tag", [
tag, "--tag",
] tag,
) ]
)
if args: if args:
command.extend(ResticBackupper.__flatten_list(args)) command.extend(ResticBackupper.__flatten_list(args))
return command return command
@ -138,7 +147,12 @@ class ResticBackupper(AbstractBackupper):
return result return result
@unlocked_repo @unlocked_repo
def start_backup(self, folders: List[str], tag: str) -> Snapshot: def start_backup(
self,
folders: List[str],
service_name: str,
reason: BackupReason = BackupReason.EXPLICIT,
) -> Snapshot:
""" """
Start backup with restic Start backup with restic
""" """
@ -147,33 +161,35 @@ class ResticBackupper(AbstractBackupper):
# of a string and an array of strings # of a string and an array of strings
assert not isinstance(folders, str) assert not isinstance(folders, str)
tags = [service_name, reason.value]
backup_command = self.restic_command( backup_command = self.restic_command(
"backup", "backup",
"--json", "--json",
folders, folders,
tag=tag, tags=tags,
) )
messages = [] service = get_service_by_id(service_name)
service = get_service_by_id(tag)
if service is None: if service is None:
raise ValueError("No service with id ", tag) raise ValueError("No service with id ", service_name)
job = get_backup_job(service) job = get_backup_job(service)
messages = []
output = [] output = []
try: try:
for raw_message in output_yielder(backup_command): for raw_message in output_yielder(backup_command):
output.append(raw_message) output.append(raw_message)
message = self.parse_message( message = self.parse_message(raw_message, job)
raw_message,
job,
)
messages.append(message) messages.append(message)
return ResticBackupper._snapshot_from_backup_messages( id = ResticBackupper._snapshot_id_from_backup_messages(messages)
messages, return Snapshot(
tag, created_at=datetime.datetime.now(datetime.timezone.utc),
id=id,
service_name=service_name,
reason=reason,
) )
except ValueError as error: except ValueError as error:
raise ValueError( raise ValueError(
"Could not create a snapshot: ", "Could not create a snapshot: ",
@ -181,16 +197,18 @@ class ResticBackupper(AbstractBackupper):
output, output,
"parsed messages:", "parsed messages:",
messages, messages,
"command: ",
backup_command,
) from error ) from error
@staticmethod @staticmethod
def _snapshot_from_backup_messages(messages, repo_name) -> Snapshot: def _snapshot_id_from_backup_messages(messages) -> str:
for message in messages: for message in messages:
if message["message_type"] == "summary": if message["message_type"] == "summary":
return ResticBackupper._snapshot_from_fresh_summary( # There is a discrepancy between versions of restic/rclone
message, # Some report short_id in this field and some full
repo_name, return message["snapshot_id"][0:SHORT_ID_LEN]
)
raise ValueError("no summary message in restic json output") raise ValueError("no summary message in restic json output")
def parse_message(self, raw_message_line: str, job=None) -> dict: def parse_message(self, raw_message_line: str, job=None) -> dict:
@ -206,16 +224,6 @@ class ResticBackupper(AbstractBackupper):
) )
return message return message
@staticmethod
def _snapshot_from_fresh_summary(message: dict, repo_name) -> Snapshot:
return Snapshot(
# There is a discrepancy between versions of restic/rclone
# Some report short_id in this field and some full
id=message["snapshot_id"][0:SHORT_ID_LEN],
created_at=datetime.datetime.now(datetime.timezone.utc),
service_name=repo_name,
)
def init(self) -> None: def init(self) -> None:
init_command = self.restic_command( init_command = self.restic_command(
"init", "init",
@ -364,7 +372,6 @@ class ResticBackupper(AbstractBackupper):
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
shell=False, shell=False,
) as handle: ) as handle:
# for some reason restore does not support # for some reason restore does not support
# nice reporting of progress via json # nice reporting of progress via json
output = handle.communicate()[0].decode("utf-8") output = handle.communicate()[0].decode("utf-8")
@ -382,15 +389,17 @@ class ResticBackupper(AbstractBackupper):
output, output,
) )
def forget_snapshot(self, snapshot_id: str) -> None:
self.forget_snapshots([snapshot_id])
@unlocked_repo @unlocked_repo
def forget_snapshot(self, snapshot_id) -> None: def forget_snapshots(self, snapshot_ids: List[str]) -> None:
""" # in case the backupper program supports batching, otherwise implement it by cycling
Either removes snapshot or marks it for deletion later,
depending on server settings
"""
forget_command = self.restic_command( forget_command = self.restic_command(
"forget", "forget",
snapshot_id, [snapshot_ids],
# TODO: prune should be done in a separate process
"--prune",
) )
with subprocess.Popen( with subprocess.Popen(
@ -410,7 +419,7 @@ class ResticBackupper(AbstractBackupper):
if "no matching ID found" in err: if "no matching ID found" in err:
raise ValueError( raise ValueError(
"trying to delete, but no such snapshot: ", snapshot_id "trying to delete, but no such snapshot(s): ", snapshot_ids
) )
assert ( assert (
@ -450,11 +459,19 @@ class ResticBackupper(AbstractBackupper):
def get_snapshots(self) -> List[Snapshot]: def get_snapshots(self) -> List[Snapshot]:
"""Get all snapshots from the repo""" """Get all snapshots from the repo"""
snapshots = [] snapshots = []
for restic_snapshot in self._load_snapshots(): for restic_snapshot in self._load_snapshots():
# Compatibility with previous snaps:
if len(restic_snapshot["tags"]) == 1:
reason = BackupReason.EXPLICIT
else:
reason = restic_snapshot["tags"][1]
snapshot = Snapshot( snapshot = Snapshot(
id=restic_snapshot["short_id"], id=restic_snapshot["short_id"],
created_at=restic_snapshot["time"], created_at=restic_snapshot["time"],
service_name=restic_snapshot["tags"][0], service_name=restic_snapshot["tags"][0],
reason=reason,
) )
snapshots.append(snapshot) snapshots.append(snapshot)

View file

@ -80,9 +80,19 @@ def get_job_by_type(type_id: str) -> Optional[Job]:
return job return job
def get_failed_job_by_type(type_id: str) -> Optional[Job]:
for job in Jobs.get_jobs():
if job.type_id == type_id and job.status == JobStatus.ERROR:
return job
def get_backup_job(service: Service) -> Optional[Job]: def get_backup_job(service: Service) -> Optional[Job]:
return get_job_by_type(backup_job_type(service)) return get_job_by_type(backup_job_type(service))
def get_backup_fail(service: Service) -> Optional[Job]:
return get_failed_job_by_type(backup_job_type(service))
def get_restore_job(service: Service) -> Optional[Job]: def get_restore_job(service: Service) -> Optional[Job]:
return get_job_by_type(restore_job_type(service)) return get_job_by_type(restore_job_type(service))

View file

@ -6,6 +6,10 @@ from datetime import datetime
from selfprivacy_api.models.backup.snapshot import Snapshot from selfprivacy_api.models.backup.snapshot import Snapshot
from selfprivacy_api.models.backup.provider import BackupProviderModel from selfprivacy_api.models.backup.provider import BackupProviderModel
from selfprivacy_api.graphql.common_types.backup import (
AutobackupQuotas,
_AutobackupQuotas,
)
from selfprivacy_api.utils.redis_pool import RedisPool from selfprivacy_api.utils.redis_pool import RedisPool
from selfprivacy_api.utils.redis_model_storage import ( from selfprivacy_api.utils.redis_model_storage import (
@ -23,6 +27,8 @@ REDIS_INITTED_CACHE = "backups:repo_initted"
REDIS_PROVIDER_KEY = "backups:provider" REDIS_PROVIDER_KEY = "backups:provider"
REDIS_AUTOBACKUP_PERIOD_KEY = "backups:autobackup_period" REDIS_AUTOBACKUP_PERIOD_KEY = "backups:autobackup_period"
REDIS_AUTOBACKUP_QUOTAS_KEY = "backups:autobackup_quotas_key"
redis = RedisPool().get_connection() redis = RedisPool().get_connection()
@ -35,6 +41,7 @@ class Storage:
redis.delete(REDIS_PROVIDER_KEY) redis.delete(REDIS_PROVIDER_KEY)
redis.delete(REDIS_AUTOBACKUP_PERIOD_KEY) redis.delete(REDIS_AUTOBACKUP_PERIOD_KEY)
redis.delete(REDIS_INITTED_CACHE) redis.delete(REDIS_INITTED_CACHE)
redis.delete(REDIS_AUTOBACKUP_QUOTAS_KEY)
prefixes_to_clean = [ prefixes_to_clean = [
REDIS_SNAPSHOTS_PREFIX, REDIS_SNAPSHOTS_PREFIX,
@ -170,3 +177,23 @@ class Storage:
def mark_as_uninitted(): def mark_as_uninitted():
"""Marks the repository as initialized""" """Marks the repository as initialized"""
redis.delete(REDIS_INITTED_CACHE) redis.delete(REDIS_INITTED_CACHE)
@staticmethod
def set_autobackup_quotas(quotas: AutobackupQuotas) -> None:
store_model_as_hash(redis, REDIS_AUTOBACKUP_QUOTAS_KEY, quotas.to_pydantic())
@staticmethod
def autobackup_quotas() -> AutobackupQuotas:
quotas_model = hash_as_model(
redis, REDIS_AUTOBACKUP_QUOTAS_KEY, _AutobackupQuotas
)
if quotas_model is None:
unlimited_quotas = AutobackupQuotas(
last=-1,
daily=-1,
weekly=-1,
monthly=-1,
yearly=-1,
)
return unlimited_quotas
return AutobackupQuotas.from_pydantic(quotas_model) # pylint: disable=no-member

View file

@ -3,13 +3,20 @@ The tasks module contains the worker tasks that are used to back up and restore
""" """
from datetime import datetime, timezone from datetime import datetime, timezone
from selfprivacy_api.graphql.common_types.backup import RestoreStrategy from selfprivacy_api.graphql.common_types.backup import (
RestoreStrategy,
BackupReason,
)
from selfprivacy_api.models.backup.snapshot import Snapshot from selfprivacy_api.models.backup.snapshot import Snapshot
from selfprivacy_api.utils.huey import huey from selfprivacy_api.utils.huey import huey
from huey import crontab from huey import crontab
from selfprivacy_api.services.service import Service from selfprivacy_api.services.service import Service
from selfprivacy_api.services import get_service_by_id
from selfprivacy_api.backup import Backups from selfprivacy_api.backup import Backups
from selfprivacy_api.jobs import Jobs, JobStatus, Job
SNAPSHOT_CACHE_TTL_HOURS = 6 SNAPSHOT_CACHE_TTL_HOURS = 6
@ -26,11 +33,30 @@ def validate_datetime(dt: datetime) -> bool:
# huey tasks need to return something # huey tasks need to return something
@huey.task() @huey.task()
def start_backup(service: Service) -> bool: def start_backup(service_id: str, reason: BackupReason = BackupReason.EXPLICIT) -> bool:
""" """
The worker task that starts the backup process. The worker task that starts the backup process.
""" """
Backups.back_up(service) service = get_service_by_id(service_id)
if service is None:
raise ValueError(f"No such service: {service_id}")
Backups.back_up(service, reason)
return True
@huey.task()
def prune_autobackup_snapshots(job: Job) -> bool:
"""
Remove all autobackup snapshots that do not fit into quotas set
"""
Jobs.update(job, JobStatus.RUNNING)
try:
Backups.prune_all_autosnaps()
except Exception as e:
Jobs.update(job, JobStatus.ERROR, error=type(e).__name__ + ":" + str(e))
return False
Jobs.update(job, JobStatus.FINISHED)
return True return True
@ -53,7 +79,7 @@ def automatic_backup():
""" """
time = datetime.utcnow().replace(tzinfo=timezone.utc) time = datetime.utcnow().replace(tzinfo=timezone.utc)
for service in Backups.services_to_back_up(time): for service in Backups.services_to_back_up(time):
start_backup(service) start_backup(service, BackupReason.AUTO)
@huey.periodic_task(crontab(hour=SNAPSHOT_CACHE_TTL_HOURS)) @huey.periodic_task(crontab(hour=SNAPSHOT_CACHE_TTL_HOURS))

View file

@ -27,4 +27,4 @@ async def get_token_header(
def get_api_version() -> str: def get_api_version() -> str:
"""Get API version""" """Get API version"""
return "2.3.1" return "3.0.0"

View file

@ -1,10 +1,36 @@
"""Backup""" """Backup"""
# pylint: disable=too-few-public-methods # pylint: disable=too-few-public-methods
import strawberry
from enum import Enum from enum import Enum
import strawberry
from pydantic import BaseModel
@strawberry.enum @strawberry.enum
class RestoreStrategy(Enum): class RestoreStrategy(Enum):
INPLACE = "INPLACE" INPLACE = "INPLACE"
DOWNLOAD_VERIFY_OVERWRITE = "DOWNLOAD_VERIFY_OVERWRITE" DOWNLOAD_VERIFY_OVERWRITE = "DOWNLOAD_VERIFY_OVERWRITE"
@strawberry.enum
class BackupReason(Enum):
EXPLICIT = "EXPLICIT"
AUTO = "AUTO"
PRE_RESTORE = "PRE_RESTORE"
class _AutobackupQuotas(BaseModel):
last: int
daily: int
weekly: int
monthly: int
yearly: int
@strawberry.experimental.pydantic.type(model=_AutobackupQuotas, all_fields=True)
class AutobackupQuotas:
pass
@strawberry.experimental.pydantic.input(model=_AutobackupQuotas, all_fields=True)
class AutobackupQuotasInput:
pass

View file

@ -11,3 +11,4 @@ class DnsRecord:
content: str content: str
ttl: int ttl: int
priority: typing.Optional[int] priority: typing.Optional[int]
display_name: str

View file

@ -2,6 +2,7 @@ from enum import Enum
import typing import typing
import strawberry import strawberry
import datetime import datetime
from selfprivacy_api.graphql.common_types.backup import BackupReason
from selfprivacy_api.graphql.common_types.dns import DnsRecord from selfprivacy_api.graphql.common_types.dns import DnsRecord
from selfprivacy_api.services import get_service_by_id, get_services_by_location from selfprivacy_api.services import get_service_by_id, get_services_by_location
@ -114,6 +115,7 @@ class SnapshotInfo:
id: str id: str
service: Service service: Service
created_at: datetime.datetime created_at: datetime.datetime
reason: BackupReason
def service_to_graphql_service(service: ServiceInterface) -> Service: def service_to_graphql_service(service: ServiceInterface) -> Service:
@ -137,6 +139,7 @@ def service_to_graphql_service(service: ServiceInterface) -> Service:
content=record.content, content=record.content,
ttl=record.ttl, ttl=record.ttl,
priority=record.priority, priority=record.priority,
display_name=record.display_name,
) )
for record in service.get_dns_records() for record in service.get_dns_records()
], ],

View file

@ -17,7 +17,6 @@ class UserType(Enum):
@strawberry.type @strawberry.type
class User: class User:
user_type: UserType user_type: UserType
username: str username: str
# userHomeFolderspace: UserHomeFolderUsage # userHomeFolderspace: UserHomeFolderUsage
@ -32,7 +31,6 @@ class UserMutationReturn(MutationReturnInterface):
def get_user_by_username(username: str) -> typing.Optional[User]: def get_user_by_username(username: str) -> typing.Optional[User]:
user = users_actions.get_user_by_username(username) user = users_actions.get_user_by_username(username)
if user is None: if user is None:
return None return None

View file

@ -1,6 +1,8 @@
import typing import typing
import strawberry import strawberry
from selfprivacy_api.jobs import Jobs
from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.mutations.mutation_interface import ( from selfprivacy_api.graphql.mutations.mutation_interface import (
GenericMutationReturn, GenericMutationReturn,
@ -11,11 +13,18 @@ from selfprivacy_api.graphql.queries.backup import BackupConfiguration
from selfprivacy_api.graphql.queries.backup import Backup from selfprivacy_api.graphql.queries.backup import Backup
from selfprivacy_api.graphql.queries.providers import BackupProvider from selfprivacy_api.graphql.queries.providers import BackupProvider
from selfprivacy_api.graphql.common_types.jobs import job_to_api_job from selfprivacy_api.graphql.common_types.jobs import job_to_api_job
from selfprivacy_api.graphql.common_types.backup import RestoreStrategy from selfprivacy_api.graphql.common_types.backup import (
AutobackupQuotasInput,
RestoreStrategy,
)
from selfprivacy_api.backup import Backups from selfprivacy_api.backup import Backups
from selfprivacy_api.services import get_service_by_id from selfprivacy_api.services import get_service_by_id
from selfprivacy_api.backup.tasks import start_backup, restore_snapshot from selfprivacy_api.backup.tasks import (
start_backup,
restore_snapshot,
prune_autobackup_snapshots,
)
from selfprivacy_api.backup.jobs import add_backup_job, add_restore_job from selfprivacy_api.backup.jobs import add_backup_job, add_restore_job
@ -90,6 +99,41 @@ class BackupMutations:
configuration=Backup().configuration(), configuration=Backup().configuration(),
) )
@strawberry.mutation(permission_classes=[IsAuthenticated])
def set_autobackup_quotas(
self, quotas: AutobackupQuotasInput
) -> GenericBackupConfigReturn:
"""
Set autobackup quotas.
Values <=0 for any timeframe mean no limits for that timeframe.
To disable autobackup use autobackup period setting, not this mutation.
"""
job = Jobs.add(
name="Trimming autobackup snapshots",
type_id="backups.autobackup_trimming",
description="Pruning the excessive snapshots after the new autobackup quotas are set",
)
try:
Backups.set_autobackup_quotas(quotas)
# this task is async and can fail with only a job to report the error
prune_autobackup_snapshots(job)
return GenericBackupConfigReturn(
success=True,
message="",
code=200,
configuration=Backup().configuration(),
)
except Exception as e:
return GenericBackupConfigReturn(
success=False,
message=type(e).__name__ + ":" + str(e),
code=400,
configuration=Backup().configuration(),
)
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def start_backup(self, service_id: str) -> GenericJobMutationReturn: def start_backup(self, service_id: str) -> GenericJobMutationReturn:
"""Start backup""" """Start backup"""
@ -104,7 +148,7 @@ class BackupMutations:
) )
job = add_backup_job(service) job = add_backup_job(service)
start_backup(service) start_backup(service_id)
return GenericJobMutationReturn( return GenericJobMutationReturn(
success=True, success=True,

View file

@ -20,6 +20,7 @@ from selfprivacy_api.graphql.mutations.mutation_interface import (
GenericMutationReturn, GenericMutationReturn,
) )
from selfprivacy_api.graphql.mutations.services_mutations import ( from selfprivacy_api.graphql.mutations.services_mutations import (
ServiceJobMutationReturn,
ServiceMutationReturn, ServiceMutationReturn,
ServicesMutations, ServicesMutations,
) )
@ -201,7 +202,7 @@ class DeprecatedServicesMutations:
"services", "services",
) )
move_service: ServiceMutationReturn = deprecated_mutation( move_service: ServiceJobMutationReturn = deprecated_mutation(
ServicesMutations.move_service, ServicesMutations.move_service,
"services", "services",
) )

View file

@ -4,6 +4,7 @@ import typing
import strawberry import strawberry
from selfprivacy_api.graphql import IsAuthenticated from selfprivacy_api.graphql import IsAuthenticated
from selfprivacy_api.graphql.common_types.jobs import job_to_api_job from selfprivacy_api.graphql.common_types.jobs import job_to_api_job
from selfprivacy_api.jobs import JobStatus
from selfprivacy_api.graphql.common_types.service import ( from selfprivacy_api.graphql.common_types.service import (
Service, Service,
@ -47,14 +48,22 @@ class ServicesMutations:
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def enable_service(self, service_id: str) -> ServiceMutationReturn: def enable_service(self, service_id: str) -> ServiceMutationReturn:
"""Enable service.""" """Enable service."""
service = get_service_by_id(service_id) try:
if service is None: service = get_service_by_id(service_id)
if service is None:
return ServiceMutationReturn(
success=False,
message="Service not found.",
code=404,
)
service.enable()
except Exception as e:
return ServiceMutationReturn( return ServiceMutationReturn(
success=False, success=False,
message="Service not found.", message=format_error(e),
code=404, code=400,
) )
service.enable()
return ServiceMutationReturn( return ServiceMutationReturn(
success=True, success=True,
message="Service enabled.", message="Service enabled.",
@ -65,14 +74,21 @@ class ServicesMutations:
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def disable_service(self, service_id: str) -> ServiceMutationReturn: def disable_service(self, service_id: str) -> ServiceMutationReturn:
"""Disable service.""" """Disable service."""
service = get_service_by_id(service_id) try:
if service is None: service = get_service_by_id(service_id)
if service is None:
return ServiceMutationReturn(
success=False,
message="Service not found.",
code=404,
)
service.disable()
except Exception as e:
return ServiceMutationReturn( return ServiceMutationReturn(
success=False, success=False,
message="Service not found.", message=format_error(e),
code=404, code=400,
) )
service.disable()
return ServiceMutationReturn( return ServiceMutationReturn(
success=True, success=True,
message="Service disabled.", message="Service disabled.",
@ -144,6 +160,8 @@ class ServicesMutations:
message="Service not found.", message="Service not found.",
code=404, code=404,
) )
# TODO: make serviceImmovable and BlockdeviceNotFound exceptions
# in the move_to_volume() function and handle them here
if not service.is_movable(): if not service.is_movable():
return ServiceJobMutationReturn( return ServiceJobMutationReturn(
success=False, success=False,
@ -160,10 +178,31 @@ class ServicesMutations:
service=service_to_graphql_service(service), service=service_to_graphql_service(service),
) )
job = service.move_to_volume(volume) job = service.move_to_volume(volume)
return ServiceJobMutationReturn( if job.status in [JobStatus.CREATED, JobStatus.RUNNING]:
success=True, return ServiceJobMutationReturn(
message="Service moved.", success=True,
code=200, message="Started moving the service.",
service=service_to_graphql_service(service), code=200,
job=job_to_api_job(job), service=service_to_graphql_service(service),
) job=job_to_api_job(job),
)
elif job.status == JobStatus.FINISHED:
return ServiceJobMutationReturn(
success=True,
message="Service moved.",
code=200,
service=service_to_graphql_service(service),
job=job_to_api_job(job),
)
else:
return ServiceJobMutationReturn(
success=False,
message=f"Service move failure: {job.status_text}",
code=400,
service=service_to_graphql_service(service),
job=job_to_api_job(job),
)
def format_error(e: Exception) -> str:
return type(e).__name__ + ": " + str(e)

View file

@ -12,6 +12,7 @@ from selfprivacy_api.graphql.mutations.mutation_interface import (
import selfprivacy_api.actions.system as system_actions import selfprivacy_api.actions.system as system_actions
from selfprivacy_api.graphql.common_types.jobs import job_to_api_job from selfprivacy_api.graphql.common_types.jobs import job_to_api_job
from selfprivacy_api.jobs.nix_collect_garbage import start_nix_collect_garbage from selfprivacy_api.jobs.nix_collect_garbage import start_nix_collect_garbage
import selfprivacy_api.actions.ssh as ssh_actions
@strawberry.type @strawberry.type
@ -29,6 +30,22 @@ class AutoUpgradeSettingsMutationReturn(MutationReturnInterface):
allowReboot: bool allowReboot: bool
@strawberry.type
class SSHSettingsMutationReturn(MutationReturnInterface):
"""A return type for after changing SSH settings"""
enable: bool
password_authentication: bool
@strawberry.input
class SSHSettingsInput:
"""Input type for SSH settings"""
enable: bool
password_authentication: bool
@strawberry.input @strawberry.input
class AutoUpgradeSettingsInput: class AutoUpgradeSettingsInput:
"""Input type for auto upgrade settings""" """Input type for auto upgrade settings"""
@ -80,40 +97,88 @@ class SystemMutations:
) )
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def run_system_rebuild(self) -> GenericMutationReturn: def change_ssh_settings(
system_actions.rebuild_system() self, settings: SSHSettingsInput
return GenericMutationReturn( ) -> SSHSettingsMutationReturn:
success=True, """Change ssh settings of the server."""
message="Starting rebuild system", ssh_actions.set_ssh_settings(
code=200, enable=settings.enable,
password_authentication=settings.password_authentication,
) )
new_settings = ssh_actions.get_ssh_settings()
return SSHSettingsMutationReturn(
success=True,
message="SSH settings changed",
code=200,
enable=new_settings.enable,
password_authentication=new_settings.passwordAuthentication,
)
@strawberry.mutation(permission_classes=[IsAuthenticated])
def run_system_rebuild(self) -> GenericMutationReturn:
try:
system_actions.rebuild_system()
return GenericMutationReturn(
success=True,
message="Starting rebuild system",
code=200,
)
except system_actions.ShellException as e:
return GenericMutationReturn(
success=False,
message=str(e),
code=500,
)
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def run_system_rollback(self) -> GenericMutationReturn: def run_system_rollback(self) -> GenericMutationReturn:
system_actions.rollback_system() system_actions.rollback_system()
return GenericMutationReturn( try:
success=True, return GenericMutationReturn(
message="Starting rebuild system", success=True,
code=200, message="Starting rebuild system",
) code=200,
)
except system_actions.ShellException as e:
return GenericMutationReturn(
success=False,
message=str(e),
code=500,
)
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def run_system_upgrade(self) -> GenericMutationReturn: def run_system_upgrade(self) -> GenericMutationReturn:
system_actions.upgrade_system() system_actions.upgrade_system()
return GenericMutationReturn( try:
success=True, return GenericMutationReturn(
message="Starting rebuild system", success=True,
code=200, message="Starting rebuild system",
) code=200,
)
except system_actions.ShellException as e:
return GenericMutationReturn(
success=False,
message=str(e),
code=500,
)
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def reboot_system(self) -> GenericMutationReturn: def reboot_system(self) -> GenericMutationReturn:
system_actions.reboot_system() system_actions.reboot_system()
return GenericMutationReturn( try:
success=True, return GenericMutationReturn(
message="System reboot has started", success=True,
code=200, message="System reboot has started",
) code=200,
)
except system_actions.ShellException as e:
return GenericMutationReturn(
success=False,
message=str(e),
code=500,
)
@strawberry.mutation(permission_classes=[IsAuthenticated]) @strawberry.mutation(permission_classes=[IsAuthenticated])
def pull_repository_changes(self) -> GenericMutationReturn: def pull_repository_changes(self) -> GenericMutationReturn:

View file

@ -69,6 +69,12 @@ class UsersMutations:
message=str(e), message=str(e),
code=400, code=400,
) )
except users_actions.InvalidConfiguration as e:
return UserMutationReturn(
success=False,
message=str(e),
code=400,
)
except users_actions.UserAlreadyExists as e: except users_actions.UserAlreadyExists as e:
return UserMutationReturn( return UserMutationReturn(
success=False, success=False,
@ -147,7 +153,7 @@ class UsersMutations:
except InvalidPublicKey: except InvalidPublicKey:
return UserMutationReturn( return UserMutationReturn(
success=False, success=False,
message="Invalid key type. Only ssh-ed25519 and ssh-rsa are supported", message="Invalid key type. Only ssh-ed25519, ssh-rsa and ecdsa are supported",
code=400, code=400,
) )
except UserNotFound: except UserNotFound:

View file

@ -38,7 +38,7 @@ class ApiRecoveryKeyStatus:
def get_recovery_key_status() -> ApiRecoveryKeyStatus: def get_recovery_key_status() -> ApiRecoveryKeyStatus:
"""Get recovery key status""" """Get recovery key status, times are timezone-aware"""
status = get_api_recovery_token_status() status = get_api_recovery_token_status()
if status is None or not status.exists: if status is None or not status.exists:
return ApiRecoveryKeyStatus( return ApiRecoveryKeyStatus(

View file

@ -13,6 +13,7 @@ from selfprivacy_api.graphql.common_types.service import (
SnapshotInfo, SnapshotInfo,
service_to_graphql_service, service_to_graphql_service,
) )
from selfprivacy_api.graphql.common_types.backup import AutobackupQuotas
from selfprivacy_api.services import get_service_by_id from selfprivacy_api.services import get_service_by_id
@ -26,6 +27,8 @@ class BackupConfiguration:
is_initialized: bool is_initialized: bool
# If none, autobackups are disabled # If none, autobackups are disabled
autobackup_period: typing.Optional[int] autobackup_period: typing.Optional[int]
# None is equal to all quotas being unlimited (-1). Optional for compatibility reasons.
autobackup_quotas: AutobackupQuotas
# Bucket name for Backblaze, path for some other providers # Bucket name for Backblaze, path for some other providers
location_name: typing.Optional[str] location_name: typing.Optional[str]
location_id: typing.Optional[str] location_id: typing.Optional[str]
@ -42,6 +45,7 @@ class Backup:
autobackup_period=Backups.autobackup_period_minutes(), autobackup_period=Backups.autobackup_period_minutes(),
location_name=Backups.provider().location, location_name=Backups.provider().location,
location_id=Backups.provider().repo_id, location_id=Backups.provider().repo_id,
autobackup_quotas=Backups.autobackup_quotas(),
) )
@strawberry.field @strawberry.field
@ -73,6 +77,7 @@ class Backup:
id=snap.id, id=snap.id,
service=service, service=service,
created_at=snap.created_at, created_at=snap.created_at,
reason=snap.reason,
) )
result.append(graphql_snap) result.append(graphql_snap)
return result return result

View file

@ -15,7 +15,6 @@ from selfprivacy_api.jobs import Jobs
class Job: class Job:
@strawberry.field @strawberry.field
def get_jobs(self) -> typing.List[ApiJob]: def get_jobs(self) -> typing.List[ApiJob]:
Jobs.get_jobs() Jobs.get_jobs()
return [job_to_api_job(job) for job in Jobs.get_jobs()] return [job_to_api_job(job) for job in Jobs.get_jobs()]

View file

@ -33,6 +33,7 @@ class SystemDomainInfo:
content=record.content, content=record.content,
ttl=record.ttl, ttl=record.ttl,
priority=record.priority, priority=record.priority,
display_name=record.display_name,
) )
for record in get_all_required_dns_records() for record in get_all_required_dns_records()
] ]

View file

@ -8,8 +8,8 @@ A job is a dictionary with the following keys:
- name: name of the job - name: name of the job
- description: description of the job - description: description of the job
- status: status of the job - status: status of the job
- created_at: date of creation of the job - created_at: date of creation of the job, naive localtime
- updated_at: date of last update of the job - updated_at: date of last update of the job, naive localtime
- finished_at: date of finish of the job - finished_at: date of finish of the job
- error: error message if the job failed - error: error message if the job failed
- result: result of the job - result: result of the job
@ -224,6 +224,14 @@ class Jobs:
return job return job
@staticmethod
def set_expiration(job: Job, expiration_seconds: int) -> Job:
redis = RedisPool().get_connection()
key = _redis_key_from_uuid(job.uid)
if redis.exists(key):
redis.expire(key, expiration_seconds)
return job
@staticmethod @staticmethod
def get_job(uid: str) -> typing.Optional[Job]: def get_job(uid: str) -> typing.Optional[Job]:
""" """

View file

@ -8,33 +8,12 @@ at api.skippedMigrations in userdata.json and populating it
with IDs of the migrations to skip. with IDs of the migrations to skip.
Adding DISABLE_ALL to that array disables the migrations module entirely. Adding DISABLE_ALL to that array disables the migrations module entirely.
""" """
from selfprivacy_api.migrations.check_for_failed_binds_migration import (
CheckForFailedBindsMigration, from selfprivacy_api.utils import ReadUserData, UserDataFiles
) from selfprivacy_api.migrations.write_token_to_redis import WriteTokenToRedis
from selfprivacy_api.utils import ReadUserData
from selfprivacy_api.migrations.fix_nixos_config_branch import FixNixosConfigBranch
from selfprivacy_api.migrations.create_tokens_json import CreateTokensJson
from selfprivacy_api.migrations.migrate_to_selfprivacy_channel import (
MigrateToSelfprivacyChannel,
)
from selfprivacy_api.migrations.mount_volume import MountVolume
from selfprivacy_api.migrations.providers import CreateProviderFields
from selfprivacy_api.migrations.prepare_for_nixos_2211 import (
MigrateToSelfprivacyChannelFrom2205,
)
from selfprivacy_api.migrations.prepare_for_nixos_2305 import (
MigrateToSelfprivacyChannelFrom2211,
)
migrations = [ migrations = [
FixNixosConfigBranch(), WriteTokenToRedis(),
CreateTokensJson(),
MigrateToSelfprivacyChannel(),
MountVolume(),
CheckForFailedBindsMigration(),
CreateProviderFields(),
MigrateToSelfprivacyChannelFrom2205(),
MigrateToSelfprivacyChannelFrom2211(),
] ]
@ -43,7 +22,7 @@ def run_migrations():
Go over all migrations. If they are not skipped in userdata file, run them Go over all migrations. If they are not skipped in userdata file, run them
if the migration needed. if the migration needed.
""" """
with ReadUserData() as data: with ReadUserData(UserDataFiles.SECRETS) as data:
if "api" not in data: if "api" not in data:
skipped_migrations = [] skipped_migrations = []
elif "skippedMigrations" not in data["api"]: elif "skippedMigrations" not in data["api"]:

View file

@ -1,48 +0,0 @@
from selfprivacy_api.jobs import JobStatus, Jobs
from selfprivacy_api.migrations.migration import Migration
from selfprivacy_api.utils import WriteUserData
class CheckForFailedBindsMigration(Migration):
"""Mount volume."""
def get_migration_name(self):
return "check_for_failed_binds_migration"
def get_migration_description(self):
return "If binds migration failed, try again."
def is_migration_needed(self):
try:
jobs = Jobs.get_jobs()
# If there is a job with type_id "migrations.migrate_to_binds" and status is not "FINISHED",
# then migration is needed and job is deleted
for job in jobs:
if (
job.type_id == "migrations.migrate_to_binds"
and job.status != JobStatus.FINISHED
):
return True
return False
except Exception as error:
print(error)
return False
def migrate(self):
# Get info about existing volumes
# Write info about volumes to userdata.json
try:
jobs = Jobs.get_jobs()
for job in jobs:
if (
job.type_id == "migrations.migrate_to_binds"
and job.status != JobStatus.FINISHED
):
Jobs.remove(job)
with WriteUserData() as userdata:
userdata["useBinds"] = False
print("Done")
except Exception as error:
print(error)
print("Error mounting volume")

View file

@ -1,58 +0,0 @@
from datetime import datetime
import os
import json
from pathlib import Path
from selfprivacy_api.migrations.migration import Migration
from selfprivacy_api.utils import TOKENS_FILE, ReadUserData
class CreateTokensJson(Migration):
def get_migration_name(self):
return "create_tokens_json"
def get_migration_description(self):
return """Selfprivacy API used a single token in userdata.json for authentication.
This migration creates a new tokens.json file with the old token in it.
This migration runs if the tokens.json file does not exist.
Old token is located at ["api"]["token"] in userdata.json.
tokens.json path is declared in TOKENS_FILE imported from utils.py
tokens.json must have the following format:
{
"tokens": [
{
"token": "token_string",
"name": "Master Token",
"date": "current date from str(datetime.now())",
}
]
}
tokens.json must have 0600 permissions.
"""
def is_migration_needed(self):
return not os.path.exists(TOKENS_FILE)
def migrate(self):
try:
print(f"Creating tokens.json file at {TOKENS_FILE}")
with ReadUserData() as userdata:
token = userdata["api"]["token"]
# Touch tokens.json with 0600 permissions
Path(TOKENS_FILE).touch(mode=0o600)
# Write token to tokens.json
structure = {
"tokens": [
{
"token": token,
"name": "primary_token",
"date": str(datetime.now()),
}
]
}
with open(TOKENS_FILE, "w", encoding="utf-8") as tokens:
json.dump(structure, tokens, indent=4)
print("Done")
except Exception as e:
print(e)
print("Error creating tokens.json")

View file

@ -1,57 +0,0 @@
import os
import subprocess
from selfprivacy_api.migrations.migration import Migration
class FixNixosConfigBranch(Migration):
def get_migration_name(self):
return "fix_nixos_config_branch"
def get_migration_description(self):
return """Mobile SelfPrivacy app introduced a bug in version 0.4.0.
New servers were initialized with a rolling-testing nixos config branch.
This was fixed in app version 0.4.2, but existing servers were not updated.
This migration fixes this by changing the nixos config branch to master.
"""
def is_migration_needed(self):
"""Check the current branch of /etc/nixos and return True if it is rolling-testing"""
current_working_directory = os.getcwd()
try:
os.chdir("/etc/nixos")
nixos_config_branch = subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"], start_new_session=True
)
os.chdir(current_working_directory)
return nixos_config_branch.decode("utf-8").strip() == "rolling-testing"
except subprocess.CalledProcessError:
os.chdir(current_working_directory)
return False
def migrate(self):
"""Affected server pulled the config with the --single-branch flag.
Git config remote.origin.fetch has to be changed, so all branches will be fetched.
Then, fetch all branches, pull and switch to master branch.
"""
print("Fixing Nixos config branch")
current_working_directory = os.getcwd()
try:
os.chdir("/etc/nixos")
subprocess.check_output(
[
"git",
"config",
"remote.origin.fetch",
"+refs/heads/*:refs/remotes/origin/*",
]
)
subprocess.check_output(["git", "fetch", "--all"])
subprocess.check_output(["git", "pull"])
subprocess.check_output(["git", "checkout", "master"])
os.chdir(current_working_directory)
print("Done")
except subprocess.CalledProcessError:
os.chdir(current_working_directory)
print("Error")

View file

@ -1,49 +0,0 @@
import os
import subprocess
from selfprivacy_api.migrations.migration import Migration
class MigrateToSelfprivacyChannel(Migration):
"""Migrate to selfprivacy Nix channel."""
def get_migration_name(self):
return "migrate_to_selfprivacy_channel"
def get_migration_description(self):
return "Migrate to selfprivacy Nix channel."
def is_migration_needed(self):
try:
output = subprocess.check_output(
["nix-channel", "--list"], start_new_session=True
)
output = output.decode("utf-8")
first_line = output.split("\n", maxsplit=1)[0]
return first_line.startswith("nixos") and (
first_line.endswith("nixos-21.11") or first_line.endswith("nixos-21.05")
)
except subprocess.CalledProcessError:
return False
def migrate(self):
# Change the channel and update them.
# Also, go to /etc/nixos directory and make a git pull
current_working_directory = os.getcwd()
try:
print("Changing channel")
os.chdir("/etc/nixos")
subprocess.check_output(
[
"nix-channel",
"--add",
"https://channel.selfprivacy.org/nixos-selfpricacy",
"nixos",
]
)
subprocess.check_output(["nix-channel", "--update"])
subprocess.check_output(["git", "pull"])
os.chdir(current_working_directory)
except subprocess.CalledProcessError:
os.chdir(current_working_directory)
print("Error")

View file

@ -1,51 +0,0 @@
import os
import subprocess
from selfprivacy_api.migrations.migration import Migration
from selfprivacy_api.utils import ReadUserData, WriteUserData
from selfprivacy_api.utils.block_devices import BlockDevices
class MountVolume(Migration):
"""Mount volume."""
def get_migration_name(self):
return "mount_volume"
def get_migration_description(self):
return "Mount volume if it is not mounted."
def is_migration_needed(self):
try:
with ReadUserData() as userdata:
return "volumes" not in userdata
except Exception as e:
print(e)
return False
def migrate(self):
# Get info about existing volumes
# Write info about volumes to userdata.json
try:
volumes = BlockDevices().get_block_devices()
# If there is an unmounted volume sdb,
# Write it to userdata.json
is_there_a_volume = False
for volume in volumes:
if volume.name == "sdb":
is_there_a_volume = True
break
with WriteUserData() as userdata:
userdata["volumes"] = []
if is_there_a_volume:
userdata["volumes"].append(
{
"device": "/dev/sdb",
"mountPoint": "/volumes/sdb",
"fsType": "ext4",
}
)
print("Done")
except Exception as e:
print(e)
print("Error mounting volume")

View file

@ -1,58 +0,0 @@
import os
import subprocess
from selfprivacy_api.migrations.migration import Migration
class MigrateToSelfprivacyChannelFrom2205(Migration):
"""Migrate to selfprivacy Nix channel.
For some reason NixOS 22.05 servers initialized with the nixos channel instead of selfprivacy.
This stops us from upgrading to NixOS 22.11
"""
def get_migration_name(self):
return "migrate_to_selfprivacy_channel_from_2205"
def get_migration_description(self):
return "Migrate to selfprivacy Nix channel from NixOS 22.05."
def is_migration_needed(self):
try:
output = subprocess.check_output(
["nix-channel", "--list"], start_new_session=True
)
output = output.decode("utf-8")
first_line = output.split("\n", maxsplit=1)[0]
return first_line.startswith("nixos") and (
first_line.endswith("nixos-22.05")
)
except subprocess.CalledProcessError:
return False
def migrate(self):
# Change the channel and update them.
# Also, go to /etc/nixos directory and make a git pull
current_working_directory = os.getcwd()
try:
print("Changing channel")
os.chdir("/etc/nixos")
subprocess.check_output(
[
"nix-channel",
"--add",
"https://channel.selfprivacy.org/nixos-selfpricacy",
"nixos",
]
)
subprocess.check_output(["nix-channel", "--update"])
nixos_config_branch = subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"], start_new_session=True
)
if nixos_config_branch.decode("utf-8").strip() == "api-redis":
print("Also changing nixos-config branch from api-redis to master")
subprocess.check_output(["git", "checkout", "master"])
subprocess.check_output(["git", "pull"])
os.chdir(current_working_directory)
except subprocess.CalledProcessError:
os.chdir(current_working_directory)
print("Error")

View file

@ -1,58 +0,0 @@
import os
import subprocess
from selfprivacy_api.migrations.migration import Migration
class MigrateToSelfprivacyChannelFrom2211(Migration):
"""Migrate to selfprivacy Nix channel.
For some reason NixOS 22.11 servers initialized with the nixos channel instead of selfprivacy.
This stops us from upgrading to NixOS 23.05
"""
def get_migration_name(self):
return "migrate_to_selfprivacy_channel_from_2211"
def get_migration_description(self):
return "Migrate to selfprivacy Nix channel from NixOS 22.11."
def is_migration_needed(self):
try:
output = subprocess.check_output(
["nix-channel", "--list"], start_new_session=True
)
output = output.decode("utf-8")
first_line = output.split("\n", maxsplit=1)[0]
return first_line.startswith("nixos") and (
first_line.endswith("nixos-22.11")
)
except subprocess.CalledProcessError:
return False
def migrate(self):
# Change the channel and update them.
# Also, go to /etc/nixos directory and make a git pull
current_working_directory = os.getcwd()
try:
print("Changing channel")
os.chdir("/etc/nixos")
subprocess.check_output(
[
"nix-channel",
"--add",
"https://channel.selfprivacy.org/nixos-selfpricacy",
"nixos",
]
)
subprocess.check_output(["nix-channel", "--update"])
nixos_config_branch = subprocess.check_output(
["git", "rev-parse", "--abbrev-ref", "HEAD"], start_new_session=True
)
if nixos_config_branch.decode("utf-8").strip() == "api-redis":
print("Also changing nixos-config branch from api-redis to master")
subprocess.check_output(["git", "checkout", "master"])
subprocess.check_output(["git", "pull"])
os.chdir(current_working_directory)
except subprocess.CalledProcessError:
os.chdir(current_working_directory)
print("Error")

View file

@ -1,43 +0,0 @@
from selfprivacy_api.migrations.migration import Migration
from selfprivacy_api.utils import ReadUserData, WriteUserData
class CreateProviderFields(Migration):
"""Unhardcode providers"""
def get_migration_name(self):
return "create_provider_fields"
def get_migration_description(self):
return "Add DNS, backup and server provider fields to enable user to choose between different clouds and to make the deployment adapt to these preferences."
def is_migration_needed(self):
try:
with ReadUserData() as userdata:
return "dns" not in userdata
except Exception as e:
print(e)
return False
def migrate(self):
# Write info about providers to userdata.json
try:
with WriteUserData() as userdata:
userdata["dns"] = {
"provider": "CLOUDFLARE",
"apiKey": userdata["cloudflare"]["apiKey"],
}
userdata["server"] = {
"provider": "HETZNER",
}
userdata["backup"] = {
"provider": "BACKBLAZE",
"accountId": userdata["backblaze"]["accountId"],
"accountKey": userdata["backblaze"]["accountKey"],
"bucket": userdata["backblaze"]["bucket"],
}
print("Done")
except Exception as e:
print(e)
print("Error migrating provider fields")

View file

@ -0,0 +1,63 @@
from datetime import datetime
from typing import Optional
from selfprivacy_api.migrations.migration import Migration
from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.repositories.tokens.redis_tokens_repository import (
RedisTokensRepository,
)
from selfprivacy_api.repositories.tokens.abstract_tokens_repository import (
AbstractTokensRepository,
)
from selfprivacy_api.utils import ReadUserData, UserDataFiles
class WriteTokenToRedis(Migration):
"""Load Json tokens into Redis"""
def get_migration_name(self):
return "write_token_to_redis"
def get_migration_description(self):
return "Loads the initial token into redis token storage"
def is_repo_empty(self, repo: AbstractTokensRepository) -> bool:
if repo.get_tokens() != []:
return False
return True
def get_token_from_json(self) -> Optional[Token]:
try:
with ReadUserData(UserDataFiles.SECRETS) as userdata:
return Token(
token=userdata["api"]["token"],
device_name="Initial device",
created_at=datetime.now(),
)
except Exception as e:
print(e)
return None
def is_migration_needed(self):
try:
if self.get_token_from_json() is not None and self.is_repo_empty(
RedisTokensRepository()
):
return True
except Exception as e:
print(e)
return False
def migrate(self):
# Write info about providers to userdata.json
try:
token = self.get_token_from_json()
if token is None:
print("No token found in secrets.json")
return
RedisTokensRepository()._store_token(token)
print("Done")
except Exception as e:
print(e)
print("Error migrating access tokens from json to redis")

View file

@ -1,8 +1,11 @@
import datetime import datetime
from pydantic import BaseModel from pydantic import BaseModel
from selfprivacy_api.graphql.common_types.backup import BackupReason
class Snapshot(BaseModel): class Snapshot(BaseModel):
id: str id: str
service_name: str service_name: str
created_at: datetime.datetime created_at: datetime.datetime
reason: BackupReason = BackupReason.EXPLICIT

View file

@ -1,11 +1,13 @@
""" """
New device key used to obtain access token. New device key used to obtain access token.
""" """
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
import secrets import secrets
from pydantic import BaseModel from pydantic import BaseModel
from mnemonic import Mnemonic from mnemonic import Mnemonic
from selfprivacy_api.models.tokens.time import is_past
class NewDeviceKey(BaseModel): class NewDeviceKey(BaseModel):
""" """
@ -20,15 +22,15 @@ class NewDeviceKey(BaseModel):
def is_valid(self) -> bool: def is_valid(self) -> bool:
""" """
Check if the recovery key is valid. Check if key is valid.
""" """
if self.expires_at < datetime.now(): if is_past(self.expires_at):
return False return False
return True return True
def as_mnemonic(self) -> str: def as_mnemonic(self) -> str:
""" """
Get the recovery key as a mnemonic. Get the key as a mnemonic.
""" """
return Mnemonic(language="english").to_mnemonic(bytes.fromhex(self.key)) return Mnemonic(language="english").to_mnemonic(bytes.fromhex(self.key))
@ -37,10 +39,10 @@ class NewDeviceKey(BaseModel):
""" """
Factory to generate a random token. Factory to generate a random token.
""" """
creation_date = datetime.now() creation_date = datetime.now(timezone.utc)
key = secrets.token_bytes(16).hex() key = secrets.token_bytes(16).hex()
return NewDeviceKey( return NewDeviceKey(
key=key, key=key,
created_at=creation_date, created_at=creation_date,
expires_at=datetime.now() + timedelta(minutes=10), expires_at=creation_date + timedelta(minutes=10),
) )

View file

@ -3,12 +3,14 @@ Recovery key used to obtain access token.
Recovery key has a token string, date of creation, optional date of expiration and optional count of uses left. Recovery key has a token string, date of creation, optional date of expiration and optional count of uses left.
""" """
from datetime import datetime from datetime import datetime, timezone
import secrets import secrets
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
from mnemonic import Mnemonic from mnemonic import Mnemonic
from selfprivacy_api.models.tokens.time import is_past, ensure_timezone
class RecoveryKey(BaseModel): class RecoveryKey(BaseModel):
""" """
@ -26,7 +28,7 @@ class RecoveryKey(BaseModel):
""" """
Check if the recovery key is valid. Check if the recovery key is valid.
""" """
if self.expires_at is not None and self.expires_at < datetime.now(): if self.expires_at is not None and is_past(self.expires_at):
return False return False
if self.uses_left is not None and self.uses_left <= 0: if self.uses_left is not None and self.uses_left <= 0:
return False return False
@ -45,8 +47,11 @@ class RecoveryKey(BaseModel):
) -> "RecoveryKey": ) -> "RecoveryKey":
""" """
Factory to generate a random token. Factory to generate a random token.
If passed naive time as expiration, assumes utc
""" """
creation_date = datetime.now() creation_date = datetime.now(timezone.utc)
if expiration is not None:
expiration = ensure_timezone(expiration)
key = secrets.token_bytes(24).hex() key = secrets.token_bytes(24).hex()
return RecoveryKey( return RecoveryKey(
key=key, key=key,

View file

@ -0,0 +1,14 @@
from datetime import datetime, timezone
def is_past(dt: datetime) -> bool:
# we cannot compare a naive now()
# to dt which might be tz-aware or unaware
dt = ensure_timezone(dt)
return dt < datetime.now(timezone.utc)
def ensure_timezone(dt: datetime) -> datetime:
if dt.tzinfo is None or dt.tzinfo.utcoffset(None) is None:
dt = dt.replace(tzinfo=timezone.utc)
return dt

View file

@ -1,8 +0,0 @@
from selfprivacy_api.repositories.tokens.abstract_tokens_repository import (
AbstractTokensRepository,
)
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
JsonTokensRepository,
)
repository = JsonTokensRepository()

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
@ -86,13 +88,15 @@ class AbstractTokensRepository(ABC):
def get_recovery_key(self) -> Optional[RecoveryKey]: def get_recovery_key(self) -> Optional[RecoveryKey]:
"""Get the recovery key""" """Get the recovery key"""
@abstractmethod
def create_recovery_key( def create_recovery_key(
self, self,
expiration: Optional[datetime], expiration: Optional[datetime],
uses_left: Optional[int], uses_left: Optional[int],
) -> RecoveryKey: ) -> RecoveryKey:
"""Create the recovery key""" """Create the recovery key"""
recovery_key = RecoveryKey.generate(expiration, uses_left)
self._store_recovery_key(recovery_key)
return recovery_key
def use_mnemonic_recovery_key( def use_mnemonic_recovery_key(
self, mnemonic_phrase: str, device_name: str self, mnemonic_phrase: str, device_name: str
@ -123,6 +127,14 @@ class AbstractTokensRepository(ABC):
return False return False
return recovery_key.is_valid() return recovery_key.is_valid()
@abstractmethod
def _store_recovery_key(self, recovery_key: RecoveryKey) -> None:
"""Store recovery key directly"""
@abstractmethod
def _delete_recovery_key(self) -> None:
"""Delete the recovery key"""
def get_new_device_key(self) -> NewDeviceKey: def get_new_device_key(self) -> NewDeviceKey:
"""Creates and returns the new device key""" """Creates and returns the new device key"""
new_device_key = NewDeviceKey.generate() new_device_key = NewDeviceKey.generate()
@ -156,6 +168,26 @@ class AbstractTokensRepository(ABC):
return new_token return new_token
def reset(self):
for token in self.get_tokens():
self.delete_token(token)
self.delete_new_device_key()
self._delete_recovery_key()
def clone(self, source: AbstractTokensRepository) -> None:
"""Clone the state of another repository to this one"""
self.reset()
for token in source.get_tokens():
self._store_token(token)
recovery_key = source.get_recovery_key()
if recovery_key is not None:
self._store_recovery_key(recovery_key)
new_device_key = source._get_stored_new_device_key()
if new_device_key is not None:
self._store_new_device_key(new_device_key)
@abstractmethod @abstractmethod
def _store_token(self, new_token: Token): def _store_token(self, new_token: Token):
"""Store a token directly""" """Store a token directly"""

View file

@ -1,133 +0,0 @@
"""
temporary legacy
"""
from typing import Optional
from datetime import datetime
from selfprivacy_api.utils import UserDataFiles, WriteUserData, ReadUserData
from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.models.tokens.recovery_key import RecoveryKey
from selfprivacy_api.models.tokens.new_device_key import NewDeviceKey
from selfprivacy_api.repositories.tokens.exceptions import (
TokenNotFound,
)
from selfprivacy_api.repositories.tokens.abstract_tokens_repository import (
AbstractTokensRepository,
)
DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S.%f"
class JsonTokensRepository(AbstractTokensRepository):
def get_tokens(self) -> list[Token]:
"""Get the tokens"""
tokens_list = []
with ReadUserData(UserDataFiles.TOKENS) as tokens_file:
for userdata_token in tokens_file["tokens"]:
tokens_list.append(
Token(
token=userdata_token["token"],
device_name=userdata_token["name"],
created_at=userdata_token["date"],
)
)
return tokens_list
def _store_token(self, new_token: Token):
"""Store a token directly"""
with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
tokens_file["tokens"].append(
{
"token": new_token.token,
"name": new_token.device_name,
"date": new_token.created_at.strftime(DATETIME_FORMAT),
}
)
def delete_token(self, input_token: Token) -> None:
"""Delete the token"""
with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
for userdata_token in tokens_file["tokens"]:
if userdata_token["token"] == input_token.token:
tokens_file["tokens"].remove(userdata_token)
return
raise TokenNotFound("Token not found!")
def get_recovery_key(self) -> Optional[RecoveryKey]:
"""Get the recovery key"""
with ReadUserData(UserDataFiles.TOKENS) as tokens_file:
if (
"recovery_token" not in tokens_file
or tokens_file["recovery_token"] is None
):
return
recovery_key = RecoveryKey(
key=tokens_file["recovery_token"].get("token"),
created_at=tokens_file["recovery_token"].get("date"),
expires_at=tokens_file["recovery_token"].get("expiration"),
uses_left=tokens_file["recovery_token"].get("uses_left"),
)
return recovery_key
def create_recovery_key(
self,
expiration: Optional[datetime],
uses_left: Optional[int],
) -> RecoveryKey:
"""Create the recovery key"""
recovery_key = RecoveryKey.generate(expiration, uses_left)
with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
key_expiration: Optional[str] = None
if recovery_key.expires_at is not None:
key_expiration = recovery_key.expires_at.strftime(DATETIME_FORMAT)
tokens_file["recovery_token"] = {
"token": recovery_key.key,
"date": recovery_key.created_at.strftime(DATETIME_FORMAT),
"expiration": key_expiration,
"uses_left": recovery_key.uses_left,
}
return recovery_key
def _decrement_recovery_token(self):
"""Decrement recovery key use count by one"""
if self.is_recovery_key_valid():
with WriteUserData(UserDataFiles.TOKENS) as tokens:
if tokens["recovery_token"]["uses_left"] is not None:
tokens["recovery_token"]["uses_left"] -= 1
def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
tokens_file["new_device"] = {
"token": new_device_key.key,
"date": new_device_key.created_at.strftime(DATETIME_FORMAT),
"expiration": new_device_key.expires_at.strftime(DATETIME_FORMAT),
}
def delete_new_device_key(self) -> None:
"""Delete the new device key"""
with WriteUserData(UserDataFiles.TOKENS) as tokens_file:
if "new_device" in tokens_file:
del tokens_file["new_device"]
return
def _get_stored_new_device_key(self) -> Optional[NewDeviceKey]:
"""Retrieves new device key that is already stored."""
with ReadUserData(UserDataFiles.TOKENS) as tokens_file:
if "new_device" not in tokens_file or tokens_file["new_device"] is None:
return
new_device_key = NewDeviceKey(
key=tokens_file["new_device"]["token"],
created_at=tokens_file["new_device"]["date"],
expires_at=tokens_file["new_device"]["expiration"],
)
return new_device_key

View file

@ -4,6 +4,7 @@ Token repository using Redis as backend.
from typing import Any, Optional from typing import Any, Optional
from datetime import datetime from datetime import datetime
from hashlib import md5 from hashlib import md5
from datetime import timezone
from selfprivacy_api.repositories.tokens.abstract_tokens_repository import ( from selfprivacy_api.repositories.tokens.abstract_tokens_repository import (
AbstractTokensRepository, AbstractTokensRepository,
@ -53,6 +54,7 @@ class RedisTokensRepository(AbstractTokensRepository):
token = self._token_from_hash(key) token = self._token_from_hash(key)
if token == input_token: if token == input_token:
return key return key
return None
def delete_token(self, input_token: Token) -> None: def delete_token(self, input_token: Token) -> None:
"""Delete the token""" """Delete the token"""
@ -62,13 +64,6 @@ class RedisTokensRepository(AbstractTokensRepository):
raise TokenNotFound raise TokenNotFound
redis.delete(key) redis.delete(key)
def reset(self):
for token in self.get_tokens():
self.delete_token(token)
self.delete_new_device_key()
redis = self.connection
redis.delete(RECOVERY_KEY_REDIS_KEY)
def get_recovery_key(self) -> Optional[RecoveryKey]: def get_recovery_key(self) -> Optional[RecoveryKey]:
"""Get the recovery key""" """Get the recovery key"""
redis = self.connection redis = self.connection
@ -76,15 +71,13 @@ class RedisTokensRepository(AbstractTokensRepository):
return self._recovery_key_from_hash(RECOVERY_KEY_REDIS_KEY) return self._recovery_key_from_hash(RECOVERY_KEY_REDIS_KEY)
return None return None
def create_recovery_key( def _store_recovery_key(self, recovery_key: RecoveryKey) -> None:
self,
expiration: Optional[datetime],
uses_left: Optional[int],
) -> RecoveryKey:
"""Create the recovery key"""
recovery_key = RecoveryKey.generate(expiration=expiration, uses_left=uses_left)
self._store_model_as_hash(RECOVERY_KEY_REDIS_KEY, recovery_key) self._store_model_as_hash(RECOVERY_KEY_REDIS_KEY, recovery_key)
return recovery_key
def _delete_recovery_key(self) -> None:
"""Delete the recovery key"""
redis = self.connection
redis.delete(RECOVERY_KEY_REDIS_KEY)
def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None: def _store_new_device_key(self, new_device_key: NewDeviceKey) -> None:
"""Store new device key directly""" """Store new device key directly"""
@ -157,6 +150,7 @@ class RedisTokensRepository(AbstractTokensRepository):
if token is not None: if token is not None:
token.created_at = token.created_at.replace(tzinfo=None) token.created_at = token.created_at.replace(tzinfo=None)
return token return token
return None
def _recovery_key_from_hash(self, redis_key: str) -> Optional[RecoveryKey]: def _recovery_key_from_hash(self, redis_key: str) -> Optional[RecoveryKey]:
return self._hash_as_model(redis_key, RecoveryKey) return self._hash_as_model(redis_key, RecoveryKey)
@ -168,5 +162,7 @@ class RedisTokensRepository(AbstractTokensRepository):
redis = self.connection redis = self.connection
for key, value in model.dict().items(): for key, value in model.dict().items():
if isinstance(value, datetime): if isinstance(value, datetime):
if value.tzinfo is None:
value = value.replace(tzinfo=timezone.utc)
value = value.isoformat() value = value.isoformat()
redis.hset(redis_key, key, str(value)) redis.hset(redis_key, key, str(value))

View file

@ -1,125 +0,0 @@
from datetime import datetime
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from selfprivacy_api.actions.api_tokens import (
CannotDeleteCallerException,
InvalidExpirationDate,
InvalidUsesLeft,
NotFoundException,
delete_api_token,
refresh_api_token,
get_api_recovery_token_status,
get_api_tokens_with_caller_flag,
get_new_api_recovery_key,
use_mnemonic_recovery_token,
delete_new_device_auth_token,
get_new_device_auth_token,
use_new_device_auth_token,
)
from selfprivacy_api.dependencies import TokenHeader, get_token_header
router = APIRouter(
prefix="/auth",
tags=["auth"],
responses={404: {"description": "Not found"}},
)
@router.get("/tokens")
async def rest_get_tokens(auth_token: TokenHeader = Depends(get_token_header)):
"""Get the tokens info"""
return get_api_tokens_with_caller_flag(auth_token.token)
class DeleteTokenInput(BaseModel):
"""Delete token input"""
token_name: str
@router.delete("/tokens")
async def rest_delete_tokens(
token: DeleteTokenInput, auth_token: TokenHeader = Depends(get_token_header)
):
"""Delete the tokens"""
try:
delete_api_token(auth_token.token, token.token_name)
except NotFoundException:
raise HTTPException(status_code=404, detail="Token not found")
except CannotDeleteCallerException:
raise HTTPException(status_code=400, detail="Cannot delete caller's token")
return {"message": "Token deleted"}
@router.post("/tokens")
async def rest_refresh_token(auth_token: TokenHeader = Depends(get_token_header)):
"""Refresh the token"""
try:
new_token = refresh_api_token(auth_token.token)
except NotFoundException:
raise HTTPException(status_code=404, detail="Token not found")
return {"token": new_token}
@router.get("/recovery_token")
async def rest_get_recovery_token_status(
auth_token: TokenHeader = Depends(get_token_header),
):
return get_api_recovery_token_status()
class CreateRecoveryTokenInput(BaseModel):
expiration: Optional[datetime] = None
uses: Optional[int] = None
@router.post("/recovery_token")
async def rest_create_recovery_token(
limits: CreateRecoveryTokenInput = CreateRecoveryTokenInput(),
auth_token: TokenHeader = Depends(get_token_header),
):
try:
token = get_new_api_recovery_key(limits.expiration, limits.uses)
except InvalidExpirationDate as e:
raise HTTPException(status_code=400, detail=str(e))
except InvalidUsesLeft as e:
raise HTTPException(status_code=400, detail=str(e))
return {"token": token}
class UseTokenInput(BaseModel):
token: str
device: str
@router.post("/recovery_token/use")
async def rest_use_recovery_token(input: UseTokenInput):
token = use_mnemonic_recovery_token(input.token, input.device)
if token is None:
raise HTTPException(status_code=404, detail="Token not found")
return {"token": token}
@router.post("/new_device")
async def rest_new_device(auth_token: TokenHeader = Depends(get_token_header)):
token = get_new_device_auth_token()
return {"token": token}
@router.delete("/new_device")
async def rest_delete_new_device_token(
auth_token: TokenHeader = Depends(get_token_header),
):
delete_new_device_auth_token()
return {"token": None}
@router.post("/new_device/authorize")
async def rest_new_device_authorize(input: UseTokenInput):
token = use_new_device_auth_token(input.token, input.device)
if token is None:
raise HTTPException(status_code=404, detail="Token not found")
return {"message": "Device authorized", "token": token}

View file

@ -1,336 +0,0 @@
"""Basic services legacy api"""
import base64
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from selfprivacy_api.actions.ssh import (
InvalidPublicKey,
KeyAlreadyExists,
KeyNotFound,
create_ssh_key,
enable_ssh,
get_ssh_settings,
remove_ssh_key,
set_ssh_settings,
)
from selfprivacy_api.actions.users import UserNotFound, get_user_by_username
from selfprivacy_api.dependencies import get_token_header
from selfprivacy_api.services.bitwarden import Bitwarden
from selfprivacy_api.services.gitea import Gitea
from selfprivacy_api.services.mailserver import MailServer
from selfprivacy_api.services.nextcloud import Nextcloud
from selfprivacy_api.services.ocserv import Ocserv
from selfprivacy_api.services.pleroma import Pleroma
from selfprivacy_api.services.service import ServiceStatus
from selfprivacy_api.utils import get_dkim_key, get_domain
router = APIRouter(
prefix="/services",
tags=["services"],
dependencies=[Depends(get_token_header)],
responses={404: {"description": "Not found"}},
)
def service_status_to_return_code(status: ServiceStatus):
"""Converts service status object to return code for
compatibility with legacy api"""
if status == ServiceStatus.ACTIVE:
return 0
elif status == ServiceStatus.FAILED:
return 1
elif status == ServiceStatus.INACTIVE:
return 3
elif status == ServiceStatus.OFF:
return 4
else:
return 2
@router.get("/status")
async def get_status():
"""Get the status of the services"""
mail_status = MailServer.get_status()
bitwarden_status = Bitwarden.get_status()
gitea_status = Gitea.get_status()
nextcloud_status = Nextcloud.get_status()
ocserv_stauts = Ocserv.get_status()
pleroma_status = Pleroma.get_status()
return {
"imap": service_status_to_return_code(mail_status),
"smtp": service_status_to_return_code(mail_status),
"http": 0,
"bitwarden": service_status_to_return_code(bitwarden_status),
"gitea": service_status_to_return_code(gitea_status),
"nextcloud": service_status_to_return_code(nextcloud_status),
"ocserv": service_status_to_return_code(ocserv_stauts),
"pleroma": service_status_to_return_code(pleroma_status),
}
@router.post("/bitwarden/enable")
async def enable_bitwarden():
"""Enable Bitwarden"""
Bitwarden.enable()
return {
"status": 0,
"message": "Bitwarden enabled",
}
@router.post("/bitwarden/disable")
async def disable_bitwarden():
"""Disable Bitwarden"""
Bitwarden.disable()
return {
"status": 0,
"message": "Bitwarden disabled",
}
@router.post("/gitea/enable")
async def enable_gitea():
"""Enable Gitea"""
Gitea.enable()
return {
"status": 0,
"message": "Gitea enabled",
}
@router.post("/gitea/disable")
async def disable_gitea():
"""Disable Gitea"""
Gitea.disable()
return {
"status": 0,
"message": "Gitea disabled",
}
@router.get("/mailserver/dkim")
async def get_mailserver_dkim():
"""Get the DKIM record for the mailserver"""
domain = get_domain()
dkim = get_dkim_key(domain, parse=False)
if dkim is None:
raise HTTPException(status_code=404, detail="DKIM record not found")
dkim = base64.b64encode(dkim.encode("utf-8")).decode("utf-8")
return dkim
@router.post("/nextcloud/enable")
async def enable_nextcloud():
"""Enable Nextcloud"""
Nextcloud.enable()
return {
"status": 0,
"message": "Nextcloud enabled",
}
@router.post("/nextcloud/disable")
async def disable_nextcloud():
"""Disable Nextcloud"""
Nextcloud.disable()
return {
"status": 0,
"message": "Nextcloud disabled",
}
@router.post("/ocserv/enable")
async def enable_ocserv():
"""Enable Ocserv"""
Ocserv.enable()
return {
"status": 0,
"message": "Ocserv enabled",
}
@router.post("/ocserv/disable")
async def disable_ocserv():
"""Disable Ocserv"""
Ocserv.disable()
return {
"status": 0,
"message": "Ocserv disabled",
}
@router.post("/pleroma/enable")
async def enable_pleroma():
"""Enable Pleroma"""
Pleroma.enable()
return {
"status": 0,
"message": "Pleroma enabled",
}
@router.post("/pleroma/disable")
async def disable_pleroma():
"""Disable Pleroma"""
Pleroma.disable()
return {
"status": 0,
"message": "Pleroma disabled",
}
@router.get("/restic/backup/list")
async def get_restic_backup_list():
raise HTTPException(
status_code=410,
detail="This endpoint is deprecated, please use GraphQL API",
)
@router.put("/restic/backup/create")
async def create_restic_backup():
raise HTTPException(
status_code=410,
detail="This endpoint is deprecated, please use GraphQL API",
)
@router.get("/restic/backup/status")
async def get_restic_backup_status():
raise HTTPException(
status_code=410,
detail="This endpoint is deprecated, please use GraphQL API",
)
@router.get("/restic/backup/reload")
async def reload_restic_backup():
raise HTTPException(
status_code=410,
detail="This endpoint is deprecated, please use GraphQL API",
)
class BackupRestoreInput(BaseModel):
backupId: str
@router.put("/restic/backup/restore")
async def restore_restic_backup(backup: BackupRestoreInput):
raise HTTPException(
status_code=410,
detail="This endpoint is deprecated, please use GraphQL API",
)
class BackupConfigInput(BaseModel):
accountId: str
accountKey: str
bucket: str
@router.put("/restic/backblaze/config")
async def set_backblaze_config(backup_config: BackupConfigInput):
raise HTTPException(
status_code=410,
detail="This endpoint is deprecated, please use GraphQL API",
)
@router.post("/ssh/enable")
async def rest_enable_ssh():
"""Enable SSH"""
enable_ssh()
return {
"status": 0,
"message": "SSH enabled",
}
@router.get("/ssh")
async def rest_get_ssh():
"""Get the SSH configuration"""
settings = get_ssh_settings()
return {
"enable": settings.enable,
"passwordAuthentication": settings.passwordAuthentication,
}
class SshConfigInput(BaseModel):
enable: Optional[bool] = None
passwordAuthentication: Optional[bool] = None
@router.put("/ssh")
async def rest_set_ssh(ssh_config: SshConfigInput):
"""Set the SSH configuration"""
set_ssh_settings(ssh_config.enable, ssh_config.passwordAuthentication)
return "SSH settings changed"
class SshKeyInput(BaseModel):
public_key: str
@router.put("/ssh/key/send", status_code=201)
async def rest_send_ssh_key(input: SshKeyInput):
"""Send the SSH key"""
try:
create_ssh_key("root", input.public_key)
except KeyAlreadyExists as error:
raise HTTPException(status_code=409, detail="Key already exists") from error
except InvalidPublicKey as error:
raise HTTPException(
status_code=400,
detail="Invalid key type. Only ssh-ed25519 and ssh-rsa are supported",
) from error
return {
"status": 0,
"message": "SSH key sent",
}
@router.get("/ssh/keys/{username}")
async def rest_get_ssh_keys(username: str):
"""Get the SSH keys for a user"""
user = get_user_by_username(username)
if user is None:
raise HTTPException(status_code=404, detail="User not found")
return user.ssh_keys
@router.post("/ssh/keys/{username}", status_code=201)
async def rest_add_ssh_key(username: str, input: SshKeyInput):
try:
create_ssh_key(username, input.public_key)
except KeyAlreadyExists as error:
raise HTTPException(status_code=409, detail="Key already exists") from error
except InvalidPublicKey as error:
raise HTTPException(
status_code=400,
detail="Invalid key type. Only ssh-ed25519 and ssh-rsa are supported",
) from error
except UserNotFound as error:
raise HTTPException(status_code=404, detail="User not found") from error
return {
"message": "New SSH key successfully written",
}
@router.delete("/ssh/keys/{username}")
async def rest_delete_ssh_key(username: str, input: SshKeyInput):
try:
remove_ssh_key(username, input.public_key)
except KeyNotFound as error:
raise HTTPException(status_code=404, detail="Key not found") from error
except UserNotFound as error:
raise HTTPException(status_code=404, detail="User not found") from error
return {"message": "SSH key deleted"}

View file

@ -1,105 +0,0 @@
from typing import Optional
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel
from selfprivacy_api.dependencies import get_token_header
import selfprivacy_api.actions.system as system_actions
router = APIRouter(
prefix="/system",
tags=["system"],
dependencies=[Depends(get_token_header)],
responses={404: {"description": "Not found"}},
)
@router.get("/configuration/timezone")
async def get_timezone():
"""Get the timezone of the server"""
return system_actions.get_timezone()
class ChangeTimezoneRequestBody(BaseModel):
"""Change the timezone of the server"""
timezone: str
@router.put("/configuration/timezone")
async def change_timezone(timezone: ChangeTimezoneRequestBody):
"""Change the timezone of the server"""
try:
system_actions.change_timezone(timezone.timezone)
except system_actions.InvalidTimezone as e:
raise HTTPException(status_code=400, detail=str(e))
return {"timezone": timezone.timezone}
@router.get("/configuration/autoUpgrade")
async def get_auto_upgrade_settings():
"""Get the auto-upgrade settings"""
return system_actions.get_auto_upgrade_settings().dict()
class AutoUpgradeSettings(BaseModel):
"""Settings for auto-upgrading user data"""
enable: Optional[bool] = None
allowReboot: Optional[bool] = None
@router.put("/configuration/autoUpgrade")
async def set_auto_upgrade_settings(settings: AutoUpgradeSettings):
"""Set the auto-upgrade settings"""
system_actions.set_auto_upgrade_settings(settings.enable, settings.allowReboot)
return "Auto-upgrade settings changed"
@router.get("/configuration/apply")
async def apply_configuration():
"""Apply the configuration"""
return_code = system_actions.rebuild_system()
return return_code
@router.get("/configuration/rollback")
async def rollback_configuration():
"""Rollback the configuration"""
return_code = system_actions.rollback_system()
return return_code
@router.get("/configuration/upgrade")
async def upgrade_configuration():
"""Upgrade the configuration"""
return_code = system_actions.upgrade_system()
return return_code
@router.get("/reboot")
async def reboot_system():
"""Reboot the system"""
system_actions.reboot_system()
return "System reboot has started"
@router.get("/version")
async def get_system_version():
"""Get the system version"""
return {"system_version": system_actions.get_system_version()}
@router.get("/pythonVersion")
async def get_python_version():
"""Get the Python version"""
return system_actions.get_python_version()
@router.get("/configuration/pull")
async def pull_configuration():
"""Pull the configuration"""
action_result = system_actions.pull_repository_changes()
if action_result.status == 0:
return action_result.dict()
raise HTTPException(status_code=500, detail=action_result.dict())

View file

@ -1,62 +0,0 @@
"""Users management module"""
from typing import Optional
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel
import selfprivacy_api.actions.users as users_actions
from selfprivacy_api.dependencies import get_token_header
router = APIRouter(
prefix="/users",
tags=["users"],
dependencies=[Depends(get_token_header)],
responses={404: {"description": "Not found"}},
)
@router.get("")
async def get_users(withMainUser: bool = False):
"""Get the list of users"""
users: list[users_actions.UserDataUser] = users_actions.get_users(
exclude_primary=not withMainUser, exclude_root=True
)
return [user.username for user in users]
class UserInput(BaseModel):
"""User input"""
username: str
password: str
@router.post("", status_code=201)
async def create_user(user: UserInput):
try:
users_actions.create_user(user.username, user.password)
except users_actions.PasswordIsEmpty as e:
raise HTTPException(status_code=400, detail=str(e))
except users_actions.UsernameForbidden as e:
raise HTTPException(status_code=409, detail=str(e))
except users_actions.UsernameNotAlphanumeric as e:
raise HTTPException(status_code=400, detail=str(e))
except users_actions.UsernameTooLong as e:
raise HTTPException(status_code=400, detail=str(e))
except users_actions.UserAlreadyExists as e:
raise HTTPException(status_code=409, detail=str(e))
return {"result": 0, "username": user.username}
@router.delete("/{username}")
async def delete_user(username: str):
try:
users_actions.delete_user(username)
except users_actions.UserNotFound as e:
raise HTTPException(status_code=404, detail=str(e))
except users_actions.UserIsProtected as e:
raise HTTPException(status_code=400, detail=str(e))
return {"result": 0, "username": username}

View file

@ -3,7 +3,7 @@
import typing import typing
from selfprivacy_api.services.bitwarden import Bitwarden from selfprivacy_api.services.bitwarden import Bitwarden
from selfprivacy_api.services.gitea import Gitea from selfprivacy_api.services.gitea import Gitea
from selfprivacy_api.services.jitsi import Jitsi from selfprivacy_api.services.jitsimeet import JitsiMeet
from selfprivacy_api.services.mailserver import MailServer from selfprivacy_api.services.mailserver import MailServer
from selfprivacy_api.services.nextcloud import Nextcloud from selfprivacy_api.services.nextcloud import Nextcloud
from selfprivacy_api.services.pleroma import Pleroma from selfprivacy_api.services.pleroma import Pleroma
@ -18,7 +18,7 @@ services: list[Service] = [
Nextcloud(), Nextcloud(),
Pleroma(), Pleroma(),
Ocserv(), Ocserv(),
Jitsi(), JitsiMeet(),
] ]
@ -54,12 +54,14 @@ def get_all_required_dns_records() -> list[ServiceDnsRecord]:
name="api", name="api",
content=ip4, content=ip4,
ttl=3600, ttl=3600,
display_name="SelfPrivacy API",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="AAAA", type="AAAA",
name="api", name="api",
content=ip6, content=ip6,
ttl=3600, ttl=3600,
display_name="SelfPrivacy API (IPv6)",
), ),
] ]
for service in get_enabled_services(): for service in get_enabled_services():

View file

@ -58,11 +58,6 @@ class Bitwarden(Service):
def get_backup_description() -> str: def get_backup_description() -> str:
return "Password database, encryption certificate and attachments." return "Password database, encryption certificate and attachments."
@staticmethod
def is_enabled() -> bool:
with ReadUserData() as user_data:
return user_data.get("bitwarden", {}).get("enable", False)
@staticmethod @staticmethod
def get_status() -> ServiceStatus: def get_status() -> ServiceStatus:
""" """
@ -76,22 +71,6 @@ class Bitwarden(Service):
""" """
return get_service_status("vaultwarden.service") return get_service_status("vaultwarden.service")
@staticmethod
def enable():
"""Enable Bitwarden service."""
with WriteUserData() as user_data:
if "bitwarden" not in user_data:
user_data["bitwarden"] = {}
user_data["bitwarden"]["enable"] = True
@staticmethod
def disable():
"""Disable Bitwarden service."""
with WriteUserData() as user_data:
if "bitwarden" not in user_data:
user_data["bitwarden"] = {}
user_data["bitwarden"]["enable"] = False
@staticmethod @staticmethod
def stop(): def stop():
subprocess.run(["systemctl", "stop", "vaultwarden.service"]) subprocess.run(["systemctl", "stop", "vaultwarden.service"])
@ -129,12 +108,14 @@ class Bitwarden(Service):
name="password", name="password",
content=network_utils.get_ip4(), content=network_utils.get_ip4(),
ttl=3600, ttl=3600,
display_name="Bitwarden",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="AAAA", type="AAAA",
name="password", name="password",
content=network_utils.get_ip6(), content=network_utils.get_ip6(),
ttl=3600, ttl=3600,
display_name="Bitwarden (IPv6)",
), ),
] ]

View file

@ -244,9 +244,11 @@ def move_service(
progress=95, progress=95,
) )
with WriteUserData() as user_data: with WriteUserData() as user_data:
if userdata_location not in user_data: if "modules" not in user_data:
user_data[userdata_location] = {} user_data["modules"] = {}
user_data[userdata_location]["location"] = volume.name if userdata_location not in user_data["modules"]:
user_data["modules"][userdata_location] = {}
user_data["modules"][userdata_location]["location"] = volume.name
# Start service # Start service
service.start() service.start()
Jobs.update( Jobs.update(

View file

@ -54,11 +54,6 @@ class Gitea(Service):
def get_backup_description() -> str: def get_backup_description() -> str:
return "Git repositories, database and user data." return "Git repositories, database and user data."
@staticmethod
def is_enabled() -> bool:
with ReadUserData() as user_data:
return user_data.get("gitea", {}).get("enable", False)
@staticmethod @staticmethod
def get_status() -> ServiceStatus: def get_status() -> ServiceStatus:
""" """
@ -71,22 +66,6 @@ class Gitea(Service):
""" """
return get_service_status("gitea.service") return get_service_status("gitea.service")
@staticmethod
def enable():
"""Enable Gitea service."""
with WriteUserData() as user_data:
if "gitea" not in user_data:
user_data["gitea"] = {}
user_data["gitea"]["enable"] = True
@staticmethod
def disable():
"""Disable Gitea service."""
with WriteUserData() as user_data:
if "gitea" not in user_data:
user_data["gitea"] = {}
user_data["gitea"]["enable"] = False
@staticmethod @staticmethod
def stop(): def stop():
subprocess.run(["systemctl", "stop", "gitea.service"]) subprocess.run(["systemctl", "stop", "gitea.service"])
@ -123,12 +102,14 @@ class Gitea(Service):
name="git", name="git",
content=network_utils.get_ip4(), content=network_utils.get_ip4(),
ttl=3600, ttl=3600,
display_name="Gitea",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="AAAA", type="AAAA",
name="git", name="git",
content=network_utils.get_ip6(), content=network_utils.get_ip6(),
ttl=3600, ttl=3600,
display_name="Gitea (IPv6)",
), ),
] ]

View file

@ -1,4 +1,4 @@
"""Class representing Jitsi service""" """Class representing Jitsi Meet service"""
import base64 import base64
import subprocess import subprocess
import typing import typing
@ -11,26 +11,26 @@ from selfprivacy_api.services.service import Service, ServiceDnsRecord, ServiceS
from selfprivacy_api.utils import ReadUserData, WriteUserData, get_domain from selfprivacy_api.utils import ReadUserData, WriteUserData, get_domain
from selfprivacy_api.utils.block_devices import BlockDevice from selfprivacy_api.utils.block_devices import BlockDevice
import selfprivacy_api.utils.network as network_utils import selfprivacy_api.utils.network as network_utils
from selfprivacy_api.services.jitsi.icon import JITSI_ICON from selfprivacy_api.services.jitsimeet.icon import JITSI_ICON
class Jitsi(Service): class JitsiMeet(Service):
"""Class representing Jitsi service""" """Class representing Jitsi service"""
@staticmethod @staticmethod
def get_id() -> str: def get_id() -> str:
"""Return service id.""" """Return service id."""
return "jitsi" return "jitsi-meet"
@staticmethod @staticmethod
def get_display_name() -> str: def get_display_name() -> str:
"""Return service display name.""" """Return service display name."""
return "Jitsi" return "JitsiMeet"
@staticmethod @staticmethod
def get_description() -> str: def get_description() -> str:
"""Return service description.""" """Return service description."""
return "Jitsi is a free and open-source video conferencing solution." return "Jitsi Meet is a free and open-source video conferencing solution."
@staticmethod @staticmethod
def get_svg_icon() -> str: def get_svg_icon() -> str:
@ -55,33 +55,12 @@ class Jitsi(Service):
def get_backup_description() -> str: def get_backup_description() -> str:
return "Secrets that are used to encrypt the communication." return "Secrets that are used to encrypt the communication."
@staticmethod
def is_enabled() -> bool:
with ReadUserData() as user_data:
return user_data.get("jitsi", {}).get("enable", False)
@staticmethod @staticmethod
def get_status() -> ServiceStatus: def get_status() -> ServiceStatus:
return get_service_status_from_several_units( return get_service_status_from_several_units(
["jitsi-videobridge.service", "jicofo.service"] ["jitsi-videobridge.service", "jicofo.service"]
) )
@staticmethod
def enable():
"""Enable Jitsi service."""
with WriteUserData() as user_data:
if "jitsi" not in user_data:
user_data["jitsi"] = {}
user_data["jitsi"]["enable"] = True
@staticmethod
def disable():
"""Disable Gitea service."""
with WriteUserData() as user_data:
if "jitsi" not in user_data:
user_data["jitsi"] = {}
user_data["jitsi"]["enable"] = False
@staticmethod @staticmethod
def stop(): def stop():
subprocess.run( subprocess.run(
@ -132,14 +111,16 @@ class Jitsi(Service):
name="meet", name="meet",
content=ip4, content=ip4,
ttl=3600, ttl=3600,
display_name="Jitsi",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="AAAA", type="AAAA",
name="meet", name="meet",
content=ip6, content=ip6,
ttl=3600, ttl=3600,
display_name="Jitsi (IPv6)",
), ),
] ]
def move_to_volume(self, volume: BlockDevice) -> Job: def move_to_volume(self, volume: BlockDevice) -> Job:
raise NotImplementedError("jitsi service is not movable") raise NotImplementedError("jitsi-meet service is not movable")

View file

@ -21,7 +21,7 @@ class MailServer(Service):
@staticmethod @staticmethod
def get_id() -> str: def get_id() -> str:
return "email" return "simple-nixos-mailserver"
@staticmethod @staticmethod
def get_display_name() -> str: def get_display_name() -> str:
@ -121,27 +121,43 @@ class MailServer(Service):
name=domain, name=domain,
content=ip4, content=ip4,
ttl=3600, ttl=3600,
display_name="Root Domain",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="AAAA", type="AAAA",
name=domain, name=domain,
content=ip6, content=ip6,
ttl=3600, ttl=3600,
display_name="Root Domain (IPv6)",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="MX", name=domain, content=domain, ttl=3600, priority=10 type="MX",
name=domain,
content=domain,
ttl=3600,
priority=10,
display_name="Mail server record",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="TXT", name="_dmarc", content="v=DMARC1; p=none", ttl=18000 type="TXT",
name="_dmarc",
content="v=DMARC1; p=none",
ttl=18000,
display_name="DMARC record",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="TXT", type="TXT",
name=domain, name=domain,
content=f"v=spf1 a mx ip4:{ip4} -all", content=f"v=spf1 a mx ip4:{ip4} -all",
ttl=18000, ttl=18000,
display_name="SPF record",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="TXT", name="selector._domainkey", content=dkim_record, ttl=18000 type="TXT",
name="selector._domainkey",
content=dkim_record,
ttl=18000,
display_name="DKIM key",
), ),
] ]
@ -157,7 +173,7 @@ class MailServer(Service):
volume, volume,
job, job,
FolderMoveNames.default_foldermoves(self), FolderMoveNames.default_foldermoves(self),
"email", "simple-nixos-mailserver",
) )
return job return job

View file

@ -53,11 +53,6 @@ class Nextcloud(Service):
def get_backup_description() -> str: def get_backup_description() -> str:
return "All the files and other data stored in Nextcloud." return "All the files and other data stored in Nextcloud."
@staticmethod
def is_enabled() -> bool:
with ReadUserData() as user_data:
return user_data.get("nextcloud", {}).get("enable", False)
@staticmethod @staticmethod
def get_status() -> ServiceStatus: def get_status() -> ServiceStatus:
""" """
@ -71,22 +66,6 @@ class Nextcloud(Service):
""" """
return get_service_status("phpfpm-nextcloud.service") return get_service_status("phpfpm-nextcloud.service")
@staticmethod
def enable():
"""Enable Nextcloud service."""
with WriteUserData() as user_data:
if "nextcloud" not in user_data:
user_data["nextcloud"] = {}
user_data["nextcloud"]["enable"] = True
@staticmethod
def disable():
"""Disable Nextcloud service."""
with WriteUserData() as user_data:
if "nextcloud" not in user_data:
user_data["nextcloud"] = {}
user_data["nextcloud"]["enable"] = False
@staticmethod @staticmethod
def stop(): def stop():
"""Stop Nextcloud service.""" """Stop Nextcloud service."""
@ -128,12 +107,14 @@ class Nextcloud(Service):
name="cloud", name="cloud",
content=network_utils.get_ip4(), content=network_utils.get_ip4(),
ttl=3600, ttl=3600,
display_name="Nextcloud",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="AAAA", type="AAAA",
name="cloud", name="cloud",
content=network_utils.get_ip6(), content=network_utils.get_ip6(),
ttl=3600, ttl=3600,
display_name="Nextcloud (IPv6)",
), ),
] ]

View file

@ -51,29 +51,10 @@ class Ocserv(Service):
def get_backup_description() -> str: def get_backup_description() -> str:
return "Nothing to backup." return "Nothing to backup."
@staticmethod
def is_enabled() -> bool:
with ReadUserData() as user_data:
return user_data.get("ocserv", {}).get("enable", False)
@staticmethod @staticmethod
def get_status() -> ServiceStatus: def get_status() -> ServiceStatus:
return get_service_status("ocserv.service") return get_service_status("ocserv.service")
@staticmethod
def enable():
with WriteUserData() as user_data:
if "ocserv" not in user_data:
user_data["ocserv"] = {}
user_data["ocserv"]["enable"] = True
@staticmethod
def disable():
with WriteUserData() as user_data:
if "ocserv" not in user_data:
user_data["ocserv"] = {}
user_data["ocserv"]["enable"] = False
@staticmethod @staticmethod
def stop(): def stop():
subprocess.run(["systemctl", "stop", "ocserv.service"], check=False) subprocess.run(["systemctl", "stop", "ocserv.service"], check=False)
@ -106,12 +87,14 @@ class Ocserv(Service):
name="vpn", name="vpn",
content=network_utils.get_ip4(), content=network_utils.get_ip4(),
ttl=3600, ttl=3600,
display_name="OpenConnect VPN",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="AAAA", type="AAAA",
name="vpn", name="vpn",
content=network_utils.get_ip6(), content=network_utils.get_ip6(),
ttl=3600, ttl=3600,
display_name="OpenConnect VPN (IPv6)",
), ),
] ]

View file

@ -50,29 +50,10 @@ class Pleroma(Service):
def get_backup_description() -> str: def get_backup_description() -> str:
return "Your Pleroma accounts, posts and media." return "Your Pleroma accounts, posts and media."
@staticmethod
def is_enabled() -> bool:
with ReadUserData() as user_data:
return user_data.get("pleroma", {}).get("enable", False)
@staticmethod @staticmethod
def get_status() -> ServiceStatus: def get_status() -> ServiceStatus:
return get_service_status("pleroma.service") return get_service_status("pleroma.service")
@staticmethod
def enable():
with WriteUserData() as user_data:
if "pleroma" not in user_data:
user_data["pleroma"] = {}
user_data["pleroma"]["enable"] = True
@staticmethod
def disable():
with WriteUserData() as user_data:
if "pleroma" not in user_data:
user_data["pleroma"] = {}
user_data["pleroma"]["enable"] = False
@staticmethod @staticmethod
def stop(): def stop():
subprocess.run(["systemctl", "stop", "pleroma.service"]) subprocess.run(["systemctl", "stop", "pleroma.service"])
@ -127,12 +108,14 @@ class Pleroma(Service):
name="social", name="social",
content=network_utils.get_ip4(), content=network_utils.get_ip4(),
ttl=3600, ttl=3600,
display_name="Pleroma",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="AAAA", type="AAAA",
name="social", name="social",
content=network_utils.get_ip6(), content=network_utils.get_ip6(),
ttl=3600, ttl=3600,
display_name="Pleroma (IPv6)",
), ),
] ]

View file

@ -12,6 +12,7 @@ from selfprivacy_api.services.generic_size_counter import get_storage_usage
from selfprivacy_api.services.owned_path import OwnedPath from selfprivacy_api.services.owned_path import OwnedPath
from selfprivacy_api import utils from selfprivacy_api import utils
from selfprivacy_api.utils.waitloop import wait_until_true from selfprivacy_api.utils.waitloop import wait_until_true
from selfprivacy_api.utils import ReadUserData, WriteUserData, get_domain
DEFAULT_START_STOP_TIMEOUT = 5 * 60 DEFAULT_START_STOP_TIMEOUT = 5 * 60
@ -33,6 +34,7 @@ class ServiceDnsRecord(BaseModel):
name: str name: str
content: str content: str
ttl: int ttl: int
display_name: str
priority: typing.Optional[int] = None priority: typing.Optional[int] = None
@ -124,11 +126,17 @@ class Service(ABC):
""" """
pass pass
@staticmethod @classmethod
@abstractmethod def is_enabled(cls) -> bool:
def is_enabled() -> bool: """
"""`True` if the service is enabled.""" `True` if the service is enabled.
pass `False` if it is not enabled or not defined in file
If there is nothing in the file, this is equivalent to False
because NixOS won't enable it then.
"""
name = cls.get_id()
with ReadUserData() as user_data:
return user_data.get("modules", {}).get(name, {}).get("enable", False)
@staticmethod @staticmethod
@abstractmethod @abstractmethod
@ -136,17 +144,25 @@ class Service(ABC):
"""The status of the service, reported by systemd.""" """The status of the service, reported by systemd."""
pass pass
@staticmethod @classmethod
@abstractmethod def _set_enable(cls, enable: bool):
def enable(): name = cls.get_id()
"""Enable the service. Usually this means enabling systemd unit.""" with WriteUserData() as user_data:
pass if "modules" not in user_data:
user_data["modules"] = {}
if name not in user_data["modules"]:
user_data["modules"][name] = {}
user_data["modules"][name]["enable"] = enable
@staticmethod @classmethod
@abstractmethod def enable(cls):
def disable(): """Enable the service. Usually this means enabling systemd unit."""
cls._set_enable(True)
@classmethod
def disable(cls):
"""Disable the service. Usually this means disabling systemd unit.""" """Disable the service. Usually this means disabling systemd unit."""
pass cls._set_enable(False)
@staticmethod @staticmethod
@abstractmethod @abstractmethod
@ -209,9 +225,13 @@ class Service(ABC):
return root_device return root_device
with utils.ReadUserData() as userdata: with utils.ReadUserData() as userdata:
if userdata.get("useBinds", False): if userdata.get("useBinds", False):
return userdata.get(cls.get_id(), {}).get( return (
"location", userdata.get("modules", {})
root_device, .get(cls.get_id(), {})
.get(
"location",
root_device,
)
) )
else: else:
return root_device return root_device
@ -246,6 +266,8 @@ class Service(ABC):
@abstractmethod @abstractmethod
def move_to_volume(self, volume: BlockDevice) -> Job: def move_to_volume(self, volume: BlockDevice) -> Job:
"""Cannot raise errors.
Returns errors as an errored out Job instead."""
pass pass
@classmethod @classmethod

View file

@ -8,9 +8,10 @@ from os import path
# from enum import Enum # from enum import Enum
from selfprivacy_api.jobs import Job from selfprivacy_api.jobs import Job, Jobs, JobStatus
from selfprivacy_api.services.service import Service, ServiceDnsRecord, ServiceStatus from selfprivacy_api.services.service import Service, ServiceDnsRecord, ServiceStatus
from selfprivacy_api.utils.block_devices import BlockDevice from selfprivacy_api.utils.block_devices import BlockDevice
from selfprivacy_api.services.generic_service_mover import move_service, FolderMoveNames
import selfprivacy_api.utils.network as network_utils import selfprivacy_api.utils.network as network_utils
from selfprivacy_api.services.test_service.icon import BITWARDEN_ICON from selfprivacy_api.services.test_service.icon import BITWARDEN_ICON
@ -22,16 +23,19 @@ class DummyService(Service):
"""A test service""" """A test service"""
folders: List[str] = [] folders: List[str] = []
startstop_delay = 0 startstop_delay = 0.0
backuppable = True backuppable = True
movable = True
# if False, we try to actually move
simulate_moving = True
drive = "sda1"
def __init_subclass__(cls, folders: List[str]): def __init_subclass__(cls, folders: List[str]):
cls.folders = folders cls.folders = folders
def __init__(self): def __init__(self):
super().__init__() super().__init__()
status_file = self.status_file() with open(self.status_file(), "w") as file:
with open(status_file, "w") as file:
file.write(ServiceStatus.ACTIVE.value) file.write(ServiceStatus.ACTIVE.value)
@staticmethod @staticmethod
@ -61,9 +65,9 @@ class DummyService(Service):
domain = "test.com" domain = "test.com"
return f"https://password.{domain}" return f"https://password.{domain}"
@staticmethod @classmethod
def is_movable() -> bool: def is_movable(cls) -> bool:
return True return cls.movable
@staticmethod @staticmethod
def is_required() -> bool: def is_required() -> bool:
@ -73,10 +77,6 @@ class DummyService(Service):
def get_backup_description() -> str: def get_backup_description() -> str:
return "How did we get here?" return "How did we get here?"
@staticmethod
def is_enabled() -> bool:
return True
@classmethod @classmethod
def status_file(cls) -> str: def status_file(cls) -> str:
dir = cls.folders[0] dir = cls.folders[0]
@ -116,22 +116,30 @@ class DummyService(Service):
we can only set it up dynamically for tests via a classmethod""" we can only set it up dynamically for tests via a classmethod"""
cls.backuppable = new_value cls.backuppable = new_value
@classmethod
def set_movable(cls, new_value: bool) -> None:
"""For tests: because is_movale is static,
we can only set it up dynamically for tests via a classmethod"""
cls.movable = new_value
@classmethod @classmethod
def can_be_backed_up(cls) -> bool: def can_be_backed_up(cls) -> bool:
"""`True` if the service can be backed up.""" """`True` if the service can be backed up."""
return cls.backuppable return cls.backuppable
@classmethod @classmethod
def enable(cls): def set_delay(cls, new_delay_sec: float) -> None:
pass cls.startstop_delay = new_delay_sec
@classmethod @classmethod
def disable(cls, delay): def set_drive(cls, new_drive: str) -> None:
pass cls.drive = new_drive
@classmethod @classmethod
def set_delay(cls, new_delay): def set_simulated_moves(cls, enabled: bool) -> None:
cls.startstop_delay = new_delay """If True, this service will not actually call moving code
when moved"""
cls.simulate_moving = enabled
@classmethod @classmethod
def stop(cls): def stop(cls):
@ -169,9 +177,9 @@ class DummyService(Service):
storage_usage = 0 storage_usage = 0
return storage_usage return storage_usage
@staticmethod @classmethod
def get_drive() -> str: def get_drive(cls) -> str:
return "sda1" return cls.drive
@classmethod @classmethod
def get_folders(cls) -> List[str]: def get_folders(cls) -> List[str]:
@ -186,14 +194,34 @@ class DummyService(Service):
name="password", name="password",
content=network_utils.get_ip4(), content=network_utils.get_ip4(),
ttl=3600, ttl=3600,
display_name="Test Service",
), ),
ServiceDnsRecord( ServiceDnsRecord(
type="AAAA", type="AAAA",
name="password", name="password",
content=network_utils.get_ip6(), content=network_utils.get_ip6(),
ttl=3600, ttl=3600,
display_name="Test Service (IPv6)",
), ),
] ]
def move_to_volume(self, volume: BlockDevice) -> Job: def move_to_volume(self, volume: BlockDevice) -> Job:
pass job = Jobs.add(
type_id=f"services.{self.get_id()}.move",
name=f"Move {self.get_display_name()}",
description=f"Moving {self.get_display_name()} data to {volume.name}",
)
if self.simulate_moving is False:
# completely generic code, TODO: make it the default impl.
move_service(
self,
volume,
job,
FolderMoveNames.default_foldermoves(self),
self.get_id(),
)
else:
Jobs.update(job, status=JobStatus.FINISHED)
self.set_drive(volume.name)
return job

View file

@ -6,27 +6,25 @@ import json
import os import os
import subprocess import subprocess
import portalocker import portalocker
import typing
USERDATA_FILE = "/etc/nixos/userdata/userdata.json" USERDATA_FILE = "/etc/nixos/userdata.json"
TOKENS_FILE = "/etc/nixos/userdata/tokens.json" SECRETS_FILE = "/etc/selfprivacy/secrets.json"
JOBS_FILE = "/etc/nixos/userdata/jobs.json" DKIM_DIR = "/var/dkim/"
DOMAIN_FILE = "/var/domain"
class UserDataFiles(Enum): class UserDataFiles(Enum):
"""Enum for userdata files""" """Enum for userdata files"""
USERDATA = 0 USERDATA = 0
TOKENS = 1 SECRETS = 3
JOBS = 2
def get_domain(): def get_domain():
"""Get domain from /var/domain without trailing new line""" """Get domain from userdata.json"""
with open(DOMAIN_FILE, "r", encoding="utf-8") as domain_file: with ReadUserData() as user_data:
domain = domain_file.readline().rstrip() return user_data["domain"]
return domain
class WriteUserData(object): class WriteUserData(object):
@ -35,14 +33,12 @@ class WriteUserData(object):
def __init__(self, file_type=UserDataFiles.USERDATA): def __init__(self, file_type=UserDataFiles.USERDATA):
if file_type == UserDataFiles.USERDATA: if file_type == UserDataFiles.USERDATA:
self.userdata_file = open(USERDATA_FILE, "r+", encoding="utf-8") self.userdata_file = open(USERDATA_FILE, "r+", encoding="utf-8")
elif file_type == UserDataFiles.TOKENS: elif file_type == UserDataFiles.SECRETS:
self.userdata_file = open(TOKENS_FILE, "r+", encoding="utf-8")
elif file_type == UserDataFiles.JOBS:
# Make sure file exists # Make sure file exists
if not os.path.exists(JOBS_FILE): if not os.path.exists(SECRETS_FILE):
with open(JOBS_FILE, "w", encoding="utf-8") as jobs_file: with open(SECRETS_FILE, "w", encoding="utf-8") as secrets_file:
jobs_file.write("{}") secrets_file.write("{}")
self.userdata_file = open(JOBS_FILE, "r+", encoding="utf-8") self.userdata_file = open(SECRETS_FILE, "r+", encoding="utf-8")
else: else:
raise ValueError("Unknown file type") raise ValueError("Unknown file type")
portalocker.lock(self.userdata_file, portalocker.LOCK_EX) portalocker.lock(self.userdata_file, portalocker.LOCK_EX)
@ -66,14 +62,11 @@ class ReadUserData(object):
def __init__(self, file_type=UserDataFiles.USERDATA): def __init__(self, file_type=UserDataFiles.USERDATA):
if file_type == UserDataFiles.USERDATA: if file_type == UserDataFiles.USERDATA:
self.userdata_file = open(USERDATA_FILE, "r", encoding="utf-8") self.userdata_file = open(USERDATA_FILE, "r", encoding="utf-8")
elif file_type == UserDataFiles.TOKENS: elif file_type == UserDataFiles.SECRETS:
self.userdata_file = open(TOKENS_FILE, "r", encoding="utf-8") if not os.path.exists(SECRETS_FILE):
elif file_type == UserDataFiles.JOBS: with open(SECRETS_FILE, "w", encoding="utf-8") as secrets_file:
# Make sure file exists secrets_file.write("{}")
if not os.path.exists(JOBS_FILE): self.userdata_file = open(SECRETS_FILE, "r", encoding="utf-8")
with open(JOBS_FILE, "w", encoding="utf-8") as jobs_file:
jobs_file.write("{}")
self.userdata_file = open(JOBS_FILE, "r", encoding="utf-8")
else: else:
raise ValueError("Unknown file type") raise ValueError("Unknown file type")
portalocker.lock(self.userdata_file, portalocker.LOCK_SH) portalocker.lock(self.userdata_file, portalocker.LOCK_SH)
@ -88,10 +81,12 @@ class ReadUserData(object):
def validate_ssh_public_key(key): def validate_ssh_public_key(key):
"""Validate SSH public key. It may be ssh-ed25519 or ssh-rsa.""" """Validate SSH public key.
It may be ssh-ed25519, ssh-rsa or ecdsa-sha2-nistp256."""
if not key.startswith("ssh-ed25519"): if not key.startswith("ssh-ed25519"):
if not key.startswith("ssh-rsa"): if not key.startswith("ssh-rsa"):
return False if not key.startswith("ecdsa-sha2-nistp256"):
return False
return True return True
@ -164,26 +159,31 @@ def parse_date(date_str: str) -> datetime.datetime:
raise ValueError("Invalid date string") raise ValueError("Invalid date string")
def get_dkim_key(domain, parse=True): def parse_dkim(dkim: str) -> str:
# extract key from file
dkim = dkim.split("(")[1]
dkim = dkim.split(")")[0]
# replace all quotes with nothing
dkim = dkim.replace('"', "")
# trim whitespace, remove newlines and tabs
dkim = dkim.strip()
dkim = dkim.replace("\n", "")
dkim = dkim.replace("\t", "")
# remove all redundant spaces
dkim = " ".join(dkim.split())
return dkim
def get_dkim_key(domain: str, parse: bool = True) -> typing.Optional[str]:
"""Get DKIM key from /var/dkim/<domain>.selector.txt""" """Get DKIM key from /var/dkim/<domain>.selector.txt"""
if os.path.exists("/var/dkim/" + domain + ".selector.txt"):
cat_process = subprocess.Popen( dkim_path = os.path.join(DKIM_DIR, domain + ".selector.txt")
["cat", "/var/dkim/" + domain + ".selector.txt"], stdout=subprocess.PIPE if os.path.exists(dkim_path):
) with open(dkim_path, encoding="utf-8") as dkim_file:
dkim = cat_process.communicate()[0] dkim = dkim_file.read()
if parse: if parse:
# Extract key from file dkim = parse_dkim(dkim)
dkim = dkim.split(b"(")[1] return dkim
dkim = dkim.split(b")")[0]
# Replace all quotes with nothing
dkim = dkim.replace(b'"', b"")
# Trim whitespace, remove newlines and tabs
dkim = dkim.strip()
dkim = dkim.replace(b"\n", b"")
dkim = dkim.replace(b"\t", b"")
# Remove all redundant spaces
dkim = b" ".join(dkim.split())
return str(dkim, "utf-8")
return None return None

View file

@ -1,4 +1,5 @@
"""Wrapper for block device functions.""" """A block device API wrapping lsblk"""
from __future__ import annotations
import subprocess import subprocess
import json import json
import typing import typing
@ -11,6 +12,7 @@ def get_block_device(device_name):
""" """
Return a block device by name. Return a block device by name.
""" """
# TODO: remove the function and related tests: dublicated by singleton
lsblk_output = subprocess.check_output( lsblk_output = subprocess.check_output(
[ [
"lsblk", "lsblk",
@ -43,22 +45,37 @@ class BlockDevice:
A block device. A block device.
""" """
def __init__(self, block_device): def __init__(self, device_dict: dict):
self.name = block_device["name"] self.update_from_dict(device_dict)
self.path = block_device["path"]
self.fsavail = str(block_device["fsavail"]) def update_from_dict(self, device_dict: dict):
self.fssize = str(block_device["fssize"]) self.name = device_dict["name"]
self.fstype = block_device["fstype"] self.path = device_dict["path"]
self.fsused = str(block_device["fsused"]) self.fsavail = str(device_dict["fsavail"])
self.mountpoints = block_device["mountpoints"] self.fssize = str(device_dict["fssize"])
self.label = block_device["label"] self.fstype = device_dict["fstype"]
self.uuid = block_device["uuid"] self.fsused = str(device_dict["fsused"])
self.size = str(block_device["size"]) self.mountpoints = device_dict["mountpoints"]
self.model = block_device["model"] self.label = device_dict["label"]
self.serial = block_device["serial"] self.uuid = device_dict["uuid"]
self.type = block_device["type"] self.size = str(device_dict["size"])
self.model = device_dict["model"]
self.serial = device_dict["serial"]
self.type = device_dict["type"]
self.locked = False self.locked = False
self.children: typing.List[BlockDevice] = []
if "children" in device_dict.keys():
for child in device_dict["children"]:
self.children.append(BlockDevice(child))
def all_children(self) -> typing.List[BlockDevice]:
result = []
for child in self.children:
result.extend(child.all_children())
result.append(child)
return result
def __str__(self): def __str__(self):
return self.name return self.name
@ -82,17 +99,7 @@ class BlockDevice:
Update current data and return a dictionary of stats. Update current data and return a dictionary of stats.
""" """
device = get_block_device(self.name) device = get_block_device(self.name)
self.fsavail = str(device["fsavail"]) self.update_from_dict(device)
self.fssize = str(device["fssize"])
self.fstype = device["fstype"]
self.fsused = str(device["fsused"])
self.mountpoints = device["mountpoints"]
self.label = device["label"]
self.uuid = device["uuid"]
self.size = str(device["size"])
self.model = device["model"]
self.serial = device["serial"]
self.type = device["type"]
return { return {
"name": self.name, "name": self.name,
@ -110,6 +117,14 @@ class BlockDevice:
"type": self.type, "type": self.type,
} }
def is_usable_partition(self):
# Ignore devices with type "rom"
if self.type == "rom":
return False
if self.fstype == "ext4":
return True
return False
def resize(self): def resize(self):
""" """
Resize the block device. Resize the block device.
@ -165,41 +180,16 @@ class BlockDevices(metaclass=SingletonMetaclass):
""" """
Update the list of block devices. Update the list of block devices.
""" """
devices = [] devices = BlockDevices.lsblk_devices()
lsblk_output = subprocess.check_output(
[ children = []
"lsblk",
"-J",
"-b",
"-o",
"NAME,PATH,FSAVAIL,FSSIZE,FSTYPE,FSUSED,MOUNTPOINTS,LABEL,UUID,SIZE,MODEL,SERIAL,TYPE",
]
)
lsblk_output = lsblk_output.decode("utf-8")
lsblk_output = json.loads(lsblk_output)
for device in lsblk_output["blockdevices"]:
# Ignore devices with type "rom"
if device["type"] == "rom":
continue
# Ignore iso9660 devices
if device["fstype"] == "iso9660":
continue
if device["fstype"] is None:
if "children" in device:
for child in device["children"]:
if child["fstype"] == "ext4":
device = child
break
devices.append(device)
# Add new devices and delete non-existent devices
for device in devices: for device in devices:
if device["name"] not in [ children.extend(device.all_children())
block_device.name for block_device in self.block_devices devices.extend(children)
]:
self.block_devices.append(BlockDevice(device)) valid_devices = [device for device in devices if device.is_usable_partition()]
for block_device in self.block_devices:
if block_device.name not in [device["name"] for device in devices]: self.block_devices = valid_devices
self.block_devices.remove(block_device)
def get_block_device(self, name: str) -> typing.Optional[BlockDevice]: def get_block_device(self, name: str) -> typing.Optional[BlockDevice]:
""" """
@ -236,3 +226,25 @@ class BlockDevices(metaclass=SingletonMetaclass):
if "/" in block_device.mountpoints: if "/" in block_device.mountpoints:
return block_device return block_device
raise RuntimeError("No root block device found") raise RuntimeError("No root block device found")
@staticmethod
def lsblk_device_dicts() -> typing.List[dict]:
lsblk_output_bytes = subprocess.check_output(
[
"lsblk",
"-J",
"-b",
"-o",
"NAME,PATH,FSAVAIL,FSSIZE,FSTYPE,FSUSED,MOUNTPOINTS,LABEL,UUID,SIZE,MODEL,SERIAL,TYPE",
]
)
lsblk_output = lsblk_output_bytes.decode("utf-8")
return json.loads(lsblk_output)["blockdevices"]
@staticmethod
def lsblk_devices() -> typing.List[BlockDevice]:
devices = []
for device in BlockDevices.lsblk_device_dicts():
devices.append(device)
return [BlockDevice(device) for device in devices]

View file

@ -2,14 +2,15 @@
import os import os
from huey import SqliteHuey from huey import SqliteHuey
HUEY_DATABASE = "/etc/nixos/userdata/tasks.db" HUEY_DATABASE = "/etc/selfprivacy/tasks.db"
# Singleton instance containing the huey database. # Singleton instance containing the huey database.
test_mode = os.environ.get("TEST_MODE") test_mode = os.environ.get("TEST_MODE")
huey = SqliteHuey( huey = SqliteHuey(
HUEY_DATABASE, "selfprivacy-api",
filename=HUEY_DATABASE if not test_mode else None,
immediate=test_mode == "true", immediate=test_mode == "true",
utc=True, utc=True,
) )

View file

@ -1,11 +1,14 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from enum import Enum
def store_model_as_hash(redis, redis_key, model): def store_model_as_hash(redis, redis_key, model):
for key, value in model.dict().items(): for key, value in model.dict().items():
if isinstance(value, datetime): if isinstance(value, datetime):
value = value.isoformat() value = value.isoformat()
if isinstance(value, Enum):
value = value.value
redis.hset(redis_key, key, str(value)) redis.hset(redis_key, key, str(value))

View file

@ -0,0 +1,52 @@
from datetime import datetime, timezone
def ensure_tz_aware(dt: datetime) -> datetime:
"""
returns timezone-aware datetime
assumes utc on naive datetime input
"""
if dt.tzinfo is None:
# astimezone() is dangerous, it makes an implicit assumption that
# the time is localtime
dt = dt.replace(tzinfo=timezone.utc)
return dt
def ensure_tz_aware_strict(dt: datetime) -> datetime:
"""
returns timezone-aware datetime
raises error if input is a naive datetime
"""
if dt.tzinfo is None:
raise ValueError(
"no timezone in datetime (tz-aware datetime is required for this operation)",
dt,
)
return dt
def tzaware_parse_time(iso_timestamp: str) -> datetime:
"""
parse an iso8601 timestamp into timezone-aware datetime
assume utc if no timezone in stamp
example of timestamp:
2023-11-10T12:07:47.868788+00:00
"""
dt = datetime.fromisoformat(iso_timestamp)
dt = ensure_tz_aware(dt)
return dt
def tzaware_parse_time_strict(iso_timestamp: str) -> datetime:
"""
parse an iso8601 timestamp into timezone-aware datetime
raise an error if no timezone in stamp
example of timestamp:
2023-11-10T12:07:47.868788+00:00
"""
dt = datetime.fromisoformat(iso_timestamp)
dt = ensure_tz_aware_strict(dt)
return dt

View file

@ -2,7 +2,7 @@ from setuptools import setup, find_packages
setup( setup(
name="selfprivacy_api", name="selfprivacy_api",
version="2.3.1", version="3.0.0",
packages=find_packages(), packages=find_packages(),
scripts=[ scripts=[
"selfprivacy_api/app.py", "selfprivacy_api/app.py",

View file

@ -1,49 +0,0 @@
{ pkgs ? import <nixos-22.11> { } }:
let
sp-python = pkgs.python310.withPackages (p: with p; [
setuptools
portalocker
pytz
pytest
pytest-asyncio
pytest-mock
pytest-datadir
huey
gevent
mnemonic
coverage
pylint
rope
mypy
pylsp-mypy
pydantic
typing-extensions
psutil
black
fastapi
uvicorn
redis
strawberry-graphql
flake8-bugbear
flake8
]);
in
pkgs.mkShell {
buildInputs = [
sp-python
pkgs.black
pkgs.redis
pkgs.restic
pkgs.rclone
];
shellHook = ''
PYTHONPATH=${sp-python}/${sp-python.sitePackages}
# envs set with export and as attributes are treated differently.
# for example. printenv <Name> will not fetch the value of an attribute.
export USE_REDIS_PORT=6379
pkill redis-server
sleep 2
setsid redis-server --bind 127.0.0.1 --port $USE_REDIS_PORT >/dev/null 2>/dev/null &
# maybe set more env-vars
'';
}

View file

@ -1,6 +1,45 @@
import json import json
from datetime import datetime, timezone, timedelta
from mnemonic import Mnemonic from mnemonic import Mnemonic
# for expiration tests. If headache, consider freezegun
RECOVERY_KEY_VALIDATION_DATETIME = "selfprivacy_api.models.tokens.time.datetime"
DEVICE_KEY_VALIDATION_DATETIME = RECOVERY_KEY_VALIDATION_DATETIME
def ten_minutes_into_future_naive():
return datetime.now() + timedelta(minutes=10)
def ten_minutes_into_future_naive_utc():
return datetime.utcnow() + timedelta(minutes=10)
def ten_minutes_into_future():
return datetime.now(timezone.utc) + timedelta(minutes=10)
def ten_minutes_into_past_naive():
return datetime.now() - timedelta(minutes=10)
def ten_minutes_into_past_naive_utc():
return datetime.utcnow() - timedelta(minutes=10)
def ten_minutes_into_past():
return datetime.now(timezone.utc) - timedelta(minutes=10)
class NearFuture(datetime):
@classmethod
def now(cls, tz=None):
return datetime.now(tz) + timedelta(minutes=13)
@classmethod
def utcnow(cls):
return datetime.utcnow() + timedelta(minutes=13)
def read_json(file_path): def read_json(file_path):
with open(file_path, "r", encoding="utf-8") as file: with open(file_path, "r", encoding="utf-8") as file:
@ -28,5 +67,15 @@ def generate_backup_query(query_array):
return "query TestBackup {\n backup {" + "\n".join(query_array) + "}\n}" return "query TestBackup {\n backup {" + "\n".join(query_array) + "}\n}"
def generate_service_query(query_array):
return "query TestService {\n services {" + "\n".join(query_array) + "}\n}"
def mnemonic_to_hex(mnemonic): def mnemonic_to_hex(mnemonic):
return Mnemonic(language="english").to_entropy(mnemonic).hex() return Mnemonic(language="english").to_entropy(mnemonic).hex()
def assert_recovery_recent(time_generated: str):
assert datetime.fromisoformat(time_generated) - timedelta(seconds=5) < datetime.now(
timezone.utc
)

View file

@ -3,9 +3,57 @@
# pylint: disable=unused-argument # pylint: disable=unused-argument
import os import os
import pytest import pytest
from os import path import datetime
from os import path
from os import makedirs
from typing import Generator
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.utils.huey import huey
import selfprivacy_api.services as services
from selfprivacy_api.services import get_service_by_id, Service
from selfprivacy_api.services.test_service import DummyService
from selfprivacy_api.repositories.tokens.redis_tokens_repository import (
RedisTokensRepository,
)
TESTFILE_BODY = "testytest!"
TESTFILE_2_BODY = "testissimo!"
TOKENS_FILE_CONTENTS = {
"tokens": [
{
"token": "TEST_TOKEN",
"name": "test_token",
"date": datetime.datetime(2022, 1, 14, 8, 31, 10, 789314),
},
{
"token": "TEST_TOKEN2",
"name": "test_token2",
"date": datetime.datetime(2022, 1, 14, 8, 31, 10, 789314),
},
]
}
TOKENS = [
Token(
token="TEST_TOKEN",
device_name="test_token",
created_at=datetime.datetime(2022, 1, 14, 8, 31, 10, 789314),
),
Token(
token="TEST_TOKEN2",
device_name="test_token2",
created_at=datetime.datetime(2022, 1, 14, 8, 31, 10, 789314),
),
]
DEVICE_WE_AUTH_TESTS_WITH = TOKENS_FILE_CONTENTS["tokens"][0]
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
@ -17,19 +65,22 @@ def global_data_dir():
@pytest.fixture @pytest.fixture
def tokens_file(mocker, shared_datadir): def empty_redis_repo():
"""Mock tokens file.""" repo = RedisTokensRepository()
mock = mocker.patch( repo.reset()
"selfprivacy_api.utils.TOKENS_FILE", shared_datadir / "tokens.json" assert repo.get_tokens() == []
) return repo
return mock
@pytest.fixture @pytest.fixture
def jobs_file(mocker, shared_datadir): def redis_repo_with_tokens():
"""Mock tokens file.""" repo = RedisTokensRepository()
mock = mocker.patch("selfprivacy_api.utils.JOBS_FILE", shared_datadir / "jobs.json") repo.reset()
return mock for token in TOKENS:
repo._store_token(token)
assert sorted(repo.get_tokens(), key=lambda x: x.token) == sorted(
TOKENS, key=lambda x: x.token
)
@pytest.fixture @pytest.fixture
@ -56,27 +107,75 @@ def huey_database(mocker, shared_datadir):
@pytest.fixture @pytest.fixture
def client(tokens_file, huey_database, jobs_file): def client(huey_database, redis_repo_with_tokens):
from selfprivacy_api.app import app from selfprivacy_api.app import app
return TestClient(app) return TestClient(app)
@pytest.fixture @pytest.fixture
def authorized_client(tokens_file, huey_database, jobs_file): def authorized_client(huey_database, redis_repo_with_tokens):
"""Authorized test client fixture.""" """Authorized test client fixture."""
from selfprivacy_api.app import app from selfprivacy_api.app import app
client = TestClient(app) client = TestClient(app)
client.headers.update({"Authorization": "Bearer TEST_TOKEN"}) client.headers.update(
{"Authorization": "Bearer " + DEVICE_WE_AUTH_TESTS_WITH["token"]}
)
return client return client
@pytest.fixture @pytest.fixture
def wrong_auth_client(tokens_file, huey_database, jobs_file): def wrong_auth_client(huey_database, redis_repo_with_tokens):
"""Wrong token test client fixture.""" """Wrong token test client fixture."""
from selfprivacy_api.app import app from selfprivacy_api.app import app
client = TestClient(app) client = TestClient(app)
client.headers.update({"Authorization": "Bearer WRONG_TOKEN"}) client.headers.update({"Authorization": "Bearer WRONG_TOKEN"})
return client return client
@pytest.fixture()
def raw_dummy_service(tmpdir):
dirnames = ["test_service", "also_test_service"]
service_dirs = []
for d in dirnames:
service_dir = path.join(tmpdir, d)
makedirs(service_dir)
service_dirs.append(service_dir)
testfile_path_1 = path.join(service_dirs[0], "testfile.txt")
with open(testfile_path_1, "w") as file:
file.write(TESTFILE_BODY)
testfile_path_2 = path.join(service_dirs[1], "testfile2.txt")
with open(testfile_path_2, "w") as file:
file.write(TESTFILE_2_BODY)
# we need this to not change get_folders() much
class TestDummyService(DummyService, folders=service_dirs):
pass
service = TestDummyService()
# assert pickle.dumps(service) is not None
return service
@pytest.fixture()
def dummy_service(
tmpdir, raw_dummy_service, generic_userdata
) -> Generator[Service, None, None]:
service = raw_dummy_service
# register our service
services.services.append(service)
huey.immediate = True
assert huey.immediate is True
assert get_service_by_id(service.get_id()) is not None
service.enable()
yield service
# cleanup because apparently it matters wrt tasks
services.services.remove(service)

View file

@ -1 +0,0 @@
{}

View file

@ -1,14 +0,0 @@
{
"tokens": [
{
"token": "TEST_TOKEN",
"name": "test_token",
"date": "2022-01-14 08:31:10.789314"
},
{
"token": "TEST_TOKEN2",
"name": "test_token2",
"date": "2022-01-14 08:31:10.789314"
}
]
}

View file

@ -1,21 +1,76 @@
{ {
"api": { "dns": {
"token": "TEST_TOKEN", "provider": "CLOUDFLARE",
"enableSwagger": false "useStagingACME": false
}, },
"bitwarden": { "server": {
"enable": true "provider": "HETZNER"
}, },
"databasePassword": "PASSWORD", "domain": "test-domain.tld",
"domain": "test.tld",
"hashedMasterPassword": "HASHED_PASSWORD", "hashedMasterPassword": "HASHED_PASSWORD",
"hostname": "test-instance", "hostname": "test-instance",
"nextcloud": { "timezone": "Etc/UTC",
"adminPassword": "ADMIN", "username": "tester",
"databasePassword": "ADMIN", "useBinds": true,
"enable": true "sshKeys": [
"ssh-rsa KEY test@pc"
],
"users": [
{
"username": "user1",
"hashedPassword": "HASHED_PASSWORD_1",
"sshKeys": ["ssh-rsa KEY user1@pc"]
},
{
"username": "user2",
"hashedPassword": "HASHED_PASSWORD_2",
"sshKeys": ["ssh-rsa KEY user2@pc"]
},
{
"username": "user3",
"hashedPassword": "HASHED_PASSWORD_3",
"sshKeys": ["ssh-rsa KEY user3@pc"]
}
],
"autoUpgrade": {
"enable": true,
"allowReboot": true
}, },
"resticPassword": "PASS", "modules": {
"bitwarden": {
"enable": true,
"location": "sdb"
},
"gitea": {
"enable": true,
"location": "sdb"
},
"jitsi-meet": {
"enable": true
},
"nextcloud": {
"enable": true,
"location": "sdb"
},
"ocserv": {
"enable": true
},
"pleroma": {
"enable": true,
"location": "sdb"
},
"simple-nixos-mailserver": {
"enable": true,
"location": "sdb"
}
},
"volumes": [
{
"device": "/dev/sdb",
"mountPoint": "/volumes/sdb",
"fsType": "ext4"
}
],
"ssh": { "ssh": {
"enable": true, "enable": true,
"passwordAuthentication": true, "passwordAuthentication": true,
@ -23,34 +78,6 @@
"ssh-ed25519 KEY test@pc" "ssh-ed25519 KEY test@pc"
] ]
}, },
"username": "tester",
"gitea": {
"enable": true
},
"ocserv": {
"enable": true
},
"pleroma": {
"enable": true
},
"jitsi": {
"enable": true
},
"autoUpgrade": {
"enable": true,
"allowReboot": true
},
"timezone": "Europe/Moscow",
"sshKeys": [
"ssh-rsa KEY test@pc"
],
"dns": {
"provider": "CLOUDFLARE",
"apiKey": "TOKEN"
},
"server": {
"provider": "HETZNER"
},
"backup": { "backup": {
"provider": "BACKBLAZE", "provider": "BACKBLAZE",
"accountId": "ID", "accountId": "ID",

538
tests/test_autobackup.py Normal file
View file

@ -0,0 +1,538 @@
import pytest
from copy import copy
from datetime import datetime, timezone, timedelta
from selfprivacy_api.jobs import Jobs
from selfprivacy_api.services import Service, get_all_services
from selfprivacy_api.graphql.common_types.backup import (
BackupReason,
AutobackupQuotas,
)
from selfprivacy_api.backup import Backups, Snapshot
from selfprivacy_api.backup.tasks import (
prune_autobackup_snapshots,
)
from tests.test_backup import backups
def backuppable_services() -> list[Service]:
return [service for service in get_all_services() if service.can_be_backed_up()]
def dummy_snapshot(date: datetime):
return Snapshot(
id=str(hash(date)),
service_name="someservice",
created_at=date,
reason=BackupReason.EXPLICIT,
)
def test_no_default_autobackup(backups, dummy_service):
now = datetime.now(timezone.utc)
assert not Backups.is_time_to_backup_service(dummy_service, now)
assert not Backups.is_time_to_backup(now)
# --------------------- Timing -------------------------
def test_set_autobackup_period(backups):
assert Backups.autobackup_period_minutes() is None
Backups.set_autobackup_period_minutes(2)
assert Backups.autobackup_period_minutes() == 2
Backups.disable_all_autobackup()
assert Backups.autobackup_period_minutes() is None
Backups.set_autobackup_period_minutes(3)
assert Backups.autobackup_period_minutes() == 3
Backups.set_autobackup_period_minutes(0)
assert Backups.autobackup_period_minutes() is None
Backups.set_autobackup_period_minutes(3)
assert Backups.autobackup_period_minutes() == 3
Backups.set_autobackup_period_minutes(-1)
assert Backups.autobackup_period_minutes() is None
def test_autobackup_timer_periods(backups, dummy_service):
now = datetime.now(timezone.utc)
backup_period = 13 # minutes
assert not Backups.is_time_to_backup_service(dummy_service, now)
assert not Backups.is_time_to_backup(now)
Backups.set_autobackup_period_minutes(backup_period)
assert Backups.is_time_to_backup_service(dummy_service, now)
assert Backups.is_time_to_backup(now)
Backups.set_autobackup_period_minutes(0)
assert not Backups.is_time_to_backup_service(dummy_service, now)
assert not Backups.is_time_to_backup(now)
def test_autobackup_timer_enabling(backups, dummy_service):
now = datetime.now(timezone.utc)
backup_period = 13 # minutes
dummy_service.set_backuppable(False)
Backups.set_autobackup_period_minutes(backup_period)
assert Backups.is_time_to_backup(
now
) # there are other services too, not just our dummy
# not backuppable service is not backuppable even if period is set
assert not Backups.is_time_to_backup_service(dummy_service, now)
dummy_service.set_backuppable(True)
assert dummy_service.can_be_backed_up()
assert Backups.is_time_to_backup_service(dummy_service, now)
Backups.disable_all_autobackup()
assert not Backups.is_time_to_backup_service(dummy_service, now)
assert not Backups.is_time_to_backup(now)
def test_autobackup_timing(backups, dummy_service):
backup_period = 13 # minutes
now = datetime.now(timezone.utc)
Backups.set_autobackup_period_minutes(backup_period)
assert Backups.is_time_to_backup_service(dummy_service, now)
assert Backups.is_time_to_backup(now)
Backups.back_up(dummy_service)
now = datetime.now(timezone.utc)
assert not Backups.is_time_to_backup_service(dummy_service, now)
past = datetime.now(timezone.utc) - timedelta(minutes=1)
assert not Backups.is_time_to_backup_service(dummy_service, past)
future = datetime.now(timezone.utc) + timedelta(minutes=backup_period + 2)
assert Backups.is_time_to_backup_service(dummy_service, future)
# --------------------- What to autobackup and what not to --------------------
def test_services_to_autobackup(backups, dummy_service):
backup_period = 13 # minutes
now = datetime.now(timezone.utc)
dummy_service.set_backuppable(False)
services = Backups.services_to_back_up(now)
assert len(services) == 0
dummy_service.set_backuppable(True)
services = Backups.services_to_back_up(now)
assert len(services) == 0
Backups.set_autobackup_period_minutes(backup_period)
services = Backups.services_to_back_up(now)
assert len(services) == len(backuppable_services())
assert dummy_service.get_id() in [
service.get_id() for service in backuppable_services()
]
def test_do_not_autobackup_disabled_services(backups, dummy_service):
now = datetime.now(timezone.utc)
Backups.set_autobackup_period_minutes(3)
assert Backups.is_time_to_backup_service(dummy_service, now) is True
dummy_service.disable()
assert Backups.is_time_to_backup_service(dummy_service, now) is False
def test_failed_autoback_prevents_more_autobackup(backups, dummy_service):
backup_period = 13 # minutes
now = datetime.now(timezone.utc)
Backups.set_autobackup_period_minutes(backup_period)
assert Backups.is_time_to_backup_service(dummy_service, now)
# artificially making an errored out backup job
dummy_service.set_backuppable(False)
with pytest.raises(ValueError):
Backups.back_up(dummy_service)
dummy_service.set_backuppable(True)
assert Backups.get_last_backed_up(dummy_service) is None
assert Backups.get_last_backup_error_time(dummy_service) is not None
assert Backups.is_time_to_backup_service(dummy_service, now) is False
# --------------------- Quotas and Pruning -------------------------
unlimited_quotas = AutobackupQuotas(
last=-1,
daily=-1,
weekly=-1,
monthly=-1,
yearly=-1,
)
zero_quotas = AutobackupQuotas(
last=0,
daily=0,
weekly=0,
monthly=0,
yearly=0,
)
unlimited_quotas = AutobackupQuotas(
last=-1,
daily=-1,
weekly=-1,
monthly=-1,
yearly=-1,
)
zero_quotas = AutobackupQuotas(
last=0,
daily=0,
weekly=0,
monthly=0,
yearly=0,
)
def test_get_empty_quotas(backups):
quotas = Backups.autobackup_quotas()
assert quotas is not None
assert quotas == unlimited_quotas
def test_set_quotas(backups):
quotas = AutobackupQuotas(
last=3,
daily=2343,
weekly=343,
monthly=0,
yearly=-34556,
)
Backups.set_autobackup_quotas(quotas)
assert Backups.autobackup_quotas() == AutobackupQuotas(
last=3,
daily=2343,
weekly=343,
monthly=0,
yearly=-1,
)
def test_set_zero_quotas(backups):
quotas = AutobackupQuotas(
last=0,
daily=0,
weekly=0,
monthly=0,
yearly=0,
)
Backups.set_autobackup_quotas(quotas)
assert Backups.autobackup_quotas() == zero_quotas
def test_set_unlimited_quotas(backups):
quotas = AutobackupQuotas(
last=-1,
daily=-1,
weekly=-1,
monthly=-1,
yearly=-1,
)
Backups.set_autobackup_quotas(quotas)
assert Backups.autobackup_quotas() == unlimited_quotas
def test_set_zero_quotas_after_unlimited(backups):
quotas = AutobackupQuotas(
last=-1,
daily=-1,
weekly=-1,
monthly=-1,
yearly=-1,
)
Backups.set_autobackup_quotas(quotas)
assert Backups.autobackup_quotas() == unlimited_quotas
quotas = AutobackupQuotas(
last=0,
daily=0,
weekly=0,
monthly=0,
yearly=0,
)
Backups.set_autobackup_quotas(quotas)
assert Backups.autobackup_quotas() == zero_quotas
def test_autobackup_snapshots_pruning(backups):
# Wednesday, fourth week
now = datetime(year=2023, month=1, day=25, hour=10)
snaps = [
dummy_snapshot(now),
dummy_snapshot(now - timedelta(minutes=5)),
dummy_snapshot(now - timedelta(hours=2)),
dummy_snapshot(now - timedelta(hours=5)),
dummy_snapshot(now - timedelta(days=1)),
dummy_snapshot(now - timedelta(days=1, hours=2)),
dummy_snapshot(now - timedelta(days=1, hours=3)),
dummy_snapshot(now - timedelta(days=2)),
dummy_snapshot(now - timedelta(days=7)),
dummy_snapshot(now - timedelta(days=12)),
dummy_snapshot(now - timedelta(days=23)),
dummy_snapshot(now - timedelta(days=28)),
dummy_snapshot(now - timedelta(days=32)),
dummy_snapshot(now - timedelta(days=47)),
dummy_snapshot(now - timedelta(days=64)),
dummy_snapshot(now - timedelta(days=84)),
dummy_snapshot(now - timedelta(days=104)),
dummy_snapshot(now - timedelta(days=365 * 2)),
]
old_len = len(snaps)
quotas = copy(unlimited_quotas)
Backups.set_autobackup_quotas(quotas)
assert Backups._prune_snaps_with_quotas(snaps) == snaps
quotas = copy(zero_quotas)
quotas.last = 2
quotas.daily = 2
Backups.set_autobackup_quotas(quotas)
snaps_to_keep = Backups._prune_snaps_with_quotas(snaps)
assert snaps_to_keep == [
dummy_snapshot(now),
dummy_snapshot(now - timedelta(minutes=5)),
# dummy_snapshot(now - timedelta(hours=2)),
# dummy_snapshot(now - timedelta(hours=5)),
dummy_snapshot(now - timedelta(days=1)),
# dummy_snapshot(now - timedelta(days=1, hours=2)),
# dummy_snapshot(now - timedelta(days=1, hours=3)),
# dummy_snapshot(now - timedelta(days=2)),
# dummy_snapshot(now - timedelta(days=7)),
# dummy_snapshot(now - timedelta(days=12)),
# dummy_snapshot(now - timedelta(days=23)),
# dummy_snapshot(now - timedelta(days=28)),
# dummy_snapshot(now - timedelta(days=32)),
# dummy_snapshot(now - timedelta(days=47)),
# dummy_snapshot(now - timedelta(days=64)),
# dummy_snapshot(now - timedelta(days=84)),
# dummy_snapshot(now - timedelta(days=104)),
# dummy_snapshot(now - timedelta(days=365 * 2)),
]
# checking that this function does not mutate the argument
assert snaps != snaps_to_keep
assert len(snaps) == old_len
quotas = copy(zero_quotas)
quotas.weekly = 4
Backups.set_autobackup_quotas(quotas)
snaps_to_keep = Backups._prune_snaps_with_quotas(snaps)
assert snaps_to_keep == [
dummy_snapshot(now),
# dummy_snapshot(now - timedelta(minutes=5)),
# dummy_snapshot(now - timedelta(hours=2)),
# dummy_snapshot(now - timedelta(hours=5)),
# dummy_snapshot(now - timedelta(days=1)),
# dummy_snapshot(now - timedelta(days=1, hours=2)),
# dummy_snapshot(now - timedelta(days=1, hours=3)),
# dummy_snapshot(now - timedelta(days=2)),
dummy_snapshot(now - timedelta(days=7)),
dummy_snapshot(now - timedelta(days=12)),
dummy_snapshot(now - timedelta(days=23)),
# dummy_snapshot(now - timedelta(days=28)),
# dummy_snapshot(now - timedelta(days=32)),
# dummy_snapshot(now - timedelta(days=47)),
# dummy_snapshot(now - timedelta(days=64)),
# dummy_snapshot(now - timedelta(days=84)),
# dummy_snapshot(now - timedelta(days=104)),
# dummy_snapshot(now - timedelta(days=365 * 2)),
]
quotas = copy(zero_quotas)
quotas.monthly = 7
Backups.set_autobackup_quotas(quotas)
snaps_to_keep = Backups._prune_snaps_with_quotas(snaps)
assert snaps_to_keep == [
dummy_snapshot(now),
# dummy_snapshot(now - timedelta(minutes=5)),
# dummy_snapshot(now - timedelta(hours=2)),
# dummy_snapshot(now - timedelta(hours=5)),
# dummy_snapshot(now - timedelta(days=1)),
# dummy_snapshot(now - timedelta(days=1, hours=2)),
# dummy_snapshot(now - timedelta(days=1, hours=3)),
# dummy_snapshot(now - timedelta(days=2)),
# dummy_snapshot(now - timedelta(days=7)),
# dummy_snapshot(now - timedelta(days=12)),
# dummy_snapshot(now - timedelta(days=23)),
dummy_snapshot(now - timedelta(days=28)),
# dummy_snapshot(now - timedelta(days=32)),
# dummy_snapshot(now - timedelta(days=47)),
dummy_snapshot(now - timedelta(days=64)),
# dummy_snapshot(now - timedelta(days=84)),
dummy_snapshot(now - timedelta(days=104)),
dummy_snapshot(now - timedelta(days=365 * 2)),
]
def test_autobackup_snapshots_pruning_yearly(backups):
snaps = [
dummy_snapshot(datetime(year=2055, month=3, day=1)),
dummy_snapshot(datetime(year=2055, month=2, day=1)),
dummy_snapshot(datetime(year=2023, month=4, day=1)),
dummy_snapshot(datetime(year=2023, month=3, day=1)),
dummy_snapshot(datetime(year=2023, month=2, day=1)),
dummy_snapshot(datetime(year=2021, month=2, day=1)),
]
quotas = copy(zero_quotas)
quotas.yearly = 2
Backups.set_autobackup_quotas(quotas)
snaps_to_keep = Backups._prune_snaps_with_quotas(snaps)
assert snaps_to_keep == [
dummy_snapshot(datetime(year=2055, month=3, day=1)),
dummy_snapshot(datetime(year=2023, month=4, day=1)),
]
def test_autobackup_snapshots_pruning_bottleneck(backups):
now = datetime(year=2023, month=1, day=25, hour=10)
snaps = [
dummy_snapshot(now),
dummy_snapshot(now - timedelta(minutes=5)),
dummy_snapshot(now - timedelta(hours=2)),
dummy_snapshot(now - timedelta(hours=3)),
dummy_snapshot(now - timedelta(hours=4)),
]
yearly_quota = copy(zero_quotas)
yearly_quota.yearly = 2
monthly_quota = copy(zero_quotas)
monthly_quota.monthly = 2
weekly_quota = copy(zero_quotas)
weekly_quota.weekly = 2
daily_quota = copy(zero_quotas)
daily_quota.daily = 2
last_quota = copy(zero_quotas)
last_quota.last = 1
last_quota.yearly = 2
for quota in [last_quota, yearly_quota, monthly_quota, weekly_quota, daily_quota]:
print(quota)
Backups.set_autobackup_quotas(quota)
snaps_to_keep = Backups._prune_snaps_with_quotas(snaps)
assert snaps_to_keep == [
dummy_snapshot(now),
# If there is a vacant quota, we should keep the last snapshot even if it doesn't fit
dummy_snapshot(now - timedelta(hours=4)),
]
def test_autobackup_snapshots_pruning_edgeweek(backups):
# jan 1 2023 is Sunday
snaps = [
dummy_snapshot(datetime(year=2023, month=1, day=6)),
dummy_snapshot(datetime(year=2023, month=1, day=1)),
dummy_snapshot(datetime(year=2022, month=12, day=31)),
dummy_snapshot(datetime(year=2022, month=12, day=30)),
]
quotas = copy(zero_quotas)
quotas.weekly = 2
Backups.set_autobackup_quotas(quotas)
snaps_to_keep = Backups._prune_snaps_with_quotas(snaps)
assert snaps_to_keep == [
dummy_snapshot(datetime(year=2023, month=1, day=6)),
dummy_snapshot(datetime(year=2023, month=1, day=1)),
]
def test_autobackup_snapshots_pruning_big_gap(backups):
snaps = [
dummy_snapshot(datetime(year=2023, month=1, day=6)),
dummy_snapshot(datetime(year=2023, month=1, day=2)),
dummy_snapshot(datetime(year=2022, month=10, day=31)),
dummy_snapshot(datetime(year=2022, month=10, day=30)),
]
quotas = copy(zero_quotas)
quotas.weekly = 2
Backups.set_autobackup_quotas(quotas)
snaps_to_keep = Backups._prune_snaps_with_quotas(snaps)
assert snaps_to_keep == [
dummy_snapshot(datetime(year=2023, month=1, day=6)),
dummy_snapshot(datetime(year=2022, month=10, day=31)),
]
def test_quotas_exceeded_with_too_many_autobackups(backups, dummy_service):
assert Backups.autobackup_quotas()
quota = copy(zero_quotas)
quota.last = 2
Backups.set_autobackup_quotas(quota)
assert Backups.autobackup_quotas().last == 2
snap = Backups.back_up(dummy_service, BackupReason.AUTO)
assert len(Backups.get_snapshots(dummy_service)) == 1
snap2 = Backups.back_up(dummy_service, BackupReason.AUTO)
assert len(Backups.get_snapshots(dummy_service)) == 2
snap3 = Backups.back_up(dummy_service, BackupReason.AUTO)
assert len(Backups.get_snapshots(dummy_service)) == 2
snaps = Backups.get_snapshots(dummy_service)
assert snap2 in snaps
assert snap3 in snaps
assert snap not in snaps
quota.last = -1
Backups.set_autobackup_quotas(quota)
snap4 = Backups.back_up(dummy_service, BackupReason.AUTO)
snaps = Backups.get_snapshots(dummy_service)
assert len(snaps) == 3
assert snap4 in snaps
# Retroactivity
quota.last = 1
Backups.set_autobackup_quotas(quota)
job = Jobs.add("trimming", "test.autobackup_trimming", "trimming the snaps!")
handle = prune_autobackup_snapshots(job)
handle(blocking=True)
snaps = Backups.get_snapshots(dummy_service)
assert len(snaps) == 1
snap5 = Backups.back_up(dummy_service, BackupReason.AUTO)
snaps = Backups.get_snapshots(dummy_service)
assert len(snaps) == 1
assert snap5 in snaps
# Explicit snaps are not affected
snap6 = Backups.back_up(dummy_service, BackupReason.EXPLICIT)
snaps = Backups.get_snapshots(dummy_service)
assert len(snaps) == 2
assert snap5 in snaps
assert snap6 in snaps

View file

@ -1,23 +1,25 @@
import pytest import pytest
import os import os
import os.path as path import os.path as path
from os import makedirs
from os import remove from os import remove
from os import listdir from os import listdir
from os import urandom from os import urandom
from datetime import datetime, timedelta, timezone
from subprocess import Popen
from datetime import datetime, timedelta, timezone
import tempfile import tempfile
import selfprivacy_api.services as services from selfprivacy_api.utils.huey import huey
from selfprivacy_api.services import Service, get_all_services
from selfprivacy_api.services.service import ServiceStatus from selfprivacy_api.services.service import ServiceStatus
from selfprivacy_api.services import get_service_by_id
from selfprivacy_api.services.test_service import DummyService
from selfprivacy_api.graphql.queries.providers import BackupProvider from selfprivacy_api.graphql.queries.providers import BackupProvider
from selfprivacy_api.graphql.common_types.backup import RestoreStrategy from selfprivacy_api.graphql.common_types.backup import (
RestoreStrategy,
BackupReason,
)
from selfprivacy_api.jobs import Jobs, JobStatus from selfprivacy_api.jobs import Jobs, JobStatus
from selfprivacy_api.models.backup.snapshot import Snapshot from selfprivacy_api.models.backup.snapshot import Snapshot
@ -28,9 +30,6 @@ from selfprivacy_api.backup.providers import AbstractBackupProvider
from selfprivacy_api.backup.providers.backblaze import Backblaze from selfprivacy_api.backup.providers.backblaze import Backblaze
from selfprivacy_api.backup.providers.none import NoBackups from selfprivacy_api.backup.providers.none import NoBackups
from selfprivacy_api.backup.util import sync from selfprivacy_api.backup.util import sync
from selfprivacy_api.backup.backuppers.restic_backupper import ResticBackupper
from selfprivacy_api.backup.jobs import add_backup_job, add_restore_job
from selfprivacy_api.backup.tasks import ( from selfprivacy_api.backup.tasks import (
start_backup, start_backup,
@ -38,16 +37,15 @@ from selfprivacy_api.backup.tasks import (
reload_snapshot_cache, reload_snapshot_cache,
) )
from selfprivacy_api.backup.storage import Storage from selfprivacy_api.backup.storage import Storage
from selfprivacy_api.backup.jobs import get_backup_job
TESTFILE_BODY = "testytest!"
TESTFILE_2_BODY = "testissimo!"
REPO_NAME = "test_backup" REPO_NAME = "test_backup"
REPOFILE_NAME = "totallyunrelated"
def prepare_localfile_backups(temp_dir): def prepare_localfile_backups(temp_dir):
test_repo_path = path.join(temp_dir, "totallyunrelated") test_repo_path = path.join(temp_dir, REPOFILE_NAME)
assert not path.exists(test_repo_path) assert not path.exists(test_repo_path)
Backups.set_localfile_repo(test_repo_path) Backups.set_localfile_repo(test_repo_path)
@ -62,16 +60,24 @@ def backups_local(tmpdir):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def backups(tmpdir): def backups(tmpdir):
# for those tests that are supposed to pass with any repo """
For those tests that are supposed to pass with
both local and cloud repos
"""
# Sometimes this is false. Idk why.
huey.immediate = True
assert huey.immediate is True
Backups.reset() Backups.reset()
if BACKUP_PROVIDER_ENVS["kind"] in os.environ.keys(): if BACKUP_PROVIDER_ENVS["kind"] in os.environ.keys():
Backups.set_provider_from_envs() Backups.set_provider_from_envs()
else: else:
prepare_localfile_backups(tmpdir) prepare_localfile_backups(tmpdir)
Jobs.reset() Jobs.reset()
# assert not repo_path
Backups.init_repo() Backups.init_repo()
assert Backups.provider().location == str(tmpdir) + "/" + REPOFILE_NAME
yield yield
Backups.erase_repo() Backups.erase_repo()
@ -81,45 +87,6 @@ def backups_backblaze(generic_userdata):
Backups.reset(reset_json=False) Backups.reset(reset_json=False)
@pytest.fixture()
def raw_dummy_service(tmpdir):
dirnames = ["test_service", "also_test_service"]
service_dirs = []
for d in dirnames:
service_dir = path.join(tmpdir, d)
makedirs(service_dir)
service_dirs.append(service_dir)
testfile_path_1 = path.join(service_dirs[0], "testfile.txt")
with open(testfile_path_1, "w") as file:
file.write(TESTFILE_BODY)
testfile_path_2 = path.join(service_dirs[1], "testfile2.txt")
with open(testfile_path_2, "w") as file:
file.write(TESTFILE_2_BODY)
# we need this to not change get_folders() much
class TestDummyService(DummyService, folders=service_dirs):
pass
service = TestDummyService()
return service
@pytest.fixture()
def dummy_service(tmpdir, backups, raw_dummy_service) -> Service:
service = raw_dummy_service
# register our service
services.services.append(service)
assert get_service_by_id(service.get_id()) is not None
yield service
# cleanup because apparently it matters wrt tasks
services.services.remove(service)
@pytest.fixture() @pytest.fixture()
def memory_backup() -> AbstractBackupProvider: def memory_backup() -> AbstractBackupProvider:
ProviderClass = providers.get_provider(BackupProvider.MEMORY) ProviderClass = providers.get_provider(BackupProvider.MEMORY)
@ -242,16 +209,6 @@ def test_reinit_after_purge(backups):
assert len(Backups.get_all_snapshots()) == 0 assert len(Backups.get_all_snapshots()) == 0
def test_backup_simple_file(raw_dummy_service, file_backup):
# temporarily incomplete
service = raw_dummy_service
assert service is not None
assert file_backup is not None
name = service.get_id()
file_backup.backupper.init()
def test_backup_service(dummy_service, backups): def test_backup_service(dummy_service, backups):
id = dummy_service.get_id() id = dummy_service.get_id()
assert_job_finished(f"services.{id}.backup", count=0) assert_job_finished(f"services.{id}.backup", count=0)
@ -293,6 +250,16 @@ def test_backup_returns_snapshot(backups, dummy_service):
assert Backups.get_snapshot_by_id(snapshot.id) is not None assert Backups.get_snapshot_by_id(snapshot.id) is not None
assert snapshot.service_name == name assert snapshot.service_name == name
assert snapshot.created_at is not None assert snapshot.created_at is not None
assert snapshot.reason == BackupReason.EXPLICIT
def test_backup_reasons(backups, dummy_service):
snap = Backups.back_up(dummy_service, BackupReason.AUTO)
assert snap.reason == BackupReason.AUTO
Backups.force_snapshot_cache_reload()
snaps = Backups.get_snapshots(dummy_service)
assert snaps[0].reason == BackupReason.AUTO
def folder_files(folder): def folder_files(folder):
@ -404,7 +371,7 @@ def simulated_service_stopping_delay(request) -> float:
def test_backup_service_task(backups, dummy_service, simulated_service_stopping_delay): def test_backup_service_task(backups, dummy_service, simulated_service_stopping_delay):
dummy_service.set_delay(simulated_service_stopping_delay) dummy_service.set_delay(simulated_service_stopping_delay)
handle = start_backup(dummy_service) handle = start_backup(dummy_service.get_id())
handle(blocking=True) handle(blocking=True)
snaps = Backups.get_snapshots(dummy_service) snaps = Backups.get_snapshots(dummy_service)
@ -435,7 +402,10 @@ def test_forget_snapshot(backups, dummy_service):
def test_forget_nonexistent_snapshot(backups, dummy_service): def test_forget_nonexistent_snapshot(backups, dummy_service):
bogus = Snapshot( bogus = Snapshot(
id="gibberjibber", service_name="nohoho", created_at=datetime.now(timezone.utc) id="gibberjibber",
service_name="nohoho",
created_at=datetime.now(timezone.utc),
reason=BackupReason.EXPLICIT,
) )
with pytest.raises(ValueError): with pytest.raises(ValueError):
Backups.forget_snapshot(bogus) Backups.forget_snapshot(bogus)
@ -446,7 +416,7 @@ def test_backup_larger_file(backups, dummy_service):
mega = 2**20 mega = 2**20
make_large_file(dir, 100 * mega) make_large_file(dir, 100 * mega)
handle = start_backup(dummy_service) handle = start_backup(dummy_service.get_id())
handle(blocking=True) handle(blocking=True)
# results will be slightly different on different machines. if someone has troubles with it on their machine, consider dropping this test. # results will be slightly different on different machines. if someone has troubles with it on their machine, consider dropping this test.
@ -508,120 +478,17 @@ def test_restore_snapshot_task(
snaps = Backups.get_snapshots(dummy_service) snaps = Backups.get_snapshots(dummy_service)
if restore_strategy == RestoreStrategy.INPLACE: if restore_strategy == RestoreStrategy.INPLACE:
assert len(snaps) == 2 assert len(snaps) == 2
reasons = [snap.reason for snap in snaps]
assert BackupReason.PRE_RESTORE in reasons
else: else:
assert len(snaps) == 1 assert len(snaps) == 1
def test_set_autobackup_period(backups): def test_backup_unbackuppable(backups, dummy_service):
assert Backups.autobackup_period_minutes() is None
Backups.set_autobackup_period_minutes(2)
assert Backups.autobackup_period_minutes() == 2
Backups.disable_all_autobackup()
assert Backups.autobackup_period_minutes() is None
Backups.set_autobackup_period_minutes(3)
assert Backups.autobackup_period_minutes() == 3
Backups.set_autobackup_period_minutes(0)
assert Backups.autobackup_period_minutes() is None
Backups.set_autobackup_period_minutes(3)
assert Backups.autobackup_period_minutes() == 3
Backups.set_autobackup_period_minutes(-1)
assert Backups.autobackup_period_minutes() is None
def test_no_default_autobackup(backups, dummy_service):
now = datetime.now(timezone.utc)
assert not Backups.is_time_to_backup_service(dummy_service, now)
assert not Backups.is_time_to_backup(now)
def backuppable_services() -> list[Service]:
return [service for service in get_all_services() if service.can_be_backed_up()]
def test_services_to_back_up(backups, dummy_service):
backup_period = 13 # minutes
now = datetime.now(timezone.utc)
dummy_service.set_backuppable(False) dummy_service.set_backuppable(False)
services = Backups.services_to_back_up(now) assert dummy_service.can_be_backed_up() is False
assert len(services) == 0 with pytest.raises(ValueError):
Backups.back_up(dummy_service)
dummy_service.set_backuppable(True)
services = Backups.services_to_back_up(now)
assert len(services) == 0
Backups.set_autobackup_period_minutes(backup_period)
services = Backups.services_to_back_up(now)
assert len(services) == len(backuppable_services())
assert dummy_service.get_id() in [
service.get_id() for service in backuppable_services()
]
def test_autobackup_timer_periods(backups, dummy_service):
now = datetime.now(timezone.utc)
backup_period = 13 # minutes
assert not Backups.is_time_to_backup_service(dummy_service, now)
assert not Backups.is_time_to_backup(now)
Backups.set_autobackup_period_minutes(backup_period)
assert Backups.is_time_to_backup_service(dummy_service, now)
assert Backups.is_time_to_backup(now)
Backups.set_autobackup_period_minutes(0)
assert not Backups.is_time_to_backup_service(dummy_service, now)
assert not Backups.is_time_to_backup(now)
def test_autobackup_timer_enabling(backups, dummy_service):
now = datetime.now(timezone.utc)
backup_period = 13 # minutes
dummy_service.set_backuppable(False)
Backups.set_autobackup_period_minutes(backup_period)
assert Backups.is_time_to_backup(
now
) # there are other services too, not just our dummy
# not backuppable service is not backuppable even if period is set
assert not Backups.is_time_to_backup_service(dummy_service, now)
dummy_service.set_backuppable(True)
assert dummy_service.can_be_backed_up()
assert Backups.is_time_to_backup_service(dummy_service, now)
Backups.disable_all_autobackup()
assert not Backups.is_time_to_backup_service(dummy_service, now)
assert not Backups.is_time_to_backup(now)
def test_autobackup_timing(backups, dummy_service):
backup_period = 13 # minutes
now = datetime.now(timezone.utc)
Backups.set_autobackup_period_minutes(backup_period)
assert Backups.is_time_to_backup_service(dummy_service, now)
assert Backups.is_time_to_backup(now)
Backups.back_up(dummy_service)
now = datetime.now(timezone.utc)
assert not Backups.is_time_to_backup_service(dummy_service, now)
past = datetime.now(timezone.utc) - timedelta(minutes=1)
assert not Backups.is_time_to_backup_service(dummy_service, past)
future = datetime.now(timezone.utc) + timedelta(minutes=backup_period + 2)
assert Backups.is_time_to_backup_service(dummy_service, future)
# Storage # Storage

View file

@ -67,7 +67,7 @@ def only_root_in_userdata(mocker, datadir):
read_json(datadir / "only_root.json")["volumes"][0]["mountPoint"] read_json(datadir / "only_root.json")["volumes"][0]["mountPoint"]
== "/volumes/sda1" == "/volumes/sda1"
) )
assert read_json(datadir / "only_root.json")["volumes"][0]["filesystem"] == "ext4" assert read_json(datadir / "only_root.json")["volumes"][0]["fsType"] == "ext4"
return datadir return datadir
@ -416,32 +416,37 @@ def lsblk_full_mock(mocker):
def test_get_block_devices(lsblk_full_mock, authorized_client): def test_get_block_devices(lsblk_full_mock, authorized_client):
block_devices = BlockDevices().get_block_devices() block_devices = BlockDevices().get_block_devices()
assert len(block_devices) == 2 assert len(block_devices) == 2
assert block_devices[0].name == "sda1" devices_by_name = {device.name: device for device in block_devices}
assert block_devices[0].path == "/dev/sda1" sda1 = devices_by_name["sda1"]
assert block_devices[0].fsavail == "4605702144" sdb = devices_by_name["sdb"]
assert block_devices[0].fssize == "19814920192"
assert block_devices[0].fstype == "ext4" assert sda1.name == "sda1"
assert block_devices[0].fsused == "14353719296" assert sda1.path == "/dev/sda1"
assert block_devices[0].mountpoints == ["/nix/store", "/"] assert sda1.fsavail == "4605702144"
assert block_devices[0].label is None assert sda1.fssize == "19814920192"
assert block_devices[0].uuid == "ec80c004-baec-4a2c-851d-0e1807135511" assert sda1.fstype == "ext4"
assert block_devices[0].size == "20210236928" assert sda1.fsused == "14353719296"
assert block_devices[0].model is None assert sda1.mountpoints == ["/nix/store", "/"]
assert block_devices[0].serial is None assert sda1.label is None
assert block_devices[0].type == "part" assert sda1.uuid == "ec80c004-baec-4a2c-851d-0e1807135511"
assert block_devices[1].name == "sdb" assert sda1.size == "20210236928"
assert block_devices[1].path == "/dev/sdb" assert sda1.model is None
assert block_devices[1].fsavail == "11888545792" assert sda1.serial is None
assert block_devices[1].fssize == "12573614080" assert sda1.type == "part"
assert block_devices[1].fstype == "ext4"
assert block_devices[1].fsused == "24047616" assert sdb.name == "sdb"
assert block_devices[1].mountpoints == ["/volumes/sdb"] assert sdb.path == "/dev/sdb"
assert block_devices[1].label is None assert sdb.fsavail == "11888545792"
assert block_devices[1].uuid == "fa9d0026-ee23-4047-b8b1-297ae16fa751" assert sdb.fssize == "12573614080"
assert block_devices[1].size == "12884901888" assert sdb.fstype == "ext4"
assert block_devices[1].model == "Volume" assert sdb.fsused == "24047616"
assert block_devices[1].serial == "21378102" assert sdb.mountpoints == ["/volumes/sdb"]
assert block_devices[1].type == "disk" assert sdb.label is None
assert sdb.uuid == "fa9d0026-ee23-4047-b8b1-297ae16fa751"
assert sdb.size == "12884901888"
assert sdb.model == "Volume"
assert sdb.serial == "21378102"
assert sdb.type == "disk"
def test_get_block_device(lsblk_full_mock, authorized_client): def test_get_block_device(lsblk_full_mock, authorized_client):
@ -506,3 +511,30 @@ def test_get_root_block_device(lsblk_full_mock, authorized_client):
assert block_device.model is None assert block_device.model is None
assert block_device.serial is None assert block_device.serial is None
assert block_device.type == "part" assert block_device.type == "part"
# Unassuming sanity check, yes this did fail
def test_get_real_devices():
block_devices = BlockDevices().get_block_devices()
assert block_devices is not None
assert len(block_devices) > 0
# Unassuming sanity check
def test_get_real_root_device():
devices = BlockDevices().get_block_devices()
try:
block_device = BlockDevices().get_root_block_device()
except Exception as e:
raise Exception("cannot get root device:", e, "devices found:", devices)
assert block_device is not None
assert block_device.name is not None
assert block_device.name != ""
def test_get_real_root_device_raw(authorized_client):
block_device = BlockDevices().get_root_block_device()
assert block_device is not None
assert block_device.name is not None
assert block_device.name != ""

View file

@ -1,59 +1,59 @@
{ {
"api": { "dns": {
"token": "TEST_TOKEN", "provider": "CLOUDFLARE",
"enableSwagger": false "useStagingACME": false
}, },
"bitwarden": { "server": {
"enable": true "provider": "HETZNER"
}, },
"databasePassword": "PASSWORD", "domain": "test-domain.tld",
"domain": "test.tld",
"hashedMasterPassword": "HASHED_PASSWORD", "hashedMasterPassword": "HASHED_PASSWORD",
"hostname": "test-instance", "hostname": "test-instance",
"nextcloud": { "timezone": "Etc/UTC",
"adminPassword": "ADMIN", "username": "tester",
"databasePassword": "ADMIN", "useBinds": true,
"enable": true "sshKeys": [
"ssh-rsa KEY test@pc"
],
"users": [],
"autoUpgrade": {
"enable": true,
"allowReboot": true
}, },
"resticPassword": "PASS", "modules": {
"bitwarden": {
"enable": true,
"location": "sdb"
},
"gitea": {
"enable": true,
"location": "sdb"
},
"jitsi-meet": {
"enable": true
},
"nextcloud": {
"enable": true,
"location": "sdb"
},
"ocserv": {
"enable": true
},
"pleroma": {
"enable": true,
"location": "sdb"
},
"simple-nixos-mailserver": {
"enable": true,
"location": "sdb"
}
},
"volumes": [],
"ssh": { "ssh": {
"enable": true, "enable": true,
"passwordAuthentication": true, "passwordAuthentication": true,
"rootKeys": [ "rootKeys": [
"ssh-ed25519 KEY test@pc" "ssh-ed25519 KEY test@pc"
] ]
}, }
"username": "tester",
"gitea": {
"enable": false
},
"ocserv": {
"enable": true
},
"pleroma": {
"enable": true
},
"autoUpgrade": {
"enable": true,
"allowReboot": true
},
"timezone": "Europe/Moscow",
"sshKeys": [
"ssh-rsa KEY test@pc"
],
"dns": {
"provider": "CLOUDFLARE",
"apiKey": "TOKEN"
},
"server": {
"provider": "HETZNER"
},
"backup": {
"provider": "BACKBLAZE",
"accountId": "ID",
"accountKey": "KEY",
"bucket": "selfprivacy"
},
"volumes": [
]
} }

View file

@ -1,64 +1,65 @@
{ {
"api": { "dns": {
"token": "TEST_TOKEN", "provider": "CLOUDFLARE",
"enableSwagger": false "useStagingACME": false
}, },
"bitwarden": { "server": {
"enable": true "provider": "HETZNER"
}, },
"databasePassword": "PASSWORD", "domain": "test-domain.tld",
"domain": "test.tld",
"hashedMasterPassword": "HASHED_PASSWORD", "hashedMasterPassword": "HASHED_PASSWORD",
"hostname": "test-instance", "hostname": "test-instance",
"nextcloud": { "timezone": "Etc/UTC",
"adminPassword": "ADMIN", "username": "tester",
"databasePassword": "ADMIN", "useBinds": true,
"enable": true "sshKeys": [
"ssh-rsa KEY test@pc"
],
"users": [],
"autoUpgrade": {
"enable": true,
"allowReboot": true
}, },
"resticPassword": "PASS", "modules": {
"bitwarden": {
"enable": true,
"location": "sdb"
},
"gitea": {
"enable": true,
"location": "sdb"
},
"jitsi-meet": {
"enable": true
},
"nextcloud": {
"enable": true,
"location": "sdb"
},
"ocserv": {
"enable": true
},
"pleroma": {
"enable": true,
"location": "sdb"
},
"simple-nixos-mailserver": {
"enable": true,
"location": "sdb"
}
},
"volumes": [
{
"device": "/dev/sda1",
"mountPoint": "/volumes/sda1",
"fsType": "ext4"
}
],
"ssh": { "ssh": {
"enable": true, "enable": true,
"passwordAuthentication": true, "passwordAuthentication": true,
"rootKeys": [ "rootKeys": [
"ssh-ed25519 KEY test@pc" "ssh-ed25519 KEY test@pc"
] ]
},
"username": "tester",
"gitea": {
"enable": false
},
"ocserv": {
"enable": true
},
"pleroma": {
"enable": true
},
"autoUpgrade": {
"enable": true,
"allowReboot": true
},
"timezone": "Europe/Moscow",
"sshKeys": [
"ssh-rsa KEY test@pc"
],
"volumes": [
{
"device": "/dev/sda1",
"mountPoint": "/volumes/sda1",
"filesystem": "ext4"
}
],
"dns": {
"provider": "CLOUDFLARE",
"apiKey": "TOKEN"
},
"server": {
"provider": "HETZNER"
},
"backup": {
"provider": "BACKBLAZE",
"accountId": "ID",
"accountKey": "KEY",
"bucket": "selfprivacy"
} }
} }

View file

@ -1,57 +1,58 @@
{ {
"api": { "dns": {
"token": "TEST_TOKEN", "provider": "CLOUDFLARE",
"enableSwagger": false "useStagingACME": false
}, },
"bitwarden": { "server": {
"enable": true "provider": "HETZNER"
}, },
"databasePassword": "PASSWORD", "domain": "test-domain.tld",
"domain": "test.tld",
"hashedMasterPassword": "HASHED_PASSWORD", "hashedMasterPassword": "HASHED_PASSWORD",
"hostname": "test-instance", "hostname": "test-instance",
"nextcloud": { "timezone": "Etc/UTC",
"adminPassword": "ADMIN", "username": "tester",
"databasePassword": "ADMIN", "useBinds": true,
"enable": true "sshKeys": [
"ssh-rsa KEY test@pc"
],
"users": [],
"autoUpgrade": {
"enable": true,
"allowReboot": true
},
"modules": {
"bitwarden": {
"enable": true,
"location": "sdb"
},
"gitea": {
"enable": true,
"location": "sdb"
},
"jitsi-meet": {
"enable": true
},
"nextcloud": {
"enable": true,
"location": "sdb"
},
"ocserv": {
"enable": true
},
"pleroma": {
"enable": true,
"location": "sdb"
},
"simple-nixos-mailserver": {
"enable": true,
"location": "sdb"
}
}, },
"resticPassword": "PASS",
"ssh": { "ssh": {
"enable": true, "enable": true,
"passwordAuthentication": true, "passwordAuthentication": true,
"rootKeys": [ "rootKeys": [
"ssh-ed25519 KEY test@pc" "ssh-ed25519 KEY test@pc"
] ]
},
"username": "tester",
"gitea": {
"enable": false
},
"ocserv": {
"enable": true
},
"pleroma": {
"enable": true
},
"autoUpgrade": {
"enable": true,
"allowReboot": true
},
"timezone": "Europe/Moscow",
"sshKeys": [
"ssh-rsa KEY test@pc"
],
"dns": {
"provider": "CLOUDFLARE",
"apiKey": "TOKEN"
},
"server": {
"provider": "HETZNER"
},
"backup": {
"provider": "BACKBLAZE",
"accountId": "ID",
"accountKey": "KEY",
"bucket": "selfprivacy"
} }
} }

View file

@ -1,6 +1,5 @@
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
# pylint: disable=unused-argument # pylint: disable=unused-argument
import json
import os import os
import pytest import pytest

52
tests/test_dkim.py Normal file
View file

@ -0,0 +1,52 @@
import pytest
import os
from os import path
from tests.conftest import global_data_dir
from selfprivacy_api.utils import get_dkim_key, get_domain
###############################################################################
DKIM_FILE_CONTENT = b'selector._domainkey\tIN\tTXT\t( "v=DKIM1; k=rsa; "\n\t "p=MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDNn/IhEz1SxgHxxxI8vlPYC2dNueiLe1GC4SYz8uHimC8SDkMvAwm7rqi2SimbFgGB5nccCNOqCkrIqJTCB9vufqBnVKAjshHqpOr5hk4JJ1T/AGQKWinstmDbfTLPYTbU8ijZrwwGeqQLlnXR5nSN0GB9GazheA9zaPsT6PV+aQIDAQAB" ) ; ----- DKIM key selector for test-domain.tld\n'
@pytest.fixture
def dkim_file(mocker, tmpdir, generic_userdata):
domain = get_domain()
assert domain is not None
assert domain != ""
filename = domain + ".selector.txt"
dkim_path = path.join(tmpdir, filename)
with open(dkim_path, "wb") as file:
file.write(DKIM_FILE_CONTENT)
mocker.patch("selfprivacy_api.utils.DKIM_DIR", tmpdir)
return dkim_path
@pytest.fixture
def no_dkim_file(dkim_file):
os.remove(dkim_file)
assert path.exists(dkim_file) is False
return dkim_file
###############################################################################
def test_get_dkim_key(dkim_file):
"""Test DKIM key"""
dkim_key = get_dkim_key("test-domain.tld")
assert (
dkim_key
== "v=DKIM1; k=rsa; p=MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDNn/IhEz1SxgHxxxI8vlPYC2dNueiLe1GC4SYz8uHimC8SDkMvAwm7rqi2SimbFgGB5nccCNOqCkrIqJTCB9vufqBnVKAjshHqpOr5hk4JJ1T/AGQKWinstmDbfTLPYTbU8ijZrwwGeqQLlnXR5nSN0GB9GazheA9zaPsT6PV+aQIDAQAB"
)
def test_no_dkim_key(no_dkim_file):
"""Test no DKIM key"""
dkim_key = get_dkim_key("test-domain.tld")
assert dkim_key is None

View file

@ -0,0 +1,96 @@
from tests.common import generate_api_query
from tests.conftest import TOKENS_FILE_CONTENTS, DEVICE_WE_AUTH_TESTS_WITH
ORIGINAL_DEVICES = TOKENS_FILE_CONTENTS["tokens"]
def assert_ok(output: dict, code=200) -> None:
if output["success"] is False:
# convenience for debugging, this should display error
# if message is empty, consider adding helpful messages
raise ValueError(output["code"], output["message"])
assert output["success"] is True
assert output["message"] is not None
assert output["code"] == code
def assert_errorcode(output: dict, code) -> None:
assert output["success"] is False
assert output["message"] is not None
assert output["code"] == code
def assert_empty(response):
assert response.status_code == 200
assert response.json().get("data") is None
def get_data(response):
assert response.status_code == 200
response = response.json()
if (
"errors" in response.keys()
): # convenience for debugging, this will display error
raise ValueError(response["errors"])
data = response.get("data")
assert data is not None
return data
API_DEVICES_QUERY = """
devices {
creationDate
isCaller
name
}
"""
def request_devices(client):
return client.post(
"/graphql",
json={"query": generate_api_query([API_DEVICES_QUERY])},
)
def graphql_get_devices(client):
response = request_devices(client)
data = get_data(response)
devices = data["api"]["devices"]
assert devices is not None
return devices
def set_client_token(client, token):
client.headers.update({"Authorization": "Bearer " + token})
def assert_token_valid(client, token):
set_client_token(client, token)
assert graphql_get_devices(client) is not None
def assert_same(graphql_devices, abstract_devices):
"""Orderless comparison"""
assert len(graphql_devices) == len(abstract_devices)
for original_device in abstract_devices:
assert original_device["name"] in [device["name"] for device in graphql_devices]
for device in graphql_devices:
if device["name"] == original_device["name"]:
assert device["creationDate"] == original_device["date"].isoformat()
def assert_original(client):
devices = graphql_get_devices(client)
assert_original_devices(devices)
def assert_original_devices(devices):
assert_same(devices, ORIGINAL_DEVICES)
for device in devices:
if device["name"] == DEVICE_WE_AUTH_TESTS_WITH["name"]:
assert device["isCaller"] is True
else:
assert device["isCaller"] is False

View file

@ -1,14 +0,0 @@
{
"tokens": [
{
"token": "TEST_TOKEN",
"name": "test_token",
"date": "2022-01-14 08:31:10.789314"
},
{
"token": "TEST_TOKEN2",
"name": "test_token2",
"date": "2022-01-14 08:31:10.789314"
}
]
}

View file

@ -3,27 +3,13 @@
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
from tests.common import generate_api_query from tests.common import generate_api_query
from tests.test_graphql.common import assert_original_devices
from tests.test_graphql.test_api_devices import API_DEVICES_QUERY from tests.test_graphql.test_api_devices import API_DEVICES_QUERY
from tests.test_graphql.test_api_recovery import API_RECOVERY_QUERY from tests.test_graphql.test_api_recovery import API_RECOVERY_QUERY
from tests.test_graphql.test_api_version import API_VERSION_QUERY from tests.test_graphql.test_api_version import API_VERSION_QUERY
TOKENS_FILE_CONTETS = {
"tokens": [
{
"token": "TEST_TOKEN",
"name": "test_token",
"date": "2022-01-14 08:31:10.789314",
},
{
"token": "TEST_TOKEN2",
"name": "test_token2",
"date": "2022-01-14 08:31:10.789314",
},
]
}
def test_graphql_get_entire_api_data(authorized_client):
def test_graphql_get_entire_api_data(authorized_client, tokens_file):
response = authorized_client.post( response = authorized_client.post(
"/graphql", "/graphql",
json={ json={
@ -35,20 +21,11 @@ def test_graphql_get_entire_api_data(authorized_client, tokens_file):
assert response.status_code == 200 assert response.status_code == 200
assert response.json().get("data") is not None assert response.json().get("data") is not None
assert "version" in response.json()["data"]["api"] assert "version" in response.json()["data"]["api"]
assert response.json()["data"]["api"]["devices"] is not None
assert len(response.json()["data"]["api"]["devices"]) == 2 devices = response.json()["data"]["api"]["devices"]
assert ( assert devices is not None
response.json()["data"]["api"]["devices"][0]["creationDate"] assert_original_devices(devices)
== "2022-01-14T08:31:10.789314"
)
assert response.json()["data"]["api"]["devices"][0]["isCaller"] is True
assert response.json()["data"]["api"]["devices"][0]["name"] == "test_token"
assert (
response.json()["data"]["api"]["devices"][1]["creationDate"]
== "2022-01-14T08:31:10.789314"
)
assert response.json()["data"]["api"]["devices"][1]["isCaller"] is False
assert response.json()["data"]["api"]["devices"][1]["name"] == "test_token2"
assert response.json()["data"]["api"]["recoveryKey"] is not None assert response.json()["data"]["api"]["recoveryKey"] is not None
assert response.json()["data"]["api"]["recoveryKey"]["exists"] is False assert response.json()["data"]["api"]["recoveryKey"]["exists"] is False
assert response.json()["data"]["api"]["recoveryKey"]["valid"] is False assert response.json()["data"]["api"]["recoveryKey"]["valid"] is False

View file

@ -1,9 +1,13 @@
from os import path from os import path
from tests.test_graphql.test_backup import dummy_service, backups, raw_dummy_service from tests.test_backup import backups
from tests.common import generate_backup_query from tests.common import generate_backup_query
from selfprivacy_api.graphql.common_types.service import service_to_graphql_service from selfprivacy_api.graphql.common_types.service import service_to_graphql_service
from selfprivacy_api.graphql.common_types.backup import (
_AutobackupQuotas,
AutobackupQuotas,
)
from selfprivacy_api.jobs import Jobs, JobStatus from selfprivacy_api.jobs import Jobs, JobStatus
API_RELOAD_SNAPSHOTS = """ API_RELOAD_SNAPSHOTS = """
@ -38,6 +42,34 @@ mutation TestAutobackupPeriod($period: Int) {
} }
""" """
API_SET_AUTOBACKUP_QUOTAS_MUTATION = """
mutation TestAutobackupQuotas($input: AutobackupQuotasInput!) {
backup {
setAutobackupQuotas(quotas: $input) {
success
message
code
configuration {
provider
encryptionKey
isInitialized
autobackupPeriod
locationName
locationId
autobackupQuotas {
last
daily
weekly
monthly
yearly
}
}
}
}
}
"""
API_REMOVE_REPOSITORY_MUTATION = """ API_REMOVE_REPOSITORY_MUTATION = """
mutation TestRemoveRepo { mutation TestRemoveRepo {
backup { backup {
@ -113,6 +145,7 @@ allSnapshots {
id id
} }
createdAt createdAt
reason
} }
""" """
@ -177,6 +210,17 @@ def api_set_period(authorized_client, period):
return response return response
def api_set_quotas(authorized_client, quotas: _AutobackupQuotas):
response = authorized_client.post(
"/graphql",
json={
"query": API_SET_AUTOBACKUP_QUOTAS_MUTATION,
"variables": {"input": quotas.dict()},
},
)
return response
def api_remove(authorized_client): def api_remove(authorized_client):
response = authorized_client.post( response = authorized_client.post(
"/graphql", "/graphql",
@ -221,6 +265,10 @@ def api_init_without_key(
def assert_ok(data): def assert_ok(data):
if data["success"] is False:
# convenience for debugging, this should display error
# if empty, consider adding helpful messages
raise ValueError(data["code"], data["message"])
assert data["code"] == 200 assert data["code"] == 200
assert data["success"] is True assert data["success"] is True
@ -231,7 +279,7 @@ def get_data(response):
if ( if (
"errors" in response.keys() "errors" in response.keys()
): # convenience for debugging, this will display error ): # convenience for debugging, this will display error
assert response["errors"] == [] raise ValueError(response["errors"])
assert response["data"] is not None assert response["data"] is not None
data = response["data"] data = response["data"]
return data return data
@ -253,12 +301,12 @@ def test_dummy_service_convertible_to_gql(dummy_service):
assert gql_service is not None assert gql_service is not None
def test_snapshots_empty(authorized_client, dummy_service): def test_snapshots_empty(authorized_client, dummy_service, backups):
snaps = api_snapshots(authorized_client) snaps = api_snapshots(authorized_client)
assert snaps == [] assert snaps == []
def test_start_backup(authorized_client, dummy_service): def test_start_backup(authorized_client, dummy_service, backups):
response = api_backup(authorized_client, dummy_service) response = api_backup(authorized_client, dummy_service)
data = get_data(response)["backup"]["startBackup"] data = get_data(response)["backup"]["startBackup"]
assert data["success"] is True assert data["success"] is True
@ -274,7 +322,7 @@ def test_start_backup(authorized_client, dummy_service):
assert snap["service"]["id"] == "testservice" assert snap["service"]["id"] == "testservice"
def test_restore(authorized_client, dummy_service): def test_restore(authorized_client, dummy_service, backups):
api_backup(authorized_client, dummy_service) api_backup(authorized_client, dummy_service)
snap = api_snapshots(authorized_client)[0] snap = api_snapshots(authorized_client)[0]
assert snap["id"] is not None assert snap["id"] is not None
@ -287,7 +335,7 @@ def test_restore(authorized_client, dummy_service):
assert Jobs.get_job(job["uid"]).status == JobStatus.FINISHED assert Jobs.get_job(job["uid"]).status == JobStatus.FINISHED
def test_reinit(authorized_client, dummy_service, tmpdir): def test_reinit(authorized_client, dummy_service, tmpdir, backups):
test_repo_path = path.join(tmpdir, "not_at_all_sus") test_repo_path = path.join(tmpdir, "not_at_all_sus")
response = api_init_without_key( response = api_init_without_key(
authorized_client, "FILE", "", "", test_repo_path, "" authorized_client, "FILE", "", "", test_repo_path, ""
@ -309,7 +357,7 @@ def test_reinit(authorized_client, dummy_service, tmpdir):
assert Jobs.get_job(job["uid"]).status == JobStatus.FINISHED assert Jobs.get_job(job["uid"]).status == JobStatus.FINISHED
def test_remove(authorized_client, generic_userdata): def test_remove(authorized_client, generic_userdata, backups):
response = api_remove(authorized_client) response = api_remove(authorized_client)
data = get_data(response)["backup"]["removeRepository"] data = get_data(response)["backup"]["removeRepository"]
assert_ok(data) assert_ok(data)
@ -323,7 +371,23 @@ def test_remove(authorized_client, generic_userdata):
assert configuration["isInitialized"] is False assert configuration["isInitialized"] is False
def test_autobackup_period_nonzero(authorized_client): def test_autobackup_quotas_nonzero(authorized_client, backups):
quotas = _AutobackupQuotas(
last=3,
daily=2,
weekly=4,
monthly=13,
yearly=14,
)
response = api_set_quotas(authorized_client, quotas)
data = get_data(response)["backup"]["setAutobackupQuotas"]
assert_ok(data)
configuration = data["configuration"]
assert configuration["autobackupQuotas"] == quotas
def test_autobackup_period_nonzero(authorized_client, backups):
new_period = 11 new_period = 11
response = api_set_period(authorized_client, new_period) response = api_set_period(authorized_client, new_period)
data = get_data(response)["backup"]["setAutobackupPeriod"] data = get_data(response)["backup"]["setAutobackupPeriod"]
@ -333,7 +397,7 @@ def test_autobackup_period_nonzero(authorized_client):
assert configuration["autobackupPeriod"] == new_period assert configuration["autobackupPeriod"] == new_period
def test_autobackup_period_zero(authorized_client): def test_autobackup_period_zero(authorized_client, backups):
new_period = 0 new_period = 0
# since it is none by default, we better first set it to something non-negative # since it is none by default, we better first set it to something non-negative
response = api_set_period(authorized_client, 11) response = api_set_period(authorized_client, 11)
@ -346,7 +410,7 @@ def test_autobackup_period_zero(authorized_client):
assert configuration["autobackupPeriod"] == None assert configuration["autobackupPeriod"] == None
def test_autobackup_period_none(authorized_client): def test_autobackup_period_none(authorized_client, backups):
# since it is none by default, we better first set it to something non-negative # since it is none by default, we better first set it to something non-negative
response = api_set_period(authorized_client, 11) response = api_set_period(authorized_client, 11)
# and now we nullify it # and now we nullify it
@ -358,7 +422,7 @@ def test_autobackup_period_none(authorized_client):
assert configuration["autobackupPeriod"] == None assert configuration["autobackupPeriod"] == None
def test_autobackup_period_negative(authorized_client): def test_autobackup_period_negative(authorized_client, backups):
# since it is none by default, we better first set it to something non-negative # since it is none by default, we better first set it to something non-negative
response = api_set_period(authorized_client, 11) response = api_set_period(authorized_client, 11)
# and now we nullify it # and now we nullify it
@ -372,7 +436,7 @@ def test_autobackup_period_negative(authorized_client):
# We cannot really check the effect at this level, we leave it to backend tests # We cannot really check the effect at this level, we leave it to backend tests
# But we still make it run in both empty and full scenarios and ask for snaps afterwards # But we still make it run in both empty and full scenarios and ask for snaps afterwards
def test_reload_snapshots_bare_bare_bare(authorized_client, dummy_service): def test_reload_snapshots_bare_bare_bare(authorized_client, dummy_service, backups):
api_remove(authorized_client) api_remove(authorized_client)
response = api_reload_snapshots(authorized_client) response = api_reload_snapshots(authorized_client)
@ -383,7 +447,7 @@ def test_reload_snapshots_bare_bare_bare(authorized_client, dummy_service):
assert snaps == [] assert snaps == []
def test_reload_snapshots(authorized_client, dummy_service): def test_reload_snapshots(authorized_client, dummy_service, backups):
response = api_backup(authorized_client, dummy_service) response = api_backup(authorized_client, dummy_service)
data = get_data(response)["backup"]["startBackup"] data = get_data(response)["backup"]["startBackup"]
@ -395,7 +459,7 @@ def test_reload_snapshots(authorized_client, dummy_service):
assert len(snaps) == 1 assert len(snaps) == 1
def test_forget_snapshot(authorized_client, dummy_service): def test_forget_snapshot(authorized_client, dummy_service, backups):
response = api_backup(authorized_client, dummy_service) response = api_backup(authorized_client, dummy_service)
data = get_data(response)["backup"]["startBackup"] data = get_data(response)["backup"]["startBackup"]
@ -410,7 +474,7 @@ def test_forget_snapshot(authorized_client, dummy_service):
assert len(snaps) == 0 assert len(snaps) == 0
def test_forget_nonexistent_snapshot(authorized_client, dummy_service): def test_forget_nonexistent_snapshot(authorized_client, dummy_service, backups):
snaps = api_snapshots(authorized_client) snaps = api_snapshots(authorized_client)
assert len(snaps) == 0 assert len(snaps) == 0
response = api_forget(authorized_client, "898798uekiodpjoiweoiwuoeirueor") response = api_forget(authorized_client, "898798uekiodpjoiweoiwuoeirueor")

View file

@ -1,76 +1,78 @@
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
# pylint: disable=unused-argument # pylint: disable=unused-argument
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
import datetime from tests.common import (
import pytest RECOVERY_KEY_VALIDATION_DATETIME,
from mnemonic import Mnemonic DEVICE_KEY_VALIDATION_DATETIME,
NearFuture,
from selfprivacy_api.repositories.tokens.json_tokens_repository import ( generate_api_query,
JsonTokensRepository, )
from tests.conftest import DEVICE_WE_AUTH_TESTS_WITH
from tests.test_graphql.common import (
get_data,
assert_empty,
assert_ok,
assert_errorcode,
assert_token_valid,
assert_original,
assert_same,
graphql_get_devices,
request_devices,
set_client_token,
API_DEVICES_QUERY,
ORIGINAL_DEVICES,
) )
from selfprivacy_api.models.tokens.token import Token
from tests.common import generate_api_query, read_json, write_json
TOKENS_FILE_CONTETS = {
"tokens": [
{
"token": "TEST_TOKEN",
"name": "test_token",
"date": "2022-01-14 08:31:10.789314",
},
{
"token": "TEST_TOKEN2",
"name": "test_token2",
"date": "2022-01-14 08:31:10.789314",
},
]
}
API_DEVICES_QUERY = """
devices {
creationDate
isCaller
name
}
"""
@pytest.fixture def graphql_get_caller_token_info(client):
def token_repo(): devices = graphql_get_devices(client)
return JsonTokensRepository() for device in devices:
if device["isCaller"] is True:
return device
def test_graphql_tokens_info(authorized_client, tokens_file): def graphql_get_new_device_key(authorized_client) -> str:
response = authorized_client.post( response = authorized_client.post(
"/graphql", "/graphql",
json={"query": generate_api_query([API_DEVICES_QUERY])}, json={"query": NEW_DEVICE_KEY_MUTATION},
) )
assert response.status_code == 200 assert_ok(get_data(response)["api"]["getNewDeviceApiKey"])
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["devices"] is not None key = response.json()["data"]["api"]["getNewDeviceApiKey"]["key"]
assert len(response.json()["data"]["api"]["devices"]) == 2 assert key.split(" ").__len__() == 12
assert ( return key
response.json()["data"]["api"]["devices"][0]["creationDate"]
== "2022-01-14T08:31:10.789314"
)
assert response.json()["data"]["api"]["devices"][0]["isCaller"] is True
assert response.json()["data"]["api"]["devices"][0]["name"] == "test_token"
assert (
response.json()["data"]["api"]["devices"][1]["creationDate"]
== "2022-01-14T08:31:10.789314"
)
assert response.json()["data"]["api"]["devices"][1]["isCaller"] is False
assert response.json()["data"]["api"]["devices"][1]["name"] == "test_token2"
def test_graphql_tokens_info_unauthorized(client, tokens_file): def graphql_try_auth_new_device(client, mnemonic_key, device_name):
response = client.post( return client.post(
"/graphql", "/graphql",
json={"query": generate_api_query([API_DEVICES_QUERY])}, json={
"query": AUTHORIZE_WITH_NEW_DEVICE_KEY_MUTATION,
"variables": {
"input": {
"key": mnemonic_key,
"deviceName": device_name,
}
},
},
) )
assert response.status_code == 200
assert response.json()["data"] is None
def graphql_authorize_new_device(client, mnemonic_key, device_name) -> str:
response = graphql_try_auth_new_device(client, mnemonic_key, "new_device")
assert_ok(get_data(response)["api"]["authorizeWithNewDeviceApiKey"])
token = response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["token"]
assert_token_valid(client, token)
return token
def test_graphql_tokens_info(authorized_client):
assert_original(authorized_client)
def test_graphql_tokens_info_unauthorized(client):
response = request_devices(client)
assert_empty(response)
DELETE_TOKEN_MUTATION = """ DELETE_TOKEN_MUTATION = """
@ -86,7 +88,7 @@ mutation DeleteToken($device: String!) {
""" """
def test_graphql_delete_token_unauthorized(client, tokens_file): def test_graphql_delete_token_unauthorized(client):
response = client.post( response = client.post(
"/graphql", "/graphql",
json={ json={
@ -96,57 +98,45 @@ def test_graphql_delete_token_unauthorized(client, tokens_file):
}, },
}, },
) )
assert response.status_code == 200 assert_empty(response)
assert response.json()["data"] is None
def test_graphql_delete_token(authorized_client, tokens_file): def test_graphql_delete_token(authorized_client):
test_devices = ORIGINAL_DEVICES.copy()
device_to_delete = test_devices.pop(1)
assert device_to_delete != DEVICE_WE_AUTH_TESTS_WITH
response = authorized_client.post( response = authorized_client.post(
"/graphql", "/graphql",
json={ json={
"query": DELETE_TOKEN_MUTATION, "query": DELETE_TOKEN_MUTATION,
"variables": { "variables": {
"device": "test_token2", "device": device_to_delete["name"],
}, },
}, },
) )
assert response.status_code == 200 assert_ok(get_data(response)["api"]["deleteDeviceApiToken"])
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["deleteDeviceApiToken"]["success"] is True devices = graphql_get_devices(authorized_client)
assert response.json()["data"]["api"]["deleteDeviceApiToken"]["message"] is not None assert_same(devices, test_devices)
assert response.json()["data"]["api"]["deleteDeviceApiToken"]["code"] == 200
assert read_json(tokens_file) == {
"tokens": [
{
"token": "TEST_TOKEN",
"name": "test_token",
"date": "2022-01-14 08:31:10.789314",
}
]
}
def test_graphql_delete_self_token(authorized_client, tokens_file): def test_graphql_delete_self_token(authorized_client):
response = authorized_client.post( response = authorized_client.post(
"/graphql", "/graphql",
json={ json={
"query": DELETE_TOKEN_MUTATION, "query": DELETE_TOKEN_MUTATION,
"variables": { "variables": {
"device": "test_token", "device": DEVICE_WE_AUTH_TESTS_WITH["name"],
}, },
}, },
) )
assert response.status_code == 200 assert_errorcode(get_data(response)["api"]["deleteDeviceApiToken"], 400)
assert response.json().get("data") is not None assert_original(authorized_client)
assert response.json()["data"]["api"]["deleteDeviceApiToken"]["success"] is False
assert response.json()["data"]["api"]["deleteDeviceApiToken"]["message"] is not None
assert response.json()["data"]["api"]["deleteDeviceApiToken"]["code"] == 400
assert read_json(tokens_file) == TOKENS_FILE_CONTETS
def test_graphql_delete_nonexistent_token( def test_graphql_delete_nonexistent_token(
authorized_client, authorized_client,
tokens_file,
): ):
response = authorized_client.post( response = authorized_client.post(
"/graphql", "/graphql",
@ -157,12 +147,9 @@ def test_graphql_delete_nonexistent_token(
}, },
}, },
) )
assert response.status_code == 200 assert_errorcode(get_data(response)["api"]["deleteDeviceApiToken"], 404)
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["deleteDeviceApiToken"]["success"] is False assert_original(authorized_client)
assert response.json()["data"]["api"]["deleteDeviceApiToken"]["message"] is not None
assert response.json()["data"]["api"]["deleteDeviceApiToken"]["code"] == 404
assert read_json(tokens_file) == TOKENS_FILE_CONTETS
REFRESH_TOKEN_MUTATION = """ REFRESH_TOKEN_MUTATION = """
@ -179,37 +166,27 @@ mutation RefreshToken {
""" """
def test_graphql_refresh_token_unauthorized(client, tokens_file): def test_graphql_refresh_token_unauthorized(client):
response = client.post( response = client.post(
"/graphql", "/graphql",
json={"query": REFRESH_TOKEN_MUTATION}, json={"query": REFRESH_TOKEN_MUTATION},
) )
assert response.status_code == 200 assert_empty(response)
assert response.json()["data"] is None
def test_graphql_refresh_token( def test_graphql_refresh_token(authorized_client, client):
authorized_client, caller_name_and_date = graphql_get_caller_token_info(authorized_client)
tokens_file,
token_repo,
):
response = authorized_client.post( response = authorized_client.post(
"/graphql", "/graphql",
json={"query": REFRESH_TOKEN_MUTATION}, json={"query": REFRESH_TOKEN_MUTATION},
) )
assert response.status_code == 200 assert_ok(get_data(response)["api"]["refreshDeviceApiToken"])
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["refreshDeviceApiToken"]["success"] is True new_token = response.json()["data"]["api"]["refreshDeviceApiToken"]["token"]
assert ( assert_token_valid(client, new_token)
response.json()["data"]["api"]["refreshDeviceApiToken"]["message"] is not None
) set_client_token(client, new_token)
assert response.json()["data"]["api"]["refreshDeviceApiToken"]["code"] == 200 assert graphql_get_caller_token_info(client) == caller_name_and_date
token = token_repo.get_token_by_name("test_token")
assert token == Token(
token=response.json()["data"]["api"]["refreshDeviceApiToken"]["token"],
device_name="test_token",
created_at=datetime.datetime(2022, 1, 14, 8, 31, 10, 789314),
)
NEW_DEVICE_KEY_MUTATION = """ NEW_DEVICE_KEY_MUTATION = """
@ -228,39 +205,12 @@ mutation NewDeviceKey {
def test_graphql_get_new_device_auth_key_unauthorized( def test_graphql_get_new_device_auth_key_unauthorized(
client, client,
tokens_file,
): ):
response = client.post( response = client.post(
"/graphql", "/graphql",
json={"query": NEW_DEVICE_KEY_MUTATION}, json={"query": NEW_DEVICE_KEY_MUTATION},
) )
assert response.status_code == 200 assert_empty(response)
assert response.json()["data"] is None
def test_graphql_get_new_device_auth_key(
authorized_client,
tokens_file,
):
response = authorized_client.post(
"/graphql",
json={"query": NEW_DEVICE_KEY_MUTATION},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["success"] is True
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["message"] is not None
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["code"] == 200
assert (
response.json()["data"]["api"]["getNewDeviceApiKey"]["key"].split(" ").__len__()
== 12
)
token = (
Mnemonic(language="english")
.to_entropy(response.json()["data"]["api"]["getNewDeviceApiKey"]["key"])
.hex()
)
assert read_json(tokens_file)["new_device"]["token"] == token
INVALIDATE_NEW_DEVICE_KEY_MUTATION = """ INVALIDATE_NEW_DEVICE_KEY_MUTATION = """
@ -278,7 +228,6 @@ mutation InvalidateNewDeviceKey {
def test_graphql_invalidate_new_device_token_unauthorized( def test_graphql_invalidate_new_device_token_unauthorized(
client, client,
tokens_file,
): ):
response = client.post( response = client.post(
"/graphql", "/graphql",
@ -289,48 +238,20 @@ def test_graphql_invalidate_new_device_token_unauthorized(
}, },
}, },
) )
assert response.status_code == 200 assert_empty(response)
assert response.json()["data"] is None
def test_graphql_get_and_delete_new_device_key( def test_graphql_get_and_delete_new_device_key(client, authorized_client):
authorized_client, mnemonic_key = graphql_get_new_device_key(authorized_client)
tokens_file,
):
response = authorized_client.post(
"/graphql",
json={"query": NEW_DEVICE_KEY_MUTATION},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["success"] is True
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["message"] is not None
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["code"] == 200
assert (
response.json()["data"]["api"]["getNewDeviceApiKey"]["key"].split(" ").__len__()
== 12
)
token = (
Mnemonic(language="english")
.to_entropy(response.json()["data"]["api"]["getNewDeviceApiKey"]["key"])
.hex()
)
assert read_json(tokens_file)["new_device"]["token"] == token
response = authorized_client.post( response = authorized_client.post(
"/graphql", "/graphql",
json={"query": INVALIDATE_NEW_DEVICE_KEY_MUTATION}, json={"query": INVALIDATE_NEW_DEVICE_KEY_MUTATION},
) )
assert response.status_code == 200 assert_ok(get_data(response)["api"]["invalidateNewDeviceApiKey"])
assert response.json().get("data") is not None
assert ( response = graphql_try_auth_new_device(client, mnemonic_key, "new_device")
response.json()["data"]["api"]["invalidateNewDeviceApiKey"]["success"] is True assert_errorcode(get_data(response)["api"]["authorizeWithNewDeviceApiKey"], 404)
)
assert (
response.json()["data"]["api"]["invalidateNewDeviceApiKey"]["message"]
is not None
)
assert response.json()["data"]["api"]["invalidateNewDeviceApiKey"]["code"] == 200
assert read_json(tokens_file) == TOKENS_FILE_CONTETS
AUTHORIZE_WITH_NEW_DEVICE_KEY_MUTATION = """ AUTHORIZE_WITH_NEW_DEVICE_KEY_MUTATION = """
@ -347,214 +268,48 @@ mutation AuthorizeWithNewDeviceKey($input: UseNewDeviceKeyInput!) {
""" """
def test_graphql_get_and_authorize_new_device( def test_graphql_get_and_authorize_new_device(client, authorized_client):
client, mnemonic_key = graphql_get_new_device_key(authorized_client)
authorized_client, old_devices = graphql_get_devices(authorized_client)
tokens_file,
): graphql_authorize_new_device(client, mnemonic_key, "new_device")
response = authorized_client.post( new_devices = graphql_get_devices(authorized_client)
"/graphql",
json={"query": NEW_DEVICE_KEY_MUTATION}, assert len(new_devices) == len(old_devices) + 1
) assert "new_device" in [device["name"] for device in new_devices]
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["success"] is True
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["message"] is not None
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["code"] == 200
mnemonic_key = response.json()["data"]["api"]["getNewDeviceApiKey"]["key"]
assert mnemonic_key.split(" ").__len__() == 12
key = Mnemonic(language="english").to_entropy(mnemonic_key).hex()
assert read_json(tokens_file)["new_device"]["token"] == key
response = client.post(
"/graphql",
json={
"query": AUTHORIZE_WITH_NEW_DEVICE_KEY_MUTATION,
"variables": {
"input": {
"key": mnemonic_key,
"deviceName": "new_device",
}
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert (
response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["success"]
is True
)
assert (
response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["message"]
is not None
)
assert response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["code"] == 200
token = response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["token"]
assert read_json(tokens_file)["tokens"][2]["token"] == token
assert read_json(tokens_file)["tokens"][2]["name"] == "new_device"
def test_graphql_authorize_new_device_with_invalid_key( def test_graphql_authorize_new_device_with_invalid_key(client, authorized_client):
client, response = graphql_try_auth_new_device(client, "invalid_token", "new_device")
tokens_file, assert_errorcode(get_data(response)["api"]["authorizeWithNewDeviceApiKey"], 404)
):
response = client.post( assert_original(authorized_client)
"/graphql",
json={
"query": AUTHORIZE_WITH_NEW_DEVICE_KEY_MUTATION,
"variables": {
"input": {
"key": "invalid_token",
"deviceName": "test_token",
}
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert (
response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["success"]
is False
)
assert (
response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["message"]
is not None
)
assert response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["code"] == 404
assert read_json(tokens_file) == TOKENS_FILE_CONTETS
def test_graphql_get_and_authorize_used_key( def test_graphql_get_and_authorize_used_key(client, authorized_client):
client, mnemonic_key = graphql_get_new_device_key(authorized_client)
authorized_client,
tokens_file,
):
response = authorized_client.post(
"/graphql",
json={"query": NEW_DEVICE_KEY_MUTATION},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["success"] is True
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["message"] is not None
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["code"] == 200
mnemonic_key = response.json()["data"]["api"]["getNewDeviceApiKey"]["key"]
assert mnemonic_key.split(" ").__len__() == 12
key = Mnemonic(language="english").to_entropy(mnemonic_key).hex()
assert read_json(tokens_file)["new_device"]["token"] == key
response = client.post(
"/graphql",
json={
"query": AUTHORIZE_WITH_NEW_DEVICE_KEY_MUTATION,
"variables": {
"input": {
"key": mnemonic_key,
"deviceName": "new_token",
}
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert (
response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["success"]
is True
)
assert (
response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["message"]
is not None
)
assert response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["code"] == 200
assert (
read_json(tokens_file)["tokens"][2]["token"]
== response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["token"]
)
assert read_json(tokens_file)["tokens"][2]["name"] == "new_token"
response = client.post( graphql_authorize_new_device(client, mnemonic_key, "new_device")
"/graphql", devices = graphql_get_devices(authorized_client)
json={
"query": AUTHORIZE_WITH_NEW_DEVICE_KEY_MUTATION, response = graphql_try_auth_new_device(client, mnemonic_key, "new_device2")
"variables": { assert_errorcode(get_data(response)["api"]["authorizeWithNewDeviceApiKey"], 404)
"input": {
"key": NEW_DEVICE_KEY_MUTATION, assert graphql_get_devices(authorized_client) == devices
"deviceName": "test_token2",
}
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert (
response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["success"]
is False
)
assert (
response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["message"]
is not None
)
assert response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["code"] == 404
assert read_json(tokens_file)["tokens"].__len__() == 3
def test_graphql_get_and_authorize_key_after_12_minutes( def test_graphql_get_and_authorize_key_after_12_minutes(
client, client, authorized_client, mocker
authorized_client,
tokens_file,
): ):
response = authorized_client.post( mnemonic_key = graphql_get_new_device_key(authorized_client)
"/graphql", mock = mocker.patch(DEVICE_KEY_VALIDATION_DATETIME, NearFuture)
json={"query": NEW_DEVICE_KEY_MUTATION},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["success"] is True
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["message"] is not None
assert response.json()["data"]["api"]["getNewDeviceApiKey"]["code"] == 200
assert (
response.json()["data"]["api"]["getNewDeviceApiKey"]["key"].split(" ").__len__()
== 12
)
key = (
Mnemonic(language="english")
.to_entropy(response.json()["data"]["api"]["getNewDeviceApiKey"]["key"])
.hex()
)
assert read_json(tokens_file)["new_device"]["token"] == key
file_data = read_json(tokens_file) response = graphql_try_auth_new_device(client, mnemonic_key, "new_device")
file_data["new_device"]["expiration"] = str( assert_errorcode(get_data(response)["api"]["authorizeWithNewDeviceApiKey"], 404)
datetime.datetime.now() - datetime.timedelta(minutes=13)
)
write_json(tokens_file, file_data)
response = client.post(
"/graphql",
json={
"query": AUTHORIZE_WITH_NEW_DEVICE_KEY_MUTATION,
"variables": {
"input": {
"key": key,
"deviceName": "test_token",
}
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert (
response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["success"]
is False
)
assert (
response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["message"]
is not None
)
assert response.json()["data"]["api"]["authorizeWithNewDeviceApiKey"]["code"] == 404
def test_graphql_authorize_without_token( def test_graphql_authorize_without_token(
client, client,
tokens_file,
): ):
response = client.post( response = client.post(
"/graphql", "/graphql",
@ -567,5 +322,4 @@ def test_graphql_authorize_without_token(
}, },
}, },
) )
assert response.status_code == 200 assert_empty(response)
assert response.json().get("data") is None

View file

@ -1,24 +1,33 @@
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
# pylint: disable=unused-argument # pylint: disable=unused-argument
# pylint: disable=missing-function-docstring # pylint: disable=missing-function-docstring
import datetime
from tests.common import generate_api_query, mnemonic_to_hex, read_json, write_json import pytest
TOKENS_FILE_CONTETS = { from datetime import datetime, timezone
"tokens": [
{ from tests.common import (
"token": "TEST_TOKEN", generate_api_query,
"name": "test_token", assert_recovery_recent,
"date": "2022-01-14 08:31:10.789314", NearFuture,
}, RECOVERY_KEY_VALIDATION_DATETIME,
{ )
"token": "TEST_TOKEN2",
"name": "test_token2", # Graphql API's output should be timezone-naive
"date": "2022-01-14 08:31:10.789314", from tests.common import ten_minutes_into_future_naive_utc as ten_minutes_into_future
}, from tests.common import ten_minutes_into_future as ten_minutes_into_future_tz
] from tests.common import ten_minutes_into_past_naive_utc as ten_minutes_into_past
}
from tests.test_graphql.common import (
assert_empty,
get_data,
assert_ok,
assert_errorcode,
assert_token_valid,
assert_original,
graphql_get_devices,
set_client_token,
)
API_RECOVERY_QUERY = """ API_RECOVERY_QUERY = """
recoveryKey { recoveryKey {
@ -31,28 +40,89 @@ recoveryKey {
""" """
def test_graphql_recovery_key_status_unauthorized(client, tokens_file): def request_recovery_status(client):
response = client.post( return client.post(
"/graphql", "/graphql",
json={"query": generate_api_query([API_RECOVERY_QUERY])}, json={"query": generate_api_query([API_RECOVERY_QUERY])},
) )
assert response.status_code == 200
assert response.json().get("data") is None
def test_graphql_recovery_key_status_when_none_exists(authorized_client, tokens_file): def graphql_recovery_status(client):
response = authorized_client.post( response = request_recovery_status(client)
data = get_data(response)
status = data["api"]["recoveryKey"]
assert status is not None
return status
def request_make_new_recovery_key(client, expires_at=None, uses=None):
json = {"query": API_RECOVERY_KEY_GENERATE_MUTATION}
limits = {}
if expires_at is not None:
limits["expirationDate"] = expires_at.isoformat()
if uses is not None:
limits["uses"] = uses
if limits != {}:
json["variables"] = {"limits": limits}
response = client.post("/graphql", json=json)
return response
def graphql_make_new_recovery_key(client, expires_at=None, uses=None):
response = request_make_new_recovery_key(client, expires_at, uses)
output = get_data(response)["api"]["getNewRecoveryApiKey"]
assert_ok(output)
key = output["key"]
assert key is not None
assert key.split(" ").__len__() == 18
return key
def request_recovery_auth(client, key, device_name):
return client.post(
"/graphql", "/graphql",
json={"query": generate_api_query([API_RECOVERY_QUERY])}, json={
"query": API_RECOVERY_KEY_USE_MUTATION,
"variables": {
"input": {
"key": key,
"deviceName": device_name,
},
},
},
) )
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["recoveryKey"] is not None def graphql_use_recovery_key(client, key, device_name):
assert response.json()["data"]["api"]["recoveryKey"]["exists"] is False response = request_recovery_auth(client, key, device_name)
assert response.json()["data"]["api"]["recoveryKey"]["valid"] is False output = get_data(response)["api"]["useRecoveryApiKey"]
assert response.json()["data"]["api"]["recoveryKey"]["creationDate"] is None assert_ok(output)
assert response.json()["data"]["api"]["recoveryKey"]["expirationDate"] is None
assert response.json()["data"]["api"]["recoveryKey"]["usesLeft"] is None token = output["token"]
assert token is not None
assert_token_valid(client, token)
set_client_token(client, token)
assert device_name in [device["name"] for device in graphql_get_devices(client)]
return token
def test_graphql_recovery_key_status_unauthorized(client):
response = request_recovery_status(client)
assert_empty(response)
def test_graphql_recovery_key_status_when_none_exists(authorized_client):
status = graphql_recovery_status(authorized_client)
assert status["exists"] is False
assert status["valid"] is False
assert status["creationDate"] is None
assert status["expirationDate"] is None
assert status["usesLeft"] is None
API_RECOVERY_KEY_GENERATE_MUTATION = """ API_RECOVERY_KEY_GENERATE_MUTATION = """
@ -82,287 +152,86 @@ mutation TestUseRecoveryKey($input: UseRecoveryKeyInput!) {
""" """
def test_graphql_generate_recovery_key(client, authorized_client, tokens_file): def test_graphql_generate_recovery_key(client, authorized_client):
response = authorized_client.post( key = graphql_make_new_recovery_key(authorized_client)
"/graphql",
json={
"query": API_RECOVERY_KEY_GENERATE_MUTATION,
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["success"] is True
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["code"] == 200
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["key"] is not None
assert (
response.json()["data"]["api"]["getNewRecoveryApiKey"]["key"]
.split(" ")
.__len__()
== 18
)
assert read_json(tokens_file)["recovery_token"] is not None
time_generated = read_json(tokens_file)["recovery_token"]["date"]
assert time_generated is not None
key = response.json()["data"]["api"]["getNewRecoveryApiKey"]["key"]
assert (
datetime.datetime.strptime(time_generated, "%Y-%m-%dT%H:%M:%S.%f")
- datetime.timedelta(seconds=5)
< datetime.datetime.now()
)
# Try to get token status status = graphql_recovery_status(authorized_client)
response = authorized_client.post( assert status["exists"] is True
"/graphql", assert status["valid"] is True
json={"query": generate_api_query([API_RECOVERY_QUERY])}, assert_recovery_recent(status["creationDate"])
) assert status["expirationDate"] is None
assert response.status_code == 200 assert status["usesLeft"] is None
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["recoveryKey"] is not None
assert response.json()["data"]["api"]["recoveryKey"]["exists"] is True
assert response.json()["data"]["api"]["recoveryKey"]["valid"] is True
assert response.json()["data"]["api"]["recoveryKey"][
"creationDate"
] == time_generated.replace("Z", "")
assert response.json()["data"]["api"]["recoveryKey"]["expirationDate"] is None
assert response.json()["data"]["api"]["recoveryKey"]["usesLeft"] is None
# Try to use token graphql_use_recovery_key(client, key, "new_test_token")
response = client.post( # And again
"/graphql", graphql_use_recovery_key(client, key, "new_test_token2")
json={
"query": API_RECOVERY_KEY_USE_MUTATION,
"variables": {
"input": {
"key": key,
"deviceName": "new_test_token",
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["success"] is True
assert response.json()["data"]["api"]["useRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["code"] == 200
assert response.json()["data"]["api"]["useRecoveryApiKey"]["token"] is not None
assert (
response.json()["data"]["api"]["useRecoveryApiKey"]["token"]
== read_json(tokens_file)["tokens"][2]["token"]
)
assert read_json(tokens_file)["tokens"][2]["name"] == "new_test_token"
# Try to use token again
response = client.post(
"/graphql",
json={
"query": API_RECOVERY_KEY_USE_MUTATION,
"variables": {
"input": {
"key": key,
"deviceName": "new_test_token2",
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["success"] is True
assert response.json()["data"]["api"]["useRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["code"] == 200
assert response.json()["data"]["api"]["useRecoveryApiKey"]["token"] is not None
assert (
response.json()["data"]["api"]["useRecoveryApiKey"]["token"]
== read_json(tokens_file)["tokens"][3]["token"]
)
assert read_json(tokens_file)["tokens"][3]["name"] == "new_test_token2"
@pytest.mark.parametrize(
"expiration_date", [ten_minutes_into_future(), ten_minutes_into_future_tz()]
)
def test_graphql_generate_recovery_key_with_expiration_date( def test_graphql_generate_recovery_key_with_expiration_date(
client, authorized_client, tokens_file client, authorized_client, expiration_date: datetime
): ):
expiration_date = datetime.datetime.now() + datetime.timedelta(minutes=5) key = graphql_make_new_recovery_key(authorized_client, expires_at=expiration_date)
expiration_date_str = expiration_date.strftime("%Y-%m-%dT%H:%M:%S.%f")
response = authorized_client.post(
"/graphql",
json={
"query": API_RECOVERY_KEY_GENERATE_MUTATION,
"variables": {
"limits": {
"expirationDate": expiration_date_str,
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["success"] is True
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["code"] == 200
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["key"] is not None
assert (
response.json()["data"]["api"]["getNewRecoveryApiKey"]["key"]
.split(" ")
.__len__()
== 18
)
assert read_json(tokens_file)["recovery_token"] is not None
key = response.json()["data"]["api"]["getNewRecoveryApiKey"]["key"] status = graphql_recovery_status(authorized_client)
assert read_json(tokens_file)["recovery_token"]["expiration"] == expiration_date_str assert status["exists"] is True
assert read_json(tokens_file)["recovery_token"]["token"] == mnemonic_to_hex(key) assert status["valid"] is True
assert_recovery_recent(status["creationDate"])
time_generated = read_json(tokens_file)["recovery_token"]["date"] # timezone-aware comparison. Should pass regardless of server's tz
assert time_generated is not None assert datetime.fromisoformat(status["expirationDate"]) == expiration_date.replace(
assert ( tzinfo=timezone.utc
datetime.datetime.strptime(time_generated, "%Y-%m-%dT%H:%M:%S.%f")
- datetime.timedelta(seconds=5)
< datetime.datetime.now()
) )
# Try to get token status assert status["usesLeft"] is None
response = authorized_client.post(
"/graphql",
json={"query": generate_api_query([API_RECOVERY_QUERY])},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["recoveryKey"] is not None
assert response.json()["data"]["api"]["recoveryKey"]["exists"] is True
assert response.json()["data"]["api"]["recoveryKey"]["valid"] is True
assert response.json()["data"]["api"]["recoveryKey"][
"creationDate"
] == time_generated.replace("Z", "")
assert (
response.json()["data"]["api"]["recoveryKey"]["expirationDate"]
== expiration_date_str
)
assert response.json()["data"]["api"]["recoveryKey"]["usesLeft"] is None
# Try to use token graphql_use_recovery_key(client, key, "new_test_token")
response = authorized_client.post( # And again
"/graphql", graphql_use_recovery_key(client, key, "new_test_token2")
json={
"query": API_RECOVERY_KEY_USE_MUTATION,
"variables": {
"input": {
"key": key,
"deviceName": "new_test_token",
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["success"] is True
assert response.json()["data"]["api"]["useRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["code"] == 200
assert response.json()["data"]["api"]["useRecoveryApiKey"]["token"] is not None
assert (
response.json()["data"]["api"]["useRecoveryApiKey"]["token"]
== read_json(tokens_file)["tokens"][2]["token"]
)
# Try to use token again
response = authorized_client.post(
"/graphql",
json={
"query": API_RECOVERY_KEY_USE_MUTATION,
"variables": {
"input": {
"key": key,
"deviceName": "new_test_token2",
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["success"] is True
assert response.json()["data"]["api"]["useRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["code"] == 200
assert response.json()["data"]["api"]["useRecoveryApiKey"]["token"] is not None
assert (
response.json()["data"]["api"]["useRecoveryApiKey"]["token"]
== read_json(tokens_file)["tokens"][3]["token"]
)
# Try to use token after expiration date
new_data = read_json(tokens_file)
new_data["recovery_token"]["expiration"] = (
datetime.datetime.now() - datetime.timedelta(minutes=5)
).strftime("%Y-%m-%dT%H:%M:%S.%f")
write_json(tokens_file, new_data)
response = authorized_client.post(
"/graphql",
json={
"query": API_RECOVERY_KEY_USE_MUTATION,
"variables": {
"input": {
"key": key,
"deviceName": "new_test_token3",
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["success"] is False
assert response.json()["data"]["api"]["useRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["code"] == 404
assert response.json()["data"]["api"]["useRecoveryApiKey"]["token"] is None
assert read_json(tokens_file)["tokens"] == new_data["tokens"]
# Try to get token status
response = authorized_client.post(
"/graphql",
json={"query": generate_api_query([API_RECOVERY_QUERY])},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["recoveryKey"] is not None
assert response.json()["data"]["api"]["recoveryKey"]["exists"] is True
assert response.json()["data"]["api"]["recoveryKey"]["valid"] is False
assert (
response.json()["data"]["api"]["recoveryKey"]["creationDate"] == time_generated
)
assert (
response.json()["data"]["api"]["recoveryKey"]["expirationDate"]
== new_data["recovery_token"]["expiration"]
)
assert response.json()["data"]["api"]["recoveryKey"]["usesLeft"] is None
def test_graphql_generate_recovery_key_with_expiration_in_the_past( def test_graphql_use_recovery_key_after_expiration(client, authorized_client, mocker):
authorized_client, tokens_file expiration_date = ten_minutes_into_future()
): key = graphql_make_new_recovery_key(authorized_client, expires_at=expiration_date)
expiration_date = datetime.datetime.now() - datetime.timedelta(minutes=5)
expiration_date_str = expiration_date.strftime("%Y-%m-%dT%H:%M:%S.%f")
response = authorized_client.post( # Timewarp to after it expires
"/graphql", mock = mocker.patch(RECOVERY_KEY_VALIDATION_DATETIME, NearFuture)
json={
"query": API_RECOVERY_KEY_GENERATE_MUTATION, response = request_recovery_auth(client, key, "new_test_token3")
"variables": { output = get_data(response)["api"]["useRecoveryApiKey"]
"limits": { assert_errorcode(output, 404)
"expirationDate": expiration_date_str,
}, assert output["token"] is None
}, assert_original(authorized_client)
},
status = graphql_recovery_status(authorized_client)
assert status["exists"] is True
assert status["valid"] is False
assert_recovery_recent(status["creationDate"])
# timezone-aware comparison. Should pass regardless of server's tz
assert datetime.fromisoformat(status["expirationDate"]) == expiration_date.replace(
tzinfo=timezone.utc
) )
assert response.status_code == 200 assert status["usesLeft"] is None
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["success"] is False
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["code"] == 400
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["key"] is None
assert "recovery_token" not in read_json(tokens_file)
def test_graphql_generate_recovery_key_with_invalid_time_format( def test_graphql_generate_recovery_key_with_expiration_in_the_past(authorized_client):
authorized_client, tokens_file expiration_date = ten_minutes_into_past()
): response = request_make_new_recovery_key(
authorized_client, expires_at=expiration_date
)
output = get_data(response)["api"]["getNewRecoveryApiKey"]
assert_errorcode(output, 400)
assert output["key"] is None
assert graphql_recovery_status(authorized_client)["exists"] is False
def test_graphql_generate_recovery_key_with_invalid_time_format(authorized_client):
expiration_date = "invalid_time_format" expiration_date = "invalid_time_format"
expiration_date_str = expiration_date expiration_date_str = expiration_date
@ -377,183 +246,56 @@ def test_graphql_generate_recovery_key_with_invalid_time_format(
}, },
}, },
) )
assert response.status_code == 200 assert_empty(response)
assert response.json().get("data") is None assert graphql_recovery_status(authorized_client)["exists"] is False
assert "recovery_token" not in read_json(tokens_file)
def test_graphql_generate_recovery_key_with_limited_uses( def test_graphql_generate_recovery_key_with_limited_uses(authorized_client, client):
authorized_client, tokens_file mnemonic_key = graphql_make_new_recovery_key(authorized_client, uses=2)
):
response = authorized_client.post( status = graphql_recovery_status(authorized_client)
"/graphql", assert status["exists"] is True
json={ assert status["valid"] is True
"query": API_RECOVERY_KEY_GENERATE_MUTATION, assert status["creationDate"] is not None
"variables": { assert status["expirationDate"] is None
"limits": { assert status["usesLeft"] == 2
"expirationDate": None,
"uses": 2,
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["success"] is True
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["code"] == 200
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["key"] is not None
mnemonic_key = response.json()["data"]["api"]["getNewRecoveryApiKey"]["key"] graphql_use_recovery_key(client, mnemonic_key, "new_test_token1")
key = mnemonic_to_hex(mnemonic_key)
assert read_json(tokens_file)["recovery_token"]["token"] == key status = graphql_recovery_status(authorized_client)
assert read_json(tokens_file)["recovery_token"]["uses_left"] == 2 assert status["exists"] is True
assert status["valid"] is True
assert status["creationDate"] is not None
assert status["expirationDate"] is None
assert status["usesLeft"] == 1
# Try to get token status graphql_use_recovery_key(client, mnemonic_key, "new_test_token2")
response = authorized_client.post(
"/graphql",
json={"query": generate_api_query([API_RECOVERY_QUERY])},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["recoveryKey"] is not None
assert response.json()["data"]["api"]["recoveryKey"]["exists"] is True
assert response.json()["data"]["api"]["recoveryKey"]["valid"] is True
assert response.json()["data"]["api"]["recoveryKey"]["creationDate"] is not None
assert response.json()["data"]["api"]["recoveryKey"]["expirationDate"] is None
assert response.json()["data"]["api"]["recoveryKey"]["usesLeft"] == 2
# Try to use token status = graphql_recovery_status(authorized_client)
response = authorized_client.post( assert status["exists"] is True
"/graphql", assert status["valid"] is False
json={ assert status["creationDate"] is not None
"query": API_RECOVERY_KEY_USE_MUTATION, assert status["expirationDate"] is None
"variables": { assert status["usesLeft"] == 0
"input": {
"key": mnemonic_key,
"deviceName": "test_token1",
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["success"] is True
assert response.json()["data"]["api"]["useRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["code"] == 200
assert response.json()["data"]["api"]["useRecoveryApiKey"]["token"] is not None
# Try to get token status response = request_recovery_auth(client, mnemonic_key, "new_test_token3")
response = authorized_client.post( output = get_data(response)["api"]["useRecoveryApiKey"]
"/graphql", assert_errorcode(output, 404)
json={"query": generate_api_query([API_RECOVERY_QUERY])},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["recoveryKey"] is not None
assert response.json()["data"]["api"]["recoveryKey"]["exists"] is True
assert response.json()["data"]["api"]["recoveryKey"]["valid"] is True
assert response.json()["data"]["api"]["recoveryKey"]["creationDate"] is not None
assert response.json()["data"]["api"]["recoveryKey"]["expirationDate"] is None
assert response.json()["data"]["api"]["recoveryKey"]["usesLeft"] == 1
# Try to use token
response = authorized_client.post(
"/graphql",
json={
"query": API_RECOVERY_KEY_USE_MUTATION,
"variables": {
"input": {
"key": mnemonic_key,
"deviceName": "test_token2",
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["success"] is True
assert response.json()["data"]["api"]["useRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["code"] == 200
assert response.json()["data"]["api"]["useRecoveryApiKey"]["token"] is not None
# Try to get token status
response = authorized_client.post(
"/graphql",
json={"query": generate_api_query([API_RECOVERY_QUERY])},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["recoveryKey"] is not None
assert response.json()["data"]["api"]["recoveryKey"]["exists"] is True
assert response.json()["data"]["api"]["recoveryKey"]["valid"] is False
assert response.json()["data"]["api"]["recoveryKey"]["creationDate"] is not None
assert response.json()["data"]["api"]["recoveryKey"]["expirationDate"] is None
assert response.json()["data"]["api"]["recoveryKey"]["usesLeft"] == 0
# Try to use token
response = authorized_client.post(
"/graphql",
json={
"query": API_RECOVERY_KEY_USE_MUTATION,
"variables": {
"input": {
"key": mnemonic_key,
"deviceName": "test_token3",
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["success"] is False
assert response.json()["data"]["api"]["useRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["useRecoveryApiKey"]["code"] == 404
assert response.json()["data"]["api"]["useRecoveryApiKey"]["token"] is None
def test_graphql_generate_recovery_key_with_negative_uses( def test_graphql_generate_recovery_key_with_negative_uses(authorized_client):
authorized_client, tokens_file response = request_make_new_recovery_key(authorized_client, uses=-1)
):
# Try to get token status output = get_data(response)["api"]["getNewRecoveryApiKey"]
response = authorized_client.post( assert_errorcode(output, 400)
"/graphql", assert output["key"] is None
json={ assert graphql_recovery_status(authorized_client)["exists"] is False
"query": API_RECOVERY_KEY_GENERATE_MUTATION,
"variables": {
"limits": {
"uses": -1,
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["success"] is False
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["code"] == 400
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["key"] is None
def test_graphql_generate_recovery_key_with_zero_uses(authorized_client, tokens_file): def test_graphql_generate_recovery_key_with_zero_uses(authorized_client):
# Try to get token status response = request_make_new_recovery_key(authorized_client, uses=0)
response = authorized_client.post(
"/graphql", output = get_data(response)["api"]["getNewRecoveryApiKey"]
json={ assert_errorcode(output, 400)
"query": API_RECOVERY_KEY_GENERATE_MUTATION, assert output["key"] is None
"variables": { assert graphql_recovery_status(authorized_client)["exists"] is False
"limits": {
"uses": 0,
},
},
},
)
assert response.status_code == 200
assert response.json().get("data") is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["success"] is False
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["message"] is not None
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["code"] == 400
assert response.json()["data"]["api"]["getNewRecoveryApiKey"]["key"] is None

View file

@ -1,218 +0,0 @@
# pylint: disable=redefined-outer-name
# pylint: disable=unused-argument
# pylint: disable=missing-function-docstring
"""
tests that restrict json token repository implementation
"""
import pytest
from datetime import datetime
from selfprivacy_api.models.tokens.token import Token
from selfprivacy_api.repositories.tokens.exceptions import (
TokenNotFound,
RecoveryKeyNotFound,
NewDeviceKeyNotFound,
)
from selfprivacy_api.repositories.tokens.json_tokens_repository import (
JsonTokensRepository,
)
from tests.common import read_json
from test_tokens_repository import (
mock_recovery_key_generate,
mock_generate_token,
mock_new_device_key_generate,
empty_keys,
)
ORIGINAL_TOKEN_CONTENT = [
{
"token": "KG9ni-B-CMPk327Zv1qC7YBQaUGaBUcgdkvMvQ2atFI",
"name": "primary_token",
"date": "2022-07-15 17:41:31.675698",
},
{
"token": "3JKgLOtFu6ZHgE4OU-R-VdW47IKpg-YQL0c6n7bol68",
"name": "second_token",
"date": "2022-07-15 17:41:31.675698Z",
},
{
"token": "LYiwFDekvALKTQSjk7vtMQuNP_6wqKuV-9AyMKytI_8",
"name": "third_token",
"date": "2022-07-15T17:41:31.675698Z",
},
{
"token": "dD3CFPcEZvapscgzWb7JZTLog7OMkP7NzJeu2fAazXM",
"name": "forth_token",
"date": "2022-07-15T17:41:31.675698",
},
]
@pytest.fixture
def tokens(mocker, datadir):
mocker.patch("selfprivacy_api.utils.TOKENS_FILE", new=datadir / "tokens.json")
assert read_json(datadir / "tokens.json")["tokens"] == ORIGINAL_TOKEN_CONTENT
return datadir
@pytest.fixture
def null_keys(mocker, datadir):
mocker.patch("selfprivacy_api.utils.TOKENS_FILE", new=datadir / "null_keys.json")
assert read_json(datadir / "null_keys.json")["recovery_token"] is None
assert read_json(datadir / "null_keys.json")["new_device"] is None
return datadir
def test_delete_token(tokens):
repo = JsonTokensRepository()
input_token = Token(
token="KG9ni-B-CMPk327Zv1qC7YBQaUGaBUcgdkvMvQ2atFI",
device_name="primary_token",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
)
repo.delete_token(input_token)
assert read_json(tokens / "tokens.json")["tokens"] == [
{
"token": "3JKgLOtFu6ZHgE4OU-R-VdW47IKpg-YQL0c6n7bol68",
"name": "second_token",
"date": "2022-07-15 17:41:31.675698Z",
},
{
"token": "LYiwFDekvALKTQSjk7vtMQuNP_6wqKuV-9AyMKytI_8",
"name": "third_token",
"date": "2022-07-15T17:41:31.675698Z",
},
{
"token": "dD3CFPcEZvapscgzWb7JZTLog7OMkP7NzJeu2fAazXM",
"name": "forth_token",
"date": "2022-07-15T17:41:31.675698",
},
]
def test_delete_not_found_token(tokens):
repo = JsonTokensRepository()
input_token = Token(
token="imbadtoken",
device_name="primary_token",
created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
)
with pytest.raises(TokenNotFound):
assert repo.delete_token(input_token) is None
assert read_json(tokens / "tokens.json")["tokens"] == ORIGINAL_TOKEN_CONTENT
def test_create_recovery_key(tokens, mock_recovery_key_generate):
repo = JsonTokensRepository()
assert repo.create_recovery_key(uses_left=1, expiration=None) is not None
assert read_json(tokens / "tokens.json")["recovery_token"] == {
"token": "889bf49c1d3199d71a2e704718772bd53a422020334db051",
"date": "2022-07-15T17:41:31.675698",
"expiration": None,
"uses_left": 1,
}
def test_use_mnemonic_recovery_key_when_null(null_keys):
repo = JsonTokensRepository()
with pytest.raises(RecoveryKeyNotFound):
assert (
repo.use_mnemonic_recovery_key(
mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb",
device_name="primary_token",
)
is None
)
def test_use_mnemonic_recovery_key(tokens, mock_generate_token):
repo = JsonTokensRepository()
assert repo.use_mnemonic_recovery_key(
mnemonic_phrase="uniform clarify napkin bid dress search input armor police cross salon because myself uphold slice bamboo hungry park",
device_name="newdevice",
) == Token(
token="ur71mC4aiI6FIYAN--cTL-38rPHS5D6NuB1bgN_qKF4",
device_name="newdevice",
created_at=datetime(2022, 11, 14, 6, 6, 32, 777123),
)
assert read_json(tokens / "tokens.json")["tokens"] == [
{
"date": "2022-07-15 17:41:31.675698",
"name": "primary_token",
"token": "KG9ni-B-CMPk327Zv1qC7YBQaUGaBUcgdkvMvQ2atFI",
},
{
"token": "3JKgLOtFu6ZHgE4OU-R-VdW47IKpg-YQL0c6n7bol68",
"name": "second_token",
"date": "2022-07-15 17:41:31.675698Z",
},
{
"token": "LYiwFDekvALKTQSjk7vtMQuNP_6wqKuV-9AyMKytI_8",
"name": "third_token",
"date": "2022-07-15T17:41:31.675698Z",
},
{
"token": "dD3CFPcEZvapscgzWb7JZTLog7OMkP7NzJeu2fAazXM",
"name": "forth_token",
"date": "2022-07-15T17:41:31.675698",
},
{
"date": "2022-11-14T06:06:32.777123",
"name": "newdevice",
"token": "ur71mC4aiI6FIYAN--cTL-38rPHS5D6NuB1bgN_qKF4",
},
]
assert read_json(tokens / "tokens.json")["recovery_token"] == {
"date": "2022-11-11T11:48:54.228038",
"expiration": None,
"token": "ed653e4b8b042b841d285fa7a682fa09e925ddb2d8906f54",
"uses_left": 1,
}
def test_get_new_device_key(tokens, mock_new_device_key_generate):
repo = JsonTokensRepository()
assert repo.get_new_device_key() is not None
assert read_json(tokens / "tokens.json")["new_device"] == {
"date": "2022-07-15T17:41:31.675698",
"expiration": "2022-07-15T17:41:31.675698",
"token": "43478d05b35e4781598acd76e33832bb",
}
def test_delete_new_device_key(tokens):
repo = JsonTokensRepository()
assert repo.delete_new_device_key() is None
assert "new_device" not in read_json(tokens / "tokens.json")
def test_delete_new_device_key_when_empty(empty_keys):
repo = JsonTokensRepository()
repo.delete_new_device_key()
assert "new_device" not in read_json(empty_keys / "empty_keys.json")
def test_use_mnemonic_new_device_key_when_null(null_keys):
repo = JsonTokensRepository()
with pytest.raises(NewDeviceKeyNotFound):
assert (
repo.use_mnemonic_new_device_key(
device_name="imnew",
mnemonic_phrase="captain ribbon toddler settle symbol minute step broccoli bless universe divide bulb",
)
is None
)

View file

@ -1,9 +0,0 @@
{
"tokens": [
{
"token": "KG9ni-B-CMPk327Zv1qC7YBQaUGaBUcgdkvMvQ2atFI",
"name": "primary_token",
"date": "2022-07-15 17:41:31.675698"
}
]
}

View file

@ -1,26 +0,0 @@
{
"tokens": [
{
"token": "KG9ni-B-CMPk327Zv1qC7YBQaUGaBUcgdkvMvQ2atFI",
"name": "primary_token",
"date": "2022-07-15 17:41:31.675698"
},
{
"token": "3JKgLOtFu6ZHgE4OU-R-VdW47IKpg-YQL0c6n7bol68",
"name": "second_token",
"date": "2022-07-15 17:41:31.675698Z"
},
{
"token": "LYiwFDekvALKTQSjk7vtMQuNP_6wqKuV-9AyMKytI_8",
"name": "third_token",
"date": "2022-07-15T17:41:31.675698Z"
},
{
"token": "dD3CFPcEZvapscgzWb7JZTLog7OMkP7NzJeu2fAazXM",
"name": "forth_token",
"date": "2022-07-15T17:41:31.675698"
}
],
"recovery_token": null,
"new_device": null
}

View file

@ -1,35 +0,0 @@
{
"tokens": [
{
"token": "KG9ni-B-CMPk327Zv1qC7YBQaUGaBUcgdkvMvQ2atFI",
"name": "primary_token",
"date": "2022-07-15 17:41:31.675698"
},
{
"token": "3JKgLOtFu6ZHgE4OU-R-VdW47IKpg-YQL0c6n7bol68",
"name": "second_token",
"date": "2022-07-15 17:41:31.675698Z"
},
{
"token": "LYiwFDekvALKTQSjk7vtMQuNP_6wqKuV-9AyMKytI_8",
"name": "third_token",
"date": "2022-07-15T17:41:31.675698Z"
},
{
"token": "dD3CFPcEZvapscgzWb7JZTLog7OMkP7NzJeu2fAazXM",
"name": "forth_token",
"date": "2022-07-15T17:41:31.675698"
}
],
"recovery_token": {
"token": "ed653e4b8b042b841d285fa7a682fa09e925ddb2d8906f54",
"date": "2022-11-11T11:48:54.228038",
"expiration": null,
"uses_left": 2
},
"new_device": {
"token": "2237238de23dc71ab558e317bdb8ff8e",
"date": "2022-10-26 20:50:47.973212",
"expiration": "2022-10-26 21:00:47.974153"
}
}

Some files were not shown because too many files have changed in this diff Show more