Code refactored
This commit is contained in:
parent
b80e4f7745
commit
d5b3436aec
3 changed files with 247 additions and 213 deletions
|
|
@ -1,51 +1,30 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
import operator
|
||||
from pathlib import Path
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
import tomllib
|
||||
from typing import Any
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
import traceback
|
||||
|
||||
|
||||
from .api import Api, ApiAction, ApiVersion
|
||||
from .config import Config
|
||||
from .discord_manager import DiscordManager
|
||||
from .logger import create_logger
|
||||
from .objects import (
|
||||
Attachment, ChannelCategory, ChannelType, FileMime, HTTPHeaders, Message, MessageReference, MessageReferenceType,
|
||||
Overwrite, OverwriteType, Permissions, Role, TextChannel, User)
|
||||
from .objects import (ChannelCategory, FileMime, Message, MessageReference, MessageReferenceType, Overwrite,
|
||||
OverwriteType, Permissions, Role, TextChannel)
|
||||
from .youtube_manager import YoutubeManager
|
||||
from .youtube_subscription import SUBSCRIPTION_FILE_COLUMNS, SubscriptionHelper, Subscriptions
|
||||
|
||||
|
||||
class ApiEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if is_dataclass(o):
|
||||
return asdict(o) # type: ignore
|
||||
if isinstance(o, Enum):
|
||||
return o.value
|
||||
return super().default(o)
|
||||
|
||||
|
||||
class DiscordManager:
|
||||
class Bot:
|
||||
DEFAULT_MESSAGE_LIST_LIMIT = 50
|
||||
INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n'
|
||||
'You can upload a new one to update the configuration.')
|
||||
MAX_DOWNLOAD_SIZE: int = 50_000
|
||||
MIN_API_VERSION = 9
|
||||
|
||||
@dataclass
|
||||
class RateLimit:
|
||||
remaining: int
|
||||
next_reset: float
|
||||
|
||||
class Task(Enum):
|
||||
DELETE_MESSAGES = 1
|
||||
|
|
@ -60,20 +39,20 @@ class DiscordManager:
|
|||
return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version']
|
||||
|
||||
def __init__(self, bot_token: str, guild_id: int, yt_api_key: str, config: Config | None = None,
|
||||
log_level: int = logging.INFO) -> None:
|
||||
log_level: int = logging.INFO):
|
||||
self.config = config or Config()
|
||||
self.guild_id = guild_id
|
||||
self._bot_token = bot_token
|
||||
self.logger = create_logger('breadtube', log_level, stdout=True)
|
||||
|
||||
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
|
||||
self.version = self._get_code_version()
|
||||
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
|
||||
self.discord_manager = DiscordManager(bot_token=bot_token, bot_version=self.version, logger=self.logger)
|
||||
self.tasks: list[tuple[Bot.Task, float, Any]] = []
|
||||
|
||||
self.logger.info('Retrieving bot user')
|
||||
self.bot_user = self.get_current_user()
|
||||
self.bot_user = self.discord_manager.get_current_user(request_timeout=self.config.request_timeout)
|
||||
self.logger.info('Retrieving guild roles before init')
|
||||
self.guild_roles: list[Role] = self.list_roles()
|
||||
self.guild_roles: list[Role] = self.discord_manager.list_roles(
|
||||
self.guild_id, request_timeout=self.config.request_timeout)
|
||||
bot_role: Role | None = None
|
||||
everyone_role: Role | None = None
|
||||
for role in self.guild_roles:
|
||||
|
|
@ -88,7 +67,8 @@ class DiscordManager:
|
|||
self.bot_role: Role = bot_role
|
||||
self.everyone_role: Role = everyone_role
|
||||
|
||||
categories, text_channel = self.list_channels()
|
||||
categories, text_channel = self.discord_manager.list_channels(
|
||||
self.guild_id, request_timeout=self.config.request_timeout)
|
||||
self.guild_text_channels: list[TextChannel] = text_channel
|
||||
self.guild_categories: list[ChannelCategory] = categories
|
||||
self.init_message: Message | None = None
|
||||
|
|
@ -111,84 +91,6 @@ class DiscordManager:
|
|||
self.yt_manager = YoutubeManager(api_key=yt_api_key, logger=self.logger)
|
||||
self.logger.info('Bot initialized')
|
||||
|
||||
def _update_rate_limit(self, headers: HTTPHeaders):
|
||||
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
|
||||
if header_key not in headers:
|
||||
self.logger.info('Warning: no "%s" found in headers', header_key)
|
||||
return
|
||||
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
|
||||
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
|
||||
self.logger.debug('Updated rate limit: %s', self.rate_limit)
|
||||
if self.rate_limit.remaining <= 0:
|
||||
sleep_time = self.rate_limit.next_reset - time.time()
|
||||
if sleep_time > 0:
|
||||
self.logger.debug('Rate limit: sleeping %.03f second', sleep_time)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
|
||||
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
|
||||
expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]:
|
||||
if api_version.value < self.MIN_API_VERSION:
|
||||
self.logger.warning(
|
||||
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
|
||||
api_version, self.MIN_API_VERSION)
|
||||
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
|
||||
self.logger.debug('Discord API Request: %s %s', api_action.value, url)
|
||||
|
||||
boundary: str = ''
|
||||
if upload_files:
|
||||
boundary = f'{random.randbytes(16).hex()}'
|
||||
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
|
||||
'Content-Type: application/json\r\n\r\n'.encode() + data
|
||||
+ f'\r\n--{boundary}'.encode()) if data else b''
|
||||
for file_index, (name, mime, content) in enumerate(upload_files):
|
||||
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
|
||||
f' filename="{name}"\r\nContent-Length: {len(content)}'
|
||||
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
|
||||
data += f'\r\n--{boundary}--'.encode()
|
||||
request = urllib.request.Request(url, data=data)
|
||||
request.method = api_action.value
|
||||
request.add_header('User-Agent', f'BreadTube (v{self.version})')
|
||||
request.add_header('Accept', 'application/json')
|
||||
if upload_files:
|
||||
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
|
||||
else:
|
||||
request.add_header('Content-Type', 'application/json')
|
||||
request.add_header('Authorization', f'Bot {self._bot_token}')
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
|
||||
if response.status != expected_code:
|
||||
raise RuntimeError(
|
||||
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
|
||||
body = response.read()
|
||||
headers = dict(response.getheaders())
|
||||
self._update_rate_limit(headers)
|
||||
return headers, json.loads(body.decode()) if body else None
|
||||
except urllib.error.HTTPError as error:
|
||||
raise RuntimeError(
|
||||
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
||||
except urllib.error.URLError as error:
|
||||
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
|
||||
except TimeoutError as error:
|
||||
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
|
||||
|
||||
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
|
||||
request = urllib.request.Request(attachment.url)
|
||||
request.add_header('User-Agent', f'BreadTube (v{self.version})')
|
||||
request.add_header('Authorization', f'Bot {self._bot_token}')
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
|
||||
if response.status != expected_code:
|
||||
raise RuntimeError(
|
||||
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
|
||||
return dict(response.getheaders()), response.read()
|
||||
except urllib.error.HTTPError as error:
|
||||
raise RuntimeError(
|
||||
f'HTTP error downloading attachment ({attachment}):'
|
||||
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
||||
except urllib.error.URLError as error:
|
||||
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
|
||||
|
||||
def init_bot_channel(self) -> TextChannel | None:
|
||||
for channel in self.guild_text_channels:
|
||||
if channel.name == self.config.bot_channel:
|
||||
|
|
@ -203,14 +105,15 @@ class DiscordManager:
|
|||
return channel
|
||||
|
||||
self.logger.info('Creating breadtube bot channel')
|
||||
return self.create_text_channel({
|
||||
'name': self.config.bot_channel,
|
||||
'permission_overwrites': [
|
||||
Overwrite(self.everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
|
||||
deny=Permissions.VIEW_CHANNEL),
|
||||
Overwrite(self.bot_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL,
|
||||
deny=Permissions.NONE)]
|
||||
})
|
||||
return self.discord_manager.create_text_channel(
|
||||
self.guild_id, {
|
||||
'name': self.config.bot_channel,
|
||||
'permission_overwrites': [
|
||||
Overwrite(self.everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
|
||||
deny=Permissions.VIEW_CHANNEL),
|
||||
Overwrite(self.bot_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL,
|
||||
deny=Permissions.NONE)]},
|
||||
request_timeout=self.config.request_timeout)
|
||||
|
||||
def _get_bot_channel_messages(self) -> list[Message]:
|
||||
messages_id_delete_task: set[int] = set()
|
||||
|
|
@ -221,7 +124,8 @@ class DiscordManager:
|
|||
last_message_id: int | None = None
|
||||
messages: list[Message] = []
|
||||
while True:
|
||||
message_batch = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
|
||||
message_batch = self.discord_manager.list_text_channel_messages(
|
||||
self.bot_channel, request_timeout=self.config.request_timeout, after_id=last_message_id)
|
||||
messages.extend([m for m in message_batch if m.id not in messages_id_delete_task])
|
||||
if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT:
|
||||
break
|
||||
|
|
@ -255,7 +159,8 @@ class DiscordManager:
|
|||
has_error = False
|
||||
for attachment in message.attachments:
|
||||
try:
|
||||
_, content = self._download_attachment(attachment)
|
||||
_, content = self.discord_manager.download_attachment(
|
||||
attachment, request_timeout=self.config.request_timeout)
|
||||
if new_config is None and content.startswith(b'config'):
|
||||
try:
|
||||
self.config = Config.from_str(content.decode())
|
||||
|
|
@ -269,7 +174,7 @@ class DiscordManager:
|
|||
SubscriptionHelper.update_subscriptions(
|
||||
new=subscriptions, previous=self._yt_subscriptions)
|
||||
self._yt_subscriptions = subscriptions
|
||||
self.tasks.append((DiscordManager.Task.INIT_SUBS, time.time() + 1, None))
|
||||
self.tasks.append((Bot.Task.INIT_SUBS, time.time() + 1, None))
|
||||
except RuntimeError as error:
|
||||
self.logger.error('Invalid init subscriptions file: %s', error)
|
||||
has_error = True
|
||||
|
|
@ -288,7 +193,8 @@ class DiscordManager:
|
|||
immediate_delete[message.id] = message
|
||||
continue
|
||||
try:
|
||||
_, content = self._download_attachment(attachment)
|
||||
_, content = self.discord_manager.download_attachment(
|
||||
attachment, request_timeout=self.config.request_timeout)
|
||||
if content.startswith(b'config') and new_config is None:
|
||||
try:
|
||||
config = Config.from_str(content.decode())
|
||||
|
|
@ -299,14 +205,14 @@ class DiscordManager:
|
|||
continue
|
||||
except RuntimeError as error:
|
||||
self.logger.info('Invalid config file: %s', error)
|
||||
bot_message = self.create_message(self.bot_channel, {
|
||||
bot_message = self.discord_manager.create_message(self.bot_channel, {
|
||||
'content': str(error),
|
||||
'message_reference': MessageReference(
|
||||
type=MessageReferenceType.DEFAULT,
|
||||
message_id=message.id,
|
||||
channel_id=self.bot_channel.id,
|
||||
guild_id=None,
|
||||
fail_if_not_exists=None)})
|
||||
fail_if_not_exists=None)}, request_timeout=self.config.request_timeout)
|
||||
delayed_delete[bot_message.id] = bot_message
|
||||
delayed_delete[message.id] = message
|
||||
continue
|
||||
|
|
@ -320,14 +226,14 @@ class DiscordManager:
|
|||
continue
|
||||
except RuntimeError as error:
|
||||
self.logger.info('Invalid subscriptions file: %s', error)
|
||||
bot_message = self.create_message(self.bot_channel, {
|
||||
bot_message = self.discord_manager.create_message(self.bot_channel, {
|
||||
'content': str(error),
|
||||
'message_reference': MessageReference(
|
||||
type=MessageReferenceType.DEFAULT,
|
||||
message_id=message.id,
|
||||
channel_id=self.bot_channel.id,
|
||||
guild_id=None,
|
||||
fail_if_not_exists=None)})
|
||||
fail_if_not_exists=None)}, request_timeout=self.config.request_timeout)
|
||||
delayed_delete[bot_message.id] = bot_message
|
||||
delayed_delete[message.id] = message
|
||||
continue
|
||||
|
|
@ -342,7 +248,7 @@ class DiscordManager:
|
|||
self.logger.info('Loading subscriptions: %s', new_subscriptions)
|
||||
SubscriptionHelper.update_subscriptions(new=new_subscriptions, previous=self._yt_subscriptions)
|
||||
self._yt_subscriptions = new_subscriptions
|
||||
self.tasks.append((DiscordManager.Task.INIT_SUBS, time.time() + 1, None))
|
||||
self.tasks.append((Bot.Task.INIT_SUBS, time.time() + 1, None))
|
||||
|
||||
# New init message is needed, previous need to be deleted
|
||||
if (new_config is not None or new_subscriptions is not None) and self.init_message is not None:
|
||||
|
|
@ -352,8 +258,8 @@ class DiscordManager:
|
|||
# Init message absent or deleted
|
||||
if self.init_message is None:
|
||||
assert self.config is not None
|
||||
self.init_message = self.create_message(
|
||||
self.bot_channel, {'content': self.INIT_MESSAGE},
|
||||
self.init_message = self.discord_manager.create_message(
|
||||
self.bot_channel, {'content': self.INIT_MESSAGE}, request_timeout=self.config.request_timeout,
|
||||
upload_files=[
|
||||
('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode()),
|
||||
('subscriptions.csv', FileMime.TEXT_CSV, SubscriptionHelper.generate_text(self._yt_subscriptions))
|
||||
|
|
@ -361,18 +267,19 @@ class DiscordManager:
|
|||
|
||||
for message in immediate_delete.values():
|
||||
try:
|
||||
self.delete_message(message)
|
||||
self.discord_manager.delete_message(message, request_timeout=self.config.request_timeout)
|
||||
except RuntimeError as error:
|
||||
self.logger.error('Error deleting after bot channel scan (immediate): %s', error)
|
||||
|
||||
if delayed_delete:
|
||||
self.tasks.append((
|
||||
DiscordManager.Task.DELETE_MESSAGES,
|
||||
Bot.Task.DELETE_MESSAGES,
|
||||
time.time() + self.config.bot_message_duration,
|
||||
list(delayed_delete.values())))
|
||||
|
||||
def _init_subs(self):
|
||||
categories, text_channel = self.list_channels()
|
||||
categories, text_channel = self.discord_manager.list_channels(
|
||||
self.guild_id, request_timeout=self.config.request_timeout)
|
||||
self.guild_text_channels = text_channel
|
||||
self.guild_categories = categories
|
||||
|
||||
|
|
@ -404,16 +311,17 @@ class DiscordManager:
|
|||
break
|
||||
if selected_category is None:
|
||||
selected_category = category_ranges[-1][2]
|
||||
sub_channel = self.create_text_channel({
|
||||
'name': discord_name,
|
||||
'parent_id': selected_category.id,
|
||||
'permission_overwrites': [
|
||||
Overwrite(self.everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
|
||||
deny=Permissions.SEND_MESSAGES),
|
||||
Overwrite(self.bot_role.id, OverwriteType.ROLE,
|
||||
allow=Permissions.VIEW_CHANNEL | Permissions.SEND_MESSAGES,
|
||||
deny=Permissions.NONE)]
|
||||
})
|
||||
sub_channel = self.discord_manager.create_text_channel(
|
||||
self.guild_id, {
|
||||
'name': discord_name,
|
||||
'parent_id': selected_category.id,
|
||||
'permission_overwrites': [
|
||||
Overwrite(self.everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
|
||||
deny=Permissions.SEND_MESSAGES),
|
||||
Overwrite(self.bot_role.id, OverwriteType.ROLE,
|
||||
allow=Permissions.VIEW_CHANNEL | Permissions.SEND_MESSAGES,
|
||||
deny=Permissions.NONE)]},
|
||||
request_timeout=self.config.request_timeout)
|
||||
if sub_info.channel_info is None:
|
||||
_, channel_info = self.yt_manager.request_channel_info(
|
||||
sub_info.channel_id, request_timeout=self.config.request_timeout)
|
||||
|
|
@ -422,7 +330,8 @@ class DiscordManager:
|
|||
continue
|
||||
sub_info.channel_info = channel_info.items[0].snippet
|
||||
channel_url = f'https://www.youtube.com/{sub_info.channel_info.customUrl}'
|
||||
_ = self.create_message(sub_channel, {'content': channel_url})
|
||||
_ = self.discord_manager.create_message(
|
||||
sub_channel, {'content': channel_url}, request_timeout=self.config.request_timeout)
|
||||
sub_info.last_update = time.time()
|
||||
|
||||
def run(self):
|
||||
|
|
@ -436,7 +345,7 @@ class DiscordManager:
|
|||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
match task_type:
|
||||
case DiscordManager.Task.DELETE_MESSAGES:
|
||||
case Bot.Task.DELETE_MESSAGES:
|
||||
if not isinstance(task_params, list):
|
||||
self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params)
|
||||
elif not task_params:
|
||||
|
|
@ -446,11 +355,12 @@ class DiscordManager:
|
|||
else:
|
||||
for message in task_params:
|
||||
try:
|
||||
self.delete_message(message)
|
||||
self.discord_manager.delete_message(
|
||||
message, request_timeout=self.config.request_timeout)
|
||||
except Exception as error:
|
||||
self.logger.error('Error deleting message %s: %s -> %s',
|
||||
message, error, traceback.format_exc().replace('\n', ' | '))
|
||||
case DiscordManager.Task.SCAN_BOT_CHANNEL:
|
||||
case Bot.Task.SCAN_BOT_CHANNEL:
|
||||
try:
|
||||
self._scan_bot_channel()
|
||||
except Exception as error:
|
||||
|
|
@ -458,73 +368,10 @@ class DiscordManager:
|
|||
error, traceback.format_exc().replace('\n', ' | '))
|
||||
self.tasks.append((
|
||||
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
|
||||
case DiscordManager.Task.INIT_SUBS:
|
||||
case Bot.Task.INIT_SUBS:
|
||||
try:
|
||||
self._init_subs()
|
||||
except Exception as error:
|
||||
self.logger.error('Error initializing subscriptions : %s -> %s',
|
||||
error, traceback.format_exc().replace('\n', ' | '))
|
||||
time.sleep(1)
|
||||
|
||||
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
|
||||
_, channel_info = self._send_request(
|
||||
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
||||
expected_code=201)
|
||||
if not isinstance(channel_info, dict):
|
||||
raise RuntimeError(f'Error creating channel with params (no info): {params}')
|
||||
return TextChannel.from_dict(channel_info)
|
||||
|
||||
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
|
||||
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
|
||||
_, message_info = self._send_request(
|
||||
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
||||
upload_files=upload_files)
|
||||
if not isinstance(message_info, dict):
|
||||
raise RuntimeError(f'Error creating message with params (no info): {params}')
|
||||
return Message.from_dict(message_info)
|
||||
|
||||
def delete_message(self, message: Message):
|
||||
_, _ = self._send_request(
|
||||
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
|
||||
|
||||
def get_current_user(self) -> User:
|
||||
_, user_info = self._send_request(*Api.User.get_current())
|
||||
if not isinstance(user_info, dict):
|
||||
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
|
||||
return User.from_dict(user_info)
|
||||
|
||||
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
|
||||
_, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
|
||||
categories: list[ChannelCategory] = []
|
||||
text_channels: list[TextChannel] = []
|
||||
if channels_info is not None:
|
||||
for channel_info in channels_info:
|
||||
channel_type = ChannelType(channel_info['type'])
|
||||
match channel_type:
|
||||
case ChannelType.GUILD_CATEGORY:
|
||||
categories.append(ChannelCategory.from_dict(channel_info))
|
||||
case ChannelType.GUILD_TEXT:
|
||||
text_channels.append(TextChannel.from_dict(channel_info))
|
||||
return categories, text_channels
|
||||
|
||||
def list_roles(self) -> list[Role]:
|
||||
_, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
|
||||
if not isinstance(roles_info, list):
|
||||
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
|
||||
return [Role.from_dict(r) for r in roles_info]
|
||||
|
||||
def list_text_channel_messages(self, channel: TextChannel, limit: int | None = None, before_id: int | None = None,
|
||||
after_id: int | None = None) -> list[Message]:
|
||||
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
|
||||
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
|
||||
params: Api.Message.ListMessageParams = {}
|
||||
if limit is not None:
|
||||
params['limit'] = limit
|
||||
if before_id is not None:
|
||||
params['before'] = before_id
|
||||
if after_id is not None:
|
||||
params['after'] = after_id
|
||||
_, messages_info = self._send_request(
|
||||
*Api.Message.list_by_channel(channel.id),
|
||||
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
|
||||
return [Message.from_dict(m) for m in messages_info or []]
|
||||
187
breadtube_bot/discord_manager.py
Normal file
187
breadtube_bot/discord_manager.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
from .api import Api, ApiAction, ApiVersion
|
||||
from .objects import Attachment, ChannelCategory, ChannelType, FileMime, HTTPHeaders, Message, Role, TextChannel, User
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import logging
|
||||
|
||||
|
||||
class ApiEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if is_dataclass(o):
|
||||
return asdict(o) # type: ignore
|
||||
if isinstance(o, Enum):
|
||||
return o.value
|
||||
return super().default(o)
|
||||
|
||||
|
||||
class DiscordManager:
|
||||
MIN_API_VERSION = 9
|
||||
|
||||
@dataclass
|
||||
class RateLimit:
|
||||
remaining: int
|
||||
next_reset: float
|
||||
|
||||
def __init__(self, bot_token: str, bot_version: str, logger: logging.Logger):
|
||||
self._bot_token = bot_token
|
||||
self._logger = logger
|
||||
self._version = bot_version
|
||||
|
||||
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
|
||||
|
||||
def _update_rate_limit(self, headers: HTTPHeaders):
|
||||
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
|
||||
if header_key not in headers:
|
||||
self._logger.info('Warning: no "%s" found in headers', header_key)
|
||||
return
|
||||
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
|
||||
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
|
||||
self._logger.debug('Updated rate limit: %s', self.rate_limit)
|
||||
if self.rate_limit.remaining <= 0:
|
||||
sleep_time = self.rate_limit.next_reset - time.time()
|
||||
if sleep_time > 0:
|
||||
self._logger.debug('Rate limit: sleeping %.03f second', sleep_time)
|
||||
time.sleep(sleep_time)
|
||||
|
||||
def _send_request(
|
||||
self, api_action: ApiAction, endpoint: str, request_timeout: float, data: bytes | None = None,
|
||||
expected_code: int = 200, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
|
||||
api_version: ApiVersion = ApiVersion.V10) -> tuple[HTTPHeaders, dict | list | None]:
|
||||
if api_version.value < self.MIN_API_VERSION:
|
||||
self._logger.warning(
|
||||
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
|
||||
api_version, self.MIN_API_VERSION)
|
||||
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
|
||||
self._logger.debug('Discord API Request: %s %s', api_action.value, url)
|
||||
|
||||
boundary: str = ''
|
||||
if upload_files:
|
||||
boundary = f'{random.randbytes(16).hex()}'
|
||||
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
|
||||
'Content-Type: application/json\r\n\r\n'.encode() + data
|
||||
+ f'\r\n--{boundary}'.encode()) if data else b''
|
||||
for file_index, (name, mime, content) in enumerate(upload_files):
|
||||
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
|
||||
f' filename="{name}"\r\nContent-Length: {len(content)}'
|
||||
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
|
||||
data += f'\r\n--{boundary}--'.encode()
|
||||
request = urllib.request.Request(url, data=data)
|
||||
request.method = api_action.value
|
||||
request.add_header('User-Agent', f'BreadTube (v{self._version})')
|
||||
request.add_header('Accept', 'application/json')
|
||||
if upload_files:
|
||||
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
|
||||
else:
|
||||
request.add_header('Content-Type', 'application/json')
|
||||
request.add_header('Authorization', f'Bot {self._bot_token}')
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=request_timeout) as response:
|
||||
if response.status != expected_code:
|
||||
raise RuntimeError(
|
||||
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
|
||||
body = response.read()
|
||||
headers = dict(response.getheaders())
|
||||
self._update_rate_limit(headers)
|
||||
return headers, json.loads(body.decode()) if body else None
|
||||
except urllib.error.HTTPError as error:
|
||||
raise RuntimeError(
|
||||
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
||||
except urllib.error.URLError as error:
|
||||
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
|
||||
except TimeoutError as error:
|
||||
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
|
||||
|
||||
def download_attachment(self, attachment: Attachment, request_timeout: float,
|
||||
expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
|
||||
request = urllib.request.Request(attachment.url)
|
||||
request.add_header('User-Agent', f'BreadTube (v{self._version})')
|
||||
request.add_header('Authorization', f'Bot {self._bot_token}')
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=request_timeout) as response:
|
||||
if response.status != expected_code:
|
||||
raise RuntimeError(
|
||||
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
|
||||
return dict(response.getheaders()), response.read()
|
||||
except urllib.error.HTTPError as error:
|
||||
raise RuntimeError(
|
||||
f'HTTP error downloading attachment ({attachment}):'
|
||||
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
||||
except urllib.error.URLError as error:
|
||||
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
|
||||
|
||||
def create_text_channel(
|
||||
self, guild_id: int, params: Api.Guild.CreateTextChannelParams, request_timeout: float) -> TextChannel:
|
||||
_, channel_info = self._send_request(
|
||||
*Api.Guild.create_channel(guild_id=guild_id), request_timeout=request_timeout,
|
||||
data=json.dumps(params, cls=ApiEncoder).encode(), expected_code=201)
|
||||
if not isinstance(channel_info, dict):
|
||||
raise RuntimeError(f'Error creating channel with params (no info): {params}')
|
||||
return TextChannel.from_dict(channel_info)
|
||||
|
||||
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams, request_timeout: float,
|
||||
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
|
||||
_, message_info = self._send_request(
|
||||
*Api.Message.create(channel_id=channel.id), request_timeout=request_timeout,
|
||||
data=json.dumps(params, cls=ApiEncoder).encode(), upload_files=upload_files)
|
||||
if not isinstance(message_info, dict):
|
||||
raise RuntimeError(f'Error creating message with params (no info): {params}')
|
||||
return Message.from_dict(message_info)
|
||||
|
||||
def delete_message(self, message: Message, request_timeout: float):
|
||||
_, _ = self._send_request(
|
||||
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), request_timeout=request_timeout,
|
||||
expected_code=204)
|
||||
|
||||
def get_current_user(self, request_timeout: float) -> User:
|
||||
_, user_info = self._send_request(*Api.User.get_current(), request_timeout=request_timeout)
|
||||
if not isinstance(user_info, dict):
|
||||
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
|
||||
return User.from_dict(user_info)
|
||||
|
||||
def list_channels(self, guild_id: int, request_timeout: float) -> tuple[list[ChannelCategory], list[TextChannel]]:
|
||||
_, channels_info = self._send_request(*Api.Guild.list_guilds(guild_id), request_timeout=request_timeout)
|
||||
categories: list[ChannelCategory] = []
|
||||
text_channels: list[TextChannel] = []
|
||||
if channels_info is not None:
|
||||
for channel_info in channels_info:
|
||||
channel_type = ChannelType(channel_info['type'])
|
||||
match channel_type:
|
||||
case ChannelType.GUILD_CATEGORY:
|
||||
categories.append(ChannelCategory.from_dict(channel_info))
|
||||
case ChannelType.GUILD_TEXT:
|
||||
text_channels.append(TextChannel.from_dict(channel_info))
|
||||
return categories, text_channels
|
||||
|
||||
def list_roles(self, guild_id: int, request_timeout: float) -> list[Role]:
|
||||
_, roles_info = self._send_request(*Api.Guild.list_roles(guild_id), request_timeout=request_timeout)
|
||||
if not isinstance(roles_info, list):
|
||||
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
|
||||
return [Role.from_dict(r) for r in roles_info]
|
||||
|
||||
def list_text_channel_messages(
|
||||
self, channel: TextChannel, request_timeout: float, limit: int | None = None, before_id: int | None = None,
|
||||
after_id: int | None = None) -> list[Message]:
|
||||
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
|
||||
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
|
||||
params: Api.Message.ListMessageParams = {}
|
||||
if limit is not None:
|
||||
params['limit'] = limit
|
||||
if before_id is not None:
|
||||
params['before'] = before_id
|
||||
if after_id is not None:
|
||||
params['after'] = after_id
|
||||
_, messages_info = self._send_request(
|
||||
*Api.Message.list_by_channel(channel.id), request_timeout=request_timeout,
|
||||
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
|
||||
return [Message.from_dict(m) for m in messages_info or []]
|
||||
Loading…
Add table
Add a link
Reference in a new issue