refactor: use AppKey's instead of string identifiers for web application

This commit is contained in:
Evgenii Alekseev 2023-12-27 13:41:07 +02:00
parent e4b22fd620
commit ee3ccf70ac
13 changed files with 89 additions and 37 deletions

View File

@ -30,6 +30,14 @@ ahriman.web.cors module
:no-undoc-members: :no-undoc-members:
:show-inheritance: :show-inheritance:
ahriman.web.keys module
-----------------------
.. automodule:: ahriman.web.keys
:members:
:no-undoc-members:
:show-inheritance:
ahriman.web.routes module ahriman.web.routes module
------------------------- -------------------------

View File

@ -23,7 +23,7 @@ from aiohttp.web import Application
from typing import Any from typing import Any
from ahriman import __version__ from ahriman import __version__
from ahriman.core.configuration import Configuration from ahriman.web.keys import ConfigurationKey
__all__ = ["setup_apispec"] __all__ = ["setup_apispec"]
@ -89,7 +89,7 @@ def _servers(application: Application) -> list[dict[str, Any]]:
Returns: Returns:
list[dict[str, Any]]: list (actually only one) of defined web urls list[dict[str, Any]]: list (actually only one) of defined web urls
""" """
configuration: Configuration = application["configuration"] configuration = application[ConfigurationKey]
address = configuration.get("web", "address", fallback=None) address = configuration.get("web", "address", fallback=None)
if not address: if not address:
host = configuration.get("web", "host") host = configuration.get("web", "host")

40
src/ahriman/web/keys.py Normal file
View File

@ -0,0 +1,40 @@
#
# Copyright (c) 2021-2023 ahriman team.
#
# This file is part of ahriman
# (see https://github.com/arcan1s/ahriman).
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
from aiohttp.web import AppKey
from ahriman.core.auth import Auth
from ahriman.core.configuration import Configuration
from ahriman.core.spawn import Spawn
from ahriman.core.status.watcher import Watcher
from ahriman.models.repository_id import RepositoryId
__all__ = [
"AuthKey",
"ConfigurationKey",
"SpawnKey",
"WatcherKey",
]
AuthKey = AppKey("validator", Auth)
ConfigurationKey = AppKey("configuration", Configuration)
SpawnKey = AppKey("spawn", Spawn)
WatcherKey = AppKey("watcher", dict[RepositoryId, Watcher])

View File

@ -154,7 +154,7 @@ def setup_auth(application: Application, configuration: Configuration, validator
setup_session(application, storage) setup_session(application, storage)
authorization_policy = _AuthorizationPolicy(validator) authorization_policy = _AuthorizationPolicy(validator)
identity_policy = application["identity"] = aiohttp_security.SessionIdentityPolicy() identity_policy = aiohttp_security.SessionIdentityPolicy()
aiohttp_security.setup(application, identity_policy, authorization_policy) aiohttp_security.setup(application, identity_policy, authorization_policy)
application.middlewares.append(_auth_handler(validator.allow_read_only)) application.middlewares.append(_auth_handler(validator.allow_read_only))

View File

@ -29,6 +29,7 @@ from ahriman.core.spawn import Spawn
from ahriman.core.status.watcher import Watcher from ahriman.core.status.watcher import Watcher
from ahriman.models.repository_id import RepositoryId from ahriman.models.repository_id import RepositoryId
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.web.keys import AuthKey, ConfigurationKey, SpawnKey, WatcherKey
T = TypeVar("T", str, list[str]) T = TypeVar("T", str, list[str])
@ -54,8 +55,7 @@ class BaseView(View, CorsViewMixin):
Returns: Returns:
Configuration: configuration instance Configuration: configuration instance
""" """
configuration: Configuration = self.request.app["configuration"] return self.request.app[ConfigurationKey]
return configuration
@property @property
def services(self) -> dict[RepositoryId, Watcher]: def services(self) -> dict[RepositoryId, Watcher]:
@ -65,8 +65,7 @@ class BaseView(View, CorsViewMixin):
Returns: Returns:
dict[RepositoryId, Watcher]: map of loaded watchers per known repository dict[RepositoryId, Watcher]: map of loaded watchers per known repository
""" """
watchers: dict[RepositoryId, Watcher] = self.request.app["watcher"] return self.request.app[WatcherKey]
return watchers
@property @property
def sign(self) -> GPG: def sign(self) -> GPG:
@ -86,8 +85,7 @@ class BaseView(View, CorsViewMixin):
Returns: Returns:
Spawn: external process spawner instance Spawn: external process spawner instance
""" """
spawner: Spawn = self.request.app["spawn"] return self.request.app[SpawnKey]
return spawner
@property @property
def validator(self) -> Auth: def validator(self) -> Auth:
@ -97,8 +95,7 @@ class BaseView(View, CorsViewMixin):
Returns: Returns:
Auth: authorization service instance Auth: authorization service instance
""" """
validator: Auth = self.request.app["validator"] return self.request.app[AuthKey]
return validator
@classmethod @classmethod
async def get_permission(cls, request: Request) -> UserAccess: async def get_permission(cls, request: Request) -> UserAccess:

View File

@ -34,6 +34,7 @@ from ahriman.core.status.watcher import Watcher
from ahriman.models.repository_id import RepositoryId from ahriman.models.repository_id import RepositoryId
from ahriman.web.apispec import setup_apispec from ahriman.web.apispec import setup_apispec
from ahriman.web.cors import setup_cors from ahriman.web.cors import setup_cors
from ahriman.web.keys import AuthKey, ConfigurationKey, SpawnKey, WatcherKey
from ahriman.web.middlewares.exception_handler import exception_handler from ahriman.web.middlewares.exception_handler import exception_handler
from ahriman.web.routes import setup_routes from ahriman.web.routes import setup_routes
@ -97,7 +98,7 @@ async def _on_startup(application: Application) -> None:
application.logger.info("server started") application.logger.info("server started")
try: try:
for watcher in application["watcher"].values(): for watcher in application[WatcherKey].values():
watcher.load() watcher.load()
except Exception: except Exception:
message = "could not load packages" message = "could not load packages"
@ -114,7 +115,7 @@ def run_server(application: Application) -> None:
""" """
application.logger.info("start server") application.logger.info("start server")
configuration: Configuration = application["configuration"] configuration = application[ConfigurationKey]
host = configuration.get("web", "host") host = configuration.get("web", "host")
port = configuration.getint("web", "port") port = configuration.getint("web", "port")
unix_socket = _create_socket(configuration, application) unix_socket = _create_socket(configuration, application)
@ -156,7 +157,7 @@ def setup_server(configuration: Configuration, spawner: Spawn, repositories: lis
aiohttp_jinja2.setup(application, trim_blocks=True, lstrip_blocks=True, autoescape=True, loader=loader) aiohttp_jinja2.setup(application, trim_blocks=True, lstrip_blocks=True, autoescape=True, loader=loader)
application.logger.info("setup configuration") application.logger.info("setup configuration")
application["configuration"] = configuration application[ConfigurationKey] = configuration
application.logger.info("setup watchers") application.logger.info("setup watchers")
if not repositories: if not repositories:
@ -166,13 +167,13 @@ def setup_server(configuration: Configuration, spawner: Spawn, repositories: lis
for repository_id in repositories: for repository_id in repositories:
application.logger.info("load repository %s", repository_id) application.logger.info("load repository %s", repository_id)
watchers[repository_id] = Watcher(repository_id, database) watchers[repository_id] = Watcher(repository_id, database)
application["watcher"] = watchers application[WatcherKey] = watchers
application.logger.info("setup process spawner") application.logger.info("setup process spawner")
application["spawn"] = spawner application[SpawnKey] = spawner
application.logger.info("setup authorization") application.logger.info("setup authorization")
validator = application["validator"] = Auth.load(configuration, database) validator = application[AuthKey] = Auth.load(configuration, database)
if validator.enabled: if validator.enabled:
from ahriman.web.middlewares.auth_handler import setup_auth from ahriman.web.middlewares.auth_handler import setup_auth
setup_auth(application, configuration, validator) setup_auth(application, configuration, validator)

View File

@ -16,6 +16,7 @@ from ahriman.core.configuration import Configuration
from ahriman.core.database import SQLite from ahriman.core.database import SQLite
from ahriman.core.spawn import Spawn from ahriman.core.spawn import Spawn
from ahriman.models.user import User from ahriman.models.user import User
from ahriman.web.keys import AuthKey
from ahriman.web.web import setup_server from ahriman.web.web import setup_server
@ -159,7 +160,7 @@ def application_with_auth(configuration: Configuration, user: User, spawner: Spa
_, repository_id = configuration.check_loaded() _, repository_id = configuration.check_loaded()
application = setup_server(configuration, spawner, [repository_id]) application = setup_server(configuration, spawner, [repository_id])
generated = user.hash_password(application["validator"].salt) generated = user.hash_password(application[AuthKey].salt)
mocker.patch("ahriman.core.database.SQLite.user_get", return_value=generated) mocker.patch("ahriman.core.database.SQLite.user_get", return_value=generated)
return application return application
@ -245,5 +246,5 @@ def client_with_oauth_auth(application_with_auth: Application, event_loop: BaseE
TestClient: web client test instance TestClient: web client test instance
""" """
mocker.patch("pathlib.Path.iterdir", return_value=[]) mocker.patch("pathlib.Path.iterdir", return_value=[])
application_with_auth["validator"] = MagicMock(spec=OAuth) application_with_auth[AuthKey] = MagicMock(spec=OAuth)
return event_loop.run_until_complete(aiohttp_client(application_with_auth)) return event_loop.run_until_complete(aiohttp_client(application_with_auth))

View File

@ -12,6 +12,7 @@ from ahriman.core.configuration import Configuration
from ahriman.models.build_status import BuildStatusEnum from ahriman.models.build_status import BuildStatusEnum
from ahriman.models.user import User from ahriman.models.user import User
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.web.keys import AuthKey
from ahriman.web.middlewares.auth_handler import _AuthorizationPolicy, _auth_handler, _cookie_secret_key, setup_auth from ahriman.web.middlewares.auth_handler import _AuthorizationPolicy, _auth_handler, _cookie_secret_key, setup_auth
@ -192,5 +193,5 @@ def test_setup_auth(application_with_auth: Application, configuration: Configura
""" """
setup_mock = mocker.patch("aiohttp_security.setup") setup_mock = mocker.patch("aiohttp_security.setup")
application = setup_auth(application_with_auth, configuration, auth) application = setup_auth(application_with_auth, configuration, auth)
assert application.get("validator") is not None assert application.get(AuthKey) is not None
setup_mock.assert_called_once_with(application_with_auth, pytest.helpers.anyvar(int), pytest.helpers.anyvar(int)) setup_mock.assert_called_once_with(application_with_auth, pytest.helpers.anyvar(int), pytest.helpers.anyvar(int))

View File

@ -5,6 +5,7 @@ from pytest_mock import MockerFixture
from ahriman import __version__ from ahriman import __version__
from ahriman.web.apispec import _info, _security, _servers, setup_apispec from ahriman.web.apispec import _info, _security, _servers, setup_apispec
from ahriman.web.keys import ConfigurationKey
def test_info() -> None: def test_info() -> None:
@ -36,7 +37,7 @@ def test_servers_address(application: Application) -> None:
""" """
must generate servers definitions with address must generate servers definitions with address
""" """
application["configuration"].set_option("web", "address", "https://example.com") application[ConfigurationKey].set_option("web", "address", "https://example.com")
servers = _servers(application) servers = _servers(application)
assert servers == [{"url": "https://example.com"}] assert servers == [{"url": "https://example.com"}]

View File

View File

@ -10,6 +10,7 @@ from ahriman.core.exceptions import InitializeError
from ahriman.core.log.filtered_access_logger import FilteredAccessLogger from ahriman.core.log.filtered_access_logger import FilteredAccessLogger
from ahriman.core.spawn import Spawn from ahriman.core.spawn import Spawn
from ahriman.core.status.watcher import Watcher from ahriman.core.status.watcher import Watcher
from ahriman.web.keys import ConfigurationKey
from ahriman.web.web import _create_socket, _on_shutdown, _on_startup, run_server, setup_server from ahriman.web.web import _create_socket, _on_shutdown, _on_startup, run_server, setup_server
@ -18,14 +19,14 @@ async def test_create_socket(application: Application, mocker: MockerFixture) ->
must create socket must create socket
""" """
path = "/run/ahriman.sock" path = "/run/ahriman.sock"
application["configuration"].set_option("web", "unix_socket", str(path)) application[ConfigurationKey].set_option("web", "unix_socket", str(path))
current_on_shutdown = len(application.on_shutdown) current_on_shutdown = len(application.on_shutdown)
bind_mock = mocker.patch("socket.socket.bind") bind_mock = mocker.patch("socket.socket.bind")
chmod_mock = mocker.patch("pathlib.Path.chmod") chmod_mock = mocker.patch("pathlib.Path.chmod")
unlink_mock = mocker.patch("pathlib.Path.unlink") unlink_mock = mocker.patch("pathlib.Path.unlink")
sock = _create_socket(application["configuration"], application) sock = _create_socket(application[ConfigurationKey], application)
assert sock.family == socket.AF_UNIX assert sock.family == socket.AF_UNIX
assert sock.type == socket.SOCK_STREAM assert sock.type == socket.SOCK_STREAM
bind_mock.assert_called_once_with(str(path)) bind_mock.assert_called_once_with(str(path))
@ -41,7 +42,7 @@ def test_create_socket_empty(application: Application) -> None:
""" """
must skip socket creation if not set by configuration must skip socket creation if not set by configuration
""" """
assert _create_socket(application["configuration"], application) is None assert _create_socket(application[ConfigurationKey], application) is None
def test_create_socket_safe(application: Application, mocker: MockerFixture) -> None: def test_create_socket_safe(application: Application, mocker: MockerFixture) -> None:
@ -49,14 +50,14 @@ def test_create_socket_safe(application: Application, mocker: MockerFixture) ->
must create socket with default permission set must create socket with default permission set
""" """
path = "/run/ahriman.sock" path = "/run/ahriman.sock"
application["configuration"].set_option("web", "unix_socket", str(path)) application[ConfigurationKey].set_option("web", "unix_socket", str(path))
application["configuration"].set_option("web", "unix_socket_unsafe", "no") application[ConfigurationKey].set_option("web", "unix_socket_unsafe", "no")
mocker.patch("socket.socket.bind") mocker.patch("socket.socket.bind")
mocker.patch("pathlib.Path.unlink") mocker.patch("pathlib.Path.unlink")
chmod_mock = mocker.patch("pathlib.Path.chmod") chmod_mock = mocker.patch("pathlib.Path.chmod")
sock = _create_socket(application["configuration"], application) sock = _create_socket(application[ConfigurationKey], application)
assert sock is not None assert sock is not None
chmod_mock.assert_not_called() chmod_mock.assert_not_called()
@ -97,7 +98,7 @@ def test_run(application: Application, mocker: MockerFixture) -> None:
must run application must run application
""" """
port = 8080 port = 8080
application["configuration"].set_option("web", "port", str(port)) application[ConfigurationKey].set_option("web", "port", str(port))
run_application_mock = mocker.patch("ahriman.web.web.run_app") run_application_mock = mocker.patch("ahriman.web.web.run_app")
run_server(application) run_server(application)
@ -112,7 +113,7 @@ def test_run_with_auth(application_with_auth: Application, mocker: MockerFixture
must run application with enabled authorization must run application with enabled authorization
""" """
port = 8080 port = 8080
application_with_auth["configuration"].set_option("web", "port", str(port)) application_with_auth[ConfigurationKey].set_option("web", "port", str(port))
run_application_mock = mocker.patch("ahriman.web.web.run_app") run_application_mock = mocker.patch("ahriman.web.web.run_app")
run_server(application_with_auth) run_server(application_with_auth)
@ -127,12 +128,12 @@ def test_run_with_socket(application: Application, mocker: MockerFixture) -> Non
must run application must run application
""" """
port = 8080 port = 8080
application["configuration"].set_option("web", "port", str(port)) application[ConfigurationKey].set_option("web", "port", str(port))
socket_mock = mocker.patch("ahriman.web.web._create_socket", return_value=42) socket_mock = mocker.patch("ahriman.web.web._create_socket", return_value=42)
run_application_mock = mocker.patch("ahriman.web.web.run_app") run_application_mock = mocker.patch("ahriman.web.web.run_app")
run_server(application) run_server(application)
socket_mock.assert_called_once_with(application["configuration"], application) socket_mock.assert_called_once_with(application[ConfigurationKey], application)
run_application_mock.assert_called_once_with( run_application_mock.assert_called_once_with(
application, host="127.0.0.1", port=port, sock=42, handle_signals=True, application, host="127.0.0.1", port=port, sock=42, handle_signals=True,
access_log=pytest.helpers.anyvar(int), access_log_class=FilteredAccessLogger access_log=pytest.helpers.anyvar(int), access_log_class=FilteredAccessLogger

View File

@ -9,6 +9,7 @@ from unittest.mock import AsyncMock
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
from ahriman.models.repository_id import RepositoryId from ahriman.models.repository_id import RepositoryId
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.web.keys import WatcherKey
from ahriman.web.views.base import BaseView from ahriman.web.views.base import BaseView
@ -172,9 +173,9 @@ def test_service(base: BaseView) -> None:
must return service for repository must return service for repository
""" """
repository_id = RepositoryId("i686", "repo") repository_id = RepositoryId("i686", "repo")
base.request.app["watcher"] = { base.request.app[WatcherKey] = {
repository_id: watcher repository_id: watcher
for watcher in base.request.app["watcher"].values() for watcher in base.request.app[WatcherKey].values()
} }
assert base.service(repository_id) == base.services[repository_id] assert base.service(repository_id) == base.services[repository_id]

View File

@ -5,6 +5,7 @@ from pytest_mock import MockerFixture
from ahriman.models.user import User from ahriman.models.user import User
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.web.keys import AuthKey
from ahriman.web.views.v1.user.login import LoginView from ahriman.web.views.v1.user.login import LoginView
@ -45,7 +46,7 @@ async def test_get_redirect_to_oauth(client_with_oauth_auth: TestClient) -> None
""" """
must redirect to OAuth service provider in case if no code is supplied must redirect to OAuth service provider in case if no code is supplied
""" """
oauth = client_with_oauth_auth.app["validator"] oauth = client_with_oauth_auth.app[AuthKey]
oauth.get_oauth_url.return_value = "http://localhost" oauth.get_oauth_url.return_value = "http://localhost"
request_schema = pytest.helpers.schema_request(LoginView.get, location="querystring") request_schema = pytest.helpers.schema_request(LoginView.get, location="querystring")
@ -60,7 +61,7 @@ async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClie
""" """
must redirect to OAuth service provider in case if empty code is supplied must redirect to OAuth service provider in case if empty code is supplied
""" """
oauth = client_with_oauth_auth.app["validator"] oauth = client_with_oauth_auth.app[AuthKey]
oauth.get_oauth_url.return_value = "http://localhost" oauth.get_oauth_url.return_value = "http://localhost"
request_schema = pytest.helpers.schema_request(LoginView.get, location="querystring") request_schema = pytest.helpers.schema_request(LoginView.get, location="querystring")
@ -75,7 +76,7 @@ async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) ->
""" """
must log in user correctly from OAuth must log in user correctly from OAuth
""" """
oauth = client_with_oauth_auth.app["validator"] oauth = client_with_oauth_auth.app[AuthKey]
oauth.get_oauth_username.return_value = "user" oauth.get_oauth_username.return_value = "user"
oauth.known_username.return_value = True oauth.known_username.return_value = True
oauth.enabled = False # lol oauth.enabled = False # lol
@ -98,7 +99,7 @@ async def test_get_unauthorized(client_with_oauth_auth: TestClient, mocker: Mock
""" """
must return unauthorized from OAuth must return unauthorized from OAuth
""" """
oauth = client_with_oauth_auth.app["validator"] oauth = client_with_oauth_auth.app[AuthKey]
oauth.known_username.return_value = False oauth.known_username.return_value = False
oauth.max_age = 60 oauth.max_age = 60
remember_mock = mocker.patch("aiohttp_security.remember") remember_mock = mocker.patch("aiohttp_security.remember")