Implement using multiple repos
This commit is contained in:
parent
f817c79b01
commit
781d3c04ef
|
@ -1,20 +1,24 @@
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from aiogram import types, Dispatcher
|
from aiogram import types, Dispatcher
|
||||||
from services.repository import Repo
|
from keyboards.default import default_kb
|
||||||
|
from services.repositories import Repos, UserRepo
|
||||||
|
|
||||||
|
|
||||||
async def start(message: types.Message, repo: Repo):
|
async def start(message: types.Message, repo: Repos):
|
||||||
user = await repo.get_user(message.from_user.id)
|
user = await repo.get_repo(UserRepo).get(message.from_user.id)
|
||||||
if user is not None:
|
if user is not None:
|
||||||
await message.answer(
|
await message.answer(
|
||||||
"Добро пожаловать в Бота Теплицы Социальных Технологий "
|
"Добро пожаловать в Бота Теплицы Социальных Технологий "
|
||||||
"для организации антивоенных проектов."
|
"для организации антивоенных проектов.",
|
||||||
|
reply_markup=default_kb
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def request_invite(message: types.Message, repo: Repo):
|
async def request_invite(message: types.Message, repo: Repos):
|
||||||
await repo.add_user(message.from_user.id, message.chat.id)
|
await repo.get_repo(UserRepo).add(
|
||||||
|
message.from_user.id, message.chat.id
|
||||||
|
)
|
||||||
logger.debug(f"Added new user {message.from_user.full_name}")
|
logger.debug(f"Added new user {message.from_user.full_name}")
|
||||||
|
|
||||||
bot_info = await message.bot.get_me()
|
bot_info = await message.bot.get_me()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware
|
from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware
|
||||||
from asyncpg.pool import Pool
|
from asyncpg.pool import Pool
|
||||||
|
|
||||||
from services.repository import Repo
|
from services.repositories import Repos
|
||||||
|
|
||||||
|
|
||||||
class DbMiddleware(LifetimeControllerMiddleware):
|
class DbMiddleware(LifetimeControllerMiddleware):
|
||||||
|
@ -14,7 +14,7 @@ class DbMiddleware(LifetimeControllerMiddleware):
|
||||||
async def pre_process(self, obj, data, *args):
|
async def pre_process(self, obj, data, *args):
|
||||||
conn = await self.pool.acquire()
|
conn = await self.pool.acquire()
|
||||||
data["conn"] = conn
|
data["conn"] = conn
|
||||||
data["repo"] = Repo(conn)
|
data["repo"] = Repos(conn)
|
||||||
|
|
||||||
async def post_process(self, obj, data, *args):
|
async def post_process(self, obj, data, *args):
|
||||||
del data["repo"]
|
del data["repo"]
|
||||||
|
|
46
services/repositories.py
Normal file
46
services/repositories.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Type, TypeVar
|
||||||
|
from asyncpg.connection import Connection
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRepo:
|
||||||
|
def __init__(self, conn: Connection):
|
||||||
|
self.conn = conn
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar('T', bound=BaseRepo)
|
||||||
|
|
||||||
|
|
||||||
|
class Repos(BaseRepo):
|
||||||
|
def get_repo(self, Repo: Type[T]) -> T:
|
||||||
|
return Repo(self.conn)
|
||||||
|
|
||||||
|
|
||||||
|
class UserRepo(BaseRepo):
|
||||||
|
async def add(self, user_id: int, chat_id: int):
|
||||||
|
await self.conn.execute(
|
||||||
|
"INSERT INTO users (user_id, chat_id) "
|
||||||
|
"VALUES ($1, $2) ON CONFLICT (user_id) DO NOTHING",
|
||||||
|
user_id, chat_id
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get(self, user_id: int):
|
||||||
|
return await self.conn.fetchval(
|
||||||
|
"SELECT user_id FROM users WHERE user_id = $1", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectRepo(BaseRepo):
|
||||||
|
async def add(self, name: str, description: str, *args):
|
||||||
|
await self.conn.execute(
|
||||||
|
"INSERT INTO projects "
|
||||||
|
"(name, description, contributors, status) "
|
||||||
|
"VALUES ($1, $2, $3, $4)",
|
||||||
|
name, description, *args
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get(self, user_id: int):
|
||||||
|
return await self.conn.fetchval(
|
||||||
|
"SELECT name, description FROM projects WHERE creator = $1",
|
||||||
|
user_id
|
||||||
|
)
|
|
@ -1,20 +0,0 @@
|
||||||
from asyncpg.connection import Connection
|
|
||||||
|
|
||||||
|
|
||||||
class Repo:
|
|
||||||
"""Db abstraction layer"""
|
|
||||||
|
|
||||||
def __init__(self, conn: Connection):
|
|
||||||
self.conn = conn
|
|
||||||
|
|
||||||
async def add_user(self, user_id: int, chat_id: int):
|
|
||||||
await self.conn.execute(
|
|
||||||
"INSERT INTO users (user_id, chat_id) "
|
|
||||||
"VALUES ($1, $2) ON CONFLICT (user_id) DO NOTHING",
|
|
||||||
user_id, chat_id
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_user(self, user_id: int):
|
|
||||||
return await self.conn.fetchval(
|
|
||||||
"SELECT user_id FROM users WHERE user_id = $1", user_id
|
|
||||||
)
|
|
Reference in a new issue