diff --git a/materia-server/src/materia_server/app.py b/materia-server/src/materia_server/app.py new file mode 100644 index 0000000..11bdf28 --- /dev/null +++ b/materia-server/src/materia_server/app.py @@ -0,0 +1,81 @@ +from contextlib import _AsyncGeneratorContextManager, asynccontextmanager +from os import environ +import os +from pathlib import Path +import pwd +import sys +from typing import AsyncIterator, TypedDict +import click + +from pydantic import BaseModel +from pydanclick import from_pydantic +import pydantic +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from materia_server import config as _config +from materia_server.config import Config +from materia_server._logging import make_logger, uvicorn_log_config, Logger +from materia_server.models import Database, DatabaseError, DatabaseMigrationError, Cache, CacheError +from materia_server import routers + + +class AppContext(TypedDict): + config: Config + logger: Logger + database: Database + cache: Cache + +def make_lifespan(config: Config, logger: Logger): + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]: + + try: + logger.info("Connecting to database {}", config.database.url()) + database = await Database.new(config.database.url()) # type: ignore + + logger.info("Running migrations") + await database.run_migrations() + + logger.info("Connecting to cache {}", config.cache.url()) + cache = await Cache.new(config.cache.url()) # type: ignore + except DatabaseError as e: + logger.error(f"Failed to connect postgres: {e}") + sys.exit() + except DatabaseMigrationError as e: + logger.error(f"Failed to run migrations: {e}") + sys.exit() + except CacheError as e: + logger.error(f"Failed to connect redis: {e}") + sys.exit() + + yield AppContext( + config = config, + database = database, + cache = cache, + logger = logger + ) + + if database.engine is not None: + await database.dispose() + + return lifespan + +def make_application(config: Config, logger: Logger): + app = FastAPI( + title = "materia", + version = "0.1.0", + docs_url = "/api/docs", + lifespan = make_lifespan(config, logger) + ) + app.add_middleware( + CORSMiddleware, + allow_origins = [ "http://localhost", "http://localhost:5173" ], + allow_credentials = True, + allow_methods = ["*"], + allow_headers = ["*"], + ) + app.include_router(routers.api.router) + + return app diff --git a/materia-server/src/materia_server/main.py b/materia-server/src/materia_server/main.py index 6559730..ba9ab89 100644 --- a/materia-server/src/materia_server/main.py +++ b/materia-server/src/materia_server/main.py @@ -17,49 +17,9 @@ from fastapi.middleware.cors import CORSMiddleware from materia_server import config as _config from materia_server.config import Config from materia_server._logging import make_logger, uvicorn_log_config, Logger -from materia_server.models import Database, Cache +from materia_server.models import Database, DatabaseError, Cache from materia_server import routers - - -# TODO: add cache -class AppContext(TypedDict): - config: Config - database: Database - cache: Cache - logger: Logger - -def create_lifespan(config: Config, logger): - @asynccontextmanager - async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]: - - try: - logger.info("Connecting {}", config.database.url()) - database = Database.new(config.database.url()) # type: ignore - except: - logger.error("Failed to connect postgres.") - sys.exit() - - try: - logger.info("Connecting {}", config.cache.url()) - cache = await Cache.new(config.cache.url()) # type: ignore - except: - logger.error("Failed to connect redis.") - sys.exit() - - async with database.connection() as connection: - await connection.run_sync(database.run_migrations) # type: ignore - - yield AppContext( - config = config, - database = database, - cache = cache, - logger = logger - ) - - if database.engine is not None: - await database.dispose() - - return lifespan +from materia_server.app import make_application @click.group() def server(): @@ -128,30 +88,13 @@ def start(application: _config.Application, config_path: Path, log: _config.Log) config.application.mode = application.mode - - - app = FastAPI( - title = "materia", - version = "0.1.0", - docs_url = "/api/docs", - lifespan = create_lifespan(config, logger) - ) - app.add_middleware( - CORSMiddleware, - allow_origins = [ "http://localhost", "http://localhost:5173" ], - allow_credentials = True, - allow_methods = ["*"], - allow_headers = ["*"], - ) - app.include_router(routers.api.router) - try: uvicorn.run( - app, + make_application(config, logger), port = config.server.port, host = str(config.server.address), # reload = config.application.mode == "development", - log_config = uvicorn_log_config(config) + log_config = uvicorn_log_config(config), ) except (KeyboardInterrupt, SystemExit): pass diff --git a/materia-server/src/materia_server/models/__init__.py b/materia-server/src/materia_server/models/__init__.py index 7ea79f2..15d0397 100644 --- a/materia-server/src/materia_server/models/__init__.py +++ b/materia-server/src/materia_server/models/__init__.py @@ -1,18 +1,7 @@ -#from materia_server.models.base import Base -#from materia_server.models.auth import LoginType, LoginSource, OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode -#from materia_server.models.user import User -#from materia_server.models.repository import Repository -#from materia_server.models.directory import Directory, DirectoryLink -#from materia_server.models.file import File, FileLink +from materia_server.models.auth import LoginType, LoginSource, OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode -#from materia_server.models.repository import * - -from materia_server.models.auth.source import LoginType, LoginSource -from materia_server.models.auth.oauth2 import OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode - -from materia_server.models.database.database import Database -from materia_server.models.database.cache import Cache +from materia_server.models.database import Database, DatabaseError, DatabaseMigrationError, Cache, CacheError from materia_server.models.user import User, UserCredentials, UserInfo diff --git a/materia-server/src/materia_server/models/database/__init__.py b/materia-server/src/materia_server/models/database/__init__.py index b5ac5ac..2ebc3c1 100644 --- a/materia-server/src/materia_server/models/database/__init__.py +++ b/materia-server/src/materia_server/models/database/__init__.py @@ -1,2 +1,2 @@ -from materia_server.models.database.database import Database -from materia_server.models.database.cache import Cache +from materia_server.models.database.database import DatabaseError, DatabaseMigrationError, Database +from materia_server.models.database.cache import Cache, CacheError diff --git a/materia-server/src/materia_server/models/database/cache.py b/materia-server/src/materia_server/models/database/cache.py index 886f94a..a27b1ce 100644 --- a/materia-server/src/materia_server/models/database/cache.py +++ b/materia-server/src/materia_server/models/database/cache.py @@ -4,6 +4,8 @@ from pydantic import BaseModel, RedisDsn from redis import asyncio as aioredis from redis.asyncio.client import Pipeline +class CacheError(Exception): + pass class Cache: def __init__(self, url: RedisDsn, pool: aioredis.ConnectionPool): @@ -11,16 +13,22 @@ class Cache: self.pool: aioredis.ConnectionPool = pool @staticmethod - async def new(url: RedisDsn, encoding: str = "utf-8", decode_responses: bool = True) -> Self: + async def new( + url: RedisDsn, + encoding: str = "utf-8", + decode_responses: bool = True, + test_connection: bool = True + ) -> Self: pool = aioredis.ConnectionPool.from_url(str(url), encoding = encoding, decode_responses = decode_responses) - try: - connection = pool.make_connection() - await connection.connect() - except ConnectionError as e: - raise e - else: - await connection.disconnect() + if test_connection: + try: + connection = pool.make_connection() + await connection.connect() + except ConnectionError as e: + raise CacheError(f"{e}") + else: + await connection.disconnect() return Cache( url = url, @@ -32,7 +40,7 @@ class Cache: try: yield aioredis.Redis(connection_pool = self.pool) except Exception as e: - raise e + raise CacheError(f"{e}") @asynccontextmanager async def pipeline(self, transaction: bool = True) -> AsyncGenerator[Pipeline, Any]: @@ -41,5 +49,5 @@ class Cache: try: yield client.pipeline(transaction = transaction) except Exception as e: - raise e + raise CacheError(f"{e}") diff --git a/materia-server/src/materia_server/models/database/database.py b/materia-server/src/materia_server/models/database/database.py index f63684b..5884580 100644 --- a/materia-server/src/materia_server/models/database/database.py +++ b/materia-server/src/materia_server/models/database/database.py @@ -16,6 +16,12 @@ from materia_server.models.base import Base __all__ = [ "Database" ] +class DatabaseError(Exception): + pass + +class DatabaseMigrationError(Exception): + pass + class Database: def __init__(self, url: PostgresDsn, engine: AsyncEngine, sessionmaker: async_sessionmaker[AsyncSession]): self.url: PostgresDsn = url @@ -23,7 +29,14 @@ class Database: self.sessionmaker: async_sessionmaker[AsyncSession] = sessionmaker @staticmethod - def new(url: PostgresDsn, pool_size: int = 100, autocommit: bool = False, autoflush: bool = False, expire_on_commit: bool = False) -> Self: + async def new( + url: PostgresDsn, + pool_size: int = 100, + autocommit: bool = False, + autoflush: bool = False, + expire_on_commit: bool = False, + test_connection: bool = True + ) -> Self: engine = create_async_engine(str(url), pool_size = pool_size) sessionmaker = async_sessionmaker( bind = engine, @@ -32,12 +45,21 @@ class Database: expire_on_commit = expire_on_commit ) - return Database( + database = Database( url = url, engine = engine, sessionmaker = sessionmaker ) + if test_connection: + try: + async with database.connection() as connection: + await connection.rollback() + except Exception as e: + raise DatabaseError(f"{e}") + + return database + async def dispose(self): await self.engine.dispose() @@ -48,7 +70,7 @@ class Database: yield connection except Exception as e: await connection.rollback() - raise e + raise DatabaseError(f"{e}") @asynccontextmanager async def session(self) -> AsyncIterator[AsyncSession]: @@ -58,19 +80,14 @@ class Database: yield session except Exception as e: await session.rollback() - raise e + raise DatabaseError(f"{e}") finally: await session.close() - def run_migrations(self, connection: Connection): - #aconfig = AlembicConfig(Path(__file__).parent.parent.parent / "alembic.ini") + def run_sync_migrations(self, connection: Connection): aconfig = AlembicConfig() aconfig.set_main_option("sqlalchemy.url", str(self.url)) - - aconfig.set_main_option("script_location", str(Path(__file__).parent.parent.joinpath("migrations"))) - print(str(Path(__file__).parent.parent.joinpath("migrations"))) - context = MigrationContext.configure( connection = connection, # type: ignore @@ -79,9 +96,15 @@ class Database: "fn": lambda rev, _: ScriptDirectory.from_config(aconfig)._upgrade_revs("head", rev) } ) - - with context.begin_transaction(): - with Operations.context(context): - context.run_migrations() + try: + with context.begin_transaction(): + with Operations.context(context): + context.run_migrations() + except Exception as e: + raise DatabaseMigrationError(f"{e}") + + async def run_migrations(self): + async with self.connection() as connection: + await connection.run_sync(self.run_sync_migrations) # type: ignore diff --git a/materia-server/src/materia_server/models/migrations/env.py b/materia-server/src/materia_server/models/migrations/env.py index 7bff352..44736d4 100644 --- a/materia-server/src/materia_server/models/migrations/env.py +++ b/materia-server/src/materia_server/models/migrations/env.py @@ -19,9 +19,7 @@ import materia_server.models.file # this is the Alembic Config object, which provides # access to the values within the .ini file in use. -context.configure( - version_table_schema = "public" - ) + config = context.config #config.set_main_option("sqlalchemy.url", Config().database.url()) diff --git a/materia-server/src/materia_server/models/migrations/versions/939b37d98be0_.py b/materia-server/src/materia_server/models/migrations/versions/939b37d98be0_.py new file mode 100644 index 0000000..357aaf0 --- /dev/null +++ b/materia-server/src/materia_server/models/migrations/versions/939b37d98be0_.py @@ -0,0 +1,140 @@ +"""empty message + +Revision ID: 939b37d98be0 +Revises: +Create Date: 2024-06-24 15:39:38.380581 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '939b37d98be0' +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').create(op.get_bind()) + op.create_table('login_source', + sa.Column('id', sa.BigInteger(), nullable=False), + sa.Column('type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False), + sa.Column('created', sa.Integer(), nullable=False), + sa.Column('updated', sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('user', + sa.Column('id', sa.Uuid(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('lower_name', sa.String(), nullable=False), + sa.Column('full_name', sa.String(), nullable=True), + sa.Column('email', sa.String(), nullable=False), + sa.Column('is_email_private', sa.Boolean(), nullable=False), + sa.Column('hashed_password', sa.String(), nullable=False), + sa.Column('must_change_password', sa.Boolean(), nullable=False), + sa.Column('login_type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False), + sa.Column('created', sa.BigInteger(), nullable=False), + sa.Column('updated', sa.BigInteger(), nullable=False), + sa.Column('last_login', sa.BigInteger(), nullable=True), + sa.Column('is_active', sa.Boolean(), nullable=False), + sa.Column('is_admin', sa.Boolean(), nullable=False), + sa.Column('avatar', sa.String(), nullable=True), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('lower_name'), + sa.UniqueConstraint('name') + ) + op.create_table('oauth2_application', + sa.Column('id', sa.BigInteger(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('client_id', sa.Uuid(), nullable=False), + sa.Column('hashed_client_secret', sa.String(), nullable=False), + sa.Column('redirect_uris', sa.JSON(), nullable=False), + sa.Column('confidential_client', sa.Boolean(), nullable=False), + sa.Column('created', sa.BigInteger(), nullable=False), + sa.Column('updated', sa.BigInteger(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('repository', + sa.Column('id', sa.BigInteger(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('capacity', sa.BigInteger(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('directory', + sa.Column('id', sa.BigInteger(), nullable=False), + sa.Column('repository_id', sa.BigInteger(), nullable=False), + sa.Column('parent_id', sa.BigInteger(), nullable=True), + sa.Column('created', sa.BigInteger(), nullable=False), + sa.Column('updated', sa.BigInteger(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('path', sa.String(), nullable=True), + sa.Column('is_public', sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('oauth2_grant', + sa.Column('id', sa.BigInteger(), nullable=False), + sa.Column('user_id', sa.Uuid(), nullable=False), + sa.Column('application_id', sa.BigInteger(), nullable=False), + sa.Column('scope', sa.String(), nullable=False), + sa.Column('created', sa.Integer(), nullable=False), + sa.Column('updated', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['application_id'], ['oauth2_application.id'], ondelete='CASCADE'), + sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('directory_link', + sa.Column('id', sa.BigInteger(), nullable=False), + sa.Column('directory_id', sa.BigInteger(), nullable=False), + sa.Column('created', sa.BigInteger(), nullable=False), + sa.Column('url', sa.String(), nullable=False), + sa.ForeignKeyConstraint(['directory_id'], ['directory.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('file', + sa.Column('id', sa.BigInteger(), nullable=False), + sa.Column('repository_id', sa.BigInteger(), nullable=False), + sa.Column('parent_id', sa.BigInteger(), nullable=True), + sa.Column('created', sa.BigInteger(), nullable=False), + sa.Column('updated', sa.BigInteger(), nullable=False), + sa.Column('name', sa.String(), nullable=False), + sa.Column('path', sa.String(), nullable=True), + sa.Column('is_public', sa.Boolean(), nullable=False), + sa.Column('size', sa.BigInteger(), nullable=False), + sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ), + sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + op.create_table('file_link', + sa.Column('id', sa.BigInteger(), nullable=False), + sa.Column('file_id', sa.BigInteger(), nullable=False), + sa.Column('created', sa.BigInteger(), nullable=False), + sa.Column('url', sa.String(), nullable=False), + sa.ForeignKeyConstraint(['file_id'], ['file.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('file_link') + op.drop_table('file') + op.drop_table('directory_link') + op.drop_table('oauth2_grant') + op.drop_table('directory') + op.drop_table('repository') + op.drop_table('oauth2_application') + op.drop_table('user') + op.drop_table('login_source') + sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').drop(op.get_bind()) + # ### end Alembic commands ###