simplify login ttl processing

This commit is contained in:
Evgenii Alekseev 2023-02-23 17:54:30 +02:00
parent 20974dae6f
commit a93f43dcd0
7 changed files with 17 additions and 225 deletions

View File

@ -196,14 +196,6 @@ ahriman.models.user\_access module
:no-undoc-members: :no-undoc-members:
:show-inheritance: :show-inheritance:
ahriman.models.user\_identity module
------------------------------------
.. automodule:: ahriman.models.user_identity
:members:
:no-undoc-members:
:show-inheritance:
Module contents Module contents
--------------- ---------------

View File

@ -1,102 +0,0 @@
#
# Copyright (c) 2021-2023 ahriman team.
#
# This file is part of ahriman
# (see https://github.com/arcan1s/ahriman).
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Optional, Type
@dataclass(frozen=True)
class UserIdentity:
"""
user identity used inside web service
Attributes:
username(str): username
expire_at(int): identity expiration timestamp
"""
username: str
expire_at: int
@classmethod
def from_identity(cls: Type[UserIdentity], identity: str) -> Optional[UserIdentity]:
"""
parse identity into object
Args:
identity(str): identity from session data
Returns:
Optional[UserIdentity]: user identity object if it can be parsed and not expired and None otherwise
"""
try:
username, expire_at = identity.split()
user = cls(username, int(expire_at))
return None if user.is_expired() else user
except ValueError:
return None
@classmethod
def from_username(cls: Type[UserIdentity], username: Optional[str], max_age: int) -> Optional[UserIdentity]:
"""
generate identity from username
Args:
username(Optional[str]): username
max_age(int): time to expire, seconds
Returns:
Optional[UserIdentity]: constructed identity object
"""
return cls(username, cls.expire_when(max_age)) if username is not None else None
@staticmethod
def expire_when(max_age: int) -> int:
"""
generate expiration time using delta
Args:
max_age(int): time delta to generate. Must be usually TTE
Returns:
int: expiration timestamp
"""
return int(time.time()) + max_age
def is_expired(self) -> bool:
"""
compare timestamp with current timestamp and return True in case if identity is expired
Returns:
bool: True in case if identity is expired and False otherwise
"""
return self.expire_when(0) > self.expire_at
def to_identity(self) -> str:
"""
convert object to identity representation
Returns:
str: web service identity
"""
return f"{self.username} {self.expire_at}"

View File

@ -33,7 +33,6 @@ from typing import Optional
from ahriman.core.auth import Auth from ahriman.core.auth import Auth
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.models.user_identity import UserIdentity
from ahriman.web.middlewares import HandlerType, MiddlewareType from ahriman.web.middlewares import HandlerType, MiddlewareType
@ -67,10 +66,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
Returns: Returns:
Optional[str]: user identity (username) in case if user exists and None otherwise Optional[str]: user identity (username) in case if user exists and None otherwise
""" """
user = UserIdentity.from_identity(identity) return identity if await self.validator.known_username(identity) else None
if user is None:
return None
return user.username if await self.validator.known_username(user.username) else None
async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool: async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool:
""" """
@ -84,10 +80,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
Returns: Returns:
bool: True in case if user is allowed to perform this request and False otherwise bool: True in case if user is allowed to perform this request and False otherwise
""" """
user = UserIdentity.from_identity(identity) return await self.validator.verify_access(identity, permission, context)
if user is None:
return False
return await self.validator.verify_access(user.username, permission, context)
def auth_handler(allow_read_only: bool) -> MiddlewareType: def auth_handler(allow_read_only: bool) -> MiddlewareType:

View File

@ -21,7 +21,6 @@ from aiohttp.web import HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized
from ahriman.core.auth.helpers import remember from ahriman.core.auth.helpers import remember
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.models.user_identity import UserIdentity
from ahriman.web.views.base import BaseView from ahriman.web.views.base import BaseView
@ -64,10 +63,9 @@ class LoginView(BaseView):
raise HTTPFound(oauth_provider.get_oauth_url()) raise HTTPFound(oauth_provider.get_oauth_url())
response = HTTPFound("/") response = HTTPFound("/")
username = await oauth_provider.get_oauth_username(code) identity = await oauth_provider.get_oauth_username(code)
identity = UserIdentity.from_username(username, self.validator.max_age) if identity is not None and await self.validator.known_username(identity):
if identity is not None and await self.validator.known_username(username): await remember(self.request, response, identity)
await remember(self.request, response, identity.to_identity())
raise response raise response
raise HTTPUnauthorized() raise HTTPUnauthorized()
@ -111,12 +109,11 @@ class LoginView(BaseView):
302: Found 302: Found
""" """
data = await self.extract_data() data = await self.extract_data()
username = data.get("username") identity = data.get("username")
response = HTTPFound("/") response = HTTPFound("/")
identity = UserIdentity.from_username(username, self.validator.max_age) if identity is not None and await self.validator.check_credentials(identity, data.get("password")):
if identity is not None and await self.validator.check_credentials(username, data.get("password")): await remember(self.request, response, identity)
await remember(self.request, response, identity.to_identity())
raise response raise response
raise HTTPUnauthorized() raise HTTPUnauthorized()

View File

@ -13,7 +13,6 @@ from ahriman.models.package import Package
from ahriman.models.package_description import PackageDescription from ahriman.models.package_description import PackageDescription
from ahriman.models.package_source import PackageSource from ahriman.models.package_source import PackageSource
from ahriman.models.remote_source import RemoteSource from ahriman.models.remote_source import RemoteSource
from ahriman.models.user_identity import UserIdentity
@pytest.fixture @pytest.fixture
@ -149,14 +148,3 @@ def pyalpm_package_description_ahriman(package_description_ahriman: PackageDescr
type(mock).provides = PropertyMock(return_value=package_description_ahriman.provides) type(mock).provides = PropertyMock(return_value=package_description_ahriman.provides)
type(mock).url = PropertyMock(return_value=package_description_ahriman.url) type(mock).url = PropertyMock(return_value=package_description_ahriman.url)
return mock return mock
@pytest.fixture
def user_identity() -> UserIdentity:
"""
identity fixture
Returns:
UserIdentity: user identity test instance
"""
return UserIdentity("username", int(time.time()) + 30)

View File

@ -1,64 +0,0 @@
from ahriman.models.user_identity import UserIdentity
def test_from_identity(user_identity: UserIdentity) -> None:
"""
must construct identity object from string
"""
identity = UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}")
assert identity == user_identity
def test_from_identity_expired(user_identity: UserIdentity) -> None:
"""
must construct None from expired identity
"""
user_identity = UserIdentity(username=user_identity.username, expire_at=user_identity.expire_at - 60)
assert UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}") is None
def test_from_identity_no_split() -> None:
"""
must construct None from invalid string
"""
assert UserIdentity.from_identity("username") is None
def test_from_identity_not_int() -> None:
"""
must construct None from invalid timestamp
"""
assert UserIdentity.from_identity("username timestamp") is None
def test_from_username() -> None:
"""
must construct identity from username
"""
identity = UserIdentity.from_username("username", 0)
assert identity.username == "username"
# we want to check timestamp too, but later
def test_expire_when() -> None:
"""
must return correct expiration time
"""
assert UserIdentity.expire_when(-1) < UserIdentity.expire_when(0) < UserIdentity.expire_when(1)
def test_is_expired(user_identity: UserIdentity) -> None:
"""
must return expired flag for expired identities
"""
assert not user_identity.is_expired()
user_identity = UserIdentity(username=user_identity.username, expire_at=user_identity.expire_at - 60)
assert user_identity.is_expired()
def test_to_identity(user_identity: UserIdentity) -> None:
"""
must return correct identity string
"""
assert user_identity == UserIdentity.from_identity(user_identity.to_identity())

View File

@ -5,42 +5,28 @@ from aiohttp import web
from aiohttp.test_utils import TestClient from aiohttp.test_utils import TestClient
from cryptography import fernet from cryptography import fernet
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from unittest.mock import AsyncMock from unittest.mock import AsyncMock, call as MockCall
from ahriman.core.auth import Auth from ahriman.core.auth import Auth
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
from ahriman.models.user import User from ahriman.models.user import User
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.models.user_identity import UserIdentity
from ahriman.web.middlewares.auth_handler import AuthorizationPolicy, auth_handler, cookie_secret_key, setup_auth from ahriman.web.middlewares.auth_handler import AuthorizationPolicy, auth_handler, cookie_secret_key, setup_auth
def _identity(username: str) -> str:
"""
generate identity from user
Args:
username(str): name of the user
Returns:
str: user identity string
"""
return f"{username} {UserIdentity.expire_when(60)}"
async def test_authorized_userid(authorization_policy: AuthorizationPolicy, user: User, mocker: MockerFixture) -> None: async def test_authorized_userid(authorization_policy: AuthorizationPolicy, user: User, mocker: MockerFixture) -> None:
""" """
must return authorized user id must return authorized user id
""" """
mocker.patch("ahriman.core.database.SQLite.user_get", return_value=user) mocker.patch("ahriman.core.database.SQLite.user_get", return_value=user)
assert await authorization_policy.authorized_userid(_identity(user.username)) == user.username assert await authorization_policy.authorized_userid(user.username) == user.username
async def test_authorized_userid_unknown(authorization_policy: AuthorizationPolicy, user: User) -> None: async def test_authorized_userid_unknown(authorization_policy: AuthorizationPolicy, user: User) -> None:
""" """
must not allow unknown user id for authorization must not allow unknown user id for authorization
""" """
assert await authorization_policy.authorized_userid(_identity("somerandomname")) is None assert await authorization_policy.authorized_userid("somerandomname") is None
assert await authorization_policy.authorized_userid("somerandomname") is None assert await authorization_policy.authorized_userid("somerandomname") is None
@ -51,11 +37,13 @@ async def test_permits(authorization_policy: AuthorizationPolicy, user: User) ->
authorization_policy.validator = AsyncMock() authorization_policy.validator = AsyncMock()
authorization_policy.validator.verify_access.side_effect = lambda username, *args: username == user.username authorization_policy.validator.verify_access.side_effect = lambda username, *args: username == user.username
assert await authorization_policy.permits(_identity(user.username), user.access, "/endpoint") assert await authorization_policy.permits(user.username, user.access, "/endpoint")
authorization_policy.validator.verify_access.assert_called_once_with(user.username, user.access, "/endpoint") assert not await authorization_policy.permits("somerandomname", user.access, "/endpoint")
assert not await authorization_policy.permits(_identity("somerandomname"), user.access, "/endpoint") authorization_policy.validator.verify_access.assert_has_calls([
assert not await authorization_policy.permits(user.username, user.access, "/endpoint") MockCall(user.username, user.access, "/endpoint"),
MockCall("somerandomname", user.access, "/endpoint"),
])
async def test_auth_handler_unix_socket(client_with_auth: TestClient, mocker: MockerFixture) -> None: async def test_auth_handler_unix_socket(client_with_auth: TestClient, mocker: MockerFixture) -> None: