materia-server: fix migrations, split app and cli
This commit is contained in:
parent
317085fc04
commit
f7bac07837
81
materia-server/src/materia_server/app.py
Normal file
81
materia-server/src/materia_server/app.py
Normal file
@ -0,0 +1,81 @@
|
||||
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
|
||||
from os import environ
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pwd
|
||||
import sys
|
||||
from typing import AsyncIterator, TypedDict
|
||||
import click
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydanclick import from_pydantic
|
||||
import pydantic
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from materia_server import config as _config
|
||||
from materia_server.config import Config
|
||||
from materia_server._logging import make_logger, uvicorn_log_config, Logger
|
||||
from materia_server.models import Database, DatabaseError, DatabaseMigrationError, Cache, CacheError
|
||||
from materia_server import routers
|
||||
|
||||
|
||||
class AppContext(TypedDict):
|
||||
config: Config
|
||||
logger: Logger
|
||||
database: Database
|
||||
cache: Cache
|
||||
|
||||
def make_lifespan(config: Config, logger: Logger):
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
|
||||
|
||||
try:
|
||||
logger.info("Connecting to database {}", config.database.url())
|
||||
database = await Database.new(config.database.url()) # type: ignore
|
||||
|
||||
logger.info("Running migrations")
|
||||
await database.run_migrations()
|
||||
|
||||
logger.info("Connecting to cache {}", config.cache.url())
|
||||
cache = await Cache.new(config.cache.url()) # type: ignore
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Failed to connect postgres: {e}")
|
||||
sys.exit()
|
||||
except DatabaseMigrationError as e:
|
||||
logger.error(f"Failed to run migrations: {e}")
|
||||
sys.exit()
|
||||
except CacheError as e:
|
||||
logger.error(f"Failed to connect redis: {e}")
|
||||
sys.exit()
|
||||
|
||||
yield AppContext(
|
||||
config = config,
|
||||
database = database,
|
||||
cache = cache,
|
||||
logger = logger
|
||||
)
|
||||
|
||||
if database.engine is not None:
|
||||
await database.dispose()
|
||||
|
||||
return lifespan
|
||||
|
||||
def make_application(config: Config, logger: Logger):
|
||||
app = FastAPI(
|
||||
title = "materia",
|
||||
version = "0.1.0",
|
||||
docs_url = "/api/docs",
|
||||
lifespan = make_lifespan(config, logger)
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins = [ "http://localhost", "http://localhost:5173" ],
|
||||
allow_credentials = True,
|
||||
allow_methods = ["*"],
|
||||
allow_headers = ["*"],
|
||||
)
|
||||
app.include_router(routers.api.router)
|
||||
|
||||
return app
|
@ -17,49 +17,9 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from materia_server import config as _config
|
||||
from materia_server.config import Config
|
||||
from materia_server._logging import make_logger, uvicorn_log_config, Logger
|
||||
from materia_server.models import Database, Cache
|
||||
from materia_server.models import Database, DatabaseError, Cache
|
||||
from materia_server import routers
|
||||
|
||||
|
||||
# TODO: add cache
|
||||
class AppContext(TypedDict):
|
||||
config: Config
|
||||
database: Database
|
||||
cache: Cache
|
||||
logger: Logger
|
||||
|
||||
def create_lifespan(config: Config, logger):
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
|
||||
|
||||
try:
|
||||
logger.info("Connecting {}", config.database.url())
|
||||
database = Database.new(config.database.url()) # type: ignore
|
||||
except:
|
||||
logger.error("Failed to connect postgres.")
|
||||
sys.exit()
|
||||
|
||||
try:
|
||||
logger.info("Connecting {}", config.cache.url())
|
||||
cache = await Cache.new(config.cache.url()) # type: ignore
|
||||
except:
|
||||
logger.error("Failed to connect redis.")
|
||||
sys.exit()
|
||||
|
||||
async with database.connection() as connection:
|
||||
await connection.run_sync(database.run_migrations) # type: ignore
|
||||
|
||||
yield AppContext(
|
||||
config = config,
|
||||
database = database,
|
||||
cache = cache,
|
||||
logger = logger
|
||||
)
|
||||
|
||||
if database.engine is not None:
|
||||
await database.dispose()
|
||||
|
||||
return lifespan
|
||||
from materia_server.app import make_application
|
||||
|
||||
@click.group()
|
||||
def server():
|
||||
@ -128,30 +88,13 @@ def start(application: _config.Application, config_path: Path, log: _config.Log)
|
||||
|
||||
config.application.mode = application.mode
|
||||
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title = "materia",
|
||||
version = "0.1.0",
|
||||
docs_url = "/api/docs",
|
||||
lifespan = create_lifespan(config, logger)
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins = [ "http://localhost", "http://localhost:5173" ],
|
||||
allow_credentials = True,
|
||||
allow_methods = ["*"],
|
||||
allow_headers = ["*"],
|
||||
)
|
||||
app.include_router(routers.api.router)
|
||||
|
||||
try:
|
||||
uvicorn.run(
|
||||
app,
|
||||
make_application(config, logger),
|
||||
port = config.server.port,
|
||||
host = str(config.server.address),
|
||||
# reload = config.application.mode == "development",
|
||||
log_config = uvicorn_log_config(config)
|
||||
log_config = uvicorn_log_config(config),
|
||||
)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
|
@ -1,18 +1,7 @@
|
||||
|
||||
#from materia_server.models.base import Base
|
||||
#from materia_server.models.auth import LoginType, LoginSource, OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode
|
||||
#from materia_server.models.user import User
|
||||
#from materia_server.models.repository import Repository
|
||||
#from materia_server.models.directory import Directory, DirectoryLink
|
||||
#from materia_server.models.file import File, FileLink
|
||||
from materia_server.models.auth import LoginType, LoginSource, OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode
|
||||
|
||||
#from materia_server.models.repository import *
|
||||
|
||||
from materia_server.models.auth.source import LoginType, LoginSource
|
||||
from materia_server.models.auth.oauth2 import OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode
|
||||
|
||||
from materia_server.models.database.database import Database
|
||||
from materia_server.models.database.cache import Cache
|
||||
from materia_server.models.database import Database, DatabaseError, DatabaseMigrationError, Cache, CacheError
|
||||
|
||||
from materia_server.models.user import User, UserCredentials, UserInfo
|
||||
|
||||
|
@ -1,2 +1,2 @@
|
||||
from materia_server.models.database.database import Database
|
||||
from materia_server.models.database.cache import Cache
|
||||
from materia_server.models.database.database import DatabaseError, DatabaseMigrationError, Database
|
||||
from materia_server.models.database.cache import Cache, CacheError
|
||||
|
@ -4,6 +4,8 @@ from pydantic import BaseModel, RedisDsn
|
||||
from redis import asyncio as aioredis
|
||||
from redis.asyncio.client import Pipeline
|
||||
|
||||
class CacheError(Exception):
|
||||
pass
|
||||
|
||||
class Cache:
|
||||
def __init__(self, url: RedisDsn, pool: aioredis.ConnectionPool):
|
||||
@ -11,16 +13,22 @@ class Cache:
|
||||
self.pool: aioredis.ConnectionPool = pool
|
||||
|
||||
@staticmethod
|
||||
async def new(url: RedisDsn, encoding: str = "utf-8", decode_responses: bool = True) -> Self:
|
||||
async def new(
|
||||
url: RedisDsn,
|
||||
encoding: str = "utf-8",
|
||||
decode_responses: bool = True,
|
||||
test_connection: bool = True
|
||||
) -> Self:
|
||||
pool = aioredis.ConnectionPool.from_url(str(url), encoding = encoding, decode_responses = decode_responses)
|
||||
|
||||
try:
|
||||
connection = pool.make_connection()
|
||||
await connection.connect()
|
||||
except ConnectionError as e:
|
||||
raise e
|
||||
else:
|
||||
await connection.disconnect()
|
||||
if test_connection:
|
||||
try:
|
||||
connection = pool.make_connection()
|
||||
await connection.connect()
|
||||
except ConnectionError as e:
|
||||
raise CacheError(f"{e}")
|
||||
else:
|
||||
await connection.disconnect()
|
||||
|
||||
return Cache(
|
||||
url = url,
|
||||
@ -32,7 +40,7 @@ class Cache:
|
||||
try:
|
||||
yield aioredis.Redis(connection_pool = self.pool)
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise CacheError(f"{e}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def pipeline(self, transaction: bool = True) -> AsyncGenerator[Pipeline, Any]:
|
||||
@ -41,5 +49,5 @@ class Cache:
|
||||
try:
|
||||
yield client.pipeline(transaction = transaction)
|
||||
except Exception as e:
|
||||
raise e
|
||||
raise CacheError(f"{e}")
|
||||
|
||||
|
@ -16,6 +16,12 @@ from materia_server.models.base import Base
|
||||
|
||||
__all__ = [ "Database" ]
|
||||
|
||||
class DatabaseError(Exception):
|
||||
pass
|
||||
|
||||
class DatabaseMigrationError(Exception):
|
||||
pass
|
||||
|
||||
class Database:
|
||||
def __init__(self, url: PostgresDsn, engine: AsyncEngine, sessionmaker: async_sessionmaker[AsyncSession]):
|
||||
self.url: PostgresDsn = url
|
||||
@ -23,7 +29,14 @@ class Database:
|
||||
self.sessionmaker: async_sessionmaker[AsyncSession] = sessionmaker
|
||||
|
||||
@staticmethod
|
||||
def new(url: PostgresDsn, pool_size: int = 100, autocommit: bool = False, autoflush: bool = False, expire_on_commit: bool = False) -> Self:
|
||||
async def new(
|
||||
url: PostgresDsn,
|
||||
pool_size: int = 100,
|
||||
autocommit: bool = False,
|
||||
autoflush: bool = False,
|
||||
expire_on_commit: bool = False,
|
||||
test_connection: bool = True
|
||||
) -> Self:
|
||||
engine = create_async_engine(str(url), pool_size = pool_size)
|
||||
sessionmaker = async_sessionmaker(
|
||||
bind = engine,
|
||||
@ -32,12 +45,21 @@ class Database:
|
||||
expire_on_commit = expire_on_commit
|
||||
)
|
||||
|
||||
return Database(
|
||||
database = Database(
|
||||
url = url,
|
||||
engine = engine,
|
||||
sessionmaker = sessionmaker
|
||||
)
|
||||
|
||||
if test_connection:
|
||||
try:
|
||||
async with database.connection() as connection:
|
||||
await connection.rollback()
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"{e}")
|
||||
|
||||
return database
|
||||
|
||||
async def dispose(self):
|
||||
await self.engine.dispose()
|
||||
|
||||
@ -48,7 +70,7 @@ class Database:
|
||||
yield connection
|
||||
except Exception as e:
|
||||
await connection.rollback()
|
||||
raise e
|
||||
raise DatabaseError(f"{e}")
|
||||
|
||||
@asynccontextmanager
|
||||
async def session(self) -> AsyncIterator[AsyncSession]:
|
||||
@ -58,19 +80,14 @@ class Database:
|
||||
yield session
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise e
|
||||
raise DatabaseError(f"{e}")
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
def run_migrations(self, connection: Connection):
|
||||
#aconfig = AlembicConfig(Path(__file__).parent.parent.parent / "alembic.ini")
|
||||
def run_sync_migrations(self, connection: Connection):
|
||||
aconfig = AlembicConfig()
|
||||
aconfig.set_main_option("sqlalchemy.url", str(self.url))
|
||||
|
||||
|
||||
aconfig.set_main_option("script_location", str(Path(__file__).parent.parent.joinpath("migrations")))
|
||||
print(str(Path(__file__).parent.parent.joinpath("migrations")))
|
||||
|
||||
|
||||
context = MigrationContext.configure(
|
||||
connection = connection, # type: ignore
|
||||
@ -79,9 +96,15 @@ class Database:
|
||||
"fn": lambda rev, _: ScriptDirectory.from_config(aconfig)._upgrade_revs("head", rev)
|
||||
}
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
with Operations.context(context):
|
||||
context.run_migrations()
|
||||
|
||||
try:
|
||||
with context.begin_transaction():
|
||||
with Operations.context(context):
|
||||
context.run_migrations()
|
||||
except Exception as e:
|
||||
raise DatabaseMigrationError(f"{e}")
|
||||
|
||||
async def run_migrations(self):
|
||||
async with self.connection() as connection:
|
||||
await connection.run_sync(self.run_sync_migrations) # type: ignore
|
||||
|
||||
|
@ -19,9 +19,7 @@ import materia_server.models.file
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
context.configure(
|
||||
version_table_schema = "public"
|
||||
)
|
||||
|
||||
config = context.config
|
||||
|
||||
#config.set_main_option("sqlalchemy.url", Config().database.url())
|
||||
|
@ -0,0 +1,140 @@
|
||||
"""empty message
|
||||
|
||||
Revision ID: 939b37d98be0
|
||||
Revises:
|
||||
Create Date: 2024-06-24 15:39:38.380581
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '939b37d98be0'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').create(op.get_bind())
|
||||
op.create_table('login_source',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False),
|
||||
sa.Column('created', sa.Integer(), nullable=False),
|
||||
sa.Column('updated', sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('user',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('lower_name', sa.String(), nullable=False),
|
||||
sa.Column('full_name', sa.String(), nullable=True),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('is_email_private', sa.Boolean(), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(), nullable=False),
|
||||
sa.Column('must_change_password', sa.Boolean(), nullable=False),
|
||||
sa.Column('login_type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
||||
sa.Column('last_login', sa.BigInteger(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_admin', sa.Boolean(), nullable=False),
|
||||
sa.Column('avatar', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('lower_name'),
|
||||
sa.UniqueConstraint('name')
|
||||
)
|
||||
op.create_table('oauth2_application',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('client_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('hashed_client_secret', sa.String(), nullable=False),
|
||||
sa.Column('redirect_uris', sa.JSON(), nullable=False),
|
||||
sa.Column('confidential_client', sa.Boolean(), nullable=False),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('repository',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('capacity', sa.BigInteger(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('directory',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('repository_id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('parent_id', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('path', sa.String(), nullable=True),
|
||||
sa.Column('is_public', sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('oauth2_grant',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('application_id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('scope', sa.String(), nullable=False),
|
||||
sa.Column('created', sa.Integer(), nullable=False),
|
||||
sa.Column('updated', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['application_id'], ['oauth2_application.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('directory_link',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('directory_id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('url', sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['directory_id'], ['directory.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('file',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('repository_id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('parent_id', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('path', sa.String(), nullable=True),
|
||||
sa.Column('is_public', sa.Boolean(), nullable=False),
|
||||
sa.Column('size', sa.BigInteger(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ),
|
||||
sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('file_link',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('file_id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('url', sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['file_id'], ['file.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('file_link')
|
||||
op.drop_table('file')
|
||||
op.drop_table('directory_link')
|
||||
op.drop_table('oauth2_grant')
|
||||
op.drop_table('directory')
|
||||
op.drop_table('repository')
|
||||
op.drop_table('oauth2_application')
|
||||
op.drop_table('user')
|
||||
op.drop_table('login_source')
|
||||
sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').drop(op.get_bind())
|
||||
# ### end Alembic commands ###
|
Loading…
Reference in New Issue
Block a user