diff --git a/bot.py b/bot.py index 7877ed6..f9a8a2b 100644 --- a/bot.py +++ b/bot.py @@ -1,23 +1,32 @@ #!/usr/bin/python import logging import os -import asyncio +import asyncio from aiogram import Bot, Dispatcher +import asyncpg + +from middlewares.db import DbMiddleware from handlers.commands import register_command_handlers async def main(): """Configures and runs the bot, connects to the db""" logging.basicConfig(level=logging.INFO) + DB_URL = os.getenv("DB_URL") API_TOKEN = os.getenv("API_TOKEN") + pool = await asyncpg.create_pool(DB_URL) + bot = Bot(API_TOKEN) dp = Dispatcher(bot) # Register handlers register_command_handlers(dp) + # Register middlewares + dp.middleware.setup(DbMiddleware(pool)) + try: await dp.start_polling() finally: diff --git a/middlewares/db.py b/middlewares/db.py new file mode 100644 index 0000000..354a6ad --- /dev/null +++ b/middlewares/db.py @@ -0,0 +1,23 @@ +from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware +from asyncpg.pool import Pool + +from services.repository import Repo + + +class DbMiddleware(LifetimeControllerMiddleware): + skip_patterns = ["error", "update"] + + def __init__(self, pool: Pool): + super().__init__() + self.pool = pool + + async def pre_process(self, obj, data, *args): + conn = await self.pool.acquire() + data["conn"] = conn + data["repo"] = Repo(conn) + + async def post_process(self, obj, data, *args): + del data["repo"] + conn = data.get("conn") + if conn: + await self.pool.release(conn)