From a93f43dcd050e6604043d6113f3794366fc8a8d9 Mon Sep 17 00:00:00 2001 From: Evgeniy Alekseev Date: Thu, 23 Feb 2023 17:54:30 +0200 Subject: [PATCH] simplify login ttl processing --- docs/ahriman.models.rst | 8 -- src/ahriman/models/user_identity.py | 102 ------------------ src/ahriman/web/middlewares/auth_handler.py | 11 +- src/ahriman/web/views/user/login.py | 15 ++- tests/ahriman/models/conftest.py | 12 --- tests/ahriman/models/test_user_identity.py | 64 ----------- .../web/middlewares/test_auth_handler.py | 30 ++---- 7 files changed, 17 insertions(+), 225 deletions(-) delete mode 100644 src/ahriman/models/user_identity.py delete mode 100644 tests/ahriman/models/test_user_identity.py diff --git a/docs/ahriman.models.rst b/docs/ahriman.models.rst index 57fb3490..72193b88 100644 --- a/docs/ahriman.models.rst +++ b/docs/ahriman.models.rst @@ -196,14 +196,6 @@ ahriman.models.user\_access module :no-undoc-members: :show-inheritance: -ahriman.models.user\_identity module ------------------------------------- - -.. automodule:: ahriman.models.user_identity - :members: - :no-undoc-members: - :show-inheritance: - Module contents --------------- diff --git a/src/ahriman/models/user_identity.py b/src/ahriman/models/user_identity.py deleted file mode 100644 index 0b3781ab..00000000 --- a/src/ahriman/models/user_identity.py +++ /dev/null @@ -1,102 +0,0 @@ -# -# Copyright (c) 2021-2023 ahriman team. -# -# This file is part of ahriman -# (see https://github.com/arcan1s/ahriman). -# -# This program is free software: you can redistribute it and/or modify -# it under the terms of the GNU General Public License as published by -# the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# This program is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU General Public License for more details. -# -# You should have received a copy of the GNU General Public License -# along with this program. If not, see . -# -from __future__ import annotations - -import time - -from dataclasses import dataclass -from typing import Optional, Type - - -@dataclass(frozen=True) -class UserIdentity: - """ - user identity used inside web service - - Attributes: - username(str): username - expire_at(int): identity expiration timestamp - """ - - username: str - expire_at: int - - @classmethod - def from_identity(cls: Type[UserIdentity], identity: str) -> Optional[UserIdentity]: - """ - parse identity into object - - Args: - identity(str): identity from session data - - Returns: - Optional[UserIdentity]: user identity object if it can be parsed and not expired and None otherwise - """ - try: - username, expire_at = identity.split() - user = cls(username, int(expire_at)) - return None if user.is_expired() else user - except ValueError: - return None - - @classmethod - def from_username(cls: Type[UserIdentity], username: Optional[str], max_age: int) -> Optional[UserIdentity]: - """ - generate identity from username - - Args: - username(Optional[str]): username - max_age(int): time to expire, seconds - - Returns: - Optional[UserIdentity]: constructed identity object - """ - return cls(username, cls.expire_when(max_age)) if username is not None else None - - @staticmethod - def expire_when(max_age: int) -> int: - """ - generate expiration time using delta - - Args: - max_age(int): time delta to generate. Must be usually TTE - - Returns: - int: expiration timestamp - """ - return int(time.time()) + max_age - - def is_expired(self) -> bool: - """ - compare timestamp with current timestamp and return True in case if identity is expired - - Returns: - bool: True in case if identity is expired and False otherwise - """ - return self.expire_when(0) > self.expire_at - - def to_identity(self) -> str: - """ - convert object to identity representation - - Returns: - str: web service identity - """ - return f"{self.username} {self.expire_at}" diff --git a/src/ahriman/web/middlewares/auth_handler.py b/src/ahriman/web/middlewares/auth_handler.py index b47ef8ba..27effb5c 100644 --- a/src/ahriman/web/middlewares/auth_handler.py +++ b/src/ahriman/web/middlewares/auth_handler.py @@ -33,7 +33,6 @@ from typing import Optional from ahriman.core.auth import Auth from ahriman.core.configuration import Configuration from ahriman.models.user_access import UserAccess -from ahriman.models.user_identity import UserIdentity from ahriman.web.middlewares import HandlerType, MiddlewareType @@ -67,10 +66,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type Returns: Optional[str]: user identity (username) in case if user exists and None otherwise """ - user = UserIdentity.from_identity(identity) - if user is None: - return None - return user.username if await self.validator.known_username(user.username) else None + return identity if await self.validator.known_username(identity) else None async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool: """ @@ -84,10 +80,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type Returns: bool: True in case if user is allowed to perform this request and False otherwise """ - user = UserIdentity.from_identity(identity) - if user is None: - return False - return await self.validator.verify_access(user.username, permission, context) + return await self.validator.verify_access(identity, permission, context) def auth_handler(allow_read_only: bool) -> MiddlewareType: diff --git a/src/ahriman/web/views/user/login.py b/src/ahriman/web/views/user/login.py index e579060b..b73810f0 100644 --- a/src/ahriman/web/views/user/login.py +++ b/src/ahriman/web/views/user/login.py @@ -21,7 +21,6 @@ from aiohttp.web import HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized from ahriman.core.auth.helpers import remember from ahriman.models.user_access import UserAccess -from ahriman.models.user_identity import UserIdentity from ahriman.web.views.base import BaseView @@ -64,10 +63,9 @@ class LoginView(BaseView): raise HTTPFound(oauth_provider.get_oauth_url()) response = HTTPFound("/") - username = await oauth_provider.get_oauth_username(code) - identity = UserIdentity.from_username(username, self.validator.max_age) - if identity is not None and await self.validator.known_username(username): - await remember(self.request, response, identity.to_identity()) + identity = await oauth_provider.get_oauth_username(code) + if identity is not None and await self.validator.known_username(identity): + await remember(self.request, response, identity) raise response raise HTTPUnauthorized() @@ -111,12 +109,11 @@ class LoginView(BaseView): 302: Found """ data = await self.extract_data() - username = data.get("username") + identity = data.get("username") response = HTTPFound("/") - identity = UserIdentity.from_username(username, self.validator.max_age) - if identity is not None and await self.validator.check_credentials(username, data.get("password")): - await remember(self.request, response, identity.to_identity()) + if identity is not None and await self.validator.check_credentials(identity, data.get("password")): + await remember(self.request, response, identity) raise response raise HTTPUnauthorized() diff --git a/tests/ahriman/models/conftest.py b/tests/ahriman/models/conftest.py index a25234b8..aec4f156 100644 --- a/tests/ahriman/models/conftest.py +++ b/tests/ahriman/models/conftest.py @@ -13,7 +13,6 @@ from ahriman.models.package import Package from ahriman.models.package_description import PackageDescription from ahriman.models.package_source import PackageSource from ahriman.models.remote_source import RemoteSource -from ahriman.models.user_identity import UserIdentity @pytest.fixture @@ -149,14 +148,3 @@ def pyalpm_package_description_ahriman(package_description_ahriman: PackageDescr type(mock).provides = PropertyMock(return_value=package_description_ahriman.provides) type(mock).url = PropertyMock(return_value=package_description_ahriman.url) return mock - - -@pytest.fixture -def user_identity() -> UserIdentity: - """ - identity fixture - - Returns: - UserIdentity: user identity test instance - """ - return UserIdentity("username", int(time.time()) + 30) diff --git a/tests/ahriman/models/test_user_identity.py b/tests/ahriman/models/test_user_identity.py deleted file mode 100644 index 83211350..00000000 --- a/tests/ahriman/models/test_user_identity.py +++ /dev/null @@ -1,64 +0,0 @@ -from ahriman.models.user_identity import UserIdentity - - -def test_from_identity(user_identity: UserIdentity) -> None: - """ - must construct identity object from string - """ - identity = UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}") - assert identity == user_identity - - -def test_from_identity_expired(user_identity: UserIdentity) -> None: - """ - must construct None from expired identity - """ - user_identity = UserIdentity(username=user_identity.username, expire_at=user_identity.expire_at - 60) - assert UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}") is None - - -def test_from_identity_no_split() -> None: - """ - must construct None from invalid string - """ - assert UserIdentity.from_identity("username") is None - - -def test_from_identity_not_int() -> None: - """ - must construct None from invalid timestamp - """ - assert UserIdentity.from_identity("username timestamp") is None - - -def test_from_username() -> None: - """ - must construct identity from username - """ - identity = UserIdentity.from_username("username", 0) - assert identity.username == "username" - # we want to check timestamp too, but later - - -def test_expire_when() -> None: - """ - must return correct expiration time - """ - assert UserIdentity.expire_when(-1) < UserIdentity.expire_when(0) < UserIdentity.expire_when(1) - - -def test_is_expired(user_identity: UserIdentity) -> None: - """ - must return expired flag for expired identities - """ - assert not user_identity.is_expired() - - user_identity = UserIdentity(username=user_identity.username, expire_at=user_identity.expire_at - 60) - assert user_identity.is_expired() - - -def test_to_identity(user_identity: UserIdentity) -> None: - """ - must return correct identity string - """ - assert user_identity == UserIdentity.from_identity(user_identity.to_identity()) diff --git a/tests/ahriman/web/middlewares/test_auth_handler.py b/tests/ahriman/web/middlewares/test_auth_handler.py index e23ff17b..e6aa9e81 100644 --- a/tests/ahriman/web/middlewares/test_auth_handler.py +++ b/tests/ahriman/web/middlewares/test_auth_handler.py @@ -5,42 +5,28 @@ from aiohttp import web from aiohttp.test_utils import TestClient from cryptography import fernet from pytest_mock import MockerFixture -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, call as MockCall from ahriman.core.auth import Auth from ahriman.core.configuration import Configuration from ahriman.models.user import User from ahriman.models.user_access import UserAccess -from ahriman.models.user_identity import UserIdentity from ahriman.web.middlewares.auth_handler import AuthorizationPolicy, auth_handler, cookie_secret_key, setup_auth -def _identity(username: str) -> str: - """ - generate identity from user - - Args: - username(str): name of the user - - Returns: - str: user identity string - """ - return f"{username} {UserIdentity.expire_when(60)}" - - async def test_authorized_userid(authorization_policy: AuthorizationPolicy, user: User, mocker: MockerFixture) -> None: """ must return authorized user id """ mocker.patch("ahriman.core.database.SQLite.user_get", return_value=user) - assert await authorization_policy.authorized_userid(_identity(user.username)) == user.username + assert await authorization_policy.authorized_userid(user.username) == user.username async def test_authorized_userid_unknown(authorization_policy: AuthorizationPolicy, user: User) -> None: """ must not allow unknown user id for authorization """ - assert await authorization_policy.authorized_userid(_identity("somerandomname")) is None + assert await authorization_policy.authorized_userid("somerandomname") is None assert await authorization_policy.authorized_userid("somerandomname") is None @@ -51,11 +37,13 @@ async def test_permits(authorization_policy: AuthorizationPolicy, user: User) -> authorization_policy.validator = AsyncMock() authorization_policy.validator.verify_access.side_effect = lambda username, *args: username == user.username - assert await authorization_policy.permits(_identity(user.username), user.access, "/endpoint") - authorization_policy.validator.verify_access.assert_called_once_with(user.username, user.access, "/endpoint") + assert await authorization_policy.permits(user.username, user.access, "/endpoint") + assert not await authorization_policy.permits("somerandomname", user.access, "/endpoint") - assert not await authorization_policy.permits(_identity("somerandomname"), user.access, "/endpoint") - assert not await authorization_policy.permits(user.username, user.access, "/endpoint") + authorization_policy.validator.verify_access.assert_has_calls([ + MockCall(user.username, user.access, "/endpoint"), + MockCall("somerandomname", user.access, "/endpoint"), + ]) async def test_auth_handler_unix_socket(client_with_auth: TestClient, mocker: MockerFixture) -> None: