from __future__ import annotations from dataclasses import asdict, dataclass, is_dataclass from enum import Enum import logging import operator from pathlib import Path import json import random import time import tomllib from typing import Any import urllib.error import urllib.request from .api import Api, ApiAction, ApiVersion from .config import Config from .logger import create_logger from .objects import ( Attachment, ChannelCategory, ChannelType, FileMime, Message, MessageReference, MessageReferenceType, Overwrite, OverwriteType, Permissions, Role, TextChannel) HTTPHeaders = dict[str, str] class ApiEncoder(json.JSONEncoder): def default(self, o): if is_dataclass(o): return asdict(o) # type: ignore if isinstance(o, Enum): return o.value return super().default(o) class DiscordManager: MAX_CONFIG_SIZE: int = 50_000 DEFAULT_MESSAGE_LIST_LIMIT = 50 INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n' 'You can upload a new one to update the configuration.') @dataclass class RateLimit: remaining: int next_reset: float class Task(Enum): SCAN_BOT_CHANNEL = 1 DELETE_MESSAGES = 2 @staticmethod def _get_code_version() -> str: pyproject_path = Path(__file__).parents[1] / 'pyproject.toml' if not pyproject_path.exists(): raise RuntimeError('Cannot current bot version') return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version'] def __init__(self, bot_token: str, guild_id: int, config: Config | None = None, log_level: int = logging.INFO) -> None: self.config = config or Config() self.guild_id = guild_id self._bot_token = bot_token self.logger = create_logger('breadtube', log_level, stdout=True) self.rate_limit = self.RateLimit(remaining=1, next_reset=0) self.version = self._get_code_version() self.tasks: list[tuple[DiscordManager.Task, float, Any]] = [] self.logger.info('Retrieving guild roles before init') self.guild_roles: list = self.list_roles() self.bot_channel: TextChannel | None = None self.init_message: Message | None = None for _ in range(self.config.bot_channel_init_retries): while not self.init_bot_channel(): time.sleep(10) break else: self.logger.info('Bot init OK') break raise RuntimeError("Couldn't initialize bot channel/role/permission") self._scan_bot_channel() self.tasks.append(( self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None)) self.logger.info('Bot initialized') def _update_rate_limit(self, headers: HTTPHeaders): for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']: if header_key not in headers: self.logger.info('Warning: no "%s" found in headers', header_key) return self.rate_limit.remaining = int(headers['x-ratelimit-remaining']) self.rate_limit.next_reset = float(headers['x-ratelimit-reset']) self.logger.debug('Updated rate limit: %s', self.rate_limit) def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10, data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None, expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]: min_api_version = 9 if api_version.value < min_api_version: self.logger.warning( 'Warning: using deprecated API version %d (minimum non deprecated is %d)', api_version, min_api_version) url = f'https://discord.com/api/v{api_version.value}{endpoint}' boundary: str = '' if upload_files: boundary = f'{random.randbytes(16).hex()}' data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n' 'Content-Type: application/json\r\n\r\n'.encode() + data + f'\r\n--{boundary}'.encode()) if data else b'' for file_index, (name, mime, content) in enumerate(upload_files): data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";' f' filename="{name}"\r\nContent-Length: {len(content)}' f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content data += f'\r\n--{boundary}--'.encode() request = urllib.request.Request(url, data=data) request.method = api_action.value request.add_header('User-Agent', f'BreadTube (v{self.version})') request.add_header('Accept', 'application/json') if upload_files: request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}') else: request.add_header('Content-Type', 'application/json') request.add_header('Authorization', f'Bot {self._bot_token}') try: with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response: if response.status != expected_code: raise RuntimeError( f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}') body = response.read() return dict(response.getheaders()), json.loads(body.decode()) if body else None except urllib.error.HTTPError as error: raise RuntimeError( f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error except urllib.error.URLError as error: raise RuntimeError(f'URL error calling API ({url}): {error}') from error def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]: request = urllib.request.Request(attachment.url) request.add_header('User-Agent', f'BreadTube (v{self.version})') request.add_header('Authorization', f'Bot {self._bot_token}') try: with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response: if response.status != expected_code: raise RuntimeError( f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})') return dict(response.getheaders()), response.read() except urllib.error.HTTPError as error: raise RuntimeError( f'HTTP error downloading attachment ({attachment}):' f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error except urllib.error.URLError as error: raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error def init_bot_channel(self) -> bool: _, text_channel = self.list_channels() breadtube_role: Role | None = None everyone_role: Role | None = None for role in self.guild_roles: if role.name == self.config.bot_role: breadtube_role = role elif role.name == '@everyone': everyone_role = role if breadtube_role is None: self.logger.info('No BreadTube role found') return False if everyone_role is None: self.logger.info('No everyone role found') return False for channel in text_channel: if channel.name == self.config.bot_channel: self.bot_channel = channel self.logger.info('Found breadtube bot channel') for perm in self.bot_channel.permission_overwrites: if perm.id == breadtube_role.id: if not perm.allow | Permissions.VIEW_CHANNEL: self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing') return False self.logger.info('BreadTube channel permission OK') break break else: self.bot_channel = self.create_text_channel({ 'name': self.config.bot_channel, 'permission_overwrites': [ Overwrite(everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE, deny=Permissions.VIEW_CHANNEL), Overwrite(breadtube_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL, deny=Permissions.NONE)] }) self.logger.info('Created breadtube bot channel') return True def _scan_bot_channel(self): if self.bot_channel is None: self.logger.error('Cannot scan bot channel: bot channel is None') return [] last_message_id: int | None = None while True: messages = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id) if len(messages) < self.DEFAULT_MESSAGE_LIST_LIMIT: break last_message_id = messages[-1].id messages = sorted(messages, key=lambda x: x.timestamp) self.init_message = None new_config: Config | None = None messages_to_delete: list[Message] = [] for message in messages: # Skip message to be deleted skip = True for task_type, _, task_params in self.tasks: if task_type == self.Task.DELETE_MESSAGES and (any(m.id == message.id for m in task_params) or any( m.id == message.id for m in messages_to_delete)): self.logger.debug('Skipping message already marked to be deleted') break else: skip = False if skip: continue delete_message = True for attachment in message.attachments: if attachment.size < self.MAX_CONFIG_SIZE: try: _, content = self._download_attachment(attachment) if content.startswith(b'config'): try: config = Config.from_str(content.decode()) if config != self.config: new_config = config elif message.content == self.INIT_MESSAGE: if self.init_message is not None: self.logger.debug('Deleting duplicated init message') try: self.delete_message(self.init_message) except RuntimeError as error: self.logger.error('Error deleting init_message while scanning: %s', error) self.init_message = message delete_message = False break except RuntimeError as error: self.logger.info('Invalid config file: %s', error) messages_to_delete.extend([ self.create_message(self.bot_channel, { 'content': str(error), 'message_reference': MessageReference( type=MessageReferenceType.DEFAULT, message_id=message.id, channel_id=self.bot_channel.id, guild_id=None, fail_if_not_exists=None)}), message]) delete_message = False break except Exception as error: self.logger.error('Error downloading attachment: %s', error) messages_to_delete.extend([ self.create_message(self.bot_channel, { 'content': str(error), 'message_reference': MessageReference( type=MessageReferenceType.DEFAULT, message_id=message.id, channel_id=self.bot_channel.id, guild_id=None, fail_if_not_exists=None)}), message]) delete_message = False break if delete_message: if any(m.id == message.id for m in messages_to_delete): self.logger.warning( 'Warning wrongly trying to delete message id %d while marked to be deleted', message.id) else: self.logger.debug('Deleting message: %s', message) try: self.delete_message(message) except RuntimeError as error: self.logger.error('Error deleting after scanned: %s', error) if new_config is not None: self.logger.info('Loading new config: %s', new_config) self.config = new_config if self.init_message is not None: self.delete_message(self.init_message) self.init_message = None if self.init_message is None: self.init_message = self.create_message( self.bot_channel, {'content': self.INIT_MESSAGE}, upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())]) if messages_to_delete: self.tasks.append(( DiscordManager.Task.DELETE_MESSAGES, time.time() + self.config.bot_message_duration, messages_to_delete)) def run(self): while True: if self.tasks: self.tasks = sorted(self.tasks, key=operator.itemgetter(1), reverse=True) task_type, task_time, task_params = self.tasks.pop() sleep_time = task_time - time.time() self.logger.debug( 'Next task %s at %.03f (%s), sleeping %.03fs', task_type, task_time, task_params, sleep_time) if sleep_time > 0: time.sleep(sleep_time) match task_type: case DiscordManager.Task.SCAN_BOT_CHANNEL: try: self._scan_bot_channel() except Exception as error: self.logger.error('Error scanning bot channel: %s', error) self.tasks.append(( self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None)) case DiscordManager.Task.DELETE_MESSAGES: if not isinstance(task_params, list): self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params) elif not task_params: self.logger.error('Empty params for DELETE_MESSAGES: %s', task_params) elif any(not isinstance(v, Message) for v in task_params): self.logger.error('All params not int for DELETE_MESSAGES: %s', task_params) else: for message in task_params: try: self.delete_message(message) except Exception as error: self.logger.error('Error deleting message %s: %s', message, error) if self.rate_limit.remaining <= 1: sleep_time = self.rate_limit.next_reset - time.time() if sleep_time > 0: self.logger.debug('Rate limit: sleeping %.03f second') time.sleep(sleep_time) time.sleep(1) def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel: headers, channel_info = self._send_request( *Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(), expected_code=201) self._update_rate_limit(headers) if not isinstance(channel_info, dict): raise RuntimeError(f'Error creating channel with params (no info): {params}') return TextChannel.from_dict(channel_info) def create_message(self, channel: TextChannel, params: Api.Message.CreateParams, upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message: headers, message_info = self._send_request( *Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(), upload_files=upload_files) self._update_rate_limit(headers) if not isinstance(message_info, dict): raise RuntimeError(f'Error creating message with params (no info): {params}') return Message.from_dict(message_info) def delete_message(self, message: Message): headers, _ = self._send_request( *Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204) self._update_rate_limit(headers) def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]: headers, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id)) self._update_rate_limit(headers) categories: list[ChannelCategory] = [] text_channels: list[TextChannel] = [] if channels_info is not None: for channel_info in channels_info: channel_type = ChannelType(channel_info['type']) match channel_type: case ChannelType.GUILD_CATEGORY: categories.append(ChannelCategory.from_dict(channel_info)) case ChannelType.GUILD_TEXT: text_channels.append(TextChannel.from_dict(channel_info)) return categories, text_channels def list_roles(self) -> list[Role]: headers, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id)) self._update_rate_limit(headers) if not isinstance(roles_info, list): raise RuntimeError(f'Error listing roles (not a list): {roles_info}') return [Role.from_dict(r) for r in roles_info] def list_text_channel_messages(self, channel: TextChannel, limit: int | None = None, before_id: int | None = None, after_id: int | None = None) -> list[Message]: if limit is not None and not (0 < limit <= 100): # noqa: PLR2004 raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range') params: Api.Message.ListMessageParams = {} if limit is not None: params['limit'] = limit if before_id is not None: params['before'] = before_id if after_id is not None: params['after'] = after_id headers, messages_info = self._send_request( *Api.Message.list_by_channel(channel.id), data=json.dumps(params, cls=ApiEncoder).encode() if params else None) self._update_rate_limit(headers) return [Message.from_dict(m) for m in messages_info or []]