From 72edbe6599a1ced6b51246df62a624d253fc93da Mon Sep 17 00:00:00 2001 From: BreadTube Date: Tue, 23 Sep 2025 04:50:23 +0900 Subject: [PATCH] Bot config and channel init --- bot.py | 18 +-- breadtube_bot/api.py | 109 ++++++++++++++++++ breadtube_bot/config.py | 21 ++++ breadtube_bot/logger.py | 60 ++++++++++ breadtube_bot/manager.py | 236 ++++++++++++++++++++++++--------------- breadtube_bot/objects.py | 162 ++++++++++++++++++++++++++- pyproject.toml | 2 +- 7 files changed, 499 insertions(+), 109 deletions(-) create mode 100644 breadtube_bot/config.py create mode 100644 breadtube_bot/logger.py diff --git a/bot.py b/bot.py index 16db5a1..67d94e7 100644 --- a/bot.py +++ b/bot.py @@ -1,29 +1,13 @@ from pathlib import Path -import sys from breadtube_bot.manager import DiscordManager -from breadtube_bot.objects import TextChannel def main(): bot_token = Path('data/discord_bot_token.txt').read_text(encoding='utf-8').strip() guild_id = 1306964577812086824 manager = DiscordManager(bot_token=bot_token, guild_id=guild_id) - _categories, text_channel = manager.list_channels() - breadtube_channel: TextChannel | None = None - for channel in text_channel: - if channel.name == 'breadtube-bot': - breadtube_channel = channel - break - - if breadtube_channel is None: - print('Cannot find beadtube-bot channel') - sys.exit(1) - - messages = manager.list_text_channel_messages(breadtube_channel) - for message in messages: - print(message) - manager.delete_message(message) + print(manager.rate_limit) if __name__ == '__main__': diff --git a/breadtube_bot/api.py b/breadtube_bot/api.py index 14570d1..f901ccb 100644 --- a/breadtube_bot/api.py +++ b/breadtube_bot/api.py @@ -1,4 +1,7 @@ from enum import Enum +from typing import TypedDict + +from breadtube_bot.objects import Overwrite class ApiVersion(Enum): @@ -17,11 +20,117 @@ class ApiAction(Enum): class Api: class Guild: + @staticmethod + def create_channel(guild_id: int) -> tuple[ApiAction, str]: + return ApiAction.POST, f'/guilds/{guild_id}/channels' + + class CreateChannelParams(TypedDict, total=False): + # All + # channel name (1-100 characters) + name: str + # All + # the type of channel + type: int + # Text, Announcement, Forum, Media + # channel topic (0-1024 characters) + topic: str + # Voice, Stage + # the bitrate (in bits) of the voice or stage channel; min 8000 + bitrate: int + # Voice, Stage + # the user limit of the voice channel + user_limit: int + # Text, Voice, Stage, Forum, Media + # amount of seconds a user has to wait before sending another message (0-21600); + # bots, as well as users with the permission manage_messages or manage_channel, are unaffected + rate_limit_per_user: int + # All + # sorting position of the channel (channels with the same position are sorted by id) + position: int + # All + # the channel's permission overwrites + permission_overwrites: list[dict] + # Text, Voice, Announcement, Stage, Forum, Media + # id of the parent category for a channel + parent_id: int + # Text, Voice, Announcement, Stage, Forum + # whether the channel is nsfw + nsfw: bool + # Voice, Stage + # channel voice region id of the voice or stage channel, automatic when set to null + rtc_region: str + # Voice, Stage + # the camera video quality mode of the voice channel + video_quality_mode: int + # Text, Announcement, Forum, Media + # the default duration that the clients use (not the API) for newly created threads in the channel, + # in minutes, to automatically archive the thread after recent activity + default_auto_archive_duration: int + # Forum, Media + # emoji to show in the add reaction button on a thread in a GUILD_FORUM or a GUILD_MEDIA channel + default_reaction_emoji: dict + # Forum, Media + # set of tags that can be used in a GUILD_FORUM or a GUILD_MEDIA channel + available_tags: list[dict] + # Forum, Media + # the default sort order type used to order posts in GUILD_FORUM and GUILD_MEDIA channels + default_sort_order: int + # Forum + # the default forum layout view used to display posts in GUILD_FORUM channels + default_forum_layout: int + # Text, Announcement, Forum, Media + # the initial rate_limit_per_user to set on newly created threads in a channel. + # this field is copied to the thread at creation time and does not live update. + default_thread_rate_limit_per_user: int + + class CreateTextChannelParams(TypedDict, total=False): + name: str + type: int + topic: str + rate_limit_per_user: int + position: int + permission_overwrites: list[Overwrite] + parent_id: int + nsfw: bool + default_auto_archive_duration: int + default_thread_rate_limit_per_user: int + @staticmethod def list_guilds(guild_id: int) -> tuple[ApiAction, str]: return ApiAction.GET, f'/guilds/{guild_id}/channels' + @staticmethod + def list_roles(guild_id: int) -> tuple[ApiAction, str]: + return ApiAction.GET, f'/guilds/{guild_id}/roles' + class Message: + @staticmethod + def create(channel_id: int) -> tuple[ApiAction, str]: + return ApiAction.POST, f'/channels/{channel_id}/messages' + + class CreateParams(TypedDict, total=False): + content: str # Message contents (up to 2000 characters) + # Can be used to verify a message was sent (up to 25 characters). + # Value will appear in the Message Create event. + nonce: int | str + tts: bool # true if this is a TTS message + # embeds: list[Embeded] # Up to 10 rich embeds (up to 6000 characters) + # allowed_mentions: MentionObject # Allowed mentions for the message + # message_reference: MessageReference # Include to make your message a reply or a forward + # components: list[MessageComponent] # Components to include with the message + sticker_ids: list[int] # IDs of up to 3 stickers in the server to send in the message + # files[n]: FileContents # Contents of the file being sent. See Uploading Files + # payload_json: str # JSON-encoded body of non-file params, only for multipart/form-data requests + # attachments: list[Attachment] # Attachment objects with filename and description. See Uploading Files + # Message flags combined as a bitfield + # (only SUPPRESS_EMBEDS, SUPPRESS_NOTIFICATIONS, IS_VOICE_MESSAGE, and IS_COMPONENTS_V2 can be set) + # flags: MessageFlags + # If true and nonce is present, it will be checked for uniqueness in the past few minutes. + # If another message was created by the same author with the same nonce, that message will be returned + # and no new message will be created. + # enforce_nonce: bool + # poll: PollRequest # A poll! + @staticmethod def delete(channel_id: int, message_id: int) -> tuple[ApiAction, str]: return ApiAction.DELETE, f'/channels/{channel_id}/messages/{message_id}' diff --git a/breadtube_bot/config.py b/breadtube_bot/config.py new file mode 100644 index 0000000..cabe266 --- /dev/null +++ b/breadtube_bot/config.py @@ -0,0 +1,21 @@ +from dataclasses import asdict, dataclass + + +@dataclass +class Config: + bot_channel: str = 'breadtube-bot' + bot_role: str = 'BreadTube' + bot_channel_init_retries: int = 3 + + def to_str(self) -> str: + return '\n'.join(['config', *[f'{k}={v}' for k, v in asdict(self).items()]]) + + def from_str(self, text: str): + lines = text.strip().splitlines() + if not lines: + raise RuntimeError('Config cannot load: empty input') + if lines[0] != 'config': + raise RuntimeError('Config cannot load: first line is not "config"') + for line in lines[1:]: + key, value = line.split('=', maxsplit=1) + setattr(self, key, self.__annotations__[key](value)) diff --git a/breadtube_bot/logger.py b/breadtube_bot/logger.py new file mode 100644 index 0000000..62fe446 --- /dev/null +++ b/breadtube_bot/logger.py @@ -0,0 +1,60 @@ +from logging import handlers +import logging +from pathlib import Path +import sys + + +class ConsoleColor: + """Simple shortcut to use colors in console""" + HEADER = '\033[95m' + BLUE = '\033[94m' + GREEN = '\033[92m' + ORANGE = '\033[93m' + RED = '\033[91m' + ENDCOLOR = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + + +class ColoredFormatter(logging.Formatter): + """Formatter changing the record during format : adds colors to levelname""" + def format(self, record): + levelno = record.levelno + if levelno == logging.ERROR: + levelname_color = ConsoleColor.RED + record.levelname + ConsoleColor.ENDCOLOR + elif levelno == logging.WARNING: + levelname_color = ConsoleColor.ORANGE + record.levelname + ConsoleColor.ENDCOLOR + elif levelno == logging.INFO: + levelname_color = ConsoleColor.GREEN + record.levelname + ConsoleColor.ENDCOLOR + elif levelno == logging.DEBUG: + levelname_color = ConsoleColor.BLUE + record.levelname + ConsoleColor.ENDCOLOR + else: + levelname_color = record.levelname + record.levelname = levelname_color + return logging.Formatter.format(self, record) + + +def create_logger(name: str, level: int, log_dir: Path | None = None, stdout=False) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(level) + + if log_dir is not None: + log_dir.mkdir(parents=True, exist_ok=True) + logger.setLevel(level) + file_log_handler = handlers.RotatingFileHandler( + log_dir / f'{name}.log', + maxBytes=500000, + backupCount=5) + file_log_handler.setLevel(level) + log_formatter = logging.Formatter('%(asctime)s %(levelname)s : %(message)s') + file_log_handler.setFormatter(log_formatter) + logger.addHandler(file_log_handler) + + if stdout: + terminal_log_handler = logging.StreamHandler(sys.stdout) + terminal_log_handler.setLevel(level) + colored_log_formatter = ColoredFormatter('%(asctime)s %(levelname)s : %(message)s') + terminal_log_handler.setFormatter(colored_log_formatter) + logger.addHandler(terminal_log_handler) + + return logger diff --git a/breadtube_bot/manager.py b/breadtube_bot/manager.py index 3ee2ec3..5a3ec21 100644 --- a/breadtube_bot/manager.py +++ b/breadtube_bot/manager.py @@ -1,26 +1,39 @@ -from dataclasses import dataclass -from datetime import datetime +from dataclasses import asdict, dataclass, is_dataclass +from enum import Enum +import logging from pathlib import Path import json +import random +import time import tomllib import urllib.error import urllib.request from .api import Api, ApiAction, ApiVersion +from .config import Config +from .logger import create_logger from .objects import ( - ChannelCategory, ChannelFlags, ChannelType, Message, Overwrite, OverwriteType, Permissions, TextChannel, User) + ChannelCategory, ChannelType, FileMime, Message, Overwrite, OverwriteType, Permissions, Role, TextChannel) HTTPHeaders = dict[str, str] -@dataclass -class _RateLimit: - remaining: int - next_reset: float +class ApiEncoder(json.JSONEncoder): + def default(self, o): + if is_dataclass(o): + return asdict(o) # type: ignore + if isinstance(o, Enum): + return o.value + return super().default(o) class DiscordManager: + @dataclass + class RateLimit: + remaining: int + next_reset: float + @staticmethod def _get_code_version() -> str: pyproject_path = Path(__file__).parents[1] / 'pyproject.toml' @@ -28,133 +41,176 @@ class DiscordManager: raise RuntimeError('Cannot current bot version') return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version'] - def __init__(self, bot_token: str, guild_id: int) -> None: + def __init__(self, bot_token: str, guild_id: int, config: Config | None = None, + log_level: int = logging.INFO) -> None: + self.config = config or Config() self.guild_id = guild_id self._bot_token = bot_token + self.logger = create_logger('breadtube', log_level, stdout=True) - self.rate_limit: _RateLimit = _RateLimit(remaining=1, next_reset=0) + self.rate_limit = self.RateLimit(remaining=1, next_reset=0) self.version = self._get_code_version() + self.guild_roles: list = self.list_roles() + for _ in range(self.config.bot_channel_init_retries): + while not self.init_bot_channel(): + time.sleep(10) + break + else: + self.logger.info('Bot init OK') + break + raise RuntimeError("Couldn't initialize bot channel/role/permission") + self.logger.info('Bot initialized') + def _update_rate_limit(self, headers: HTTPHeaders): for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']: if header_key not in headers: - print(f'Warning: no "{header_key}" found in headers') + self.logger.info('Warning: no "%s" found in headers', header_key) return self.rate_limit.remaining = int(headers['x-ratelimit-remaining']) self.rate_limit.next_reset = float(headers['x-ratelimit-reset']) def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10, - expected_code: int = 200) -> tuple[ - HTTPHeaders, dict]: + data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None, + expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]: timeout = 3 min_api_version = 9 - if api_action == ApiAction.POST: - raise NotImplementedError if api_version.value < min_api_version: - print(f'Warning: using deprecated API version {api_version} (minimum non deprecated is {min_api_version})') + self.logger.warning( + 'Warning: using deprecated API version %d (minimum non deprecated is %d)', + api_version, min_api_version) url = f'https://discord.com/api/v{api_version.value}{endpoint}' - request = urllib.request.Request(url) + + boundary: str = '' + if upload_files: + boundary = f'{random.randbytes(16).hex()}' + data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n' + 'Content-Type: application/json\r\n\r\n'.encode() + data + + f'\r\n--{boundary}'.encode()) if data else b'' + for file_index, (name, mime, content) in enumerate(upload_files): + data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";' + f' filename="{name}"\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content + data += f'\r\n--{boundary}--'.encode() + request = urllib.request.Request(url, data=data) + request.method = api_action.value request.add_header('User-Agent', f'BreadTube (v{self.version})') request.add_header('Accept', 'application/json') + if upload_files: + request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}') + else: + request.add_header('Content-Type', 'application/json') request.add_header('Authorization', f'Bot {self._bot_token}') try: with urllib.request.urlopen(request, timeout=timeout) as response: if response.status != expected_code: raise RuntimeError( f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}') - return dict(response.getheaders()), json.loads(response.read().decode()) + body = response.read() + return dict(response.getheaders()), json.loads(body.decode()) if body else None except urllib.error.HTTPError as error: raise RuntimeError( f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error except urllib.error.URLError as error: raise RuntimeError(f'URL error calling API ({url}): {error}') from error - @staticmethod - def _parse_overwrite(info: dict) -> Overwrite: - return Overwrite( - id=int(info['id']), - type=OverwriteType(info['type']), - allow=Permissions(int(info['allow'])), - deny=Permissions(int(info['deny'])) - ) + def init_bot_channel(self) -> bool: + _, text_channel = self.list_channels() + breadtube_role: Role | None = None + everyone_role: Role | None = None + for role in self.guild_roles: + if role.name == self.config.bot_role: + breadtube_role = role + elif role.name == '@everyone': + everyone_role = role + if breadtube_role is None: + self.logger.info('No BreadTube role found') + return False + if everyone_role is None: + self.logger.info('No everyone role found') + return False - def _parse_channel_category(self, info: dict) -> ChannelCategory: - parent_id: str | None = info.get('parent_id') - return ChannelCategory( - id=int(info['id']), - guild_id=int(info['guild_id']), - position=int(info['position']), - permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']], - name=info.get('name'), - parent_id=int(parent_id) if parent_id is not None else None, - flags=ChannelFlags(info['flags']), - ) + breadtube_channel: TextChannel | None = None + for channel in text_channel: + if channel.name == self.config.bot_channel: + breadtube_channel = channel + self.logger.info('Found breadtube bot channel') + for perm in breadtube_channel.permission_overwrites: + if perm.id == breadtube_role.id: + if not perm.allow | Permissions.VIEW_CHANNEL: + self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing') + return False + self.logger.info('BreadTube channel permission OK') + break + messages = self.list_text_channel_messages(breadtube_channel) + for message in messages: + self.logger.debug('Deleting message: %s', message) + self.delete_message(message) + break + else: + breadtube_channel = self.create_text_channel({ + 'name': self.config.bot_channel, + 'permission_overwrites': [ + Overwrite(everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE, + deny=Permissions.VIEW_CHANNEL), + Overwrite(breadtube_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL, + deny=Permissions.NONE)] + }) + self.logger.info('Created breadtube bot channel') - def _parse_text_channel(self, info: dict) -> TextChannel: - parent_id: str | None = info.get('parent_id') - last_message_id: str | None = info.get('last_message_id') - last_pin_timestamp: str | None = info.get('last_pin_timestamp') - return TextChannel( - id=int(info['id']), - guild_id=int(info['guild_id']), - position=int(info['position']), - permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']], - name=info.get('name'), - topic=info.get('topic'), - nsfw=info['nsfw'], - last_message_id=int(last_message_id) if last_message_id is not None else None, - rate_limit_per_user=int(info['rate_limit_per_user']), - parent_id=int(parent_id) if parent_id is not None else None, - last_pin_timestamp=(datetime.fromisoformat(last_pin_timestamp) if last_pin_timestamp is not None else None), - flags=ChannelFlags(info['flags']), - ) + self.create_message( + breadtube_channel, + {'content': 'This is the current configuration used, upload a new one to update the configuration'}, + upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())]) + return True - @staticmethod - def _parse_user(info: dict) -> User: - return User( - id=int(info['id']), - username=info['username'], - discriminator=info['discriminator'], - global_name=info.get('global_name') - ) + def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel: + headers, channel_info = self._send_request( + *Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(), + expected_code=201) + self._update_rate_limit(headers) + if not isinstance(channel_info, dict): + raise RuntimeError(f'Error creating channel with params (no info): {params}') + return TextChannel.from_dict(channel_info) - def _parse_message(self, info: dict) -> Message: - edited_timestamp: str | None = info.get('edited_timestamp') - return Message( - id=int(info['id']), - channel_id=int(info['channel_id']), - author=(self._parse_user(info['author']) if info.get('webhook_id') is None else User( - id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)), - content=info['content'], - timestamp=datetime.fromisoformat(info['timestamp']), - edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None - ) + def create_message(self, channel: TextChannel, params: Api.Message.CreateParams, + upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message: + headers, message_info = self._send_request( + *Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(), + upload_files=upload_files) + self._update_rate_limit(headers) + if not isinstance(message_info, dict): + raise RuntimeError(f'Error creating message with params (no info): {params}') + return Message.from_dict(message_info) def delete_message(self, message: Message): - try: - headers, _ = self._send_request( - *Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204) - self._update_rate_limit(headers) - print(f'Message {message.id} deleted') - except RuntimeError as error: - print(error) + headers, _ = self._send_request( + *Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204) + self._update_rate_limit(headers) def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]: - headers, channels = self._send_request(*Api.Guild.list_guilds(self.guild_id)) + headers, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id)) self._update_rate_limit(headers) categories: list[ChannelCategory] = [] text_channels: list[TextChannel] = [] - for channel in channels: - channel_type = ChannelType(channel['type']) - match channel_type: - case ChannelType.GUILD_CATEGORY: - categories.append(self._parse_channel_category(channel)) - case ChannelType.GUILD_TEXT: - text_channels.append(self._parse_text_channel(channel)) + if channels_info is not None: + for channel_info in channels_info: + channel_type = ChannelType(channel_info['type']) + match channel_type: + case ChannelType.GUILD_CATEGORY: + categories.append(ChannelCategory.from_dict(channel_info)) + case ChannelType.GUILD_TEXT: + text_channels.append(TextChannel.from_dict(channel_info)) return categories, text_channels - def list_text_channel_messages(self, channel: TextChannel) -> list: + def list_roles(self) -> list[Role]: + headers, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id)) + self._update_rate_limit(headers) + if not isinstance(roles_info, list): + raise RuntimeError(f'Error listing roles (not a list): {roles_info}') + return [Role.from_dict(r) for r in roles_info] + + def list_text_channel_messages(self, channel: TextChannel) -> list[Message]: headers, messages = self._send_request(*Api.Message.list_by_channel(channel.id)) self._update_rate_limit(headers) - return [self._parse_message(m) for m in messages] + return [Message.from_dict(m) for m in messages or []] diff --git a/breadtube_bot/objects.py b/breadtube_bot/objects.py index 071301c..cb34a5c 100644 --- a/breadtube_bot/objects.py +++ b/breadtube_bot/objects.py @@ -1,8 +1,27 @@ +from __future__ import annotations + from dataclasses import dataclass from datetime import datetime from enum import Enum, IntFlag +class FileMime(Enum): + AUDIO_OGG = 'audio/ogg' + IMAGE_JPEG = 'image/jpeg' + IMAGE_PNG = 'image/png' + IMAGE_SVG = 'image/svg' + JSON = 'application/json' + PDF = 'application/pdf' + TEXT_CSV = 'text/csv' + TEXT_HTML = 'text/html' + TEXT_MARKDOWN = 'text/markdown' + TEXT_PLAIN = 'text/plain' + VIDEO_MP4 = 'video/mp4' + VIDEO_MPEG = 'video/mpeg' + VIDEO_WEBM = 'video/webm' + ZIP = 'application/zip' + + class ChannelType(Enum): GUILD_TEXT = 0 DM = 1 @@ -35,8 +54,9 @@ class OverwriteType(Enum): class Permissions(IntFlag): + NONE = 0 # Allows creation of instant invites - CREATE_INSTANT_INVITE = 1 << 0 + CREATE_INSTANT_INVITE = 1 # Allows kicking members KICK_MEMBERS = 1 << 1 # Allows banning members @@ -153,6 +173,14 @@ class Overwrite: allow: Permissions deny: Permissions + @staticmethod + def from_dict(info: dict) -> Overwrite: + return Overwrite( + id=int(info['id']), + type=OverwriteType(info['type']), + allow=Permissions(int(info['allow'])), + deny=Permissions(int(info['deny']))) + @dataclass class ChannelCategory: @@ -164,6 +192,18 @@ class ChannelCategory: parent_id: int | None flags: ChannelFlags + @staticmethod + def from_dict(info: dict) -> ChannelCategory: + parent_id: str | None = info.get('parent_id') + return ChannelCategory( + id=int(info['id']), + guild_id=int(info['guild_id']), + position=int(info['position']), + permission_overwrites=[Overwrite.from_dict(o) for o in info['permission_overwrites']], + name=info.get('name'), + parent_id=int(parent_id) if parent_id is not None else None, + flags=ChannelFlags(info['flags'])) + @dataclass class TextChannel: @@ -180,6 +220,25 @@ class TextChannel: last_pin_timestamp: datetime | None flags: ChannelFlags + @staticmethod + def from_dict(info: dict) -> TextChannel: + parent_id: str | None = info.get('parent_id') + last_message_id: str | None = info.get('last_message_id') + last_pin_timestamp: str | None = info.get('last_pin_timestamp') + return TextChannel( + id=int(info['id']), + guild_id=int(info['guild_id']), + position=int(info['position']), + permission_overwrites=[Overwrite.from_dict(o) for o in info['permission_overwrites']], + name=info.get('name'), + topic=info.get('topic'), + nsfw=info['nsfw'], + last_message_id=int(last_message_id) if last_message_id is not None else None, + rate_limit_per_user=int(info['rate_limit_per_user']), + parent_id=int(parent_id) if parent_id is not None else None, + last_pin_timestamp=(datetime.fromisoformat(last_pin_timestamp) if last_pin_timestamp is not None else None), + flags=ChannelFlags(info['flags'])) + @dataclass class User: # TODO : complete attributes @@ -188,6 +247,14 @@ class User: # TODO : complete attributes discriminator: str global_name: str | None + @staticmethod + def from_dict(info: dict) -> User: + return User( + id=int(info['id']), + username=info['username'], + discriminator=info['discriminator'], + global_name=info.get('global_name')) + @dataclass class Message: # TODO : complete attributes @@ -197,3 +264,96 @@ class Message: # TODO : complete attributes content: str timestamp: datetime edited_timestamp: datetime | None + + @staticmethod + def from_dict(info: dict) -> Message: + edited_timestamp: str | None = info.get('edited_timestamp') + return Message( + id=int(info['id']), + channel_id=int(info['channel_id']), + author=(User.from_dict(info['author']) if info.get('webhook_id') is None else User( + id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)), + content=info['content'], + timestamp=datetime.fromisoformat(info['timestamp']), + edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None) + + +@dataclass +class RoleColors: + primary_color: int + seconday_color: int | None + tertiary_color: int | None + + @staticmethod + def from_dict(info: dict) -> RoleColors: + seconday_color = info.get('secondary_color') + tertiary_color = info.get('tertiary_color') + return RoleColors( + primary_color=int(info['primary_color']), + seconday_color=int(seconday_color) if seconday_color is not None else None, + tertiary_color=int(tertiary_color) if tertiary_color is not None else None) + + +class RoleFlags(IntFlag): + NONE = 0 + # role can be selected by members in an onboarding prompt + IN_PROMPT = 1 + + +@dataclass +class RoleTags: + bot_id: int | None + intergration_id: int | None + premium_subscriber: bool + subcription_listing_id: int | None + available_for_purchase: bool + guild_connections: bool + + @staticmethod + def from_dict(info: dict) -> RoleTags: + bot_id = info.get('bot_id') + intergration_id = info.get('intergration_id') + subcription_listing_id = info.get('subcription_listing_id') + return RoleTags( + bot_id=int(bot_id) if bot_id is not None else None, + intergration_id=int(intergration_id) if intergration_id is not None else None, + premium_subscriber='premium_subscriber' in info, + subcription_listing_id=int(subcription_listing_id) if subcription_listing_id is not None else None, + available_for_purchase='available_for_purchase' in info, + guild_connections='guild_connections' in info) + + +@dataclass +class Role: + id: int # role id + name: str # role name + color: int # Deprecated integer representation of hexadecimal color code + colors: RoleColors # the role's colors + hoist: bool # if this role is pinned in the user listing + icon: str | None # role icon hash + unicode_emoji: str | None # role unicode emoji + position: int # position of this role (roles with the same position are sorted by id) + permissions: Permissions # permission bit set + managed: bool # whether this role is managed by an integration + mentionable: bool # whether this role is mentionable + tags: RoleTags | None # the tags this role has + flags: int # role flags combined as a bitfield + + @staticmethod + def from_dict(info: dict) -> Role: + tags = info.get('tags') + return Role( + id=int(info['id']), + name=info['name'], + color=int(info['color']), + colors=RoleColors.from_dict(info['colors']), + hoist=info['hoist'], + icon=info.get('icon'), + unicode_emoji=info.get('unicode_emoji'), + position=int(info['position']), + permissions=Permissions(int(info['permissions'])), + managed=info['managed'], + mentionable=info['mentionable'], + tags=RoleTags.from_dict(tags) if tags is not None else None, + flags=RoleFlags(int(info['flags'])) + ) diff --git a/pyproject.toml b/pyproject.toml index f42820a..0f5f408 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ pythonpath = ["."] preview = true select = ["A", "ARG", "B", "C", "E", "F", "FURB", "G", "I","ICN", "ISC", "PERF", "PIE", "PL", "PLE", "PTH", "Q", "RET", "RSE", "RUF", "SLF", "SIM", "T20", "TCH", "UP", "W"] -ignore = ["E275", "FURB140", "I001", "PERF203", "RET502", "RET503", "SIM105", "T201"] +ignore = ["E275", "FURB140", "I001", "PERF203", "RET502", "RET503", "SIM105"] [tool.ruff.lint.per-file-ignores] "tests/*" = ["SLF001"]