Config scan from bot channel implementation
This commit is contained in:
parent
72edbe6599
commit
157e8c1b17
6 changed files with 453 additions and 34 deletions
20
bot.py
20
bot.py
|
|
@ -1,13 +1,27 @@
|
|||
from argparse import ArgumentParser
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from breadtube_bot.manager import DiscordManager
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser('BreadTube-bot')
|
||||
parser.add_argument('--guild', type=int, default=1306964577812086824, help='Guild id to manage')
|
||||
parser.add_argument('--debug', action='store_true', default=False, help='Run in debug mode (for logs)')
|
||||
arguments = parser.parse_args()
|
||||
|
||||
debug_mode: bool = arguments.debug
|
||||
guild_id: int = arguments.guild
|
||||
del arguments
|
||||
|
||||
bot_token = Path('data/discord_bot_token.txt').read_text(encoding='utf-8').strip()
|
||||
guild_id = 1306964577812086824
|
||||
manager = DiscordManager(bot_token=bot_token, guild_id=guild_id)
|
||||
print(manager.rate_limit)
|
||||
manager = DiscordManager(
|
||||
bot_token=bot_token, guild_id=guild_id, log_level=logging.DEBUG if debug_mode else logging.INFO)
|
||||
try:
|
||||
manager.run()
|
||||
except KeyboardInterrupt:
|
||||
print('\r ') # noqa: T201
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from enum import Enum
|
||||
from typing import TypedDict
|
||||
|
||||
from breadtube_bot.objects import Overwrite
|
||||
from breadtube_bot.objects import MessageReference, Overwrite
|
||||
|
||||
|
||||
class ApiVersion(Enum):
|
||||
|
|
@ -116,7 +116,7 @@ class Api:
|
|||
tts: bool # true if this is a TTS message
|
||||
# embeds: list[Embeded] # Up to 10 rich embeds (up to 6000 characters)
|
||||
# allowed_mentions: MentionObject # Allowed mentions for the message
|
||||
# message_reference: MessageReference # Include to make your message a reply or a forward
|
||||
message_reference: MessageReference # Include to make your message a reply or a forward
|
||||
# components: list[MessageComponent] # Components to include with the message
|
||||
sticker_ids: list[int] # IDs of up to 3 stickers in the server to send in the message
|
||||
# files[n]: FileContents # Contents of the file being sent. See Uploading Files
|
||||
|
|
@ -136,7 +136,11 @@ class Api:
|
|||
return ApiAction.DELETE, f'/channels/{channel_id}/messages/{message_id}'
|
||||
|
||||
@staticmethod
|
||||
def list_by_channel(channel_id: int, limit: int | None = None) -> tuple[ApiAction, str]:
|
||||
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
|
||||
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
|
||||
def list_by_channel(channel_id: int) -> tuple[ApiAction, str]:
|
||||
return ApiAction.GET, f'/channels/{channel_id}/messages'
|
||||
|
||||
class ListMessageParams(TypedDict, total=False):
|
||||
around: int # Get messages around this message ID
|
||||
before: int # Get messages before this message ID
|
||||
after: int # Get messages after this message ID
|
||||
limit: int # Max number of messages to return (1-100), default=50
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations as _annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
|
||||
|
|
@ -5,17 +7,33 @@ from dataclasses import asdict, dataclass
|
|||
class Config:
|
||||
bot_channel: str = 'breadtube-bot'
|
||||
bot_role: str = 'BreadTube'
|
||||
bot_channel_scan_interval: float = 30.
|
||||
bot_channel_init_retries: int = 3
|
||||
bot_message_duration: float = 150.
|
||||
request_timeout: float = 3.
|
||||
|
||||
def to_str(self) -> str:
|
||||
return '\n'.join(['config', *[f'{k}={v}' for k, v in asdict(self).items()]])
|
||||
|
||||
def from_str(self, text: str):
|
||||
@staticmethod
|
||||
def from_str(text: str) -> Config:
|
||||
annotations = Config.__annotations__
|
||||
global_types = globals()['__builtins__']
|
||||
config = Config()
|
||||
lines = text.strip().splitlines()
|
||||
if not lines:
|
||||
raise RuntimeError('Config cannot load: empty input')
|
||||
raise RuntimeError('Cannot load config: empty input')
|
||||
if lines[0] != 'config':
|
||||
raise RuntimeError('Config cannot load: first line is not "config"')
|
||||
for line in lines[1:]:
|
||||
raise RuntimeError('Cannot load config: first line is not "config"')
|
||||
config_dict = {}
|
||||
for line_number, line in enumerate(lines[1:]):
|
||||
key, value = line.split('=', maxsplit=1)
|
||||
setattr(self, key, self.__annotations__[key](value))
|
||||
if key not in annotations:
|
||||
raise RuntimeError(f'Invalid config: invalid key {key} at line {line_number + 1}')
|
||||
if key in config_dict:
|
||||
raise RuntimeError(f'Invalid config: duplicated key {key} at line {line_number + 1}')
|
||||
config_dict[key] = value
|
||||
|
||||
for key, value in config_dict.items():
|
||||
setattr(config, key, global_types[annotations[key]](value))
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -1,11 +1,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
import operator
|
||||
from pathlib import Path
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
import tomllib
|
||||
from typing import Any
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
|
|
@ -13,7 +17,8 @@ from .api import Api, ApiAction, ApiVersion
|
|||
from .config import Config
|
||||
from .logger import create_logger
|
||||
from .objects import (
|
||||
ChannelCategory, ChannelType, FileMime, Message, Overwrite, OverwriteType, Permissions, Role, TextChannel)
|
||||
Attachment, ChannelCategory, ChannelType, FileMime, Message, MessageReference, MessageReferenceType, Overwrite,
|
||||
OverwriteType, Permissions, Role, TextChannel)
|
||||
|
||||
|
||||
HTTPHeaders = dict[str, str]
|
||||
|
|
@ -29,11 +34,20 @@ class ApiEncoder(json.JSONEncoder):
|
|||
|
||||
|
||||
class DiscordManager:
|
||||
MAX_CONFIG_SIZE: int = 50_000
|
||||
DEFAULT_MESSAGE_LIST_LIMIT = 50
|
||||
INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n'
|
||||
'You can upload a new one to update the configuration.')
|
||||
|
||||
@dataclass
|
||||
class RateLimit:
|
||||
remaining: int
|
||||
next_reset: float
|
||||
|
||||
class Task(Enum):
|
||||
SCAN_BOT_CHANNEL = 1
|
||||
DELETE_MESSAGES = 2
|
||||
|
||||
@staticmethod
|
||||
def _get_code_version() -> str:
|
||||
pyproject_path = Path(__file__).parents[1] / 'pyproject.toml'
|
||||
|
|
@ -50,8 +64,12 @@ class DiscordManager:
|
|||
|
||||
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
|
||||
self.version = self._get_code_version()
|
||||
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
|
||||
|
||||
self.logger.info('Retrieving guild roles before init')
|
||||
self.guild_roles: list = self.list_roles()
|
||||
self.bot_channel: TextChannel | None = None
|
||||
self.init_message: Message | None = None
|
||||
for _ in range(self.config.bot_channel_init_retries):
|
||||
while not self.init_bot_channel():
|
||||
time.sleep(10)
|
||||
|
|
@ -60,6 +78,10 @@ class DiscordManager:
|
|||
self.logger.info('Bot init OK')
|
||||
break
|
||||
raise RuntimeError("Couldn't initialize bot channel/role/permission")
|
||||
|
||||
self._scan_bot_channel()
|
||||
self.tasks.append((
|
||||
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
|
||||
self.logger.info('Bot initialized')
|
||||
|
||||
def _update_rate_limit(self, headers: HTTPHeaders):
|
||||
|
|
@ -69,11 +91,11 @@ class DiscordManager:
|
|||
return
|
||||
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
|
||||
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
|
||||
self.logger.debug('Updated rate limit: %s', self.rate_limit)
|
||||
|
||||
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
|
||||
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
|
||||
expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]:
|
||||
timeout = 3
|
||||
min_api_version = 9
|
||||
|
||||
if api_version.value < min_api_version:
|
||||
|
|
@ -90,7 +112,8 @@ class DiscordManager:
|
|||
+ f'\r\n--{boundary}'.encode()) if data else b''
|
||||
for file_index, (name, mime, content) in enumerate(upload_files):
|
||||
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
|
||||
f' filename="{name}"\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
|
||||
f' filename="{name}"\r\nContent-Length: {len(content)}'
|
||||
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
|
||||
data += f'\r\n--{boundary}--'.encode()
|
||||
request = urllib.request.Request(url, data=data)
|
||||
request.method = api_action.value
|
||||
|
|
@ -102,7 +125,7 @@ class DiscordManager:
|
|||
request.add_header('Content-Type', 'application/json')
|
||||
request.add_header('Authorization', f'Bot {self._bot_token}')
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=timeout) as response:
|
||||
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
|
||||
if response.status != expected_code:
|
||||
raise RuntimeError(
|
||||
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
|
||||
|
|
@ -114,6 +137,23 @@ class DiscordManager:
|
|||
except urllib.error.URLError as error:
|
||||
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
|
||||
|
||||
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
|
||||
request = urllib.request.Request(attachment.url)
|
||||
request.add_header('User-Agent', f'BreadTube (v{self.version})')
|
||||
request.add_header('Authorization', f'Bot {self._bot_token}')
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
|
||||
if response.status != expected_code:
|
||||
raise RuntimeError(
|
||||
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
|
||||
return dict(response.getheaders()), response.read()
|
||||
except urllib.error.HTTPError as error:
|
||||
raise RuntimeError(
|
||||
f'HTTP error downloading attachment ({attachment}):'
|
||||
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
||||
except urllib.error.URLError as error:
|
||||
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
|
||||
|
||||
def init_bot_channel(self) -> bool:
|
||||
_, text_channel = self.list_channels()
|
||||
breadtube_role: Role | None = None
|
||||
|
|
@ -130,25 +170,20 @@ class DiscordManager:
|
|||
self.logger.info('No everyone role found')
|
||||
return False
|
||||
|
||||
breadtube_channel: TextChannel | None = None
|
||||
for channel in text_channel:
|
||||
if channel.name == self.config.bot_channel:
|
||||
breadtube_channel = channel
|
||||
self.bot_channel = channel
|
||||
self.logger.info('Found breadtube bot channel')
|
||||
for perm in breadtube_channel.permission_overwrites:
|
||||
for perm in self.bot_channel.permission_overwrites:
|
||||
if perm.id == breadtube_role.id:
|
||||
if not perm.allow | Permissions.VIEW_CHANNEL:
|
||||
self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing')
|
||||
return False
|
||||
self.logger.info('BreadTube channel permission OK')
|
||||
break
|
||||
messages = self.list_text_channel_messages(breadtube_channel)
|
||||
for message in messages:
|
||||
self.logger.debug('Deleting message: %s', message)
|
||||
self.delete_message(message)
|
||||
break
|
||||
else:
|
||||
breadtube_channel = self.create_text_channel({
|
||||
self.bot_channel = self.create_text_channel({
|
||||
'name': self.config.bot_channel,
|
||||
'permission_overwrites': [
|
||||
Overwrite(everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
|
||||
|
|
@ -157,13 +192,152 @@ class DiscordManager:
|
|||
deny=Permissions.NONE)]
|
||||
})
|
||||
self.logger.info('Created breadtube bot channel')
|
||||
|
||||
self.create_message(
|
||||
breadtube_channel,
|
||||
{'content': 'This is the current configuration used, upload a new one to update the configuration'},
|
||||
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
|
||||
return True
|
||||
|
||||
def _scan_bot_channel(self):
|
||||
if self.bot_channel is None:
|
||||
self.logger.error('Cannot scan bot channel: bot channel is None')
|
||||
return []
|
||||
|
||||
last_message_id: int | None = None
|
||||
while True:
|
||||
messages = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
|
||||
if len(messages) < self.DEFAULT_MESSAGE_LIST_LIMIT:
|
||||
break
|
||||
last_message_id = messages[-1].id
|
||||
messages = sorted(messages, key=lambda x: x.timestamp)
|
||||
|
||||
self.init_message = None
|
||||
new_config: Config | None = None
|
||||
messages_to_delete: list[Message] = []
|
||||
for message in messages:
|
||||
# Skip message to be deleted
|
||||
skip = True
|
||||
for task_type, _, task_params in self.tasks:
|
||||
if task_type == self.Task.DELETE_MESSAGES and (any(m.id == message.id for m in task_params) or any(
|
||||
m.id == message.id for m in messages_to_delete)):
|
||||
self.logger.debug('Skipping message already marked to be deleted')
|
||||
break
|
||||
else:
|
||||
skip = False
|
||||
if skip:
|
||||
continue
|
||||
|
||||
delete_message = True
|
||||
for attachment in message.attachments:
|
||||
if attachment.size < self.MAX_CONFIG_SIZE:
|
||||
try:
|
||||
_, content = self._download_attachment(attachment)
|
||||
if content.startswith(b'config'):
|
||||
try:
|
||||
config = Config.from_str(content.decode())
|
||||
if config != self.config:
|
||||
new_config = config
|
||||
elif message.content == self.INIT_MESSAGE:
|
||||
if self.init_message is not None:
|
||||
self.logger.debug('Deleting duplicated init message')
|
||||
try:
|
||||
self.delete_message(self.init_message)
|
||||
except RuntimeError as error:
|
||||
self.logger.error('Error deleting init_message while scanning: %s', error)
|
||||
self.init_message = message
|
||||
delete_message = False
|
||||
break
|
||||
except RuntimeError as error:
|
||||
self.logger.info('Invalid config file: %s', error)
|
||||
messages_to_delete.extend([
|
||||
self.create_message(self.bot_channel, {
|
||||
'content': str(error),
|
||||
'message_reference': MessageReference(
|
||||
type=MessageReferenceType.DEFAULT,
|
||||
message_id=message.id,
|
||||
channel_id=self.bot_channel.id,
|
||||
guild_id=None,
|
||||
fail_if_not_exists=None)}),
|
||||
message])
|
||||
delete_message = False
|
||||
break
|
||||
except Exception as error:
|
||||
self.logger.error('Error downloading attachment: %s', error)
|
||||
messages_to_delete.extend([
|
||||
self.create_message(self.bot_channel, {
|
||||
'content': str(error),
|
||||
'message_reference': MessageReference(
|
||||
type=MessageReferenceType.DEFAULT,
|
||||
message_id=message.id,
|
||||
channel_id=self.bot_channel.id,
|
||||
guild_id=None,
|
||||
fail_if_not_exists=None)}),
|
||||
message])
|
||||
delete_message = False
|
||||
break
|
||||
if delete_message:
|
||||
if any(m.id == message.id for m in messages_to_delete):
|
||||
self.logger.warning(
|
||||
'Warning wrongly trying to delete message id %d while marked to be deleted', message.id)
|
||||
else:
|
||||
self.logger.debug('Deleting message: %s', message)
|
||||
try:
|
||||
self.delete_message(message)
|
||||
except RuntimeError as error:
|
||||
self.logger.error('Error deleting after scanned: %s', error)
|
||||
|
||||
if new_config is not None:
|
||||
self.logger.info('Loading new config: %s', new_config)
|
||||
self.config = new_config
|
||||
if self.init_message is not None:
|
||||
self.delete_message(self.init_message)
|
||||
self.init_message = None
|
||||
|
||||
if self.init_message is None:
|
||||
self.init_message = self.create_message(
|
||||
self.bot_channel, {'content': self.INIT_MESSAGE},
|
||||
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
|
||||
|
||||
if messages_to_delete:
|
||||
self.tasks.append((
|
||||
DiscordManager.Task.DELETE_MESSAGES,
|
||||
time.time() + self.config.bot_message_duration,
|
||||
messages_to_delete))
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
if self.tasks:
|
||||
self.tasks = sorted(self.tasks, key=operator.itemgetter(1), reverse=True)
|
||||
task_type, task_time, task_params = self.tasks.pop()
|
||||
sleep_time = task_time - time.time()
|
||||
self.logger.debug(
|
||||
'Next task %s at %.03f (%s), sleeping %.03fs', task_type, task_time, task_params, sleep_time)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
match task_type:
|
||||
case DiscordManager.Task.SCAN_BOT_CHANNEL:
|
||||
try:
|
||||
self._scan_bot_channel()
|
||||
except Exception as error:
|
||||
self.logger.error('Error scanning bot channel: %s', error)
|
||||
self.tasks.append((
|
||||
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
|
||||
case DiscordManager.Task.DELETE_MESSAGES:
|
||||
if not isinstance(task_params, list):
|
||||
self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params)
|
||||
elif not task_params:
|
||||
self.logger.error('Empty params for DELETE_MESSAGES: %s', task_params)
|
||||
elif any(not isinstance(v, Message) for v in task_params):
|
||||
self.logger.error('All params not int for DELETE_MESSAGES: %s', task_params)
|
||||
else:
|
||||
for message in task_params:
|
||||
try:
|
||||
self.delete_message(message)
|
||||
except Exception as error:
|
||||
self.logger.error('Error deleting message %s: %s', message, error)
|
||||
if self.rate_limit.remaining <= 1:
|
||||
sleep_time = self.rate_limit.next_reset - time.time()
|
||||
if sleep_time > 0:
|
||||
self.logger.debug('Rate limit: sleeping %.03f second')
|
||||
time.sleep(sleep_time)
|
||||
time.sleep(1)
|
||||
|
||||
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
|
||||
headers, channel_info = self._send_request(
|
||||
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
||||
|
|
@ -210,7 +384,19 @@ class DiscordManager:
|
|||
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
|
||||
return [Role.from_dict(r) for r in roles_info]
|
||||
|
||||
def list_text_channel_messages(self, channel: TextChannel) -> list[Message]:
|
||||
headers, messages = self._send_request(*Api.Message.list_by_channel(channel.id))
|
||||
def list_text_channel_messages(self, channel: TextChannel, limit: int | None = None, before_id: int | None = None,
|
||||
after_id: int | None = None) -> list[Message]:
|
||||
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
|
||||
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
|
||||
params: Api.Message.ListMessageParams = {}
|
||||
if limit is not None:
|
||||
params['limit'] = limit
|
||||
if before_id is not None:
|
||||
params['before'] = before_id
|
||||
if after_id is not None:
|
||||
params['after'] = after_id
|
||||
headers, messages_info = self._send_request(
|
||||
*Api.Message.list_by_channel(channel.id),
|
||||
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
|
||||
self._update_rate_limit(headers)
|
||||
return [Message.from_dict(m) for m in messages or []]
|
||||
return [Message.from_dict(m) for m in messages_info or []]
|
||||
|
|
|
|||
|
|
@ -16,11 +16,16 @@ class FileMime(Enum):
|
|||
TEXT_HTML = 'text/html'
|
||||
TEXT_MARKDOWN = 'text/markdown'
|
||||
TEXT_PLAIN = 'text/plain'
|
||||
UNKNOWN = 'application/unknown'
|
||||
VIDEO_MP4 = 'video/mp4'
|
||||
VIDEO_MPEG = 'video/mpeg'
|
||||
VIDEO_WEBM = 'video/webm'
|
||||
ZIP = 'application/zip'
|
||||
|
||||
@classmethod
|
||||
def _missing_(cls, value): # noqa: ARG003
|
||||
return FileMime.UNKNOWN
|
||||
|
||||
|
||||
class ChannelType(Enum):
|
||||
GUILD_TEXT = 0
|
||||
|
|
@ -256,6 +261,144 @@ class User: # TODO : complete attributes
|
|||
global_name=info.get('global_name'))
|
||||
|
||||
|
||||
class AttachmentFlags(IntFlag):
|
||||
IS_REMIX = 1 << 2 # this attachment has been edited using the remix feature on mobile
|
||||
|
||||
|
||||
@dataclass
|
||||
class Attachment:
|
||||
id: int # attachment id
|
||||
filename: str # name of file attached
|
||||
title: str | None # the title of the file
|
||||
description: str | None # description for the file (max 1024 characters)
|
||||
content_type: str | None # the attachment's media type
|
||||
size: int # size of file in bytes
|
||||
url: str # source url of file
|
||||
proxy_url: str # a proxied url of file
|
||||
height: int | None # height of file (if image)
|
||||
width: int | None # width of file (if image)
|
||||
ephemeral: bool | None # whether this attachment is ephemeral
|
||||
duration_secs: float | None # the duration of the audio file (currently for voice messages)
|
||||
waveform: str | None # base64 encoded bytearray representing a sampled waveform (currently for voice messages)
|
||||
flags: int | None # attachment flags combined as a bitfield
|
||||
|
||||
@staticmethod
|
||||
def from_dict(info: dict) -> Attachment:
|
||||
height = info.get('height')
|
||||
width = info.get('width')
|
||||
duraction_secs = info.get('duration_secs')
|
||||
flags = info.get('flags')
|
||||
return Attachment(
|
||||
id=int(info['id']),
|
||||
filename=info['filename'],
|
||||
title=info.get('title'),
|
||||
description=info.get('description'),
|
||||
content_type=info.get('content_type'),
|
||||
size=int(info['size']),
|
||||
url=info['url'],
|
||||
proxy_url=info['proxy_url'],
|
||||
height=int(height) if height is not None else None,
|
||||
width=int(width) if width is not None else None,
|
||||
ephemeral=info.get('ephemeral'),
|
||||
duration_secs=float(duraction_secs) if duraction_secs is not None else None,
|
||||
waveform=info.get('waveform'),
|
||||
flags=AttachmentFlags(int(flags)) if flags is not None else None,
|
||||
)
|
||||
|
||||
|
||||
class MessageType(Enum):
|
||||
DEFAULT = 0
|
||||
RECIPIENT_ADD = 1
|
||||
RECIPIENT_REMOVE = 2
|
||||
CALL = 3
|
||||
CHANNEL_NAME_CHANGE = 4
|
||||
CHANNEL_ICON_CHANGE = 5
|
||||
CHANNEL_PINNED_MESSAGE = 6
|
||||
USER_JOIN = 7
|
||||
GUILD_BOOST = 8
|
||||
GUILD_BOOST_TIER_1 = 9
|
||||
GUILD_BOOST_TIER_2 = 10
|
||||
GUILD_BOOST_TIER_3 = 11
|
||||
CHANNEL_FOLLOW_ADD = 12
|
||||
GUILD_DISCOVERY_DISQUALIFIED = 14
|
||||
GUILD_DISCOVERY_REQUALIFIED = 15
|
||||
GUILD_DISCOVERY_GRACE_PERIOD_INITIAL_WARNING = 16
|
||||
GUILD_DISCOVERY_GRACE_PERIOD_FINAL_WARNING = 17
|
||||
THREAD_CREATED = 18
|
||||
REPLY = 19
|
||||
CHAT_INPUT_COMMAND = 20
|
||||
THREAD_STARTER_MESSAGE = 21
|
||||
GUILD_INVITE_REMINDER = 22
|
||||
CONTEXT_MENU_COMMAND = 23
|
||||
AUTO_MODERATION_ACTION = 24
|
||||
ROLE_SUBSCRIPTION_PURCHASE = 25
|
||||
INTERACTION_PREMIUM_UPSELL = 26
|
||||
STAGE_START = 27
|
||||
STAGE_END = 28
|
||||
STAGE_SPEAKER = 29
|
||||
STAGE_TOPIC = 31
|
||||
GUILD_APPLICATION_PREMIUM_SUBSCRIPTION = 32
|
||||
GUILD_INCIDENT_ALERT_MODE_ENABLED = 36
|
||||
GUILD_INCIDENT_ALERT_MODE_DISABLED = 37
|
||||
GUILD_INCIDENT_REPORT_RAID = 38
|
||||
GUILD_INCIDENT_REPORT_FALSE_ALARM = 39
|
||||
PURCHASE_NOTIFICATION = 44
|
||||
POLL_RESULT = 46
|
||||
|
||||
|
||||
NON_DELETABLE_MESSAGE_TYPES = [
|
||||
MessageType.RECIPIENT_ADD,
|
||||
MessageType.RECIPIENT_REMOVE,
|
||||
MessageType.CALL,
|
||||
MessageType.CHANNEL_NAME_CHANGE,
|
||||
MessageType.CHANNEL_ICON_CHANGE,
|
||||
MessageType.THREAD_STARTER_MESSAGE]
|
||||
|
||||
|
||||
class MessageFlags(IntFlag):
|
||||
NONE = 0
|
||||
CROSSPOSTED = 1 << 0 # this message has been published to subscribed channels (via Channel Following)
|
||||
IS_CROSSPOST = 1 << 1 # this message originated from a message in another channel (via Channel Following)
|
||||
SUPPRESS_EMBEDS = 1 << 2 # do not include any embeds when serializing this message
|
||||
SOURCE_MESSAGE_DELETED = 1 << 3 # the source message for this crosspost has been deleted (via Channel Following)
|
||||
URGENT = 1 << 4 # this message came from the urgent message system
|
||||
HAS_THREAD = 1 << 5 # this message has an associated thread, with the same id as the message
|
||||
EPHEMERAL = 1 << 6 # this message is only visible to the user who invoked the Interaction
|
||||
LOADING = 1 << 7 # this message is an Interaction Response and the bot is "thinking"
|
||||
# this message failed to mention some roles and add their members to the thread
|
||||
FAILED_TO_MENTION_SOME_ROLES_IN_THREAD = 1 << 8
|
||||
SUPPRESS_NOTIFICATIONS = 1 << 12 # this message will not trigger push and desktop notifications
|
||||
IS_VOICE_MESSAGE = 1 << 13 # this message is a voice message
|
||||
HAS_SNAPSHOT = 1 << 14 # this message has a snapshot (via Message Forwarding)
|
||||
IS_COMPONENTS_V2 = 1 << 15 # allows you to create fully component-driven messages
|
||||
|
||||
|
||||
class MessageReferenceType(Enum):
|
||||
DEFAULT = 0 # A standard reference used by replies.
|
||||
FORWARD = 1 # Reference used to point to a message at a point in time.
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageReference:
|
||||
type: MessageReferenceType | None # type of reference.
|
||||
message_id: int | None # id of the originating message
|
||||
channel_id: int | None # id of the originating message's channel
|
||||
guild_id: int | None # id of the originating message's guild
|
||||
# when sending, whether to error if the referenced message doesn't exist
|
||||
# instead of sending as a normal (non-reply) message, default true
|
||||
fail_if_not_exists: bool | None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(info: dict) -> MessageReference:
|
||||
ref_type: int | None = info.get('type')
|
||||
return MessageReference(
|
||||
type=MessageReferenceType(ref_type) if ref_type is not None else None,
|
||||
message_id=info.get('message_id'),
|
||||
channel_id=info.get('channel_id'),
|
||||
guild_id=info.get('guild_id'),
|
||||
fail_if_not_exists=info.get('fail_if_not_exists'))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message: # TODO : complete attributes
|
||||
id: int
|
||||
|
|
@ -264,10 +407,16 @@ class Message: # TODO : complete attributes
|
|||
content: str
|
||||
timestamp: datetime
|
||||
edited_timestamp: datetime | None
|
||||
attachments: list[Attachment]
|
||||
type: MessageType
|
||||
flags: MessageFlags | None
|
||||
message_reference: MessageReference | None
|
||||
|
||||
@staticmethod
|
||||
def from_dict(info: dict) -> Message:
|
||||
edited_timestamp: str | None = info.get('edited_timestamp')
|
||||
flags: int | None = info.get('flags')
|
||||
message_reference = info.get('message_reference')
|
||||
return Message(
|
||||
id=int(info['id']),
|
||||
channel_id=int(info['channel_id']),
|
||||
|
|
@ -275,7 +424,11 @@ class Message: # TODO : complete attributes
|
|||
id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)),
|
||||
content=info['content'],
|
||||
timestamp=datetime.fromisoformat(info['timestamp']),
|
||||
edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None)
|
||||
edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None,
|
||||
attachments=[Attachment.from_dict(a) for a in info['attachments']],
|
||||
type=MessageType(info['type']),
|
||||
flags=MessageFlags(flags) if flags is not None else None,
|
||||
message_reference=MessageReference.from_dict(message_reference) if message_reference is not None else None)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
44
tests/test_config.py
Normal file
44
tests/test_config.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
from breadtube_bot.config import Config
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def test_empty():
|
||||
with pytest.raises(RuntimeError, match='Cannot load config: empty input'):
|
||||
Config.from_str('')
|
||||
|
||||
|
||||
def test_wrong_header():
|
||||
with pytest.raises(RuntimeError, match='Cannot load config: first line is not "config"'):
|
||||
Config.from_str('connfig')
|
||||
|
||||
|
||||
def test_lacking_key():
|
||||
expected_config = Config(bot_role='test-role', request_timeout=5.5)
|
||||
config = Config.from_str(
|
||||
'config\n'
|
||||
f'request_timeout={expected_config.request_timeout}\n'
|
||||
f'bot_role={expected_config.bot_role}')
|
||||
assert config == expected_config
|
||||
|
||||
|
||||
def test_wrong_key():
|
||||
with pytest.raises(RuntimeError, match=('Invalid config: invalid key bot_channel_init at line 2')):
|
||||
Config.from_str('config\nrequest_timeout=3.\nbot_channel_init=2')
|
||||
|
||||
|
||||
def test_duplicated_key():
|
||||
with pytest.raises(RuntimeError, match=('Invalid config: duplicated key request_timeout at line 3')):
|
||||
Config.from_str('config\nrequest_timeout=3.\nbot_channel_init_retries=2\nrequest_timeout=5.')
|
||||
|
||||
|
||||
def test_correct_config():
|
||||
expected_config = Config(bot_channel='test-channel', bot_role='test-role', bot_channel_init_retries=298,
|
||||
request_timeout=5.5)
|
||||
config = Config.from_str(
|
||||
'config\n'
|
||||
f'request_timeout={expected_config.request_timeout}\n'
|
||||
f'bot_channel_init_retries={expected_config.bot_channel_init_retries}\n'
|
||||
f'bot_channel={expected_config.bot_channel}\n'
|
||||
f'bot_role={expected_config.bot_role}')
|
||||
assert config == expected_config
|
||||
Loading…
Add table
Add a link
Reference in a new issue