mirror of
https://github.com/arcan1s/ahriman.git
synced 2025-04-24 07:17:17 +00:00
simplify login ttl processing
This commit is contained in:
parent
20974dae6f
commit
a93f43dcd0
@ -196,14 +196,6 @@ ahriman.models.user\_access module
|
|||||||
:no-undoc-members:
|
:no-undoc-members:
|
||||||
:show-inheritance:
|
:show-inheritance:
|
||||||
|
|
||||||
ahriman.models.user\_identity module
|
|
||||||
------------------------------------
|
|
||||||
|
|
||||||
.. automodule:: ahriman.models.user_identity
|
|
||||||
:members:
|
|
||||||
:no-undoc-members:
|
|
||||||
:show-inheritance:
|
|
||||||
|
|
||||||
Module contents
|
Module contents
|
||||||
---------------
|
---------------
|
||||||
|
|
||||||
|
@ -1,102 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2021-2023 ahriman team.
|
|
||||||
#
|
|
||||||
# This file is part of ahriman
|
|
||||||
# (see https://github.com/arcan1s/ahriman).
|
|
||||||
#
|
|
||||||
# This program is free software: you can redistribute it and/or modify
|
|
||||||
# it under the terms of the GNU General Public License as published by
|
|
||||||
# the Free Software Foundation, either version 3 of the License, or
|
|
||||||
# (at your option) any later version.
|
|
||||||
#
|
|
||||||
# This program is distributed in the hope that it will be useful,
|
|
||||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
# GNU General Public License for more details.
|
|
||||||
#
|
|
||||||
# You should have received a copy of the GNU General Public License
|
|
||||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
||||||
#
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Optional, Type
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class UserIdentity:
|
|
||||||
"""
|
|
||||||
user identity used inside web service
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
username(str): username
|
|
||||||
expire_at(int): identity expiration timestamp
|
|
||||||
"""
|
|
||||||
|
|
||||||
username: str
|
|
||||||
expire_at: int
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_identity(cls: Type[UserIdentity], identity: str) -> Optional[UserIdentity]:
|
|
||||||
"""
|
|
||||||
parse identity into object
|
|
||||||
|
|
||||||
Args:
|
|
||||||
identity(str): identity from session data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[UserIdentity]: user identity object if it can be parsed and not expired and None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
username, expire_at = identity.split()
|
|
||||||
user = cls(username, int(expire_at))
|
|
||||||
return None if user.is_expired() else user
|
|
||||||
except ValueError:
|
|
||||||
return None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_username(cls: Type[UserIdentity], username: Optional[str], max_age: int) -> Optional[UserIdentity]:
|
|
||||||
"""
|
|
||||||
generate identity from username
|
|
||||||
|
|
||||||
Args:
|
|
||||||
username(Optional[str]): username
|
|
||||||
max_age(int): time to expire, seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Optional[UserIdentity]: constructed identity object
|
|
||||||
"""
|
|
||||||
return cls(username, cls.expire_when(max_age)) if username is not None else None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def expire_when(max_age: int) -> int:
|
|
||||||
"""
|
|
||||||
generate expiration time using delta
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_age(int): time delta to generate. Must be usually TTE
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: expiration timestamp
|
|
||||||
"""
|
|
||||||
return int(time.time()) + max_age
|
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
|
||||||
"""
|
|
||||||
compare timestamp with current timestamp and return True in case if identity is expired
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: True in case if identity is expired and False otherwise
|
|
||||||
"""
|
|
||||||
return self.expire_when(0) > self.expire_at
|
|
||||||
|
|
||||||
def to_identity(self) -> str:
|
|
||||||
"""
|
|
||||||
convert object to identity representation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: web service identity
|
|
||||||
"""
|
|
||||||
return f"{self.username} {self.expire_at}"
|
|
@ -33,7 +33,6 @@ from typing import Optional
|
|||||||
from ahriman.core.auth import Auth
|
from ahriman.core.auth import Auth
|
||||||
from ahriman.core.configuration import Configuration
|
from ahriman.core.configuration import Configuration
|
||||||
from ahriman.models.user_access import UserAccess
|
from ahriman.models.user_access import UserAccess
|
||||||
from ahriman.models.user_identity import UserIdentity
|
|
||||||
from ahriman.web.middlewares import HandlerType, MiddlewareType
|
from ahriman.web.middlewares import HandlerType, MiddlewareType
|
||||||
|
|
||||||
|
|
||||||
@ -67,10 +66,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
|
|||||||
Returns:
|
Returns:
|
||||||
Optional[str]: user identity (username) in case if user exists and None otherwise
|
Optional[str]: user identity (username) in case if user exists and None otherwise
|
||||||
"""
|
"""
|
||||||
user = UserIdentity.from_identity(identity)
|
return identity if await self.validator.known_username(identity) else None
|
||||||
if user is None:
|
|
||||||
return None
|
|
||||||
return user.username if await self.validator.known_username(user.username) else None
|
|
||||||
|
|
||||||
async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool:
|
async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool:
|
||||||
"""
|
"""
|
||||||
@ -84,10 +80,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True in case if user is allowed to perform this request and False otherwise
|
bool: True in case if user is allowed to perform this request and False otherwise
|
||||||
"""
|
"""
|
||||||
user = UserIdentity.from_identity(identity)
|
return await self.validator.verify_access(identity, permission, context)
|
||||||
if user is None:
|
|
||||||
return False
|
|
||||||
return await self.validator.verify_access(user.username, permission, context)
|
|
||||||
|
|
||||||
|
|
||||||
def auth_handler(allow_read_only: bool) -> MiddlewareType:
|
def auth_handler(allow_read_only: bool) -> MiddlewareType:
|
||||||
|
@ -21,7 +21,6 @@ from aiohttp.web import HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized
|
|||||||
|
|
||||||
from ahriman.core.auth.helpers import remember
|
from ahriman.core.auth.helpers import remember
|
||||||
from ahriman.models.user_access import UserAccess
|
from ahriman.models.user_access import UserAccess
|
||||||
from ahriman.models.user_identity import UserIdentity
|
|
||||||
from ahriman.web.views.base import BaseView
|
from ahriman.web.views.base import BaseView
|
||||||
|
|
||||||
|
|
||||||
@ -64,10 +63,9 @@ class LoginView(BaseView):
|
|||||||
raise HTTPFound(oauth_provider.get_oauth_url())
|
raise HTTPFound(oauth_provider.get_oauth_url())
|
||||||
|
|
||||||
response = HTTPFound("/")
|
response = HTTPFound("/")
|
||||||
username = await oauth_provider.get_oauth_username(code)
|
identity = await oauth_provider.get_oauth_username(code)
|
||||||
identity = UserIdentity.from_username(username, self.validator.max_age)
|
if identity is not None and await self.validator.known_username(identity):
|
||||||
if identity is not None and await self.validator.known_username(username):
|
await remember(self.request, response, identity)
|
||||||
await remember(self.request, response, identity.to_identity())
|
|
||||||
raise response
|
raise response
|
||||||
|
|
||||||
raise HTTPUnauthorized()
|
raise HTTPUnauthorized()
|
||||||
@ -111,12 +109,11 @@ class LoginView(BaseView):
|
|||||||
302: Found
|
302: Found
|
||||||
"""
|
"""
|
||||||
data = await self.extract_data()
|
data = await self.extract_data()
|
||||||
username = data.get("username")
|
identity = data.get("username")
|
||||||
|
|
||||||
response = HTTPFound("/")
|
response = HTTPFound("/")
|
||||||
identity = UserIdentity.from_username(username, self.validator.max_age)
|
if identity is not None and await self.validator.check_credentials(identity, data.get("password")):
|
||||||
if identity is not None and await self.validator.check_credentials(username, data.get("password")):
|
await remember(self.request, response, identity)
|
||||||
await remember(self.request, response, identity.to_identity())
|
|
||||||
raise response
|
raise response
|
||||||
|
|
||||||
raise HTTPUnauthorized()
|
raise HTTPUnauthorized()
|
||||||
|
@ -13,7 +13,6 @@ from ahriman.models.package import Package
|
|||||||
from ahriman.models.package_description import PackageDescription
|
from ahriman.models.package_description import PackageDescription
|
||||||
from ahriman.models.package_source import PackageSource
|
from ahriman.models.package_source import PackageSource
|
||||||
from ahriman.models.remote_source import RemoteSource
|
from ahriman.models.remote_source import RemoteSource
|
||||||
from ahriman.models.user_identity import UserIdentity
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -149,14 +148,3 @@ def pyalpm_package_description_ahriman(package_description_ahriman: PackageDescr
|
|||||||
type(mock).provides = PropertyMock(return_value=package_description_ahriman.provides)
|
type(mock).provides = PropertyMock(return_value=package_description_ahriman.provides)
|
||||||
type(mock).url = PropertyMock(return_value=package_description_ahriman.url)
|
type(mock).url = PropertyMock(return_value=package_description_ahriman.url)
|
||||||
return mock
|
return mock
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def user_identity() -> UserIdentity:
|
|
||||||
"""
|
|
||||||
identity fixture
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
UserIdentity: user identity test instance
|
|
||||||
"""
|
|
||||||
return UserIdentity("username", int(time.time()) + 30)
|
|
||||||
|
@ -1,64 +0,0 @@
|
|||||||
from ahriman.models.user_identity import UserIdentity
|
|
||||||
|
|
||||||
|
|
||||||
def test_from_identity(user_identity: UserIdentity) -> None:
|
|
||||||
"""
|
|
||||||
must construct identity object from string
|
|
||||||
"""
|
|
||||||
identity = UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}")
|
|
||||||
assert identity == user_identity
|
|
||||||
|
|
||||||
|
|
||||||
def test_from_identity_expired(user_identity: UserIdentity) -> None:
|
|
||||||
"""
|
|
||||||
must construct None from expired identity
|
|
||||||
"""
|
|
||||||
user_identity = UserIdentity(username=user_identity.username, expire_at=user_identity.expire_at - 60)
|
|
||||||
assert UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_from_identity_no_split() -> None:
|
|
||||||
"""
|
|
||||||
must construct None from invalid string
|
|
||||||
"""
|
|
||||||
assert UserIdentity.from_identity("username") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_from_identity_not_int() -> None:
|
|
||||||
"""
|
|
||||||
must construct None from invalid timestamp
|
|
||||||
"""
|
|
||||||
assert UserIdentity.from_identity("username timestamp") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_from_username() -> None:
|
|
||||||
"""
|
|
||||||
must construct identity from username
|
|
||||||
"""
|
|
||||||
identity = UserIdentity.from_username("username", 0)
|
|
||||||
assert identity.username == "username"
|
|
||||||
# we want to check timestamp too, but later
|
|
||||||
|
|
||||||
|
|
||||||
def test_expire_when() -> None:
|
|
||||||
"""
|
|
||||||
must return correct expiration time
|
|
||||||
"""
|
|
||||||
assert UserIdentity.expire_when(-1) < UserIdentity.expire_when(0) < UserIdentity.expire_when(1)
|
|
||||||
|
|
||||||
|
|
||||||
def test_is_expired(user_identity: UserIdentity) -> None:
|
|
||||||
"""
|
|
||||||
must return expired flag for expired identities
|
|
||||||
"""
|
|
||||||
assert not user_identity.is_expired()
|
|
||||||
|
|
||||||
user_identity = UserIdentity(username=user_identity.username, expire_at=user_identity.expire_at - 60)
|
|
||||||
assert user_identity.is_expired()
|
|
||||||
|
|
||||||
|
|
||||||
def test_to_identity(user_identity: UserIdentity) -> None:
|
|
||||||
"""
|
|
||||||
must return correct identity string
|
|
||||||
"""
|
|
||||||
assert user_identity == UserIdentity.from_identity(user_identity.to_identity())
|
|
@ -5,42 +5,28 @@ from aiohttp import web
|
|||||||
from aiohttp.test_utils import TestClient
|
from aiohttp.test_utils import TestClient
|
||||||
from cryptography import fernet
|
from cryptography import fernet
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock, call as MockCall
|
||||||
|
|
||||||
from ahriman.core.auth import Auth
|
from ahriman.core.auth import Auth
|
||||||
from ahriman.core.configuration import Configuration
|
from ahriman.core.configuration import Configuration
|
||||||
from ahriman.models.user import User
|
from ahriman.models.user import User
|
||||||
from ahriman.models.user_access import UserAccess
|
from ahriman.models.user_access import UserAccess
|
||||||
from ahriman.models.user_identity import UserIdentity
|
|
||||||
from ahriman.web.middlewares.auth_handler import AuthorizationPolicy, auth_handler, cookie_secret_key, setup_auth
|
from ahriman.web.middlewares.auth_handler import AuthorizationPolicy, auth_handler, cookie_secret_key, setup_auth
|
||||||
|
|
||||||
|
|
||||||
def _identity(username: str) -> str:
|
|
||||||
"""
|
|
||||||
generate identity from user
|
|
||||||
|
|
||||||
Args:
|
|
||||||
username(str): name of the user
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: user identity string
|
|
||||||
"""
|
|
||||||
return f"{username} {UserIdentity.expire_when(60)}"
|
|
||||||
|
|
||||||
|
|
||||||
async def test_authorized_userid(authorization_policy: AuthorizationPolicy, user: User, mocker: MockerFixture) -> None:
|
async def test_authorized_userid(authorization_policy: AuthorizationPolicy, user: User, mocker: MockerFixture) -> None:
|
||||||
"""
|
"""
|
||||||
must return authorized user id
|
must return authorized user id
|
||||||
"""
|
"""
|
||||||
mocker.patch("ahriman.core.database.SQLite.user_get", return_value=user)
|
mocker.patch("ahriman.core.database.SQLite.user_get", return_value=user)
|
||||||
assert await authorization_policy.authorized_userid(_identity(user.username)) == user.username
|
assert await authorization_policy.authorized_userid(user.username) == user.username
|
||||||
|
|
||||||
|
|
||||||
async def test_authorized_userid_unknown(authorization_policy: AuthorizationPolicy, user: User) -> None:
|
async def test_authorized_userid_unknown(authorization_policy: AuthorizationPolicy, user: User) -> None:
|
||||||
"""
|
"""
|
||||||
must not allow unknown user id for authorization
|
must not allow unknown user id for authorization
|
||||||
"""
|
"""
|
||||||
assert await authorization_policy.authorized_userid(_identity("somerandomname")) is None
|
assert await authorization_policy.authorized_userid("somerandomname") is None
|
||||||
assert await authorization_policy.authorized_userid("somerandomname") is None
|
assert await authorization_policy.authorized_userid("somerandomname") is None
|
||||||
|
|
||||||
|
|
||||||
@ -51,11 +37,13 @@ async def test_permits(authorization_policy: AuthorizationPolicy, user: User) ->
|
|||||||
authorization_policy.validator = AsyncMock()
|
authorization_policy.validator = AsyncMock()
|
||||||
authorization_policy.validator.verify_access.side_effect = lambda username, *args: username == user.username
|
authorization_policy.validator.verify_access.side_effect = lambda username, *args: username == user.username
|
||||||
|
|
||||||
assert await authorization_policy.permits(_identity(user.username), user.access, "/endpoint")
|
assert await authorization_policy.permits(user.username, user.access, "/endpoint")
|
||||||
authorization_policy.validator.verify_access.assert_called_once_with(user.username, user.access, "/endpoint")
|
assert not await authorization_policy.permits("somerandomname", user.access, "/endpoint")
|
||||||
|
|
||||||
assert not await authorization_policy.permits(_identity("somerandomname"), user.access, "/endpoint")
|
authorization_policy.validator.verify_access.assert_has_calls([
|
||||||
assert not await authorization_policy.permits(user.username, user.access, "/endpoint")
|
MockCall(user.username, user.access, "/endpoint"),
|
||||||
|
MockCall("somerandomname", user.access, "/endpoint"),
|
||||||
|
])
|
||||||
|
|
||||||
|
|
||||||
async def test_auth_handler_unix_socket(client_with_auth: TestClient, mocker: MockerFixture) -> None:
|
async def test_auth_handler_unix_socket(client_with_auth: TestClient, mocker: MockerFixture) -> None:
|
||||||
|
Loading…
Reference in New Issue
Block a user