From 157e8c1b175cbc90653ec4932003fcd1ed79d82a Mon Sep 17 00:00:00 2001 From: BreadTube Date: Tue, 23 Sep 2025 22:48:35 +0900 Subject: [PATCH] Config scan from bot channel implementation --- bot.py | 20 +++- breadtube_bot/api.py | 14 ++- breadtube_bot/config.py | 28 ++++- breadtube_bot/manager.py | 226 +++++++++++++++++++++++++++++++++++---- breadtube_bot/objects.py | 155 ++++++++++++++++++++++++++- tests/test_config.py | 44 ++++++++ 6 files changed, 453 insertions(+), 34 deletions(-) create mode 100644 tests/test_config.py diff --git a/bot.py b/bot.py index 67d94e7..248ffec 100644 --- a/bot.py +++ b/bot.py @@ -1,13 +1,27 @@ +from argparse import ArgumentParser +import logging from pathlib import Path from breadtube_bot.manager import DiscordManager def main(): + parser = ArgumentParser('BreadTube-bot') + parser.add_argument('--guild', type=int, default=1306964577812086824, help='Guild id to manage') + parser.add_argument('--debug', action='store_true', default=False, help='Run in debug mode (for logs)') + arguments = parser.parse_args() + + debug_mode: bool = arguments.debug + guild_id: int = arguments.guild + del arguments + bot_token = Path('data/discord_bot_token.txt').read_text(encoding='utf-8').strip() - guild_id = 1306964577812086824 - manager = DiscordManager(bot_token=bot_token, guild_id=guild_id) - print(manager.rate_limit) + manager = DiscordManager( + bot_token=bot_token, guild_id=guild_id, log_level=logging.DEBUG if debug_mode else logging.INFO) + try: + manager.run() + except KeyboardInterrupt: + print('\r ') # noqa: T201 if __name__ == '__main__': diff --git a/breadtube_bot/api.py b/breadtube_bot/api.py index f901ccb..6277246 100644 --- a/breadtube_bot/api.py +++ b/breadtube_bot/api.py @@ -1,7 +1,7 @@ from enum import Enum from typing import TypedDict -from breadtube_bot.objects import Overwrite +from breadtube_bot.objects import MessageReference, Overwrite class ApiVersion(Enum): @@ -116,7 +116,7 @@ class Api: tts: bool # true if this is a TTS message # embeds: list[Embeded] # Up to 10 rich embeds (up to 6000 characters) # allowed_mentions: MentionObject # Allowed mentions for the message - # message_reference: MessageReference # Include to make your message a reply or a forward + message_reference: MessageReference # Include to make your message a reply or a forward # components: list[MessageComponent] # Components to include with the message sticker_ids: list[int] # IDs of up to 3 stickers in the server to send in the message # files[n]: FileContents # Contents of the file being sent. See Uploading Files @@ -136,7 +136,11 @@ class Api: return ApiAction.DELETE, f'/channels/{channel_id}/messages/{message_id}' @staticmethod - def list_by_channel(channel_id: int, limit: int | None = None) -> tuple[ApiAction, str]: - if limit is not None and not (0 < limit <= 100): # noqa: PLR2004 - raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range') + def list_by_channel(channel_id: int) -> tuple[ApiAction, str]: return ApiAction.GET, f'/channels/{channel_id}/messages' + + class ListMessageParams(TypedDict, total=False): + around: int # Get messages around this message ID + before: int # Get messages before this message ID + after: int # Get messages after this message ID + limit: int # Max number of messages to return (1-100), default=50 diff --git a/breadtube_bot/config.py b/breadtube_bot/config.py index cabe266..4335475 100644 --- a/breadtube_bot/config.py +++ b/breadtube_bot/config.py @@ -1,3 +1,5 @@ +from __future__ import annotations as _annotations + from dataclasses import asdict, dataclass @@ -5,17 +7,33 @@ from dataclasses import asdict, dataclass class Config: bot_channel: str = 'breadtube-bot' bot_role: str = 'BreadTube' + bot_channel_scan_interval: float = 30. bot_channel_init_retries: int = 3 + bot_message_duration: float = 150. + request_timeout: float = 3. def to_str(self) -> str: return '\n'.join(['config', *[f'{k}={v}' for k, v in asdict(self).items()]]) - def from_str(self, text: str): + @staticmethod + def from_str(text: str) -> Config: + annotations = Config.__annotations__ + global_types = globals()['__builtins__'] + config = Config() lines = text.strip().splitlines() if not lines: - raise RuntimeError('Config cannot load: empty input') + raise RuntimeError('Cannot load config: empty input') if lines[0] != 'config': - raise RuntimeError('Config cannot load: first line is not "config"') - for line in lines[1:]: + raise RuntimeError('Cannot load config: first line is not "config"') + config_dict = {} + for line_number, line in enumerate(lines[1:]): key, value = line.split('=', maxsplit=1) - setattr(self, key, self.__annotations__[key](value)) + if key not in annotations: + raise RuntimeError(f'Invalid config: invalid key {key} at line {line_number + 1}') + if key in config_dict: + raise RuntimeError(f'Invalid config: duplicated key {key} at line {line_number + 1}') + config_dict[key] = value + + for key, value in config_dict.items(): + setattr(config, key, global_types[annotations[key]](value)) + return config diff --git a/breadtube_bot/manager.py b/breadtube_bot/manager.py index 5a3ec21..9f34b5e 100644 --- a/breadtube_bot/manager.py +++ b/breadtube_bot/manager.py @@ -1,11 +1,15 @@ +from __future__ import annotations + from dataclasses import asdict, dataclass, is_dataclass from enum import Enum import logging +import operator from pathlib import Path import json import random import time import tomllib +from typing import Any import urllib.error import urllib.request @@ -13,7 +17,8 @@ from .api import Api, ApiAction, ApiVersion from .config import Config from .logger import create_logger from .objects import ( - ChannelCategory, ChannelType, FileMime, Message, Overwrite, OverwriteType, Permissions, Role, TextChannel) + Attachment, ChannelCategory, ChannelType, FileMime, Message, MessageReference, MessageReferenceType, Overwrite, + OverwriteType, Permissions, Role, TextChannel) HTTPHeaders = dict[str, str] @@ -29,11 +34,20 @@ class ApiEncoder(json.JSONEncoder): class DiscordManager: + MAX_CONFIG_SIZE: int = 50_000 + DEFAULT_MESSAGE_LIST_LIMIT = 50 + INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n' + 'You can upload a new one to update the configuration.') + @dataclass class RateLimit: remaining: int next_reset: float + class Task(Enum): + SCAN_BOT_CHANNEL = 1 + DELETE_MESSAGES = 2 + @staticmethod def _get_code_version() -> str: pyproject_path = Path(__file__).parents[1] / 'pyproject.toml' @@ -50,8 +64,12 @@ class DiscordManager: self.rate_limit = self.RateLimit(remaining=1, next_reset=0) self.version = self._get_code_version() + self.tasks: list[tuple[DiscordManager.Task, float, Any]] = [] + self.logger.info('Retrieving guild roles before init') self.guild_roles: list = self.list_roles() + self.bot_channel: TextChannel | None = None + self.init_message: Message | None = None for _ in range(self.config.bot_channel_init_retries): while not self.init_bot_channel(): time.sleep(10) @@ -60,6 +78,10 @@ class DiscordManager: self.logger.info('Bot init OK') break raise RuntimeError("Couldn't initialize bot channel/role/permission") + + self._scan_bot_channel() + self.tasks.append(( + self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None)) self.logger.info('Bot initialized') def _update_rate_limit(self, headers: HTTPHeaders): @@ -69,11 +91,11 @@ class DiscordManager: return self.rate_limit.remaining = int(headers['x-ratelimit-remaining']) self.rate_limit.next_reset = float(headers['x-ratelimit-reset']) + self.logger.debug('Updated rate limit: %s', self.rate_limit) def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10, data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None, expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]: - timeout = 3 min_api_version = 9 if api_version.value < min_api_version: @@ -90,7 +112,8 @@ class DiscordManager: + f'\r\n--{boundary}'.encode()) if data else b'' for file_index, (name, mime, content) in enumerate(upload_files): data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";' - f' filename="{name}"\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content + f' filename="{name}"\r\nContent-Length: {len(content)}' + f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content data += f'\r\n--{boundary}--'.encode() request = urllib.request.Request(url, data=data) request.method = api_action.value @@ -102,7 +125,7 @@ class DiscordManager: request.add_header('Content-Type', 'application/json') request.add_header('Authorization', f'Bot {self._bot_token}') try: - with urllib.request.urlopen(request, timeout=timeout) as response: + with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response: if response.status != expected_code: raise RuntimeError( f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}') @@ -114,6 +137,23 @@ class DiscordManager: except urllib.error.URLError as error: raise RuntimeError(f'URL error calling API ({url}): {error}') from error + def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]: + request = urllib.request.Request(attachment.url) + request.add_header('User-Agent', f'BreadTube (v{self.version})') + request.add_header('Authorization', f'Bot {self._bot_token}') + try: + with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response: + if response.status != expected_code: + raise RuntimeError( + f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})') + return dict(response.getheaders()), response.read() + except urllib.error.HTTPError as error: + raise RuntimeError( + f'HTTP error downloading attachment ({attachment}):' + f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error + except urllib.error.URLError as error: + raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error + def init_bot_channel(self) -> bool: _, text_channel = self.list_channels() breadtube_role: Role | None = None @@ -130,25 +170,20 @@ class DiscordManager: self.logger.info('No everyone role found') return False - breadtube_channel: TextChannel | None = None for channel in text_channel: if channel.name == self.config.bot_channel: - breadtube_channel = channel + self.bot_channel = channel self.logger.info('Found breadtube bot channel') - for perm in breadtube_channel.permission_overwrites: + for perm in self.bot_channel.permission_overwrites: if perm.id == breadtube_role.id: if not perm.allow | Permissions.VIEW_CHANNEL: self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing') return False self.logger.info('BreadTube channel permission OK') break - messages = self.list_text_channel_messages(breadtube_channel) - for message in messages: - self.logger.debug('Deleting message: %s', message) - self.delete_message(message) break else: - breadtube_channel = self.create_text_channel({ + self.bot_channel = self.create_text_channel({ 'name': self.config.bot_channel, 'permission_overwrites': [ Overwrite(everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE, @@ -157,13 +192,152 @@ class DiscordManager: deny=Permissions.NONE)] }) self.logger.info('Created breadtube bot channel') - - self.create_message( - breadtube_channel, - {'content': 'This is the current configuration used, upload a new one to update the configuration'}, - upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())]) return True + def _scan_bot_channel(self): + if self.bot_channel is None: + self.logger.error('Cannot scan bot channel: bot channel is None') + return [] + + last_message_id: int | None = None + while True: + messages = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id) + if len(messages) < self.DEFAULT_MESSAGE_LIST_LIMIT: + break + last_message_id = messages[-1].id + messages = sorted(messages, key=lambda x: x.timestamp) + + self.init_message = None + new_config: Config | None = None + messages_to_delete: list[Message] = [] + for message in messages: + # Skip message to be deleted + skip = True + for task_type, _, task_params in self.tasks: + if task_type == self.Task.DELETE_MESSAGES and (any(m.id == message.id for m in task_params) or any( + m.id == message.id for m in messages_to_delete)): + self.logger.debug('Skipping message already marked to be deleted') + break + else: + skip = False + if skip: + continue + + delete_message = True + for attachment in message.attachments: + if attachment.size < self.MAX_CONFIG_SIZE: + try: + _, content = self._download_attachment(attachment) + if content.startswith(b'config'): + try: + config = Config.from_str(content.decode()) + if config != self.config: + new_config = config + elif message.content == self.INIT_MESSAGE: + if self.init_message is not None: + self.logger.debug('Deleting duplicated init message') + try: + self.delete_message(self.init_message) + except RuntimeError as error: + self.logger.error('Error deleting init_message while scanning: %s', error) + self.init_message = message + delete_message = False + break + except RuntimeError as error: + self.logger.info('Invalid config file: %s', error) + messages_to_delete.extend([ + self.create_message(self.bot_channel, { + 'content': str(error), + 'message_reference': MessageReference( + type=MessageReferenceType.DEFAULT, + message_id=message.id, + channel_id=self.bot_channel.id, + guild_id=None, + fail_if_not_exists=None)}), + message]) + delete_message = False + break + except Exception as error: + self.logger.error('Error downloading attachment: %s', error) + messages_to_delete.extend([ + self.create_message(self.bot_channel, { + 'content': str(error), + 'message_reference': MessageReference( + type=MessageReferenceType.DEFAULT, + message_id=message.id, + channel_id=self.bot_channel.id, + guild_id=None, + fail_if_not_exists=None)}), + message]) + delete_message = False + break + if delete_message: + if any(m.id == message.id for m in messages_to_delete): + self.logger.warning( + 'Warning wrongly trying to delete message id %d while marked to be deleted', message.id) + else: + self.logger.debug('Deleting message: %s', message) + try: + self.delete_message(message) + except RuntimeError as error: + self.logger.error('Error deleting after scanned: %s', error) + + if new_config is not None: + self.logger.info('Loading new config: %s', new_config) + self.config = new_config + if self.init_message is not None: + self.delete_message(self.init_message) + self.init_message = None + + if self.init_message is None: + self.init_message = self.create_message( + self.bot_channel, {'content': self.INIT_MESSAGE}, + upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())]) + + if messages_to_delete: + self.tasks.append(( + DiscordManager.Task.DELETE_MESSAGES, + time.time() + self.config.bot_message_duration, + messages_to_delete)) + + def run(self): + while True: + if self.tasks: + self.tasks = sorted(self.tasks, key=operator.itemgetter(1), reverse=True) + task_type, task_time, task_params = self.tasks.pop() + sleep_time = task_time - time.time() + self.logger.debug( + 'Next task %s at %.03f (%s), sleeping %.03fs', task_type, task_time, task_params, sleep_time) + if sleep_time > 0: + time.sleep(sleep_time) + match task_type: + case DiscordManager.Task.SCAN_BOT_CHANNEL: + try: + self._scan_bot_channel() + except Exception as error: + self.logger.error('Error scanning bot channel: %s', error) + self.tasks.append(( + self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None)) + case DiscordManager.Task.DELETE_MESSAGES: + if not isinstance(task_params, list): + self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params) + elif not task_params: + self.logger.error('Empty params for DELETE_MESSAGES: %s', task_params) + elif any(not isinstance(v, Message) for v in task_params): + self.logger.error('All params not int for DELETE_MESSAGES: %s', task_params) + else: + for message in task_params: + try: + self.delete_message(message) + except Exception as error: + self.logger.error('Error deleting message %s: %s', message, error) + if self.rate_limit.remaining <= 1: + sleep_time = self.rate_limit.next_reset - time.time() + if sleep_time > 0: + self.logger.debug('Rate limit: sleeping %.03f second') + time.sleep(sleep_time) + time.sleep(1) + def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel: headers, channel_info = self._send_request( *Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(), @@ -210,7 +384,19 @@ class DiscordManager: raise RuntimeError(f'Error listing roles (not a list): {roles_info}') return [Role.from_dict(r) for r in roles_info] - def list_text_channel_messages(self, channel: TextChannel) -> list[Message]: - headers, messages = self._send_request(*Api.Message.list_by_channel(channel.id)) + def list_text_channel_messages(self, channel: TextChannel, limit: int | None = None, before_id: int | None = None, + after_id: int | None = None) -> list[Message]: + if limit is not None and not (0 < limit <= 100): # noqa: PLR2004 + raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range') + params: Api.Message.ListMessageParams = {} + if limit is not None: + params['limit'] = limit + if before_id is not None: + params['before'] = before_id + if after_id is not None: + params['after'] = after_id + headers, messages_info = self._send_request( + *Api.Message.list_by_channel(channel.id), + data=json.dumps(params, cls=ApiEncoder).encode() if params else None) self._update_rate_limit(headers) - return [Message.from_dict(m) for m in messages or []] + return [Message.from_dict(m) for m in messages_info or []] diff --git a/breadtube_bot/objects.py b/breadtube_bot/objects.py index cb34a5c..2cead75 100644 --- a/breadtube_bot/objects.py +++ b/breadtube_bot/objects.py @@ -16,11 +16,16 @@ class FileMime(Enum): TEXT_HTML = 'text/html' TEXT_MARKDOWN = 'text/markdown' TEXT_PLAIN = 'text/plain' + UNKNOWN = 'application/unknown' VIDEO_MP4 = 'video/mp4' VIDEO_MPEG = 'video/mpeg' VIDEO_WEBM = 'video/webm' ZIP = 'application/zip' + @classmethod + def _missing_(cls, value): # noqa: ARG003 + return FileMime.UNKNOWN + class ChannelType(Enum): GUILD_TEXT = 0 @@ -256,6 +261,144 @@ class User: # TODO : complete attributes global_name=info.get('global_name')) +class AttachmentFlags(IntFlag): + IS_REMIX = 1 << 2 # this attachment has been edited using the remix feature on mobile + + +@dataclass +class Attachment: + id: int # attachment id + filename: str # name of file attached + title: str | None # the title of the file + description: str | None # description for the file (max 1024 characters) + content_type: str | None # the attachment's media type + size: int # size of file in bytes + url: str # source url of file + proxy_url: str # a proxied url of file + height: int | None # height of file (if image) + width: int | None # width of file (if image) + ephemeral: bool | None # whether this attachment is ephemeral + duration_secs: float | None # the duration of the audio file (currently for voice messages) + waveform: str | None # base64 encoded bytearray representing a sampled waveform (currently for voice messages) + flags: int | None # attachment flags combined as a bitfield + + @staticmethod + def from_dict(info: dict) -> Attachment: + height = info.get('height') + width = info.get('width') + duraction_secs = info.get('duration_secs') + flags = info.get('flags') + return Attachment( + id=int(info['id']), + filename=info['filename'], + title=info.get('title'), + description=info.get('description'), + content_type=info.get('content_type'), + size=int(info['size']), + url=info['url'], + proxy_url=info['proxy_url'], + height=int(height) if height is not None else None, + width=int(width) if width is not None else None, + ephemeral=info.get('ephemeral'), + duration_secs=float(duraction_secs) if duraction_secs is not None else None, + waveform=info.get('waveform'), + flags=AttachmentFlags(int(flags)) if flags is not None else None, + ) + + +class MessageType(Enum): + DEFAULT = 0 + RECIPIENT_ADD = 1 + RECIPIENT_REMOVE = 2 + CALL = 3 + CHANNEL_NAME_CHANGE = 4 + CHANNEL_ICON_CHANGE = 5 + CHANNEL_PINNED_MESSAGE = 6 + USER_JOIN = 7 + GUILD_BOOST = 8 + GUILD_BOOST_TIER_1 = 9 + GUILD_BOOST_TIER_2 = 10 + GUILD_BOOST_TIER_3 = 11 + CHANNEL_FOLLOW_ADD = 12 + GUILD_DISCOVERY_DISQUALIFIED = 14 + GUILD_DISCOVERY_REQUALIFIED = 15 + GUILD_DISCOVERY_GRACE_PERIOD_INITIAL_WARNING = 16 + GUILD_DISCOVERY_GRACE_PERIOD_FINAL_WARNING = 17 + THREAD_CREATED = 18 + REPLY = 19 + CHAT_INPUT_COMMAND = 20 + THREAD_STARTER_MESSAGE = 21 + GUILD_INVITE_REMINDER = 22 + CONTEXT_MENU_COMMAND = 23 + AUTO_MODERATION_ACTION = 24 + ROLE_SUBSCRIPTION_PURCHASE = 25 + INTERACTION_PREMIUM_UPSELL = 26 + STAGE_START = 27 + STAGE_END = 28 + STAGE_SPEAKER = 29 + STAGE_TOPIC = 31 + GUILD_APPLICATION_PREMIUM_SUBSCRIPTION = 32 + GUILD_INCIDENT_ALERT_MODE_ENABLED = 36 + GUILD_INCIDENT_ALERT_MODE_DISABLED = 37 + GUILD_INCIDENT_REPORT_RAID = 38 + GUILD_INCIDENT_REPORT_FALSE_ALARM = 39 + PURCHASE_NOTIFICATION = 44 + POLL_RESULT = 46 + + +NON_DELETABLE_MESSAGE_TYPES = [ + MessageType.RECIPIENT_ADD, + MessageType.RECIPIENT_REMOVE, + MessageType.CALL, + MessageType.CHANNEL_NAME_CHANGE, + MessageType.CHANNEL_ICON_CHANGE, + MessageType.THREAD_STARTER_MESSAGE] + + +class MessageFlags(IntFlag): + NONE = 0 + CROSSPOSTED = 1 << 0 # this message has been published to subscribed channels (via Channel Following) + IS_CROSSPOST = 1 << 1 # this message originated from a message in another channel (via Channel Following) + SUPPRESS_EMBEDS = 1 << 2 # do not include any embeds when serializing this message + SOURCE_MESSAGE_DELETED = 1 << 3 # the source message for this crosspost has been deleted (via Channel Following) + URGENT = 1 << 4 # this message came from the urgent message system + HAS_THREAD = 1 << 5 # this message has an associated thread, with the same id as the message + EPHEMERAL = 1 << 6 # this message is only visible to the user who invoked the Interaction + LOADING = 1 << 7 # this message is an Interaction Response and the bot is "thinking" + # this message failed to mention some roles and add their members to the thread + FAILED_TO_MENTION_SOME_ROLES_IN_THREAD = 1 << 8 + SUPPRESS_NOTIFICATIONS = 1 << 12 # this message will not trigger push and desktop notifications + IS_VOICE_MESSAGE = 1 << 13 # this message is a voice message + HAS_SNAPSHOT = 1 << 14 # this message has a snapshot (via Message Forwarding) + IS_COMPONENTS_V2 = 1 << 15 # allows you to create fully component-driven messages + + +class MessageReferenceType(Enum): + DEFAULT = 0 # A standard reference used by replies. + FORWARD = 1 # Reference used to point to a message at a point in time. + + +@dataclass +class MessageReference: + type: MessageReferenceType | None # type of reference. + message_id: int | None # id of the originating message + channel_id: int | None # id of the originating message's channel + guild_id: int | None # id of the originating message's guild + # when sending, whether to error if the referenced message doesn't exist + # instead of sending as a normal (non-reply) message, default true + fail_if_not_exists: bool | None + + @staticmethod + def from_dict(info: dict) -> MessageReference: + ref_type: int | None = info.get('type') + return MessageReference( + type=MessageReferenceType(ref_type) if ref_type is not None else None, + message_id=info.get('message_id'), + channel_id=info.get('channel_id'), + guild_id=info.get('guild_id'), + fail_if_not_exists=info.get('fail_if_not_exists')) + + @dataclass class Message: # TODO : complete attributes id: int @@ -264,10 +407,16 @@ class Message: # TODO : complete attributes content: str timestamp: datetime edited_timestamp: datetime | None + attachments: list[Attachment] + type: MessageType + flags: MessageFlags | None + message_reference: MessageReference | None @staticmethod def from_dict(info: dict) -> Message: edited_timestamp: str | None = info.get('edited_timestamp') + flags: int | None = info.get('flags') + message_reference = info.get('message_reference') return Message( id=int(info['id']), channel_id=int(info['channel_id']), @@ -275,7 +424,11 @@ class Message: # TODO : complete attributes id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)), content=info['content'], timestamp=datetime.fromisoformat(info['timestamp']), - edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None) + edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None, + attachments=[Attachment.from_dict(a) for a in info['attachments']], + type=MessageType(info['type']), + flags=MessageFlags(flags) if flags is not None else None, + message_reference=MessageReference.from_dict(message_reference) if message_reference is not None else None) @dataclass diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..b20e817 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,44 @@ +from breadtube_bot.config import Config + +import pytest + + +def test_empty(): + with pytest.raises(RuntimeError, match='Cannot load config: empty input'): + Config.from_str('') + + +def test_wrong_header(): + with pytest.raises(RuntimeError, match='Cannot load config: first line is not "config"'): + Config.from_str('connfig') + + +def test_lacking_key(): + expected_config = Config(bot_role='test-role', request_timeout=5.5) + config = Config.from_str( + 'config\n' + f'request_timeout={expected_config.request_timeout}\n' + f'bot_role={expected_config.bot_role}') + assert config == expected_config + + +def test_wrong_key(): + with pytest.raises(RuntimeError, match=('Invalid config: invalid key bot_channel_init at line 2')): + Config.from_str('config\nrequest_timeout=3.\nbot_channel_init=2') + + +def test_duplicated_key(): + with pytest.raises(RuntimeError, match=('Invalid config: duplicated key request_timeout at line 3')): + Config.from_str('config\nrequest_timeout=3.\nbot_channel_init_retries=2\nrequest_timeout=5.') + + +def test_correct_config(): + expected_config = Config(bot_channel='test-channel', bot_role='test-role', bot_channel_init_retries=298, + request_timeout=5.5) + config = Config.from_str( + 'config\n' + f'request_timeout={expected_config.request_timeout}\n' + f'bot_channel_init_retries={expected_config.bot_channel_init_retries}\n' + f'bot_channel={expected_config.bot_channel}\n' + f'bot_role={expected_config.bot_role}') + assert config == expected_config