make auth method asyncs

This commit is contained in:
Evgenii Alekseev 2021-09-12 03:22:17 +03:00
parent 1b29b5773d
commit c4e7f63d7c
12 changed files with 199 additions and 192 deletions

View File

@ -35,7 +35,7 @@ Authorization mapping. Group name must refer to user access level, i.e. it shoul
Key is always username (case-insensitive), option value depends on authorization provider: Key is always username (case-insensitive), option value depends on authorization provider:
* `MappingAuth` (default) - reads salted password hashes from values, uses SHA512 in order to hash passwords. Password can be set by using `create-user` subcommand. * `Mapping` (default) - reads salted password hashes from values, uses SHA512 in order to hash passwords. Password can be set by using `create-user` subcommand.
## `build:*` groups ## `build:*` groups

View File

@ -19,10 +19,12 @@
# #
from __future__ import annotations from __future__ import annotations
from typing import Optional, Type from typing import Dict, Optional, Type
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
from ahriman.core.exceptions import DuplicateUser
from ahriman.models.auth_settings import AuthSettings from ahriman.models.auth_settings import AuthSettings
from ahriman.models.user import User
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
@ -62,11 +64,30 @@ class Auth:
""" """
provider = AuthSettings.from_option(configuration.get("auth", "target", fallback="disabled")) provider = AuthSettings.from_option(configuration.get("auth", "target", fallback="disabled"))
if provider == AuthSettings.Configuration: if provider == AuthSettings.Configuration:
from ahriman.core.auth.mapping_auth import MappingAuth from ahriman.core.auth.mapping import Mapping
return MappingAuth(configuration) return Mapping(configuration)
return cls(configuration) return cls(configuration)
def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool: # pylint: disable=no-self-use @staticmethod
def get_users(configuration: Configuration) -> Dict[str, User]:
"""
load users from settings
:param configuration: configuration instance
:return: map of username to its descriptor
"""
users: Dict[str, User] = {}
for role in UserAccess:
section = configuration.section_name("auth", role.value)
if not configuration.has_section(section):
continue
for user, password in configuration[section].items():
normalized_user = user.lower()
if normalized_user in users:
raise DuplicateUser(normalized_user)
users[normalized_user] = User(normalized_user, password, role)
return users
async def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool: # pylint: disable=no-self-use
""" """
validate user password validate user password
:param username: username :param username: username
@ -76,20 +97,20 @@ class Auth:
del username, password del username, password
return True return True
def is_safe_request(self, uri: Optional[str], required: UserAccess) -> bool: async def is_safe_request(self, uri: Optional[str], required: UserAccess) -> bool:
""" """
check if requested path are allowed without authorization check if requested path are allowed without authorization
:param uri: request uri :param uri: request uri
:param required: required access level :param required: required access level
:return: True in case if this URI can be requested without authorization and False otherwise :return: True in case if this URI can be requested without authorization and False otherwise
""" """
if not uri:
return False # request without context is not allowed
if required == UserAccess.Read and self.allow_read_only: if required == UserAccess.Read and self.allow_read_only:
return True # in case if read right requested and allowed in options return True # in case if read right requested and allowed in options
if not uri:
return False # request without context is not allowed
return uri in self.allowed_paths or any(uri.startswith(path) for path in self.allowed_paths_groups) return uri in self.allowed_paths or any(uri.startswith(path) for path in self.allowed_paths_groups)
def known_username(self, username: str) -> bool: # pylint: disable=no-self-use async def known_username(self, username: str) -> bool: # pylint: disable=no-self-use
""" """
check if user is known check if user is known
:param username: username :param username: username
@ -98,7 +119,7 @@ class Auth:
del username del username
return True return True
def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool: # pylint: disable=no-self-use async def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool: # pylint: disable=no-self-use
""" """
validate if user has access to requested resource validate if user has access to requested resource
:param username: username :param username: username

View File

@ -17,17 +17,16 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
from typing import Dict, Optional from typing import Optional
from ahriman.core.auth.auth import Auth from ahriman.core.auth.auth import Auth
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
from ahriman.core.exceptions import DuplicateUser
from ahriman.models.auth_settings import AuthSettings from ahriman.models.auth_settings import AuthSettings
from ahriman.models.user import User from ahriman.models.user import User
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
class MappingAuth(Auth): class Mapping(Auth):
""" """
user authorization based on mapping from configuration file user authorization based on mapping from configuration file
:ivar salt: random generated string to salt passwords :ivar salt: random generated string to salt passwords
@ -44,26 +43,7 @@ class MappingAuth(Auth):
self.salt = configuration.get("auth", "salt") self.salt = configuration.get("auth", "salt")
self._users = self.get_users(configuration) self._users = self.get_users(configuration)
@staticmethod async def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool:
def get_users(configuration: Configuration) -> Dict[str, User]:
"""
load users from settings
:param configuration: configuration instance
:return: map of username to its descriptor
"""
users: Dict[str, User] = {}
for role in UserAccess:
section = configuration.section_name("auth", role.value)
if not configuration.has_section(section):
continue
for user, password in configuration[section].items():
normalized_user = user.lower()
if normalized_user in users:
raise DuplicateUser(normalized_user)
users[normalized_user] = User(normalized_user, password, role)
return users
def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool:
""" """
validate user password validate user password
:param username: username :param username: username
@ -84,7 +64,7 @@ class MappingAuth(Auth):
normalized_user = username.lower() normalized_user = username.lower()
return self._users.get(normalized_user) return self._users.get(normalized_user)
def known_username(self, username: str) -> bool: async def known_username(self, username: str) -> bool:
""" """
check if user is known check if user is known
:param username: username :param username: username
@ -92,7 +72,7 @@ class MappingAuth(Auth):
""" """
return self.get_user(username) is not None return self.get_user(username) is not None
def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool: async def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool:
""" """
validate if user has access to requested resource validate if user has access to requested resource
:param username: username :param username: username
@ -100,6 +80,5 @@ class MappingAuth(Auth):
:param context: URI request path :param context: URI request path
:return: True in case if user is allowed to do this request and False otherwise :return: True in case if user is allowed to do this request and False otherwise
""" """
del context
user = self.get_user(username) user = self.get_user(username)
return user is not None and user.verify_access(required) return user is not None and user.verify_access(required)

View File

@ -30,10 +30,12 @@ class AuthSettings(Enum):
web authorization type web authorization type
:cvar Disabled: authorization is disabled :cvar Disabled: authorization is disabled
:cvar Configuration: configuration based authorization :cvar Configuration: configuration based authorization
:cvar OAuth: OAuth based provider
""" """
Disabled = auto() Disabled = auto()
Configuration = auto() Configuration = auto()
OAuth = auto()
@classmethod @classmethod
def from_option(cls: Type[AuthSettings], value: str) -> AuthSettings: def from_option(cls: Type[AuthSettings], value: str) -> AuthSettings:
@ -46,6 +48,8 @@ class AuthSettings(Enum):
return cls.Disabled return cls.Disabled
if value.lower() in ("configuration", "mapping"): if value.lower() in ("configuration", "mapping"):
return cls.Configuration return cls.Configuration
if value.lower() in ('oauth', 'oauth2'):
return cls.OAuth
raise InvalidOption(value) raise InvalidOption(value)
@property @property

View File

@ -52,7 +52,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
:param identity: username :param identity: username
:return: user identity (username) in case if user exists and None otherwise :return: user identity (username) in case if user exists and None otherwise
""" """
return identity if self.validator.known_username(identity) else None return identity if await self.validator.known_username(identity) else None
async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool: async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool:
""" """
@ -62,7 +62,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
:param context: URI request path :param context: URI request path
:return: True in case if user is allowed to perform this request and False otherwise :return: True in case if user is allowed to perform this request and False otherwise
""" """
return self.validator.verify_access(identity, permission, context) return await self.validator.verify_access(identity, permission, context)
def auth_handler(validator: Auth) -> MiddlewareType: def auth_handler(validator: Auth) -> MiddlewareType:
@ -78,7 +78,7 @@ def auth_handler(validator: Auth) -> MiddlewareType:
else: else:
permission = UserAccess.Write permission = UserAccess.Write
if not validator.is_safe_request(request.path, permission): if not await validator.is_safe_request(request.path, permission):
await aiohttp_security.check_permission(request, permission, request.path) await aiohttp_security.check_permission(request, permission, request.path)
return await handler(request) return await handler(request)

View File

@ -44,7 +44,7 @@ class LoginView(BaseView):
username = data.get("username") username = data.get("username")
response = HTTPFound("/") response = HTTPFound("/")
if self.validator.check_credentials(username, data.get("password")): if await self.validator.check_credentials(username, data.get("password")):
await remember(self.request, response, username) await remember(self.request, response, username)
return response return response

View File

@ -1,13 +1,13 @@
import pytest import pytest
from ahriman.core.auth.mapping_auth import MappingAuth from ahriman.core.auth.mapping import Mapping
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
@pytest.fixture @pytest.fixture
def mapping_auth(configuration: Configuration) -> MappingAuth: def mapping_auth(configuration: Configuration) -> Mapping:
""" """
auth provider fixture auth provider fixture
:return: auth service instance :return: auth service instance
""" """
return MappingAuth(configuration) return Mapping(configuration)

View File

@ -1,6 +1,9 @@
import pytest
from ahriman.core.auth.auth import Auth from ahriman.core.auth.auth import Auth
from ahriman.core.auth.mapping_auth import MappingAuth from ahriman.core.auth.mapping import Mapping
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
from ahriman.core.exceptions import DuplicateUser
from ahriman.models.user import User from ahriman.models.user import User
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
@ -28,63 +31,109 @@ def test_load_mapping(configuration: Configuration) -> None:
""" """
configuration.set_option("auth", "target", "configuration") configuration.set_option("auth", "target", "configuration")
auth = Auth.load(configuration) auth = Auth.load(configuration)
assert isinstance(auth, MappingAuth) assert isinstance(auth, Mapping)
def test_check_credentials(auth: Auth, user: User) -> None: def test_get_users(mapping_auth: Auth, configuration: Configuration) -> None:
"""
must return valid user list
"""
user_write = User("user_write", "pwd_write", UserAccess.Write)
write_section = Configuration.section_name("auth", user_write.access.value)
configuration.set_option(write_section, user_write.username, user_write.password)
user_read = User("user_read", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user_read.access.value)
configuration.set_option(read_section, user_read.username, user_read.password)
user_read = User("user_read", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user_read.access.value)
configuration.set_option(read_section, user_read.username, user_read.password)
users = mapping_auth.get_users(configuration)
expected = {user_write.username: user_write, user_read.username: user_read}
assert users == expected
def test_get_users_normalized(mapping_auth: Auth, configuration: Configuration) -> None:
"""
must return user list with normalized usernames in keys
"""
user = User("UsEr", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user.access.value)
configuration.set_option(read_section, user.username, user.password)
users = mapping_auth.get_users(configuration)
expected = user.username.lower()
assert expected in users
assert users[expected].username == expected
def test_get_users_duplicate(mapping_auth: Auth, configuration: Configuration, user: User) -> None:
"""
must raise exception on duplicate username
"""
write_section = Configuration.section_name("auth", UserAccess.Write.value)
configuration.set_option(write_section, user.username, user.password)
read_section = Configuration.section_name("auth", UserAccess.Read.value)
configuration.set_option(read_section, user.username, user.password)
with pytest.raises(DuplicateUser):
mapping_auth.get_users(configuration)
async def test_check_credentials(auth: Auth, user: User) -> None:
""" """
must pass any credentials must pass any credentials
""" """
assert auth.check_credentials(user.username, user.password) assert await auth.check_credentials(user.username, user.password)
assert auth.check_credentials(None, "") assert await auth.check_credentials(None, "")
assert auth.check_credentials("", None) assert await auth.check_credentials("", None)
assert auth.check_credentials(None, None) assert await auth.check_credentials(None, None)
def test_is_safe_request(auth: Auth) -> None: async def test_is_safe_request(auth: Auth) -> None:
""" """
must validate safe request must validate safe request
""" """
# login and logout are always safe # login and logout are always safe
assert auth.is_safe_request("/user-api/v1/login", UserAccess.Write) assert await auth.is_safe_request("/user-api/v1/login", UserAccess.Write)
assert auth.is_safe_request("/user-api/v1/logout", UserAccess.Write) assert await auth.is_safe_request("/user-api/v1/logout", UserAccess.Write)
auth.allowed_paths.add("/safe") auth.allowed_paths.add("/safe")
auth.allowed_paths_groups.add("/unsafe/safe") auth.allowed_paths_groups.add("/unsafe/safe")
assert auth.is_safe_request("/safe", UserAccess.Write) assert await auth.is_safe_request("/safe", UserAccess.Write)
assert not auth.is_safe_request("/unsafe", UserAccess.Write) assert not await auth.is_safe_request("/unsafe", UserAccess.Write)
assert auth.is_safe_request("/unsafe/safe", UserAccess.Write) assert await auth.is_safe_request("/unsafe/safe", UserAccess.Write)
assert auth.is_safe_request("/unsafe/safe/suffix", UserAccess.Write) assert await auth.is_safe_request("/unsafe/safe/suffix", UserAccess.Write)
def test_is_safe_request_empty(auth: Auth) -> None: async def test_is_safe_request_empty(auth: Auth) -> None:
""" """
must not allow requests without path must not allow requests without path
""" """
assert not auth.is_safe_request(None, UserAccess.Read) assert not await auth.is_safe_request(None, UserAccess.Read)
assert not auth.is_safe_request("", UserAccess.Read) assert not await auth.is_safe_request("", UserAccess.Read)
def test_is_safe_request_read_only(auth: Auth) -> None: async def test_is_safe_request_read_only(auth: Auth) -> None:
""" """
must allow read-only requests if it is set in settings must allow read-only requests if it is set in settings
""" """
assert auth.is_safe_request("/", UserAccess.Read) assert await auth.is_safe_request("/", UserAccess.Read)
auth.allow_read_only = True auth.allow_read_only = True
assert auth.is_safe_request("/unsafe", UserAccess.Read) assert await auth.is_safe_request("/unsafe", UserAccess.Read)
def test_known_username(auth: Auth, user: User) -> None: async def test_known_username(auth: Auth, user: User) -> None:
""" """
must allow any username must allow any username
""" """
assert auth.known_username(user.username) assert await auth.known_username(user.username)
def test_verify_access(auth: Auth, user: User) -> None: async def test_verify_access(auth: Auth, user: User) -> None:
""" """
must allow any access must allow any access
""" """
assert auth.verify_access(user.username, user.access, None) assert await auth.verify_access(user.username, user.access, None)
assert auth.verify_access(user.username, UserAccess.Write, None) assert await auth.verify_access(user.username, UserAccess.Write, None)

View File

@ -0,0 +1,72 @@
from ahriman.core.auth.mapping import Mapping
from ahriman.models.user import User
from ahriman.models.user_access import UserAccess
async def test_check_credentials(mapping_auth: Mapping, user: User) -> None:
"""
must return true for valid credentials
"""
current_password = user.password
user.password = user.hash_password(mapping_auth.salt)
mapping_auth._users[user.username] = user
assert await mapping_auth.check_credentials(user.username, current_password)
# here password is hashed so it is invalid
assert not await mapping_auth.check_credentials(user.username, user.password)
async def test_check_credentials_empty(mapping_auth: Mapping) -> None:
"""
must reject on empty credentials
"""
assert not await mapping_auth.check_credentials(None, "")
assert not await mapping_auth.check_credentials("", None)
assert not await mapping_auth.check_credentials(None, None)
async def test_check_credentials_unknown(mapping_auth: Mapping, user: User) -> None:
"""
must reject on unknown user
"""
assert not await mapping_auth.check_credentials(user.username, user.password)
def test_get_user(mapping_auth: Mapping, user: User) -> None:
"""
must return user from storage by username
"""
mapping_auth._users[user.username] = user
assert mapping_auth.get_user(user.username) == user
def test_get_user_normalized(mapping_auth: Mapping, user: User) -> None:
"""
must return user from storage by username case-insensitive
"""
mapping_auth._users[user.username] = user
assert mapping_auth.get_user(user.username.upper()) == user
def test_get_user_unknown(mapping_auth: Mapping, user: User) -> None:
"""
must return None in case if no user found
"""
assert mapping_auth.get_user(user.username) is None
async def test_known_username(mapping_auth: Mapping, user: User) -> None:
"""
must allow only known users
"""
mapping_auth._users[user.username] = user
assert await mapping_auth.known_username(user.username)
assert not await mapping_auth.known_username(user.password)
async def test_verify_access(mapping_auth: Mapping, user: User) -> None:
"""
must verify user access
"""
mapping_auth._users[user.username] = user
assert await mapping_auth.verify_access(user.username, user.access, None)
assert not await mapping_auth.verify_access(user.username, UserAccess.Write, None)

View File

@ -1,121 +0,0 @@
import pytest
from ahriman.core.auth.mapping_auth import MappingAuth
from ahriman.core.configuration import Configuration
from ahriman.core.exceptions import DuplicateUser
from ahriman.models.user import User
from ahriman.models.user_access import UserAccess
def test_get_users(mapping_auth: MappingAuth, configuration: Configuration) -> None:
"""
must return valid user list
"""
user_write = User("user_write", "pwd_write", UserAccess.Write)
write_section = Configuration.section_name("auth", user_write.access.value)
configuration.set_option(write_section, user_write.username, user_write.password)
user_read = User("user_read", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user_read.access.value)
configuration.set_option(read_section, user_read.username, user_read.password)
user_read = User("user_read", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user_read.access.value)
configuration.set_option(read_section, user_read.username, user_read.password)
users = mapping_auth.get_users(configuration)
expected = {user_write.username: user_write, user_read.username: user_read}
assert users == expected
def test_get_users_normalized(mapping_auth: MappingAuth, configuration: Configuration) -> None:
"""
must return user list with normalized usernames in keys
"""
user = User("UsEr", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user.access.value)
configuration.set_option(read_section, user.username, user.password)
users = mapping_auth.get_users(configuration)
expected = user.username.lower()
assert expected in users
assert users[expected].username == expected
def test_get_users_duplicate(mapping_auth: MappingAuth, configuration: Configuration, user: User) -> None:
"""
must raise exception on duplicate username
"""
write_section = Configuration.section_name("auth", UserAccess.Write.value)
configuration.set_option(write_section, user.username, user.password)
read_section = Configuration.section_name("auth", UserAccess.Read.value)
configuration.set_option(read_section, user.username, user.password)
with pytest.raises(DuplicateUser):
mapping_auth.get_users(configuration)
def test_check_credentials(mapping_auth: MappingAuth, user: User) -> None:
"""
must return true for valid credentials
"""
current_password = user.password
user.password = user.hash_password(mapping_auth.salt)
mapping_auth._users[user.username] = user
assert mapping_auth.check_credentials(user.username, current_password)
assert not mapping_auth.check_credentials(user.username, user.password) # here password is hashed so it is invalid
def test_check_credentials_empty(mapping_auth: MappingAuth) -> None:
"""
must reject on empty credentials
"""
assert not mapping_auth.check_credentials(None, "")
assert not mapping_auth.check_credentials("", None)
assert not mapping_auth.check_credentials(None, None)
def test_check_credentials_unknown(mapping_auth: MappingAuth, user: User) -> None:
"""
must reject on unknown user
"""
assert not mapping_auth.check_credentials(user.username, user.password)
def test_get_user(mapping_auth: MappingAuth, user: User) -> None:
"""
must return user from storage by username
"""
mapping_auth._users[user.username] = user
assert mapping_auth.get_user(user.username) == user
def test_get_user_normalized(mapping_auth: MappingAuth, user: User) -> None:
"""
must return user from storage by username case-insensitive
"""
mapping_auth._users[user.username] = user
assert mapping_auth.get_user(user.username.upper()) == user
def test_get_user_unknown(mapping_auth: MappingAuth, user: User) -> None:
"""
must return None in case if no user found
"""
assert mapping_auth.get_user(user.username) is None
def test_known_username(mapping_auth: MappingAuth, user: User) -> None:
"""
must allow only known users
"""
mapping_auth._users[user.username] = user
assert mapping_auth.known_username(user.username)
assert not mapping_auth.known_username(user.password)
def test_verify_access(mapping_auth: MappingAuth, user: User) -> None:
"""
must verify user access
"""
mapping_auth._users[user.username] = user
assert mapping_auth.verify_access(user.username, user.access, None)
assert not mapping_auth.verify_access(user.username, UserAccess.Write, None)

View File

@ -21,6 +21,10 @@ def test_from_option_valid() -> None:
assert AuthSettings.from_option("no") == AuthSettings.Disabled assert AuthSettings.from_option("no") == AuthSettings.Disabled
assert AuthSettings.from_option("NO") == AuthSettings.Disabled assert AuthSettings.from_option("NO") == AuthSettings.Disabled
assert AuthSettings.from_option("oauth") == AuthSettings.OAuth
assert AuthSettings.from_option("OAuth") == AuthSettings.OAuth
assert AuthSettings.from_option("OAuth2") == AuthSettings.OAuth
assert AuthSettings.from_option("configuration") == AuthSettings.Configuration assert AuthSettings.from_option("configuration") == AuthSettings.Configuration
assert AuthSettings.from_option("ConFigUration") == AuthSettings.Configuration assert AuthSettings.from_option("ConFigUration") == AuthSettings.Configuration
assert AuthSettings.from_option("mapping") == AuthSettings.Configuration assert AuthSettings.from_option("mapping") == AuthSettings.Configuration

View File

@ -2,10 +2,9 @@ import pytest
from aiohttp import web from aiohttp import web
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from unittest.mock import AsyncMock, MagicMock from unittest.mock import AsyncMock
from ahriman.core.auth.auth import Auth from ahriman.core.auth.auth import Auth
from ahriman.core.configuration import Configuration
from ahriman.models.user import User from ahriman.models.user import User
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.web.middlewares.auth_handler import auth_handler, AuthorizationPolicy, setup_auth from ahriman.web.middlewares.auth_handler import auth_handler, AuthorizationPolicy, setup_auth
@ -23,7 +22,7 @@ async def test_permits(authorization_policy: AuthorizationPolicy, user: User) ->
""" """
must call validator check must call validator check
""" """
authorization_policy.validator = MagicMock() authorization_policy.validator = AsyncMock()
authorization_policy.validator.verify_access.return_value = True authorization_policy.validator.verify_access.return_value = True
assert await authorization_policy.permits(user.username, user.access, "/endpoint") assert await authorization_policy.permits(user.username, user.access, "/endpoint")