diff --git a/.gitignore b/.gitignore index 02dbd42..4fb9e1a 100644 --- a/.gitignore +++ b/.gitignore @@ -10,3 +10,6 @@ __pycache__/ .pdm.toml .pdm-python .pdm-build + +.pytest_cache +.coverage diff --git a/pdm.lock b/pdm.lock index 030fd82..c8d8953 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:fe3214096aaef3097e2009f717762fb370bb726aa89a52e7b2a40d60016be987" +content_hash = "sha256:47f5e7de3c9bda99b31aadaaabcc4a7efe77f94ff969135bb278cabcb41d1e20" [[package]] name = "aiofiles" @@ -97,6 +97,20 @@ files = [ {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, ] +[[package]] +name = "asgi-lifespan" +version = "2.1.0" +requires_python = ">=3.7" +summary = "Programmatic startup/shutdown of ASGI apps." +groups = ["dev"] +dependencies = [ + "sniffio", +] +files = [ + {file = "asgi-lifespan-2.1.0.tar.gz", hash = "sha256:5e2effaf0bfe39829cf2d64e7ecc47c7d86d676a6599f7afba378c31f5e3a308"}, + {file = "asgi_lifespan-2.1.0-py3-none-any.whl", hash = "sha256:ed840706680e28428c01e14afb3875d7d76d3206f3d5b2f2294e059b5c23804f"}, +] + [[package]] name = "asyncpg" version = "0.29.0" @@ -296,6 +310,52 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "coverage" +version = "7.6.1" +requires_python = ">=3.8" +summary = "Code coverage measurement for Python" +groups = ["dev"] +files = [ + {file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"}, + {file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"}, + {file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"}, + {file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"}, + {file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"}, +] + +[[package]] +name = "coverage" +version = "7.6.1" +extras = ["toml"] +requires_python = ">=3.8" +summary = "Code coverage measurement for Python" +groups = ["dev"] +dependencies = [ + "coverage==7.6.1", +] +files = [ + {file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"}, + {file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"}, + {file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"}, + {file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"}, + {file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"}, +] + [[package]] name = "cryptography" version = "43.0.0" @@ -1044,6 +1104,21 @@ files = [ {file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"}, ] +[[package]] +name = "pytest-cov" +version = "5.0.0" +requires_python = ">=3.8" +summary = "Pytest plugin for measuring coverage." +groups = ["dev"] +dependencies = [ + "coverage[toml]>=5.2.1", + "pytest>=4.6", +] +files = [ + {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"}, + {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"}, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1157,7 +1232,7 @@ name = "sniffio" version = "1.3.1" requires_python = ">=3.7" summary = "Sniff out which async library your code is running under" -groups = ["default"] +groups = ["default", "dev"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, diff --git a/pyproject.toml b/pyproject.toml index 3140b1a..f94323f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,16 +41,6 @@ requires-python = ">=3.12,<3.13" readme = "README.md" license = {text = "MIT"} -[tool.pdm.dev-dependencies] -dev = [ - "-e file:///${PROJECT_ROOT}/workspaces/frontend", - "black<24.0.0,>=23.3.0", - "pytest<8.0.0,>=7.3.2", - "pyflakes<4.0.0,>=3.0.1", - "pyright<2.0.0,>=1.1.314", - "pytest-asyncio>=0.23.7", -] - [build-system] requires = ["pdm-backend"] build-backend = "pdm.backend" @@ -65,8 +55,20 @@ reportGeneralTypeIssues = false pythonpath = ["."] testpaths = ["tests"] + [tool.pdm] distribution = true +[tool.pdm.dev-dependencies] +dev = [ + "-e file:///${PROJECT_ROOT}/workspaces/frontend", + "black<24.0.0,>=23.3.0", + "pytest<8.0.0,>=7.3.2", + "pyflakes<4.0.0,>=3.0.1", + "pyright<2.0.0,>=1.1.314", + "pytest-asyncio>=0.23.7", + "asgi-lifespan>=2.1.0", + "pytest-cov>=5.0.0", +] [tool.pdm.build] includes = ["src/materia"] diff --git a/src/materia/config.py b/src/materia/config.py index 751921f..638da4b 100644 --- a/src/materia/config.py +++ b/src/materia/config.py @@ -44,6 +44,9 @@ class Server(BaseModel): port: int = 54601 domain: str = "localhost" + def url(self) -> str: + return "{}://{}:{}".format(self.scheme, self.address, self.port) + class Database(BaseModel): backend: Literal["postgresql"] = "postgresql" diff --git a/src/materia/models/database/database.py b/src/materia/models/database/database.py index ced75f5..4c75378 100644 --- a/src/materia/models/database/database.py +++ b/src/materia/models/database/database.py @@ -80,7 +80,7 @@ class Database: async with database.connection() as connection: await connection.rollback() except Exception as e: - raise DatabaseError(f"{e}") + raise DatabaseError(f"Failed to connect to database: {url}") from e return database @@ -94,7 +94,7 @@ class Database: yield connection except Exception as e: await connection.rollback() - raise DatabaseError(f"{e}") + raise DatabaseError(*e.args) from e @asynccontextmanager async def session(self) -> SessionContext: @@ -102,12 +102,12 @@ class Database: try: yield session + except HTTPException as e: + await session.rollback() + raise e from None except Exception as e: await session.rollback() - raise DatabaseError(f"{e}") - except HTTPException: - # if the initial exception reaches HTTPException, then everything is handled fine (should be) - await session.rollback() + raise DatabaseError(*e.args) from e finally: await session.close() diff --git a/src/materia/models/repository.py b/src/materia/models/repository.py index d7a4141..9bda11d 100644 --- a/src/materia/models/repository.py +++ b/src/materia/models/repository.py @@ -98,8 +98,14 @@ class Repository(Base): await session.refresh(user, attribute_names=["repository"]) return user.repository - async def info(self) -> "RepositoryInfo": - return RepositoryInfo.model_validate(self) + async def info(self, session: SessionContext) -> "RepositoryInfo": + session.add(self) + await session.refresh(self, attribute_names=["files"]) + + info = RepositoryInfo.model_validate(self) + info.used = sum([file.size for file in self.files]) + + return info class RepositoryInfo(BaseModel): diff --git a/src/materia/models/user.py b/src/materia/models/user.py index 86a9745..e9055b1 100644 --- a/src/materia/models/user.py +++ b/src/materia/models/user.py @@ -1,5 +1,5 @@ from uuid import UUID, uuid4 -from typing import Optional, Self +from typing import Optional, Self, BinaryIO import time import re @@ -8,6 +8,9 @@ import pydantic from sqlalchemy import BigInteger, Enum from sqlalchemy.orm import mapped_column, Mapped, relationship import sqlalchemy as sa +from PIL import Image +from sqids.sqids import Sqids +from aiofiles import os as async_os from materia import security from materia.models.base import Base @@ -82,8 +85,7 @@ class User(Base): @staticmethod def check_password(password: str, config: Config) -> bool: - if len(password) < config.security.password_min_length: - return False + return not len(password) < config.security.password_min_length @staticmethod async def count(session: SessionContext) -> Optional[int]: @@ -146,6 +148,57 @@ class User(Base): return user_info + async def edit_avatar( + self, avatar: BinaryIO | None, session: SessionContext, config: Config + ): + avatar_dir = config.application.working_directory.joinpath("avatars") + + if avatar is None: + if self.avatar is None: + return + + avatar_file = FileSystem( + avatar_dir.joinpath(self.avatar), config.application.working_directory + ) + if await avatar_file.exists(): + await avatar_file.remove() + + session.add(self) + self.avatar = None + await session.flush() + + return + + try: + image = Image.open(avatar) + except Exception as e: + raise UserError("Failed to read avatar data") from e + + avatar_hashes: list[str] = ( + await session.scalars(sa.select(User.avatar).where(User.avatar.isnot(None))) + ).all() + avatar_id = Sqids(min_length=10, blocklist=avatar_hashes).encode( + [int(time.time())] + ) + + try: + if not avatar_dir.exists(): + await async_os.mkdir(avatar_dir) + image.save(avatar_dir.joinpath(avatar_id), format=image.format) + except Exception as e: + raise UserError(f"Failed to save avatar: {e}") from e + + if old_avatar := self.avatar: + avatar_file = FileSystem( + avatar_dir.joinpath(old_avatar), config.application.working_directory + ) + if await avatar_file.exists(): + await avatar_file.remove() + + session.add(self) + self.avatar = avatar_id + await session.flush() + class UserCredentials(BaseModel): name: str @@ -177,3 +230,4 @@ class UserInfo(BaseModel): from materia.models.repository import Repository +from materia.models.filesystem import FileSystem diff --git a/src/materia/routers/api/auth/auth.py b/src/materia/routers/api/auth/auth.py index 099f4ec..50bdeae 100644 --- a/src/materia/routers/api/auth/auth.py +++ b/src/materia/routers/api/auth/auth.py @@ -10,24 +10,21 @@ router = APIRouter(tags=["auth"]) @router.post("/auth/signup") async def signup(body: UserCredentials, ctx: Context = Depends()): + if not User.check_username(body.name): + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid username" + ) + if not User.check_password(body.password, ctx.config): + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Password is too short (minimum length {ctx.config.security.password_min_length})", + ) + async with ctx.database.session() as session: - if not User.check_username(body.name): - raise HTTPException( - status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid username" - ) - if not User.check_password(body.password, ctx.config): - raise HTTPException( - status.HTTP_500_INTERNAL_SERVER_ERROR, - detail=f"Password is too short (minimum length {ctx.config.security.password_min_length})", - ) if await User.by_name(body.name, session, with_lower=True): - raise HTTPException( - status.HTTP_500_INTERNAL_SERVER_ERROR, detail="User already exists" - ) + raise HTTPException(status.HTTP_409_CONFLICT, detail="User already exists") if await User.by_email(body.email, session): # type: ignore - raise HTTPException( - status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email already used" - ) + raise HTTPException(status.HTTP_409_CONFLICT, detail="Email already used") count: Optional[int] = await User.count(session) @@ -42,14 +39,20 @@ async def signup(body: UserCredentials, ctx: Context = Depends()): login_type=LoginType.Plain, # first registered user is admin is_admin=count == 0, - ).new(session) + ).new(session, ctx.config) + + await session.commit() @router.post("/auth/signin") async def signin(body: UserCredentials, response: Response, ctx: Context = Depends()): - if (current_user := await User.by_name(body.name, ctx.database)) is None: - if (current_user := await User.by_email(str(body.email), ctx.database)) is None: - raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid email") + async with ctx.database.session() as session: + if (current_user := await User.by_name(body.name, session)) is None: + if (current_user := await User.by_email(str(body.email), session)) is None: + raise HTTPException( + status.HTTP_401_UNAUTHORIZED, detail="Invalid email" + ) + if not security.validate_password( body.password, current_user.hashed_password, diff --git a/src/materia/routers/api/repository.py b/src/materia/routers/api/repository.py index 08b8263..c53676b 100644 --- a/src/materia/routers/api/repository.py +++ b/src/materia/routers/api/repository.py @@ -22,14 +22,15 @@ async def create( user: User = Depends(middleware.user), ctx: middleware.Context = Depends() ): async with ctx.database.session() as session: - if await Repository.by_user(user, session): + if await Repository.from_user(user, session): raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists") async with ctx.database.session() as session: try: await Repository( user_id=user.id, capacity=ctx.config.repository.capacity - ).new(session) + ).new(session, ctx.config) + await session.commit() except Exception as e: raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, detail=" ".join(e.args) @@ -41,13 +42,7 @@ async def info( repository=Depends(middleware.repository), ctx: middleware.Context = Depends() ): async with ctx.database.session() as session: - session.add(repository) - await session.refresh(repository, attribute_names=["files"]) - - info = RepositoryInfo.model_validate(repository) - info.used = sum([file.size for file in repository.files]) - - return info + return await repository.info(session) @router.delete("/repository") diff --git a/src/materia/routers/api/user.py b/src/materia/routers/api/user.py index d8b8730..5918401 100644 --- a/src/materia/routers/api/user.py +++ b/src/materia/routers/api/user.py @@ -19,71 +19,56 @@ router = APIRouter(tags=["user"]) async def info( claims=Depends(middleware.jwt_cookie), ctx: middleware.Context = Depends() ): - if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)): - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user") + async with ctx.database.session() as session: + if not (current_user := await User.by_id(uuid.UUID(claims.sub), session)): + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user") - info = UserInfo.model_validate(current_user) - if current_user.is_email_private: - info.email = None - - return info + return current_user.info() @router.delete("/user") async def remove( user: User = Depends(middleware.user), ctx: middleware.Context = Depends() ): - repository_path = Config.data_dir() / "repository" / user.lower_name - - async with ctx.database.session() as session: - session.add(user) - await session.refresh(user, attribute_names=["repository"]) - try: - if repository_path.exists(): - shutil.rmtree(str(repository_path)) - except OSError: + async with ctx.database.session() as session: + await user.remove(session) + await session.commit() + + except Exception as e: raise HTTPException( - status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to remove user" - ) - - await user.repository.remove(ctx.database) + status.HTTP_500_INTERNAL_SERVER_ERROR, f"Failed to remove user: {e}" + ) from e -@router.post("/user/avatar") +@router.put("/user/avatar") async def avatar( file: UploadFile, user: User = Depends(middleware.user), ctx: middleware.Context = Depends(), ): async with ctx.database.session() as session: - avatars: list[str] = (await session.scalars(sa.select(User.avatar))).all() - avatars = list(filter(lambda avatar_hash: avatar_hash, avatars)) + try: + await user.edit_avatar(io.BytesIO(await file.read()), session, ctx.config) + await session.commit() + except Exception as e: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + f"{e}", + ) - avatar_id = Sqids(min_length=10, blocklist=avatars).encode([len(avatars)]) - - try: - img = Image.open(io.BytesIO(await file.read())) - except OSError as _: - raise HTTPException( - status.HTTP_422_UNPROCESSABLE_ENTITY, "Failed to read file data" - ) - - try: - if not (avatars_dir := Config.data_dir() / "avatars").exists(): - avatars_dir.mkdir() - img.save(avatars_dir / avatar_id, format=img.format) - except OSError as _: - raise HTTPException( - status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to save avatar" - ) - - if old_avatar := user.avatar: - if (old_file := Config.data_dir() / "avatars" / old_avatar).exists(): - old_file.unlink() +@router.delete("/user/avatar") +async def remove_avatar( + user: User = Depends(middleware.user), + ctx: middleware.Context = Depends(), +): async with ctx.database.session() as session: - await session.execute( - sa.update(User).where(User.id == user.id).values(avatar=avatar_id) - ) - await session.commit() + try: + await user.edit_avatar(None, session, ctx.config) + await session.commit() + except Exception as e: + raise HTTPException( + status.HTTP_500_INTERNAL_SERVER_ERROR, + f"{e}", + ) diff --git a/src/materia/routers/middleware.py b/src/materia/routers/middleware.py index 1f0b383..995009b 100644 --- a/src/materia/routers/middleware.py +++ b/src/materia/routers/middleware.py @@ -82,15 +82,17 @@ async def jwt_cookie(request: Request, response: Response, ctx: Context = Depend except jwt.PyJWTError as e: raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}") - if not await User.by_id(uuid.UUID(access_claims.sub), ctx.database): - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user") + async with ctx.database.session() as session: + if not await User.by_id(uuid.UUID(access_claims.sub), session): + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user") return access_claims async def user(claims=Depends(jwt_cookie), ctx: Context = Depends()) -> User: - if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)): - raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user") + async with ctx.database.session() as session: + if not (current_user := await User.by_id(uuid.UUID(claims.sub), session)): + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user") return current_user diff --git a/src/materia/routers/root.py b/src/materia/routers/root.py index 6b26ac2..ad0f4e1 100644 --- a/src/materia/routers/root.py +++ b/src/materia/routers/root.py @@ -15,6 +15,4 @@ else: @router.get("/{spa:path}", response_class=HTMLResponse) async def root(request: Request): - return templates.TemplateResponse( - "base.html", {"request": request, "view": "app"} - ) + return templates.TemplateResponse(request, "base.html", {"view": "app"}) diff --git a/src/materia/security/token.py b/src/materia/security/token.py index 86d1266..43c5080 100644 --- a/src/materia/security/token.py +++ b/src/materia/security/token.py @@ -1,26 +1,29 @@ from typing import Optional -import datetime +import datetime from pydantic import BaseModel -import jwt +import jwt class TokenClaims(BaseModel): - sub: str - exp: int - iat: int + sub: str + exp: int + iat: int iss: Optional[str] = None -def generate_token(sub: str, secret: str, duration: int, iss: Optional[str] = None) -> str: +def generate_token( + sub: str, secret: str, duration: int, iss: Optional[str] = None +) -> str: now = datetime.datetime.now() iat = now.timestamp() - exp = (now + datetime.timedelta(seconds = duration)).timestamp() - claims = TokenClaims(sub = sub, exp = int(exp), iat = int(iat), iss = iss) + exp = (now + datetime.timedelta(seconds=duration)).timestamp() + claims = TokenClaims(sub=sub, exp=int(exp), iat=int(iat), iss=iss) return jwt.encode(claims.model_dump(), secret) + def validate_token(token: str, secret: str) -> TokenClaims: - payload = jwt.decode(token, secret, algorithms = [ "HS256" ]) + payload = jwt.decode(token, secret, algorithms=["HS256"]) return TokenClaims(**payload) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..3310da1 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,178 @@ +import pytest_asyncio +from materia.config import Config +from materia.models import ( + Database, + Cache, + User, + LoginType, +) +from materia.models.base import Base +from materia import security +import sqlalchemy as sa +from sqlalchemy.pool import NullPool +from materia.app import make_application, AppContext +from materia._logging import make_logger +from httpx import AsyncClient, ASGITransport, Cookies +import asyncio +from fastapi import FastAPI +from contextlib import asynccontextmanager +from typing import AsyncIterator +from asgi_lifespan import LifespanManager +from fastapi.middleware.cors import CORSMiddleware +from materia import routers +from pathlib import Path + + +@pytest_asyncio.fixture(scope="session") +async def config() -> Config: + conf = Config() + conf.database.port = 54320 + conf.cache.port = 63790 + + return conf + + +@pytest_asyncio.fixture(scope="session") +async def database(config: Config) -> Database: + config_postgres = config + config_postgres.database.user = "postgres" + config_postgres.database.name = "postgres" + database_postgres = await Database.new( + config_postgres.database.url(), poolclass=NullPool + ) + + async with database_postgres.connection() as connection: + await connection.execution_options(isolation_level="AUTOCOMMIT") + await connection.execute(sa.text("create role pytest login")) + await connection.execute(sa.text("create database pytest owner pytest")) + await connection.commit() + + await database_postgres.dispose() + + config.database.user = "pytest" + config.database.name = "pytest" + database_pytest = await Database.new(config.database.url(), poolclass=NullPool) + + yield database_pytest + + await database_pytest.dispose() + + async with database_postgres.connection() as connection: + await connection.execution_options(isolation_level="AUTOCOMMIT") + await connection.execute(sa.text("drop database pytest")), + await connection.execute(sa.text("drop role pytest")) + await connection.commit() + await database_postgres.dispose() + + +@pytest_asyncio.fixture(scope="session") +async def cache(config: Config) -> Cache: + config_pytest = config + config_pytest.cache.user = "pytest" + cache_pytest = await Cache.new(config_pytest.cache.url()) + + yield cache_pytest + + +@pytest_asyncio.fixture(scope="function", autouse=True) +async def setup_database(database: Database): + async with database.connection() as connection: + await connection.run_sync(Base.metadata.create_all) + await connection.commit() + yield + async with database.connection() as connection: + await connection.run_sync(Base.metadata.drop_all) + await connection.commit() + + +@pytest_asyncio.fixture() +async def session(database: Database, request): + session = database.sessionmaker() + yield session + await session.rollback() + await session.close() + + +@pytest_asyncio.fixture(scope="function") +async def data(config: Config): + class TestData: + user = User( + name="PyTest", + lower_name="pytest", + email="pytest@example.com", + hashed_password=security.hash_password( + "iampytest", algo=config.security.password_hash_algo + ), + login_type=LoginType.Plain, + is_admin=True, + ) + + return TestData() + + +@pytest_asyncio.fixture(scope="function") +async def api_config(config: Config, tmpdir) -> Config: + config.application.working_directory = Path(tmpdir) + config.oauth2.jwt_secret = "pytest_secret_key" + + yield config + + +@pytest_asyncio.fixture(scope="function") +async def api_client( + api_config: Config, database: Database, cache: Cache +) -> AsyncClient: + + logger = make_logger(api_config) + + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]: + yield AppContext( + config=api_config, database=database, cache=cache, logger=logger + ) + + app = FastAPI(lifespan=lifespan) + app.include_router(routers.api.router) + app.include_router(routers.resources.router) + app.include_router(routers.root.router) + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + async with LifespanManager(app) as manager: + async with AsyncClient( + transport=ASGITransport(app=manager.app), base_url=api_config.server.url() + ) as client: + yield client + + +@pytest_asyncio.fixture(scope="function") +async def auth_client(api_client: AsyncClient, api_config: Config) -> AsyncClient: + data = {"name": "PyTest", "password": "iampytest", "email": "pytest@example.com"} + + await api_client.post( + "/api/auth/signup", + json=data, + ) + auth = await api_client.post( + "/api/auth/signin", + json=data, + ) + + cookies = Cookies() + cookies.set( + "materia_at", + auth.cookies[api_config.security.cookie_access_token_name], + ) + cookies.set( + "materia_rt", + auth.cookies[api_config.security.cookie_refresh_token_name], + ) + + api_client.cookies = cookies + + yield api_client diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 0000000..6e46696 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,86 @@ +import pytest +from materia.config import Config +from httpx import AsyncClient, Cookies +from materia.models.base import Base +import aiofiles +from io import BytesIO + + +@pytest.mark.asyncio +async def test_auth(api_client: AsyncClient, api_config: Config): + data = {"name": "PyTest", "password": "iampytest", "email": "pytest@example.com"} + + response = await api_client.post( + "/api/auth/signup", + json=data, + ) + assert response.status_code == 200 + + response = await api_client.post( + "/api/auth/signin", + json=data, + ) + assert response.status_code == 200 + assert response.cookies.get(api_config.security.cookie_access_token_name) + assert response.cookies.get(api_config.security.cookie_refresh_token_name) + + # TODO: conflict usernames and emails + + response = await api_client.get("/api/auth/signout") + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_user(auth_client: AsyncClient, api_config: Config): + info = await auth_client.get("/api/user") + assert info.status_code == 200, info.text + + async with AsyncClient() as client: + pytest_logo_res = await client.get( + "https://docs.pytest.org/en/stable/_static/pytest1.png" + ) + assert isinstance(pytest_logo_res.content, bytes) + pytest_logo = BytesIO(pytest_logo_res.content) + + avatar = await auth_client.put( + "/api/user/avatar", + files={"file": ("pytest.png", pytest_logo)}, + ) + assert avatar.status_code == 200, avatar.text + + info = await auth_client.get("/api/user") + avatar_info = info.json()["avatar"] + assert avatar_info is not None + assert api_config.application.working_directory.joinpath( + "avatars", avatar_info + ).exists() + + avatar = await auth_client.delete("/api/user/avatar") + assert avatar.status_code == 200, avatar.text + + info = await auth_client.get("/api/user") + assert info.json()["avatar"] is None + assert not api_config.application.working_directory.joinpath( + "avatars", avatar_info + ).exists() + + delete = await auth_client.delete("/api/user") + assert delete.status_code == 200, delete.text + + info = await auth_client.get("/api/user") + assert info.status_code == 401, info.text + + +@pytest.mark.asyncio +async def test_repository(auth_client: AsyncClient, api_config: Config): + info = await auth_client.get("/api/repository") + assert info.status_code == 404, info.text + + create = await auth_client.post("/api/repository") + assert create.status_code == 200, create.text + + create = await auth_client.post("/api/repository") + assert create.status_code == 409, create.text + + info = await auth_client.get("/api/repository") + assert info.status_code == 200, info.text diff --git a/tests/test_database.py b/tests/test_models.py similarity index 69% rename from tests/test_database.py rename to tests/test_models.py index 4f2eba2..dc05c39 100644 --- a/tests/test_database.py +++ b/tests/test_models.py @@ -1,145 +1,23 @@ import pytest_asyncio import pytest -import os -import sys from pathlib import Path from materia.config import Config from materia.models import ( - Database, User, - LoginType, Repository, Directory, RepositoryError, File, ) -from materia.models.base import Base from materia.models.database import SessionContext from materia import security import sqlalchemy as sa -from sqlalchemy.pool import NullPool from sqlalchemy.orm.session import make_transient from sqlalchemy import inspect import aiofiles import aiofiles.os -@pytest_asyncio.fixture(scope="session") -async def config() -> Config: - conf = Config() - conf.database.port = 54320 - # conf.application.working_directory = conf.application.working_directory / "temp" - # if (cwd := conf.application.working_directory.resolve()).exists(): - # os.chdir(cwd) - # if local_conf := Config.open(cwd / "config.toml"): - # conf = local_conf - return conf - - -@pytest_asyncio.fixture(scope="session") -async def db(config: Config, request) -> Database: - config_postgres = config - config_postgres.database.user = "postgres" - config_postgres.database.name = "postgres" - database_postgres = await Database.new( - config_postgres.database.url(), poolclass=NullPool - ) - - async with database_postgres.connection() as connection: - await connection.execution_options(isolation_level="AUTOCOMMIT") - await connection.execute(sa.text("create role pytest login")) - await connection.execute(sa.text("create database pytest owner pytest")) - await connection.commit() - - await database_postgres.dispose() - - config.database.user = "pytest" - config.database.name = "pytest" - database = await Database.new(config.database.url(), poolclass=NullPool) - - yield database - - await database.dispose() - - async with database_postgres.connection() as connection: - await connection.execution_options(isolation_level="AUTOCOMMIT") - await connection.execute(sa.text("drop database pytest")), - await connection.execute(sa.text("drop role pytest")) - await connection.commit() - await database_postgres.dispose() - - -""" -@pytest.mark.asyncio -async def test_migrations(db): - await db.run_migrations() - await db.rollback_migrations() -""" - - -@pytest_asyncio.fixture(scope="session", autouse=True) -async def setup_db(db: Database, request): - async with db.connection() as connection: - await connection.run_sync(Base.metadata.create_all) - await connection.commit() - yield - async with db.connection() as connection: - await connection.run_sync(Base.metadata.drop_all) - await connection.commit() - - -@pytest_asyncio.fixture(autouse=True) -async def session(db: Database, request): - session = db.sessionmaker() - yield session - await session.rollback() - await session.close() - - -""" -@pytest_asyncio.fixture(scope="session") -async def user(config: Config, session) -> User: - test_user = User( - name="pytest", - lower_name="pytest", - email="pytest@example.com", - hashed_password=security.hash_password( - "iampytest", algo=config.security.password_hash_algo - ), - login_type=LoginType.Plain, - is_admin=True, - ) - - async with db.session() as session: - session.add(test_user) - await session.flush() - await session.refresh(test_user) - - yield test_user - - async with db.session() as session: - await session.delete(test_user) - await session.flush() -""" - - -@pytest_asyncio.fixture(scope="function") -async def data(config: Config): - class TestData: - user = User( - name="PyTest", - lower_name="pytest", - email="pytest@example.com", - hashed_password=security.hash_password( - "iampytest", algo=config.security.password_hash_algo - ), - login_type=LoginType.Plain, - is_admin=True, - ) - - return TestData() - - @pytest.mark.asyncio async def test_user(data, session: SessionContext, config: Config): # simple @@ -161,6 +39,10 @@ async def test_user(data, session: SessionContext, config: Config): await data.user.edit_name("AsyncPyTest", session) assert await User.by_name("asyncpytest", session, with_lower=True) == data.user + assert await User.by_email("pytest@example.com", session) == data.user + assert await User.by_id(data.user.id, session) == data.user + await data.user.edit_password("iamnotpytest", session, config) + assert security.validate_password("iamnotpytest", data.user.hashed_password) await data.user.remove(session) @@ -280,9 +162,9 @@ async def test_directory(data, tmpdir, session: SessionContext, config: Config): # rename assert (await directory.rename("test1", session, config)).name == "test1" - directory2 = await Directory( - repository_id=repository.id, parent_id=None, name="test2" - ).new(session, config) + await Directory(repository_id=repository.id, parent_id=None, name="test2").new( + session, config + ) assert (await directory.rename("test2", session, config)).name == "test2.1" assert (await repository.path(session, config)).joinpath("test2.1").exists() assert not (await repository.path(session, config)).joinpath("test1").exists() @@ -358,7 +240,7 @@ async def test_file(data, tmpdir, session: SessionContext, config: Config): assert ( await file.rename("test_file_rename.txt", session, config) ).name == "test_file_rename.txt" - file2 = await File( + await File( repository_id=repository.id, parent_id=directory.id, name="test_file_2.txt" ).new(b"", session, config) assert (