48 lines
1.2 KiB
Python
48 lines
1.2 KiB
Python
from loguru import logger
|
|
from typing import Type, TypeVar
|
|
from asyncpg.connection import Connection
|
|
|
|
|
|
class BaseRepo:
|
|
def __init__(self, conn: Connection):
|
|
self.conn = conn
|
|
|
|
|
|
T = TypeVar("T", bound=BaseRepo)
|
|
|
|
|
|
class Repos(BaseRepo):
|
|
def get_repo(self, Repo: Type[T]) -> T:
|
|
return Repo(self.conn)
|
|
|
|
|
|
class UserRepo(BaseRepo):
|
|
async def add(self, user_id: int, chat_id: int):
|
|
await self.conn.execute(
|
|
"INSERT INTO users (user_id, chat_id) "
|
|
"VALUES ($1, $2) ON CONFLICT (user_id) DO NOTHING",
|
|
user_id,
|
|
chat_id,
|
|
)
|
|
|
|
async def get(self, user_id: int):
|
|
return await self.conn.fetchval(
|
|
"SELECT user_id FROM users WHERE user_id = $1", user_id
|
|
)
|
|
|
|
|
|
class ProjectRepo(BaseRepo):
|
|
async def add(self, *args, user_id: int):
|
|
await self.conn.execute(
|
|
"INSERT INTO projects "
|
|
"(creator, name, description, contacts, "
|
|
"contributors, status, tag, category) "
|
|
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8)",
|
|
user_id, *args
|
|
)
|
|
|
|
async def get(self, user_id: int):
|
|
return await self.conn.fetchval(
|
|
"SELECT name, description FROM projects WHERE creator = $1", user_id
|
|
)
|