tests and fixtures

This commit is contained in:
L-Nafaryus 2024-08-14 00:56:30 +05:00
parent aefedfe187
commit 58e7175d45
Signed by: L-Nafaryus
GPG Key ID: 553C97999B363D38
16 changed files with 516 additions and 241 deletions

3
.gitignore vendored
View File

@ -10,3 +10,6 @@ __pycache__/
.pdm.toml
.pdm-python
.pdm-build
.pytest_cache
.coverage

79
pdm.lock generated
View File

@ -5,7 +5,7 @@
groups = ["default", "dev"]
strategy = ["cross_platform", "inherit_metadata"]
lock_version = "4.4.1"
content_hash = "sha256:fe3214096aaef3097e2009f717762fb370bb726aa89a52e7b2a40d60016be987"
content_hash = "sha256:47f5e7de3c9bda99b31aadaaabcc4a7efe77f94ff969135bb278cabcb41d1e20"
[[package]]
name = "aiofiles"
@ -97,6 +97,20 @@ files = [
{file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"},
]
[[package]]
name = "asgi-lifespan"
version = "2.1.0"
requires_python = ">=3.7"
summary = "Programmatic startup/shutdown of ASGI apps."
groups = ["dev"]
dependencies = [
"sniffio",
]
files = [
{file = "asgi-lifespan-2.1.0.tar.gz", hash = "sha256:5e2effaf0bfe39829cf2d64e7ecc47c7d86d676a6599f7afba378c31f5e3a308"},
{file = "asgi_lifespan-2.1.0-py3-none-any.whl", hash = "sha256:ed840706680e28428c01e14afb3875d7d76d3206f3d5b2f2294e059b5c23804f"},
]
[[package]]
name = "asyncpg"
version = "0.29.0"
@ -296,6 +310,52 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "coverage"
version = "7.6.1"
requires_python = ">=3.8"
summary = "Code coverage measurement for Python"
groups = ["dev"]
files = [
{file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"},
{file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"},
{file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"},
{file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"},
{file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"},
{file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"},
]
[[package]]
name = "coverage"
version = "7.6.1"
extras = ["toml"]
requires_python = ">=3.8"
summary = "Code coverage measurement for Python"
groups = ["dev"]
dependencies = [
"coverage==7.6.1",
]
files = [
{file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"},
{file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"},
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"},
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"},
{file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"},
{file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"},
{file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"},
{file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"},
]
[[package]]
name = "cryptography"
version = "43.0.0"
@ -1044,6 +1104,21 @@ files = [
{file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"},
]
[[package]]
name = "pytest-cov"
version = "5.0.0"
requires_python = ">=3.8"
summary = "Pytest plugin for measuring coverage."
groups = ["dev"]
dependencies = [
"coverage[toml]>=5.2.1",
"pytest>=4.6",
]
files = [
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"
@ -1157,7 +1232,7 @@ name = "sniffio"
version = "1.3.1"
requires_python = ">=3.7"
summary = "Sniff out which async library your code is running under"
groups = ["default"]
groups = ["default", "dev"]
files = [
{file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"},
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},

View File

@ -41,16 +41,6 @@ requires-python = ">=3.12,<3.13"
readme = "README.md"
license = {text = "MIT"}
[tool.pdm.dev-dependencies]
dev = [
"-e file:///${PROJECT_ROOT}/workspaces/frontend",
"black<24.0.0,>=23.3.0",
"pytest<8.0.0,>=7.3.2",
"pyflakes<4.0.0,>=3.0.1",
"pyright<2.0.0,>=1.1.314",
"pytest-asyncio>=0.23.7",
]
[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"
@ -65,8 +55,20 @@ reportGeneralTypeIssues = false
pythonpath = ["."]
testpaths = ["tests"]
[tool.pdm]
distribution = true
[tool.pdm.dev-dependencies]
dev = [
"-e file:///${PROJECT_ROOT}/workspaces/frontend",
"black<24.0.0,>=23.3.0",
"pytest<8.0.0,>=7.3.2",
"pyflakes<4.0.0,>=3.0.1",
"pyright<2.0.0,>=1.1.314",
"pytest-asyncio>=0.23.7",
"asgi-lifespan>=2.1.0",
"pytest-cov>=5.0.0",
]
[tool.pdm.build]
includes = ["src/materia"]

View File

@ -44,6 +44,9 @@ class Server(BaseModel):
port: int = 54601
domain: str = "localhost"
def url(self) -> str:
return "{}://{}:{}".format(self.scheme, self.address, self.port)
class Database(BaseModel):
backend: Literal["postgresql"] = "postgresql"

View File

@ -80,7 +80,7 @@ class Database:
async with database.connection() as connection:
await connection.rollback()
except Exception as e:
raise DatabaseError(f"{e}")
raise DatabaseError(f"Failed to connect to database: {url}") from e
return database
@ -94,7 +94,7 @@ class Database:
yield connection
except Exception as e:
await connection.rollback()
raise DatabaseError(f"{e}")
raise DatabaseError(*e.args) from e
@asynccontextmanager
async def session(self) -> SessionContext:
@ -102,12 +102,12 @@ class Database:
try:
yield session
except HTTPException as e:
await session.rollback()
raise e from None
except Exception as e:
await session.rollback()
raise DatabaseError(f"{e}")
except HTTPException:
# if the initial exception reaches HTTPException, then everything is handled fine (should be)
await session.rollback()
raise DatabaseError(*e.args) from e
finally:
await session.close()

View File

@ -98,8 +98,14 @@ class Repository(Base):
await session.refresh(user, attribute_names=["repository"])
return user.repository
async def info(self) -> "RepositoryInfo":
return RepositoryInfo.model_validate(self)
async def info(self, session: SessionContext) -> "RepositoryInfo":
session.add(self)
await session.refresh(self, attribute_names=["files"])
info = RepositoryInfo.model_validate(self)
info.used = sum([file.size for file in self.files])
return info
class RepositoryInfo(BaseModel):

View File

@ -1,5 +1,5 @@
from uuid import UUID, uuid4
from typing import Optional, Self
from typing import Optional, Self, BinaryIO
import time
import re
@ -8,6 +8,9 @@ import pydantic
from sqlalchemy import BigInteger, Enum
from sqlalchemy.orm import mapped_column, Mapped, relationship
import sqlalchemy as sa
from PIL import Image
from sqids.sqids import Sqids
from aiofiles import os as async_os
from materia import security
from materia.models.base import Base
@ -82,8 +85,7 @@ class User(Base):
@staticmethod
def check_password(password: str, config: Config) -> bool:
if len(password) < config.security.password_min_length:
return False
return not len(password) < config.security.password_min_length
@staticmethod
async def count(session: SessionContext) -> Optional[int]:
@ -146,6 +148,57 @@ class User(Base):
return user_info
async def edit_avatar(
self, avatar: BinaryIO | None, session: SessionContext, config: Config
):
avatar_dir = config.application.working_directory.joinpath("avatars")
if avatar is None:
if self.avatar is None:
return
avatar_file = FileSystem(
avatar_dir.joinpath(self.avatar), config.application.working_directory
)
if await avatar_file.exists():
await avatar_file.remove()
session.add(self)
self.avatar = None
await session.flush()
return
try:
image = Image.open(avatar)
except Exception as e:
raise UserError("Failed to read avatar data") from e
avatar_hashes: list[str] = (
await session.scalars(sa.select(User.avatar).where(User.avatar.isnot(None)))
).all()
avatar_id = Sqids(min_length=10, blocklist=avatar_hashes).encode(
[int(time.time())]
)
try:
if not avatar_dir.exists():
await async_os.mkdir(avatar_dir)
image.save(avatar_dir.joinpath(avatar_id), format=image.format)
except Exception as e:
raise UserError(f"Failed to save avatar: {e}") from e
if old_avatar := self.avatar:
avatar_file = FileSystem(
avatar_dir.joinpath(old_avatar), config.application.working_directory
)
if await avatar_file.exists():
await avatar_file.remove()
session.add(self)
self.avatar = avatar_id
await session.flush()
class UserCredentials(BaseModel):
name: str
@ -177,3 +230,4 @@ class UserInfo(BaseModel):
from materia.models.repository import Repository
from materia.models.filesystem import FileSystem

View File

@ -10,24 +10,21 @@ router = APIRouter(tags=["auth"])
@router.post("/auth/signup")
async def signup(body: UserCredentials, ctx: Context = Depends()):
if not User.check_username(body.name):
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid username"
)
if not User.check_password(body.password, ctx.config):
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Password is too short (minimum length {ctx.config.security.password_min_length})",
)
async with ctx.database.session() as session:
if not User.check_username(body.name):
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid username"
)
if not User.check_password(body.password, ctx.config):
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Password is too short (minimum length {ctx.config.security.password_min_length})",
)
if await User.by_name(body.name, session, with_lower=True):
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="User already exists"
)
raise HTTPException(status.HTTP_409_CONFLICT, detail="User already exists")
if await User.by_email(body.email, session): # type: ignore
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email already used"
)
raise HTTPException(status.HTTP_409_CONFLICT, detail="Email already used")
count: Optional[int] = await User.count(session)
@ -42,14 +39,20 @@ async def signup(body: UserCredentials, ctx: Context = Depends()):
login_type=LoginType.Plain,
# first registered user is admin
is_admin=count == 0,
).new(session)
).new(session, ctx.config)
await session.commit()
@router.post("/auth/signin")
async def signin(body: UserCredentials, response: Response, ctx: Context = Depends()):
if (current_user := await User.by_name(body.name, ctx.database)) is None:
if (current_user := await User.by_email(str(body.email), ctx.database)) is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid email")
async with ctx.database.session() as session:
if (current_user := await User.by_name(body.name, session)) is None:
if (current_user := await User.by_email(str(body.email), session)) is None:
raise HTTPException(
status.HTTP_401_UNAUTHORIZED, detail="Invalid email"
)
if not security.validate_password(
body.password,
current_user.hashed_password,

View File

@ -22,14 +22,15 @@ async def create(
user: User = Depends(middleware.user), ctx: middleware.Context = Depends()
):
async with ctx.database.session() as session:
if await Repository.by_user(user, session):
if await Repository.from_user(user, session):
raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists")
async with ctx.database.session() as session:
try:
await Repository(
user_id=user.id, capacity=ctx.config.repository.capacity
).new(session)
).new(session, ctx.config)
await session.commit()
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, detail=" ".join(e.args)
@ -41,13 +42,7 @@ async def info(
repository=Depends(middleware.repository), ctx: middleware.Context = Depends()
):
async with ctx.database.session() as session:
session.add(repository)
await session.refresh(repository, attribute_names=["files"])
info = RepositoryInfo.model_validate(repository)
info.used = sum([file.size for file in repository.files])
return info
return await repository.info(session)
@router.delete("/repository")

View File

@ -19,71 +19,56 @@ router = APIRouter(tags=["user"])
async def info(
claims=Depends(middleware.jwt_cookie), ctx: middleware.Context = Depends()
):
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
async with ctx.database.session() as session:
if not (current_user := await User.by_id(uuid.UUID(claims.sub), session)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
info = UserInfo.model_validate(current_user)
if current_user.is_email_private:
info.email = None
return info
return current_user.info()
@router.delete("/user")
async def remove(
user: User = Depends(middleware.user), ctx: middleware.Context = Depends()
):
repository_path = Config.data_dir() / "repository" / user.lower_name
async with ctx.database.session() as session:
session.add(user)
await session.refresh(user, attribute_names=["repository"])
try:
if repository_path.exists():
shutil.rmtree(str(repository_path))
except OSError:
async with ctx.database.session() as session:
await user.remove(session)
await session.commit()
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to remove user"
)
await user.repository.remove(ctx.database)
status.HTTP_500_INTERNAL_SERVER_ERROR, f"Failed to remove user: {e}"
) from e
@router.post("/user/avatar")
@router.put("/user/avatar")
async def avatar(
file: UploadFile,
user: User = Depends(middleware.user),
ctx: middleware.Context = Depends(),
):
async with ctx.database.session() as session:
avatars: list[str] = (await session.scalars(sa.select(User.avatar))).all()
avatars = list(filter(lambda avatar_hash: avatar_hash, avatars))
try:
await user.edit_avatar(io.BytesIO(await file.read()), session, ctx.config)
await session.commit()
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
f"{e}",
)
avatar_id = Sqids(min_length=10, blocklist=avatars).encode([len(avatars)])
try:
img = Image.open(io.BytesIO(await file.read()))
except OSError as _:
raise HTTPException(
status.HTTP_422_UNPROCESSABLE_ENTITY, "Failed to read file data"
)
try:
if not (avatars_dir := Config.data_dir() / "avatars").exists():
avatars_dir.mkdir()
img.save(avatars_dir / avatar_id, format=img.format)
except OSError as _:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to save avatar"
)
if old_avatar := user.avatar:
if (old_file := Config.data_dir() / "avatars" / old_avatar).exists():
old_file.unlink()
@router.delete("/user/avatar")
async def remove_avatar(
user: User = Depends(middleware.user),
ctx: middleware.Context = Depends(),
):
async with ctx.database.session() as session:
await session.execute(
sa.update(User).where(User.id == user.id).values(avatar=avatar_id)
)
await session.commit()
try:
await user.edit_avatar(None, session, ctx.config)
await session.commit()
except Exception as e:
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
f"{e}",
)

View File

@ -82,15 +82,17 @@ async def jwt_cookie(request: Request, response: Response, ctx: Context = Depend
except jwt.PyJWTError as e:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}")
if not await User.by_id(uuid.UUID(access_claims.sub), ctx.database):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
async with ctx.database.session() as session:
if not await User.by_id(uuid.UUID(access_claims.sub), session):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
return access_claims
async def user(claims=Depends(jwt_cookie), ctx: Context = Depends()) -> User:
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
async with ctx.database.session() as session:
if not (current_user := await User.by_id(uuid.UUID(claims.sub), session)):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
return current_user

View File

@ -15,6 +15,4 @@ else:
@router.get("/{spa:path}", response_class=HTMLResponse)
async def root(request: Request):
return templates.TemplateResponse(
"base.html", {"request": request, "view": "app"}
)
return templates.TemplateResponse(request, "base.html", {"view": "app"})

View File

@ -1,26 +1,29 @@
from typing import Optional
import datetime
import datetime
from pydantic import BaseModel
import jwt
import jwt
class TokenClaims(BaseModel):
sub: str
exp: int
iat: int
sub: str
exp: int
iat: int
iss: Optional[str] = None
def generate_token(sub: str, secret: str, duration: int, iss: Optional[str] = None) -> str:
def generate_token(
sub: str, secret: str, duration: int, iss: Optional[str] = None
) -> str:
now = datetime.datetime.now()
iat = now.timestamp()
exp = (now + datetime.timedelta(seconds = duration)).timestamp()
claims = TokenClaims(sub = sub, exp = int(exp), iat = int(iat), iss = iss)
exp = (now + datetime.timedelta(seconds=duration)).timestamp()
claims = TokenClaims(sub=sub, exp=int(exp), iat=int(iat), iss=iss)
return jwt.encode(claims.model_dump(), secret)
def validate_token(token: str, secret: str) -> TokenClaims:
payload = jwt.decode(token, secret, algorithms = [ "HS256" ])
payload = jwt.decode(token, secret, algorithms=["HS256"])
return TokenClaims(**payload)

178
tests/conftest.py Normal file
View File

@ -0,0 +1,178 @@
import pytest_asyncio
from materia.config import Config
from materia.models import (
Database,
Cache,
User,
LoginType,
)
from materia.models.base import Base
from materia import security
import sqlalchemy as sa
from sqlalchemy.pool import NullPool
from materia.app import make_application, AppContext
from materia._logging import make_logger
from httpx import AsyncClient, ASGITransport, Cookies
import asyncio
from fastapi import FastAPI
from contextlib import asynccontextmanager
from typing import AsyncIterator
from asgi_lifespan import LifespanManager
from fastapi.middleware.cors import CORSMiddleware
from materia import routers
from pathlib import Path
@pytest_asyncio.fixture(scope="session")
async def config() -> Config:
conf = Config()
conf.database.port = 54320
conf.cache.port = 63790
return conf
@pytest_asyncio.fixture(scope="session")
async def database(config: Config) -> Database:
config_postgres = config
config_postgres.database.user = "postgres"
config_postgres.database.name = "postgres"
database_postgres = await Database.new(
config_postgres.database.url(), poolclass=NullPool
)
async with database_postgres.connection() as connection:
await connection.execution_options(isolation_level="AUTOCOMMIT")
await connection.execute(sa.text("create role pytest login"))
await connection.execute(sa.text("create database pytest owner pytest"))
await connection.commit()
await database_postgres.dispose()
config.database.user = "pytest"
config.database.name = "pytest"
database_pytest = await Database.new(config.database.url(), poolclass=NullPool)
yield database_pytest
await database_pytest.dispose()
async with database_postgres.connection() as connection:
await connection.execution_options(isolation_level="AUTOCOMMIT")
await connection.execute(sa.text("drop database pytest")),
await connection.execute(sa.text("drop role pytest"))
await connection.commit()
await database_postgres.dispose()
@pytest_asyncio.fixture(scope="session")
async def cache(config: Config) -> Cache:
config_pytest = config
config_pytest.cache.user = "pytest"
cache_pytest = await Cache.new(config_pytest.cache.url())
yield cache_pytest
@pytest_asyncio.fixture(scope="function", autouse=True)
async def setup_database(database: Database):
async with database.connection() as connection:
await connection.run_sync(Base.metadata.create_all)
await connection.commit()
yield
async with database.connection() as connection:
await connection.run_sync(Base.metadata.drop_all)
await connection.commit()
@pytest_asyncio.fixture()
async def session(database: Database, request):
session = database.sessionmaker()
yield session
await session.rollback()
await session.close()
@pytest_asyncio.fixture(scope="function")
async def data(config: Config):
class TestData:
user = User(
name="PyTest",
lower_name="pytest",
email="pytest@example.com",
hashed_password=security.hash_password(
"iampytest", algo=config.security.password_hash_algo
),
login_type=LoginType.Plain,
is_admin=True,
)
return TestData()
@pytest_asyncio.fixture(scope="function")
async def api_config(config: Config, tmpdir) -> Config:
config.application.working_directory = Path(tmpdir)
config.oauth2.jwt_secret = "pytest_secret_key"
yield config
@pytest_asyncio.fixture(scope="function")
async def api_client(
api_config: Config, database: Database, cache: Cache
) -> AsyncClient:
logger = make_logger(api_config)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
yield AppContext(
config=api_config, database=database, cache=cache, logger=logger
)
app = FastAPI(lifespan=lifespan)
app.include_router(routers.api.router)
app.include_router(routers.resources.router)
app.include_router(routers.root.router)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
async with LifespanManager(app) as manager:
async with AsyncClient(
transport=ASGITransport(app=manager.app), base_url=api_config.server.url()
) as client:
yield client
@pytest_asyncio.fixture(scope="function")
async def auth_client(api_client: AsyncClient, api_config: Config) -> AsyncClient:
data = {"name": "PyTest", "password": "iampytest", "email": "pytest@example.com"}
await api_client.post(
"/api/auth/signup",
json=data,
)
auth = await api_client.post(
"/api/auth/signin",
json=data,
)
cookies = Cookies()
cookies.set(
"materia_at",
auth.cookies[api_config.security.cookie_access_token_name],
)
cookies.set(
"materia_rt",
auth.cookies[api_config.security.cookie_refresh_token_name],
)
api_client.cookies = cookies
yield api_client

86
tests/test_api.py Normal file
View File

@ -0,0 +1,86 @@
import pytest
from materia.config import Config
from httpx import AsyncClient, Cookies
from materia.models.base import Base
import aiofiles
from io import BytesIO
@pytest.mark.asyncio
async def test_auth(api_client: AsyncClient, api_config: Config):
data = {"name": "PyTest", "password": "iampytest", "email": "pytest@example.com"}
response = await api_client.post(
"/api/auth/signup",
json=data,
)
assert response.status_code == 200
response = await api_client.post(
"/api/auth/signin",
json=data,
)
assert response.status_code == 200
assert response.cookies.get(api_config.security.cookie_access_token_name)
assert response.cookies.get(api_config.security.cookie_refresh_token_name)
# TODO: conflict usernames and emails
response = await api_client.get("/api/auth/signout")
assert response.status_code == 200
@pytest.mark.asyncio
async def test_user(auth_client: AsyncClient, api_config: Config):
info = await auth_client.get("/api/user")
assert info.status_code == 200, info.text
async with AsyncClient() as client:
pytest_logo_res = await client.get(
"https://docs.pytest.org/en/stable/_static/pytest1.png"
)
assert isinstance(pytest_logo_res.content, bytes)
pytest_logo = BytesIO(pytest_logo_res.content)
avatar = await auth_client.put(
"/api/user/avatar",
files={"file": ("pytest.png", pytest_logo)},
)
assert avatar.status_code == 200, avatar.text
info = await auth_client.get("/api/user")
avatar_info = info.json()["avatar"]
assert avatar_info is not None
assert api_config.application.working_directory.joinpath(
"avatars", avatar_info
).exists()
avatar = await auth_client.delete("/api/user/avatar")
assert avatar.status_code == 200, avatar.text
info = await auth_client.get("/api/user")
assert info.json()["avatar"] is None
assert not api_config.application.working_directory.joinpath(
"avatars", avatar_info
).exists()
delete = await auth_client.delete("/api/user")
assert delete.status_code == 200, delete.text
info = await auth_client.get("/api/user")
assert info.status_code == 401, info.text
@pytest.mark.asyncio
async def test_repository(auth_client: AsyncClient, api_config: Config):
info = await auth_client.get("/api/repository")
assert info.status_code == 404, info.text
create = await auth_client.post("/api/repository")
assert create.status_code == 200, create.text
create = await auth_client.post("/api/repository")
assert create.status_code == 409, create.text
info = await auth_client.get("/api/repository")
assert info.status_code == 200, info.text

View File

@ -1,145 +1,23 @@
import pytest_asyncio
import pytest
import os
import sys
from pathlib import Path
from materia.config import Config
from materia.models import (
Database,
User,
LoginType,
Repository,
Directory,
RepositoryError,
File,
)
from materia.models.base import Base
from materia.models.database import SessionContext
from materia import security
import sqlalchemy as sa
from sqlalchemy.pool import NullPool
from sqlalchemy.orm.session import make_transient
from sqlalchemy import inspect
import aiofiles
import aiofiles.os
@pytest_asyncio.fixture(scope="session")
async def config() -> Config:
conf = Config()
conf.database.port = 54320
# conf.application.working_directory = conf.application.working_directory / "temp"
# if (cwd := conf.application.working_directory.resolve()).exists():
# os.chdir(cwd)
# if local_conf := Config.open(cwd / "config.toml"):
# conf = local_conf
return conf
@pytest_asyncio.fixture(scope="session")
async def db(config: Config, request) -> Database:
config_postgres = config
config_postgres.database.user = "postgres"
config_postgres.database.name = "postgres"
database_postgres = await Database.new(
config_postgres.database.url(), poolclass=NullPool
)
async with database_postgres.connection() as connection:
await connection.execution_options(isolation_level="AUTOCOMMIT")
await connection.execute(sa.text("create role pytest login"))
await connection.execute(sa.text("create database pytest owner pytest"))
await connection.commit()
await database_postgres.dispose()
config.database.user = "pytest"
config.database.name = "pytest"
database = await Database.new(config.database.url(), poolclass=NullPool)
yield database
await database.dispose()
async with database_postgres.connection() as connection:
await connection.execution_options(isolation_level="AUTOCOMMIT")
await connection.execute(sa.text("drop database pytest")),
await connection.execute(sa.text("drop role pytest"))
await connection.commit()
await database_postgres.dispose()
"""
@pytest.mark.asyncio
async def test_migrations(db):
await db.run_migrations()
await db.rollback_migrations()
"""
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_db(db: Database, request):
async with db.connection() as connection:
await connection.run_sync(Base.metadata.create_all)
await connection.commit()
yield
async with db.connection() as connection:
await connection.run_sync(Base.metadata.drop_all)
await connection.commit()
@pytest_asyncio.fixture(autouse=True)
async def session(db: Database, request):
session = db.sessionmaker()
yield session
await session.rollback()
await session.close()
"""
@pytest_asyncio.fixture(scope="session")
async def user(config: Config, session) -> User:
test_user = User(
name="pytest",
lower_name="pytest",
email="pytest@example.com",
hashed_password=security.hash_password(
"iampytest", algo=config.security.password_hash_algo
),
login_type=LoginType.Plain,
is_admin=True,
)
async with db.session() as session:
session.add(test_user)
await session.flush()
await session.refresh(test_user)
yield test_user
async with db.session() as session:
await session.delete(test_user)
await session.flush()
"""
@pytest_asyncio.fixture(scope="function")
async def data(config: Config):
class TestData:
user = User(
name="PyTest",
lower_name="pytest",
email="pytest@example.com",
hashed_password=security.hash_password(
"iampytest", algo=config.security.password_hash_algo
),
login_type=LoginType.Plain,
is_admin=True,
)
return TestData()
@pytest.mark.asyncio
async def test_user(data, session: SessionContext, config: Config):
# simple
@ -161,6 +39,10 @@ async def test_user(data, session: SessionContext, config: Config):
await data.user.edit_name("AsyncPyTest", session)
assert await User.by_name("asyncpytest", session, with_lower=True) == data.user
assert await User.by_email("pytest@example.com", session) == data.user
assert await User.by_id(data.user.id, session) == data.user
await data.user.edit_password("iamnotpytest", session, config)
assert security.validate_password("iamnotpytest", data.user.hashed_password)
await data.user.remove(session)
@ -280,9 +162,9 @@ async def test_directory(data, tmpdir, session: SessionContext, config: Config):
# rename
assert (await directory.rename("test1", session, config)).name == "test1"
directory2 = await Directory(
repository_id=repository.id, parent_id=None, name="test2"
).new(session, config)
await Directory(repository_id=repository.id, parent_id=None, name="test2").new(
session, config
)
assert (await directory.rename("test2", session, config)).name == "test2.1"
assert (await repository.path(session, config)).joinpath("test2.1").exists()
assert not (await repository.path(session, config)).joinpath("test1").exists()
@ -358,7 +240,7 @@ async def test_file(data, tmpdir, session: SessionContext, config: Config):
assert (
await file.rename("test_file_rename.txt", session, config)
).name == "test_file_rename.txt"
file2 = await File(
await File(
repository_id=repository.id, parent_id=directory.id, name="test_file_2.txt"
).new(b"", session, config)
assert (