mirror of
https://github.com/arcan1s/ahriman.git
synced 2025-04-24 15:27:17 +00:00
make auth method asyncs
This commit is contained in:
parent
1b29b5773d
commit
c4e7f63d7c
@ -35,7 +35,7 @@ Authorization mapping. Group name must refer to user access level, i.e. it shoul
|
||||
|
||||
Key is always username (case-insensitive), option value depends on authorization provider:
|
||||
|
||||
* `MappingAuth` (default) - reads salted password hashes from values, uses SHA512 in order to hash passwords. Password can be set by using `create-user` subcommand.
|
||||
* `Mapping` (default) - reads salted password hashes from values, uses SHA512 in order to hash passwords. Password can be set by using `create-user` subcommand.
|
||||
|
||||
## `build:*` groups
|
||||
|
||||
|
@ -19,10 +19,12 @@
|
||||
#
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Type
|
||||
from typing import Dict, Optional, Type
|
||||
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.core.exceptions import DuplicateUser
|
||||
from ahriman.models.auth_settings import AuthSettings
|
||||
from ahriman.models.user import User
|
||||
from ahriman.models.user_access import UserAccess
|
||||
|
||||
|
||||
@ -62,11 +64,30 @@ class Auth:
|
||||
"""
|
||||
provider = AuthSettings.from_option(configuration.get("auth", "target", fallback="disabled"))
|
||||
if provider == AuthSettings.Configuration:
|
||||
from ahriman.core.auth.mapping_auth import MappingAuth
|
||||
return MappingAuth(configuration)
|
||||
from ahriman.core.auth.mapping import Mapping
|
||||
return Mapping(configuration)
|
||||
return cls(configuration)
|
||||
|
||||
def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool: # pylint: disable=no-self-use
|
||||
@staticmethod
|
||||
def get_users(configuration: Configuration) -> Dict[str, User]:
|
||||
"""
|
||||
load users from settings
|
||||
:param configuration: configuration instance
|
||||
:return: map of username to its descriptor
|
||||
"""
|
||||
users: Dict[str, User] = {}
|
||||
for role in UserAccess:
|
||||
section = configuration.section_name("auth", role.value)
|
||||
if not configuration.has_section(section):
|
||||
continue
|
||||
for user, password in configuration[section].items():
|
||||
normalized_user = user.lower()
|
||||
if normalized_user in users:
|
||||
raise DuplicateUser(normalized_user)
|
||||
users[normalized_user] = User(normalized_user, password, role)
|
||||
return users
|
||||
|
||||
async def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool: # pylint: disable=no-self-use
|
||||
"""
|
||||
validate user password
|
||||
:param username: username
|
||||
@ -76,20 +97,20 @@ class Auth:
|
||||
del username, password
|
||||
return True
|
||||
|
||||
def is_safe_request(self, uri: Optional[str], required: UserAccess) -> bool:
|
||||
async def is_safe_request(self, uri: Optional[str], required: UserAccess) -> bool:
|
||||
"""
|
||||
check if requested path are allowed without authorization
|
||||
:param uri: request uri
|
||||
:param required: required access level
|
||||
:return: True in case if this URI can be requested without authorization and False otherwise
|
||||
"""
|
||||
if not uri:
|
||||
return False # request without context is not allowed
|
||||
if required == UserAccess.Read and self.allow_read_only:
|
||||
return True # in case if read right requested and allowed in options
|
||||
if not uri:
|
||||
return False # request without context is not allowed
|
||||
return uri in self.allowed_paths or any(uri.startswith(path) for path in self.allowed_paths_groups)
|
||||
|
||||
def known_username(self, username: str) -> bool: # pylint: disable=no-self-use
|
||||
async def known_username(self, username: str) -> bool: # pylint: disable=no-self-use
|
||||
"""
|
||||
check if user is known
|
||||
:param username: username
|
||||
@ -98,7 +119,7 @@ class Auth:
|
||||
del username
|
||||
return True
|
||||
|
||||
def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool: # pylint: disable=no-self-use
|
||||
async def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool: # pylint: disable=no-self-use
|
||||
"""
|
||||
validate if user has access to requested resource
|
||||
:param username: username
|
||||
|
@ -17,17 +17,16 @@
|
||||
# You should have received a copy of the GNU General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional
|
||||
|
||||
from ahriman.core.auth.auth import Auth
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.core.exceptions import DuplicateUser
|
||||
from ahriman.models.auth_settings import AuthSettings
|
||||
from ahriman.models.user import User
|
||||
from ahriman.models.user_access import UserAccess
|
||||
|
||||
|
||||
class MappingAuth(Auth):
|
||||
class Mapping(Auth):
|
||||
"""
|
||||
user authorization based on mapping from configuration file
|
||||
:ivar salt: random generated string to salt passwords
|
||||
@ -44,26 +43,7 @@ class MappingAuth(Auth):
|
||||
self.salt = configuration.get("auth", "salt")
|
||||
self._users = self.get_users(configuration)
|
||||
|
||||
@staticmethod
|
||||
def get_users(configuration: Configuration) -> Dict[str, User]:
|
||||
"""
|
||||
load users from settings
|
||||
:param configuration: configuration instance
|
||||
:return: map of username to its descriptor
|
||||
"""
|
||||
users: Dict[str, User] = {}
|
||||
for role in UserAccess:
|
||||
section = configuration.section_name("auth", role.value)
|
||||
if not configuration.has_section(section):
|
||||
continue
|
||||
for user, password in configuration[section].items():
|
||||
normalized_user = user.lower()
|
||||
if normalized_user in users:
|
||||
raise DuplicateUser(normalized_user)
|
||||
users[normalized_user] = User(normalized_user, password, role)
|
||||
return users
|
||||
|
||||
def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool:
|
||||
async def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool:
|
||||
"""
|
||||
validate user password
|
||||
:param username: username
|
||||
@ -84,7 +64,7 @@ class MappingAuth(Auth):
|
||||
normalized_user = username.lower()
|
||||
return self._users.get(normalized_user)
|
||||
|
||||
def known_username(self, username: str) -> bool:
|
||||
async def known_username(self, username: str) -> bool:
|
||||
"""
|
||||
check if user is known
|
||||
:param username: username
|
||||
@ -92,7 +72,7 @@ class MappingAuth(Auth):
|
||||
"""
|
||||
return self.get_user(username) is not None
|
||||
|
||||
def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool:
|
||||
async def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool:
|
||||
"""
|
||||
validate if user has access to requested resource
|
||||
:param username: username
|
||||
@ -100,6 +80,5 @@ class MappingAuth(Auth):
|
||||
:param context: URI request path
|
||||
:return: True in case if user is allowed to do this request and False otherwise
|
||||
"""
|
||||
del context
|
||||
user = self.get_user(username)
|
||||
return user is not None and user.verify_access(required)
|
@ -30,10 +30,12 @@ class AuthSettings(Enum):
|
||||
web authorization type
|
||||
:cvar Disabled: authorization is disabled
|
||||
:cvar Configuration: configuration based authorization
|
||||
:cvar OAuth: OAuth based provider
|
||||
"""
|
||||
|
||||
Disabled = auto()
|
||||
Configuration = auto()
|
||||
OAuth = auto()
|
||||
|
||||
@classmethod
|
||||
def from_option(cls: Type[AuthSettings], value: str) -> AuthSettings:
|
||||
@ -46,6 +48,8 @@ class AuthSettings(Enum):
|
||||
return cls.Disabled
|
||||
if value.lower() in ("configuration", "mapping"):
|
||||
return cls.Configuration
|
||||
if value.lower() in ('oauth', 'oauth2'):
|
||||
return cls.OAuth
|
||||
raise InvalidOption(value)
|
||||
|
||||
@property
|
||||
|
@ -52,7 +52,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
|
||||
:param identity: username
|
||||
:return: user identity (username) in case if user exists and None otherwise
|
||||
"""
|
||||
return identity if self.validator.known_username(identity) else None
|
||||
return identity if await self.validator.known_username(identity) else None
|
||||
|
||||
async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool:
|
||||
"""
|
||||
@ -62,7 +62,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
|
||||
:param context: URI request path
|
||||
:return: True in case if user is allowed to perform this request and False otherwise
|
||||
"""
|
||||
return self.validator.verify_access(identity, permission, context)
|
||||
return await self.validator.verify_access(identity, permission, context)
|
||||
|
||||
|
||||
def auth_handler(validator: Auth) -> MiddlewareType:
|
||||
@ -78,7 +78,7 @@ def auth_handler(validator: Auth) -> MiddlewareType:
|
||||
else:
|
||||
permission = UserAccess.Write
|
||||
|
||||
if not validator.is_safe_request(request.path, permission):
|
||||
if not await validator.is_safe_request(request.path, permission):
|
||||
await aiohttp_security.check_permission(request, permission, request.path)
|
||||
|
||||
return await handler(request)
|
||||
|
@ -44,7 +44,7 @@ class LoginView(BaseView):
|
||||
username = data.get("username")
|
||||
|
||||
response = HTTPFound("/")
|
||||
if self.validator.check_credentials(username, data.get("password")):
|
||||
if await self.validator.check_credentials(username, data.get("password")):
|
||||
await remember(self.request, response, username)
|
||||
return response
|
||||
|
||||
|
@ -1,13 +1,13 @@
|
||||
import pytest
|
||||
|
||||
from ahriman.core.auth.mapping_auth import MappingAuth
|
||||
from ahriman.core.auth.mapping import Mapping
|
||||
from ahriman.core.configuration import Configuration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mapping_auth(configuration: Configuration) -> MappingAuth:
|
||||
def mapping_auth(configuration: Configuration) -> Mapping:
|
||||
"""
|
||||
auth provider fixture
|
||||
:return: auth service instance
|
||||
"""
|
||||
return MappingAuth(configuration)
|
||||
return Mapping(configuration)
|
||||
|
@ -1,6 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from ahriman.core.auth.auth import Auth
|
||||
from ahriman.core.auth.mapping_auth import MappingAuth
|
||||
from ahriman.core.auth.mapping import Mapping
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.core.exceptions import DuplicateUser
|
||||
from ahriman.models.user import User
|
||||
from ahriman.models.user_access import UserAccess
|
||||
|
||||
@ -28,63 +31,109 @@ def test_load_mapping(configuration: Configuration) -> None:
|
||||
"""
|
||||
configuration.set_option("auth", "target", "configuration")
|
||||
auth = Auth.load(configuration)
|
||||
assert isinstance(auth, MappingAuth)
|
||||
assert isinstance(auth, Mapping)
|
||||
|
||||
|
||||
def test_check_credentials(auth: Auth, user: User) -> None:
|
||||
def test_get_users(mapping_auth: Auth, configuration: Configuration) -> None:
|
||||
"""
|
||||
must return valid user list
|
||||
"""
|
||||
user_write = User("user_write", "pwd_write", UserAccess.Write)
|
||||
write_section = Configuration.section_name("auth", user_write.access.value)
|
||||
configuration.set_option(write_section, user_write.username, user_write.password)
|
||||
user_read = User("user_read", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user_read.access.value)
|
||||
configuration.set_option(read_section, user_read.username, user_read.password)
|
||||
user_read = User("user_read", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user_read.access.value)
|
||||
configuration.set_option(read_section, user_read.username, user_read.password)
|
||||
|
||||
users = mapping_auth.get_users(configuration)
|
||||
expected = {user_write.username: user_write, user_read.username: user_read}
|
||||
assert users == expected
|
||||
|
||||
|
||||
def test_get_users_normalized(mapping_auth: Auth, configuration: Configuration) -> None:
|
||||
"""
|
||||
must return user list with normalized usernames in keys
|
||||
"""
|
||||
user = User("UsEr", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user.access.value)
|
||||
configuration.set_option(read_section, user.username, user.password)
|
||||
|
||||
users = mapping_auth.get_users(configuration)
|
||||
expected = user.username.lower()
|
||||
assert expected in users
|
||||
assert users[expected].username == expected
|
||||
|
||||
|
||||
def test_get_users_duplicate(mapping_auth: Auth, configuration: Configuration, user: User) -> None:
|
||||
"""
|
||||
must raise exception on duplicate username
|
||||
"""
|
||||
write_section = Configuration.section_name("auth", UserAccess.Write.value)
|
||||
configuration.set_option(write_section, user.username, user.password)
|
||||
read_section = Configuration.section_name("auth", UserAccess.Read.value)
|
||||
configuration.set_option(read_section, user.username, user.password)
|
||||
|
||||
with pytest.raises(DuplicateUser):
|
||||
mapping_auth.get_users(configuration)
|
||||
|
||||
|
||||
async def test_check_credentials(auth: Auth, user: User) -> None:
|
||||
"""
|
||||
must pass any credentials
|
||||
"""
|
||||
assert auth.check_credentials(user.username, user.password)
|
||||
assert auth.check_credentials(None, "")
|
||||
assert auth.check_credentials("", None)
|
||||
assert auth.check_credentials(None, None)
|
||||
assert await auth.check_credentials(user.username, user.password)
|
||||
assert await auth.check_credentials(None, "")
|
||||
assert await auth.check_credentials("", None)
|
||||
assert await auth.check_credentials(None, None)
|
||||
|
||||
|
||||
def test_is_safe_request(auth: Auth) -> None:
|
||||
async def test_is_safe_request(auth: Auth) -> None:
|
||||
"""
|
||||
must validate safe request
|
||||
"""
|
||||
# login and logout are always safe
|
||||
assert auth.is_safe_request("/user-api/v1/login", UserAccess.Write)
|
||||
assert auth.is_safe_request("/user-api/v1/logout", UserAccess.Write)
|
||||
assert await auth.is_safe_request("/user-api/v1/login", UserAccess.Write)
|
||||
assert await auth.is_safe_request("/user-api/v1/logout", UserAccess.Write)
|
||||
|
||||
auth.allowed_paths.add("/safe")
|
||||
auth.allowed_paths_groups.add("/unsafe/safe")
|
||||
|
||||
assert auth.is_safe_request("/safe", UserAccess.Write)
|
||||
assert not auth.is_safe_request("/unsafe", UserAccess.Write)
|
||||
assert auth.is_safe_request("/unsafe/safe", UserAccess.Write)
|
||||
assert auth.is_safe_request("/unsafe/safe/suffix", UserAccess.Write)
|
||||
assert await auth.is_safe_request("/safe", UserAccess.Write)
|
||||
assert not await auth.is_safe_request("/unsafe", UserAccess.Write)
|
||||
assert await auth.is_safe_request("/unsafe/safe", UserAccess.Write)
|
||||
assert await auth.is_safe_request("/unsafe/safe/suffix", UserAccess.Write)
|
||||
|
||||
|
||||
def test_is_safe_request_empty(auth: Auth) -> None:
|
||||
async def test_is_safe_request_empty(auth: Auth) -> None:
|
||||
"""
|
||||
must not allow requests without path
|
||||
"""
|
||||
assert not auth.is_safe_request(None, UserAccess.Read)
|
||||
assert not auth.is_safe_request("", UserAccess.Read)
|
||||
assert not await auth.is_safe_request(None, UserAccess.Read)
|
||||
assert not await auth.is_safe_request("", UserAccess.Read)
|
||||
|
||||
|
||||
def test_is_safe_request_read_only(auth: Auth) -> None:
|
||||
async def test_is_safe_request_read_only(auth: Auth) -> None:
|
||||
"""
|
||||
must allow read-only requests if it is set in settings
|
||||
"""
|
||||
assert auth.is_safe_request("/", UserAccess.Read)
|
||||
assert await auth.is_safe_request("/", UserAccess.Read)
|
||||
auth.allow_read_only = True
|
||||
assert auth.is_safe_request("/unsafe", UserAccess.Read)
|
||||
assert await auth.is_safe_request("/unsafe", UserAccess.Read)
|
||||
|
||||
|
||||
def test_known_username(auth: Auth, user: User) -> None:
|
||||
async def test_known_username(auth: Auth, user: User) -> None:
|
||||
"""
|
||||
must allow any username
|
||||
"""
|
||||
assert auth.known_username(user.username)
|
||||
assert await auth.known_username(user.username)
|
||||
|
||||
|
||||
def test_verify_access(auth: Auth, user: User) -> None:
|
||||
async def test_verify_access(auth: Auth, user: User) -> None:
|
||||
"""
|
||||
must allow any access
|
||||
"""
|
||||
assert auth.verify_access(user.username, user.access, None)
|
||||
assert auth.verify_access(user.username, UserAccess.Write, None)
|
||||
assert await auth.verify_access(user.username, user.access, None)
|
||||
assert await auth.verify_access(user.username, UserAccess.Write, None)
|
||||
|
72
tests/ahriman/core/auth/test_mapping.py
Normal file
72
tests/ahriman/core/auth/test_mapping.py
Normal file
@ -0,0 +1,72 @@
|
||||
from ahriman.core.auth.mapping import Mapping
|
||||
from ahriman.models.user import User
|
||||
from ahriman.models.user_access import UserAccess
|
||||
|
||||
|
||||
async def test_check_credentials(mapping_auth: Mapping, user: User) -> None:
|
||||
"""
|
||||
must return true for valid credentials
|
||||
"""
|
||||
current_password = user.password
|
||||
user.password = user.hash_password(mapping_auth.salt)
|
||||
mapping_auth._users[user.username] = user
|
||||
assert await mapping_auth.check_credentials(user.username, current_password)
|
||||
# here password is hashed so it is invalid
|
||||
assert not await mapping_auth.check_credentials(user.username, user.password)
|
||||
|
||||
|
||||
async def test_check_credentials_empty(mapping_auth: Mapping) -> None:
|
||||
"""
|
||||
must reject on empty credentials
|
||||
"""
|
||||
assert not await mapping_auth.check_credentials(None, "")
|
||||
assert not await mapping_auth.check_credentials("", None)
|
||||
assert not await mapping_auth.check_credentials(None, None)
|
||||
|
||||
|
||||
async def test_check_credentials_unknown(mapping_auth: Mapping, user: User) -> None:
|
||||
"""
|
||||
must reject on unknown user
|
||||
"""
|
||||
assert not await mapping_auth.check_credentials(user.username, user.password)
|
||||
|
||||
|
||||
def test_get_user(mapping_auth: Mapping, user: User) -> None:
|
||||
"""
|
||||
must return user from storage by username
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.get_user(user.username) == user
|
||||
|
||||
|
||||
def test_get_user_normalized(mapping_auth: Mapping, user: User) -> None:
|
||||
"""
|
||||
must return user from storage by username case-insensitive
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.get_user(user.username.upper()) == user
|
||||
|
||||
|
||||
def test_get_user_unknown(mapping_auth: Mapping, user: User) -> None:
|
||||
"""
|
||||
must return None in case if no user found
|
||||
"""
|
||||
assert mapping_auth.get_user(user.username) is None
|
||||
|
||||
|
||||
async def test_known_username(mapping_auth: Mapping, user: User) -> None:
|
||||
"""
|
||||
must allow only known users
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert await mapping_auth.known_username(user.username)
|
||||
assert not await mapping_auth.known_username(user.password)
|
||||
|
||||
|
||||
async def test_verify_access(mapping_auth: Mapping, user: User) -> None:
|
||||
"""
|
||||
must verify user access
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert await mapping_auth.verify_access(user.username, user.access, None)
|
||||
assert not await mapping_auth.verify_access(user.username, UserAccess.Write, None)
|
@ -1,121 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from ahriman.core.auth.mapping_auth import MappingAuth
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.core.exceptions import DuplicateUser
|
||||
from ahriman.models.user import User
|
||||
from ahriman.models.user_access import UserAccess
|
||||
|
||||
|
||||
def test_get_users(mapping_auth: MappingAuth, configuration: Configuration) -> None:
|
||||
"""
|
||||
must return valid user list
|
||||
"""
|
||||
user_write = User("user_write", "pwd_write", UserAccess.Write)
|
||||
write_section = Configuration.section_name("auth", user_write.access.value)
|
||||
configuration.set_option(write_section, user_write.username, user_write.password)
|
||||
user_read = User("user_read", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user_read.access.value)
|
||||
configuration.set_option(read_section, user_read.username, user_read.password)
|
||||
user_read = User("user_read", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user_read.access.value)
|
||||
configuration.set_option(read_section, user_read.username, user_read.password)
|
||||
|
||||
users = mapping_auth.get_users(configuration)
|
||||
expected = {user_write.username: user_write, user_read.username: user_read}
|
||||
assert users == expected
|
||||
|
||||
|
||||
def test_get_users_normalized(mapping_auth: MappingAuth, configuration: Configuration) -> None:
|
||||
"""
|
||||
must return user list with normalized usernames in keys
|
||||
"""
|
||||
user = User("UsEr", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user.access.value)
|
||||
configuration.set_option(read_section, user.username, user.password)
|
||||
|
||||
users = mapping_auth.get_users(configuration)
|
||||
expected = user.username.lower()
|
||||
assert expected in users
|
||||
assert users[expected].username == expected
|
||||
|
||||
|
||||
def test_get_users_duplicate(mapping_auth: MappingAuth, configuration: Configuration, user: User) -> None:
|
||||
"""
|
||||
must raise exception on duplicate username
|
||||
"""
|
||||
write_section = Configuration.section_name("auth", UserAccess.Write.value)
|
||||
configuration.set_option(write_section, user.username, user.password)
|
||||
read_section = Configuration.section_name("auth", UserAccess.Read.value)
|
||||
configuration.set_option(read_section, user.username, user.password)
|
||||
|
||||
with pytest.raises(DuplicateUser):
|
||||
mapping_auth.get_users(configuration)
|
||||
|
||||
|
||||
def test_check_credentials(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must return true for valid credentials
|
||||
"""
|
||||
current_password = user.password
|
||||
user.password = user.hash_password(mapping_auth.salt)
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.check_credentials(user.username, current_password)
|
||||
assert not mapping_auth.check_credentials(user.username, user.password) # here password is hashed so it is invalid
|
||||
|
||||
|
||||
def test_check_credentials_empty(mapping_auth: MappingAuth) -> None:
|
||||
"""
|
||||
must reject on empty credentials
|
||||
"""
|
||||
assert not mapping_auth.check_credentials(None, "")
|
||||
assert not mapping_auth.check_credentials("", None)
|
||||
assert not mapping_auth.check_credentials(None, None)
|
||||
|
||||
|
||||
def test_check_credentials_unknown(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must reject on unknown user
|
||||
"""
|
||||
assert not mapping_auth.check_credentials(user.username, user.password)
|
||||
|
||||
|
||||
def test_get_user(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must return user from storage by username
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.get_user(user.username) == user
|
||||
|
||||
|
||||
def test_get_user_normalized(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must return user from storage by username case-insensitive
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.get_user(user.username.upper()) == user
|
||||
|
||||
|
||||
def test_get_user_unknown(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must return None in case if no user found
|
||||
"""
|
||||
assert mapping_auth.get_user(user.username) is None
|
||||
|
||||
|
||||
def test_known_username(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must allow only known users
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.known_username(user.username)
|
||||
assert not mapping_auth.known_username(user.password)
|
||||
|
||||
|
||||
def test_verify_access(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must verify user access
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.verify_access(user.username, user.access, None)
|
||||
assert not mapping_auth.verify_access(user.username, UserAccess.Write, None)
|
@ -21,6 +21,10 @@ def test_from_option_valid() -> None:
|
||||
assert AuthSettings.from_option("no") == AuthSettings.Disabled
|
||||
assert AuthSettings.from_option("NO") == AuthSettings.Disabled
|
||||
|
||||
assert AuthSettings.from_option("oauth") == AuthSettings.OAuth
|
||||
assert AuthSettings.from_option("OAuth") == AuthSettings.OAuth
|
||||
assert AuthSettings.from_option("OAuth2") == AuthSettings.OAuth
|
||||
|
||||
assert AuthSettings.from_option("configuration") == AuthSettings.Configuration
|
||||
assert AuthSettings.from_option("ConFigUration") == AuthSettings.Configuration
|
||||
assert AuthSettings.from_option("mapping") == AuthSettings.Configuration
|
||||
|
@ -2,10 +2,9 @@ import pytest
|
||||
|
||||
from aiohttp import web
|
||||
from pytest_mock import MockerFixture
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from ahriman.core.auth.auth import Auth
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.models.user import User
|
||||
from ahriman.models.user_access import UserAccess
|
||||
from ahriman.web.middlewares.auth_handler import auth_handler, AuthorizationPolicy, setup_auth
|
||||
@ -23,7 +22,7 @@ async def test_permits(authorization_policy: AuthorizationPolicy, user: User) ->
|
||||
"""
|
||||
must call validator check
|
||||
"""
|
||||
authorization_policy.validator = MagicMock()
|
||||
authorization_policy.validator = AsyncMock()
|
||||
authorization_policy.validator.verify_access.return_value = True
|
||||
|
||||
assert await authorization_policy.permits(user.username, user.access, "/endpoint")
|
||||
|
Loading…
Reference in New Issue
Block a user