materia-server: repository api, directory api, collapsed modules
This commit is contained in:
parent
d8b19da646
commit
317085fc04
@ -2,7 +2,7 @@
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts
|
||||
script_location = materia_server:models/migrations
|
||||
script_location = ./src/materia_server/models/migrations
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
@ -60,7 +60,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
#sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||
sqlalchemy.url = postgresql+asyncpg://materia:materia@127.0.0.1:54320/materia
|
||||
|
||||
|
||||
[post_write_hooks]
|
@ -35,7 +35,7 @@ readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
||||
[tool.pdm.build]
|
||||
includes = ["src/materia_server", "src/materia_server/alembic.ini"]
|
||||
includes = ["src/materia_server"]
|
||||
|
||||
[build-system]
|
||||
requires = ["pdm-backend"]
|
||||
|
@ -106,6 +106,10 @@ class OAuth2(BaseModel):
|
||||
#def check(self) -> Self:
|
||||
# if self.jwt_signing_algo in ["HS256", "HS384", "HS512"]:
|
||||
# assert self.jwt_secret is not None, "JWT secret must be set for HS256, HS384, HS512 algorithms"
|
||||
# else:
|
||||
# assert self.jwt_signing_key is not None, "JWT signing key must be set"
|
||||
#
|
||||
# return self
|
||||
|
||||
|
||||
class Mailer(BaseModel):
|
||||
@ -171,9 +175,6 @@ class Config(BaseSettings, env_prefix = "materia_", env_nested_delimiter = "_"):
|
||||
else:
|
||||
return cwd
|
||||
|
||||
@staticmethod
|
||||
def create(path: Path, config: Self | None = None):
|
||||
config = config or Config()
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
@ -17,7 +17,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from materia_server import config as _config
|
||||
from materia_server.config import Config
|
||||
from materia_server._logging import make_logger, uvicorn_log_config, Logger
|
||||
from materia_server.models.database import Database, Cache
|
||||
from materia_server.models import Database, Cache
|
||||
from materia_server import routers
|
||||
|
||||
|
||||
@ -31,12 +31,19 @@ class AppContext(TypedDict):
|
||||
def create_lifespan(config: Config, logger):
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
|
||||
database = Database.new(config.database.url()) # type: ignore
|
||||
|
||||
try:
|
||||
logger.info("Connecting {}", config.database.url())
|
||||
database = Database.new(config.database.url()) # type: ignore
|
||||
except:
|
||||
logger.error("Failed to connect postgres.")
|
||||
sys.exit()
|
||||
|
||||
try:
|
||||
logger.info("Connecting {}", config.cache.url())
|
||||
cache = await Cache.new(config.cache.url()) # type: ignore
|
||||
except:
|
||||
logger.error("Failed to connect redis {}", config.cache.url())
|
||||
logger.error("Failed to connect redis.")
|
||||
sys.exit()
|
||||
|
||||
async with database.connection() as connection:
|
||||
@ -64,6 +71,7 @@ def server():
|
||||
@from_pydantic("log", _config.Log, prefix = "log")
|
||||
def start(application: _config.Application, config_path: Path, log: _config.Log):
|
||||
config = Config()
|
||||
config.log = log
|
||||
logger = make_logger(config)
|
||||
|
||||
#if user := application.user:
|
||||
@ -71,8 +79,12 @@ def start(application: _config.Application, config_path: Path, log: _config.Log)
|
||||
#if group := application.group:
|
||||
# os.setgid(pwd.getpwnam(user).pw_gid)
|
||||
# TODO: merge cli options with config
|
||||
if working_directory := (application.working_directory or config.application.working_directory):
|
||||
os.chdir(working_directory.resolve())
|
||||
if working_directory := (application.working_directory or config.application.working_directory).resolve():
|
||||
try:
|
||||
os.chdir(working_directory)
|
||||
except FileNotFoundError as e:
|
||||
logger.error("Failed to change working directory: {}", e)
|
||||
sys.exit()
|
||||
logger.debug(f"Current working directory: {working_directory}")
|
||||
|
||||
# check the configuration file or use default
|
||||
@ -106,6 +118,13 @@ def start(application: _config.Application, config_path: Path, log: _config.Log)
|
||||
|
||||
config.log.level = log.level
|
||||
logger = make_logger(config)
|
||||
if (working_directory := config.application.working_directory.resolve()):
|
||||
logger.debug(f"Change working directory: {working_directory}")
|
||||
try:
|
||||
os.chdir(working_directory)
|
||||
except FileNotFoundError as e:
|
||||
logger.error("Failed to change working directory: {}", e)
|
||||
sys.exit()
|
||||
|
||||
config.application.mode = application.mode
|
||||
|
||||
|
@ -6,4 +6,18 @@
|
||||
#from materia_server.models.directory import Directory, DirectoryLink
|
||||
#from materia_server.models.file import File, FileLink
|
||||
|
||||
#from materia_server.models.repository import *
|
||||
|
||||
from materia_server.models.auth.source import LoginType, LoginSource
|
||||
from materia_server.models.auth.oauth2 import OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode
|
||||
|
||||
from materia_server.models.database.database import Database
|
||||
from materia_server.models.database.cache import Cache
|
||||
|
||||
from materia_server.models.user import User, UserCredentials, UserInfo
|
||||
|
||||
from materia_server.models.repository import Repository, RepositoryInfo
|
||||
|
||||
from materia_server.models.directory import Directory, DirectoryLink, DirectoryInfo
|
||||
|
||||
from materia_server.models.file import File, FileLink
|
||||
|
@ -1,4 +1,5 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
from typing import AsyncIterator, Self
|
||||
from pathlib import Path
|
||||
|
||||
@ -10,6 +11,7 @@ from alembic.operations import Operations
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script.base import ScriptDirectory
|
||||
|
||||
from materia_server.config import Config
|
||||
from materia_server.models.base import Base
|
||||
|
||||
__all__ = [ "Database" ]
|
||||
@ -61,14 +63,20 @@ class Database:
|
||||
await session.close()
|
||||
|
||||
def run_migrations(self, connection: Connection):
|
||||
config = AlembicConfig(Path(__file__).parent.parent.parent / "alembic.ini")
|
||||
config.set_main_option("sqlalchemy.url", self.url) # type: ignore
|
||||
#aconfig = AlembicConfig(Path(__file__).parent.parent.parent / "alembic.ini")
|
||||
aconfig = AlembicConfig()
|
||||
aconfig.set_main_option("sqlalchemy.url", str(self.url))
|
||||
|
||||
|
||||
aconfig.set_main_option("script_location", str(Path(__file__).parent.parent.joinpath("migrations")))
|
||||
print(str(Path(__file__).parent.parent.joinpath("migrations")))
|
||||
|
||||
|
||||
context = MigrationContext.configure(
|
||||
connection = connection, # type: ignore
|
||||
opts = {
|
||||
"target_metadata": Base.metadata,
|
||||
"fn": lambda rev, _: ScriptDirectory.from_config(config)._upgrade_revs("head", rev)
|
||||
"fn": lambda rev, _: ScriptDirectory.from_config(aconfig)._upgrade_revs("head", rev)
|
||||
}
|
||||
)
|
||||
|
||||
|
@ -1,10 +1,14 @@
|
||||
from time import time
|
||||
from typing import List
|
||||
from typing import List, Optional, Self
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import BigInteger, ForeignKey
|
||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from materia_server.models.base import Base
|
||||
from materia_server.models import database
|
||||
|
||||
|
||||
class Directory(Base):
|
||||
@ -12,7 +16,7 @@ class Directory(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
|
||||
repository_id: Mapped[int] = mapped_column(ForeignKey("repository.id", ondelete = "CASCADE"))
|
||||
parent_id: Mapped[int] = mapped_column(ForeignKey("directory.id"), nullable = True)
|
||||
parent_id: Mapped[int] = mapped_column(ForeignKey("directory.id", ondelete = "CASCADE"), nullable = True)
|
||||
created: Mapped[int] = mapped_column(BigInteger, nullable = False, default = time)
|
||||
updated: Mapped[int] = mapped_column(BigInteger, nullable = False, default = time)
|
||||
name: Mapped[str]
|
||||
@ -25,6 +29,14 @@ class Directory(Base):
|
||||
files: Mapped[List["File"]] = relationship(back_populates = "parent")
|
||||
link: Mapped["DirectoryLink"] = relationship(back_populates = "directory")
|
||||
|
||||
@staticmethod
|
||||
async def by_path(repository_id: int, path: Path | None, name: str, db: database.Database) -> Self | None:
|
||||
async with db.session() as session:
|
||||
query_path = Directory.path == str(path) if isinstance(path, Path) else Directory.path.is_(None)
|
||||
return (await session
|
||||
.scalars(sa.select(Directory)
|
||||
.where(sa.and_(Directory.repository_id == repository_id, Directory.name == name, query_path)))
|
||||
).first()
|
||||
|
||||
class DirectoryLink(Base):
|
||||
__tablename__ = "directory_link"
|
||||
@ -36,5 +48,20 @@ class DirectoryLink(Base):
|
||||
|
||||
directory: Mapped["Directory"] = relationship(back_populates = "link")
|
||||
|
||||
from materia_server.models.repository.repository import Repository
|
||||
from materia_server.models.file.file import File
|
||||
class DirectoryInfo(BaseModel):
|
||||
model_config = ConfigDict(from_attributes = True)
|
||||
|
||||
id: int
|
||||
repository_id: int
|
||||
parent_id: Optional[int]
|
||||
created: int
|
||||
updated: int
|
||||
name: str
|
||||
path: Optional[str]
|
||||
is_public: bool
|
||||
|
||||
used: Optional[int] = None
|
||||
|
||||
|
||||
from materia_server.models.repository import Repository
|
||||
from materia_server.models.file import File
|
@ -1 +0,0 @@
|
||||
from materia_server.models.directory.directory import Directory, DirectoryLink
|
@ -35,5 +35,5 @@ class FileLink(Base):
|
||||
file: Mapped["File"] = relationship(back_populates = "link")
|
||||
|
||||
|
||||
from materia_server.models.repository.repository import Repository
|
||||
from materia_server.models.directory.directory import Directory
|
||||
from materia_server.models.repository import Repository
|
||||
from materia_server.models.directory import Directory
|
@ -1 +0,0 @@
|
||||
from materia_server.models.file.file import File, FileLink
|
@ -19,8 +19,12 @@ import materia_server.models.file
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
context.configure(
|
||||
version_table_schema = "public"
|
||||
)
|
||||
config = context.config
|
||||
config.set_main_option("sqlalchemy.url", Config().database.url())
|
||||
|
||||
#config.set_main_option("sqlalchemy.url", Config().database.url())
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
@ -59,6 +63,7 @@ def run_migrations_offline() -> None:
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
version_table_schema = "public"
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
@ -99,4 +104,5 @@ def run_migrations_online() -> None:
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
print("online")
|
||||
run_migrations_online()
|
||||
|
@ -1,140 +0,0 @@
|
||||
"""empty message
|
||||
|
||||
Revision ID: 76191498b728
|
||||
Revises:
|
||||
Create Date: 2024-06-03 18:44:07.044588
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '76191498b728'
|
||||
down_revision: Union[str, None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').create(op.get_bind())
|
||||
op.create_table('login_source',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False),
|
||||
sa.Column('created', sa.Integer(), nullable=False),
|
||||
sa.Column('updated', sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('user',
|
||||
sa.Column('id', sa.Uuid(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('lower_name', sa.String(), nullable=False),
|
||||
sa.Column('full_name', sa.String(), nullable=False),
|
||||
sa.Column('email', sa.String(), nullable=False),
|
||||
sa.Column('is_email_private', sa.Boolean(), nullable=False),
|
||||
sa.Column('hashed_password', sa.String(), nullable=False),
|
||||
sa.Column('must_change_password', sa.Boolean(), nullable=False),
|
||||
sa.Column('login_type', postgresql.ENUM('Plain', 'OAuth2', 'Smtp', name='logintype', create_type=False), nullable=False),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
||||
sa.Column('last_login', sa.BigInteger(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('is_admin', sa.Boolean(), nullable=False),
|
||||
sa.Column('avatar', sa.String(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('lower_name'),
|
||||
sa.UniqueConstraint('name')
|
||||
)
|
||||
op.create_table('oauth2_application',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('client_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('hashed_client_secret', sa.String(), nullable=False),
|
||||
sa.Column('redirect_uris', sa.JSON(), nullable=False),
|
||||
sa.Column('confidential_client', sa.Boolean(), nullable=False),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('repository',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('capacity', sa.BigInteger(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('directory',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('repository_id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('parent_id', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('path', sa.String(), nullable=True),
|
||||
sa.Column('is_public', sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ),
|
||||
sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('oauth2_grant',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(), nullable=False),
|
||||
sa.Column('application_id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('scope', sa.String(), nullable=False),
|
||||
sa.Column('created', sa.Integer(), nullable=False),
|
||||
sa.Column('updated', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['application_id'], ['oauth2_application.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['user.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('directory_link',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('directory_id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('url', sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['directory_id'], ['directory.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('file',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('repository_id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('parent_id', sa.BigInteger(), nullable=True),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('updated', sa.BigInteger(), nullable=False),
|
||||
sa.Column('name', sa.String(), nullable=False),
|
||||
sa.Column('path', sa.String(), nullable=True),
|
||||
sa.Column('is_public', sa.Boolean(), nullable=False),
|
||||
sa.Column('size', sa.BigInteger(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['parent_id'], ['directory.id'], ),
|
||||
sa.ForeignKeyConstraint(['repository_id'], ['repository.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_table('file_link',
|
||||
sa.Column('id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('file_id', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created', sa.BigInteger(), nullable=False),
|
||||
sa.Column('url', sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['file_id'], ['file.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('file_link')
|
||||
op.drop_table('file')
|
||||
op.drop_table('directory_link')
|
||||
op.drop_table('oauth2_grant')
|
||||
op.drop_table('directory')
|
||||
op.drop_table('repository')
|
||||
op.drop_table('oauth2_application')
|
||||
op.drop_table('user')
|
||||
op.drop_table('login_source')
|
||||
sa.Enum('Plain', 'OAuth2', 'Smtp', name='logintype').drop(op.get_bind())
|
||||
# ### end Alembic commands ###
|
51
materia-server/src/materia_server/models/repository.py
Normal file
51
materia-server/src/materia_server/models/repository.py
Normal file
@ -0,0 +1,51 @@
|
||||
from time import time
|
||||
from typing import List, Self
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import BigInteger, ForeignKey
|
||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel
|
||||
|
||||
from materia_server.models.base import Base
|
||||
from materia_server.models import database
|
||||
|
||||
|
||||
class Repository(Base):
|
||||
__tablename__ = "repository"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
|
||||
capacity: Mapped[int] = mapped_column(BigInteger, nullable = False)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates = "repository")
|
||||
directories: Mapped[List["Directory"]] = relationship(back_populates = "repository")
|
||||
files: Mapped[List["File"]] = relationship(back_populates = "repository")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return { k: getattr(self, k) for k, v in Repository.__dict__.items() if isinstance(v, InstrumentedAttribute) }
|
||||
|
||||
async def create(self, db: database.Database):
|
||||
async with db.session() as session:
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
|
||||
async def update(self, db: database.Database):
|
||||
async with db.session() as session:
|
||||
await session.execute(sa.update(Repository).where(Repository.id == self.id).values(self.to_dict()))
|
||||
await session.commit()
|
||||
|
||||
@staticmethod
|
||||
async def by_user_id(user_id: UUID, db: database.Database) -> Self | None:
|
||||
async with db.session() as session:
|
||||
return (await session.scalars(sa.select(Repository).where(Repository.user_id == user_id))).first()
|
||||
|
||||
|
||||
class RepositoryInfo(BaseModel):
|
||||
capacity: int
|
||||
used: int
|
||||
|
||||
from materia_server.models.user import User
|
||||
from materia_server.models.directory import Directory
|
||||
from materia_server.models.file import File
|
@ -1 +0,0 @@
|
||||
from materia_server.models.repository.repository import Repository
|
@ -1,25 +0,0 @@
|
||||
from time import time
|
||||
from typing import List
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from sqlalchemy import BigInteger, ForeignKey
|
||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||
|
||||
from materia_server.models.base import Base
|
||||
|
||||
|
||||
class Repository(Base):
|
||||
__tablename__ = "repository"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
|
||||
capacity: Mapped[int] = mapped_column(BigInteger, nullable = False)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates = "repository")
|
||||
directories: Mapped[List["Directory"]] = relationship(back_populates = "repository")
|
||||
files: Mapped[List["File"]] = relationship(back_populates = "repository")
|
||||
|
||||
|
||||
from materia_server.models.user.user import User
|
||||
from materia_server.models.directory.directory import Directory
|
||||
from materia_server.models.file.file import File
|
@ -80,9 +80,10 @@ class UserCredentials(BaseModel):
|
||||
password: str
|
||||
email: Optional[EmailStr]
|
||||
|
||||
class UserIdentity(BaseModel):
|
||||
class UserInfo(BaseModel):
|
||||
model_config = ConfigDict(from_attributes = True)
|
||||
|
||||
id: UUID
|
||||
name: str
|
||||
lower_name: str
|
||||
full_name: Optional[str]
|
||||
@ -101,4 +102,4 @@ class UserIdentity(BaseModel):
|
||||
|
||||
avatar: Optional[str]
|
||||
|
||||
from materia_server.models.repository.repository import Repository
|
||||
from materia_server.models.repository import Repository
|
@ -1 +0,0 @@
|
||||
from materia_server.models.user.user import User, UserCredentials, UserIdentity
|
@ -1,2 +1 @@
|
||||
from materia_server.routers import api
|
||||
from materia_server.routers import middleware
|
||||
from materia_server.routers import middleware, api
|
||||
|
@ -1,8 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
from materia_server.routers.api import auth
|
||||
from materia_server.routers.api import user
|
||||
|
||||
router = APIRouter(prefix = "/api")
|
||||
from materia_server.routers.api.auth import auth, oauth
|
||||
from materia_server.routers.api import user, repository, directory
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(auth.router)
|
||||
router.include_router(oauth.router)
|
||||
router.include_router(user.router)
|
||||
router.include_router(repository.router)
|
||||
router.include_router(directory.router)
|
||||
|
@ -1,8 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from materia_server.routers.api.auth import auth
|
||||
from materia_server.routers.api.auth import oauth
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(auth.router)
|
||||
router.include_router(oauth.router)
|
||||
|
@ -3,33 +3,32 @@ from typing import Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response, status
|
||||
|
||||
from materia_server import security
|
||||
from materia_server.routers import context
|
||||
from materia_server.models import user
|
||||
from materia_server.models import auth
|
||||
from materia_server.routers.middleware import Context
|
||||
from materia_server.models import LoginType, User, UserCredentials
|
||||
|
||||
router = APIRouter(tags = ["auth"])
|
||||
|
||||
|
||||
@router.post("/auth/signup")
|
||||
async def signup(body: user.UserCredentials, ctx: context.Context = Depends()):
|
||||
if not user.User.is_valid_username(body.name):
|
||||
async def signup(body: UserCredentials, ctx: Context = Depends()):
|
||||
if not User.is_valid_username(body.name):
|
||||
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "Invalid username")
|
||||
if await user.User.by_name(body.name, ctx.database) is not None:
|
||||
if await User.by_name(body.name, ctx.database) is not None:
|
||||
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "User already exists")
|
||||
if await user.User.by_email(body.email, ctx.database) is not None: # type: ignore
|
||||
if await User.by_email(body.email, ctx.database) is not None: # type: ignore
|
||||
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = "Email already used")
|
||||
if len(body.password) < ctx.config.security.password_min_length:
|
||||
raise HTTPException(status_code = status.HTTP_500_INTERNAL_SERVER_ERROR, detail = f"Password is too short (minimum length {ctx.config.security.password_min_length})")
|
||||
|
||||
count: Optional[int] = await user.User.count(ctx.database)
|
||||
count: Optional[int] = await User.count(ctx.database)
|
||||
|
||||
new_user = user.User(
|
||||
new_user = User(
|
||||
name = body.name,
|
||||
lower_name = body.name.lower(),
|
||||
full_name = body.name,
|
||||
email = body.email,
|
||||
hashed_password = security.hash_password(body.password, algo = ctx.config.security.password_hash_algo),
|
||||
login_type = auth.LoginType.Plain,
|
||||
login_type = LoginType.Plain,
|
||||
# first registered user is admin
|
||||
is_admin = count == 0
|
||||
)
|
||||
@ -39,9 +38,9 @@ async def signup(body: user.UserCredentials, ctx: context.Context = Depends()):
|
||||
await session.commit()
|
||||
|
||||
@router.post("/auth/signin")
|
||||
async def signin(body: user.UserCredentials, response: Response, ctx: context.Context = Depends()):
|
||||
if (current_user := await user.User.by_name(body.name, ctx.database)) is None:
|
||||
if (current_user := await user.User.by_email(str(body.email), ctx.database)) is None:
|
||||
async def signin(body: UserCredentials, response: Response, ctx: Context = Depends()):
|
||||
if (current_user := await User.by_name(body.name, ctx.database)) is None:
|
||||
if (current_user := await User.by_email(str(body.email), ctx.database)) is None:
|
||||
raise HTTPException(status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid email")
|
||||
if not security.validate_password(body.password, current_user.hashed_password, algo = ctx.config.security.password_hash_algo):
|
||||
raise HTTPException(status_code = status.HTTP_401_UNAUTHORIZED, detail = "Invalid password")
|
||||
@ -79,6 +78,6 @@ async def signin(body: user.UserCredentials, response: Response, ctx: context.Co
|
||||
)
|
||||
|
||||
@router.get("/auth/signout")
|
||||
async def signout(response: Response, ctx: context.Context = Depends()):
|
||||
async def signout(response: Response, ctx: Context = Depends()):
|
||||
response.delete_cookie(ctx.config.security.cookie_access_token_name)
|
||||
response.delete_cookie(ctx.config.security.cookie_refresh_token_name)
|
||||
|
@ -7,9 +7,8 @@ from fastapi.security.oauth2 import OAuth2PasswordRequestForm
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR
|
||||
|
||||
from materia_server.models import auth
|
||||
from materia_server.models.user import user
|
||||
from materia_server.routers import context
|
||||
from materia_server.models import User
|
||||
from materia_server.routers.middleware import Context
|
||||
|
||||
|
||||
router = APIRouter(tags = ["oauth2"])
|
||||
@ -35,17 +34,17 @@ class AuthorizationCodeResponse(BaseModel):
|
||||
code: str
|
||||
|
||||
@router.post("/oauth2/authorize")
|
||||
async def authorize(form: Annotated[OAuth2AuthorizationCodeRequestForm, Depends()], ctx: context.Context = Depends()):
|
||||
async def authorize(form: Annotated[OAuth2AuthorizationCodeRequestForm, Depends()], ctx: Context = Depends()):
|
||||
# grant_type: authorization_code, password_credentials, client_credentials, authorization_code (pkce)
|
||||
ctx.logger.debug(form)
|
||||
|
||||
if form.grant_type == "authorization_code":
|
||||
# TODO: form validation
|
||||
|
||||
if not (app := await auth.OAuth2Application.by_client_id(form.client_id, ctx.database)):
|
||||
if not (app := await OAuth2Application.by_client_id(form.client_id, ctx.database)):
|
||||
raise HTTPException(status_code = HTTP_500_INTERNAL_SERVER_ERROR, detail = "Client ID not registered")
|
||||
|
||||
if not (owner := user.User.by_id(app.user_id, ctx.database)):
|
||||
if not (owner := await User.by_id(app.user_id, ctx.database)):
|
||||
raise HTTPException(status_code = HTTP_500_INTERNAL_SERVER_ERROR, detail = "User not found")
|
||||
|
||||
if not app.contains_redirect_uri(form.redirect_uri):
|
||||
@ -79,5 +78,5 @@ class AccessTokenResponse(BaseModel):
|
||||
scope: Optional[str]
|
||||
|
||||
@router.post("/oauth2/access_token")
|
||||
async def token(ctx: context.Context = Depends()):
|
||||
async def token(ctx: Context = Depends()):
|
||||
pass
|
||||
|
@ -1,143 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, UploadFile, status
|
||||
from fastapi.responses import RedirectResponse, StreamingResponse
|
||||
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordRequestFormStrict
|
||||
import httpx
|
||||
from sqlalchemy import and_, insert, select, update
|
||||
from authlib.integrations.starlette_client import OAuth, OAuthError
|
||||
import base64
|
||||
from cryptography.fernet import Fernet
|
||||
import json
|
||||
|
||||
from materia import db
|
||||
from materia.api import schema
|
||||
from materia.api.state import ConfigState, DatabaseState
|
||||
from materia.api.middleware import JwtMiddleware
|
||||
from materia.api.token import TokenClaims
|
||||
from materia.config import Config
|
||||
|
||||
oauth = OAuth()
|
||||
oauth.register(
|
||||
"materia",
|
||||
authorize_url = "http://127.0.0.1:54601/api/auth/authorize",
|
||||
access_token_url = "http://127.0.0.1:54601/api/auth/token",
|
||||
scope = "user:read",
|
||||
client_id = "",
|
||||
client_secret = ""
|
||||
)
|
||||
|
||||
class OAuth2Provider:
|
||||
pass
|
||||
|
||||
router = APIRouter(tags = ["auth"])
|
||||
|
||||
@router.get("/user/signin")
|
||||
async def signin(request: Request, provider: str = None):
|
||||
if not provider:
|
||||
return RedirectResponse("/api/auth/authorize")
|
||||
else:
|
||||
return RedirectResponse(request.url_for(provider.authorize_url))
|
||||
|
||||
@router.post("/auth/test_auth")
|
||||
async def test_auth(database: DatabaseState = Depends()):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post("https://vcs.elnafo.ru/login/oauth/authorize", data = {
|
||||
"client_id": "1edfe-0bbe-4f53-bab6-7e24f0b842e3",
|
||||
"client_secret": "gto_7ecfnqg2c6kbe2qf25wjee237mmkxvbkb7arjacyvtypi24hqv4q",
|
||||
"response_type": "code",
|
||||
"redirect_uri": "http://127.0.0.1:54601"
|
||||
})
|
||||
return response.content, response.status_code
|
||||
|
||||
@router.post("/auth/provider")
|
||||
async def provider(form: Annotated[OAuth2PasswordRequestForm, Depends()], database: DatabaseState = Depends()):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post("https://vcs.elnafo.ru/login/oauth/access_token", data = {
|
||||
"client_id": "1edfec03-0bbe-4f53-bab6-7e24f0b842e3",
|
||||
"client_secret": "gto_7ecfnqg2c6kbe2qf25wjee237mmkxvbkb7arjacyvtypi24hqv4q",
|
||||
"grant_type": "authorization_code",
|
||||
"code": "gta_63l6zogw5wlnkeng4gf3buqtoekkaxk7zhr67zlkyrv2ukwfeava"
|
||||
})
|
||||
return response.content, response.status_code
|
||||
|
||||
@router.post("/auth/authorize")
|
||||
async def authorize(form: Annotated[OAuth2PasswordRequestForm, Depends()], database: DatabaseState = Depends()):
|
||||
|
||||
|
||||
if form.client_id:
|
||||
async with database.session() as session:
|
||||
if not (user := (await session.scalars(select(db.User).where(db.User.login_name == form.username))).first()):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
|
||||
|
||||
await session.refresh(user, attribute_names = ["oauth2_apps"])
|
||||
oauth2_app = None
|
||||
|
||||
for app in user.oauth2_apps:
|
||||
if form.client_id == app.client_id and bcrypt.checkpw(form.client_secret.encode(), app.client_secret):
|
||||
oauth2_app = app
|
||||
|
||||
if not oauth2_app:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid client id")
|
||||
|
||||
data = json.dumps({"client_id": form.client_id}).encode()
|
||||
|
||||
else:
|
||||
async with database.session() as session:
|
||||
if not (user := (await session.scalars(select(db.User).where(db.User.login_name == form.username))).first()):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user credentials")
|
||||
|
||||
if not bcrypt.checkpw(form.password.encode(), user.hashed_password.encode()):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid password")
|
||||
|
||||
data = json.dumps({"username": form.username}).encode()
|
||||
|
||||
key = b'sGEuUeKrooiNAy7L9sf6IFIjpv86TC9iYU_sbWqA-1c=' # Fernet.generate_key()
|
||||
f = Fernet(key)
|
||||
code = base64.b64encode(f.encrypt(data), b"-_").decode().replace("=", "")
|
||||
global storage
|
||||
storage = code
|
||||
return code
|
||||
|
||||
storage = None
|
||||
|
||||
|
||||
@router.post("/auth/token")
|
||||
async def token(exchange: schema.Exchange, response: Response, config: ConfigState = Depends()):
|
||||
if exchange.grant_type == "authorization_code":
|
||||
if not exchange.code:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Missing authorization code")
|
||||
# expiration
|
||||
if exchange.code != storage:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Invalid authorization code")
|
||||
|
||||
token = TokenClaims.create(
|
||||
"asd",
|
||||
config.jwt.secret,
|
||||
config.jwt.maxage
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
"token",
|
||||
value = token,
|
||||
max_age = config.jwt.maxage,
|
||||
secure = True,
|
||||
httponly = True,
|
||||
samesite = "none"
|
||||
)
|
||||
|
||||
return schema.AccessToken(
|
||||
access_token = token,
|
||||
token_type = "Bearer",
|
||||
expires_in = config.jwt.maxage,
|
||||
refresh_token = token,
|
||||
scope = "identify"
|
||||
)
|
||||
elif exchange.grant_type == "refresh_token":
|
||||
pass
|
||||
else:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST)
|
||||
|
@ -1,44 +1,40 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import and_, insert, select, update
|
||||
|
||||
from materia import db
|
||||
from materia.api.state import ConfigState, DatabaseState
|
||||
from materia.api.middleware import JwtMiddleware
|
||||
from materia.config import Config
|
||||
from materia.api import schema
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from materia_server.models import User, Directory, DirectoryInfo
|
||||
from materia_server.models.directory import DirectoryInfo
|
||||
from materia_server.routers import middleware
|
||||
from materia_server.config import Config
|
||||
|
||||
|
||||
router = APIRouter(tags = ["directory"])
|
||||
|
||||
@router.post("/directory", dependencies = [Depends(JwtMiddleware())])
|
||||
async def create(request: Request, path: Path = Path(), config: ConfigState = Depends(), database: DatabaseState = Depends()):
|
||||
user = request.state.user
|
||||
repository_path = Config.data_dir() / "repository" / user.login_name.lower()
|
||||
@router.post("/directory")
|
||||
async def create(path: Path = Path(), user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
|
||||
repository_path = Config.data_dir() / "repository" / user.lower_name
|
||||
blacklist = [os.sep, ".", "..", "*"]
|
||||
directory_path = Path(os.sep.join(filter(lambda part: part not in blacklist, path.parts)))
|
||||
|
||||
async with database.session() as session:
|
||||
async with ctx.database.session() as session:
|
||||
session.add(user)
|
||||
await session.refresh(user, attribute_names = ["repository"])
|
||||
|
||||
if not user.repository:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Repository is not found")
|
||||
|
||||
current_directory = None
|
||||
current_path = Path()
|
||||
directory = None
|
||||
|
||||
for part in directory_path.parts:
|
||||
if not (directory := (await session
|
||||
.scalars(select(db.Directory)
|
||||
.where(and_(db.Directory.name == part, db.Directory.path == str(current_path))))
|
||||
).first()):
|
||||
directory = db.Directory(
|
||||
if not await Directory.by_path(user.repository.id, current_path, part, ctx.database):
|
||||
directory = Directory(
|
||||
repository_id = user.repository.id,
|
||||
parent_id = current_directory.id if current_directory else None,
|
||||
name = part,
|
||||
path = str(current_path)
|
||||
path = None if current_path == Path() else str(current_path)
|
||||
)
|
||||
session.add(directory)
|
||||
|
||||
@ -52,23 +48,20 @@ async def create(request: Request, path: Path = Path(), config: ConfigState = De
|
||||
|
||||
await session.commit()
|
||||
|
||||
@router.get("/directory", dependencies = [Depends(JwtMiddleware())])
|
||||
async def info(request: Request, repository_id: int, path: Path, config: ConfigState = Depends(), database: DatabaseState = Depends()):
|
||||
async with database.session() as session:
|
||||
if directory := (await session
|
||||
.scalars(select(db.Directory)
|
||||
.where(and_(db.Directory.repository_id == repository_id, db.Directory.name == path.name, db.Directory.path == path.parent))
|
||||
)).first():
|
||||
await session.refresh(directory, attribute_names = ["files"])
|
||||
return schema.DirectoryInfo(
|
||||
id = directory.id,
|
||||
created_at = directory.created_unix,
|
||||
updated_at = directory.updated_unix,
|
||||
name = directory.name,
|
||||
path = directory.path,
|
||||
is_public = directory.is_public,
|
||||
used = sum([ file.size for file in directory.files ])
|
||||
)
|
||||
@router.get("/directory")
|
||||
async def info(path: Path, user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
|
||||
async with ctx.database.session() as session:
|
||||
session.add(user)
|
||||
await session.refresh(user, attribute_names = ["repository"])
|
||||
|
||||
if not(directory := await Directory.by_path(user.repository.id, None if path.parent == Path() else path.parent, path.name, ctx.database)):
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Directory is not found")
|
||||
|
||||
session.add(directory)
|
||||
await session.refresh(directory, attribute_names = ["files"])
|
||||
|
||||
info = DirectoryInfo.model_validate(directory)
|
||||
info.used = sum([ file.size for file in directory.files ])
|
||||
|
||||
return info
|
||||
|
||||
else:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Repository is not found")
|
||||
|
@ -1,51 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import and_, insert, select, update
|
||||
|
||||
from materia import db
|
||||
from materia.api import schema
|
||||
from materia.api.state import ConfigState, DatabaseState
|
||||
from materia.api.middleware import JwtMiddleware
|
||||
from materia.config import Config
|
||||
from materia.api import repository, directory
|
||||
|
||||
router = APIRouter(tags = ["file"])
|
||||
|
||||
@router.put("/file", dependencies = [Depends(JwtMiddleware())])
|
||||
async def upload(request: Request, file: UploadFile, directory_path: Path = Path(), config: ConfigState = Depends(), database: DatabaseState = Depends()):
|
||||
user = request.state.user
|
||||
|
||||
try:
|
||||
await repository.create(request, config = config, database = database)
|
||||
except:
|
||||
pass
|
||||
|
||||
#try:
|
||||
# directory_info = directory.info
|
||||
# await directory.create(request, path = directory_path, config = config, database = database)
|
||||
|
||||
async with database.session() as session:
|
||||
if file_ := (await session
|
||||
.scalars(select(db.File)
|
||||
.where(and_(db.File.name == file.filename, db.File.path == str(directory_path))))
|
||||
).first():
|
||||
await session.execute(update(db.File).where(db.File.id == file_.id).values(updated_unix = time.time(), size = file.size))
|
||||
else:
|
||||
file_ = db.File(
|
||||
repository_id = user.repository.id,
|
||||
parent_id = directory.id if directory else None,
|
||||
name = file.filename,
|
||||
path = str(directory_path),
|
||||
size = file.size
|
||||
)
|
||||
session.add(file_)
|
||||
|
||||
try:
|
||||
(repository_path / directory_path / file.filename).write_bytes(await file.read())
|
||||
except OSError:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to write a file")
|
||||
|
||||
await session.commit()
|
@ -1,91 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import and_, insert, select, update
|
||||
|
||||
from materia import db
|
||||
from materia.api.state import ConfigState, DatabaseState
|
||||
from materia.api.middleware import JwtMiddleware
|
||||
from materia.config import Config
|
||||
from materia.api import repository
|
||||
|
||||
|
||||
router = APIRouter(tags = ["filesystem"])
|
||||
|
||||
|
||||
@router.get("/play")
|
||||
async def play():
|
||||
def iterfile():
|
||||
with open(Config.data_dir() / ".." / "bfg.mp3", mode="rb") as file_like: #
|
||||
yield from file_like #
|
||||
|
||||
return StreamingResponse(iterfile(), media_type="audio/mp3")
|
||||
|
||||
@router.put("/file/upload", dependencies = [Depends(JwtMiddleware())])
|
||||
async def upload(request: Request, file: UploadFile, config: ConfigState = Depends(), database: DatabaseState = Depends(), directory_path: Path = Path()):
|
||||
user = request.state.user
|
||||
repository_path = Config.data_dir() / "repository" / user.login_name.lower()
|
||||
blacklist = [os.sep, ".", "..", "*"]
|
||||
directory_path = Path(os.sep.join(filter(lambda part: part not in blacklist, directory_path.parts)))
|
||||
|
||||
try:
|
||||
await repository.create(request, config = config, database = database)
|
||||
except:
|
||||
pass
|
||||
|
||||
async with database.session() as session:
|
||||
session.add(user)
|
||||
await session.refresh(user, attribute_names = ["repository"])
|
||||
|
||||
current_directory = None
|
||||
current_path = Path()
|
||||
directory = None
|
||||
|
||||
for part in directory_path.parts:
|
||||
if not (directory := (await session
|
||||
.scalars(select(db.Directory)
|
||||
.where(and_(db.Directory.name == part, db.Directory.path == str(current_path))))
|
||||
).first()):
|
||||
directory = db.Directory(
|
||||
repository_id = user.repository.id,
|
||||
parent_id = current_directory.id if current_directory else None,
|
||||
name = part,
|
||||
path = str(current_path)
|
||||
)
|
||||
session.add(directory)
|
||||
|
||||
current_directory = directory
|
||||
current_path /= part
|
||||
|
||||
try:
|
||||
(repository_path / directory_path).mkdir(parents = True, exist_ok = True)
|
||||
except OSError:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to created a directory")
|
||||
|
||||
await session.commit()
|
||||
|
||||
async with database.session() as session:
|
||||
if file_ := (await session
|
||||
.scalars(select(db.File)
|
||||
.where(and_(db.File.name == file.filename, db.File.path == str(directory_path))))
|
||||
).first():
|
||||
await session.execute(update(db.File).where(db.File.id == file_.id).values(updated_unix = time.time(), size = file.size))
|
||||
else:
|
||||
file_ = db.File(
|
||||
repository_id = user.repository.id,
|
||||
parent_id = directory.id if directory else None,
|
||||
name = file.filename,
|
||||
path = str(directory_path),
|
||||
size = file.size
|
||||
)
|
||||
session.add(file_)
|
||||
|
||||
try:
|
||||
(repository_path / directory_path / file.filename).write_bytes(await file.read())
|
||||
except OSError:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to write a file")
|
||||
|
||||
await session.commit()
|
||||
|
@ -1,60 +1,45 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import and_, insert, select, update
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from materia import db
|
||||
from materia.api import schema
|
||||
from materia.api.state import ConfigState, DatabaseState
|
||||
from materia.api.middleware import JwtMiddleware
|
||||
from materia.config import Config
|
||||
from materia_server.models import User, Repository, RepositoryInfo
|
||||
from materia_server.routers import middleware
|
||||
from materia_server.config import Config
|
||||
|
||||
|
||||
router = APIRouter(tags = ["repository"])
|
||||
|
||||
@router.post("/repository", dependencies = [Depends(JwtMiddleware())])
|
||||
async def create(request: Request, config: ConfigState = Depends(), database: DatabaseState = Depends()):
|
||||
user = request.state.user
|
||||
repository_path = Config.data_dir() / "repository" / user.login_name.lower()
|
||||
@router.post("/repository")
|
||||
async def create(user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
|
||||
repository_path = Config.data_dir() / "repository" / user.lower_name
|
||||
|
||||
async with database.session() as session:
|
||||
if await Repository.by_user_id(user.id, ctx.database):
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists")
|
||||
|
||||
repository = Repository(
|
||||
user_id = user.id,
|
||||
capacity = ctx.config.repository.capacity
|
||||
)
|
||||
|
||||
try:
|
||||
repository_path.mkdir(parents = True, exist_ok = True)
|
||||
except OSError:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to created a repository")
|
||||
|
||||
await repository.create(ctx.database)
|
||||
|
||||
|
||||
@router.get("/repository", response_model = RepositoryInfo)
|
||||
async def info(user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
|
||||
async with ctx.database.session() as session:
|
||||
session.add(user)
|
||||
await session.refresh(user, attribute_names = ["repository"])
|
||||
|
||||
if not (repository := user.repository):
|
||||
repository = db.Repository(
|
||||
owner_id = user.id,
|
||||
capacity = config.repository.capacity
|
||||
)
|
||||
session.add(repository)
|
||||
|
||||
try:
|
||||
repository_path.mkdir(parents = True, exist_ok = True)
|
||||
except OSError:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to created a repository")
|
||||
|
||||
await session.commit()
|
||||
|
||||
else:
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists")
|
||||
|
||||
@router.get("/repository", dependencies = [Depends(JwtMiddleware())])
|
||||
async def info(request: Request, database: DatabaseState = Depends()):
|
||||
user = request.state.user
|
||||
|
||||
async with database.session() as session:
|
||||
session.add(user)
|
||||
await session.refresh(user, attribute_names = ["repository"])
|
||||
|
||||
if repository := user.repository:
|
||||
await session.refresh(repository, attribute_names = ["files"])
|
||||
|
||||
return schema.RepositoryInfo(
|
||||
capacity = repository.capacity,
|
||||
used = sum([ file.size for file in repository.files ])
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Repository is not found")
|
||||
|
||||
await session.refresh(repository, attribute_names = ["files"])
|
||||
|
||||
return RepositoryInfo(
|
||||
capacity = repository.capacity,
|
||||
used = sum([ file.size for file in repository.files ])
|
||||
)
|
||||
|
||||
|
@ -1,5 +0,0 @@
|
||||
from materia.api.schema.user import NewUser, User, RemoveUser, LoginUser
|
||||
from materia.api.schema.token import Token
|
||||
from materia.api.schema.repository import RepositoryInfo
|
||||
from materia.api.schema.directory import DirectoryInfo
|
||||
from materia.api.schema.auth import AccessToken, Exchange
|
@ -1,25 +0,0 @@
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthCode(BaseModel):
|
||||
client_id: str
|
||||
response_type: str
|
||||
state: str
|
||||
redirect_uri: Optional[str]
|
||||
scope: Optional[str]
|
||||
|
||||
class Exchange(BaseModel):
|
||||
grant_type: str
|
||||
client_id: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
redirect_uri: Optional[str] = None
|
||||
code: Optional[str] = None
|
||||
refresh_token: Optional[str] = None
|
||||
|
||||
class AccessToken(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
expires_in: int
|
||||
refresh_token: str
|
||||
scope: Optional[str]
|
@ -1,11 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class DirectoryInfo(BaseModel):
|
||||
id: int
|
||||
created_at: int
|
||||
updated_at: int
|
||||
name: str
|
||||
path: str
|
||||
is_public: bool
|
||||
used: int
|
@ -1,6 +0,0 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RepositoryInfo(BaseModel):
|
||||
capacity: int
|
||||
used: int
|
@ -1,10 +0,0 @@
|
||||
from typing import Optional, Self
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
from materia.api.token import TokenClaims
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
|
@ -1,40 +0,0 @@
|
||||
from typing import Optional, Self
|
||||
from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
from materia import db
|
||||
|
||||
|
||||
class NewUser(BaseModel):
|
||||
login: str
|
||||
password: str
|
||||
email: str
|
||||
|
||||
class User(BaseModel):
|
||||
id: str
|
||||
login: str
|
||||
name: str
|
||||
email: str
|
||||
is_admin: bool
|
||||
avatar: Optional[str]
|
||||
|
||||
|
||||
@staticmethod
|
||||
def from_(user: db.User) -> Self:
|
||||
return User(
|
||||
id = str(user.id),
|
||||
login = user.login_name,
|
||||
name = user.name,
|
||||
email = user.email,
|
||||
is_admin = user.is_admin,
|
||||
avatar = user.avatar
|
||||
)
|
||||
|
||||
class RemoveUser(BaseModel):
|
||||
id: UUID
|
||||
|
||||
class LoginUser(BaseModel):
|
||||
email: Optional[str] = None
|
||||
login: Optional[str] = None
|
||||
password: str
|
||||
|
@ -1,27 +0,0 @@
|
||||
from typing import Self
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
import time
|
||||
import datetime
|
||||
|
||||
|
||||
class TokenClaims(BaseModel):
|
||||
sub: str
|
||||
exp: int
|
||||
iat: int
|
||||
|
||||
@staticmethod
|
||||
def create(sub: str, secret: str, duration: int) -> str:
|
||||
now = datetime.datetime.now()
|
||||
iat = now.timestamp()
|
||||
exp = (now + datetime.timedelta(seconds = duration)).timestamp()
|
||||
claims = TokenClaims(sub = sub, exp = int(exp), iat = int(iat))
|
||||
|
||||
return jwt.encode(claims.model_dump(), secret)
|
||||
|
||||
@staticmethod
|
||||
def verify(token: str, secret: str) -> Self:
|
||||
data = jwt.decode(token, secret, algorithms = ["HS256"])
|
||||
|
||||
return TokenClaims(**data)
|
||||
|
50
materia-server/src/materia_server/routers/api/user.py
Normal file
50
materia-server/src/materia_server/routers/api/user.py
Normal file
@ -0,0 +1,50 @@
|
||||
|
||||
import uuid
|
||||
import io
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile
|
||||
import sqlalchemy as sa
|
||||
from sqids.sqids import Sqids
|
||||
from PIL import Image
|
||||
|
||||
from materia_server.config import Config
|
||||
from materia_server.models import User, UserInfo
|
||||
from materia_server.routers import middleware
|
||||
|
||||
|
||||
router = APIRouter(tags = ["user"])
|
||||
|
||||
@router.get("/user", response_model = UserInfo)
|
||||
async def info(claims = Depends(middleware.jwt_cookie), ctx: middleware.Context = Depends()):
|
||||
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
||||
|
||||
return UserInfo.model_validate(current_user)
|
||||
|
||||
@router.post("/user/avatar")
|
||||
async def avatar(file: UploadFile, user: User = Depends(middleware.user), ctx: middleware.Context = Depends()):
|
||||
async with ctx.database.session() as session:
|
||||
avatars: list[str] = (await session.scalars(sa.select(User.avatar))).all()
|
||||
avatars = list(filter(lambda avatar_hash: avatar_hash, avatars))
|
||||
|
||||
avatar_id = Sqids(min_length = 10, blocklist = avatars).encode([len(avatars)])
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(await file.read()))
|
||||
except OSError as _:
|
||||
raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, "Failed to read file data")
|
||||
|
||||
try:
|
||||
if not (avatars_dir := Config.data_dir() / "avatars").exists():
|
||||
avatars_dir.mkdir()
|
||||
img.save(avatars_dir / avatar_id, format = img.format)
|
||||
except OSError as _:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to save avatar")
|
||||
|
||||
if old_avatar := user.avatar:
|
||||
if (old_file := Config.data_dir() / "avatars" / old_avatar).exists():
|
||||
old_file.unlink()
|
||||
|
||||
async with ctx.database.session() as session:
|
||||
await session.execute(sa.update(user.User).where(user.User.id == user.id).values(avatar = avatar_id))
|
||||
await session.commit()
|
@ -1,5 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from materia_server.routers.api.user import user
|
||||
|
||||
router = APIRouter()
|
||||
router.include_router(user.router)
|
@ -1,19 +0,0 @@
|
||||
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
|
||||
from materia_server import security
|
||||
from materia_server.routers import context
|
||||
from materia_server.models import user
|
||||
from materia_server.models import auth
|
||||
from materia_server.routers.middleware import JwtMiddleware
|
||||
|
||||
router = APIRouter(tags = ["user"])
|
||||
|
||||
@router.get("/user/identity", response_model = user.UserIdentity)
|
||||
async def identity(request: Request, claims = Depends(JwtMiddleware()), ctx: context.Context = Depends()):
|
||||
if not (current_user := await user.User.by_id(uuid.UUID(claims.sub), ctx.database)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
||||
|
||||
return user.UserIdentity.model_validate(current_user)
|
@ -1,134 +0,0 @@
|
||||
import io
|
||||
from typing import Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, UploadFile, status
|
||||
from sqlalchemy import delete, select, insert, func, or_, update
|
||||
import bcrypt
|
||||
from sqids.sqids import Sqids
|
||||
from PIL import Image
|
||||
|
||||
|
||||
from materia.config import Config
|
||||
from materia.api.middleware import JwtMiddleware
|
||||
from materia import db
|
||||
from materia.api import schema
|
||||
from materia.api.state import ConfigState, DatabaseState
|
||||
from materia.api.token import TokenClaims
|
||||
|
||||
|
||||
router = APIRouter(tags = ["user"])
|
||||
|
||||
|
||||
@router.post("/user/register", response_model = schema.User)
|
||||
async def register(body: schema.NewUser, database: DatabaseState = Depends()):
|
||||
|
||||
async with database.session() as session:
|
||||
count: Optional[int] = await session.scalar(select(func.count(db.User.id)))
|
||||
|
||||
user = (await session.scalars(
|
||||
select(db.User)
|
||||
.where(or_(db.User.login_name == body.login, db.User.email == body.email)
|
||||
))).first()
|
||||
|
||||
if user is not None:
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "User already exists")
|
||||
|
||||
hashed_password = bcrypt.hashpw(body.password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
new_user = db.User(
|
||||
login_name = body.login,
|
||||
hashed_password = hashed_password,
|
||||
name = body.login,
|
||||
email = body.email,
|
||||
is_admin = count == 0,
|
||||
)
|
||||
|
||||
async with database.session() as session:
|
||||
user = (await session.scalars(insert(db.User).returning(db.User), [new_user.__dict__])).first()
|
||||
await session.commit()
|
||||
|
||||
return schema.User.from_(user)
|
||||
|
||||
@router.post("/user/remove", status_code = 200)
|
||||
async def remove(body: schema.RemoveUser, database: DatabaseState = Depends()):
|
||||
async with database.session() as session:
|
||||
await session.execute(delete(db.User).where(db.User.id == body.id))
|
||||
await session.commit()
|
||||
|
||||
@router.post("/user/login", status_code = 200, response_model = schema.Token)
|
||||
async def login(body: schema.LoginUser, response: Response, database: DatabaseState = Depends(), config: ConfigState = Depends()) -> Any:
|
||||
query = select(db.User)
|
||||
if login := body.login:
|
||||
query = query.where(db.User.login_name == login)
|
||||
elif email := body.email:
|
||||
query = query.where(db.User.email == email)
|
||||
else:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Missing credentials")
|
||||
|
||||
async with database.session() as session:
|
||||
if not (user := (await session.scalars(query)).first()):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||
|
||||
if not bcrypt.checkpw(body.password.encode(), user.hashed_password.encode()):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid password")
|
||||
|
||||
token = TokenClaims.create(
|
||||
str(user.id),
|
||||
config.jwt.secret,
|
||||
config.jwt.maxage
|
||||
)
|
||||
|
||||
response.set_cookie(
|
||||
"token",
|
||||
value = token,
|
||||
max_age = config.jwt.maxage,
|
||||
secure = True,
|
||||
httponly = True,
|
||||
samesite = "none"
|
||||
)
|
||||
|
||||
return schema.Token(access_token = token)
|
||||
|
||||
@router.get("/user/logout", status_code = 200)
|
||||
async def logout(response: Response):
|
||||
response.set_cookie(
|
||||
"token",
|
||||
value = "",
|
||||
max_age = -1,
|
||||
secure = True,
|
||||
httponly = True,
|
||||
samesite = "none"
|
||||
)
|
||||
|
||||
@router.get("/user/current", dependencies = [Depends(JwtMiddleware())], response_model = schema.User)
|
||||
async def current(request: Request):
|
||||
return schema.User.from_(request.state.user)
|
||||
|
||||
@router.post("/user/avatar", dependencies = [Depends(JwtMiddleware())])
|
||||
async def avatar(request: Request, file: UploadFile, database: DatabaseState = Depends()):
|
||||
async with database.session() as session:
|
||||
avatars: list[str] = (await session.scalars(select(db.User.avatar))).all()
|
||||
avatars = list(filter(lambda avatar_hash: avatar_hash, avatars))
|
||||
|
||||
avatar_id = Sqids(min_length = 10, blocklist = avatars).encode([len(avatars)])
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(await file.read()))
|
||||
except OSError as _:
|
||||
raise HTTPException(status.HTTP_422_UNPROCESSABLE_ENTITY, "Failed to read file data")
|
||||
|
||||
try:
|
||||
if not (avatars_dir := Config.data_dir() / "avatars").exists():
|
||||
avatars_dir.mkdir()
|
||||
img.save(avatars_dir / avatar_id, format = img.format)
|
||||
except OSError as _:
|
||||
raise HTTPException(status.WS_1011_INTERNAL_ERROR, "Failed to save avatar")
|
||||
|
||||
if old_avatar := request.state.user.avatar:
|
||||
if (old_file := Config.data_dir() / "avatars" / old_avatar).exists():
|
||||
old_file.unlink()
|
||||
|
||||
async with database.session() as session:
|
||||
await session.execute(update(db.User).where(db.User.id == request.state.user.id).values(avatar = avatar_id))
|
||||
await session.commit()
|
||||
|
||||
|
@ -1,15 +0,0 @@
|
||||
from fastapi import Request
|
||||
|
||||
from materia_server.config import Config
|
||||
from materia_server.models.database import Database, Cache
|
||||
from materia_server._logging import Logger
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, request: Request):
|
||||
self.config = request.state.config
|
||||
self.database = request.state.database
|
||||
#self.cache = request.state.cache
|
||||
self.logger = request.state.logger
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Optional, Sequence
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from fastapi import HTTPException, Request, Response, status, Depends, Cookie
|
||||
from fastapi.security.base import SecurityBase
|
||||
import jwt
|
||||
@ -10,109 +11,71 @@ from http import HTTPMethod as HttpMethod
|
||||
from fastapi.security import HTTPBearer, OAuth2PasswordBearer, OAuth2PasswordRequestForm, APIKeyQuery, APIKeyCookie, APIKeyHeader
|
||||
|
||||
from materia_server import security
|
||||
from materia_server.routers import context
|
||||
from materia_server.models import user
|
||||
from materia_server.models import User
|
||||
|
||||
|
||||
async def get_token_claims(token, ctx: context.Context = Depends()) -> security.TokenClaims:
|
||||
class Context:
|
||||
def __init__(self, request: Request):
|
||||
self.config = request.state.config
|
||||
self.database = request.state.database
|
||||
self.cache = request.state.cache
|
||||
self.logger = request.state.logger
|
||||
|
||||
|
||||
async def jwt_cookie(request: Request, response: Response, ctx: Context = Depends()):
|
||||
if not (access_token := request.cookies.get(ctx.config.security.cookie_access_token_name)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing token")
|
||||
refresh_token = request.cookies.get(ctx.config.security.cookie_refresh_token_name)
|
||||
|
||||
if ctx.config.oauth2.jwt_signing_algo in ["HS256", "HS384", "HS512"]:
|
||||
secret = ctx.config.oauth2.jwt_secret
|
||||
else:
|
||||
secret = ctx.config.oauth2.jwt_signing_key
|
||||
|
||||
issuer = "{}://{}".format(ctx.config.server.scheme, ctx.config.server.domain)
|
||||
|
||||
try:
|
||||
secret = ctx.config.oauth2.jwt_secret if ctx.config.oauth2.jwt_signing_algo in ["HS256", "HS384", "HS512"] else ctx.config.oauth2.jwt_signing_key
|
||||
claims = security.validate_token(token, secret)
|
||||
user_id = uuid.UUID(claims.sub) # type: ignore
|
||||
except jwt.PyJWKError as _:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid token")
|
||||
except ValueError as _:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid token")
|
||||
refresh_claims = security.validate_token(refresh_token, secret) if refresh_token else None
|
||||
|
||||
if refresh_claims:
|
||||
if refresh_claims.exp < datetime.now().timestamp():
|
||||
refresh_claims = None
|
||||
except jwt.PyJWTError:
|
||||
refresh_claims = None
|
||||
|
||||
if not (current_user := await user.User.by_id(user_id, ctx.database)):
|
||||
try:
|
||||
access_claims = security.validate_token(access_token, secret)
|
||||
|
||||
if access_claims.exp < datetime.now().timestamp():
|
||||
if refresh_claims:
|
||||
new_access_token = security.generate_token(
|
||||
access_claims.sub,
|
||||
str(secret),
|
||||
ctx.config.oauth2.access_token_lifetime,
|
||||
issuer
|
||||
)
|
||||
access_claims = security.validate_token(new_access_token, secret)
|
||||
response.set_cookie(
|
||||
ctx.config.security.cookie_access_token_name,
|
||||
value = new_access_token,
|
||||
max_age = ctx.config.oauth2.access_token_lifetime,
|
||||
secure = True,
|
||||
httponly = ctx.config.security.cookie_http_only,
|
||||
samesite = "lax"
|
||||
)
|
||||
else:
|
||||
access_claims = None
|
||||
except jwt.PyJWTError as e:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}")
|
||||
|
||||
if not await User.by_id(uuid.UUID(access_claims.sub), ctx.database):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
|
||||
|
||||
return access_claims
|
||||
|
||||
|
||||
async def user(claims = Depends(jwt_cookie), ctx: Context = Depends()):
|
||||
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
||||
|
||||
return claims
|
||||
|
||||
class JwtBearer(HTTPBearer):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(scheme_name = "Bearer", **kwargs)
|
||||
self.claims = None
|
||||
|
||||
async def __call__(self, request: Request, ctx: context.Context = Depends()):
|
||||
if credentials := await super().__call__(request):
|
||||
token = credentials.credentials
|
||||
else:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Missing token")
|
||||
|
||||
self.claims = await get_token_claims(token, ctx)
|
||||
|
||||
class JwtCookie(SecurityBase):
|
||||
def __init(self, *, auto_error: bool = True):
|
||||
self.auto_error = auto_error
|
||||
self.claims = None
|
||||
|
||||
async def __call__(self, request: Request, response: Response, ctx: context.Context = Depends()):
|
||||
if not (access_token := request.cookies.get(ctx.config.security.cookie_access_token_name)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing token")
|
||||
refresh_token = request.cookies.get(ctx.config.security.cookie_refresh_token_name)
|
||||
|
||||
if ctx.config.oauth2.jwt_signing_algo in ["HS256", "HS384", "HS512"]:
|
||||
secret = ctx.config.oauth2.jwt_secret
|
||||
else:
|
||||
secret = ctx.config.oauth2.jwt_signing_key
|
||||
|
||||
try:
|
||||
refresh_claims = security.validate_token(refresh_token, secret) if refresh_token else None
|
||||
# TODO: check expiration
|
||||
except jwt.PyJWTError:
|
||||
refresh_claims = None
|
||||
|
||||
try:
|
||||
access_claims = security.validate_token(access_token, secret)
|
||||
# TODO: if exp then check refresh token and create new else raise
|
||||
except jwt.PyJWTError as e:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}")
|
||||
else:
|
||||
# TODO: validate user
|
||||
pass
|
||||
|
||||
self.claims = access_claims
|
||||
|
||||
|
||||
WILDCARD = "*"
|
||||
NULL = "null"
|
||||
|
||||
class HttpHeader(StrEnum):
|
||||
ACCESS_CONTROL_ALLOW_CREDENTIALS = "Access-Control-Allow-Credentials"
|
||||
ACCESS_CONTROL_ALLOW_METHODS = "Access-Control-Allow-Methods"
|
||||
ACCESS_CONTROL_ALLOW_ORIGIN = "Access-Control-Allow-Origin"
|
||||
ACCESS_CONTROL_ALLOW_HEADERS = "Access-Control-Allow-Headers"
|
||||
ACCESS_CONTROL_EXPOSE_HEADERS = "Access-Control-Expose-Headers"
|
||||
ACCESS_CONTROL_MAX_AGE = "Access-Control-Max-Age"
|
||||
CONTENT_TYPE = "Content-Type"
|
||||
AUTHORIZATION = "Authorization"
|
||||
VARY = "Vary"
|
||||
ORIGIN = "Origin"
|
||||
|
||||
class CorsMiddleware(BaseModel):
|
||||
allow_credentials: bool = False
|
||||
allow_headers: Sequence[HttpHeader | str] = []
|
||||
allow_methods: Sequence[HttpMethod | str] = []
|
||||
allow_origin: str = WILDCARD
|
||||
expose_headers: Sequence[HttpHeader | str] = []
|
||||
max_age: int = 600
|
||||
|
||||
|
||||
async def __call__(self, request: Request, response: Response):
|
||||
|
||||
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_CREDENTIALS] = str(self.allow_credentials).lower()
|
||||
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_HEADERS] = self.make_from(self.allow_headers)
|
||||
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_METHODS] = self.make_from(self.allow_methods)
|
||||
response.headers[HttpHeader.ACCESS_CONTROL_ALLOW_ORIGIN] = str(self.allow_origin)
|
||||
response.headers[HttpHeader.ACCESS_CONTROL_EXPOSE_HEADERS] = self.make_from(self.expose_headers)
|
||||
response.headers[HttpHeader.ACCESS_CONTROL_MAX_AGE] = str(self.max_age)
|
||||
|
||||
def make_from(self, value: Sequence[HttpHeader | HttpMethod | str]) -> str:
|
||||
if WILDCARD in value:
|
||||
return WILDCARD
|
||||
elif NULL in value:
|
||||
return NULL
|
||||
else:
|
||||
return ",".join(set(value))
|
||||
|
||||
return current_user
|
||||
|
Loading…
Reference in New Issue
Block a user