Code refactored

This commit is contained in:
BreadTube 2025-09-29 18:49:49 +09:00 committed by Corentin
commit d5b3436aec
3 changed files with 247 additions and 213 deletions

View file

@ -0,0 +1,187 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, is_dataclass
from enum import Enum
import json
import random
import time
import urllib.error
import urllib.request
from .api import Api, ApiAction, ApiVersion
from .objects import Attachment, ChannelCategory, ChannelType, FileMime, HTTPHeaders, Message, Role, TextChannel, User
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import logging
class ApiEncoder(json.JSONEncoder):
def default(self, o):
if is_dataclass(o):
return asdict(o) # type: ignore
if isinstance(o, Enum):
return o.value
return super().default(o)
class DiscordManager:
MIN_API_VERSION = 9
@dataclass
class RateLimit:
remaining: int
next_reset: float
def __init__(self, bot_token: str, bot_version: str, logger: logging.Logger):
self._bot_token = bot_token
self._logger = logger
self._version = bot_version
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
def _update_rate_limit(self, headers: HTTPHeaders):
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
if header_key not in headers:
self._logger.info('Warning: no "%s" found in headers', header_key)
return
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
self._logger.debug('Updated rate limit: %s', self.rate_limit)
if self.rate_limit.remaining <= 0:
sleep_time = self.rate_limit.next_reset - time.time()
if sleep_time > 0:
self._logger.debug('Rate limit: sleeping %.03f second', sleep_time)
time.sleep(sleep_time)
def _send_request(
self, api_action: ApiAction, endpoint: str, request_timeout: float, data: bytes | None = None,
expected_code: int = 200, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
api_version: ApiVersion = ApiVersion.V10) -> tuple[HTTPHeaders, dict | list | None]:
if api_version.value < self.MIN_API_VERSION:
self._logger.warning(
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
api_version, self.MIN_API_VERSION)
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
self._logger.debug('Discord API Request: %s %s', api_action.value, url)
boundary: str = ''
if upload_files:
boundary = f'{random.randbytes(16).hex()}'
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
'Content-Type: application/json\r\n\r\n'.encode() + data
+ f'\r\n--{boundary}'.encode()) if data else b''
for file_index, (name, mime, content) in enumerate(upload_files):
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
f' filename="{name}"\r\nContent-Length: {len(content)}'
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
data += f'\r\n--{boundary}--'.encode()
request = urllib.request.Request(url, data=data)
request.method = api_action.value
request.add_header('User-Agent', f'BreadTube (v{self._version})')
request.add_header('Accept', 'application/json')
if upload_files:
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
else:
request.add_header('Content-Type', 'application/json')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
body = response.read()
headers = dict(response.getheaders())
self._update_rate_limit(headers)
return headers, json.loads(body.decode()) if body else None
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
except TimeoutError as error:
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
def download_attachment(self, attachment: Attachment, request_timeout: float,
expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
request = urllib.request.Request(attachment.url)
request.add_header('User-Agent', f'BreadTube (v{self._version})')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
return dict(response.getheaders()), response.read()
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error downloading attachment ({attachment}):'
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
def create_text_channel(
self, guild_id: int, params: Api.Guild.CreateTextChannelParams, request_timeout: float) -> TextChannel:
_, channel_info = self._send_request(
*Api.Guild.create_channel(guild_id=guild_id), request_timeout=request_timeout,
data=json.dumps(params, cls=ApiEncoder).encode(), expected_code=201)
if not isinstance(channel_info, dict):
raise RuntimeError(f'Error creating channel with params (no info): {params}')
return TextChannel.from_dict(channel_info)
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams, request_timeout: float,
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
_, message_info = self._send_request(
*Api.Message.create(channel_id=channel.id), request_timeout=request_timeout,
data=json.dumps(params, cls=ApiEncoder).encode(), upload_files=upload_files)
if not isinstance(message_info, dict):
raise RuntimeError(f'Error creating message with params (no info): {params}')
return Message.from_dict(message_info)
def delete_message(self, message: Message, request_timeout: float):
_, _ = self._send_request(
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), request_timeout=request_timeout,
expected_code=204)
def get_current_user(self, request_timeout: float) -> User:
_, user_info = self._send_request(*Api.User.get_current(), request_timeout=request_timeout)
if not isinstance(user_info, dict):
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
return User.from_dict(user_info)
def list_channels(self, guild_id: int, request_timeout: float) -> tuple[list[ChannelCategory], list[TextChannel]]:
_, channels_info = self._send_request(*Api.Guild.list_guilds(guild_id), request_timeout=request_timeout)
categories: list[ChannelCategory] = []
text_channels: list[TextChannel] = []
if channels_info is not None:
for channel_info in channels_info:
channel_type = ChannelType(channel_info['type'])
match channel_type:
case ChannelType.GUILD_CATEGORY:
categories.append(ChannelCategory.from_dict(channel_info))
case ChannelType.GUILD_TEXT:
text_channels.append(TextChannel.from_dict(channel_info))
return categories, text_channels
def list_roles(self, guild_id: int, request_timeout: float) -> list[Role]:
_, roles_info = self._send_request(*Api.Guild.list_roles(guild_id), request_timeout=request_timeout)
if not isinstance(roles_info, list):
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
return [Role.from_dict(r) for r in roles_info]
def list_text_channel_messages(
self, channel: TextChannel, request_timeout: float, limit: int | None = None, before_id: int | None = None,
after_id: int | None = None) -> list[Message]:
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
params: Api.Message.ListMessageParams = {}
if limit is not None:
params['limit'] = limit
if before_id is not None:
params['before'] = before_id
if after_id is not None:
params['after'] = after_id
_, messages_info = self._send_request(
*Api.Message.list_by_channel(channel.id), request_timeout=request_timeout,
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
return [Message.from_dict(m) for m in messages_info or []]