Code refactored

This commit is contained in:
BreadTube 2025-09-29 18:49:49 +09:00 committed by Corentin
commit d5b3436aec
3 changed files with 247 additions and 213 deletions

View file

@ -1,51 +1,30 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import asdict, dataclass, is_dataclass
from enum import Enum from enum import Enum
import logging import logging
import operator import operator
from pathlib import Path from pathlib import Path
import json
import random
import re import re
import time import time
import tomllib import tomllib
from typing import Any from typing import Any
import urllib.error
import urllib.request
import traceback import traceback
from .api import Api, ApiAction, ApiVersion
from .config import Config from .config import Config
from .discord_manager import DiscordManager
from .logger import create_logger from .logger import create_logger
from .objects import ( from .objects import (ChannelCategory, FileMime, Message, MessageReference, MessageReferenceType, Overwrite,
Attachment, ChannelCategory, ChannelType, FileMime, HTTPHeaders, Message, MessageReference, MessageReferenceType, OverwriteType, Permissions, Role, TextChannel)
Overwrite, OverwriteType, Permissions, Role, TextChannel, User)
from .youtube_manager import YoutubeManager from .youtube_manager import YoutubeManager
from .youtube_subscription import SUBSCRIPTION_FILE_COLUMNS, SubscriptionHelper, Subscriptions from .youtube_subscription import SUBSCRIPTION_FILE_COLUMNS, SubscriptionHelper, Subscriptions
class ApiEncoder(json.JSONEncoder): class Bot:
def default(self, o):
if is_dataclass(o):
return asdict(o) # type: ignore
if isinstance(o, Enum):
return o.value
return super().default(o)
class DiscordManager:
DEFAULT_MESSAGE_LIST_LIMIT = 50 DEFAULT_MESSAGE_LIST_LIMIT = 50
INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n' INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n'
'You can upload a new one to update the configuration.') 'You can upload a new one to update the configuration.')
MAX_DOWNLOAD_SIZE: int = 50_000 MAX_DOWNLOAD_SIZE: int = 50_000
MIN_API_VERSION = 9
@dataclass
class RateLimit:
remaining: int
next_reset: float
class Task(Enum): class Task(Enum):
DELETE_MESSAGES = 1 DELETE_MESSAGES = 1
@ -60,20 +39,20 @@ class DiscordManager:
return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version'] return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version']
def __init__(self, bot_token: str, guild_id: int, yt_api_key: str, config: Config | None = None, def __init__(self, bot_token: str, guild_id: int, yt_api_key: str, config: Config | None = None,
log_level: int = logging.INFO) -> None: log_level: int = logging.INFO):
self.config = config or Config() self.config = config or Config()
self.guild_id = guild_id self.guild_id = guild_id
self._bot_token = bot_token
self.logger = create_logger('breadtube', log_level, stdout=True) self.logger = create_logger('breadtube', log_level, stdout=True)
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
self.version = self._get_code_version() self.version = self._get_code_version()
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = [] self.discord_manager = DiscordManager(bot_token=bot_token, bot_version=self.version, logger=self.logger)
self.tasks: list[tuple[Bot.Task, float, Any]] = []
self.logger.info('Retrieving bot user') self.logger.info('Retrieving bot user')
self.bot_user = self.get_current_user() self.bot_user = self.discord_manager.get_current_user(request_timeout=self.config.request_timeout)
self.logger.info('Retrieving guild roles before init') self.logger.info('Retrieving guild roles before init')
self.guild_roles: list[Role] = self.list_roles() self.guild_roles: list[Role] = self.discord_manager.list_roles(
self.guild_id, request_timeout=self.config.request_timeout)
bot_role: Role | None = None bot_role: Role | None = None
everyone_role: Role | None = None everyone_role: Role | None = None
for role in self.guild_roles: for role in self.guild_roles:
@ -88,7 +67,8 @@ class DiscordManager:
self.bot_role: Role = bot_role self.bot_role: Role = bot_role
self.everyone_role: Role = everyone_role self.everyone_role: Role = everyone_role
categories, text_channel = self.list_channels() categories, text_channel = self.discord_manager.list_channels(
self.guild_id, request_timeout=self.config.request_timeout)
self.guild_text_channels: list[TextChannel] = text_channel self.guild_text_channels: list[TextChannel] = text_channel
self.guild_categories: list[ChannelCategory] = categories self.guild_categories: list[ChannelCategory] = categories
self.init_message: Message | None = None self.init_message: Message | None = None
@ -111,84 +91,6 @@ class DiscordManager:
self.yt_manager = YoutubeManager(api_key=yt_api_key, logger=self.logger) self.yt_manager = YoutubeManager(api_key=yt_api_key, logger=self.logger)
self.logger.info('Bot initialized') self.logger.info('Bot initialized')
def _update_rate_limit(self, headers: HTTPHeaders):
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
if header_key not in headers:
self.logger.info('Warning: no "%s" found in headers', header_key)
return
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
self.logger.debug('Updated rate limit: %s', self.rate_limit)
if self.rate_limit.remaining <= 0:
sleep_time = self.rate_limit.next_reset - time.time()
if sleep_time > 0:
self.logger.debug('Rate limit: sleeping %.03f second', sleep_time)
time.sleep(sleep_time)
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]:
if api_version.value < self.MIN_API_VERSION:
self.logger.warning(
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
api_version, self.MIN_API_VERSION)
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
self.logger.debug('Discord API Request: %s %s', api_action.value, url)
boundary: str = ''
if upload_files:
boundary = f'{random.randbytes(16).hex()}'
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
'Content-Type: application/json\r\n\r\n'.encode() + data
+ f'\r\n--{boundary}'.encode()) if data else b''
for file_index, (name, mime, content) in enumerate(upload_files):
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
f' filename="{name}"\r\nContent-Length: {len(content)}'
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
data += f'\r\n--{boundary}--'.encode()
request = urllib.request.Request(url, data=data)
request.method = api_action.value
request.add_header('User-Agent', f'BreadTube (v{self.version})')
request.add_header('Accept', 'application/json')
if upload_files:
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
else:
request.add_header('Content-Type', 'application/json')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
body = response.read()
headers = dict(response.getheaders())
self._update_rate_limit(headers)
return headers, json.loads(body.decode()) if body else None
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
except TimeoutError as error:
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
request = urllib.request.Request(attachment.url)
request.add_header('User-Agent', f'BreadTube (v{self.version})')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
return dict(response.getheaders()), response.read()
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error downloading attachment ({attachment}):'
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
def init_bot_channel(self) -> TextChannel | None: def init_bot_channel(self) -> TextChannel | None:
for channel in self.guild_text_channels: for channel in self.guild_text_channels:
if channel.name == self.config.bot_channel: if channel.name == self.config.bot_channel:
@ -203,14 +105,15 @@ class DiscordManager:
return channel return channel
self.logger.info('Creating breadtube bot channel') self.logger.info('Creating breadtube bot channel')
return self.create_text_channel({ return self.discord_manager.create_text_channel(
self.guild_id, {
'name': self.config.bot_channel, 'name': self.config.bot_channel,
'permission_overwrites': [ 'permission_overwrites': [
Overwrite(self.everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE, Overwrite(self.everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
deny=Permissions.VIEW_CHANNEL), deny=Permissions.VIEW_CHANNEL),
Overwrite(self.bot_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL, Overwrite(self.bot_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL,
deny=Permissions.NONE)] deny=Permissions.NONE)]},
}) request_timeout=self.config.request_timeout)
def _get_bot_channel_messages(self) -> list[Message]: def _get_bot_channel_messages(self) -> list[Message]:
messages_id_delete_task: set[int] = set() messages_id_delete_task: set[int] = set()
@ -221,7 +124,8 @@ class DiscordManager:
last_message_id: int | None = None last_message_id: int | None = None
messages: list[Message] = [] messages: list[Message] = []
while True: while True:
message_batch = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id) message_batch = self.discord_manager.list_text_channel_messages(
self.bot_channel, request_timeout=self.config.request_timeout, after_id=last_message_id)
messages.extend([m for m in message_batch if m.id not in messages_id_delete_task]) messages.extend([m for m in message_batch if m.id not in messages_id_delete_task])
if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT: if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT:
break break
@ -255,7 +159,8 @@ class DiscordManager:
has_error = False has_error = False
for attachment in message.attachments: for attachment in message.attachments:
try: try:
_, content = self._download_attachment(attachment) _, content = self.discord_manager.download_attachment(
attachment, request_timeout=self.config.request_timeout)
if new_config is None and content.startswith(b'config'): if new_config is None and content.startswith(b'config'):
try: try:
self.config = Config.from_str(content.decode()) self.config = Config.from_str(content.decode())
@ -269,7 +174,7 @@ class DiscordManager:
SubscriptionHelper.update_subscriptions( SubscriptionHelper.update_subscriptions(
new=subscriptions, previous=self._yt_subscriptions) new=subscriptions, previous=self._yt_subscriptions)
self._yt_subscriptions = subscriptions self._yt_subscriptions = subscriptions
self.tasks.append((DiscordManager.Task.INIT_SUBS, time.time() + 1, None)) self.tasks.append((Bot.Task.INIT_SUBS, time.time() + 1, None))
except RuntimeError as error: except RuntimeError as error:
self.logger.error('Invalid init subscriptions file: %s', error) self.logger.error('Invalid init subscriptions file: %s', error)
has_error = True has_error = True
@ -288,7 +193,8 @@ class DiscordManager:
immediate_delete[message.id] = message immediate_delete[message.id] = message
continue continue
try: try:
_, content = self._download_attachment(attachment) _, content = self.discord_manager.download_attachment(
attachment, request_timeout=self.config.request_timeout)
if content.startswith(b'config') and new_config is None: if content.startswith(b'config') and new_config is None:
try: try:
config = Config.from_str(content.decode()) config = Config.from_str(content.decode())
@ -299,14 +205,14 @@ class DiscordManager:
continue continue
except RuntimeError as error: except RuntimeError as error:
self.logger.info('Invalid config file: %s', error) self.logger.info('Invalid config file: %s', error)
bot_message = self.create_message(self.bot_channel, { bot_message = self.discord_manager.create_message(self.bot_channel, {
'content': str(error), 'content': str(error),
'message_reference': MessageReference( 'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT, type=MessageReferenceType.DEFAULT,
message_id=message.id, message_id=message.id,
channel_id=self.bot_channel.id, channel_id=self.bot_channel.id,
guild_id=None, guild_id=None,
fail_if_not_exists=None)}) fail_if_not_exists=None)}, request_timeout=self.config.request_timeout)
delayed_delete[bot_message.id] = bot_message delayed_delete[bot_message.id] = bot_message
delayed_delete[message.id] = message delayed_delete[message.id] = message
continue continue
@ -320,14 +226,14 @@ class DiscordManager:
continue continue
except RuntimeError as error: except RuntimeError as error:
self.logger.info('Invalid subscriptions file: %s', error) self.logger.info('Invalid subscriptions file: %s', error)
bot_message = self.create_message(self.bot_channel, { bot_message = self.discord_manager.create_message(self.bot_channel, {
'content': str(error), 'content': str(error),
'message_reference': MessageReference( 'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT, type=MessageReferenceType.DEFAULT,
message_id=message.id, message_id=message.id,
channel_id=self.bot_channel.id, channel_id=self.bot_channel.id,
guild_id=None, guild_id=None,
fail_if_not_exists=None)}) fail_if_not_exists=None)}, request_timeout=self.config.request_timeout)
delayed_delete[bot_message.id] = bot_message delayed_delete[bot_message.id] = bot_message
delayed_delete[message.id] = message delayed_delete[message.id] = message
continue continue
@ -342,7 +248,7 @@ class DiscordManager:
self.logger.info('Loading subscriptions: %s', new_subscriptions) self.logger.info('Loading subscriptions: %s', new_subscriptions)
SubscriptionHelper.update_subscriptions(new=new_subscriptions, previous=self._yt_subscriptions) SubscriptionHelper.update_subscriptions(new=new_subscriptions, previous=self._yt_subscriptions)
self._yt_subscriptions = new_subscriptions self._yt_subscriptions = new_subscriptions
self.tasks.append((DiscordManager.Task.INIT_SUBS, time.time() + 1, None)) self.tasks.append((Bot.Task.INIT_SUBS, time.time() + 1, None))
# New init message is needed, previous need to be deleted # New init message is needed, previous need to be deleted
if (new_config is not None or new_subscriptions is not None) and self.init_message is not None: if (new_config is not None or new_subscriptions is not None) and self.init_message is not None:
@ -352,8 +258,8 @@ class DiscordManager:
# Init message absent or deleted # Init message absent or deleted
if self.init_message is None: if self.init_message is None:
assert self.config is not None assert self.config is not None
self.init_message = self.create_message( self.init_message = self.discord_manager.create_message(
self.bot_channel, {'content': self.INIT_MESSAGE}, self.bot_channel, {'content': self.INIT_MESSAGE}, request_timeout=self.config.request_timeout,
upload_files=[ upload_files=[
('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode()), ('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode()),
('subscriptions.csv', FileMime.TEXT_CSV, SubscriptionHelper.generate_text(self._yt_subscriptions)) ('subscriptions.csv', FileMime.TEXT_CSV, SubscriptionHelper.generate_text(self._yt_subscriptions))
@ -361,18 +267,19 @@ class DiscordManager:
for message in immediate_delete.values(): for message in immediate_delete.values():
try: try:
self.delete_message(message) self.discord_manager.delete_message(message, request_timeout=self.config.request_timeout)
except RuntimeError as error: except RuntimeError as error:
self.logger.error('Error deleting after bot channel scan (immediate): %s', error) self.logger.error('Error deleting after bot channel scan (immediate): %s', error)
if delayed_delete: if delayed_delete:
self.tasks.append(( self.tasks.append((
DiscordManager.Task.DELETE_MESSAGES, Bot.Task.DELETE_MESSAGES,
time.time() + self.config.bot_message_duration, time.time() + self.config.bot_message_duration,
list(delayed_delete.values()))) list(delayed_delete.values())))
def _init_subs(self): def _init_subs(self):
categories, text_channel = self.list_channels() categories, text_channel = self.discord_manager.list_channels(
self.guild_id, request_timeout=self.config.request_timeout)
self.guild_text_channels = text_channel self.guild_text_channels = text_channel
self.guild_categories = categories self.guild_categories = categories
@ -404,7 +311,8 @@ class DiscordManager:
break break
if selected_category is None: if selected_category is None:
selected_category = category_ranges[-1][2] selected_category = category_ranges[-1][2]
sub_channel = self.create_text_channel({ sub_channel = self.discord_manager.create_text_channel(
self.guild_id, {
'name': discord_name, 'name': discord_name,
'parent_id': selected_category.id, 'parent_id': selected_category.id,
'permission_overwrites': [ 'permission_overwrites': [
@ -412,8 +320,8 @@ class DiscordManager:
deny=Permissions.SEND_MESSAGES), deny=Permissions.SEND_MESSAGES),
Overwrite(self.bot_role.id, OverwriteType.ROLE, Overwrite(self.bot_role.id, OverwriteType.ROLE,
allow=Permissions.VIEW_CHANNEL | Permissions.SEND_MESSAGES, allow=Permissions.VIEW_CHANNEL | Permissions.SEND_MESSAGES,
deny=Permissions.NONE)] deny=Permissions.NONE)]},
}) request_timeout=self.config.request_timeout)
if sub_info.channel_info is None: if sub_info.channel_info is None:
_, channel_info = self.yt_manager.request_channel_info( _, channel_info = self.yt_manager.request_channel_info(
sub_info.channel_id, request_timeout=self.config.request_timeout) sub_info.channel_id, request_timeout=self.config.request_timeout)
@ -422,7 +330,8 @@ class DiscordManager:
continue continue
sub_info.channel_info = channel_info.items[0].snippet sub_info.channel_info = channel_info.items[0].snippet
channel_url = f'https://www.youtube.com/{sub_info.channel_info.customUrl}' channel_url = f'https://www.youtube.com/{sub_info.channel_info.customUrl}'
_ = self.create_message(sub_channel, {'content': channel_url}) _ = self.discord_manager.create_message(
sub_channel, {'content': channel_url}, request_timeout=self.config.request_timeout)
sub_info.last_update = time.time() sub_info.last_update = time.time()
def run(self): def run(self):
@ -436,7 +345,7 @@ class DiscordManager:
if sleep_time > 0: if sleep_time > 0:
time.sleep(sleep_time) time.sleep(sleep_time)
match task_type: match task_type:
case DiscordManager.Task.DELETE_MESSAGES: case Bot.Task.DELETE_MESSAGES:
if not isinstance(task_params, list): if not isinstance(task_params, list):
self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params) self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params)
elif not task_params: elif not task_params:
@ -446,11 +355,12 @@ class DiscordManager:
else: else:
for message in task_params: for message in task_params:
try: try:
self.delete_message(message) self.discord_manager.delete_message(
message, request_timeout=self.config.request_timeout)
except Exception as error: except Exception as error:
self.logger.error('Error deleting message %s: %s -> %s', self.logger.error('Error deleting message %s: %s -> %s',
message, error, traceback.format_exc().replace('\n', ' | ')) message, error, traceback.format_exc().replace('\n', ' | '))
case DiscordManager.Task.SCAN_BOT_CHANNEL: case Bot.Task.SCAN_BOT_CHANNEL:
try: try:
self._scan_bot_channel() self._scan_bot_channel()
except Exception as error: except Exception as error:
@ -458,73 +368,10 @@ class DiscordManager:
error, traceback.format_exc().replace('\n', ' | ')) error, traceback.format_exc().replace('\n', ' | '))
self.tasks.append(( self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None)) self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
case DiscordManager.Task.INIT_SUBS: case Bot.Task.INIT_SUBS:
try: try:
self._init_subs() self._init_subs()
except Exception as error: except Exception as error:
self.logger.error('Error initializing subscriptions : %s -> %s', self.logger.error('Error initializing subscriptions : %s -> %s',
error, traceback.format_exc().replace('\n', ' | ')) error, traceback.format_exc().replace('\n', ' | '))
time.sleep(1) time.sleep(1)
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
_, channel_info = self._send_request(
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
expected_code=201)
if not isinstance(channel_info, dict):
raise RuntimeError(f'Error creating channel with params (no info): {params}')
return TextChannel.from_dict(channel_info)
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
_, message_info = self._send_request(
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
upload_files=upload_files)
if not isinstance(message_info, dict):
raise RuntimeError(f'Error creating message with params (no info): {params}')
return Message.from_dict(message_info)
def delete_message(self, message: Message):
_, _ = self._send_request(
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
def get_current_user(self) -> User:
_, user_info = self._send_request(*Api.User.get_current())
if not isinstance(user_info, dict):
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
return User.from_dict(user_info)
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
_, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
categories: list[ChannelCategory] = []
text_channels: list[TextChannel] = []
if channels_info is not None:
for channel_info in channels_info:
channel_type = ChannelType(channel_info['type'])
match channel_type:
case ChannelType.GUILD_CATEGORY:
categories.append(ChannelCategory.from_dict(channel_info))
case ChannelType.GUILD_TEXT:
text_channels.append(TextChannel.from_dict(channel_info))
return categories, text_channels
def list_roles(self) -> list[Role]:
_, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
if not isinstance(roles_info, list):
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
return [Role.from_dict(r) for r in roles_info]
def list_text_channel_messages(self, channel: TextChannel, limit: int | None = None, before_id: int | None = None,
after_id: int | None = None) -> list[Message]:
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
params: Api.Message.ListMessageParams = {}
if limit is not None:
params['limit'] = limit
if before_id is not None:
params['before'] = before_id
if after_id is not None:
params['after'] = after_id
_, messages_info = self._send_request(
*Api.Message.list_by_channel(channel.id),
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
return [Message.from_dict(m) for m in messages_info or []]

View file

@ -0,0 +1,187 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, is_dataclass
from enum import Enum
import json
import random
import time
import urllib.error
import urllib.request
from .api import Api, ApiAction, ApiVersion
from .objects import Attachment, ChannelCategory, ChannelType, FileMime, HTTPHeaders, Message, Role, TextChannel, User
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import logging
class ApiEncoder(json.JSONEncoder):
def default(self, o):
if is_dataclass(o):
return asdict(o) # type: ignore
if isinstance(o, Enum):
return o.value
return super().default(o)
class DiscordManager:
MIN_API_VERSION = 9
@dataclass
class RateLimit:
remaining: int
next_reset: float
def __init__(self, bot_token: str, bot_version: str, logger: logging.Logger):
self._bot_token = bot_token
self._logger = logger
self._version = bot_version
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
def _update_rate_limit(self, headers: HTTPHeaders):
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
if header_key not in headers:
self._logger.info('Warning: no "%s" found in headers', header_key)
return
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
self._logger.debug('Updated rate limit: %s', self.rate_limit)
if self.rate_limit.remaining <= 0:
sleep_time = self.rate_limit.next_reset - time.time()
if sleep_time > 0:
self._logger.debug('Rate limit: sleeping %.03f second', sleep_time)
time.sleep(sleep_time)
def _send_request(
self, api_action: ApiAction, endpoint: str, request_timeout: float, data: bytes | None = None,
expected_code: int = 200, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
api_version: ApiVersion = ApiVersion.V10) -> tuple[HTTPHeaders, dict | list | None]:
if api_version.value < self.MIN_API_VERSION:
self._logger.warning(
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
api_version, self.MIN_API_VERSION)
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
self._logger.debug('Discord API Request: %s %s', api_action.value, url)
boundary: str = ''
if upload_files:
boundary = f'{random.randbytes(16).hex()}'
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
'Content-Type: application/json\r\n\r\n'.encode() + data
+ f'\r\n--{boundary}'.encode()) if data else b''
for file_index, (name, mime, content) in enumerate(upload_files):
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
f' filename="{name}"\r\nContent-Length: {len(content)}'
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
data += f'\r\n--{boundary}--'.encode()
request = urllib.request.Request(url, data=data)
request.method = api_action.value
request.add_header('User-Agent', f'BreadTube (v{self._version})')
request.add_header('Accept', 'application/json')
if upload_files:
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
else:
request.add_header('Content-Type', 'application/json')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
body = response.read()
headers = dict(response.getheaders())
self._update_rate_limit(headers)
return headers, json.loads(body.decode()) if body else None
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
except TimeoutError as error:
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
def download_attachment(self, attachment: Attachment, request_timeout: float,
expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
request = urllib.request.Request(attachment.url)
request.add_header('User-Agent', f'BreadTube (v{self._version})')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
return dict(response.getheaders()), response.read()
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error downloading attachment ({attachment}):'
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
def create_text_channel(
self, guild_id: int, params: Api.Guild.CreateTextChannelParams, request_timeout: float) -> TextChannel:
_, channel_info = self._send_request(
*Api.Guild.create_channel(guild_id=guild_id), request_timeout=request_timeout,
data=json.dumps(params, cls=ApiEncoder).encode(), expected_code=201)
if not isinstance(channel_info, dict):
raise RuntimeError(f'Error creating channel with params (no info): {params}')
return TextChannel.from_dict(channel_info)
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams, request_timeout: float,
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
_, message_info = self._send_request(
*Api.Message.create(channel_id=channel.id), request_timeout=request_timeout,
data=json.dumps(params, cls=ApiEncoder).encode(), upload_files=upload_files)
if not isinstance(message_info, dict):
raise RuntimeError(f'Error creating message with params (no info): {params}')
return Message.from_dict(message_info)
def delete_message(self, message: Message, request_timeout: float):
_, _ = self._send_request(
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), request_timeout=request_timeout,
expected_code=204)
def get_current_user(self, request_timeout: float) -> User:
_, user_info = self._send_request(*Api.User.get_current(), request_timeout=request_timeout)
if not isinstance(user_info, dict):
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
return User.from_dict(user_info)
def list_channels(self, guild_id: int, request_timeout: float) -> tuple[list[ChannelCategory], list[TextChannel]]:
_, channels_info = self._send_request(*Api.Guild.list_guilds(guild_id), request_timeout=request_timeout)
categories: list[ChannelCategory] = []
text_channels: list[TextChannel] = []
if channels_info is not None:
for channel_info in channels_info:
channel_type = ChannelType(channel_info['type'])
match channel_type:
case ChannelType.GUILD_CATEGORY:
categories.append(ChannelCategory.from_dict(channel_info))
case ChannelType.GUILD_TEXT:
text_channels.append(TextChannel.from_dict(channel_info))
return categories, text_channels
def list_roles(self, guild_id: int, request_timeout: float) -> list[Role]:
_, roles_info = self._send_request(*Api.Guild.list_roles(guild_id), request_timeout=request_timeout)
if not isinstance(roles_info, list):
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
return [Role.from_dict(r) for r in roles_info]
def list_text_channel_messages(
self, channel: TextChannel, request_timeout: float, limit: int | None = None, before_id: int | None = None,
after_id: int | None = None) -> list[Message]:
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
params: Api.Message.ListMessageParams = {}
if limit is not None:
params['limit'] = limit
if before_id is not None:
params['before'] = before_id
if after_id is not None:
params['after'] = after_id
_, messages_info = self._send_request(
*Api.Message.list_by_channel(channel.id), request_timeout=request_timeout,
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
return [Message.from_dict(m) for m in messages_info or []]

View file

@ -2,7 +2,7 @@ from argparse import ArgumentParser
import logging import logging
from pathlib import Path from pathlib import Path
from breadtube_bot.manager import DiscordManager from breadtube_bot.bot import Bot
def main(): def main():
@ -17,7 +17,7 @@ def main():
bot_token = Path('data/discord_bot_token.txt').read_text(encoding='utf-8').strip() bot_token = Path('data/discord_bot_token.txt').read_text(encoding='utf-8').strip()
yt_api_key = Path('data/google_api_key.txt').read_text(encoding='utf-8').strip() yt_api_key = Path('data/google_api_key.txt').read_text(encoding='utf-8').strip()
manager = DiscordManager(bot_token=bot_token, guild_id=guild_id, yt_api_key=yt_api_key, manager = Bot(bot_token=bot_token, guild_id=guild_id, yt_api_key=yt_api_key,
log_level=logging.DEBUG if debug_mode else logging.INFO) log_level=logging.DEBUG if debug_mode else logging.INFO)
try: try:
manager.run() manager.run()