Add database middleware
This commit is contained in:
parent
6daa0a5e9b
commit
cdae6b0c8f
11
bot.py
11
bot.py
|
@ -1,23 +1,32 @@
|
||||||
#!/usr/bin/python
|
#!/usr/bin/python
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from aiogram import Bot, Dispatcher
|
from aiogram import Bot, Dispatcher
|
||||||
|
import asyncpg
|
||||||
|
|
||||||
|
from middlewares.db import DbMiddleware
|
||||||
from handlers.commands import register_command_handlers
|
from handlers.commands import register_command_handlers
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
"""Configures and runs the bot, connects to the db"""
|
"""Configures and runs the bot, connects to the db"""
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
DB_URL = os.getenv("DB_URL")
|
||||||
API_TOKEN = os.getenv("API_TOKEN")
|
API_TOKEN = os.getenv("API_TOKEN")
|
||||||
|
|
||||||
|
pool = await asyncpg.create_pool(DB_URL)
|
||||||
|
|
||||||
bot = Bot(API_TOKEN)
|
bot = Bot(API_TOKEN)
|
||||||
dp = Dispatcher(bot)
|
dp = Dispatcher(bot)
|
||||||
|
|
||||||
# Register handlers
|
# Register handlers
|
||||||
register_command_handlers(dp)
|
register_command_handlers(dp)
|
||||||
|
|
||||||
|
# Register middlewares
|
||||||
|
dp.middleware.setup(DbMiddleware(pool))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await dp.start_polling()
|
await dp.start_polling()
|
||||||
finally:
|
finally:
|
||||||
|
|
23
middlewares/db.py
Normal file
23
middlewares/db.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
from aiogram.dispatcher.middlewares import LifetimeControllerMiddleware
|
||||||
|
from asyncpg.pool import Pool
|
||||||
|
|
||||||
|
from services.repository import Repo
|
||||||
|
|
||||||
|
|
||||||
|
class DbMiddleware(LifetimeControllerMiddleware):
|
||||||
|
skip_patterns = ["error", "update"]
|
||||||
|
|
||||||
|
def __init__(self, pool: Pool):
|
||||||
|
super().__init__()
|
||||||
|
self.pool = pool
|
||||||
|
|
||||||
|
async def pre_process(self, obj, data, *args):
|
||||||
|
conn = await self.pool.acquire()
|
||||||
|
data["conn"] = conn
|
||||||
|
data["repo"] = Repo(conn)
|
||||||
|
|
||||||
|
async def post_process(self, obj, data, *args):
|
||||||
|
del data["repo"]
|
||||||
|
conn = data.get("conn")
|
||||||
|
if conn:
|
||||||
|
await self.pool.release(conn)
|
Reference in a new issue