- {% if auth_enabled %}
+ {% if auth.enabled %}
{% include "build-status/login-modal.jinja2" %}
{% endif %}
diff --git a/setup.py b/setup.py
index 8e0eb5f1..e49160fc 100644
--- a/setup.py
+++ b/setup.py
@@ -106,6 +106,7 @@ setup(
"Jinja2",
"aiohttp",
"aiohttp_jinja2",
+ "aioauth-client",
"aiohttp_session",
"aiohttp_security",
"cryptography",
diff --git a/src/ahriman/core/auth/auth.py b/src/ahriman/core/auth/auth.py
index c33982e9..a3be9ef4 100644
--- a/src/ahriman/core/auth/auth.py
+++ b/src/ahriman/core/auth/auth.py
@@ -19,10 +19,14 @@
#
from __future__ import annotations
-from typing import Optional, Type
+import logging
+
+from typing import Dict, Optional, Type
from ahriman.core.configuration import Configuration
+from ahriman.core.exceptions import DuplicateUser
from ahriman.models.auth_settings import AuthSettings
+from ahriman.models.user import User
from ahriman.models.user_access import UserAccess
@@ -45,6 +49,8 @@ class Auth:
:param configuration: configuration instance
:param provider: authorization type definition
"""
+ self.logger = logging.getLogger("http")
+
self.allow_read_only = configuration.getboolean("auth", "allow_read_only")
self.allowed_paths = set(configuration.getlist("auth", "allowed_paths"))
self.allowed_paths.update(self.ALLOWED_PATHS)
@@ -53,6 +59,17 @@ class Auth:
self.enabled = provider.is_enabled
self.max_age = configuration.getint("auth", "max_age", fallback=7 * 24 * 3600)
+ @property
+ def auth_control(self) -> str:
+ """
+ This workaround is required to make different behaviour for login interface.
+ In case of internal authentication it must provide an interface (modal form) to login with button sends POST
+ request. But for an external providers behaviour can be different: e.g. OAuth provider requires sending GET
+ request to external resource
+ :return: login control as html code to insert
+ """
+ return """
login """
+
@classmethod
def load(cls: Type[Auth], configuration: Configuration) -> Auth:
"""
@@ -62,11 +79,33 @@ class Auth:
"""
provider = AuthSettings.from_option(configuration.get("auth", "target", fallback="disabled"))
if provider == AuthSettings.Configuration:
- from ahriman.core.auth.mapping_auth import MappingAuth
- return MappingAuth(configuration)
+ from ahriman.core.auth.mapping import Mapping
+ return Mapping(configuration)
+ if provider == AuthSettings.OAuth:
+ from ahriman.core.auth.oauth import OAuth
+ return OAuth(configuration)
return cls(configuration)
- def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool: # pylint: disable=no-self-use
+ @staticmethod
+ def get_users(configuration: Configuration) -> Dict[str, User]:
+ """
+ load users from settings
+ :param configuration: configuration instance
+ :return: map of username to its descriptor
+ """
+ users: Dict[str, User] = {}
+ for role in UserAccess:
+ section = configuration.section_name("auth", role.value)
+ if not configuration.has_section(section):
+ continue
+ for user, password in configuration[section].items():
+ normalized_user = user.lower()
+ if normalized_user in users:
+ raise DuplicateUser(normalized_user)
+ users[normalized_user] = User(normalized_user, password, role)
+ return users
+
+ async def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool: # pylint: disable=no-self-use
"""
validate user password
:param username: username
@@ -76,20 +115,20 @@ class Auth:
del username, password
return True
- def is_safe_request(self, uri: Optional[str], required: UserAccess) -> bool:
+ async def is_safe_request(self, uri: Optional[str], required: UserAccess) -> bool:
"""
check if requested path are allowed without authorization
:param uri: request uri
:param required: required access level
:return: True in case if this URI can be requested without authorization and False otherwise
"""
- if not uri:
- return False # request without context is not allowed
if required == UserAccess.Read and self.allow_read_only:
return True # in case if read right requested and allowed in options
+ if not uri:
+ return False # request without context is not allowed
return uri in self.allowed_paths or any(uri.startswith(path) for path in self.allowed_paths_groups)
- def known_username(self, username: str) -> bool: # pylint: disable=no-self-use
+ async def known_username(self, username: Optional[str]) -> bool: # pylint: disable=no-self-use
"""
check if user is known
:param username: username
@@ -98,7 +137,7 @@ class Auth:
del username
return True
- def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool: # pylint: disable=no-self-use
+ async def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool: # pylint: disable=no-self-use
"""
validate if user has access to requested resource
:param username: username
diff --git a/src/ahriman/core/auth/mapping_auth.py b/src/ahriman/core/auth/mapping.py
similarity index 70%
rename from src/ahriman/core/auth/mapping_auth.py
rename to src/ahriman/core/auth/mapping.py
index 3a25b9d3..778e7e6f 100644
--- a/src/ahriman/core/auth/mapping_auth.py
+++ b/src/ahriman/core/auth/mapping.py
@@ -17,17 +17,16 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see
.
#
-from typing import Dict, Optional
+from typing import Optional
from ahriman.core.auth.auth import Auth
from ahriman.core.configuration import Configuration
-from ahriman.core.exceptions import DuplicateUser
from ahriman.models.auth_settings import AuthSettings
from ahriman.models.user import User
from ahriman.models.user_access import UserAccess
-class MappingAuth(Auth):
+class Mapping(Auth):
"""
user authorization based on mapping from configuration file
:ivar salt: random generated string to salt passwords
@@ -44,26 +43,7 @@ class MappingAuth(Auth):
self.salt = configuration.get("auth", "salt")
self._users = self.get_users(configuration)
- @staticmethod
- def get_users(configuration: Configuration) -> Dict[str, User]:
- """
- load users from settings
- :param configuration: configuration instance
- :return: map of username to its descriptor
- """
- users: Dict[str, User] = {}
- for role in UserAccess:
- section = configuration.section_name("auth", role.value)
- if not configuration.has_section(section):
- continue
- for user, password in configuration[section].items():
- normalized_user = user.lower()
- if normalized_user in users:
- raise DuplicateUser(normalized_user)
- users[normalized_user] = User(normalized_user, password, role)
- return users
-
- def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool:
+ async def check_credentials(self, username: Optional[str], password: Optional[str]) -> bool:
"""
validate user password
:param username: username
@@ -84,15 +64,15 @@ class MappingAuth(Auth):
normalized_user = username.lower()
return self._users.get(normalized_user)
- def known_username(self, username: str) -> bool:
+ async def known_username(self, username: Optional[str]) -> bool:
"""
check if user is known
:param username: username
:return: True in case if user is known and can be authorized and False otherwise
"""
- return self.get_user(username) is not None
+ return username is not None and self.get_user(username) is not None
- def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool:
+ async def verify_access(self, username: str, required: UserAccess, context: Optional[str]) -> bool:
"""
validate if user has access to requested resource
:param username: username
@@ -100,6 +80,5 @@ class MappingAuth(Auth):
:param context: URI request path
:return: True in case if user is allowed to do this request and False otherwise
"""
- del context
user = self.get_user(username)
return user is not None and user.verify_access(required)
diff --git a/src/ahriman/core/auth/oauth.py b/src/ahriman/core/auth/oauth.py
new file mode 100644
index 00000000..9784aa69
--- /dev/null
+++ b/src/ahriman/core/auth/oauth.py
@@ -0,0 +1,113 @@
+#
+# Copyright (c) 2021 ahriman team.
+#
+# This file is part of ahriman
+# (see https://github.com/arcan1s/ahriman).
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see
.
+#
+import aioauth_client # type: ignore
+
+from typing import Optional, Type
+
+from ahriman.core.auth.mapping import Mapping
+from ahriman.core.configuration import Configuration
+from ahriman.core.exceptions import InvalidOption
+from ahriman.models.auth_settings import AuthSettings
+
+
+class OAuth(Mapping):
+ """
+ OAuth user authorization.
+ It is required to create application first and put application credentials.
+ :ivar client_id: application client id
+ :ivar client_secret: application client secret key
+ :ivar provider: provider class, should be one of aiohttp-client provided classes
+ :ivar redirect_uri: redirect URI registered in provider
+ :ivar scopes: list of scopes required by the application
+ """
+
+ def __init__(self, configuration: Configuration, provider: AuthSettings = AuthSettings.OAuth) -> None:
+ """
+ default constructor
+ :param configuration: configuration instance
+ :param provider: authorization type definition
+ """
+ Mapping.__init__(self, configuration, provider)
+ self.client_id = configuration.get("auth", "client_id")
+ self.client_secret = configuration.get("auth", "client_secret")
+ # in order to use OAuth feature the service must be publicity available
+ # thus we expect that address is set
+ self.redirect_uri = f"""{configuration.get("web", "address")}/user-api/v1/login"""
+ self.provider = self.get_provider(configuration.get("auth", "oauth_provider"))
+ # it is list but we will have to convert to string it anyway
+ self.scopes = configuration.get("auth", "oauth_scopes")
+
+ @property
+ def auth_control(self) -> str:
+ """
+ :return: login control as html code to insert
+ """
+ return """
login """
+
+ @staticmethod
+ def get_provider(name: str) -> Type[aioauth_client.OAuth2Client]:
+ """
+ load OAuth2 provider by name
+ :param name: name of the provider. Must be valid class defined in aioauth-client library
+ :return: loaded provider type
+ """
+ provider: Type[aioauth_client.OAuth2Client] = getattr(aioauth_client, name)
+ try:
+ is_oauth2_client = issubclass(provider, aioauth_client.OAuth2Client)
+ except TypeError: # what if it is random string?
+ is_oauth2_client = False
+ if not is_oauth2_client:
+ raise InvalidOption(name)
+ return provider
+
+ def get_client(self) -> aioauth_client.OAuth2Client:
+ """
+ load client from parameters
+ :return: generated client according to current settings
+ """
+ return self.provider(client_id=self.client_id, client_secret=self.client_secret)
+
+ def get_oauth_url(self) -> str:
+ """
+ get authorization URI for the specified settings
+ :return: authorization URI as a string
+ """
+ client = self.get_client()
+ uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri)
+ return uri
+
+ async def get_oauth_username(self, code: str) -> Optional[str]:
+ """
+ extract OAuth username from remote
+ :param code: authorization code provided by external service
+ :return: username as is in OAuth provider
+ """
+ try:
+ client = self.get_client()
+ access_token, _ = await client.get_access_token(code, redirect_uri=self.redirect_uri)
+ client.access_token = access_token
+
+ print(f"HEEELOOOO {client}")
+ user, _ = await client.user_info()
+ username: str = user.email
+ return username
+ except Exception:
+ self.logger.exception("got exception while performing request")
+ return None
diff --git a/src/ahriman/core/exceptions.py b/src/ahriman/core/exceptions.py
index 10a09989..a2daf49d 100644
--- a/src/ahriman/core/exceptions.py
+++ b/src/ahriman/core/exceptions.py
@@ -63,11 +63,12 @@ class InitializeException(Exception):
base service initialization exception
"""
- def __init__(self) -> None:
+ def __init__(self, details: str) -> None:
"""
default constructor
+ :param details: details of the exception
"""
- Exception.__init__(self, "Could not load service")
+ Exception.__init__(self, f"Could not load service: {details}")
class InvalidOption(Exception):
diff --git a/src/ahriman/core/util.py b/src/ahriman/core/util.py
index db3f460a..438aac58 100644
--- a/src/ahriman/core/util.py
+++ b/src/ahriman/core/util.py
@@ -46,12 +46,12 @@ def check_output(*args: str, exception: Optional[Exception], cwd: Optional[Path]
if logger is not None:
for line in result.splitlines():
logger.debug(line)
+ return result
except subprocess.CalledProcessError as e:
if e.output is not None and logger is not None:
for line in e.output.splitlines():
logger.debug(line)
raise exception or e
- return result
def exception_response_text(exception: requests.exceptions.HTTPError) -> str:
diff --git a/src/ahriman/models/auth_settings.py b/src/ahriman/models/auth_settings.py
index 46294e30..93b901af 100644
--- a/src/ahriman/models/auth_settings.py
+++ b/src/ahriman/models/auth_settings.py
@@ -30,10 +30,12 @@ class AuthSettings(Enum):
web authorization type
:cvar Disabled: authorization is disabled
:cvar Configuration: configuration based authorization
+ :cvar OAuth: OAuth based provider
"""
Disabled = auto()
Configuration = auto()
+ OAuth = auto()
@classmethod
def from_option(cls: Type[AuthSettings], value: str) -> AuthSettings:
@@ -46,6 +48,8 @@ class AuthSettings(Enum):
return cls.Disabled
if value.lower() in ("configuration", "mapping"):
return cls.Configuration
+ if value.lower() in ('oauth', 'oauth2'):
+ return cls.OAuth
raise InvalidOption(value)
@property
diff --git a/src/ahriman/models/counters.py b/src/ahriman/models/counters.py
index c1ffafbd..13e7e344 100644
--- a/src/ahriman/models/counters.py
+++ b/src/ahriman/models/counters.py
@@ -37,6 +37,7 @@ class Counters:
:ivar failed: packages in failed status count
:ivar success: packages in success status count
"""
+
total: int
unknown: int = 0
pending: int = 0
diff --git a/src/ahriman/models/internal_status.py b/src/ahriman/models/internal_status.py
index 8956e37c..0d452788 100644
--- a/src/ahriman/models/internal_status.py
+++ b/src/ahriman/models/internal_status.py
@@ -34,6 +34,7 @@ class InternalStatus:
:ivar repository: repository name
:ivar version: service version
"""
+
architecture: Optional[str] = None
packages: Counters = field(default=Counters(total=0))
repository: Optional[str] = None
diff --git a/src/ahriman/models/user.py b/src/ahriman/models/user.py
index ef0c58bb..977fd0bf 100644
--- a/src/ahriman/models/user.py
+++ b/src/ahriman/models/user.py
@@ -35,6 +35,7 @@ class User:
:ivar password: hashed user password with salt
:ivar access: user role
"""
+
username: str
password: str
access: UserAccess
@@ -42,16 +43,18 @@ class User:
_HASHER = sha512_crypt
@classmethod
- def from_option(cls: Type[User], username: Optional[str], password: Optional[str]) -> Optional[User]:
+ def from_option(cls: Type[User], username: Optional[str], password: Optional[str],
+ access: UserAccess = UserAccess.Read) -> Optional[User]:
"""
build user descriptor from configuration options
:param username: username
:param password: password as string
+ :param access: optional user access
:return: generated user descriptor if all options are supplied and None otherwise
"""
if username is None or password is None:
return None
- return cls(username, password, UserAccess.Read)
+ return cls(username, password, access)
@staticmethod
def generate_password(length: int) -> str:
@@ -70,7 +73,10 @@ class User:
:param salt: salt for hashed password
:return: True in case if password matches, False otherwise
"""
- verified: bool = self._HASHER.verify(password + salt, self.password)
+ try:
+ verified: bool = self._HASHER.verify(password + salt, self.password)
+ except ValueError:
+ verified = False # the absence of evidence is not the evidence of absence (c) Gin Rummy
return verified
def hash_password(self, salt: str) -> str:
@@ -79,6 +85,10 @@ class User:
:param salt: salt for hashed password
:return: hashed string to store in configuration
"""
+ if not self.password:
+ # in case of empty password we leave it empty. This feature is used by any external (like OAuth) provider
+ # when we do not store any password here
+ return ""
password_hash: str = self._HASHER.hash(self.password + salt)
return password_hash
diff --git a/src/ahriman/web/middlewares/auth_handler.py b/src/ahriman/web/middlewares/auth_handler.py
index 93a211b8..49eb9bbb 100644
--- a/src/ahriman/web/middlewares/auth_handler.py
+++ b/src/ahriman/web/middlewares/auth_handler.py
@@ -48,11 +48,11 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
async def authorized_userid(self, identity: str) -> Optional[str]:
"""
- retrieve authorized username
+ retrieve authenticated username
:param identity: username
:return: user identity (username) in case if user exists and None otherwise
"""
- return identity if self.validator.known_username(identity) else None
+ return identity if await self.validator.known_username(identity) else None
async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool:
"""
@@ -62,7 +62,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy): # type
:param context: URI request path
:return: True in case if user is allowed to perform this request and False otherwise
"""
- return self.validator.verify_access(identity, permission, context)
+ return await self.validator.verify_access(identity, permission, context)
def auth_handler(validator: Auth) -> MiddlewareType:
@@ -78,7 +78,7 @@ def auth_handler(validator: Auth) -> MiddlewareType:
else:
permission = UserAccess.Write
- if not validator.is_safe_request(request.path, permission):
+ if not await validator.is_safe_request(request.path, permission):
await aiohttp_security.check_permission(request, permission, request.path)
return await handler(request)
diff --git a/src/ahriman/web/routes.py b/src/ahriman/web/routes.py
index 980d2d1a..64062bc4 100644
--- a/src/ahriman/web/routes.py
+++ b/src/ahriman/web/routes.py
@@ -61,6 +61,7 @@ def setup_routes(application: Application, static_path: Path) -> None:
GET /status-api/v1/status get web service status itself
+ GET /user-api/v1/login OAuth2 handler for login
POST /user-api/v1/login login to service
POST /user-api/v1/logout logout from service
@@ -92,5 +93,6 @@ def setup_routes(application: Application, static_path: Path) -> None:
application.router.add_get("/status-api/v1/status", StatusView, allow_head=True)
+ application.router.add_get("/user-api/v1/login", LoginView)
application.router.add_post("/user-api/v1/login", LoginView)
application.router.add_post("/user-api/v1/logout", LogoutView)
diff --git a/src/ahriman/web/views/index.py b/src/ahriman/web/views/index.py
index 341ab8bf..f617fa97 100644
--- a/src/ahriman/web/views/index.py
+++ b/src/ahriman/web/views/index.py
@@ -34,9 +34,11 @@ class IndexView(BaseView):
It uses jinja2 templates for report generation, the following variables are allowed:
architecture - repository architecture, string, required
- authorized - alias for `not auth_enabled or auth_username is not None`
- auth_enabled - whether authorization is enabled by configuration or not, boolean, required
- auth_username - authorized user id if any, string. None means not authorized
+ auth - authorization descriptor, required
+ * authenticated - alias to check if user can see the page, boolean, required
+ * control - HTML to insert for login control, HTML string, required
+ * enabled - whether authorization is enabled by configuration or not, boolean, required
+ * username - authenticated username if any, string, null means not authenticated
packages - sorted list of packages properties, required
* base, string
* depends, sorted list of strings
@@ -74,24 +76,27 @@ class IndexView(BaseView):
"status_color": status.status.bootstrap_color(),
"timestamp": pretty_datetime(status.timestamp),
"version": package.version,
- "web_url": package.web_url
+ "web_url": package.web_url,
} for package, status in sorted(self.service.packages, key=lambda item: item[0].base)
]
service = {
"status": self.service.status.status.value,
"status_color": self.service.status.status.badges_color(),
- "timestamp": pretty_datetime(self.service.status.timestamp)
+ "timestamp": pretty_datetime(self.service.status.timestamp),
}
# auth block
auth_username = await authorized_userid(self.request)
- authorized = not self.validator.enabled or self.validator.allow_read_only or auth_username is not None
+ auth = {
+ "authenticated": not self.validator.enabled or self.validator.allow_read_only or auth_username is not None,
+ "control": self.validator.auth_control,
+ "enabled": self.validator.enabled,
+ "username": auth_username,
+ }
return {
"architecture": self.service.architecture,
- "authorized": authorized,
- "auth_enabled": self.validator.enabled,
- "auth_username": auth_username,
+ "auth": auth,
"packages": packages,
"repository": self.service.repository.name,
"service": service,
diff --git a/src/ahriman/web/views/user/login.py b/src/ahriman/web/views/user/login.py
index 8155e9d5..18c05dd0 100644
--- a/src/ahriman/web/views/user/login.py
+++ b/src/ahriman/web/views/user/login.py
@@ -17,7 +17,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see
.
#
-from aiohttp.web import HTTPFound, HTTPUnauthorized, Response
+from aiohttp.web import HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized, Response
from ahriman.core.auth.helpers import remember
from ahriman.web.views.base import BaseView
@@ -28,6 +28,33 @@ class LoginView(BaseView):
login endpoint view
"""
+ async def get(self) -> Response:
+ """
+ OAuth2 response handler
+
+ In case if code provided it will do a request to get user email. In case if no code provided it will redirect
+ to authorization url provided by OAuth client
+
+ :return: redirect to main page
+ """
+ from ahriman.core.auth.oauth import OAuth
+
+ code = self.request.query.getone("code", default=None)
+ oauth_provider = self.validator
+ if not isinstance(oauth_provider, OAuth): # there is actually property, but mypy does not like it anyway
+ raise HTTPMethodNotAllowed(self.request.method, ["POST"])
+
+ if not code:
+ return HTTPFound(oauth_provider.get_oauth_url())
+
+ response = HTTPFound("/")
+ username = await oauth_provider.get_oauth_username(code)
+ if await self.validator.known_username(username):
+ await remember(self.request, response, username)
+ return response
+
+ raise HTTPUnauthorized()
+
async def post(self) -> Response:
"""
login user to service
@@ -44,7 +71,7 @@ class LoginView(BaseView):
username = data.get("username")
response = HTTPFound("/")
- if self.validator.check_credentials(username, data.get("password")):
+ if await self.validator.check_credentials(username, data.get("password")):
await remember(self.request, response, username)
return response
diff --git a/src/ahriman/web/web.py b/src/ahriman/web/web.py
index 94e56c5f..d97f65e5 100644
--- a/src/ahriman/web/web.py
+++ b/src/ahriman/web/web.py
@@ -49,8 +49,9 @@ async def on_startup(application: web.Application) -> None:
try:
application["watcher"].load()
except Exception:
- application.logger.exception("could not load packages")
- raise InitializeException()
+ message = "could not load packages"
+ application.logger.exception(message)
+ raise InitializeException(message)
def run_server(application: web.Application) -> None:
diff --git a/tests/ahriman/core/auth/conftest.py b/tests/ahriman/core/auth/conftest.py
index 4430b262..a21974f0 100644
--- a/tests/ahriman/core/auth/conftest.py
+++ b/tests/ahriman/core/auth/conftest.py
@@ -1,13 +1,26 @@
import pytest
-from ahriman.core.auth.mapping_auth import MappingAuth
+from ahriman.core.auth.mapping import Mapping
+from ahriman.core.auth.oauth import OAuth
from ahriman.core.configuration import Configuration
@pytest.fixture
-def mapping_auth(configuration: Configuration) -> MappingAuth:
+def mapping(configuration: Configuration) -> Mapping:
"""
auth provider fixture
+ :param configuration: configuration fixture
:return: auth service instance
"""
- return MappingAuth(configuration)
+ return Mapping(configuration)
+
+
+@pytest.fixture
+def oauth(configuration: Configuration) -> OAuth:
+ """
+ OAuth provider fixture
+ :param configuration: configuration fixture
+ :return: OAuth2 service instance
+ """
+ configuration.set("web", "address", "https://example.com")
+ return OAuth(configuration)
diff --git a/tests/ahriman/core/auth/test_auth.py b/tests/ahriman/core/auth/test_auth.py
index 5dda1f2a..0a359066 100644
--- a/tests/ahriman/core/auth/test_auth.py
+++ b/tests/ahriman/core/auth/test_auth.py
@@ -1,10 +1,22 @@
+import pytest
+
from ahriman.core.auth.auth import Auth
-from ahriman.core.auth.mapping_auth import MappingAuth
+from ahriman.core.auth.mapping import Mapping
+from ahriman.core.auth.oauth import OAuth
from ahriman.core.configuration import Configuration
+from ahriman.core.exceptions import DuplicateUser
from ahriman.models.user import User
from ahriman.models.user_access import UserAccess
+def test_auth_control(auth: Auth) -> None:
+ """
+ must return a control for authorization
+ """
+ assert auth.auth_control
+ assert "button" in auth.auth_control # I think it should be button
+
+
def test_load_dummy(configuration: Configuration) -> None:
"""
must load dummy validator if authorization is not enabled
@@ -28,63 +40,119 @@ def test_load_mapping(configuration: Configuration) -> None:
"""
configuration.set_option("auth", "target", "configuration")
auth = Auth.load(configuration)
- assert isinstance(auth, MappingAuth)
+ assert isinstance(auth, Mapping)
-def test_check_credentials(auth: Auth, user: User) -> None:
+def test_load_oauth(configuration: Configuration) -> None:
+ """
+ must load OAuth2 validator if option set
+ """
+ configuration.set_option("auth", "target", "oauth")
+ configuration.set_option("web", "address", "https://example.com")
+ auth = Auth.load(configuration)
+ assert isinstance(auth, OAuth)
+
+
+def test_get_users(mapping: Auth, configuration: Configuration) -> None:
+ """
+ must return valid user list
+ """
+ user_write = User("user_write", "pwd_write", UserAccess.Write)
+ write_section = Configuration.section_name("auth", user_write.access.value)
+ configuration.set_option(write_section, user_write.username, user_write.password)
+ user_read = User("user_read", "pwd_read", UserAccess.Read)
+ read_section = Configuration.section_name("auth", user_read.access.value)
+ configuration.set_option(read_section, user_read.username, user_read.password)
+ user_read = User("user_read", "pwd_read", UserAccess.Read)
+ read_section = Configuration.section_name("auth", user_read.access.value)
+ configuration.set_option(read_section, user_read.username, user_read.password)
+
+ users = mapping.get_users(configuration)
+ expected = {user_write.username: user_write, user_read.username: user_read}
+ assert users == expected
+
+
+def test_get_users_normalized(mapping: Auth, configuration: Configuration) -> None:
+ """
+ must return user list with normalized usernames in keys
+ """
+ user = User("UsEr", "pwd_read", UserAccess.Read)
+ read_section = Configuration.section_name("auth", user.access.value)
+ configuration.set_option(read_section, user.username, user.password)
+
+ users = mapping.get_users(configuration)
+ expected = user.username.lower()
+ assert expected in users
+ assert users[expected].username == expected
+
+
+def test_get_users_duplicate(mapping: Auth, configuration: Configuration, user: User) -> None:
+ """
+ must raise exception on duplicate username
+ """
+ write_section = Configuration.section_name("auth", UserAccess.Write.value)
+ configuration.set_option(write_section, user.username, user.password)
+ read_section = Configuration.section_name("auth", UserAccess.Read.value)
+ configuration.set_option(read_section, user.username, user.password)
+
+ with pytest.raises(DuplicateUser):
+ mapping.get_users(configuration)
+
+
+async def test_check_credentials(auth: Auth, user: User) -> None:
"""
must pass any credentials
"""
- assert auth.check_credentials(user.username, user.password)
- assert auth.check_credentials(None, "")
- assert auth.check_credentials("", None)
- assert auth.check_credentials(None, None)
+ assert await auth.check_credentials(user.username, user.password)
+ assert await auth.check_credentials(None, "")
+ assert await auth.check_credentials("", None)
+ assert await auth.check_credentials(None, None)
-def test_is_safe_request(auth: Auth) -> None:
+async def test_is_safe_request(auth: Auth) -> None:
"""
must validate safe request
"""
# login and logout are always safe
- assert auth.is_safe_request("/user-api/v1/login", UserAccess.Write)
- assert auth.is_safe_request("/user-api/v1/logout", UserAccess.Write)
+ assert await auth.is_safe_request("/user-api/v1/login", UserAccess.Write)
+ assert await auth.is_safe_request("/user-api/v1/logout", UserAccess.Write)
auth.allowed_paths.add("/safe")
auth.allowed_paths_groups.add("/unsafe/safe")
- assert auth.is_safe_request("/safe", UserAccess.Write)
- assert not auth.is_safe_request("/unsafe", UserAccess.Write)
- assert auth.is_safe_request("/unsafe/safe", UserAccess.Write)
- assert auth.is_safe_request("/unsafe/safe/suffix", UserAccess.Write)
+ assert await auth.is_safe_request("/safe", UserAccess.Write)
+ assert not await auth.is_safe_request("/unsafe", UserAccess.Write)
+ assert await auth.is_safe_request("/unsafe/safe", UserAccess.Write)
+ assert await auth.is_safe_request("/unsafe/safe/suffix", UserAccess.Write)
-def test_is_safe_request_empty(auth: Auth) -> None:
+async def test_is_safe_request_empty(auth: Auth) -> None:
"""
must not allow requests without path
"""
- assert not auth.is_safe_request(None, UserAccess.Read)
- assert not auth.is_safe_request("", UserAccess.Read)
+ assert not await auth.is_safe_request(None, UserAccess.Read)
+ assert not await auth.is_safe_request("", UserAccess.Read)
-def test_is_safe_request_read_only(auth: Auth) -> None:
+async def test_is_safe_request_read_only(auth: Auth) -> None:
"""
must allow read-only requests if it is set in settings
"""
- assert auth.is_safe_request("/", UserAccess.Read)
+ assert await auth.is_safe_request("/", UserAccess.Read)
auth.allow_read_only = True
- assert auth.is_safe_request("/unsafe", UserAccess.Read)
+ assert await auth.is_safe_request("/unsafe", UserAccess.Read)
-def test_known_username(auth: Auth, user: User) -> None:
+async def test_known_username(auth: Auth, user: User) -> None:
"""
must allow any username
"""
- assert auth.known_username(user.username)
+ assert await auth.known_username(user.username)
-def test_verify_access(auth: Auth, user: User) -> None:
+async def test_verify_access(auth: Auth, user: User) -> None:
"""
must allow any access
"""
- assert auth.verify_access(user.username, user.access, None)
- assert auth.verify_access(user.username, UserAccess.Write, None)
+ assert await auth.verify_access(user.username, user.access, None)
+ assert await auth.verify_access(user.username, UserAccess.Write, None)
diff --git a/tests/ahriman/core/auth/test_mapping.py b/tests/ahriman/core/auth/test_mapping.py
new file mode 100644
index 00000000..216aab24
--- /dev/null
+++ b/tests/ahriman/core/auth/test_mapping.py
@@ -0,0 +1,73 @@
+from ahriman.core.auth.mapping import Mapping
+from ahriman.models.user import User
+from ahriman.models.user_access import UserAccess
+
+
+async def test_check_credentials(mapping: Mapping, user: User) -> None:
+ """
+ must return true for valid credentials
+ """
+ current_password = user.password
+ user.password = user.hash_password(mapping.salt)
+ mapping._users[user.username] = user
+ assert await mapping.check_credentials(user.username, current_password)
+ # here password is hashed so it is invalid
+ assert not await mapping.check_credentials(user.username, user.password)
+
+
+async def test_check_credentials_empty(mapping: Mapping) -> None:
+ """
+ must reject on empty credentials
+ """
+ assert not await mapping.check_credentials(None, "")
+ assert not await mapping.check_credentials("", None)
+ assert not await mapping.check_credentials(None, None)
+
+
+async def test_check_credentials_unknown(mapping: Mapping, user: User) -> None:
+ """
+ must reject on unknown user
+ """
+ assert not await mapping.check_credentials(user.username, user.password)
+
+
+def test_get_user(mapping: Mapping, user: User) -> None:
+ """
+ must return user from storage by username
+ """
+ mapping._users[user.username] = user
+ assert mapping.get_user(user.username) == user
+
+
+def test_get_user_normalized(mapping: Mapping, user: User) -> None:
+ """
+ must return user from storage by username case-insensitive
+ """
+ mapping._users[user.username] = user
+ assert mapping.get_user(user.username.upper()) == user
+
+
+def test_get_user_unknown(mapping: Mapping, user: User) -> None:
+ """
+ must return None in case if no user found
+ """
+ assert mapping.get_user(user.username) is None
+
+
+async def test_known_username(mapping: Mapping, user: User) -> None:
+ """
+ must allow only known users
+ """
+ mapping._users[user.username] = user
+ assert await mapping.known_username(user.username)
+ assert not await mapping.known_username(None)
+ assert not await mapping.known_username(user.password)
+
+
+async def test_verify_access(mapping: Mapping, user: User) -> None:
+ """
+ must verify user access
+ """
+ mapping._users[user.username] = user
+ assert await mapping.verify_access(user.username, user.access, None)
+ assert not await mapping.verify_access(user.username, UserAccess.Write, None)
diff --git a/tests/ahriman/core/auth/test_mapping_auth.py b/tests/ahriman/core/auth/test_mapping_auth.py
deleted file mode 100644
index b3b11a91..00000000
--- a/tests/ahriman/core/auth/test_mapping_auth.py
+++ /dev/null
@@ -1,121 +0,0 @@
-import pytest
-
-from ahriman.core.auth.mapping_auth import MappingAuth
-from ahriman.core.configuration import Configuration
-from ahriman.core.exceptions import DuplicateUser
-from ahriman.models.user import User
-from ahriman.models.user_access import UserAccess
-
-
-def test_get_users(mapping_auth: MappingAuth, configuration: Configuration) -> None:
- """
- must return valid user list
- """
- user_write = User("user_write", "pwd_write", UserAccess.Write)
- write_section = Configuration.section_name("auth", user_write.access.value)
- configuration.set_option(write_section, user_write.username, user_write.password)
- user_read = User("user_read", "pwd_read", UserAccess.Read)
- read_section = Configuration.section_name("auth", user_read.access.value)
- configuration.set_option(read_section, user_read.username, user_read.password)
- user_read = User("user_read", "pwd_read", UserAccess.Read)
- read_section = Configuration.section_name("auth", user_read.access.value)
- configuration.set_option(read_section, user_read.username, user_read.password)
-
- users = mapping_auth.get_users(configuration)
- expected = {user_write.username: user_write, user_read.username: user_read}
- assert users == expected
-
-
-def test_get_users_normalized(mapping_auth: MappingAuth, configuration: Configuration) -> None:
- """
- must return user list with normalized usernames in keys
- """
- user = User("UsEr", "pwd_read", UserAccess.Read)
- read_section = Configuration.section_name("auth", user.access.value)
- configuration.set_option(read_section, user.username, user.password)
-
- users = mapping_auth.get_users(configuration)
- expected = user.username.lower()
- assert expected in users
- assert users[expected].username == expected
-
-
-def test_get_users_duplicate(mapping_auth: MappingAuth, configuration: Configuration, user: User) -> None:
- """
- must raise exception on duplicate username
- """
- write_section = Configuration.section_name("auth", UserAccess.Write.value)
- configuration.set_option(write_section, user.username, user.password)
- read_section = Configuration.section_name("auth", UserAccess.Read.value)
- configuration.set_option(read_section, user.username, user.password)
-
- with pytest.raises(DuplicateUser):
- mapping_auth.get_users(configuration)
-
-
-def test_check_credentials(mapping_auth: MappingAuth, user: User) -> None:
- """
- must return true for valid credentials
- """
- current_password = user.password
- user.password = user.hash_password(mapping_auth.salt)
- mapping_auth._users[user.username] = user
- assert mapping_auth.check_credentials(user.username, current_password)
- assert not mapping_auth.check_credentials(user.username, user.password) # here password is hashed so it is invalid
-
-
-def test_check_credentials_empty(mapping_auth: MappingAuth) -> None:
- """
- must reject on empty credentials
- """
- assert not mapping_auth.check_credentials(None, "")
- assert not mapping_auth.check_credentials("", None)
- assert not mapping_auth.check_credentials(None, None)
-
-
-def test_check_credentials_unknown(mapping_auth: MappingAuth, user: User) -> None:
- """
- must reject on unknown user
- """
- assert not mapping_auth.check_credentials(user.username, user.password)
-
-
-def test_get_user(mapping_auth: MappingAuth, user: User) -> None:
- """
- must return user from storage by username
- """
- mapping_auth._users[user.username] = user
- assert mapping_auth.get_user(user.username) == user
-
-
-def test_get_user_normalized(mapping_auth: MappingAuth, user: User) -> None:
- """
- must return user from storage by username case-insensitive
- """
- mapping_auth._users[user.username] = user
- assert mapping_auth.get_user(user.username.upper()) == user
-
-
-def test_get_user_unknown(mapping_auth: MappingAuth, user: User) -> None:
- """
- must return None in case if no user found
- """
- assert mapping_auth.get_user(user.username) is None
-
-
-def test_known_username(mapping_auth: MappingAuth, user: User) -> None:
- """
- must allow only known users
- """
- mapping_auth._users[user.username] = user
- assert mapping_auth.known_username(user.username)
- assert not mapping_auth.known_username(user.password)
-
-
-def test_verify_access(mapping_auth: MappingAuth, user: User) -> None:
- """
- must verify user access
- """
- mapping_auth._users[user.username] = user
- assert mapping_auth.verify_access(user.username, user.access, None)
- assert not mapping_auth.verify_access(user.username, UserAccess.Write, None)
diff --git a/tests/ahriman/core/auth/test_oauth.py b/tests/ahriman/core/auth/test_oauth.py
new file mode 100644
index 00000000..73dc2ffc
--- /dev/null
+++ b/tests/ahriman/core/auth/test_oauth.py
@@ -0,0 +1,98 @@
+import aioauth_client
+import pytest
+
+from pytest_mock import MockerFixture
+
+from ahriman.core.auth.oauth import OAuth
+from ahriman.core.exceptions import InvalidOption
+
+
+def test_auth_control(oauth: OAuth) -> None:
+ """
+ must return a control for authorization
+ """
+ assert oauth.auth_control
+ assert "
None:
+ """
+ must return valid provider type
+ """
+ assert OAuth.get_provider("OAuth2Client") == aioauth_client.OAuth2Client
+ assert OAuth.get_provider("GoogleClient") == aioauth_client.GoogleClient
+ assert OAuth.get_provider("GoogleClient") == aioauth_client.GoogleClient
+
+
+def test_get_provider_not_a_type() -> None:
+ """
+ must raise an exception if attribute is not a type
+ """
+ with pytest.raises(InvalidOption):
+ OAuth.get_provider("__version__")
+
+
+def test_get_provider_invalid_type() -> None:
+ """
+ must raise an exception if attribute is not an OAuth2 client
+ """
+ with pytest.raises(InvalidOption):
+ OAuth.get_provider("User")
+ with pytest.raises(InvalidOption):
+ OAuth.get_provider("OAuth1Client")
+
+
+def test_get_client(oauth: OAuth) -> None:
+ """
+ must return valid OAuth2 client
+ """
+ client = oauth.get_client()
+ assert isinstance(client, aioauth_client.GoogleClient)
+ assert client.client_id == oauth.client_id
+ assert client.client_secret == oauth.client_secret
+
+
+def test_get_oauth_url(oauth: OAuth, mocker: MockerFixture) -> None:
+ """
+ must generate valid OAuth authorization URL
+ """
+ authorize_url_mock = mocker.patch("aioauth_client.GoogleClient.get_authorize_url")
+ oauth.get_oauth_url()
+ authorize_url_mock.assert_called_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri)
+
+
+async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
+ """
+ must return authorized user ID
+ """
+ access_token_mock = mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
+ user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info",
+ return_value=(aioauth_client.User(email="email"), ""))
+
+ email = await oauth.get_oauth_username("code")
+ access_token_mock.assert_called_with("code", redirect_uri=oauth.redirect_uri)
+ user_info_mock.assert_called_once()
+ assert email == "email"
+
+
+async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixture) -> None:
+ """
+ must return None in case of OAuth request error (get_access_token)
+ """
+ mocker.patch("aioauth_client.GoogleClient.get_access_token", side_effect=Exception())
+ user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info")
+
+ email = await oauth.get_oauth_username("code")
+ assert email is None
+ user_info_mock.assert_not_called()
+
+
+async def test_get_oauth_username_exception_2(oauth: OAuth, mocker: MockerFixture) -> None:
+ """
+ must return None in case of OAuth request error (user_info)
+ """
+ mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
+ mocker.patch("aioauth_client.GoogleClient.user_info", side_effect=Exception())
+
+ email = await oauth.get_oauth_username("code")
+ assert email is None
diff --git a/tests/ahriman/models/test_auth_settings.py b/tests/ahriman/models/test_auth_settings.py
index c5d6ed9b..b3816cd3 100644
--- a/tests/ahriman/models/test_auth_settings.py
+++ b/tests/ahriman/models/test_auth_settings.py
@@ -21,6 +21,10 @@ def test_from_option_valid() -> None:
assert AuthSettings.from_option("no") == AuthSettings.Disabled
assert AuthSettings.from_option("NO") == AuthSettings.Disabled
+ assert AuthSettings.from_option("oauth") == AuthSettings.OAuth
+ assert AuthSettings.from_option("OAuth") == AuthSettings.OAuth
+ assert AuthSettings.from_option("OAuth2") == AuthSettings.OAuth
+
assert AuthSettings.from_option("configuration") == AuthSettings.Configuration
assert AuthSettings.from_option("ConFigUration") == AuthSettings.Configuration
assert AuthSettings.from_option("mapping") == AuthSettings.Configuration
diff --git a/tests/ahriman/models/test_user.py b/tests/ahriman/models/test_user.py
index 6acc2f3d..be3257fa 100644
--- a/tests/ahriman/models/test_user.py
+++ b/tests/ahriman/models/test_user.py
@@ -10,6 +10,7 @@ def test_from_option(user: User) -> None:
# default is read access
user.access = UserAccess.Write
assert User.from_option(user.username, user.password) != user
+ assert User.from_option(user.username, user.password, user.access) == user
def test_from_option_empty() -> None:
@@ -32,6 +33,26 @@ def test_check_credentials_hash_password(user: User) -> None:
assert not user.check_credentials(user.password, "salt")
+def test_check_credentials_empty_hash(user: User) -> None:
+ """
+ must reject any authorization if the hash is invalid
+ """
+ current_password = user.password
+ assert not user.check_credentials(current_password, "salt")
+ user.password = ""
+ assert not user.check_credentials(current_password, "salt")
+
+
+def test_hash_password_empty_hash(user: User) -> None:
+ """
+ must return empty string after hash in case if password not set
+ """
+ user.password = ""
+ assert user.hash_password("salt") == ""
+ user.password = None
+ assert user.hash_password("salt") == ""
+
+
def test_generate_password() -> None:
"""
must generate password with specified length
diff --git a/tests/ahriman/web/middlewares/test_auth_handler.py b/tests/ahriman/web/middlewares/test_auth_handler.py
index d1789a42..21a6aebf 100644
--- a/tests/ahriman/web/middlewares/test_auth_handler.py
+++ b/tests/ahriman/web/middlewares/test_auth_handler.py
@@ -2,10 +2,9 @@ import pytest
from aiohttp import web
from pytest_mock import MockerFixture
-from unittest.mock import AsyncMock, MagicMock
+from unittest.mock import AsyncMock
from ahriman.core.auth.auth import Auth
-from ahriman.core.configuration import Configuration
from ahriman.models.user import User
from ahriman.models.user_access import UserAccess
from ahriman.web.middlewares.auth_handler import auth_handler, AuthorizationPolicy, setup_auth
@@ -23,7 +22,7 @@ async def test_permits(authorization_policy: AuthorizationPolicy, user: User) ->
"""
must call validator check
"""
- authorization_policy.validator = MagicMock()
+ authorization_policy.validator = AsyncMock()
authorization_policy.validator.verify_access.return_value = True
assert await authorization_policy.permits(user.username, user.access, "/endpoint")
diff --git a/tests/ahriman/web/views/service/test_views_service_add.py b/tests/ahriman/web/views/service/test_views_service_add.py
index 24c750c9..4442dcab 100644
--- a/tests/ahriman/web/views/service/test_views_service_add.py
+++ b/tests/ahriman/web/views/service/test_views_service_add.py
@@ -9,7 +9,7 @@ async def test_post(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add")
response = await client.post("/service-api/v1/add", json={"packages": ["ahriman"]})
- assert response.status == 200
+ assert response.ok
add_mock.assert_called_with(["ahriman"], True)
@@ -20,7 +20,7 @@ async def test_post_now(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add")
response = await client.post("/service-api/v1/add", json={"packages": ["ahriman"], "build_now": False})
- assert response.status == 200
+ assert response.ok
add_mock.assert_called_with(["ahriman"], False)
@@ -42,5 +42,5 @@ async def test_post_update(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_add")
response = await client.post("/service-api/v1/update", json={"packages": ["ahriman"]})
- assert response.status == 200
+ assert response.ok
add_mock.assert_called_with(["ahriman"], True)
diff --git a/tests/ahriman/web/views/service/test_views_service_remove.py b/tests/ahriman/web/views/service/test_views_service_remove.py
index d7c45d80..c801c3be 100644
--- a/tests/ahriman/web/views/service/test_views_service_remove.py
+++ b/tests/ahriman/web/views/service/test_views_service_remove.py
@@ -9,7 +9,7 @@ async def test_post(client: TestClient, mocker: MockerFixture) -> None:
add_mock = mocker.patch("ahriman.core.spawn.Spawn.packages_remove")
response = await client.post("/service-api/v1/remove", json={"packages": ["ahriman"]})
- assert response.status == 200
+ assert response.ok
add_mock.assert_called_with(["ahriman"])
diff --git a/tests/ahriman/web/views/service/test_views_service_search.py b/tests/ahriman/web/views/service/test_views_service_search.py
index bfd3158d..4f248d72 100644
--- a/tests/ahriman/web/views/service/test_views_service_search.py
+++ b/tests/ahriman/web/views/service/test_views_service_search.py
@@ -11,7 +11,7 @@ async def test_get(client: TestClient, aur_package_ahriman: aur.Package, mocker:
mocker.patch("aur.search", return_value=[aur_package_ahriman])
response = await client.get("/service-api/v1/search", params={"for": "ahriman"})
- assert response.status == 200
+ assert response.ok
assert await response.json() == ["ahriman"]
@@ -33,7 +33,7 @@ async def test_get_join(client: TestClient, mocker: MockerFixture) -> None:
search_mock = mocker.patch("aur.search")
response = await client.get("/service-api/v1/search", params=[("for", "ahriman"), ("for", "maybe")])
- assert response.status == 200
+ assert response.ok
search_mock.assert_called_with("ahriman maybe")
@@ -44,7 +44,7 @@ async def test_get_join_filter(client: TestClient, mocker: MockerFixture) -> Non
search_mock = mocker.patch("aur.search")
response = await client.get("/service-api/v1/search", params=[("for", "ah"), ("for", "maybe")])
- assert response.status == 200
+ assert response.ok
search_mock.assert_called_with("maybe")
diff --git a/tests/ahriman/web/views/status/test_views_status_ahriman.py b/tests/ahriman/web/views/status/test_views_status_ahriman.py
index 489da68c..7108701b 100644
--- a/tests/ahriman/web/views/status/test_views_status_ahriman.py
+++ b/tests/ahriman/web/views/status/test_views_status_ahriman.py
@@ -11,7 +11,7 @@ async def test_get(client: TestClient) -> None:
response = await client.get("/status-api/v1/ahriman")
status = BuildStatus.from_json(await response.json())
- assert response.status == 200
+ assert response.ok
assert status.status == BuildStatusEnum.Unknown
@@ -26,7 +26,7 @@ async def test_post(client: TestClient) -> None:
response = await client.get("/status-api/v1/ahriman")
status = BuildStatus.from_json(await response.json())
- assert response.status == 200
+ assert response.ok
assert status.status == BuildStatusEnum.Success
diff --git a/tests/ahriman/web/views/status/test_views_status_package.py b/tests/ahriman/web/views/status/test_views_status_package.py
index bb760890..cdbe447e 100644
--- a/tests/ahriman/web/views/status/test_views_status_package.py
+++ b/tests/ahriman/web/views/status/test_views_status_package.py
@@ -14,7 +14,7 @@ async def test_get(client: TestClient, package_ahriman: Package, package_python_
json={"status": BuildStatusEnum.Success.value, "package": package_python_schedule.view()})
response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}")
- assert response.status == 200
+ assert response.ok
packages = [Package.from_json(item["package"]) for item in await response.json()]
assert packages
@@ -45,7 +45,7 @@ async def test_delete(client: TestClient, package_ahriman: Package, package_pyth
assert response.status == 404
response = await client.get(f"/status-api/v1/packages/{package_python_schedule.base}")
- assert response.status == 200
+ assert response.ok
async def test_delete_unknown(client: TestClient, package_ahriman: Package, package_python_schedule: Package) -> None:
@@ -62,7 +62,7 @@ async def test_delete_unknown(client: TestClient, package_ahriman: Package, pack
assert response.status == 404
response = await client.get(f"/status-api/v1/packages/{package_python_schedule.base}")
- assert response.status == 200
+ assert response.ok
async def test_post(client: TestClient, package_ahriman: Package) -> None:
@@ -75,7 +75,7 @@ async def test_post(client: TestClient, package_ahriman: Package) -> None:
assert post_response.status == 204
response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}")
- assert response.status == 200
+ assert response.ok
async def test_post_exception(client: TestClient, package_ahriman: Package) -> None:
@@ -100,7 +100,7 @@ async def test_post_light(client: TestClient, package_ahriman: Package) -> None:
assert post_response.status == 204
response = await client.get(f"/status-api/v1/packages/{package_ahriman.base}")
- assert response.status == 200
+ assert response.ok
statuses = {
Package.from_json(item["package"]).base: BuildStatus.from_json(item["status"])
for item in await response.json()
diff --git a/tests/ahriman/web/views/status/test_views_status_packages.py b/tests/ahriman/web/views/status/test_views_status_packages.py
index c1d18b0e..30d4a962 100644
--- a/tests/ahriman/web/views/status/test_views_status_packages.py
+++ b/tests/ahriman/web/views/status/test_views_status_packages.py
@@ -15,7 +15,7 @@ async def test_get(client: TestClient, package_ahriman: Package, package_python_
json={"status": BuildStatusEnum.Success.value, "package": package_python_schedule.view()})
response = await client.get("/status-api/v1/packages")
- assert response.status == 200
+ assert response.ok
packages = [Package.from_json(item["package"]) for item in await response.json()]
assert packages
diff --git a/tests/ahriman/web/views/status/test_views_status_status.py b/tests/ahriman/web/views/status/test_views_status_status.py
index e7e7fd1c..776f3775 100644
--- a/tests/ahriman/web/views/status/test_views_status_status.py
+++ b/tests/ahriman/web/views/status/test_views_status_status.py
@@ -14,7 +14,7 @@ async def test_get(client: TestClient, package_ahriman: Package) -> None:
json={"status": BuildStatusEnum.Success.value, "package": package_ahriman.view()})
response = await client.get("/status-api/v1/status")
- assert response.status == 200
+ assert response.ok
json = await response.json()
assert json["version"] == version.__version__
diff --git a/tests/ahriman/web/views/test_views_index.py b/tests/ahriman/web/views/test_views_index.py
index 61342667..0c9ff926 100644
--- a/tests/ahriman/web/views/test_views_index.py
+++ b/tests/ahriman/web/views/test_views_index.py
@@ -6,7 +6,7 @@ async def test_get(client_with_auth: TestClient) -> None:
must generate status page correctly (/)
"""
response = await client_with_auth.get("/")
- assert response.status == 200
+ assert response.ok
assert await response.text()
@@ -15,7 +15,7 @@ async def test_get_index(client_with_auth: TestClient) -> None:
must generate status page correctly (/index.html)
"""
response = await client_with_auth.get("/index.html")
- assert response.status == 200
+ assert response.ok
assert await response.text()
@@ -24,7 +24,7 @@ async def test_get_without_auth(client: TestClient) -> None:
must use dummy authorized_userid function in case if no security library installed
"""
response = await client.get("/")
- assert response.status == 200
+ assert response.ok
assert await response.text()
@@ -33,4 +33,4 @@ async def test_get_static(client: TestClient) -> None:
must return static files
"""
response = await client.get("/static/favicon.ico")
- assert response.status == 200
+ assert response.ok
diff --git a/tests/ahriman/web/views/user/test_views_user_login.py b/tests/ahriman/web/views/user/test_views_user_login.py
index 8be422f1..622cf565 100644
--- a/tests/ahriman/web/views/user/test_views_user_login.py
+++ b/tests/ahriman/web/views/user/test_views_user_login.py
@@ -1,9 +1,75 @@
from aiohttp.test_utils import TestClient
from pytest_mock import MockerFixture
+from unittest.mock import MagicMock
+from ahriman.core.auth.oauth import OAuth
from ahriman.models.user import User
+async def test_get_default_validator(client_with_auth: TestClient) -> None:
+ """
+ must return 405 in case if no OAuth enabled
+ """
+ get_response = await client_with_auth.get("/user-api/v1/login")
+ assert get_response.status == 405
+
+
+async def test_get_redirect_to_oauth(client_with_auth: TestClient) -> None:
+ """
+ must redirect to OAuth service provider in case if no code is supplied
+ """
+ oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
+ oauth.get_oauth_url.return_value = "https://example.com"
+
+ get_response = await client_with_auth.get("/user-api/v1/login")
+ assert get_response.ok
+ oauth.get_oauth_url.assert_called_once()
+
+
+async def test_get_redirect_to_oauth_empty_code(client_with_auth: TestClient) -> None:
+ """
+ must redirect to OAuth service provider in case if empty code is supplied
+ """
+ oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
+ oauth.get_oauth_url.return_value = "https://example.com"
+
+ get_response = await client_with_auth.get("/user-api/v1/login", params={"code": ""})
+ assert get_response.ok
+ oauth.get_oauth_url.assert_called_once()
+
+
+async def test_get(client_with_auth: TestClient, mocker: MockerFixture) -> None:
+ """
+ must login user correctly from OAuth
+ """
+ oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
+ oauth.get_oauth_username.return_value = "user"
+ oauth.known_username.return_value = True
+ oauth.enabled = False # lol
+ remember_mock = mocker.patch("aiohttp_security.remember")
+
+ get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"})
+
+ assert get_response.ok
+ oauth.get_oauth_username.assert_called_with("code")
+ oauth.known_username.assert_called_with("user")
+ remember_mock.assert_called_once()
+
+
+async def test_get_unauthorized(client_with_auth: TestClient, mocker: MockerFixture) -> None:
+ """
+ must return unauthorized from OAuth
+ """
+ oauth = client_with_auth.app["validator"] = MagicMock(spec=OAuth)
+ oauth.known_username.return_value = False
+ remember_mock = mocker.patch("aiohttp_security.remember")
+
+ get_response = await client_with_auth.get("/user-api/v1/login", params={"code": "code"})
+
+ assert get_response.status == 401
+ remember_mock.assert_not_called()
+
+
async def test_post(client_with_auth: TestClient, user: User, mocker: MockerFixture) -> None:
"""
must login user correctly
@@ -12,10 +78,10 @@ async def test_post(client_with_auth: TestClient, user: User, mocker: MockerFixt
remember_mock = mocker.patch("aiohttp_security.remember")
post_response = await client_with_auth.post("/user-api/v1/login", json=payload)
- assert post_response.status == 200
+ assert post_response.ok
post_response = await client_with_auth.post("/user-api/v1/login", data=payload)
- assert post_response.status == 200
+ assert post_response.ok
remember_mock.assert_called()
@@ -26,7 +92,7 @@ async def test_post_skip(client: TestClient, user: User) -> None:
"""
payload = {"username": user.username, "password": user.password}
post_response = await client.post("/user-api/v1/login", json=payload)
- assert post_response.status == 200
+ assert post_response.ok
async def test_post_unauthorized(client_with_auth: TestClient, user: User, mocker: MockerFixture) -> None:
diff --git a/tests/ahriman/web/views/user/test_views_user_logout.py b/tests/ahriman/web/views/user/test_views_user_logout.py
index 3d287bb0..7e316204 100644
--- a/tests/ahriman/web/views/user/test_views_user_logout.py
+++ b/tests/ahriman/web/views/user/test_views_user_logout.py
@@ -11,7 +11,7 @@ async def test_post(client_with_auth: TestClient, mocker: MockerFixture) -> None
forget_mock = mocker.patch("aiohttp_security.forget")
post_response = await client_with_auth.post("/user-api/v1/logout")
- assert post_response.status == 200
+ assert post_response.ok
forget_mock.assert_called_once()
@@ -32,4 +32,4 @@ async def test_post_disabled(client: TestClient) -> None:
must raise exception if auth is disabled
"""
post_response = await client.post("/user-api/v1/logout")
- assert post_response.status == 200
+ assert post_response.ok
diff --git a/tests/testresources/core/ahriman.ini b/tests/testresources/core/ahriman.ini
index 95343437..54ade25c 100644
--- a/tests/testresources/core/ahriman.ini
+++ b/tests/testresources/core/ahriman.ini
@@ -10,6 +10,10 @@ root = /
[auth]
allow_read_only = no
+client_id = client_id
+client_secret = client_secret
+oauth_provider = GoogleClient
+oauth_scopes = https://www.googleapis.com/auth/userinfo.email
salt = salt
[build]