80 lines
2.7 KiB
Python
Raw Normal View History

from contextlib import asynccontextmanager
from typing import AsyncIterator, Self
from pathlib import Path
from pydantic import BaseModel, PostgresDsn
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from asyncpg import Connection
from alembic.config import Config as AlembicConfig
from alembic.operations import Operations
from alembic.runtime.migration import MigrationContext
from alembic.script.base import ScriptDirectory
from materia_server.models.base import Base
__all__ = [ "Database" ]
class Database:
def __init__(self, url: PostgresDsn, engine: AsyncEngine, sessionmaker: async_sessionmaker[AsyncSession]):
self.url: PostgresDsn = url
self.engine: AsyncEngine = engine
self.sessionmaker: async_sessionmaker[AsyncSession] = sessionmaker
@staticmethod
def new(url: PostgresDsn, pool_size: int = 100, autocommit: bool = False, autoflush: bool = False, expire_on_commit: bool = False) -> Self:
engine = create_async_engine(str(url), pool_size = pool_size)
sessionmaker = async_sessionmaker(
bind = engine,
autocommit = autocommit,
autoflush = autoflush,
expire_on_commit = expire_on_commit
)
return Database(
url = url,
engine = engine,
sessionmaker = sessionmaker
)
async def dispose(self):
await self.engine.dispose()
@asynccontextmanager
async def connection(self) -> AsyncIterator[AsyncConnection]:
async with self.engine.begin() as connection:
try:
yield connection
except Exception as e:
await connection.rollback()
raise e
@asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]:
session = self.sessionmaker();
try:
yield session
except Exception as e:
await session.rollback()
raise e
finally:
await session.close()
def run_migrations(self, connection: Connection):
config = AlembicConfig(Path(__file__).parent.parent.parent / "alembic.ini")
config.set_main_option("sqlalchemy.url", self.url) # type: ignore
context = MigrationContext.configure(
connection = connection, # type: ignore
opts = {
"target_metadata": Base.metadata,
"fn": lambda rev, _: ScriptDirectory.from_config(config)._upgrade_revs("head", rev)
}
)
with context.begin_transaction():
with Operations.context(context):
context.run_migrations()