feat: allow to pass repository identifier to all database methods

This commit is contained in:
Evgenii Alekseev 2023-11-04 16:36:14 +02:00
parent 7b667c8157
commit b116e6fa07
3 changed files with 40 additions and 25 deletions

View File

@ -21,6 +21,7 @@ from sqlite3 import Connection
from ahriman.core.database.operations import Operations
from ahriman.models.package import Package
from ahriman.models.repository_id import RepositoryId
class BuildOperations(Operations):
@ -28,13 +29,16 @@ class BuildOperations(Operations):
operations for build queue functions
"""
def build_queue_clear(self, package_base: str | None) -> None:
def build_queue_clear(self, package_base: str | None, repository_id: RepositoryId | None = None) -> None:
"""
remove packages from build queue
Args:
package_base(str | None): optional filter by package base
repository_id(RepositoryId, optional): repository unique identifier override (Default value = None)
"""
repository_id = repository_id or self._repository_id
def run(connection: Connection) -> None:
connection.execute(
"""
@ -43,36 +47,44 @@ class BuildOperations(Operations):
""",
{
"package_base": package_base,
"repository": self._repository_id.id,
"repository": repository_id.id,
})
return self.with_connection(run, commit=True)
def build_queue_get(self) -> list[Package]:
def build_queue_get(self, repository_id: RepositoryId | None = None) -> list[Package]:
"""
retrieve packages from build queue
Args:
repository_id(RepositoryId, optional): repository unique identifier override (Default value = None)
Return:
list[Package]: list of packages to be built
"""
repository_id = repository_id or self._repository_id
def run(connection: Connection) -> list[Package]:
return [
Package.from_json(row["properties"])
for row in connection.execute(
"""select properties from build_queue where repository = :repository""",
{"repository": self._repository_id.id}
{"repository": repository_id.id}
)
]
return self.with_connection(run)
def build_queue_insert(self, package: Package) -> None:
def build_queue_insert(self, package: Package, repository_id: RepositoryId | None = None) -> None:
"""
insert packages to build queue
Args:
package(Package): package to be inserted
repository_id(RepositoryId, optional): repository unique identifier override (Default value = None)
"""
repository_id = repository_id or self._repository_id
def run(connection: Connection) -> None:
connection.execute(
"""
@ -86,7 +98,7 @@ class BuildOperations(Operations):
{
"package_base": package.base,
"properties": package.view(),
"repository": self._repository_id.id,
"repository": repository_id.id,
})
return self.with_connection(run, commit=True)

View File

@ -302,26 +302,35 @@ class PackageOperations(Operations):
return self.with_connection(lambda connection: list(run(connection)))
def remote_update(self, package: Package) -> None:
def remote_update(self, package: Package, repository_id: RepositoryId | None = None) -> None:
"""
update package remote source
Args:
package(Package): package properties
repository_id(RepositoryId, optional): repository unique identifier override (Default value = None)
"""
return self.with_connection(
lambda connection: self._package_update_insert_base(connection, package, self._repository_id),
commit=True)
repository_id = repository_id or self._repository_id
def remotes_get(self) -> dict[str, RemoteSource]:
def run(connection: Connection) -> None:
self._package_update_insert_base(connection, package, repository_id)
return self.with_connection(run, commit=True)
def remotes_get(self, repository_id: RepositoryId | None = None) -> dict[str, RemoteSource]:
"""
get packages remotes based on current settings
Args:
repository_id(RepositoryId, optional): repository unique identifier override (Default value = None)
Returns:
dict[str, RemoteSource]: map of package base to its remote sources
"""
repository_id = repository_id or self._repository_id
def run(connection: Connection) -> dict[str, Package]:
return self._packages_get_select_package_bases(connection, self._repository_id)
return self._packages_get_select_package_bases(connection, repository_id)
return {
package_base: package.remote

View File

@ -19,12 +19,10 @@ def test_build_queue_insert_clear_multi(database: SQLite, package_ahriman: Packa
must clear all packages from queue for specific repository
"""
database.build_queue_insert(package_ahriman)
database._repository_id = RepositoryId("i686", database._repository_id.name)
database.build_queue_insert(package_ahriman)
database.build_queue_insert(package_ahriman, RepositoryId("i686", database._repository_id.name))
database.build_queue_clear(None)
database._repository_id = RepositoryId("x86_64", database._repository_id.name)
assert database.build_queue_get() == [package_ahriman]
assert database.build_queue_get(RepositoryId("i686", database._repository_id.name)) == [package_ahriman]
def test_build_queue_insert_clear_specific(database: SQLite, package_ahriman: Package,
@ -68,19 +66,15 @@ def test_build_queue_insert_multi(database: SQLite, package_ahriman: Package) ->
assert database.build_queue_get() == [package_ahriman]
package_ahriman.version = "2"
database._repository_id = RepositoryId("i686", database._repository_id.name)
database.build_queue_insert(package_ahriman)
assert database.build_queue_get() == [package_ahriman]
database.build_queue_insert(package_ahriman, RepositoryId("i686", database._repository_id.name))
assert database.build_queue_get(RepositoryId("i686", database._repository_id.name)) == [package_ahriman]
package_ahriman.version = "1"
database._repository_id = RepositoryId("x86_64", database._repository_id.name)
assert database.build_queue_get() == [package_ahriman]
assert database.build_queue_get(RepositoryId("x86_64", database._repository_id.name)) == [package_ahriman]
package_ahriman.version = "3"
database._repository_id = RepositoryId(database._repository_id.architecture, "repo")
database.build_queue_insert(package_ahriman)
assert database.build_queue_get() == [package_ahriman]
database.build_queue_insert(package_ahriman, RepositoryId(database._repository_id.architecture, "repo"))
assert database.build_queue_get(RepositoryId(database._repository_id.architecture, "repo")) == [package_ahriman]
package_ahriman.version = "1"
database._repository_id = RepositoryId(database._repository_id.architecture, "aur-clone")
assert database.build_queue_get() == [package_ahriman]