feat: implement CSRF protection

This commit is contained in:
2026-02-17 03:16:13 +02:00
parent 3b43861bcf
commit 431b1a7150
7 changed files with 109 additions and 25 deletions

View File

@@ -22,6 +22,11 @@ try:
except ImportError: except ImportError:
aiohttp_security = None # type: ignore[assignment] aiohttp_security = None # type: ignore[assignment]
try:
import aiohttp_session
except ImportError:
aiohttp_session = None # type: ignore[assignment]
from typing import Any from typing import Any
@@ -50,7 +55,7 @@ async def check_authorized(*args: Any, **kwargs: Any) -> Any:
Args: Args:
*args(Any): argument list as provided by check_authorized function *args(Any): argument list as provided by check_authorized function
**kwargs(Any): named argument list as provided by authorized_userid function **kwargs(Any): named argument list as provided by check_authorized function
Returns: Returns:
Any: ``None`` in case if no aiohttp_security module found and function call otherwise Any: ``None`` in case if no aiohttp_security module found and function call otherwise
@@ -66,7 +71,7 @@ async def forget(*args: Any, **kwargs: Any) -> Any:
Args: Args:
*args(Any): argument list as provided by forget function *args(Any): argument list as provided by forget function
**kwargs(Any): named argument list as provided by authorized_userid function **kwargs(Any): named argument list as provided by forget function
Returns: Returns:
Any: ``None`` in case if no aiohttp_security module found and function call otherwise Any: ``None`` in case if no aiohttp_security module found and function call otherwise
@@ -76,13 +81,29 @@ async def forget(*args: Any, **kwargs: Any) -> Any:
return None return None
async def get_session(*args: Any, **kwargs: Any) -> Any:
"""
handle aiohttp session methods
Args:
*args(Any): argument list as provided by get_session function
**kwargs(Any): named argument list as provided by get_session function
Returns:
Any: empty dictionary in case if no aiohttp_session module found and function call otherwise
"""
if aiohttp_session is not None:
return await aiohttp_session.get_session(*args, **kwargs)
return {}
async def remember(*args: Any, **kwargs: Any) -> Any: async def remember(*args: Any, **kwargs: Any) -> Any:
""" """
handle disabled auth handle disabled auth
Args: Args:
*args(Any): argument list as provided by remember function *args(Any): argument list as provided by remember function
**kwargs(Any): named argument list as provided by authorized_userid function **kwargs(Any): named argument list as provided by remember function
Returns: Returns:
Any: ``None`` in case if no aiohttp_security module found and function call otherwise Any: ``None`` in case if no aiohttp_security module found and function call otherwise

View File

@@ -19,6 +19,8 @@
# #
import aioauth_client import aioauth_client
from typing import Any
from ahriman.core.auth.mapping import Mapping from ahriman.core.auth.mapping import Mapping
from ahriman.core.configuration import Configuration from ahriman.core.configuration import Configuration
from ahriman.core.database import SQLite from ahriman.core.database import SQLite
@@ -53,7 +55,7 @@ class OAuth(Mapping):
self.client_secret = configuration.get("auth", "client_secret") self.client_secret = configuration.get("auth", "client_secret")
# in order to use OAuth feature the service must be publicity available # in order to use OAuth feature the service must be publicity available
# thus we expect that address is set # thus we expect that address is set
self.redirect_uri = f"""{configuration.get("web", "address")}/api/v1/login""" self.redirect_uri = f"{configuration.get("web", "address")}/api/v1/login"
self.provider = self.get_provider(configuration.get("auth", "oauth_provider")) self.provider = self.get_provider(configuration.get("auth", "oauth_provider"))
# it is list, but we will have to convert to string it anyway # it is list, but we will have to convert to string it anyway
self.scopes = configuration.get("auth", "oauth_scopes") self.scopes = configuration.get("auth", "oauth_scopes")
@@ -102,27 +104,35 @@ class OAuth(Mapping):
""" """
return self.provider(client_id=self.client_id, client_secret=self.client_secret) return self.provider(client_id=self.client_id, client_secret=self.client_secret)
def get_oauth_url(self) -> str: def get_oauth_url(self, state: str) -> str:
""" """
get authorization URI for the specified settings get authorization URI for the specified settings
Args:
state(str): CSRF token to pass to OAuth2 provider
Returns: Returns:
str: authorization URI as a string str: authorization URI as a string
""" """
client = self.get_client() client = self.get_client()
uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri) uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri, state=state)
return uri return uri
async def get_oauth_username(self, code: str) -> str | None: async def get_oauth_username(self, code: str, state: str | None, session: dict[str, Any]) -> str | None:
""" """
extract OAuth username from remote extract OAuth username from remote
Args: Args:
code(str): authorization code provided by external service code(str): authorization code provided by external service
state(str | None): CSRF token returned by external service
session(dict[str, Any]): current session instance
Returns: Returns:
str | None: username as is in OAuth provider str | None: username as is in OAuth provider
""" """
if state is None or state != session.get("state"):
return None
try: try:
client = self.get_client() client = self.get_client()
access_token, _ = await client.get_access_token(code, redirect_uri=self.redirect_uri) access_token, _ = await client.get_access_token(code, redirect_uri=self.redirect_uri)

View File

@@ -28,3 +28,6 @@ class OAuth2Schema(Schema):
code = fields.String(metadata={ code = fields.String(metadata={
"description": "OAuth2 authorization code. In case if not set, the redirect to provider will be initiated", "description": "OAuth2 authorization code. In case if not set, the redirect to provider will be initiated",
}) })
state = fields.String(metadata={
"description": "CSRF token returned by OAuth2 provider",
})

View File

@@ -18,9 +18,10 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
from aiohttp.web import HTTPBadRequest, HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized from aiohttp.web import HTTPBadRequest, HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized
from secrets import token_urlsafe
from typing import ClassVar from typing import ClassVar
from ahriman.core.auth.helpers import remember from ahriman.core.auth.helpers import get_session, remember
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
from ahriman.web.apispec.decorators import apidocs from ahriman.web.apispec.decorators import apidocs
from ahriman.web.schemas import LoginSchema, OAuth2Schema from ahriman.web.schemas import LoginSchema, OAuth2Schema
@@ -68,15 +69,18 @@ class LoginView(BaseView):
raise HTTPMethodNotAllowed(self.request.method, ["POST"]) raise HTTPMethodNotAllowed(self.request.method, ["POST"])
oauth_provider = self.validator oauth_provider = self.validator
if not isinstance(oauth_provider, OAuth): # there is actually property, but mypy does not like it anyway if not isinstance(oauth_provider, OAuth):
raise HTTPMethodNotAllowed(self.request.method, ["POST"]) raise HTTPMethodNotAllowed(self.request.method, ["POST"])
session = await get_session(self.request)
code = self.request.query.get("code") code = self.request.query.get("code")
if not code: if not code:
raise HTTPFound(oauth_provider.get_oauth_url()) state = session["state"] = token_urlsafe()
raise HTTPFound(oauth_provider.get_oauth_url(state))
response = HTTPFound("/") response = HTTPFound("/")
identity = await oauth_provider.get_oauth_username(code) identity = await oauth_provider.get_oauth_username(code, self.request.query.get("state"), session)
if identity is not None and await self.validator.known_username(identity): if identity is not None and await self.validator.known_username(identity):
await remember(self.request, response, identity) await remember(self.request, response, identity)
raise response raise response

View File

@@ -13,6 +13,13 @@ def test_import_aiohttp_security() -> None:
assert helpers.aiohttp_security assert helpers.aiohttp_security
def test_import_aiohttp_session() -> None:
"""
must import aiohttp_session correctly
"""
assert helpers.aiohttp_session
async def test_authorized_userid_dummy(mocker: MockerFixture) -> None: async def test_authorized_userid_dummy(mocker: MockerFixture) -> None:
""" """
must not call authorized_userid from library if not enabled must not call authorized_userid from library if not enabled
@@ -55,6 +62,23 @@ async def test_forget_dummy(mocker: MockerFixture) -> None:
await helpers.forget() await helpers.forget()
async def test_get_session_dummy(mocker: MockerFixture) -> None:
"""
must return empty dict if no aiohttp_session module found
"""
mocker.patch.object(helpers, "aiohttp_session", None)
assert await helpers.get_session() == {}
async def test_get_session_library(mocker: MockerFixture) -> None:
"""
must call get_session from library if enabled
"""
get_session_mock = mocker.patch("aiohttp_session.get_session")
await helpers.get_session()
get_session_mock.assert_called_once_with()
async def test_forget_library(mocker: MockerFixture) -> None: async def test_forget_library(mocker: MockerFixture) -> None:
""" """
must call forget from library if enabled must call forget from library if enabled
@@ -88,3 +112,12 @@ def test_import_aiohttp_security_missing(mocker: MockerFixture) -> None:
mocker.patch.dict(sys.modules, {"aiohttp_security": None}) mocker.patch.dict(sys.modules, {"aiohttp_security": None})
importlib.reload(helpers) importlib.reload(helpers)
assert helpers.aiohttp_security is None assert helpers.aiohttp_security is None
def test_import_aiohttp_session_missing(mocker: MockerFixture) -> None:
"""
must set missing flag if no aiohttp_session module found
"""
mocker.patch.dict(sys.modules, {"aiohttp_session": None})
importlib.reload(helpers)
assert helpers.aiohttp_session is None

View File

@@ -57,8 +57,8 @@ def test_get_oauth_url(oauth: OAuth, mocker: MockerFixture) -> None:
must generate valid OAuth authorization URL must generate valid OAuth authorization URL
""" """
authorize_url_mock = mocker.patch("aioauth_client.GoogleClient.get_authorize_url") authorize_url_mock = mocker.patch("aioauth_client.GoogleClient.get_authorize_url")
oauth.get_oauth_url() oauth.get_oauth_url(state="state")
authorize_url_mock.assert_called_once_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri) authorize_url_mock.assert_called_once_with(scope=oauth.scopes, redirect_uri=oauth.redirect_uri, state="state")
async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None: async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
@@ -69,10 +69,9 @@ async def test_get_oauth_username(oauth: OAuth, mocker: MockerFixture) -> None:
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info", user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info",
return_value=(aioauth_client.User(email="email"), "")) return_value=(aioauth_client.User(email="email"), ""))
email = await oauth.get_oauth_username("code") assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) == "email"
access_token_mock.assert_called_once_with("code", redirect_uri=oauth.redirect_uri) access_token_mock.assert_called_once_with("code", redirect_uri=oauth.redirect_uri)
user_info_mock.assert_called_once_with() user_info_mock.assert_called_once_with()
assert email == "email"
async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixture) -> None: async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixture) -> None:
@@ -82,8 +81,7 @@ async def test_get_oauth_username_empty_email(oauth: OAuth, mocker: MockerFixtur
mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", "")) mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
mocker.patch("aioauth_client.GoogleClient.user_info", return_value=(aioauth_client.User(username="username"), "")) mocker.patch("aioauth_client.GoogleClient.user_info", return_value=(aioauth_client.User(username="username"), ""))
username = await oauth.get_oauth_username("code") assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) == "username"
assert username == "username"
async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixture) -> None: async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixture) -> None:
@@ -93,8 +91,7 @@ async def test_get_oauth_username_exception_1(oauth: OAuth, mocker: MockerFixtur
mocker.patch("aioauth_client.GoogleClient.get_access_token", side_effect=Exception) mocker.patch("aioauth_client.GoogleClient.get_access_token", side_effect=Exception)
user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info") user_info_mock = mocker.patch("aioauth_client.GoogleClient.user_info")
email = await oauth.get_oauth_username("code") assert await oauth.get_oauth_username("code", state="state", session={"state": "state"}) is None
assert email is None
user_info_mock.assert_not_called() user_info_mock.assert_not_called()
@@ -105,5 +102,19 @@ async def test_get_oauth_username_exception_2(oauth: OAuth, mocker: MockerFixtur
mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", "")) mocker.patch("aioauth_client.GoogleClient.get_access_token", return_value=("token", ""))
mocker.patch("aioauth_client.GoogleClient.user_info", side_effect=Exception) mocker.patch("aioauth_client.GoogleClient.user_info", side_effect=Exception)
email = await oauth.get_oauth_username("code") username = await oauth.get_oauth_username("code", state="state", session={"state": "state"})
assert email is None assert username is None
async def test_get_oauth_username_csrf_missing(oauth: OAuth) -> None:
"""
must return None if CSRF state is missing
"""
assert await oauth.get_oauth_username("code", state=None, session={"state": "state"}) is None
async def test_get_oauth_username_csrf_mismatch(oauth: OAuth) -> None:
"""
must return None if CSRF state does not match session
"""
assert await oauth.get_oauth_username("code", state="wrong", session={"state": "state"}) is None

View File

@@ -54,7 +54,7 @@ async def test_get_redirect_to_oauth(client_with_oauth_auth: TestClient) -> None
assert not request_schema.validate(payload) assert not request_schema.validate(payload)
response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False) response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False)
assert response.ok assert response.ok
oauth.get_oauth_url.assert_called_once_with() oauth.get_oauth_url.assert_called_once_with(pytest.helpers.anyvar(str))
async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClient) -> None: async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClient) -> None:
@@ -69,13 +69,15 @@ async def test_get_redirect_to_oauth_empty_code(client_with_oauth_auth: TestClie
assert not request_schema.validate(payload) assert not request_schema.validate(payload)
response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False) response = await client_with_oauth_auth.get("/api/v1/login", params=payload, allow_redirects=False)
assert response.ok assert response.ok
oauth.get_oauth_url.assert_called_once_with() oauth.get_oauth_url.assert_called_once_with(pytest.helpers.anyvar(str))
async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) -> None: async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) -> None:
""" """
must log in user correctly from OAuth must log in user correctly from OAuth
""" """
session = {"state": "state"}
mocker.patch("ahriman.web.views.v1.user.login.get_session", return_value=session)
oauth = client_with_oauth_auth.app[AuthKey] oauth = client_with_oauth_auth.app[AuthKey]
oauth.get_oauth_username.return_value = "user" oauth.get_oauth_username.return_value = "user"
oauth.known_username.return_value = True oauth.known_username.return_value = True
@@ -84,12 +86,12 @@ async def test_get(client_with_oauth_auth: TestClient, mocker: MockerFixture) ->
remember_mock = mocker.patch("ahriman.web.views.v1.user.login.remember") remember_mock = mocker.patch("ahriman.web.views.v1.user.login.remember")
request_schema = pytest.helpers.schema_request(LoginView.get, location="querystring") request_schema = pytest.helpers.schema_request(LoginView.get, location="querystring")
payload = {"code": "code"} payload = {"code": "code", "state": "state"}
assert not request_schema.validate(payload) assert not request_schema.validate(payload)
response = await client_with_oauth_auth.get("/api/v1/login", params=payload) response = await client_with_oauth_auth.get("/api/v1/login", params=payload)
assert response.ok assert response.ok
oauth.get_oauth_username.assert_called_once_with("code") oauth.get_oauth_username.assert_called_once_with("code", "state", session)
oauth.known_username.assert_called_once_with("user") oauth.known_username.assert_called_once_with("user")
remember_mock.assert_called_once_with( remember_mock.assert_called_once_with(
pytest.helpers.anyvar(int), pytest.helpers.anyvar(int), pytest.helpers.anyvar(int)) pytest.helpers.anyvar(int), pytest.helpers.anyvar(int), pytest.helpers.anyvar(int))