mirror of
https://github.com/arcan1s/ahriman.git
synced 2026-03-14 05:53:39 +00:00
feat: implement CSRF protection
This commit is contained in:
@@ -22,6 +22,11 @@ try:
|
||||
except ImportError:
|
||||
aiohttp_security = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
import aiohttp_session
|
||||
except ImportError:
|
||||
aiohttp_session = None # type: ignore[assignment]
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
@@ -50,7 +55,7 @@ async def check_authorized(*args: Any, **kwargs: Any) -> Any:
|
||||
|
||||
Args:
|
||||
*args(Any): argument list as provided by check_authorized function
|
||||
**kwargs(Any): named argument list as provided by authorized_userid function
|
||||
**kwargs(Any): named argument list as provided by check_authorized function
|
||||
|
||||
Returns:
|
||||
Any: ``None`` in case if no aiohttp_security module found and function call otherwise
|
||||
@@ -66,7 +71,7 @@ async def forget(*args: Any, **kwargs: Any) -> Any:
|
||||
|
||||
Args:
|
||||
*args(Any): argument list as provided by forget function
|
||||
**kwargs(Any): named argument list as provided by authorized_userid function
|
||||
**kwargs(Any): named argument list as provided by forget function
|
||||
|
||||
Returns:
|
||||
Any: ``None`` in case if no aiohttp_security module found and function call otherwise
|
||||
@@ -76,13 +81,29 @@ async def forget(*args: Any, **kwargs: Any) -> Any:
|
||||
return None
|
||||
|
||||
|
||||
async def get_session(*args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
handle aiohttp session methods
|
||||
|
||||
Args:
|
||||
*args(Any): argument list as provided by get_session function
|
||||
**kwargs(Any): named argument list as provided by get_session function
|
||||
|
||||
Returns:
|
||||
Any: empty dictionary in case if no aiohttp_session module found and function call otherwise
|
||||
"""
|
||||
if aiohttp_session is not None:
|
||||
return await aiohttp_session.get_session(*args, **kwargs)
|
||||
return {}
|
||||
|
||||
|
||||
async def remember(*args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
handle disabled auth
|
||||
|
||||
Args:
|
||||
*args(Any): argument list as provided by remember function
|
||||
**kwargs(Any): named argument list as provided by authorized_userid function
|
||||
**kwargs(Any): named argument list as provided by remember function
|
||||
|
||||
Returns:
|
||||
Any: ``None`` in case if no aiohttp_security module found and function call otherwise
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
#
|
||||
import aioauth_client
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ahriman.core.auth.mapping import Mapping
|
||||
from ahriman.core.configuration import Configuration
|
||||
from ahriman.core.database import SQLite
|
||||
@@ -53,7 +55,7 @@ class OAuth(Mapping):
|
||||
self.client_secret = configuration.get("auth", "client_secret")
|
||||
# in order to use OAuth feature the service must be publicity available
|
||||
# thus we expect that address is set
|
||||
self.redirect_uri = f"""{configuration.get("web", "address")}/api/v1/login"""
|
||||
self.redirect_uri = f"{configuration.get("web", "address")}/api/v1/login"
|
||||
self.provider = self.get_provider(configuration.get("auth", "oauth_provider"))
|
||||
# it is list, but we will have to convert to string it anyway
|
||||
self.scopes = configuration.get("auth", "oauth_scopes")
|
||||
@@ -102,27 +104,35 @@ class OAuth(Mapping):
|
||||
"""
|
||||
return self.provider(client_id=self.client_id, client_secret=self.client_secret)
|
||||
|
||||
def get_oauth_url(self) -> str:
|
||||
def get_oauth_url(self, state: str) -> str:
|
||||
"""
|
||||
get authorization URI for the specified settings
|
||||
|
||||
Args:
|
||||
state(str): CSRF token to pass to OAuth2 provider
|
||||
|
||||
Returns:
|
||||
str: authorization URI as a string
|
||||
"""
|
||||
client = self.get_client()
|
||||
uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri)
|
||||
uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri, state=state)
|
||||
return uri
|
||||
|
||||
async def get_oauth_username(self, code: str) -> str | None:
|
||||
async def get_oauth_username(self, code: str, state: str | None, session: dict[str, Any]) -> str | None:
|
||||
"""
|
||||
extract OAuth username from remote
|
||||
|
||||
Args:
|
||||
code(str): authorization code provided by external service
|
||||
state(str | None): CSRF token returned by external service
|
||||
session(dict[str, Any]): current session instance
|
||||
|
||||
Returns:
|
||||
str | None: username as is in OAuth provider
|
||||
"""
|
||||
if state is None or state != session.get("state"):
|
||||
return None
|
||||
|
||||
try:
|
||||
client = self.get_client()
|
||||
access_token, _ = await client.get_access_token(code, redirect_uri=self.redirect_uri)
|
||||
|
||||
@@ -28,3 +28,6 @@ class OAuth2Schema(Schema):
|
||||
code = fields.String(metadata={
|
||||
"description": "OAuth2 authorization code. In case if not set, the redirect to provider will be initiated",
|
||||
})
|
||||
state = fields.String(metadata={
|
||||
"description": "CSRF token returned by OAuth2 provider",
|
||||
})
|
||||
|
||||
@@ -18,9 +18,10 @@
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
from aiohttp.web import HTTPBadRequest, HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized
|
||||
from secrets import token_urlsafe
|
||||
from typing import ClassVar
|
||||
|
||||
from ahriman.core.auth.helpers import remember
|
||||
from ahriman.core.auth.helpers import get_session, remember
|
||||
from ahriman.models.user_access import UserAccess
|
||||
from ahriman.web.apispec.decorators import apidocs
|
||||
from ahriman.web.schemas import LoginSchema, OAuth2Schema
|
||||
@@ -68,15 +69,18 @@ class LoginView(BaseView):
|
||||
raise HTTPMethodNotAllowed(self.request.method, ["POST"])
|
||||
|
||||
oauth_provider = self.validator
|
||||
if not isinstance(oauth_provider, OAuth): # there is actually property, but mypy does not like it anyway
|
||||
if not isinstance(oauth_provider, OAuth):
|
||||
raise HTTPMethodNotAllowed(self.request.method, ["POST"])
|
||||
|
||||
session = await get_session(self.request)
|
||||
|
||||
code = self.request.query.get("code")
|
||||
if not code:
|
||||
raise HTTPFound(oauth_provider.get_oauth_url())
|
||||
state = session["state"] = token_urlsafe()
|
||||
raise HTTPFound(oauth_provider.get_oauth_url(state))
|
||||
|
||||
response = HTTPFound("/")
|
||||
identity = await oauth_provider.get_oauth_username(code)
|
||||
identity = await oauth_provider.get_oauth_username(code, self.request.query.get("state"), session)
|
||||
if identity is not None and await self.validator.known_username(identity):
|
||||
await remember(self.request, response, identity)
|
||||
raise response
|
||||
|
||||
Reference in New Issue
Block a user