tests and fixtures
This commit is contained in:
parent
aefedfe187
commit
58e7175d45
3
.gitignore
vendored
3
.gitignore
vendored
@ -10,3 +10,6 @@ __pycache__/
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build
|
||||
|
||||
.pytest_cache
|
||||
.coverage
|
||||
|
79
pdm.lock
generated
79
pdm.lock
generated
@ -5,7 +5,7 @@
|
||||
groups = ["default", "dev"]
|
||||
strategy = ["cross_platform", "inherit_metadata"]
|
||||
lock_version = "4.4.1"
|
||||
content_hash = "sha256:fe3214096aaef3097e2009f717762fb370bb726aa89a52e7b2a40d60016be987"
|
||||
content_hash = "sha256:47f5e7de3c9bda99b31aadaaabcc4a7efe77f94ff969135bb278cabcb41d1e20"
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
@ -97,6 +97,20 @@ files = [
|
||||
{file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asgi-lifespan"
|
||||
version = "2.1.0"
|
||||
requires_python = ">=3.7"
|
||||
summary = "Programmatic startup/shutdown of ASGI apps."
|
||||
groups = ["dev"]
|
||||
dependencies = [
|
||||
"sniffio",
|
||||
]
|
||||
files = [
|
||||
{file = "asgi-lifespan-2.1.0.tar.gz", hash = "sha256:5e2effaf0bfe39829cf2d64e7ecc47c7d86d676a6599f7afba378c31f5e3a308"},
|
||||
{file = "asgi_lifespan-2.1.0-py3-none-any.whl", hash = "sha256:ed840706680e28428c01e14afb3875d7d76d3206f3d5b2f2294e059b5c23804f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asyncpg"
|
||||
version = "0.29.0"
|
||||
@ -296,6 +310,52 @@ files = [
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.6.1"
|
||||
requires_python = ">=3.8"
|
||||
summary = "Code coverage measurement for Python"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"},
|
||||
{file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"},
|
||||
{file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "7.6.1"
|
||||
extras = ["toml"]
|
||||
requires_python = ">=3.8"
|
||||
summary = "Code coverage measurement for Python"
|
||||
groups = ["dev"]
|
||||
dependencies = [
|
||||
"coverage==7.6.1",
|
||||
]
|
||||
files = [
|
||||
{file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"},
|
||||
{file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"},
|
||||
{file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"},
|
||||
{file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "43.0.0"
|
||||
@ -1044,6 +1104,21 @@ files = [
|
||||
{file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pytest-cov"
|
||||
version = "5.0.0"
|
||||
requires_python = ">=3.8"
|
||||
summary = "Pytest plugin for measuring coverage."
|
||||
groups = ["dev"]
|
||||
dependencies = [
|
||||
"coverage[toml]>=5.2.1",
|
||||
"pytest>=4.6",
|
||||
]
|
||||
files = [
|
||||
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
|
||||
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "python-dateutil"
|
||||
version = "2.9.0.post0"
|
||||
@ -1157,7 +1232,7 @@ name = "sniffio"
|
||||
version = "1.3.1"
|
||||
requires_python = ">=3.7"
|
||||
summary = "Sniff out which async library your code is running under"
|
||||
groups = ["default"]
|
||||
groups = ["default", "dev"]
|
||||
files = [
|
||||
{file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"},
|
||||
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
|
||||
|
@ -41,16 +41,6 @@ requires-python = ">=3.12,<3.13"
|
||||
readme = "README.md"
|
||||
license = {text = "MIT"}
|
||||
|
||||
[tool.pdm.dev-dependencies]
|
||||
dev = [
|
||||
"-e file:///${PROJECT_ROOT}/workspaces/frontend",
|
||||
"black<24.0.0,>=23.3.0",
|
||||
"pytest<8.0.0,>=7.3.2",
|
||||
"pyflakes<4.0.0,>=3.0.1",
|
||||
"pyright<2.0.0,>=1.1.314",
|
||||
"pytest-asyncio>=0.23.7",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["pdm-backend"]
|
||||
build-backend = "pdm.backend"
|
||||
@ -65,8 +55,20 @@ reportGeneralTypeIssues = false
|
||||
pythonpath = ["."]
|
||||
testpaths = ["tests"]
|
||||
|
||||
|
||||
[tool.pdm]
|
||||
distribution = true
|
||||
[tool.pdm.dev-dependencies]
|
||||
dev = [
|
||||
"-e file:///${PROJECT_ROOT}/workspaces/frontend",
|
||||
"black<24.0.0,>=23.3.0",
|
||||
"pytest<8.0.0,>=7.3.2",
|
||||
"pyflakes<4.0.0,>=3.0.1",
|
||||
"pyright<2.0.0,>=1.1.314",
|
||||
"pytest-asyncio>=0.23.7",
|
||||
"asgi-lifespan>=2.1.0",
|
||||
"pytest-cov>=5.0.0",
|
||||
]
|
||||
|
||||
[tool.pdm.build]
|
||||
includes = ["src/materia"]
|
||||
|
@ -44,6 +44,9 @@ class Server(BaseModel):
|
||||
port: int = 54601
|
||||
domain: str = "localhost"
|
||||
|
||||
def url(self) -> str:
|
||||
return "{}://{}:{}".format(self.scheme, self.address, self.port)
|
||||
|
||||
|
||||
class Database(BaseModel):
|
||||
backend: Literal["postgresql"] = "postgresql"
|
||||
|
@ -80,7 +80,7 @@ class Database:
|
||||
async with database.connection() as connection:
|
||||
await connection.rollback()
|
||||
except Exception as e:
|
||||
raise DatabaseError(f"{e}")
|
||||
raise DatabaseError(f"Failed to connect to database: {url}") from e
|
||||
|
||||
return database
|
||||
|
||||
@ -94,7 +94,7 @@ class Database:
|
||||
yield connection
|
||||
except Exception as e:
|
||||
await connection.rollback()
|
||||
raise DatabaseError(f"{e}")
|
||||
raise DatabaseError(*e.args) from e
|
||||
|
||||
@asynccontextmanager
|
||||
async def session(self) -> SessionContext:
|
||||
@ -102,12 +102,12 @@ class Database:
|
||||
|
||||
try:
|
||||
yield session
|
||||
except HTTPException as e:
|
||||
await session.rollback()
|
||||
raise e from None
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise DatabaseError(f"{e}")
|
||||
except HTTPException:
|
||||
# if the initial exception reaches HTTPException, then everything is handled fine (should be)
|
||||
await session.rollback()
|
||||
raise DatabaseError(*e.args) from e
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
@ -98,8 +98,14 @@ class Repository(Base):
|
||||
await session.refresh(user, attribute_names=["repository"])
|
||||
return user.repository
|
||||
|
||||
async def info(self) -> "RepositoryInfo":
|
||||
return RepositoryInfo.model_validate(self)
|
||||
async def info(self, session: SessionContext) -> "RepositoryInfo":
|
||||
session.add(self)
|
||||
await session.refresh(self, attribute_names=["files"])
|
||||
|
||||
info = RepositoryInfo.model_validate(self)
|
||||
info.used = sum([file.size for file in self.files])
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class RepositoryInfo(BaseModel):
|
||||
|
@ -1,5 +1,5 @@
|
||||
from uuid import UUID, uuid4
|
||||
from typing import Optional, Self
|
||||
from typing import Optional, Self, BinaryIO
|
||||
import time
|
||||
import re
|
||||
|
||||
@ -8,6 +8,9 @@ import pydantic
|
||||
from sqlalchemy import BigInteger, Enum
|
||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||
import sqlalchemy as sa
|
||||
from PIL import Image
|
||||
from sqids.sqids import Sqids
|
||||
from aiofiles import os as async_os
|
||||
|
||||
from materia import security
|
||||
from materia.models.base import Base
|
||||
@ -82,8 +85,7 @@ class User(Base):
|
||||
|
||||
@staticmethod
|
||||
def check_password(password: str, config: Config) -> bool:
|
||||
if len(password) < config.security.password_min_length:
|
||||
return False
|
||||
return not len(password) < config.security.password_min_length
|
||||
|
||||
@staticmethod
|
||||
async def count(session: SessionContext) -> Optional[int]:
|
||||
@ -146,6 +148,57 @@ class User(Base):
|
||||
|
||||
return user_info
|
||||
|
||||
async def edit_avatar(
|
||||
self, avatar: BinaryIO | None, session: SessionContext, config: Config
|
||||
):
|
||||
avatar_dir = config.application.working_directory.joinpath("avatars")
|
||||
|
||||
if avatar is None:
|
||||
if self.avatar is None:
|
||||
return
|
||||
|
||||
avatar_file = FileSystem(
|
||||
avatar_dir.joinpath(self.avatar), config.application.working_directory
|
||||
)
|
||||
if await avatar_file.exists():
|
||||
await avatar_file.remove()
|
||||
|
||||
session.add(self)
|
||||
self.avatar = None
|
||||
await session.flush()
|
||||
|
||||
return
|
||||
|
||||
try:
|
||||
image = Image.open(avatar)
|
||||
except Exception as e:
|
||||
raise UserError("Failed to read avatar data") from e
|
||||
|
||||
avatar_hashes: list[str] = (
|
||||
await session.scalars(sa.select(User.avatar).where(User.avatar.isnot(None)))
|
||||
).all()
|
||||
avatar_id = Sqids(min_length=10, blocklist=avatar_hashes).encode(
|
||||
[int(time.time())]
|
||||
)
|
||||
|
||||
try:
|
||||
if not avatar_dir.exists():
|
||||
await async_os.mkdir(avatar_dir)
|
||||
image.save(avatar_dir.joinpath(avatar_id), format=image.format)
|
||||
except Exception as e:
|
||||
raise UserError(f"Failed to save avatar: {e}") from e
|
||||
|
||||
if old_avatar := self.avatar:
|
||||
avatar_file = FileSystem(
|
||||
avatar_dir.joinpath(old_avatar), config.application.working_directory
|
||||
)
|
||||
if await avatar_file.exists():
|
||||
await avatar_file.remove()
|
||||
|
||||
session.add(self)
|
||||
self.avatar = avatar_id
|
||||
await session.flush()
|
||||
|
||||
|
||||
class UserCredentials(BaseModel):
|
||||
name: str
|
||||
@ -177,3 +230,4 @@ class UserInfo(BaseModel):
|
||||
|
||||
|
||||
from materia.models.repository import Repository
|
||||
from materia.models.filesystem import FileSystem
|
||||
|
@ -10,24 +10,21 @@ router = APIRouter(tags=["auth"])
|
||||
|
||||
@router.post("/auth/signup")
|
||||
async def signup(body: UserCredentials, ctx: Context = Depends()):
|
||||
if not User.check_username(body.name):
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid username"
|
||||
)
|
||||
if not User.check_password(body.password, ctx.config):
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Password is too short (minimum length {ctx.config.security.password_min_length})",
|
||||
)
|
||||
|
||||
async with ctx.database.session() as session:
|
||||
if not User.check_username(body.name):
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid username"
|
||||
)
|
||||
if not User.check_password(body.password, ctx.config):
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Password is too short (minimum length {ctx.config.security.password_min_length})",
|
||||
)
|
||||
if await User.by_name(body.name, session, with_lower=True):
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="User already exists"
|
||||
)
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, detail="User already exists")
|
||||
if await User.by_email(body.email, session): # type: ignore
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email already used"
|
||||
)
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, detail="Email already used")
|
||||
|
||||
count: Optional[int] = await User.count(session)
|
||||
|
||||
@ -42,14 +39,20 @@ async def signup(body: UserCredentials, ctx: Context = Depends()):
|
||||
login_type=LoginType.Plain,
|
||||
# first registered user is admin
|
||||
is_admin=count == 0,
|
||||
).new(session)
|
||||
).new(session, ctx.config)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
@router.post("/auth/signin")
|
||||
async def signin(body: UserCredentials, response: Response, ctx: Context = Depends()):
|
||||
if (current_user := await User.by_name(body.name, ctx.database)) is None:
|
||||
if (current_user := await User.by_email(str(body.email), ctx.database)) is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid email")
|
||||
async with ctx.database.session() as session:
|
||||
if (current_user := await User.by_name(body.name, session)) is None:
|
||||
if (current_user := await User.by_email(str(body.email), session)) is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_401_UNAUTHORIZED, detail="Invalid email"
|
||||
)
|
||||
|
||||
if not security.validate_password(
|
||||
body.password,
|
||||
current_user.hashed_password,
|
||||
|
@ -22,14 +22,15 @@ async def create(
|
||||
user: User = Depends(middleware.user), ctx: middleware.Context = Depends()
|
||||
):
|
||||
async with ctx.database.session() as session:
|
||||
if await Repository.by_user(user, session):
|
||||
if await Repository.from_user(user, session):
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists")
|
||||
|
||||
async with ctx.database.session() as session:
|
||||
try:
|
||||
await Repository(
|
||||
user_id=user.id, capacity=ctx.config.repository.capacity
|
||||
).new(session)
|
||||
).new(session, ctx.config)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, detail=" ".join(e.args)
|
||||
@ -41,13 +42,7 @@ async def info(
|
||||
repository=Depends(middleware.repository), ctx: middleware.Context = Depends()
|
||||
):
|
||||
async with ctx.database.session() as session:
|
||||
session.add(repository)
|
||||
await session.refresh(repository, attribute_names=["files"])
|
||||
|
||||
info = RepositoryInfo.model_validate(repository)
|
||||
info.used = sum([file.size for file in repository.files])
|
||||
|
||||
return info
|
||||
return await repository.info(session)
|
||||
|
||||
|
||||
@router.delete("/repository")
|
||||
|
@ -19,71 +19,56 @@ router = APIRouter(tags=["user"])
|
||||
async def info(
|
||||
claims=Depends(middleware.jwt_cookie), ctx: middleware.Context = Depends()
|
||||
):
|
||||
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
||||
async with ctx.database.session() as session:
|
||||
if not (current_user := await User.by_id(uuid.UUID(claims.sub), session)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
||||
|
||||
info = UserInfo.model_validate(current_user)
|
||||
if current_user.is_email_private:
|
||||
info.email = None
|
||||
|
||||
return info
|
||||
return current_user.info()
|
||||
|
||||
|
||||
@router.delete("/user")
|
||||
async def remove(
|
||||
user: User = Depends(middleware.user), ctx: middleware.Context = Depends()
|
||||
):
|
||||
repository_path = Config.data_dir() / "repository" / user.lower_name
|
||||
|
||||
async with ctx.database.session() as session:
|
||||
session.add(user)
|
||||
await session.refresh(user, attribute_names=["repository"])
|
||||
|
||||
try:
|
||||
if repository_path.exists():
|
||||
shutil.rmtree(str(repository_path))
|
||||
except OSError:
|
||||
async with ctx.database.session() as session:
|
||||
await user.remove(session)
|
||||
await session.commit()
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to remove user"
|
||||
)
|
||||
|
||||
await user.repository.remove(ctx.database)
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, f"Failed to remove user: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
@router.post("/user/avatar")
|
||||
@router.put("/user/avatar")
|
||||
async def avatar(
|
||||
file: UploadFile,
|
||||
user: User = Depends(middleware.user),
|
||||
ctx: middleware.Context = Depends(),
|
||||
):
|
||||
async with ctx.database.session() as session:
|
||||
avatars: list[str] = (await session.scalars(sa.select(User.avatar))).all()
|
||||
avatars = list(filter(lambda avatar_hash: avatar_hash, avatars))
|
||||
try:
|
||||
await user.edit_avatar(io.BytesIO(await file.read()), session, ctx.config)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
f"{e}",
|
||||
)
|
||||
|
||||
avatar_id = Sqids(min_length=10, blocklist=avatars).encode([len(avatars)])
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(await file.read()))
|
||||
except OSError as _:
|
||||
raise HTTPException(
|
||||
status.HTTP_422_UNPROCESSABLE_ENTITY, "Failed to read file data"
|
||||
)
|
||||
|
||||
try:
|
||||
if not (avatars_dir := Config.data_dir() / "avatars").exists():
|
||||
avatars_dir.mkdir()
|
||||
img.save(avatars_dir / avatar_id, format=img.format)
|
||||
except OSError as _:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to save avatar"
|
||||
)
|
||||
|
||||
if old_avatar := user.avatar:
|
||||
if (old_file := Config.data_dir() / "avatars" / old_avatar).exists():
|
||||
old_file.unlink()
|
||||
|
||||
@router.delete("/user/avatar")
|
||||
async def remove_avatar(
|
||||
user: User = Depends(middleware.user),
|
||||
ctx: middleware.Context = Depends(),
|
||||
):
|
||||
async with ctx.database.session() as session:
|
||||
await session.execute(
|
||||
sa.update(User).where(User.id == user.id).values(avatar=avatar_id)
|
||||
)
|
||||
await session.commit()
|
||||
try:
|
||||
await user.edit_avatar(None, session, ctx.config)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
f"{e}",
|
||||
)
|
||||
|
@ -82,15 +82,17 @@ async def jwt_cookie(request: Request, response: Response, ctx: Context = Depend
|
||||
except jwt.PyJWTError as e:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}")
|
||||
|
||||
if not await User.by_id(uuid.UUID(access_claims.sub), ctx.database):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
|
||||
async with ctx.database.session() as session:
|
||||
if not await User.by_id(uuid.UUID(access_claims.sub), session):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
|
||||
|
||||
return access_claims
|
||||
|
||||
|
||||
async def user(claims=Depends(jwt_cookie), ctx: Context = Depends()) -> User:
|
||||
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
||||
async with ctx.database.session() as session:
|
||||
if not (current_user := await User.by_id(uuid.UUID(claims.sub), session)):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
||||
|
||||
return current_user
|
||||
|
||||
|
@ -15,6 +15,4 @@ else:
|
||||
|
||||
@router.get("/{spa:path}", response_class=HTMLResponse)
|
||||
async def root(request: Request):
|
||||
return templates.TemplateResponse(
|
||||
"base.html", {"request": request, "view": "app"}
|
||||
)
|
||||
return templates.TemplateResponse(request, "base.html", {"view": "app"})
|
||||
|
@ -1,26 +1,29 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
import jwt
|
||||
import jwt
|
||||
|
||||
|
||||
class TokenClaims(BaseModel):
|
||||
sub: str
|
||||
exp: int
|
||||
iat: int
|
||||
sub: str
|
||||
exp: int
|
||||
iat: int
|
||||
iss: Optional[str] = None
|
||||
|
||||
|
||||
def generate_token(sub: str, secret: str, duration: int, iss: Optional[str] = None) -> str:
|
||||
def generate_token(
|
||||
sub: str, secret: str, duration: int, iss: Optional[str] = None
|
||||
) -> str:
|
||||
now = datetime.datetime.now()
|
||||
iat = now.timestamp()
|
||||
exp = (now + datetime.timedelta(seconds = duration)).timestamp()
|
||||
claims = TokenClaims(sub = sub, exp = int(exp), iat = int(iat), iss = iss)
|
||||
exp = (now + datetime.timedelta(seconds=duration)).timestamp()
|
||||
claims = TokenClaims(sub=sub, exp=int(exp), iat=int(iat), iss=iss)
|
||||
|
||||
return jwt.encode(claims.model_dump(), secret)
|
||||
|
||||
|
||||
def validate_token(token: str, secret: str) -> TokenClaims:
|
||||
payload = jwt.decode(token, secret, algorithms = [ "HS256" ])
|
||||
payload = jwt.decode(token, secret, algorithms=["HS256"])
|
||||
|
||||
return TokenClaims(**payload)
|
||||
|
178
tests/conftest.py
Normal file
178
tests/conftest.py
Normal file
@ -0,0 +1,178 @@
|
||||
import pytest_asyncio
|
||||
from materia.config import Config
|
||||
from materia.models import (
|
||||
Database,
|
||||
Cache,
|
||||
User,
|
||||
LoginType,
|
||||
)
|
||||
from materia.models.base import Base
|
||||
from materia import security
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.pool import NullPool
|
||||
from materia.app import make_application, AppContext
|
||||
from materia._logging import make_logger
|
||||
from httpx import AsyncClient, ASGITransport, Cookies
|
||||
import asyncio
|
||||
from fastapi import FastAPI
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator
|
||||
from asgi_lifespan import LifespanManager
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from materia import routers
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def config() -> Config:
|
||||
conf = Config()
|
||||
conf.database.port = 54320
|
||||
conf.cache.port = 63790
|
||||
|
||||
return conf
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def database(config: Config) -> Database:
|
||||
config_postgres = config
|
||||
config_postgres.database.user = "postgres"
|
||||
config_postgres.database.name = "postgres"
|
||||
database_postgres = await Database.new(
|
||||
config_postgres.database.url(), poolclass=NullPool
|
||||
)
|
||||
|
||||
async with database_postgres.connection() as connection:
|
||||
await connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
await connection.execute(sa.text("create role pytest login"))
|
||||
await connection.execute(sa.text("create database pytest owner pytest"))
|
||||
await connection.commit()
|
||||
|
||||
await database_postgres.dispose()
|
||||
|
||||
config.database.user = "pytest"
|
||||
config.database.name = "pytest"
|
||||
database_pytest = await Database.new(config.database.url(), poolclass=NullPool)
|
||||
|
||||
yield database_pytest
|
||||
|
||||
await database_pytest.dispose()
|
||||
|
||||
async with database_postgres.connection() as connection:
|
||||
await connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
await connection.execute(sa.text("drop database pytest")),
|
||||
await connection.execute(sa.text("drop role pytest"))
|
||||
await connection.commit()
|
||||
await database_postgres.dispose()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def cache(config: Config) -> Cache:
|
||||
config_pytest = config
|
||||
config_pytest.cache.user = "pytest"
|
||||
cache_pytest = await Cache.new(config_pytest.cache.url())
|
||||
|
||||
yield cache_pytest
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function", autouse=True)
|
||||
async def setup_database(database: Database):
|
||||
async with database.connection() as connection:
|
||||
await connection.run_sync(Base.metadata.create_all)
|
||||
await connection.commit()
|
||||
yield
|
||||
async with database.connection() as connection:
|
||||
await connection.run_sync(Base.metadata.drop_all)
|
||||
await connection.commit()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture()
|
||||
async def session(database: Database, request):
|
||||
session = database.sessionmaker()
|
||||
yield session
|
||||
await session.rollback()
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def data(config: Config):
|
||||
class TestData:
|
||||
user = User(
|
||||
name="PyTest",
|
||||
lower_name="pytest",
|
||||
email="pytest@example.com",
|
||||
hashed_password=security.hash_password(
|
||||
"iampytest", algo=config.security.password_hash_algo
|
||||
),
|
||||
login_type=LoginType.Plain,
|
||||
is_admin=True,
|
||||
)
|
||||
|
||||
return TestData()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def api_config(config: Config, tmpdir) -> Config:
|
||||
config.application.working_directory = Path(tmpdir)
|
||||
config.oauth2.jwt_secret = "pytest_secret_key"
|
||||
|
||||
yield config
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def api_client(
|
||||
api_config: Config, database: Database, cache: Cache
|
||||
) -> AsyncClient:
|
||||
|
||||
logger = make_logger(api_config)
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
|
||||
yield AppContext(
|
||||
config=api_config, database=database, cache=cache, logger=logger
|
||||
)
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.include_router(routers.api.router)
|
||||
app.include_router(routers.resources.router)
|
||||
app.include_router(routers.root.router)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
async with LifespanManager(app) as manager:
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=manager.app), base_url=api_config.server.url()
|
||||
) as client:
|
||||
yield client
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def auth_client(api_client: AsyncClient, api_config: Config) -> AsyncClient:
|
||||
data = {"name": "PyTest", "password": "iampytest", "email": "pytest@example.com"}
|
||||
|
||||
await api_client.post(
|
||||
"/api/auth/signup",
|
||||
json=data,
|
||||
)
|
||||
auth = await api_client.post(
|
||||
"/api/auth/signin",
|
||||
json=data,
|
||||
)
|
||||
|
||||
cookies = Cookies()
|
||||
cookies.set(
|
||||
"materia_at",
|
||||
auth.cookies[api_config.security.cookie_access_token_name],
|
||||
)
|
||||
cookies.set(
|
||||
"materia_rt",
|
||||
auth.cookies[api_config.security.cookie_refresh_token_name],
|
||||
)
|
||||
|
||||
api_client.cookies = cookies
|
||||
|
||||
yield api_client
|
86
tests/test_api.py
Normal file
86
tests/test_api.py
Normal file
@ -0,0 +1,86 @@
|
||||
import pytest
|
||||
from materia.config import Config
|
||||
from httpx import AsyncClient, Cookies
|
||||
from materia.models.base import Base
|
||||
import aiofiles
|
||||
from io import BytesIO
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth(api_client: AsyncClient, api_config: Config):
|
||||
data = {"name": "PyTest", "password": "iampytest", "email": "pytest@example.com"}
|
||||
|
||||
response = await api_client.post(
|
||||
"/api/auth/signup",
|
||||
json=data,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response = await api_client.post(
|
||||
"/api/auth/signin",
|
||||
json=data,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.cookies.get(api_config.security.cookie_access_token_name)
|
||||
assert response.cookies.get(api_config.security.cookie_refresh_token_name)
|
||||
|
||||
# TODO: conflict usernames and emails
|
||||
|
||||
response = await api_client.get("/api/auth/signout")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user(auth_client: AsyncClient, api_config: Config):
|
||||
info = await auth_client.get("/api/user")
|
||||
assert info.status_code == 200, info.text
|
||||
|
||||
async with AsyncClient() as client:
|
||||
pytest_logo_res = await client.get(
|
||||
"https://docs.pytest.org/en/stable/_static/pytest1.png"
|
||||
)
|
||||
assert isinstance(pytest_logo_res.content, bytes)
|
||||
pytest_logo = BytesIO(pytest_logo_res.content)
|
||||
|
||||
avatar = await auth_client.put(
|
||||
"/api/user/avatar",
|
||||
files={"file": ("pytest.png", pytest_logo)},
|
||||
)
|
||||
assert avatar.status_code == 200, avatar.text
|
||||
|
||||
info = await auth_client.get("/api/user")
|
||||
avatar_info = info.json()["avatar"]
|
||||
assert avatar_info is not None
|
||||
assert api_config.application.working_directory.joinpath(
|
||||
"avatars", avatar_info
|
||||
).exists()
|
||||
|
||||
avatar = await auth_client.delete("/api/user/avatar")
|
||||
assert avatar.status_code == 200, avatar.text
|
||||
|
||||
info = await auth_client.get("/api/user")
|
||||
assert info.json()["avatar"] is None
|
||||
assert not api_config.application.working_directory.joinpath(
|
||||
"avatars", avatar_info
|
||||
).exists()
|
||||
|
||||
delete = await auth_client.delete("/api/user")
|
||||
assert delete.status_code == 200, delete.text
|
||||
|
||||
info = await auth_client.get("/api/user")
|
||||
assert info.status_code == 401, info.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_repository(auth_client: AsyncClient, api_config: Config):
|
||||
info = await auth_client.get("/api/repository")
|
||||
assert info.status_code == 404, info.text
|
||||
|
||||
create = await auth_client.post("/api/repository")
|
||||
assert create.status_code == 200, create.text
|
||||
|
||||
create = await auth_client.post("/api/repository")
|
||||
assert create.status_code == 409, create.text
|
||||
|
||||
info = await auth_client.get("/api/repository")
|
||||
assert info.status_code == 200, info.text
|
@ -1,145 +1,23 @@
|
||||
import pytest_asyncio
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from materia.config import Config
|
||||
from materia.models import (
|
||||
Database,
|
||||
User,
|
||||
LoginType,
|
||||
Repository,
|
||||
Directory,
|
||||
RepositoryError,
|
||||
File,
|
||||
)
|
||||
from materia.models.base import Base
|
||||
from materia.models.database import SessionContext
|
||||
from materia import security
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.orm.session import make_transient
|
||||
from sqlalchemy import inspect
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def config() -> Config:
|
||||
conf = Config()
|
||||
conf.database.port = 54320
|
||||
# conf.application.working_directory = conf.application.working_directory / "temp"
|
||||
# if (cwd := conf.application.working_directory.resolve()).exists():
|
||||
# os.chdir(cwd)
|
||||
# if local_conf := Config.open(cwd / "config.toml"):
|
||||
# conf = local_conf
|
||||
return conf
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def db(config: Config, request) -> Database:
|
||||
config_postgres = config
|
||||
config_postgres.database.user = "postgres"
|
||||
config_postgres.database.name = "postgres"
|
||||
database_postgres = await Database.new(
|
||||
config_postgres.database.url(), poolclass=NullPool
|
||||
)
|
||||
|
||||
async with database_postgres.connection() as connection:
|
||||
await connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
await connection.execute(sa.text("create role pytest login"))
|
||||
await connection.execute(sa.text("create database pytest owner pytest"))
|
||||
await connection.commit()
|
||||
|
||||
await database_postgres.dispose()
|
||||
|
||||
config.database.user = "pytest"
|
||||
config.database.name = "pytest"
|
||||
database = await Database.new(config.database.url(), poolclass=NullPool)
|
||||
|
||||
yield database
|
||||
|
||||
await database.dispose()
|
||||
|
||||
async with database_postgres.connection() as connection:
|
||||
await connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||
await connection.execute(sa.text("drop database pytest")),
|
||||
await connection.execute(sa.text("drop role pytest"))
|
||||
await connection.commit()
|
||||
await database_postgres.dispose()
|
||||
|
||||
|
||||
"""
|
||||
@pytest.mark.asyncio
|
||||
async def test_migrations(db):
|
||||
await db.run_migrations()
|
||||
await db.rollback_migrations()
|
||||
"""
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
||||
async def setup_db(db: Database, request):
|
||||
async with db.connection() as connection:
|
||||
await connection.run_sync(Base.metadata.create_all)
|
||||
await connection.commit()
|
||||
yield
|
||||
async with db.connection() as connection:
|
||||
await connection.run_sync(Base.metadata.drop_all)
|
||||
await connection.commit()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def session(db: Database, request):
|
||||
session = db.sessionmaker()
|
||||
yield session
|
||||
await session.rollback()
|
||||
await session.close()
|
||||
|
||||
|
||||
"""
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def user(config: Config, session) -> User:
|
||||
test_user = User(
|
||||
name="pytest",
|
||||
lower_name="pytest",
|
||||
email="pytest@example.com",
|
||||
hashed_password=security.hash_password(
|
||||
"iampytest", algo=config.security.password_hash_algo
|
||||
),
|
||||
login_type=LoginType.Plain,
|
||||
is_admin=True,
|
||||
)
|
||||
|
||||
async with db.session() as session:
|
||||
session.add(test_user)
|
||||
await session.flush()
|
||||
await session.refresh(test_user)
|
||||
|
||||
yield test_user
|
||||
|
||||
async with db.session() as session:
|
||||
await session.delete(test_user)
|
||||
await session.flush()
|
||||
"""
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="function")
|
||||
async def data(config: Config):
|
||||
class TestData:
|
||||
user = User(
|
||||
name="PyTest",
|
||||
lower_name="pytest",
|
||||
email="pytest@example.com",
|
||||
hashed_password=security.hash_password(
|
||||
"iampytest", algo=config.security.password_hash_algo
|
||||
),
|
||||
login_type=LoginType.Plain,
|
||||
is_admin=True,
|
||||
)
|
||||
|
||||
return TestData()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user(data, session: SessionContext, config: Config):
|
||||
# simple
|
||||
@ -161,6 +39,10 @@ async def test_user(data, session: SessionContext, config: Config):
|
||||
|
||||
await data.user.edit_name("AsyncPyTest", session)
|
||||
assert await User.by_name("asyncpytest", session, with_lower=True) == data.user
|
||||
assert await User.by_email("pytest@example.com", session) == data.user
|
||||
assert await User.by_id(data.user.id, session) == data.user
|
||||
await data.user.edit_password("iamnotpytest", session, config)
|
||||
assert security.validate_password("iamnotpytest", data.user.hashed_password)
|
||||
|
||||
await data.user.remove(session)
|
||||
|
||||
@ -280,9 +162,9 @@ async def test_directory(data, tmpdir, session: SessionContext, config: Config):
|
||||
|
||||
# rename
|
||||
assert (await directory.rename("test1", session, config)).name == "test1"
|
||||
directory2 = await Directory(
|
||||
repository_id=repository.id, parent_id=None, name="test2"
|
||||
).new(session, config)
|
||||
await Directory(repository_id=repository.id, parent_id=None, name="test2").new(
|
||||
session, config
|
||||
)
|
||||
assert (await directory.rename("test2", session, config)).name == "test2.1"
|
||||
assert (await repository.path(session, config)).joinpath("test2.1").exists()
|
||||
assert not (await repository.path(session, config)).joinpath("test1").exists()
|
||||
@ -358,7 +240,7 @@ async def test_file(data, tmpdir, session: SessionContext, config: Config):
|
||||
assert (
|
||||
await file.rename("test_file_rename.txt", session, config)
|
||||
).name == "test_file_rename.txt"
|
||||
file2 = await File(
|
||||
await File(
|
||||
repository_id=repository.id, parent_id=directory.id, name="test_file_2.txt"
|
||||
).new(b"", session, config)
|
||||
assert (
|
Loading…
Reference in New Issue
Block a user