from dataclasses import dataclass from datetime import datetime from pathlib import Path import json import tomllib import urllib.error import urllib.request from .api import Api, ApiAction, ApiVersion from .objects import ( ChannelCategory, ChannelFlags, ChannelType, Message, Overwrite, OverwriteType, Permissions, TextChannel, User) HTTPHeaders = dict[str, str] @dataclass class _RateLimit: remaining: int next_reset: float class DiscordManager: @staticmethod def _get_code_version() -> str: pyproject_path = Path(__file__).parents[1] / 'pyproject.toml' if not pyproject_path.exists(): raise RuntimeError('Cannot current bot version') return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version'] def __init__(self, bot_token: str, guild_id: int) -> None: self.guild_id = guild_id self._bot_token = bot_token self.rate_limit: _RateLimit = _RateLimit(remaining=1, next_reset=0) self.version = self._get_code_version() def _update_rate_limit(self, headers: HTTPHeaders): for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']: if header_key not in headers: print(f'Warning: no "{header_key}" found in headers') return self.rate_limit.remaining = int(headers['x-ratelimit-remaining']) self.rate_limit.next_reset = float(headers['x-ratelimit-reset']) def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10, expected_code: int = 200) -> tuple[ HTTPHeaders, dict]: timeout = 3 min_api_version = 9 if api_action == ApiAction.POST: raise NotImplementedError if api_version.value < min_api_version: print(f'Warning: using deprecated API version {api_version} (minimum non deprecated is {min_api_version})') url = f'https://discord.com/api/v{api_version.value}{endpoint}' request = urllib.request.Request(url) request.add_header('User-Agent', f'BreadTube (v{self.version})') request.add_header('Accept', 'application/json') request.add_header('Authorization', f'Bot {self._bot_token}') try: with urllib.request.urlopen(request, timeout=timeout) as response: if response.status != expected_code: raise RuntimeError( f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}') return dict(response.getheaders()), json.loads(response.read().decode()) except urllib.error.HTTPError as error: raise RuntimeError( f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error except urllib.error.URLError as error: raise RuntimeError(f'URL error calling API ({url}): {error}') from error @staticmethod def _parse_overwrite(info: dict) -> Overwrite: return Overwrite( id=int(info['id']), type=OverwriteType(info['type']), allow=Permissions(int(info['allow'])), deny=Permissions(int(info['deny'])) ) def _parse_channel_category(self, info: dict) -> ChannelCategory: parent_id: str | None = info.get('parent_id') return ChannelCategory( id=int(info['id']), guild_id=int(info['guild_id']), position=int(info['position']), permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']], name=info.get('name'), parent_id=int(parent_id) if parent_id is not None else None, flags=ChannelFlags(info['flags']), ) def _parse_text_channel(self, info: dict) -> TextChannel: parent_id: str | None = info.get('parent_id') last_message_id: str | None = info.get('last_message_id') last_pin_timestamp: str | None = info.get('last_pin_timestamp') return TextChannel( id=int(info['id']), guild_id=int(info['guild_id']), position=int(info['position']), permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']], name=info.get('name'), topic=info.get('topic'), nsfw=info['nsfw'], last_message_id=int(last_message_id) if last_message_id is not None else None, rate_limit_per_user=int(info['rate_limit_per_user']), parent_id=int(parent_id) if parent_id is not None else None, last_pin_timestamp=(datetime.fromisoformat(last_pin_timestamp) if last_pin_timestamp is not None else None), flags=ChannelFlags(info['flags']), ) @staticmethod def _parse_user(info: dict) -> User: return User( id=int(info['id']), username=info['username'], discriminator=info['discriminator'], global_name=info.get('global_name') ) def _parse_message(self, info: dict) -> Message: edited_timestamp: str | None = info.get('edited_timestamp') return Message( id=int(info['id']), channel_id=int(info['channel_id']), author=(self._parse_user(info['author']) if info.get('webhook_id') is None else User( id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)), content=info['content'], timestamp=datetime.fromisoformat(info['timestamp']), edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None ) def delete_message(self, message: Message): try: headers, _ = self._send_request( *Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204) self._update_rate_limit(headers) print(f'Message {message.id} deleted') except RuntimeError as error: print(error) def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]: headers, channels = self._send_request(*Api.Guild.list_guilds(self.guild_id)) self._update_rate_limit(headers) categories: list[ChannelCategory] = [] text_channels: list[TextChannel] = [] for channel in channels: channel_type = ChannelType(channel['type']) match channel_type: case ChannelType.GUILD_CATEGORY: categories.append(self._parse_channel_category(channel)) case ChannelType.GUILD_TEXT: text_channels.append(self._parse_text_channel(channel)) return categories, text_channels def list_text_channel_messages(self, channel: TextChannel) -> list: headers, messages = self._send_request(*Api.Message.list_by_channel(channel.id)) self._update_rate_limit(headers) return [self._parse_message(m) for m in messages]