diff --git a/breadtube_bot/api.py b/breadtube_bot/api.py index 6277246..eceac7a 100644 --- a/breadtube_bot/api.py +++ b/breadtube_bot/api.py @@ -144,3 +144,8 @@ class Api: before: int # Get messages before this message ID after: int # Get messages after this message ID limit: int # Max number of messages to return (1-100), default=50 + + class User: + @staticmethod + def get_current() -> tuple[ApiAction, str]: + return ApiAction.GET, '/users/@me' diff --git a/breadtube_bot/manager.py b/breadtube_bot/manager.py index 9f34b5e..ad38d4d 100644 --- a/breadtube_bot/manager.py +++ b/breadtube_bot/manager.py @@ -12,13 +12,14 @@ import tomllib from typing import Any import urllib.error import urllib.request +import traceback from .api import Api, ApiAction, ApiVersion from .config import Config from .logger import create_logger from .objects import ( Attachment, ChannelCategory, ChannelType, FileMime, Message, MessageReference, MessageReferenceType, Overwrite, - OverwriteType, Permissions, Role, TextChannel) + OverwriteType, Permissions, Role, TextChannel, User) HTTPHeaders = dict[str, str] @@ -66,6 +67,8 @@ class DiscordManager: self.version = self._get_code_version() self.tasks: list[tuple[DiscordManager.Task, float, Any]] = [] + self.logger.info('Retrieving bot user') + self.bot_user = self.get_current_user() self.logger.info('Retrieving guild roles before init') self.guild_roles: list = self.list_roles() self.bot_channel: TextChannel | None = None @@ -92,6 +95,11 @@ class DiscordManager: self.rate_limit.remaining = int(headers['x-ratelimit-remaining']) self.rate_limit.next_reset = float(headers['x-ratelimit-reset']) self.logger.debug('Updated rate limit: %s', self.rate_limit) + if self.rate_limit.remaining <= 0: + sleep_time = self.rate_limit.next_reset - time.time() + if sleep_time > 0: + self.logger.debug('Rate limit: sleeping %.03f second', sleep_time) + time.sleep(sleep_time) def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10, data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None, @@ -130,12 +138,16 @@ class DiscordManager: raise RuntimeError( f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}') body = response.read() - return dict(response.getheaders()), json.loads(body.decode()) if body else None + headers = dict(response.getheaders()) + self._update_rate_limit(headers) + return headers, json.loads(body.decode()) if body else None except urllib.error.HTTPError as error: raise RuntimeError( f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error except urllib.error.URLError as error: raise RuntimeError(f'URL error calling API ({url}): {error}') from error + except TimeoutError as error: + raise RuntimeError(f'Timeout calling API ({url}): {error}') from error def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]: request = urllib.request.Request(attachment.url) @@ -199,106 +211,87 @@ class DiscordManager: self.logger.error('Cannot scan bot channel: bot channel is None') return [] + messages_id_delete_task: set[int] = set() + for task_type, _, task_params in self.tasks: + if task_type == self.Task.DELETE_MESSAGES: + messages_id_delete_task.update(message.id for message in task_params) + last_message_id: int | None = None + messages: list[Message] = [] while True: - messages = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id) - if len(messages) < self.DEFAULT_MESSAGE_LIST_LIMIT: + message_batch = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id) + messages.extend([m for m in message_batch if m.id not in messages_id_delete_task]) + if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT: break - last_message_id = messages[-1].id - messages = sorted(messages, key=lambda x: x.timestamp) + last_message_id = message_batch[-1].id self.init_message = None new_config: Config | None = None - messages_to_delete: list[Message] = [] + delayed_delete: dict[int, Message] = {} + immediate_delete: dict[int, Message] = {} for message in messages: - # Skip message to be deleted - skip = True - for task_type, _, task_params in self.tasks: - if task_type == self.Task.DELETE_MESSAGES and (any(m.id == message.id for m in task_params) or any( - m.id == message.id for m in messages_to_delete)): - self.logger.debug('Skipping message already marked to be deleted') - break - else: - skip = False - if skip: + if message.id in delayed_delete: + self.logger.debug('Skipping message already marked to be deleted') continue - delete_message = True - for attachment in message.attachments: + if self.init_message is None and new_config is None and len(message.attachments) == 1: + attachment = message.attachments[0] if attachment.size < self.MAX_CONFIG_SIZE: try: _, content = self._download_attachment(attachment) if content.startswith(b'config'): try: config = Config.from_str(content.decode()) - if config != self.config: - new_config = config - elif message.content == self.INIT_MESSAGE: - if self.init_message is not None: - self.logger.debug('Deleting duplicated init message') - try: - self.delete_message(self.init_message) - except RuntimeError as error: - self.logger.error('Error deleting init_message while scanning: %s', error) + if message.author.id == self.bot_user.id: # keep using current config + self.logger.debug('Found previous init message') self.init_message = message - delete_message = False - break + if config != self.config: # First scan qill need to load config + self.config = config + continue + if config != self.config: # New config to update to + new_config = config + self.logger.debug('Marking new config message for immediate deletion: %s', message) + immediate_delete[message.id] = message + continue except RuntimeError as error: self.logger.info('Invalid config file: %s', error) - messages_to_delete.extend([ - self.create_message(self.bot_channel, { + bot_message = self.create_message(self.bot_channel, { 'content': str(error), 'message_reference': MessageReference( type=MessageReferenceType.DEFAULT, message_id=message.id, channel_id=self.bot_channel.id, guild_id=None, - fail_if_not_exists=None)}), - message]) - delete_message = False - break + fail_if_not_exists=None)}) + delayed_delete[bot_message.id] = bot_message + delayed_delete[message.id] = message + continue except Exception as error: self.logger.error('Error downloading attachment: %s', error) - messages_to_delete.extend([ - self.create_message(self.bot_channel, { - 'content': str(error), - 'message_reference': MessageReference( - type=MessageReferenceType.DEFAULT, - message_id=message.id, - channel_id=self.bot_channel.id, - guild_id=None, - fail_if_not_exists=None)}), - message]) - delete_message = False - break - if delete_message: - if any(m.id == message.id for m in messages_to_delete): - self.logger.warning( - 'Warning wrongly trying to delete message id %d while marked to be deleted', message.id) - else: - self.logger.debug('Deleting message: %s', message) - try: - self.delete_message(message) - except RuntimeError as error: - self.logger.error('Error deleting after scanned: %s', error) + self.logger.debug('Marking message for immediate deletion: %s', message) + immediate_delete[message.id] = message if new_config is not None: - self.logger.info('Loading new config: %s', new_config) + self.logger.info('Loading config: %s', new_config) self.config = new_config - if self.init_message is not None: - self.delete_message(self.init_message) - self.init_message = None if self.init_message is None: + assert self.config is not None self.init_message = self.create_message( self.bot_channel, {'content': self.INIT_MESSAGE}, upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())]) - if messages_to_delete: + for message in immediate_delete.values(): + try: + self.delete_message(message) + except RuntimeError as error: + self.logger.error('Error deleting after bot channel scan (immediate): %s', error) + + if delayed_delete: self.tasks.append(( DiscordManager.Task.DELETE_MESSAGES, time.time() + self.config.bot_message_duration, - messages_to_delete)) + list(delayed_delete.values()))) def run(self): while True: @@ -307,7 +300,7 @@ class DiscordManager: task_type, task_time, task_params = self.tasks.pop() sleep_time = task_time - time.time() self.logger.debug( - 'Next task %s at %.03f (%s), sleeping %.03fs', task_type, task_time, task_params, sleep_time) + 'Next task %s at %.03f (sleeping for %.03fs) : %s', task_type, task_time, sleep_time, task_params) if sleep_time > 0: time.sleep(sleep_time) match task_type: @@ -315,7 +308,8 @@ class DiscordManager: try: self._scan_bot_channel() except Exception as error: - self.logger.error('Error scanning bot channel: %s', error) + self.logger.error('Error scanning bot channel: %s -> %s', + error, traceback.format_exc().replace('\n', ' | ')) self.tasks.append(( self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None)) case DiscordManager.Task.DELETE_MESSAGES: @@ -330,41 +324,39 @@ class DiscordManager: try: self.delete_message(message) except Exception as error: - self.logger.error('Error deleting message %s: %s', message, error) - if self.rate_limit.remaining <= 1: - sleep_time = self.rate_limit.next_reset - time.time() - if sleep_time > 0: - self.logger.debug('Rate limit: sleeping %.03f second') - time.sleep(sleep_time) + self.logger.error('Error deleting message %s: %s -> %s', + message, error, traceback.format_exc().replace('\n', ' | ')) time.sleep(1) def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel: - headers, channel_info = self._send_request( + _, channel_info = self._send_request( *Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(), expected_code=201) - self._update_rate_limit(headers) if not isinstance(channel_info, dict): raise RuntimeError(f'Error creating channel with params (no info): {params}') return TextChannel.from_dict(channel_info) def create_message(self, channel: TextChannel, params: Api.Message.CreateParams, upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message: - headers, message_info = self._send_request( + _, message_info = self._send_request( *Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(), upload_files=upload_files) - self._update_rate_limit(headers) if not isinstance(message_info, dict): raise RuntimeError(f'Error creating message with params (no info): {params}') return Message.from_dict(message_info) def delete_message(self, message: Message): - headers, _ = self._send_request( + _, _ = self._send_request( *Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204) - self._update_rate_limit(headers) + + def get_current_user(self) -> User: + _, user_info = self._send_request(*Api.User.get_current()) + if not isinstance(user_info, dict): + raise RuntimeError(f'Error getting current user (not a dict): {user_info}') + return User.from_dict(user_info) def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]: - headers, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id)) - self._update_rate_limit(headers) + _, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id)) categories: list[ChannelCategory] = [] text_channels: list[TextChannel] = [] if channels_info is not None: @@ -378,8 +370,7 @@ class DiscordManager: return categories, text_channels def list_roles(self) -> list[Role]: - headers, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id)) - self._update_rate_limit(headers) + _, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id)) if not isinstance(roles_info, list): raise RuntimeError(f'Error listing roles (not a list): {roles_info}') return [Role.from_dict(r) for r in roles_info] diff --git a/breadtube_bot/objects.py b/breadtube_bot/objects.py index 2cead75..a44f8fe 100644 --- a/breadtube_bot/objects.py +++ b/breadtube_bot/objects.py @@ -251,6 +251,7 @@ class User: # TODO : complete attributes username: str discriminator: str global_name: str | None + bot: bool @staticmethod def from_dict(info: dict) -> User: @@ -258,7 +259,8 @@ class User: # TODO : complete attributes id=int(info['id']), username=info['username'], discriminator=info['discriminator'], - global_name=info.get('global_name')) + global_name=info.get('global_name'), + bot=info.get('bot', False)) class AttachmentFlags(IntFlag): @@ -421,7 +423,7 @@ class Message: # TODO : complete attributes id=int(info['id']), channel_id=int(info['channel_id']), author=(User.from_dict(info['author']) if info.get('webhook_id') is None else User( - id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)), + id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None, bot=False)), content=info['content'], timestamp=datetime.fromisoformat(info['timestamp']), edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None,