full coverage

This commit is contained in:
Evgenii Alekseev 2021-09-12 19:59:06 +03:00
parent b6950ba554
commit 9e782bb0b1
22 changed files with 309 additions and 75 deletions

View File

@ -30,7 +30,6 @@ Base authorization settings. `OAuth2` provider requires `aioauth-client` library
* `client_secret` - OAuth2 application client secret key, string, required in case if `oauth2` is used. * `client_secret` - OAuth2 application client secret key, string, required in case if `oauth2` is used.
* `max_age` - parameter which controls both cookie expiration and token expiration inside the service, integer, optional, default is 7 days. * `max_age` - parameter which controls both cookie expiration and token expiration inside the service, integer, optional, default is 7 days.
* `oauth_provider` - OAuth2 provider class name as is in `aioauth-client` (e.g. `GoogleClient`, `GithubClient` etc), string, required in case if `oauth2` is used. * `oauth_provider` - OAuth2 provider class name as is in `aioauth-client` (e.g. `GoogleClient`, `GithubClient` etc), string, required in case if `oauth2` is used.
* `oauth_redirect_uri` - full URI for OAuth2 redirect, must point to `/user-api/v1/login`, e.g. `https://example.com/user-api/v1/login`, string, required in case if `oauth2` is used.
* `oauth_scopes` - scopes list for OAuth2 provider, which will allow retrieving user email (which is used for checking user permissions), e.g. `https://www.googleapis.com/auth/userinfo.email` for `GoogleClient` or `user:email` for `GithubClient`, space separated list of strings, required in case if `oauth2` is used. * `oauth_scopes` - scopes list for OAuth2 provider, which will allow retrieving user email (which is used for checking user permissions), e.g. `https://www.googleapis.com/auth/userinfo.email` for `GoogleClient` or `user:email` for `GithubClient`, space separated list of strings, required in case if `oauth2` is used.
* `salt` - password hash salt, string, required in case if authorization enabled (automatically generated by `create-user` subcommand). * `salt` - password hash salt, string, required in case if authorization enabled (automatically generated by `create-user` subcommand).
@ -127,7 +126,7 @@ Group name must refer to architecture, e.g. it should be `s3:x86_64` for x86_64
Web server settings. If any of `host`/`port` is not set, web integration will be disabled. Group name must refer to architecture, e.g. it should be `web:x86_64` for x86_64 architecture. Web server settings. If any of `host`/`port` is not set, web integration will be disabled. Group name must refer to architecture, e.g. it should be `web:x86_64` for x86_64 architecture.
* `address` - optional address in form `proto://host:port` (`port` can be omitted in case of default `proto` ports), will be used instead of `http://{host}:{port}` in case if set, string, optional. * `address` - optional address in form `proto://host:port` (`port` can be omitted in case of default `proto` ports), will be used instead of `http://{host}:{port}` in case if set, string, optional. This option is required in case if `OAuth` provider is used.
* `host` - host to bind, string, optional. * `host` - host to bind, string, optional.
* `password` - password to authorize in web service in order to update service status, string, required in case if authorization enabled. * `password` - password to authorize in web service in order to update service status, string, required in case if authorization enabled.
* `port` - port to bind, int, optional. * `port` - port to bind, int, optional.

View File

@ -12,6 +12,8 @@ root = /
target = disabled target = disabled
allow_read_only = yes allow_read_only = yes
max_age = 604800 max_age = 604800
oauth_provider = GoogleClient
oauth_scopes = https://www.googleapis.com/auth/userinfo.email
[build] [build]
archbuild_flags = archbuild_flags =

View File

@ -19,6 +19,8 @@
# #
from __future__ import annotations from __future__ import annotations
import logging
from typing import Dict, Optional, Type from typing import Dict, Optional, Type
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
@ -47,6 +49,8 @@ class Auth:
:param configuration: configuration instance :param configuration: configuration instance
:param provider: authorization type definition :param provider: authorization type definition
""" """
self.logger = logging.getLogger("http")
self.allow_read_only = configuration.getboolean("auth", "allow_read_only") self.allow_read_only = configuration.getboolean("auth", "allow_read_only")
self.allowed_paths = set(configuration.getlist("auth", "allowed_paths")) self.allowed_paths = set(configuration.getlist("auth", "allowed_paths"))
self.allowed_paths.update(self.ALLOWED_PATHS) self.allowed_paths.update(self.ALLOWED_PATHS)
@ -124,7 +128,7 @@ class Auth:
return False # request without context is not allowed return False # request without context is not allowed
return uri in self.allowed_paths or any(uri.startswith(path) for path in self.allowed_paths_groups) return uri in self.allowed_paths or any(uri.startswith(path) for path in self.allowed_paths_groups)
async def known_username(self, username: str) -> bool: # pylint: disable=no-self-use async def known_username(self, username: Optional[str]) -> bool: # pylint: disable=no-self-use
""" """
check if user is known check if user is known
:param username: username :param username: username

View File

@ -64,13 +64,13 @@ class Mapping(Auth):
normalized_user = username.lower() normalized_user = username.lower()
return self._users.get(normalized_user) return self._users.get(normalized_user)
async def known_username(self, username: str) -> bool: async def known_username(self, username: Optional[str]) -> bool:
""" """
check if user is known check if user is known
:param username: username :param username: username
:return: True in case if user is known and can be authorized and False otherwise :return: True in case if user is known and can be authorized and False otherwise
""" """
return self.get_user(username) is not None return username is not None and self.get_user(username) is not None
async def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool: async def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool:
""" """

View File

@ -19,7 +19,7 @@
# #
import aioauth_client # type: ignore import aioauth_client # type: ignore
from typing import Type from typing import Optional, Type
from ahriman.core.auth.mapping import Mapping from ahriman.core.auth.mapping import Mapping
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
@ -47,7 +47,9 @@ class OAuth(Mapping):
Mapping.__init__(self, configuration, provider) Mapping.__init__(self, configuration, provider)
self.client_id = configuration.get("auth", "client_id") self.client_id = configuration.get("auth", "client_id")
self.client_secret = configuration.get("auth", "client_secret") self.client_secret = configuration.get("auth", "client_secret")
self.redirect_uri = configuration.get("auth", "oauth_redirect_uri") # in order to use OAuth feature the service must be publicity available
# thus we expect that address is set
self.redirect_uri = f"""{configuration.get("web", "address")}/user-api/v1/login"""
self.provider = self.get_provider(configuration.get("auth", "oauth_provider")) self.provider = self.get_provider(configuration.get("auth", "oauth_provider"))
# it is list but we will have to convert to string it anyway # it is list but we will have to convert to string it anyway
self.scopes = configuration.get("auth", "oauth_scopes") self.scopes = configuration.get("auth", "oauth_scopes")
@ -91,16 +93,21 @@ class OAuth(Mapping):
uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri) uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri)
return uri return uri
async def get_oauth_username(self, code: str) -> str: async def get_oauth_username(self, code: str) -> Optional[str]:
""" """
extract OAuth username from remote extract OAuth username from remote
:param code: authorization code provided by external service :param code: authorization code provided by external service
:return: username as is in OAuth provider :return: username as is in OAuth provider
""" """
client = self.get_client() try:
access_token, _ = await client.get_access_token(code, redirect_uri=self.redirect_uri) client = self.get_client()
client.access_token = access_token access_token, _ = await client.get_access_token(code, redirect_uri=self.redirect_uri)
client.access_token = access_token
user, _ = await client.user_info() print(f"HEEELOOOO {client}")
username: str = user.email user, _ = await client.user_info()
return username username: str = user.email
return username
except Exception:
self.logger.exception("got exception while performing request")
return None

View File

@ -41,10 +41,10 @@ class LoginView(BaseView):
code = self.request.query.getone("code", default=None) code = self.request.query.getone("code", default=None)
oauth_provider = self.validator oauth_provider = self.validator
if not isinstance(oauth_provider, OAuth): if not isinstance(oauth_provider, OAuth): # there is actually property, but mypy does not like it anyway
raise HTTPMethodNotAllowed(self.request.method, ["POST"]) raise HTTPMethodNotAllowed(self.request.method, ["POST"])
if code is None: if not code:
return HTTPFound(oauth_provider.get_oauth_url()) return HTTPFound(oauth_provider.get_oauth_url())
response = HTTPFound("/") response = HTTPFound("/")

View File

@ -1,13 +1,26 @@
import pytest import pytest
from ahriman.core.auth.mapping import Mapping from ahriman.core.auth.mapping import Mapping
from ahriman.core.auth.oauth import OAuth
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
@pytest.fixture @pytest.fixture
def mapping_auth(configuration: Configuration) -> Mapping: def mapping(configuration: Configuration) -> Mapping:
""" """
auth provider fixture auth provider fixture
:param configuration: configuration fixture
:return: auth service instance :return: auth service instance
""" """
return Mapping(configuration) return Mapping(configuration)
@pytest.fixture
def oauth(configuration: Configuration) -> OAuth:
"""
OAuth provider fixture
:param configuration: configuration fixture
:return: OAuth2 service instance
"""
configuration.set("web", "address", "https://example.com")
return OAuth(configuration)

View File

@ -2,12 +2,21 @@ import pytest
from ahriman.core.auth.auth import Auth from ahriman.core.auth.auth import Auth
from ahriman.core.auth.mapping import Mapping from ahriman.core.auth.mapping import Mapping
from ahriman.core.auth.oauth import OAuth
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
from ahriman.core.exceptions import DuplicateUser from ahriman.core.exceptions import DuplicateUser
from ahriman.models.user import User from ahriman.models.user import User
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
def test_auth_control(auth: Auth) -> None:
"""
must return a control for authorization
"""
assert auth.auth_control
assert "button" in auth.auth_control # I think it should be button
def test_load_dummy(configuration: Configuration) -> None: def test_load_dummy(configuration: Configuration) -> None:
""" """
must load dummy validator if authorization is not enabled must load dummy validator if authorization is not enabled
@ -34,7 +43,17 @@ def test_load_mapping(configuration: Configuration) -> None:
assert isinstance(auth, Mapping) assert isinstance(auth, Mapping)
def test_get_users(mapping_auth: Auth, configuration: Configuration) -> None: def test_load_oauth(configuration: Configuration) -> None:
"""
must load OAuth2 validator if option set
"""
configuration.set_option("auth", "target", "oauth")
configuration.set_option("web", "address", "https://example.com")
auth = Auth.load(configuration)
assert isinstance(auth, OAuth)
def test_get_users(mapping: Auth, configuration: Configuration) -> None:
""" """
must return valid user list must return valid user list
""" """
@ -48,12 +67,12 @@ def test_get_users(mapping_auth: Auth, configuration: Configuration) -> None:
read_section = Configuration.section_name("auth", user_read.access.value) read_section = Configuration.section_name("auth", user_read.access.value)
configuration.set_option(read_section, user_read.username, user_read.password) configuration.set_option(read_section, user_read.username, user_read.password)
users = mapping_auth.get_users(configuration) users = mapping.get_users(configuration)
expected = {user_write.username: user_write, user_read.username: user_read} expected = {user_write.username: user_write, user_read.username: user_read}
assert users == expected assert users == expected
def test_get_users_normalized(mapping_auth: Auth, configuration: Configuration) -> None: def test_get_users_normalized(mapping: Auth, configuration: Configuration) -> None:
""" """
must return user list with normalized usernames in keys must return user list with normalized usernames in keys
""" """
@ -61,13 +80,13 @@ def test_get_users_normalized(mapping_auth: Auth, configuration: Configuration)
read_section = Configuration.section_name("auth", user.access.value) read_section = Configuration.section_name("auth", user.access.value)
configuration.set_option(read_section, user.username, user.password) configuration.set_option(read_section, user.username, user.password)
users = mapping_auth.get_users(configuration) users = mapping.get_users(configuration)
expected = user.username.lower() expected = user.username.lower()
assert expected in users assert expected in users
assert users[expected].username == expected assert users[expected].username == expected
def test_get_users_duplicate(mapping_auth: Auth, configuration: Configuration, user: User) -> None: def test_get_users_duplicate(mapping: Auth, configuration: Configuration, user: User) -> None:
""" """
must raise exception on duplicate username must raise exception on duplicate username
""" """
@ -77,7 +96,7 @@ def test_get_users_duplicate(mapping_auth: Auth, configuration: Configuration, u
configuration.set_option(read_section, user.username, user.password) configuration.set_option(read_section, user.username, user.password)
with pytest.raises(DuplicateUser): with pytest.raises(DuplicateUser):
mapping_auth.get_users(configuration) mapping.get_users(configuration)
async def test_check_credentials(auth: Auth, user: User) -> None: async def test_check_credentials(auth: Auth, user: User) -> None:

View File

@ -3,70 +3,71 @@ from ahriman.models.user import User
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
async def test_check_credentials(mapping_auth: Mapping, user: User) -> None: async def test_check_credentials(mapping: Mapping, user: User) -> None:
""" """
must return true for valid credentials must return true for valid credentials
""" """
current_password = user.password current_password = user.password
user.password = user.hash_password(mapping_auth.salt) user.password = user.hash_password(mapping.salt)
mapping_auth._users[user.username] = user mapping._users[user.username] = user
assert await mapping_auth.check_credentials(user.username, current_password) assert await mapping.check_credentials(user.username, current_password)
# here password is hashed so it is invalid # here password is hashed so it is invalid
assert not await mapping_auth.check_credentials(user.username, user.password) assert not await mapping.check_credentials(user.username, user.password)
async def test_check_credentials_empty(mapping_auth: Mapping) -> None: async def test_check_credentials_empty(mapping: Mapping) -> None:
""" """
must reject on empty credentials must reject on empty credentials
""" """
assert not await mapping_auth.check_credentials(None, "") assert not await mapping.check_credentials(None, "")
assert not await mapping_auth.check_credentials("", None) assert not await mapping.check_credentials("", None)
assert not await mapping_auth.check_credentials(None, None) assert not await mapping.check_credentials(None, None)
async def test_check_credentials_unknown(mapping_auth: Mapping, user: User) -> None: async def test_check_credentials_unknown(mapping: Mapping, user: User) -> None:
""" """
must reject on unknown user must reject on unknown user
""" """
assert not await mapping_auth.check_credentials(user.username, user.password) assert not await mapping.check_credentials(user.username, user.password)
def test_get_user(mapping_auth: Mapping, user: User) -> None: def test_get_user(mapping: Mapping, user: User) -> None:
""" """
must return user from storage by username must return user from storage by username
""" """
mapping_auth._users[user.username] = user mapping._users[user.username] = user
assert mapping_auth.get_user(user.username) == user assert mapping.get_user(user.username) == user
def test_get_user_normalized(mapping_auth: Mapping, user: User) -> None: def test_get_user_normalized(mapping: Mapping, user: User) -> None:
""" """
must return user from storage by username case-insensitive must return user from storage by username case-insensitive
""" """
mapping_auth._users[user.username] = user mapping._users[user.username] = user
assert mapping_auth.get_user(user.username.upper()) == user assert mapping.get_user(user.username.upper()) == user
def test_get_user_unknown(mapping_auth: Mapping, user: User) -> None: def test_get_user_unknown(mapping: Mapping, user: User) -> None:
""" """
must return None in case if no user found must return None in case if no user found
""" """
assert mapping_auth.get_user(user.username) is None assert mapping.get_user(user.username) is None
async def test_known_username(mapping_auth: Mapping, user: User) -> None: async def test_known_username(mapping: Mapping, user: User) -> None:
""" """
must allow only known users must allow only known users
""" """
mapping_auth._users[user.username] = user mapping._users[user.username] = user
assert await mapping_auth.known_username(user.username) assert await mapping.known_username(user.username)
assert not await mapping_auth.known_username(user.password) assert not await mapping.known_username(None)
assert not await mapping.known_username(user.password)
async def test_verify_access(mapping_auth: Mapping, user: User) -> None: async def test_verify_access(mapping: Mapping, user: User) -> None:
""" """
must verify user access must verify user access
""" """
mapping_auth._users[user.username] = user mapping._users[user.username] = user
assert await mapping_auth.verify_access(user.username, user.access, None) assert await mapping.verify_access(user.username, user.access, None)
assert not await mapping_auth.verify_access(user.username, UserAccess.Write, None) assert not await mapping.verify_access(user.username, UserAccess.Write, None)

View File

@ -0,0 +1,98 @@
import aioauth_client
import pytest
from pytest_mock import MockerFixture
from ahriman.core.auth.oauth import OAuth
from ahriman.core.exceptions import InvalidOption
def test_auth_control(oauth: OAuth) -> None:
"""
must return a control for authorization
"""
assert oauth.auth_control
assert "<a" in oauth.auth_control # I think it should be a link
def test_get_provider() -> None:
"""
must return valid provider type
"""
assert OAuth.get_provider("OAuth2Client") == aioauth_client.OAuth2Client
assert OAuth.get_provider("GoogleClient") == aioauth_client.GoogleClient
assert OAuth.get_provider("GoogleClient") == aioauth_client.GoogleClient
def test_get_provider_not_a_type() -> None:
"""
must raise an exception if attribute is not a type
"""
with pytest.raises(InvalidOption):
OAuth.get_provider("__version__")
def test_get_provider_invalid_type() -> None:
"""
must raise an exception if attribute is not an OAuth2 client
"""
with pytest.raises(InvalidOption):
OAuth.get_provider("User")
with pytest.raises(InvalidOption):
OAuth.get_provider("OAuth1Client")
def test_get_client(oauth: OAuth) -> None:
"""
must return valid OAuth2 client
"""
client = oauth.get_client()
assert isinstance(client, aioauth_client.GoogleClient)
assert client.client_id == oauth.client_id
assert client.client_secret == oauth.client_secret
def test_get_oauth_url(oauth: OAuth, mocker: MockerFixture) -> None:
"""
must generate valid OAuth authorization URL
"""
authorize_url_mock = mocker.patch("aioauth_client.GoogleClient.get_authorize_url")
oauth.get_oauth_url()
authorize_url_mock.assert_called_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri)
async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
"""
must return authorized user ID
"""
access_token_mock = mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info",
return_value=(aioauth_client.User(email="email"), ""))
email = await oauth.get_oauth_username("code")
access_token_mock.assert_called_with("code", redirect_uri=oauth.redirect_uri)
user_info_mock.assert_called_once()
assert email == "email"
async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixture) -> None:
"""
must return None in case of OAuth request error (get_access_token)
"""
mocker.patch("aioauth_client.GoogleClient.get_access_token", side_effect=Exception())
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info")
email = await oauth.get_oauth_username("code")
assert email is None
user_info_mock.assert_not_called()
async def test_get_oauth_username_exception_2(oauth: OAuth, mocker: MockerFixture) -> None:
"""
must return None in case of OAuth request error (user_info)
"""
mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
mocker.patch("aioauth_client.GoogleClient.user_info", side_effect=Exception())
email = await oauth.get_oauth_username("code")
assert email is None

View File

@ -10,6 +10,7 @@ def test_from_option(user: User) -> None:
# default is read access # default is read access
user.access = UserAccess.Write user.access = UserAccess.Write
assert User.from_option(user.username, user.password) != user assert User.from_option(user.username, user.password) != user
assert User.from_option(user.username, user.password, user.access) == user
def test_from_option_empty() -> None: def test_from_option_empty() -> None:
@ -32,6 +33,26 @@ def test_check_credentials_hash_password(user: User) -> None:
assert not user.check_credentials(user.password, "salt") assert not user.check_credentials(user.password, "salt")
def test_check_credentials_empty_hash(user: User) -> None:
"""
must reject any authorization if the hash is invalid
"""
current_password = user.password
assert not user.check_credentials(current_password, "salt")
user.password = ""
assert not user.check_credentials(current_password, "salt")
def test_hash_password_empty_hash(user: User) -> None:
"""
must return empty string after hash in case if password not set
"""
user.password = ""
assert user.hash_password("salt") == ""
user.password = None
assert user.hash_password("salt") == ""
def test_generate_password() -> None: def test_generate_password() -> None:
""" """
must generate password with specified length must generate password with specified length

View File

@ -9,7 +9,7 @@ async def test_post(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add") add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add")
response = await client.post("/service-api/v1/add", json={"packages": ["ahriman"]}) response = await client.post("/service-api/v1/add", json={"packages": ["ahriman"]})
assert response.status == 200 assert response.ok
add_mock.assert_called_with(["ahriman"], True) add_mock.assert_called_with(["ahriman"], True)
@ -20,7 +20,7 @@ async def test_post_now(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add") add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add")
response = await client.post("/service-api/v1/add", json={"packages": ["ahriman"], "build_now": False}) response = await client.post("/service-api/v1/add", json={"packages": ["ahriman"], "build_now": False})
assert response.status == 200 assert response.ok
add_mock.assert_called_with(["ahriman"], False) add_mock.assert_called_with(["ahriman"], False)
@ -42,5 +42,5 @@ async def test_post_update(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add") add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add")
response = await client.post("/service-api/v1/update", json={"packages": ["ahriman"]}) response = await client.post("/service-api/v1/update", json={"packages": ["ahriman"]})
assert response.status == 200 assert response.ok
add_mock.assert_called_with(["ahriman"], True) add_mock.assert_called_with(["ahriman"], True)

View File

@ -9,7 +9,7 @@ async def test_post(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_remove") add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_remove")
response = await client.post("/service-api/v1/remove", json={"packages": ["ahriman"]}) response = await client.post("/service-api/v1/remove", json={"packages": ["ahriman"]})
assert response.status == 200 assert response.ok
add_mock.assert_called_with(["ahriman"]) add_mock.assert_called_with(["ahriman"])

View File

@ -11,7 +11,7 @@ async def test_get(client: TestClient, aur_package_ahriman: aur.Package, mocker:
mocker.patch("aur.search", return_value=[aur_package_ahriman]) mocker.patch("aur.search", return_value=[aur_package_ahriman])
response = await client.get("/service-api/v1/search", params={"for": "ahriman"}) response = await client.get("/service-api/v1/search", params={"for": "ahriman"})
assert response.status == 200 assert response.ok
assert await response.json() == ["ahriman"] assert await response.json() == ["ahriman"]
@ -33,7 +33,7 @@ async def test_get_join(client: TestClient, mocker: MockerFixture) -> None:
search_mock = mocker.patch("aur.search") search_mock = mocker.patch("aur.search")
response = await client.get("/service-api/v1/search", params=[("for", "ahriman"), ("for", "maybe")]) response = await client.get("/service-api/v1/search", params=[("for", "ahriman"), ("for", "maybe")])
assert response.status == 200 assert response.ok
search_mock.assert_called_with("ahriman maybe") search_mock.assert_called_with("ahriman maybe")
@ -44,7 +44,7 @@ async def test_get_join_filter(client: TestClient, mocker: MockerFixture) -> Non
search_mock = mocker.patch("aur.search") search_mock = mocker.patch("aur.search")
response = await client.get("/service-api/v1/search", params=[("for", "ah"), ("for", "maybe")]) response = await client.get("/service-api/v1/search", params=[("for", "ah"), ("for", "maybe")])
assert response.status == 200 assert response.ok
search_mock.assert_called_with("maybe") search_mock.assert_called_with("maybe")

View File

@ -11,7 +11,7 @@ async def test_get(client: TestClient) -> None:
response = await client.get("/status-api/v1/ahriman") response = await client.get("/status-api/v1/ahriman")
status = BuildStatus.from_json(await response.json()) status = BuildStatus.from_json(await response.json())
assert response.status == 200 assert response.ok
assert status.status == BuildStatusEnum.Unknown assert status.status == BuildStatusEnum.Unknown
@ -26,7 +26,7 @@ async def test_post(client: TestClient) -> None:
response = await client.get("/status-api/v1/ahriman") response = await client.get("/status-api/v1/ahriman")
status = BuildStatus.from_json(await response.json()) status = BuildStatus.from_json(await response.json())
assert response.status == 200 assert response.ok
assert status.status == BuildStatusEnum.Success assert status.status == BuildStatusEnum.Success

View File

@ -14,7 +14,7 @@ async def test_get(client: TestClient, package_ahriman: Package, package_python_
json={"status": BuildStatusEnum.Success.value, "package": package_python_schedule.view()}) json={"status": BuildStatusEnum.Success.value, "package": package_python_schedule.view()})
response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}") response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}")
assert response.status == 200 assert response.ok
packages = [Package.from_json(item["package"]) for item in await response.json()] packages = [Package.from_json(item["package"]) for item in await response.json()]
assert packages assert packages
@ -45,7 +45,7 @@ async def test_delete(client: TestClient, package_ahriman: Package, package_pyth
assert response.status == 404 assert response.status == 404
response = await client.get(f"/status-api/v1/packages/{package_python_schedule.base}") response = await client.get(f"/status-api/v1/packages/{package_python_schedule.base}")
assert response.status == 200 assert response.ok
async def test_delete_unknown(client: TestClient, package_ahriman: Package, package_python_schedule: Package) -> None: async def test_delete_unknown(client: TestClient, package_ahriman: Package, package_python_schedule: Package) -> None:
@ -62,7 +62,7 @@ async def test_delete_unknown(client: TestClient, package_ahriman: Package, pack
assert response.status == 404 assert response.status == 404
response = await client.get(f"/status-api/v1/packages/{package_python_schedule.base}") response = await client.get(f"/status-api/v1/packages/{package_python_schedule.base}")
assert response.status == 200 assert response.ok
async def test_post(client: TestClient, package_ahriman: Package) -> None: async def test_post(client: TestClient, package_ahriman: Package) -> None:
@ -75,7 +75,7 @@ async def test_post(client: TestClient, package_ahriman: Package) -> None:
assert post_response.status == 204 assert post_response.status == 204
response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}") response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}")
assert response.status == 200 assert response.ok
async def test_post_exception(client: TestClient, package_ahriman: Package) -> None: async def test_post_exception(client: TestClient, package_ahriman: Package) -> None:
@ -100,7 +100,7 @@ async def test_post_light(client: TestClient, package_ahriman: Package) -> None:
assert post_response.status == 204 assert post_response.status == 204
response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}") response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}")
assert response.status == 200 assert response.ok
statuses = { statuses = {
Package.from_json(item["package"]).base: BuildStatus.from_json(item["status"]) Package.from_json(item["package"]).base: BuildStatus.from_json(item["status"])
for item in await response.json() for item in await response.json()

View File

@ -15,7 +15,7 @@ async def test_get(client: TestClient, package_ahriman: Package, package_python_
json={"status": BuildStatusEnum.Success.value, "package": package_python_schedule.view()}) json={"status": BuildStatusEnum.Success.value, "package": package_python_schedule.view()})
response = await client.get("/status-api/v1/packages") response = await client.get("/status-api/v1/packages")
assert response.status == 200 assert response.ok
packages = [Package.from_json(item["package"]) for item in await response.json()] packages = [Package.from_json(item["package"]) for item in await response.json()]
assert packages assert packages

View File

@ -14,7 +14,7 @@ async def test_get(client: TestClient, package_ahriman: Package) -> None:
json={"status": BuildStatusEnum.Success.value, "package": package_ahriman.view()}) json={"status": BuildStatusEnum.Success.value, "package": package_ahriman.view()})
response = await client.get("/status-api/v1/status") response = await client.get("/status-api/v1/status")
assert response.status == 200 assert response.ok
json = await response.json() json = await response.json()
assert json["version"] == version.__version__ assert json["version"] == version.__version__

View File

@ -6,7 +6,7 @@ async def test_get(client_with_auth: TestClient) -> None:
must generate status page correctly (/) must generate status page correctly (/)
""" """
response = await client_with_auth.get("/") response = await client_with_auth.get("/")
assert response.status == 200 assert response.ok
assert await response.text() assert await response.text()
@ -15,7 +15,7 @@ async def test_get_index(client_with_auth: TestClient) -> None:
must generate status page correctly (/index.html) must generate status page correctly (/index.html)
""" """
response = await client_with_auth.get("/index.html") response = await client_with_auth.get("/index.html")
assert response.status == 200 assert response.ok
assert await response.text() assert await response.text()
@ -24,7 +24,7 @@ async def test_get_without_auth(client: TestClient) -> None:
must use dummy authorized_userid function in case if no security library installed must use dummy authorized_userid function in case if no security library installed
""" """
response = await client.get("/") response = await client.get("/")
assert response.status == 200 assert response.ok
assert await response.text() assert await response.text()
@ -33,4 +33,4 @@ async def test_get_static(client: TestClient) -> None:
must return static files must return static files
""" """
response = await client.get("/static/favicon.ico") response = await client.get("/static/favicon.ico")
assert response.status == 200 assert response.ok

View File

@ -1,9 +1,75 @@
from aiohttp.test_utils import TestClient from aiohttp.test_utils import TestClient
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from unittest.mock import MagicMock
from ahriman.core.auth.oauth import OAuth
from ahriman.models.user import User from ahriman.models.user import User
async def test_get_default_validator(client_with_auth: TestClient) -> None:
"""
must return 405 in case if no OAuth enabled
"""
get_response = await client_with_auth.get("/user-api/v1/login")
assert get_response.status == 405
async def test_get_redirect_to_oauth(client_with_auth: TestClient) -> None:
"""
must redirect to OAuth service provider in case if no code is supplied
"""
oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
oauth.get_oauth_url.return_value = "https://example.com"
get_response = await client_with_auth.get("/user-api/v1/login")
assert get_response.ok
oauth.get_oauth_url.assert_called_once()
async def test_get_redirect_to_oauth_empty_code(client_with_auth: TestClient) -> None:
"""
must redirect to OAuth service provider in case if empty code is supplied
"""
oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
oauth.get_oauth_url.return_value = "https://example.com"
get_response = await client_with_auth.get("/user-api/v1/login", params={"code": ""})
assert get_response.ok
oauth.get_oauth_url.assert_called_once()
async def test_get(client_with_auth: TestClient, mocker: MockerFixture) -> None:
"""
must login user correctly from OAuth
"""
oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
oauth.get_oauth_username.return_value = "user"
oauth.known_username.return_value = True
oauth.enabled = False # lol
remember_mock = mocker.patch("aiohttp_security.remember")
get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"})
assert get_response.ok
oauth.get_oauth_username.assert_called_with("code")
oauth.known_username.assert_called_with("user")
remember_mock.assert_called_once()
async def test_get_unauthorized(client_with_auth: TestClient, mocker: MockerFixture) -> None:
"""
must return unauthorized from OAuth
"""
oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
oauth.known_username.return_value = False
remember_mock = mocker.patch("aiohttp_security.remember")
get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"})
assert get_response.status == 401
remember_mock.assert_not_called()
async def test_post(client_with_auth: TestClient, user: User, mocker: MockerFixture) -> None: async def test_post(client_with_auth: TestClient, user: User, mocker: MockerFixture) -> None:
""" """
must login user correctly must login user correctly
@ -12,10 +78,10 @@ async def test_post(client_with_auth: TestClient, user: User, mocker: MockerFixt
remember_mock = mocker.patch("aiohttp_security.remember") remember_mock = mocker.patch("aiohttp_security.remember")
post_response = await client_with_auth.post("/user-api/v1/login", json=payload) post_response = await client_with_auth.post("/user-api/v1/login", json=payload)
assert post_response.status == 200 assert post_response.ok
post_response = await client_with_auth.post("/user-api/v1/login", data=payload) post_response = await client_with_auth.post("/user-api/v1/login", data=payload)
assert post_response.status == 200 assert post_response.ok
remember_mock.assert_called() remember_mock.assert_called()
@ -26,7 +92,7 @@ async def test_post_skip(client: TestClient, user: User) -> None:
""" """
payload = {"username": user.username, "password": user.password} payload = {"username": user.username, "password": user.password}
post_response = await client.post("/user-api/v1/login", json=payload) post_response = await client.post("/user-api/v1/login", json=payload)
assert post_response.status == 200 assert post_response.ok
async def test_post_unauthorized(client_with_auth: TestClient, user: User, mocker: MockerFixture) -> None: async def test_post_unauthorized(client_with_auth: TestClient, user: User, mocker: MockerFixture) -> None:

View File

@ -11,7 +11,7 @@ async def test_post(client_with_auth: TestClient, mocker: MockerFixture) -> None
forget_mock = mocker.patch("aiohttp_security.forget") forget_mock = mocker.patch("aiohttp_security.forget")
post_response = await client_with_auth.post("/user-api/v1/logout") post_response = await client_with_auth.post("/user-api/v1/logout")
assert post_response.status == 200 assert post_response.ok
forget_mock.assert_called_once() forget_mock.assert_called_once()
@ -32,4 +32,4 @@ async def test_post_disabled(client: TestClient) -> None:
must raise exception if auth is disabled must raise exception if auth is disabled
""" """
post_response = await client.post("/user-api/v1/logout") post_response = await client.post("/user-api/v1/logout")
assert post_response.status == 200 assert post_response.ok

View File

@ -10,6 +10,10 @@ root = /
[auth] [auth]
allow_read_only = no allow_read_only = no
client_id = client_id
client_secret = client_secret
oauth_provider = GoogleClient
oauth_scopes = https://www.googleapis.com/auth/userinfo.email
salt = salt salt = salt
[build] [build]