diff --git a/src/ahriman/core/database/operations/operations.py b/src/ahriman/core/database/operations/operations.py index ff6882a5..f92a5e5b 100644 --- a/src/ahriman/core/database/operations/operations.py +++ b/src/ahriman/core/database/operations/operations.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . # +import contextlib import sqlite3 from collections.abc import Callable @@ -87,10 +88,12 @@ class Operations(LazyLogging): Returns: T: result of the ``query`` call """ - with sqlite3.connect(self.path, detect_types=sqlite3.PARSE_DECLTYPES) as connection: + with contextlib.closing(sqlite3.connect(self.path, detect_types=sqlite3.PARSE_DECLTYPES)) as connection: connection.set_trace_callback(self.logger.debug) connection.row_factory = self.factory + result = query(connection) if commit: connection.commit() + return result diff --git a/tests/ahriman/core/database/operations/test_operations.py b/tests/ahriman/core/database/operations/test_operations.py index 8143ccd6..6d35404b 100644 --- a/tests/ahriman/core/database/operations/test_operations.py +++ b/tests/ahriman/core/database/operations/test_operations.py @@ -1,3 +1,4 @@ +import pytest import sqlite3 from pytest_mock import MockerFixture @@ -24,15 +25,29 @@ def test_factory(database: SQLite) -> None: def test_with_connection(database: SQLite, mocker: MockerFixture) -> None: """ - must run query inside connection + must run query inside connection and close it at the end """ connection_mock = MagicMock() connect_mock = mocker.patch("sqlite3.connect", return_value=connection_mock) database.with_connection(lambda conn: conn.execute("select 1")) connect_mock.assert_called_once_with(database.path, detect_types=sqlite3.PARSE_DECLTYPES) - connection_mock.__enter__().set_trace_callback.assert_called_once_with(database.logger.debug) - connection_mock.__enter__().commit.assert_not_called() + connection_mock.set_trace_callback.assert_called_once_with(database.logger.debug) + connection_mock.commit.assert_not_called() + connection_mock.close.assert_called_once_with() + + +def test_with_connection_close(database: SQLite, mocker: MockerFixture) -> None: + """ + must close connection on errors + """ + connection_mock = MagicMock() + connection_mock.commit.side_effect = Exception + mocker.patch("sqlite3.connect", return_value=connection_mock) + + with pytest.raises(Exception): + database.with_connection(lambda conn: conn.execute("select 1"), commit=True) + connection_mock.close.assert_called_once_with() def test_with_connection_with_commit(database: SQLite, mocker: MockerFixture) -> None: @@ -44,4 +59,4 @@ def test_with_connection_with_commit(database: SQLite, mocker: MockerFixture) -> mocker.patch("sqlite3.connect", return_value=connection_mock) database.with_connection(lambda conn: conn.execute("select 1"), commit=True) - connection_mock.__enter__().commit.assert_called_once_with() + connection_mock.commit.assert_called_once_with()