Implement using multiple repos
This commit is contained in:
parent
f817c79b01
commit
781d3c04ef
|
@ -1,20 +1,24 @@
|
|||
from loguru import logger
|
||||
|
||||
from aiogram import types, Dispatcher
|
||||
from services.repository import Repo
|
||||
from keyboards.default import default_kb
|
||||
from services.repositories import Repos, UserRepo
|
||||
|
||||
|
||||
async def start(message: types.Message, repo: Repo):
|
||||
user = await repo.get_user(message.from_user.id)
|
||||
async def start(message: types.Message, repo: Repos):
|
||||
user = await repo.get_repo(UserRepo).get(message.from_user.id)
|
||||
if user is not None:
|
||||
await message.answer(
|
||||
"Добро пожаловать в Бота Теплицы Социальных Технологий "
|
||||
"для организации антивоенных проектов."
|
||||
"для организации антивоенных проектов.",
|
||||
reply_markup=default_kb
|
||||
)
|
||||
|
||||
|
||||
async def request_invite(message: types.Message, repo: Repo):
|
||||
await repo.add_user(message.from_user.id, message.chat.id)
|
||||
async def request_invite(message: types.Message, repo: Repos):
|
||||
await repo.get_repo(UserRepo).add(
|
||||
message.from_user.id, message.chat.id
|
||||
)
|
||||
logger.debug(f"Added new user {message.from_user.full_name}")
|
||||
|
||||
bot_info = await message.bot.get_me()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware
|
||||
from asyncpg.pool import Pool
|
||||
|
||||
from services.repository import Repo
|
||||
from services.repositories import Repos
|
||||
|
||||
|
||||
class DbMiddleware(LifetimeControllerMiddleware):
|
||||
|
@ -14,7 +14,7 @@ class DbMiddleware(LifetimeControllerMiddleware):
|
|||
async def pre_process(self, obj, data, *args):
|
||||
conn = await self.pool.acquire()
|
||||
data["conn"] = conn
|
||||
data["repo"] = Repo(conn)
|
||||
data["repo"] = Repos(conn)
|
||||
|
||||
async def post_process(self, obj, data, *args):
|
||||
del data["repo"]
|
||||
|
|
46
services/repositories.py
Normal file
46
services/repositories.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
from loguru import logger
|
||||
from typing import Type, TypeVar
|
||||
from asyncpg.connection import Connection
|
||||
|
||||
|
||||
class BaseRepo:
|
||||
def __init__(self, conn: Connection):
|
||||
self.conn = conn
|
||||
|
||||
|
||||
T = TypeVar('T', bound=BaseRepo)
|
||||
|
||||
|
||||
class Repos(BaseRepo):
|
||||
def get_repo(self, Repo: Type[T]) -> T:
|
||||
return Repo(self.conn)
|
||||
|
||||
|
||||
class UserRepo(BaseRepo):
|
||||
async def add(self, user_id: int, chat_id: int):
|
||||
await self.conn.execute(
|
||||
"INSERT INTO users (user_id, chat_id) "
|
||||
"VALUES ($1, $2) ON CONFLICT (user_id) DO NOTHING",
|
||||
user_id, chat_id
|
||||
)
|
||||
|
||||
async def get(self, user_id: int):
|
||||
return await self.conn.fetchval(
|
||||
"SELECT user_id FROM users WHERE user_id = $1", user_id
|
||||
)
|
||||
|
||||
|
||||
class ProjectRepo(BaseRepo):
|
||||
async def add(self, name: str, description: str, *args):
|
||||
await self.conn.execute(
|
||||
"INSERT INTO projects "
|
||||
"(name, description, contributors, status) "
|
||||
"VALUES ($1, $2, $3, $4)",
|
||||
name, description, *args
|
||||
)
|
||||
|
||||
async def get(self, user_id: int):
|
||||
return await self.conn.fetchval(
|
||||
"SELECT name, description FROM projects WHERE creator = $1",
|
||||
user_id
|
||||
)
|
|
@ -1,20 +0,0 @@
|
|||
from asyncpg.connection import Connection
|
||||
|
||||
|
||||
class Repo:
|
||||
"""Db abstraction layer"""
|
||||
|
||||
def __init__(self, conn: Connection):
|
||||
self.conn = conn
|
||||
|
||||
async def add_user(self, user_id: int, chat_id: int):
|
||||
await self.conn.execute(
|
||||
"INSERT INTO users (user_id, chat_id) "
|
||||
"VALUES ($1, $2) ON CONFLICT (user_id) DO NOTHING",
|
||||
user_id, chat_id
|
||||
)
|
||||
|
||||
async def get_user(self, user_id: int):
|
||||
return await self.conn.fetchval(
|
||||
"SELECT user_id FROM users WHERE user_id = $1", user_id
|
||||
)
|
Reference in a new issue