mirror of
https://github.com/arcan1s/ahriman.git
synced 2026-03-09 11:43:39 +00:00
feat: support request-id header
This commit is contained in:
@@ -28,6 +28,14 @@ ahriman.core.log.lazy\_logging module
|
||||
:no-undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
ahriman.core.log.log\_context module
|
||||
------------------------------------
|
||||
|
||||
.. automodule:: ahriman.core.log.log_context
|
||||
:members:
|
||||
:no-undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
ahriman.core.log.log\_loader module
|
||||
-----------------------------------
|
||||
|
||||
|
||||
@@ -28,6 +28,14 @@ ahriman.web.middlewares.metrics\_handler module
|
||||
:no-undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
ahriman.web.middlewares.request\_id\_handler module
|
||||
---------------------------------------------------
|
||||
|
||||
.. automodule:: ahriman.web.middlewares.request_id_handler
|
||||
:members:
|
||||
:no-undoc-members:
|
||||
:show-inheritance:
|
||||
|
||||
Module contents
|
||||
---------------
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ export class Client {
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
Accept: "application/json",
|
||||
"X-Request-ID": crypto.randomUUID(),
|
||||
};
|
||||
if (json !== undefined) {
|
||||
headers["Content-Type"] = "application/json";
|
||||
|
||||
@@ -26,10 +26,12 @@ formatter = syslog_format
|
||||
args = ("/dev/log",)
|
||||
|
||||
[formatter_generic_format]
|
||||
format = [%(levelname)s %(asctime)s] [%(name)s]: %(message)s
|
||||
format = [{levelname} {asctime}] [{name}]: {message}
|
||||
style = {
|
||||
|
||||
[formatter_syslog_format]
|
||||
format = [%(levelname)s] [%(name)s]: %(message)s
|
||||
format = [{levelname}] [{name}]: {message}
|
||||
style = {
|
||||
|
||||
[logger_root]
|
||||
level = DEBUG
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#
|
||||
import contextlib
|
||||
import requests
|
||||
import uuid
|
||||
|
||||
from requests.adapters import BaseAdapter
|
||||
from urllib.parse import urlparse
|
||||
@@ -60,6 +61,15 @@ class SyncAhrimanClient(SyncHttpClient):
|
||||
|
||||
return adapters
|
||||
|
||||
def headers(self) -> dict[str, str]:
|
||||
"""
|
||||
additional request headers
|
||||
|
||||
Returns:
|
||||
dict[str, str]: additional request headers defined by class
|
||||
"""
|
||||
return SyncHttpClient.headers(self) | {"X-Request-ID": str(uuid.uuid4())}
|
||||
|
||||
def on_session_creation(self, session: requests.Session) -> None:
|
||||
"""
|
||||
method which will be called on session creation
|
||||
|
||||
@@ -144,6 +144,15 @@ class SyncHttpClient(LazyLogging):
|
||||
"https://": HTTPAdapter(max_retries=self.retry),
|
||||
}
|
||||
|
||||
def headers(self) -> dict[str, str]:
|
||||
"""
|
||||
additional request headers
|
||||
|
||||
Returns:
|
||||
dict[str, str]: additional request headers defined by class
|
||||
"""
|
||||
return {}
|
||||
|
||||
def make_request(self, method: Literal["DELETE", "GET", "HEAD", "POST", "PUT"], url: str, *,
|
||||
headers: dict[str, str] | None = None,
|
||||
params: list[tuple[str, str]] | None = None,
|
||||
@@ -178,6 +187,9 @@ class SyncHttpClient(LazyLogging):
|
||||
if session is None:
|
||||
session = self.session
|
||||
|
||||
if additional_headers := self.headers():
|
||||
headers = additional_headers | (headers or {})
|
||||
|
||||
try:
|
||||
response = session.request(method, url, params=params, data=data, headers=headers, files=files, json=json,
|
||||
stream=stream, auth=self.auth, timeout=self.timeout)
|
||||
|
||||
@@ -24,6 +24,7 @@ from collections.abc import Iterator
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
from ahriman.core.log.log_context import LogContext
|
||||
from ahriman.models.log_record_id import LogRecordId
|
||||
|
||||
|
||||
@@ -54,30 +55,20 @@ class LazyLogging:
|
||||
prefix = "" if clazz.__module__ is None else f"{clazz.__module__}."
|
||||
return f"{prefix}{clazz.__qualname__}"
|
||||
|
||||
@staticmethod
|
||||
def _package_logger_reset() -> None:
|
||||
@contextlib.contextmanager
|
||||
def in_context(self, name: str, value: Any) -> Iterator[None]:
|
||||
"""
|
||||
reset package logger to empty one
|
||||
"""
|
||||
logging.setLogRecordFactory(logging.LogRecord)
|
||||
|
||||
@staticmethod
|
||||
def _package_logger_set(package_base: str, version: str | None) -> None:
|
||||
"""
|
||||
set package base as extra info to the logger
|
||||
execute function while setting log context. The context will be reset after the execution
|
||||
|
||||
Args:
|
||||
package_base(str): package base
|
||||
version(str | None): package version if available
|
||||
name(str): attribute name to set on log records
|
||||
value(Any): current value of the context variable
|
||||
"""
|
||||
current_factory = logging.getLogRecordFactory()
|
||||
|
||||
def package_record_factory(*args: Any, **kwargs: Any) -> logging.LogRecord:
|
||||
record = current_factory(*args, **kwargs)
|
||||
record.package_id = LogRecordId(package_base, version or "<unknown>")
|
||||
return record
|
||||
|
||||
logging.setLogRecordFactory(package_record_factory)
|
||||
token = LogContext.set(name, value)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
LogContext.reset(name, token)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def in_package_context(self, package_base: str, version: str | None) -> Iterator[None]:
|
||||
@@ -94,8 +85,5 @@ class LazyLogging:
|
||||
>>> with self.in_package_context(package.base, package.version):
|
||||
>>> build_package(package)
|
||||
"""
|
||||
try:
|
||||
self._package_logger_set(package_base, version)
|
||||
with self.in_context("package_id", LogRecordId(package_base, version or "<unknown>")):
|
||||
yield
|
||||
finally:
|
||||
self._package_logger_reset()
|
||||
|
||||
108
src/ahriman/core/log/log_context.py
Normal file
108
src/ahriman/core/log/log_context.py
Normal file
@@ -0,0 +1,108 @@
|
||||
#
|
||||
# Copyright (c) 2021-2026 ahriman team.
|
||||
#
|
||||
# This file is part of ahriman
|
||||
# (see https://github.com/arcan1s/ahriman).
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
import contextvars
|
||||
import logging
|
||||
|
||||
from typing import Any, ClassVar, TypeVar, cast
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class LogContext:
|
||||
"""
|
||||
logging context manager which provides context variables injection into log records
|
||||
"""
|
||||
|
||||
_context: ClassVar[dict[str, contextvars.ContextVar[Any]]] = {}
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> T | None:
|
||||
"""
|
||||
get context variable if available
|
||||
|
||||
Args:
|
||||
name(str): name of the context variable
|
||||
|
||||
Returns:
|
||||
T | None: context variable if available and ``None`` otherwise
|
||||
"""
|
||||
if (variable := cls._context.get(name)) is not None:
|
||||
return cast(T | None, variable.get())
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def log_record_factory(cls, *args: Any, **kwargs: Any) -> logging.LogRecord:
|
||||
"""
|
||||
log record factory which injects all registered context variables into log records
|
||||
|
||||
Args:
|
||||
*args(Any): positional arguments for the log factory
|
||||
**kwargs(Any): keyword arguments for the log factory
|
||||
|
||||
Returns:
|
||||
logging.LogRecord: log record with context variables set as attributes
|
||||
"""
|
||||
record = logging.LogRecord(*args, **kwargs)
|
||||
|
||||
for name, variable in cls._context.items():
|
||||
if (value := variable.get()) is not None:
|
||||
setattr(record, name, value)
|
||||
|
||||
return record
|
||||
|
||||
@classmethod
|
||||
def register(cls, name: str) -> contextvars.ContextVar[T]:
|
||||
"""
|
||||
(re)create context variable for log records
|
||||
|
||||
Args:
|
||||
name(str): name of the context variable
|
||||
|
||||
Returns:
|
||||
contextvars.ContextVar[T]: created context variable
|
||||
"""
|
||||
variable = cls._context[name] = contextvars.ContextVar(name, default=None)
|
||||
return variable
|
||||
|
||||
@classmethod
|
||||
def reset(cls, name: str, token: contextvars.Token[T]) -> None:
|
||||
"""
|
||||
reset context variable to its previous value
|
||||
|
||||
Args:
|
||||
name(str): attribute name to reset on log records
|
||||
token(contextvars.Token[T]): previously registered token
|
||||
"""
|
||||
cls._context[name].reset(token)
|
||||
|
||||
@classmethod
|
||||
def set(cls, name: str, value: T) -> contextvars.Token[T]:
|
||||
"""
|
||||
set context variable for log records. This value will be automatically emitted with each log record
|
||||
|
||||
Args:
|
||||
name(str): attribute name to set on log records
|
||||
value(T): current value of the context variable
|
||||
|
||||
Returns:
|
||||
contextvars.Token[T]: token created with this value
|
||||
"""
|
||||
return cls._context[name].set(value)
|
||||
@@ -21,10 +21,11 @@ import logging
|
||||
|
||||
from logging.config import fileConfig
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
from typing import ClassVar, Literal
|
||||
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.core.log.http_log_handler import HttpLogHandler
|
||||
from ahriman.core.log.log_context import LogContext
|
||||
from ahriman.models.log_handler import LogHandler
|
||||
from ahriman.models.repository_id import RepositoryId
|
||||
|
||||
@@ -36,11 +37,13 @@ class LogLoader:
|
||||
Attributes:
|
||||
DEFAULT_LOG_FORMAT(str): (class attribute) default log format (in case of fallback)
|
||||
DEFAULT_LOG_LEVEL(int): (class attribute) default log level (in case of fallback)
|
||||
DEFAULT_LOG_STYLE(str): (class attribute) default log style (in case of fallback)
|
||||
DEFAULT_SYSLOG_DEVICE(Path): (class attribute) default path to syslog device
|
||||
"""
|
||||
|
||||
DEFAULT_LOG_FORMAT: ClassVar[str] = "[%(levelname)s %(asctime)s] [%(name)s]: %(message)s"
|
||||
DEFAULT_LOG_FORMAT: ClassVar[str] = "[{levelname} {asctime}] [{name}]: {message}"
|
||||
DEFAULT_LOG_LEVEL: ClassVar[int] = logging.DEBUG
|
||||
DEFAULT_LOG_STYLE: ClassVar[Literal["%", "{", "$"]] = "{"
|
||||
DEFAULT_SYSLOG_DEVICE: ClassVar[Path] = Path("/") / "dev" / "log"
|
||||
|
||||
@staticmethod
|
||||
@@ -100,10 +103,22 @@ class LogLoader:
|
||||
fileConfig(log_configuration, disable_existing_loggers=True)
|
||||
logging.debug("using %s logger", default_handler)
|
||||
except Exception:
|
||||
logging.basicConfig(filename=None, format=LogLoader.DEFAULT_LOG_FORMAT, level=LogLoader.DEFAULT_LOG_LEVEL)
|
||||
logging.basicConfig(filename=None, format=LogLoader.DEFAULT_LOG_FORMAT,
|
||||
style=LogLoader.DEFAULT_LOG_STYLE, level=LogLoader.DEFAULT_LOG_LEVEL)
|
||||
logging.exception("could not load logging from configuration, fallback to stderr")
|
||||
|
||||
HttpLogHandler.load(repository_id, configuration, report=report)
|
||||
LogLoader.register_context()
|
||||
|
||||
if quiet:
|
||||
logging.disable(logging.WARNING) # only print errors here
|
||||
|
||||
@staticmethod
|
||||
def register_context() -> None:
|
||||
"""
|
||||
register logging context
|
||||
"""
|
||||
# predefined context variables
|
||||
for variable in ("package_id", "request_id"):
|
||||
LogContext.register(variable)
|
||||
logging.setLogRecordFactory(LogContext.log_record_factory)
|
||||
|
||||
51
src/ahriman/web/middlewares/request_id_handler.py
Normal file
51
src/ahriman/web/middlewares/request_id_handler.py
Normal file
@@ -0,0 +1,51 @@
|
||||
#
|
||||
# Copyright (c) 2021-2026 ahriman team.
|
||||
#
|
||||
# This file is part of ahriman
|
||||
# (see https://github.com/arcan1s/ahriman).
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
import uuid
|
||||
|
||||
from aiohttp.typedefs import Middleware
|
||||
from aiohttp.web import Request, StreamResponse, middleware
|
||||
|
||||
from ahriman.core.log.log_context import LogContext
|
||||
from ahriman.web.middlewares import HandlerType
|
||||
|
||||
|
||||
__all__ = ["request_id_handler"]
|
||||
|
||||
|
||||
def request_id_handler() -> Middleware:
|
||||
"""
|
||||
middleware to trace request id header
|
||||
|
||||
Returns:
|
||||
Middleware: request id processing middleware
|
||||
"""
|
||||
@middleware
|
||||
async def handle(request: Request, handler: HandlerType) -> StreamResponse:
|
||||
request_id = request.headers.getone("X-Request-ID", str(uuid.uuid4()))
|
||||
|
||||
token = LogContext.set("request_id", request_id)
|
||||
try:
|
||||
response = await handler(request)
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
return response
|
||||
finally:
|
||||
LogContext.reset("request_id", token)
|
||||
|
||||
return handle
|
||||
@@ -38,6 +38,7 @@ from ahriman.web.cors import setup_cors
|
||||
from ahriman.web.keys import AuthKey, ConfigurationKey, SpawnKey, WatcherKey, WorkersKey
|
||||
from ahriman.web.middlewares.exception_handler import exception_handler
|
||||
from ahriman.web.middlewares.metrics_handler import metrics_handler
|
||||
from ahriman.web.middlewares.request_id_handler import request_id_handler
|
||||
from ahriman.web.routes import setup_routes
|
||||
|
||||
|
||||
@@ -146,6 +147,7 @@ def setup_server(configuration: Configuration, spawner: Spawn, repositories: lis
|
||||
application.on_startup.append(_on_startup)
|
||||
|
||||
application.middlewares.append(normalize_path_middleware(append_slash=False, remove_slash=True))
|
||||
application.middlewares.append(request_id_handler())
|
||||
application.middlewares.append(exception_handler(application.logger))
|
||||
application.middlewares.append(metrics_handler())
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from ahriman.core.auth import Auth
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.core.database import SQLite
|
||||
from ahriman.core.database.migrations import Migrations
|
||||
from ahriman.core.log.log_loader import LogLoader
|
||||
from ahriman.core.repository import Repository
|
||||
from ahriman.core.spawn import Spawn
|
||||
from ahriman.core.status import Client
|
||||
@@ -124,6 +125,14 @@ def import_error(package: str, components: list[str], mocker: MockerFixture) ->
|
||||
|
||||
|
||||
# generic fixtures
|
||||
@pytest.fixture(autouse=True)
|
||||
def _register_log_context() -> None:
|
||||
"""
|
||||
register log context variables and factory
|
||||
"""
|
||||
LogLoader.register_context()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aur_package_ahriman() -> AURPackage:
|
||||
"""
|
||||
|
||||
@@ -30,6 +30,13 @@ def test_login_url(ahriman_client: SyncAhrimanClient) -> None:
|
||||
assert ahriman_client._login_url().endswith("/api/v1/login")
|
||||
|
||||
|
||||
def test_headers(ahriman_client: SyncAhrimanClient) -> None:
|
||||
"""
|
||||
must inject request id header
|
||||
"""
|
||||
assert "X-Request-ID" in ahriman_client.headers()
|
||||
|
||||
|
||||
def test_on_session_creation(ahriman_client: SyncAhrimanClient, user: User, mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must log in user on start
|
||||
|
||||
@@ -94,6 +94,13 @@ def test_adapters() -> None:
|
||||
assert all(adapter.max_retries == client.retry for adapter in adapters.values())
|
||||
|
||||
|
||||
def test_headers() -> None:
|
||||
"""
|
||||
must return empty additional headers
|
||||
"""
|
||||
assert SyncHttpClient().headers() == {}
|
||||
|
||||
|
||||
def test_make_request(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must make HTTP request
|
||||
@@ -194,6 +201,20 @@ def test_make_request_session() -> None:
|
||||
stream=None, auth=None, timeout=client.timeout)
|
||||
|
||||
|
||||
def test_make_request_with_additional_headers(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must merge additional headers into request
|
||||
"""
|
||||
request_mock = mocker.patch("requests.Session.request")
|
||||
mocker.patch("ahriman.core.http.sync_http_client.SyncHttpClient.headers", return_value={"X-Custom": "value"})
|
||||
client = SyncHttpClient()
|
||||
|
||||
client.make_request("GET", "url")
|
||||
request_mock.assert_called_once_with(
|
||||
"GET", "url", params=None, data=None, headers={"X-Custom": "value"}, files=None, json=None,
|
||||
stream=None, auth=None, timeout=client.timeout)
|
||||
|
||||
|
||||
def test_on_session_creation() -> None:
|
||||
"""
|
||||
must do nothing on start
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import logging
|
||||
import pytest
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from ahriman.core.alpm.repo import Repo
|
||||
from ahriman.core.build_tools.task import Task
|
||||
from ahriman.core.database import SQLite
|
||||
@@ -30,59 +28,46 @@ def test_logger_name(database: SQLite, repo: Repo, task_ahriman: Task) -> None:
|
||||
assert task_ahriman.logger_name == "ahriman.core.build_tools.task.Task"
|
||||
|
||||
|
||||
def test_package_logger_set_reset(database: SQLite) -> None:
|
||||
def test_in_context(database: SQLite) -> None:
|
||||
"""
|
||||
must set and reset package base attribute
|
||||
must set and reset generic log context
|
||||
"""
|
||||
log_record_id = LogRecordId("base", "version")
|
||||
with database.in_context("package_id", "42"):
|
||||
record = logging.makeLogRecord({})
|
||||
assert record.package_id == "42"
|
||||
|
||||
database._package_logger_set(log_record_id.package_base, log_record_id.version)
|
||||
record = logging.makeLogRecord({})
|
||||
assert record.package_id == log_record_id
|
||||
assert not hasattr(record, "package_id")
|
||||
|
||||
|
||||
def test_in_context_failed(database: SQLite) -> None:
|
||||
"""
|
||||
must reset context even if exception occurs
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
with database.in_context("package_id", "42"):
|
||||
raise ValueError()
|
||||
|
||||
database._package_logger_reset()
|
||||
record = logging.makeLogRecord({})
|
||||
with pytest.raises(AttributeError):
|
||||
assert record.package_id
|
||||
assert not hasattr(record, "package_id")
|
||||
|
||||
|
||||
def test_in_package_context(database: SQLite, package_ahriman: Package, mocker: MockerFixture) -> None:
|
||||
def test_in_package_context(database: SQLite, package_ahriman: Package) -> None:
|
||||
"""
|
||||
must set package log context
|
||||
"""
|
||||
set_mock = mocker.patch("ahriman.core.log.LazyLogging._package_logger_set")
|
||||
reset_mock = mocker.patch("ahriman.core.log.LazyLogging._package_logger_reset")
|
||||
|
||||
with database.in_package_context(package_ahriman.base, package_ahriman.version):
|
||||
pass
|
||||
record = logging.makeLogRecord({})
|
||||
assert record.package_id == LogRecordId(package_ahriman.base, package_ahriman.version)
|
||||
|
||||
set_mock.assert_called_once_with(package_ahriman.base, package_ahriman.version)
|
||||
reset_mock.assert_called_once_with()
|
||||
record = logging.makeLogRecord({})
|
||||
assert not hasattr(record, "package_id")
|
||||
|
||||
|
||||
def test_in_package_context_empty_version(database: SQLite, package_ahriman: Package, mocker: MockerFixture) -> None:
|
||||
def test_in_package_context_empty_version(database: SQLite, package_ahriman: Package) -> None:
|
||||
"""
|
||||
must set package log context with empty version
|
||||
"""
|
||||
set_mock = mocker.patch("ahriman.core.log.LazyLogging._package_logger_set")
|
||||
reset_mock = mocker.patch("ahriman.core.log.LazyLogging._package_logger_reset")
|
||||
|
||||
with database.in_package_context(package_ahriman.base, None):
|
||||
pass
|
||||
|
||||
set_mock.assert_called_once_with(package_ahriman.base, None)
|
||||
reset_mock.assert_called_once_with()
|
||||
|
||||
|
||||
def test_in_package_context_failed(database: SQLite, package_ahriman: Package, mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must reset package context even if exception occurs
|
||||
"""
|
||||
mocker.patch("ahriman.core.log.LazyLogging._package_logger_set")
|
||||
reset_mock = mocker.patch("ahriman.core.log.LazyLogging._package_logger_reset")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with database.in_package_context(package_ahriman.base, ""):
|
||||
raise ValueError()
|
||||
|
||||
reset_mock.assert_called_once_with()
|
||||
record = logging.makeLogRecord({})
|
||||
assert record.package_id == LogRecordId(package_ahriman.base, "<unknown>")
|
||||
|
||||
75
tests/ahriman/core/log/test_log_context.py
Normal file
75
tests/ahriman/core/log/test_log_context.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import logging
|
||||
|
||||
from ahriman.core.log.log_context import LogContext
|
||||
|
||||
|
||||
def test_get() -> None:
|
||||
"""
|
||||
must get context variable value
|
||||
"""
|
||||
token = LogContext.set("package_id", "value")
|
||||
assert LogContext.get("package_id") == "value"
|
||||
LogContext.reset("package_id", token)
|
||||
|
||||
|
||||
def test_get_empty() -> None:
|
||||
"""
|
||||
must return None when context variable is unknown or not set
|
||||
"""
|
||||
assert LogContext.get("package_id") is None
|
||||
assert LogContext.get("random") is None
|
||||
|
||||
|
||||
def test_log_record_factory() -> None:
|
||||
"""
|
||||
must inject all registered context variables into log records
|
||||
"""
|
||||
package_token = LogContext.set("package_id", "package")
|
||||
|
||||
record = logging.makeLogRecord({})
|
||||
assert record.package_id == "package"
|
||||
|
||||
LogContext.reset("package_id", package_token)
|
||||
|
||||
|
||||
def test_log_record_factory_empty() -> None:
|
||||
"""
|
||||
must not inject context variable when value is None
|
||||
"""
|
||||
record = logging.makeLogRecord({})
|
||||
assert not hasattr(record, "package_id")
|
||||
|
||||
|
||||
def test_register() -> None:
|
||||
"""
|
||||
must register a context variable
|
||||
"""
|
||||
variable = LogContext.register("random")
|
||||
|
||||
assert "random" in LogContext._context
|
||||
assert LogContext._context["random"] is variable
|
||||
|
||||
del LogContext._context["random"]
|
||||
|
||||
|
||||
def test_reset() -> None:
|
||||
"""
|
||||
must reset context variable so it is no longer injected
|
||||
"""
|
||||
token = LogContext.set("package_id", "value")
|
||||
LogContext.reset("package_id", token)
|
||||
|
||||
record = logging.makeLogRecord({})
|
||||
assert not hasattr(record, "package_id")
|
||||
|
||||
|
||||
def test_set() -> None:
|
||||
"""
|
||||
must set context variable and inject it into log records
|
||||
"""
|
||||
token = LogContext.set("package_id", "value")
|
||||
|
||||
record = logging.makeLogRecord({})
|
||||
assert record.package_id == "value"
|
||||
|
||||
LogContext.reset("package_id", token)
|
||||
@@ -7,6 +7,7 @@ from pytest_mock import MockerFixture
|
||||
from systemd.journal import JournalHandler
|
||||
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.core.log.log_context import LogContext
|
||||
from ahriman.core.log.log_loader import LogLoader
|
||||
from ahriman.models.log_handler import LogHandler
|
||||
|
||||
@@ -75,3 +76,13 @@ def test_load_quiet(configuration: Configuration, mocker: MockerFixture) -> None
|
||||
_, repository_id = configuration.check_loaded()
|
||||
LogLoader.load(repository_id, configuration, LogHandler.Journald, quiet=True, report=False)
|
||||
disable_mock.assert_called_once_with(logging.WARNING)
|
||||
|
||||
|
||||
def test_register_context() -> None:
|
||||
"""
|
||||
must register predefined context variables and install log record factory
|
||||
"""
|
||||
LogLoader.register_context()
|
||||
assert "package_id" in LogContext._context
|
||||
assert "request_id" in LogContext._context
|
||||
assert logging.getLogRecordFactory().__func__ is LogContext.log_record_factory.__func__
|
||||
|
||||
43
tests/ahriman/web/middlewares/test_request_id_handler.py
Normal file
43
tests/ahriman/web/middlewares/test_request_id_handler.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import logging
|
||||
import pytest
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from typing import Any
|
||||
|
||||
from ahriman.web.middlewares.request_id_handler import request_id_handler
|
||||
|
||||
|
||||
async def test_request_id_handler() -> None:
|
||||
"""
|
||||
must use request id from request if available
|
||||
"""
|
||||
request = pytest.helpers.request("", "", "")
|
||||
request.headers = MagicMock()
|
||||
request.headers.getone.return_value = "request_id"
|
||||
|
||||
response = MagicMock()
|
||||
response.headers = {}
|
||||
|
||||
async def check_handler(_: Any) -> MagicMock:
|
||||
record = logging.makeLogRecord({})
|
||||
assert record.request_id == "request_id"
|
||||
return response
|
||||
|
||||
handler = request_id_handler()
|
||||
await handler(request, check_handler)
|
||||
assert response.headers["X-Request-ID"] == "request_id"
|
||||
|
||||
|
||||
async def test_request_id_handler_generate() -> None:
|
||||
"""
|
||||
must generate request id and set it in response header
|
||||
"""
|
||||
request = pytest.helpers.request("", "", "")
|
||||
|
||||
response = MagicMock()
|
||||
response.headers = {}
|
||||
request_handler = AsyncMock(return_value=response)
|
||||
|
||||
handler = request_id_handler()
|
||||
await handler(request, request_handler)
|
||||
assert "X-Request-ID" in response.headers
|
||||
Reference in New Issue
Block a user