tests and fixtures

This commit is contained in:
L-Nafaryus 2024-08-14 00:56:30 +05:00
parent aefedfe187
commit 58e7175d45
Signed by: L-Nafaryus
GPG Key ID: 553C97999B363D38
16 changed files with 516 additions and 241 deletions

3
.gitignore vendored
View File

@ -10,3 +10,6 @@ __pycache__/
.pdm.toml .pdm.toml
.pdm-python .pdm-python
.pdm-build .pdm-build
.pytest_cache
.coverage

View File

@ -5,7 +5,7 @@
groups = ["default", "dev"] groups = ["default", "dev"]
strategy = ["cross_platform", "inherit_metadata"] strategy = ["cross_platform", "inherit_metadata"]
lock_version = "4.4.1" lock_version = "4.4.1"
content_hash = "sha256:fe3214096aaef3097e2009f717762fb370bb726aa89a52e7b2a40d60016be987" content_hash = "sha256:47f5e7de3c9bda99b31aadaaabcc4a7efe77f94ff969135bb278cabcb41d1e20"
[[package]] [[package]]
name = "aiofiles" name = "aiofiles"
@ -97,6 +97,20 @@ files = [
{file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"},
] ]
[[package]]
name = "asgi-lifespan"
version = "2.1.0"
requires_python = ">=3.7"
summary = "Programmatic startup/shutdown of ASGI apps."
groups = ["dev"]
dependencies = [
"sniffio",
]
files = [
{file = "asgi-lifespan-2.1.0.tar.gz", hash = "sha256:5e2effaf0bfe39829cf2d64e7ecc47c7d86d676a6599f7afba378c31f5e3a308"},
{file = "asgi_lifespan-2.1.0-py3-none-any.whl", hash = "sha256:ed840706680e28428c01e14afb3875d7d76d3206f3d5b2f2294e059b5c23804f"},
]
[[package]] [[package]]
name = "asyncpg" name = "asyncpg"
version = "0.29.0" version = "0.29.0"
@ -296,6 +310,52 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
] ]
[[package]]
name = "coverage"
version = "7.6.1"
requires_python = ">=3.8"
summary = "Code coverage measurement for Python"
groups = ["dev"]
files = [
{file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"},
{file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"},
{file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"},
{file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"},
{file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"},
{file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"},
]
[[package]]
name = "coverage"
version = "7.6.1"
extras = ["toml"]
requires_python = ">=3.8"
summary = "Code coverage measurement for Python"
groups = ["dev"]
dependencies = [
"coverage==7.6.1",
]
files = [
{file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"},
{file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"},
{file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"},
{file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"},
{file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"},
{file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"},
]
[[package]] [[package]]
name = "cryptography" name = "cryptography"
version = "43.0.0" version = "43.0.0"
@ -1044,6 +1104,21 @@ files = [
{file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"}, {file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"},
] ]
[[package]]
name = "pytest-cov"
version = "5.0.0"
requires_python = ">=3.8"
summary = "Pytest plugin for measuring coverage."
groups = ["dev"]
dependencies = [
"coverage[toml]>=5.2.1",
"pytest>=4.6",
]
files = [
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.9.0.post0" version = "2.9.0.post0"
@ -1157,7 +1232,7 @@ name = "sniffio"
version = "1.3.1" version = "1.3.1"
requires_python = ">=3.7" requires_python = ">=3.7"
summary = "Sniff out which async library your code is running under" summary = "Sniff out which async library your code is running under"
groups = ["default"] groups = ["default", "dev"]
files = [ files = [
{file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"},
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},

View File

@ -41,16 +41,6 @@ requires-python = ">=3.12,<3.13"
readme = "README.md" readme = "README.md"
license = {text = "MIT"} license = {text = "MIT"}
[tool.pdm.dev-dependencies]
dev = [
"-e file:///${PROJECT_ROOT}/workspaces/frontend",
"black<24.0.0,>=23.3.0",
"pytest<8.0.0,>=7.3.2",
"pyflakes<4.0.0,>=3.0.1",
"pyright<2.0.0,>=1.1.314",
"pytest-asyncio>=0.23.7",
]
[build-system] [build-system]
requires = ["pdm-backend"] requires = ["pdm-backend"]
build-backend = "pdm.backend" build-backend = "pdm.backend"
@ -65,8 +55,20 @@ reportGeneralTypeIssues = false
pythonpath = ["."] pythonpath = ["."]
testpaths = ["tests"] testpaths = ["tests"]
[tool.pdm] [tool.pdm]
distribution = true distribution = true
[tool.pdm.dev-dependencies]
dev = [
"-e file:///${PROJECT_ROOT}/workspaces/frontend",
"black<24.0.0,>=23.3.0",
"pytest<8.0.0,>=7.3.2",
"pyflakes<4.0.0,>=3.0.1",
"pyright<2.0.0,>=1.1.314",
"pytest-asyncio>=0.23.7",
"asgi-lifespan>=2.1.0",
"pytest-cov>=5.0.0",
]
[tool.pdm.build] [tool.pdm.build]
includes = ["src/materia"] includes = ["src/materia"]

View File

@ -44,6 +44,9 @@ class Server(BaseModel):
port: int = 54601 port: int = 54601
domain: str = "localhost" domain: str = "localhost"
def url(self) -> str:
return "{}://{}:{}".format(self.scheme, self.address, self.port)
class Database(BaseModel): class Database(BaseModel):
backend: Literal["postgresql"] = "postgresql" backend: Literal["postgresql"] = "postgresql"

View File

@ -80,7 +80,7 @@ class Database:
async with database.connection() as connection: async with database.connection() as connection:
await connection.rollback() await connection.rollback()
except Exception as e: except Exception as e:
raise DatabaseError(f"{e}") raise DatabaseError(f"Failed to connect to database: {url}") from e
return database return database
@ -94,7 +94,7 @@ class Database:
yield connection yield connection
except Exception as e: except Exception as e:
await connection.rollback() await connection.rollback()
raise DatabaseError(f"{e}") raise DatabaseError(*e.args) from e
@asynccontextmanager @asynccontextmanager
async def session(self) -> SessionContext: async def session(self) -> SessionContext:
@ -102,12 +102,12 @@ class Database:
try: try:
yield session yield session
except HTTPException as e:
await session.rollback()
raise e from None
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise DatabaseError(f"{e}") raise DatabaseError(*e.args) from e
except HTTPException:
# if the initial exception reaches HTTPException, then everything is handled fine (should be)
await session.rollback()
finally: finally:
await session.close() await session.close()

View File

@ -98,8 +98,14 @@ class Repository(Base):
await session.refresh(user, attribute_names=["repository"]) await session.refresh(user, attribute_names=["repository"])
return user.repository return user.repository
async def info(self) -> "RepositoryInfo": async def info(self, session: SessionContext) -> "RepositoryInfo":
return RepositoryInfo.model_validate(self) session.add(self)
await session.refresh(self, attribute_names=["files"])
info = RepositoryInfo.model_validate(self)
info.used = sum([file.size for file in self.files])
return info
class RepositoryInfo(BaseModel): class RepositoryInfo(BaseModel):

View File

@ -1,5 +1,5 @@
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from typing import Optional, Self from typing import Optional, Self, BinaryIO
import time import time
import re import re
@ -8,6 +8,9 @@ import pydantic
from sqlalchemy import BigInteger, Enum from sqlalchemy import BigInteger, Enum
from sqlalchemy.orm import mapped_column, Mapped, relationship from sqlalchemy.orm import mapped_column, Mapped, relationship
import sqlalchemy as sa import sqlalchemy as sa
from PIL import Image
from sqids.sqids import Sqids
from aiofiles import os as async_os
from materia import security from materia import security
from materia.models.base import Base from materia.models.base import Base
@ -82,8 +85,7 @@ class User(Base):
@staticmethod @staticmethod
def check_password(password: str, config: Config) -> bool: def check_password(password: str, config: Config) -> bool:
if len(password) < config.security.password_min_length: return not len(password) < config.security.password_min_length
return False
@staticmethod @staticmethod
async def count(session: SessionContext) -> Optional[int]: async def count(session: SessionContext) -> Optional[int]:
@ -146,6 +148,57 @@ class User(Base):
return user_info return user_info
async def edit_avatar(
self, avatar: BinaryIO | None, session: SessionContext, config: Config
):
avatar_dir = config.application.working_directory.joinpath("avatars")
if avatar is None:
if self.avatar is None:
return
avatar_file = FileSystem(
avatar_dir.joinpath(self.avatar), config.application.working_directory
)
if await avatar_file.exists():
await avatar_file.remove()
session.add(self)
self.avatar = None
await session.flush()
return
try:
image = Image.open(avatar)
except Exception as e:
raise UserError("Failed to read avatar data") from e
avatar_hashes: list[str] = (
await session.scalars(sa.select(User.avatar).where(User.avatar.isnot(None)))
).all()
avatar_id = Sqids(min_length=10, blocklist=avatar_hashes).encode(
[int(time.time())]
)
try:
if not avatar_dir.exists():
await async_os.mkdir(avatar_dir)
image.save(avatar_dir.joinpath(avatar_id), format=image.format)
except Exception as e:
raise UserError(f"Failed to save avatar: {e}") from e
if old_avatar := self.avatar:
avatar_file = FileSystem(
avatar_dir.joinpath(old_avatar), config.application.working_directory
)
if await avatar_file.exists():
await avatar_file.remove()
session.add(self)
self.avatar = avatar_id
await session.flush()
class UserCredentials(BaseModel): class UserCredentials(BaseModel):
name: str name: str
@ -177,3 +230,4 @@ class UserInfo(BaseModel):
from materia.models.repository import Repository from materia.models.repository import Repository
from materia.models.filesystem import FileSystem

View File

@ -10,24 +10,21 @@ router = APIRouter(tags=["auth"])
@router.post("/auth/signup") @router.post("/auth/signup")
async def signup(body: UserCredentials, ctx: Context = Depends()): async def signup(body: UserCredentials, ctx: Context = Depends()):
if not User.check_username(body.name):
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid username"
)
if not User.check_password(body.password, ctx.config):
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Password is too short (minimum length {ctx.config.security.password_min_length})",
)
async with ctx.database.session() as session: async with ctx.database.session() as session:
if not User.check_username(body.name):
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid username"
)
if not User.check_password(body.password, ctx.config):
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Password is too short (minimum length {ctx.config.security.password_min_length})",
)
if await User.by_name(body.name, session, with_lower=True): if await User.by_name(body.name, session, with_lower=True):
raise HTTPException( raise HTTPException(status.HTTP_409_CONFLICT, detail="User already exists")
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="User already exists"
)
if await User.by_email(body.email, session): # type: ignore if await User.by_email(body.email, session): # type: ignore
raise HTTPException( raise HTTPException(status.HTTP_409_CONFLICT, detail="Email already used")
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email already used"
)
count: Optional[int] = await User.count(session) count: Optional[int] = await User.count(session)
@ -42,14 +39,20 @@ async def signup(body: UserCredentials, ctx: Context = Depends()):
login_type=LoginType.Plain, login_type=LoginType.Plain,
# first registered user is admin # first registered user is admin
is_admin=count == 0, is_admin=count == 0,
).new(session) ).new(session, ctx.config)
await session.commit()
@router.post("/auth/signin") @router.post("/auth/signin")
async def signin(body: UserCredentials, response: Response, ctx: Context = Depends()): async def signin(body: UserCredentials, response: Response, ctx: Context = Depends()):
if (current_user := await User.by_name(body.name, ctx.database)) is None: async with ctx.database.session() as session:
if (current_user := await User.by_email(str(body.email), ctx.database)) is None: if (current_user := await User.by_name(body.name, session)) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid email") if (current_user := await User.by_email(str(body.email), session)) is None:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, detail="Invalid email"
)
if not security.validate_password( if not security.validate_password(
body.password, body.password,
current_user.hashed_password, current_user.hashed_password,

View File

@ -22,14 +22,15 @@ async def create(
user: User = Depends(middleware.user), ctx: middleware.Context = Depends() user: User = Depends(middleware.user), ctx: middleware.Context = Depends()
): ):
async with ctx.database.session() as session: async with ctx.database.session() as session:
if await Repository.by_user(user, session): if await Repository.from_user(user, session):
raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists") raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists")
async with ctx.database.session() as session: async with ctx.database.session() as session:
try: try:
await Repository( await Repository(
user_id=user.id, capacity=ctx.config.repository.capacity user_id=user.id, capacity=ctx.config.repository.capacity
).new(session) ).new(session, ctx.config)
await session.commit()
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, detail=" ".join(e.args) status.HTTP_500_INTERNAL_SERVER_ERROR, detail=" ".join(e.args)
@ -41,13 +42,7 @@ async def info(
repository=Depends(middleware.repository), ctx: middleware.Context = Depends() repository=Depends(middleware.repository), ctx: middleware.Context = Depends()
): ):
async with ctx.database.session() as session: async with ctx.database.session() as session:
session.add(repository) return await repository.info(session)
await session.refresh(repository, attribute_names=["files"])
info = RepositoryInfo.model_validate(repository)
info.used = sum([file.size for file in repository.files])
return info
@router.delete("/repository") @router.delete("/repository")

View File

@ -19,71 +19,56 @@ router = APIRouter(tags=["user"])
async def info( async def info(
claims=Depends(middleware.jwt_cookie), ctx: middleware.Context = Depends() claims=Depends(middleware.jwt_cookie), ctx: middleware.Context = Depends()
): ):
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)): async with ctx.database.session() as session:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user") if not (current_user := await User.by_id(uuid.UUID(claims.sub), session)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
info = UserInfo.model_validate(current_user) return current_user.info()
if current_user.is_email_private:
info.email = None
return info
@router.delete("/user") @router.delete("/user")
async def remove( async def remove(
user: User = Depends(middleware.user), ctx: middleware.Context = Depends() user: User = Depends(middleware.user), ctx: middleware.Context = Depends()
): ):
repository_path = Config.data_dir() / "repository" / user.lower_name
async with ctx.database.session() as session:
session.add(user)
await session.refresh(user, attribute_names=["repository"])
try: try:
if repository_path.exists(): async with ctx.database.session() as session:
shutil.rmtree(str(repository_path)) await user.remove(session)
except OSError: await session.commit()
except Exception as e:
raise HTTPException( raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to remove user" status.HTTP_500_INTERNAL_SERVER_ERROR, f"Failed to remove user: {e}"
) ) from e
await user.repository.remove(ctx.database)
@router.post("/user/avatar") @router.put("/user/avatar")
async def avatar( async def avatar(
file: UploadFile, file: UploadFile,
user: User = Depends(middleware.user), user: User = Depends(middleware.user),
ctx: middleware.Context = Depends(), ctx: middleware.Context = Depends(),
): ):
async with ctx.database.session() as session: async with ctx.database.session() as session:
avatars: list[str] = (await session.scalars(sa.select(User.avatar))).all() try:
avatars = list(filter(lambda avatar_hash: avatar_hash, avatars)) await user.edit_avatar(io.BytesIO(await file.read()), session, ctx.config)
await session.commit()
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
f"{e}",
)
avatar_id = Sqids(min_length=10, blocklist=avatars).encode([len(avatars)])
try:
img = Image.open(io.BytesIO(await file.read()))
except OSError as _:
raise HTTPException(
status.HTTP_422_UNPROCESSABLE_ENTITY, "Failed to read file data"
)
try:
if not (avatars_dir := Config.data_dir() / "avatars").exists():
avatars_dir.mkdir()
img.save(avatars_dir / avatar_id, format=img.format)
except OSError as _:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to save avatar"
)
if old_avatar := user.avatar:
if (old_file := Config.data_dir() / "avatars" / old_avatar).exists():
old_file.unlink()
@router.delete("/user/avatar")
async def remove_avatar(
user: User = Depends(middleware.user),
ctx: middleware.Context = Depends(),
):
async with ctx.database.session() as session: async with ctx.database.session() as session:
await session.execute( try:
sa.update(User).where(User.id == user.id).values(avatar=avatar_id) await user.edit_avatar(None, session, ctx.config)
) await session.commit()
await session.commit() except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
f"{e}",
)

View File

@ -82,15 +82,17 @@ async def jwt_cookie(request: Request, response: Response, ctx: Context = Depend
except jwt.PyJWTError as e: except jwt.PyJWTError as e:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}") raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}")
if not await User.by_id(uuid.UUID(access_claims.sub), ctx.database): async with ctx.database.session() as session:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user") if not await User.by_id(uuid.UUID(access_claims.sub), session):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
return access_claims return access_claims
async def user(claims=Depends(jwt_cookie), ctx: Context = Depends()) -> User: async def user(claims=Depends(jwt_cookie), ctx: Context = Depends()) -> User:
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)): async with ctx.database.session() as session:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user") if not (current_user := await User.by_id(uuid.UUID(claims.sub), session)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
return current_user return current_user

View File

@ -15,6 +15,4 @@ else:
@router.get("/{spa:path}", response_class=HTMLResponse) @router.get("/{spa:path}", response_class=HTMLResponse)
async def root(request: Request): async def root(request: Request):
return templates.TemplateResponse( return templates.TemplateResponse(request, "base.html", {"view": "app"})
"base.html", {"request": request, "view": "app"}
)

View File

@ -1,26 +1,29 @@
from typing import Optional from typing import Optional
import datetime import datetime
from pydantic import BaseModel from pydantic import BaseModel
import jwt import jwt
class TokenClaims(BaseModel): class TokenClaims(BaseModel):
sub: str sub: str
exp: int exp: int
iat: int iat: int
iss: Optional[str] = None iss: Optional[str] = None
def generate_token(sub: str, secret: str, duration: int, iss: Optional[str] = None) -> str: def generate_token(
sub: str, secret: str, duration: int, iss: Optional[str] = None
) -> str:
now = datetime.datetime.now() now = datetime.datetime.now()
iat = now.timestamp() iat = now.timestamp()
exp = (now + datetime.timedelta(seconds = duration)).timestamp() exp = (now + datetime.timedelta(seconds=duration)).timestamp()
claims = TokenClaims(sub = sub, exp = int(exp), iat = int(iat), iss = iss) claims = TokenClaims(sub=sub, exp=int(exp), iat=int(iat), iss=iss)
return jwt.encode(claims.model_dump(), secret) return jwt.encode(claims.model_dump(), secret)
def validate_token(token: str, secret: str) -> TokenClaims: def validate_token(token: str, secret: str) -> TokenClaims:
payload = jwt.decode(token, secret, algorithms = [ "HS256" ]) payload = jwt.decode(token, secret, algorithms=["HS256"])
return TokenClaims(**payload) return TokenClaims(**payload)

178
tests/conftest.py Normal file
View File

@ -0,0 +1,178 @@
import pytest_asyncio
from materia.config import Config
from materia.models import (
Database,
Cache,
User,
LoginType,
)
from materia.models.base import Base
from materia import security
import sqlalchemy as sa
from sqlalchemy.pool import NullPool
from materia.app import make_application, AppContext
from materia._logging import make_logger
from httpx import AsyncClient, ASGITransport, Cookies
import asyncio
from fastapi import FastAPI
from contextlib import asynccontextmanager
from typing import AsyncIterator
from asgi_lifespan import LifespanManager
from fastapi.middleware.cors import CORSMiddleware
from materia import routers
from pathlib import Path
@pytest_asyncio.fixture(scope="session")
async def config() -> Config:
conf = Config()
conf.database.port = 54320
conf.cache.port = 63790
return conf
@pytest_asyncio.fixture(scope="session")
async def database(config: Config) -> Database:
config_postgres = config
config_postgres.database.user = "postgres"
config_postgres.database.name = "postgres"
database_postgres = await Database.new(
config_postgres.database.url(), poolclass=NullPool
)
async with database_postgres.connection() as connection:
await connection.execution_options(isolation_level="AUTOCOMMIT")
await connection.execute(sa.text("create role pytest login"))
await connection.execute(sa.text("create database pytest owner pytest"))
await connection.commit()
await database_postgres.dispose()
config.database.user = "pytest"
config.database.name = "pytest"
database_pytest = await Database.new(config.database.url(), poolclass=NullPool)
yield database_pytest
await database_pytest.dispose()
async with database_postgres.connection() as connection:
await connection.execution_options(isolation_level="AUTOCOMMIT")
await connection.execute(sa.text("drop database pytest")),
await connection.execute(sa.text("drop role pytest"))
await connection.commit()
await database_postgres.dispose()
@pytest_asyncio.fixture(scope="session")
async def cache(config: Config) -> Cache:
config_pytest = config
config_pytest.cache.user = "pytest"
cache_pytest = await Cache.new(config_pytest.cache.url())
yield cache_pytest
@pytest_asyncio.fixture(scope="function", autouse=True)
async def setup_database(database: Database):
async with database.connection() as connection:
await connection.run_sync(Base.metadata.create_all)
await connection.commit()
yield
async with database.connection() as connection:
await connection.run_sync(Base.metadata.drop_all)
await connection.commit()
@pytest_asyncio.fixture()
async def session(database: Database, request):
session = database.sessionmaker()
yield session
await session.rollback()
await session.close()
@pytest_asyncio.fixture(scope="function")
async def data(config: Config):
class TestData:
user = User(
name="PyTest",
lower_name="pytest",
email="pytest@example.com",
hashed_password=security.hash_password(
"iampytest", algo=config.security.password_hash_algo
),
login_type=LoginType.Plain,
is_admin=True,
)
return TestData()
@pytest_asyncio.fixture(scope="function")
async def api_config(config: Config, tmpdir) -> Config:
config.application.working_directory = Path(tmpdir)
config.oauth2.jwt_secret = "pytest_secret_key"
yield config
@pytest_asyncio.fixture(scope="function")
async def api_client(
api_config: Config, database: Database, cache: Cache
) -> AsyncClient:
logger = make_logger(api_config)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
yield AppContext(
config=api_config, database=database, cache=cache, logger=logger
)
app = FastAPI(lifespan=lifespan)
app.include_router(routers.api.router)
app.include_router(routers.resources.router)
app.include_router(routers.root.router)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
async with LifespanManager(app) as manager:
async with AsyncClient(
transport=ASGITransport(app=manager.app), base_url=api_config.server.url()
) as client:
yield client
@pytest_asyncio.fixture(scope="function")
async def auth_client(api_client: AsyncClient, api_config: Config) -> AsyncClient:
data = {"name": "PyTest", "password": "iampytest", "email": "pytest@example.com"}
await api_client.post(
"/api/auth/signup",
json=data,
)
auth = await api_client.post(
"/api/auth/signin",
json=data,
)
cookies = Cookies()
cookies.set(
"materia_at",
auth.cookies[api_config.security.cookie_access_token_name],
)
cookies.set(
"materia_rt",
auth.cookies[api_config.security.cookie_refresh_token_name],
)
api_client.cookies = cookies
yield api_client

86
tests/test_api.py Normal file
View File

@ -0,0 +1,86 @@
import pytest
from materia.config import Config
from httpx import AsyncClient, Cookies
from materia.models.base import Base
import aiofiles
from io import BytesIO
@pytest.mark.asyncio
async def test_auth(api_client: AsyncClient, api_config: Config):
data = {"name": "PyTest", "password": "iampytest", "email": "pytest@example.com"}
response = await api_client.post(
"/api/auth/signup",
json=data,
)
assert response.status_code == 200
response = await api_client.post(
"/api/auth/signin",
json=data,
)
assert response.status_code == 200
assert response.cookies.get(api_config.security.cookie_access_token_name)
assert response.cookies.get(api_config.security.cookie_refresh_token_name)
# TODO: conflict usernames and emails
response = await api_client.get("/api/auth/signout")
assert response.status_code == 200
@pytest.mark.asyncio
async def test_user(auth_client: AsyncClient, api_config: Config):
info = await auth_client.get("/api/user")
assert info.status_code == 200, info.text
async with AsyncClient() as client:
pytest_logo_res = await client.get(
"https://docs.pytest.org/en/stable/_static/pytest1.png"
)
assert isinstance(pytest_logo_res.content, bytes)
pytest_logo = BytesIO(pytest_logo_res.content)
avatar = await auth_client.put(
"/api/user/avatar",
files={"file": ("pytest.png", pytest_logo)},
)
assert avatar.status_code == 200, avatar.text
info = await auth_client.get("/api/user")
avatar_info = info.json()["avatar"]
assert avatar_info is not None
assert api_config.application.working_directory.joinpath(
"avatars", avatar_info
).exists()
avatar = await auth_client.delete("/api/user/avatar")
assert avatar.status_code == 200, avatar.text
info = await auth_client.get("/api/user")
assert info.json()["avatar"] is None
assert not api_config.application.working_directory.joinpath(
"avatars", avatar_info
).exists()
delete = await auth_client.delete("/api/user")
assert delete.status_code == 200, delete.text
info = await auth_client.get("/api/user")
assert info.status_code == 401, info.text
@pytest.mark.asyncio
async def test_repository(auth_client: AsyncClient, api_config: Config):
info = await auth_client.get("/api/repository")
assert info.status_code == 404, info.text
create = await auth_client.post("/api/repository")
assert create.status_code == 200, create.text
create = await auth_client.post("/api/repository")
assert create.status_code == 409, create.text
info = await auth_client.get("/api/repository")
assert info.status_code == 200, info.text

View File

@ -1,145 +1,23 @@
import pytest_asyncio import pytest_asyncio
import pytest import pytest
import os
import sys
from pathlib import Path from pathlib import Path
from materia.config import Config from materia.config import Config
from materia.models import ( from materia.models import (
Database,
User, User,
LoginType,
Repository, Repository,
Directory, Directory,
RepositoryError, RepositoryError,
File, File,
) )
from materia.models.base import Base
from materia.models.database import SessionContext from materia.models.database import SessionContext
from materia import security from materia import security
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.pool import NullPool
from sqlalchemy.orm.session import make_transient from sqlalchemy.orm.session import make_transient
from sqlalchemy import inspect from sqlalchemy import inspect
import aiofiles import aiofiles
import aiofiles.os import aiofiles.os
@pytest_asyncio.fixture(scope="session")
async def config() -> Config:
conf = Config()
conf.database.port = 54320
# conf.application.working_directory = conf.application.working_directory / "temp"
# if (cwd := conf.application.working_directory.resolve()).exists():
# os.chdir(cwd)
# if local_conf := Config.open(cwd / "config.toml"):
# conf = local_conf
return conf
@pytest_asyncio.fixture(scope="session")
async def db(config: Config, request) -> Database:
config_postgres = config
config_postgres.database.user = "postgres"
config_postgres.database.name = "postgres"
database_postgres = await Database.new(
config_postgres.database.url(), poolclass=NullPool
)
async with database_postgres.connection() as connection:
await connection.execution_options(isolation_level="AUTOCOMMIT")
await connection.execute(sa.text("create role pytest login"))
await connection.execute(sa.text("create database pytest owner pytest"))
await connection.commit()
await database_postgres.dispose()
config.database.user = "pytest"
config.database.name = "pytest"
database = await Database.new(config.database.url(), poolclass=NullPool)
yield database
await database.dispose()
async with database_postgres.connection() as connection:
await connection.execution_options(isolation_level="AUTOCOMMIT")
await connection.execute(sa.text("drop database pytest")),
await connection.execute(sa.text("drop role pytest"))
await connection.commit()
await database_postgres.dispose()
"""
@pytest.mark.asyncio
async def test_migrations(db):
await db.run_migrations()
await db.rollback_migrations()
"""
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_db(db: Database, request):
async with db.connection() as connection:
await connection.run_sync(Base.metadata.create_all)
await connection.commit()
yield
async with db.connection() as connection:
await connection.run_sync(Base.metadata.drop_all)
await connection.commit()
@pytest_asyncio.fixture(autouse=True)
async def session(db: Database, request):
session = db.sessionmaker()
yield session
await session.rollback()
await session.close()
"""
@pytest_asyncio.fixture(scope="session")
async def user(config: Config, session) -> User:
test_user = User(
name="pytest",
lower_name="pytest",
email="pytest@example.com",
hashed_password=security.hash_password(
"iampytest", algo=config.security.password_hash_algo
),
login_type=LoginType.Plain,
is_admin=True,
)
async with db.session() as session:
session.add(test_user)
await session.flush()
await session.refresh(test_user)
yield test_user
async with db.session() as session:
await session.delete(test_user)
await session.flush()
"""
@pytest_asyncio.fixture(scope="function")
async def data(config: Config):
class TestData:
user = User(
name="PyTest",
lower_name="pytest",
email="pytest@example.com",
hashed_password=security.hash_password(
"iampytest", algo=config.security.password_hash_algo
),
login_type=LoginType.Plain,
is_admin=True,
)
return TestData()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_user(data, session: SessionContext, config: Config): async def test_user(data, session: SessionContext, config: Config):
# simple # simple
@ -161,6 +39,10 @@ async def test_user(data, session: SessionContext, config: Config):
await data.user.edit_name("AsyncPyTest", session) await data.user.edit_name("AsyncPyTest", session)
assert await User.by_name("asyncpytest", session, with_lower=True) == data.user assert await User.by_name("asyncpytest", session, with_lower=True) == data.user
assert await User.by_email("pytest@example.com", session) == data.user
assert await User.by_id(data.user.id, session) == data.user
await data.user.edit_password("iamnotpytest", session, config)
assert security.validate_password("iamnotpytest", data.user.hashed_password)
await data.user.remove(session) await data.user.remove(session)
@ -280,9 +162,9 @@ async def test_directory(data, tmpdir, session: SessionContext, config: Config):
# rename # rename
assert (await directory.rename("test1", session, config)).name == "test1" assert (await directory.rename("test1", session, config)).name == "test1"
directory2 = await Directory( await Directory(repository_id=repository.id, parent_id=None, name="test2").new(
repository_id=repository.id, parent_id=None, name="test2" session, config
).new(session, config) )
assert (await directory.rename("test2", session, config)).name == "test2.1" assert (await directory.rename("test2", session, config)).name == "test2.1"
assert (await repository.path(session, config)).joinpath("test2.1").exists() assert (await repository.path(session, config)).joinpath("test2.1").exists()
assert not (await repository.path(session, config)).joinpath("test1").exists() assert not (await repository.path(session, config)).joinpath("test1").exists()
@ -358,7 +240,7 @@ async def test_file(data, tmpdir, session: SessionContext, config: Config):
assert ( assert (
await file.rename("test_file_rename.txt", session, config) await file.rename("test_file_rename.txt", session, config)
).name == "test_file_rename.txt" ).name == "test_file_rename.txt"
file2 = await File( await File(
repository_id=repository.id, parent_id=directory.id, name="test_file_2.txt" repository_id=repository.id, parent_id=directory.id, name="test_file_2.txt"
).new(b"", session, config) ).new(b"", session, config)
assert ( assert (