move some modules to core module
This commit is contained in:
parent
b3be3d25ee
commit
3637ea99a8
200
pdm.lock
generated
200
pdm.lock
generated
@ -5,7 +5,7 @@
|
||||
groups = ["default", "dev"]
|
||||
strategy = ["cross_platform", "inherit_metadata"]
|
||||
lock_version = "4.4.1"
|
||||
content_hash = "sha256:47f5e7de3c9bda99b31aadaaabcc4a7efe77f94ff969135bb278cabcb41d1e20"
|
||||
content_hash = "sha256:ba7a816a8bfe503b899a8eba3e5ca58e2d751a278c19b1c1db6e647f83fcd62d"
|
||||
|
||||
[[package]]
|
||||
name = "aiofiles"
|
||||
@ -71,6 +71,20 @@ files = [
|
||||
{file = "alembic_postgresql_enum-1.3.0.tar.gz", hash = "sha256:64d5de7ac2ea39433afd965b057ca882fb420eb5cd6a7db8e2b4d0e7e673cae1"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "amqp"
|
||||
version = "5.2.0"
|
||||
requires_python = ">=3.6"
|
||||
summary = "Low-level AMQP client for Python (fork of amqplib)."
|
||||
groups = ["default"]
|
||||
dependencies = [
|
||||
"vine<6.0.0,>=5.0.0",
|
||||
]
|
||||
files = [
|
||||
{file = "amqp-5.2.0-py3-none-any.whl", hash = "sha256:827cb12fb0baa892aad844fd95258143bce4027fdac4fccddbc43330fd281637"},
|
||||
{file = "amqp-5.2.0.tar.gz", hash = "sha256:a1ecff425ad063ad42a486c902807d1482311481c8ad95a72694b2975e75f7fd"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
version = "0.6.0"
|
||||
@ -179,6 +193,17 @@ files = [
|
||||
{file = "bcrypt-4.1.2.tar.gz", hash = "sha256:33313a1200a3ae90b75587ceac502b048b840fc69e7f7a0905b5f87fac7a1258"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "billiard"
|
||||
version = "4.2.0"
|
||||
requires_python = ">=3.7"
|
||||
summary = "Python multiprocessing fork with improvements and bugfixes"
|
||||
groups = ["default"]
|
||||
files = [
|
||||
{file = "billiard-4.2.0-py3-none-any.whl", hash = "sha256:07aa978b308f334ff8282bd4a746e681b3513db5c9a514cbdd810cbbdc19714d"},
|
||||
{file = "billiard-4.2.0.tar.gz", hash = "sha256:9a3c3184cb275aa17a732f93f65b20c525d3d9f253722d26a82194803ade5a2c"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "black"
|
||||
version = "23.12.1"
|
||||
@ -212,6 +237,28 @@ files = [
|
||||
{file = "cachetools-5.4.0.tar.gz", hash = "sha256:b8adc2e7c07f105ced7bc56dbb6dfbe7c4a00acce20e2227b3f355be89bc6827"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "celery"
|
||||
version = "5.4.0"
|
||||
requires_python = ">=3.8"
|
||||
summary = "Distributed Task Queue."
|
||||
groups = ["default"]
|
||||
dependencies = [
|
||||
"billiard<5.0,>=4.2.0",
|
||||
"click-didyoumean>=0.3.0",
|
||||
"click-plugins>=1.1.1",
|
||||
"click-repl>=0.2.0",
|
||||
"click<9.0,>=8.1.2",
|
||||
"kombu<6.0,>=5.3.4",
|
||||
"python-dateutil>=2.8.2",
|
||||
"tzdata>=2022.7",
|
||||
"vine<6.0,>=5.1.0",
|
||||
]
|
||||
files = [
|
||||
{file = "celery-5.4.0-py3-none-any.whl", hash = "sha256:369631eb580cf8c51a82721ec538684994f8277637edde2dfc0dacd73ed97f64"},
|
||||
{file = "celery-5.4.0.tar.gz", hash = "sha256:504a19140e8d3029d5acad88330c541d4c3f64c789d85f94756762d8bca7e706"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2024.7.4"
|
||||
@ -298,6 +345,48 @@ files = [
|
||||
{file = "click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click-didyoumean"
|
||||
version = "0.3.1"
|
||||
requires_python = ">=3.6.2"
|
||||
summary = "Enables git-like *did-you-mean* feature in click"
|
||||
groups = ["default"]
|
||||
dependencies = [
|
||||
"click>=7",
|
||||
]
|
||||
files = [
|
||||
{file = "click_didyoumean-0.3.1-py3-none-any.whl", hash = "sha256:5c4bb6007cfea5f2fd6583a2fb6701a22a41eb98957e63d0fac41c10e7c3117c"},
|
||||
{file = "click_didyoumean-0.3.1.tar.gz", hash = "sha256:4f82fdff0dbe64ef8ab2279bd6aa3f6a99c3b28c05aa09cbfc07c9d7fbb5a463"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click-plugins"
|
||||
version = "1.1.1"
|
||||
summary = "An extension module for click to enable registering CLI commands via setuptools entry-points."
|
||||
groups = ["default"]
|
||||
dependencies = [
|
||||
"click>=4.0",
|
||||
]
|
||||
files = [
|
||||
{file = "click-plugins-1.1.1.tar.gz", hash = "sha256:46ab999744a9d831159c3411bb0c79346d94a444df9a3a3742e9ed63645f264b"},
|
||||
{file = "click_plugins-1.1.1-py2.py3-none-any.whl", hash = "sha256:5d262006d3222f5057fd81e1623d4443e41dcda5dc815c06b442aa3c02889fc8"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "click-repl"
|
||||
version = "0.3.0"
|
||||
requires_python = ">=3.6"
|
||||
summary = "REPL plugin for Click"
|
||||
groups = ["default"]
|
||||
dependencies = [
|
||||
"click>=7.0",
|
||||
"prompt-toolkit>=3.0.36",
|
||||
]
|
||||
files = [
|
||||
{file = "click-repl-0.3.0.tar.gz", hash = "sha256:17849c23dba3d667247dc4defe1757fff98694e90fe37474f3feebb69ced26a9"},
|
||||
{file = "click_repl-0.3.0-py3-none-any.whl", hash = "sha256:fb7e06deb8da8de86180a33a9da97ac316751c094c6899382da7feeeeb51b812"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
@ -668,6 +757,21 @@ files = [
|
||||
{file = "jinja2-3.1.4.tar.gz", hash = "sha256:4a3aee7acbbe7303aede8e9648d13b8bf88a429282aa6122a993f0ac800cb369"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kombu"
|
||||
version = "5.4.0"
|
||||
requires_python = ">=3.8"
|
||||
summary = "Messaging library for Python."
|
||||
groups = ["default"]
|
||||
dependencies = [
|
||||
"amqp<6.0.0,>=5.1.1",
|
||||
"vine==5.1.0",
|
||||
]
|
||||
files = [
|
||||
{file = "kombu-5.4.0-py3-none-any.whl", hash = "sha256:c8dd99820467610b4febbc7a9e8a0d3d7da2d35116b67184418b51cc520ea6b6"},
|
||||
{file = "kombu-5.4.0.tar.gz", hash = "sha256:ad200a8dbdaaa2bbc5f26d2ee7d707d9a1fded353a0f4bd751ce8c7d9f449c60"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "loguru"
|
||||
version = "0.7.2"
|
||||
@ -913,6 +1017,20 @@ files = [
|
||||
{file = "premailer-3.10.0.tar.gz", hash = "sha256:d1875a8411f5dc92b53ef9f193db6c0f879dc378d618e0ad292723e388bfe4c2"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prompt-toolkit"
|
||||
version = "3.0.47"
|
||||
requires_python = ">=3.7.0"
|
||||
summary = "Library for building powerful interactive command lines in Python"
|
||||
groups = ["default"]
|
||||
dependencies = [
|
||||
"wcwidth",
|
||||
]
|
||||
files = [
|
||||
{file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"},
|
||||
{file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psycopg2-binary"
|
||||
version = "2.9.9"
|
||||
@ -1227,6 +1345,20 @@ files = [
|
||||
{file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smart-open"
|
||||
version = "7.0.4"
|
||||
requires_python = "<4.0,>=3.7"
|
||||
summary = "Utils for streaming large files (S3, HDFS, GCS, Azure Blob Storage, gzip, bz2...)"
|
||||
groups = ["default"]
|
||||
dependencies = [
|
||||
"wrapt",
|
||||
]
|
||||
files = [
|
||||
{file = "smart_open-7.0.4-py3-none-any.whl", hash = "sha256:4e98489932b3372595cddc075e6033194775165702887216b65eba760dfd8d47"},
|
||||
{file = "smart_open-7.0.4.tar.gz", hash = "sha256:62b65852bdd1d1d516839fcb1f6bc50cd0f16e05b4ec44b52f43d38bcb838524"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "sniffio"
|
||||
version = "1.3.1"
|
||||
@ -1310,6 +1442,19 @@ files = [
|
||||
{file = "starlette-0.37.2.tar.gz", hash = "sha256:9af890290133b79fc3db55474ade20f6220a364a0402e0b556e7cd5e1e093823"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "streaming-form-data"
|
||||
version = "1.16.0"
|
||||
requires_python = ">=3.8"
|
||||
summary = "Streaming parser for multipart/form-data"
|
||||
groups = ["default"]
|
||||
dependencies = [
|
||||
"smart-open>=6.0",
|
||||
]
|
||||
files = [
|
||||
{file = "streaming-form-data-1.16.0.tar.gz", hash = "sha256:cd95cde7a1e362c0f2b6e8bf2bcaf7339df1d4727b06de29968d010fcbbb9f5c"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.10.2"
|
||||
@ -1332,6 +1477,17 @@ files = [
|
||||
{file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tzdata"
|
||||
version = "2024.1"
|
||||
requires_python = ">=2"
|
||||
summary = "Provider of IANA time zone data"
|
||||
groups = ["default"]
|
||||
files = [
|
||||
{file = "tzdata-2024.1-py2.py3-none-any.whl", hash = "sha256:9068bc196136463f5245e51efda838afa15aaeca9903f49050dfa2679db4d252"},
|
||||
{file = "tzdata-2024.1.tar.gz", hash = "sha256:2674120f8d891909751c38abcdfd386ac0a5a1127954fbc332af6b5ceae07efd"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "urllib3"
|
||||
version = "2.2.2"
|
||||
@ -1412,6 +1568,17 @@ files = [
|
||||
{file = "uvloop-0.19.0.tar.gz", hash = "sha256:0246f4fd1bf2bf702e06b0d45ee91677ee5c31242f39aab4ea6fe0c51aedd0fd"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "vine"
|
||||
version = "5.1.0"
|
||||
requires_python = ">=3.6"
|
||||
summary = "Python promises."
|
||||
groups = ["default"]
|
||||
files = [
|
||||
{file = "vine-5.1.0-py3-none-any.whl", hash = "sha256:40fdf3c48b2cfe1c38a49e9ae2da6fda88e4794c810050a728bd7413811fb1dc"},
|
||||
{file = "vine-5.1.0.tar.gz", hash = "sha256:8b62e981d35c41049211cf62a0a1242d8c1ee9bd15bb196ce38aefd6799e61e0"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "watchfiles"
|
||||
version = "0.22.0"
|
||||
@ -1450,6 +1617,16 @@ files = [
|
||||
{file = "watchfiles-0.22.0.tar.gz", hash = "sha256:988e981aaab4f3955209e7e28c7794acdb690be1efa7f16f8ea5aba7ffdadacb"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wcwidth"
|
||||
version = "0.2.13"
|
||||
summary = "Measures the displayed width of unicode strings in a terminal"
|
||||
groups = ["default"]
|
||||
files = [
|
||||
{file = "wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859"},
|
||||
{file = "wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "websockets"
|
||||
version = "12.0"
|
||||
@ -1498,3 +1675,24 @@ files = [
|
||||
{file = "win32_setctime-1.1.0-py3-none-any.whl", hash = "sha256:231db239e959c2fe7eb1d7dc129f11172354f98361c4fa2d6d2d7e278baa8aad"},
|
||||
{file = "win32_setctime-1.1.0.tar.gz", hash = "sha256:15cf5750465118d6929ae4de4eb46e8edae9a5634350c01ba582df868e932cb2"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wrapt"
|
||||
version = "1.16.0"
|
||||
requires_python = ">=3.6"
|
||||
summary = "Module for decorators, wrappers and monkey patching."
|
||||
groups = ["default"]
|
||||
files = [
|
||||
{file = "wrapt-1.16.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b"},
|
||||
{file = "wrapt-1.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36"},
|
||||
{file = "wrapt-1.16.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73"},
|
||||
{file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809"},
|
||||
{file = "wrapt-1.16.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b"},
|
||||
{file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81"},
|
||||
{file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9"},
|
||||
{file = "wrapt-1.16.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c"},
|
||||
{file = "wrapt-1.16.0-cp312-cp312-win32.whl", hash = "sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc"},
|
||||
{file = "wrapt-1.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8"},
|
||||
{file = "wrapt-1.16.0-py3-none-any.whl", hash = "sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1"},
|
||||
{file = "wrapt-1.16.0.tar.gz", hash = "sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d"},
|
||||
]
|
||||
|
@ -36,6 +36,8 @@ dependencies = [
|
||||
"jinja2>=3.1.4",
|
||||
"aiofiles>=24.1.0",
|
||||
"aioshutil>=1.5",
|
||||
"Celery>=5.4.0",
|
||||
"streaming-form-data>=1.16.0",
|
||||
]
|
||||
requires-python = ">=3.12,<3.13"
|
||||
readme = "README.md"
|
||||
|
@ -1,3 +1,3 @@
|
||||
from materia.main import main
|
||||
from materia.app import cli
|
||||
|
||||
main()
|
||||
cli()
|
||||
|
@ -1,83 +0,0 @@
|
||||
import sys
|
||||
from typing import Sequence
|
||||
from loguru import logger
|
||||
from loguru._logger import Logger
|
||||
import logging
|
||||
import inspect
|
||||
|
||||
from materia.config import Config
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
level: str | int
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
frame, depth = inspect.currentframe(), 2
|
||||
while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__):
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth = depth, exception = record.exc_info).log(level, record.getMessage())
|
||||
|
||||
|
||||
def make_logger(config: Config, interceptions: Sequence[str] = ["uvicorn", "uvicorn.access", "uvicorn.error", "uvicorn.asgi", "fastapi"]) -> Logger:
|
||||
logger.remove()
|
||||
|
||||
if config.log.mode in ["console", "all"]:
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
enqueue = True,
|
||||
backtrace = True,
|
||||
level = config.log.level.upper(),
|
||||
format = config.log.console_format,
|
||||
filter = lambda record: record["level"].name in ["INFO", "WARNING", "DEBUG", "TRACE"]
|
||||
)
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
enqueue = True,
|
||||
backtrace = True,
|
||||
level = config.log.level.upper(),
|
||||
format = config.log.console_format,
|
||||
filter = lambda record: record["level"].name in ["ERROR", "CRITICAL"]
|
||||
)
|
||||
|
||||
if config.log.mode in ["file", "all"]:
|
||||
logger.add(
|
||||
str(config.log.file),
|
||||
rotation = config.log.file_rotation,
|
||||
retention = config.log.file_retention,
|
||||
enqueue = True,
|
||||
backtrace = True,
|
||||
level = config.log.level.upper(),
|
||||
format = config.log.file_format
|
||||
)
|
||||
|
||||
logging.basicConfig(handlers = [InterceptHandler()], level = logging.NOTSET, force = True)
|
||||
|
||||
for external_logger in interceptions:
|
||||
logging.getLogger(external_logger).handlers = [InterceptHandler()]
|
||||
|
||||
return logger # type: ignore
|
||||
|
||||
def uvicorn_log_config(config: Config) -> dict:
|
||||
return {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"handlers": {
|
||||
"default": {
|
||||
"class": "materia._logging.InterceptHandler"
|
||||
},
|
||||
"access": {
|
||||
"class": "materia._logging.InterceptHandler"
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn": {"handlers": ["default"], "level": config.log.level.upper(), "propagate": False},
|
||||
"uvicorn.error": {"level": config.log.level.upper()},
|
||||
"uvicorn.access": {"handlers": ["access"], "level": config.log.level.upper(), "propagate": False},
|
||||
},
|
||||
}
|
@ -1 +1,2 @@
|
||||
from materia.app.app import AppContext, make_lifespan, make_application
|
||||
from materia.app.app import Context, Application
|
||||
from materia.app.cli import cli
|
||||
|
@ -1,93 +1,159 @@
|
||||
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
|
||||
from os import environ
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pwd
|
||||
import sys
|
||||
from typing import AsyncIterator, TypedDict
|
||||
import click
|
||||
from typing import AsyncIterator, TypedDict, Self, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydanclick import from_pydantic
|
||||
import pydantic
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from materia import config as _config
|
||||
from materia.config import Config
|
||||
from materia._logging import make_logger, uvicorn_log_config, Logger
|
||||
from materia.models import (
|
||||
from materia.core import (
|
||||
Config,
|
||||
Logger,
|
||||
LoggerInstance,
|
||||
Database,
|
||||
DatabaseError,
|
||||
DatabaseMigrationError,
|
||||
Cache,
|
||||
CacheError,
|
||||
Cron,
|
||||
)
|
||||
from materia import routers
|
||||
|
||||
|
||||
class AppContext(TypedDict):
|
||||
class Context(TypedDict):
|
||||
config: Config
|
||||
logger: Logger
|
||||
logger: LoggerInstance
|
||||
database: Database
|
||||
cache: Cache
|
||||
|
||||
|
||||
def make_lifespan(config: Config, logger: Logger):
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[AppContext]:
|
||||
class ApplicationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Application:
|
||||
__instance__: Optional[Self] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Config,
|
||||
logger: LoggerInstance,
|
||||
database: Database,
|
||||
cache: Cache,
|
||||
cron: Cron,
|
||||
backend: FastAPI,
|
||||
):
|
||||
if Application.__instance__:
|
||||
raise ApplicationError("Cannot create multiple applications")
|
||||
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
self.database = database
|
||||
self.cache = cache
|
||||
self.cron = cron
|
||||
self.backend = backend
|
||||
|
||||
@staticmethod
|
||||
async def new(config: Config):
|
||||
if Application.__instance__:
|
||||
raise ApplicationError("Cannot create multiple applications")
|
||||
|
||||
logger = Logger.new(**config.log.model_dump())
|
||||
|
||||
# if user := config.application.user:
|
||||
# os.setuid(pwd.getpwnam(user).pw_uid)
|
||||
# if group := config.application.group:
|
||||
# os.setgid(pwd.getpwnam(user).pw_gid)
|
||||
logger.debug("Initializing application...")
|
||||
|
||||
try:
|
||||
logger.debug("Changing working directory")
|
||||
os.chdir(config.application.working_directory.resolve())
|
||||
except FileNotFoundError as e:
|
||||
logger.error("Failed to change working directory: {}", e)
|
||||
sys.exit()
|
||||
|
||||
try:
|
||||
logger.info("Connecting to database {}", config.database.url())
|
||||
database = await Database.new(config.database.url()) # type: ignore
|
||||
|
||||
logger.info("Running migrations")
|
||||
await database.run_migrations()
|
||||
|
||||
logger.info("Connecting to cache {}", config.cache.url())
|
||||
logger.info("Connecting to cache server {}", config.cache.url())
|
||||
cache = await Cache.new(config.cache.url()) # type: ignore
|
||||
except DatabaseError as e:
|
||||
logger.error(f"Failed to connect postgres: {e}")
|
||||
sys.exit()
|
||||
except DatabaseMigrationError as e:
|
||||
logger.error(f"Failed to run migrations: {e}")
|
||||
sys.exit()
|
||||
except CacheError as e:
|
||||
logger.error(f"Failed to connect redis: {e}")
|
||||
|
||||
logger.info("Prepairing cron")
|
||||
cron = Cron.new(
|
||||
config.cron.workers_count,
|
||||
backend_url=config.cache.url(),
|
||||
broker_url=config.cache.url(),
|
||||
)
|
||||
|
||||
logger.info("Running database migrations")
|
||||
await database.run_migrations()
|
||||
except Exception as e:
|
||||
logger.error(" ".join(e.args))
|
||||
sys.exit()
|
||||
|
||||
yield AppContext(config=config, database=database, cache=cache, logger=logger)
|
||||
try:
|
||||
import materia_frontend
|
||||
except ModuleNotFoundError:
|
||||
logger.warning(
|
||||
"`materia_frontend` is not installed. No user interface will be served."
|
||||
)
|
||||
|
||||
if database.engine is not None:
|
||||
await database.dispose()
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[Context]:
|
||||
yield Context(config=config, logger=logger, database=database, cache=cache)
|
||||
|
||||
return lifespan
|
||||
if database.engine is not None:
|
||||
await database.dispose()
|
||||
|
||||
backend = FastAPI(
|
||||
title="materia",
|
||||
version="0.1.0",
|
||||
docs_url="/api/docs",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
backend.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost", "http://localhost:5173"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
backend.include_router(routers.api.router)
|
||||
backend.include_router(routers.resources.router)
|
||||
backend.include_router(routers.root.router)
|
||||
|
||||
def make_application(config: Config, logger: Logger):
|
||||
try:
|
||||
import materia_frontend
|
||||
except ModuleNotFoundError:
|
||||
logger.warning(
|
||||
"`materia_frontend` is not installed. No user interface will be served."
|
||||
return Application(
|
||||
config=config,
|
||||
logger=logger,
|
||||
database=database,
|
||||
cache=cache,
|
||||
cron=cron,
|
||||
backend=backend,
|
||||
)
|
||||
|
||||
app = FastAPI(
|
||||
title="materia",
|
||||
version="0.1.0",
|
||||
docs_url="/api/docs",
|
||||
lifespan=make_lifespan(config, logger),
|
||||
)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["http://localhost", "http://localhost:5173"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.include_router(routers.api.router)
|
||||
app.include_router(routers.resources.router)
|
||||
app.include_router(routers.root.router)
|
||||
@staticmethod
|
||||
def instance() -> Optional[Self]:
|
||||
return Application.__instance__
|
||||
|
||||
return app
|
||||
async def start(self):
|
||||
self.logger.info(f"Spinning up cron workers [{self.config.cron.workers_count}]")
|
||||
self.cron.run_workers()
|
||||
|
||||
try:
|
||||
# uvicorn.run(
|
||||
# self.backend,
|
||||
# port=self.config.server.port,
|
||||
# host=str(self.config.server.address),
|
||||
# # reload = config.application.mode == "development",
|
||||
# log_config=Logger.uvicorn_config(self.config.log.level),
|
||||
# )
|
||||
uvicorn_config = uvicorn.Config(
|
||||
self.backend,
|
||||
port=self.config.server.port,
|
||||
host=str(self.config.server.address),
|
||||
log_config=Logger.uvicorn_config(self.config.log.level),
|
||||
)
|
||||
server = uvicorn.Server(uvicorn_config)
|
||||
|
||||
await server.serve()
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
self.logger.info("Exiting...")
|
||||
|
@ -1,55 +1,22 @@
|
||||
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
|
||||
from os import environ
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pwd
|
||||
import sys
|
||||
from typing import AsyncIterator, TypedDict
|
||||
import click
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydanclick import from_pydantic
|
||||
import pydantic
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from materia import config as _config
|
||||
from materia.config import Config
|
||||
from materia._logging import make_logger, uvicorn_log_config, Logger
|
||||
from materia.models import Database, DatabaseError, Cache
|
||||
from materia import routers
|
||||
from materia.app import make_application
|
||||
from materia.core.config import Config
|
||||
from materia.core.logging import Logger
|
||||
from materia.app import Application
|
||||
import asyncio
|
||||
|
||||
|
||||
@click.group()
|
||||
def main():
|
||||
def cli():
|
||||
pass
|
||||
|
||||
|
||||
@main.command()
|
||||
@click.option("--config_path", type=Path)
|
||||
@from_pydantic("application", _config.Application, prefix="app")
|
||||
@from_pydantic("log", _config.Log, prefix="log")
|
||||
def start(application: _config.Application, config_path: Path, log: _config.Log):
|
||||
config = Config()
|
||||
config.log = log
|
||||
logger = make_logger(config)
|
||||
|
||||
# if user := application.user:
|
||||
# os.setuid(pwd.getpwnam(user).pw_uid)
|
||||
# if group := application.group:
|
||||
# os.setgid(pwd.getpwnam(user).pw_gid)
|
||||
# TODO: merge cli options with config
|
||||
if working_directory := (
|
||||
application.working_directory or config.application.working_directory
|
||||
).resolve():
|
||||
try:
|
||||
os.chdir(working_directory)
|
||||
except FileNotFoundError as e:
|
||||
logger.error("Failed to change working directory: {}", e)
|
||||
sys.exit()
|
||||
logger.debug(f"Current working directory: {working_directory}")
|
||||
@cli.command()
|
||||
@click.option("--config", type=Path)
|
||||
def start(config: Path):
|
||||
config_path = config
|
||||
logger = Logger.new()
|
||||
|
||||
# check the configuration file or use default
|
||||
if config_path is not None:
|
||||
@ -80,31 +47,14 @@ def start(application: _config.Application, config_path: Path, log: _config.Log)
|
||||
logger.info("Using the default configuration.")
|
||||
config = Config()
|
||||
|
||||
config.log.level = log.level
|
||||
logger = make_logger(config)
|
||||
if working_directory := config.application.working_directory.resolve():
|
||||
logger.debug(f"Change working directory: {working_directory}")
|
||||
try:
|
||||
os.chdir(working_directory)
|
||||
except FileNotFoundError as e:
|
||||
logger.error("Failed to change working directory: {}", e)
|
||||
sys.exit()
|
||||
async def main():
|
||||
app = await Application.new(config)
|
||||
await app.start()
|
||||
|
||||
config.application.mode = application.mode
|
||||
|
||||
try:
|
||||
uvicorn.run(
|
||||
make_application(config, logger),
|
||||
port=config.server.port,
|
||||
host=str(config.server.address),
|
||||
# reload = config.application.mode == "development",
|
||||
log_config=uvicorn_log_config(config),
|
||||
)
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
pass
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
@main.group()
|
||||
@cli.group()
|
||||
def config():
|
||||
pass
|
||||
|
||||
@ -123,7 +73,7 @@ def config():
|
||||
def config_create(path: Path, force: bool):
|
||||
path = path.resolve()
|
||||
config = Config()
|
||||
logger = make_logger(config)
|
||||
logger = Logger.new()
|
||||
|
||||
if path.exists() and not force:
|
||||
logger.warning("File already exists at the given path. Exit.")
|
||||
@ -148,8 +98,7 @@ def config_create(path: Path, force: bool):
|
||||
)
|
||||
def config_check(path: Path):
|
||||
path = path.resolve()
|
||||
config = Config()
|
||||
logger = make_logger(config)
|
||||
logger = Logger.new()
|
||||
|
||||
if not path.exists():
|
||||
logger.error("Configuration file was not found at the given path. Exit.")
|
||||
@ -164,4 +113,4 @@ def config_check(path: Path):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
cli()
|
13
src/materia/core/__init__.py
Normal file
13
src/materia/core/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
from materia.core.logging import Logger, LoggerInstance, LogLevel, LogMode
|
||||
from materia.core.database import (
|
||||
DatabaseError,
|
||||
DatabaseMigrationError,
|
||||
Database,
|
||||
SessionMaker,
|
||||
SessionContext,
|
||||
ConnectionContext,
|
||||
)
|
||||
from materia.core.filesystem import FileSystem, FileSystemError, TemporaryFileTarget
|
||||
from materia.core.config import Config
|
||||
from materia.core.cache import Cache, CacheError
|
||||
from materia.core.cron import Cron, CronError
|
@ -1,53 +1,56 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any, AsyncGenerator, Self
|
||||
from pydantic import BaseModel, RedisDsn
|
||||
from pydantic import RedisDsn
|
||||
from redis import asyncio as aioredis
|
||||
from redis.asyncio.client import Pipeline
|
||||
from materia.core.logging import Logger
|
||||
|
||||
|
||||
class CacheError(Exception):
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
class Cache:
|
||||
def __init__(self, url: RedisDsn, pool: aioredis.ConnectionPool):
|
||||
self.url: RedisDsn = url
|
||||
self.url: RedisDsn = url
|
||||
self.pool: aioredis.ConnectionPool = pool
|
||||
|
||||
@staticmethod
|
||||
async def new(
|
||||
url: RedisDsn,
|
||||
encoding: str = "utf-8",
|
||||
decode_responses: bool = True,
|
||||
test_connection: bool = True
|
||||
) -> Self:
|
||||
pool = aioredis.ConnectionPool.from_url(str(url), encoding = encoding, decode_responses = decode_responses)
|
||||
url: RedisDsn,
|
||||
encoding: str = "utf-8",
|
||||
decode_responses: bool = True,
|
||||
test_connection: bool = True,
|
||||
) -> Self:
|
||||
pool = aioredis.ConnectionPool.from_url(
|
||||
str(url), encoding=encoding, decode_responses=decode_responses
|
||||
)
|
||||
|
||||
if test_connection:
|
||||
try:
|
||||
if logger := Logger.instance():
|
||||
logger.debug("Testing cache connection")
|
||||
connection = pool.make_connection()
|
||||
await connection.connect()
|
||||
except ConnectionError as e:
|
||||
raise CacheError(f"{e}")
|
||||
raise CacheError(f"{e}")
|
||||
else:
|
||||
await connection.disconnect()
|
||||
|
||||
return Cache(
|
||||
url = url,
|
||||
pool = pool
|
||||
)
|
||||
return Cache(url=url, pool=pool)
|
||||
|
||||
@asynccontextmanager
|
||||
async def client(self) -> AsyncGenerator[aioredis.Redis, Any]:
|
||||
@asynccontextmanager
|
||||
async def client(self) -> AsyncGenerator[aioredis.Redis, Any]:
|
||||
try:
|
||||
yield aioredis.Redis(connection_pool = self.pool)
|
||||
yield aioredis.Redis(connection_pool=self.pool)
|
||||
except Exception as e:
|
||||
raise CacheError(f"{e}")
|
||||
|
||||
@asynccontextmanager
|
||||
@asynccontextmanager
|
||||
async def pipeline(self, transaction: bool = True) -> AsyncGenerator[Pipeline, Any]:
|
||||
client = await aioredis.Redis(connection_pool = self.pool)
|
||||
client = await aioredis.Redis(connection_pool=self.pool)
|
||||
|
||||
try:
|
||||
yield client.pipeline(transaction = transaction)
|
||||
yield client.pipeline(transaction=transaction)
|
||||
except Exception as e:
|
||||
raise CacheError(f"{e}")
|
||||
|
@ -1,15 +1,9 @@
|
||||
from os import environ
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import Any, Literal, Optional, Self, Union
|
||||
|
||||
from typing import Literal, Optional, Self, Union
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Field,
|
||||
HttpUrl,
|
||||
model_validator,
|
||||
TypeAdapter,
|
||||
PostgresDsn,
|
||||
NameEmail,
|
||||
)
|
||||
from pydantic_settings import BaseSettings
|
||||
@ -149,11 +143,11 @@ class Mailer(BaseModel):
|
||||
|
||||
|
||||
class Cron(BaseModel):
|
||||
pass
|
||||
workers_count: int = 1
|
||||
|
||||
|
||||
class Repository(BaseModel):
|
||||
capacity: int = 41943040
|
||||
capacity: int = 5 << 30
|
||||
|
||||
|
||||
class Config(BaseSettings, env_prefix="materia_", env_nested_delimiter="_"):
|
68
src/materia/core/cron.py
Normal file
68
src/materia/core/cron.py
Normal file
@ -0,0 +1,68 @@
|
||||
from typing import Optional, Self
|
||||
from celery import Celery
|
||||
from pydantic import RedisDsn
|
||||
from threading import Thread
|
||||
from materia.core.logging import Logger
|
||||
|
||||
|
||||
class CronError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Cron:
|
||||
__instance__: Optional[Self] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
workers_count: int,
|
||||
backend: Celery,
|
||||
):
|
||||
self.workers_count = workers_count
|
||||
self.backend = backend
|
||||
self.workers = []
|
||||
self.worker_threads = []
|
||||
|
||||
Cron.__instance__ = self
|
||||
|
||||
@staticmethod
|
||||
def new(
|
||||
workers_count: int = 1,
|
||||
backend_url: Optional[RedisDsn] = None,
|
||||
broker_url: Optional[RedisDsn] = None,
|
||||
test_connection: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
cron = Cron(
|
||||
workers_count,
|
||||
Celery(
|
||||
"cron",
|
||||
backend=backend_url,
|
||||
broker=broker_url,
|
||||
task_serializer="pickle",
|
||||
accept_content=["pickle", "json"],
|
||||
**kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
for _ in range(workers_count):
|
||||
cron.workers.append(cron.backend.Worker())
|
||||
|
||||
if test_connection:
|
||||
try:
|
||||
if logger := Logger.instance():
|
||||
logger.debug("Testing cron broker connection")
|
||||
cron.backend.broker_connection().ensure_connection(max_retries=3)
|
||||
except Exception as e:
|
||||
raise CronError(f"Failed to connect cron broker: {broker_url}") from e
|
||||
|
||||
return cron
|
||||
|
||||
@staticmethod
|
||||
def instance() -> Optional[Self]:
|
||||
return Cron.__instance__
|
||||
|
||||
def run_workers(self):
|
||||
for worker in self.workers:
|
||||
thread = Thread(target=worker.start, daemon=True)
|
||||
self.worker_threads.append(thread)
|
||||
thread.start()
|
@ -1,9 +1,8 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
from typing import AsyncIterator, Self, TypeAlias
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, PostgresDsn, ValidationError
|
||||
from pydantic import PostgresDsn, ValidationError
|
||||
from sqlalchemy.ext.asyncio import (
|
||||
AsyncConnection,
|
||||
AsyncEngine,
|
||||
@ -19,11 +18,7 @@ from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script.base import ScriptDirectory
|
||||
import alembic_postgresql_enum
|
||||
from fastapi import HTTPException
|
||||
|
||||
from materia.config import Config
|
||||
from materia.models.base import Base
|
||||
|
||||
__all__ = ["Database"]
|
||||
from materia.core.logging import Logger
|
||||
|
||||
|
||||
class DatabaseError(Exception):
|
||||
@ -77,6 +72,8 @@ class Database:
|
||||
|
||||
if test_connection:
|
||||
try:
|
||||
if logger := Logger.instance():
|
||||
logger.debug("Testing database connection")
|
||||
async with database.connection() as connection:
|
||||
await connection.rollback()
|
||||
except Exception as e:
|
||||
@ -112,10 +109,13 @@ class Database:
|
||||
await session.close()
|
||||
|
||||
def run_sync_migrations(self, connection: Connection):
|
||||
from materia.models.base import Base
|
||||
|
||||
aconfig = AlembicConfig()
|
||||
aconfig.set_main_option("sqlalchemy.url", str(self.url))
|
||||
aconfig.set_main_option(
|
||||
"script_location", str(Path(__file__).parent.parent.joinpath("migrations"))
|
||||
"script_location",
|
||||
str(Path(__file__).parent.parent.joinpath("models", "migrations")),
|
||||
)
|
||||
|
||||
context = MigrationContext.configure(
|
||||
@ -140,10 +140,13 @@ class Database:
|
||||
await connection.run_sync(self.run_sync_migrations) # type: ignore
|
||||
|
||||
def rollback_sync_migrations(self, connection: Connection):
|
||||
from materia.models.base import Base
|
||||
|
||||
aconfig = AlembicConfig()
|
||||
aconfig.set_main_option("sqlalchemy.url", str(self.url))
|
||||
aconfig.set_main_option(
|
||||
"script_location", str(Path(__file__).parent.parent.joinpath("migrations"))
|
||||
"script_location",
|
||||
str(Path(__file__).parent.parent.joinpath("models", "migrations")),
|
||||
)
|
||||
|
||||
context = MigrationContext.configure(
|
@ -5,6 +5,11 @@ from aiofiles import os as async_os
|
||||
from aiofiles import ospath as async_path
|
||||
import aioshutil
|
||||
import re
|
||||
from tempfile import NamedTemporaryFile
|
||||
from streaming_form_data.targets import BaseTarget
|
||||
from uuid import uuid4
|
||||
from materia.core.misc import optional
|
||||
|
||||
|
||||
valid_path = re.compile(r"^/(.*/)*([^/]*)$")
|
||||
|
||||
@ -13,26 +18,19 @@ class FileSystemError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def wrapped_next(i: Iterator[T]) -> Optional[T]:
|
||||
try:
|
||||
return next(i)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
|
||||
class FileSystem:
|
||||
def __init__(self, path: Path, working_directory: Path):
|
||||
if path == Path():
|
||||
def __init__(self, path: Path, isolated_directory: Optional[Path] = None):
|
||||
if path == Path() or path is None:
|
||||
raise FileSystemError("The given path is empty")
|
||||
if working_directory == Path():
|
||||
raise FileSystemError("The given working directory is empty")
|
||||
|
||||
self.path = path
|
||||
self.working_directory = working_directory
|
||||
self.relative_path = path.relative_to(working_directory)
|
||||
|
||||
if isolated_directory and not isolated_directory.is_absolute():
|
||||
raise FileSystemError("The isolated directory must be absolute")
|
||||
|
||||
self.isolated_directory = isolated_directory
|
||||
# self.working_directory = working_directory
|
||||
# self.relative_path = path.relative_to(working_directory)
|
||||
|
||||
async def exists(self) -> bool:
|
||||
return await async_path.exists(self.path)
|
||||
@ -49,19 +47,28 @@ class FileSystem:
|
||||
def name(self) -> str:
|
||||
return self.path.name
|
||||
|
||||
async def remove(self):
|
||||
async def check_isolation(self, path: Path):
|
||||
if not self.isolated_directory:
|
||||
return
|
||||
if not (await async_path.exists(self.isolated_directory)):
|
||||
raise FileSystemError("Missed isolated directory")
|
||||
if not optional(path.relative_to, self.isolated_directory):
|
||||
raise FileSystemError(
|
||||
"Attempting to work with a path that is outside the isolated directory"
|
||||
)
|
||||
if self.path == self.isolated_directory:
|
||||
raise FileSystemError("Attempting to modify the isolated directory")
|
||||
|
||||
async def remove(self, shallow: bool = False):
|
||||
await self.check_isolation(self.path)
|
||||
try:
|
||||
if await self.is_file():
|
||||
if await self.exists() and await self.is_file() and not shallow:
|
||||
await aiofiles.os.remove(self.path)
|
||||
|
||||
if await self.is_directory():
|
||||
if await self.exists() and await self.is_directory() and not shallow:
|
||||
await aioshutil.rmtree(str(self.path))
|
||||
|
||||
except OSError as e:
|
||||
raise FileSystemError(
|
||||
f"Failed to remove content at /{self.relative_path}:",
|
||||
*e.args,
|
||||
)
|
||||
raise FileSystemError(*e.args) from e
|
||||
|
||||
async def generate_name(self, target_directory: Path, name: str) -> str:
|
||||
"""Generate name based on target directory contents and self type."""
|
||||
@ -98,18 +105,13 @@ class FileSystem:
|
||||
force: bool = False,
|
||||
shallow: bool = False,
|
||||
) -> Path:
|
||||
if self.path == self.working_directory:
|
||||
raise FileSystemError("Cannot modify working directory")
|
||||
|
||||
new_name = new_name or self.path.name
|
||||
|
||||
if await async_path.exists(target_directory.joinpath(new_name)) and not shallow:
|
||||
if force:
|
||||
if await async_path.exists(target_directory.joinpath(new_name)):
|
||||
if force or shallow:
|
||||
new_name = await self.generate_name(target_directory, new_name)
|
||||
else:
|
||||
raise FileSystemError(
|
||||
f"Target destination already exists /{target_directory.joinpath(new_name)}"
|
||||
)
|
||||
raise FileSystemError("Target destination already exists")
|
||||
|
||||
return target_directory.joinpath(new_name)
|
||||
|
||||
@ -119,26 +121,24 @@ class FileSystem:
|
||||
new_name: Optional[str] = None,
|
||||
force: bool = False,
|
||||
shallow: bool = False,
|
||||
):
|
||||
) -> Self:
|
||||
await self.check_isolation(self.path)
|
||||
new_path = await self._generate_new_path(
|
||||
target_directory, new_name, force=force, shallow=shallow
|
||||
)
|
||||
target = FileSystem(new_path, self.isolated_directory)
|
||||
|
||||
try:
|
||||
if not shallow:
|
||||
if await self.exists() and not shallow:
|
||||
await aioshutil.move(self.path, new_path)
|
||||
|
||||
except Exception as e:
|
||||
raise FileSystemError(
|
||||
f"Failed to move content from /{self.relative_path}:",
|
||||
*e.args,
|
||||
)
|
||||
raise FileSystemError(*e.args) from e
|
||||
|
||||
return FileSystem(new_path, self.working_directory)
|
||||
return target
|
||||
|
||||
async def rename(
|
||||
self, new_name: str, force: bool = False, shallow: bool = False
|
||||
) -> Path:
|
||||
) -> Self:
|
||||
return await self.move(
|
||||
self.path.parent, new_name=new_name, force=force, shallow=shallow
|
||||
)
|
||||
@ -150,50 +150,41 @@ class FileSystem:
|
||||
force: bool = False,
|
||||
shallow: bool = False,
|
||||
) -> Self:
|
||||
await self.check_isolation(self.path)
|
||||
new_path = await self._generate_new_path(
|
||||
target_directory, new_name, force=force, shallow=shallow
|
||||
)
|
||||
target = FileSystem(new_path, self.isolated_directory)
|
||||
|
||||
try:
|
||||
if not shallow:
|
||||
if await self.is_file():
|
||||
await aioshutil.copy(self.path, new_path)
|
||||
|
||||
if await self.is_directory():
|
||||
await aioshutil.copytree(self.path, new_path)
|
||||
if await self.is_file() and not shallow:
|
||||
await aioshutil.copy(self.path, new_path)
|
||||
|
||||
if await self.is_directory() and not shallow:
|
||||
await aioshutil.copytree(self.path, new_path)
|
||||
except Exception as e:
|
||||
raise FileSystemError(
|
||||
f"Failed to copy content from /{new_path}:",
|
||||
*e.args,
|
||||
)
|
||||
raise FileSystemError(*e.args) from e
|
||||
|
||||
return FileSystem(new_path, self.working_directory)
|
||||
return target
|
||||
|
||||
async def make_directory(self):
|
||||
async def make_directory(self, force: bool = False):
|
||||
try:
|
||||
if await self.exists():
|
||||
raise FileSystemError("Failed to create directory: already exists")
|
||||
if await self.exists() and not force:
|
||||
raise FileSystemError("Already exists")
|
||||
|
||||
await async_os.mkdir(self.path)
|
||||
await async_os.makedirs(self.path, exist_ok=force)
|
||||
except Exception as e:
|
||||
raise FileSystemError(
|
||||
f"Failed to create directory at /{self.relative_path}:",
|
||||
*e.args,
|
||||
)
|
||||
raise FileSystemError(*e.args)
|
||||
|
||||
async def write_file(self, data: bytes):
|
||||
async def write_file(self, data: bytes, force: bool = False):
|
||||
try:
|
||||
if await self.exists():
|
||||
raise FileSystemError("Failed to write file: already exists")
|
||||
if await self.exists() and not force:
|
||||
raise FileSystemError("Already exists")
|
||||
|
||||
async with aiofiles.open(self.path, mode="wb") as file:
|
||||
await file.write(data)
|
||||
except Exception as e:
|
||||
raise FileSystemError(
|
||||
f"Failed to write file to /{self.relative_path}:",
|
||||
*e.args,
|
||||
)
|
||||
raise FileSystemError(*e.args)
|
||||
|
||||
@staticmethod
|
||||
def check_path(path: Path) -> bool:
|
||||
@ -206,3 +197,39 @@ class FileSystem:
|
||||
path = Path("/").joinpath(path)
|
||||
|
||||
return Path(*path.resolve().parts[1:])
|
||||
|
||||
|
||||
class TemporaryFileTarget(BaseTarget):
|
||||
def __init__(
|
||||
self, working_directory: Path, allow_overwrite: bool = True, *args, **kwargs
|
||||
):
|
||||
if working_directory == Path():
|
||||
raise FileSystemError("The given working directory is empty")
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._mode = "wb" if allow_overwrite else "xb"
|
||||
self._fd = None
|
||||
self._path = working_directory.joinpath("cache", str(uuid4()))
|
||||
|
||||
def on_start(self):
|
||||
if not self._path.parent.exists():
|
||||
self._path.parent.mkdir(exist_ok=True)
|
||||
|
||||
self._fd = open(str(self._path), mode="wb")
|
||||
|
||||
def on_data_received(self, chunk: bytes):
|
||||
if self._fd:
|
||||
self._fd.write(chunk)
|
||||
|
||||
def on_finish(self):
|
||||
if self._fd:
|
||||
self._fd.close()
|
||||
|
||||
def path(self) -> Optional[Path]:
|
||||
return self._path
|
||||
|
||||
def remove(self):
|
||||
if self._fd:
|
||||
if (path := Path(self._fd.name)).exists():
|
||||
path.unlink()
|
128
src/materia/core/logging.py
Normal file
128
src/materia/core/logging.py
Normal file
@ -0,0 +1,128 @@
|
||||
import sys
|
||||
from typing import Sequence, Literal, Optional, TypeAlias
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from loguru._logger import Logger as LoggerInstance
|
||||
import logging
|
||||
import inspect
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
level: str | int
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
frame, depth = inspect.currentframe(), 2
|
||||
while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__):
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
LogLevel: TypeAlias = Literal["info", "warning", "error", "critical", "debug", "trace"]
|
||||
LogMode: TypeAlias = Literal["console", "file", "all"]
|
||||
|
||||
|
||||
class Logger:
|
||||
__instance__: Optional[LoggerInstance] = None
|
||||
|
||||
def __init__(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def new(
|
||||
mode: LogMode = "console",
|
||||
level: LogLevel = "info",
|
||||
console_format: str = (
|
||||
"<level>{level: <8}</level> <green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> - {message}"
|
||||
),
|
||||
file_format: str = (
|
||||
"<level>{level: <8}</level>: <green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> - {message}"
|
||||
),
|
||||
file: Optional[Path] = None,
|
||||
file_rotation: str = "3 days",
|
||||
file_retention: str = "1 week",
|
||||
interceptions: Sequence[str] = [
|
||||
"uvicorn",
|
||||
"uvicorn.access",
|
||||
"uvicorn.error",
|
||||
"uvicorn.asgi",
|
||||
"fastapi",
|
||||
],
|
||||
) -> LoggerInstance:
|
||||
logger.remove()
|
||||
|
||||
if mode in ["console", "all"]:
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
enqueue=True,
|
||||
backtrace=True,
|
||||
level=level.upper(),
|
||||
format=console_format,
|
||||
filter=lambda record: record["level"].name
|
||||
in ["INFO", "WARNING", "DEBUG", "TRACE"],
|
||||
)
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
enqueue=True,
|
||||
backtrace=True,
|
||||
level=level.upper(),
|
||||
format=console_format,
|
||||
filter=lambda record: record["level"].name in ["ERROR", "CRITICAL"],
|
||||
)
|
||||
|
||||
if mode in ["file", "all"]:
|
||||
logger.add(
|
||||
str(file),
|
||||
rotation=file_rotation,
|
||||
retention=file_retention,
|
||||
enqueue=True,
|
||||
backtrace=True,
|
||||
level=level.upper(),
|
||||
format=file_format,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
handlers=[InterceptHandler()], level=logging.NOTSET, force=True
|
||||
)
|
||||
|
||||
for external_logger in interceptions:
|
||||
logging.getLogger(external_logger).handlers = [InterceptHandler()]
|
||||
|
||||
Logger.__instance__ = logger
|
||||
|
||||
return logger # type: ignore
|
||||
|
||||
@staticmethod
|
||||
def instance() -> Optional[LoggerInstance]:
|
||||
return Logger.__instance__
|
||||
|
||||
@staticmethod
|
||||
def uvicorn_config(level: LogLevel) -> dict:
|
||||
return {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"handlers": {
|
||||
"default": {"class": "materia.core.logging.InterceptHandler"},
|
||||
"access": {"class": "materia.core.logging.InterceptHandler"},
|
||||
},
|
||||
"loggers": {
|
||||
"uvicorn": {
|
||||
"handlers": ["default"],
|
||||
"level": level.upper(),
|
||||
"propagate": False,
|
||||
},
|
||||
"uvicorn.error": {"level": level.upper()},
|
||||
"uvicorn.access": {
|
||||
"handlers": ["access"],
|
||||
"level": level.upper(),
|
||||
"propagate": False,
|
||||
},
|
||||
},
|
||||
}
|
28
src/materia/core/misc.py
Normal file
28
src/materia/core/misc.py
Normal file
@ -0,0 +1,28 @@
|
||||
from typing import Optional, Self, Iterator, TypeVar, Callable, Any, ParamSpec
|
||||
from functools import partial
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
def optional(func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Optional[T]:
|
||||
try:
|
||||
res = func(*args, **kwargs)
|
||||
except TypeError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
return None
|
||||
return res
|
||||
|
||||
|
||||
def optional_next(it: Iterator[T]) -> Optional[T]:
|
||||
return optional(next, it)
|
||||
|
||||
|
||||
def optional_string(value: Any, format_string: Optional[str] = None) -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
res = optional(str, value)
|
||||
if res is None:
|
||||
return ""
|
||||
return format_string.format(res)
|
@ -1,31 +1,17 @@
|
||||
from materia.models.auth import (
|
||||
LoginType,
|
||||
LoginSource,
|
||||
OAuth2Application,
|
||||
OAuth2Grant,
|
||||
OAuth2AuthorizationCode,
|
||||
# OAuth2Application,
|
||||
# OAuth2Grant,
|
||||
# OAuth2AuthorizationCode,
|
||||
)
|
||||
|
||||
from materia.models.database import (
|
||||
Database,
|
||||
DatabaseError,
|
||||
DatabaseMigrationError,
|
||||
Cache,
|
||||
CacheError,
|
||||
SessionContext,
|
||||
)
|
||||
|
||||
from materia.models.user import User, UserCredentials, UserInfo
|
||||
|
||||
from materia.models.filesystem import FileSystem
|
||||
|
||||
from materia.models.repository import (
|
||||
Repository,
|
||||
RepositoryInfo,
|
||||
RepositoryContent,
|
||||
RepositoryError,
|
||||
)
|
||||
|
||||
from materia.models.directory import (
|
||||
Directory,
|
||||
DirectoryLink,
|
||||
@ -34,7 +20,6 @@ from materia.models.directory import (
|
||||
DirectoryRename,
|
||||
DirectoryCopyMove,
|
||||
)
|
||||
|
||||
from materia.models.file import (
|
||||
File,
|
||||
FileLink,
|
||||
|
@ -1,3 +1,3 @@
|
||||
from materia.models.auth.source import LoginType, LoginSource
|
||||
from materia.models.auth.oauth2 import OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode
|
||||
|
||||
# from materia.models.auth.oauth2 import OAuth2Application, OAuth2Grant, OAuth2AuthorizationCode
|
||||
|
@ -1,33 +1,33 @@
|
||||
from time import time
|
||||
from typing import List, Optional, Self, Union
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import bcrypt
|
||||
import httpx
|
||||
from sqlalchemy import BigInteger, ExceptionContext, ForeignKey, JSON, and_, delete, select, update
|
||||
from sqlalchemy import BigInteger, ForeignKey, JSON, and_, select
|
||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from materia.models.base import Base
|
||||
from materia.models.database import Database, Cache
|
||||
from materia.core import Database, Cache
|
||||
from materia import security
|
||||
from materia.models import user
|
||||
|
||||
|
||||
class OAuth2Application(Base):
|
||||
__tablename__ = "oauth2_application"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id", ondelete = "CASCADE"))
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id", ondelete="CASCADE"))
|
||||
name: Mapped[str]
|
||||
client_id: Mapped[UUID] = mapped_column(default = uuid4)
|
||||
client_id: Mapped[UUID] = mapped_column(default=uuid4)
|
||||
hashed_client_secret: Mapped[str]
|
||||
redirect_uris: Mapped[List[str]] = mapped_column(JSON)
|
||||
confidential_client: Mapped[bool] = mapped_column(default = True)
|
||||
created: Mapped[int] = mapped_column(BigInteger, default = time)
|
||||
updated: Mapped[int] = mapped_column(BigInteger, default = time)
|
||||
confidential_client: Mapped[bool] = mapped_column(default=True)
|
||||
created: Mapped[int] = mapped_column(BigInteger, default=time)
|
||||
updated: Mapped[int] = mapped_column(BigInteger, default=time)
|
||||
|
||||
#user: Mapped["user.User"] = relationship(back_populates = "oauth2_applications")
|
||||
grants: Mapped[List["OAuth2Grant"]] = relationship(back_populates = "application")
|
||||
# user: Mapped["user.User"] = relationship(back_populates = "oauth2_applications")
|
||||
grants: Mapped[List["OAuth2Grant"]] = relationship(back_populates="application")
|
||||
|
||||
def contains_redirect_uri(self, uri: HttpUrl) -> bool:
|
||||
if not self.confidential_client:
|
||||
@ -41,14 +41,14 @@ class OAuth2Application(Base):
|
||||
return False
|
||||
|
||||
async def generate_client_secret(self, db: Database) -> str:
|
||||
client_secret = security.generate_key()
|
||||
client_secret = security.generate_key()
|
||||
hashed_secret = bcrypt.hashpw(client_secret, bcrypt.gensalt())
|
||||
|
||||
self.hashed_client_secret = str(hashed_secret)
|
||||
|
||||
|
||||
async with db.session() as session:
|
||||
session.add(self)
|
||||
await session.commit()
|
||||
await session.commit()
|
||||
|
||||
return str(client_secret)
|
||||
|
||||
@ -64,30 +64,53 @@ class OAuth2Application(Base):
|
||||
@staticmethod
|
||||
async def delete(db: Database, id: int, user_id: int):
|
||||
async with db.session() as session:
|
||||
if not (application := (await session.scalars(
|
||||
select(OAuth2Application)
|
||||
.where(and_(OAuth2Application.id == id, OAuth2Application.user_id == user_id))
|
||||
)).first()):
|
||||
if not (
|
||||
application := (
|
||||
await session.scalars(
|
||||
select(OAuth2Application).where(
|
||||
and_(
|
||||
OAuth2Application.id == id,
|
||||
OAuth2Application.user_id == user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
).first()
|
||||
):
|
||||
raise Exception("OAuth2Application not found")
|
||||
|
||||
#await session.refresh(application, attribute_names = [ "grants" ])
|
||||
# await session.refresh(application, attribute_names = [ "grants" ])
|
||||
await session.delete(application)
|
||||
|
||||
@staticmethod
|
||||
async def by_client_id(client_id: str, db: Database) -> Union[Self, None]:
|
||||
async with db.session() as session:
|
||||
return await session.scalar(select(OAuth2Application).where(OAuth2Application.client_id == client_id))
|
||||
return await session.scalar(
|
||||
select(OAuth2Application).where(
|
||||
OAuth2Application.client_id == client_id
|
||||
)
|
||||
)
|
||||
|
||||
async def grant_by_user_id(self, user_id: UUID, db: Database) -> Union["OAuth2Grant", None]:
|
||||
async def grant_by_user_id(
|
||||
self, user_id: UUID, db: Database
|
||||
) -> Union["OAuth2Grant", None]:
|
||||
async with db.session() as session:
|
||||
return (await session.scalars(select(OAuth2Grant).where(and_(OAuth2Grant.application_id == self.id, OAuth2Grant.user_id == user_id)))).first()
|
||||
return (
|
||||
await session.scalars(
|
||||
select(OAuth2Grant).where(
|
||||
and_(
|
||||
OAuth2Grant.application_id == self.id,
|
||||
OAuth2Grant.user_id == user_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
).first()
|
||||
|
||||
|
||||
class OAuth2AuthorizationCode(BaseModel):
|
||||
grant: "OAuth2Grant"
|
||||
code: str
|
||||
code: str
|
||||
redirect_uri: HttpUrl
|
||||
created: int
|
||||
created: int
|
||||
lifetime: int
|
||||
|
||||
def generate_redirect_uri(self, state: Optional[str] = None) -> httpx.URL:
|
||||
@ -104,31 +127,36 @@ class OAuth2AuthorizationCode(BaseModel):
|
||||
class OAuth2Grant(Base):
|
||||
__tablename__ = "oauth2_grant"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id", ondelete = "CASCADE"))
|
||||
application_id: Mapped[int] = mapped_column(ForeignKey("oauth2_application.id", ondelete = "CASCADE"))
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id", ondelete="CASCADE"))
|
||||
application_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("oauth2_application.id", ondelete="CASCADE")
|
||||
)
|
||||
scope: Mapped[str]
|
||||
created: Mapped[int] = mapped_column(default = time)
|
||||
updated: Mapped[int] = mapped_column(default = time)
|
||||
created: Mapped[int] = mapped_column(default=time)
|
||||
updated: Mapped[int] = mapped_column(default=time)
|
||||
|
||||
application: Mapped[OAuth2Application] = relationship(back_populates = "grants")
|
||||
application: Mapped[OAuth2Application] = relationship(back_populates="grants")
|
||||
|
||||
async def generate_authorization_code(self, redirect_uri: HttpUrl, cache: Cache) -> OAuth2AuthorizationCode:
|
||||
async def generate_authorization_code(
|
||||
self, redirect_uri: HttpUrl, cache: Cache
|
||||
) -> OAuth2AuthorizationCode:
|
||||
code = OAuth2AuthorizationCode(
|
||||
grant = self,
|
||||
redirect_uri = redirect_uri,
|
||||
code = security.generate_key().decode(),
|
||||
created = int(time()),
|
||||
lifetime = 3000
|
||||
grant=self,
|
||||
redirect_uri=redirect_uri,
|
||||
code=security.generate_key().decode(),
|
||||
created=int(time()),
|
||||
lifetime=3000,
|
||||
)
|
||||
|
||||
async with cache.client() as client:
|
||||
client.set("oauth2_authorization_code_{}".format(code.created), code.code, ex = code.lifetime)
|
||||
|
||||
return code
|
||||
async with cache.client() as client:
|
||||
client.set(
|
||||
"oauth2_authorization_code_{}".format(code.created),
|
||||
code.code,
|
||||
ex=code.lifetime,
|
||||
)
|
||||
|
||||
return code
|
||||
|
||||
def scope_contains(self, scope: str) -> bool:
|
||||
return scope in self.scope.split(" ")
|
||||
|
||||
|
||||
|
||||
|
@ -1,9 +1,7 @@
|
||||
|
||||
|
||||
import enum
|
||||
from time import time
|
||||
|
||||
from sqlalchemy import BigInteger, Enum
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from materia.models.base import Base
|
||||
@ -18,13 +16,13 @@ class LoginType(enum.Enum):
|
||||
class LoginSource(Base):
|
||||
__tablename__ = "login_source"
|
||||
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key = True)
|
||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
type: Mapped[LoginType]
|
||||
created: Mapped[int] = mapped_column(default = time)
|
||||
updated: Mapped[int] = mapped_column(default = time)
|
||||
created: Mapped[int] = mapped_column(default=time)
|
||||
updated: Mapped[int] = mapped_column(default=time)
|
||||
|
||||
def is_plain(self) -> bool:
|
||||
return self.type == LoginType.Plain
|
||||
return self.type == LoginType.Plain
|
||||
|
||||
def is_oauth2(self) -> bool:
|
||||
return self.type == LoginType.OAuth2
|
||||
|
@ -1,9 +0,0 @@
|
||||
from materia.models.database.database import (
|
||||
DatabaseError,
|
||||
DatabaseMigrationError,
|
||||
Database,
|
||||
SessionMaker,
|
||||
SessionContext,
|
||||
ConnectionContext,
|
||||
)
|
||||
from materia.models.database.cache import Cache, CacheError
|
@ -1,19 +1,14 @@
|
||||
from time import time
|
||||
from typing import List, Optional, Self
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import aiofiles
|
||||
import re
|
||||
|
||||
from sqlalchemy import BigInteger, ForeignKey, inspect
|
||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, ValidationError
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from materia.models.base import Base
|
||||
from materia.models import database
|
||||
from materia.models.database import SessionContext
|
||||
from materia.config import Config
|
||||
from materia.core import SessionContext, Config, FileSystem
|
||||
|
||||
|
||||
class DirectoryError(Exception):
|
||||
@ -307,4 +302,3 @@ class DirectoryCopyMove(BaseModel):
|
||||
|
||||
from materia.models.repository import Repository
|
||||
from materia.models.file import File
|
||||
from materia.models.filesystem import FileSystem
|
||||
|
@ -1,19 +1,14 @@
|
||||
from time import time
|
||||
from typing import Optional, Self
|
||||
from typing import Optional, Self, Union
|
||||
from pathlib import Path
|
||||
import aioshutil
|
||||
|
||||
from sqlalchemy import BigInteger, ForeignKey, inspect
|
||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
|
||||
from materia.models.base import Base
|
||||
from materia.models import database
|
||||
from materia.models.database import SessionContext
|
||||
from materia.config import Config
|
||||
from materia.core import SessionContext, Config, FileSystem
|
||||
|
||||
|
||||
class FileError(Exception):
|
||||
@ -41,18 +36,23 @@ class File(Base):
|
||||
link: Mapped["FileLink"] = relationship(back_populates="file")
|
||||
|
||||
async def new(
|
||||
self, data: bytes, session: SessionContext, config: Config
|
||||
self, data: Union[bytes, Path], session: SessionContext, config: Config
|
||||
) -> Optional[Self]:
|
||||
session.add(self)
|
||||
await session.flush()
|
||||
await session.refresh(self, attribute_names=["repository"])
|
||||
|
||||
file_path = await self.real_path(session, config)
|
||||
repository_path = await self.repository.real_path(session, config)
|
||||
new_file = FileSystem(file_path, repository_path)
|
||||
|
||||
new_file = FileSystem(
|
||||
file_path, await self.repository.real_path(session, config)
|
||||
)
|
||||
await new_file.write_file(data)
|
||||
if isinstance(data, bytes):
|
||||
await new_file.write_file(data)
|
||||
elif isinstance(data, Path):
|
||||
from_file = FileSystem(data, config.application.working_directory)
|
||||
await from_file.move(file_path.parent, new_name=file_path.name)
|
||||
else:
|
||||
raise FileError(f"Unknown data type passed: {type(data)}")
|
||||
|
||||
self.size = await new_file.size()
|
||||
await session.flush()
|
||||
@ -113,8 +113,10 @@ class File(Base):
|
||||
if path == Path():
|
||||
raise FileError("Cannot find file by empty path")
|
||||
|
||||
parent_directory = await Directory.by_path(
|
||||
repository, path.parent, session, config
|
||||
parent_directory = (
|
||||
None
|
||||
if path.parent == Path()
|
||||
else await Directory.by_path(repository, path.parent, session, config)
|
||||
)
|
||||
|
||||
current_file = (
|
||||
@ -214,10 +216,10 @@ class File(Base):
|
||||
await session.flush()
|
||||
return self
|
||||
|
||||
async def info(self) -> Optional["FileInfo"]:
|
||||
if self.is_public:
|
||||
return FileInfo.model_validate(self)
|
||||
return None
|
||||
def info(self) -> Optional["FileInfo"]:
|
||||
# if self.is_public:
|
||||
return FileInfo.model_validate(self)
|
||||
# return None
|
||||
|
||||
|
||||
def convert_bytes(size: int):
|
||||
@ -269,4 +271,3 @@ class FileCopyMove(BaseModel):
|
||||
|
||||
from materia.models.repository import Repository
|
||||
from materia.models.directory import Directory
|
||||
from materia.models.filesystem import FileSystem
|
||||
|
@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import async_engine_from_config
|
||||
from alembic import context
|
||||
import alembic_postgresql_enum
|
||||
|
||||
from materia.config import Config
|
||||
from materia.core import Config
|
||||
from materia.models.base import Base
|
||||
import materia.models.user
|
||||
import materia.models.auth
|
||||
@ -22,12 +22,12 @@ import materia.models.file
|
||||
|
||||
config = context.config
|
||||
|
||||
#config.set_main_option("sqlalchemy.url", Config().database.url())
|
||||
# config.set_main_option("sqlalchemy.url", Config().database.url())
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name, disable_existing_loggers = False)
|
||||
fileConfig(config.config_file_name, disable_existing_loggers=False)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
@ -61,7 +61,7 @@ def run_migrations_offline() -> None:
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
version_table_schema = "public"
|
||||
version_table_schema="public",
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
@ -1,19 +1,15 @@
|
||||
from time import time
|
||||
from typing import List, Self, Optional
|
||||
from uuid import UUID, uuid4
|
||||
from uuid import UUID
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
from sqlalchemy import BigInteger, ForeignKey, inspect
|
||||
from sqlalchemy import BigInteger, ForeignKey
|
||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from materia.models.base import Base
|
||||
from materia.models import database
|
||||
from materia.models.database import SessionContext
|
||||
from materia.config import Config
|
||||
from materia.core import SessionContext, Config
|
||||
|
||||
|
||||
class RepositoryError(Exception):
|
||||
@ -99,12 +95,19 @@ class Repository(Base):
|
||||
await session.refresh(user, attribute_names=["repository"])
|
||||
return user.repository
|
||||
|
||||
async def info(self, session: SessionContext) -> "RepositoryInfo":
|
||||
async def used(self, session: SessionContext) -> int:
|
||||
session.add(self)
|
||||
await session.refresh(self, attribute_names=["files"])
|
||||
|
||||
return sum([file.size for file in self.files])
|
||||
|
||||
async def remaining_capacity(self, session: SessionContext) -> int:
|
||||
used = await self.used(session)
|
||||
return self.capacity - used
|
||||
|
||||
async def info(self, session: SessionContext) -> "RepositoryInfo":
|
||||
info = RepositoryInfo.model_validate(self)
|
||||
info.used = sum([file.size for file in self.files])
|
||||
info.used = await self.used(session)
|
||||
|
||||
return info
|
||||
|
||||
|
@ -4,8 +4,7 @@ import time
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel, EmailStr, ConfigDict
|
||||
import pydantic
|
||||
from sqlalchemy import BigInteger, Enum
|
||||
from sqlalchemy import BigInteger
|
||||
from sqlalchemy.orm import mapped_column, Mapped, relationship
|
||||
import sqlalchemy as sa
|
||||
from PIL import Image
|
||||
@ -15,10 +14,7 @@ from aiofiles import os as async_os
|
||||
from materia import security
|
||||
from materia.models.base import Base
|
||||
from materia.models.auth.source import LoginType
|
||||
from materia.models import database
|
||||
from materia.models.database import SessionContext
|
||||
from materia.config import Config
|
||||
from loguru import logger
|
||||
from materia.core import SessionContext, Config, FileSystem
|
||||
|
||||
valid_username = re.compile(r"^[\da-zA-Z][-.\w]*$")
|
||||
invalid_username = re.compile(r"[-._]{2,}|[-._]$")
|
||||
@ -230,4 +226,3 @@ class UserInfo(BaseModel):
|
||||
|
||||
|
||||
from materia.models.repository import Repository
|
||||
from materia.models.filesystem import FileSystem
|
||||
|
@ -1,9 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from materia.models import (
|
||||
User,
|
||||
Directory,
|
||||
@ -11,14 +7,10 @@ from materia.models import (
|
||||
DirectoryPath,
|
||||
DirectoryRename,
|
||||
DirectoryCopyMove,
|
||||
FileSystem,
|
||||
Repository,
|
||||
)
|
||||
from materia.models.database import SessionContext
|
||||
from materia.core import SessionContext, Config, FileSystem
|
||||
from materia.routers import middleware
|
||||
from materia.config import Config
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(tags=["directory"])
|
||||
|
||||
|
@ -1,24 +1,39 @@
|
||||
import os
|
||||
from typing import Annotated, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile
|
||||
|
||||
from fastapi import (
|
||||
Request,
|
||||
APIRouter,
|
||||
Depends,
|
||||
HTTPException,
|
||||
status,
|
||||
UploadFile,
|
||||
File as _File,
|
||||
Form,
|
||||
)
|
||||
from fastapi.responses import JSONResponse
|
||||
from materia.models import (
|
||||
User,
|
||||
File,
|
||||
FileInfo,
|
||||
Directory,
|
||||
DirectoryPath,
|
||||
Repository,
|
||||
FileSystem,
|
||||
FileRename,
|
||||
FilePath,
|
||||
FileCopyMove,
|
||||
)
|
||||
from materia.models.database import SessionContext
|
||||
from materia.core import (
|
||||
SessionContext,
|
||||
Config,
|
||||
FileSystem,
|
||||
TemporaryFileTarget,
|
||||
Database,
|
||||
)
|
||||
from materia.routers import middleware
|
||||
from materia.config import Config
|
||||
from materia.routers.api.directory import validate_target_directory
|
||||
from streaming_form_data import StreamingFormDataParser
|
||||
from streaming_form_data.targets import ValueTarget
|
||||
from starlette.requests import ClientDisconnect
|
||||
from aiofiles import ospath as async_path
|
||||
from materia.tasks import remove_cache_file
|
||||
|
||||
router = APIRouter(tags=["file"])
|
||||
|
||||
@ -42,36 +57,86 @@ async def validate_current_file(
|
||||
return file
|
||||
|
||||
|
||||
class FileSizeValidator:
|
||||
def __init__(self, capacity: int):
|
||||
self.body = 0
|
||||
self.capacity = capacity
|
||||
|
||||
def __call__(self, chunk: bytes):
|
||||
self.body += len(chunk)
|
||||
if self.body > self.capacity:
|
||||
raise HTTPException(status.HTTP_413_REQUEST_ENTITY_TOO_LARGE)
|
||||
|
||||
|
||||
@router.post("/file")
|
||||
async def create(
|
||||
file: UploadFile,
|
||||
path: DirectoryPath,
|
||||
request: Request,
|
||||
repository: Repository = Depends(middleware.repository),
|
||||
ctx: middleware.Context = Depends(),
|
||||
):
|
||||
if not file.filename:
|
||||
database = await Database.new(ctx.config.database.url(), test_connection=False)
|
||||
async with database.session() as session:
|
||||
capacity = await repository.remaining_capacity(session)
|
||||
|
||||
try:
|
||||
file = TemporaryFileTarget(
|
||||
ctx.config.application.working_directory,
|
||||
validator=FileSizeValidator(capacity),
|
||||
)
|
||||
path = ValueTarget()
|
||||
|
||||
ctx.logger.debug(f"Shedule remove cache file: {file.path().name}")
|
||||
remove_cache_file.apply_async(args=(file.path(), ctx.config), countdown=10)
|
||||
|
||||
parser = StreamingFormDataParser(headers=request.headers)
|
||||
parser.register("file", file)
|
||||
parser.register("path", path)
|
||||
|
||||
async for chunk in request.stream():
|
||||
parser.data_received(chunk)
|
||||
|
||||
except ClientDisconnect:
|
||||
file.remove()
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Client disconnect")
|
||||
except HTTPException as e:
|
||||
file.remove()
|
||||
raise e
|
||||
except Exception as e:
|
||||
file.remove()
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, " ".join(e.args))
|
||||
|
||||
path = Path(path.value.decode())
|
||||
|
||||
if not file.multipart_filename:
|
||||
file.remove()
|
||||
raise HTTPException(
|
||||
status.HTTP_417_EXPECTATION_FAILED, "Cannot upload file without name"
|
||||
)
|
||||
if not FileSystem.check_path(path.path):
|
||||
if not FileSystem.check_path(path):
|
||||
file.remove()
|
||||
raise HTTPException(status.HTTP_500_INTERNAL_SERVER_ERROR, "Invalid path")
|
||||
|
||||
async with ctx.database.session() as session:
|
||||
async with database.session() as session:
|
||||
target_directory = await validate_target_directory(
|
||||
path.path, repository, session, ctx.config
|
||||
path, repository, session, ctx.config
|
||||
)
|
||||
|
||||
await File(
|
||||
repository_id=repository.id,
|
||||
parent_id=target_directory.id if target_directory else None,
|
||||
name=file.filename,
|
||||
size=file.size,
|
||||
).new(await file.read(), session, ctx.config)
|
||||
|
||||
await session.commit()
|
||||
try:
|
||||
await File(
|
||||
repository_id=repository.id,
|
||||
parent_id=target_directory.id if target_directory else None,
|
||||
name=file.multipart_filename,
|
||||
size=await async_path.getsize(file.path()),
|
||||
).new(file.path(), session, ctx.config)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR, "Failed to create file"
|
||||
)
|
||||
else:
|
||||
await session.commit()
|
||||
|
||||
|
||||
@router.get("/file")
|
||||
@router.get("/file", response_model=FileInfo)
|
||||
async def info(
|
||||
path: Path,
|
||||
repository: Repository = Depends(middleware.repository),
|
||||
@ -80,7 +145,7 @@ async def info(
|
||||
async with ctx.database.session() as session:
|
||||
file = await validate_current_file(path, repository, session, ctx.config)
|
||||
|
||||
info = await file.info(session)
|
||||
info = file.info()
|
||||
|
||||
return info
|
||||
|
||||
|
@ -1,7 +1,4 @@
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from materia.models import (
|
||||
User,
|
||||
Repository,
|
||||
@ -11,7 +8,6 @@ from materia.models import (
|
||||
DirectoryInfo,
|
||||
)
|
||||
from materia.routers import middleware
|
||||
from materia.config import Config
|
||||
|
||||
|
||||
router = APIRouter(tags=["repository"])
|
||||
|
16
src/materia/routers/api/tasks.py
Normal file
16
src/materia/routers/api/tasks.py
Normal file
@ -0,0 +1,16 @@
|
||||
from celery.result import AsyncResult
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
router = APIRouter(tags=["tasks"])
|
||||
|
||||
|
||||
@router.get("/tasks/${task_id}")
|
||||
async def status_task(task_id):
|
||||
task_result = AsyncResult(task_id)
|
||||
result = {
|
||||
"task_id": task_id,
|
||||
"task_status": task_result.status,
|
||||
"task_result": task_result.result,
|
||||
}
|
||||
return JSONResponse(result)
|
@ -1,13 +1,6 @@
|
||||
import uuid
|
||||
import io
|
||||
import shutil
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile
|
||||
import sqlalchemy as sa
|
||||
from sqids.sqids import Sqids
|
||||
from PIL import Image
|
||||
|
||||
from materia.config import Config
|
||||
from materia.models import User, UserInfo
|
||||
from materia.routers import middleware
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
from typing import Optional, Sequence
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from fastapi import HTTPException, Request, Response, status, Depends, Cookie
|
||||
from fastapi import HTTPException, Request, Response, status, Depends
|
||||
from fastapi.security.base import SecurityBase
|
||||
import jwt
|
||||
from sqlalchemy import select
|
||||
|
@ -5,7 +5,7 @@ from pathlib import Path
|
||||
import mimetypes
|
||||
|
||||
from materia.routers import middleware
|
||||
from materia.config import Config
|
||||
from materia.core import Config
|
||||
|
||||
router = APIRouter(tags=["resources"], prefix="/resources")
|
||||
|
||||
|
@ -1,16 +1,19 @@
|
||||
from typing import Literal
|
||||
|
||||
import bcrypt
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_password(password: str, algo: Literal["bcrypt"] = "bcrypt") -> str:
|
||||
if algo == "bcrypt":
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
else:
|
||||
raise NotImplemented(algo)
|
||||
raise NotImplementedError(algo)
|
||||
|
||||
def validate_password(password: str, hash: str, algo: Literal["bcrypt"] = "bcrypt") -> bool:
|
||||
|
||||
def validate_password(
|
||||
password: str, hash: str, algo: Literal["bcrypt"] = "bcrypt"
|
||||
) -> bool:
|
||||
if algo == "bcrypt":
|
||||
return bcrypt.checkpw(password.encode(), hash.encode())
|
||||
else:
|
||||
raise NotImplemented(algo)
|
||||
raise NotImplementedError(algo)
|
||||
|
1
src/materia/tasks/__init__.py
Normal file
1
src/materia/tasks/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from materia.tasks.file import remove_cache_file
|
17
src/materia/tasks/file.py
Normal file
17
src/materia/tasks/file.py
Normal file
@ -0,0 +1,17 @@
|
||||
from materia.core import Cron, CronError, SessionContext, Config, Database
|
||||
from celery import shared_task
|
||||
from fastapi import UploadFile
|
||||
from materia.models import File
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from materia.core import FileSystem, Config
|
||||
|
||||
|
||||
@shared_task(name="remove_cache_file")
|
||||
def remove_cache_file(path: Path, config: Config):
|
||||
target = FileSystem(path, config.application.working_directory.joinpath("cache"))
|
||||
|
||||
async def wrapper():
|
||||
await target.remove()
|
||||
|
||||
asyncio.run(wrapper())
|
Loading…
Reference in New Issue
Block a user