From ba80a91d95f7d1ba70d688d847a73552d29c9c9b Mon Sep 17 00:00:00 2001 From: Evgenii Alekseev Date: Tue, 17 Feb 2026 03:16:13 +0200 Subject: [PATCH] feat: implement CSRF protection --- src/ahriman/core/auth/helpers.py | 27 +++++++++++++-- src/ahriman/core/auth/oauth.py | 18 +++++++--- src/ahriman/web/schemas/oauth2_schema.py | 3 ++ src/ahriman/web/views/v1/user/login.py | 12 ++++--- tests/ahriman/core/auth/test_helpers.py | 33 +++++++++++++++++++ tests/ahriman/core/auth/test_oauth.py | 31 +++++++++++------ .../views/v1/user/test_view_v1_user_login.py | 10 +++--- 7 files changed, 109 insertions(+), 25 deletions(-) diff --git a/src/ahriman/core/auth/helpers.py b/src/ahriman/core/auth/helpers.py index 54daf28c..d528fa01 100644 --- a/src/ahriman/core/auth/helpers.py +++ b/src/ahriman/core/auth/helpers.py @@ -22,6 +22,11 @@ try: except ImportError: aiohttp_security = None # type: ignore[assignment] +try: + import aiohttp_session +except ImportError: + aiohttp_session = None # type: ignore[assignment] + from typing import Any @@ -50,7 +55,7 @@ async def check_authorized(*args: Any, **kwargs: Any) -> Any: Args: *args(Any): argument list as provided by check_authorized function - **kwargs(Any): named argument list as provided by authorized_userid function + **kwargs(Any): named argument list as provided by check_authorized function Returns: Any: ``None`` in case if no aiohttp_security module found and function call otherwise @@ -66,7 +71,7 @@ async def forget(*args: Any, **kwargs: Any) -> Any: Args: *args(Any): argument list as provided by forget function - **kwargs(Any): named argument list as provided by authorized_userid function + **kwargs(Any): named argument list as provided by forget function Returns: Any: ``None`` in case if no aiohttp_security module found and function call otherwise @@ -76,13 +81,29 @@ async def forget(*args: Any, **kwargs: Any) -> Any: return None +async def get_session(*args: Any, **kwargs: Any) -> Any: + """ + handle aiohttp session methods + + Args: + *args(Any): argument list as provided by get_session function + **kwargs(Any): named argument list as provided by get_session function + + Returns: + Any: empty dictionary in case if no aiohttp_session module found and function call otherwise + """ + if aiohttp_session is not None: + return await aiohttp_session.get_session(*args, **kwargs) + return {} + + async def remember(*args: Any, **kwargs: Any) -> Any: """ handle disabled auth Args: *args(Any): argument list as provided by remember function - **kwargs(Any): named argument list as provided by authorized_userid function + **kwargs(Any): named argument list as provided by remember function Returns: Any: ``None`` in case if no aiohttp_security module found and function call otherwise diff --git a/src/ahriman/core/auth/oauth.py b/src/ahriman/core/auth/oauth.py index e4b64baf..288027d9 100644 --- a/src/ahriman/core/auth/oauth.py +++ b/src/ahriman/core/auth/oauth.py @@ -19,6 +19,8 @@ # import aioauth_client +from typing import Any + from ahriman.core.auth.mapping import Mapping from ahriman.core.configuration import Configuration from ahriman.core.database import SQLite @@ -53,7 +55,7 @@ class OAuth(Mapping): self.client_secret = configuration.get("auth", "client_secret") # in order to use OAuth feature the service must be publicity available # thus we expect that address is set - self.redirect_uri = f"""{configuration.get("web", "address")}/api/v1/login""" + self.redirect_uri = f"{configuration.get("web", "address")}/api/v1/login" self.provider = self.get_provider(configuration.get("auth", "oauth_provider")) # it is list, but we will have to convert to string it anyway self.scopes = configuration.get("auth", "oauth_scopes") @@ -102,27 +104,35 @@ class OAuth(Mapping): """ return self.provider(client_id=self.client_id, client_secret=self.client_secret) - def get_oauth_url(self) -> str: + def get_oauth_url(self, state: str) -> str: """ get authorization URI for the specified settings + Args: + state(str): CSRF token to pass to OAuth2 provider + Returns: str: authorization URI as a string """ client = self.get_client() - uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri) + uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri, state=state) return uri - async def get_oauth_username(self, code: str) -> str | None: + async def get_oauth_username(self, code: str, state: str | None, session: dict[str, Any]) -> str | None: """ extract OAuth username from remote Args: code(str): authorization code provided by external service + state(str | None): CSRF token returned by external service + session(dict[str, Any]): current session instance Returns: str | None: username as is in OAuth provider """ + if state is None or state != session.get("state"): + return None + try: client = self.get_client() access_token, _ = await client.get_access_token(code, redirect_uri=self.redirect_uri) diff --git a/src/ahriman/web/schemas/oauth2_schema.py b/src/ahriman/web/schemas/oauth2_schema.py index b403057a..c5158e2d 100644 --- a/src/ahriman/web/schemas/oauth2_schema.py +++ b/src/ahriman/web/schemas/oauth2_schema.py @@ -28,3 +28,6 @@ class OAuth2Schema(Schema): code = fields.String(metadata={ "description": "OAuth2 authorization code. In case if not set, the redirect to provider will be initiated", }) + state = fields.String(metadata={ + "description": "CSRF token returned by OAuth2 provider", + }) diff --git a/src/ahriman/web/views/v1/user/login.py b/src/ahriman/web/views/v1/user/login.py index 7ac743ee..47742306 100644 --- a/src/ahriman/web/views/v1/user/login.py +++ b/src/ahriman/web/views/v1/user/login.py @@ -18,9 +18,10 @@ # along with this program. If not, see . # from aiohttp.web import HTTPBadRequest, HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized +from secrets import token_urlsafe from typing import ClassVar -from ahriman.core.auth.helpers import remember +from ahriman.core.auth.helpers import get_session, remember from ahriman.models.user_access import UserAccess from ahriman.web.apispec.decorators import apidocs from ahriman.web.schemas import LoginSchema, OAuth2Schema @@ -68,15 +69,18 @@ class LoginView(BaseView): raise HTTPMethodNotAllowed(self.request.method, ["POST"]) oauth_provider = self.validator - if not isinstance(oauth_provider, OAuth): # there is actually property, but mypy does not like it anyway + if not isinstance(oauth_provider, OAuth): raise HTTPMethodNotAllowed(self.request.method, ["POST"]) + session = await get_session(self.request) + code = self.request.query.get("code") if not code: - raise HTTPFound(oauth_provider.get_oauth_url()) + state = session["state"] = token_urlsafe() + raise HTTPFound(oauth_provider.get_oauth_url(state)) response = HTTPFound("/") - identity = await oauth_provider.get_oauth_username(code) + identity = await oauth_provider.get_oauth_username(code, self.request.query.get("state"), session) if identity is not None and await self.validator.known_username(identity): await remember(self.request, response, identity) raise response diff --git a/tests/ahriman/core/auth/test_helpers.py b/tests/ahriman/core/auth/test_helpers.py index c929da11..29c9dc2e 100644 --- a/tests/ahriman/core/auth/test_helpers.py +++ b/tests/ahriman/core/auth/test_helpers.py @@ -13,6 +13,13 @@ def test_import_aiohttp_security() -> None: assert helpers.aiohttp_security +def test_import_aiohttp_session() -> None: + """ + must import aiohttp_session correctly + """ + assert helpers.aiohttp_session + + async def test_authorized_userid_dummy(mocker: MockerFixture) -> None: """ must not call authorized_userid from library if not enabled @@ -55,6 +62,23 @@ async def test_forget_dummy(mocker: MockerFixture) -> None: await helpers.forget() +async def test_get_session_dummy(mocker: MockerFixture) -> None: + """ + must return empty dict if no aiohttp_session module found + """ + mocker.patch.object(helpers, "aiohttp_session", None) + assert await helpers.get_session() == {} + + +async def test_get_session_library(mocker: MockerFixture) -> None: + """ + must call get_session from library if enabled + """ + get_session_mock = mocker.patch("aiohttp_session.get_session") + await helpers.get_session() + get_session_mock.assert_called_once_with() + + async def test_forget_library(mocker: MockerFixture) -> None: """ must call forget from library if enabled @@ -88,3 +112,12 @@ def test_import_aiohttp_security_missing(mocker: MockerFixture) -> None: mocker.patch.dict(sys.modules, {"aiohttp_security": None}) importlib.reload(helpers) assert helpers.aiohttp_security is None + + +def test_import_aiohttp_session_missing(mocker: MockerFixture) -> None: + """ + must set missing flag if no aiohttp_session module found + """ + mocker.patch.dict(sys.modules, {"aiohttp_session": None}) + importlib.reload(helpers) + assert helpers.aiohttp_session is None diff --git a/tests/ahriman/core/auth/test_oauth.py b/tests/ahriman/core/auth/test_oauth.py index 2f6ceb99..1bb67943 100644 --- a/tests/ahriman/core/auth/test_oauth.py +++ b/tests/ahriman/core/auth/test_oauth.py @@ -57,8 +57,8 @@ def test_get_oauth_url(oauth: OAuth, mocker: MockerFixture) -> None: must generate valid OAuth authorization URL """ authorize_url_mock = mocker.patch("aioauth_client.GoogleClient.get_authorize_url") - oauth.get_oauth_url() - authorize_url_mock.assert_called_once_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri) + oauth.get_oauth_url(state="state") + authorize_url_mock.assert_called_once_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri, state="state") async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None: @@ -69,10 +69,9 @@ async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None: user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info", return_value=(aioauth_client.User(email="email"), "")) - email = await oauth.get_oauth_username("code") + assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) == "email" access_token_mock.assert_called_once_with("code", redirect_uri=oauth.redirect_uri) user_info_mock.assert_called_once_with() - assert email == "email" async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixture) -> None: @@ -82,8 +81,7 @@ async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixtur mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", "")) mocker.patch("aioauth_client.GoogleClient.user_info", return_value=(aioauth_client.User(username="username"), "")) - username = await oauth.get_oauth_username("code") - assert username == "username" + assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) == "username" async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixture) -> None: @@ -93,8 +91,7 @@ async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixtur mocker.patch("aioauth_client.GoogleClient.get_access_token", side_effect=Exception) user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info") - email = await oauth.get_oauth_username("code") - assert email is None + assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) is None user_info_mock.assert_not_called() @@ -105,5 +102,19 @@ async def test_get_oauth_username_exception_2(oauth: OAuth, mocker: MockerFixtur mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", "")) mocker.patch("aioauth_client.GoogleClient.user_info", side_effect=Exception) - email = await oauth.get_oauth_username("code") - assert email is None + username = await oauth.get_oauth_username("code", state="state", session={"state": "state"}) + assert username is None + + +async def test_get_oauth_username_csrf_missing(oauth: OAuth) -> None: + """ + must return None if CSRF state is missing + """ + assert await oauth.get_oauth_username("code", state=None, session={"state": "state"}) is None + + +async def test_get_oauth_username_csrf_mismatch(oauth: OAuth) -> None: + """ + must return None if CSRF state does not match session + """ + assert await oauth.get_oauth_username("code", state="wrong", session={"state": "state"}) is None diff --git a/tests/ahriman/web/views/v1/user/test_view_v1_user_login.py b/tests/ahriman/web/views/v1/user/test_view_v1_user_login.py index e10093a4..5ff6d26d 100644 --- a/tests/ahriman/web/views/v1/user/test_view_v1_user_login.py +++ b/tests/ahriman/web/views/v1/user/test_view_v1_user_login.py @@ -54,7 +54,7 @@ async def test_get_redirect_to_oauth(client_with_oauth_auth: TestClient) -> None assert not request_schema.validate(payload) response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False) assert response.ok - oauth.get_oauth_url.assert_called_once_with() + oauth.get_oauth_url.assert_called_once_with(pytest.helpers.anyvar(str)) async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClient) -> None: @@ -69,13 +69,15 @@ async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClie assert not request_schema.validate(payload) response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False) assert response.ok - oauth.get_oauth_url.assert_called_once_with() + oauth.get_oauth_url.assert_called_once_with(pytest.helpers.anyvar(str)) async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) -> None: """ must log in user correctly from OAuth """ + session = {"state": "state"} + mocker.patch("ahriman.web.views.v1.user.login.get_session", return_value=session) oauth = client_with_oauth_auth.app[AuthKey] oauth.get_oauth_username.return_value = "user" oauth.known_username.return_value = True @@ -84,12 +86,12 @@ async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) -> remember_mock = mocker.patch("ahriman.web.views.v1.user.login.remember") request_schema = pytest.helpers.schema_request(LoginView.get, location="querystring") - payload = {"code": "code"} + payload = {"code": "code", "state": "state"} assert not request_schema.validate(payload) response = await client_with_oauth_auth.get("/api/v1/login", params=payload) assert response.ok - oauth.get_oauth_username.assert_called_once_with("code") + oauth.get_oauth_username.assert_called_once_with("code", "state", session) oauth.known_username.assert_called_once_with("user") remember_mock.assert_called_once_with( pytest.helpers.anyvar(int), pytest.helpers.anyvar(int), pytest.helpers.anyvar(int))