materia-server: add tests

This commit is contained in:
L-Nafaryus 2024-07-25 13:33:05 +05:00
parent 577f6f3ddf
commit 850bb89346
Signed by: L-Nafaryus
GPG Key ID: 553C97999B363D38
7 changed files with 366 additions and 185 deletions

View File

@ -19,8 +19,7 @@
}: let }: let
system = "x86_64-linux"; system = "x86_64-linux";
pkgs = import nixpkgs {inherit system;}; pkgs = import nixpkgs {inherit system;};
bonpkgs = bonfire.packages.${system}; bonLib = bonfire.lib;
bonlib = bonfire.lib;
dreamBuildPackage = { dreamBuildPackage = {
module, module,
@ -77,7 +76,7 @@
meta = with nixpkgs.lib; { meta = with nixpkgs.lib; {
description = "Materia frontend"; description = "Materia frontend";
license = licenses.mit; license = licenses.mit;
maintainers = with bonlib.maintainers; [L-Nafaryus]; maintainers = with bonLib.maintainers; [L-Nafaryus];
broken = false; broken = false;
}; };
}; };
@ -115,7 +114,7 @@
meta = with nixpkgs.lib; { meta = with nixpkgs.lib; {
description = "Materia web client"; description = "Materia web client";
license = licenses.mit; license = licenses.mit;
maintainers = with bonlib.maintainers; [L-Nafaryus]; maintainers = with bonLib.maintainers; [L-Nafaryus];
broken = false; broken = false;
}; };
}; };
@ -150,96 +149,15 @@
meta = with nixpkgs.lib; { meta = with nixpkgs.lib; {
description = "Materia"; description = "Materia";
license = licenses.mit; license = licenses.mit;
maintainers = with bonlib.maintainers; [L-Nafaryus]; maintainers = with bonLib.maintainers; [L-Nafaryus];
broken = false; broken = false;
mainProgram = "materia-server"; mainProgram = "materia-server";
}; };
}; };
postgresql = let postgresql-devel = bonfire.packages.x86_64-linux.postgresql;
user = "postgres";
database = "postgres";
dataDir = "/var/lib/postgresql";
entryPoint = pkgs.writeTextDir "entrypoint.sh" ''
initdb -U ${user}
postgres -k ${dataDir}
'';
in
pkgs.dockerTools.buildImage {
name = "postgresql";
tag = "devel";
copyToRoot = pkgs.buildEnv { redis-devel = bonfire.packages.x86_64-linux.redis;
name = "image-root";
pathsToLink = ["/bin" "/etc" "/"];
paths = with pkgs; [
bash
postgresql
entryPoint
];
};
runAsRoot = with pkgs; ''
#!${runtimeShell}
${dockerTools.shadowSetup}
groupadd -r ${user}
useradd -r -g ${user} --home-dir=${dataDir} ${user}
mkdir -p ${dataDir}
chown -R ${user}:${user} ${dataDir}
'';
config = {
Entrypoint = ["bash" "/entrypoint.sh"];
StopSignal = "SIGINT";
User = "${user}:${user}";
Env = ["PGDATA=${dataDir}"];
WorkingDir = dataDir;
ExposedPorts = {
"5432/tcp" = {};
};
};
};
redis = let
user = "redis";
dataDir = "/var/lib/redis";
entryPoint = pkgs.writeTextDir "entrypoint.sh" ''
redis-server \
--daemonize no \
--dir "${dataDir}"
'';
in
pkgs.dockerTools.buildImage {
name = "redis";
tag = "devel";
copyToRoot = pkgs.buildEnv {
name = "image-root";
pathsToLink = ["/bin" "/etc" "/"];
paths = with pkgs; [
bash
redis
entryPoint
];
};
runAsRoot = with pkgs; ''
#!${runtimeShell}
${dockerTools.shadowSetup}
groupadd -r ${user}
useradd -r -g ${user} --home-dir=${dataDir} ${user}
mkdir -p ${dataDir}
chown -R ${user}:${user} ${dataDir}
'';
config = {
Entrypoint = ["bash" "/entrypoint.sh"];
StopSignal = "SIGINT";
User = "${user}:${user}";
WorkingDir = dataDir;
ExposedPorts = {
"6379/tcp" = {};
};
};
};
}; };
apps.x86_64-linux = { apps.x86_64-linux = {

View File

@ -5,7 +5,7 @@
groups = ["default", "dev"] groups = ["default", "dev"]
strategy = ["cross_platform", "inherit_metadata"] strategy = ["cross_platform", "inherit_metadata"]
lock_version = "4.4.1" lock_version = "4.4.1"
content_hash = "sha256:4d8864659da597f26a1c544eaaba475fa1deb061210a05bf509dd0f6cc5fb11c" content_hash = "sha256:6bbe412ab2d74821a30f7deab8c2fe796e6a807a5d3009934c8b88364f8dc4b6"
[[package]] [[package]]
name = "aiosmtplib" name = "aiosmtplib"
@ -1259,6 +1259,20 @@ files = [
{file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"}, {file = "pytest-7.4.4.tar.gz", hash = "sha256:2cf0005922c6ace4a3e2ec8b4080eb0d9753fdc93107415332f50ce9e7994280"},
] ]
[[package]]
name = "pytest-asyncio"
version = "0.23.7"
requires_python = ">=3.8"
summary = "Pytest support for asyncio"
groups = ["dev"]
dependencies = [
"pytest<9,>=7.0.0",
]
files = [
{file = "pytest_asyncio-0.23.7-py3-none-any.whl", hash = "sha256:009b48127fbe44518a547bddd25611551b0e43ccdbf1e67d12479f569832c20b"},
{file = "pytest_asyncio-0.23.7.tar.gz", hash = "sha256:5f5c72948f4c49e7db4f29f2521d4031f1c27f86e57b046126654083d4770268"},
]
[[package]] [[package]]
name = "python-dateutil" name = "python-dateutil"
version = "2.9.0.post0" version = "2.9.0.post0"

View File

@ -36,9 +36,6 @@ requires-python = "<3.12,>=3.10"
readme = "README.md" readme = "README.md"
license = {text = "MIT"} license = {text = "MIT"}
[tool.pdm.build]
includes = ["src/materia_server"]
[build-system] [build-system]
requires = ["pdm-backend"] requires = ["pdm-backend"]
build-backend = "pdm.backend" build-backend = "pdm.backend"
@ -46,13 +43,6 @@ build-backend = "pdm.backend"
[project.scripts] [project.scripts]
materia-server = "materia_server.main:server" materia-server = "materia_server.main:server"
[tool.pdm.scripts]
start-server.cmd = "python ./src/materia_server/main.py {args:start --app-mode development --log-level debug}"
db-upgrade.cmd = "alembic -c ./src/materia_server/alembic.ini upgrade {args:head}"
db-downgrade.shell = "alembic -c ./src/materia_server/alembic.ini downgrade {args:base}"
db-revision.cmd = "alembic revision {args:--autogenerate}"
remove-revisions.shell = "rm -v ./src/materia_server/models/migrations/versions/*.py"
[tool.pyright] [tool.pyright]
reportGeneralTypeIssues = false reportGeneralTypeIssues = false
@ -61,8 +51,18 @@ pythonpath = ["."]
testpaths = ["tests"] testpaths = ["tests"]
[tool.pdm] [tool.pdm]
distribution = true distribution = true
[tool.pdm.build]
includes = ["src/materia_server"]
[tool.pdm.scripts]
start-server.cmd = "python ./src/materia_server/main.py {args:start --app-mode development --log-level debug}"
db-upgrade.cmd = "alembic -c ./src/materia_server/alembic.ini upgrade {args:head}"
db-downgrade.shell = "alembic -c ./src/materia_server/alembic.ini downgrade {args:base}"
db-revision.cmd = "alembic revision {args:--autogenerate}"
remove-revisions.shell = "rm -v ./src/materia_server/models/migrations/versions/*.py"
[tool.pdm.dev-dependencies] [tool.pdm.dev-dependencies]
dev = [ dev = [
@ -70,6 +70,7 @@ dev = [
"pytest<8.0.0,>=7.3.2", "pytest<8.0.0,>=7.3.2",
"pyflakes<4.0.0,>=3.0.1", "pyflakes<4.0.0,>=3.0.1",
"pyright<2.0.0,>=1.1.314", "pyright<2.0.0,>=1.1.314",
"pytest-asyncio>=0.23.7",
] ]

View File

@ -1,12 +1,20 @@
from os import environ from os import environ
from pathlib import Path from pathlib import Path
import sys import sys
from typing import Any, Literal, Optional, Self, Union from typing import Any, Literal, Optional, Self, Union
from pydantic import BaseModel, Field, HttpUrl, model_validator, TypeAdapter, PostgresDsn, NameEmail from pydantic import (
BaseModel,
Field,
HttpUrl,
model_validator,
TypeAdapter,
PostgresDsn,
NameEmail,
)
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
from pydantic.networks import IPvAnyAddress from pydantic.networks import IPvAnyAddress
import toml import toml
class Application(BaseModel): class Application(BaseModel):
@ -15,53 +23,61 @@ class Application(BaseModel):
mode: Literal["production", "development"] = "production" mode: Literal["production", "development"] = "production"
working_directory: Optional[Path] = Path.cwd() working_directory: Optional[Path] = Path.cwd()
class Log(BaseModel): class Log(BaseModel):
mode: Literal["console", "file", "all"] = "console" mode: Literal["console", "file", "all"] = "console"
level: Literal["info", "warning", "error", "critical", "debug", "trace"] = "info" level: Literal["info", "warning", "error", "critical", "debug", "trace"] = "info"
console_format: str = "<level>{level: <8}</level> <green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> - {message}" console_format: str = (
file_format: str = "<level>{level: <8}</level>: <green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> - {message}" "<level>{level: <8}</level> <green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> - {message}"
)
file_format: str = (
"<level>{level: <8}</level>: <green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> - {message}"
)
file: Optional[Path] = None file: Optional[Path] = None
file_rotation: str = "3 days" file_rotation: str = "3 days"
file_retention: str = "1 week" file_retention: str = "1 week"
class Server(BaseModel): class Server(BaseModel):
scheme: Literal["http", "https"] = "http" scheme: Literal["http", "https"] = "http"
address: IPvAnyAddress = Field(default = "127.0.0.1") address: IPvAnyAddress = Field(default="127.0.0.1")
port: int = 54601 port: int = 54601
domain: str = "localhost" domain: str = "localhost"
class Database(BaseModel): class Database(BaseModel):
backend: Literal["postgresql"] = "postgresql" backend: Literal["postgresql"] = "postgresql"
scheme: Literal["postgresql+asyncpg"] = "postgresql+asyncpg" scheme: Literal["postgresql+asyncpg"] = "postgresql+asyncpg"
address: IPvAnyAddress = Field(default = "127.0.0.1") address: IPvAnyAddress = Field(default="127.0.0.1")
port: int = 5432 port: int = 5432
name: str = "materia" name: Optional[str] = "materia"
user: str = "materia" user: str = "materia"
password: Optional[Union[str, Path]] = None password: Optional[Union[str, Path]] = None
# ssl: bool = False # ssl: bool = False
def url(self) -> str: def url(self) -> str:
if self.backend in ["postgresql"]: if self.backend in ["postgresql"]:
return "{}://{}:{}@{}:{}/{}".format( return (
self.scheme, "{}://{}:{}@{}:{}".format(
self.user, self.scheme, self.user, self.password, self.address, self.port
self.password, )
self.address, + f"/{self.name}"
self.port, if self.name
self.name else ""
) )
else: else:
raise NotImplemented() raise NotImplementedError()
class Cache(BaseModel): class Cache(BaseModel):
backend: Literal["redis"] = "redis" # add: memory backend: Literal["redis"] = "redis" # add: memory
# gc_interval: Optional[int] = 60 # for: memory # gc_interval: Optional[int] = 60 # for: memory
scheme: Literal["redis", "rediss"] = "redis" scheme: Literal["redis", "rediss"] = "redis"
address: Optional[IPvAnyAddress] = Field(default = "127.0.0.1") address: Optional[IPvAnyAddress] = Field(default="127.0.0.1")
port: Optional[int] = 6379 port: Optional[int] = 6379
user: Optional[str] = None user: Optional[str] = None
password: Optional[Union[str, Path]] = None password: Optional[Union[str, Path]] = None
database: Optional[int] = 0 # for: redis database: Optional[int] = 0 # for: redis
def url(self) -> str: def url(self) -> str:
if self.backend in ["redis"]: if self.backend in ["redis"]:
@ -72,38 +88,39 @@ class Cache(BaseModel):
self.password, self.password,
self.address, self.address,
self.port, self.port,
self.database self.database,
) )
else: else:
return "{}://{}:{}/{}".format( return "{}://{}:{}/{}".format(
self.scheme, self.scheme, self.address, self.port, self.database
self.address,
self.port,
self.database
) )
else: else:
raise NotImplemented() raise NotImplemented()
class Security(BaseModel): class Security(BaseModel):
secret_key: Optional[Union[str, Path]] = None secret_key: Optional[Union[str, Path]] = None
password_min_length: int = 8 password_min_length: int = 8
password_hash_algo: Literal["bcrypt"] = "bcrypt" password_hash_algo: Literal["bcrypt"] = "bcrypt"
cookie_http_only: bool = True cookie_http_only: bool = True
cookie_access_token_name: str = "materia_at" cookie_access_token_name: str = "materia_at"
cookie_refresh_token_name: str = "materia_rt" cookie_refresh_token_name: str = "materia_rt"
class OAuth2(BaseModel): class OAuth2(BaseModel):
enabled: bool = True enabled: bool = True
jwt_signing_algo: Literal["HS256"] = "HS256" jwt_signing_algo: Literal["HS256"] = "HS256"
# check if signing algo need a key or generate it | HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512, EdDSA # check if signing algo need a key or generate it | HS256, HS384, HS512, RS256, RS384, RS512, ES256, ES384, ES512, EdDSA
jwt_signing_key: Optional[Union[str, Path]] = None jwt_signing_key: Optional[Union[str, Path]] = None
jwt_secret: Optional[Union[str, Path]] = None # only for HS256, HS384, HS512 | generate jwt_secret: Optional[Union[str, Path]] = (
access_token_lifetime: int = 3600 None # only for HS256, HS384, HS512 | generate
)
access_token_lifetime: int = 3600
refresh_token_lifetime: int = 730 * 60 refresh_token_lifetime: int = 730 * 60
refresh_token_validation: bool = False refresh_token_validation: bool = False
#@model_validator(mode = "after") # @model_validator(mode = "after")
#def check(self) -> Self: # def check(self) -> Self:
# if self.jwt_signing_algo in ["HS256", "HS384", "HS512"]: # if self.jwt_signing_algo in ["HS256", "HS384", "HS512"]:
# assert self.jwt_secret is not None, "JWT secret must be set for HS256, HS384, HS512 algorithms" # assert self.jwt_secret is not None, "JWT secret must be set for HS256, HS384, HS512 algorithms"
# else: # else:
@ -113,12 +130,12 @@ class OAuth2(BaseModel):
class Mailer(BaseModel): class Mailer(BaseModel):
enabled: bool = False enabled: bool = False
scheme: Optional[Literal["smtp", "smtps", "smtp+starttls"]] = None scheme: Optional[Literal["smtp", "smtps", "smtp+starttls"]] = None
address: Optional[IPvAnyAddress] = None address: Optional[IPvAnyAddress] = None
port: Optional[int] = None port: Optional[int] = None
helo: bool = True helo: bool = True
cert_file: Optional[Path] = None cert_file: Optional[Path] = None
key_file: Optional[Path] = None key_file: Optional[Path] = None
@ -127,22 +144,25 @@ class Mailer(BaseModel):
password: Optional[str] = None password: Optional[str] = None
plain_text: bool = False plain_text: bool = False
class Cron(BaseModel): class Cron(BaseModel):
pass pass
class Repository(BaseModel): class Repository(BaseModel):
capacity: int = 41943040 capacity: int = 41943040
class Config(BaseSettings, env_prefix = "materia_", env_nested_delimiter = "_"):
class Config(BaseSettings, env_prefix="materia_", env_nested_delimiter="_"):
application: Application = Application() application: Application = Application()
log: Log = Log() log: Log = Log()
server: Server = Server() server: Server = Server()
database: Database = Database() database: Database = Database()
cache: Cache = Cache() cache: Cache = Cache()
security: Security = Security() security: Security = Security()
oauth2: OAuth2 = OAuth2() oauth2: OAuth2 = OAuth2()
mailer: Mailer = Mailer() mailer: Mailer = Mailer()
cron: Cron = Cron() cron: Cron = Cron()
repository: Repository = Repository() repository: Repository = Repository()
@staticmethod @staticmethod
@ -151,7 +171,7 @@ class Config(BaseSettings, env_prefix = "materia_", env_nested_delimiter = "_"):
data: dict = toml.load(path) data: dict = toml.load(path)
except Exception as e: except Exception as e:
raise e raise e
#return None # return None
else: else:
return Config(**data) return Config(**data)
@ -163,7 +183,7 @@ class Config(BaseSettings, env_prefix = "materia_", env_nested_delimiter = "_"):
for key_second in dump[key_first].keys(): for key_second in dump[key_first].keys():
if isinstance(dump[key_first][key_second], Path): if isinstance(dump[key_first][key_second], Path):
dump[key_first][key_second] = str(dump[key_first][key_second]) dump[key_first][key_second] = str(dump[key_first][key_second])
with open(path, "w") as file: with open(path, "w") as file:
toml.dump(dump, file) toml.dump(dump, file)
@ -174,7 +194,3 @@ class Config(BaseSettings, env_prefix = "materia_", env_nested_delimiter = "_"):
return cwd / "temp" return cwd / "temp"
else: else:
return cwd return cwd

View File

@ -1,4 +1,3 @@
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import declarative_base
Base = declarative_base() Base = declarative_base()

View File

@ -4,7 +4,14 @@ from typing import AsyncIterator, Self
from pathlib import Path from pathlib import Path
from pydantic import BaseModel, PostgresDsn from pydantic import BaseModel, PostgresDsn
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import (
AsyncConnection,
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.pool import NullPool
from asyncpg import Connection from asyncpg import Connection
from alembic.config import Config as AlembicConfig from alembic.config import Config as AlembicConfig
from alembic.operations import Operations from alembic.operations import Operations
@ -14,42 +21,52 @@ from alembic.script.base import ScriptDirectory
from materia_server.config import Config from materia_server.config import Config
from materia_server.models.base import Base from materia_server.models.base import Base
__all__ = [ "Database" ] __all__ = ["Database"]
class DatabaseError(Exception): class DatabaseError(Exception):
pass pass
class DatabaseMigrationError(Exception): class DatabaseMigrationError(Exception):
pass pass
class Database: class Database:
def __init__(self, url: PostgresDsn, engine: AsyncEngine, sessionmaker: async_sessionmaker[AsyncSession]): def __init__(
self.url: PostgresDsn = url self,
url: PostgresDsn,
engine: AsyncEngine,
sessionmaker: async_sessionmaker[AsyncSession],
):
self.url: PostgresDsn = url
self.engine: AsyncEngine = engine self.engine: AsyncEngine = engine
self.sessionmaker: async_sessionmaker[AsyncSession] = sessionmaker self.sessionmaker: async_sessionmaker[AsyncSession] = sessionmaker
@staticmethod @staticmethod
async def new( async def new(
url: PostgresDsn, url: PostgresDsn,
pool_size: int = 100, pool_size: int = 100,
autocommit: bool = False, poolclass=None,
autoflush: bool = False, autocommit: bool = False,
expire_on_commit: bool = False, autoflush: bool = False,
test_connection: bool = True expire_on_commit: bool = False,
) -> Self: test_connection: bool = True,
engine = create_async_engine(str(url), pool_size = pool_size) ) -> Self:
engine_options = {"pool_size": pool_size}
if poolclass == NullPool:
engine_options = {"poolclass": NullPool}
engine = create_async_engine(str(url), **engine_options)
sessionmaker = async_sessionmaker( sessionmaker = async_sessionmaker(
bind = engine, bind=engine,
autocommit = autocommit, autocommit=autocommit,
autoflush = autoflush, autoflush=autoflush,
expire_on_commit = expire_on_commit expire_on_commit=expire_on_commit,
) )
database = Database( database = Database(url=url, engine=engine, sessionmaker=sessionmaker)
url = url,
engine = engine,
sessionmaker = sessionmaker
)
if test_connection: if test_connection:
try: try:
@ -63,38 +80,42 @@ class Database:
async def dispose(self): async def dispose(self):
await self.engine.dispose() await self.engine.dispose()
@asynccontextmanager @asynccontextmanager
async def connection(self) -> AsyncIterator[AsyncConnection]: async def connection(self) -> AsyncIterator[AsyncConnection]:
async with self.engine.begin() as connection: async with self.engine.connect() as connection:
try: try:
yield connection yield connection
except Exception as e: except Exception as e:
await connection.rollback() await connection.rollback()
raise DatabaseError(f"{e}") raise DatabaseError(f"{e}")
@asynccontextmanager @asynccontextmanager
async def session(self) -> AsyncIterator[AsyncSession]: async def session(self) -> AsyncIterator[AsyncSession]:
session = self.sessionmaker(); session = self.sessionmaker()
try: try:
yield session yield session
except Exception as e: except Exception as e:
await session.rollback() await session.rollback()
raise DatabaseError(f"{e}") raise DatabaseError(f"{e}")
finally: finally:
await session.close() await session.close()
def run_sync_migrations(self, connection: Connection): def run_sync_migrations(self, connection: Connection):
aconfig = AlembicConfig() aconfig = AlembicConfig()
aconfig.set_main_option("sqlalchemy.url", str(self.url)) aconfig.set_main_option("sqlalchemy.url", str(self.url))
aconfig.set_main_option("script_location", str(Path(__file__).parent.parent.joinpath("migrations"))) aconfig.set_main_option(
"script_location", str(Path(__file__).parent.parent.joinpath("migrations"))
)
context = MigrationContext.configure( context = MigrationContext.configure(
connection = connection, # type: ignore connection=connection, # type: ignore
opts = { opts={
"target_metadata": Base.metadata, "target_metadata": Base.metadata,
"fn": lambda rev, _: ScriptDirectory.from_config(aconfig)._upgrade_revs("head", rev) "fn": lambda rev, _: ScriptDirectory.from_config(aconfig)._upgrade_revs(
} "head", rev
),
},
) )
try: try:
@ -106,5 +127,32 @@ class Database:
async def run_migrations(self): async def run_migrations(self):
async with self.connection() as connection: async with self.connection() as connection:
await connection.run_sync(self.run_sync_migrations) # type: ignore await connection.run_sync(self.run_sync_migrations) # type: ignore
def rollback_sync_migrations(self, connection: Connection):
aconfig = AlembicConfig()
aconfig.set_main_option("sqlalchemy.url", str(self.url))
aconfig.set_main_option(
"script_location", str(Path(__file__).parent.parent.joinpath("migrations"))
)
context = MigrationContext.configure(
connection=connection, # type: ignore
opts={
"target_metadata": Base.metadata,
"fn": lambda rev, _: ScriptDirectory.from_config(
aconfig
)._downgrade_revs("base", rev),
},
)
try:
with context.begin_transaction():
with Operations.context(context):
context.run_migrations()
except Exception as e:
raise DatabaseMigrationError(f"{e}")
async def rollback_migrations(self):
async with self.connection() as connection:
await connection.run_sync(self.rollback_sync_migrations) # type: ignore

View File

@ -0,0 +1,185 @@
import pytest_asyncio
import pytest
import os
from materia_server.config import Config
from materia_server.models import Database, User, LoginType, Repository, Directory
from materia_server import security
import sqlalchemy as sa
from sqlalchemy.pool import NullPool
from dataclasses import dataclass
@pytest_asyncio.fixture(scope="session")
async def config() -> Config:
conf = Config()
conf.database.port = 54320
# conf.application.working_directory = conf.application.working_directory / "temp"
# if (cwd := conf.application.working_directory.resolve()).exists():
# os.chdir(cwd)
# if local_conf := Config.open(cwd / "config.toml"):
# conf = local_conf
return conf
@pytest_asyncio.fixture(scope="session")
async def db(config: Config, request) -> Database:
config_postgres = config
config_postgres.database.user = "postgres"
config_postgres.database.name = "postgres"
database_postgres = await Database.new(
config_postgres.database.url(), poolclass=NullPool
)
async with database_postgres.connection() as connection:
await connection.execution_options(isolation_level="AUTOCOMMIT")
await connection.execute(sa.text("create role pytest login"))
await connection.execute(sa.text("create database pytest owner pytest"))
await connection.commit()
await database_postgres.dispose()
config.database.user = "pytest"
config.database.name = "pytest"
database = await Database.new(config.database.url(), poolclass=NullPool)
yield database
await database.dispose()
# database_postgres = await Database.new(config_postgres.database.url())
async with database_postgres.connection() as connection:
await connection.execution_options(isolation_level="AUTOCOMMIT")
await connection.execute(sa.text("drop database pytest")),
await connection.execute(sa.text("drop role pytest"))
await connection.commit()
await database_postgres.dispose()
@pytest_asyncio.fixture(scope="session", autouse=True)
async def setup_db(db: Database, request):
await db.run_migrations()
yield
# await db.rollback_migrations()
@pytest_asyncio.fixture(autouse=True)
async def session(db: Database, request):
session = db.sessionmaker()
yield session
await session.rollback()
await session.close()
@pytest_asyncio.fixture(scope="session")
async def user(config: Config, session) -> User:
test_user = User(
name="pytest",
lower_name="pytest",
email="pytest@example.com",
hashed_password=security.hash_password(
"iampytest", algo=config.security.password_hash_algo
),
login_type=LoginType.Plain,
is_admin=True,
)
async with db.session() as session:
session.add(test_user)
await session.flush()
await session.refresh(test_user)
yield test_user
async with db.session() as session:
await session.delete(test_user)
await session.flush()
@pytest_asyncio.fixture
async def data(config: Config):
class TestData:
user = User(
name="pytest",
lower_name="pytest",
email="pytest@example.com",
hashed_password=security.hash_password(
"iampytest", algo=config.security.password_hash_algo
),
login_type=LoginType.Plain,
is_admin=True,
)
return TestData()
@pytest.mark.asyncio
async def test_user(data, session):
session.add(data.user)
await session.flush()
assert data.user.id is not None
assert security.validate_password("iampytest", data.user.hashed_password)
@pytest.mark.asyncio
async def test_repository(data, session, config):
session.add(data.user)
await session.flush()
repository = Repository(user_id=data.user.id, capacity=config.repository.capacity)
session.add(repository)
await session.flush()
assert repository.id is not None
@pytest.mark.asyncio
async def test_directory(data, session, config):
session.add(data.user)
await session.flush()
repository = Repository(user_id=data.user.id, capacity=config.repository.capacity)
session.add(repository)
await session.flush()
directory = Directory(
repository_id=repository.id, parent_id=None, name="test1", path=None
)
session.add(directory)
await session.flush()
assert directory.id is not None
assert (
await session.scalars(
sa.select(Directory).where(
sa.and_(
Directory.repository_id == repository.id,
Directory.name == "test1",
Directory.path.is_(None),
)
)
)
).first() == directory
nested_directory = Directory(
repository_id=repository.id,
parent_id=directory.id,
name="test_nested",
path="test1",
)
session.add(nested_directory)
await session.flush()
assert nested_directory.id is not None
assert (
await session.scalars(
sa.select(Directory).where(
sa.and_(
Directory.repository_id == repository.id,
Directory.name == "test_nested",
Directory.path == "test1",
)
)
)
).first() == nested_directory
assert nested_directory.parent_id == directory.id