frozen dataclasses

This commit is contained in:
Evgenii Alekseev 2022-07-26 03:26:56 +03:00
parent c5fbccd519
commit 9d016f51b5
31 changed files with 89 additions and 120 deletions

View File

@ -1 +1 @@
skips: ['B101', 'B105', 'B404'] skips: ['B101', 'B105', 'B106', 'B404']

View File

@ -149,7 +149,7 @@ class Users(Handler):
Returns: Returns:
User: built user descriptor User: built user descriptor
""" """
user = User(args.username, args.password, args.role) password = args.password
if user.password is None: if password is None:
user.password = getpass.getpass() password = getpass.getpass()
return user return User(username=args.username, password=password, access=args.role)

View File

@ -83,7 +83,7 @@ class Migrations(LazyLogging):
module = import_module(f"{__name__}.{module_name}") module = import_module(f"{__name__}.{module_name}")
steps: List[str] = getattr(module, "steps", []) steps: List[str] = getattr(module, "steps", [])
self.logger.debug("found migration %s at index %s with steps count %s", module_name, index, len(steps)) self.logger.debug("found migration %s at index %s with steps count %s", module_name, index, len(steps))
migrations.append(Migration(index, module_name, steps)) migrations.append(Migration(index=index, name=module_name, steps=steps))
return migrations return migrations
@ -97,7 +97,7 @@ class Migrations(LazyLogging):
migrations = self.migrations() migrations = self.migrations()
current_version = self.user_version() current_version = self.user_version()
expected_version = len(migrations) expected_version = len(migrations)
result = MigrationResult(current_version, expected_version) result = MigrationResult(old_version=current_version, new_version=expected_version)
if not result.is_outdated: if not result.is_outdated:
self.logger.info("no migrations required") self.logger.info("no migrations required")

View File

@ -58,7 +58,7 @@ class AuthOperations(Operations):
def run(connection: Connection) -> List[User]: def run(connection: Connection) -> List[User]:
return [ return [
User(cursor["username"], cursor["password"], UserAccess(cursor["access"])) User(username=cursor["username"], password=cursor["password"], access=UserAccess(cursor["access"]))
for cursor in connection.execute( for cursor in connection.execute(
""" """
select * from users select * from users

View File

@ -154,7 +154,11 @@ class PackageOperations(Operations):
Dict[str, Package]: map of the package base to its descriptor (without packages themselves) Dict[str, Package]: map of the package base to its descriptor (without packages themselves)
""" """
return { return {
row["package_base"]: Package(row["package_base"], row["version"], RemoteSource.from_json(row), {}) row["package_base"]: Package(
base=row["package_base"],
version=row["version"],
remote=RemoteSource.from_json(row),
packages={})
for row in connection.execute("""select * from package_bases""") for row in connection.execute("""select * from package_bases""")
} }

View File

@ -80,7 +80,7 @@ class Client:
Returns: Returns:
InternalStatus: current internal (web) service status InternalStatus: current internal (web) service status
""" """
return InternalStatus(BuildStatus()) return InternalStatus(status=BuildStatus())
def remove(self, base: str) -> None: def remove(self, base: str) -> None:
""" """

View File

@ -189,7 +189,7 @@ class WebClient(Client, LazyLogging):
self.logger.exception("could not get web service status: %s", exception_response_text(e)) self.logger.exception("could not get web service status: %s", exception_response_text(e))
except Exception: except Exception:
self.logger.exception("could not get web service status") self.logger.exception("could not get web service status")
return InternalStatus(BuildStatus()) return InternalStatus(status=BuildStatus())
def remove(self, base: str) -> None: def remove(self, base: str) -> None:
""" """

View File

@ -29,7 +29,7 @@ from typing import Any, Callable, Dict, List, Optional, Type
from ahriman.core.util import filter_json, full_version from ahriman.core.util import filter_json, full_version
@dataclass @dataclass(frozen=True, kw_only=True)
class AURPackage: class AURPackage:
""" """
AUR package descriptor AUR package descriptor

View File

@ -47,7 +47,7 @@ class BuildStatusEnum(str, Enum):
Success = "success" Success = "success"
@dataclass @dataclass(frozen=True)
class BuildStatus: class BuildStatus:
""" """
build status holder build status holder
@ -64,7 +64,7 @@ class BuildStatus:
""" """
convert status to enum type convert status to enum type
""" """
self.status = BuildStatusEnum(self.status) object.__setattr__(self, "status", BuildStatusEnum(self.status))
@classmethod @classmethod
def from_json(cls: Type[BuildStatus], dump: Dict[str, Any]) -> BuildStatus: def from_json(cls: Type[BuildStatus], dump: Dict[str, Any]) -> BuildStatus:

View File

@ -27,7 +27,7 @@ from ahriman.models.build_status import BuildStatus
from ahriman.models.package import Package from ahriman.models.package import Package
@dataclass @dataclass(frozen=True, kw_only=True)
class Counters: class Counters:
""" """
package counters package counters

View File

@ -26,7 +26,7 @@ from ahriman.models.build_status import BuildStatus
from ahriman.models.counters import Counters from ahriman.models.counters import Counters
@dataclass @dataclass(frozen=True, kw_only=True)
class InternalStatus: class InternalStatus:
""" """
internal server status internal server status

View File

@ -21,7 +21,7 @@ from dataclasses import dataclass
from typing import List from typing import List
@dataclass @dataclass(frozen=True, kw_only=True)
class Migration: class Migration:
""" """
migration implementation migration implementation

View File

@ -22,7 +22,7 @@ from dataclasses import dataclass
from ahriman.core.exceptions import MigrationError from ahriman.core.exceptions import MigrationError
@dataclass @dataclass(frozen=True, kw_only=True)
class MigrationResult: class MigrationResult:
""" """
migration result implementation model migration result implementation model

View File

@ -38,7 +38,7 @@ from ahriman.models.remote_source import RemoteSource
from ahriman.models.repository_paths import RepositoryPaths from ahriman.models.repository_paths import RepositoryPaths
@dataclass @dataclass(kw_only=True)
class Package(LazyLogging): class Package(LazyLogging):
""" """
package properties representation package properties representation
@ -147,7 +147,7 @@ class Package(LazyLogging):
""" """
package = pacman.handle.load_pkg(str(path)) package = pacman.handle.load_pkg(str(path))
description = PackageDescription.from_package(package, path) description = PackageDescription.from_package(package, path)
return cls(package.base, package.version, remote, {package.name: description}) return cls(base=package.base, version=package.version, remote=remote, packages={package.name: description})
@classmethod @classmethod
def from_aur(cls: Type[Package], name: str, pacman: Pacman) -> Package: def from_aur(cls: Type[Package], name: str, pacman: Pacman) -> Package:
@ -163,7 +163,11 @@ class Package(LazyLogging):
""" """
package = AUR.info(name, pacman=pacman) package = AUR.info(name, pacman=pacman)
remote = RemoteSource.from_source(PackageSource.AUR, package.package_base, package.repository) remote = RemoteSource.from_source(PackageSource.AUR, package.package_base, package.repository)
return cls(package.package_base, package.version, remote, {package.name: PackageDescription()}) return cls(
base=package.package_base,
version=package.version,
remote=remote,
packages={package.name: PackageDescription()})
@classmethod @classmethod
def from_build(cls: Type[Package], path: Path) -> Package: def from_build(cls: Type[Package], path: Path) -> Package:
@ -186,7 +190,7 @@ class Package(LazyLogging):
packages = {key: PackageDescription() for key in srcinfo["packages"]} packages = {key: PackageDescription() for key in srcinfo["packages"]}
version = full_version(srcinfo.get("epoch"), srcinfo["pkgver"], srcinfo["pkgrel"]) version = full_version(srcinfo.get("epoch"), srcinfo["pkgver"], srcinfo["pkgrel"])
return cls(srcinfo["pkgbase"], version, None, packages) return cls(base=srcinfo["pkgbase"], version=version, remote=None, packages=packages)
@classmethod @classmethod
def from_json(cls: Type[Package], dump: Dict[str, Any]) -> Package: def from_json(cls: Type[Package], dump: Dict[str, Any]) -> Package:
@ -204,11 +208,7 @@ class Package(LazyLogging):
for key, value in dump.get("packages", {}).items() for key, value in dump.get("packages", {}).items()
} }
remote = dump.get("remote", {}) remote = dump.get("remote", {})
return cls( return cls(base=dump["base"], version=dump["version"], remote=RemoteSource.from_json(remote), packages=packages)
base=dump["base"],
version=dump["version"],
remote=RemoteSource.from_json(remote),
packages=packages)
@classmethod @classmethod
def from_official(cls: Type[Package], name: str, pacman: Pacman, use_syncdb: bool = True) -> Package: def from_official(cls: Type[Package], name: str, pacman: Pacman, use_syncdb: bool = True) -> Package:
@ -225,7 +225,11 @@ class Package(LazyLogging):
""" """
package = OfficialSyncdb.info(name, pacman=pacman) if use_syncdb else Official.info(name, pacman=pacman) package = OfficialSyncdb.info(name, pacman=pacman) if use_syncdb else Official.info(name, pacman=pacman)
remote = RemoteSource.from_source(PackageSource.Repository, package.package_base, package.repository) remote = RemoteSource.from_source(PackageSource.Repository, package.package_base, package.repository)
return cls(package.package_base, package.version, remote, {package.name: PackageDescription()}) return cls(
base=package.package_base,
version=package.version,
remote=remote,
packages={package.name: PackageDescription()})
@staticmethod @staticmethod
def dependencies(path: Path) -> Set[str]: def dependencies(path: Path) -> Set[str]:

View File

@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Type
from ahriman.core.util import filter_json from ahriman.core.util import filter_json
@dataclass @dataclass(kw_only=True)
class PackageDescription: class PackageDescription:
""" """
package specific properties package specific properties

View File

@ -17,11 +17,11 @@
# You should have received a copy of the GNU General Public License # You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>. # along with this program. If not, see <http://www.gnu.org/licenses/>.
# #
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any from typing import Any
@dataclass @dataclass(frozen=True)
class Property: class Property:
""" """
holder of object properties descriptor holder of object properties descriptor
@ -34,4 +34,4 @@ class Property:
name: str name: str
value: Any value: Any
is_required: bool = False is_required: bool = field(default=False, kw_only=True)

View File

@ -27,7 +27,7 @@ from ahriman.core.util import filter_json
from ahriman.models.package_source import PackageSource from ahriman.models.package_source import PackageSource
@dataclass @dataclass(frozen=True, kw_only=True)
class RemoteSource: class RemoteSource:
""" """
remote package source properties remote package source properties
@ -50,7 +50,7 @@ class RemoteSource:
""" """
convert source to enum type convert source to enum type
""" """
self.source = PackageSource(self.source) object.__setattr__(self, "source", PackageSource(self.source))
@property @property
def pkgbuild_dir(self) -> Path: def pkgbuild_dir(self) -> Path:

View File

@ -29,7 +29,7 @@ from typing import Set, Tuple, Type
from ahriman.core.exceptions import InvalidPath from ahriman.core.exceptions import InvalidPath
@dataclass @dataclass(frozen=True)
class RepositoryPaths: class RepositoryPaths:
""" """
repository paths holder. For the most operations with paths you want to use this object repository paths holder. For the most operations with paths you want to use this object

View File

@ -19,7 +19,7 @@
# #
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass, replace
from typing import Optional, Type from typing import Optional, Type
from passlib.pwd import genword as generate_password # type: ignore from passlib.pwd import genword as generate_password # type: ignore
from passlib.handlers.sha2_crypt import sha512_crypt # type: ignore from passlib.handlers.sha2_crypt import sha512_crypt # type: ignore
@ -27,7 +27,7 @@ from passlib.handlers.sha2_crypt import sha512_crypt # type: ignore
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
@dataclass @dataclass(frozen=True, kw_only=True)
class User: class User:
""" """
authorized web user model authorized web user model
@ -82,7 +82,7 @@ class User:
""" """
if username is None or password is None: if username is None or password is None:
return None return None
return cls(username, password, access) return cls(username=username, password=password, access=access)
@staticmethod @staticmethod
def generate_password(length: int) -> str: def generate_password(length: int) -> str:
@ -130,7 +130,7 @@ class User:
# when we do not store any password here # when we do not store any password here
return self return self
password_hash: str = self._HASHER.hash(self.password + salt) password_hash: str = self._HASHER.hash(self.password + salt)
return User(self.username, password_hash, self.access) return replace(self, password=password_hash)
def verify_access(self, required: UserAccess) -> bool: def verify_access(self, required: UserAccess) -> bool:
""" """

View File

@ -25,7 +25,7 @@ from dataclasses import dataclass
from typing import Optional, Type from typing import Optional, Type
@dataclass @dataclass(frozen=True)
class UserIdentity: class UserIdentity:
""" """
user identity used inside web service user identity used inside web service

View File

@ -38,7 +38,7 @@ def test_run(args: argparse.Namespace, configuration: Configuration, database: S
must run command must run command
""" """
args = _default_args(args) args = _default_args(args)
user = User(args.username, args.password, args.role) user = User(username=args.username, password=args.password, access=args.role)
mocker.patch("ahriman.core.database.SQLite.load", return_value=database) mocker.patch("ahriman.core.database.SQLite.load", return_value=database)
mocker.patch("ahriman.models.user.User.hash_password", return_value=user) mocker.patch("ahriman.models.user.User.hash_password", return_value=user)
get_auth_configuration_mock = mocker.patch("ahriman.application.handlers.Users.configuration_get") get_auth_configuration_mock = mocker.patch("ahriman.application.handlers.Users.configuration_get")

View File

@ -426,7 +426,7 @@ def user() -> User:
Returns: Returns:
User: user descriptor instance User: user descriptor instance
""" """
return User("user", "pa55w0rd", UserAccess.Reporter) return User(username="user", password="pa55w0rd", access=UserAccess.Reporter)
@pytest.fixture @pytest.fixture

View File

@ -44,7 +44,7 @@ def test_run(migrations: Migrations, mocker: MockerFixture) -> None:
cursor = MagicMock() cursor = MagicMock()
mocker.patch("ahriman.core.database.migrations.Migrations.user_version", return_value=0) mocker.patch("ahriman.core.database.migrations.Migrations.user_version", return_value=0)
mocker.patch("ahriman.core.database.migrations.Migrations.migrations", mocker.patch("ahriman.core.database.migrations.Migrations.migrations",
return_value=[Migration(0, "test", ["select 1"])]) return_value=[Migration(index=0, name="test", steps=["select 1"])])
migrations.connection.cursor.return_value = cursor migrations.connection.cursor.return_value = cursor
validate_mock = mocker.patch("ahriman.models.migration_result.MigrationResult.validate") validate_mock = mocker.patch("ahriman.models.migration_result.MigrationResult.validate")
migrate_data_mock = mocker.patch("ahriman.core.database.migrations.migrate_data") migrate_data_mock = mocker.patch("ahriman.core.database.migrations.migrate_data")
@ -58,7 +58,8 @@ def test_run(migrations: Migrations, mocker: MockerFixture) -> None:
mock.call("commit"), mock.call("commit"),
]) ])
cursor.close.assert_called_once_with() cursor.close.assert_called_once_with()
migrate_data_mock.assert_called_once_with(MigrationResult(0, 1), migrations.connection, migrations.configuration) migrate_data_mock.assert_called_once_with(
MigrationResult(old_version=0, new_version=1), migrations.connection, migrations.configuration)
def test_run_migration_exception(migrations: Migrations, mocker: MockerFixture) -> None: def test_run_migration_exception(migrations: Migrations, mocker: MockerFixture) -> None:
@ -69,7 +70,7 @@ def test_run_migration_exception(migrations: Migrations, mocker: MockerFixture)
mocker.patch("logging.Logger.info", side_effect=Exception()) mocker.patch("logging.Logger.info", side_effect=Exception())
mocker.patch("ahriman.core.database.migrations.Migrations.user_version", return_value=0) mocker.patch("ahriman.core.database.migrations.Migrations.user_version", return_value=0)
mocker.patch("ahriman.core.database.migrations.Migrations.migrations", mocker.patch("ahriman.core.database.migrations.Migrations.migrations",
return_value=[Migration(0, "test", ["select 1"])]) return_value=[Migration(index=0, name="test", steps=["select 1"])])
mocker.patch("ahriman.models.migration_result.MigrationResult.validate") mocker.patch("ahriman.models.migration_result.MigrationResult.validate")
migrations.connection.cursor.return_value = cursor migrations.connection.cursor.return_value = cursor
@ -90,7 +91,7 @@ def test_run_sql_exception(migrations: Migrations, mocker: MockerFixture) -> Non
cursor.execute.side_effect = Exception() cursor.execute.side_effect = Exception()
mocker.patch("ahriman.core.database.migrations.Migrations.user_version", return_value=0) mocker.patch("ahriman.core.database.migrations.Migrations.user_version", return_value=0)
mocker.patch("ahriman.core.database.migrations.Migrations.migrations", mocker.patch("ahriman.core.database.migrations.Migrations.migrations",
return_value=[Migration(0, "test", ["select 1"])]) return_value=[Migration(index=0, name="test", steps=["select 1"])])
mocker.patch("ahriman.models.migration_result.MigrationResult.validate") mocker.patch("ahriman.models.migration_result.MigrationResult.validate")
migrations.connection.cursor.return_value = cursor migrations.connection.cursor.return_value = cursor

View File

@ -16,21 +16,21 @@ def test_user_list(database: SQLite, user: User) -> None:
must return all users must return all users
""" """
database.user_update(user) database.user_update(user)
database.user_update(User(user.password, user.username, user.access)) database.user_update(User(username=user.password, password=user.username, access=user.access))
users = database.user_list(None, None) users = database.user_list(None, None)
assert len(users) == 2 assert len(users) == 2
assert user in users assert user in users
assert User(user.password, user.username, user.access) in users assert User(username=user.password, password=user.username, access=user.access) in users
def test_user_list_filter_by_username(database: SQLite) -> None: def test_user_list_filter_by_username(database: SQLite) -> None:
""" """
must return users filtered by its id must return users filtered by its id
""" """
first = User("1", "", UserAccess.Read) first = User(username="1", password="", access=UserAccess.Read)
second = User("2", "", UserAccess.Full) second = User(username="2", password="", access=UserAccess.Full)
third = User("3", "", UserAccess.Read) third = User(username="3", password="", access=UserAccess.Read)
database.user_update(first) database.user_update(first)
database.user_update(second) database.user_update(second)
@ -45,9 +45,9 @@ def test_user_list_filter_by_access(database: SQLite) -> None:
""" """
must return users filtered by its access must return users filtered by its access
""" """
first = User("1", "", UserAccess.Read) first = User(username="1", password="", access=UserAccess.Read)
second = User("2", "", UserAccess.Full) second = User(username="2", password="", access=UserAccess.Full)
third = User("3", "", UserAccess.Read) third = User(username="3", password="", access=UserAccess.Read)
database.user_update(first) database.user_update(first)
database.user_update(second) database.user_update(second)
@ -63,9 +63,9 @@ def test_user_list_filter_by_username_access(database: SQLite) -> None:
""" """
must return users filtered by its access and username must return users filtered by its access and username
""" """
first = User("1", "", UserAccess.Read) first = User(username="1", password="", access=UserAccess.Read)
second = User("2", "", UserAccess.Full) second = User(username="2", password="", access=UserAccess.Full)
third = User("3", "", UserAccess.Read) third = User(username="3", password="", access=UserAccess.Read)
database.user_update(first) database.user_update(first)
database.user_update(second) database.user_update(second)
@ -91,7 +91,6 @@ def test_user_update(database: SQLite, user: User) -> None:
database.user_update(user) database.user_update(user)
assert database.user_get(user.username) == user assert database.user_get(user.username) == user
new_user = user.hash_password("salt") new_user = User(username=user.username, password=user.hash_password("salt").password, access=UserAccess.Full)
new_user.access = UserAccess.Full
database.user_update(new_user) database.user_update(new_user)
assert database.user_get(new_user.username) == new_user assert database.user_get(new_user.username) == new_user

View File

@ -51,9 +51,8 @@ def test_get_internal(client: Client) -> None:
""" """
must return dummy status for web service must return dummy status for web service
""" """
expected = InternalStatus(BuildStatus())
actual = client.get_internal() actual = client.get_internal()
actual.status.timestamp = expected.status.timestamp expected = InternalStatus(status=BuildStatus(timestamp=actual.status.timestamp))
assert actual == expected assert actual == expected

View File

@ -164,8 +164,9 @@ def test_get_internal(web_client: WebClient, mocker: MockerFixture) -> None:
""" """
must return web service status must return web service status
""" """
status = InternalStatus(status=BuildStatus(), architecture="x86_64")
response_obj = Response() response_obj = Response()
response_obj._content = json.dumps(InternalStatus(BuildStatus(), architecture="x86_64").view()).encode("utf8") response_obj._content = json.dumps(status.view()).encode("utf8")
response_obj.status_code = 200 response_obj.status_code = 200
requests_mock = mocker.patch("requests.Session.get", return_value=response_obj) requests_mock = mocker.patch("requests.Session.get", return_value=response_obj)

View File

@ -61,11 +61,11 @@ def test_from_pacman(pyalpm_package_ahriman: pyalpm.Package, aur_package_ahriman
""" """
model = AURPackage.from_pacman(pyalpm_package_ahriman) model = AURPackage.from_pacman(pyalpm_package_ahriman)
# some fields are missing so we are changing them # some fields are missing so we are changing them
model.id = aur_package_ahriman.id object.__setattr__(model, "id", aur_package_ahriman.id)
model.package_base_id = aur_package_ahriman.package_base_id object.__setattr__(model, "package_base_id", aur_package_ahriman.package_base_id)
model.first_submitted = aur_package_ahriman.first_submitted object.__setattr__(model, "first_submitted", aur_package_ahriman.first_submitted)
model.url_path = aur_package_ahriman.url_path object.__setattr__(model, "url_path", aur_package_ahriman.url_path)
model.maintainer = aur_package_ahriman.maintainer object.__setattr__(model, "maintainer", aur_package_ahriman.maintainer)
assert model == aur_package_ahriman assert model == aur_package_ahriman

View File

@ -1,4 +1,3 @@
import datetime
import time import time
from ahriman.models.build_status import BuildStatus, BuildStatusEnum from ahriman.models.build_status import BuildStatus, BuildStatusEnum
@ -53,43 +52,3 @@ def test_build_status_eq(build_status_failed: BuildStatus) -> None:
""" """
other = BuildStatus.from_json(build_status_failed.view()) other = BuildStatus.from_json(build_status_failed.view())
assert other == build_status_failed assert other == build_status_failed
def test_build_status_eq_self(build_status_failed: BuildStatus) -> None:
"""
must be equal itself
"""
assert build_status_failed == build_status_failed
def test_build_status_ne_by_status(build_status_failed: BuildStatus) -> None:
"""
must be not equal by status
"""
other = BuildStatus.from_json(build_status_failed.view())
other.status = BuildStatusEnum.Success
assert build_status_failed != other
def test_build_status_ne_by_timestamp(build_status_failed: BuildStatus) -> None:
"""
must be not equal by timestamp
"""
other = BuildStatus.from_json(build_status_failed.view())
other.timestamp = datetime.datetime.utcnow().timestamp()
assert build_status_failed != other
def test_build_status_ne_other(build_status_failed: BuildStatus) -> None:
"""
must be not equal to random object
"""
assert build_status_failed != object()
def test_build_status_repr(build_status_failed: BuildStatus) -> None:
"""
must return string in __repr__ function
"""
assert build_status_failed.__repr__()
assert isinstance(build_status_failed.__repr__(), str)

View File

@ -69,7 +69,7 @@ def test_chown(repository_paths: RepositoryPaths, mocker: MockerFixture) -> None
""" """
must correctly set owner for the directory must correctly set owner for the directory
""" """
repository_paths.owner = _get_owner(repository_paths.root, same=False) object.__setattr__(repository_paths, "owner", _get_owner(repository_paths.root, same=False))
mocker.patch.object(RepositoryPaths, "root_owner", (42, 42)) mocker.patch.object(RepositoryPaths, "root_owner", (42, 42))
chown_mock = mocker.patch("os.chown") chown_mock = mocker.patch("os.chown")
@ -82,7 +82,7 @@ def test_chown_parent(repository_paths: RepositoryPaths, mocker: MockerFixture)
""" """
must correctly set owner for the directory including parents must correctly set owner for the directory including parents
""" """
repository_paths.owner = _get_owner(repository_paths.root, same=False) object.__setattr__(repository_paths, "owner", _get_owner(repository_paths.root, same=False))
mocker.patch.object(RepositoryPaths, "root_owner", (42, 42)) mocker.patch.object(RepositoryPaths, "root_owner", (42, 42))
chown_mock = mocker.patch("os.chown") chown_mock = mocker.patch("os.chown")
@ -98,7 +98,7 @@ def test_chown_skip(repository_paths: RepositoryPaths, mocker: MockerFixture) ->
""" """
must skip ownership set in case if it is same as root must skip ownership set in case if it is same as root
""" """
repository_paths.owner = _get_owner(repository_paths.root, same=True) object.__setattr__(repository_paths, "owner", _get_owner(repository_paths.root, same=True))
mocker.patch.object(RepositoryPaths, "root_owner", (42, 42)) mocker.patch.object(RepositoryPaths, "root_owner", (42, 42))
chown_mock = mocker.patch("os.chown") chown_mock = mocker.patch("os.chown")

View File

@ -1,3 +1,5 @@
from dataclasses import replace
from ahriman.models.user import User from ahriman.models.user import User
from ahriman.models.user_access import UserAccess from ahriman.models.user_access import UserAccess
@ -6,10 +8,10 @@ def test_from_option(user: User) -> None:
""" """
must generate user from options must generate user from options
""" """
user.access = UserAccess.Read user = replace(user, access=UserAccess.Read)
assert User.from_option(user.username, user.password) == user assert User.from_option(user.username, user.password) == user
# default is read access # default is read access
user.access = UserAccess.Full user = replace(user, access=UserAccess.Full)
assert User.from_option(user.username, user.password) != user assert User.from_option(user.username, user.password) != user
assert User.from_option(user.username, user.password, user.access) == user assert User.from_option(user.username, user.password, user.access) == user
@ -40,7 +42,7 @@ def test_check_credentials_empty_hash(user: User) -> None:
""" """
current_password = user.password current_password = user.password
assert not user.check_credentials(current_password, "salt") assert not user.check_credentials(current_password, "salt")
user.password = "" user = replace(user, password="")
assert not user.check_credentials(current_password, "salt") assert not user.check_credentials(current_password, "salt")
@ -48,9 +50,9 @@ def test_hash_password_empty_hash(user: User) -> None:
""" """
must return empty string after hash in case if password not set must return empty string after hash in case if password not set
""" """
user.password = "" user = replace(user, password="")
assert user.hash_password("salt") == user assert user.hash_password("salt") == user
user.password = None user = replace(user, password=None)
assert user.hash_password("salt") == user assert user.hash_password("salt") == user
@ -71,7 +73,7 @@ def test_verify_access_read(user: User) -> None:
""" """
user with read access must be able to only request read user with read access must be able to only request read
""" """
user.access = UserAccess.Read user = replace(user, access=UserAccess.Read)
assert user.verify_access(UserAccess.Read) assert user.verify_access(UserAccess.Read)
assert not user.verify_access(UserAccess.Full) assert not user.verify_access(UserAccess.Full)
@ -80,7 +82,7 @@ def test_verify_access_write(user: User) -> None:
""" """
user with write access must be able to do anything user with write access must be able to do anything
""" """
user.access = UserAccess.Full user = replace(user, access=UserAccess.Full)
assert user.verify_access(UserAccess.Read) assert user.verify_access(UserAccess.Read)
assert user.verify_access(UserAccess.Full) assert user.verify_access(UserAccess.Full)

View File

@ -13,7 +13,7 @@ def test_from_identity_expired(user_identity: UserIdentity) -> None:
""" """
must construct None from expired identity must construct None from expired identity
""" """
user_identity.expire_at -= 60 user_identity = UserIdentity(username=user_identity.username, expire_at=user_identity.expire_at - 60)
assert UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}") is None assert UserIdentity.from_identity(f"{user_identity.username} {user_identity.expire_at}") is None
@ -53,7 +53,7 @@ def test_is_expired(user_identity: UserIdentity) -> None:
""" """
assert not user_identity.is_expired() assert not user_identity.is_expired()
user_identity.expire_at -= 60 user_identity = UserIdentity(username=user_identity.username, expire_at=user_identity.expire_at - 60)
assert user_identity.is_expired() assert user_identity.is_expired()