mirror of
https://github.com/arcan1s/ahriman.git
synced 2026-02-24 21:59:48 +00:00
fix: use context manager for selector and smtp session
This commit is contained in:
@@ -74,6 +74,18 @@ class Email(Report, JinjaTemplate):
|
|||||||
self.ssl = SmtpSSLSettings.from_option(configuration.get(section, "ssl", fallback="disabled"))
|
self.ssl = SmtpSSLSettings.from_option(configuration.get(section, "ssl", fallback="disabled"))
|
||||||
self.user = configuration.get(section, "user", fallback=None)
|
self.user = configuration.get(section, "user", fallback=None)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _smtp_session(self) -> type[smtplib.SMTP]:
|
||||||
|
"""
|
||||||
|
build SMTP session based on configuration settings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
type[smtplib.SMTP]: SMTP or SMTP_SSL session depending on whether SSL is enabled or not
|
||||||
|
"""
|
||||||
|
if self.ssl == SmtpSSLSettings.SSL:
|
||||||
|
return smtplib.SMTP_SSL
|
||||||
|
return smtplib.SMTP
|
||||||
|
|
||||||
def _send(self, text: str, attachment: dict[str, str]) -> None:
|
def _send(self, text: str, attachment: dict[str, str]) -> None:
|
||||||
"""
|
"""
|
||||||
send email callback
|
send email callback
|
||||||
@@ -93,16 +105,13 @@ class Email(Report, JinjaTemplate):
|
|||||||
attach.add_header("Content-Disposition", "attachment", filename=filename)
|
attach.add_header("Content-Disposition", "attachment", filename=filename)
|
||||||
message.attach(attach)
|
message.attach(attach)
|
||||||
|
|
||||||
if self.ssl != SmtpSSLSettings.SSL:
|
with self._smtp_session(self.host, self.port) as session:
|
||||||
session = smtplib.SMTP(self.host, self.port)
|
|
||||||
if self.ssl == SmtpSSLSettings.STARTTLS:
|
if self.ssl == SmtpSSLSettings.STARTTLS:
|
||||||
session.starttls()
|
session.starttls()
|
||||||
else:
|
|
||||||
session = smtplib.SMTP_SSL(self.host, self.port)
|
if self.user is not None and self.password is not None:
|
||||||
if self.user is not None and self.password is not None:
|
session.login(self.user, self.password)
|
||||||
session.login(self.user, self.password)
|
session.sendmail(self.sender, self.receivers, message.as_string())
|
||||||
session.sendmail(self.sender, self.receivers, message.as_string())
|
|
||||||
session.quit()
|
|
||||||
|
|
||||||
def generate(self, packages: list[Package], result: Result) -> None:
|
def generate(self, packages: list[Package], result: Result) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -164,6 +164,11 @@ def check_output(*args: str, exception: Exception | Callable[[int, list[str], st
|
|||||||
if key in ("PATH",) # whitelisted variables only
|
if key in ("PATH",) # whitelisted variables only
|
||||||
} | environment
|
} | environment
|
||||||
|
|
||||||
|
result: dict[str, list[str]] = {
|
||||||
|
"stdout": [],
|
||||||
|
"stderr": [],
|
||||||
|
}
|
||||||
|
|
||||||
with subprocess.Popen(args, cwd=cwd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
with subprocess.Popen(args, cwd=cwd, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
|
||||||
user=user, env=full_environment, text=True, encoding="utf8", errors="backslashreplace",
|
user=user, env=full_environment, text=True, encoding="utf8", errors="backslashreplace",
|
||||||
bufsize=1) as process:
|
bufsize=1) as process:
|
||||||
@@ -172,30 +177,27 @@ def check_output(*args: str, exception: Exception | Callable[[int, list[str], st
|
|||||||
input_channel.write(input_data)
|
input_channel.write(input_data)
|
||||||
input_channel.close()
|
input_channel.close()
|
||||||
|
|
||||||
selector = selectors.DefaultSelector()
|
with selectors.DefaultSelector() as selector:
|
||||||
selector.register(get_io(process, "stdout"), selectors.EVENT_READ, data="stdout")
|
selector.register(get_io(process, "stdout"), selectors.EVENT_READ, data="stdout")
|
||||||
selector.register(get_io(process, "stderr"), selectors.EVENT_READ, data="stderr")
|
selector.register(get_io(process, "stderr"), selectors.EVENT_READ, data="stderr")
|
||||||
|
|
||||||
result: dict[str, list[str]] = {
|
while selector.get_map(): # while there are unread selectors, keep reading
|
||||||
"stdout": [],
|
for key_data, output in poll(selector):
|
||||||
"stderr": [],
|
result[key_data].append(output)
|
||||||
}
|
|
||||||
while selector.get_map(): # while there are unread selectors, keep reading
|
|
||||||
for key_data, output in poll(selector):
|
|
||||||
result[key_data].append(output)
|
|
||||||
|
|
||||||
stdout = "\n".join(result["stdout"]).rstrip("\n") # remove newline at the end of any
|
|
||||||
stderr = "\n".join(result["stderr"]).rstrip("\n")
|
|
||||||
|
|
||||||
status_code = process.wait()
|
status_code = process.wait()
|
||||||
if status_code != 0:
|
|
||||||
if isinstance(exception, Exception):
|
|
||||||
raise exception
|
|
||||||
if callable(exception):
|
|
||||||
raise exception(status_code, list(args), stdout, stderr)
|
|
||||||
raise CalledProcessError(status_code, list(args), stderr)
|
|
||||||
|
|
||||||
return stdout
|
stdout = "\n".join(result["stdout"]).rstrip("\n") # remove newline at the end of any
|
||||||
|
stderr = "\n".join(result["stderr"]).rstrip("\n")
|
||||||
|
|
||||||
|
if status_code != 0:
|
||||||
|
if isinstance(exception, Exception):
|
||||||
|
raise exception
|
||||||
|
if callable(exception):
|
||||||
|
raise exception(status_code, list(args), stdout, stderr)
|
||||||
|
raise CalledProcessError(status_code, list(args), stderr)
|
||||||
|
|
||||||
|
return stdout
|
||||||
|
|
||||||
|
|
||||||
def check_user(root: Path, *, unsafe: bool) -> None:
|
def check_user(root: Path, *, unsafe: bool) -> None:
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import smtplib
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
@@ -6,6 +8,7 @@ from ahriman.core.configuration import Configuration
|
|||||||
from ahriman.core.report.email import Email
|
from ahriman.core.report.email import Email
|
||||||
from ahriman.models.package import Package
|
from ahriman.models.package import Package
|
||||||
from ahriman.models.result import Result
|
from ahriman.models.result import Result
|
||||||
|
from ahriman.models.smtp_ssl_settings import SmtpSSLSettings
|
||||||
|
|
||||||
|
|
||||||
def test_template(configuration: Configuration) -> None:
|
def test_template(configuration: Configuration) -> None:
|
||||||
@@ -37,17 +40,36 @@ def test_template_full(configuration: Configuration) -> None:
|
|||||||
assert Email(repository_id, configuration, "email").template_full == root.parent / template
|
assert Email(repository_id, configuration, "email").template_full == root.parent / template
|
||||||
|
|
||||||
|
|
||||||
|
def test_smtp_session(email: Email) -> None:
|
||||||
|
"""
|
||||||
|
must build normal SMTP session if SSL is disabled
|
||||||
|
"""
|
||||||
|
email.ssl = SmtpSSLSettings.Disabled
|
||||||
|
assert email._smtp_session == smtplib.SMTP
|
||||||
|
|
||||||
|
email.ssl = SmtpSSLSettings.STARTTLS
|
||||||
|
assert email._smtp_session == smtplib.SMTP
|
||||||
|
|
||||||
|
|
||||||
|
def test_smtp_session_ssl(email: Email) -> None:
|
||||||
|
"""
|
||||||
|
must build SMTP_SSL session if SSL is enabled
|
||||||
|
"""
|
||||||
|
email.ssl = SmtpSSLSettings.SSL
|
||||||
|
assert email._smtp_session == smtplib.SMTP_SSL
|
||||||
|
|
||||||
|
|
||||||
def test_send(email: Email, mocker: MockerFixture) -> None:
|
def test_send(email: Email, mocker: MockerFixture) -> None:
|
||||||
"""
|
"""
|
||||||
must send an email with attachment
|
must send an email with attachment
|
||||||
"""
|
"""
|
||||||
smtp_mock = mocker.patch("smtplib.SMTP")
|
smtp_mock = mocker.patch("smtplib.SMTP")
|
||||||
|
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
|
||||||
|
|
||||||
email._send("a text", {"attachment.html": "an attachment"})
|
email._send("a text", {"attachment.html": "an attachment"})
|
||||||
smtp_mock.return_value.starttls.assert_not_called()
|
smtp_mock.return_value.starttls.assert_not_called()
|
||||||
smtp_mock.return_value.login.assert_not_called()
|
smtp_mock.return_value.login.assert_not_called()
|
||||||
smtp_mock.return_value.sendmail.assert_called_once_with(email.sender, email.receivers, pytest.helpers.anyvar(int))
|
smtp_mock.return_value.sendmail.assert_called_once_with(email.sender, email.receivers, pytest.helpers.anyvar(int))
|
||||||
smtp_mock.return_value.quit.assert_called_once_with()
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_auth(configuration: Configuration, mocker: MockerFixture) -> None:
|
def test_send_auth(configuration: Configuration, mocker: MockerFixture) -> None:
|
||||||
@@ -57,6 +79,7 @@ def test_send_auth(configuration: Configuration, mocker: MockerFixture) -> None:
|
|||||||
configuration.set_option("email", "user", "username")
|
configuration.set_option("email", "user", "username")
|
||||||
configuration.set_option("email", "password", "password")
|
configuration.set_option("email", "password", "password")
|
||||||
smtp_mock = mocker.patch("smtplib.SMTP")
|
smtp_mock = mocker.patch("smtplib.SMTP")
|
||||||
|
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
|
||||||
_, repository_id = configuration.check_loaded()
|
_, repository_id = configuration.check_loaded()
|
||||||
|
|
||||||
email = Email(repository_id, configuration, "email")
|
email = Email(repository_id, configuration, "email")
|
||||||
@@ -70,6 +93,7 @@ def test_send_auth_no_password(configuration: Configuration, mocker: MockerFixtu
|
|||||||
"""
|
"""
|
||||||
configuration.set_option("email", "user", "username")
|
configuration.set_option("email", "user", "username")
|
||||||
smtp_mock = mocker.patch("smtplib.SMTP")
|
smtp_mock = mocker.patch("smtplib.SMTP")
|
||||||
|
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
|
||||||
_, repository_id = configuration.check_loaded()
|
_, repository_id = configuration.check_loaded()
|
||||||
|
|
||||||
email = Email(repository_id, configuration, "email")
|
email = Email(repository_id, configuration, "email")
|
||||||
@@ -83,6 +107,7 @@ def test_send_auth_no_user(configuration: Configuration, mocker: MockerFixture)
|
|||||||
"""
|
"""
|
||||||
configuration.set_option("email", "password", "password")
|
configuration.set_option("email", "password", "password")
|
||||||
smtp_mock = mocker.patch("smtplib.SMTP")
|
smtp_mock = mocker.patch("smtplib.SMTP")
|
||||||
|
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
|
||||||
_, repository_id = configuration.check_loaded()
|
_, repository_id = configuration.check_loaded()
|
||||||
|
|
||||||
email = Email(repository_id, configuration, "email")
|
email = Email(repository_id, configuration, "email")
|
||||||
@@ -96,6 +121,7 @@ def test_send_ssl_tls(configuration: Configuration, mocker: MockerFixture) -> No
|
|||||||
"""
|
"""
|
||||||
configuration.set_option("email", "ssl", "ssl")
|
configuration.set_option("email", "ssl", "ssl")
|
||||||
smtp_mock = mocker.patch("smtplib.SMTP_SSL")
|
smtp_mock = mocker.patch("smtplib.SMTP_SSL")
|
||||||
|
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
|
||||||
_, repository_id = configuration.check_loaded()
|
_, repository_id = configuration.check_loaded()
|
||||||
|
|
||||||
email = Email(repository_id, configuration, "email")
|
email = Email(repository_id, configuration, "email")
|
||||||
@@ -103,7 +129,6 @@ def test_send_ssl_tls(configuration: Configuration, mocker: MockerFixture) -> No
|
|||||||
smtp_mock.return_value.starttls.assert_not_called()
|
smtp_mock.return_value.starttls.assert_not_called()
|
||||||
smtp_mock.return_value.login.assert_not_called()
|
smtp_mock.return_value.login.assert_not_called()
|
||||||
smtp_mock.return_value.sendmail.assert_called_once_with(email.sender, email.receivers, pytest.helpers.anyvar(int))
|
smtp_mock.return_value.sendmail.assert_called_once_with(email.sender, email.receivers, pytest.helpers.anyvar(int))
|
||||||
smtp_mock.return_value.quit.assert_called_once_with()
|
|
||||||
|
|
||||||
|
|
||||||
def test_send_starttls(configuration: Configuration, mocker: MockerFixture) -> None:
|
def test_send_starttls(configuration: Configuration, mocker: MockerFixture) -> None:
|
||||||
@@ -112,6 +137,7 @@ def test_send_starttls(configuration: Configuration, mocker: MockerFixture) -> N
|
|||||||
"""
|
"""
|
||||||
configuration.set_option("email", "ssl", "starttls")
|
configuration.set_option("email", "ssl", "starttls")
|
||||||
smtp_mock = mocker.patch("smtplib.SMTP")
|
smtp_mock = mocker.patch("smtplib.SMTP")
|
||||||
|
smtp_mock.return_value.__enter__.return_value = smtp_mock.return_value
|
||||||
_, repository_id = configuration.check_loaded()
|
_, repository_id = configuration.check_loaded()
|
||||||
|
|
||||||
email = Email(repository_id, configuration, "email")
|
email = Email(repository_id, configuration, "email")
|
||||||
|
|||||||
Reference in New Issue
Block a user