# # Copyright (c) 2021-2025 ahriman team. # # This file is part of ahriman # (see https://github.com/arcan1s/ahriman). # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . # import aioauth_client from ahriman.core.auth.mapping import Mapping from ahriman.core.configuration import Configuration from ahriman.core.database import SQLite from ahriman.core.exceptions import OptionError from ahriman.models.auth_settings import AuthSettings class OAuth(Mapping): """ User authorization implementation via OAuth. It is required to create application first and put application credentials. Attributes: client_id(str): application client id client_secret(str): application client secret key icon(str): icon to be used in login control provider(aioauth_client.OAuth2Client): provider class, should be one of aiohttp-client provided classes redirect_uri(str): redirect URI registered in provider scopes(str): list of scopes required by the application """ def __init__(self, configuration: Configuration, database: SQLite, provider: AuthSettings = AuthSettings.OAuth) -> None: """ Args: configuration(Configuration): configuration instance database(SQLite): database instance provider(AuthSettings, optional): authorization type definition (Default value = AuthSettings.OAuth) """ Mapping.__init__(self, configuration, database, provider) self.client_id = configuration.get("auth", "client_id") self.client_secret = configuration.get("auth", "client_secret") # in order to use OAuth feature the service must be publicity available # thus we expect that address is set self.redirect_uri = f"""{configuration.get("web", "address")}/api/v1/login""" self.provider = self.get_provider(configuration.get("auth", "oauth_provider")) # it is list, but we will have to convert to string it anyway self.scopes = configuration.get("auth", "oauth_scopes") self.icon = configuration.get("auth", "oauth_icon", fallback="google") @property def auth_control(self) -> str: """ get authorization html control Returns: str: login control as html code to insert """ return f""" login""" @staticmethod def get_provider(name: str) -> type[aioauth_client.OAuth2Client]: """ load OAuth2 provider by name Args: name(str): name of the provider. Must be valid class defined in aioauth-client library Returns: type[aioauth_client.OAuth2Client]: loaded provider type Raises: OptionError: in case if invalid OAuth provider name supplied """ provider: type[aioauth_client.OAuth2Client] = getattr(aioauth_client, name) try: is_oauth2_client = issubclass(provider, aioauth_client.OAuth2Client) except TypeError: # what if it is random string? is_oauth2_client = False if not is_oauth2_client: raise OptionError(name) return provider def get_client(self) -> aioauth_client.OAuth2Client: """ load client from parameters Returns: aioauth_client.OAuth2Client: generated client according to current settings """ return self.provider(client_id=self.client_id, client_secret=self.client_secret) def get_oauth_url(self) -> str: """ get authorization URI for the specified settings Returns: str: authorization URI as a string """ client = self.get_client() uri: str = client.get_authorize_url(scope=self.scopes, redirect_uri=self.redirect_uri) return uri async def get_oauth_username(self, code: str) -> str | None: """ extract OAuth username from remote Args: code(str): authorization code provided by external service Returns: str | None: username as is in OAuth provider """ try: client = self.get_client() access_token, _ = await client.get_access_token(code, redirect_uri=self.redirect_uri) client.access_token = access_token user, _ = await client.user_info() username: str = user.email or user.username # type: ignore[attr-defined] return username except Exception: self.logger.exception("got exception while performing request") return None