materia-server: fix migrations, split app and cli

This commit is contained in:
L-Nafaryus 2024-06-24 18:52:04 +05:00
parent 317085fc04
commit f7bac07837
Signed by: L-Nafaryus
GPG Key ID: 553C97999B363D38
8 changed files with 285 additions and 103 deletions

View File

@ -0,0 +1,81 @@
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
from os import environ
import os
from pathlib import Path
import pwd
import sys
from typing import AsyncIterator, TypedDict
import click
from pydantic import BaseModel
from pydanclick import from_pydantic
import pydantic
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from materia_server import config as _config
from materia_server.config import Config
from materia_server._logging import make_logger, uvicorn_log_config, Logger
from materia_server.models import Database, DatabaseError, DatabaseMigrationError, Cache, CacheError
from materia_server import routers
class AppContext(TypedDict):
config: Config
logger: Logger
database: Database
cache: Cache
def make_lifespan(config: Config, logger: Logger):
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
try:
logger.info("Connecting to database {}", config.database.url())
database = await Database.new(config.database.url()) # type: ignore
logger.info("Running migrations")
await database.run_migrations()
logger.info("Connecting to cache {}", config.cache.url())
cache = await Cache.new(config.cache.url()) # type: ignore
except DatabaseError as e:
logger.error(f"Failed to connect postgres: {e}")
sys.exit()
except DatabaseMigrationError as e:
logger.error(f"Failed to run migrations: {e}")
sys.exit()
except CacheError as e:
logger.error(f"Failed to connect redis: {e}")
sys.exit()
yield AppContext(
config = config,
database = database,
cache = cache,
logger = logger
)
if database.engine is not None:
await database.dispose()
return lifespan
def make_application(config: Config, logger: Logger):
app = FastAPI(
title = "materia",
version = "0.1.0",
docs_url = "/api/docs",
lifespan = make_lifespan(config, logger)
)
app.add_middleware(
CORSMiddleware,
allow_origins = [ "http://localhost", "http://localhost:5173" ],
allow_credentials = True,
allow_methods = ["*"],
allow_headers = ["*"],
)
app.include_router(routers.api.router)
return app

View File

@ -17,49 +17,9 @@ from fastapi.middleware.cors import CORSMiddleware
from materia_server import config as _config from materia_server import config as _config
from materia_server.config import Config from materia_server.config import Config
from materia_server._logging import make_logger, uvicorn_log_config, Logger from materia_server._logging import make_logger, uvicorn_log_config, Logger
from materia_server.models import Database, Cache from materia_server.models import Database, DatabaseError, Cache
from materia_server import routers from materia_server import routers
from materia_server.app import make_application
# TODO: add cache
class AppContext(TypedDict):
config: Config
database: Database
cache: Cache
logger: Logger
def create_lifespan(config: Config, logger):
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
try:
logger.info("Connecting {}", config.database.url())
database = Database.new(config.database.url()) # type: ignore
except:
logger.error("Failed to connect postgres.")
sys.exit()
try:
logger.info("Connecting {}", config.cache.url())
cache = await Cache.new(config.cache.url()) # type: ignore
except:
logger.error("Failed to connect redis.")
sys.exit()
async with database.connection() as connection:
await connection.run_sync(database.run_migrations) # type: ignore
yield AppContext(
config = config,
database = database,
cache = cache,
logger = logger
)
if database.engine is not None:
await database.dispose()
return lifespan
@click.group() @click.group()
def server(): def server():
@ -128,30 +88,13 @@ def start(application: _config.Application, config_path: Path, log: _config.Log)
config.application.mode = application.mode config.application.mode = application.mode
app = FastAPI(
title = "materia",
version = "0.1.0",
docs_url = "/api/docs",
lifespan = create_lifespan(config, logger)
)
app.add_middleware(
CORSMiddleware,
allow_origins = [ "http://localhost", "http://localhost:5173" ],
allow_credentials = True,
allow_methods = ["*"],
allow_headers = ["*"],
)
app.include_router(routers.api.router)
try: try:
uvicorn.run( uvicorn.run(
app, make_application(config, logger),
port = config.server.port, port = config.server.port,
host = str(config.server.address), host = str(config.server.address),
# reload = config.application.mode == "development", # reload = config.application.mode == "development",
log_config = uvicorn_log_config(config) log_config = uvicorn_log_config(config),
) )
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
pass pass

View File

@ -1,18 +1,7 @@
#from materia_server.models.base import Base from materia_server.models.auth import LoginType, LoginSource, OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode
#from materia_server.models.auth import LoginType, LoginSource, OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode
#from materia_server.models.user import User
#from materia_server.models.repository import Repository
#from materia_server.models.directory import Directory, DirectoryLink
#from materia_server.models.file import File, FileLink
#from materia_server.models.repository import * from materia_server.models.database import Database, DatabaseError, DatabaseMigrationError, Cache, CacheError
from materia_server.models.auth.source import LoginType, LoginSource
from materia_server.models.auth.oauth2 import OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode
from materia_server.models.database.database import Database
from materia_server.models.database.cache import Cache
from materia_server.models.user import User, UserCredentials, UserInfo from materia_server.models.user import User, UserCredentials, UserInfo

View File

@ -1,2 +1,2 @@
from materia_server.models.database.database import Database from materia_server.models.database.database import DatabaseError, DatabaseMigrationError, Database
from materia_server.models.database.cache import Cache from materia_server.models.database.cache import Cache, CacheError

View File

@ -4,6 +4,8 @@ from pydantic import BaseModel, RedisDsn
from redis import asyncio as aioredis from redis import asyncio as aioredis
from redis.asyncio.client import Pipeline from redis.asyncio.client import Pipeline
class CacheError(Exception):
pass
class Cache: class Cache:
def __init__(self, url: RedisDsn, pool: aioredis.ConnectionPool): def __init__(self, url: RedisDsn, pool: aioredis.ConnectionPool):
@ -11,16 +13,22 @@ class Cache:
self.pool: aioredis.ConnectionPool = pool self.pool: aioredis.ConnectionPool = pool
@staticmethod @staticmethod
async def new(url: RedisDsn, encoding: str = "utf-8", decode_responses: bool = True) -> Self: async def new(
url: RedisDsn,
encoding: str = "utf-8",
decode_responses: bool = True,
test_connection: bool = True
) -> Self:
pool = aioredis.ConnectionPool.from_url(str(url), encoding = encoding, decode_responses = decode_responses) pool = aioredis.ConnectionPool.from_url(str(url), encoding = encoding, decode_responses = decode_responses)
try: if test_connection:
connection = pool.make_connection() try:
await connection.connect() connection = pool.make_connection()
except ConnectionError as e: await connection.connect()
raise e except ConnectionError as e:
else: raise CacheError(f"{e}")
await connection.disconnect() else:
await connection.disconnect()
return Cache( return Cache(
url = url, url = url,
@ -32,7 +40,7 @@ class Cache:
try: try:
yield aioredis.Redis(connection_pool = self.pool) yield aioredis.Redis(connection_pool = self.pool)
except Exception as e: except Exception as e:
raise e raise CacheError(f"{e}")
@asynccontextmanager @asynccontextmanager
async def pipeline(self, transaction: bool = True) -> AsyncGenerator[Pipeline, Any]: async def pipeline(self, transaction: bool = True) -> AsyncGenerator[Pipeline, Any]:
@ -41,5 +49,5 @@ class Cache:
try: try:
yield client.pipeline(transaction = transaction) yield client.pipeline(transaction = transaction)
except Exception as e: except Exception as e:
raise e raise CacheError(f"{e}")

View File

@ -16,6 +16,12 @@ from materia_server.models.base import Base
__all__ = [ "Database" ] __all__ = [ "Database" ]
class DatabaseError(Exception):
pass
class DatabaseMigrationError(Exception):
pass
class Database: class Database:
def __init__(self, url: PostgresDsn, engine: AsyncEngine, sessionmaker: async_sessionmaker[AsyncSession]): def __init__(self, url: PostgresDsn, engine: AsyncEngine, sessionmaker: async_sessionmaker[AsyncSession]):
self.url: PostgresDsn = url self.url: PostgresDsn = url
@ -23,7 +29,14 @@ class Database:
self.sessionmaker: async_sessionmaker[AsyncSession] = sessionmaker self.sessionmaker: async_sessionmaker[AsyncSession] = sessionmaker
@staticmethod @staticmethod
def new(url: PostgresDsn, pool_size: int = 100, autocommit: bool = False, autoflush: bool = False, expire_on_commit: bool = False) -> Self: async def new(
url: PostgresDsn,
pool_size: int = 100,
autocommit: bool = False,
autoflush: bool = False,
expire_on_commit: bool = False,
test_connection: bool = True
) -> Self:
engine = create_async_engine(str(url), pool_size = pool_size) engine = create_async_engine(str(url), pool_size = pool_size)
sessionmaker = async_sessionmaker( sessionmaker = async_sessionmaker(
bind = engine, bind = engine,
@ -32,12 +45,21 @@ class Database:
expire_on_commit = expire_on_commit expire_on_commit = expire_on_commit
) )
return Database( database = Database(
url = url, url = url,
engine = engine, engine = engine,
sessionmaker = sessionmaker sessionmaker = sessionmaker
) )
if test_connection:
try:
async with database.connection() as connection:
await connection.rollback()
except Exception as e:
raise DatabaseError(f"{e}")
return database
async def dispose(self): async def dispose(self):
await self.engine.dispose() await self.engine.dispose()
@ -48,7 +70,7 @@ class Database:
yield connection yield connection
except Exception as e: except Exception as e:
await connection.rollback() await connection.rollback()
raise e raise DatabaseError(f"{e}")
@asynccontextmanager @asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]: async def session(self) -> AsyncIterator[AsyncSession]:
@ -58,19 +80,14 @@ class Database:
yield session yield session
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise e raise DatabaseError(f"{e}")
finally: finally:
await session.close() await session.close()
def run_migrations(self, connection: Connection): def run_sync_migrations(self, connection: Connection):
#aconfig = AlembicConfig(Path(__file__).parent.parent.parent / "alembic.ini")
aconfig = AlembicConfig() aconfig = AlembicConfig()
aconfig.set_main_option("sqlalchemy.url", str(self.url)) aconfig.set_main_option("sqlalchemy.url", str(self.url))
aconfig.set_main_option("script_location", str(Path(__file__).parent.parent.joinpath("migrations"))) aconfig.set_main_option("script_location", str(Path(__file__).parent.parent.joinpath("migrations")))
print(str(Path(__file__).parent.parent.joinpath("migrations")))
context = MigrationContext.configure( context = MigrationContext.configure(
connection = connection, # type: ignore connection = connection, # type: ignore
@ -80,8 +97,14 @@ class Database:
} }
) )
with context.begin_transaction(): try:
with Operations.context(context): with context.begin_transaction():
context.run_migrations() with Operations.context(context):
context.run_migrations()
except Exception as e:
raise DatabaseMigrationError(f"{e}")
async def run_migrations(self):
async with self.connection() as connection:
await connection.run_sync(self.run_sync_migrations) # type: ignore

View File

@ -19,9 +19,7 @@ import materia_server.models.file
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
context.configure(
version_table_schema = "public"
)
config = context.config config = context.config
#config.set_main_option("sqlalchemy.url", Config().database.url()) #config.set_main_option("sqlalchemy.url", Config().database.url())

View File

@ -0,0 +1,140 @@
"""empty message
Revision ID: 939b37d98be0
Revises:
Create Date: 2024-06-24 15:39:38.380581
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '939b37d98be0'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').create(op.get_bind())
op.create_table('login_source',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False),
sa.Column('created', sa.Integer(), nullable=False),
sa.Column('updated', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('lower_name', sa.String(), nullable=False),
sa.Column('full_name', sa.String(), nullable=True),
sa.Column('email', sa.String(), nullable=False),
sa.Column('is_email_private', sa.Boolean(), nullable=False),
sa.Column('hashed_password', sa.String(), nullable=False),
sa.Column('must_change_password', sa.Boolean(), nullable=False),
sa.Column('login_type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('updated', sa.BigInteger(), nullable=False),
sa.Column('last_login', sa.BigInteger(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_admin', sa.Boolean(), nullable=False),
sa.Column('avatar', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('lower_name'),
sa.UniqueConstraint('name')
)
op.create_table('oauth2_application',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('client_id', sa.Uuid(), nullable=False),
sa.Column('hashed_client_secret', sa.String(), nullable=False),
sa.Column('redirect_uris', sa.JSON(), nullable=False),
sa.Column('confidential_client', sa.Boolean(), nullable=False),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('updated', sa.BigInteger(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('repository',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('capacity', sa.BigInteger(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('directory',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('repository_id', sa.BigInteger(), nullable=False),
sa.Column('parent_id', sa.BigInteger(), nullable=True),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('updated', sa.BigInteger(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('path', sa.String(), nullable=True),
sa.Column('is_public', sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('oauth2_grant',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('application_id', sa.BigInteger(), nullable=False),
sa.Column('scope', sa.String(), nullable=False),
sa.Column('created', sa.Integer(), nullable=False),
sa.Column('updated', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['application_id'], ['oauth2_application.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('directory_link',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('directory_id', sa.BigInteger(), nullable=False),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('url', sa.String(), nullable=False),
sa.ForeignKeyConstraint(['directory_id'], ['directory.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('file',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('repository_id', sa.BigInteger(), nullable=False),
sa.Column('parent_id', sa.BigInteger(), nullable=True),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('updated', sa.BigInteger(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('path', sa.String(), nullable=True),
sa.Column('is_public', sa.Boolean(), nullable=False),
sa.Column('size', sa.BigInteger(), nullable=False),
sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ),
sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('file_link',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('file_id', sa.BigInteger(), nullable=False),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('url', sa.String(), nullable=False),
sa.ForeignKeyConstraint(['file_id'], ['file.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('file_link')
op.drop_table('file')
op.drop_table('directory_link')
op.drop_table('oauth2_grant')
op.drop_table('directory')
op.drop_table('repository')
op.drop_table('oauth2_application')
op.drop_table('user')
op.drop_table('login_source')
sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').drop(op.get_bind())
# ### end Alembic commands ###