* make auth method asyncs

* oauth2 demo support

* full coverage

* update docs
This commit is contained in:
2021-09-12 21:41:38 +03:00
committed by GitHub
parent 1b29b5773d
commit d19deb57e7
39 changed files with 695 additions and 251 deletions

View File

@ -1,13 +1,26 @@
import pytest
from ahriman.core.auth.mapping_auth import MappingAuth
from ahriman.core.auth.mapping import Mapping
from ahriman.core.auth.oauth import OAuth
from ahriman.core.configuration import Configuration
@pytest.fixture
def mapping_auth(configuration: Configuration) -> MappingAuth:
def mapping(configuration: Configuration) -> Mapping:
"""
auth provider fixture
:param configuration: configuration fixture
:return: auth service instance
"""
return MappingAuth(configuration)
return Mapping(configuration)
@pytest.fixture
def oauth(configuration: Configuration) -> OAuth:
"""
OAuth provider fixture
:param configuration: configuration fixture
:return: OAuth2 service instance
"""
configuration.set("web", "address", "https://example.com")
return OAuth(configuration)

View File

@ -1,10 +1,22 @@
import pytest
from ahriman.core.auth.auth import Auth
from ahriman.core.auth.mapping_auth import MappingAuth
from ahriman.core.auth.mapping import Mapping
from ahriman.core.auth.oauth import OAuth
from ahriman.core.configuration import Configuration
from ahriman.core.exceptions import DuplicateUser
from ahriman.models.user import User
from ahriman.models.user_access import UserAccess
def test_auth_control(auth: Auth) -> None:
"""
must return a control for authorization
"""
assert auth.auth_control
assert "button" in auth.auth_control # I think it should be button
def test_load_dummy(configuration: Configuration) -> None:
"""
must load dummy validator if authorization is not enabled
@ -28,63 +40,119 @@ def test_load_mapping(configuration: Configuration) -> None:
"""
configuration.set_option("auth", "target", "configuration")
auth = Auth.load(configuration)
assert isinstance(auth, MappingAuth)
assert isinstance(auth, Mapping)
def test_check_credentials(auth: Auth, user: User) -> None:
def test_load_oauth(configuration: Configuration) -> None:
"""
must load OAuth2 validator if option set
"""
configuration.set_option("auth", "target", "oauth")
configuration.set_option("web", "address", "https://example.com")
auth = Auth.load(configuration)
assert isinstance(auth, OAuth)
def test_get_users(mapping: Auth, configuration: Configuration) -> None:
"""
must return valid user list
"""
user_write = User("user_write", "pwd_write", UserAccess.Write)
write_section = Configuration.section_name("auth", user_write.access.value)
configuration.set_option(write_section, user_write.username, user_write.password)
user_read = User("user_read", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user_read.access.value)
configuration.set_option(read_section, user_read.username, user_read.password)
user_read = User("user_read", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user_read.access.value)
configuration.set_option(read_section, user_read.username, user_read.password)
users = mapping.get_users(configuration)
expected = {user_write.username: user_write, user_read.username: user_read}
assert users == expected
def test_get_users_normalized(mapping: Auth, configuration: Configuration) -> None:
"""
must return user list with normalized usernames in keys
"""
user = User("UsEr", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user.access.value)
configuration.set_option(read_section, user.username, user.password)
users = mapping.get_users(configuration)
expected = user.username.lower()
assert expected in users
assert users[expected].username == expected
def test_get_users_duplicate(mapping: Auth, configuration: Configuration, user: User) -> None:
"""
must raise exception on duplicate username
"""
write_section = Configuration.section_name("auth", UserAccess.Write.value)
configuration.set_option(write_section, user.username, user.password)
read_section = Configuration.section_name("auth", UserAccess.Read.value)
configuration.set_option(read_section, user.username, user.password)
with pytest.raises(DuplicateUser):
mapping.get_users(configuration)
async def test_check_credentials(auth: Auth, user: User) -> None:
"""
must pass any credentials
"""
assert auth.check_credentials(user.username, user.password)
assert auth.check_credentials(None, "")
assert auth.check_credentials("", None)
assert auth.check_credentials(None, None)
assert await auth.check_credentials(user.username, user.password)
assert await auth.check_credentials(None, "")
assert await auth.check_credentials("", None)
assert await auth.check_credentials(None, None)
def test_is_safe_request(auth: Auth) -> None:
async def test_is_safe_request(auth: Auth) -> None:
"""
must validate safe request
"""
# login and logout are always safe
assert auth.is_safe_request("/user-api/v1/login", UserAccess.Write)
assert auth.is_safe_request("/user-api/v1/logout", UserAccess.Write)
assert await auth.is_safe_request("/user-api/v1/login", UserAccess.Write)
assert await auth.is_safe_request("/user-api/v1/logout", UserAccess.Write)
auth.allowed_paths.add("/safe")
auth.allowed_paths_groups.add("/unsafe/safe")
assert auth.is_safe_request("/safe", UserAccess.Write)
assert not auth.is_safe_request("/unsafe", UserAccess.Write)
assert auth.is_safe_request("/unsafe/safe", UserAccess.Write)
assert auth.is_safe_request("/unsafe/safe/suffix", UserAccess.Write)
assert await auth.is_safe_request("/safe", UserAccess.Write)
assert not await auth.is_safe_request("/unsafe", UserAccess.Write)
assert await auth.is_safe_request("/unsafe/safe", UserAccess.Write)
assert await auth.is_safe_request("/unsafe/safe/suffix", UserAccess.Write)
def test_is_safe_request_empty(auth: Auth) -> None:
async def test_is_safe_request_empty(auth: Auth) -> None:
"""
must not allow requests without path
"""
assert not auth.is_safe_request(None, UserAccess.Read)
assert not auth.is_safe_request("", UserAccess.Read)
assert not await auth.is_safe_request(None, UserAccess.Read)
assert not await auth.is_safe_request("", UserAccess.Read)
def test_is_safe_request_read_only(auth: Auth) -> None:
async def test_is_safe_request_read_only(auth: Auth) -> None:
"""
must allow read-only requests if it is set in settings
"""
assert auth.is_safe_request("/", UserAccess.Read)
assert await auth.is_safe_request("/", UserAccess.Read)
auth.allow_read_only = True
assert auth.is_safe_request("/unsafe", UserAccess.Read)
assert await auth.is_safe_request("/unsafe", UserAccess.Read)
def test_known_username(auth: Auth, user: User) -> None:
async def test_known_username(auth: Auth, user: User) -> None:
"""
must allow any username
"""
assert auth.known_username(user.username)
assert await auth.known_username(user.username)
def test_verify_access(auth: Auth, user: User) -> None:
async def test_verify_access(auth: Auth, user: User) -> None:
"""
must allow any access
"""
assert auth.verify_access(user.username, user.access, None)
assert auth.verify_access(user.username, UserAccess.Write, None)
assert await auth.verify_access(user.username, user.access, None)
assert await auth.verify_access(user.username, UserAccess.Write, None)

View File

@ -0,0 +1,73 @@
from ahriman.core.auth.mapping import Mapping
from ahriman.models.user import User
from ahriman.models.user_access import UserAccess
async def test_check_credentials(mapping: Mapping, user: User) -> None:
"""
must return true for valid credentials
"""
current_password = user.password
user.password = user.hash_password(mapping.salt)
mapping._users[user.username] = user
assert await mapping.check_credentials(user.username, current_password)
# here password is hashed so it is invalid
assert not await mapping.check_credentials(user.username, user.password)
async def test_check_credentials_empty(mapping: Mapping) -> None:
"""
must reject on empty credentials
"""
assert not await mapping.check_credentials(None, "")
assert not await mapping.check_credentials("", None)
assert not await mapping.check_credentials(None, None)
async def test_check_credentials_unknown(mapping: Mapping, user: User) -> None:
"""
must reject on unknown user
"""
assert not await mapping.check_credentials(user.username, user.password)
def test_get_user(mapping: Mapping, user: User) -> None:
"""
must return user from storage by username
"""
mapping._users[user.username] = user
assert mapping.get_user(user.username) == user
def test_get_user_normalized(mapping: Mapping, user: User) -> None:
"""
must return user from storage by username case-insensitive
"""
mapping._users[user.username] = user
assert mapping.get_user(user.username.upper()) == user
def test_get_user_unknown(mapping: Mapping, user: User) -> None:
"""
must return None in case if no user found
"""
assert mapping.get_user(user.username) is None
async def test_known_username(mapping: Mapping, user: User) -> None:
"""
must allow only known users
"""
mapping._users[user.username] = user
assert await mapping.known_username(user.username)
assert not await mapping.known_username(None)
assert not await mapping.known_username(user.password)
async def test_verify_access(mapping: Mapping, user: User) -> None:
"""
must verify user access
"""
mapping._users[user.username] = user
assert await mapping.verify_access(user.username, user.access, None)
assert not await mapping.verify_access(user.username, UserAccess.Write, None)

View File

@ -1,121 +0,0 @@
import pytest
from ahriman.core.auth.mapping_auth import MappingAuth
from ahriman.core.configuration import Configuration
from ahriman.core.exceptions import DuplicateUser
from ahriman.models.user import User
from ahriman.models.user_access import UserAccess
def test_get_users(mapping_auth: MappingAuth, configuration: Configuration) -> None:
"""
must return valid user list
"""
user_write = User("user_write", "pwd_write", UserAccess.Write)
write_section = Configuration.section_name("auth", user_write.access.value)
configuration.set_option(write_section, user_write.username, user_write.password)
user_read = User("user_read", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user_read.access.value)
configuration.set_option(read_section, user_read.username, user_read.password)
user_read = User("user_read", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user_read.access.value)
configuration.set_option(read_section, user_read.username, user_read.password)
users = mapping_auth.get_users(configuration)
expected = {user_write.username: user_write, user_read.username: user_read}
assert users == expected
def test_get_users_normalized(mapping_auth: MappingAuth, configuration: Configuration) -> None:
"""
must return user list with normalized usernames in keys
"""
user = User("UsEr", "pwd_read", UserAccess.Read)
read_section = Configuration.section_name("auth", user.access.value)
configuration.set_option(read_section, user.username, user.password)
users = mapping_auth.get_users(configuration)
expected = user.username.lower()
assert expected in users
assert users[expected].username == expected
def test_get_users_duplicate(mapping_auth: MappingAuth, configuration: Configuration, user: User) -> None:
"""
must raise exception on duplicate username
"""
write_section = Configuration.section_name("auth", UserAccess.Write.value)
configuration.set_option(write_section, user.username, user.password)
read_section = Configuration.section_name("auth", UserAccess.Read.value)
configuration.set_option(read_section, user.username, user.password)
with pytest.raises(DuplicateUser):
mapping_auth.get_users(configuration)
def test_check_credentials(mapping_auth: MappingAuth, user: User) -> None:
"""
must return true for valid credentials
"""
current_password = user.password
user.password = user.hash_password(mapping_auth.salt)
mapping_auth._users[user.username] = user
assert mapping_auth.check_credentials(user.username, current_password)
assert not mapping_auth.check_credentials(user.username, user.password) # here password is hashed so it is invalid
def test_check_credentials_empty(mapping_auth: MappingAuth) -> None:
"""
must reject on empty credentials
"""
assert not mapping_auth.check_credentials(None, "")
assert not mapping_auth.check_credentials("", None)
assert not mapping_auth.check_credentials(None, None)
def test_check_credentials_unknown(mapping_auth: MappingAuth, user: User) -> None:
"""
must reject on unknown user
"""
assert not mapping_auth.check_credentials(user.username, user.password)
def test_get_user(mapping_auth: MappingAuth, user: User) -> None:
"""
must return user from storage by username
"""
mapping_auth._users[user.username] = user
assert mapping_auth.get_user(user.username) == user
def test_get_user_normalized(mapping_auth: MappingAuth, user: User) -> None:
"""
must return user from storage by username case-insensitive
"""
mapping_auth._users[user.username] = user
assert mapping_auth.get_user(user.username.upper()) == user
def test_get_user_unknown(mapping_auth: MappingAuth, user: User) -> None:
"""
must return None in case if no user found
"""
assert mapping_auth.get_user(user.username) is None
def test_known_username(mapping_auth: MappingAuth, user: User) -> None:
"""
must allow only known users
"""
mapping_auth._users[user.username] = user
assert mapping_auth.known_username(user.username)
assert not mapping_auth.known_username(user.password)
def test_verify_access(mapping_auth: MappingAuth, user: User) -> None:
"""
must verify user access
"""
mapping_auth._users[user.username] = user
assert mapping_auth.verify_access(user.username, user.access, None)
assert not mapping_auth.verify_access(user.username, UserAccess.Write, None)

View File

@ -0,0 +1,98 @@
import aioauth_client
import pytest
from pytest_mock import MockerFixture
from ahriman.core.auth.oauth import OAuth
from ahriman.core.exceptions import InvalidOption
def test_auth_control(oauth: OAuth) -> None:
"""
must return a control for authorization
"""
assert oauth.auth_control
assert "<a" in oauth.auth_control # I think it should be a link
def test_get_provider() -> None:
"""
must return valid provider type
"""
assert OAuth.get_provider("OAuth2Client") == aioauth_client.OAuth2Client
assert OAuth.get_provider("GoogleClient") == aioauth_client.GoogleClient
assert OAuth.get_provider("GoogleClient") == aioauth_client.GoogleClient
def test_get_provider_not_a_type() -> None:
"""
must raise an exception if attribute is not a type
"""
with pytest.raises(InvalidOption):
OAuth.get_provider("__version__")
def test_get_provider_invalid_type() -> None:
"""
must raise an exception if attribute is not an OAuth2 client
"""
with pytest.raises(InvalidOption):
OAuth.get_provider("User")
with pytest.raises(InvalidOption):
OAuth.get_provider("OAuth1Client")
def test_get_client(oauth: OAuth) -> None:
"""
must return valid OAuth2 client
"""
client = oauth.get_client()
assert isinstance(client, aioauth_client.GoogleClient)
assert client.client_id == oauth.client_id
assert client.client_secret == oauth.client_secret
def test_get_oauth_url(oauth: OAuth, mocker: MockerFixture) -> None:
"""
must generate valid OAuth authorization URL
"""
authorize_url_mock = mocker.patch("aioauth_client.GoogleClient.get_authorize_url")
oauth.get_oauth_url()
authorize_url_mock.assert_called_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri)
async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
"""
must return authorized user ID
"""
access_token_mock = mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info",
return_value=(aioauth_client.User(email="email"), ""))
email = await oauth.get_oauth_username("code")
access_token_mock.assert_called_with("code", redirect_uri=oauth.redirect_uri)
user_info_mock.assert_called_once()
assert email == "email"
async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixture) -> None:
"""
must return None in case of OAuth request error (get_access_token)
"""
mocker.patch("aioauth_client.GoogleClient.get_access_token", side_effect=Exception())
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info")
email = await oauth.get_oauth_username("code")
assert email is None
user_info_mock.assert_not_called()
async def test_get_oauth_username_exception_2(oauth: OAuth, mocker: MockerFixture) -> None:
"""
must return None in case of OAuth request error (user_info)
"""
mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
mocker.patch("aioauth_client.GoogleClient.user_info", side_effect=Exception())
email = await oauth.get_oauth_username("code")
assert email is None

View File

@ -21,6 +21,10 @@ def test_from_option_valid() -> None:
assert AuthSettings.from_option("no") == AuthSettings.Disabled
assert AuthSettings.from_option("NO") == AuthSettings.Disabled
assert AuthSettings.from_option("oauth") == AuthSettings.OAuth
assert AuthSettings.from_option("OAuth") == AuthSettings.OAuth
assert AuthSettings.from_option("OAuth2") == AuthSettings.OAuth
assert AuthSettings.from_option("configuration") == AuthSettings.Configuration
assert AuthSettings.from_option("ConFigUration") == AuthSettings.Configuration
assert AuthSettings.from_option("mapping") == AuthSettings.Configuration

View File

@ -10,6 +10,7 @@ def test_from_option(user: User) -> None:
# default is read access
user.access = UserAccess.Write
assert User.from_option(user.username, user.password) != user
assert User.from_option(user.username, user.password, user.access) == user
def test_from_option_empty() -> None:
@ -32,6 +33,26 @@ def test_check_credentials_hash_password(user: User) -> None:
assert not user.check_credentials(user.password, "salt")
def test_check_credentials_empty_hash(user: User) -> None:
"""
must reject any authorization if the hash is invalid
"""
current_password = user.password
assert not user.check_credentials(current_password, "salt")
user.password = ""
assert not user.check_credentials(current_password, "salt")
def test_hash_password_empty_hash(user: User) -> None:
"""
must return empty string after hash in case if password not set
"""
user.password = ""
assert user.hash_password("salt") == ""
user.password = None
assert user.hash_password("salt") == ""
def test_generate_password() -> None:
"""
must generate password with specified length

View File

@ -2,10 +2,9 @@ import pytest
from aiohttp import web
from pytest_mock import MockerFixture
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock
from ahriman.core.auth.auth import Auth
from ahriman.core.configuration import Configuration
from ahriman.models.user import User
from ahriman.models.user_access import UserAccess
from ahriman.web.middlewares.auth_handler import auth_handler, AuthorizationPolicy, setup_auth
@ -23,7 +22,7 @@ async def test_permits(authorization_policy: AuthorizationPolicy, user: User) ->
"""
must call validator check
"""
authorization_policy.validator = MagicMock()
authorization_policy.validator = AsyncMock()
authorization_policy.validator.verify_access.return_value = True
assert await authorization_policy.permits(user.username, user.access, "/endpoint")

View File

@ -9,7 +9,7 @@ async def test_post(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add")
response = await client.post("/service-api/v1/add", json={"packages": ["ahriman"]})
assert response.status == 200
assert response.ok
add_mock.assert_called_with(["ahriman"], True)
@ -20,7 +20,7 @@ async def test_post_now(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add")
response = await client.post("/service-api/v1/add", json={"packages": ["ahriman"], "build_now": False})
assert response.status == 200
assert response.ok
add_mock.assert_called_with(["ahriman"], False)
@ -42,5 +42,5 @@ async def test_post_update(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add")
response = await client.post("/service-api/v1/update", json={"packages": ["ahriman"]})
assert response.status == 200
assert response.ok
add_mock.assert_called_with(["ahriman"], True)

View File

@ -9,7 +9,7 @@ async def test_post(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_remove")
response = await client.post("/service-api/v1/remove", json={"packages": ["ahriman"]})
assert response.status == 200
assert response.ok
add_mock.assert_called_with(["ahriman"])

View File

@ -11,7 +11,7 @@ async def test_get(client: TestClient, aur_package_ahriman: aur.Package, mocker:
mocker.patch("aur.search", return_value=[aur_package_ahriman])
response = await client.get("/service-api/v1/search", params={"for": "ahriman"})
assert response.status == 200
assert response.ok
assert await response.json() == ["ahriman"]
@ -33,7 +33,7 @@ async def test_get_join(client: TestClient, mocker: MockerFixture) -> None:
search_mock = mocker.patch("aur.search")
response = await client.get("/service-api/v1/search", params=[("for", "ahriman"), ("for", "maybe")])
assert response.status == 200
assert response.ok
search_mock.assert_called_with("ahriman maybe")
@ -44,7 +44,7 @@ async def test_get_join_filter(client: TestClient, mocker: MockerFixture) -> Non
search_mock = mocker.patch("aur.search")
response = await client.get("/service-api/v1/search", params=[("for", "ah"), ("for", "maybe")])
assert response.status == 200
assert response.ok
search_mock.assert_called_with("maybe")

View File

@ -11,7 +11,7 @@ async def test_get(client: TestClient) -> None:
response = await client.get("/status-api/v1/ahriman")
status = BuildStatus.from_json(await response.json())
assert response.status == 200
assert response.ok
assert status.status == BuildStatusEnum.Unknown
@ -26,7 +26,7 @@ async def test_post(client: TestClient) -> None:
response = await client.get("/status-api/v1/ahriman")
status = BuildStatus.from_json(await response.json())
assert response.status == 200
assert response.ok
assert status.status == BuildStatusEnum.Success

View File

@ -14,7 +14,7 @@ async def test_get(client: TestClient, package_ahriman: Package, package_python_
json={"status": BuildStatusEnum.Success.value, "package": package_python_schedule.view()})
response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}")
assert response.status == 200
assert response.ok
packages = [Package.from_json(item["package"]) for item in await response.json()]
assert packages
@ -45,7 +45,7 @@ async def test_delete(client: TestClient, package_ahriman: Package, package_pyth
assert response.status == 404
response = await client.get(f"/status-api/v1/packages/{package_python_schedule.base}")
assert response.status == 200
assert response.ok
async def test_delete_unknown(client: TestClient, package_ahriman: Package, package_python_schedule: Package) -> None:
@ -62,7 +62,7 @@ async def test_delete_unknown(client: TestClient, package_ahriman: Package, pack
assert response.status == 404
response = await client.get(f"/status-api/v1/packages/{package_python_schedule.base}")
assert response.status == 200
assert response.ok
async def test_post(client: TestClient, package_ahriman: Package) -> None:
@ -75,7 +75,7 @@ async def test_post(client: TestClient, package_ahriman: Package) -> None:
assert post_response.status == 204
response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}")
assert response.status == 200
assert response.ok
async def test_post_exception(client: TestClient, package_ahriman: Package) -> None:
@ -100,7 +100,7 @@ async def test_post_light(client: TestClient, package_ahriman: Package) -> None:
assert post_response.status == 204
response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}")
assert response.status == 200
assert response.ok
statuses = {
Package.from_json(item["package"]).base: BuildStatus.from_json(item["status"])
for item in await response.json()

View File

@ -15,7 +15,7 @@ async def test_get(client: TestClient, package_ahriman: Package, package_python_
json={"status": BuildStatusEnum.Success.value, "package": package_python_schedule.view()})
response = await client.get("/status-api/v1/packages")
assert response.status == 200
assert response.ok
packages = [Package.from_json(item["package"]) for item in await response.json()]
assert packages

View File

@ -14,7 +14,7 @@ async def test_get(client: TestClient, package_ahriman: Package) -> None:
json={"status": BuildStatusEnum.Success.value, "package": package_ahriman.view()})
response = await client.get("/status-api/v1/status")
assert response.status == 200
assert response.ok
json = await response.json()
assert json["version"] == version.__version__

View File

@ -6,7 +6,7 @@ async def test_get(client_with_auth: TestClient) -> None:
must generate status page correctly (/)
"""
response = await client_with_auth.get("/")
assert response.status == 200
assert response.ok
assert await response.text()
@ -15,7 +15,7 @@ async def test_get_index(client_with_auth: TestClient) -> None:
must generate status page correctly (/index.html)
"""
response = await client_with_auth.get("/index.html")
assert response.status == 200
assert response.ok
assert await response.text()
@ -24,7 +24,7 @@ async def test_get_without_auth(client: TestClient) -> None:
must use dummy authorized_userid function in case if no security library installed
"""
response = await client.get("/")
assert response.status == 200
assert response.ok
assert await response.text()
@ -33,4 +33,4 @@ async def test_get_static(client: TestClient) -> None:
must return static files
"""
response = await client.get("/static/favicon.ico")
assert response.status == 200
assert response.ok

View File

@ -1,9 +1,75 @@
from aiohttp.test_utils import TestClient
from pytest_mock import MockerFixture
from unittest.mock import MagicMock
from ahriman.core.auth.oauth import OAuth
from ahriman.models.user import User
async def test_get_default_validator(client_with_auth: TestClient) -> None:
"""
must return 405 in case if no OAuth enabled
"""
get_response = await client_with_auth.get("/user-api/v1/login")
assert get_response.status == 405
async def test_get_redirect_to_oauth(client_with_auth: TestClient) -> None:
"""
must redirect to OAuth service provider in case if no code is supplied
"""
oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
oauth.get_oauth_url.return_value = "https://example.com"
get_response = await client_with_auth.get("/user-api/v1/login")
assert get_response.ok
oauth.get_oauth_url.assert_called_once()
async def test_get_redirect_to_oauth_empty_code(client_with_auth: TestClient) -> None:
"""
must redirect to OAuth service provider in case if empty code is supplied
"""
oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
oauth.get_oauth_url.return_value = "https://example.com"
get_response = await client_with_auth.get("/user-api/v1/login", params={"code": ""})
assert get_response.ok
oauth.get_oauth_url.assert_called_once()
async def test_get(client_with_auth: TestClient, mocker: MockerFixture) -> None:
"""
must login user correctly from OAuth
"""
oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
oauth.get_oauth_username.return_value = "user"
oauth.known_username.return_value = True
oauth.enabled = False # lol
remember_mock = mocker.patch("aiohttp_security.remember")
get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"})
assert get_response.ok
oauth.get_oauth_username.assert_called_with("code")
oauth.known_username.assert_called_with("user")
remember_mock.assert_called_once()
async def test_get_unauthorized(client_with_auth: TestClient, mocker: MockerFixture) -> None:
"""
must return unauthorized from OAuth
"""
oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
oauth.known_username.return_value = False
remember_mock = mocker.patch("aiohttp_security.remember")
get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"})
assert get_response.status == 401
remember_mock.assert_not_called()
async def test_post(client_with_auth: TestClient, user: User, mocker: MockerFixture) -> None:
"""
must login user correctly
@ -12,10 +78,10 @@ async def test_post(client_with_auth: TestClient, user: User, mocker: MockerFixt
remember_mock = mocker.patch("aiohttp_security.remember")
post_response = await client_with_auth.post("/user-api/v1/login", json=payload)
assert post_response.status == 200
assert post_response.ok
post_response = await client_with_auth.post("/user-api/v1/login", data=payload)
assert post_response.status == 200
assert post_response.ok
remember_mock.assert_called()
@ -26,7 +92,7 @@ async def test_post_skip(client: TestClient, user: User) -> None:
"""
payload = {"username": user.username, "password": user.password}
post_response = await client.post("/user-api/v1/login", json=payload)
assert post_response.status == 200
assert post_response.ok
async def test_post_unauthorized(client_with_auth: TestClient, user: User, mocker: MockerFixture) -> None:

View File

@ -11,7 +11,7 @@ async def test_post(client_with_auth: TestClient, mocker: MockerFixture) -> None
forget_mock = mocker.patch("aiohttp_security.forget")
post_response = await client_with_auth.post("/user-api/v1/logout")
assert post_response.status == 200
assert post_response.ok
forget_mock.assert_called_once()
@ -32,4 +32,4 @@ async def test_post_disabled(client: TestClient) -> None:
must raise exception if auth is disabled
"""
post_response = await client.post("/user-api/v1/logout")
assert post_response.status == 200
assert post_response.ok

View File

@ -10,6 +10,10 @@ root = /
[auth]
allow_read_only = no
client_id = client_id
client_secret = client_secret
oauth_provider = GoogleClient
oauth_scopes = https://www.googleapis.com/auth/userinfo.email
salt = salt
[build]