Compare commits

..

14 Commits

Author SHA1 Message Date
9ec566f095 feat: add retry policy 2026-02-20 02:49:32 +02:00
dec025b45a feat: raise OptionError on missing OAuth provider class instead of generic AttributeError 2026-02-19 10:19:03 +02:00
89008e5350 fix: use context manager for selector and smtp session 2026-02-19 10:19:03 +02:00
422196d413 fix: force data filter for tar archive extraction
(python3.14 default anyway)
2026-02-19 10:19:03 +02:00
6fe2eade26 feat: (more) secure cookies 2026-02-19 10:19:03 +02:00
5266f54257 fix: speedup table reload by updating only changed statuses
it has been found that on big (>100) repos it starts lagging on reload.
This commit adds guard to avoid updating rows whose package statuses
were not changed
2026-02-19 10:19:03 +02:00
bbf9e38fda Release 2.20.0rc1 2026-02-18 13:34:08 +02:00
ba80a91d95 feat: implement CSRF protection 2026-02-18 13:34:08 +02:00
536d040a6a feat: handle only unknownpackageerror on aur load 2026-02-18 13:34:08 +02:00
bed8752f3a fix: filter logs by repository (twice) before rotation 2026-02-18 13:34:05 +02:00
4093ca8986 fix: do not clear queue on queue fetch failures 2026-02-18 13:34:03 +02:00
f027155885 docs: correct docstring for list_flatmap method 2026-02-18 13:34:03 +02:00
443d4ae667 fix: correct vcs definition for cvs packages 2026-02-18 13:34:00 +02:00
c8f7fa8c51 fix: load gitremote triggers configuration schema from non-standard
paths
2026-02-18 13:33:57 +02:00
26 changed files with 295 additions and 114 deletions

View File

@@ -158,7 +158,9 @@ Reporting to web service related settings. In most cases there is fallback to we
* ``enabled`` - enable reporting to web service, boolean, optional, default ``yes`` for backward compatibility. * ``enabled`` - enable reporting to web service, boolean, optional, default ``yes`` for backward compatibility.
* ``address`` - remote web service address with protocol, string, optional. In case of websocket, the ``http+unix`` scheme and URL encoded address (e.g. ``%2Fvar%2Flib%2Fahriman`` for ``/var/lib/ahriman``) must be used, e.g. ``http+unix://%2Fvar%2Flib%2Fahriman%2Fsocket``. In case if none set, it will be guessed from ``web`` section. * ``address`` - remote web service address with protocol, string, optional. In case of websocket, the ``http+unix`` scheme and URL encoded address (e.g. ``%2Fvar%2Flib%2Fahriman`` for ``/var/lib/ahriman``) must be used, e.g. ``http+unix://%2Fvar%2Flib%2Fahriman%2Fsocket``. In case if none set, it will be guessed from ``web`` section.
* ``max_retries`` - maximum amount of retries of HTTP requests, integer, optional, default ``0``.
* ``password`` - password to authorize in web service in order to update service status, string, required in case if authorization enabled. * ``password`` - password to authorize in web service in order to update service status, string, required in case if authorization enabled.
* ``retry_backoff`` - retry exponential backoff, float, optional, default ``0.0``.
* ``suppress_http_log_errors`` - suppress HTTP log errors, boolean, optional, default ``no``. If set to ``yes``, any HTTP log errors (e.g. if web server is not available, but HTTP logging is enabled) will be suppressed. * ``suppress_http_log_errors`` - suppress HTTP log errors, boolean, optional, default ``no``. If set to ``yes``, any HTTP log errors (e.g. if web server is not available, but HTTP logging is enabled) will be suppressed.
* ``timeout`` - HTTP request timeout in seconds, integer, optional, default is ``30``. * ``timeout`` - HTTP request timeout in seconds, integer, optional, default is ``30``.
* ``username`` - username to authorize in web service in order to update service status, string, required in case if authorization enabled. * ``username`` - username to authorize in web service in order to update service status, string, required in case if authorization enabled.
@@ -367,6 +369,8 @@ Section name must be either ``telegram`` (plus optional architecture name, e.g.
* ``chat_id`` - telegram chat id, either string with ``@`` or integer value, required. * ``chat_id`` - telegram chat id, either string with ``@`` or integer value, required.
* ``homepage`` - link to homepage, string, optional. * ``homepage`` - link to homepage, string, optional.
* ``link_path`` - prefix for HTML links, string, required. * ``link_path`` - prefix for HTML links, string, required.
* ``max_retries`` - maximum amount of retries of HTTP requests, integer, optional, default ``0``.
* ``retry_backoff`` - retry exponential backoff, float, optional, default ``0.0``.
* ``rss_url`` - link to RSS feed, string, optional. * ``rss_url`` - link to RSS feed, string, optional.
* ``template`` - Jinja2 template name, string, required. * ``template`` - Jinja2 template name, string, required.
* ``template_type`` - ``parse_mode`` to be passed to telegram API, one of ``MarkdownV2``, ``HTML``, ``Markdown``, string, optional, default ``HTML``. * ``template_type`` - ``parse_mode`` to be passed to telegram API, one of ``MarkdownV2``, ``HTML``, ``Markdown``, string, optional, default ``HTML``.
@@ -392,6 +396,7 @@ Type will be read from several sources:
This feature requires GitHub key creation (see below). Section name must be either ``github`` (plus optional architecture name, e.g. ``github:x86_64``) or random name with ``type`` set. This feature requires GitHub key creation (see below). Section name must be either ``github`` (plus optional architecture name, e.g. ``github:x86_64``) or random name with ``type`` set.
* ``type`` - type of the upload, string, optional, must be set to ``github`` if exists. * ``type`` - type of the upload, string, optional, must be set to ``github`` if exists.
* ``max_retries`` - maximum amount of retries of HTTP requests, integer, optional, default ``0``.
* ``owner`` - GitHub repository owner, string, required. * ``owner`` - GitHub repository owner, string, required.
* ``password`` - created GitHub API key. In order to create it do the following: * ``password`` - created GitHub API key. In order to create it do the following:
@@ -401,6 +406,7 @@ This feature requires GitHub key creation (see below). Section name must be eith
#. Generate new token. Required scope is ``public_repo`` (or ``repo`` for private repository support). #. Generate new token. Required scope is ``public_repo`` (or ``repo`` for private repository support).
* ``repository`` - GitHub repository name, string, required. Repository must be created before any action and must have active branch (e.g. with readme). * ``repository`` - GitHub repository name, string, required. Repository must be created before any action and must have active branch (e.g. with readme).
* ``retry_backoff`` - retry exponential backoff, float, optional, default ``0.0``.
* ``timeout`` - HTTP request timeout in seconds, integer, optional, default is ``30``. * ``timeout`` - HTTP request timeout in seconds, integer, optional, default is ``30``.
* ``use_full_release_name`` - if set to ``yes``, the release will contain both repository name and architecture, and only architecture otherwise, boolean, optional, default ``no`` (legacy behavior). * ``use_full_release_name`` - if set to ``yes``, the release will contain both repository name and architecture, and only architecture otherwise, boolean, optional, default ``no`` (legacy behavior).
* ``username`` - GitHub authorization user, string, required. Basically the same as ``owner``. * ``username`` - GitHub authorization user, string, required. Basically the same as ``owner``.
@@ -411,6 +417,8 @@ This feature requires GitHub key creation (see below). Section name must be eith
Section name must be either ``remote-service`` (plus optional architecture name, e.g. ``remote-service:x86_64``) or random name with ``type`` set. Section name must be either ``remote-service`` (plus optional architecture name, e.g. ``remote-service:x86_64``) or random name with ``type`` set.
* ``type`` - type of the report, string, optional, must be set to ``remote-service`` if exists. * ``type`` - type of the report, string, optional, must be set to ``remote-service`` if exists.
* ``max_retries`` - maximum amount of retries of HTTP requests, integer, optional, default ``0``.
* ``retry_backoff`` - retry exponential backoff, float, optional, default ``0.0``.
* ``timeout`` - HTTP request timeout in seconds, integer, optional, default is ``30``. * ``timeout`` - HTTP request timeout in seconds, integer, optional, default is ``30``.
``rsync`` type ``rsync`` type

View File

@@ -73,8 +73,12 @@ enabled = yes
; In case if unix sockets are used, it might point to the valid socket with encoded path, e.g.: ; In case if unix sockets are used, it might point to the valid socket with encoded path, e.g.:
; address = http+unix://%2Fvar%2Flib%2Fahriman%2Fsocket ; address = http+unix://%2Fvar%2Flib%2Fahriman%2Fsocket
;address = http://${web:host}:${web:port} ;address = http://${web:host}:${web:port}
; Maximum amount of retries of HTTP requests.
;max_retries = 0
; Optional password for authentication (if enabled). ; Optional password for authentication (if enabled).
;password = ;password =
; Retry exponential backoff.
;retry_backoff = 0.0
; Do not log HTTP errors if occurs. ; Do not log HTTP errors if occurs.
suppress_http_log_errors = yes suppress_http_log_errors = yes
; HTTP request timeout in seconds. ; HTTP request timeout in seconds.
@@ -216,6 +220,10 @@ templates[] = ${prefix}/share/ahriman/templates
;homepage= ;homepage=
; Prefix for packages links. Link to a package will be formed as link_path / filename. ; Prefix for packages links. Link to a package will be formed as link_path / filename.
;link_path = ;link_path =
; Maximum amount of retries of HTTP requests.
;max_retries = 0
; Retry exponential backoff.
;retry_backoff = 0.0
; Optional link to the RSS feed. ; Optional link to the RSS feed.
;rss_url = ;rss_url =
; Template name to be used. ; Template name to be used.
@@ -236,12 +244,16 @@ target =
[github] [github]
; Trigger type name. ; Trigger type name.
;type = github ;type = github
; Maximum amount of retries of HTTP requests.
;max_retries = 0
; GitHub repository owner username. ; GitHub repository owner username.
;owner = ;owner =
; GitHub API key. public_repo (repo) scope is required. ; GitHub API key. public_repo (repo) scope is required.
;password = ;password =
; GitHub repository name. ; GitHub repository name.
;repository = ;repository =
; Retry exponential backoff.
;retry_backoff = 0.0
; HTTP request timeout in seconds. ; HTTP request timeout in seconds.
;timeout = 30 ;timeout = 30
; Include repository name to release name (recommended). ; Include repository name to release name (recommended).
@@ -253,6 +265,10 @@ target =
[remote-service] [remote-service]
; Trigger type name. ; Trigger type name.
;type = remote-service ;type = remote-service
; Maximum amount of retries of HTTP requests.
;max_retries = 0
; Retry exponential backoff.
;retry_backoff = 0.0
; HTTP request timeout in seconds. ; HTTP request timeout in seconds.
;timeout = 30 ;timeout = 30

View File

@@ -87,7 +87,7 @@
}; };
}); });
updateTable(table, payload); updateTable(table, payload, row => row.timestamp);
table.bootstrapTable("hideLoading"); table.bootstrapTable("hideLoading");
}, },
onFailure, onFailure,

View File

@@ -195,16 +195,19 @@
return intervalId; return intervalId;
} }
function updateTable(table, rows) { function updateTable(table, rows, rowChangedKey) {
// instead of using load method here, we just update rows manually to avoid table reinitialization // instead of using load method here, we just update rows manually to avoid table reinitialization
const currentData = table.bootstrapTable("getData").reduce((accumulator, row) => { const currentData = table.bootstrapTable("getData").reduce((accumulator, row) => {
accumulator[row.id] = row["0"]; accumulator[row.id] = {state: row["0"], key: rowChangedKey(row)};
return accumulator; return accumulator;
}, {}); }, {});
// insert or update rows // insert or update rows, skipping ones whose status hasn't changed
rows.forEach(row => { rows.forEach(row => {
if (Object.hasOwn(currentData, row.id)) { if (Object.hasOwn(currentData, row.id)) {
row["0"] = currentData[row.id]; // copy checkbox state if (rowChangedKey(row) === currentData[row.id].key) {
return;
}
row["0"] = currentData[row.id].state; // copy checkbox state
table.bootstrapTable("updateByUniqueId", { table.bootstrapTable("updateByUniqueId", {
id: row.id, id: row.id,
row: row, row: row,

View File

@@ -81,11 +81,13 @@ class Backup(Handler):
Returns: Returns:
set[Path]: map of the filesystem paths set[Path]: map of the filesystem paths
""" """
paths = set(configuration.include.glob("*.ini")) # configuration files
root, _ = configuration.check_loaded() root, _ = configuration.check_loaded()
paths.add(root) # the configuration itself paths = set(configuration.includes)
paths.add(SQLite.database_path(configuration)) # database paths.add(root)
# database
paths.add(SQLite.database_path(configuration))
# local caches # local caches
repository_paths = configuration.repository_paths repository_paths = configuration.repository_paths

View File

@@ -47,7 +47,7 @@ class Restore(Handler):
report(bool): force enable or disable reporting report(bool): force enable or disable reporting
""" """
with tarfile.open(args.path) as archive: with tarfile.open(args.path) as archive:
archive.extractall(path=args.output) # nosec archive.extractall(path=args.output, filter="data")
@staticmethod @staticmethod
def _set_repo_restore_parser(root: SubParserAction) -> argparse.ArgumentParser: def _set_repo_restore_parser(root: SubParserAction) -> argparse.ArgumentParser:

View File

@@ -86,7 +86,7 @@ class OAuth(Mapping):
Raises: Raises:
OptionError: in case if invalid OAuth provider name supplied OptionError: in case if invalid OAuth provider name supplied
""" """
provider: type[aioauth_client.OAuth2Client] = getattr(aioauth_client, name) provider: type = getattr(aioauth_client, name, type(None))
try: try:
is_oauth2_client = issubclass(provider, aioauth_client.OAuth2Client) is_oauth2_client = issubclass(provider, aioauth_client.OAuth2Client)
except TypeError: # what if it is random string? except TypeError: # what if it is random string?

View File

@@ -296,10 +296,20 @@ CONFIGURATION_SCHEMA: ConfigurationSchema = {
"empty": False, "empty": False,
"is_url": [], "is_url": [],
}, },
"max_retries": {
"type": "integer",
"coerce": "integer",
"min": 0,
},
"password": { "password": {
"type": "string", "type": "string",
"empty": False, "empty": False,
}, },
"retry_backoff": {
"type": "float",
"coerce": "float",
"min": 0,
},
"suppress_http_log_errors": { "suppress_http_log_errors": {
"type": "boolean", "type": "boolean",
"coerce": "boolean", "coerce": "boolean",

View File

@@ -76,6 +76,19 @@ class Validator(RootValidator):
converted: bool = self.configuration._convert_to_boolean(value) # type: ignore[attr-defined] converted: bool = self.configuration._convert_to_boolean(value) # type: ignore[attr-defined]
return converted return converted
def _normalize_coerce_float(self, value: str) -> float:
"""
extract float from string value
Args:
value(str): converting value
Returns:
float: value converted to float according to configuration rules
"""
del self
return float(value)
def _normalize_coerce_integer(self, value: str) -> int: def _normalize_coerce_integer(self, value: str) -> int:
""" """
extract integer from string value extract integer from string value

View File

@@ -20,10 +20,9 @@
import contextlib import contextlib
import requests import requests
from functools import cached_property from requests.adapters import BaseAdapter
from urllib.parse import urlparse from urllib.parse import urlparse
from ahriman import __version__
from ahriman.core.http.sync_http_client import SyncHttpClient from ahriman.core.http.sync_http_client import SyncHttpClient
@@ -37,32 +36,36 @@ class SyncAhrimanClient(SyncHttpClient):
address: str address: str
@cached_property def _login_url(self) -> str:
def session(self) -> requests.Session:
""" """
get or create session get url for the login api
Returns: Returns:
request.Session: created session object str: full url for web service to log in
""" """
if urlparse(self.address).scheme == "http+unix": return f"{self.address}/api/v1/login"
import requests_unixsocket
session: requests.Session = requests_unixsocket.Session() # type: ignore[no-untyped-call]
session.headers["User-Agent"] = f"ahriman/{__version__}"
return session
session = requests.Session() def adapters(self) -> dict[str, BaseAdapter]:
session.headers["User-Agent"] = f"ahriman/{__version__}"
self._login(session)
return session
def _login(self, session: requests.Session) -> None:
""" """
process login to the service get registered adapters
Returns:
dict[str, BaseAdapter]: map of protocol and adapter used for this protocol
"""
adapters = SyncHttpClient.adapters(self)
if (scheme := urlparse(self.address).scheme) == "http+unix":
from requests_unixsocket.adapters import UnixAdapter
adapters[f"{scheme}://"] = UnixAdapter() # type: ignore[no-untyped-call]
return adapters
def start(self, session: requests.Session) -> None:
"""
method which will be called on session creation
Args: Args:
session(requests.Session): request session to login session(requests.Session): created requests session
""" """
if self.auth is None: if self.auth is None:
return # no auth configured return # no auth configured
@@ -74,12 +77,3 @@ class SyncAhrimanClient(SyncHttpClient):
} }
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
self.make_request("POST", self._login_url(), json=payload, session=session) self.make_request("POST", self._login_url(), json=payload, session=session)
def _login_url(self) -> str:
"""
get url for the login api
Returns:
str: full url for web service to log in
"""
return f"{self.address}/api/v1/login"

View File

@@ -21,7 +21,9 @@ import requests
import sys import sys
from functools import cached_property from functools import cached_property
from requests.adapters import BaseAdapter, HTTPAdapter
from typing import Any, IO, Literal from typing import Any, IO, Literal
from urllib3.util.retry import Retry
from ahriman import __version__ from ahriman import __version__
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
@@ -62,6 +64,16 @@ class SyncHttpClient(LazyLogging):
self.timeout: int | None = configuration.getint(section, "timeout", fallback=30) self.timeout: int | None = configuration.getint(section, "timeout", fallback=30)
self.suppress_errors = suppress_errors self.suppress_errors = suppress_errors
retries = configuration.getint(section, "max_retries", fallback=0)
self.retry = Retry(
total=retries,
connect=retries,
read=retries,
status=retries,
status_forcelist=[429, 500, 502, 503, 504],
backoff_factor=configuration.getfloat(section, "retry_backoff", fallback=0.0),
)
@cached_property @cached_property
def session(self) -> requests.Session: def session(self) -> requests.Session:
""" """
@@ -71,11 +83,17 @@ class SyncHttpClient(LazyLogging):
request.Session: created session object request.Session: created session object
""" """
session = requests.Session() session = requests.Session()
for protocol, adapter in self.adapters().items():
session.mount(protocol, adapter)
python_version = ".".join(map(str, sys.version_info[:3])) # just major.minor.patch python_version = ".".join(map(str, sys.version_info[:3])) # just major.minor.patch
session.headers["User-Agent"] = f"ahriman/{__version__} " \ session.headers["User-Agent"] = f"ahriman/{__version__} " \
f"{requests.utils.default_user_agent()} " \ f"{requests.utils.default_user_agent()} " \
f"python/{python_version}" f"python/{python_version}"
self.start(session)
return session return session
@staticmethod @staticmethod
@@ -92,6 +110,19 @@ class SyncHttpClient(LazyLogging):
result: str = exception.response.text if exception.response is not None else "" result: str = exception.response.text if exception.response is not None else ""
return result return result
def adapters(self) -> dict[str, BaseAdapter]:
"""
get registered adapters
Returns:
dict[str, BaseAdapter]: map of protocol and adapter used for this protocol
"""
adapter = HTTPAdapter(max_retries=self.retry)
return {
"http://": adapter,
"https://": adapter,
}
def make_request(self, method: Literal["DELETE", "GET", "HEAD", "POST", "PUT"], url: str, *, def make_request(self, method: Literal["DELETE", "GET", "HEAD", "POST", "PUT"], url: str, *,
headers: dict[str, str] | None = None, headers: dict[str, str] | None = None,
params: list[tuple[str, str]] | None = None, params: list[tuple[str, str]] | None = None,
@@ -139,3 +170,11 @@ class SyncHttpClient(LazyLogging):
if not suppress_errors: if not suppress_errors:
self.logger.exception("could not perform http request") self.logger.exception("could not perform http request")
raise raise
def start(self, session: requests.Session) -> None:
"""
method which will be called on session creation
Args:
session(requests.Session): created requests session
"""

View File

@@ -74,6 +74,18 @@ class Email(Report, JinjaTemplate):
self.ssl = SmtpSSLSettings.from_option(configuration.get(section, "ssl", fallback="disabled")) self.ssl = SmtpSSLSettings.from_option(configuration.get(section, "ssl", fallback="disabled"))
self.user = configuration.get(section, "user", fallback=None) self.user = configuration.get(section, "user", fallback=None)
@property
def _smtp_session(self) -> type[smtplib.SMTP]:
"""
build SMTP session based on configuration settings
Returns:
type[smtplib.SMTP]: SMTP or SMTP_SSL session depending on whether SSL is enabled or not
"""
if self.ssl == SmtpSSLSettings.SSL:
return smtplib.SMTP_SSL
return smtplib.SMTP
def _send(self, text: str, attachment: dict[str, str]) -> None: def _send(self, text: str, attachment: dict[str, str]) -> None:
""" """
send email callback send email callback
@@ -93,16 +105,13 @@ class Email(Report, JinjaTemplate):
attach.add_header("Content-Disposition", "attachment", filename=filename) attach.add_header("Content-Disposition", "attachment", filename=filename)
message.attach(attach) message.attach(attach)
if self.ssl != SmtpSSLSettings.SSL: with self._smtp_session(self.host, self.port) as session:
session = smtplib.SMTP(self.host, self.port)
if self.ssl == SmtpSSLSettings.STARTTLS: if self.ssl == SmtpSSLSettings.STARTTLS:
session.starttls() session.starttls()
else:
session = smtplib.SMTP_SSL(self.host, self.port) if self.user is not None and self.password is not None:
if self.user is not None and self.password is not None: session.login(self.user, self.password)
session.login(self.user, self.password) session.sendmail(self.sender, self.receivers, message.as_string())
session.sendmail(self.sender, self.receivers, message.as_string())
session.quit()
def generate(self, packages: list[Package], result: Result) -> None: def generate(self, packages: list[Package], result: Result) -> None:
""" """

View File

@@ -302,6 +302,16 @@ class ReportTrigger(Trigger):
"empty": False, "empty": False,
"is_url": [], "is_url": [],
}, },
"max_retries": {
"type": "integer",
"coerce": "integer",
"min": 0,
},
"retry_backoff": {
"type": "float",
"coerce": "float",
"min": 0,
},
"rss_url": { "rss_url": {
"type": "string", "type": "string",
"empty": False, "empty": False,

View File

@@ -54,6 +54,11 @@ class UploadTrigger(Trigger):
"type": "string", "type": "string",
"allowed": ["github"], "allowed": ["github"],
}, },
"max_retries": {
"type": "integer",
"coerce": "integer",
"min": 0,
},
"owner": { "owner": {
"type": "string", "type": "string",
"required": True, "required": True,
@@ -68,6 +73,11 @@ class UploadTrigger(Trigger):
"required": True, "required": True,
"empty": False, "empty": False,
}, },
"retry_backoff": {
"type": "float",
"coerce": "float",
"min": 0,
},
"timeout": { "timeout": {
"type": "integer", "type": "integer",
"coerce": "integer", "coerce": "integer",
@@ -90,6 +100,16 @@ class UploadTrigger(Trigger):
"type": "string", "type": "string",
"allowed": ["ahriman", "remote-service"], "allowed": ["ahriman", "remote-service"],
}, },
"max_retries": {
"type": "integer",
"coerce": "integer",
"min": 0,
},
"retry_backoff": {
"type": "float",
"coerce": "float",
"min": 0,
},
"timeout": { "timeout": {
"type": "integer", "type": "integer",
"coerce": "integer", "coerce": "integer",

View File

@@ -164,6 +164,11 @@ def check_output(*args: str, exception: Exception | Callable[[int, list[str], st
if key in ("PATH",) # whitelisted variables only if key in ("PATH",) # whitelisted variables only
} | environment } | environment
result: dict[str, list[str]] = {
"stdout": [],
"stderr": [],
}
with subprocess.Popen(args, cwd=cwd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, with subprocess.Popen(args, cwd=cwd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
user=user, env=full_environment, text=True, encoding="utf8", errors="backslashreplace", user=user, env=full_environment, text=True, encoding="utf8", errors="backslashreplace",
bufsize=1) as process: bufsize=1) as process:
@@ -172,30 +177,27 @@ def check_output(*args: str, exception: Exception | Callable[[int, list[str], st
input_channel.write(input_data) input_channel.write(input_data)
input_channel.close() input_channel.close()
selector = selectors.DefaultSelector() with selectors.DefaultSelector() as selector:
selector.register(get_io(process, "stdout"), selectors.EVENT_READ, data="stdout") selector.register(get_io(process, "stdout"), selectors.EVENT_READ, data="stdout")
selector.register(get_io(process, "stderr"), selectors.EVENT_READ, data="stderr") selector.register(get_io(process, "stderr"), selectors.EVENT_READ, data="stderr")
result: dict[str, list[str]] = { while selector.get_map(): # while there are unread selectors, keep reading
"stdout": [], for key_data, output in poll(selector):
"stderr": [], result[key_data].append(output)
}
while selector.get_map(): # while there are unread selectors, keep reading
for key_data, output in poll(selector):
result[key_data].append(output)
stdout = "\n".join(result["stdout"]).rstrip("\n") # remove newline at the end of any
stderr = "\n".join(result["stderr"]).rstrip("\n")
status_code = process.wait() status_code = process.wait()
if status_code != 0:
if isinstance(exception, Exception):
raise exception
if callable(exception):
raise exception(status_code, list(args), stdout, stderr)
raise CalledProcessError(status_code, list(args), stderr)
return stdout stdout = "\n".join(result["stdout"]).rstrip("\n") # remove newline at the end of any
stderr = "\n".join(result["stderr"]).rstrip("\n")
if status_code != 0:
if isinstance(exception, Exception):
raise exception
if callable(exception):
raise exception(status_code, list(args), stdout, stderr)
raise CalledProcessError(status_code, list(args), stderr)
return stdout
def check_user(root: Path, *, unsafe: bool) -> None: def check_user(root: Path, *, unsafe: bool) -> None:

View File

@@ -72,7 +72,7 @@ def _security() -> list[dict[str, Any]]:
return [{ return [{
"token": { "token": {
"type": "apiKey", # as per specification we are using api key "type": "apiKey", # as per specification we are using api key
"name": "API_SESSION", "name": "AHRIMAN",
"in": "cookie", "in": "cookie",
} }
}] }]

View File

@@ -149,11 +149,17 @@ def setup_auth(application: Application, configuration: Configuration, validator
Application: configured web application Application: configured web application
""" """
secret_key = _cookie_secret_key(configuration) secret_key = _cookie_secret_key(configuration)
storage = EncryptedCookieStorage(secret_key, cookie_name="API_SESSION", max_age=validator.max_age) storage = EncryptedCookieStorage(
secret_key,
cookie_name="AHRIMAN",
max_age=validator.max_age,
httponly=True,
samesite="Strict",
)
setup_session(application, storage) setup_session(application, storage)
authorization_policy = _AuthorizationPolicy(validator) authorization_policy = _AuthorizationPolicy(validator)
identity_policy = aiohttp_security.SessionIdentityPolicy() identity_policy = aiohttp_security.SessionIdentityPolicy("SESSION")
aiohttp_security.setup(application, identity_policy, authorization_policy) aiohttp_security.setup(application, identity_policy, authorization_policy)
application.middlewares.append(_auth_handler(validator.allow_read_only)) application.middlewares.append(_auth_handler(validator.allow_read_only))

View File

@@ -25,6 +25,6 @@ class AuthSchema(Schema):
request cookie authorization schema request cookie authorization schema
""" """
API_SESSION = fields.String(required=True, metadata={ AHRIMAN = fields.String(required=True, metadata={
"description": "API session key as returned from authorization", "description": "API session key as returned from authorization",
}) })

View File

@@ -34,7 +34,7 @@ def test_run(args: argparse.Namespace, configuration: Configuration, mocker: Moc
_, repository_id = configuration.check_loaded() _, repository_id = configuration.check_loaded()
Restore.run(args, repository_id, configuration, report=False) Restore.run(args, repository_id, configuration, report=False)
extract_mock.extractall.assert_called_once_with(path=args.output) extract_mock.extractall.assert_called_once_with(path=args.output, filter="data")
def test_disallow_multi_architecture_run() -> None: def test_disallow_multi_architecture_run() -> None:

View File

@@ -33,6 +33,14 @@ def test_normalize_coerce_boolean(validator: Validator, mocker: MockerFixture) -
convert_mock.assert_called_once_with("1") convert_mock.assert_called_once_with("1")
def test_normalize_coerce_float(validator: Validator) -> None:
"""
must convert string value to float by using configuration converters
"""
assert validator._normalize_coerce_float("1.5") == 1.5
assert validator._normalize_coerce_float("0.0") == 0.0
def test_normalize_coerce_integer(validator: Validator) -> None: def test_normalize_coerce_integer(validator: Validator) -> None:
""" """
must convert string value to integer by using configuration converters must convert string value to integer by using configuration converters

View File

@@ -1,6 +1,5 @@
import pytest import pytest
import requests import requests
import requests_unixsocket
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
@@ -8,31 +7,32 @@ from ahriman.core.http import SyncAhrimanClient
from ahriman.models.user import User from ahriman.models.user import User
def test_session(ahriman_client: SyncAhrimanClient, mocker: MockerFixture) -> None: def test_adapters(ahriman_client: SyncAhrimanClient) -> None:
""" """
must create normal requests session must return native adapters
""" """
login_mock = mocker.patch("ahriman.core.http.SyncAhrimanClient._login") assert "http+unix://" not in ahriman_client.adapters()
assert isinstance(ahriman_client.session, requests.Session)
assert not isinstance(ahriman_client.session, requests_unixsocket.Session)
login_mock.assert_called_once_with(pytest.helpers.anyvar(int))
def test_session_unix_socket(ahriman_client: SyncAhrimanClient, mocker: MockerFixture) -> None: def test_adapters_unix_socket(ahriman_client: SyncAhrimanClient) -> None:
""" """
must create unix socket session must register unix socket adapter
""" """
login_mock = mocker.patch("ahriman.core.http.SyncAhrimanClient._login")
ahriman_client.address = "http+unix://path" ahriman_client.address = "http+unix://path"
assert "http+unix://" in ahriman_client.adapters()
assert isinstance(ahriman_client.session, requests_unixsocket.Session)
login_mock.assert_not_called()
def test_login(ahriman_client: SyncAhrimanClient, user: User, mocker: MockerFixture) -> None: def test_login_url(ahriman_client: SyncAhrimanClient) -> None:
""" """
must login user must generate login url correctly
"""
assert ahriman_client._login_url().startswith(ahriman_client.address)
assert ahriman_client._login_url().endswith("/api/v1/login")
def test_start(ahriman_client: SyncAhrimanClient, user: User, mocker: MockerFixture) -> None:
"""
must log in user on start
""" """
ahriman_client.auth = (user.username, user.password) ahriman_client.auth = (user.username, user.password)
requests_mock = mocker.patch("ahriman.core.http.SyncAhrimanClient.make_request") requests_mock = mocker.patch("ahriman.core.http.SyncAhrimanClient.make_request")
@@ -42,40 +42,32 @@ def test_login(ahriman_client: SyncAhrimanClient, user: User, mocker: MockerFixt
} }
session = requests.Session() session = requests.Session()
ahriman_client._login(session) ahriman_client.start(session)
requests_mock.assert_called_once_with("POST", pytest.helpers.anyvar(str, True), json=payload, session=session) requests_mock.assert_called_once_with("POST", pytest.helpers.anyvar(str, True), json=payload, session=session)
def test_login_failed(ahriman_client: SyncAhrimanClient, user: User, mocker: MockerFixture) -> None: def test_start_failed(ahriman_client: SyncAhrimanClient, user: User, mocker: MockerFixture) -> None:
""" """
must suppress any exception happened during login must suppress any exception happened during session start
""" """
ahriman_client.user = user ahriman_client.user = user
mocker.patch("requests.Session.request", side_effect=Exception) mocker.patch("requests.Session.request", side_effect=Exception)
ahriman_client._login(requests.Session()) ahriman_client.start(requests.Session())
def test_login_failed_http_error(ahriman_client: SyncAhrimanClient, user: User, mocker: MockerFixture) -> None: def test_start_failed_http_error(ahriman_client: SyncAhrimanClient, user: User, mocker: MockerFixture) -> None:
""" """
must suppress HTTP exception happened during login must suppress HTTP exception happened during session start
""" """
ahriman_client.user = user ahriman_client.user = user
mocker.patch("requests.Session.request", side_effect=requests.HTTPError) mocker.patch("requests.Session.request", side_effect=requests.HTTPError)
ahriman_client._login(requests.Session()) ahriman_client.start(requests.Session())
def test_login_skip(ahriman_client: SyncAhrimanClient, mocker: MockerFixture) -> None: def test_start_skip(ahriman_client: SyncAhrimanClient, mocker: MockerFixture) -> None:
""" """
must skip login if no user set must skip login if no user set
""" """
requests_mock = mocker.patch("requests.Session.request") requests_mock = mocker.patch("requests.Session.request")
ahriman_client._login(requests.Session()) ahriman_client.start(requests.Session())
requests_mock.assert_not_called() requests_mock.assert_not_called()
def test_login_url(ahriman_client: SyncAhrimanClient) -> None:
"""
must generate login url correctly
"""
assert ahriman_client._login_url().startswith(ahriman_client.address)
assert ahriman_client._login_url().endswith("/api/v1/login")

View File

@@ -33,12 +33,15 @@ def test_init_auth_empty() -> None:
assert SyncHttpClient().auth is None assert SyncHttpClient().auth is None
def test_session() -> None: def test_session(mocker: MockerFixture) -> None:
""" """
must generate valid session must generate valid session
""" """
start_mock = mocker.patch("ahriman.core.http.sync_http_client.SyncHttpClient.start")
session = SyncHttpClient().session session = SyncHttpClient().session
assert "User-Agent" in session.headers assert "User-Agent" in session.headers
start_mock.assert_called_once_with(pytest.helpers.anyvar(int))
def test_exception_response_text() -> None: def test_exception_response_text() -> None:
@@ -60,6 +63,18 @@ def test_exception_response_text_empty() -> None:
assert SyncHttpClient.exception_response_text(exception) == "" assert SyncHttpClient.exception_response_text(exception) == ""
def test_adapters() -> None:
"""
must create adapters with retry policy
"""
client = SyncHttpClient()
adapers = client.adapters()
assert "http://" in adapers
assert "https://" in adapers
assert all(adapter.max_retries == client.retry for adapter in adapers.values())
def test_make_request(mocker: MockerFixture) -> None: def test_make_request(mocker: MockerFixture) -> None:
""" """
must make HTTP request must make HTTP request
@@ -158,3 +173,11 @@ def test_make_request_session() -> None:
session_mock.request.assert_called_once_with( session_mock.request.assert_called_once_with(
"GET", "url", params=None, data=None, headers=None, files=None, json=None, "GET", "url", params=None, data=None, headers=None, files=None, json=None,
stream=None, auth=None, timeout=client.timeout) stream=None, auth=None, timeout=client.timeout)
def test_start() -> None:
"""
must do nothing on start
"""
client = SyncHttpClient()
client.start(client.session)

View File

@@ -1,3 +1,5 @@
import smtplib
import pytest import pytest
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
@@ -6,6 +8,7 @@ from ahriman.core.configuration import Configuration
from ahriman.core.report.email import Email from ahriman.core.report.email import Email
from ahriman.models.package import Package from ahriman.models.package import Package
from ahriman.models.result import Result from ahriman.models.result import Result
from ahriman.models.smtp_ssl_settings import SmtpSSLSettings
def test_template(configuration: Configuration) -> None: def test_template(configuration: Configuration) -> None:
@@ -37,17 +40,36 @@ def test_template_full(configuration: Configuration) -> None:
assert Email(repository_id, configuration, "email").template_full == root.parent / template assert Email(repository_id, configuration, "email").template_full == root.parent / template
def test_smtp_session(email: Email) -> None:
"""
must build normal SMTP session if SSL is disabled
"""
email.ssl = SmtpSSLSettings.Disabled
assert email._smtp_session == smtplib.SMTP
email.ssl = SmtpSSLSettings.STARTTLS
assert email._smtp_session == smtplib.SMTP
def test_smtp_session_ssl(email: Email) -> None:
"""
must build SMTP_SSL session if SSL is enabled
"""
email.ssl = SmtpSSLSettings.SSL
assert email._smtp_session == smtplib.SMTP_SSL
def test_send(email: Email, mocker: MockerFixture) -> None: def test_send(email: Email, mocker: MockerFixture) -> None:
""" """
must send an email with attachment must send an email with attachment
""" """
smtp_mock = mocker.patch("smtplib.SMTP") smtp_mock = mocker.patch("smtplib.SMTP")
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
email._send("a text", {"attachment.html": "an attachment"}) email._send("a text", {"attachment.html": "an attachment"})
smtp_mock.return_value.starttls.assert_not_called() smtp_mock.return_value.starttls.assert_not_called()
smtp_mock.return_value.login.assert_not_called() smtp_mock.return_value.login.assert_not_called()
smtp_mock.return_value.sendmail.assert_called_once_with(email.sender, email.receivers, pytest.helpers.anyvar(int)) smtp_mock.return_value.sendmail.assert_called_once_with(email.sender, email.receivers, pytest.helpers.anyvar(int))
smtp_mock.return_value.quit.assert_called_once_with()
def test_send_auth(configuration: Configuration, mocker: MockerFixture) -> None: def test_send_auth(configuration: Configuration, mocker: MockerFixture) -> None:
@@ -57,6 +79,7 @@ def test_send_auth(configuration: Configuration, mocker: MockerFixture) -> None:
configuration.set_option("email", "user", "username") configuration.set_option("email", "user", "username")
configuration.set_option("email", "password", "password") configuration.set_option("email", "password", "password")
smtp_mock = mocker.patch("smtplib.SMTP") smtp_mock = mocker.patch("smtplib.SMTP")
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
_, repository_id = configuration.check_loaded() _, repository_id = configuration.check_loaded()
email = Email(repository_id, configuration, "email") email = Email(repository_id, configuration, "email")
@@ -70,6 +93,7 @@ def test_send_auth_no_password(configuration: Configuration, mocker: MockerFixtu
""" """
configuration.set_option("email", "user", "username") configuration.set_option("email", "user", "username")
smtp_mock = mocker.patch("smtplib.SMTP") smtp_mock = mocker.patch("smtplib.SMTP")
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
_, repository_id = configuration.check_loaded() _, repository_id = configuration.check_loaded()
email = Email(repository_id, configuration, "email") email = Email(repository_id, configuration, "email")
@@ -83,6 +107,7 @@ def test_send_auth_no_user(configuration: Configuration, mocker: MockerFixture)
""" """
configuration.set_option("email", "password", "password") configuration.set_option("email", "password", "password")
smtp_mock = mocker.patch("smtplib.SMTP") smtp_mock = mocker.patch("smtplib.SMTP")
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
_, repository_id = configuration.check_loaded() _, repository_id = configuration.check_loaded()
email = Email(repository_id, configuration, "email") email = Email(repository_id, configuration, "email")
@@ -96,6 +121,7 @@ def test_send_ssl_tls(configuration: Configuration, mocker: MockerFixture) -> No
""" """
configuration.set_option("email", "ssl", "ssl") configuration.set_option("email", "ssl", "ssl")
smtp_mock = mocker.patch("smtplib.SMTP_SSL") smtp_mock = mocker.patch("smtplib.SMTP_SSL")
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
_, repository_id = configuration.check_loaded() _, repository_id = configuration.check_loaded()
email = Email(repository_id, configuration, "email") email = Email(repository_id, configuration, "email")
@@ -103,7 +129,6 @@ def test_send_ssl_tls(configuration: Configuration, mocker: MockerFixture) -> No
smtp_mock.return_value.starttls.assert_not_called() smtp_mock.return_value.starttls.assert_not_called()
smtp_mock.return_value.login.assert_not_called() smtp_mock.return_value.login.assert_not_called()
smtp_mock.return_value.sendmail.assert_called_once_with(email.sender, email.receivers, pytest.helpers.anyvar(int)) smtp_mock.return_value.sendmail.assert_called_once_with(email.sender, email.receivers, pytest.helpers.anyvar(int))
smtp_mock.return_value.quit.assert_called_once_with()
def test_send_starttls(configuration: Configuration, mocker: MockerFixture) -> None: def test_send_starttls(configuration: Configuration, mocker: MockerFixture) -> None:
@@ -112,6 +137,7 @@ def test_send_starttls(configuration: Configuration, mocker: MockerFixture) -> N
""" """
configuration.set_option("email", "ssl", "starttls") configuration.set_option("email", "ssl", "starttls")
smtp_mock = mocker.patch("smtplib.SMTP") smtp_mock = mocker.patch("smtplib.SMTP")
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
_, repository_id = configuration.check_loaded() _, repository_id = configuration.check_loaded()
email = Email(repository_id, configuration, "email") email = Email(repository_id, configuration, "email")

View File

@@ -23,7 +23,7 @@ def test_security() -> None:
must generate security definitions for swagger must generate security definitions for swagger
""" """
token = next(iter(_security()))["token"] token = next(iter(_security()))["token"]
assert token == {"type": "apiKey", "name": "API_SESSION", "in": "cookie"} assert token == {"type": "apiKey", "name": "AHRIMAN", "in": "cookie"}
def test_servers(application: Application) -> None: def test_servers(application: Application) -> None:

View File

@@ -6,4 +6,4 @@ def test_schema() -> None:
must return valid schema must return valid schema
""" """
schema = AuthSchema() schema = AuthSchema()
assert not schema.validate({"API_SESSION": "key"}) assert not schema.validate({"AHRIMAN": "key"})

View File

@@ -27,7 +27,7 @@ def _client(client: TestClient, mocker: MockerFixture) -> TestClient:
"parameters": [ "parameters": [
{ {
"in": "cookie", "in": "cookie",
"name": "API_SESSION", "name": "AHRIMAN",
"schema": { "schema": {
"type": "string", "type": "string",
}, },
@@ -39,7 +39,7 @@ def _client(client: TestClient, mocker: MockerFixture) -> TestClient:
"parameters": [ "parameters": [
{ {
"in": "cookie", "in": "cookie",
"name": "API_SESSION", "name": "AHRIMAN",
"schema": { "schema": {
"type": "string", "type": "string",
}, },
@@ -60,7 +60,7 @@ def _client(client: TestClient, mocker: MockerFixture) -> TestClient:
{ {
"token": { "token": {
"type": "apiKey", "type": "apiKey",
"name": "API_SESSION", "name": "AHRIMAN",
"in": "cookie", "in": "cookie",
}, },
}, },