Config scan from bot channel implementation

This commit is contained in:
BreadTube 2025-09-23 22:48:35 +09:00 committed by Corentin
commit 157e8c1b17
6 changed files with 453 additions and 34 deletions

20
bot.py
View file

@ -1,13 +1,27 @@
from argparse import ArgumentParser
import logging
from pathlib import Path from pathlib import Path
from breadtube_bot.manager import DiscordManager from breadtube_bot.manager import DiscordManager
def main(): def main():
parser = ArgumentParser('BreadTube-bot')
parser.add_argument('--guild', type=int, default=1306964577812086824, help='Guild id to manage')
parser.add_argument('--debug', action='store_true', default=False, help='Run in debug mode (for logs)')
arguments = parser.parse_args()
debug_mode: bool = arguments.debug
guild_id: int = arguments.guild
del arguments
bot_token = Path('data/discord_bot_token.txt').read_text(encoding='utf-8').strip() bot_token = Path('data/discord_bot_token.txt').read_text(encoding='utf-8').strip()
guild_id = 1306964577812086824 manager = DiscordManager(
manager = DiscordManager(bot_token=bot_token, guild_id=guild_id) bot_token=bot_token, guild_id=guild_id, log_level=logging.DEBUG if debug_mode else logging.INFO)
print(manager.rate_limit) try:
manager.run()
except KeyboardInterrupt:
print('\r ') # noqa: T201
if __name__ == '__main__': if __name__ == '__main__':

View file

@ -1,7 +1,7 @@
from enum import Enum from enum import Enum
from typing import TypedDict from typing import TypedDict
from breadtube_bot.objects import Overwrite from breadtube_bot.objects import MessageReference, Overwrite
class ApiVersion(Enum): class ApiVersion(Enum):
@ -116,7 +116,7 @@ class Api:
tts: bool # true if this is a TTS message tts: bool # true if this is a TTS message
# embeds: list[Embeded] # Up to 10 rich embeds (up to 6000 characters) # embeds: list[Embeded] # Up to 10 rich embeds (up to 6000 characters)
# allowed_mentions: MentionObject # Allowed mentions for the message # allowed_mentions: MentionObject # Allowed mentions for the message
# message_reference: MessageReference # Include to make your message a reply or a forward message_reference: MessageReference # Include to make your message a reply or a forward
# components: list[MessageComponent] # Components to include with the message # components: list[MessageComponent] # Components to include with the message
sticker_ids: list[int] # IDs of up to 3 stickers in the server to send in the message sticker_ids: list[int] # IDs of up to 3 stickers in the server to send in the message
# files[n]: FileContents # Contents of the file being sent. See Uploading Files # files[n]: FileContents # Contents of the file being sent. See Uploading Files
@ -136,7 +136,11 @@ class Api:
return ApiAction.DELETE, f'/channels/{channel_id}/messages/{message_id}' return ApiAction.DELETE, f'/channels/{channel_id}/messages/{message_id}'
@staticmethod @staticmethod
def list_by_channel(channel_id: int, limit: int | None = None) -> tuple[ApiAction, str]: def list_by_channel(channel_id: int) -> tuple[ApiAction, str]:
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
return ApiAction.GET, f'/channels/{channel_id}/messages' return ApiAction.GET, f'/channels/{channel_id}/messages'
class ListMessageParams(TypedDict, total=False):
around: int # Get messages around this message ID
before: int # Get messages before this message ID
after: int # Get messages after this message ID
limit: int # Max number of messages to return (1-100), default=50

View file

@ -1,3 +1,5 @@
from __future__ import annotations as _annotations
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
@ -5,17 +7,33 @@ from dataclasses import asdict, dataclass
class Config: class Config:
bot_channel: str = 'breadtube-bot' bot_channel: str = 'breadtube-bot'
bot_role: str = 'BreadTube' bot_role: str = 'BreadTube'
bot_channel_scan_interval: float = 30.
bot_channel_init_retries: int = 3 bot_channel_init_retries: int = 3
bot_message_duration: float = 150.
request_timeout: float = 3.
def to_str(self) -> str: def to_str(self) -> str:
return '\n'.join(['config', *[f'{k}={v}' for k, v in asdict(self).items()]]) return '\n'.join(['config', *[f'{k}={v}' for k, v in asdict(self).items()]])
def from_str(self, text: str): @staticmethod
def from_str(text: str) -> Config:
annotations = Config.__annotations__
global_types = globals()['__builtins__']
config = Config()
lines = text.strip().splitlines() lines = text.strip().splitlines()
if not lines: if not lines:
raise RuntimeError('Config cannot load: empty input') raise RuntimeError('Cannot load config: empty input')
if lines[0] != 'config': if lines[0] != 'config':
raise RuntimeError('Config cannot load: first line is not "config"') raise RuntimeError('Cannot load config: first line is not "config"')
for line in lines[1:]: config_dict = {}
for line_number, line in enumerate(lines[1:]):
key, value = line.split('=', maxsplit=1) key, value = line.split('=', maxsplit=1)
setattr(self, key, self.__annotations__[key](value)) if key not in annotations:
raise RuntimeError(f'Invalid config: invalid key {key} at line {line_number + 1}')
if key in config_dict:
raise RuntimeError(f'Invalid config: duplicated key {key} at line {line_number + 1}')
config_dict[key] = value
for key, value in config_dict.items():
setattr(config, key, global_types[annotations[key]](value))
return config

View file

@ -1,11 +1,15 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, is_dataclass from dataclasses import asdict, dataclass, is_dataclass
from enum import Enum from enum import Enum
import logging import logging
import operator
from pathlib import Path from pathlib import Path
import json import json
import random import random
import time import time
import tomllib import tomllib
from typing import Any
import urllib.error import urllib.error
import urllib.request import urllib.request
@ -13,7 +17,8 @@ from .api import Api, ApiAction, ApiVersion
from .config import Config from .config import Config
from .logger import create_logger from .logger import create_logger
from .objects import ( from .objects import (
ChannelCategory, ChannelType, FileMime, Message, Overwrite, OverwriteType, Permissions, Role, TextChannel) Attachment, ChannelCategory, ChannelType, FileMime, Message, MessageReference, MessageReferenceType, Overwrite,
OverwriteType, Permissions, Role, TextChannel)
HTTPHeaders = dict[str, str] HTTPHeaders = dict[str, str]
@ -29,11 +34,20 @@ class ApiEncoder(json.JSONEncoder):
class DiscordManager: class DiscordManager:
MAX_CONFIG_SIZE: int = 50_000
DEFAULT_MESSAGE_LIST_LIMIT = 50
INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n'
'You can upload a new one to update the configuration.')
@dataclass @dataclass
class RateLimit: class RateLimit:
remaining: int remaining: int
next_reset: float next_reset: float
class Task(Enum):
SCAN_BOT_CHANNEL = 1
DELETE_MESSAGES = 2
@staticmethod @staticmethod
def _get_code_version() -> str: def _get_code_version() -> str:
pyproject_path = Path(__file__).parents[1] / 'pyproject.toml' pyproject_path = Path(__file__).parents[1] / 'pyproject.toml'
@ -50,8 +64,12 @@ class DiscordManager:
self.rate_limit = self.RateLimit(remaining=1, next_reset=0) self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
self.version = self._get_code_version() self.version = self._get_code_version()
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
self.logger.info('Retrieving guild roles before init')
self.guild_roles: list = self.list_roles() self.guild_roles: list = self.list_roles()
self.bot_channel: TextChannel | None = None
self.init_message: Message | None = None
for _ in range(self.config.bot_channel_init_retries): for _ in range(self.config.bot_channel_init_retries):
while not self.init_bot_channel(): while not self.init_bot_channel():
time.sleep(10) time.sleep(10)
@ -60,6 +78,10 @@ class DiscordManager:
self.logger.info('Bot init OK') self.logger.info('Bot init OK')
break break
raise RuntimeError("Couldn't initialize bot channel/role/permission") raise RuntimeError("Couldn't initialize bot channel/role/permission")
self._scan_bot_channel()
self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
self.logger.info('Bot initialized') self.logger.info('Bot initialized')
def _update_rate_limit(self, headers: HTTPHeaders): def _update_rate_limit(self, headers: HTTPHeaders):
@ -69,11 +91,11 @@ class DiscordManager:
return return
self.rate_limit.remaining = int(headers['x-ratelimit-remaining']) self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
self.rate_limit.next_reset = float(headers['x-ratelimit-reset']) self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
self.logger.debug('Updated rate limit: %s', self.rate_limit)
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10, def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None, data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]: expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]:
timeout = 3
min_api_version = 9 min_api_version = 9
if api_version.value < min_api_version: if api_version.value < min_api_version:
@ -90,7 +112,8 @@ class DiscordManager:
+ f'\r\n--{boundary}'.encode()) if data else b'' + f'\r\n--{boundary}'.encode()) if data else b''
for file_index, (name, mime, content) in enumerate(upload_files): for file_index, (name, mime, content) in enumerate(upload_files):
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";' data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
f' filename="{name}"\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content f' filename="{name}"\r\nContent-Length: {len(content)}'
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
data += f'\r\n--{boundary}--'.encode() data += f'\r\n--{boundary}--'.encode()
request = urllib.request.Request(url, data=data) request = urllib.request.Request(url, data=data)
request.method = api_action.value request.method = api_action.value
@ -102,7 +125,7 @@ class DiscordManager:
request.add_header('Content-Type', 'application/json') request.add_header('Content-Type', 'application/json')
request.add_header('Authorization', f'Bot {self._bot_token}') request.add_header('Authorization', f'Bot {self._bot_token}')
try: try:
with urllib.request.urlopen(request, timeout=timeout) as response: with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
if response.status != expected_code: if response.status != expected_code:
raise RuntimeError( raise RuntimeError(
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}') f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
@ -114,6 +137,23 @@ class DiscordManager:
except urllib.error.URLError as error: except urllib.error.URLError as error:
raise RuntimeError(f'URL error calling API ({url}): {error}') from error raise RuntimeError(f'URL error calling API ({url}): {error}') from error
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
request = urllib.request.Request(attachment.url)
request.add_header('User-Agent', f'BreadTube (v{self.version})')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
return dict(response.getheaders()), response.read()
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error downloading attachment ({attachment}):'
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
def init_bot_channel(self) -> bool: def init_bot_channel(self) -> bool:
_, text_channel = self.list_channels() _, text_channel = self.list_channels()
breadtube_role: Role | None = None breadtube_role: Role | None = None
@ -130,25 +170,20 @@ class DiscordManager:
self.logger.info('No everyone role found') self.logger.info('No everyone role found')
return False return False
breadtube_channel: TextChannel | None = None
for channel in text_channel: for channel in text_channel:
if channel.name == self.config.bot_channel: if channel.name == self.config.bot_channel:
breadtube_channel = channel self.bot_channel = channel
self.logger.info('Found breadtube bot channel') self.logger.info('Found breadtube bot channel')
for perm in breadtube_channel.permission_overwrites: for perm in self.bot_channel.permission_overwrites:
if perm.id == breadtube_role.id: if perm.id == breadtube_role.id:
if not perm.allow | Permissions.VIEW_CHANNEL: if not perm.allow | Permissions.VIEW_CHANNEL:
self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing') self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing')
return False return False
self.logger.info('BreadTube channel permission OK') self.logger.info('BreadTube channel permission OK')
break break
messages = self.list_text_channel_messages(breadtube_channel)
for message in messages:
self.logger.debug('Deleting message: %s', message)
self.delete_message(message)
break break
else: else:
breadtube_channel = self.create_text_channel({ self.bot_channel = self.create_text_channel({
'name': self.config.bot_channel, 'name': self.config.bot_channel,
'permission_overwrites': [ 'permission_overwrites': [
Overwrite(everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE, Overwrite(everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
@ -157,13 +192,152 @@ class DiscordManager:
deny=Permissions.NONE)] deny=Permissions.NONE)]
}) })
self.logger.info('Created breadtube bot channel') self.logger.info('Created breadtube bot channel')
self.create_message(
breadtube_channel,
{'content': 'This is the current configuration used, upload a new one to update the configuration'},
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
return True return True
def _scan_bot_channel(self):
if self.bot_channel is None:
self.logger.error('Cannot scan bot channel: bot channel is None')
return []
last_message_id: int | None = None
while True:
messages = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
if len(messages) < self.DEFAULT_MESSAGE_LIST_LIMIT:
break
last_message_id = messages[-1].id
messages = sorted(messages, key=lambda x: x.timestamp)
self.init_message = None
new_config: Config | None = None
messages_to_delete: list[Message] = []
for message in messages:
# Skip message to be deleted
skip = True
for task_type, _, task_params in self.tasks:
if task_type == self.Task.DELETE_MESSAGES and (any(m.id == message.id for m in task_params) or any(
m.id == message.id for m in messages_to_delete)):
self.logger.debug('Skipping message already marked to be deleted')
break
else:
skip = False
if skip:
continue
delete_message = True
for attachment in message.attachments:
if attachment.size < self.MAX_CONFIG_SIZE:
try:
_, content = self._download_attachment(attachment)
if content.startswith(b'config'):
try:
config = Config.from_str(content.decode())
if config != self.config:
new_config = config
elif message.content == self.INIT_MESSAGE:
if self.init_message is not None:
self.logger.debug('Deleting duplicated init message')
try:
self.delete_message(self.init_message)
except RuntimeError as error:
self.logger.error('Error deleting init_message while scanning: %s', error)
self.init_message = message
delete_message = False
break
except RuntimeError as error:
self.logger.info('Invalid config file: %s', error)
messages_to_delete.extend([
self.create_message(self.bot_channel, {
'content': str(error),
'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT,
message_id=message.id,
channel_id=self.bot_channel.id,
guild_id=None,
fail_if_not_exists=None)}),
message])
delete_message = False
break
except Exception as error:
self.logger.error('Error downloading attachment: %s', error)
messages_to_delete.extend([
self.create_message(self.bot_channel, {
'content': str(error),
'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT,
message_id=message.id,
channel_id=self.bot_channel.id,
guild_id=None,
fail_if_not_exists=None)}),
message])
delete_message = False
break
if delete_message:
if any(m.id == message.id for m in messages_to_delete):
self.logger.warning(
'Warning wrongly trying to delete message id %d while marked to be deleted', message.id)
else:
self.logger.debug('Deleting message: %s', message)
try:
self.delete_message(message)
except RuntimeError as error:
self.logger.error('Error deleting after scanned: %s', error)
if new_config is not None:
self.logger.info('Loading new config: %s', new_config)
self.config = new_config
if self.init_message is not None:
self.delete_message(self.init_message)
self.init_message = None
if self.init_message is None:
self.init_message = self.create_message(
self.bot_channel, {'content': self.INIT_MESSAGE},
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
if messages_to_delete:
self.tasks.append((
DiscordManager.Task.DELETE_MESSAGES,
time.time() + self.config.bot_message_duration,
messages_to_delete))
def run(self):
while True:
if self.tasks:
self.tasks = sorted(self.tasks, key=operator.itemgetter(1), reverse=True)
task_type, task_time, task_params = self.tasks.pop()
sleep_time = task_time - time.time()
self.logger.debug(
'Next task %s at %.03f (%s), sleeping %.03fs', task_type, task_time, task_params, sleep_time)
if sleep_time > 0:
time.sleep(sleep_time)
match task_type:
case DiscordManager.Task.SCAN_BOT_CHANNEL:
try:
self._scan_bot_channel()
except Exception as error:
self.logger.error('Error scanning bot channel: %s', error)
self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
case DiscordManager.Task.DELETE_MESSAGES:
if not isinstance(task_params, list):
self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params)
elif not task_params:
self.logger.error('Empty params for DELETE_MESSAGES: %s', task_params)
elif any(not isinstance(v, Message) for v in task_params):
self.logger.error('All params not int for DELETE_MESSAGES: %s', task_params)
else:
for message in task_params:
try:
self.delete_message(message)
except Exception as error:
self.logger.error('Error deleting message %s: %s', message, error)
if self.rate_limit.remaining <= 1:
sleep_time = self.rate_limit.next_reset - time.time()
if sleep_time > 0:
self.logger.debug('Rate limit: sleeping %.03f second')
time.sleep(sleep_time)
time.sleep(1)
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel: def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
headers, channel_info = self._send_request( headers, channel_info = self._send_request(
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(), *Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
@ -210,7 +384,19 @@ class DiscordManager:
raise RuntimeError(f'Error listing roles (not a list): {roles_info}') raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
return [Role.from_dict(r) for r in roles_info] return [Role.from_dict(r) for r in roles_info]
def list_text_channel_messages(self, channel: TextChannel) -> list[Message]: def list_text_channel_messages(self, channel: TextChannel, limit: int | None = None, before_id: int | None = None,
headers, messages = self._send_request(*Api.Message.list_by_channel(channel.id)) after_id: int | None = None) -> list[Message]:
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
params: Api.Message.ListMessageParams = {}
if limit is not None:
params['limit'] = limit
if before_id is not None:
params['before'] = before_id
if after_id is not None:
params['after'] = after_id
headers, messages_info = self._send_request(
*Api.Message.list_by_channel(channel.id),
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
self._update_rate_limit(headers) self._update_rate_limit(headers)
return [Message.from_dict(m) for m in messages or []] return [Message.from_dict(m) for m in messages_info or []]

View file

@ -16,11 +16,16 @@ class FileMime(Enum):
TEXT_HTML = 'text/html' TEXT_HTML = 'text/html'
TEXT_MARKDOWN = 'text/markdown' TEXT_MARKDOWN = 'text/markdown'
TEXT_PLAIN = 'text/plain' TEXT_PLAIN = 'text/plain'
UNKNOWN = 'application/unknown'
VIDEO_MP4 = 'video/mp4' VIDEO_MP4 = 'video/mp4'
VIDEO_MPEG = 'video/mpeg' VIDEO_MPEG = 'video/mpeg'
VIDEO_WEBM = 'video/webm' VIDEO_WEBM = 'video/webm'
ZIP = 'application/zip' ZIP = 'application/zip'
@classmethod
def _missing_(cls, value): # noqa: ARG003
return FileMime.UNKNOWN
class ChannelType(Enum): class ChannelType(Enum):
GUILD_TEXT = 0 GUILD_TEXT = 0
@ -256,6 +261,144 @@ class User: # TODO : complete attributes
global_name=info.get('global_name')) global_name=info.get('global_name'))
class AttachmentFlags(IntFlag):
IS_REMIX = 1 << 2 # this attachment has been edited using the remix feature on mobile
@dataclass
class Attachment:
id: int # attachment id
filename: str # name of file attached
title: str | None # the title of the file
description: str | None # description for the file (max 1024 characters)
content_type: str | None # the attachment's media type
size: int # size of file in bytes
url: str # source url of file
proxy_url: str # a proxied url of file
height: int | None # height of file (if image)
width: int | None # width of file (if image)
ephemeral: bool | None # whether this attachment is ephemeral
duration_secs: float | None # the duration of the audio file (currently for voice messages)
waveform: str | None # base64 encoded bytearray representing a sampled waveform (currently for voice messages)
flags: int | None # attachment flags combined as a bitfield
@staticmethod
def from_dict(info: dict) -> Attachment:
height = info.get('height')
width = info.get('width')
duraction_secs = info.get('duration_secs')
flags = info.get('flags')
return Attachment(
id=int(info['id']),
filename=info['filename'],
title=info.get('title'),
description=info.get('description'),
content_type=info.get('content_type'),
size=int(info['size']),
url=info['url'],
proxy_url=info['proxy_url'],
height=int(height) if height is not None else None,
width=int(width) if width is not None else None,
ephemeral=info.get('ephemeral'),
duration_secs=float(duraction_secs) if duraction_secs is not None else None,
waveform=info.get('waveform'),
flags=AttachmentFlags(int(flags)) if flags is not None else None,
)
class MessageType(Enum):
DEFAULT = 0
RECIPIENT_ADD = 1
RECIPIENT_REMOVE = 2
CALL = 3
CHANNEL_NAME_CHANGE = 4
CHANNEL_ICON_CHANGE = 5
CHANNEL_PINNED_MESSAGE = 6
USER_JOIN = 7
GUILD_BOOST = 8
GUILD_BOOST_TIER_1 = 9
GUILD_BOOST_TIER_2 = 10
GUILD_BOOST_TIER_3 = 11
CHANNEL_FOLLOW_ADD = 12
GUILD_DISCOVERY_DISQUALIFIED = 14
GUILD_DISCOVERY_REQUALIFIED = 15
GUILD_DISCOVERY_GRACE_PERIOD_INITIAL_WARNING = 16
GUILD_DISCOVERY_GRACE_PERIOD_FINAL_WARNING = 17
THREAD_CREATED = 18
REPLY = 19
CHAT_INPUT_COMMAND = 20
THREAD_STARTER_MESSAGE = 21
GUILD_INVITE_REMINDER = 22
CONTEXT_MENU_COMMAND = 23
AUTO_MODERATION_ACTION = 24
ROLE_SUBSCRIPTION_PURCHASE = 25
INTERACTION_PREMIUM_UPSELL = 26
STAGE_START = 27
STAGE_END = 28
STAGE_SPEAKER = 29
STAGE_TOPIC = 31
GUILD_APPLICATION_PREMIUM_SUBSCRIPTION = 32
GUILD_INCIDENT_ALERT_MODE_ENABLED = 36
GUILD_INCIDENT_ALERT_MODE_DISABLED = 37
GUILD_INCIDENT_REPORT_RAID = 38
GUILD_INCIDENT_REPORT_FALSE_ALARM = 39
PURCHASE_NOTIFICATION = 44
POLL_RESULT = 46
NON_DELETABLE_MESSAGE_TYPES = [
MessageType.RECIPIENT_ADD,
MessageType.RECIPIENT_REMOVE,
MessageType.CALL,
MessageType.CHANNEL_NAME_CHANGE,
MessageType.CHANNEL_ICON_CHANGE,
MessageType.THREAD_STARTER_MESSAGE]
class MessageFlags(IntFlag):
NONE = 0
CROSSPOSTED = 1 << 0 # this message has been published to subscribed channels (via Channel Following)
IS_CROSSPOST = 1 << 1 # this message originated from a message in another channel (via Channel Following)
SUPPRESS_EMBEDS = 1 << 2 # do not include any embeds when serializing this message
SOURCE_MESSAGE_DELETED = 1 << 3 # the source message for this crosspost has been deleted (via Channel Following)
URGENT = 1 << 4 # this message came from the urgent message system
HAS_THREAD = 1 << 5 # this message has an associated thread, with the same id as the message
EPHEMERAL = 1 << 6 # this message is only visible to the user who invoked the Interaction
LOADING = 1 << 7 # this message is an Interaction Response and the bot is "thinking"
# this message failed to mention some roles and add their members to the thread
FAILED_TO_MENTION_SOME_ROLES_IN_THREAD = 1 << 8
SUPPRESS_NOTIFICATIONS = 1 << 12 # this message will not trigger push and desktop notifications
IS_VOICE_MESSAGE = 1 << 13 # this message is a voice message
HAS_SNAPSHOT = 1 << 14 # this message has a snapshot (via Message Forwarding)
IS_COMPONENTS_V2 = 1 << 15 # allows you to create fully component-driven messages
class MessageReferenceType(Enum):
DEFAULT = 0 # A standard reference used by replies.
FORWARD = 1 # Reference used to point to a message at a point in time.
@dataclass
class MessageReference:
type: MessageReferenceType | None # type of reference.
message_id: int | None # id of the originating message
channel_id: int | None # id of the originating message's channel
guild_id: int | None # id of the originating message's guild
# when sending, whether to error if the referenced message doesn't exist
# instead of sending as a normal (non-reply) message, default true
fail_if_not_exists: bool | None
@staticmethod
def from_dict(info: dict) -> MessageReference:
ref_type: int | None = info.get('type')
return MessageReference(
type=MessageReferenceType(ref_type) if ref_type is not None else None,
message_id=info.get('message_id'),
channel_id=info.get('channel_id'),
guild_id=info.get('guild_id'),
fail_if_not_exists=info.get('fail_if_not_exists'))
@dataclass @dataclass
class Message: # TODO : complete attributes class Message: # TODO : complete attributes
id: int id: int
@ -264,10 +407,16 @@ class Message: # TODO : complete attributes
content: str content: str
timestamp: datetime timestamp: datetime
edited_timestamp: datetime | None edited_timestamp: datetime | None
attachments: list[Attachment]
type: MessageType
flags: MessageFlags | None
message_reference: MessageReference | None
@staticmethod @staticmethod
def from_dict(info: dict) -> Message: def from_dict(info: dict) -> Message:
edited_timestamp: str | None = info.get('edited_timestamp') edited_timestamp: str | None = info.get('edited_timestamp')
flags: int | None = info.get('flags')
message_reference = info.get('message_reference')
return Message( return Message(
id=int(info['id']), id=int(info['id']),
channel_id=int(info['channel_id']), channel_id=int(info['channel_id']),
@ -275,7 +424,11 @@ class Message: # TODO : complete attributes
id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)), id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)),
content=info['content'], content=info['content'],
timestamp=datetime.fromisoformat(info['timestamp']), timestamp=datetime.fromisoformat(info['timestamp']),
edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None) edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None,
attachments=[Attachment.from_dict(a) for a in info['attachments']],
type=MessageType(info['type']),
flags=MessageFlags(flags) if flags is not None else None,
message_reference=MessageReference.from_dict(message_reference) if message_reference is not None else None)
@dataclass @dataclass

44
tests/test_config.py Normal file
View file

@ -0,0 +1,44 @@
from breadtube_bot.config import Config
import pytest
def test_empty():
with pytest.raises(RuntimeError, match='Cannot load config: empty input'):
Config.from_str('')
def test_wrong_header():
with pytest.raises(RuntimeError, match='Cannot load config: first line is not "config"'):
Config.from_str('connfig')
def test_lacking_key():
expected_config = Config(bot_role='test-role', request_timeout=5.5)
config = Config.from_str(
'config\n'
f'request_timeout={expected_config.request_timeout}\n'
f'bot_role={expected_config.bot_role}')
assert config == expected_config
def test_wrong_key():
with pytest.raises(RuntimeError, match=('Invalid config: invalid key bot_channel_init at line 2')):
Config.from_str('config\nrequest_timeout=3.\nbot_channel_init=2')
def test_duplicated_key():
with pytest.raises(RuntimeError, match=('Invalid config: duplicated key request_timeout at line 3')):
Config.from_str('config\nrequest_timeout=3.\nbot_channel_init_retries=2\nrequest_timeout=5.')
def test_correct_config():
expected_config = Config(bot_channel='test-channel', bot_role='test-role', bot_channel_init_retries=298,
request_timeout=5.5)
config = Config.from_str(
'config\n'
f'request_timeout={expected_config.request_timeout}\n'
f'bot_channel_init_retries={expected_config.bot_channel_init_retries}\n'
f'bot_channel={expected_config.bot_channel}\n'
f'bot_role={expected_config.bot_role}')
assert config == expected_config