materia-server: repository api, directory api, collapsed modules

This commit is contained in:
L-Nafaryus 2024-06-22 01:45:13 +05:00
parent d8b19da646
commit 317085fc04
Signed by: L-Nafaryus
GPG Key ID: 553C97999B363D38
40 changed files with 356 additions and 998 deletions

View File

@ -2,7 +2,7 @@
[alembic] [alembic]
# path to migration scripts # path to migration scripts
script_location = materia_server:models/migrations script_location = ./src/materia_server/models/migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time # Uncomment the line below if you want the files to be prepended with date and time
@ -60,7 +60,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# are written from script.py.mako # are written from script.py.mako
# output_encoding = utf-8 # output_encoding = utf-8
#sqlalchemy.url = driver://user:pass@localhost/dbname sqlalchemy.url = postgresql+asyncpg://materia:materia@127.0.0.1:54320/materia
[post_write_hooks] [post_write_hooks]

View File

@ -35,7 +35,7 @@ readme = "README.md"
license = {text = "MIT"} license = {text = "MIT"}
[tool.pdm.build] [tool.pdm.build]
includes = ["src/materia_server", "src/materia_server/alembic.ini"] includes = ["src/materia_server"]
[build-system] [build-system]
requires = ["pdm-backend"] requires = ["pdm-backend"]

View File

@ -106,6 +106,10 @@ class OAuth2(BaseModel):
#def check(self) -> Self: #def check(self) -> Self:
# if self.jwt_signing_algo in ["HS256", "HS384", "HS512"]: # if self.jwt_signing_algo in ["HS256", "HS384", "HS512"]:
# assert self.jwt_secret is not None, "JWT secret must be set for HS256, HS384, HS512 algorithms" # assert self.jwt_secret is not None, "JWT secret must be set for HS256, HS384, HS512 algorithms"
# else:
# assert self.jwt_signing_key is not None, "JWT signing key must be set"
#
# return self
class Mailer(BaseModel): class Mailer(BaseModel):
@ -171,9 +175,6 @@ class Config(BaseSettings, env_prefix = "materia_", env_nested_delimiter = "_"):
else: else:
return cwd return cwd
@staticmethod
def create(path: Path, config: Self | None = None):
config = config or Config()
pass

View File

@ -17,7 +17,7 @@ from fastapi.middleware.cors import CORSMiddleware
from materia_server import config as _config from materia_server import config as _config
from materia_server.config import Config from materia_server.config import Config
from materia_server._logging import make_logger, uvicorn_log_config, Logger from materia_server._logging import make_logger, uvicorn_log_config, Logger
from materia_server.models.database import Database, Cache from materia_server.models import Database, Cache
from materia_server import routers from materia_server import routers
@ -31,12 +31,19 @@ class AppContext(TypedDict):
def create_lifespan(config: Config, logger): def create_lifespan(config: Config, logger):
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]: async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
database = Database.new(config.database.url()) # type: ignore
try:
logger.info("Connecting {}", config.database.url())
database = Database.new(config.database.url()) # type: ignore
except:
logger.error("Failed to connect postgres.")
sys.exit()
try: try:
logger.info("Connecting {}", config.cache.url())
cache = await Cache.new(config.cache.url()) # type: ignore cache = await Cache.new(config.cache.url()) # type: ignore
except: except:
logger.error("Failed to connect redis {}", config.cache.url()) logger.error("Failed to connect redis.")
sys.exit() sys.exit()
async with database.connection() as connection: async with database.connection() as connection:
@ -64,6 +71,7 @@ def server():
@from_pydantic("log", _config.Log, prefix = "log") @from_pydantic("log", _config.Log, prefix = "log")
def start(application: _config.Application, config_path: Path, log: _config.Log): def start(application: _config.Application, config_path: Path, log: _config.Log):
config = Config() config = Config()
config.log = log
logger = make_logger(config) logger = make_logger(config)
#if user := application.user: #if user := application.user:
@ -71,8 +79,12 @@ def start(application: _config.Application, config_path: Path, log: _config.Log)
#if group := application.group: #if group := application.group:
# os.setgid(pwd.getpwnam(user).pw_gid) # os.setgid(pwd.getpwnam(user).pw_gid)
# TODO: merge cli options with config # TODO: merge cli options with config
if working_directory := (application.working_directory or config.application.working_directory): if working_directory := (application.working_directory or config.application.working_directory).resolve():
os.chdir(working_directory.resolve()) try:
os.chdir(working_directory)
except FileNotFoundError as e:
logger.error("Failed to change working directory: {}", e)
sys.exit()
logger.debug(f"Current working directory: {working_directory}") logger.debug(f"Current working directory: {working_directory}")
# check the configuration file or use default # check the configuration file or use default
@ -106,6 +118,13 @@ def start(application: _config.Application, config_path: Path, log: _config.Log)
config.log.level = log.level config.log.level = log.level
logger = make_logger(config) logger = make_logger(config)
if (working_directory := config.application.working_directory.resolve()):
logger.debug(f"Change working directory: {working_directory}")
try:
os.chdir(working_directory)
except FileNotFoundError as e:
logger.error("Failed to change working directory: {}", e)
sys.exit()
config.application.mode = application.mode config.application.mode = application.mode

View File

@ -6,4 +6,18 @@
#from materia_server.models.directory import Directory, DirectoryLink #from materia_server.models.directory import Directory, DirectoryLink
#from materia_server.models.file import File, FileLink #from materia_server.models.file import File, FileLink
#from materia_server.models.repository import *
from materia_server.models.auth.source import LoginType, LoginSource
from materia_server.models.auth.oauth2 import OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode
from materia_server.models.database.database import Database
from materia_server.models.database.cache import Cache
from materia_server.models.user import User, UserCredentials, UserInfo
from materia_server.models.repository import Repository, RepositoryInfo
from materia_server.models.directory import Directory, DirectoryLink, DirectoryInfo
from materia_server.models.file import File, FileLink

View File

@ -1,4 +1,5 @@
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import os
from typing import AsyncIterator, Self from typing import AsyncIterator, Self
from pathlib import Path from pathlib import Path
@ -10,6 +11,7 @@ from alembic.operations import Operations
from alembic.runtime.migration import MigrationContext from alembic.runtime.migration import MigrationContext
from alembic.script.base import ScriptDirectory from alembic.script.base import ScriptDirectory
from materia_server.config import Config
from materia_server.models.base import Base from materia_server.models.base import Base
__all__ = [ "Database" ] __all__ = [ "Database" ]
@ -61,14 +63,20 @@ class Database:
await session.close() await session.close()
def run_migrations(self, connection: Connection): def run_migrations(self, connection: Connection):
config = AlembicConfig(Path(__file__).parent.parent.parent / "alembic.ini") #aconfig = AlembicConfig(Path(__file__).parent.parent.parent / "alembic.ini")
config.set_main_option("sqlalchemy.url", self.url) # type: ignore aconfig = AlembicConfig()
aconfig.set_main_option("sqlalchemy.url", str(self.url))
aconfig.set_main_option("script_location", str(Path(__file__).parent.parent.joinpath("migrations")))
print(str(Path(__file__).parent.parent.joinpath("migrations")))
context = MigrationContext.configure( context = MigrationContext.configure(
connection = connection, # type: ignore connection = connection, # type: ignore
opts = { opts = {
"target_metadata": Base.metadata, "target_metadata": Base.metadata,
"fn": lambda rev, _: ScriptDirectory.from_config(config)._upgrade_revs("head", rev) "fn": lambda rev, _: ScriptDirectory.from_config(aconfig)._upgrade_revs("head", rev)
} }
) )

View File

@ -1,10 +1,14 @@
from time import time from time import time
from typing import List from typing import List, Optional, Self
from pathlib import Path
from sqlalchemy import BigInteger, ForeignKey from sqlalchemy import BigInteger, ForeignKey
from sqlalchemy.orm import mapped_column, Mapped, relationship from sqlalchemy.orm import mapped_column, Mapped, relationship
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict
from materia_server.models.base import Base from materia_server.models.base import Base
from materia_server.models import database
class Directory(Base): class Directory(Base):
@ -12,7 +16,7 @@ class Directory(Base):
id: Mapped[int] = mapped_column(BigInteger, primary_key = True) id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
repository_id: Mapped[int] = mapped_column(ForeignKey("repository.id", ondelete = "CASCADE")) repository_id: Mapped[int] = mapped_column(ForeignKey("repository.id", ondelete = "CASCADE"))
parent_id: Mapped[int] = mapped_column(ForeignKey("directory.id"), nullable = True) parent_id: Mapped[int] = mapped_column(ForeignKey("directory.id", ondelete = "CASCADE"), nullable = True)
created: Mapped[int] = mapped_column(BigInteger, nullable = False, default = time) created: Mapped[int] = mapped_column(BigInteger, nullable = False, default = time)
updated: Mapped[int] = mapped_column(BigInteger, nullable = False, default = time) updated: Mapped[int] = mapped_column(BigInteger, nullable = False, default = time)
name: Mapped[str] name: Mapped[str]
@ -25,6 +29,14 @@ class Directory(Base):
files: Mapped[List["File"]] = relationship(back_populates = "parent") files: Mapped[List["File"]] = relationship(back_populates = "parent")
link: Mapped["DirectoryLink"] = relationship(back_populates = "directory") link: Mapped["DirectoryLink"] = relationship(back_populates = "directory")
@staticmethod
async def by_path(repository_id: int, path: Path | None, name: str, db: database.Database) -> Self | None:
async with db.session() as session:
query_path = Directory.path == str(path) if isinstance(path, Path) else Directory.path.is_(None)
return (await session
.scalars(sa.select(Directory)
.where(sa.and_(Directory.repository_id == repository_id, Directory.name == name, query_path)))
).first()
class DirectoryLink(Base): class DirectoryLink(Base):
__tablename__ = "directory_link" __tablename__ = "directory_link"
@ -36,5 +48,20 @@ class DirectoryLink(Base):
directory: Mapped["Directory"] = relationship(back_populates = "link") directory: Mapped["Directory"] = relationship(back_populates = "link")
from materia_server.models.repository.repository import Repository class DirectoryInfo(BaseModel):
from materia_server.models.file.file import File model_config = ConfigDict(from_attributes = True)
id: int
repository_id: int
parent_id: Optional[int]
created: int
updated: int
name: str
path: Optional[str]
is_public: bool
used: Optional[int] = None
from materia_server.models.repository import Repository
from materia_server.models.file import File

View File

@ -1 +0,0 @@
from materia_server.models.directory.directory import Directory, DirectoryLink

View File

@ -35,5 +35,5 @@ class FileLink(Base):
file: Mapped["File"] = relationship(back_populates = "link") file: Mapped["File"] = relationship(back_populates = "link")
from materia_server.models.repository.repository import Repository from materia_server.models.repository import Repository
from materia_server.models.directory.directory import Directory from materia_server.models.directory import Directory

View File

@ -1 +0,0 @@
from materia_server.models.file.file import File, FileLink

View File

@ -19,8 +19,12 @@ import materia_server.models.file
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
context.configure(
version_table_schema = "public"
)
config = context.config config = context.config
config.set_main_option("sqlalchemy.url", Config().database.url())
#config.set_main_option("sqlalchemy.url", Config().database.url())
# Interpret the config file for Python logging. # Interpret the config file for Python logging.
# This line sets up loggers basically. # This line sets up loggers basically.
@ -59,6 +63,7 @@ def run_migrations_offline() -> None:
target_metadata=target_metadata, target_metadata=target_metadata,
literal_binds=True, literal_binds=True,
dialect_opts={"paramstyle": "named"}, dialect_opts={"paramstyle": "named"},
version_table_schema = "public"
) )
with context.begin_transaction(): with context.begin_transaction():
@ -99,4 +104,5 @@ def run_migrations_online() -> None:
if context.is_offline_mode(): if context.is_offline_mode():
run_migrations_offline() run_migrations_offline()
else: else:
print("online")
run_migrations_online() run_migrations_online()

View File

@ -1,140 +0,0 @@
"""empty message
Revision ID: 76191498b728
Revises:
Create Date: 2024-06-03 18:44:07.044588
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '76191498b728'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').create(op.get_bind())
op.create_table('login_source',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False),
sa.Column('created', sa.Integer(), nullable=False),
sa.Column('updated', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user',
sa.Column('id', sa.Uuid(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('lower_name', sa.String(), nullable=False),
sa.Column('full_name', sa.String(), nullable=False),
sa.Column('email', sa.String(), nullable=False),
sa.Column('is_email_private', sa.Boolean(), nullable=False),
sa.Column('hashed_password', sa.String(), nullable=False),
sa.Column('must_change_password', sa.Boolean(), nullable=False),
sa.Column('login_type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('updated', sa.BigInteger(), nullable=False),
sa.Column('last_login', sa.BigInteger(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('is_admin', sa.Boolean(), nullable=False),
sa.Column('avatar', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('lower_name'),
sa.UniqueConstraint('name')
)
op.create_table('oauth2_application',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('client_id', sa.Uuid(), nullable=False),
sa.Column('hashed_client_secret', sa.String(), nullable=False),
sa.Column('redirect_uris', sa.JSON(), nullable=False),
sa.Column('confidential_client', sa.Boolean(), nullable=False),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('updated', sa.BigInteger(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('repository',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('capacity', sa.BigInteger(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
sa.PrimaryKeyConstraint('id')
)
op.create_table('directory',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('repository_id', sa.BigInteger(), nullable=False),
sa.Column('parent_id', sa.BigInteger(), nullable=True),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('updated', sa.BigInteger(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('path', sa.String(), nullable=True),
sa.Column('is_public', sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ),
sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('oauth2_grant',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('user_id', sa.Uuid(), nullable=False),
sa.Column('application_id', sa.BigInteger(), nullable=False),
sa.Column('scope', sa.String(), nullable=False),
sa.Column('created', sa.Integer(), nullable=False),
sa.Column('updated', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['application_id'], ['oauth2_application.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('directory_link',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('directory_id', sa.BigInteger(), nullable=False),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('url', sa.String(), nullable=False),
sa.ForeignKeyConstraint(['directory_id'], ['directory.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('file',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('repository_id', sa.BigInteger(), nullable=False),
sa.Column('parent_id', sa.BigInteger(), nullable=True),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('updated', sa.BigInteger(), nullable=False),
sa.Column('name', sa.String(), nullable=False),
sa.Column('path', sa.String(), nullable=True),
sa.Column('is_public', sa.Boolean(), nullable=False),
sa.Column('size', sa.BigInteger(), nullable=False),
sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ),
sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_table('file_link',
sa.Column('id', sa.BigInteger(), nullable=False),
sa.Column('file_id', sa.BigInteger(), nullable=False),
sa.Column('created', sa.BigInteger(), nullable=False),
sa.Column('url', sa.String(), nullable=False),
sa.ForeignKeyConstraint(['file_id'], ['file.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('file_link')
op.drop_table('file')
op.drop_table('directory_link')
op.drop_table('oauth2_grant')
op.drop_table('directory')
op.drop_table('repository')
op.drop_table('oauth2_application')
op.drop_table('user')
op.drop_table('login_source')
sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').drop(op.get_bind())
# ### end Alembic commands ###

View File

@ -0,0 +1,51 @@
from time import time
from typing import List, Self
from uuid import UUID, uuid4
from sqlalchemy import BigInteger, ForeignKey
from sqlalchemy.orm import mapped_column, Mapped, relationship
from sqlalchemy.orm.attributes import InstrumentedAttribute
import sqlalchemy as sa
from pydantic import BaseModel
from materia_server.models.base import Base
from materia_server.models import database
class Repository(Base):
__tablename__ = "repository"
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
capacity: Mapped[int] = mapped_column(BigInteger, nullable = False)
user: Mapped["User"] = relationship(back_populates = "repository")
directories: Mapped[List["Directory"]] = relationship(back_populates = "repository")
files: Mapped[List["File"]] = relationship(back_populates = "repository")
def to_dict(self) -> dict:
return { k: getattr(self, k) for k, v in Repository.__dict__.items() if isinstance(v, InstrumentedAttribute) }
async def create(self, db: database.Database):
async with db.session() as session:
session.add(self)
await session.commit()
async def update(self, db: database.Database):
async with db.session() as session:
await session.execute(sa.update(Repository).where(Repository.id == self.id).values(self.to_dict()))
await session.commit()
@staticmethod
async def by_user_id(user_id: UUID, db: database.Database) -> Self | None:
async with db.session() as session:
return (await session.scalars(sa.select(Repository).where(Repository.user_id == user_id))).first()
class RepositoryInfo(BaseModel):
capacity: int
used: int
from materia_server.models.user import User
from materia_server.models.directory import Directory
from materia_server.models.file import File

View File

@ -1 +0,0 @@
from materia_server.models.repository.repository import Repository

View File

@ -1,25 +0,0 @@
from time import time
from typing import List
from uuid import UUID, uuid4
from sqlalchemy import BigInteger, ForeignKey
from sqlalchemy.orm import mapped_column, Mapped, relationship
from materia_server.models.base import Base
class Repository(Base):
__tablename__ = "repository"
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
capacity: Mapped[int] = mapped_column(BigInteger, nullable = False)
user: Mapped["User"] = relationship(back_populates = "repository")
directories: Mapped[List["Directory"]] = relationship(back_populates = "repository")
files: Mapped[List["File"]] = relationship(back_populates = "repository")
from materia_server.models.user.user import User
from materia_server.models.directory.directory import Directory
from materia_server.models.file.file import File

View File

@ -80,9 +80,10 @@ class UserCredentials(BaseModel):
password: str password: str
email: Optional[EmailStr] email: Optional[EmailStr]
class UserIdentity(BaseModel): class UserInfo(BaseModel):
model_config = ConfigDict(from_attributes = True) model_config = ConfigDict(from_attributes = True)
id: UUID
name: str name: str
lower_name: str lower_name: str
full_name: Optional[str] full_name: Optional[str]
@ -101,4 +102,4 @@ class UserIdentity(BaseModel):
avatar: Optional[str] avatar: Optional[str]
from materia_server.models.repository.repository import Repository from materia_server.models.repository import Repository

View File

@ -1 +0,0 @@
from materia_server.models.user.user import User, UserCredentials, UserIdentity

View File

@ -1,2 +1 @@
from materia_server.routers import api from materia_server.routers import middleware, api
from materia_server.routers import middleware

View File

@ -1,8 +1,10 @@
from fastapi import APIRouter from fastapi import APIRouter
from materia_server.routers.api import auth from materia_server.routers.api.auth import auth, oauth
from materia_server.routers.api import user from materia_server.routers.api import user, repository, directory
router = APIRouter(prefix = "/api")
router = APIRouter()
router.include_router(auth.router) router.include_router(auth.router)
router.include_router(oauth.router)
router.include_router(user.router) router.include_router(user.router)
router.include_router(repository.router)
router.include_router(directory.router)

View File

@ -1,8 +0,0 @@
from fastapi import APIRouter
from materia_server.routers.api.auth import auth
from materia_server.routers.api.auth import oauth
router = APIRouter()
router.include_router(auth.router)
router.include_router(oauth.router)

View File

@ -3,33 +3,32 @@ from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Response, status from fastapi import APIRouter, Depends, HTTPException, Response, status
from materia_server import security from materia_server import security
from materia_server.routers import context from materia_server.routers.middleware import Context
from materia_server.models import user from materia_server.models import LoginType, User, UserCredentials
from materia_server.models import auth
router = APIRouter(tags = ["auth"]) router = APIRouter(tags = ["auth"])
@router.post("/auth/signup") @router.post("/auth/signup")
async def signup(body: user.UserCredentials, ctx: context.Context = Depends()): async def signup(body: UserCredentials, ctx: Context = Depends()):
if not user.User.is_valid_username(body.name): if not User.is_valid_username(body.name):
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "Invalid username") raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "Invalid username")
if await user.User.by_name(body.name, ctx.database) is not None: if await User.by_name(body.name, ctx.database) is not None:
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "User already exists") raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "User already exists")
if await user.User.by_email(body.email, ctx.database) is not None: # type: ignore if await User.by_email(body.email, ctx.database) is not None: # type: ignore
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "Email already used") raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "Email already used")
if len(body.password) < ctx.config.security.password_min_length: if len(body.password) < ctx.config.security.password_min_length:
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = f"Password is too short (minimum length {ctx.config.security.password_min_length})") raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = f"Password is too short (minimum length {ctx.config.security.password_min_length})")
count: Optional[int] = await user.User.count(ctx.database) count: Optional[int] = await User.count(ctx.database)
new_user = user.User( new_user = User(
name = body.name, name = body.name,
lower_name = body.name.lower(), lower_name = body.name.lower(),
full_name = body.name, full_name = body.name,
email = body.email, email = body.email,
hashed_password = security.hash_password(body.password, algo = ctx.config.security.password_hash_algo), hashed_password = security.hash_password(body.password, algo = ctx.config.security.password_hash_algo),
login_type = auth.LoginType.Plain, login_type = LoginType.Plain,
# first registered user is admin # first registered user is admin
is_admin = count == 0 is_admin = count == 0
) )
@ -39,9 +38,9 @@ async def signup(body: user.UserCredentials, ctx: context.Context = Depends()):
await session.commit() await session.commit()
@router.post("/auth/signin") @router.post("/auth/signin")
async def signin(body: user.UserCredentials, response: Response, ctx: context.Context = Depends()): async def signin(body: UserCredentials, response: Response, ctx: Context = Depends()):
if (current_user := await user.User.by_name(body.name, ctx.database)) is None: if (current_user := await User.by_name(body.name, ctx.database)) is None:
if (current_user := await user.User.by_email(str(body.email), ctx.database)) is None: if (current_user := await User.by_email(str(body.email), ctx.database)) is None:
raise HTTPException(status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid email") raise HTTPException(status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid email")
if not security.validate_password(body.password, current_user.hashed_password, algo = ctx.config.security.password_hash_algo): if not security.validate_password(body.password, current_user.hashed_password, algo = ctx.config.security.password_hash_algo):
raise HTTPException(status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid password") raise HTTPException(status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid password")
@ -79,6 +78,6 @@ async def signin(body: user.UserCredentials, response: Response, ctx: context.Co
) )
@router.get("/auth/signout") @router.get("/auth/signout")
async def signout(response: Response, ctx: context.Context = Depends()): async def signout(response: Response, ctx: Context = Depends()):
response.delete_cookie(ctx.config.security.cookie_access_token_name) response.delete_cookie(ctx.config.security.cookie_access_token_name)
response.delete_cookie(ctx.config.security.cookie_refresh_token_name) response.delete_cookie(ctx.config.security.cookie_refresh_token_name)

View File

@ -7,9 +7,8 @@ from fastapi.security.oauth2 import OAuth2PasswordRequestForm
from pydantic import BaseModel, HttpUrl from pydantic import BaseModel, HttpUrl
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
from materia_server.models import auth from materia_server.models import User
from materia_server.models.user import user from materia_server.routers.middleware import Context
from materia_server.routers import context
router = APIRouter(tags = ["oauth2"]) router = APIRouter(tags = ["oauth2"])
@ -35,17 +34,17 @@ class AuthorizationCodeResponse(BaseModel):
code: str code: str
@router.post("/oauth2/authorize") @router.post("/oauth2/authorize")
async def authorize(form: Annotated[OAuth2AuthorizationCodeRequestForm, Depends()], ctx: context.Context = Depends()): async def authorize(form: Annotated[OAuth2AuthorizationCodeRequestForm, Depends()], ctx: Context = Depends()):
# grant_type: authorization_code, password_credentials, client_credentials, authorization_code (pkce) # grant_type: authorization_code, password_credentials, client_credentials, authorization_code (pkce)
ctx.logger.debug(form) ctx.logger.debug(form)
if form.grant_type == "authorization_code": if form.grant_type == "authorization_code":
# TODO: form validation # TODO: form validation
if not (app := await auth.OAuth2Application.by_client_id(form.client_id, ctx.database)): if not (app := await OAuth2Application.by_client_id(form.client_id, ctx.database)):
raise HTTPException(status_code = HTTP_500_INTERNAL_SERVER_ERROR, detail = "Client ID not registered") raise HTTPException(status_code = HTTP_500_INTERNAL_SERVER_ERROR, detail = "Client ID not registered")
if not (owner := user.User.by_id(app.user_id, ctx.database)): if not (owner := await User.by_id(app.user_id, ctx.database)):
raise HTTPException(status_code = HTTP_500_INTERNAL_SERVER_ERROR, detail = "User not found") raise HTTPException(status_code = HTTP_500_INTERNAL_SERVER_ERROR, detail = "User not found")
if not app.contains_redirect_uri(form.redirect_uri): if not app.contains_redirect_uri(form.redirect_uri):
@ -79,5 +78,5 @@ class AccessTokenResponse(BaseModel):
scope: Optional[str] scope: Optional[str]
@router.post("/oauth2/access_token") @router.post("/oauth2/access_token")
async def token(ctx: context.Context = Depends()): async def token(ctx: Context = Depends()):
pass pass

View File

@ -1,143 +0,0 @@
import os
import time
from pathlib import Path
from typing import Annotated
import bcrypt
from fastapi import APIRouter, Depends, HTTPException, Request, Response, UploadFile, status
from fastapi.responses import RedirectResponse, StreamingResponse
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordRequestFormStrict
import httpx
from sqlalchemy import and_, insert, select, update
from authlib.integrations.starlette_client import OAuth, OAuthError
import base64
from cryptography.fernet import Fernet
import json
from materia import db
from materia.api import schema
from materia.api.state import ConfigState, DatabaseState
from materia.api.middleware import JwtMiddleware
from materia.api.token import TokenClaims
from materia.config import Config
oauth = OAuth()
oauth.register(
"materia",
authorize_url = "http://127.0.0.1:54601/api/auth/authorize",
access_token_url = "http://127.0.0.1:54601/api/auth/token",
scope = "user:read",
client_id = "",
client_secret = ""
)
class OAuth2Provider:
pass
router = APIRouter(tags = ["auth"])
@router.get("/user/signin")
async def signin(request: Request, provider: str = None):
if not provider:
return RedirectResponse("/api/auth/authorize")
else:
return RedirectResponse(request.url_for(provider.authorize_url))
@router.post("/auth/test_auth")
async def test_auth(database: DatabaseState = Depends()):
async with httpx.AsyncClient() as client:
response = await client.post("https://vcs.elnafo.ru/login/oauth/authorize", data = {
"client_id": "1edfe-0bbe-4f53-bab6-7e24f0b842e3",
"client_secret": "gto_7ecfnqg2c6kbe2qf25wjee237mmkxvbkb7arjacyvtypi24hqv4q",
"response_type": "code",
"redirect_uri": "http://127.0.0.1:54601"
})
return response.content, response.status_code
@router.post("/auth/provider")
async def provider(form: Annotated[OAuth2PasswordRequestForm, Depends()], database: DatabaseState = Depends()):
async with httpx.AsyncClient() as client:
response = await client.post("https://vcs.elnafo.ru/login/oauth/access_token", data = {
"client_id": "1edfec03-0bbe-4f53-bab6-7e24f0b842e3",
"client_secret": "gto_7ecfnqg2c6kbe2qf25wjee237mmkxvbkb7arjacyvtypi24hqv4q",
"grant_type": "authorization_code",
"code": "gta_63l6zogw5wlnkeng4gf3buqtoekkaxk7zhr67zlkyrv2ukwfeava"
})
return response.content, response.status_code
@router.post("/auth/authorize")
async def authorize(form: Annotated[OAuth2PasswordRequestForm, Depends()], database: DatabaseState = Depends()):
if form.client_id:
async with database.session() as session:
if not (user := (await session.scalars(select(db.User).where(db.User.login_name == form.username))).first()):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
await session.refresh(user, attribute_names = ["oauth2_apps"])
oauth2_app = None
for app in user.oauth2_apps:
if form.client_id == app.client_id and bcrypt.checkpw(form.client_secret.encode(), app.client_secret):
oauth2_app = app
if not oauth2_app:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid client id")
data = json.dumps({"client_id": form.client_id}).encode()
else:
async with database.session() as session:
if not (user := (await session.scalars(select(db.User).where(db.User.login_name == form.username))).first()):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user credentials")
if not bcrypt.checkpw(form.password.encode(), user.hashed_password.encode()):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid password")
data = json.dumps({"username": form.username}).encode()
key = b'sGEuUeKrooiNAy7L9sf6IFIjpv86TC9iYU_sbWqA-1c=' # Fernet.generate_key()
f = Fernet(key)
code = base64.b64encode(f.encrypt(data), b"-_").decode().replace("=", "")
global storage
storage = code
return code
storage = None
@router.post("/auth/token")
async def token(exchange: schema.Exchange, response: Response, config: ConfigState = Depends()):
if exchange.grant_type == "authorization_code":
if not exchange.code:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Missing authorization code")
# expiration
if exchange.code != storage:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Invalid authorization code")
token = TokenClaims.create(
"asd",
config.jwt.secret,
config.jwt.maxage
)
response.set_cookie(
"token",
value = token,
max_age = config.jwt.maxage,
secure = True,
httponly = True,
samesite = "none"
)
return schema.AccessToken(
access_token = token,
token_type = "Bearer",
expires_in = config.jwt.maxage,
refresh_token = token,
scope = "identify"
)
elif exchange.grant_type == "refresh_token":
pass
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST)

View File

@ -1,44 +1,40 @@
import os import os
import time
from pathlib import Path from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
from fastapi.responses import StreamingResponse
from sqlalchemy import and_, insert, select, update
from materia import db from fastapi import APIRouter, Depends, HTTPException, status
from materia.api.state import ConfigState, DatabaseState
from materia.api.middleware import JwtMiddleware from materia_server.models import User, Directory, DirectoryInfo
from materia.config import Config from materia_server.models.directory import DirectoryInfo
from materia.api import schema from materia_server.routers import middleware
from materia_server.config import Config
router = APIRouter(tags = ["directory"]) router = APIRouter(tags = ["directory"])
@router.post("/directory", dependencies = [Depends(JwtMiddleware())]) @router.post("/directory")
async def create(request: Request, path: Path = Path(), config: ConfigState = Depends(), database: DatabaseState = Depends()): async def create(path: Path = Path(), user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
user = request.state.user repository_path = Config.data_dir() / "repository" / user.lower_name
repository_path = Config.data_dir() / "repository" / user.login_name.lower()
blacklist = [os.sep, ".", "..", "*"] blacklist = [os.sep, ".", "..", "*"]
directory_path = Path(os.sep.join(filter(lambda part: part not in blacklist, path.parts))) directory_path = Path(os.sep.join(filter(lambda part: part not in blacklist, path.parts)))
async with database.session() as session: async with ctx.database.session() as session:
session.add(user) session.add(user)
await session.refresh(user, attribute_names = ["repository"]) await session.refresh(user, attribute_names = ["repository"])
if not user.repository:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Repository is not found")
current_directory = None current_directory = None
current_path = Path() current_path = Path()
directory = None directory = None
for part in directory_path.parts: for part in directory_path.parts:
if not (directory := (await session if not await Directory.by_path(user.repository.id, current_path, part, ctx.database):
.scalars(select(db.Directory) directory = Directory(
.where(and_(db.Directory.name == part, db.Directory.path == str(current_path))))
).first()):
directory = db.Directory(
repository_id = user.repository.id, repository_id = user.repository.id,
parent_id = current_directory.id if current_directory else None, parent_id = current_directory.id if current_directory else None,
name = part, name = part,
path = str(current_path) path = None if current_path == Path() else str(current_path)
) )
session.add(directory) session.add(directory)
@ -52,23 +48,20 @@ async def create(request: Request, path: Path = Path(), config: ConfigState = De
await session.commit() await session.commit()
@router.get("/directory", dependencies = [Depends(JwtMiddleware())]) @router.get("/directory")
async def info(request: Request, repository_id: int, path: Path, config: ConfigState = Depends(), database: DatabaseState = Depends()): async def info(path: Path, user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
async with database.session() as session: async with ctx.database.session() as session:
if directory := (await session session.add(user)
.scalars(select(db.Directory) await session.refresh(user, attribute_names = ["repository"])
.where(and_(db.Directory.repository_id == repository_id, db.Directory.name == path.name, db.Directory.path == path.parent))
)).first(): if not(directory := await Directory.by_path(user.repository.id, None if path.parent == Path() else path.parent, path.name, ctx.database)):
await session.refresh(directory, attribute_names = ["files"]) raise HTTPException(status.HTTP_404_NOT_FOUND, "Directory is not found")
return schema.DirectoryInfo(
id = directory.id, session.add(directory)
created_at = directory.created_unix, await session.refresh(directory, attribute_names = ["files"])
updated_at = directory.updated_unix,
name = directory.name, info = DirectoryInfo.model_validate(directory)
path = directory.path, info.used = sum([ file.size for file in directory.files ])
is_public = directory.is_public,
used = sum([ file.size for file in directory.files ]) return info
)
else:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Repository is not found")

View File

@ -1,51 +0,0 @@
import os
import time
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
from fastapi.responses import StreamingResponse
from sqlalchemy import and_, insert, select, update
from materia import db
from materia.api import schema
from materia.api.state import ConfigState, DatabaseState
from materia.api.middleware import JwtMiddleware
from materia.config import Config
from materia.api import repository, directory
router = APIRouter(tags = ["file"])
@router.put("/file", dependencies = [Depends(JwtMiddleware())])
async def upload(request: Request, file: UploadFile, directory_path: Path = Path(), config: ConfigState = Depends(), database: DatabaseState = Depends()):
user = request.state.user
try:
await repository.create(request, config = config, database = database)
except:
pass
#try:
# directory_info = directory.info
# await directory.create(request, path = directory_path, config = config, database = database)
async with database.session() as session:
if file_ := (await session
.scalars(select(db.File)
.where(and_(db.File.name == file.filename, db.File.path == str(directory_path))))
).first():
await session.execute(update(db.File).where(db.File.id == file_.id).values(updated_unix = time.time(), size = file.size))
else:
file_ = db.File(
repository_id = user.repository.id,
parent_id = directory.id if directory else None,
name = file.filename,
path = str(directory_path),
size = file.size
)
session.add(file_)
try:
(repository_path / directory_path / file.filename).write_bytes(await file.read())
except OSError:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to write a file")
await session.commit()

View File

@ -1,91 +0,0 @@
import os
import time
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
from fastapi.responses import StreamingResponse
from sqlalchemy import and_, insert, select, update
from materia import db
from materia.api.state import ConfigState, DatabaseState
from materia.api.middleware import JwtMiddleware
from materia.config import Config
from materia.api import repository
router = APIRouter(tags = ["filesystem"])
@router.get("/play")
async def play():
def iterfile():
with open(Config.data_dir() / ".." / "bfg.mp3", mode="rb") as file_like: #
yield from file_like #
return StreamingResponse(iterfile(), media_type="audio/mp3")
@router.put("/file/upload", dependencies = [Depends(JwtMiddleware())])
async def upload(request: Request, file: UploadFile, config: ConfigState = Depends(), database: DatabaseState = Depends(), directory_path: Path = Path()):
user = request.state.user
repository_path = Config.data_dir() / "repository" / user.login_name.lower()
blacklist = [os.sep, ".", "..", "*"]
directory_path = Path(os.sep.join(filter(lambda part: part not in blacklist, directory_path.parts)))
try:
await repository.create(request, config = config, database = database)
except:
pass
async with database.session() as session:
session.add(user)
await session.refresh(user, attribute_names = ["repository"])
current_directory = None
current_path = Path()
directory = None
for part in directory_path.parts:
if not (directory := (await session
.scalars(select(db.Directory)
.where(and_(db.Directory.name == part, db.Directory.path == str(current_path))))
).first()):
directory = db.Directory(
repository_id = user.repository.id,
parent_id = current_directory.id if current_directory else None,
name = part,
path = str(current_path)
)
session.add(directory)
current_directory = directory
current_path /= part
try:
(repository_path / directory_path).mkdir(parents = True, exist_ok = True)
except OSError:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to created a directory")
await session.commit()
async with database.session() as session:
if file_ := (await session
.scalars(select(db.File)
.where(and_(db.File.name == file.filename, db.File.path == str(directory_path))))
).first():
await session.execute(update(db.File).where(db.File.id == file_.id).values(updated_unix = time.time(), size = file.size))
else:
file_ = db.File(
repository_id = user.repository.id,
parent_id = directory.id if directory else None,
name = file.filename,
path = str(directory_path),
size = file.size
)
session.add(file_)
try:
(repository_path / directory_path / file.filename).write_bytes(await file.read())
except OSError:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to write a file")
await session.commit()

View File

@ -1,60 +1,45 @@
import os from fastapi import APIRouter, Depends, HTTPException, status
import time
from pathlib import Path
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
from fastapi.responses import StreamingResponse
from sqlalchemy import and_, insert, select, update
from materia import db from materia_server.models import User, Repository, RepositoryInfo
from materia.api import schema from materia_server.routers import middleware
from materia.api.state import ConfigState, DatabaseState from materia_server.config import Config
from materia.api.middleware import JwtMiddleware
from materia.config import Config
router = APIRouter(tags = ["repository"]) router = APIRouter(tags = ["repository"])
@router.post("/repository", dependencies = [Depends(JwtMiddleware())]) @router.post("/repository")
async def create(request: Request, config: ConfigState = Depends(), database: DatabaseState = Depends()): async def create(user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
user = request.state.user repository_path = Config.data_dir() / "repository" / user.lower_name
repository_path = Config.data_dir() / "repository" / user.login_name.lower()
async with database.session() as session: if await Repository.by_user_id(user.id, ctx.database):
raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists")
repository = Repository(
user_id = user.id,
capacity = ctx.config.repository.capacity
)
try:
repository_path.mkdir(parents = True, exist_ok = True)
except OSError:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to created a repository")
await repository.create(ctx.database)
@router.get("/repository", response_model = RepositoryInfo)
async def info(user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
async with ctx.database.session() as session:
session.add(user) session.add(user)
await session.refresh(user, attribute_names = ["repository"]) await session.refresh(user, attribute_names = ["repository"])
if not (repository := user.repository): if not (repository := user.repository):
repository = db.Repository(
owner_id = user.id,
capacity = config.repository.capacity
)
session.add(repository)
try:
repository_path.mkdir(parents = True, exist_ok = True)
except OSError:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to created a repository")
await session.commit()
else:
raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists")
@router.get("/repository", dependencies = [Depends(JwtMiddleware())])
async def info(request: Request, database: DatabaseState = Depends()):
user = request.state.user
async with database.session() as session:
session.add(user)
await session.refresh(user, attribute_names = ["repository"])
if repository := user.repository:
await session.refresh(repository, attribute_names = ["files"])
return schema.RepositoryInfo(
capacity = repository.capacity,
used = sum([ file.size for file in repository.files ])
)
else:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Repository is not found") raise HTTPException(status.HTTP_404_NOT_FOUND, "Repository is not found")
await session.refresh(repository, attribute_names = ["files"])
return RepositoryInfo(
capacity = repository.capacity,
used = sum([ file.size for file in repository.files ])
)

View File

@ -1,5 +0,0 @@
from materia.api.schema.user import NewUser, User, RemoveUser, LoginUser
from materia.api.schema.token import Token
from materia.api.schema.repository import RepositoryInfo
from materia.api.schema.directory import DirectoryInfo
from materia.api.schema.auth import AccessToken, Exchange

View File

@ -1,25 +0,0 @@
from typing import Optional
from pydantic import BaseModel
class AuthCode(BaseModel):
client_id: str
response_type: str
state: str
redirect_uri: Optional[str]
scope: Optional[str]
class Exchange(BaseModel):
grant_type: str
client_id: Optional[str] = None
client_secret: Optional[str] = None
redirect_uri: Optional[str] = None
code: Optional[str] = None
refresh_token: Optional[str] = None
class AccessToken(BaseModel):
access_token: str
token_type: str
expires_in: int
refresh_token: str
scope: Optional[str]

View File

@ -1,11 +0,0 @@
from pydantic import BaseModel
class DirectoryInfo(BaseModel):
id: int
created_at: int
updated_at: int
name: str
path: str
is_public: bool
used: int

View File

@ -1,6 +0,0 @@
from pydantic import BaseModel
class RepositoryInfo(BaseModel):
capacity: int
used: int

View File

@ -1,10 +0,0 @@
from typing import Optional, Self
from uuid import UUID
from pydantic import BaseModel
from materia.api.token import TokenClaims
class Token(BaseModel):
access_token: str

View File

@ -1,40 +0,0 @@
from typing import Optional, Self
from uuid import UUID
from pydantic import BaseModel
from materia import db
class NewUser(BaseModel):
login: str
password: str
email: str
class User(BaseModel):
id: str
login: str
name: str
email: str
is_admin: bool
avatar: Optional[str]
@staticmethod
def from_(user: db.User) -> Self:
return User(
id = str(user.id),
login = user.login_name,
name = user.name,
email = user.email,
is_admin = user.is_admin,
avatar = user.avatar
)
class RemoveUser(BaseModel):
id: UUID
class LoginUser(BaseModel):
email: Optional[str] = None
login: Optional[str] = None
password: str

View File

@ -1,27 +0,0 @@
from typing import Self
import jwt
from pydantic import BaseModel
import time
import datetime
class TokenClaims(BaseModel):
sub: str
exp: int
iat: int
@staticmethod
def create(sub: str, secret: str, duration: int) -> str:
now = datetime.datetime.now()
iat = now.timestamp()
exp = (now + datetime.timedelta(seconds = duration)).timestamp()
claims = TokenClaims(sub = sub, exp = int(exp), iat = int(iat))
return jwt.encode(claims.model_dump(), secret)
@staticmethod
def verify(token: str, secret: str) -> Self:
data = jwt.decode(token, secret, algorithms = ["HS256"])
return TokenClaims(**data)

View File

@ -0,0 +1,50 @@
import uuid
import io
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile
import sqlalchemy as sa
from sqids.sqids import Sqids
from PIL import Image
from materia_server.config import Config
from materia_server.models import User, UserInfo
from materia_server.routers import middleware
router = APIRouter(tags = ["user"])
@router.get("/user", response_model = UserInfo)
async def info(claims = Depends(middleware.jwt_cookie), ctx: middleware.Context = Depends()):
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
return UserInfo.model_validate(current_user)
@router.post("/user/avatar")
async def avatar(file: UploadFile, user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
async with ctx.database.session() as session:
avatars: list[str] = (await session.scalars(sa.select(User.avatar))).all()
avatars = list(filter(lambda avatar_hash: avatar_hash, avatars))
avatar_id = Sqids(min_length = 10, blocklist = avatars).encode([len(avatars)])
try:
img = Image.open(io.BytesIO(await file.read()))
except OSError as _:
raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, "Failed to read file data")
try:
if not (avatars_dir := Config.data_dir() / "avatars").exists():
avatars_dir.mkdir()
img.save(avatars_dir / avatar_id, format = img.format)
except OSError as _:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to save avatar")
if old_avatar := user.avatar:
if (old_file := Config.data_dir() / "avatars" / old_avatar).exists():
old_file.unlink()
async with ctx.database.session() as session:
await session.execute(sa.update(user.User).where(user.User.id == user.id).values(avatar = avatar_id))
await session.commit()

View File

@ -1,5 +0,0 @@
from fastapi import APIRouter
from materia_server.routers.api.user import user
router = APIRouter()
router.include_router(user.router)

View File

@ -1,19 +0,0 @@
from typing import Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from materia_server import security
from materia_server.routers import context
from materia_server.models import user
from materia_server.models import auth
from materia_server.routers.middleware import JwtMiddleware
router = APIRouter(tags = ["user"])
@router.get("/user/identity", response_model = user.UserIdentity)
async def identity(request: Request, claims = Depends(JwtMiddleware()), ctx: context.Context = Depends()):
if not (current_user := await user.User.by_id(uuid.UUID(claims.sub), ctx.database)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
return user.UserIdentity.model_validate(current_user)

View File

@ -1,134 +0,0 @@
import io
from typing import Any, Optional
from fastapi import APIRouter, Depends, HTTPException, Request, Response, UploadFile, status
from sqlalchemy import delete, select, insert, func, or_, update
import bcrypt
from sqids.sqids import Sqids
from PIL import Image
from materia.config import Config
from materia.api.middleware import JwtMiddleware
from materia import db
from materia.api import schema
from materia.api.state import ConfigState, DatabaseState
from materia.api.token import TokenClaims
router = APIRouter(tags = ["user"])
@router.post("/user/register", response_model = schema.User)
async def register(body: schema.NewUser, database: DatabaseState = Depends()):
async with database.session() as session:
count: Optional[int] = await session.scalar(select(func.count(db.User.id)))
user = (await session.scalars(
select(db.User)
.where(or_(db.User.login_name == body.login, db.User.email == body.email)
))).first()
if user is not None:
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "User already exists")
hashed_password = bcrypt.hashpw(body.password.encode(), bcrypt.gensalt()).decode()
new_user = db.User(
login_name = body.login,
hashed_password = hashed_password,
name = body.login,
email = body.email,
is_admin = count == 0,
)
async with database.session() as session:
user = (await session.scalars(insert(db.User).returning(db.User), [new_user.__dict__])).first()
await session.commit()
return schema.User.from_(user)
@router.post("/user/remove", status_code = 200)
async def remove(body: schema.RemoveUser, database: DatabaseState = Depends()):
async with database.session() as session:
await session.execute(delete(db.User).where(db.User.id == body.id))
await session.commit()
@router.post("/user/login", status_code = 200, response_model = schema.Token)
async def login(body: schema.LoginUser, response: Response, database: DatabaseState = Depends(), config: ConfigState = Depends()) -> Any:
query = select(db.User)
if login := body.login:
query = query.where(db.User.login_name == login)
elif email := body.email:
query = query.where(db.User.email == email)
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Missing credentials")
async with database.session() as session:
if not (user := (await session.scalars(query)).first()):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
if not bcrypt.checkpw(body.password.encode(), user.hashed_password.encode()):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid password")
token = TokenClaims.create(
str(user.id),
config.jwt.secret,
config.jwt.maxage
)
response.set_cookie(
"token",
value = token,
max_age = config.jwt.maxage,
secure = True,
httponly = True,
samesite = "none"
)
return schema.Token(access_token = token)
@router.get("/user/logout", status_code = 200)
async def logout(response: Response):
response.set_cookie(
"token",
value = "",
max_age = -1,
secure = True,
httponly = True,
samesite = "none"
)
@router.get("/user/current", dependencies = [Depends(JwtMiddleware())], response_model = schema.User)
async def current(request: Request):
return schema.User.from_(request.state.user)
@router.post("/user/avatar", dependencies = [Depends(JwtMiddleware())])
async def avatar(request: Request, file: UploadFile, database: DatabaseState = Depends()):
async with database.session() as session:
avatars: list[str] = (await session.scalars(select(db.User.avatar))).all()
avatars = list(filter(lambda avatar_hash: avatar_hash, avatars))
avatar_id = Sqids(min_length = 10, blocklist = avatars).encode([len(avatars)])
try:
img = Image.open(io.BytesIO(await file.read()))
except OSError as _:
raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, "Failed to read file data")
try:
if not (avatars_dir := Config.data_dir() / "avatars").exists():
avatars_dir.mkdir()
img.save(avatars_dir / avatar_id, format = img.format)
except OSError as _:
raise HTTPException(status.WS_1011_INTERNAL_ERROR, "Failed to save avatar")
if old_avatar := request.state.user.avatar:
if (old_file := Config.data_dir() / "avatars" / old_avatar).exists():
old_file.unlink()
async with database.session() as session:
await session.execute(update(db.User).where(db.User.id == request.state.user.id).values(avatar = avatar_id))
await session.commit()

View File

@ -1,15 +0,0 @@
from fastapi import Request
from materia_server.config import Config
from materia_server.models.database import Database, Cache
from materia_server._logging import Logger
class Context:
def __init__(self, request: Request):
self.config = request.state.config
self.database = request.state.database
#self.cache = request.state.cache
self.logger = request.state.logger

View File

@ -1,5 +1,6 @@
from typing import Optional, Sequence from typing import Optional, Sequence
import uuid import uuid
from datetime import datetime
from fastapi import HTTPException, Request, Response, status, Depends, Cookie from fastapi import HTTPException, Request, Response, status, Depends, Cookie
from fastapi.security.base import SecurityBase from fastapi.security.base import SecurityBase
import jwt import jwt
@ -10,109 +11,71 @@ from http import HTTPMethod as HttpMethod
from fastapi.security import HTTPBearer, OAuth2PasswordBearer, OAuth2PasswordRequestForm, APIKeyQuery, APIKeyCookie, APIKeyHeader from fastapi.security import HTTPBearer, OAuth2PasswordBearer, OAuth2PasswordRequestForm, APIKeyQuery, APIKeyCookie, APIKeyHeader
from materia_server import security from materia_server import security
from materia_server.routers import context from materia_server.models import User
from materia_server.models import user
async def get_token_claims(token, ctx: context.Context = Depends()) -> security.TokenClaims: class Context:
def __init__(self, request: Request):
self.config = request.state.config
self.database = request.state.database
self.cache = request.state.cache
self.logger = request.state.logger
async def jwt_cookie(request: Request, response: Response, ctx: Context = Depends()):
if not (access_token := request.cookies.get(ctx.config.security.cookie_access_token_name)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing token")
refresh_token = request.cookies.get(ctx.config.security.cookie_refresh_token_name)
if ctx.config.oauth2.jwt_signing_algo in ["HS256", "HS384", "HS512"]:
secret = ctx.config.oauth2.jwt_secret
else:
secret = ctx.config.oauth2.jwt_signing_key
issuer = "{}://{}".format(ctx.config.server.scheme, ctx.config.server.domain)
try: try:
secret = ctx.config.oauth2.jwt_secret if ctx.config.oauth2.jwt_signing_algo in ["HS256", "HS384", "HS512"] else ctx.config.oauth2.jwt_signing_key refresh_claims = security.validate_token(refresh_token, secret) if refresh_token else None
claims = security.validate_token(token, secret)
user_id = uuid.UUID(claims.sub) # type: ignore if refresh_claims:
except jwt.PyJWKError as _: if refresh_claims.exp < datetime.now().timestamp():
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token") refresh_claims = None
except ValueError as _: except jwt.PyJWTError:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid token") refresh_claims = None
if not (current_user := await user.User.by_id(user_id, ctx.database)): try:
access_claims = security.validate_token(access_token, secret)
if access_claims.exp < datetime.now().timestamp():
if refresh_claims:
new_access_token = security.generate_token(
access_claims.sub,
str(secret),
ctx.config.oauth2.access_token_lifetime,
issuer
)
access_claims = security.validate_token(new_access_token, secret)
response.set_cookie(
ctx.config.security.cookie_access_token_name,
value = new_access_token,
max_age = ctx.config.oauth2.access_token_lifetime,
secure = True,
httponly = ctx.config.security.cookie_http_only,
samesite = "lax"
)
else:
access_claims = None
except jwt.PyJWTError as e:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}")
if not await User.by_id(uuid.UUID(access_claims.sub), ctx.database):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
return access_claims
async def user(claims = Depends(jwt_cookie), ctx: Context = Depends()):
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
return claims return current_user
class JwtBearer(HTTPBearer):
def __init__(self, **kwargs):
super().__init__(scheme_name = "Bearer", **kwargs)
self.claims = None
async def __call__(self, request: Request, ctx: context.Context = Depends()):
if credentials := await super().__call__(request):
token = credentials.credentials
else:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Missing token")
self.claims = await get_token_claims(token, ctx)
class JwtCookie(SecurityBase):
def __init(self, *, auto_error: bool = True):
self.auto_error = auto_error
self.claims = None
async def __call__(self, request: Request, response: Response, ctx: context.Context = Depends()):
if not (access_token := request.cookies.get(ctx.config.security.cookie_access_token_name)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing token")
refresh_token = request.cookies.get(ctx.config.security.cookie_refresh_token_name)
if ctx.config.oauth2.jwt_signing_algo in ["HS256", "HS384", "HS512"]:
secret = ctx.config.oauth2.jwt_secret
else:
secret = ctx.config.oauth2.jwt_signing_key
try:
refresh_claims = security.validate_token(refresh_token, secret) if refresh_token else None
# TODO: check expiration
except jwt.PyJWTError:
refresh_claims = None
try:
access_claims = security.validate_token(access_token, secret)
# TODO: if exp then check refresh token and create new else raise
except jwt.PyJWTError as e:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}")
else:
# TODO: validate user
pass
self.claims = access_claims
WILDCARD = "*"
NULL = "null"
class HttpHeader(StrEnum):
ACCESS_CONTROL_ALLOW_CREDENTIALS = "Access-Control-Allow-Credentials"
ACCESS_CONTROL_ALLOW_METHODS = "Access-Control-Allow-Methods"
ACCESS_CONTROL_ALLOW_ORIGIN = "Access-Control-Allow-Origin"
ACCESS_CONTROL_ALLOW_HEADERS = "Access-Control-Allow-Headers"
ACCESS_CONTROL_EXPOSE_HEADERS = "Access-Control-Expose-Headers"
ACCESS_CONTROL_MAX_AGE = "Access-Control-Max-Age"
CONTENT_TYPE = "Content-Type"
AUTHORIZATION = "Authorization"
VARY = "Vary"
ORIGIN = "Origin"
class CorsMiddleware(BaseModel):
allow_credentials: bool = False
allow_headers: Sequence[HttpHeader | str] = []
allow_methods: Sequence[HttpMethod | str] = []
allow_origin: str = WILDCARD
expose_headers: Sequence[HttpHeader | str] = []
max_age: int = 600
async def __call__(self, request: Request, response: Response):
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_CREDENTIALS] = str(self.allow_credentials).lower()
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_HEADERS] = self.make_from(self.allow_headers)
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_METHODS] = self.make_from(self.allow_methods)
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_ORIGIN] = str(self.allow_origin)
response.headers[HttpHeader.ACCESS_CONTROL_EXPOSE_HEADERS] = self.make_from(self.expose_headers)
response.headers[HttpHeader.ACCESS_CONTROL_MAX_AGE] = str(self.max_age)
def make_from(self, value: Sequence[HttpHeader | HttpMethod | str]) -> str:
if WILDCARD in value:
return WILDCARD
elif NULL in value:
return NULL
else:
return ",".join(set(value))