feat: implement CSRF protection

This commit is contained in:
2026-02-17 03:16:13 +02:00
parent 536d040a6a
commit ba80a91d95
7 changed files with 109 additions and 25 deletions

View File

@@ -13,6 +13,13 @@ def test_import_aiohttp_security() -> None:
assert helpers.aiohttp_security
def test_import_aiohttp_session() -> None:
"""
must import aiohttp_session correctly
"""
assert helpers.aiohttp_session
async def test_authorized_userid_dummy(mocker: MockerFixture) -> None:
"""
must not call authorized_userid from library if not enabled
@@ -55,6 +62,23 @@ async def test_forget_dummy(mocker: MockerFixture) -> None:
await helpers.forget()
async def test_get_session_dummy(mocker: MockerFixture) -> None:
"""
must return empty dict if no aiohttp_session module found
"""
mocker.patch.object(helpers, "aiohttp_session", None)
assert await helpers.get_session() == {}
async def test_get_session_library(mocker: MockerFixture) -> None:
"""
must call get_session from library if enabled
"""
get_session_mock = mocker.patch("aiohttp_session.get_session")
await helpers.get_session()
get_session_mock.assert_called_once_with()
async def test_forget_library(mocker: MockerFixture) -> None:
"""
must call forget from library if enabled
@@ -88,3 +112,12 @@ def test_import_aiohttp_security_missing(mocker: MockerFixture) -> None:
mocker.patch.dict(sys.modules, {"aiohttp_security": None})
importlib.reload(helpers)
assert helpers.aiohttp_security is None
def test_import_aiohttp_session_missing(mocker: MockerFixture) -> None:
"""
must set missing flag if no aiohttp_session module found
"""
mocker.patch.dict(sys.modules, {"aiohttp_session": None})
importlib.reload(helpers)
assert helpers.aiohttp_session is None

View File

@@ -57,8 +57,8 @@ def test_get_oauth_url(oauth: OAuth, mocker: MockerFixture) -> None:
must generate valid OAuth authorization URL
"""
authorize_url_mock = mocker.patch("aioauth_client.GoogleClient.get_authorize_url")
oauth.get_oauth_url()
authorize_url_mock.assert_called_once_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri)
oauth.get_oauth_url(state="state")
authorize_url_mock.assert_called_once_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri, state="state")
async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
@@ -69,10 +69,9 @@ async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info",
return_value=(aioauth_client.User(email="email"), ""))
email = await oauth.get_oauth_username("code")
assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) == "email"
access_token_mock.assert_called_once_with("code", redirect_uri=oauth.redirect_uri)
user_info_mock.assert_called_once_with()
assert email == "email"
async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixture) -> None:
@@ -82,8 +81,7 @@ async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixtur
mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
mocker.patch("aioauth_client.GoogleClient.user_info", return_value=(aioauth_client.User(username="username"), ""))
username = await oauth.get_oauth_username("code")
assert username == "username"
assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) == "username"
async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixture) -> None:
@@ -93,8 +91,7 @@ async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixtur
mocker.patch("aioauth_client.GoogleClient.get_access_token", side_effect=Exception)
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info")
email = await oauth.get_oauth_username("code")
assert email is None
assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) is None
user_info_mock.assert_not_called()
@@ -105,5 +102,19 @@ async def test_get_oauth_username_exception_2(oauth: OAuth, mocker: MockerFixtur
mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
mocker.patch("aioauth_client.GoogleClient.user_info", side_effect=Exception)
email = await oauth.get_oauth_username("code")
assert email is None
username = await oauth.get_oauth_username("code", state="state", session={"state": "state"})
assert username is None
async def test_get_oauth_username_csrf_missing(oauth: OAuth) -> None:
"""
must return None if CSRF state is missing
"""
assert await oauth.get_oauth_username("code", state=None, session={"state": "state"}) is None
async def test_get_oauth_username_csrf_mismatch(oauth: OAuth) -> None:
"""
must return None if CSRF state does not match session
"""
assert await oauth.get_oauth_username("code", state="wrong", session={"state": "state"}) is None

View File

@@ -54,7 +54,7 @@ async def test_get_redirect_to_oauth(client_with_oauth_auth: TestClient) -> None
assert not request_schema.validate(payload)
response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False)
assert response.ok
oauth.get_oauth_url.assert_called_once_with()
oauth.get_oauth_url.assert_called_once_with(pytest.helpers.anyvar(str))
async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClient) -> None:
@@ -69,13 +69,15 @@ async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClie
assert not request_schema.validate(payload)
response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False)
assert response.ok
oauth.get_oauth_url.assert_called_once_with()
oauth.get_oauth_url.assert_called_once_with(pytest.helpers.anyvar(str))
async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) -> None:
"""
must log in user correctly from OAuth
"""
session = {"state": "state"}
mocker.patch("ahriman.web.views.v1.user.login.get_session", return_value=session)
oauth = client_with_oauth_auth.app[AuthKey]
oauth.get_oauth_username.return_value = "user"
oauth.known_username.return_value = True
@@ -84,12 +86,12 @@ async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) ->
remember_mock = mocker.patch("ahriman.web.views.v1.user.login.remember")
request_schema = pytest.helpers.schema_request(LoginView.get, location="querystring")
payload = {"code": "code"}
payload = {"code": "code", "state": "state"}
assert not request_schema.validate(payload)
response = await client_with_oauth_auth.get("/api/v1/login", params=payload)
assert response.ok
oauth.get_oauth_username.assert_called_once_with("code")
oauth.get_oauth_username.assert_called_once_with("code", "state", session)
oauth.known_username.assert_called_once_with("user")
remember_mock.assert_called_once_with(
pytest.helpers.anyvar(int), pytest.helpers.anyvar(int), pytest.helpers.anyvar(int))