diff --git a/src/ahriman/core/auth/helpers.py b/src/ahriman/core/auth/helpers.py index 36135386..5c383b0a 100644 --- a/src/ahriman/core/auth/helpers.py +++ b/src/ahriman/core/auth/helpers.py @@ -20,7 +20,7 @@ from typing import Any try: - import aiohttp_security # type: ignore[import-untyped] + import aiohttp_security _has_aiohttp_security = True except ImportError: _has_aiohttp_security = False diff --git a/src/ahriman/web/middlewares/auth_handler.py b/src/ahriman/web/middlewares/auth_handler.py index a3f32d62..640b1a09 100644 --- a/src/ahriman/web/middlewares/auth_handler.py +++ b/src/ahriman/web/middlewares/auth_handler.py @@ -17,7 +17,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # -import aiohttp_security # type: ignore[import-untyped] +import aiohttp_security import socket import types @@ -25,6 +25,7 @@ from aiohttp.web import Application, Request, StaticResource, StreamResponse, mi from aiohttp_session import setup as setup_session from aiohttp_session.cookie_storage import EncryptedCookieStorage from cryptography import fernet +from enum import Enum from ahriman.core.auth import Auth from ahriman.core.configuration import Configuration @@ -50,6 +51,7 @@ class _AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): Args: validator(Auth): authorization module instance """ + aiohttp_security.AbstractAuthorizationPolicy.__init__(self) self.validator = validator async def authorized_userid(self, identity: str) -> str | None: @@ -64,18 +66,21 @@ class _AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): """ return identity if await self.validator.known_username(identity) else None - async def permits(self, identity: str, permission: UserAccess, context: str | None = None) -> bool: + async def permits(self, identity: str | None, permission: str | Enum, context: str | None = None) -> bool: """ check user permissions Args: - identity(str): username - permission(UserAccess): requested permission level + identity(str | None): username + permission(str | Enum): requested permission level context(str | None, optional): URI request path (Default value = None) Returns: bool: True in case if user is allowed to perform this request and False otherwise """ + # some methods for type checking and parent class compatibility + if identity is None or not isinstance(permission, UserAccess): + return False # no identity provided or invalid access rights requested return await self.validator.verify_access(identity, permission, context) diff --git a/src/ahriman/web/views/base.py b/src/ahriman/web/views/base.py index 53209ed3..ddf15599 100644 --- a/src/ahriman/web/views/base.py +++ b/src/ahriman/web/views/base.py @@ -139,7 +139,7 @@ class BaseView(View, CorsViewMixin): return value # pylint: disable=not-callable,protected-access - async def head(self) -> StreamResponse: # type: ignore[return] + async def head(self) -> StreamResponse: """ HEAD method implementation based on the result of GET method diff --git a/tests/ahriman/web/middlewares/test_auth_handler.py b/tests/ahriman/web/middlewares/test_auth_handler.py index f7afdf1f..a05c9ec1 100644 --- a/tests/ahriman/web/middlewares/test_auth_handler.py +++ b/tests/ahriman/web/middlewares/test_auth_handler.py @@ -9,6 +9,7 @@ from unittest.mock import AsyncMock, call as MockCall from ahriman.core.auth import Auth from ahriman.core.configuration import Configuration +from ahriman.models.build_status import BuildStatusEnum from ahriman.models.user import User from ahriman.models.user_access import UserAccess from ahriman.web.middlewares.auth_handler import _AuthorizationPolicy, _auth_handler, _cookie_secret_key, setup_auth @@ -39,6 +40,9 @@ async def test_permits(authorization_policy: _AuthorizationPolicy, user: User) - assert await authorization_policy.permits(user.username, user.access, "/endpoint") assert not await authorization_policy.permits("somerandomname", user.access, "/endpoint") + assert not await authorization_policy.permits(None, user.access, "/endpoint") + assert not await authorization_policy.permits(user.username, "random", "/endpoint") + assert not await authorization_policy.permits(None, BuildStatusEnum.Building, "/endpoint") authorization_policy.validator.verify_access.assert_has_calls([ MockCall(user.username, user.access, "/endpoint"),