tests and fixtures
This commit is contained in:
parent
aefedfe187
commit
58e7175d45
3
.gitignore
vendored
3
.gitignore
vendored
@ -10,3 +10,6 @@ __pycache__/
|
|||||||
.pdm.toml
|
.pdm.toml
|
||||||
.pdm-python
|
.pdm-python
|
||||||
.pdm-build
|
.pdm-build
|
||||||
|
|
||||||
|
.pytest_cache
|
||||||
|
.coverage
|
||||||
|
79
pdm.lock
generated
79
pdm.lock
generated
@ -5,7 +5,7 @@
|
|||||||
groups = ["default", "dev"]
|
groups = ["default", "dev"]
|
||||||
strategy = ["cross_platform", "inherit_metadata"]
|
strategy = ["cross_platform", "inherit_metadata"]
|
||||||
lock_version = "4.4.1"
|
lock_version = "4.4.1"
|
||||||
content_hash = "sha256:fe3214096aaef3097e2009f717762fb370bb726aa89a52e7b2a40d60016be987"
|
content_hash = "sha256:47f5e7de3c9bda99b31aadaaabcc4a7efe77f94ff969135bb278cabcb41d1e20"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiofiles"
|
name = "aiofiles"
|
||||||
@ -97,6 +97,20 @@ files = [
|
|||||||
{file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"},
|
{file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "asgi-lifespan"
|
||||||
|
version = "2.1.0"
|
||||||
|
requires_python = ">=3.7"
|
||||||
|
summary = "Programmatic startup/shutdown of ASGI apps."
|
||||||
|
groups = ["dev"]
|
||||||
|
dependencies = [
|
||||||
|
"sniffio",
|
||||||
|
]
|
||||||
|
files = [
|
||||||
|
{file = "asgi-lifespan-2.1.0.tar.gz", hash = "sha256:5e2effaf0bfe39829cf2d64e7ecc47c7d86d676a6599f7afba378c31f5e3a308"},
|
||||||
|
{file = "asgi_lifespan-2.1.0-py3-none-any.whl", hash = "sha256:ed840706680e28428c01e14afb3875d7d76d3206f3d5b2f2294e059b5c23804f"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "asyncpg"
|
name = "asyncpg"
|
||||||
version = "0.29.0"
|
version = "0.29.0"
|
||||||
@ -296,6 +310,52 @@ files = [
|
|||||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "coverage"
|
||||||
|
version = "7.6.1"
|
||||||
|
requires_python = ">=3.8"
|
||||||
|
summary = "Code coverage measurement for Python"
|
||||||
|
groups = ["dev"]
|
||||||
|
files = [
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"},
|
||||||
|
{file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"},
|
||||||
|
{file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "coverage"
|
||||||
|
version = "7.6.1"
|
||||||
|
extras = ["toml"]
|
||||||
|
requires_python = ">=3.8"
|
||||||
|
summary = "Code coverage measurement for Python"
|
||||||
|
groups = ["dev"]
|
||||||
|
dependencies = [
|
||||||
|
"coverage==7.6.1",
|
||||||
|
]
|
||||||
|
files = [
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"},
|
||||||
|
{file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"},
|
||||||
|
{file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"},
|
||||||
|
{file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cryptography"
|
name = "cryptography"
|
||||||
version = "43.0.0"
|
version = "43.0.0"
|
||||||
@ -1044,6 +1104,21 @@ files = [
|
|||||||
{file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"},
|
{file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pytest-cov"
|
||||||
|
version = "5.0.0"
|
||||||
|
requires_python = ">=3.8"
|
||||||
|
summary = "Pytest plugin for measuring coverage."
|
||||||
|
groups = ["dev"]
|
||||||
|
dependencies = [
|
||||||
|
"coverage[toml]>=5.2.1",
|
||||||
|
"pytest>=4.6",
|
||||||
|
]
|
||||||
|
files = [
|
||||||
|
{file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"},
|
||||||
|
{file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"},
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "python-dateutil"
|
name = "python-dateutil"
|
||||||
version = "2.9.0.post0"
|
version = "2.9.0.post0"
|
||||||
@ -1157,7 +1232,7 @@ name = "sniffio"
|
|||||||
version = "1.3.1"
|
version = "1.3.1"
|
||||||
requires_python = ">=3.7"
|
requires_python = ">=3.7"
|
||||||
summary = "Sniff out which async library your code is running under"
|
summary = "Sniff out which async library your code is running under"
|
||||||
groups = ["default"]
|
groups = ["default", "dev"]
|
||||||
files = [
|
files = [
|
||||||
{file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"},
|
{file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"},
|
||||||
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
|
{file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"},
|
||||||
|
@ -41,16 +41,6 @@ requires-python = ">=3.12,<3.13"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = {text = "MIT"}
|
license = {text = "MIT"}
|
||||||
|
|
||||||
[tool.pdm.dev-dependencies]
|
|
||||||
dev = [
|
|
||||||
"-e file:///${PROJECT_ROOT}/workspaces/frontend",
|
|
||||||
"black<24.0.0,>=23.3.0",
|
|
||||||
"pytest<8.0.0,>=7.3.2",
|
|
||||||
"pyflakes<4.0.0,>=3.0.1",
|
|
||||||
"pyright<2.0.0,>=1.1.314",
|
|
||||||
"pytest-asyncio>=0.23.7",
|
|
||||||
]
|
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["pdm-backend"]
|
requires = ["pdm-backend"]
|
||||||
build-backend = "pdm.backend"
|
build-backend = "pdm.backend"
|
||||||
@ -65,8 +55,20 @@ reportGeneralTypeIssues = false
|
|||||||
pythonpath = ["."]
|
pythonpath = ["."]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
|
||||||
|
|
||||||
[tool.pdm]
|
[tool.pdm]
|
||||||
distribution = true
|
distribution = true
|
||||||
|
[tool.pdm.dev-dependencies]
|
||||||
|
dev = [
|
||||||
|
"-e file:///${PROJECT_ROOT}/workspaces/frontend",
|
||||||
|
"black<24.0.0,>=23.3.0",
|
||||||
|
"pytest<8.0.0,>=7.3.2",
|
||||||
|
"pyflakes<4.0.0,>=3.0.1",
|
||||||
|
"pyright<2.0.0,>=1.1.314",
|
||||||
|
"pytest-asyncio>=0.23.7",
|
||||||
|
"asgi-lifespan>=2.1.0",
|
||||||
|
"pytest-cov>=5.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.pdm.build]
|
[tool.pdm.build]
|
||||||
includes = ["src/materia"]
|
includes = ["src/materia"]
|
||||||
|
@ -44,6 +44,9 @@ class Server(BaseModel):
|
|||||||
port: int = 54601
|
port: int = 54601
|
||||||
domain: str = "localhost"
|
domain: str = "localhost"
|
||||||
|
|
||||||
|
def url(self) -> str:
|
||||||
|
return "{}://{}:{}".format(self.scheme, self.address, self.port)
|
||||||
|
|
||||||
|
|
||||||
class Database(BaseModel):
|
class Database(BaseModel):
|
||||||
backend: Literal["postgresql"] = "postgresql"
|
backend: Literal["postgresql"] = "postgresql"
|
||||||
|
@ -80,7 +80,7 @@ class Database:
|
|||||||
async with database.connection() as connection:
|
async with database.connection() as connection:
|
||||||
await connection.rollback()
|
await connection.rollback()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise DatabaseError(f"{e}")
|
raise DatabaseError(f"Failed to connect to database: {url}") from e
|
||||||
|
|
||||||
return database
|
return database
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ class Database:
|
|||||||
yield connection
|
yield connection
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await connection.rollback()
|
await connection.rollback()
|
||||||
raise DatabaseError(f"{e}")
|
raise DatabaseError(*e.args) from e
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def session(self) -> SessionContext:
|
async def session(self) -> SessionContext:
|
||||||
@ -102,12 +102,12 @@ class Database:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
yield session
|
yield session
|
||||||
|
except HTTPException as e:
|
||||||
|
await session.rollback()
|
||||||
|
raise e from None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await session.rollback()
|
await session.rollback()
|
||||||
raise DatabaseError(f"{e}")
|
raise DatabaseError(*e.args) from e
|
||||||
except HTTPException:
|
|
||||||
# if the initial exception reaches HTTPException, then everything is handled fine (should be)
|
|
||||||
await session.rollback()
|
|
||||||
finally:
|
finally:
|
||||||
await session.close()
|
await session.close()
|
||||||
|
|
||||||
|
@ -98,8 +98,14 @@ class Repository(Base):
|
|||||||
await session.refresh(user, attribute_names=["repository"])
|
await session.refresh(user, attribute_names=["repository"])
|
||||||
return user.repository
|
return user.repository
|
||||||
|
|
||||||
async def info(self) -> "RepositoryInfo":
|
async def info(self, session: SessionContext) -> "RepositoryInfo":
|
||||||
return RepositoryInfo.model_validate(self)
|
session.add(self)
|
||||||
|
await session.refresh(self, attribute_names=["files"])
|
||||||
|
|
||||||
|
info = RepositoryInfo.model_validate(self)
|
||||||
|
info.used = sum([file.size for file in self.files])
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
class RepositoryInfo(BaseModel):
|
class RepositoryInfo(BaseModel):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from uuid import UUID, uuid4
|
from uuid import UUID, uuid4
|
||||||
from typing import Optional, Self
|
from typing import Optional, Self, BinaryIO
|
||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
|
|
||||||
@ -8,6 +8,9 @@ import pydantic
|
|||||||
from sqlalchemy import BigInteger, Enum
|
from sqlalchemy import BigInteger, Enum
|
||||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from PIL import Image
|
||||||
|
from sqids.sqids import Sqids
|
||||||
|
from aiofiles import os as async_os
|
||||||
|
|
||||||
from materia import security
|
from materia import security
|
||||||
from materia.models.base import Base
|
from materia.models.base import Base
|
||||||
@ -82,8 +85,7 @@ class User(Base):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def check_password(password: str, config: Config) -> bool:
|
def check_password(password: str, config: Config) -> bool:
|
||||||
if len(password) < config.security.password_min_length:
|
return not len(password) < config.security.password_min_length
|
||||||
return False
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def count(session: SessionContext) -> Optional[int]:
|
async def count(session: SessionContext) -> Optional[int]:
|
||||||
@ -146,6 +148,57 @@ class User(Base):
|
|||||||
|
|
||||||
return user_info
|
return user_info
|
||||||
|
|
||||||
|
async def edit_avatar(
|
||||||
|
self, avatar: BinaryIO | None, session: SessionContext, config: Config
|
||||||
|
):
|
||||||
|
avatar_dir = config.application.working_directory.joinpath("avatars")
|
||||||
|
|
||||||
|
if avatar is None:
|
||||||
|
if self.avatar is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
avatar_file = FileSystem(
|
||||||
|
avatar_dir.joinpath(self.avatar), config.application.working_directory
|
||||||
|
)
|
||||||
|
if await avatar_file.exists():
|
||||||
|
await avatar_file.remove()
|
||||||
|
|
||||||
|
session.add(self)
|
||||||
|
self.avatar = None
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
image = Image.open(avatar)
|
||||||
|
except Exception as e:
|
||||||
|
raise UserError("Failed to read avatar data") from e
|
||||||
|
|
||||||
|
avatar_hashes: list[str] = (
|
||||||
|
await session.scalars(sa.select(User.avatar).where(User.avatar.isnot(None)))
|
||||||
|
).all()
|
||||||
|
avatar_id = Sqids(min_length=10, blocklist=avatar_hashes).encode(
|
||||||
|
[int(time.time())]
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not avatar_dir.exists():
|
||||||
|
await async_os.mkdir(avatar_dir)
|
||||||
|
image.save(avatar_dir.joinpath(avatar_id), format=image.format)
|
||||||
|
except Exception as e:
|
||||||
|
raise UserError(f"Failed to save avatar: {e}") from e
|
||||||
|
|
||||||
|
if old_avatar := self.avatar:
|
||||||
|
avatar_file = FileSystem(
|
||||||
|
avatar_dir.joinpath(old_avatar), config.application.working_directory
|
||||||
|
)
|
||||||
|
if await avatar_file.exists():
|
||||||
|
await avatar_file.remove()
|
||||||
|
|
||||||
|
session.add(self)
|
||||||
|
self.avatar = avatar_id
|
||||||
|
await session.flush()
|
||||||
|
|
||||||
|
|
||||||
class UserCredentials(BaseModel):
|
class UserCredentials(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@ -177,3 +230,4 @@ class UserInfo(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
from materia.models.repository import Repository
|
from materia.models.repository import Repository
|
||||||
|
from materia.models.filesystem import FileSystem
|
||||||
|
@ -10,7 +10,6 @@ router = APIRouter(tags=["auth"])
|
|||||||
|
|
||||||
@router.post("/auth/signup")
|
@router.post("/auth/signup")
|
||||||
async def signup(body: UserCredentials, ctx: Context = Depends()):
|
async def signup(body: UserCredentials, ctx: Context = Depends()):
|
||||||
async with ctx.database.session() as session:
|
|
||||||
if not User.check_username(body.name):
|
if not User.check_username(body.name):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid username"
|
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Invalid username"
|
||||||
@ -20,14 +19,12 @@ async def signup(body: UserCredentials, ctx: Context = Depends()):
|
|||||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail=f"Password is too short (minimum length {ctx.config.security.password_min_length})",
|
detail=f"Password is too short (minimum length {ctx.config.security.password_min_length})",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async with ctx.database.session() as session:
|
||||||
if await User.by_name(body.name, session, with_lower=True):
|
if await User.by_name(body.name, session, with_lower=True):
|
||||||
raise HTTPException(
|
raise HTTPException(status.HTTP_409_CONFLICT, detail="User already exists")
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="User already exists"
|
|
||||||
)
|
|
||||||
if await User.by_email(body.email, session): # type: ignore
|
if await User.by_email(body.email, session): # type: ignore
|
||||||
raise HTTPException(
|
raise HTTPException(status.HTTP_409_CONFLICT, detail="Email already used")
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Email already used"
|
|
||||||
)
|
|
||||||
|
|
||||||
count: Optional[int] = await User.count(session)
|
count: Optional[int] = await User.count(session)
|
||||||
|
|
||||||
@ -42,14 +39,20 @@ async def signup(body: UserCredentials, ctx: Context = Depends()):
|
|||||||
login_type=LoginType.Plain,
|
login_type=LoginType.Plain,
|
||||||
# first registered user is admin
|
# first registered user is admin
|
||||||
is_admin=count == 0,
|
is_admin=count == 0,
|
||||||
).new(session)
|
).new(session, ctx.config)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
@router.post("/auth/signin")
|
@router.post("/auth/signin")
|
||||||
async def signin(body: UserCredentials, response: Response, ctx: Context = Depends()):
|
async def signin(body: UserCredentials, response: Response, ctx: Context = Depends()):
|
||||||
if (current_user := await User.by_name(body.name, ctx.database)) is None:
|
async with ctx.database.session() as session:
|
||||||
if (current_user := await User.by_email(str(body.email), ctx.database)) is None:
|
if (current_user := await User.by_name(body.name, session)) is None:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, detail="Invalid email")
|
if (current_user := await User.by_email(str(body.email), session)) is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_401_UNAUTHORIZED, detail="Invalid email"
|
||||||
|
)
|
||||||
|
|
||||||
if not security.validate_password(
|
if not security.validate_password(
|
||||||
body.password,
|
body.password,
|
||||||
current_user.hashed_password,
|
current_user.hashed_password,
|
||||||
|
@ -22,14 +22,15 @@ async def create(
|
|||||||
user: User = Depends(middleware.user), ctx: middleware.Context = Depends()
|
user: User = Depends(middleware.user), ctx: middleware.Context = Depends()
|
||||||
):
|
):
|
||||||
async with ctx.database.session() as session:
|
async with ctx.database.session() as session:
|
||||||
if await Repository.by_user(user, session):
|
if await Repository.from_user(user, session):
|
||||||
raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists")
|
raise HTTPException(status.HTTP_409_CONFLICT, "Repository already exists")
|
||||||
|
|
||||||
async with ctx.database.session() as session:
|
async with ctx.database.session() as session:
|
||||||
try:
|
try:
|
||||||
await Repository(
|
await Repository(
|
||||||
user_id=user.id, capacity=ctx.config.repository.capacity
|
user_id=user.id, capacity=ctx.config.repository.capacity
|
||||||
).new(session)
|
).new(session, ctx.config)
|
||||||
|
await session.commit()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR, detail=" ".join(e.args)
|
status.HTTP_500_INTERNAL_SERVER_ERROR, detail=" ".join(e.args)
|
||||||
@ -41,13 +42,7 @@ async def info(
|
|||||||
repository=Depends(middleware.repository), ctx: middleware.Context = Depends()
|
repository=Depends(middleware.repository), ctx: middleware.Context = Depends()
|
||||||
):
|
):
|
||||||
async with ctx.database.session() as session:
|
async with ctx.database.session() as session:
|
||||||
session.add(repository)
|
return await repository.info(session)
|
||||||
await session.refresh(repository, attribute_names=["files"])
|
|
||||||
|
|
||||||
info = RepositoryInfo.model_validate(repository)
|
|
||||||
info.used = sum([file.size for file in repository.files])
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/repository")
|
@router.delete("/repository")
|
||||||
|
@ -19,71 +19,56 @@ router = APIRouter(tags=["user"])
|
|||||||
async def info(
|
async def info(
|
||||||
claims=Depends(middleware.jwt_cookie), ctx: middleware.Context = Depends()
|
claims=Depends(middleware.jwt_cookie), ctx: middleware.Context = Depends()
|
||||||
):
|
):
|
||||||
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
|
async with ctx.database.session() as session:
|
||||||
|
if not (current_user := await User.by_id(uuid.UUID(claims.sub), session)):
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
||||||
|
|
||||||
info = UserInfo.model_validate(current_user)
|
return current_user.info()
|
||||||
if current_user.is_email_private:
|
|
||||||
info.email = None
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/user")
|
@router.delete("/user")
|
||||||
async def remove(
|
async def remove(
|
||||||
user: User = Depends(middleware.user), ctx: middleware.Context = Depends()
|
user: User = Depends(middleware.user), ctx: middleware.Context = Depends()
|
||||||
):
|
):
|
||||||
repository_path = Config.data_dir() / "repository" / user.lower_name
|
|
||||||
|
|
||||||
async with ctx.database.session() as session:
|
|
||||||
session.add(user)
|
|
||||||
await session.refresh(user, attribute_names=["repository"])
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if repository_path.exists():
|
async with ctx.database.session() as session:
|
||||||
shutil.rmtree(str(repository_path))
|
await user.remove(session)
|
||||||
except OSError:
|
await session.commit()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to remove user"
|
status.HTTP_500_INTERNAL_SERVER_ERROR, f"Failed to remove user: {e}"
|
||||||
)
|
) from e
|
||||||
|
|
||||||
await user.repository.remove(ctx.database)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/user/avatar")
|
@router.put("/user/avatar")
|
||||||
async def avatar(
|
async def avatar(
|
||||||
file: UploadFile,
|
file: UploadFile,
|
||||||
user: User = Depends(middleware.user),
|
user: User = Depends(middleware.user),
|
||||||
ctx: middleware.Context = Depends(),
|
ctx: middleware.Context = Depends(),
|
||||||
):
|
):
|
||||||
async with ctx.database.session() as session:
|
async with ctx.database.session() as session:
|
||||||
avatars: list[str] = (await session.scalars(sa.select(User.avatar))).all()
|
|
||||||
avatars = list(filter(lambda avatar_hash: avatar_hash, avatars))
|
|
||||||
|
|
||||||
avatar_id = Sqids(min_length=10, blocklist=avatars).encode([len(avatars)])
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
img = Image.open(io.BytesIO(await file.read()))
|
await user.edit_avatar(io.BytesIO(await file.read()), session, ctx.config)
|
||||||
except OSError as _:
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_422_UNPROCESSABLE_ENTITY, "Failed to read file data"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not (avatars_dir := Config.data_dir() / "avatars").exists():
|
|
||||||
avatars_dir.mkdir()
|
|
||||||
img.save(avatars_dir / avatar_id, format=img.format)
|
|
||||||
except OSError as _:
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to save avatar"
|
|
||||||
)
|
|
||||||
|
|
||||||
if old_avatar := user.avatar:
|
|
||||||
if (old_file := Config.data_dir() / "avatars" / old_avatar).exists():
|
|
||||||
old_file.unlink()
|
|
||||||
|
|
||||||
async with ctx.database.session() as session:
|
|
||||||
await session.execute(
|
|
||||||
sa.update(User).where(User.id == user.id).values(avatar=avatar_id)
|
|
||||||
)
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
f"{e}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/user/avatar")
|
||||||
|
async def remove_avatar(
|
||||||
|
user: User = Depends(middleware.user),
|
||||||
|
ctx: middleware.Context = Depends(),
|
||||||
|
):
|
||||||
|
async with ctx.database.session() as session:
|
||||||
|
try:
|
||||||
|
await user.edit_avatar(None, session, ctx.config)
|
||||||
|
await session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
f"{e}",
|
||||||
|
)
|
||||||
|
@ -82,14 +82,16 @@ async def jwt_cookie(request: Request, response: Response, ctx: Context = Depend
|
|||||||
except jwt.PyJWTError as e:
|
except jwt.PyJWTError as e:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, f"Invalid token: {e}")
|
||||||
|
|
||||||
if not await User.by_id(uuid.UUID(access_claims.sub), ctx.database):
|
async with ctx.database.session() as session:
|
||||||
|
if not await User.by_id(uuid.UUID(access_claims.sub), session):
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid user")
|
||||||
|
|
||||||
return access_claims
|
return access_claims
|
||||||
|
|
||||||
|
|
||||||
async def user(claims=Depends(jwt_cookie), ctx: Context = Depends()) -> User:
|
async def user(claims=Depends(jwt_cookie), ctx: Context = Depends()) -> User:
|
||||||
if not (current_user := await User.by_id(uuid.UUID(claims.sub), ctx.database)):
|
async with ctx.database.session() as session:
|
||||||
|
if not (current_user := await User.by_id(uuid.UUID(claims.sub), session)):
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Missing user")
|
||||||
|
|
||||||
return current_user
|
return current_user
|
||||||
|
@ -15,6 +15,4 @@ else:
|
|||||||
|
|
||||||
@router.get("/{spa:path}", response_class=HTMLResponse)
|
@router.get("/{spa:path}", response_class=HTMLResponse)
|
||||||
async def root(request: Request):
|
async def root(request: Request):
|
||||||
return templates.TemplateResponse(
|
return templates.TemplateResponse(request, "base.html", {"view": "app"})
|
||||||
"base.html", {"request": request, "view": "app"}
|
|
||||||
)
|
|
||||||
|
@ -12,7 +12,9 @@ class TokenClaims(BaseModel):
|
|||||||
iss: Optional[str] = None
|
iss: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
def generate_token(sub: str, secret: str, duration: int, iss: Optional[str] = None) -> str:
|
def generate_token(
|
||||||
|
sub: str, secret: str, duration: int, iss: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
iat = now.timestamp()
|
iat = now.timestamp()
|
||||||
exp = (now + datetime.timedelta(seconds=duration)).timestamp()
|
exp = (now + datetime.timedelta(seconds=duration)).timestamp()
|
||||||
@ -20,6 +22,7 @@ def generate_token(sub: str, secret: str, duration: int, iss: Optional[str] = No
|
|||||||
|
|
||||||
return jwt.encode(claims.model_dump(), secret)
|
return jwt.encode(claims.model_dump(), secret)
|
||||||
|
|
||||||
|
|
||||||
def validate_token(token: str, secret: str) -> TokenClaims:
|
def validate_token(token: str, secret: str) -> TokenClaims:
|
||||||
payload = jwt.decode(token, secret, algorithms=["HS256"])
|
payload = jwt.decode(token, secret, algorithms=["HS256"])
|
||||||
|
|
||||||
|
178
tests/conftest.py
Normal file
178
tests/conftest.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
import pytest_asyncio
|
||||||
|
from materia.config import Config
|
||||||
|
from materia.models import (
|
||||||
|
Database,
|
||||||
|
Cache,
|
||||||
|
User,
|
||||||
|
LoginType,
|
||||||
|
)
|
||||||
|
from materia.models.base import Base
|
||||||
|
from materia import security
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.pool import NullPool
|
||||||
|
from materia.app import make_application, AppContext
|
||||||
|
from materia._logging import make_logger
|
||||||
|
from httpx import AsyncClient, ASGITransport, Cookies
|
||||||
|
import asyncio
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncIterator
|
||||||
|
from asgi_lifespan import LifespanManager
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from materia import routers
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def config() -> Config:
|
||||||
|
conf = Config()
|
||||||
|
conf.database.port = 54320
|
||||||
|
conf.cache.port = 63790
|
||||||
|
|
||||||
|
return conf
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def database(config: Config) -> Database:
|
||||||
|
config_postgres = config
|
||||||
|
config_postgres.database.user = "postgres"
|
||||||
|
config_postgres.database.name = "postgres"
|
||||||
|
database_postgres = await Database.new(
|
||||||
|
config_postgres.database.url(), poolclass=NullPool
|
||||||
|
)
|
||||||
|
|
||||||
|
async with database_postgres.connection() as connection:
|
||||||
|
await connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||||
|
await connection.execute(sa.text("create role pytest login"))
|
||||||
|
await connection.execute(sa.text("create database pytest owner pytest"))
|
||||||
|
await connection.commit()
|
||||||
|
|
||||||
|
await database_postgres.dispose()
|
||||||
|
|
||||||
|
config.database.user = "pytest"
|
||||||
|
config.database.name = "pytest"
|
||||||
|
database_pytest = await Database.new(config.database.url(), poolclass=NullPool)
|
||||||
|
|
||||||
|
yield database_pytest
|
||||||
|
|
||||||
|
await database_pytest.dispose()
|
||||||
|
|
||||||
|
async with database_postgres.connection() as connection:
|
||||||
|
await connection.execution_options(isolation_level="AUTOCOMMIT")
|
||||||
|
await connection.execute(sa.text("drop database pytest")),
|
||||||
|
await connection.execute(sa.text("drop role pytest"))
|
||||||
|
await connection.commit()
|
||||||
|
await database_postgres.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def cache(config: Config) -> Cache:
|
||||||
|
config_pytest = config
|
||||||
|
config_pytest.cache.user = "pytest"
|
||||||
|
cache_pytest = await Cache.new(config_pytest.cache.url())
|
||||||
|
|
||||||
|
yield cache_pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function", autouse=True)
|
||||||
|
async def setup_database(database: Database):
|
||||||
|
async with database.connection() as connection:
|
||||||
|
await connection.run_sync(Base.metadata.create_all)
|
||||||
|
await connection.commit()
|
||||||
|
yield
|
||||||
|
async with database.connection() as connection:
|
||||||
|
await connection.run_sync(Base.metadata.drop_all)
|
||||||
|
await connection.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture()
|
||||||
|
async def session(database: Database, request):
|
||||||
|
session = database.sessionmaker()
|
||||||
|
yield session
|
||||||
|
await session.rollback()
|
||||||
|
await session.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def data(config: Config):
|
||||||
|
class TestData:
|
||||||
|
user = User(
|
||||||
|
name="PyTest",
|
||||||
|
lower_name="pytest",
|
||||||
|
email="pytest@example.com",
|
||||||
|
hashed_password=security.hash_password(
|
||||||
|
"iampytest", algo=config.security.password_hash_algo
|
||||||
|
),
|
||||||
|
login_type=LoginType.Plain,
|
||||||
|
is_admin=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TestData()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def api_config(config: Config, tmpdir) -> Config:
|
||||||
|
config.application.working_directory = Path(tmpdir)
|
||||||
|
config.oauth2.jwt_secret = "pytest_secret_key"
|
||||||
|
|
||||||
|
yield config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def api_client(
|
||||||
|
api_config: Config, database: Database, cache: Cache
|
||||||
|
) -> AsyncClient:
|
||||||
|
|
||||||
|
logger = make_logger(api_config)
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
|
||||||
|
yield AppContext(
|
||||||
|
config=api_config, database=database, cache=cache, logger=logger
|
||||||
|
)
|
||||||
|
|
||||||
|
app = FastAPI(lifespan=lifespan)
|
||||||
|
app.include_router(routers.api.router)
|
||||||
|
app.include_router(routers.resources.router)
|
||||||
|
app.include_router(routers.root.router)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
async with LifespanManager(app) as manager:
|
||||||
|
async with AsyncClient(
|
||||||
|
transport=ASGITransport(app=manager.app), base_url=api_config.server.url()
|
||||||
|
) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="function")
|
||||||
|
async def auth_client(api_client: AsyncClient, api_config: Config) -> AsyncClient:
|
||||||
|
data = {"name": "PyTest", "password": "iampytest", "email": "pytest@example.com"}
|
||||||
|
|
||||||
|
await api_client.post(
|
||||||
|
"/api/auth/signup",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
auth = await api_client.post(
|
||||||
|
"/api/auth/signin",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
cookies = Cookies()
|
||||||
|
cookies.set(
|
||||||
|
"materia_at",
|
||||||
|
auth.cookies[api_config.security.cookie_access_token_name],
|
||||||
|
)
|
||||||
|
cookies.set(
|
||||||
|
"materia_rt",
|
||||||
|
auth.cookies[api_config.security.cookie_refresh_token_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
api_client.cookies = cookies
|
||||||
|
|
||||||
|
yield api_client
|
86
tests/test_api.py
Normal file
86
tests/test_api.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import pytest
|
||||||
|
from materia.config import Config
|
||||||
|
from httpx import AsyncClient, Cookies
|
||||||
|
from materia.models.base import Base
|
||||||
|
import aiofiles
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_auth(api_client: AsyncClient, api_config: Config):
|
||||||
|
data = {"name": "PyTest", "password": "iampytest", "email": "pytest@example.com"}
|
||||||
|
|
||||||
|
response = await api_client.post(
|
||||||
|
"/api/auth/signup",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response = await api_client.post(
|
||||||
|
"/api/auth/signin",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.cookies.get(api_config.security.cookie_access_token_name)
|
||||||
|
assert response.cookies.get(api_config.security.cookie_refresh_token_name)
|
||||||
|
|
||||||
|
# TODO: conflict usernames and emails
|
||||||
|
|
||||||
|
response = await api_client.get("/api/auth/signout")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user(auth_client: AsyncClient, api_config: Config):
|
||||||
|
info = await auth_client.get("/api/user")
|
||||||
|
assert info.status_code == 200, info.text
|
||||||
|
|
||||||
|
async with AsyncClient() as client:
|
||||||
|
pytest_logo_res = await client.get(
|
||||||
|
"https://docs.pytest.org/en/stable/_static/pytest1.png"
|
||||||
|
)
|
||||||
|
assert isinstance(pytest_logo_res.content, bytes)
|
||||||
|
pytest_logo = BytesIO(pytest_logo_res.content)
|
||||||
|
|
||||||
|
avatar = await auth_client.put(
|
||||||
|
"/api/user/avatar",
|
||||||
|
files={"file": ("pytest.png", pytest_logo)},
|
||||||
|
)
|
||||||
|
assert avatar.status_code == 200, avatar.text
|
||||||
|
|
||||||
|
info = await auth_client.get("/api/user")
|
||||||
|
avatar_info = info.json()["avatar"]
|
||||||
|
assert avatar_info is not None
|
||||||
|
assert api_config.application.working_directory.joinpath(
|
||||||
|
"avatars", avatar_info
|
||||||
|
).exists()
|
||||||
|
|
||||||
|
avatar = await auth_client.delete("/api/user/avatar")
|
||||||
|
assert avatar.status_code == 200, avatar.text
|
||||||
|
|
||||||
|
info = await auth_client.get("/api/user")
|
||||||
|
assert info.json()["avatar"] is None
|
||||||
|
assert not api_config.application.working_directory.joinpath(
|
||||||
|
"avatars", avatar_info
|
||||||
|
).exists()
|
||||||
|
|
||||||
|
delete = await auth_client.delete("/api/user")
|
||||||
|
assert delete.status_code == 200, delete.text
|
||||||
|
|
||||||
|
info = await auth_client.get("/api/user")
|
||||||
|
assert info.status_code == 401, info.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_repository(auth_client: AsyncClient, api_config: Config):
|
||||||
|
info = await auth_client.get("/api/repository")
|
||||||
|
assert info.status_code == 404, info.text
|
||||||
|
|
||||||
|
create = await auth_client.post("/api/repository")
|
||||||
|
assert create.status_code == 200, create.text
|
||||||
|
|
||||||
|
create = await auth_client.post("/api/repository")
|
||||||
|
assert create.status_code == 409, create.text
|
||||||
|
|
||||||
|
info = await auth_client.get("/api/repository")
|
||||||
|
assert info.status_code == 200, info.text
|
@ -1,145 +1,23 @@
|
|||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
import pytest
|
import pytest
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from materia.config import Config
|
from materia.config import Config
|
||||||
from materia.models import (
|
from materia.models import (
|
||||||
Database,
|
|
||||||
User,
|
User,
|
||||||
LoginType,
|
|
||||||
Repository,
|
Repository,
|
||||||
Directory,
|
Directory,
|
||||||
RepositoryError,
|
RepositoryError,
|
||||||
File,
|
File,
|
||||||
)
|
)
|
||||||
from materia.models.base import Base
|
|
||||||
from materia.models.database import SessionContext
|
from materia.models.database import SessionContext
|
||||||
from materia import security
|
from materia import security
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.pool import NullPool
|
|
||||||
from sqlalchemy.orm.session import make_transient
|
from sqlalchemy.orm.session import make_transient
|
||||||
from sqlalchemy import inspect
|
from sqlalchemy import inspect
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiofiles.os
|
import aiofiles.os
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
|
||||||
async def config() -> Config:
|
|
||||||
conf = Config()
|
|
||||||
conf.database.port = 54320
|
|
||||||
# conf.application.working_directory = conf.application.working_directory / "temp"
|
|
||||||
# if (cwd := conf.application.working_directory.resolve()).exists():
|
|
||||||
# os.chdir(cwd)
|
|
||||||
# if local_conf := Config.open(cwd / "config.toml"):
|
|
||||||
# conf = local_conf
|
|
||||||
return conf
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
|
||||||
async def db(config: Config, request) -> Database:
|
|
||||||
config_postgres = config
|
|
||||||
config_postgres.database.user = "postgres"
|
|
||||||
config_postgres.database.name = "postgres"
|
|
||||||
database_postgres = await Database.new(
|
|
||||||
config_postgres.database.url(), poolclass=NullPool
|
|
||||||
)
|
|
||||||
|
|
||||||
async with database_postgres.connection() as connection:
|
|
||||||
await connection.execution_options(isolation_level="AUTOCOMMIT")
|
|
||||||
await connection.execute(sa.text("create role pytest login"))
|
|
||||||
await connection.execute(sa.text("create database pytest owner pytest"))
|
|
||||||
await connection.commit()
|
|
||||||
|
|
||||||
await database_postgres.dispose()
|
|
||||||
|
|
||||||
config.database.user = "pytest"
|
|
||||||
config.database.name = "pytest"
|
|
||||||
database = await Database.new(config.database.url(), poolclass=NullPool)
|
|
||||||
|
|
||||||
yield database
|
|
||||||
|
|
||||||
await database.dispose()
|
|
||||||
|
|
||||||
async with database_postgres.connection() as connection:
|
|
||||||
await connection.execution_options(isolation_level="AUTOCOMMIT")
|
|
||||||
await connection.execute(sa.text("drop database pytest")),
|
|
||||||
await connection.execute(sa.text("drop role pytest"))
|
|
||||||
await connection.commit()
|
|
||||||
await database_postgres.dispose()
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_migrations(db):
|
|
||||||
await db.run_migrations()
|
|
||||||
await db.rollback_migrations()
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session", autouse=True)
|
|
||||||
async def setup_db(db: Database, request):
|
|
||||||
async with db.connection() as connection:
|
|
||||||
await connection.run_sync(Base.metadata.create_all)
|
|
||||||
await connection.commit()
|
|
||||||
yield
|
|
||||||
async with db.connection() as connection:
|
|
||||||
await connection.run_sync(Base.metadata.drop_all)
|
|
||||||
await connection.commit()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(autouse=True)
|
|
||||||
async def session(db: Database, request):
|
|
||||||
session = db.sessionmaker()
|
|
||||||
yield session
|
|
||||||
await session.rollback()
|
|
||||||
await session.close()
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
|
||||||
async def user(config: Config, session) -> User:
|
|
||||||
test_user = User(
|
|
||||||
name="pytest",
|
|
||||||
lower_name="pytest",
|
|
||||||
email="pytest@example.com",
|
|
||||||
hashed_password=security.hash_password(
|
|
||||||
"iampytest", algo=config.security.password_hash_algo
|
|
||||||
),
|
|
||||||
login_type=LoginType.Plain,
|
|
||||||
is_admin=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with db.session() as session:
|
|
||||||
session.add(test_user)
|
|
||||||
await session.flush()
|
|
||||||
await session.refresh(test_user)
|
|
||||||
|
|
||||||
yield test_user
|
|
||||||
|
|
||||||
async with db.session() as session:
|
|
||||||
await session.delete(test_user)
|
|
||||||
await session.flush()
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="function")
|
|
||||||
async def data(config: Config):
|
|
||||||
class TestData:
|
|
||||||
user = User(
|
|
||||||
name="PyTest",
|
|
||||||
lower_name="pytest",
|
|
||||||
email="pytest@example.com",
|
|
||||||
hashed_password=security.hash_password(
|
|
||||||
"iampytest", algo=config.security.password_hash_algo
|
|
||||||
),
|
|
||||||
login_type=LoginType.Plain,
|
|
||||||
is_admin=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
return TestData()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_user(data, session: SessionContext, config: Config):
|
async def test_user(data, session: SessionContext, config: Config):
|
||||||
# simple
|
# simple
|
||||||
@ -161,6 +39,10 @@ async def test_user(data, session: SessionContext, config: Config):
|
|||||||
|
|
||||||
await data.user.edit_name("AsyncPyTest", session)
|
await data.user.edit_name("AsyncPyTest", session)
|
||||||
assert await User.by_name("asyncpytest", session, with_lower=True) == data.user
|
assert await User.by_name("asyncpytest", session, with_lower=True) == data.user
|
||||||
|
assert await User.by_email("pytest@example.com", session) == data.user
|
||||||
|
assert await User.by_id(data.user.id, session) == data.user
|
||||||
|
await data.user.edit_password("iamnotpytest", session, config)
|
||||||
|
assert security.validate_password("iamnotpytest", data.user.hashed_password)
|
||||||
|
|
||||||
await data.user.remove(session)
|
await data.user.remove(session)
|
||||||
|
|
||||||
@ -280,9 +162,9 @@ async def test_directory(data, tmpdir, session: SessionContext, config: Config):
|
|||||||
|
|
||||||
# rename
|
# rename
|
||||||
assert (await directory.rename("test1", session, config)).name == "test1"
|
assert (await directory.rename("test1", session, config)).name == "test1"
|
||||||
directory2 = await Directory(
|
await Directory(repository_id=repository.id, parent_id=None, name="test2").new(
|
||||||
repository_id=repository.id, parent_id=None, name="test2"
|
session, config
|
||||||
).new(session, config)
|
)
|
||||||
assert (await directory.rename("test2", session, config)).name == "test2.1"
|
assert (await directory.rename("test2", session, config)).name == "test2.1"
|
||||||
assert (await repository.path(session, config)).joinpath("test2.1").exists()
|
assert (await repository.path(session, config)).joinpath("test2.1").exists()
|
||||||
assert not (await repository.path(session, config)).joinpath("test1").exists()
|
assert not (await repository.path(session, config)).joinpath("test1").exists()
|
||||||
@ -358,7 +240,7 @@ async def test_file(data, tmpdir, session: SessionContext, config: Config):
|
|||||||
assert (
|
assert (
|
||||||
await file.rename("test_file_rename.txt", session, config)
|
await file.rename("test_file_rename.txt", session, config)
|
||||||
).name == "test_file_rename.txt"
|
).name == "test_file_rename.txt"
|
||||||
file2 = await File(
|
await File(
|
||||||
repository_id=repository.id, parent_id=directory.id, name="test_file_2.txt"
|
repository_id=repository.id, parent_id=directory.id, name="test_file_2.txt"
|
||||||
).new(b"", session, config)
|
).new(b"", session, config)
|
||||||
assert (
|
assert (
|
Loading…
x
Reference in New Issue
Block a user