materia-server: repository api, directory api, collapsed modules
This commit is contained in:
parent
d8b19da646
commit
317085fc04
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
[alembic]
|
[alembic]
|
||||||
# path to migration scripts
|
# path to migration scripts
|
||||||
script_location = materia_server:models/migrations
|
script_location = ./src/materia_server/models/migrations
|
||||||
|
|
||||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||||
# Uncomment the line below if you want the files to be prepended with date and time
|
# Uncomment the line below if you want the files to be prepended with date and time
|
||||||
@ -60,7 +60,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
|||||||
# are written from script.py.mako
|
# are written from script.py.mako
|
||||||
# output_encoding = utf-8
|
# output_encoding = utf-8
|
||||||
|
|
||||||
#sqlalchemy.url = driver://user:pass@localhost/dbname
|
sqlalchemy.url = postgresql+asyncpg://materia:materia@127.0.0.1:54320/materia
|
||||||
|
|
||||||
|
|
||||||
[post_write_hooks]
|
[post_write_hooks]
|
@ -35,7 +35,7 @@ readme = "README.md"
|
|||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
|
|
||||||
[tool.pdm.build]
|
[tool.pdm.build]
|
||||||
includes = ["src/materia_server", "src/materia_server/alembic.ini"]
|
includes = ["src/materia_server"]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["pdm-backend"]
|
requires = ["pdm-backend"]
|
||||||
|
@ -106,6 +106,10 @@ class OAuth2(BaseModel):
|
|||||||
#def check(self) -> Self:
|
#def check(self) -> Self:
|
||||||
# if self.jwt_signing_algo in ["HS256", "HS384", "HS512"]:
|
# if self.jwt_signing_algo in ["HS256", "HS384", "HS512"]:
|
||||||
# assert self.jwt_secret is not None, "JWT secret must be set for HS256, HS384, HS512 algorithms"
|
# assert self.jwt_secret is not None, "JWT secret must be set for HS256, HS384, HS512 algorithms"
|
||||||
|
# else:
|
||||||
|
# assert self.jwt_signing_key is not None, "JWT signing key must be set"
|
||||||
|
#
|
||||||
|
# return self
|
||||||
|
|
||||||
|
|
||||||
class Mailer(BaseModel):
|
class Mailer(BaseModel):
|
||||||
@ -171,9 +175,6 @@ class Config(BaseSettings, env_prefix = "materia_", env_nested_delimiter = "_"):
|
|||||||
else:
|
else:
|
||||||
return cwd
|
return cwd
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create(path: Path, config: Self | None = None):
|
|
||||||
config = config or Config()
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,7 +17,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from materia_server import config as _config
|
from materia_server import config as _config
|
||||||
from materia_server.config import Config
|
from materia_server.config import Config
|
||||||
from materia_server._logging import make_logger, uvicorn_log_config, Logger
|
from materia_server._logging import make_logger, uvicorn_log_config, Logger
|
||||||
from materia_server.models.database import Database, Cache
|
from materia_server.models import Database, Cache
|
||||||
from materia_server import routers
|
from materia_server import routers
|
||||||
|
|
||||||
|
|
||||||
@ -31,12 +31,19 @@ class AppContext(TypedDict):
|
|||||||
def create_lifespan(config: Config, logger):
|
def create_lifespan(config: Config, logger):
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
|
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
|
||||||
database = Database.new(config.database.url()) # type: ignore
|
|
||||||
|
try:
|
||||||
|
logger.info("Connecting {}", config.database.url())
|
||||||
|
database = Database.new(config.database.url()) # type: ignore
|
||||||
|
except:
|
||||||
|
logger.error("Failed to connect postgres.")
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info("Connecting {}", config.cache.url())
|
||||||
cache = await Cache.new(config.cache.url()) # type: ignore
|
cache = await Cache.new(config.cache.url()) # type: ignore
|
||||||
except:
|
except:
|
||||||
logger.error("Failed to connect redis {}", config.cache.url())
|
logger.error("Failed to connect redis.")
|
||||||
sys.exit()
|
sys.exit()
|
||||||
|
|
||||||
async with database.connection() as connection:
|
async with database.connection() as connection:
|
||||||
@ -64,6 +71,7 @@ def server():
|
|||||||
@from_pydantic("log", _config.Log, prefix = "log")
|
@from_pydantic("log", _config.Log, prefix = "log")
|
||||||
def start(application: _config.Application, config_path: Path, log: _config.Log):
|
def start(application: _config.Application, config_path: Path, log: _config.Log):
|
||||||
config = Config()
|
config = Config()
|
||||||
|
config.log = log
|
||||||
logger = make_logger(config)
|
logger = make_logger(config)
|
||||||
|
|
||||||
#if user := application.user:
|
#if user := application.user:
|
||||||
@ -71,8 +79,12 @@ def start(application: _config.Application, config_path: Path, log: _config.Log)
|
|||||||
#if group := application.group:
|
#if group := application.group:
|
||||||
# os.setgid(pwd.getpwnam(user).pw_gid)
|
# os.setgid(pwd.getpwnam(user).pw_gid)
|
||||||
# TODO: merge cli options with config
|
# TODO: merge cli options with config
|
||||||
if working_directory := (application.working_directory or config.application.working_directory):
|
if working_directory := (application.working_directory or config.application.working_directory).resolve():
|
||||||
os.chdir(working_directory.resolve())
|
try:
|
||||||
|
os.chdir(working_directory)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.error("Failed to change working directory: {}", e)
|
||||||
|
sys.exit()
|
||||||
logger.debug(f"Current working directory: {working_directory}")
|
logger.debug(f"Current working directory: {working_directory}")
|
||||||
|
|
||||||
# check the configuration file or use default
|
# check the configuration file or use default
|
||||||
@ -106,6 +118,13 @@ def start(application: _config.Application, config_path: Path, log: _config.Log)
|
|||||||
|
|
||||||
config.log.level = log.level
|
config.log.level = log.level
|
||||||
logger = make_logger(config)
|
logger = make_logger(config)
|
||||||
|
if (working_directory := config.application.working_directory.resolve()):
|
||||||
|
logger.debug(f"Change working directory: {working_directory}")
|
||||||
|
try:
|
||||||
|
os.chdir(working_directory)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
logger.error("Failed to change working directory: {}", e)
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
config.application.mode = application.mode
|
config.application.mode = application.mode
|
||||||
|
|
||||||
|
@ -6,4 +6,18 @@
|
|||||||
#from materia_server.models.directory import Directory, DirectoryLink
|
#from materia_server.models.directory import Directory, DirectoryLink
|
||||||
#from materia_server.models.file import File, FileLink
|
#from materia_server.models.file import File, FileLink
|
||||||
|
|
||||||
|
#from materia_server.models.repository import *
|
||||||
|
|
||||||
|
from materia_server.models.auth.source import LoginType, LoginSource
|
||||||
|
from materia_server.models.auth.oauth2 import OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode
|
||||||
|
|
||||||
|
from materia_server.models.database.database import Database
|
||||||
|
from materia_server.models.database.cache import Cache
|
||||||
|
|
||||||
|
from materia_server.models.user import User, UserCredentials, UserInfo
|
||||||
|
|
||||||
|
from materia_server.models.repository import Repository, RepositoryInfo
|
||||||
|
|
||||||
|
from materia_server.models.directory import Directory, DirectoryLink, DirectoryInfo
|
||||||
|
|
||||||
|
from materia_server.models.file import File, FileLink
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import os
|
||||||
from typing import AsyncIterator, Self
|
from typing import AsyncIterator, Self
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ from alembic.operations import Operations
|
|||||||
from alembic.runtime.migration import MigrationContext
|
from alembic.runtime.migration import MigrationContext
|
||||||
from alembic.script.base import ScriptDirectory
|
from alembic.script.base import ScriptDirectory
|
||||||
|
|
||||||
|
from materia_server.config import Config
|
||||||
from materia_server.models.base import Base
|
from materia_server.models.base import Base
|
||||||
|
|
||||||
__all__ = [ "Database" ]
|
__all__ = [ "Database" ]
|
||||||
@ -61,14 +63,20 @@ class Database:
|
|||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
def run_migrations(self, connection: Connection):
|
def run_migrations(self, connection: Connection):
|
||||||
config = AlembicConfig(Path(__file__).parent.parent.parent / "alembic.ini")
|
#aconfig = AlembicConfig(Path(__file__).parent.parent.parent / "alembic.ini")
|
||||||
config.set_main_option("sqlalchemy.url", self.url) # type: ignore
|
aconfig = AlembicConfig()
|
||||||
|
aconfig.set_main_option("sqlalchemy.url", str(self.url))
|
||||||
|
|
||||||
|
|
||||||
|
aconfig.set_main_option("script_location", str(Path(__file__).parent.parent.joinpath("migrations")))
|
||||||
|
print(str(Path(__file__).parent.parent.joinpath("migrations")))
|
||||||
|
|
||||||
|
|
||||||
context = MigrationContext.configure(
|
context = MigrationContext.configure(
|
||||||
connection = connection, # type: ignore
|
connection = connection, # type: ignore
|
||||||
opts = {
|
opts = {
|
||||||
"target_metadata": Base.metadata,
|
"target_metadata": Base.metadata,
|
||||||
"fn": lambda rev, _: ScriptDirectory.from_config(config)._upgrade_revs("head", rev)
|
"fn": lambda rev, _: ScriptDirectory.from_config(aconfig)._upgrade_revs("head", rev)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,10 +1,14 @@
|
|||||||
from time import time
|
from time import time
|
||||||
from typing import List
|
from typing import List, Optional, Self
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from sqlalchemy import BigInteger, ForeignKey
|
from sqlalchemy import BigInteger, ForeignKey
|
||||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from materia_server.models.base import Base
|
from materia_server.models.base import Base
|
||||||
|
from materia_server.models import database
|
||||||
|
|
||||||
|
|
||||||
class Directory(Base):
|
class Directory(Base):
|
||||||
@ -12,7 +16,7 @@ class Directory(Base):
|
|||||||
|
|
||||||
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
|
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
|
||||||
repository_id: Mapped[int] = mapped_column(ForeignKey("repository.id", ondelete = "CASCADE"))
|
repository_id: Mapped[int] = mapped_column(ForeignKey("repository.id", ondelete = "CASCADE"))
|
||||||
parent_id: Mapped[int] = mapped_column(ForeignKey("directory.id"), nullable = True)
|
parent_id: Mapped[int] = mapped_column(ForeignKey("directory.id", ondelete = "CASCADE"), nullable = True)
|
||||||
created: Mapped[int] = mapped_column(BigInteger, nullable = False, default = time)
|
created: Mapped[int] = mapped_column(BigInteger, nullable = False, default = time)
|
||||||
updated: Mapped[int] = mapped_column(BigInteger, nullable = False, default = time)
|
updated: Mapped[int] = mapped_column(BigInteger, nullable = False, default = time)
|
||||||
name: Mapped[str]
|
name: Mapped[str]
|
||||||
@ -25,6 +29,14 @@ class Directory(Base):
|
|||||||
files: Mapped[List["File"]] = relationship(back_populates = "parent")
|
files: Mapped[List["File"]] = relationship(back_populates = "parent")
|
||||||
link: Mapped["DirectoryLink"] = relationship(back_populates = "directory")
|
link: Mapped["DirectoryLink"] = relationship(back_populates = "directory")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def by_path(repository_id: int, path: Path | None, name: str, db: database.Database) -> Self | None:
|
||||||
|
async with db.session() as session:
|
||||||
|
query_path = Directory.path == str(path) if isinstance(path, Path) else Directory.path.is_(None)
|
||||||
|
return (await session
|
||||||
|
.scalars(sa.select(Directory)
|
||||||
|
.where(sa.and_(Directory.repository_id == repository_id, Directory.name == name, query_path)))
|
||||||
|
).first()
|
||||||
|
|
||||||
class DirectoryLink(Base):
|
class DirectoryLink(Base):
|
||||||
__tablename__ = "directory_link"
|
__tablename__ = "directory_link"
|
||||||
@ -36,5 +48,20 @@ class DirectoryLink(Base):
|
|||||||
|
|
||||||
directory: Mapped["Directory"] = relationship(back_populates = "link")
|
directory: Mapped["Directory"] = relationship(back_populates = "link")
|
||||||
|
|
||||||
from materia_server.models.repository.repository import Repository
|
class DirectoryInfo(BaseModel):
|
||||||
from materia_server.models.file.file import File
|
model_config = ConfigDict(from_attributes = True)
|
||||||
|
|
||||||
|
id: int
|
||||||
|
repository_id: int
|
||||||
|
parent_id: Optional[int]
|
||||||
|
created: int
|
||||||
|
updated: int
|
||||||
|
name: str
|
||||||
|
path: Optional[str]
|
||||||
|
is_public: bool
|
||||||
|
|
||||||
|
used: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
from materia_server.models.repository import Repository
|
||||||
|
from materia_server.models.file import File
|
@ -1 +0,0 @@
|
|||||||
from materia_server.models.directory.directory import Directory, DirectoryLink
|
|
@ -35,5 +35,5 @@ class FileLink(Base):
|
|||||||
file: Mapped["File"] = relationship(back_populates = "link")
|
file: Mapped["File"] = relationship(back_populates = "link")
|
||||||
|
|
||||||
|
|
||||||
from materia_server.models.repository.repository import Repository
|
from materia_server.models.repository import Repository
|
||||||
from materia_server.models.directory.directory import Directory
|
from materia_server.models.directory import Directory
|
@ -1 +0,0 @@
|
|||||||
from materia_server.models.file.file import File, FileLink
|
|
@ -19,8 +19,12 @@ import materia_server.models.file
|
|||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
|
context.configure(
|
||||||
|
version_table_schema = "public"
|
||||||
|
)
|
||||||
config = context.config
|
config = context.config
|
||||||
config.set_main_option("sqlalchemy.url", Config().database.url())
|
|
||||||
|
#config.set_main_option("sqlalchemy.url", Config().database.url())
|
||||||
|
|
||||||
# Interpret the config file for Python logging.
|
# Interpret the config file for Python logging.
|
||||||
# This line sets up loggers basically.
|
# This line sets up loggers basically.
|
||||||
@ -59,6 +63,7 @@ def run_migrations_offline() -> None:
|
|||||||
target_metadata=target_metadata,
|
target_metadata=target_metadata,
|
||||||
literal_binds=True,
|
literal_binds=True,
|
||||||
dialect_opts={"paramstyle": "named"},
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
version_table_schema = "public"
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
@ -99,4 +104,5 @@ def run_migrations_online() -> None:
|
|||||||
if context.is_offline_mode():
|
if context.is_offline_mode():
|
||||||
run_migrations_offline()
|
run_migrations_offline()
|
||||||
else:
|
else:
|
||||||
|
print("online")
|
||||||
run_migrations_online()
|
run_migrations_online()
|
||||||
|
@ -1,140 +0,0 @@
|
|||||||
"""empty message
|
|
||||||
|
|
||||||
Revision ID: 76191498b728
|
|
||||||
Revises:
|
|
||||||
Create Date: 2024-06-03 18:44:07.044588
|
|
||||||
|
|
||||||
"""
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '76191498b728'
|
|
||||||
down_revision: Union[str, None] = None
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').create(op.get_bind())
|
|
||||||
op.create_table('login_source',
|
|
||||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False),
|
|
||||||
sa.Column('created', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('updated', sa.Integer(), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_table('user',
|
|
||||||
sa.Column('id', sa.Uuid(), nullable=False),
|
|
||||||
sa.Column('name', sa.String(), nullable=False),
|
|
||||||
sa.Column('lower_name', sa.String(), nullable=False),
|
|
||||||
sa.Column('full_name', sa.String(), nullable=False),
|
|
||||||
sa.Column('email', sa.String(), nullable=False),
|
|
||||||
sa.Column('is_email_private', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('hashed_password', sa.String(), nullable=False),
|
|
||||||
sa.Column('must_change_password', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('login_type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False),
|
|
||||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('last_login', sa.BigInteger(), nullable=True),
|
|
||||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('is_admin', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('avatar', sa.String(), nullable=True),
|
|
||||||
sa.PrimaryKeyConstraint('id'),
|
|
||||||
sa.UniqueConstraint('lower_name'),
|
|
||||||
sa.UniqueConstraint('name')
|
|
||||||
)
|
|
||||||
op.create_table('oauth2_application',
|
|
||||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
|
||||||
sa.Column('name', sa.String(), nullable=False),
|
|
||||||
sa.Column('client_id', sa.Uuid(), nullable=False),
|
|
||||||
sa.Column('hashed_client_secret', sa.String(), nullable=False),
|
|
||||||
sa.Column('redirect_uris', sa.JSON(), nullable=False),
|
|
||||||
sa.Column('confidential_client', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_table('repository',
|
|
||||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
|
||||||
sa.Column('capacity', sa.BigInteger(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_table('directory',
|
|
||||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('repository_id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('parent_id', sa.BigInteger(), nullable=True),
|
|
||||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('name', sa.String(), nullable=False),
|
|
||||||
sa.Column('path', sa.String(), nullable=True),
|
|
||||||
sa.Column('is_public', sa.Boolean(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ),
|
|
||||||
sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_table('oauth2_grant',
|
|
||||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
|
||||||
sa.Column('application_id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('scope', sa.String(), nullable=False),
|
|
||||||
sa.Column('created', sa.Integer(), nullable=False),
|
|
||||||
sa.Column('updated', sa.Integer(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['application_id'], ['oauth2_application.id'], ondelete='CASCADE'),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_table('directory_link',
|
|
||||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('directory_id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('url', sa.String(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['directory_id'], ['directory.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_table('file',
|
|
||||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('repository_id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('parent_id', sa.BigInteger(), nullable=True),
|
|
||||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('name', sa.String(), nullable=False),
|
|
||||||
sa.Column('path', sa.String(), nullable=True),
|
|
||||||
sa.Column('is_public', sa.Boolean(), nullable=False),
|
|
||||||
sa.Column('size', sa.BigInteger(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ),
|
|
||||||
sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
op.create_table('file_link',
|
|
||||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('file_id', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
|
||||||
sa.Column('url', sa.String(), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['file_id'], ['file.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id')
|
|
||||||
)
|
|
||||||
# ### end Alembic commands ###
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
|
||||||
op.drop_table('file_link')
|
|
||||||
op.drop_table('file')
|
|
||||||
op.drop_table('directory_link')
|
|
||||||
op.drop_table('oauth2_grant')
|
|
||||||
op.drop_table('directory')
|
|
||||||
op.drop_table('repository')
|
|
||||||
op.drop_table('oauth2_application')
|
|
||||||
op.drop_table('user')
|
|
||||||
op.drop_table('login_source')
|
|
||||||
sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').drop(op.get_bind())
|
|
||||||
# ### end Alembic commands ###
|
|
51
materia-server/src/materia_server/models/repository.py
Normal file
51
materia-server/src/materia_server/models/repository.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
from time import time
|
||||||
|
from typing import List, Self
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, ForeignKey
|
||||||
|
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||||
|
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from materia_server.models.base import Base
|
||||||
|
from materia_server.models import database
|
||||||
|
|
||||||
|
|
||||||
|
class Repository(Base):
|
||||||
|
__tablename__ = "repository"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
|
||||||
|
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
|
||||||
|
capacity: Mapped[int] = mapped_column(BigInteger, nullable = False)
|
||||||
|
|
||||||
|
user: Mapped["User"] = relationship(back_populates = "repository")
|
||||||
|
directories: Mapped[List["Directory"]] = relationship(back_populates = "repository")
|
||||||
|
files: Mapped[List["File"]] = relationship(back_populates = "repository")
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return { k: getattr(self, k) for k, v in Repository.__dict__.items() if isinstance(v, InstrumentedAttribute) }
|
||||||
|
|
||||||
|
async def create(self, db: database.Database):
|
||||||
|
async with db.session() as session:
|
||||||
|
session.add(self)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def update(self, db: database.Database):
|
||||||
|
async with db.session() as session:
|
||||||
|
await session.execute(sa.update(Repository).where(Repository.id == self.id).values(self.to_dict()))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def by_user_id(user_id: UUID, db: database.Database) -> Self | None:
|
||||||
|
async with db.session() as session:
|
||||||
|
return (await session.scalars(sa.select(Repository).where(Repository.user_id == user_id))).first()
|
||||||
|
|
||||||
|
|
||||||
|
class RepositoryInfo(BaseModel):
|
||||||
|
capacity: int
|
||||||
|
used: int
|
||||||
|
|
||||||
|
from materia_server.models.user import User
|
||||||
|
from materia_server.models.directory import Directory
|
||||||
|
from materia_server.models.file import File
|
@ -1 +0,0 @@
|
|||||||
from materia_server.models.repository.repository import Repository
|
|
@ -1,25 +0,0 @@
|
|||||||
from time import time
|
|
||||||
from typing import List
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from sqlalchemy import BigInteger, ForeignKey
|
|
||||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
|
||||||
|
|
||||||
from materia_server.models.base import Base
|
|
||||||
|
|
||||||
|
|
||||||
class Repository(Base):
|
|
||||||
__tablename__ = "repository"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
|
|
||||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
|
|
||||||
capacity: Mapped[int] = mapped_column(BigInteger, nullable = False)
|
|
||||||
|
|
||||||
user: Mapped["User"] = relationship(back_populates = "repository")
|
|
||||||
directories: Mapped[List["Directory"]] = relationship(back_populates = "repository")
|
|
||||||
files: Mapped[List["File"]] = relationship(back_populates = "repository")
|
|
||||||
|
|
||||||
|
|
||||||
from materia_server.models.user.user import User
|
|
||||||
from materia_server.models.directory.directory import Directory
|
|
||||||
from materia_server.models.file.file import File
|
|
@ -80,9 +80,10 @@ class UserCredentials(BaseModel):
|
|||||||
password: str
|
password: str
|
||||||
email: Optional[EmailStr]
|
email: Optional[EmailStr]
|
||||||
|
|
||||||
class UserIdentity(BaseModel):
|
class UserInfo(BaseModel):
|
||||||
model_config = ConfigDict(from_attributes = True)
|
model_config = ConfigDict(from_attributes = True)
|
||||||
|
|
||||||
|
id: UUID
|
||||||
name: str
|
name: str
|
||||||
lower_name: str
|
lower_name: str
|
||||||
full_name: Optional[str]
|
full_name: Optional[str]
|
||||||
@ -101,4 +102,4 @@ class UserIdentity(BaseModel):
|
|||||||
|
|
||||||
avatar: Optional[str]
|
avatar: Optional[str]
|
||||||
|
|
||||||
from materia_server.models.repository.repository import Repository
|
from materia_server.models.repository import Repository
|
@ -1 +0,0 @@
|
|||||||
from materia_server.models.user.user import User, UserCredentials, UserIdentity
|
|
@ -1,2 +1 @@
|
|||||||
from materia_server.routers import api
|
from materia_server.routers import middleware, api
|
||||||
from materia_server.routers import middleware
|
|
||||||
|
@ -1,8 +1,10 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from materia_server.routers.api import auth
|
from materia_server.routers.api.auth import auth, oauth
|
||||||
from materia_server.routers.api import user
|
from materia_server.routers.api import user, repository, directory
|
||||||
|
|
||||||
router = APIRouter(prefix = "/api")
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
router.include_router(auth.router)
|
router.include_router(auth.router)
|
||||||
|
router.include_router(oauth.router)
|
||||||
router.include_router(user.router)
|
router.include_router(user.router)
|
||||||
|
router.include_router(repository.router)
|
||||||
|
router.include_router(directory.router)
|
||||||
|
@ -1,8 +0,0 @@
|
|||||||
from fastapi import APIRouter
|
|
||||||
from materia_server.routers.api.auth import auth
|
|
||||||
from materia_server.routers.api.auth import oauth
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
router.include_router(auth.router)
|
|
||||||
router.include_router(oauth.router)
|
|
||||||
|
|
@ -3,33 +3,32 @@ from typing import Optional
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||||
|
|
||||||
from materia_server import security
|
from materia_server import security
|
||||||
from materia_server.routers import context
|
from materia_server.routers.middleware import Context
|
||||||
from materia_server.models import user
|
from materia_server.models import LoginType, User, UserCredentials
|
||||||
from materia_server.models import auth
|
|
||||||
|
|
||||||
router = APIRouter(tags = ["auth"])
|
router = APIRouter(tags = ["auth"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/auth/signup")
|
@router.post("/auth/signup")
|
||||||
async def signup(body: user.UserCredentials, ctx: context.Context = Depends()):
|
async def signup(body: UserCredentials, ctx: Context = Depends()):
|
||||||
if not user.User.is_valid_username(body.name):
|
if not User.is_valid_username(body.name):
|
||||||
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "Invalid username")
|
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "Invalid username")
|
||||||
if await user.User.by_name(body.name, ctx.database) is not None:
|
if await User.by_name(body.name, ctx.database) is not None:
|
||||||
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "User already exists")
|
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "User already exists")
|
||||||
if await user.User.by_email(body.email, ctx.database) is not None: # type: ignore
|
if await User.by_email(body.email, ctx.database) is not None: # type: ignore
|
||||||
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "Email already used")
|
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "Email already used")
|
||||||
if len(body.password) < ctx.config.security.password_min_length:
|
if len(body.password) < ctx.config.security.password_min_length:
|
||||||
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = f"Password is too short (minimum length {ctx.config.security.password_min_length})")
|
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = f"Password is too short (minimum length {ctx.config.security.password_min_length})")
|
||||||
|
|
||||||
count: Optional[int] = await user.User.count(ctx.database)
|
count: Optional[int] = await User.count(ctx.database)
|
||||||
|
|
||||||
new_user = user.User(
|
new_user = User(
|
||||||
name = body.name,
|
name = body.name,
|
||||||
lower_name = body.name.lower(),
|
lower_name = body.name.lower(),
|
||||||
full_name = body.name,
|
full_name = body.name,
|
||||||
email = body.email,
|
email = body.email,
|
||||||
hashed_password = security.hash_password(body.password, algo = ctx.config.security.password_hash_algo),
|
hashed_password = security.hash_password(body.password, algo = ctx.config.security.password_hash_algo),
|
||||||
login_type = auth.LoginType.Plain,
|
login_type = LoginType.Plain,
|
||||||
# first registered user is admin
|
# first registered user is admin
|
||||||
is_admin = count == 0
|
is_admin = count == 0
|
||||||
)
|
)
|
||||||
@ -39,9 +38,9 @@ async def signup(body: user.UserCredentials, ctx: context.Context = Depends()):
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
@router.post("/auth/signin")
|
@router.post("/auth/signin")
|
||||||
async def signin(body: user.UserCredentials, response: Response, ctx: context.Context = Depends()):
|
async def signin(body: UserCredentials, response: Response, ctx: Context = Depends()):
|
||||||
if (current_user := await user.User.by_name(body.name, ctx.database)) is None:
|
if (current_user := await User.by_name(body.name, ctx.database)) is None:
|
||||||
if (current_user := await user.User.by_email(str(body.email), ctx.database)) is None:
|
if (current_user := await User.by_email(str(body.email), ctx.database)) is None:
|
||||||
raise HTTPException(status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid email")
|
raise HTTPException(status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid email")
|
||||||
if not security.validate_password(body.password, current_user.hashed_password, algo = ctx.config.security.password_hash_algo):
|
if not security.validate_password(body.password, current_user.hashed_password, algo = ctx.config.security.password_hash_algo):
|
||||||
raise HTTPException(status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid password")
|
raise HTTPException(status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid password")
|
||||||
@ -79,6 +78,6 @@ async def signin(body: user.UserCredentials, response: Response, ctx: context.Co
|
|||||||
)
|
)
|
||||||
|
|
||||||
@router.get("/auth/signout")
|
@router.get("/auth/signout")
|
||||||
async def signout(response: Response, ctx: context.Context = Depends()):
|
async def signout(response: Response, ctx: Context = Depends()):
|
||||||
response.delete_cookie(ctx.config.security.cookie_access_token_name)
|
response.delete_cookie(ctx.config.security.cookie_access_token_name)
|
||||||
response.delete_cookie(ctx.config.security.cookie_refresh_token_name)
|
response.delete_cookie(ctx.config.security.cookie_refresh_token_name)
|
||||||
|
@ -7,9 +7,8 @@ from fastapi.security.oauth2 import OAuth2PasswordRequestForm
|
|||||||
from pydantic import BaseModel, HttpUrl
|
from pydantic import BaseModel, HttpUrl
|
||||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
|
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
|
||||||
|
|
||||||
from materia_server.models import auth
|
from materia_server.models import User
|
||||||
from materia_server.models.user import user
|
from materia_server.routers.middleware import Context
|
||||||
from materia_server.routers import context
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(tags = ["oauth2"])
|
router = APIRouter(tags = ["oauth2"])
|
||||||
@ -35,17 +34,17 @@ class AuthorizationCodeResponse(BaseModel):
|
|||||||
code: str
|
code: str
|
||||||
|
|
||||||
@router.post("/oauth2/authorize")
|
@router.post("/oauth2/authorize")
|
||||||
async def authorize(form: Annotated[OAuth2AuthorizationCodeRequestForm, Depends()], ctx: context.Context = Depends()):
|
async def authorize(form: Annotated[OAuth2AuthorizationCodeRequestForm, Depends()], ctx: Context = Depends()):
|
||||||
# grant_type: authorization_code, password_credentials, client_credentials, authorization_code (pkce)
|
# grant_type: authorization_code, password_credentials, client_credentials, authorization_code (pkce)
|
||||||
ctx.logger.debug(form)
|
ctx.logger.debug(form)
|
||||||
|
|
||||||
if form.grant_type == "authorization_code":
|
if form.grant_type == "authorization_code":
|
||||||
# TODO: form validation
|
# TODO: form validation
|
||||||
|
|
||||||
if not (app := await auth.OAuth2Application.by_client_id(form.client_id, ctx.database)):
|
if not (app := await OAuth2Application.by_client_id(form.client_id, ctx.database)):
|
||||||
raise HTTPException(status_code = HTTP_500_INTERNAL_SERVER_ERROR, detail = "Client ID not registered")
|
raise HTTPException(status_code = HTTP_500_INTERNAL_SERVER_ERROR, detail = "Client ID not registered")
|
||||||
|
|
||||||
if not (owner := user.User.by_id(app.user_id, ctx.database)):
|
if not (owner := await User.by_id(app.user_id, ctx.database)):
|
||||||
raise HTTPException(status_code = HTTP_500_INTERNAL_SERVER_ERROR, detail = "User not found")
|
raise HTTPException(status_code = HTTP_500_INTERNAL_SERVER_ERROR, detail = "User not found")
|
||||||
|
|
||||||
if not app.contains_redirect_uri(form.redirect_uri):
|
if not app.contains_redirect_uri(form.redirect_uri):
|
||||||
@ -79,5 +78,5 @@ class AccessTokenResponse(BaseModel):
|
|||||||
scope: Optional[str]
|
scope: Optional[str]
|
||||||
|
|
||||||
@router.post("/oauth2/access_token")
|
@router.post("/oauth2/access_token")
|
||||||
async def token(ctx: context.Context = Depends()):
|
async def token(ctx: Context = Depends()):
|
||||||
pass
|
pass
|
||||||
|
@ -1,143 +0,0 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated
|
|
||||||
import bcrypt
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, UploadFile, status
|
|
||||||
from fastapi.responses import RedirectResponse, StreamingResponse
|
|
||||||
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordRequestFormStrict
|
|
||||||
import httpx
|
|
||||||
from sqlalchemy import and_, insert, select, update
|
|
||||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
|
||||||
import base64
|
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
import json
|
|
||||||
|
|
||||||
from materia import db
|
|
||||||
from materia.api import schema
|
|
||||||
from materia.api.state import ConfigState, DatabaseState
|
|
||||||
from materia.api.middleware import JwtMiddleware
|
|
||||||
from materia.api.token import TokenClaims
|
|
||||||
from materia.config import Config
|
|
||||||
|
|
||||||
oauth = OAuth()
|
|
||||||
oauth.register(
|
|
||||||
"materia",
|
|
||||||
authorize_url = "http://127.0.0.1:54601/api/auth/authorize",
|
|
||||||
access_token_url = "http://127.0.0.1:54601/api/auth/token",
|
|
||||||
scope = "user:read",
|
|
||||||
client_id = "",
|
|
||||||
client_secret = ""
|
|
||||||
)
|
|
||||||
|
|
||||||
class OAuth2Provider:
|
|
||||||
pass
|
|
||||||
|
|
||||||
router = APIRouter(tags = ["auth"])
|
|
||||||
|
|
||||||
@router.get("/user/signin")
|
|
||||||
async def signin(request: Request, provider: str = None):
|
|
||||||
if not provider:
|
|
||||||
return RedirectResponse("/api/auth/authorize")
|
|
||||||
else:
|
|
||||||
return RedirectResponse(request.url_for(provider.authorize_url))
|
|
||||||
|
|
||||||
@router.post("/auth/test_auth")
|
|
||||||
async def test_auth(database: DatabaseState = Depends()):
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.post("https://vcs.elnafo.ru/login/oauth/authorize", data = {
|
|
||||||
"client_id": "1edfe-0bbe-4f53-bab6-7e24f0b842e3",
|
|
||||||
"client_secret": "gto_7ecfnqg2c6kbe2qf25wjee237mmkxvbkb7arjacyvtypi24hqv4q",
|
|
||||||
"response_type": "code",
|
|
||||||
"redirect_uri": "http://127.0.0.1:54601"
|
|
||||||
})
|
|
||||||
return response.content, response.status_code
|
|
||||||
|
|
||||||
@router.post("/auth/provider")
|
|
||||||
async def provider(form: Annotated[OAuth2PasswordRequestForm, Depends()], database: DatabaseState = Depends()):
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.post("https://vcs.elnafo.ru/login/oauth/access_token", data = {
|
|
||||||
"client_id": "1edfec03-0bbe-4f53-bab6-7e24f0b842e3",
|
|
||||||
"client_secret": "gto_7ecfnqg2c6kbe2qf25wjee237mmkxvbkb7arjacyvtypi24hqv4q",
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"code": "gta_63l6zogw5wlnkeng4gf3buqtoekkaxk7zhr67zlkyrv2ukwfeava"
|
|
||||||
})
|
|
||||||
return response.content, response.status_code
|
|
||||||
|
|
||||||
@router.post("/auth/authorize")
|
|
||||||
async def authorize(form: Annotated[OAuth2PasswordRequestForm, Depends()], database: DatabaseState = Depends()):
|
|
||||||
|
|
||||||
|
|
||||||
if form.client_id:
|
|
||||||
async with database.session() as session:
|
|
||||||
if not (user := (await session.scalars(select(db.User).where(db.User.login_name == form.username))).first()):
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
|
|
||||||
|
|
||||||
await session.refresh(user, attribute_names = ["oauth2_apps"])
|
|
||||||
oauth2_app = None
|
|
||||||
|
|
||||||
for app in user.oauth2_apps:
|
|
||||||
if form.client_id == app.client_id and bcrypt.checkpw(form.client_secret.encode(), app.client_secret):
|
|
||||||
oauth2_app = app
|
|
||||||
|
|
||||||
if not oauth2_app:
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid client id")
|
|
||||||
|
|
||||||
data = json.dumps({"client_id": form.client_id}).encode()
|
|
||||||
|
|
||||||
else:
|
|
||||||
async with database.session() as session:
|
|
||||||
if not (user := (await session.scalars(select(db.User).where(db.User.login_name == form.username))).first()):
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user credentials")
|
|
||||||
|
|
||||||
if not bcrypt.checkpw(form.password.encode(), user.hashed_password.encode()):
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid password")
|
|
||||||
|
|
||||||
data = json.dumps({"username": form.username}).encode()
|
|
||||||
|
|
||||||
key = b'sGEuUeKrooiNAy7L9sf6IFIjpv86TC9iYU_sbWqA-1c=' # Fernet.generate_key()
|
|
||||||
f = Fernet(key)
|
|
||||||
code = base64.b64encode(f.encrypt(data), b"-_").decode().replace("=", "")
|
|
||||||
global storage
|
|
||||||
storage = code
|
|
||||||
return code
|
|
||||||
|
|
||||||
storage = None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/auth/token")
|
|
||||||
async def token(exchange: schema.Exchange, response: Response, config: ConfigState = Depends()):
|
|
||||||
if exchange.grant_type == "authorization_code":
|
|
||||||
if not exchange.code:
|
|
||||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Missing authorization code")
|
|
||||||
# expiration
|
|
||||||
if exchange.code != storage:
|
|
||||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Invalid authorization code")
|
|
||||||
|
|
||||||
token = TokenClaims.create(
|
|
||||||
"asd",
|
|
||||||
config.jwt.secret,
|
|
||||||
config.jwt.maxage
|
|
||||||
)
|
|
||||||
|
|
||||||
response.set_cookie(
|
|
||||||
"token",
|
|
||||||
value = token,
|
|
||||||
max_age = config.jwt.maxage,
|
|
||||||
secure = True,
|
|
||||||
httponly = True,
|
|
||||||
samesite = "none"
|
|
||||||
)
|
|
||||||
|
|
||||||
return schema.AccessToken(
|
|
||||||
access_token = token,
|
|
||||||
token_type = "Bearer",
|
|
||||||
expires_in = config.jwt.maxage,
|
|
||||||
refresh_token = token,
|
|
||||||
scope = "identify"
|
|
||||||
)
|
|
||||||
elif exchange.grant_type == "refresh_token":
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
|
||||||
|
|
@ -1,44 +1,40 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from sqlalchemy import and_, insert, select, update
|
|
||||||
|
|
||||||
from materia import db
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from materia.api.state import ConfigState, DatabaseState
|
|
||||||
from materia.api.middleware import JwtMiddleware
|
from materia_server.models import User, Directory, DirectoryInfo
|
||||||
from materia.config import Config
|
from materia_server.models.directory import DirectoryInfo
|
||||||
from materia.api import schema
|
from materia_server.routers import middleware
|
||||||
|
from materia_server.config import Config
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(tags = ["directory"])
|
router = APIRouter(tags = ["directory"])
|
||||||
|
|
||||||
@router.post("/directory", dependencies = [Depends(JwtMiddleware())])
|
@router.post("/directory")
|
||||||
async def create(request: Request, path: Path = Path(), config: ConfigState = Depends(), database: DatabaseState = Depends()):
|
async def create(path: Path = Path(), user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
|
||||||
user = request.state.user
|
repository_path = Config.data_dir() / "repository" / user.lower_name
|
||||||
repository_path = Config.data_dir() / "repository" / user.login_name.lower()
|
|
||||||
blacklist = [os.sep, ".", "..", "*"]
|
blacklist = [os.sep, ".", "..", "*"]
|
||||||
directory_path = Path(os.sep.join(filter(lambda part: part not in blacklist, path.parts)))
|
directory_path = Path(os.sep.join(filter(lambda part: part not in blacklist, path.parts)))
|
||||||
|
|
||||||
async with database.session() as session:
|
async with ctx.database.session() as session:
|
||||||
session.add(user)
|
session.add(user)
|
||||||
await session.refresh(user, attribute_names = ["repository"])
|
await session.refresh(user, attribute_names = ["repository"])
|
||||||
|
|
||||||
|
if not user.repository:
|
||||||
|
raise HTTPException(status.HTTP_404_NOT_FOUND, "Repository is not found")
|
||||||
|
|
||||||
current_directory = None
|
current_directory = None
|
||||||
current_path = Path()
|
current_path = Path()
|
||||||
directory = None
|
directory = None
|
||||||
|
|
||||||
for part in directory_path.parts:
|
for part in directory_path.parts:
|
||||||
if not (directory := (await session
|
if not await Directory.by_path(user.repository.id, current_path, part, ctx.database):
|
||||||
.scalars(select(db.Directory)
|
directory = Directory(
|
||||||
.where(and_(db.Directory.name == part, db.Directory.path == str(current_path))))
|
|
||||||
).first()):
|
|
||||||
directory = db.Directory(
|
|
||||||
repository_id = user.repository.id,
|
repository_id = user.repository.id,
|
||||||
parent_id = current_directory.id if current_directory else None,
|
parent_id = current_directory.id if current_directory else None,
|
||||||
name = part,
|
name = part,
|
||||||
path = str(current_path)
|
path = None if current_path == Path() else str(current_path)
|
||||||
)
|
)
|
||||||
session.add(directory)
|
session.add(directory)
|
||||||
|
|
||||||
@ -52,23 +48,20 @@ async def create(request: Request, path: Path = Path(), config: ConfigState = De
|
|||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
@router.get("/directory", dependencies = [Depends(JwtMiddleware())])
|
@router.get("/directory")
|
||||||
async def info(request: Request, repository_id: int, path: Path, config: ConfigState = Depends(), database: DatabaseState = Depends()):
|
async def info(path: Path, user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
|
||||||
async with database.session() as session:
|
async with ctx.database.session() as session:
|
||||||
if directory := (await session
|
session.add(user)
|
||||||
.scalars(select(db.Directory)
|
await session.refresh(user, attribute_names = ["repository"])
|
||||||
.where(and_(db.Directory.repository_id == repository_id, db.Directory.name == path.name, db.Directory.path == path.parent))
|
|
||||||
)).first():
|
if not(directory := await Directory.by_path(user.repository.id, None if path.parent == Path() else path.parent, path.name, ctx.database)):
|
||||||
await session.refresh(directory, attribute_names = ["files"])
|
raise HTTPException(status.HTTP_404_NOT_FOUND, "Directory is not found")
|
||||||
return schema.DirectoryInfo(
|
|
||||||
id = directory.id,
|
session.add(directory)
|
||||||
created_at = directory.created_unix,
|
await session.refresh(directory, attribute_names = ["files"])
|
||||||
updated_at = directory.updated_unix,
|
|
||||||
name = directory.name,
|
info = DirectoryInfo.model_validate(directory)
|
||||||
path = directory.path,
|
info.used = sum([ file.size for file in directory.files ])
|
||||||
is_public = directory.is_public,
|
|
||||||
used = sum([ file.size for file in directory.files ])
|
return info
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Repository is not found")
|
|
||||||
|
@ -1,51 +0,0 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from sqlalchemy import and_, insert, select, update
|
|
||||||
|
|
||||||
from materia import db
|
|
||||||
from materia.api import schema
|
|
||||||
from materia.api.state import ConfigState, DatabaseState
|
|
||||||
from materia.api.middleware import JwtMiddleware
|
|
||||||
from materia.config import Config
|
|
||||||
from materia.api import repository, directory
|
|
||||||
|
|
||||||
router = APIRouter(tags = ["file"])
|
|
||||||
|
|
||||||
@router.put("/file", dependencies = [Depends(JwtMiddleware())])
|
|
||||||
async def upload(request: Request, file: UploadFile, directory_path: Path = Path(), config: ConfigState = Depends(), database: DatabaseState = Depends()):
|
|
||||||
user = request.state.user
|
|
||||||
|
|
||||||
try:
|
|
||||||
await repository.create(request, config = config, database = database)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
#try:
|
|
||||||
# directory_info = directory.info
|
|
||||||
# await directory.create(request, path = directory_path, config = config, database = database)
|
|
||||||
|
|
||||||
async with database.session() as session:
|
|
||||||
if file_ := (await session
|
|
||||||
.scalars(select(db.File)
|
|
||||||
.where(and_(db.File.name == file.filename, db.File.path == str(directory_path))))
|
|
||||||
).first():
|
|
||||||
await session.execute(update(db.File).where(db.File.id == file_.id).values(updated_unix = time.time(), size = file.size))
|
|
||||||
else:
|
|
||||||
file_ = db.File(
|
|
||||||
repository_id = user.repository.id,
|
|
||||||
parent_id = directory.id if directory else None,
|
|
||||||
name = file.filename,
|
|
||||||
path = str(directory_path),
|
|
||||||
size = file.size
|
|
||||||
)
|
|
||||||
session.add(file_)
|
|
||||||
|
|
||||||
try:
|
|
||||||
(repository_path / directory_path / file.filename).write_bytes(await file.read())
|
|
||||||
except OSError:
|
|
||||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to write a file")
|
|
||||||
|
|
||||||
await session.commit()
|
|
@ -1,91 +0,0 @@
|
|||||||
import os
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from sqlalchemy import and_, insert, select, update
|
|
||||||
|
|
||||||
from materia import db
|
|
||||||
from materia.api.state import ConfigState, DatabaseState
|
|
||||||
from materia.api.middleware import JwtMiddleware
|
|
||||||
from materia.config import Config
|
|
||||||
from materia.api import repository
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(tags = ["filesystem"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/play")
|
|
||||||
async def play():
|
|
||||||
def iterfile():
|
|
||||||
with open(Config.data_dir() / ".." / "bfg.mp3", mode="rb") as file_like: #
|
|
||||||
yield from file_like #
|
|
||||||
|
|
||||||
return StreamingResponse(iterfile(), media_type="audio/mp3")
|
|
||||||
|
|
||||||
@router.put("/file/upload", dependencies = [Depends(JwtMiddleware())])
|
|
||||||
async def upload(request: Request, file: UploadFile, config: ConfigState = Depends(), database: DatabaseState = Depends(), directory_path: Path = Path()):
|
|
||||||
user = request.state.user
|
|
||||||
repository_path = Config.data_dir() / "repository" / user.login_name.lower()
|
|
||||||
blacklist = [os.sep, ".", "..", "*"]
|
|
||||||
directory_path = Path(os.sep.join(filter(lambda part: part not in blacklist, directory_path.parts)))
|
|
||||||
|
|
||||||
try:
|
|
||||||
await repository.create(request, config = config, database = database)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
async with database.session() as session:
|
|
||||||
session.add(user)
|
|
||||||
await session.refresh(user, attribute_names = ["repository"])
|
|
||||||
|
|
||||||
current_directory = None
|
|
||||||
current_path = Path()
|
|
||||||
directory = None
|
|
||||||
|
|
||||||
for part in directory_path.parts:
|
|
||||||
if not (directory := (await session
|
|
||||||
.scalars(select(db.Directory)
|
|
||||||
.where(and_(db.Directory.name == part, db.Directory.path == str(current_path))))
|
|
||||||
).first()):
|
|
||||||
directory = db.Directory(
|
|
||||||
repository_id = user.repository.id,
|
|
||||||
parent_id = current_directory.id if current_directory else None,
|
|
||||||
name = part,
|
|
||||||
path = str(current_path)
|
|
||||||
)
|
|
||||||
session.add(directory)
|
|
||||||
|
|
||||||
current_directory = directory
|
|
||||||
current_path /= part
|
|
||||||
|
|
||||||
try:
|
|
||||||
(repository_path / directory_path).mkdir(parents = True, exist_ok = True)
|
|
||||||
except OSError:
|
|
||||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to created a directory")
|
|
||||||
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async with database.session() as session:
|
|
||||||
if file_ := (await session
|
|
||||||
.scalars(select(db.File)
|
|
||||||
.where(and_(db.File.name == file.filename, db.File.path == str(directory_path))))
|
|
||||||
).first():
|
|
||||||
await session.execute(update(db.File).where(db.File.id == file_.id).values(updated_unix = time.time(), size = file.size))
|
|
||||||
else:
|
|
||||||
file_ = db.File(
|
|
||||||
repository_id = user.repository.id,
|
|
||||||
parent_id = directory.id if directory else None,
|
|
||||||
name = file.filename,
|
|
||||||
path = str(directory_path),
|
|
||||||
size = file.size
|
|
||||||
)
|
|
||||||
session.add(file_)
|
|
||||||
|
|
||||||
try:
|
|
||||||
(repository_path / directory_path / file.filename).write_bytes(await file.read())
|
|
||||||
except OSError:
|
|
||||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to write a file")
|
|
||||||
|
|
||||||
await session.commit()
|
|
||||||
|
|
@ -1,60 +1,45 @@
|
|||||||
import os
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from sqlalchemy import and_, insert, select, update
|
|
||||||
|
|
||||||
from materia import db
|
from materia_server.models import User, Repository, RepositoryInfo
|
||||||
from materia.api import schema
|
from materia_server.routers import middleware
|
||||||
from materia.api.state import ConfigState, DatabaseState
|
from materia_server.config import Config
|
||||||
from materia.api.middleware import JwtMiddleware
|
|
||||||
from materia.config import Config
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(tags = ["repository"])
|
router = APIRouter(tags = ["repository"])
|
||||||
|
|
||||||
@router.post("/repository", dependencies = [Depends(JwtMiddleware())])
|
@router.post("/repository")
|
||||||
async def create(request: Request, config: ConfigState = Depends(), database: DatabaseState = Depends()):
|
async def create(user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
|
||||||
user = request.state.user
|
repository_path = Config.data_dir() / "repository" / user.lower_name
|
||||||
repository_path = Config.data_dir() / "repository" / user.login_name.lower()
|
|
||||||
|
|
||||||
async with database.session() as session:
|
if await Repository.by_user_id(user.id, ctx.database):
|
||||||
|
raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists")
|
||||||
|
|
||||||
|
repository = Repository(
|
||||||
|
user_id = user.id,
|
||||||
|
capacity = ctx.config.repository.capacity
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
repository_path.mkdir(parents = True, exist_ok = True)
|
||||||
|
except OSError:
|
||||||
|
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to created a repository")
|
||||||
|
|
||||||
|
await repository.create(ctx.database)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/repository", response_model = RepositoryInfo)
|
||||||
|
async def info(user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
|
||||||
|
async with ctx.database.session() as session:
|
||||||
session.add(user)
|
session.add(user)
|
||||||
await session.refresh(user, attribute_names = ["repository"])
|
await session.refresh(user, attribute_names = ["repository"])
|
||||||
|
|
||||||
if not (repository := user.repository):
|
if not (repository := user.repository):
|
||||||
repository = db.Repository(
|
|
||||||
owner_id = user.id,
|
|
||||||
capacity = config.repository.capacity
|
|
||||||
)
|
|
||||||
session.add(repository)
|
|
||||||
|
|
||||||
try:
|
|
||||||
repository_path.mkdir(parents = True, exist_ok = True)
|
|
||||||
except OSError:
|
|
||||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to created a repository")
|
|
||||||
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists")
|
|
||||||
|
|
||||||
@router.get("/repository", dependencies = [Depends(JwtMiddleware())])
|
|
||||||
async def info(request: Request, database: DatabaseState = Depends()):
|
|
||||||
user = request.state.user
|
|
||||||
|
|
||||||
async with database.session() as session:
|
|
||||||
session.add(user)
|
|
||||||
await session.refresh(user, attribute_names = ["repository"])
|
|
||||||
|
|
||||||
if repository := user.repository:
|
|
||||||
await session.refresh(repository, attribute_names = ["files"])
|
|
||||||
|
|
||||||
return schema.RepositoryInfo(
|
|
||||||
capacity = repository.capacity,
|
|
||||||
used = sum([ file.size for file in repository.files ])
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Repository is not found")
|
raise HTTPException(status.HTTP_404_NOT_FOUND, "Repository is not found")
|
||||||
|
|
||||||
|
await session.refresh(repository, attribute_names = ["files"])
|
||||||
|
|
||||||
|
return RepositoryInfo(
|
||||||
|
capacity = repository.capacity,
|
||||||
|
used = sum([ file.size for file in repository.files ])
|
||||||
|
)
|
||||||
|
|
||||||
|
@ -1,5 +0,0 @@
|
|||||||
from materia.api.schema.user import NewUser, User, RemoveUser, LoginUser
|
|
||||||
from materia.api.schema.token import Token
|
|
||||||
from materia.api.schema.repository import RepositoryInfo
|
|
||||||
from materia.api.schema.directory import DirectoryInfo
|
|
||||||
from materia.api.schema.auth import AccessToken, Exchange
|
|
@ -1,25 +0,0 @@
|
|||||||
from typing import Optional
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class AuthCode(BaseModel):
|
|
||||||
client_id: str
|
|
||||||
response_type: str
|
|
||||||
state: str
|
|
||||||
redirect_uri: Optional[str]
|
|
||||||
scope: Optional[str]
|
|
||||||
|
|
||||||
class Exchange(BaseModel):
|
|
||||||
grant_type: str
|
|
||||||
client_id: Optional[str] = None
|
|
||||||
client_secret: Optional[str] = None
|
|
||||||
redirect_uri: Optional[str] = None
|
|
||||||
code: Optional[str] = None
|
|
||||||
refresh_token: Optional[str] = None
|
|
||||||
|
|
||||||
class AccessToken(BaseModel):
|
|
||||||
access_token: str
|
|
||||||
token_type: str
|
|
||||||
expires_in: int
|
|
||||||
refresh_token: str
|
|
||||||
scope: Optional[str]
|
|
@ -1,11 +0,0 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class DirectoryInfo(BaseModel):
|
|
||||||
id: int
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
name: str
|
|
||||||
path: str
|
|
||||||
is_public: bool
|
|
||||||
used: int
|
|
@ -1,6 +0,0 @@
|
|||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class RepositoryInfo(BaseModel):
|
|
||||||
capacity: int
|
|
||||||
used: int
|
|
@ -1,10 +0,0 @@
|
|||||||
from typing import Optional, Self
|
|
||||||
from uuid import UUID
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from materia.api.token import TokenClaims
|
|
||||||
|
|
||||||
|
|
||||||
class Token(BaseModel):
|
|
||||||
access_token: str
|
|
||||||
|
|
@ -1,40 +0,0 @@
|
|||||||
from typing import Optional, Self
|
|
||||||
from uuid import UUID
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from materia import db
|
|
||||||
|
|
||||||
|
|
||||||
class NewUser(BaseModel):
|
|
||||||
login: str
|
|
||||||
password: str
|
|
||||||
email: str
|
|
||||||
|
|
||||||
class User(BaseModel):
|
|
||||||
id: str
|
|
||||||
login: str
|
|
||||||
name: str
|
|
||||||
email: str
|
|
||||||
is_admin: bool
|
|
||||||
avatar: Optional[str]
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_(user: db.User) -> Self:
|
|
||||||
return User(
|
|
||||||
id = str(user.id),
|
|
||||||
login = user.login_name,
|
|
||||||
name = user.name,
|
|
||||||
email = user.email,
|
|
||||||
is_admin = user.is_admin,
|
|
||||||
avatar = user.avatar
|
|
||||||
)
|
|
||||||
|
|
||||||
class RemoveUser(BaseModel):
|
|
||||||
id: UUID
|
|
||||||
|
|
||||||
class LoginUser(BaseModel):
|
|
||||||
email: Optional[str] = None
|
|
||||||
login: Optional[str] = None
|
|
||||||
password: str
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
|||||||
from typing import Self
|
|
||||||
import jwt
|
|
||||||
from pydantic import BaseModel
|
|
||||||
import time
|
|
||||||
import datetime
|
|
||||||
|
|
||||||
|
|
||||||
class TokenClaims(BaseModel):
|
|
||||||
sub: str
|
|
||||||
exp: int
|
|
||||||
iat: int
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def create(sub: str, secret: str, duration: int) -> str:
|
|
||||||
now = datetime.datetime.now()
|
|
||||||
iat = now.timestamp()
|
|
||||||
exp = (now + datetime.timedelta(seconds = duration)).timestamp()
|
|
||||||
claims = TokenClaims(sub = sub, exp = int(exp), iat = int(iat))
|
|
||||||
|
|
||||||
return jwt.encode(claims.model_dump(), secret)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def verify(token: str, secret: str) -> Self:
|
|
||||||
data = jwt.decode(token, secret, algorithms = ["HS256"])
|
|
||||||
|
|
||||||
return TokenClaims(**data)
|
|
||||||
|
|
50
materia-server/src/materia_server/routers/api/user.py
Normal file
50
materia-server/src/materia_server/routers/api/user.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
|
||||||
|
import uuid
|
||||||
|
import io
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqids.sqids import Sqids
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from materia_server.config import Config
|
||||||
|
from materia_server.models import User, UserInfo
|
||||||
|
from materia_server.routers import middleware
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(tags = ["user"])
|
||||||
|
|
||||||
|
@router.get("/user", response_model = UserInfo)
|
||||||
|
async def info(claims = Depends(middleware.jwt_cookie), ctx: middleware.Context = Depends()):
|
||||||
|
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
||||||
|
|
||||||
|
return UserInfo.model_validate(current_user)
|
||||||
|
|
||||||
|
@router.post("/user/avatar")
|
||||||
|
async def avatar(file: UploadFile, user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
|
||||||
|
async with ctx.database.session() as session:
|
||||||
|
avatars: list[str] = (await session.scalars(sa.select(User.avatar))).all()
|
||||||
|
avatars = list(filter(lambda avatar_hash: avatar_hash, avatars))
|
||||||
|
|
||||||
|
avatar_id = Sqids(min_length = 10, blocklist = avatars).encode([len(avatars)])
|
||||||
|
|
||||||
|
try:
|
||||||
|
img = Image.open(io.BytesIO(await file.read()))
|
||||||
|
except OSError as _:
|
||||||
|
raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, "Failed to read file data")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (avatars_dir := Config.data_dir() / "avatars").exists():
|
||||||
|
avatars_dir.mkdir()
|
||||||
|
img.save(avatars_dir / avatar_id, format = img.format)
|
||||||
|
except OSError as _:
|
||||||
|
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to save avatar")
|
||||||
|
|
||||||
|
if old_avatar := user.avatar:
|
||||||
|
if (old_file := Config.data_dir() / "avatars" / old_avatar).exists():
|
||||||
|
old_file.unlink()
|
||||||
|
|
||||||
|
async with ctx.database.session() as session:
|
||||||
|
await session.execute(sa.update(user.User).where(user.User.id == user.id).values(avatar = avatar_id))
|
||||||
|
await session.commit()
|
@ -1,5 +0,0 @@
|
|||||||
from fastapi import APIRouter
|
|
||||||
from materia_server.routers.api.user import user
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
router.include_router(user.router)
|
|
@ -1,19 +0,0 @@
|
|||||||
|
|
||||||
from typing import Optional
|
|
||||||
import uuid
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
|
||||||
|
|
||||||
from materia_server import security
|
|
||||||
from materia_server.routers import context
|
|
||||||
from materia_server.models import user
|
|
||||||
from materia_server.models import auth
|
|
||||||
from materia_server.routers.middleware import JwtMiddleware
|
|
||||||
|
|
||||||
router = APIRouter(tags = ["user"])
|
|
||||||
|
|
||||||
@router.get("/user/identity", response_model = user.UserIdentity)
|
|
||||||
async def identity(request: Request, claims = Depends(JwtMiddleware()), ctx: context.Context = Depends()):
|
|
||||||
if not (current_user := await user.User.by_id(uuid.UUID(claims.sub), ctx.database)):
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
|
||||||
|
|
||||||
return user.UserIdentity.model_validate(current_user)
|
|
@ -1,134 +0,0 @@
|
|||||||
import io
|
|
||||||
from typing import Any, Optional
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, UploadFile, status
|
|
||||||
from sqlalchemy import delete, select, insert, func, or_, update
|
|
||||||
import bcrypt
|
|
||||||
from sqids.sqids import Sqids
|
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
from materia.config import Config
|
|
||||||
from materia.api.middleware import JwtMiddleware
|
|
||||||
from materia import db
|
|
||||||
from materia.api import schema
|
|
||||||
from materia.api.state import ConfigState, DatabaseState
|
|
||||||
from materia.api.token import TokenClaims
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(tags = ["user"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/user/register", response_model = schema.User)
|
|
||||||
async def register(body: schema.NewUser, database: DatabaseState = Depends()):
|
|
||||||
|
|
||||||
async with database.session() as session:
|
|
||||||
count: Optional[int] = await session.scalar(select(func.count(db.User.id)))
|
|
||||||
|
|
||||||
user = (await session.scalars(
|
|
||||||
select(db.User)
|
|
||||||
.where(or_(db.User.login_name == body.login, db.User.email == body.email)
|
|
||||||
))).first()
|
|
||||||
|
|
||||||
if user is not None:
|
|
||||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "User already exists")
|
|
||||||
|
|
||||||
hashed_password = bcrypt.hashpw(body.password.encode(), bcrypt.gensalt()).decode()
|
|
||||||
|
|
||||||
new_user = db.User(
|
|
||||||
login_name = body.login,
|
|
||||||
hashed_password = hashed_password,
|
|
||||||
name = body.login,
|
|
||||||
email = body.email,
|
|
||||||
is_admin = count == 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with database.session() as session:
|
|
||||||
user = (await session.scalars(insert(db.User).returning(db.User), [new_user.__dict__])).first()
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
return schema.User.from_(user)
|
|
||||||
|
|
||||||
@router.post("/user/remove", status_code = 200)
|
|
||||||
async def remove(body: schema.RemoveUser, database: DatabaseState = Depends()):
|
|
||||||
async with database.session() as session:
|
|
||||||
await session.execute(delete(db.User).where(db.User.id == body.id))
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
@router.post("/user/login", status_code = 200, response_model = schema.Token)
|
|
||||||
async def login(body: schema.LoginUser, response: Response, database: DatabaseState = Depends(), config: ConfigState = Depends()) -> Any:
|
|
||||||
query = select(db.User)
|
|
||||||
if login := body.login:
|
|
||||||
query = query.where(db.User.login_name == login)
|
|
||||||
elif email := body.email:
|
|
||||||
query = query.where(db.User.email == email)
|
|
||||||
else:
|
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Missing credentials")
|
|
||||||
|
|
||||||
async with database.session() as session:
|
|
||||||
if not (user := (await session.scalars(query)).first()):
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
|
||||||
|
|
||||||
if not bcrypt.checkpw(body.password.encode(), user.hashed_password.encode()):
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid password")
|
|
||||||
|
|
||||||
token = TokenClaims.create(
|
|
||||||
str(user.id),
|
|
||||||
config.jwt.secret,
|
|
||||||
config.jwt.maxage
|
|
||||||
)
|
|
||||||
|
|
||||||
response.set_cookie(
|
|
||||||
"token",
|
|
||||||
value = token,
|
|
||||||
max_age = config.jwt.maxage,
|
|
||||||
secure = True,
|
|
||||||
httponly = True,
|
|
||||||
samesite = "none"
|
|
||||||
)
|
|
||||||
|
|
||||||
return schema.Token(access_token = token)
|
|
||||||
|
|
||||||
@router.get("/user/logout", status_code = 200)
|
|
||||||
async def logout(response: Response):
|
|
||||||
response.set_cookie(
|
|
||||||
"token",
|
|
||||||
value = "",
|
|
||||||
max_age = -1,
|
|
||||||
secure = True,
|
|
||||||
httponly = True,
|
|
||||||
samesite = "none"
|
|
||||||
)
|
|
||||||
|
|
||||||
@router.get("/user/current", dependencies = [Depends(JwtMiddleware())], response_model = schema.User)
|
|
||||||
async def current(request: Request):
|
|
||||||
return schema.User.from_(request.state.user)
|
|
||||||
|
|
||||||
@router.post("/user/avatar", dependencies = [Depends(JwtMiddleware())])
|
|
||||||
async def avatar(request: Request, file: UploadFile, database: DatabaseState = Depends()):
|
|
||||||
async with database.session() as session:
|
|
||||||
avatars: list[str] = (await session.scalars(select(db.User.avatar))).all()
|
|
||||||
avatars = list(filter(lambda avatar_hash: avatar_hash, avatars))
|
|
||||||
|
|
||||||
avatar_id = Sqids(min_length = 10, blocklist = avatars).encode([len(avatars)])
|
|
||||||
|
|
||||||
try:
|
|
||||||
img = Image.open(io.BytesIO(await file.read()))
|
|
||||||
except OSError as _:
|
|
||||||
raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, "Failed to read file data")
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not (avatars_dir := Config.data_dir() / "avatars").exists():
|
|
||||||
avatars_dir.mkdir()
|
|
||||||
img.save(avatars_dir / avatar_id, format = img.format)
|
|
||||||
except OSError as _:
|
|
||||||
raise HTTPException(status.WS_1011_INTERNAL_ERROR, "Failed to save avatar")
|
|
||||||
|
|
||||||
if old_avatar := request.state.user.avatar:
|
|
||||||
if (old_file := Config.data_dir() / "avatars" / old_avatar).exists():
|
|
||||||
old_file.unlink()
|
|
||||||
|
|
||||||
async with database.session() as session:
|
|
||||||
await session.execute(update(db.User).where(db.User.id == request.state.user.id).values(avatar = avatar_id))
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
|
|
@ -1,15 +0,0 @@
|
|||||||
from fastapi import Request
|
|
||||||
|
|
||||||
from materia_server.config import Config
|
|
||||||
from materia_server.models.database import Database, Cache
|
|
||||||
from materia_server._logging import Logger
|
|
||||||
|
|
||||||
|
|
||||||
class Context:
|
|
||||||
def __init__(self, request: Request):
|
|
||||||
self.config = request.state.config
|
|
||||||
self.database = request.state.database
|
|
||||||
#self.cache = request.state.cache
|
|
||||||
self.logger = request.state.logger
|
|
||||||
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
|||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
from fastapi import HTTPException, Request, Response, status, Depends, Cookie
|
from fastapi import HTTPException, Request, Response, status, Depends, Cookie
|
||||||
from fastapi.security.base import SecurityBase
|
from fastapi.security.base import SecurityBase
|
||||||
import jwt
|
import jwt
|
||||||
@ -10,109 +11,71 @@ from http import HTTPMethod as HttpMethod
|
|||||||
from fastapi.security import HTTPBearer, OAuth2PasswordBearer, OAuth2PasswordRequestForm, APIKeyQuery, APIKeyCookie, APIKeyHeader
|
from fastapi.security import HTTPBearer, OAuth2PasswordBearer, OAuth2PasswordRequestForm, APIKeyQuery, APIKeyCookie, APIKeyHeader
|
||||||
|
|
||||||
from materia_server import security
|
from materia_server import security
|
||||||
from materia_server.routers import context
|
from materia_server.models import User
|
||||||
from materia_server.models import user
|
|
||||||
|
|
||||||
|
|
||||||
async def get_token_claims(token, ctx: context.Context = Depends()) -> security.TokenClaims:
|
class Context:
|
||||||
|
def __init__(self, request: Request):
|
||||||
|
self.config = request.state.config
|
||||||
|
self.database = request.state.database
|
||||||
|
self.cache = request.state.cache
|
||||||
|
self.logger = request.state.logger
|
||||||
|
|
||||||
|
|
||||||
|
async def jwt_cookie(request: Request, response: Response, ctx: Context = Depends()):
|
||||||
|
if not (access_token := request.cookies.get(ctx.config.security.cookie_access_token_name)):
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing token")
|
||||||
|
refresh_token = request.cookies.get(ctx.config.security.cookie_refresh_token_name)
|
||||||
|
|
||||||
|
if ctx.config.oauth2.jwt_signing_algo in ["HS256", "HS384", "HS512"]:
|
||||||
|
secret = ctx.config.oauth2.jwt_secret
|
||||||
|
else:
|
||||||
|
secret = ctx.config.oauth2.jwt_signing_key
|
||||||
|
|
||||||
|
issuer = "{}://{}".format(ctx.config.server.scheme, ctx.config.server.domain)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
secret = ctx.config.oauth2.jwt_secret if ctx.config.oauth2.jwt_signing_algo in ["HS256", "HS384", "HS512"] else ctx.config.oauth2.jwt_signing_key
|
refresh_claims = security.validate_token(refresh_token, secret) if refresh_token else None
|
||||||
claims = security.validate_token(token, secret)
|
|
||||||
user_id = uuid.UUID(claims.sub) # type: ignore
|
if refresh_claims:
|
||||||
except jwt.PyJWKError as _:
|
if refresh_claims.exp < datetime.now().timestamp():
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token")
|
refresh_claims = None
|
||||||
except ValueError as _:
|
except jwt.PyJWTError:
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid token")
|
refresh_claims = None
|
||||||
|
|
||||||
if not (current_user := await user.User.by_id(user_id, ctx.database)):
|
try:
|
||||||
|
access_claims = security.validate_token(access_token, secret)
|
||||||
|
|
||||||
|
if access_claims.exp < datetime.now().timestamp():
|
||||||
|
if refresh_claims:
|
||||||
|
new_access_token = security.generate_token(
|
||||||
|
access_claims.sub,
|
||||||
|
str(secret),
|
||||||
|
ctx.config.oauth2.access_token_lifetime,
|
||||||
|
issuer
|
||||||
|
)
|
||||||
|
access_claims = security.validate_token(new_access_token, secret)
|
||||||
|
response.set_cookie(
|
||||||
|
ctx.config.security.cookie_access_token_name,
|
||||||
|
value = new_access_token,
|
||||||
|
max_age = ctx.config.oauth2.access_token_lifetime,
|
||||||
|
secure = True,
|
||||||
|
httponly = ctx.config.security.cookie_http_only,
|
||||||
|
samesite = "lax"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
access_claims = None
|
||||||
|
except jwt.PyJWTError as e:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}")
|
||||||
|
|
||||||
|
if not await User.by_id(uuid.UUID(access_claims.sub), ctx.database):
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
|
||||||
|
|
||||||
|
return access_claims
|
||||||
|
|
||||||
|
|
||||||
|
async def user(claims = Depends(jwt_cookie), ctx: Context = Depends()):
|
||||||
|
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
||||||
|
|
||||||
return claims
|
return current_user
|
||||||
|
|
||||||
class JwtBearer(HTTPBearer):
|
|
||||||
def __init__(self, **kwargs):
|
|
||||||
super().__init__(scheme_name = "Bearer", **kwargs)
|
|
||||||
self.claims = None
|
|
||||||
|
|
||||||
async def __call__(self, request: Request, ctx: context.Context = Depends()):
|
|
||||||
if credentials := await super().__call__(request):
|
|
||||||
token = credentials.credentials
|
|
||||||
else:
|
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Missing token")
|
|
||||||
|
|
||||||
self.claims = await get_token_claims(token, ctx)
|
|
||||||
|
|
||||||
class JwtCookie(SecurityBase):
|
|
||||||
def __init(self, *, auto_error: bool = True):
|
|
||||||
self.auto_error = auto_error
|
|
||||||
self.claims = None
|
|
||||||
|
|
||||||
async def __call__(self, request: Request, response: Response, ctx: context.Context = Depends()):
|
|
||||||
if not (access_token := request.cookies.get(ctx.config.security.cookie_access_token_name)):
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing token")
|
|
||||||
refresh_token = request.cookies.get(ctx.config.security.cookie_refresh_token_name)
|
|
||||||
|
|
||||||
if ctx.config.oauth2.jwt_signing_algo in ["HS256", "HS384", "HS512"]:
|
|
||||||
secret = ctx.config.oauth2.jwt_secret
|
|
||||||
else:
|
|
||||||
secret = ctx.config.oauth2.jwt_signing_key
|
|
||||||
|
|
||||||
try:
|
|
||||||
refresh_claims = security.validate_token(refresh_token, secret) if refresh_token else None
|
|
||||||
# TODO: check expiration
|
|
||||||
except jwt.PyJWTError:
|
|
||||||
refresh_claims = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
access_claims = security.validate_token(access_token, secret)
|
|
||||||
# TODO: if exp then check refresh token and create new else raise
|
|
||||||
except jwt.PyJWTError as e:
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}")
|
|
||||||
else:
|
|
||||||
# TODO: validate user
|
|
||||||
pass
|
|
||||||
|
|
||||||
self.claims = access_claims
|
|
||||||
|
|
||||||
|
|
||||||
WILDCARD = "*"
|
|
||||||
NULL = "null"
|
|
||||||
|
|
||||||
class HttpHeader(StrEnum):
|
|
||||||
ACCESS_CONTROL_ALLOW_CREDENTIALS = "Access-Control-Allow-Credentials"
|
|
||||||
ACCESS_CONTROL_ALLOW_METHODS = "Access-Control-Allow-Methods"
|
|
||||||
ACCESS_CONTROL_ALLOW_ORIGIN = "Access-Control-Allow-Origin"
|
|
||||||
ACCESS_CONTROL_ALLOW_HEADERS = "Access-Control-Allow-Headers"
|
|
||||||
ACCESS_CONTROL_EXPOSE_HEADERS = "Access-Control-Expose-Headers"
|
|
||||||
ACCESS_CONTROL_MAX_AGE = "Access-Control-Max-Age"
|
|
||||||
CONTENT_TYPE = "Content-Type"
|
|
||||||
AUTHORIZATION = "Authorization"
|
|
||||||
VARY = "Vary"
|
|
||||||
ORIGIN = "Origin"
|
|
||||||
|
|
||||||
class CorsMiddleware(BaseModel):
|
|
||||||
allow_credentials: bool = False
|
|
||||||
allow_headers: Sequence[HttpHeader | str] = []
|
|
||||||
allow_methods: Sequence[HttpMethod | str] = []
|
|
||||||
allow_origin: str = WILDCARD
|
|
||||||
expose_headers: Sequence[HttpHeader | str] = []
|
|
||||||
max_age: int = 600
|
|
||||||
|
|
||||||
|
|
||||||
async def __call__(self, request: Request, response: Response):
|
|
||||||
|
|
||||||
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_CREDENTIALS] = str(self.allow_credentials).lower()
|
|
||||||
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_HEADERS] = self.make_from(self.allow_headers)
|
|
||||||
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_METHODS] = self.make_from(self.allow_methods)
|
|
||||||
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_ORIGIN] = str(self.allow_origin)
|
|
||||||
response.headers[HttpHeader.ACCESS_CONTROL_EXPOSE_HEADERS] = self.make_from(self.expose_headers)
|
|
||||||
response.headers[HttpHeader.ACCESS_CONTROL_MAX_AGE] = str(self.max_age)
|
|
||||||
|
|
||||||
def make_from(self, value: Sequence[HttpHeader | HttpMethod | str]) -> str:
|
|
||||||
if WILDCARD in value:
|
|
||||||
return WILDCARD
|
|
||||||
elif NULL in value:
|
|
||||||
return NULL
|
|
||||||
else:
|
|
||||||
return ",".join(set(value))
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user