From a809a4b67f239b0bce150222c61284ad4028c1d5 Mon Sep 17 00:00:00 2001 From: Evgenii Alekseev Date: Sun, 8 Mar 2026 02:12:46 +0200 Subject: [PATCH] feat: support request-id header --- docs/ahriman.core.log.rst | 8 ++ docs/ahriman.web.middlewares.rst | 8 ++ frontend/src/api/client/Client.ts | 1 + .../settings/ahriman.ini.d/logging.ini | 6 +- src/ahriman/core/http/sync_ahriman_client.py | 10 ++ src/ahriman/core/http/sync_http_client.py | 12 ++ src/ahriman/core/log/lazy_logging.py | 36 ++---- src/ahriman/core/log/log_context.py | 108 ++++++++++++++++++ src/ahriman/core/log/log_loader.py | 21 +++- .../web/middlewares/request_id_handler.py | 51 +++++++++ src/ahriman/web/web.py | 2 + tests/ahriman/conftest.py | 9 ++ .../core/http/test_sync_ahriman_client.py | 7 ++ .../core/http/test_sync_http_client.py | 21 ++++ tests/ahriman/core/log/test_lazy_logging.py | 63 ++++------ tests/ahriman/core/log/test_log_context.py | 75 ++++++++++++ tests/ahriman/core/log/test_log_loader.py | 11 ++ .../middlewares/test_request_id_handler.py | 43 +++++++ 18 files changed, 424 insertions(+), 68 deletions(-) create mode 100644 src/ahriman/core/log/log_context.py create mode 100644 src/ahriman/web/middlewares/request_id_handler.py create mode 100644 tests/ahriman/core/log/test_log_context.py create mode 100644 tests/ahriman/web/middlewares/test_request_id_handler.py diff --git a/docs/ahriman.core.log.rst b/docs/ahriman.core.log.rst index e4122756..df512a9b 100644 --- a/docs/ahriman.core.log.rst +++ b/docs/ahriman.core.log.rst @@ -28,6 +28,14 @@ ahriman.core.log.lazy\_logging module :no-undoc-members: :show-inheritance: +ahriman.core.log.log\_context module +------------------------------------ + +.. automodule:: ahriman.core.log.log_context + :members: + :no-undoc-members: + :show-inheritance: + ahriman.core.log.log\_loader module ----------------------------------- diff --git a/docs/ahriman.web.middlewares.rst b/docs/ahriman.web.middlewares.rst index db4d1022..8d05beaf 100644 --- a/docs/ahriman.web.middlewares.rst +++ b/docs/ahriman.web.middlewares.rst @@ -28,6 +28,14 @@ ahriman.web.middlewares.metrics\_handler module :no-undoc-members: :show-inheritance: +ahriman.web.middlewares.request\_id\_handler module +--------------------------------------------------- + +.. automodule:: ahriman.web.middlewares.request_id_handler + :members: + :no-undoc-members: + :show-inheritance: + Module contents --------------- diff --git a/frontend/src/api/client/Client.ts b/frontend/src/api/client/Client.ts index c96be413..296459ef 100644 --- a/frontend/src/api/client/Client.ts +++ b/frontend/src/api/client/Client.ts @@ -40,6 +40,7 @@ export class Client { const headers: Record = { Accept: "application/json", + "X-Request-ID": crypto.randomUUID(), }; if (json !== undefined) { headers["Content-Type"] = "application/json"; diff --git a/package/share/ahriman/settings/ahriman.ini.d/logging.ini b/package/share/ahriman/settings/ahriman.ini.d/logging.ini index e120a037..97bf8aa8 100644 --- a/package/share/ahriman/settings/ahriman.ini.d/logging.ini +++ b/package/share/ahriman/settings/ahriman.ini.d/logging.ini @@ -26,10 +26,12 @@ formatter = syslog_format args = ("/dev/log",) [formatter_generic_format] -format = [%(levelname)s %(asctime)s] [%(name)s]: %(message)s +format = [{levelname} {asctime}] [{name}]: {message} +style = { [formatter_syslog_format] -format = [%(levelname)s] [%(name)s]: %(message)s +format = [{levelname}] [{name}]: {message} +style = { [logger_root] level = DEBUG diff --git a/src/ahriman/core/http/sync_ahriman_client.py b/src/ahriman/core/http/sync_ahriman_client.py index 3d5544aa..f23f988e 100644 --- a/src/ahriman/core/http/sync_ahriman_client.py +++ b/src/ahriman/core/http/sync_ahriman_client.py @@ -19,6 +19,7 @@ # import contextlib import requests +import uuid from requests.adapters import BaseAdapter from urllib.parse import urlparse @@ -60,6 +61,15 @@ class SyncAhrimanClient(SyncHttpClient): return adapters + def headers(self) -> dict[str, str]: + """ + additional request headers + + Returns: + dict[str, str]: additional request headers defined by class + """ + return SyncHttpClient.headers(self) | {"X-Request-ID": str(uuid.uuid4())} + def on_session_creation(self, session: requests.Session) -> None: """ method which will be called on session creation diff --git a/src/ahriman/core/http/sync_http_client.py b/src/ahriman/core/http/sync_http_client.py index 2bb30d92..c07b55b3 100644 --- a/src/ahriman/core/http/sync_http_client.py +++ b/src/ahriman/core/http/sync_http_client.py @@ -144,6 +144,15 @@ class SyncHttpClient(LazyLogging): "https://": HTTPAdapter(max_retries=self.retry), } + def headers(self) -> dict[str, str]: + """ + additional request headers + + Returns: + dict[str, str]: additional request headers defined by class + """ + return {} + def make_request(self, method: Literal["DELETE", "GET", "HEAD", "POST", "PUT"], url: str, *, headers: dict[str, str] | None = None, params: list[tuple[str, str]] | None = None, @@ -178,6 +187,9 @@ class SyncHttpClient(LazyLogging): if session is None: session = self.session + if additional_headers := self.headers(): + headers = additional_headers | (headers or {}) + try: response = session.request(method, url, params=params, data=data, headers=headers, files=files, json=json, stream=stream, auth=self.auth, timeout=self.timeout) diff --git a/src/ahriman/core/log/lazy_logging.py b/src/ahriman/core/log/lazy_logging.py index 2dc930cc..278c7a17 100644 --- a/src/ahriman/core/log/lazy_logging.py +++ b/src/ahriman/core/log/lazy_logging.py @@ -24,6 +24,7 @@ from collections.abc import Iterator from functools import cached_property from typing import Any +from ahriman.core.log.log_context import LogContext from ahriman.models.log_record_id import LogRecordId @@ -54,30 +55,20 @@ class LazyLogging: prefix = "" if clazz.__module__ is None else f"{clazz.__module__}." return f"{prefix}{clazz.__qualname__}" - @staticmethod - def _package_logger_reset() -> None: + @contextlib.contextmanager + def in_context(self, name: str, value: Any) -> Iterator[None]: """ - reset package logger to empty one - """ - logging.setLogRecordFactory(logging.LogRecord) - - @staticmethod - def _package_logger_set(package_base: str, version: str | None) -> None: - """ - set package base as extra info to the logger + execute function while setting log context. The context will be reset after the execution Args: - package_base(str): package base - version(str | None): package version if available + name(str): attribute name to set on log records + value(Any): current value of the context variable """ - current_factory = logging.getLogRecordFactory() - - def package_record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord: - record = current_factory(*args, **kwargs) - record.package_id = LogRecordId(package_base, version or "") - return record - - logging.setLogRecordFactory(package_record_factory) + token = LogContext.set(name, value) + try: + yield + finally: + LogContext.reset(name, token) @contextlib.contextmanager def in_package_context(self, package_base: str, version: str | None) -> Iterator[None]: @@ -94,8 +85,5 @@ class LazyLogging: >>> with self.in_package_context(package.base, package.version): >>> build_package(package) """ - try: - self._package_logger_set(package_base, version) + with self.in_context("package_id", LogRecordId(package_base, version or "")): yield - finally: - self._package_logger_reset() diff --git a/src/ahriman/core/log/log_context.py b/src/ahriman/core/log/log_context.py new file mode 100644 index 00000000..c5a87e24 --- /dev/null +++ b/src/ahriman/core/log/log_context.py @@ -0,0 +1,108 @@ +# +# Copyright (c) 2021-2026 ahriman team. +# +# This file is part of ahriman +# (see https://github.com/arcan1s/ahriman). +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +import contextvars +import logging + +from typing import Any, ClassVar, TypeVar, cast + + +T = TypeVar("T") + + +class LogContext: + """ + logging context manager which provides context variables injection into log records + """ + + _context: ClassVar[dict[str, contextvars.ContextVar[Any]]] = {} + + @classmethod + def get(cls, name: str) -> T | None: + """ + get context variable if available + + Args: + name(str): name of the context variable + + Returns: + T | None: context variable if available and ``None`` otherwise + """ + if (variable := cls._context.get(name)) is not None: + return cast(T | None, variable.get()) + return None + + @classmethod + def log_record_factory(cls, *args: Any, **kwargs: Any) -> logging.LogRecord: + """ + log record factory which injects all registered context variables into log records + + Args: + *args(Any): positional arguments for the log factory + **kwargs(Any): keyword arguments for the log factory + + Returns: + logging.LogRecord: log record with context variables set as attributes + """ + record = logging.LogRecord(*args, **kwargs) + + for name, variable in cls._context.items(): + if (value := variable.get()) is not None: + setattr(record, name, value) + + return record + + @classmethod + def register(cls, name: str) -> contextvars.ContextVar[T]: + """ + (re)create context variable for log records + + Args: + name(str): name of the context variable + + Returns: + contextvars.ContextVar[T]: created context variable + """ + variable = cls._context[name] = contextvars.ContextVar(name, default=None) + return variable + + @classmethod + def reset(cls, name: str, token: contextvars.Token[T]) -> None: + """ + reset context variable to its previous value + + Args: + name(str): attribute name to reset on log records + token(contextvars.Token[T]): previously registered token + """ + cls._context[name].reset(token) + + @classmethod + def set(cls, name: str, value: T) -> contextvars.Token[T]: + """ + set context variable for log records. This value will be automatically emitted with each log record + + Args: + name(str): attribute name to set on log records + value(T): current value of the context variable + + Returns: + contextvars.Token[T]: token created with this value + """ + return cls._context[name].set(value) diff --git a/src/ahriman/core/log/log_loader.py b/src/ahriman/core/log/log_loader.py index e6e1ea11..6662c43a 100644 --- a/src/ahriman/core/log/log_loader.py +++ b/src/ahriman/core/log/log_loader.py @@ -21,10 +21,11 @@ import logging from logging.config import fileConfig from pathlib import Path -from typing import ClassVar +from typing import ClassVar, Literal from ahriman.core.configuration import Configuration from ahriman.core.log.http_log_handler import HttpLogHandler +from ahriman.core.log.log_context import LogContext from ahriman.models.log_handler import LogHandler from ahriman.models.repository_id import RepositoryId @@ -36,11 +37,13 @@ class LogLoader: Attributes: DEFAULT_LOG_FORMAT(str): (class attribute) default log format (in case of fallback) DEFAULT_LOG_LEVEL(int): (class attribute) default log level (in case of fallback) + DEFAULT_LOG_STYLE(str): (class attribute) default log style (in case of fallback) DEFAULT_SYSLOG_DEVICE(Path): (class attribute) default path to syslog device """ - DEFAULT_LOG_FORMAT: ClassVar[str] = "[%(levelname)s %(asctime)s] [%(name)s]: %(message)s" + DEFAULT_LOG_FORMAT: ClassVar[str] = "[{levelname} {asctime}] [{name}]: {message}" DEFAULT_LOG_LEVEL: ClassVar[int] = logging.DEBUG + DEFAULT_LOG_STYLE: ClassVar[Literal["%", "{", "$"]] = "{" DEFAULT_SYSLOG_DEVICE: ClassVar[Path] = Path("/") / "dev" / "log" @staticmethod @@ -100,10 +103,22 @@ class LogLoader: fileConfig(log_configuration, disable_existing_loggers=True) logging.debug("using %s logger", default_handler) except Exception: - logging.basicConfig(filename=None, format=LogLoader.DEFAULT_LOG_FORMAT, level=LogLoader.DEFAULT_LOG_LEVEL) + logging.basicConfig(filename=None, format=LogLoader.DEFAULT_LOG_FORMAT, + style=LogLoader.DEFAULT_LOG_STYLE, level=LogLoader.DEFAULT_LOG_LEVEL) logging.exception("could not load logging from configuration, fallback to stderr") HttpLogHandler.load(repository_id, configuration, report=report) + LogLoader.register_context() if quiet: logging.disable(logging.WARNING) # only print errors here + + @staticmethod + def register_context() -> None: + """ + register logging context + """ + # predefined context variables + for variable in ("package_id", "request_id"): + LogContext.register(variable) + logging.setLogRecordFactory(LogContext.log_record_factory) diff --git a/src/ahriman/web/middlewares/request_id_handler.py b/src/ahriman/web/middlewares/request_id_handler.py new file mode 100644 index 00000000..2fb93942 --- /dev/null +++ b/src/ahriman/web/middlewares/request_id_handler.py @@ -0,0 +1,51 @@ +# +# Copyright (c) 2021-2026 ahriman team. +# +# This file is part of ahriman +# (see https://github.com/arcan1s/ahriman). +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +import uuid + +from aiohttp.typedefs import Middleware +from aiohttp.web import Request, StreamResponse, middleware + +from ahriman.core.log.log_context import LogContext +from ahriman.web.middlewares import HandlerType + + +__all__ = ["request_id_handler"] + + +def request_id_handler() -> Middleware: + """ + middleware to trace request id header + + Returns: + Middleware: request id processing middleware + """ + @middleware + async def handle(request: Request, handler: HandlerType) -> StreamResponse: + request_id = request.headers.getone("X-Request-ID", str(uuid.uuid4())) + + token = LogContext.set("request_id", request_id) + try: + response = await handler(request) + response.headers["X-Request-ID"] = request_id + return response + finally: + LogContext.reset("request_id", token) + + return handle diff --git a/src/ahriman/web/web.py b/src/ahriman/web/web.py index 70676097..348ab66f 100644 --- a/src/ahriman/web/web.py +++ b/src/ahriman/web/web.py @@ -38,6 +38,7 @@ from ahriman.web.cors import setup_cors from ahriman.web.keys import AuthKey, ConfigurationKey, SpawnKey, WatcherKey, WorkersKey from ahriman.web.middlewares.exception_handler import exception_handler from ahriman.web.middlewares.metrics_handler import metrics_handler +from ahriman.web.middlewares.request_id_handler import request_id_handler from ahriman.web.routes import setup_routes @@ -146,6 +147,7 @@ def setup_server(configuration: Configuration, spawner: Spawn, repositories: lis application.on_startup.append(_on_startup) application.middlewares.append(normalize_path_middleware(append_slash=False, remove_slash=True)) + application.middlewares.append(request_id_handler()) application.middlewares.append(exception_handler(application.logger)) application.middlewares.append(metrics_handler()) diff --git a/tests/ahriman/conftest.py b/tests/ahriman/conftest.py index 6f5b427a..6c2c0975 100644 --- a/tests/ahriman/conftest.py +++ b/tests/ahriman/conftest.py @@ -14,6 +14,7 @@ from ahriman.core.auth import Auth from ahriman.core.configuration import Configuration from ahriman.core.database import SQLite from ahriman.core.database.migrations import Migrations +from ahriman.core.log.log_loader import LogLoader from ahriman.core.repository import Repository from ahriman.core.spawn import Spawn from ahriman.core.status import Client @@ -124,6 +125,14 @@ def import_error(package: str, components: list[str], mocker: MockerFixture) -> # generic fixtures +@pytest.fixture(autouse=True) +def _register_log_context() -> None: + """ + register log context variables and factory + """ + LogLoader.register_context() + + @pytest.fixture def aur_package_ahriman() -> AURPackage: """ diff --git a/tests/ahriman/core/http/test_sync_ahriman_client.py b/tests/ahriman/core/http/test_sync_ahriman_client.py index c9c96d4c..e1f0ea51 100644 --- a/tests/ahriman/core/http/test_sync_ahriman_client.py +++ b/tests/ahriman/core/http/test_sync_ahriman_client.py @@ -30,6 +30,13 @@ def test_login_url(ahriman_client: SyncAhrimanClient) -> None: assert ahriman_client._login_url().endswith("/api/v1/login") +def test_headers(ahriman_client: SyncAhrimanClient) -> None: + """ + must inject request id header + """ + assert "X-Request-ID" in ahriman_client.headers() + + def test_on_session_creation(ahriman_client: SyncAhrimanClient, user: User, mocker: MockerFixture) -> None: """ must log in user on start diff --git a/tests/ahriman/core/http/test_sync_http_client.py b/tests/ahriman/core/http/test_sync_http_client.py index 7fac1e3d..86daae20 100644 --- a/tests/ahriman/core/http/test_sync_http_client.py +++ b/tests/ahriman/core/http/test_sync_http_client.py @@ -94,6 +94,13 @@ def test_adapters() -> None: assert all(adapter.max_retries == client.retry for adapter in adapters.values()) +def test_headers() -> None: + """ + must return empty additional headers + """ + assert SyncHttpClient().headers() == {} + + def test_make_request(mocker: MockerFixture) -> None: """ must make HTTP request @@ -194,6 +201,20 @@ def test_make_request_session() -> None: stream=None, auth=None, timeout=client.timeout) +def test_make_request_with_additional_headers(mocker: MockerFixture) -> None: + """ + must merge additional headers into request + """ + request_mock = mocker.patch("requests.Session.request") + mocker.patch("ahriman.core.http.sync_http_client.SyncHttpClient.headers", return_value={"X-Custom": "value"}) + client = SyncHttpClient() + + client.make_request("GET", "url") + request_mock.assert_called_once_with( + "GET", "url", params=None, data=None, headers={"X-Custom": "value"}, files=None, json=None, + stream=None, auth=None, timeout=client.timeout) + + def test_on_session_creation() -> None: """ must do nothing on start diff --git a/tests/ahriman/core/log/test_lazy_logging.py b/tests/ahriman/core/log/test_lazy_logging.py index eb2579de..c2414076 100644 --- a/tests/ahriman/core/log/test_lazy_logging.py +++ b/tests/ahriman/core/log/test_lazy_logging.py @@ -1,8 +1,6 @@ import logging import pytest -from pytest_mock import MockerFixture - from ahriman.core.alpm.repo import Repo from ahriman.core.build_tools.task import Task from ahriman.core.database import SQLite @@ -30,59 +28,46 @@ def test_logger_name(database: SQLite, repo: Repo, task_ahriman: Task) -> None: assert task_ahriman.logger_name == "ahriman.core.build_tools.task.Task" -def test_package_logger_set_reset(database: SQLite) -> None: +def test_in_context(database: SQLite) -> None: """ - must set and reset package base attribute + must set and reset generic log context """ - log_record_id = LogRecordId("base", "version") + with database.in_context("package_id", "42"): + record = logging.makeLogRecord({}) + assert record.package_id == "42" - database._package_logger_set(log_record_id.package_base, log_record_id.version) record = logging.makeLogRecord({}) - assert record.package_id == log_record_id + assert not hasattr(record, "package_id") + + +def test_in_context_failed(database: SQLite) -> None: + """ + must reset context even if exception occurs + """ + with pytest.raises(ValueError): + with database.in_context("package_id", "42"): + raise ValueError() - database._package_logger_reset() record = logging.makeLogRecord({}) - with pytest.raises(AttributeError): - assert record.package_id + assert not hasattr(record, "package_id") -def test_in_package_context(database: SQLite, package_ahriman: Package, mocker: MockerFixture) -> None: +def test_in_package_context(database: SQLite, package_ahriman: Package) -> None: """ must set package log context """ - set_mock = mocker.patch("ahriman.core.log.LazyLogging._package_logger_set") - reset_mock = mocker.patch("ahriman.core.log.LazyLogging._package_logger_reset") - with database.in_package_context(package_ahriman.base, package_ahriman.version): - pass + record = logging.makeLogRecord({}) + assert record.package_id == LogRecordId(package_ahriman.base, package_ahriman.version) - set_mock.assert_called_once_with(package_ahriman.base, package_ahriman.version) - reset_mock.assert_called_once_with() + record = logging.makeLogRecord({}) + assert not hasattr(record, "package_id") -def test_in_package_context_empty_version(database: SQLite, package_ahriman: Package, mocker: MockerFixture) -> None: +def test_in_package_context_empty_version(database: SQLite, package_ahriman: Package) -> None: """ must set package log context with empty version """ - set_mock = mocker.patch("ahriman.core.log.LazyLogging._package_logger_set") - reset_mock = mocker.patch("ahriman.core.log.LazyLogging._package_logger_reset") - with database.in_package_context(package_ahriman.base, None): - pass - - set_mock.assert_called_once_with(package_ahriman.base, None) - reset_mock.assert_called_once_with() - - -def test_in_package_context_failed(database: SQLite, package_ahriman: Package, mocker: MockerFixture) -> None: - """ - must reset package context even if exception occurs - """ - mocker.patch("ahriman.core.log.LazyLogging._package_logger_set") - reset_mock = mocker.patch("ahriman.core.log.LazyLogging._package_logger_reset") - - with pytest.raises(ValueError): - with database.in_package_context(package_ahriman.base, ""): - raise ValueError() - - reset_mock.assert_called_once_with() + record = logging.makeLogRecord({}) + assert record.package_id == LogRecordId(package_ahriman.base, "") diff --git a/tests/ahriman/core/log/test_log_context.py b/tests/ahriman/core/log/test_log_context.py new file mode 100644 index 00000000..0b6e55e3 --- /dev/null +++ b/tests/ahriman/core/log/test_log_context.py @@ -0,0 +1,75 @@ +import logging + +from ahriman.core.log.log_context import LogContext + + +def test_get() -> None: + """ + must get context variable value + """ + token = LogContext.set("package_id", "value") + assert LogContext.get("package_id") == "value" + LogContext.reset("package_id", token) + + +def test_get_empty() -> None: + """ + must return None when context variable is unknown or not set + """ + assert LogContext.get("package_id") is None + assert LogContext.get("random") is None + + +def test_log_record_factory() -> None: + """ + must inject all registered context variables into log records + """ + package_token = LogContext.set("package_id", "package") + + record = logging.makeLogRecord({}) + assert record.package_id == "package" + + LogContext.reset("package_id", package_token) + + +def test_log_record_factory_empty() -> None: + """ + must not inject context variable when value is None + """ + record = logging.makeLogRecord({}) + assert not hasattr(record, "package_id") + + +def test_register() -> None: + """ + must register a context variable + """ + variable = LogContext.register("random") + + assert "random" in LogContext._context + assert LogContext._context["random"] is variable + + del LogContext._context["random"] + + +def test_reset() -> None: + """ + must reset context variable so it is no longer injected + """ + token = LogContext.set("package_id", "value") + LogContext.reset("package_id", token) + + record = logging.makeLogRecord({}) + assert not hasattr(record, "package_id") + + +def test_set() -> None: + """ + must set context variable and inject it into log records + """ + token = LogContext.set("package_id", "value") + + record = logging.makeLogRecord({}) + assert record.package_id == "value" + + LogContext.reset("package_id", token) diff --git a/tests/ahriman/core/log/test_log_loader.py b/tests/ahriman/core/log/test_log_loader.py index b7dd0c3d..40b7591c 100644 --- a/tests/ahriman/core/log/test_log_loader.py +++ b/tests/ahriman/core/log/test_log_loader.py @@ -7,6 +7,7 @@ from pytest_mock import MockerFixture from systemd.journal import JournalHandler from ahriman.core.configuration import Configuration +from ahriman.core.log.log_context import LogContext from ahriman.core.log.log_loader import LogLoader from ahriman.models.log_handler import LogHandler @@ -75,3 +76,13 @@ def test_load_quiet(configuration: Configuration, mocker: MockerFixture) -> None _, repository_id = configuration.check_loaded() LogLoader.load(repository_id, configuration, LogHandler.Journald, quiet=True, report=False) disable_mock.assert_called_once_with(logging.WARNING) + + +def test_register_context() -> None: + """ + must register predefined context variables and install log record factory + """ + LogLoader.register_context() + assert "package_id" in LogContext._context + assert "request_id" in LogContext._context + assert logging.getLogRecordFactory().__func__ is LogContext.log_record_factory.__func__ diff --git a/tests/ahriman/web/middlewares/test_request_id_handler.py b/tests/ahriman/web/middlewares/test_request_id_handler.py new file mode 100644 index 00000000..a0ef8957 --- /dev/null +++ b/tests/ahriman/web/middlewares/test_request_id_handler.py @@ -0,0 +1,43 @@ +import logging +import pytest + +from unittest.mock import AsyncMock, MagicMock +from typing import Any + +from ahriman.web.middlewares.request_id_handler import request_id_handler + + +async def test_request_id_handler() -> None: + """ + must use request id from request if available + """ + request = pytest.helpers.request("", "", "") + request.headers = MagicMock() + request.headers.getone.return_value = "request_id" + + response = MagicMock() + response.headers = {} + + async def check_handler(_: Any) -> MagicMock: + record = logging.makeLogRecord({}) + assert record.request_id == "request_id" + return response + + handler = request_id_handler() + await handler(request, check_handler) + assert response.headers["X-Request-ID"] == "request_id" + + +async def test_request_id_handler_generate() -> None: + """ + must generate request id and set it in response header + """ + request = pytest.helpers.request("", "", "") + + response = MagicMock() + response.headers = {} + request_handler = AsyncMock(return_value=response) + + handler = request_id_handler() + await handler(request, request_handler) + assert "X-Request-ID" in response.headers