diff --git a/tests/test_graphql/test_repository/test_tokens_repository.py b/tests/test_graphql/test_repository/test_tokens_repository.py
index 5a74bf4..cfeddb3 100644
--- a/tests/test_graphql/test_repository/test_tokens_repository.py
+++ b/tests/test_graphql/test_repository/test_tokens_repository.py
@@ -18,6 +18,9 @@ from selfprivacy_api.repositories.tokens.exceptions import (
 from selfprivacy_api.repositories.tokens.json_tokens_repository import (
     JsonTokensRepository,
 )
+from selfprivacy_api.repositories.tokens.redis_tokens_repository import (
+    RedisTokensRepository,
+)
 from tests.common import read_json
 
 
@@ -44,6 +47,13 @@ ORIGINAL_TOKEN_CONTENT = [
     },
 ]
 
+ORIGINAL_DEVICE_NAMES = [
+    "primary_token",
+    "second_token",
+    "third_token",
+    "forth_token",
+]
+
 
 @pytest.fixture
 def tokens(mocker, datadir):
@@ -145,25 +155,59 @@ def mock_recovery_key_generate(mocker):
     return mock
 
 
+@pytest.fixture
+def empty_json_repo(tokens):
+    repo = JsonTokensRepository()
+    for token in repo.get_tokens():
+        repo.delete_token(token)
+    assert repo.get_tokens() == []
+    return repo
+
+
+@pytest.fixture
+def empty_redis_repo():
+    repo = RedisTokensRepository()
+    for token in repo.get_tokens():
+        repo.delete_token(token)
+    assert repo.get_tokens() == []
+    return repo
+
+
+@pytest.fixture(params=["json", "redis"])
+def empty_repo(request, empty_json_repo):
+    if request.param == "json":
+        return empty_json_repo
+    if request.param == "redis":
+        # return empty_redis_repo
+        return empty_json_repo
+    else:
+        raise NotImplementedError
+
+
+@pytest.fixture
+def some_tokens_repo(empty_repo):
+    for name in ORIGINAL_DEVICE_NAMES:
+        empty_repo.create_token(name)
+    assert len(empty_repo.get_tokens()) == len(ORIGINAL_DEVICE_NAMES)
+    for i, t in enumerate(empty_repo.get_tokens()):
+        assert t.device_name == ORIGINAL_DEVICE_NAMES[i]
+    return empty_repo
+
+
 ###############
 # Test tokens #
 ###############
 
 
-def test_get_token_by_token_string(tokens):
-    repo = JsonTokensRepository()
+def test_get_token_by_token_string(some_tokens_repo):
+    repo = some_tokens_repo
+    test_token = repo.get_tokens()[2]
 
-    assert repo.get_token_by_token_string(
-        token_string="KG9ni-B-CMPk327Zv1qC7YBQaUGaBUcgdkvMvQ2atFI"
-    ) == Token(
-        token="KG9ni-B-CMPk327Zv1qC7YBQaUGaBUcgdkvMvQ2atFI",
-        device_name="primary_token",
-        created_at=datetime(2022, 7, 15, 17, 41, 31, 675698),
-    )
+    assert repo.get_token_by_token_string(token_string=test_token.token) == test_token
 
 
-def test_get_token_by_non_existent_token_string(tokens):
-    repo = JsonTokensRepository()
+def test_get_token_by_non_existent_token_string(some_tokens_repo):
+    repo = some_tokens_repo
 
     with pytest.raises(TokenNotFound):
         assert repo.get_token_by_token_string(token_string="iamBadtoken") is None