mirror of
https://github.com/arcan1s/ahriman.git
synced 2025-07-29 13:49:57 +00:00
OAuth2 (#32)
* make auth method asyncs * oauth2 demo support * full coverage * update docs
This commit is contained in:
@ -1,13 +1,26 @@
|
||||
import pytest
|
||||
|
||||
from ahriman.core.auth.mapping_auth import MappingAuth
|
||||
from ahriman.core.auth.mapping import Mapping
|
||||
from ahriman.core.auth.oauth import OAuth
|
||||
from ahriman.core.configuration import Configuration
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mapping_auth(configuration: Configuration) -> MappingAuth:
|
||||
def mapping(configuration: Configuration) -> Mapping:
|
||||
"""
|
||||
auth provider fixture
|
||||
:param configuration: configuration fixture
|
||||
:return: auth service instance
|
||||
"""
|
||||
return MappingAuth(configuration)
|
||||
return Mapping(configuration)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth(configuration: Configuration) -> OAuth:
|
||||
"""
|
||||
OAuth provider fixture
|
||||
:param configuration: configuration fixture
|
||||
:return: OAuth2 service instance
|
||||
"""
|
||||
configuration.set("web", "address", "https://example.com")
|
||||
return OAuth(configuration)
|
||||
|
@ -1,10 +1,22 @@
|
||||
import pytest
|
||||
|
||||
from ahriman.core.auth.auth import Auth
|
||||
from ahriman.core.auth.mapping_auth import MappingAuth
|
||||
from ahriman.core.auth.mapping import Mapping
|
||||
from ahriman.core.auth.oauth import OAuth
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.core.exceptions import DuplicateUser
|
||||
from ahriman.models.user import User
|
||||
from ahriman.models.user_access import UserAccess
|
||||
|
||||
|
||||
def test_auth_control(auth: Auth) -> None:
|
||||
"""
|
||||
must return a control for authorization
|
||||
"""
|
||||
assert auth.auth_control
|
||||
assert "button" in auth.auth_control # I think it should be button
|
||||
|
||||
|
||||
def test_load_dummy(configuration: Configuration) -> None:
|
||||
"""
|
||||
must load dummy validator if authorization is not enabled
|
||||
@ -28,63 +40,119 @@ def test_load_mapping(configuration: Configuration) -> None:
|
||||
"""
|
||||
configuration.set_option("auth", "target", "configuration")
|
||||
auth = Auth.load(configuration)
|
||||
assert isinstance(auth, MappingAuth)
|
||||
assert isinstance(auth, Mapping)
|
||||
|
||||
|
||||
def test_check_credentials(auth: Auth, user: User) -> None:
|
||||
def test_load_oauth(configuration: Configuration) -> None:
|
||||
"""
|
||||
must load OAuth2 validator if option set
|
||||
"""
|
||||
configuration.set_option("auth", "target", "oauth")
|
||||
configuration.set_option("web", "address", "https://example.com")
|
||||
auth = Auth.load(configuration)
|
||||
assert isinstance(auth, OAuth)
|
||||
|
||||
|
||||
def test_get_users(mapping: Auth, configuration: Configuration) -> None:
|
||||
"""
|
||||
must return valid user list
|
||||
"""
|
||||
user_write = User("user_write", "pwd_write", UserAccess.Write)
|
||||
write_section = Configuration.section_name("auth", user_write.access.value)
|
||||
configuration.set_option(write_section, user_write.username, user_write.password)
|
||||
user_read = User("user_read", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user_read.access.value)
|
||||
configuration.set_option(read_section, user_read.username, user_read.password)
|
||||
user_read = User("user_read", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user_read.access.value)
|
||||
configuration.set_option(read_section, user_read.username, user_read.password)
|
||||
|
||||
users = mapping.get_users(configuration)
|
||||
expected = {user_write.username: user_write, user_read.username: user_read}
|
||||
assert users == expected
|
||||
|
||||
|
||||
def test_get_users_normalized(mapping: Auth, configuration: Configuration) -> None:
|
||||
"""
|
||||
must return user list with normalized usernames in keys
|
||||
"""
|
||||
user = User("UsEr", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user.access.value)
|
||||
configuration.set_option(read_section, user.username, user.password)
|
||||
|
||||
users = mapping.get_users(configuration)
|
||||
expected = user.username.lower()
|
||||
assert expected in users
|
||||
assert users[expected].username == expected
|
||||
|
||||
|
||||
def test_get_users_duplicate(mapping: Auth, configuration: Configuration, user: User) -> None:
|
||||
"""
|
||||
must raise exception on duplicate username
|
||||
"""
|
||||
write_section = Configuration.section_name("auth", UserAccess.Write.value)
|
||||
configuration.set_option(write_section, user.username, user.password)
|
||||
read_section = Configuration.section_name("auth", UserAccess.Read.value)
|
||||
configuration.set_option(read_section, user.username, user.password)
|
||||
|
||||
with pytest.raises(DuplicateUser):
|
||||
mapping.get_users(configuration)
|
||||
|
||||
|
||||
async def test_check_credentials(auth: Auth, user: User) -> None:
|
||||
"""
|
||||
must pass any credentials
|
||||
"""
|
||||
assert auth.check_credentials(user.username, user.password)
|
||||
assert auth.check_credentials(None, "")
|
||||
assert auth.check_credentials("", None)
|
||||
assert auth.check_credentials(None, None)
|
||||
assert await auth.check_credentials(user.username, user.password)
|
||||
assert await auth.check_credentials(None, "")
|
||||
assert await auth.check_credentials("", None)
|
||||
assert await auth.check_credentials(None, None)
|
||||
|
||||
|
||||
def test_is_safe_request(auth: Auth) -> None:
|
||||
async def test_is_safe_request(auth: Auth) -> None:
|
||||
"""
|
||||
must validate safe request
|
||||
"""
|
||||
# login and logout are always safe
|
||||
assert auth.is_safe_request("/user-api/v1/login", UserAccess.Write)
|
||||
assert auth.is_safe_request("/user-api/v1/logout", UserAccess.Write)
|
||||
assert await auth.is_safe_request("/user-api/v1/login", UserAccess.Write)
|
||||
assert await auth.is_safe_request("/user-api/v1/logout", UserAccess.Write)
|
||||
|
||||
auth.allowed_paths.add("/safe")
|
||||
auth.allowed_paths_groups.add("/unsafe/safe")
|
||||
|
||||
assert auth.is_safe_request("/safe", UserAccess.Write)
|
||||
assert not auth.is_safe_request("/unsafe", UserAccess.Write)
|
||||
assert auth.is_safe_request("/unsafe/safe", UserAccess.Write)
|
||||
assert auth.is_safe_request("/unsafe/safe/suffix", UserAccess.Write)
|
||||
assert await auth.is_safe_request("/safe", UserAccess.Write)
|
||||
assert not await auth.is_safe_request("/unsafe", UserAccess.Write)
|
||||
assert await auth.is_safe_request("/unsafe/safe", UserAccess.Write)
|
||||
assert await auth.is_safe_request("/unsafe/safe/suffix", UserAccess.Write)
|
||||
|
||||
|
||||
def test_is_safe_request_empty(auth: Auth) -> None:
|
||||
async def test_is_safe_request_empty(auth: Auth) -> None:
|
||||
"""
|
||||
must not allow requests without path
|
||||
"""
|
||||
assert not auth.is_safe_request(None, UserAccess.Read)
|
||||
assert not auth.is_safe_request("", UserAccess.Read)
|
||||
assert not await auth.is_safe_request(None, UserAccess.Read)
|
||||
assert not await auth.is_safe_request("", UserAccess.Read)
|
||||
|
||||
|
||||
def test_is_safe_request_read_only(auth: Auth) -> None:
|
||||
async def test_is_safe_request_read_only(auth: Auth) -> None:
|
||||
"""
|
||||
must allow read-only requests if it is set in settings
|
||||
"""
|
||||
assert auth.is_safe_request("/", UserAccess.Read)
|
||||
assert await auth.is_safe_request("/", UserAccess.Read)
|
||||
auth.allow_read_only = True
|
||||
assert auth.is_safe_request("/unsafe", UserAccess.Read)
|
||||
assert await auth.is_safe_request("/unsafe", UserAccess.Read)
|
||||
|
||||
|
||||
def test_known_username(auth: Auth, user: User) -> None:
|
||||
async def test_known_username(auth: Auth, user: User) -> None:
|
||||
"""
|
||||
must allow any username
|
||||
"""
|
||||
assert auth.known_username(user.username)
|
||||
assert await auth.known_username(user.username)
|
||||
|
||||
|
||||
def test_verify_access(auth: Auth, user: User) -> None:
|
||||
async def test_verify_access(auth: Auth, user: User) -> None:
|
||||
"""
|
||||
must allow any access
|
||||
"""
|
||||
assert auth.verify_access(user.username, user.access, None)
|
||||
assert auth.verify_access(user.username, UserAccess.Write, None)
|
||||
assert await auth.verify_access(user.username, user.access, None)
|
||||
assert await auth.verify_access(user.username, UserAccess.Write, None)
|
||||
|
73
tests/ahriman/core/auth/test_mapping.py
Normal file
73
tests/ahriman/core/auth/test_mapping.py
Normal file
@ -0,0 +1,73 @@
|
||||
from ahriman.core.auth.mapping import Mapping
|
||||
from ahriman.models.user import User
|
||||
from ahriman.models.user_access import UserAccess
|
||||
|
||||
|
||||
async def test_check_credentials(mapping: Mapping, user: User) -> None:
|
||||
"""
|
||||
must return true for valid credentials
|
||||
"""
|
||||
current_password = user.password
|
||||
user.password = user.hash_password(mapping.salt)
|
||||
mapping._users[user.username] = user
|
||||
assert await mapping.check_credentials(user.username, current_password)
|
||||
# here password is hashed so it is invalid
|
||||
assert not await mapping.check_credentials(user.username, user.password)
|
||||
|
||||
|
||||
async def test_check_credentials_empty(mapping: Mapping) -> None:
|
||||
"""
|
||||
must reject on empty credentials
|
||||
"""
|
||||
assert not await mapping.check_credentials(None, "")
|
||||
assert not await mapping.check_credentials("", None)
|
||||
assert not await mapping.check_credentials(None, None)
|
||||
|
||||
|
||||
async def test_check_credentials_unknown(mapping: Mapping, user: User) -> None:
|
||||
"""
|
||||
must reject on unknown user
|
||||
"""
|
||||
assert not await mapping.check_credentials(user.username, user.password)
|
||||
|
||||
|
||||
def test_get_user(mapping: Mapping, user: User) -> None:
|
||||
"""
|
||||
must return user from storage by username
|
||||
"""
|
||||
mapping._users[user.username] = user
|
||||
assert mapping.get_user(user.username) == user
|
||||
|
||||
|
||||
def test_get_user_normalized(mapping: Mapping, user: User) -> None:
|
||||
"""
|
||||
must return user from storage by username case-insensitive
|
||||
"""
|
||||
mapping._users[user.username] = user
|
||||
assert mapping.get_user(user.username.upper()) == user
|
||||
|
||||
|
||||
def test_get_user_unknown(mapping: Mapping, user: User) -> None:
|
||||
"""
|
||||
must return None in case if no user found
|
||||
"""
|
||||
assert mapping.get_user(user.username) is None
|
||||
|
||||
|
||||
async def test_known_username(mapping: Mapping, user: User) -> None:
|
||||
"""
|
||||
must allow only known users
|
||||
"""
|
||||
mapping._users[user.username] = user
|
||||
assert await mapping.known_username(user.username)
|
||||
assert not await mapping.known_username(None)
|
||||
assert not await mapping.known_username(user.password)
|
||||
|
||||
|
||||
async def test_verify_access(mapping: Mapping, user: User) -> None:
|
||||
"""
|
||||
must verify user access
|
||||
"""
|
||||
mapping._users[user.username] = user
|
||||
assert await mapping.verify_access(user.username, user.access, None)
|
||||
assert not await mapping.verify_access(user.username, UserAccess.Write, None)
|
@ -1,121 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from ahriman.core.auth.mapping_auth import MappingAuth
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.core.exceptions import DuplicateUser
|
||||
from ahriman.models.user import User
|
||||
from ahriman.models.user_access import UserAccess
|
||||
|
||||
|
||||
def test_get_users(mapping_auth: MappingAuth, configuration: Configuration) -> None:
|
||||
"""
|
||||
must return valid user list
|
||||
"""
|
||||
user_write = User("user_write", "pwd_write", UserAccess.Write)
|
||||
write_section = Configuration.section_name("auth", user_write.access.value)
|
||||
configuration.set_option(write_section, user_write.username, user_write.password)
|
||||
user_read = User("user_read", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user_read.access.value)
|
||||
configuration.set_option(read_section, user_read.username, user_read.password)
|
||||
user_read = User("user_read", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user_read.access.value)
|
||||
configuration.set_option(read_section, user_read.username, user_read.password)
|
||||
|
||||
users = mapping_auth.get_users(configuration)
|
||||
expected = {user_write.username: user_write, user_read.username: user_read}
|
||||
assert users == expected
|
||||
|
||||
|
||||
def test_get_users_normalized(mapping_auth: MappingAuth, configuration: Configuration) -> None:
|
||||
"""
|
||||
must return user list with normalized usernames in keys
|
||||
"""
|
||||
user = User("UsEr", "pwd_read", UserAccess.Read)
|
||||
read_section = Configuration.section_name("auth", user.access.value)
|
||||
configuration.set_option(read_section, user.username, user.password)
|
||||
|
||||
users = mapping_auth.get_users(configuration)
|
||||
expected = user.username.lower()
|
||||
assert expected in users
|
||||
assert users[expected].username == expected
|
||||
|
||||
|
||||
def test_get_users_duplicate(mapping_auth: MappingAuth, configuration: Configuration, user: User) -> None:
|
||||
"""
|
||||
must raise exception on duplicate username
|
||||
"""
|
||||
write_section = Configuration.section_name("auth", UserAccess.Write.value)
|
||||
configuration.set_option(write_section, user.username, user.password)
|
||||
read_section = Configuration.section_name("auth", UserAccess.Read.value)
|
||||
configuration.set_option(read_section, user.username, user.password)
|
||||
|
||||
with pytest.raises(DuplicateUser):
|
||||
mapping_auth.get_users(configuration)
|
||||
|
||||
|
||||
def test_check_credentials(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must return true for valid credentials
|
||||
"""
|
||||
current_password = user.password
|
||||
user.password = user.hash_password(mapping_auth.salt)
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.check_credentials(user.username, current_password)
|
||||
assert not mapping_auth.check_credentials(user.username, user.password) # here password is hashed so it is invalid
|
||||
|
||||
|
||||
def test_check_credentials_empty(mapping_auth: MappingAuth) -> None:
|
||||
"""
|
||||
must reject on empty credentials
|
||||
"""
|
||||
assert not mapping_auth.check_credentials(None, "")
|
||||
assert not mapping_auth.check_credentials("", None)
|
||||
assert not mapping_auth.check_credentials(None, None)
|
||||
|
||||
|
||||
def test_check_credentials_unknown(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must reject on unknown user
|
||||
"""
|
||||
assert not mapping_auth.check_credentials(user.username, user.password)
|
||||
|
||||
|
||||
def test_get_user(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must return user from storage by username
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.get_user(user.username) == user
|
||||
|
||||
|
||||
def test_get_user_normalized(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must return user from storage by username case-insensitive
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.get_user(user.username.upper()) == user
|
||||
|
||||
|
||||
def test_get_user_unknown(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must return None in case if no user found
|
||||
"""
|
||||
assert mapping_auth.get_user(user.username) is None
|
||||
|
||||
|
||||
def test_known_username(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must allow only known users
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.known_username(user.username)
|
||||
assert not mapping_auth.known_username(user.password)
|
||||
|
||||
|
||||
def test_verify_access(mapping_auth: MappingAuth, user: User) -> None:
|
||||
"""
|
||||
must verify user access
|
||||
"""
|
||||
mapping_auth._users[user.username] = user
|
||||
assert mapping_auth.verify_access(user.username, user.access, None)
|
||||
assert not mapping_auth.verify_access(user.username, UserAccess.Write, None)
|
98
tests/ahriman/core/auth/test_oauth.py
Normal file
98
tests/ahriman/core/auth/test_oauth.py
Normal file
@ -0,0 +1,98 @@
|
||||
import aioauth_client
|
||||
import pytest
|
||||
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from ahriman.core.auth.oauth import OAuth
|
||||
from ahriman.core.exceptions import InvalidOption
|
||||
|
||||
|
||||
def test_auth_control(oauth: OAuth) -> None:
|
||||
"""
|
||||
must return a control for authorization
|
||||
"""
|
||||
assert oauth.auth_control
|
||||
assert "<a" in oauth.auth_control # I think it should be a link
|
||||
|
||||
|
||||
def test_get_provider() -> None:
|
||||
"""
|
||||
must return valid provider type
|
||||
"""
|
||||
assert OAuth.get_provider("OAuth2Client") == aioauth_client.OAuth2Client
|
||||
assert OAuth.get_provider("GoogleClient") == aioauth_client.GoogleClient
|
||||
assert OAuth.get_provider("GoogleClient") == aioauth_client.GoogleClient
|
||||
|
||||
|
||||
def test_get_provider_not_a_type() -> None:
|
||||
"""
|
||||
must raise an exception if attribute is not a type
|
||||
"""
|
||||
with pytest.raises(InvalidOption):
|
||||
OAuth.get_provider("__version__")
|
||||
|
||||
|
||||
def test_get_provider_invalid_type() -> None:
|
||||
"""
|
||||
must raise an exception if attribute is not an OAuth2 client
|
||||
"""
|
||||
with pytest.raises(InvalidOption):
|
||||
OAuth.get_provider("User")
|
||||
with pytest.raises(InvalidOption):
|
||||
OAuth.get_provider("OAuth1Client")
|
||||
|
||||
|
||||
def test_get_client(oauth: OAuth) -> None:
|
||||
"""
|
||||
must return valid OAuth2 client
|
||||
"""
|
||||
client = oauth.get_client()
|
||||
assert isinstance(client, aioauth_client.GoogleClient)
|
||||
assert client.client_id == oauth.client_id
|
||||
assert client.client_secret == oauth.client_secret
|
||||
|
||||
|
||||
def test_get_oauth_url(oauth: OAuth, mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must generate valid OAuth authorization URL
|
||||
"""
|
||||
authorize_url_mock = mocker.patch("aioauth_client.GoogleClient.get_authorize_url")
|
||||
oauth.get_oauth_url()
|
||||
authorize_url_mock.assert_called_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri)
|
||||
|
||||
|
||||
async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must return authorized user ID
|
||||
"""
|
||||
access_token_mock = mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
|
||||
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info",
|
||||
return_value=(aioauth_client.User(email="email"), ""))
|
||||
|
||||
email = await oauth.get_oauth_username("code")
|
||||
access_token_mock.assert_called_with("code", redirect_uri=oauth.redirect_uri)
|
||||
user_info_mock.assert_called_once()
|
||||
assert email == "email"
|
||||
|
||||
|
||||
async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must return None in case of OAuth request error (get_access_token)
|
||||
"""
|
||||
mocker.patch("aioauth_client.GoogleClient.get_access_token", side_effect=Exception())
|
||||
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info")
|
||||
|
||||
email = await oauth.get_oauth_username("code")
|
||||
assert email is None
|
||||
user_info_mock.assert_not_called()
|
||||
|
||||
|
||||
async def test_get_oauth_username_exception_2(oauth: OAuth, mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must return None in case of OAuth request error (user_info)
|
||||
"""
|
||||
mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
|
||||
mocker.patch("aioauth_client.GoogleClient.user_info", side_effect=Exception())
|
||||
|
||||
email = await oauth.get_oauth_username("code")
|
||||
assert email is None
|
Reference in New Issue
Block a user