mirror of
https://github.com/arcan1s/ahriman.git
synced 2026-04-01 06:03:39 +00:00
feat: implement CSRF protection
This commit is contained in:
@@ -18,9 +18,10 @@
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
#
|
||||
from aiohttp.web import HTTPBadRequest, HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized
|
||||
from secrets import token_urlsafe
|
||||
from typing import ClassVar
|
||||
|
||||
from ahriman.core.auth.helpers import remember
|
||||
from ahriman.core.auth.helpers import get_session, remember
|
||||
from ahriman.models.user_access import UserAccess
|
||||
from ahriman.web.apispec.decorators import apidocs
|
||||
from ahriman.web.schemas import LoginSchema, OAuth2Schema
|
||||
@@ -68,15 +69,18 @@ class LoginView(BaseView):
|
||||
raise HTTPMethodNotAllowed(self.request.method, ["POST"])
|
||||
|
||||
oauth_provider = self.validator
|
||||
if not isinstance(oauth_provider, OAuth): # there is actually property, but mypy does not like it anyway
|
||||
if not isinstance(oauth_provider, OAuth):
|
||||
raise HTTPMethodNotAllowed(self.request.method, ["POST"])
|
||||
|
||||
session = await get_session(self.request)
|
||||
|
||||
code = self.request.query.get("code")
|
||||
if not code:
|
||||
raise HTTPFound(oauth_provider.get_oauth_url())
|
||||
state = session["state"] = token_urlsafe()
|
||||
raise HTTPFound(oauth_provider.get_oauth_url(state))
|
||||
|
||||
response = HTTPFound("/")
|
||||
identity = await oauth_provider.get_oauth_username(code)
|
||||
identity = await oauth_provider.get_oauth_username(code, self.request.query.get("state"), session)
|
||||
if identity is not None and await self.validator.known_username(identity):
|
||||
await remember(self.request, response, identity)
|
||||
raise response
|
||||
|
||||
Reference in New Issue
Block a user