diff --git a/src/ahriman/models/user_identity.py b/src/ahriman/models/user_identity.py new file mode 100644 index 00000000..76d67245 --- /dev/null +++ b/src/ahriman/models/user_identity.py @@ -0,0 +1,84 @@ +# +# Copyright (c) 2021 ahriman team. +# +# This file is part of ahriman +# (see https://github.com/arcan1s/ahriman). +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . +# +from __future__ import annotations + +import time + +from dataclasses import dataclass +from typing import Optional, Type + + +@dataclass +class UserIdentity: + """ + user identity used inside web service + :ivar username: username + :ivar expire_at: identity expiration timestamp + """ + + username: str + expire_at: int + + @classmethod + def from_identity(cls: Type[UserIdentity], identity: str) -> Optional[UserIdentity]: + """ + parse identity into object + :param identity: identity from session data + :return: user identity object if it can be parsed and not expired and None otherwise + """ + try: + username, expire_at = identity.split() + user = cls(username, int(expire_at)) + return None if user.is_expired() else user + except ValueError: + return None + + @classmethod + def from_username(cls: Type[UserIdentity], username: Optional[str], max_age: int) -> Optional[UserIdentity]: + """ + generate identity from username + :param username: username + :param max_age: time to expire, seconds + :return: constructed identity object + """ + return cls(username, cls.expire_when(max_age)) if username is not None else None + + @staticmethod + def expire_when(max_age: int) -> int: + """ + generate expiration time using delta + :param max_age: time delta to generate. Must be usually TTE + :return: expiration timestamp + """ + return int(time.time()) + max_age + + def is_expired(self) -> bool: + """ + compare timestamp with current timestamp and return True in case if identity is expired + :return: True in case if identity is expired and False otherwise + """ + return self.expire_when(0) > self.expire_at + + def to_identity(self) -> str: + """ + convert object to identity representation + :return: web service identity + """ + return f"{self.username} {self.expire_at}" diff --git a/src/ahriman/web/middlewares/auth_handler.py b/src/ahriman/web/middlewares/auth_handler.py index 49eb9bbb..2822ead7 100644 --- a/src/ahriman/web/middlewares/auth_handler.py +++ b/src/ahriman/web/middlewares/auth_handler.py @@ -30,6 +30,7 @@ from typing import Optional from ahriman.core.auth.auth import Auth from ahriman.models.user_access import UserAccess +from ahriman.models.user_identity import UserIdentity from ahriman.web.middlewares import HandlerType, MiddlewareType @@ -52,7 +53,10 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type :param identity: username :return: user identity (username) in case if user exists and None otherwise """ - return identity if await self.validator.known_username(identity) else None + user = UserIdentity.from_identity(identity) + if user is None: + return None + return user.username if await self.validator.known_username(user.username) else None async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool: """ @@ -62,7 +66,10 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type :param context: URI request path :return: True in case if user is allowed to perform this request and False otherwise """ - return await self.validator.verify_access(identity, permission, context) + user = UserIdentity.from_identity(identity) + if user is None: + return False + return await self.validator.verify_access(user.username, permission, context) def auth_handler(validator: Auth) -> MiddlewareType: diff --git a/src/ahriman/web/views/user/login.py b/src/ahriman/web/views/user/login.py index 18c05dd0..7065d9e4 100644 --- a/src/ahriman/web/views/user/login.py +++ b/src/ahriman/web/views/user/login.py @@ -20,6 +20,7 @@ from aiohttp.web import HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized, Response from ahriman.core.auth.helpers import remember +from ahriman.models.user_identity import UserIdentity from ahriman.web.views.base import BaseView @@ -49,8 +50,9 @@ class LoginView(BaseView): response = HTTPFound("/") username = await oauth_provider.get_oauth_username(code) - if await self.validator.known_username(username): - await remember(self.request, response, username) + identity = UserIdentity.from_username(username, self.validator.max_age) + if identity is not None and await self.validator.known_username(username): + await remember(self.request, response, identity.to_identity()) return response raise HTTPUnauthorized() @@ -71,8 +73,9 @@ class LoginView(BaseView): username = data.get("username") response = HTTPFound("/") - if await self.validator.check_credentials(username, data.get("password")): - await remember(self.request, response, username) + identity = UserIdentity.from_username(username, self.validator.max_age) + if identity is not None and await self.validator.check_credentials(username, data.get("password")): + await remember(self.request, response, identity.to_identity()) return response raise HTTPUnauthorized() diff --git a/tests/ahriman/models/conftest.py b/tests/ahriman/models/conftest.py index 8bd3e490..edfdaf00 100644 --- a/tests/ahriman/models/conftest.py +++ b/tests/ahriman/models/conftest.py @@ -1,4 +1,5 @@ import pytest +import time from unittest.mock import MagicMock, PropertyMock @@ -8,6 +9,7 @@ from ahriman.models.counters import Counters from ahriman.models.internal_status import InternalStatus from ahriman.models.package import Package from ahriman.models.package_description import PackageDescription +from ahriman.models.user_identity import UserIdentity @pytest.fixture @@ -104,3 +106,12 @@ def pyalpm_package_description_ahriman(package_description_ahriman: PackageDescr type(mock).provides = PropertyMock(return_value=package_description_ahriman.provides) type(mock).url = PropertyMock(return_value=package_description_ahriman.url) return mock + + +@pytest.fixture +def user_identity() -> UserIdentity: + """ + identity fixture + :return: user identity test instance + """ + return UserIdentity("username", int(time.time()) + 30) diff --git a/tests/ahriman/models/test_user_identity.py b/tests/ahriman/models/test_user_identity.py new file mode 100644 index 00000000..2047f4ba --- /dev/null +++ b/tests/ahriman/models/test_user_identity.py @@ -0,0 +1,64 @@ +from ahriman.models.user_identity import UserIdentity + + +def test_from_identity(user_identity: UserIdentity) -> None: + """ + must construct identity object from string + """ + identity = UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}") + assert identity == user_identity + + +def test_from_identity_expired(user_identity: UserIdentity) -> None: + """ + must construct None from expired identity + """ + user_identity.expire_at -= 60 + assert UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}") is None + + +def test_from_identity_no_split() -> None: + """ + must construct None from invalid string + """ + assert UserIdentity.from_identity("username") is None + + +def test_from_identity_not_int() -> None: + """ + must construct None from invalid timestamp + """ + assert UserIdentity.from_identity("username timestamp") is None + + +def test_from_username() -> None: + """ + must construct identity from username + """ + identity = UserIdentity.from_username("username", 0) + assert identity.username == "username" + # we want to check timestamp too, but later + + +def test_expire_when() -> None: + """ + must return correct expiration time + """ + assert UserIdentity.expire_when(-1) < UserIdentity.expire_when(0) < UserIdentity.expire_when(1) + + +def test_is_expired(user_identity: UserIdentity) -> None: + """ + must return expired flag for expired identities + """ + assert not user_identity.is_expired() + + user_identity.expire_at -= 60 + assert user_identity.is_expired() + + +def test_to_identity(user_identity: UserIdentity) -> None: + """ + must return correct identity string + """ + assert user_identity == UserIdentity.from_identity(user_identity.to_identity()) diff --git a/tests/ahriman/web/middlewares/test_auth_handler.py b/tests/ahriman/web/middlewares/test_auth_handler.py index 21a6aebf..d3d65004 100644 --- a/tests/ahriman/web/middlewares/test_auth_handler.py +++ b/tests/ahriman/web/middlewares/test_auth_handler.py @@ -7,15 +7,26 @@ from unittest.mock import AsyncMock from ahriman.core.auth.auth import Auth from ahriman.models.user import User from ahriman.models.user_access import UserAccess +from ahriman.models.user_identity import UserIdentity from ahriman.web.middlewares.auth_handler import auth_handler, AuthorizationPolicy, setup_auth +def _identity(username: str) -> str: + """ + generate identity from user + :param user: user fixture object + :return: user identity string + """ + return f"{username} {UserIdentity.expire_when(60)}" + + async def test_authorized_userid(authorization_policy: AuthorizationPolicy, user: User) -> None: """ must return authorized user id """ - assert await authorization_policy.authorized_userid(user.username) == user.username - assert await authorization_policy.authorized_userid("some random name") is None + assert await authorization_policy.authorized_userid(_identity(user.username)) == user.username + assert await authorization_policy.authorized_userid(_identity("somerandomname")) is None + assert await authorization_policy.authorized_userid("somerandomname") is None async def test_permits(authorization_policy: AuthorizationPolicy, user: User) -> None: @@ -23,11 +34,14 @@ async def test_permits(authorization_policy: AuthorizationPolicy, user: User) -> must call validator check """ authorization_policy.validator = AsyncMock() - authorization_policy.validator.verify_access.return_value = True + authorization_policy.validator.verify_access.side_effect = lambda username, *args: username == user.username - assert await authorization_policy.permits(user.username, user.access, "/endpoint") + assert await authorization_policy.permits(_identity(user.username), user.access, "/endpoint") authorization_policy.validator.verify_access.assert_called_with(user.username, user.access, "/endpoint") + assert not await authorization_policy.permits(_identity("somerandomname"), user.access, "/endpoint") + assert not await authorization_policy.permits(user.username, user.access, "/endpoint") + async def test_auth_handler_api(auth: Auth, mocker: MockerFixture) -> None: """ diff --git a/tests/ahriman/web/views/user/test_views_user_login.py b/tests/ahriman/web/views/user/test_views_user_login.py index 622cf565..e524f6dc 100644 --- a/tests/ahriman/web/views/user/test_views_user_login.py +++ b/tests/ahriman/web/views/user/test_views_user_login.py @@ -45,7 +45,8 @@ async def test_get(client_with_auth: TestClient, mocker: MockerFixture) -> None: oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth) oauth.get_oauth_username.return_value = "user" oauth.known_username.return_value = True - oauth.enabled = False # lol + oauth.enabled = False # lol\ + oauth.max_age = 60 remember_mock = mocker.patch("aiohttp_security.remember") get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"}) @@ -62,6 +63,7 @@ async def test_get_unauthorized(client_with_auth: TestClient, mocker: MockerFixt """ oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth) oauth.known_username.return_value = False + oauth.max_age = 60 remember_mock = mocker.patch("aiohttp_security.remember") get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"})