This commit is contained in:
Inex Code 2022-01-17 13:28:17 +02:00
parent fe86382819
commit ade7c77754
6 changed files with 90 additions and 64 deletions

View file

@ -45,7 +45,7 @@ class CreateTokensJson(Migration):
"tokens": [ "tokens": [
{ {
"token": token, "token": token,
"name": "Master Token", "name": "primary_token",
"date": str(datetime.now()), "date": str(datetime.now()),
} }
] ]

View file

@ -8,6 +8,8 @@ from selfprivacy_api.utils.auth import (
delete_token, delete_token,
get_tokens_info, get_tokens_info,
delete_token, delete_token,
is_token_name_exists,
is_token_name_pair_valid,
refresh_token, refresh_token,
is_token_valid, is_token_valid,
) )
@ -48,13 +50,13 @@ class Tokens(Resource):
- in: body - in: body
name: token name: token
required: true required: true
description: Token to delete description: Token's name to delete
schema: schema:
type: object type: object
properties: properties:
token: token:
type: string type: string
description: Token to delete description: Token name to delete
responses: responses:
200: 200:
description: Token deleted description: Token deleted
@ -64,14 +66,14 @@ class Tokens(Resource):
description: Token not found description: Token not found
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("token", type=str, required=True, help="Token to delete") parser.add_argument("token_name", type=str, required=True, help="Token to delete")
args = parser.parse_args() args = parser.parse_args()
token = args["token"] token_name = args["token"]
if request.headers.get("Authorization") == f"Bearer {token}": if is_token_name_pair_valid(token_name, request.headers.get("Authorization").split(" ")[1]):
return {"message": "Cannot delete caller's token"}, 400 return {"message": "Cannot delete caller's token"}, 400
if not is_token_valid(token): if not is_token_name_exists(token_name):
return {"message": "Token not found"}, 404 return {"message": "Token not found"}, 404
delete_token(token) delete_token(token_name)
return {"message": "Token deleted"}, 200 return {"message": "Token deleted"}, 200
def post(self): def post(self):

View file

@ -113,17 +113,20 @@ class RecoveryToken(Resource):
""" """
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument( parser.add_argument(
"expiration", type=str, required=True, help="Token expiration date" "expiration", type=str, required=False, help="Token expiration date"
) )
parser.add_argument("uses", type=int, required=True, help="Token uses") parser.add_argument("uses", type=int, required=False, help="Token uses")
args = parser.parse_args() args = parser.parse_args()
# Convert expiration date to datetime and return 400 if it is not valid # Convert expiration date to datetime and return 400 if it is not valid
if args["expiration"]:
try: try:
expiration = datetime.strptime(args["expiration"], "%Y-%m-%dT%H:%M:%S.%fZ") expiration = datetime.strptime(args["expiration"], "%Y-%m-%dT%H:%M:%S.%fZ")
except ValueError: except ValueError:
return { return {
"error": "Invalid expiration date. Use YYYY-MM-DDTHH:MM:SS.SSSZ" "error": "Invalid expiration date. Use YYYY-MM-DDTHH:MM:SS.SSSZ"
}, 400 }, 400
else:
expiration = None
# Generate recovery token # Generate recovery token
token = generate_recovery_token(expiration, args["uses"]) token = generate_recovery_token(expiration, args["uses"])
return {"token": token} return {"token": token}

View file

@ -70,6 +70,18 @@ def is_token_valid(token):
return True return True
return False return False
def is_token_name_exists(token_name):
"""Check if token name exists"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
return token_name in [t["name"] for t in tokens["tokens"]]
def is_token_name_pair_valid(token_name, token):
"""Check if token name and token pair exists"""
with ReadUserData(UserDataFiles.TOKENS) as tokens:
for t in tokens["tokens"]:
if t["name"] == token_name and t["token"] == token:
return True
return False
def get_tokens_info(): def get_tokens_info():
"""Get all tokens info without tokens themselves""" """Get all tokens info without tokens themselves"""
@ -90,21 +102,22 @@ def _generate_token():
def create_token(name): def create_token(name):
"""Create new token""" """Create new token"""
token = _generate_token() token = _generate_token()
name = _validate_token_name(name)
with WriteUserData(UserDataFiles.TOKENS) as tokens: with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["tokens"].append( tokens["tokens"].append(
{ {
"token": token, "token": token,
"name": _validate_token_name(name), "name": name,
"date": str(datetime.now()), "date": str(datetime.now()),
} }
) )
return token return token
def delete_token(token): def delete_token(token_name):
"""Delete token""" """Delete token"""
with WriteUserData(UserDataFiles.TOKENS) as tokens: with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["tokens"] = [t for t in tokens["tokens"] if t["token"] != token] tokens["tokens"] = [t for t in tokens["tokens"] if t["name"] != token_name]
def refresh_token(token): def refresh_token(token):
@ -198,6 +211,8 @@ def use_mnemonic_recoverery_token(mnemonic_phrase, name):
Substract 1 from uses_left if it exists. Substract 1 from uses_left if it exists.
mnemonic_phrase is a string representation of the mnemonic word list. mnemonic_phrase is a string representation of the mnemonic word list.
""" """
if not is_recovery_token_valid():
return None
recovery_token_str = _get_recovery_token() recovery_token_str = _get_recovery_token()
if recovery_token_str is None: if recovery_token_str is None:
return None return None
@ -208,11 +223,12 @@ def use_mnemonic_recoverery_token(mnemonic_phrase, name):
if phrase_bytes != recovery_token: if phrase_bytes != recovery_token:
return None return None
token = _generate_token() token = _generate_token()
name = _validate_token_name(name)
with WriteUserData(UserDataFiles.TOKENS) as tokens: with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["tokens"].append( tokens["tokens"].append(
{ {
"token": token, "token": token,
"name": _validate_token_name(name), "name": name,
"date": str(datetime.now()), "date": str(datetime.now()),
} }
) )
@ -226,7 +242,7 @@ def get_new_device_auth_token():
"""Generate a new device auth token which is valid for 10 minutes and return a mnemonic phrase representation """Generate a new device auth token which is valid for 10 minutes and return a mnemonic phrase representation
Write token to the new_device of the tokens.json file. Write token to the new_device of the tokens.json file.
""" """
token = secrets.token_bytes(24) token = secrets.token_bytes(16)
token_str = token.hex() token_str = token.hex()
with WriteUserData(UserDataFiles.TOKENS) as tokens: with WriteUserData(UserDataFiles.TOKENS) as tokens:
tokens["new_device"] = { tokens["new_device"] = {

View file

@ -14,3 +14,8 @@ def test_get_api_version_unauthorized(client):
response = client.get("/api/version") response = client.get("/api/version")
assert response.status_code == 200 assert response.status_code == 200
assert "version" in response.get_json() assert "version" in response.get_json()
def test_get_swagger_json(authorized_client):
response = authorized_client.get("/api/swagger.json")
assert response.status_code == 200
assert "swagger" in response.get_json()