feat: allow to pass repository identifier to all database methods

This commit is contained in:
Evgenii Alekseev 2023-11-04 16:36:14 +02:00
parent 7b667c8157
commit b116e6fa07
3 changed files with 40 additions and 25 deletions

View File

@ -21,6 +21,7 @@ from sqlite3 import Connection
from ahriman.core.database.operations import Operations from ahriman.core.database.operations import Operations
from ahriman.models.package import Package from ahriman.models.package import Package
from ahriman.models.repository_id import RepositoryId
class BuildOperations(Operations): class BuildOperations(Operations):
@ -28,13 +29,16 @@ class BuildOperations(Operations):
operations for build queue functions operations for build queue functions
""" """
def build_queue_clear(self, package_base: str | None) -> None: def build_queue_clear(self, package_base: str | None, repository_id: RepositoryId | None = None) -> None:
""" """
remove packages from build queue remove packages from build queue
Args: Args:
package_base(str | None): optional filter by package base package_base(str | None): optional filter by package base
repository_id(RepositoryId, optional): repository unique identifier override (Default value = None)
""" """
repository_id = repository_id or self._repository_id
def run(connection: Connection) -> None: def run(connection: Connection) -> None:
connection.execute( connection.execute(
""" """
@ -43,36 +47,44 @@ class BuildOperations(Operations):
""", """,
{ {
"package_base": package_base, "package_base": package_base,
"repository": self._repository_id.id, "repository": repository_id.id,
}) })
return self.with_connection(run, commit=True) return self.with_connection(run, commit=True)
def build_queue_get(self) -> list[Package]: def build_queue_get(self, repository_id: RepositoryId | None = None) -> list[Package]:
""" """
retrieve packages from build queue retrieve packages from build queue
Args:
repository_id(RepositoryId, optional): repository unique identifier override (Default value = None)
Return: Return:
list[Package]: list of packages to be built list[Package]: list of packages to be built
""" """
repository_id = repository_id or self._repository_id
def run(connection: Connection) -> list[Package]: def run(connection: Connection) -> list[Package]:
return [ return [
Package.from_json(row["properties"]) Package.from_json(row["properties"])
for row in connection.execute( for row in connection.execute(
"""select properties from build_queue where repository = :repository""", """select properties from build_queue where repository = :repository""",
{"repository": self._repository_id.id} {"repository": repository_id.id}
) )
] ]
return self.with_connection(run) return self.with_connection(run)
def build_queue_insert(self, package: Package) -> None: def build_queue_insert(self, package: Package, repository_id: RepositoryId | None = None) -> None:
""" """
insert packages to build queue insert packages to build queue
Args: Args:
package(Package): package to be inserted package(Package): package to be inserted
repository_id(RepositoryId, optional): repository unique identifier override (Default value = None)
""" """
repository_id = repository_id or self._repository_id
def run(connection: Connection) -> None: def run(connection: Connection) -> None:
connection.execute( connection.execute(
""" """
@ -86,7 +98,7 @@ class BuildOperations(Operations):
{ {
"package_base": package.base, "package_base": package.base,
"properties": package.view(), "properties": package.view(),
"repository": self._repository_id.id, "repository": repository_id.id,
}) })
return self.with_connection(run, commit=True) return self.with_connection(run, commit=True)

View File

@ -302,26 +302,35 @@ class PackageOperations(Operations):
return self.with_connection(lambda connection: list(run(connection))) return self.with_connection(lambda connection: list(run(connection)))
def remote_update(self, package: Package) -> None: def remote_update(self, package: Package, repository_id: RepositoryId | None = None) -> None:
""" """
update package remote source update package remote source
Args: Args:
package(Package): package properties package(Package): package properties
repository_id(RepositoryId, optional): repository unique identifier override (Default value = None)
""" """
return self.with_connection( repository_id = repository_id or self._repository_id
lambda connection: self._package_update_insert_base(connection, package, self._repository_id),
commit=True)
def remotes_get(self) -> dict[str, RemoteSource]: def run(connection: Connection) -> None:
self._package_update_insert_base(connection, package, repository_id)
return self.with_connection(run, commit=True)
def remotes_get(self, repository_id: RepositoryId | None = None) -> dict[str, RemoteSource]:
""" """
get packages remotes based on current settings get packages remotes based on current settings
Args:
repository_id(RepositoryId, optional): repository unique identifier override (Default value = None)
Returns: Returns:
dict[str, RemoteSource]: map of package base to its remote sources dict[str, RemoteSource]: map of package base to its remote sources
""" """
repository_id = repository_id or self._repository_id
def run(connection: Connection) -> dict[str, Package]: def run(connection: Connection) -> dict[str, Package]:
return self._packages_get_select_package_bases(connection, self._repository_id) return self._packages_get_select_package_bases(connection, repository_id)
return { return {
package_base: package.remote package_base: package.remote

View File

@ -19,12 +19,10 @@ def test_build_queue_insert_clear_multi(database: SQLite, package_ahriman: Packa
must clear all packages from queue for specific repository must clear all packages from queue for specific repository
""" """
database.build_queue_insert(package_ahriman) database.build_queue_insert(package_ahriman)
database._repository_id = RepositoryId("i686", database._repository_id.name) database.build_queue_insert(package_ahriman, RepositoryId("i686", database._repository_id.name))
database.build_queue_insert(package_ahriman)
database.build_queue_clear(None) database.build_queue_clear(None)
database._repository_id = RepositoryId("x86_64", database._repository_id.name) assert database.build_queue_get(RepositoryId("i686", database._repository_id.name)) == [package_ahriman]
assert database.build_queue_get() == [package_ahriman]
def test_build_queue_insert_clear_specific(database: SQLite, package_ahriman: Package, def test_build_queue_insert_clear_specific(database: SQLite, package_ahriman: Package,
@ -68,19 +66,15 @@ def test_build_queue_insert_multi(database: SQLite, package_ahriman: Package) ->
assert database.build_queue_get() == [package_ahriman] assert database.build_queue_get() == [package_ahriman]
package_ahriman.version = "2" package_ahriman.version = "2"
database._repository_id = RepositoryId("i686", database._repository_id.name) database.build_queue_insert(package_ahriman, RepositoryId("i686", database._repository_id.name))
database.build_queue_insert(package_ahriman) assert database.build_queue_get(RepositoryId("i686", database._repository_id.name)) == [package_ahriman]
assert database.build_queue_get() == [package_ahriman]
package_ahriman.version = "1" package_ahriman.version = "1"
database._repository_id = RepositoryId("x86_64", database._repository_id.name) assert database.build_queue_get(RepositoryId("x86_64", database._repository_id.name)) == [package_ahriman]
assert database.build_queue_get() == [package_ahriman]
package_ahriman.version = "3" package_ahriman.version = "3"
database._repository_id = RepositoryId(database._repository_id.architecture, "repo") database.build_queue_insert(package_ahriman, RepositoryId(database._repository_id.architecture, "repo"))
database.build_queue_insert(package_ahriman) assert database.build_queue_get(RepositoryId(database._repository_id.architecture, "repo")) == [package_ahriman]
assert database.build_queue_get() == [package_ahriman]
package_ahriman.version = "1" package_ahriman.version = "1"
database._repository_id = RepositoryId(database._repository_id.architecture, "aur-clone")
assert database.build_queue_get() == [package_ahriman] assert database.build_queue_get() == [package_ahriman]