Compare commits

...

7 Commits

19 changed files with 143 additions and 49 deletions

View File

@@ -250,6 +250,7 @@ Available options are:
Remote pull trigger Remote pull trigger
^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
* ``type`` - type of the pull, string, optional, must be set to ``gitremote`` if exists.
* ``pull_url`` - URL of the remote repository from which PKGBUILDs can be pulled before build process, string, required. * ``pull_url`` - URL of the remote repository from which PKGBUILDs can be pulled before build process, string, required.
* ``pull_branch`` - branch of the remote repository from which PKGBUILDs can be pulled before build process, string, optional, default is ``master``. * ``pull_branch`` - branch of the remote repository from which PKGBUILDs can be pulled before build process, string, optional, default is ``master``.
@@ -270,6 +271,7 @@ Available options are:
Remote push trigger Remote push trigger
^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
* ``type`` - type of the push, string, optional, must be set to ``gitremote`` if exists.
* ``commit_email`` - git commit email, string, optional, default is ``ahriman@localhost``. * ``commit_email`` - git commit email, string, optional, default is ``ahriman@localhost``.
* ``commit_user`` - git commit user, string, optional, default is ``ahriman``. * ``commit_user`` - git commit user, string, optional, default is ``ahriman``.
* ``push_url`` - URL of the remote repository to which PKGBUILDs should be pushed after build process, string, required. * ``push_url`` - URL of the remote repository to which PKGBUILDs should be pushed after build process, string, required.

View File

@@ -22,6 +22,7 @@ from collections.abc import Iterable
from ahriman.application.application.application_properties import ApplicationProperties from ahriman.application.application.application_properties import ApplicationProperties
from ahriman.application.application.workers import Updater from ahriman.application.application.workers import Updater
from ahriman.core.build_tools.sources import Sources from ahriman.core.build_tools.sources import Sources
from ahriman.core.exceptions import UnknownPackageError
from ahriman.models.package import Package from ahriman.models.package import Package
from ahriman.models.packagers import Packagers from ahriman.models.packagers import Packagers
from ahriman.models.result import Result from ahriman.models.result import Result
@@ -116,7 +117,7 @@ class ApplicationRepository(ApplicationProperties):
for single in probe.packages: for single in probe.packages:
try: try:
_ = Package.from_aur(single, None) _ = Package.from_aur(single, None)
except Exception: except UnknownPackageError:
packages.append(single) packages.append(single)
return packages return packages

View File

@@ -22,6 +22,11 @@ try:
except ImportError: except ImportError:
aiohttp_security = None # type: ignore[assignment] aiohttp_security = None # type: ignore[assignment]
try:
import aiohttp_session
except ImportError:
aiohttp_session = None # type: ignore[assignment]
from typing import Any from typing import Any
@@ -50,7 +55,7 @@ async def check_authorized(*args: Any, **kwargs: Any) -> Any:
Args: Args:
*args(Any): argument list as provided by check_authorized function *args(Any): argument list as provided by check_authorized function
**kwargs(Any): named argument list as provided by authorized_userid function **kwargs(Any): named argument list as provided by check_authorized function
Returns: Returns:
Any: ``None`` in case if no aiohttp_security module found and function call otherwise Any: ``None`` in case if no aiohttp_security module found and function call otherwise
@@ -66,7 +71,7 @@ async def forget(*args: Any, **kwargs: Any) -> Any:
Args: Args:
*args(Any): argument list as provided by forget function *args(Any): argument list as provided by forget function
**kwargs(Any): named argument list as provided by authorized_userid function **kwargs(Any): named argument list as provided by forget function
Returns: Returns:
Any: ``None`` in case if no aiohttp_security module found and function call otherwise Any: ``None`` in case if no aiohttp_security module found and function call otherwise
@@ -76,13 +81,29 @@ async def forget(*args: Any, **kwargs: Any) -> Any:
return None return None
async def get_session(*args: Any, **kwargs: Any) -> Any:
"""
handle aiohttp session methods
Args:
*args(Any): argument list as provided by get_session function
**kwargs(Any): named argument list as provided by get_session function
Returns:
Any: empty dictionary in case if no aiohttp_session module found and function call otherwise
"""
if aiohttp_session is not None:
return await aiohttp_session.get_session(*args, **kwargs)
return {}
async def remember(*args: Any, **kwargs: Any) -> Any: async def remember(*args: Any, **kwargs: Any) -> Any:
""" """
handle disabled auth handle disabled auth
Args: Args:
*args(Any): argument list as provided by remember function *args(Any): argument list as provided by remember function
**kwargs(Any): named argument list as provided by authorized_userid function **kwargs(Any): named argument list as provided by remember function
Returns: Returns:
Any: ``None`` in case if no aiohttp_security module found and function call otherwise Any: ``None`` in case if no aiohttp_security module found and function call otherwise

View File

@@ -19,6 +19,8 @@
# #
import aioauth_client import aioauth_client
from typing import Any
from ahriman.core.auth.mapping import Mapping from ahriman.core.auth.mapping import Mapping
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
from ahriman.core.database import SQLite from ahriman.core.database import SQLite
@@ -53,7 +55,7 @@ class OAuth(Mapping):
self.client_secret = configuration.get("auth", "client_secret") self.client_secret = configuration.get("auth", "client_secret")
# in order to use OAuth feature the service must be publicity available # in order to use OAuth feature the service must be publicity available
# thus we expect that address is set # thus we expect that address is set
self.redirect_uri = f"""{configuration.get("web", "address")}/api/v1/login""" self.redirect_uri = f"{configuration.get("web", "address")}/api/v1/login"
self.provider = self.get_provider(configuration.get("auth", "oauth_provider")) self.provider = self.get_provider(configuration.get("auth", "oauth_provider"))
# it is list, but we will have to convert to string it anyway # it is list, but we will have to convert to string it anyway
self.scopes = configuration.get("auth", "oauth_scopes") self.scopes = configuration.get("auth", "oauth_scopes")
@@ -102,27 +104,35 @@ class OAuth(Mapping):
""" """
return self.provider(client_id=self.client_id, client_secret=self.client_secret) return self.provider(client_id=self.client_id, client_secret=self.client_secret)
def get_oauth_url(self) -> str: def get_oauth_url(self, state: str) -> str:
""" """
get authorization URI for the specified settings get authorization URI for the specified settings
Args:
state(str): CSRF token to pass to OAuth2 provider
Returns: Returns:
str: authorization URI as a string str: authorization URI as a string
""" """
client = self.get_client() client = self.get_client()
uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri) uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri, state=state)
return uri return uri
async def get_oauth_username(self, code: str) -> str | None: async def get_oauth_username(self, code: str, state: str | None, session: dict[str, Any]) -> str | None:
""" """
extract OAuth username from remote extract OAuth username from remote
Args: Args:
code(str): authorization code provided by external service code(str): authorization code provided by external service
state(str | None): CSRF token returned by external service
session(dict[str, Any]): current session instance
Returns: Returns:
str | None: username as is in OAuth provider str | None: username as is in OAuth provider
""" """
if state is None or state != session.get("state"):
return None
try: try:
client = self.get_client() client = self.get_client()
access_token, _ = await client.get_access_token(code, redirect_uri=self.redirect_uri) access_token, _ = await client.get_access_token(code, redirect_uri=self.redirect_uri)

View File

@@ -141,14 +141,15 @@ class LogsOperations(Operations):
connection.execute( connection.execute(
""" """
delete from logs delete from logs
where (package_base, version, repository, process_id) not in ( where repository = :repository
select package_base, version, repository, process_id from logs and (package_base, version, repository, process_id) not in (
where (package_base, version, repository, created) in ( select package_base, version, repository, process_id from logs
select package_base, version, repository, max(created) from logs where (package_base, version, repository, created) in (
where repository = :repository select package_base, version, repository, max(created) from logs
group by package_base, version, repository where repository = :repository
group by package_base, version, repository
)
) )
)
""", """,
{ {
"repository": repository_id.id, "repository": repository_id.id,

View File

@@ -48,6 +48,10 @@ class RemotePullTrigger(Trigger):
"gitremote": { "gitremote": {
"type": "dict", "type": "dict",
"schema": { "schema": {
"type": {
"type": "string",
"allowed": ["gitremote"],
},
"pull_url": { "pull_url": {
"type": "string", "type": "string",
"required": True, "required": True,
@@ -60,7 +64,6 @@ class RemotePullTrigger(Trigger):
}, },
}, },
} }
CONFIGURATION_SCHEMA_FALLBACK = "gitremote"
def __init__(self, repository_id: RepositoryId, configuration: Configuration) -> None: def __init__(self, repository_id: RepositoryId, configuration: Configuration) -> None:
""" """
@@ -89,7 +92,6 @@ class RemotePullTrigger(Trigger):
trigger action which will be called at the start of the application trigger action which will be called at the start of the application
""" """
for target in self.targets: for target in self.targets:
section, _ = self.configuration.gettype( section, _ = self.configuration.gettype(target, self.repository_id, fallback="gitremote")
target, self.repository_id, fallback=self.CONFIGURATION_SCHEMA_FALLBACK)
runner = RemotePull(self.repository_id, self.configuration, section) runner = RemotePull(self.repository_id, self.configuration, section)
runner.run() runner.run()

View File

@@ -52,6 +52,10 @@ class RemotePushTrigger(Trigger):
"gitremote": { "gitremote": {
"type": "dict", "type": "dict",
"schema": { "schema": {
"type": {
"type": "string",
"allowed": ["gitremote"],
},
"commit_email": { "commit_email": {
"type": "string", "type": "string",
"empty": False, "empty": False,
@@ -72,7 +76,6 @@ class RemotePushTrigger(Trigger):
}, },
}, },
} }
CONFIGURATION_SCHEMA_FALLBACK = "gitremote"
def __init__(self, repository_id: RepositoryId, configuration: Configuration) -> None: def __init__(self, repository_id: RepositoryId, configuration: Configuration) -> None:
""" """
@@ -111,7 +114,6 @@ class RemotePushTrigger(Trigger):
reporter = ctx.get(Client) reporter = ctx.get(Client)
for target in self.targets: for target in self.targets:
section, _ = self.configuration.gettype( section, _ = self.configuration.gettype(target, self.repository_id, fallback="gitremote")
target, self.repository_id, fallback=self.CONFIGURATION_SCHEMA_FALLBACK)
runner = RemotePush(reporter, self.configuration, section) runner = RemotePush(reporter, self.configuration, section)
runner.run(result) runner.run(result)

View File

@@ -185,8 +185,9 @@ class UpdateHandler(PackageInfo, Cleaner):
else: else:
self.reporter.set_pending(local.base) self.reporter.set_pending(local.base)
self.event(local.base, EventType.PackageOutdated, "Manual update is requested") self.event(local.base, EventType.PackageOutdated, "Manual update is requested")
self.clear_queue()
except Exception: except Exception:
self.logger.exception("could not load packages from database") self.logger.exception("could not load packages from database")
self.clear_queue()
return result return result

View File

@@ -34,8 +34,6 @@ class Trigger(LazyLogging):
Attributes: Attributes:
CONFIGURATION_SCHEMA(ConfigurationSchema): (class attribute) configuration schema template CONFIGURATION_SCHEMA(ConfigurationSchema): (class attribute) configuration schema template
CONFIGURATION_SCHEMA_FALLBACK(str | None): (class attribute) optional fallback option for defining
configuration schema type used
REQUIRES_REPOSITORY(bool): (class attribute) either trigger requires loaded repository or not REQUIRES_REPOSITORY(bool): (class attribute) either trigger requires loaded repository or not
configuration(Configuration): configuration instance configuration(Configuration): configuration instance
repository_id(RepositoryId): repository unique identifier repository_id(RepositoryId): repository unique identifier
@@ -59,7 +57,6 @@ class Trigger(LazyLogging):
""" """
CONFIGURATION_SCHEMA: ClassVar[ConfigurationSchema] = {} CONFIGURATION_SCHEMA: ClassVar[ConfigurationSchema] = {}
CONFIGURATION_SCHEMA_FALLBACK: ClassVar[str | None] = None
REQUIRES_REPOSITORY: ClassVar[bool] = True REQUIRES_REPOSITORY: ClassVar[bool] = True
def __init__(self, repository_id: RepositoryId, configuration: Configuration) -> None: def __init__(self, repository_id: RepositoryId, configuration: Configuration) -> None:

View File

@@ -329,10 +329,10 @@ def list_flatmap(source: Iterable[T], extractor: Callable[[T], Iterable[R]]) ->
Args: Args:
source(Iterable[T]): source list source(Iterable[T]): source list
extractor(Callable[[T], list[R]): property extractor extractor(Callable[[T], Iterable[R]]): property extractor
Returns: Returns:
list[T]: combined list of unique entries in properties list list[R]: combined list of unique entries in properties list
""" """
def generator() -> Iterator[R]: def generator() -> Iterator[R]:
for inner in source: for inner in source:

View File

@@ -155,7 +155,7 @@ class Package(LazyLogging):
bool: ``True`` in case if package base looks like VCS package and ``False`` otherwise bool: ``True`` in case if package base looks like VCS package and ``False`` otherwise
""" """
return self.base.endswith("-bzr") \ return self.base.endswith("-bzr") \
or self.base.endswith("-csv") \ or self.base.endswith("-cvs") \
or self.base.endswith("-darcs") \ or self.base.endswith("-darcs") \
or self.base.endswith("-git") \ or self.base.endswith("-git") \
or self.base.endswith("-hg") \ or self.base.endswith("-hg") \

View File

@@ -28,3 +28,6 @@ class OAuth2Schema(Schema):
code = fields.String(metadata={ code = fields.String(metadata={
"description": "OAuth2 authorization code. In case if not set, the redirect to provider will be initiated", "description": "OAuth2 authorization code. In case if not set, the redirect to provider will be initiated",
}) })
state = fields.String(metadata={
"description": "CSRF token returned by OAuth2 provider",
})

View File

@@ -18,9 +18,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
from aiohttp.web import HTTPBadRequest, HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized from aiohttp.web import HTTPBadRequest, HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized
from secrets import token_urlsafe
from typing import ClassVar from typing import ClassVar
from ahriman.core.auth.helpers import remember from ahriman.core.auth.helpers import get_session, remember
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.web.apispec.decorators import apidocs from ahriman.web.apispec.decorators import apidocs
from ahriman.web.schemas import LoginSchema, OAuth2Schema from ahriman.web.schemas import LoginSchema, OAuth2Schema
@@ -68,15 +69,18 @@ class LoginView(BaseView):
raise HTTPMethodNotAllowed(self.request.method, ["POST"]) raise HTTPMethodNotAllowed(self.request.method, ["POST"])
oauth_provider = self.validator oauth_provider = self.validator
if not isinstance(oauth_provider, OAuth): # there is actually property, but mypy does not like it anyway if not isinstance(oauth_provider, OAuth):
raise HTTPMethodNotAllowed(self.request.method, ["POST"]) raise HTTPMethodNotAllowed(self.request.method, ["POST"])
session = await get_session(self.request)
code = self.request.query.get("code") code = self.request.query.get("code")
if not code: if not code:
raise HTTPFound(oauth_provider.get_oauth_url()) state = session["state"] = token_urlsafe()
raise HTTPFound(oauth_provider.get_oauth_url(state))
response = HTTPFound("/") response = HTTPFound("/")
identity = await oauth_provider.get_oauth_username(code) identity = await oauth_provider.get_oauth_username(code, self.request.query.get("state"), session)
if identity is not None and await self.validator.known_username(identity): if identity is not None and await self.validator.known_username(identity):
await remember(self.request, response, identity) await remember(self.request, response, identity)
raise response raise response

View File

@@ -5,6 +5,7 @@ from pytest_mock import MockerFixture
from unittest.mock import call as MockCall from unittest.mock import call as MockCall
from ahriman.application.application.application_repository import ApplicationRepository from ahriman.application.application.application_repository import ApplicationRepository
from ahriman.core.exceptions import UnknownPackageError
from ahriman.core.tree import Leaf, Tree from ahriman.core.tree import Leaf, Tree
from ahriman.models.changes import Changes from ahriman.models.changes import Changes
from ahriman.models.package import Package from ahriman.models.package import Package
@@ -135,7 +136,7 @@ def test_unknown_no_aur(application_repository: ApplicationRepository, package_a
must return empty list in case if there is locally stored PKGBUILD must return empty list in case if there is locally stored PKGBUILD
""" """
mocker.patch("ahriman.core.repository.repository.Repository.packages", return_value=[package_ahriman]) mocker.patch("ahriman.core.repository.repository.Repository.packages", return_value=[package_ahriman])
mocker.patch("ahriman.models.package.Package.from_aur", side_effect=Exception) mocker.patch("ahriman.models.package.Package.from_aur", side_effect=UnknownPackageError(package_ahriman.base))
mocker.patch("ahriman.models.package.Package.from_build", return_value=package_ahriman) mocker.patch("ahriman.models.package.Package.from_build", return_value=package_ahriman)
mocker.patch("pathlib.Path.is_dir", return_value=True) mocker.patch("pathlib.Path.is_dir", return_value=True)
mocker.patch("ahriman.core.build_tools.sources.Sources.has_remotes", return_value=False) mocker.patch("ahriman.core.build_tools.sources.Sources.has_remotes", return_value=False)
@@ -149,7 +150,7 @@ def test_unknown_no_aur_no_local(application_repository: ApplicationRepository,
must return list of packages missing in aur and in local storage must return list of packages missing in aur and in local storage
""" """
mocker.patch("ahriman.core.repository.repository.Repository.packages", return_value=[package_ahriman]) mocker.patch("ahriman.core.repository.repository.Repository.packages", return_value=[package_ahriman])
mocker.patch("ahriman.models.package.Package.from_aur", side_effect=Exception) mocker.patch("ahriman.models.package.Package.from_aur", side_effect=UnknownPackageError(package_ahriman.base))
mocker.patch("pathlib.Path.is_dir", return_value=False) mocker.patch("pathlib.Path.is_dir", return_value=False)
packages = application_repository.unknown() packages = application_repository.unknown()

View File

@@ -13,6 +13,13 @@ def test_import_aiohttp_security() -> None:
assert helpers.aiohttp_security assert helpers.aiohttp_security
def test_import_aiohttp_session() -> None:
"""
must import aiohttp_session correctly
"""
assert helpers.aiohttp_session
async def test_authorized_userid_dummy(mocker: MockerFixture) -> None: async def test_authorized_userid_dummy(mocker: MockerFixture) -> None:
""" """
must not call authorized_userid from library if not enabled must not call authorized_userid from library if not enabled
@@ -55,6 +62,23 @@ async def test_forget_dummy(mocker: MockerFixture) -> None:
await helpers.forget() await helpers.forget()
async def test_get_session_dummy(mocker: MockerFixture) -> None:
"""
must return empty dict if no aiohttp_session module found
"""
mocker.patch.object(helpers, "aiohttp_session", None)
assert await helpers.get_session() == {}
async def test_get_session_library(mocker: MockerFixture) -> None:
"""
must call get_session from library if enabled
"""
get_session_mock = mocker.patch("aiohttp_session.get_session")
await helpers.get_session()
get_session_mock.assert_called_once_with()
async def test_forget_library(mocker: MockerFixture) -> None: async def test_forget_library(mocker: MockerFixture) -> None:
""" """
must call forget from library if enabled must call forget from library if enabled
@@ -88,3 +112,12 @@ def test_import_aiohttp_security_missing(mocker: MockerFixture) -> None:
mocker.patch.dict(sys.modules, {"aiohttp_security": None}) mocker.patch.dict(sys.modules, {"aiohttp_security": None})
importlib.reload(helpers) importlib.reload(helpers)
assert helpers.aiohttp_security is None assert helpers.aiohttp_security is None
def test_import_aiohttp_session_missing(mocker: MockerFixture) -> None:
"""
must set missing flag if no aiohttp_session module found
"""
mocker.patch.dict(sys.modules, {"aiohttp_session": None})
importlib.reload(helpers)
assert helpers.aiohttp_session is None

View File

@@ -57,8 +57,8 @@ def test_get_oauth_url(oauth: OAuth, mocker: MockerFixture) -> None:
must generate valid OAuth authorization URL must generate valid OAuth authorization URL
""" """
authorize_url_mock = mocker.patch("aioauth_client.GoogleClient.get_authorize_url") authorize_url_mock = mocker.patch("aioauth_client.GoogleClient.get_authorize_url")
oauth.get_oauth_url() oauth.get_oauth_url(state="state")
authorize_url_mock.assert_called_once_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri) authorize_url_mock.assert_called_once_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri, state="state")
async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None: async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
@@ -69,10 +69,9 @@ async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info", user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info",
return_value=(aioauth_client.User(email="email"), "")) return_value=(aioauth_client.User(email="email"), ""))
email = await oauth.get_oauth_username("code") assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) == "email"
access_token_mock.assert_called_once_with("code", redirect_uri=oauth.redirect_uri) access_token_mock.assert_called_once_with("code", redirect_uri=oauth.redirect_uri)
user_info_mock.assert_called_once_with() user_info_mock.assert_called_once_with()
assert email == "email"
async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixture) -> None: async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixture) -> None:
@@ -82,8 +81,7 @@ async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixtur
mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", "")) mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
mocker.patch("aioauth_client.GoogleClient.user_info", return_value=(aioauth_client.User(username="username"), "")) mocker.patch("aioauth_client.GoogleClient.user_info", return_value=(aioauth_client.User(username="username"), ""))
username = await oauth.get_oauth_username("code") assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) == "username"
assert username == "username"
async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixture) -> None: async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixture) -> None:
@@ -93,8 +91,7 @@ async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixtur
mocker.patch("aioauth_client.GoogleClient.get_access_token", side_effect=Exception) mocker.patch("aioauth_client.GoogleClient.get_access_token", side_effect=Exception)
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info") user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info")
email = await oauth.get_oauth_username("code") assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) is None
assert email is None
user_info_mock.assert_not_called() user_info_mock.assert_not_called()
@@ -105,5 +102,19 @@ async def test_get_oauth_username_exception_2(oauth: OAuth, mocker: MockerFixtur
mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", "")) mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
mocker.patch("aioauth_client.GoogleClient.user_info", side_effect=Exception) mocker.patch("aioauth_client.GoogleClient.user_info", side_effect=Exception)
email = await oauth.get_oauth_username("code") username = await oauth.get_oauth_username("code", state="state", session={"state": "state"})
assert email is None assert username is None
async def test_get_oauth_username_csrf_missing(oauth: OAuth) -> None:
"""
must return None if CSRF state is missing
"""
assert await oauth.get_oauth_username("code", state=None, session={"state": "state"}) is None
async def test_get_oauth_username_csrf_mismatch(oauth: OAuth) -> None:
"""
must return None if CSRF state does not match session
"""
assert await oauth.get_oauth_username("code", state="wrong", session={"state": "state"}) is None

View File

@@ -357,4 +357,8 @@ def test_updates_manual_with_failures(update_handler: UpdateHandler, package_ahr
""" """
mocker.patch("ahriman.core.database.SQLite.build_queue_get", side_effect=Exception) mocker.patch("ahriman.core.database.SQLite.build_queue_get", side_effect=Exception)
mocker.patch("ahriman.core.repository.update_handler.UpdateHandler.packages", return_value=[package_ahriman]) mocker.patch("ahriman.core.repository.update_handler.UpdateHandler.packages", return_value=[package_ahriman])
assert update_handler.updates_manual() == [] assert update_handler.updates_manual() == []
from ahriman.core.repository.cleaner import Cleaner
Cleaner.clear_queue.assert_not_called()

View File

@@ -70,7 +70,6 @@ def test_configuration_schema_variables() -> None:
must return empty schema must return empty schema
""" """
assert Trigger.CONFIGURATION_SCHEMA == {} assert Trigger.CONFIGURATION_SCHEMA == {}
assert Trigger.CONFIGURATION_SCHEMA_FALLBACK is None
def test_configuration_sections(configuration: Configuration) -> None: def test_configuration_sections(configuration: Configuration) -> None:

View File

@@ -54,7 +54,7 @@ async def test_get_redirect_to_oauth(client_with_oauth_auth: TestClient) -> None
assert not request_schema.validate(payload) assert not request_schema.validate(payload)
response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False) response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False)
assert response.ok assert response.ok
oauth.get_oauth_url.assert_called_once_with() oauth.get_oauth_url.assert_called_once_with(pytest.helpers.anyvar(str))
async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClient) -> None: async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClient) -> None:
@@ -69,13 +69,15 @@ async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClie
assert not request_schema.validate(payload) assert not request_schema.validate(payload)
response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False) response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False)
assert response.ok assert response.ok
oauth.get_oauth_url.assert_called_once_with() oauth.get_oauth_url.assert_called_once_with(pytest.helpers.anyvar(str))
async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) -> None: async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) -> None:
""" """
must log in user correctly from OAuth must log in user correctly from OAuth
""" """
session = {"state": "state"}
mocker.patch("ahriman.web.views.v1.user.login.get_session", return_value=session)
oauth = client_with_oauth_auth.app[AuthKey] oauth = client_with_oauth_auth.app[AuthKey]
oauth.get_oauth_username.return_value = "user" oauth.get_oauth_username.return_value = "user"
oauth.known_username.return_value = True oauth.known_username.return_value = True
@@ -84,12 +86,12 @@ async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) ->
remember_mock = mocker.patch("ahriman.web.views.v1.user.login.remember") remember_mock = mocker.patch("ahriman.web.views.v1.user.login.remember")
request_schema = pytest.helpers.schema_request(LoginView.get, location="querystring") request_schema = pytest.helpers.schema_request(LoginView.get, location="querystring")
payload = {"code": "code"} payload = {"code": "code", "state": "state"}
assert not request_schema.validate(payload) assert not request_schema.validate(payload)
response = await client_with_oauth_auth.get("/api/v1/login", params=payload) response = await client_with_oauth_auth.get("/api/v1/login", params=payload)
assert response.ok assert response.ok
oauth.get_oauth_username.assert_called_once_with("code") oauth.get_oauth_username.assert_called_once_with("code", "state", session)
oauth.known_username.assert_called_once_with("user") oauth.known_username.assert_called_once_with("user")
remember_mock.assert_called_once_with( remember_mock.assert_called_once_with(
pytest.helpers.anyvar(int), pytest.helpers.anyvar(int), pytest.helpers.anyvar(int)) pytest.helpers.anyvar(int), pytest.helpers.anyvar(int), pytest.helpers.anyvar(int))