expiration on server side support

This commit is contained in:
Evgenii Alekseev 2021-09-13 00:58:59 +03:00
parent d211cc17c6
commit 370af5854a
7 changed files with 196 additions and 11 deletions

View File

@ -0,0 +1,84 @@
#
# Copyright (c) 2021 ahriman team.
#
# This file is part of ahriman
# (see https://github.com/arcan1s/ahriman).
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Optional, Type
@dataclass
class UserIdentity:
"""
user identity used inside web service
:ivar username: username
:ivar expire_at: identity expiration timestamp
"""
username: str
expire_at: int
@classmethod
def from_identity(cls: Type[UserIdentity], identity: str) -> Optional[UserIdentity]:
"""
parse identity into object
:param identity: identity from session data
:return: user identity object if it can be parsed and not expired and None otherwise
"""
try:
username, expire_at = identity.split()
user = cls(username, int(expire_at))
return None if user.is_expired() else user
except ValueError:
return None
@classmethod
def from_username(cls: Type[UserIdentity], username: Optional[str], max_age: int) -> Optional[UserIdentity]:
"""
generate identity from username
:param username: username
:param max_age: time to expire, seconds
:return: constructed identity object
"""
return cls(username, cls.expire_when(max_age)) if username is not None else None
@staticmethod
def expire_when(max_age: int) -> int:
"""
generate expiration time using delta
:param max_age: time delta to generate. Must be usually TTE
:return: expiration timestamp
"""
return int(time.time()) + max_age
def is_expired(self) -> bool:
"""
compare timestamp with current timestamp and return True in case if identity is expired
:return: True in case if identity is expired and False otherwise
"""
return self.expire_when(0) > self.expire_at
def to_identity(self) -> str:
"""
convert object to identity representation
:return: web service identity
"""
return f"{self.username} {self.expire_at}"

View File

@ -30,6 +30,7 @@ from typing import Optional
from ahriman.core.auth.auth import Auth from ahriman.core.auth.auth import Auth
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.models.user_identity import UserIdentity
from ahriman.web.middlewares import HandlerType, MiddlewareType from ahriman.web.middlewares import HandlerType, MiddlewareType
@ -52,7 +53,10 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
:param identity: username :param identity: username
:return: user identity (username) in case if user exists and None otherwise :return: user identity (username) in case if user exists and None otherwise
""" """
return identity if await self.validator.known_username(identity) else None user = UserIdentity.from_identity(identity)
if user is None:
return None
return user.username if await self.validator.known_username(user.username) else None
async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool: async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool:
""" """
@ -62,7 +66,10 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
:param context: URI request path :param context: URI request path
:return: True in case if user is allowed to perform this request and False otherwise :return: True in case if user is allowed to perform this request and False otherwise
""" """
return await self.validator.verify_access(identity, permission, context) user = UserIdentity.from_identity(identity)
if user is None:
return False
return await self.validator.verify_access(user.username, permission, context)
def auth_handler(validator: Auth) -> MiddlewareType: def auth_handler(validator: Auth) -> MiddlewareType:

View File

@ -20,6 +20,7 @@
from aiohttp.web import HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized, Response from aiohttp.web import HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized, Response
from ahriman.core.auth.helpers import remember from ahriman.core.auth.helpers import remember
from ahriman.models.user_identity import UserIdentity
from ahriman.web.views.base import BaseView from ahriman.web.views.base import BaseView
@ -49,8 +50,9 @@ class LoginView(BaseView):
response = HTTPFound("/") response = HTTPFound("/")
username = await oauth_provider.get_oauth_username(code) username = await oauth_provider.get_oauth_username(code)
if await self.validator.known_username(username): identity = UserIdentity.from_username(username, self.validator.max_age)
await remember(self.request, response, username) if identity is not None and await self.validator.known_username(username):
await remember(self.request, response, identity.to_identity())
return response return response
raise HTTPUnauthorized() raise HTTPUnauthorized()
@ -71,8 +73,9 @@ class LoginView(BaseView):
username = data.get("username") username = data.get("username")
response = HTTPFound("/") response = HTTPFound("/")
if await self.validator.check_credentials(username, data.get("password")): identity = UserIdentity.from_username(username, self.validator.max_age)
await remember(self.request, response, username) if identity is not None and await self.validator.check_credentials(username, data.get("password")):
await remember(self.request, response, identity.to_identity())
return response return response
raise HTTPUnauthorized() raise HTTPUnauthorized()

View File

@ -1,4 +1,5 @@
import pytest import pytest
import time
from unittest.mock import MagicMock, PropertyMock from unittest.mock import MagicMock, PropertyMock
@ -8,6 +9,7 @@ from ahriman.models.counters import Counters
from ahriman.models.internal_status import InternalStatus from ahriman.models.internal_status import InternalStatus
from ahriman.models.package import Package from ahriman.models.package import Package
from ahriman.models.package_description import PackageDescription from ahriman.models.package_description import PackageDescription
from ahriman.models.user_identity import UserIdentity
@pytest.fixture @pytest.fixture
@ -104,3 +106,12 @@ def pyalpm_package_description_ahriman(package_description_ahriman: PackageDescr
type(mock).provides = PropertyMock(return_value=package_description_ahriman.provides) type(mock).provides = PropertyMock(return_value=package_description_ahriman.provides)
type(mock).url = PropertyMock(return_value=package_description_ahriman.url) type(mock).url = PropertyMock(return_value=package_description_ahriman.url)
return mock return mock
@pytest.fixture
def user_identity() -> UserIdentity:
"""
identity fixture
:return: user identity test instance
"""
return UserIdentity("username", int(time.time()) + 30)

View File

@ -0,0 +1,64 @@
from ahriman.models.user_identity import UserIdentity
def test_from_identity(user_identity: UserIdentity) -> None:
"""
must construct identity object from string
"""
identity = UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}")
assert identity == user_identity
def test_from_identity_expired(user_identity: UserIdentity) -> None:
"""
must construct None from expired identity
"""
user_identity.expire_at -= 60
assert UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}") is None
def test_from_identity_no_split() -> None:
"""
must construct None from invalid string
"""
assert UserIdentity.from_identity("username") is None
def test_from_identity_not_int() -> None:
"""
must construct None from invalid timestamp
"""
assert UserIdentity.from_identity("username timestamp") is None
def test_from_username() -> None:
"""
must construct identity from username
"""
identity = UserIdentity.from_username("username", 0)
assert identity.username == "username"
# we want to check timestamp too, but later
def test_expire_when() -> None:
"""
must return correct expiration time
"""
assert UserIdentity.expire_when(-1) < UserIdentity.expire_when(0) < UserIdentity.expire_when(1)
def test_is_expired(user_identity: UserIdentity) -> None:
"""
must return expired flag for expired identities
"""
assert not user_identity.is_expired()
user_identity.expire_at -= 60
assert user_identity.is_expired()
def test_to_identity(user_identity: UserIdentity) -> None:
"""
must return correct identity string
"""
assert user_identity == UserIdentity.from_identity(user_identity.to_identity())

View File

@ -7,15 +7,26 @@ from unittest.mock import AsyncMock
from ahriman.core.auth.auth import Auth from ahriman.core.auth.auth import Auth
from ahriman.models.user import User from ahriman.models.user import User
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.models.user_identity import UserIdentity
from ahriman.web.middlewares.auth_handler import auth_handler, AuthorizationPolicy, setup_auth from ahriman.web.middlewares.auth_handler import auth_handler, AuthorizationPolicy, setup_auth
def _identity(username: str) -> str:
"""
generate identity from user
:param user: user fixture object
:return: user identity string
"""
return f"{username} {UserIdentity.expire_when(60)}"
async def test_authorized_userid(authorization_policy: AuthorizationPolicy, user: User) -> None: async def test_authorized_userid(authorization_policy: AuthorizationPolicy, user: User) -> None:
""" """
must return authorized user id must return authorized user id
""" """
assert await authorization_policy.authorized_userid(user.username) == user.username assert await authorization_policy.authorized_userid(_identity(user.username)) == user.username
assert await authorization_policy.authorized_userid("some random name") is None assert await authorization_policy.authorized_userid(_identity("somerandomname")) is None
assert await authorization_policy.authorized_userid("somerandomname") is None
async def test_permits(authorization_policy: AuthorizationPolicy, user: User) -> None: async def test_permits(authorization_policy: AuthorizationPolicy, user: User) -> None:
@ -23,11 +34,14 @@ async def test_permits(authorization_policy: AuthorizationPolicy, user: User) ->
must call validator check must call validator check
""" """
authorization_policy.validator = AsyncMock() authorization_policy.validator = AsyncMock()
authorization_policy.validator.verify_access.return_value = True authorization_policy.validator.verify_access.side_effect = lambda username, *args: username == user.username
assert await authorization_policy.permits(user.username, user.access, "/endpoint") assert await authorization_policy.permits(_identity(user.username), user.access, "/endpoint")
authorization_policy.validator.verify_access.assert_called_with(user.username, user.access, "/endpoint") authorization_policy.validator.verify_access.assert_called_with(user.username, user.access, "/endpoint")
assert not await authorization_policy.permits(_identity("somerandomname"), user.access, "/endpoint")
assert not await authorization_policy.permits(user.username, user.access, "/endpoint")
async def test_auth_handler_api(auth: Auth, mocker: MockerFixture) -> None: async def test_auth_handler_api(auth: Auth, mocker: MockerFixture) -> None:
""" """

View File

@ -45,7 +45,8 @@ async def test_get(client_with_auth: TestClient, mocker: MockerFixture) -> None:
oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth) oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
oauth.get_oauth_username.return_value = "user" oauth.get_oauth_username.return_value = "user"
oauth.known_username.return_value = True oauth.known_username.return_value = True
oauth.enabled = False # lol oauth.enabled = False # lol\
oauth.max_age = 60
remember_mock = mocker.patch("aiohttp_security.remember") remember_mock = mocker.patch("aiohttp_security.remember")
get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"}) get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"})
@ -62,6 +63,7 @@ async def test_get_unauthorized(client_with_auth: TestClient, mocker: MockerFixt
""" """
oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth) oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
oauth.known_username.return_value = False oauth.known_username.return_value = False
oauth.max_age = 60
remember_mock = mocker.patch("aiohttp_security.remember") remember_mock = mocker.patch("aiohttp_security.remember")
get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"}) get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"})