From 781d3c04efa4ff31a4f9826caa88f5e6265a19bc Mon Sep 17 00:00:00 2001 From: LoRiot Date: Sun, 16 Oct 2022 22:50:37 +0300 Subject: [PATCH] Implement using multiple repos --- handlers/commands.py | 16 ++++++++------ middlewares/db.py | 4 ++-- services/repositories.py | 46 ++++++++++++++++++++++++++++++++++++++++ services/repository.py | 20 ----------------- 4 files changed, 58 insertions(+), 28 deletions(-) create mode 100644 services/repositories.py delete mode 100644 services/repository.py diff --git a/handlers/commands.py b/handlers/commands.py index 21dd6d2..9aa3358 100644 --- a/handlers/commands.py +++ b/handlers/commands.py @@ -1,20 +1,24 @@ from loguru import logger from aiogram import types, Dispatcher -from services.repository import Repo +from keyboards.default import default_kb +from services.repositories import Repos, UserRepo -async def start(message: types.Message, repo: Repo): - user = await repo.get_user(message.from_user.id) +async def start(message: types.Message, repo: Repos): + user = await repo.get_repo(UserRepo).get(message.from_user.id) if user is not None: await message.answer( "Добро пожаловать в Бота Теплицы Социальных Технологий " - "для организации антивоенных проектов." + "для организации антивоенных проектов.", + reply_markup=default_kb ) -async def request_invite(message: types.Message, repo: Repo): - await repo.add_user(message.from_user.id, message.chat.id) +async def request_invite(message: types.Message, repo: Repos): + await repo.get_repo(UserRepo).add( + message.from_user.id, message.chat.id + ) logger.debug(f"Added new user {message.from_user.full_name}") bot_info = await message.bot.get_me() diff --git a/middlewares/db.py b/middlewares/db.py index 354a6ad..d93fe88 100644 --- a/middlewares/db.py +++ b/middlewares/db.py @@ -1,7 +1,7 @@ from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware from asyncpg.pool import Pool -from services.repository import Repo +from services.repositories import Repos class DbMiddleware(LifetimeControllerMiddleware): @@ -14,7 +14,7 @@ class DbMiddleware(LifetimeControllerMiddleware): async def pre_process(self, obj, data, *args): conn = await self.pool.acquire() data["conn"] = conn - data["repo"] = Repo(conn) + data["repo"] = Repos(conn) async def post_process(self, obj, data, *args): del data["repo"] diff --git a/services/repositories.py b/services/repositories.py new file mode 100644 index 0000000..9f7714a --- /dev/null +++ b/services/repositories.py @@ -0,0 +1,46 @@ +from loguru import logger +from typing import Type, TypeVar +from asyncpg.connection import Connection + + +class BaseRepo: + def __init__(self, conn: Connection): + self.conn = conn + + +T = TypeVar('T', bound=BaseRepo) + + +class Repos(BaseRepo): + def get_repo(self, Repo: Type[T]) -> T: + return Repo(self.conn) + + +class UserRepo(BaseRepo): + async def add(self, user_id: int, chat_id: int): + await self.conn.execute( + "INSERT INTO users (user_id, chat_id) " + "VALUES ($1, $2) ON CONFLICT (user_id) DO NOTHING", + user_id, chat_id + ) + + async def get(self, user_id: int): + return await self.conn.fetchval( + "SELECT user_id FROM users WHERE user_id = $1", user_id + ) + + +class ProjectRepo(BaseRepo): + async def add(self, name: str, description: str, *args): + await self.conn.execute( + "INSERT INTO projects " + "(name, description, contributors, status) " + "VALUES ($1, $2, $3, $4)", + name, description, *args + ) + + async def get(self, user_id: int): + return await self.conn.fetchval( + "SELECT name, description FROM projects WHERE creator = $1", + user_id + ) diff --git a/services/repository.py b/services/repository.py deleted file mode 100644 index b98be59..0000000 --- a/services/repository.py +++ /dev/null @@ -1,20 +0,0 @@ -from asyncpg.connection import Connection - - -class Repo: - """Db abstraction layer""" - - def __init__(self, conn: Connection): - self.conn = conn - - async def add_user(self, user_id: int, chat_id: int): - await self.conn.execute( - "INSERT INTO users (user_id, chat_id) " - "VALUES ($1, $2) ON CONFLICT (user_id) DO NOTHING", - user_id, chat_id - ) - - async def get_user(self, user_id: int): - return await self.conn.fetchval( - "SELECT user_id FROM users WHERE user_id = $1", user_id - )