mirror of
				https://github.com/arcan1s/ahriman.git
				synced 2025-11-04 07:43:42 +00:00 
			
		
		
		
	Compare commits
	
		
			4 Commits
		
	
	
		
			2.7.0
			...
			a8c40a6b87
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| a8c40a6b87 | |||
| a274f91677 | |||
| 13faf66bdb | |||
| 4fb9335df9 | 
@ -1 +1 @@
 | 
			
		||||
skips: ['B101', 'B105', 'B106', 'B404']
 | 
			
		||||
skips: ['B101', 'B104', 'B105', 'B106', 'B404']
 | 
			
		||||
 | 
			
		||||
@ -196,14 +196,6 @@ ahriman.models.user\_access module
 | 
			
		||||
   :no-undoc-members:
 | 
			
		||||
   :show-inheritance:
 | 
			
		||||
 | 
			
		||||
ahriman.models.user\_identity module
 | 
			
		||||
------------------------------------
 | 
			
		||||
 | 
			
		||||
.. automodule:: ahriman.models.user_identity
 | 
			
		||||
   :members:
 | 
			
		||||
   :no-undoc-members:
 | 
			
		||||
   :show-inheritance:
 | 
			
		||||
 | 
			
		||||
Module contents
 | 
			
		||||
---------------
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -50,6 +50,7 @@ Base authorization settings. ``OAuth`` provider requires ``aioauth-client`` libr
 | 
			
		||||
* ``allow_read_only`` - allow requesting status APIs without authorization, boolean, required.
 | 
			
		||||
* ``client_id`` - OAuth2 application client ID, string, required in case if ``oauth`` is used.
 | 
			
		||||
* ``client_secret`` - OAuth2 application client secret key, string, required in case if ``oauth`` is used.
 | 
			
		||||
* ``cookie_secret_key`` - secret key which will be used for cookies encryption, string, optional. It must be 32 url-safe base64-encoded bytes and can be generated as following ``base64.urlsafe_b64encode(os.urandom(32)).decode("utf8")``. If not set, it will be generated automatically; note, however, that in this case, all sessions will be automatically expired during restart.
 | 
			
		||||
* ``max_age`` - parameter which controls both cookie expiration and token expiration inside the service, integer, optional, default is 7 days.
 | 
			
		||||
* ``oauth_provider`` - OAuth2 provider class name as is in ``aioauth-client`` (e.g. ``GoogleClient``, ``GithubClient`` etc), string, required in case if ``oauth`` is used.
 | 
			
		||||
* ``oauth_scopes`` - scopes list for OAuth2 provider, which will allow retrieving user email (which is used for checking user permissions), e.g. ``https://www.googleapis.com/auth/userinfo.email`` for ``GoogleClient`` or ``user:email`` for ``GithubClient``, space separated list of strings, required in case if ``oauth`` is used.
 | 
			
		||||
@ -68,7 +69,7 @@ Build related configuration. Group name can refer to architecture, e.g. ``build:
 | 
			
		||||
* ``makepkg_flags`` - additional flags passed to ``makepkg`` command, space separated list of strings, optional.
 | 
			
		||||
* ``makechrootpkg_flags`` - additional flags passed to ``makechrootpkg`` command, space separated list of strings, optional.
 | 
			
		||||
* ``triggers`` - list of ``ahriman.core.triggers.Trigger`` class implementation (e.g. ``ahriman.core.report.ReportTrigger ahriman.core.upload.UploadTrigger``) which will be loaded and run at the end of processing, space separated list of strings, optional. You can also specify triggers by their paths, e.g. ``/usr/lib/python3.10/site-packages/ahriman/core/report/report.py.ReportTrigger``. Triggers are run in the order of mention.
 | 
			
		||||
* ``vcs_allowed_age`` - maximal age in seconds of the VCS packages before their version will be updated with its remote source, int, optional, default ``0``.
 | 
			
		||||
* ``vcs_allowed_age`` - maximal age in seconds of the VCS packages before their version will be updated with its remote source, int, optional, default ``604800``.
 | 
			
		||||
 | 
			
		||||
``repository`` group
 | 
			
		||||
--------------------
 | 
			
		||||
 | 
			
		||||
@ -52,7 +52,7 @@ class Validate(Handler):
 | 
			
		||||
            unsafe(bool): if set no user check will be performed before path creation
 | 
			
		||||
        """
 | 
			
		||||
        schema = Validate.schema(architecture, configuration)
 | 
			
		||||
        validator = Validator(instance=configuration, schema=schema)
 | 
			
		||||
        validator = Validator(configuration=configuration, schema=schema)
 | 
			
		||||
 | 
			
		||||
        if validator.validate(configuration.dump()):
 | 
			
		||||
            return  # no errors found
 | 
			
		||||
 | 
			
		||||
@ -56,7 +56,7 @@ class Configuration(configparser.RawConfigParser):
 | 
			
		||||
        architecture according to the merge rules. Moreover, the architecture names will be removed from section names.
 | 
			
		||||
 | 
			
		||||
        In order to get current settings, the ``check_loaded`` method can be used. This method will raise an
 | 
			
		||||
        ``InitializeException`` in case if configuration was not yet loaded::
 | 
			
		||||
        ``InitializeError`` in case if configuration was not yet loaded::
 | 
			
		||||
 | 
			
		||||
            >>> path, architecture = configuration.check_loaded()
 | 
			
		||||
    """
 | 
			
		||||
@ -165,7 +165,7 @@ class Configuration(configparser.RawConfigParser):
 | 
			
		||||
            Tuple[Path, str]: configuration root path and architecture if loaded
 | 
			
		||||
 | 
			
		||||
        Raises:
 | 
			
		||||
            InitializeException: in case if architecture and/or path are not set
 | 
			
		||||
            InitializeError: in case if architecture and/or path are not set
 | 
			
		||||
        """
 | 
			
		||||
        if self.path is None or self.architecture is None:
 | 
			
		||||
            raise InitializeError("Configuration path and/or architecture are not set")
 | 
			
		||||
 | 
			
		||||
@ -64,6 +64,7 @@ CONFIGURATION_SCHEMA: ConfigurationSchema = {
 | 
			
		||||
            "mirror": {
 | 
			
		||||
                "type": "string",
 | 
			
		||||
                "required": True,
 | 
			
		||||
                "is_url": [],
 | 
			
		||||
            },
 | 
			
		||||
            "repositories": {
 | 
			
		||||
                "type": "list",
 | 
			
		||||
@ -109,9 +110,15 @@ CONFIGURATION_SCHEMA: ConfigurationSchema = {
 | 
			
		||||
            "client_secret": {
 | 
			
		||||
                "type": "string",
 | 
			
		||||
            },
 | 
			
		||||
            "cookie_secret_key": {
 | 
			
		||||
                "type": "string",
 | 
			
		||||
                "minlength": 32,
 | 
			
		||||
                "maxlength": 64,  # we cannot verify maxlength, because base64 representation might be longer than bytes
 | 
			
		||||
            },
 | 
			
		||||
            "max_age": {
 | 
			
		||||
                "type": "integer",
 | 
			
		||||
                "coerce": "integer",
 | 
			
		||||
                "min": 0,
 | 
			
		||||
            },
 | 
			
		||||
            "oauth_provider": {
 | 
			
		||||
                "type": "string",
 | 
			
		||||
@ -159,6 +166,7 @@ CONFIGURATION_SCHEMA: ConfigurationSchema = {
 | 
			
		||||
            "vcs_allowed_age": {
 | 
			
		||||
                "type": "integer",
 | 
			
		||||
                "coerce": "integer",
 | 
			
		||||
                "min": 0,
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
@ -201,6 +209,7 @@ CONFIGURATION_SCHEMA: ConfigurationSchema = {
 | 
			
		||||
        "schema": {
 | 
			
		||||
            "address": {
 | 
			
		||||
                "type": "string",
 | 
			
		||||
                "is_url": ["http", "https"],
 | 
			
		||||
            },
 | 
			
		||||
            "debug": {
 | 
			
		||||
                "type": "boolean",
 | 
			
		||||
@ -217,9 +226,11 @@ CONFIGURATION_SCHEMA: ConfigurationSchema = {
 | 
			
		||||
            },
 | 
			
		||||
            "host": {
 | 
			
		||||
                "type": "string",
 | 
			
		||||
                "is_ip_address": ["localhost"],
 | 
			
		||||
            },
 | 
			
		||||
            "index_url": {
 | 
			
		||||
                "type": "string",
 | 
			
		||||
                "is_url": ["http", "https"],
 | 
			
		||||
            },
 | 
			
		||||
            "password": {
 | 
			
		||||
                "type": "string",
 | 
			
		||||
@ -255,44 +266,4 @@ CONFIGURATION_SCHEMA: ConfigurationSchema = {
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "remote-pull": {
 | 
			
		||||
        "type": "dict",
 | 
			
		||||
        "schema": {
 | 
			
		||||
            "target": {
 | 
			
		||||
                "type": "list",
 | 
			
		||||
                "coerce": "list",
 | 
			
		||||
                "schema": {"type": "string"},
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "remote-push": {
 | 
			
		||||
        "type": "dict",
 | 
			
		||||
        "schema": {
 | 
			
		||||
            "target": {
 | 
			
		||||
                "type": "list",
 | 
			
		||||
                "coerce": "list",
 | 
			
		||||
                "schema": {"type": "string"},
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "report": {
 | 
			
		||||
        "type": "dict",
 | 
			
		||||
        "schema": {
 | 
			
		||||
            "target": {
 | 
			
		||||
                "type": "list",
 | 
			
		||||
                "coerce": "list",
 | 
			
		||||
                "schema": {"type": "string"},
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "upload": {
 | 
			
		||||
        "type": "dict",
 | 
			
		||||
        "schema": {
 | 
			
		||||
            "target": {
 | 
			
		||||
                "type": "list",
 | 
			
		||||
                "coerce": "list",
 | 
			
		||||
                "schema": {"type": "string"},
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -17,9 +17,12 @@
 | 
			
		||||
# You should have received a copy of the GNU General Public License
 | 
			
		||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
#
 | 
			
		||||
import ipaddress
 | 
			
		||||
 | 
			
		||||
from cerberus import TypeDefinition, Validator as RootValidator  # type: ignore
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Any, List
 | 
			
		||||
from urllib.parse import urlparse
 | 
			
		||||
 | 
			
		||||
from ahriman.core.configuration import Configuration
 | 
			
		||||
 | 
			
		||||
@ -29,7 +32,7 @@ class Validator(RootValidator):  # type: ignore
 | 
			
		||||
    class which defines custom validation methods for the service configuration
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        instance(Configuration): configuration instance
 | 
			
		||||
        configuration(Configuration): configuration instance
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    types_mapping = RootValidator.types_mapping.copy()
 | 
			
		||||
@ -40,12 +43,12 @@ class Validator(RootValidator):  # type: ignore
 | 
			
		||||
        default constructor
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            instance(Configuration): configuration instance used for extraction
 | 
			
		||||
            configuration(Configuration): configuration instance used for extraction
 | 
			
		||||
            *args(Any): positional arguments to be passed to base validator
 | 
			
		||||
            **kwargs(): keyword arguments to be passed to base validator
 | 
			
		||||
        """
 | 
			
		||||
        RootValidator.__init__(self, *args, **kwargs)
 | 
			
		||||
        self.instance: Configuration = kwargs["instance"]
 | 
			
		||||
        self.configuration: Configuration = kwargs["configuration"]
 | 
			
		||||
 | 
			
		||||
    def _normalize_coerce_absolute_path(self, value: str) -> Path:
 | 
			
		||||
        """
 | 
			
		||||
@ -57,7 +60,7 @@ class Validator(RootValidator):  # type: ignore
 | 
			
		||||
        Returns:
 | 
			
		||||
            Path: value converted to path instance according to configuration rules
 | 
			
		||||
        """
 | 
			
		||||
        converted: Path = self.instance.converters["path"](value)
 | 
			
		||||
        converted: Path = self.configuration.converters["path"](value)
 | 
			
		||||
        return converted
 | 
			
		||||
 | 
			
		||||
    def _normalize_coerce_boolean(self, value: str) -> bool:
 | 
			
		||||
@ -71,7 +74,7 @@ class Validator(RootValidator):  # type: ignore
 | 
			
		||||
            bool: value converted to boolean according to configuration rules
 | 
			
		||||
        """
 | 
			
		||||
        # pylint: disable=protected-access
 | 
			
		||||
        converted: bool = self.instance._convert_to_boolean(value)  # type: ignore
 | 
			
		||||
        converted: bool = self.configuration._convert_to_boolean(value)  # type: ignore
 | 
			
		||||
        return converted
 | 
			
		||||
 | 
			
		||||
    def _normalize_coerce_integer(self, value: str) -> int:
 | 
			
		||||
@ -97,9 +100,50 @@ class Validator(RootValidator):  # type: ignore
 | 
			
		||||
        Returns:
 | 
			
		||||
            List[str]: value converted to string list instance according to configuration rules
 | 
			
		||||
        """
 | 
			
		||||
        converted: List[str] = self.instance.converters["list"](value)
 | 
			
		||||
        converted: List[str] = self.configuration.converters["list"](value)
 | 
			
		||||
        return converted
 | 
			
		||||
 | 
			
		||||
    def _validate_is_ip_address(self, constraint: List[str], field: str, value: str) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        check if the specified value is valid ip address
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            constraint(List[str]): optional list of allowed special words (e.g. ``localhost``)
 | 
			
		||||
            field(str): field name to be checked
 | 
			
		||||
            value(Path): value to be checked
 | 
			
		||||
 | 
			
		||||
        Examples:
 | 
			
		||||
            The rule's arguments are validated against this schema:
 | 
			
		||||
            {"type": "list", "schema": {"type": "string"}}
 | 
			
		||||
        """
 | 
			
		||||
        if value in constraint:
 | 
			
		||||
            return
 | 
			
		||||
        try:
 | 
			
		||||
            ipaddress.ip_address(value)
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            self._error(field, f"Value {value} must be valid IP address")
 | 
			
		||||
 | 
			
		||||
    def _validate_is_url(self, constraint: List[str], field: str, value: str) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        check if the specified value is a valid url
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            constraint(List[str]): optional list of supported schemas. If empty, no schema validation will be performed
 | 
			
		||||
            field(str): field name to be checked
 | 
			
		||||
            value(str): value to be checked
 | 
			
		||||
 | 
			
		||||
        Examples:
 | 
			
		||||
            The rule's arguments are validated against this schema:
 | 
			
		||||
            {"type": "list", "schema": {"type": "string"}}
 | 
			
		||||
        """
 | 
			
		||||
        url = urlparse(value)  # it probably will never rise exceptions on parse
 | 
			
		||||
        if not url.scheme:
 | 
			
		||||
            self._error(field, f"Url scheme is not set for {value}")
 | 
			
		||||
        if not url.netloc and url.scheme not in ("file",):
 | 
			
		||||
            self._error(field, f"Location must be set for url {value} of scheme {url.scheme}")
 | 
			
		||||
        if constraint and url.scheme not in constraint:
 | 
			
		||||
            self._error(field, f"Url {value} scheme must be one of {constraint}")
 | 
			
		||||
 | 
			
		||||
    def _validate_path_exists(self, constraint: bool, field: str, value: Path) -> None:
 | 
			
		||||
        """
 | 
			
		||||
        check if paths exists
 | 
			
		||||
 | 
			
		||||
@ -33,6 +33,16 @@ class RemotePullTrigger(Trigger):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    CONFIGURATION_SCHEMA = {
 | 
			
		||||
        "remote-pull": {
 | 
			
		||||
            "type": "dict",
 | 
			
		||||
            "schema": {
 | 
			
		||||
                "target": {
 | 
			
		||||
                    "type": "list",
 | 
			
		||||
                    "coerce": "list",
 | 
			
		||||
                    "schema": {"type": "string"},
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        "gitremote": {
 | 
			
		||||
            "type": "dict",
 | 
			
		||||
            "schema": {
 | 
			
		||||
 | 
			
		||||
@ -38,6 +38,16 @@ class RemotePushTrigger(Trigger):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    CONFIGURATION_SCHEMA = {
 | 
			
		||||
        "remote-push": {
 | 
			
		||||
            "type": "dict",
 | 
			
		||||
            "schema": {
 | 
			
		||||
                "target": {
 | 
			
		||||
                    "type": "list",
 | 
			
		||||
                    "coerce": "list",
 | 
			
		||||
                    "schema": {"type": "string"},
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        "gitremote": {
 | 
			
		||||
            "type": "dict",
 | 
			
		||||
            "schema": {
 | 
			
		||||
 | 
			
		||||
@ -35,6 +35,16 @@ class ReportTrigger(Trigger):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    CONFIGURATION_SCHEMA = {
 | 
			
		||||
        "report": {
 | 
			
		||||
            "type": "dict",
 | 
			
		||||
            "schema": {
 | 
			
		||||
                "target": {
 | 
			
		||||
                    "type": "list",
 | 
			
		||||
                    "coerce": "list",
 | 
			
		||||
                    "schema": {"type": "string"},
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        "console": {
 | 
			
		||||
            "type": "dict",
 | 
			
		||||
            "schema": {
 | 
			
		||||
@ -62,6 +72,7 @@ class ReportTrigger(Trigger):
 | 
			
		||||
                },
 | 
			
		||||
                "homepage": {
 | 
			
		||||
                    "type": "string",
 | 
			
		||||
                    "is_url": ["http", "https"],
 | 
			
		||||
                },
 | 
			
		||||
                "host": {
 | 
			
		||||
                    "type": "string",
 | 
			
		||||
@ -70,6 +81,7 @@ class ReportTrigger(Trigger):
 | 
			
		||||
                "link_path": {
 | 
			
		||||
                    "type": "string",
 | 
			
		||||
                    "required": True,
 | 
			
		||||
                    "is_url": [],
 | 
			
		||||
                },
 | 
			
		||||
                "no_empty_report": {
 | 
			
		||||
                    "type": "boolean",
 | 
			
		||||
@ -82,6 +94,8 @@ class ReportTrigger(Trigger):
 | 
			
		||||
                    "type": "integer",
 | 
			
		||||
                    "coerce": "integer",
 | 
			
		||||
                    "required": True,
 | 
			
		||||
                    "min": 0,
 | 
			
		||||
                    "max": 65535,
 | 
			
		||||
                },
 | 
			
		||||
                "receivers": {
 | 
			
		||||
                    "type": "list",
 | 
			
		||||
@ -118,10 +132,12 @@ class ReportTrigger(Trigger):
 | 
			
		||||
                },
 | 
			
		||||
                "homepage": {
 | 
			
		||||
                    "type": "string",
 | 
			
		||||
                    "is_url": ["http", "https"],
 | 
			
		||||
                },
 | 
			
		||||
                "link_path": {
 | 
			
		||||
                    "type": "string",
 | 
			
		||||
                    "required": True,
 | 
			
		||||
                    "is_url": [],
 | 
			
		||||
                },
 | 
			
		||||
                "path": {
 | 
			
		||||
                    "type": "path",
 | 
			
		||||
@ -153,10 +169,12 @@ class ReportTrigger(Trigger):
 | 
			
		||||
                },
 | 
			
		||||
                "homepage": {
 | 
			
		||||
                    "type": "string",
 | 
			
		||||
                    "is_url": ["http", "https"],
 | 
			
		||||
                },
 | 
			
		||||
                "link_path": {
 | 
			
		||||
                    "type": "string",
 | 
			
		||||
                    "required": True,
 | 
			
		||||
                    "is_url": [],
 | 
			
		||||
                },
 | 
			
		||||
                "template_path": {
 | 
			
		||||
                    "type": "path",
 | 
			
		||||
@ -171,6 +189,7 @@ class ReportTrigger(Trigger):
 | 
			
		||||
                "timeout": {
 | 
			
		||||
                    "type": "integer",
 | 
			
		||||
                    "coerce": "integer",
 | 
			
		||||
                    "min": 0,
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
 | 
			
		||||
@ -35,6 +35,16 @@ class UploadTrigger(Trigger):
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    CONFIGURATION_SCHEMA = {
 | 
			
		||||
        "upload": {
 | 
			
		||||
            "type": "dict",
 | 
			
		||||
            "schema": {
 | 
			
		||||
                "target": {
 | 
			
		||||
                    "type": "list",
 | 
			
		||||
                    "coerce": "list",
 | 
			
		||||
                    "schema": {"type": "string"},
 | 
			
		||||
                },
 | 
			
		||||
            },
 | 
			
		||||
        },
 | 
			
		||||
        "github": {
 | 
			
		||||
            "type": "dict",
 | 
			
		||||
            "schema": {
 | 
			
		||||
@ -57,6 +67,7 @@ class UploadTrigger(Trigger):
 | 
			
		||||
                "timeout": {
 | 
			
		||||
                    "type": "integer",
 | 
			
		||||
                    "coerce": "integer",
 | 
			
		||||
                    "min": 0,
 | 
			
		||||
                },
 | 
			
		||||
                "username": {
 | 
			
		||||
                    "type": "string",
 | 
			
		||||
@ -101,6 +112,7 @@ class UploadTrigger(Trigger):
 | 
			
		||||
                "chunk_size": {
 | 
			
		||||
                    "type": "integer",
 | 
			
		||||
                    "coerce": "integer",
 | 
			
		||||
                    "min": 0,
 | 
			
		||||
                },
 | 
			
		||||
                "region": {
 | 
			
		||||
                    "type": "string",
 | 
			
		||||
 | 
			
		||||
@ -1,102 +0,0 @@
 | 
			
		||||
#
 | 
			
		||||
# Copyright (c) 2021-2023 ahriman team.
 | 
			
		||||
#
 | 
			
		||||
# This file is part of ahriman
 | 
			
		||||
# (see https://github.com/arcan1s/ahriman).
 | 
			
		||||
#
 | 
			
		||||
# This program is free software: you can redistribute it and/or modify
 | 
			
		||||
# it under the terms of the GNU General Public License as published by
 | 
			
		||||
# the Free Software Foundation, either version 3 of the License, or
 | 
			
		||||
# (at your option) any later version.
 | 
			
		||||
#
 | 
			
		||||
# This program is distributed in the hope that it will be useful,
 | 
			
		||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
 | 
			
		||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 | 
			
		||||
# GNU General Public License for more details.
 | 
			
		||||
#
 | 
			
		||||
# You should have received a copy of the GNU General Public License
 | 
			
		||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
#
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Optional, Type
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass(frozen=True)
 | 
			
		||||
class UserIdentity:
 | 
			
		||||
    """
 | 
			
		||||
    user identity used inside web service
 | 
			
		||||
 | 
			
		||||
    Attributes:
 | 
			
		||||
        username(str): username
 | 
			
		||||
        expire_at(int): identity expiration timestamp
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    username: str
 | 
			
		||||
    expire_at: int
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_identity(cls: Type[UserIdentity], identity: str) -> Optional[UserIdentity]:
 | 
			
		||||
        """
 | 
			
		||||
        parse identity into object
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            identity(str): identity from session data
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Optional[UserIdentity]: user identity object if it can be parsed and not expired and None otherwise
 | 
			
		||||
        """
 | 
			
		||||
        try:
 | 
			
		||||
            username, expire_at = identity.split()
 | 
			
		||||
            user = cls(username, int(expire_at))
 | 
			
		||||
            return None if user.is_expired() else user
 | 
			
		||||
        except ValueError:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def from_username(cls: Type[UserIdentity], username: Optional[str], max_age: int) -> Optional[UserIdentity]:
 | 
			
		||||
        """
 | 
			
		||||
        generate identity from username
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            username(Optional[str]): username
 | 
			
		||||
            max_age(int): time to expire, seconds
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            Optional[UserIdentity]: constructed identity object
 | 
			
		||||
        """
 | 
			
		||||
        return cls(username, cls.expire_when(max_age)) if username is not None else None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def expire_when(max_age: int) -> int:
 | 
			
		||||
        """
 | 
			
		||||
        generate expiration time using delta
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            max_age(int): time delta to generate. Must be usually TTE
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            int: expiration timestamp
 | 
			
		||||
        """
 | 
			
		||||
        return int(time.time()) + max_age
 | 
			
		||||
 | 
			
		||||
    def is_expired(self) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
        compare timestamp with current timestamp and return True in case if identity is expired
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            bool: True in case if identity is expired and False otherwise
 | 
			
		||||
        """
 | 
			
		||||
        return self.expire_when(0) > self.expire_at
 | 
			
		||||
 | 
			
		||||
    def to_identity(self) -> str:
 | 
			
		||||
        """
 | 
			
		||||
        convert object to identity representation
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            str: web service identity
 | 
			
		||||
        """
 | 
			
		||||
        return f"{self.username} {self.expire_at}"
 | 
			
		||||
@ -18,7 +18,6 @@
 | 
			
		||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
 | 
			
		||||
#
 | 
			
		||||
import aiohttp_security  # type: ignore
 | 
			
		||||
import base64
 | 
			
		||||
import socket
 | 
			
		||||
import types
 | 
			
		||||
 | 
			
		||||
@ -32,12 +31,12 @@ from cryptography import fernet
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from ahriman.core.auth import Auth
 | 
			
		||||
from ahriman.core.configuration import Configuration
 | 
			
		||||
from ahriman.models.user_access import UserAccess
 | 
			
		||||
from ahriman.models.user_identity import UserIdentity
 | 
			
		||||
from ahriman.web.middlewares import HandlerType, MiddlewareType
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__all__ = ["AuthorizationPolicy", "auth_handler", "setup_auth"]
 | 
			
		||||
__all__ = ["AuthorizationPolicy", "auth_handler", "cookie_secret_key", "setup_auth"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy):  # type: ignore
 | 
			
		||||
@ -67,10 +66,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy):  # type
 | 
			
		||||
        Returns:
 | 
			
		||||
            Optional[str]: user identity (username) in case if user exists and None otherwise
 | 
			
		||||
        """
 | 
			
		||||
        user = UserIdentity.from_identity(identity)
 | 
			
		||||
        if user is None:
 | 
			
		||||
            return None
 | 
			
		||||
        return user.username if await self.validator.known_username(user.username) else None
 | 
			
		||||
        return identity if await self.validator.known_username(identity) else None
 | 
			
		||||
 | 
			
		||||
    async def permits(self, identity: str, permission: UserAccess, context: Optional[str] = None) -> bool:
 | 
			
		||||
        """
 | 
			
		||||
@ -84,10 +80,7 @@ class AuthorizationPolicy(aiohttp_security.AbstractAuthorizationPolicy):  # type
 | 
			
		||||
        Returns:
 | 
			
		||||
            bool: True in case if user is allowed to perform this request and False otherwise
 | 
			
		||||
        """
 | 
			
		||||
        user = UserIdentity.from_identity(identity)
 | 
			
		||||
        if user is None:
 | 
			
		||||
            return False
 | 
			
		||||
        return await self.validator.verify_access(user.username, permission, context)
 | 
			
		||||
        return await self.validator.verify_access(identity, permission, context)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def auth_handler(allow_read_only: bool) -> MiddlewareType:
 | 
			
		||||
@ -125,19 +118,36 @@ def auth_handler(allow_read_only: bool) -> MiddlewareType:
 | 
			
		||||
    return handle
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def setup_auth(application: web.Application, validator: Auth) -> web.Application:
 | 
			
		||||
def cookie_secret_key(configuration: Configuration) -> fernet.Fernet:
 | 
			
		||||
    """
 | 
			
		||||
    extract cookie secret key from configuration if set or generate new one
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        configuration(Configuration): configuration instance
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        fernet.Fernet: fernet key instance
 | 
			
		||||
    """
 | 
			
		||||
    if (secret_key := configuration.get("auth", "cookie_secret_key", fallback=None)) is not None:
 | 
			
		||||
        return fernet.Fernet(secret_key)
 | 
			
		||||
 | 
			
		||||
    secret_key = fernet.Fernet.generate_key()
 | 
			
		||||
    return fernet.Fernet(secret_key)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def setup_auth(application: web.Application, configuration: Configuration, validator: Auth) -> web.Application:
 | 
			
		||||
    """
 | 
			
		||||
    setup authorization policies for the application
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        application(web.Application): web application instance
 | 
			
		||||
        configuration(Configuration): configuration instance
 | 
			
		||||
        validator(Auth): authorization module instance
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        web.Application: configured web application
 | 
			
		||||
    """
 | 
			
		||||
    fernet_key = fernet.Fernet.generate_key()
 | 
			
		||||
    secret_key = base64.urlsafe_b64decode(fernet_key)
 | 
			
		||||
    secret_key = cookie_secret_key(configuration)
 | 
			
		||||
    storage = EncryptedCookieStorage(secret_key, cookie_name="API_SESSION", max_age=validator.max_age)
 | 
			
		||||
    setup_session(application, storage)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -21,7 +21,6 @@ from aiohttp.web import HTTPFound, HTTPMethodNotAllowed, HTTPUnauthorized
 | 
			
		||||
 | 
			
		||||
from ahriman.core.auth.helpers import remember
 | 
			
		||||
from ahriman.models.user_access import UserAccess
 | 
			
		||||
from ahriman.models.user_identity import UserIdentity
 | 
			
		||||
from ahriman.web.views.base import BaseView
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -64,10 +63,9 @@ class LoginView(BaseView):
 | 
			
		||||
            raise HTTPFound(oauth_provider.get_oauth_url())
 | 
			
		||||
 | 
			
		||||
        response = HTTPFound("/")
 | 
			
		||||
        username = await oauth_provider.get_oauth_username(code)
 | 
			
		||||
        identity = UserIdentity.from_username(username, self.validator.max_age)
 | 
			
		||||
        if identity is not None and await self.validator.known_username(username):
 | 
			
		||||
            await remember(self.request, response, identity.to_identity())
 | 
			
		||||
        identity = await oauth_provider.get_oauth_username(code)
 | 
			
		||||
        if identity is not None and await self.validator.known_username(identity):
 | 
			
		||||
            await remember(self.request, response, identity)
 | 
			
		||||
            raise response
 | 
			
		||||
 | 
			
		||||
        raise HTTPUnauthorized()
 | 
			
		||||
@ -111,12 +109,11 @@ class LoginView(BaseView):
 | 
			
		||||
                302: Found
 | 
			
		||||
        """
 | 
			
		||||
        data = await self.extract_data()
 | 
			
		||||
        username = data.get("username")
 | 
			
		||||
        identity = data.get("username")
 | 
			
		||||
 | 
			
		||||
        response = HTTPFound("/")
 | 
			
		||||
        identity = UserIdentity.from_username(username, self.validator.max_age)
 | 
			
		||||
        if identity is not None and await self.validator.check_credentials(username, data.get("password")):
 | 
			
		||||
            await remember(self.request, response, identity.to_identity())
 | 
			
		||||
        if identity is not None and await self.validator.check_credentials(identity, data.get("password")):
 | 
			
		||||
            await remember(self.request, response, identity)
 | 
			
		||||
            raise response
 | 
			
		||||
 | 
			
		||||
        raise HTTPUnauthorized()
 | 
			
		||||
 | 
			
		||||
@ -90,7 +90,7 @@ async def on_startup(application: web.Application) -> None:
 | 
			
		||||
        application(web.Application): web application instance
 | 
			
		||||
 | 
			
		||||
    Raises:
 | 
			
		||||
        InitializeException: in case if matched could not be loaded
 | 
			
		||||
        InitializeError: in case if matched could not be loaded
 | 
			
		||||
    """
 | 
			
		||||
    application.logger.info("server started")
 | 
			
		||||
    try:
 | 
			
		||||
@ -168,6 +168,6 @@ def setup_service(architecture: str, configuration: Configuration, spawner: Spaw
 | 
			
		||||
    validator = application["validator"] = Auth.load(configuration, database)
 | 
			
		||||
    if validator.enabled:
 | 
			
		||||
        from ahriman.web.middlewares.auth_handler import setup_auth
 | 
			
		||||
        setup_auth(application, validator)
 | 
			
		||||
        setup_auth(application, configuration, validator)
 | 
			
		||||
 | 
			
		||||
    return application
 | 
			
		||||
 | 
			
		||||
@ -62,9 +62,11 @@ def test_schema(configuration: Configuration) -> None:
 | 
			
		||||
    assert schema.pop("email")
 | 
			
		||||
    assert schema.pop("github")
 | 
			
		||||
    assert schema.pop("html")
 | 
			
		||||
    assert schema.pop("report")
 | 
			
		||||
    assert schema.pop("rsync")
 | 
			
		||||
    assert schema.pop("s3")
 | 
			
		||||
    assert schema.pop("telegram")
 | 
			
		||||
    assert schema.pop("upload")
 | 
			
		||||
 | 
			
		||||
    assert schema == CONFIGURATION_SCHEMA
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -16,4 +16,4 @@ def validator(configuration: Configuration) -> Validator:
 | 
			
		||||
    Returns:
 | 
			
		||||
        Validator: validator test instance
 | 
			
		||||
    """
 | 
			
		||||
    return Validator(instance=configuration, schema=CONFIGURATION_SCHEMA)
 | 
			
		||||
    return Validator(configuration=configuration, schema=CONFIGURATION_SCHEMA)
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from pytest_mock import MockerFixture
 | 
			
		||||
from unittest.mock import MagicMock
 | 
			
		||||
from unittest.mock import MagicMock, call as MockCall
 | 
			
		||||
 | 
			
		||||
from ahriman.core.configuration.validator import Validator
 | 
			
		||||
 | 
			
		||||
@ -18,7 +18,7 @@ def test_normalize_coerce_absolute_path(validator: Validator) -> None:
 | 
			
		||||
    must convert string value to path by using configuration converters
 | 
			
		||||
    """
 | 
			
		||||
    convert_mock = MagicMock()
 | 
			
		||||
    validator.instance.converters["path"] = convert_mock
 | 
			
		||||
    validator.configuration.converters["path"] = convert_mock
 | 
			
		||||
 | 
			
		||||
    validator._normalize_coerce_absolute_path("value")
 | 
			
		||||
    convert_mock.assert_called_once_with("value")
 | 
			
		||||
@ -46,12 +46,56 @@ def test_normalize_coerce_list(validator: Validator) -> None:
 | 
			
		||||
    must convert string value to list by using configuration converters
 | 
			
		||||
    """
 | 
			
		||||
    convert_mock = MagicMock()
 | 
			
		||||
    validator.instance.converters["list"] = convert_mock
 | 
			
		||||
    validator.configuration.converters["list"] = convert_mock
 | 
			
		||||
 | 
			
		||||
    validator._normalize_coerce_list("value")
 | 
			
		||||
    convert_mock.assert_called_once_with("value")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_validate_is_ip_address(validator: Validator, mocker: MockerFixture) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must validate addresses correctly
 | 
			
		||||
    """
 | 
			
		||||
    error_mock = mocker.patch("ahriman.core.configuration.validator.Validator._error")
 | 
			
		||||
 | 
			
		||||
    validator._validate_is_ip_address(["localhost"], "field", "localhost")
 | 
			
		||||
    validator._validate_is_ip_address([], "field", "localhost")
 | 
			
		||||
 | 
			
		||||
    validator._validate_is_ip_address([], "field", "127.0.0.1")
 | 
			
		||||
    validator._validate_is_ip_address([], "field", "::")
 | 
			
		||||
    validator._validate_is_ip_address([], "field", "0.0.0.0")
 | 
			
		||||
 | 
			
		||||
    validator._validate_is_ip_address([], "field", "random string")
 | 
			
		||||
 | 
			
		||||
    error_mock.assert_has_calls([
 | 
			
		||||
        MockCall("field", "Value localhost must be valid IP address"),
 | 
			
		||||
        MockCall("field", "Value random string must be valid IP address"),
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_validate_is_url(validator: Validator, mocker: MockerFixture) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must validate url correctly
 | 
			
		||||
    """
 | 
			
		||||
    error_mock = mocker.patch("ahriman.core.configuration.validator.Validator._error")
 | 
			
		||||
 | 
			
		||||
    validator._validate_is_url([], "field", "http://example.com")
 | 
			
		||||
    validator._validate_is_url([], "field", "https://example.com")
 | 
			
		||||
    validator._validate_is_url([], "field", "file:///tmp")
 | 
			
		||||
 | 
			
		||||
    validator._validate_is_url(["http", "https"], "field", "file:///tmp")
 | 
			
		||||
 | 
			
		||||
    validator._validate_is_url([], "field", "http:///path")
 | 
			
		||||
 | 
			
		||||
    validator._validate_is_url([], "field", "random string")
 | 
			
		||||
 | 
			
		||||
    error_mock.assert_has_calls([
 | 
			
		||||
        MockCall("field", "Url file:///tmp scheme must be one of ['http', 'https']"),
 | 
			
		||||
        MockCall("field", "Location must be set for url http:///path of scheme http"),
 | 
			
		||||
        MockCall("field", "Url scheme is not set for random string"),
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_validate_path_exists(validator: Validator, mocker: MockerFixture) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must validate that paths exists
 | 
			
		||||
@ -67,4 +111,6 @@ def test_validate_path_exists(validator: Validator, mocker: MockerFixture) -> No
 | 
			
		||||
    mocker.patch("pathlib.Path.exists", return_value=True)
 | 
			
		||||
    validator._validate_path_exists(True, "field", Path("3"))
 | 
			
		||||
 | 
			
		||||
    error_mock.assert_called_once_with("field", "Path 2 must exist")
 | 
			
		||||
    error_mock.assert_has_calls([
 | 
			
		||||
        MockCall("field", "Path 2 must exist"),
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
@ -13,7 +13,6 @@ from ahriman.models.package import Package
 | 
			
		||||
from ahriman.models.package_description import PackageDescription
 | 
			
		||||
from ahriman.models.package_source import PackageSource
 | 
			
		||||
from ahriman.models.remote_source import RemoteSource
 | 
			
		||||
from ahriman.models.user_identity import UserIdentity
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
@ -149,14 +148,3 @@ def pyalpm_package_description_ahriman(package_description_ahriman: PackageDescr
 | 
			
		||||
    type(mock).provides = PropertyMock(return_value=package_description_ahriman.provides)
 | 
			
		||||
    type(mock).url = PropertyMock(return_value=package_description_ahriman.url)
 | 
			
		||||
    return mock
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
def user_identity() -> UserIdentity:
 | 
			
		||||
    """
 | 
			
		||||
    identity fixture
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        UserIdentity: user identity test instance
 | 
			
		||||
    """
 | 
			
		||||
    return UserIdentity("username", int(time.time()) + 30)
 | 
			
		||||
 | 
			
		||||
@ -1,64 +0,0 @@
 | 
			
		||||
from ahriman.models.user_identity import UserIdentity
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_from_identity(user_identity: UserIdentity) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must construct identity object from string
 | 
			
		||||
    """
 | 
			
		||||
    identity = UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}")
 | 
			
		||||
    assert identity == user_identity
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_from_identity_expired(user_identity: UserIdentity) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must construct None from expired identity
 | 
			
		||||
    """
 | 
			
		||||
    user_identity = UserIdentity(username=user_identity.username, expire_at=user_identity.expire_at - 60)
 | 
			
		||||
    assert UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}") is None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_from_identity_no_split() -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must construct None from invalid string
 | 
			
		||||
    """
 | 
			
		||||
    assert UserIdentity.from_identity("username") is None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_from_identity_not_int() -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must construct None from invalid timestamp
 | 
			
		||||
    """
 | 
			
		||||
    assert UserIdentity.from_identity("username timestamp") is None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_from_username() -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must construct identity from username
 | 
			
		||||
    """
 | 
			
		||||
    identity = UserIdentity.from_username("username", 0)
 | 
			
		||||
    assert identity.username == "username"
 | 
			
		||||
    # we want to check timestamp too, but later
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_expire_when() -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must return correct expiration time
 | 
			
		||||
    """
 | 
			
		||||
    assert UserIdentity.expire_when(-1) < UserIdentity.expire_when(0) < UserIdentity.expire_when(1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_is_expired(user_identity: UserIdentity) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must return expired flag for expired identities
 | 
			
		||||
    """
 | 
			
		||||
    assert not user_identity.is_expired()
 | 
			
		||||
 | 
			
		||||
    user_identity = UserIdentity(username=user_identity.username, expire_at=user_identity.expire_at - 60)
 | 
			
		||||
    assert user_identity.is_expired()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_to_identity(user_identity: UserIdentity) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must return correct identity string
 | 
			
		||||
    """
 | 
			
		||||
    assert user_identity == UserIdentity.from_identity(user_identity.to_identity())
 | 
			
		||||
@ -3,27 +3,15 @@ import socket
 | 
			
		||||
 | 
			
		||||
from aiohttp import web
 | 
			
		||||
from aiohttp.test_utils import TestClient
 | 
			
		||||
from cryptography import fernet
 | 
			
		||||
from pytest_mock import MockerFixture
 | 
			
		||||
from unittest.mock import AsyncMock
 | 
			
		||||
from unittest.mock import AsyncMock, call as MockCall
 | 
			
		||||
 | 
			
		||||
from ahriman.core.auth import Auth
 | 
			
		||||
from ahriman.core.configuration import Configuration
 | 
			
		||||
from ahriman.models.user import User
 | 
			
		||||
from ahriman.models.user_access import UserAccess
 | 
			
		||||
from ahriman.models.user_identity import UserIdentity
 | 
			
		||||
from ahriman.web.middlewares.auth_handler import auth_handler, AuthorizationPolicy, setup_auth
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _identity(username: str) -> str:
 | 
			
		||||
    """
 | 
			
		||||
    generate identity from user
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        username(str): name of the user
 | 
			
		||||
 | 
			
		||||
    Returns:
 | 
			
		||||
        str: user identity string
 | 
			
		||||
    """
 | 
			
		||||
    return f"{username} {UserIdentity.expire_when(60)}"
 | 
			
		||||
from ahriman.web.middlewares.auth_handler import AuthorizationPolicy, auth_handler, cookie_secret_key, setup_auth
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_authorized_userid(authorization_policy: AuthorizationPolicy, user: User, mocker: MockerFixture) -> None:
 | 
			
		||||
@ -31,14 +19,14 @@ async def test_authorized_userid(authorization_policy: AuthorizationPolicy, user
 | 
			
		||||
    must return authorized user id
 | 
			
		||||
    """
 | 
			
		||||
    mocker.patch("ahriman.core.database.SQLite.user_get", return_value=user)
 | 
			
		||||
    assert await authorization_policy.authorized_userid(_identity(user.username)) == user.username
 | 
			
		||||
    assert await authorization_policy.authorized_userid(user.username) == user.username
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_authorized_userid_unknown(authorization_policy: AuthorizationPolicy, user: User) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must not allow unknown user id for authorization
 | 
			
		||||
    """
 | 
			
		||||
    assert await authorization_policy.authorized_userid(_identity("somerandomname")) is None
 | 
			
		||||
    assert await authorization_policy.authorized_userid("somerandomname") is None
 | 
			
		||||
    assert await authorization_policy.authorized_userid("somerandomname") is None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -49,11 +37,13 @@ async def test_permits(authorization_policy: AuthorizationPolicy, user: User) ->
 | 
			
		||||
    authorization_policy.validator = AsyncMock()
 | 
			
		||||
    authorization_policy.validator.verify_access.side_effect = lambda username, *args: username == user.username
 | 
			
		||||
 | 
			
		||||
    assert await authorization_policy.permits(_identity(user.username), user.access, "/endpoint")
 | 
			
		||||
    authorization_policy.validator.verify_access.assert_called_once_with(user.username, user.access, "/endpoint")
 | 
			
		||||
    assert await authorization_policy.permits(user.username, user.access, "/endpoint")
 | 
			
		||||
    assert not await authorization_policy.permits("somerandomname", user.access, "/endpoint")
 | 
			
		||||
 | 
			
		||||
    assert not await authorization_policy.permits(_identity("somerandomname"), user.access, "/endpoint")
 | 
			
		||||
    assert not await authorization_policy.permits(user.username, user.access, "/endpoint")
 | 
			
		||||
    authorization_policy.validator.verify_access.assert_has_calls([
 | 
			
		||||
        MockCall(user.username, user.access, "/endpoint"),
 | 
			
		||||
        MockCall("somerandomname", user.access, "/endpoint"),
 | 
			
		||||
    ])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_auth_handler_unix_socket(client_with_auth: TestClient, mocker: MockerFixture) -> None:
 | 
			
		||||
@ -175,11 +165,28 @@ async def test_auth_handler_write(mocker: MockerFixture) -> None:
 | 
			
		||||
        check_permission_mock.assert_called_once_with(aiohttp_request, UserAccess.Full, aiohttp_request.path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_setup_auth(application_with_auth: web.Application, auth: Auth, mocker: MockerFixture) -> None:
 | 
			
		||||
def test_cookie_secret_key(configuration: Configuration) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must generate fernet key
 | 
			
		||||
    """
 | 
			
		||||
    secret_key = cookie_secret_key(configuration)
 | 
			
		||||
    assert isinstance(secret_key, fernet.Fernet)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_cookie_secret_key_cached(configuration: Configuration) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must use cookie key as set by configuration
 | 
			
		||||
    """
 | 
			
		||||
    configuration.set_option("auth", "cookie_secret_key", fernet.Fernet.generate_key().decode("utf8"))
 | 
			
		||||
    assert cookie_secret_key(configuration) is not None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_setup_auth(application_with_auth: web.Application, configuration: Configuration, auth: Auth,
 | 
			
		||||
                    mocker: MockerFixture) -> None:
 | 
			
		||||
    """
 | 
			
		||||
    must set up authorization
 | 
			
		||||
    """
 | 
			
		||||
    setup_mock = mocker.patch("aiohttp_security.setup")
 | 
			
		||||
    application = setup_auth(application_with_auth, auth)
 | 
			
		||||
    application = setup_auth(application_with_auth, configuration, auth)
 | 
			
		||||
    assert application.get("validator") is not None
 | 
			
		||||
    setup_mock.assert_called_once_with(application_with_auth, pytest.helpers.anyvar(int), pytest.helpers.anyvar(int))
 | 
			
		||||
 | 
			
		||||
		Reference in New Issue
	
	Block a user