mirror of
https://github.com/arcan1s/ahriman.git
synced 2026-02-24 21:59:48 +00:00
feat: implement CSRF protection
This commit is contained in:
@@ -13,6 +13,13 @@ def test_import_aiohttp_security() -> None:
|
||||
assert helpers.aiohttp_security
|
||||
|
||||
|
||||
def test_import_aiohttp_session() -> None:
|
||||
"""
|
||||
must import aiohttp_session correctly
|
||||
"""
|
||||
assert helpers.aiohttp_session
|
||||
|
||||
|
||||
async def test_authorized_userid_dummy(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must not call authorized_userid from library if not enabled
|
||||
@@ -55,6 +62,23 @@ async def test_forget_dummy(mocker: MockerFixture) -> None:
|
||||
await helpers.forget()
|
||||
|
||||
|
||||
async def test_get_session_dummy(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must return empty dict if no aiohttp_session module found
|
||||
"""
|
||||
mocker.patch.object(helpers, "aiohttp_session", None)
|
||||
assert await helpers.get_session() == {}
|
||||
|
||||
|
||||
async def test_get_session_library(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must call get_session from library if enabled
|
||||
"""
|
||||
get_session_mock = mocker.patch("aiohttp_session.get_session")
|
||||
await helpers.get_session()
|
||||
get_session_mock.assert_called_once_with()
|
||||
|
||||
|
||||
async def test_forget_library(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must call forget from library if enabled
|
||||
@@ -88,3 +112,12 @@ def test_import_aiohttp_security_missing(mocker: MockerFixture) -> None:
|
||||
mocker.patch.dict(sys.modules, {"aiohttp_security": None})
|
||||
importlib.reload(helpers)
|
||||
assert helpers.aiohttp_security is None
|
||||
|
||||
|
||||
def test_import_aiohttp_session_missing(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must set missing flag if no aiohttp_session module found
|
||||
"""
|
||||
mocker.patch.dict(sys.modules, {"aiohttp_session": None})
|
||||
importlib.reload(helpers)
|
||||
assert helpers.aiohttp_session is None
|
||||
|
||||
@@ -57,8 +57,8 @@ def test_get_oauth_url(oauth: OAuth, mocker: MockerFixture) -> None:
|
||||
must generate valid OAuth authorization URL
|
||||
"""
|
||||
authorize_url_mock = mocker.patch("aioauth_client.GoogleClient.get_authorize_url")
|
||||
oauth.get_oauth_url()
|
||||
authorize_url_mock.assert_called_once_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri)
|
||||
oauth.get_oauth_url(state="state")
|
||||
authorize_url_mock.assert_called_once_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri, state="state")
|
||||
|
||||
|
||||
async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
|
||||
@@ -69,10 +69,9 @@ async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
|
||||
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info",
|
||||
return_value=(aioauth_client.User(email="email"), ""))
|
||||
|
||||
email = await oauth.get_oauth_username("code")
|
||||
assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) == "email"
|
||||
access_token_mock.assert_called_once_with("code", redirect_uri=oauth.redirect_uri)
|
||||
user_info_mock.assert_called_once_with()
|
||||
assert email == "email"
|
||||
|
||||
|
||||
async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixture) -> None:
|
||||
@@ -82,8 +81,7 @@ async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixtur
|
||||
mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
|
||||
mocker.patch("aioauth_client.GoogleClient.user_info", return_value=(aioauth_client.User(username="username"), ""))
|
||||
|
||||
username = await oauth.get_oauth_username("code")
|
||||
assert username == "username"
|
||||
assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) == "username"
|
||||
|
||||
|
||||
async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixture) -> None:
|
||||
@@ -93,8 +91,7 @@ async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixtur
|
||||
mocker.patch("aioauth_client.GoogleClient.get_access_token", side_effect=Exception)
|
||||
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info")
|
||||
|
||||
email = await oauth.get_oauth_username("code")
|
||||
assert email is None
|
||||
assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) is None
|
||||
user_info_mock.assert_not_called()
|
||||
|
||||
|
||||
@@ -105,5 +102,19 @@ async def test_get_oauth_username_exception_2(oauth: OAuth, mocker: MockerFixtur
|
||||
mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
|
||||
mocker.patch("aioauth_client.GoogleClient.user_info", side_effect=Exception)
|
||||
|
||||
email = await oauth.get_oauth_username("code")
|
||||
assert email is None
|
||||
username = await oauth.get_oauth_username("code", state="state", session={"state": "state"})
|
||||
assert username is None
|
||||
|
||||
|
||||
async def test_get_oauth_username_csrf_missing(oauth: OAuth) -> None:
|
||||
"""
|
||||
must return None if CSRF state is missing
|
||||
"""
|
||||
assert await oauth.get_oauth_username("code", state=None, session={"state": "state"}) is None
|
||||
|
||||
|
||||
async def test_get_oauth_username_csrf_mismatch(oauth: OAuth) -> None:
|
||||
"""
|
||||
must return None if CSRF state does not match session
|
||||
"""
|
||||
assert await oauth.get_oauth_username("code", state="wrong", session={"state": "state"}) is None
|
||||
|
||||
@@ -54,7 +54,7 @@ async def test_get_redirect_to_oauth(client_with_oauth_auth: TestClient) -> None
|
||||
assert not request_schema.validate(payload)
|
||||
response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False)
|
||||
assert response.ok
|
||||
oauth.get_oauth_url.assert_called_once_with()
|
||||
oauth.get_oauth_url.assert_called_once_with(pytest.helpers.anyvar(str))
|
||||
|
||||
|
||||
async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClient) -> None:
|
||||
@@ -69,13 +69,15 @@ async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClie
|
||||
assert not request_schema.validate(payload)
|
||||
response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False)
|
||||
assert response.ok
|
||||
oauth.get_oauth_url.assert_called_once_with()
|
||||
oauth.get_oauth_url.assert_called_once_with(pytest.helpers.anyvar(str))
|
||||
|
||||
|
||||
async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) -> None:
|
||||
"""
|
||||
must log in user correctly from OAuth
|
||||
"""
|
||||
session = {"state": "state"}
|
||||
mocker.patch("ahriman.web.views.v1.user.login.get_session", return_value=session)
|
||||
oauth = client_with_oauth_auth.app[AuthKey]
|
||||
oauth.get_oauth_username.return_value = "user"
|
||||
oauth.known_username.return_value = True
|
||||
@@ -84,12 +86,12 @@ async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) ->
|
||||
remember_mock = mocker.patch("ahriman.web.views.v1.user.login.remember")
|
||||
request_schema = pytest.helpers.schema_request(LoginView.get, location="querystring")
|
||||
|
||||
payload = {"code": "code"}
|
||||
payload = {"code": "code", "state": "state"}
|
||||
assert not request_schema.validate(payload)
|
||||
response = await client_with_oauth_auth.get("/api/v1/login", params=payload)
|
||||
|
||||
assert response.ok
|
||||
oauth.get_oauth_username.assert_called_once_with("code")
|
||||
oauth.get_oauth_username.assert_called_once_with("code", "state", session)
|
||||
oauth.known_username.assert_called_once_with("user")
|
||||
remember_mock.assert_called_once_with(
|
||||
pytest.helpers.anyvar(int), pytest.helpers.anyvar(int), pytest.helpers.anyvar(int))
|
||||
|
||||
Reference in New Issue
Block a user