mirror of
https://github.com/arcan1s/ahriman.git
synced 2025-07-14 22:45:47 +00:00
simplify login ttl processing
This commit is contained in:
@ -1,102 +0,0 @@
|
||||
#
|
||||
# Copyright (c) 2021-2023 ahriman team.
|
||||
#
|
||||
# This file is part of ahriman
|
||||
# (see https://github.com/arcan1s/ahriman).
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU General Public License as published by
|
||||
# the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Type
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UserIdentity:
|
||||
"""
|
||||
user identity used inside web service
|
||||
|
||||
Attributes:
|
||||
username(str): username
|
||||
expire_at(int): identity expiration timestamp
|
||||
"""
|
||||
|
||||
username: str
|
||||
expire_at: int
|
||||
|
||||
@classmethod
|
||||
def from_identity(cls: Type[UserIdentity], identity: str) -> Optional[UserIdentity]:
|
||||
"""
|
||||
parse identity into object
|
||||
|
||||
Args:
|
||||
identity(str): identity from session data
|
||||
|
||||
Returns:
|
||||
Optional[UserIdentity]: user identity object if it can be parsed and not expired and None otherwise
|
||||
"""
|
||||
try:
|
||||
username, expire_at = identity.split()
|
||||
user = cls(username, int(expire_at))
|
||||
return None if user.is_expired() else user
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_username(cls: Type[UserIdentity], username: Optional[str], max_age: int) -> Optional[UserIdentity]:
|
||||
"""
|
||||
generate identity from username
|
||||
|
||||
Args:
|
||||
username(Optional[str]): username
|
||||
max_age(int): time to expire, seconds
|
||||
|
||||
Returns:
|
||||
Optional[UserIdentity]: constructed identity object
|
||||
"""
|
||||
return cls(username, cls.expire_when(max_age)) if username is not None else None
|
||||
|
||||
@staticmethod
|
||||
def expire_when(max_age: int) -> int:
|
||||
"""
|
||||
generate expiration time using delta
|
||||
|
||||
Args:
|
||||
max_age(int): time delta to generate. Must be usually TTE
|
||||
|
||||
Returns:
|
||||
int: expiration timestamp
|
||||
"""
|
||||
return int(time.time()) + max_age
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""
|
||||
compare timestamp with current timestamp and return True in case if identity is expired
|
||||
|
||||
Returns:
|
||||
bool: True in case if identity is expired and False otherwise
|
||||
"""
|
||||
return self.expire_when(0) > self.expire_at
|
||||
|
||||
def to_identity(self) -> str:
|
||||
"""
|
||||
convert object to identity representation
|
||||
|
||||
Returns:
|
||||
str: web service identity
|
||||
"""
|
||||
return f"{self.username} {self.expire_at}"
|
@ -33,7 +33,6 @@ from typing import Optional
|
||||
from ahriman.core.auth import Auth
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.models.user_access import UserAccess
|
||||
from ahriman.models.user_identity import UserIdentity
|
||||
from ahriman.web.middlewares import HandlerType, MiddlewareType
|
||||
|
||||
|
||||
@ -67,10 +66,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
|
||||
Returns:
|
||||
Optional[str]: user identity (username) in case if user exists and None otherwise
|
||||
"""
|
||||
user = UserIdentity.from_identity(identity)
|
||||
if user is None:
|
||||
return None
|
||||
return user.username if await self.validator.known_username(user.username) else None
|
||||
return identity if await self.validator.known_username(identity) else None
|
||||
|
||||
async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool:
|
||||
"""
|
||||
@ -84,10 +80,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
|
||||
Returns:
|
||||
bool: True in case if user is allowed to perform this request and False otherwise
|
||||
"""
|
||||
user = UserIdentity.from_identity(identity)
|
||||
if user is None:
|
||||
return False
|
||||
return await self.validator.verify_access(user.username, permission, context)
|
||||
return await self.validator.verify_access(identity, permission, context)
|
||||
|
||||
|
||||
def auth_handler(allow_read_only: bool) -> MiddlewareType:
|
||||
|
@ -21,7 +21,6 @@ from aiohttp.web import HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized
|
||||
|
||||
from ahriman.core.auth.helpers import remember
|
||||
from ahriman.models.user_access import UserAccess
|
||||
from ahriman.models.user_identity import UserIdentity
|
||||
from ahriman.web.views.base import BaseView
|
||||
|
||||
|
||||
@ -64,10 +63,9 @@ class LoginView(BaseView):
|
||||
raise HTTPFound(oauth_provider.get_oauth_url())
|
||||
|
||||
response = HTTPFound("/")
|
||||
username = await oauth_provider.get_oauth_username(code)
|
||||
identity = UserIdentity.from_username(username, self.validator.max_age)
|
||||
if identity is not None and await self.validator.known_username(username):
|
||||
await remember(self.request, response, identity.to_identity())
|
||||
identity = await oauth_provider.get_oauth_username(code)
|
||||
if identity is not None and await self.validator.known_username(identity):
|
||||
await remember(self.request, response, identity)
|
||||
raise response
|
||||
|
||||
raise HTTPUnauthorized()
|
||||
@ -111,12 +109,11 @@ class LoginView(BaseView):
|
||||
302: Found
|
||||
"""
|
||||
data = await self.extract_data()
|
||||
username = data.get("username")
|
||||
identity = data.get("username")
|
||||
|
||||
response = HTTPFound("/")
|
||||
identity = UserIdentity.from_username(username, self.validator.max_age)
|
||||
if identity is not None and await self.validator.check_credentials(username, data.get("password")):
|
||||
await remember(self.request, response, identity.to_identity())
|
||||
if identity is not None and await self.validator.check_credentials(identity, data.get("password")):
|
||||
await remember(self.request, response, identity)
|
||||
raise response
|
||||
|
||||
raise HTTPUnauthorized()
|
||||
|
Reference in New Issue
Block a user