From 89008e5350d5d61bed4b2091452b3dbf0aebcf6a Mon Sep 17 00:00:00 2001 From: Evgenii Alekseev Date: Thu, 19 Feb 2026 02:01:56 +0200 Subject: [PATCH] fix: use context manager for selector and smtp session --- src/ahriman/core/report/email.py | 25 ++++++++++----- src/ahriman/core/utils.py | 42 +++++++++++++------------ tests/ahriman/core/report/test_email.py | 30 ++++++++++++++++-- 3 files changed, 67 insertions(+), 30 deletions(-) diff --git a/src/ahriman/core/report/email.py b/src/ahriman/core/report/email.py index 253c1a62..3b2c731d 100644 --- a/src/ahriman/core/report/email.py +++ b/src/ahriman/core/report/email.py @@ -74,6 +74,18 @@ class Email(Report, JinjaTemplate): self.ssl = SmtpSSLSettings.from_option(configuration.get(section, "ssl", fallback="disabled")) self.user = configuration.get(section, "user", fallback=None) + @property + def _smtp_session(self) -> type[smtplib.SMTP]: + """ + build SMTP session based on configuration settings + + Returns: + type[smtplib.SMTP]: SMTP or SMTP_SSL session depending on whether SSL is enabled or not + """ + if self.ssl == SmtpSSLSettings.SSL: + return smtplib.SMTP_SSL + return smtplib.SMTP + def _send(self, text: str, attachment: dict[str, str]) -> None: """ send email callback @@ -93,16 +105,13 @@ class Email(Report, JinjaTemplate): attach.add_header("Content-Disposition", "attachment", filename=filename) message.attach(attach) - if self.ssl != SmtpSSLSettings.SSL: - session = smtplib.SMTP(self.host, self.port) + with self._smtp_session(self.host, self.port) as session: if self.ssl == SmtpSSLSettings.STARTTLS: session.starttls() - else: - session = smtplib.SMTP_SSL(self.host, self.port) - if self.user is not None and self.password is not None: - session.login(self.user, self.password) - session.sendmail(self.sender, self.receivers, message.as_string()) - session.quit() + + if self.user is not None and self.password is not None: + session.login(self.user, self.password) + session.sendmail(self.sender, self.receivers, message.as_string()) def generate(self, packages: list[Package], result: Result) -> None: """ diff --git a/src/ahriman/core/utils.py b/src/ahriman/core/utils.py index c9e5df3b..6ac48ce8 100644 --- a/src/ahriman/core/utils.py +++ b/src/ahriman/core/utils.py @@ -164,6 +164,11 @@ def check_output(*args: str, exception: Exception | Callable[[int, list[str], st if key in ("PATH",) # whitelisted variables only } | environment + result: dict[str, list[str]] = { + "stdout": [], + "stderr": [], + } + with subprocess.Popen(args, cwd=cwd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, user=user, env=full_environment, text=True, encoding="utf8", errors="backslashreplace", bufsize=1) as process: @@ -172,30 +177,27 @@ def check_output(*args: str, exception: Exception | Callable[[int, list[str], st input_channel.write(input_data) input_channel.close() - selector = selectors.DefaultSelector() - selector.register(get_io(process, "stdout"), selectors.EVENT_READ, data="stdout") - selector.register(get_io(process, "stderr"), selectors.EVENT_READ, data="stderr") + with selectors.DefaultSelector() as selector: + selector.register(get_io(process, "stdout"), selectors.EVENT_READ, data="stdout") + selector.register(get_io(process, "stderr"), selectors.EVENT_READ, data="stderr") - result: dict[str, list[str]] = { - "stdout": [], - "stderr": [], - } - while selector.get_map(): # while there are unread selectors, keep reading - for key_data, output in poll(selector): - result[key_data].append(output) - - stdout = "\n".join(result["stdout"]).rstrip("\n") # remove newline at the end of any - stderr = "\n".join(result["stderr"]).rstrip("\n") + while selector.get_map(): # while there are unread selectors, keep reading + for key_data, output in poll(selector): + result[key_data].append(output) status_code = process.wait() - if status_code != 0: - if isinstance(exception, Exception): - raise exception - if callable(exception): - raise exception(status_code, list(args), stdout, stderr) - raise CalledProcessError(status_code, list(args), stderr) - return stdout + stdout = "\n".join(result["stdout"]).rstrip("\n") # remove newline at the end of any + stderr = "\n".join(result["stderr"]).rstrip("\n") + + if status_code != 0: + if isinstance(exception, Exception): + raise exception + if callable(exception): + raise exception(status_code, list(args), stdout, stderr) + raise CalledProcessError(status_code, list(args), stderr) + + return stdout def check_user(root: Path, *, unsafe: bool) -> None: diff --git a/tests/ahriman/core/report/test_email.py b/tests/ahriman/core/report/test_email.py index 52e2b7e1..0a9bc4e7 100644 --- a/tests/ahriman/core/report/test_email.py +++ b/tests/ahriman/core/report/test_email.py @@ -1,3 +1,5 @@ +import smtplib + import pytest from pytest_mock import MockerFixture @@ -6,6 +8,7 @@ from ahriman.core.configuration import Configuration from ahriman.core.report.email import Email from ahriman.models.package import Package from ahriman.models.result import Result +from ahriman.models.smtp_ssl_settings import SmtpSSLSettings def test_template(configuration: Configuration) -> None: @@ -37,17 +40,36 @@ def test_template_full(configuration: Configuration) -> None: assert Email(repository_id, configuration, "email").template_full == root.parent / template +def test_smtp_session(email: Email) -> None: + """ + must build normal SMTP session if SSL is disabled + """ + email.ssl = SmtpSSLSettings.Disabled + assert email._smtp_session == smtplib.SMTP + + email.ssl = SmtpSSLSettings.STARTTLS + assert email._smtp_session == smtplib.SMTP + + +def test_smtp_session_ssl(email: Email) -> None: + """ + must build SMTP_SSL session if SSL is enabled + """ + email.ssl = SmtpSSLSettings.SSL + assert email._smtp_session == smtplib.SMTP_SSL + + def test_send(email: Email, mocker: MockerFixture) -> None: """ must send an email with attachment """ smtp_mock = mocker.patch("smtplib.SMTP") + smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value email._send("a text", {"attachment.html": "an attachment"}) smtp_mock.return_value.starttls.assert_not_called() smtp_mock.return_value.login.assert_not_called() smtp_mock.return_value.sendmail.assert_called_once_with(email.sender, email.receivers, pytest.helpers.anyvar(int)) - smtp_mock.return_value.quit.assert_called_once_with() def test_send_auth(configuration: Configuration, mocker: MockerFixture) -> None: @@ -57,6 +79,7 @@ def test_send_auth(configuration: Configuration, mocker: MockerFixture) -> None: configuration.set_option("email", "user", "username") configuration.set_option("email", "password", "password") smtp_mock = mocker.patch("smtplib.SMTP") + smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value _, repository_id = configuration.check_loaded() email = Email(repository_id, configuration, "email") @@ -70,6 +93,7 @@ def test_send_auth_no_password(configuration: Configuration, mocker: MockerFixtu """ configuration.set_option("email", "user", "username") smtp_mock = mocker.patch("smtplib.SMTP") + smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value _, repository_id = configuration.check_loaded() email = Email(repository_id, configuration, "email") @@ -83,6 +107,7 @@ def test_send_auth_no_user(configuration: Configuration, mocker: MockerFixture) """ configuration.set_option("email", "password", "password") smtp_mock = mocker.patch("smtplib.SMTP") + smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value _, repository_id = configuration.check_loaded() email = Email(repository_id, configuration, "email") @@ -96,6 +121,7 @@ def test_send_ssl_tls(configuration: Configuration, mocker: MockerFixture) -> No """ configuration.set_option("email", "ssl", "ssl") smtp_mock = mocker.patch("smtplib.SMTP_SSL") + smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value _, repository_id = configuration.check_loaded() email = Email(repository_id, configuration, "email") @@ -103,7 +129,6 @@ def test_send_ssl_tls(configuration: Configuration, mocker: MockerFixture) -> No smtp_mock.return_value.starttls.assert_not_called() smtp_mock.return_value.login.assert_not_called() smtp_mock.return_value.sendmail.assert_called_once_with(email.sender, email.receivers, pytest.helpers.anyvar(int)) - smtp_mock.return_value.quit.assert_called_once_with() def test_send_starttls(configuration: Configuration, mocker: MockerFixture) -> None: @@ -112,6 +137,7 @@ def test_send_starttls(configuration: Configuration, mocker: MockerFixture) -> N """ configuration.set_option("email", "ssl", "starttls") smtp_mock = mocker.patch("smtplib.SMTP") + smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value _, repository_id = configuration.check_loaded() email = Email(repository_id, configuration, "email")