Code refactored

This commit is contained in:
BreadTube 2025-09-29 18:49:49 +09:00 committed by Corentin
commit d5b3436aec
3 changed files with 247 additions and 213 deletions

View file

@ -1,530 +0,0 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, is_dataclass
from enum import Enum
import logging
import operator
from pathlib import Path
import json
import random
import re
import time
import tomllib
from typing import Any
import urllib.error
import urllib.request
import traceback
from .api import Api, ApiAction, ApiVersion
from .config import Config
from .logger import create_logger
from .objects import (
Attachment, ChannelCategory, ChannelType, FileMime, HTTPHeaders, Message, MessageReference, MessageReferenceType,
Overwrite, OverwriteType, Permissions, Role, TextChannel, User)
from .youtube_manager import YoutubeManager
from .youtube_subscription import SUBSCRIPTION_FILE_COLUMNS, SubscriptionHelper, Subscriptions
class ApiEncoder(json.JSONEncoder):
def default(self, o):
if is_dataclass(o):
return asdict(o) # type: ignore
if isinstance(o, Enum):
return o.value
return super().default(o)
class DiscordManager:
DEFAULT_MESSAGE_LIST_LIMIT = 50
INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n'
'You can upload a new one to update the configuration.')
MAX_DOWNLOAD_SIZE: int = 50_000
MIN_API_VERSION = 9
@dataclass
class RateLimit:
remaining: int
next_reset: float
class Task(Enum):
DELETE_MESSAGES = 1
SCAN_BOT_CHANNEL = 2
INIT_SUBS = 3
@staticmethod
def _get_code_version() -> str:
pyproject_path = Path(__file__).parents[1] / 'pyproject.toml'
if not pyproject_path.exists():
raise RuntimeError('Cannot current bot version')
return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version']
def __init__(self, bot_token: str, guild_id: int, yt_api_key: str, config: Config | None = None,
log_level: int = logging.INFO) -> None:
self.config = config or Config()
self.guild_id = guild_id
self._bot_token = bot_token
self.logger = create_logger('breadtube', log_level, stdout=True)
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
self.version = self._get_code_version()
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
self.logger.info('Retrieving bot user')
self.bot_user = self.get_current_user()
self.logger.info('Retrieving guild roles before init')
self.guild_roles: list[Role] = self.list_roles()
bot_role: Role | None = None
everyone_role: Role | None = None
for role in self.guild_roles:
if role.name == self.config.bot_role:
bot_role = role
elif role.name == '@everyone':
everyone_role = role
if bot_role is None:
raise RuntimeError('No BreadTube role found')
if everyone_role is None:
raise RuntimeError('No everyone role found')
self.bot_role: Role = bot_role
self.everyone_role: Role = everyone_role
categories, text_channel = self.list_channels()
self.guild_text_channels: list[TextChannel] = text_channel
self.guild_categories: list[ChannelCategory] = categories
self.init_message: Message | None = None
bot_channel: TextChannel | None = None
for _ in range(self.config.bot_channel_init_retries):
bot_channel = self.init_bot_channel()
if bot_channel is not None:
break
time.sleep(5)
if bot_channel is None:
raise RuntimeError("Couldn't initialize bot channel/role/permission")
self.bot_channel: TextChannel = bot_channel
self._yt_subscriptions: Subscriptions = {}
self._scan_bot_channel()
self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
self.yt_manager = YoutubeManager(api_key=yt_api_key, logger=self.logger)
self.logger.info('Bot initialized')
def _update_rate_limit(self, headers: HTTPHeaders):
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
if header_key not in headers:
self.logger.info('Warning: no "%s" found in headers', header_key)
return
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
self.logger.debug('Updated rate limit: %s', self.rate_limit)
if self.rate_limit.remaining <= 0:
sleep_time = self.rate_limit.next_reset - time.time()
if sleep_time > 0:
self.logger.debug('Rate limit: sleeping %.03f second', sleep_time)
time.sleep(sleep_time)
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]:
if api_version.value < self.MIN_API_VERSION:
self.logger.warning(
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
api_version, self.MIN_API_VERSION)
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
self.logger.debug('Discord API Request: %s %s', api_action.value, url)
boundary: str = ''
if upload_files:
boundary = f'{random.randbytes(16).hex()}'
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
'Content-Type: application/json\r\n\r\n'.encode() + data
+ f'\r\n--{boundary}'.encode()) if data else b''
for file_index, (name, mime, content) in enumerate(upload_files):
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
f' filename="{name}"\r\nContent-Length: {len(content)}'
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
data += f'\r\n--{boundary}--'.encode()
request = urllib.request.Request(url, data=data)
request.method = api_action.value
request.add_header('User-Agent', f'BreadTube (v{self.version})')
request.add_header('Accept', 'application/json')
if upload_files:
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
else:
request.add_header('Content-Type', 'application/json')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
body = response.read()
headers = dict(response.getheaders())
self._update_rate_limit(headers)
return headers, json.loads(body.decode()) if body else None
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
except TimeoutError as error:
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
request = urllib.request.Request(attachment.url)
request.add_header('User-Agent', f'BreadTube (v{self.version})')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
return dict(response.getheaders()), response.read()
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error downloading attachment ({attachment}):'
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
def init_bot_channel(self) -> TextChannel | None:
for channel in self.guild_text_channels:
if channel.name == self.config.bot_channel:
self.logger.info('Found breadtube bot channel')
for perm in channel.permission_overwrites:
if perm.id == self.bot_role.id:
if not perm.allow | Permissions.VIEW_CHANNEL:
self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing')
return None
self.logger.info('BreadTube channel permission OK')
break
return channel
self.logger.info('Creating breadtube bot channel')
return self.create_text_channel({
'name': self.config.bot_channel,
'permission_overwrites': [
Overwrite(self.everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
deny=Permissions.VIEW_CHANNEL),
Overwrite(self.bot_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL,
deny=Permissions.NONE)]
})
def _get_bot_channel_messages(self) -> list[Message]:
messages_id_delete_task: set[int] = set()
for task_type, _, task_params in self.tasks:
if task_type == self.Task.DELETE_MESSAGES:
messages_id_delete_task.update(message.id for message in task_params)
last_message_id: int | None = None
messages: list[Message] = []
while True:
message_batch = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
messages.extend([m for m in message_batch if m.id not in messages_id_delete_task])
if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT:
break
last_message_id = message_batch[-1].id
return messages
def _scan_bot_channel(self): # noqa: PLR0915
messages = self._get_bot_channel_messages()
init_message_found = False
new_config: Config | None = None
new_subscriptions: Subscriptions | None = None
delayed_delete: dict[int, Message] = {}
immediate_delete: dict[int, Message] = {}
for message in messages:
if init_message_found:
self.logger.debug('Marking message for immediate deletion (init found): %s', message)
immediate_delete[message.id] = message
continue
if len(message.attachments) <= 0:
self.logger.debug('Marking message for immediate deletion (no attachment): %s', message)
immediate_delete[message.id] = message
continue
if message.author.id == self.bot_user.id:
self.logger.debug('Found init message')
# If same init message: nothing to do
if self.init_message is not None and message.id == self.init_message.id:
continue
# Loading init message content
has_error = False
for attachment in message.attachments:
try:
_, content = self._download_attachment(attachment)
if new_config is None and content.startswith(b'config'):
try:
self.config = Config.from_str(content.decode())
except RuntimeError as error:
self.logger.error('Cannot load config from init message: %s', error)
has_error = True
if new_subscriptions is None and content.startswith(SUBSCRIPTION_FILE_COLUMNS[0]):
try:
subscriptions = SubscriptionHelper.read_text(content)
if set(subscriptions.keys()) != set(self._yt_subscriptions.keys()):
SubscriptionHelper.update_subscriptions(
new=subscriptions, previous=self._yt_subscriptions)
self._yt_subscriptions = subscriptions
self.tasks.append((DiscordManager.Task.INIT_SUBS, time.time() + 1, None))
except RuntimeError as error:
self.logger.error('Invalid init subscriptions file: %s', error)
has_error = True
except Exception as error:
self.logger.error('Error downloading attachment: %s', error)
has_error = True
if not has_error:
self.init_message = message
init_message_found = True
continue
self.logger.debug('Reading attachment')
attachment = message.attachments[0]
if attachment.size > self.MAX_DOWNLOAD_SIZE:
self.logger.debug('Marking message for immediate deletion (attachment too big): %s', message)
immediate_delete[message.id] = message
continue
try:
_, content = self._download_attachment(attachment)
if content.startswith(b'config') and new_config is None:
try:
config = Config.from_str(content.decode())
if config != self.config: # New config to update to
new_config = config
self.logger.debug('Marking new config message for immediate deletion: %s', message)
immediate_delete[message.id] = message
continue
except RuntimeError as error:
self.logger.info('Invalid config file: %s', error)
bot_message = self.create_message(self.bot_channel, {
'content': str(error),
'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT,
message_id=message.id,
channel_id=self.bot_channel.id,
guild_id=None,
fail_if_not_exists=None)})
delayed_delete[bot_message.id] = bot_message
delayed_delete[message.id] = message
continue
if content.startswith(SUBSCRIPTION_FILE_COLUMNS[0]):
try:
subscriptions = SubscriptionHelper.read_text(content)
if set(subscriptions.keys()) != set(self._yt_subscriptions.keys()):
new_subscriptions = subscriptions
self.logger.debug('Marking new subscriptions message for immediate deletion: %s', message)
immediate_delete[message.id] = message
continue
except RuntimeError as error:
self.logger.info('Invalid subscriptions file: %s', error)
bot_message = self.create_message(self.bot_channel, {
'content': str(error),
'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT,
message_id=message.id,
channel_id=self.bot_channel.id,
guild_id=None,
fail_if_not_exists=None)})
delayed_delete[bot_message.id] = bot_message
delayed_delete[message.id] = message
continue
except Exception as error:
self.logger.error('Error downloading attachment: %s', error)
if new_config is not None:
self.logger.info('Loading config: %s', new_config)
self.config = new_config
if new_subscriptions is not None:
self.logger.info('Loading subscriptions: %s', new_subscriptions)
SubscriptionHelper.update_subscriptions(new=new_subscriptions, previous=self._yt_subscriptions)
self._yt_subscriptions = new_subscriptions
self.tasks.append((DiscordManager.Task.INIT_SUBS, time.time() + 1, None))
# New init message is needed, previous need to be deleted
if (new_config is not None or new_subscriptions is not None) and self.init_message is not None:
immediate_delete[self.init_message.id] = self.init_message
self.init_message = None
# Init message absent or deleted
if self.init_message is None:
assert self.config is not None
self.init_message = self.create_message(
self.bot_channel, {'content': self.INIT_MESSAGE},
upload_files=[
('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode()),
('subscriptions.csv', FileMime.TEXT_CSV, SubscriptionHelper.generate_text(self._yt_subscriptions))
])
for message in immediate_delete.values():
try:
self.delete_message(message)
except RuntimeError as error:
self.logger.error('Error deleting after bot channel scan (immediate): %s', error)
if delayed_delete:
self.tasks.append((
DiscordManager.Task.DELETE_MESSAGES,
time.time() + self.config.bot_message_duration,
list(delayed_delete.values())))
def _init_subs(self):
categories, text_channel = self.list_channels()
self.guild_text_channels = text_channel
self.guild_categories = categories
channel_dict: dict[str, TextChannel] = {c.name or '': c for c in self.guild_text_channels}
unmanaged_categories: set[str] = set(self.config.unmanaged_categories.split(','))
category_ranges: list[tuple[int, int, ChannelCategory]] = []
for category in self.guild_categories:
if category.name in unmanaged_categories:
self.logger.debug('Skipping unmanaged category: %s', category.name)
continue
range_info = (category.name or '').split('-')
if len(range_info) != 2: # noqa: PLR2004
self.logger.warning('Cannot compute range for category: %s', category.name)
continue
category_ranges.append((ord(range_info[0].lower()), ord(range_info[1].lower()), category))
category_ranges = sorted(category_ranges, key=operator.itemgetter(0))
name_regex = r'([^a-z])'
for sub_info in self._yt_subscriptions.values():
discord_name = sub_info.name.lower()
discord_name = re.sub(name_regex, '-', discord_name)
category_value = ord(discord_name[0])
sub_channel: TextChannel | None = channel_dict.get(discord_name)
if sub_channel is None:
selected_category: ChannelCategory | None = None
for start_range, stop_range, category in category_ranges:
if start_range <= category_value <= stop_range:
selected_category = category
break
if selected_category is None:
selected_category = category_ranges[-1][2]
sub_channel = self.create_text_channel({
'name': discord_name,
'parent_id': selected_category.id,
'permission_overwrites': [
Overwrite(self.everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
deny=Permissions.SEND_MESSAGES),
Overwrite(self.bot_role.id, OverwriteType.ROLE,
allow=Permissions.VIEW_CHANNEL | Permissions.SEND_MESSAGES,
deny=Permissions.NONE)]
})
if sub_info.channel_info is None:
_, channel_info = self.yt_manager.request_channel_info(
sub_info.channel_id, request_timeout=self.config.request_timeout)
if not channel_info.items:
self.logger.error('No channel info return from YouTube API for channel: %s', discord_name)
continue
sub_info.channel_info = channel_info.items[0].snippet
channel_url = f'https://www.youtube.com/{sub_info.channel_info.customUrl}'
_ = self.create_message(sub_channel, {'content': channel_url})
sub_info.last_update = time.time()
def run(self):
while True:
if self.tasks:
self.tasks = sorted(self.tasks, key=operator.itemgetter(1), reverse=True)
task_type, task_time, task_params = self.tasks.pop()
sleep_time = task_time - time.time()
self.logger.debug(
'Next task %s at %.03f (sleeping for %.03fs) : %s', task_type, task_time, sleep_time, task_params)
if sleep_time > 0:
time.sleep(sleep_time)
match task_type:
case DiscordManager.Task.DELETE_MESSAGES:
if not isinstance(task_params, list):
self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params)
elif not task_params:
self.logger.error('Empty params for DELETE_MESSAGES: %s', task_params)
elif any(not isinstance(v, Message) for v in task_params):
self.logger.error('All params not int for DELETE_MESSAGES: %s', task_params)
else:
for message in task_params:
try:
self.delete_message(message)
except Exception as error:
self.logger.error('Error deleting message %s: %s -> %s',
message, error, traceback.format_exc().replace('\n', ' | '))
case DiscordManager.Task.SCAN_BOT_CHANNEL:
try:
self._scan_bot_channel()
except Exception as error:
self.logger.error('Error scanning bot channel: %s -> %s',
error, traceback.format_exc().replace('\n', ' | '))
self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
case DiscordManager.Task.INIT_SUBS:
try:
self._init_subs()
except Exception as error:
self.logger.error('Error initializing subscriptions : %s -> %s',
error, traceback.format_exc().replace('\n', ' | '))
time.sleep(1)
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
_, channel_info = self._send_request(
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
expected_code=201)
if not isinstance(channel_info, dict):
raise RuntimeError(f'Error creating channel with params (no info): {params}')
return TextChannel.from_dict(channel_info)
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
_, message_info = self._send_request(
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
upload_files=upload_files)
if not isinstance(message_info, dict):
raise RuntimeError(f'Error creating message with params (no info): {params}')
return Message.from_dict(message_info)
def delete_message(self, message: Message):
_, _ = self._send_request(
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
def get_current_user(self) -> User:
_, user_info = self._send_request(*Api.User.get_current())
if not isinstance(user_info, dict):
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
return User.from_dict(user_info)
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
_, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
categories: list[ChannelCategory] = []
text_channels: list[TextChannel] = []
if channels_info is not None:
for channel_info in channels_info:
channel_type = ChannelType(channel_info['type'])
match channel_type:
case ChannelType.GUILD_CATEGORY:
categories.append(ChannelCategory.from_dict(channel_info))
case ChannelType.GUILD_TEXT:
text_channels.append(TextChannel.from_dict(channel_info))
return categories, text_channels
def list_roles(self) -> list[Role]:
_, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
if not isinstance(roles_info, list):
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
return [Role.from_dict(r) for r in roles_info]
def list_text_channel_messages(self, channel: TextChannel, limit: int | None = None, before_id: int | None = None,
after_id: int | None = None) -> list[Message]:
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
params: Api.Message.ListMessageParams = {}
if limit is not None:
params['limit'] = limit
if before_id is not None:
params['before'] = before_id
if after_id is not None:
params['after'] = after_id
_, messages_info = self._send_request(
*Api.Message.list_by_channel(channel.id),
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
return [Message.from_dict(m) for m in messages_info or []]