Implement using multiple repos

This commit is contained in:
LoRiot 2022-10-16 22:50:37 +03:00
parent f817c79b01
commit 781d3c04ef
4 changed files with 58 additions and 28 deletions

View file

@ -1,20 +1,24 @@
from loguru import logger from loguru import logger
from aiogram import types, Dispatcher from aiogram import types, Dispatcher
from services.repository import Repo from keyboards.default import default_kb
from services.repositories import Repos, UserRepo
async def start(message: types.Message, repo: Repo): async def start(message: types.Message, repo: Repos):
user = await repo.get_user(message.from_user.id) user = await repo.get_repo(UserRepo).get(message.from_user.id)
if user is not None: if user is not None:
await message.answer( await message.answer(
"Добро пожаловать в Бота Теплицы Социальных Технологий " "Добро пожаловать в Бота Теплицы Социальных Технологий "
"для организации антивоенных проектов." "для организации антивоенных проектов.",
reply_markup=default_kb
) )
async def request_invite(message: types.Message, repo: Repo): async def request_invite(message: types.Message, repo: Repos):
await repo.add_user(message.from_user.id, message.chat.id) await repo.get_repo(UserRepo).add(
message.from_user.id, message.chat.id
)
logger.debug(f"Added new user {message.from_user.full_name}") logger.debug(f"Added new user {message.from_user.full_name}")
bot_info = await message.bot.get_me() bot_info = await message.bot.get_me()

View file

@ -1,7 +1,7 @@
from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware
from asyncpg.pool import Pool from asyncpg.pool import Pool
from services.repository import Repo from services.repositories import Repos
class DbMiddleware(LifetimeControllerMiddleware): class DbMiddleware(LifetimeControllerMiddleware):
@ -14,7 +14,7 @@ class DbMiddleware(LifetimeControllerMiddleware):
async def pre_process(self, obj, data, *args): async def pre_process(self, obj, data, *args):
conn = await self.pool.acquire() conn = await self.pool.acquire()
data["conn"] = conn data["conn"] = conn
data["repo"] = Repo(conn) data["repo"] = Repos(conn)
async def post_process(self, obj, data, *args): async def post_process(self, obj, data, *args):
del data["repo"] del data["repo"]

46
services/repositories.py Normal file
View file

@ -0,0 +1,46 @@
from loguru import logger
from typing import Type, TypeVar
from asyncpg.connection import Connection
class BaseRepo:
def __init__(self, conn: Connection):
self.conn = conn
T = TypeVar('T', bound=BaseRepo)
class Repos(BaseRepo):
def get_repo(self, Repo: Type[T]) -> T:
return Repo(self.conn)
class UserRepo(BaseRepo):
async def add(self, user_id: int, chat_id: int):
await self.conn.execute(
"INSERT INTO users (user_id, chat_id) "
"VALUES ($1, $2) ON CONFLICT (user_id) DO NOTHING",
user_id, chat_id
)
async def get(self, user_id: int):
return await self.conn.fetchval(
"SELECT user_id FROM users WHERE user_id = $1", user_id
)
class ProjectRepo(BaseRepo):
async def add(self, name: str, description: str, *args):
await self.conn.execute(
"INSERT INTO projects "
"(name, description, contributors, status) "
"VALUES ($1, $2, $3, $4)",
name, description, *args
)
async def get(self, user_id: int):
return await self.conn.fetchval(
"SELECT name, description FROM projects WHERE creator = $1",
user_id
)

View file

@ -1,20 +0,0 @@
from asyncpg.connection import Connection
class Repo:
"""Db abstraction layer"""
def __init__(self, conn: Connection):
self.conn = conn
async def add_user(self, user_id: int, chat_id: int):
await self.conn.execute(
"INSERT INTO users (user_id, chat_id) "
"VALUES ($1, $2) ON CONFLICT (user_id) DO NOTHING",
user_id, chat_id
)
async def get_user(self, user_id: int):
return await self.conn.fetchval(
"SELECT user_id FROM users WHERE user_id = $1", user_id
)