207 lines
10 KiB
Python
207 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import asdict, dataclass, is_dataclass
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
import json
|
|
import random
|
|
import time
|
|
import urllib.error
|
|
import urllib.request
|
|
|
|
from .api import Api, ApiAction, ApiVersion
|
|
from .objects import Attachment, ChannelCategory, ChannelType, FileMime, HTTPHeaders, Message, Role, TextChannel, User
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
import logging
|
|
|
|
|
|
class ApiEncoder(json.JSONEncoder):
|
|
def default(self, o):
|
|
if is_dataclass(o):
|
|
return asdict(o) # type: ignore
|
|
if isinstance(o, Enum):
|
|
return o.value
|
|
if isinstance(o, datetime):
|
|
return o.isoformat()
|
|
return super().default(o)
|
|
|
|
|
|
class DiscordManager:
|
|
MIN_API_VERSION = 9
|
|
TOO_MANY_REQUEST_STATUS = 429
|
|
|
|
@dataclass
|
|
class RateLimit:
|
|
remaining: int
|
|
next_reset: float
|
|
|
|
def __init__(self, bot_token: str, bot_version: str, logger: logging.Logger):
|
|
self._bot_token = bot_token
|
|
self._logger = logger
|
|
self._version = bot_version
|
|
|
|
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
|
|
|
|
def _update_rate_limit(self, headers: HTTPHeaders):
|
|
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
|
|
if header_key not in headers:
|
|
self._logger.info('Warning: no "%s" found in headers', header_key)
|
|
return
|
|
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
|
|
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
|
|
self._logger.debug('Updated rate limit: %s', self.rate_limit)
|
|
if self.rate_limit.remaining <= 0:
|
|
sleep_time = self.rate_limit.next_reset - time.time()
|
|
if sleep_time > 0:
|
|
self._logger.debug('Rate limit: sleeping %.03f second', sleep_time)
|
|
time.sleep(sleep_time)
|
|
|
|
def _send_request(
|
|
self, api_action: ApiAction, endpoint: str, request_timeout: float, data: bytes | None = None,
|
|
expected_code: int = 200, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
|
|
api_version: ApiVersion = ApiVersion.V10) -> tuple[HTTPHeaders, dict | list | None]:
|
|
if api_version.value < self.MIN_API_VERSION:
|
|
self._logger.warning(
|
|
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
|
|
api_version, self.MIN_API_VERSION)
|
|
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
|
|
self._logger.debug('Discord API Request: %s %s', api_action.value, url)
|
|
|
|
boundary: str = ''
|
|
if upload_files:
|
|
boundary = f'{random.randbytes(16).hex()}'
|
|
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
|
|
'Content-Type: application/json\r\n\r\n'.encode() + data
|
|
+ f'\r\n--{boundary}'.encode()) if data else b''
|
|
for file_index, (name, mime, content) in enumerate(upload_files):
|
|
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
|
|
f' filename="{name}"\r\nContent-Length: {len(content)}'
|
|
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
|
|
data += f'\r\n--{boundary}--'.encode()
|
|
request = urllib.request.Request(url, data=data)
|
|
request.method = api_action.value
|
|
request.add_header('User-Agent', f'BreadTube (v{self._version})')
|
|
request.add_header('Accept', 'application/json')
|
|
if upload_files:
|
|
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
|
|
else:
|
|
request.add_header('Content-Type', 'application/json')
|
|
request.add_header('Authorization', f'Bot {self._bot_token}')
|
|
|
|
def _request() -> tuple[int, dict, bytes | None]:
|
|
nonlocal request, request_timeout
|
|
with urllib.request.urlopen(request, timeout=request_timeout) as response:
|
|
headers = dict(response.getheaders())
|
|
return response.status, headers, response.read()
|
|
|
|
try:
|
|
body = b''
|
|
try:
|
|
status, headers, body = _request()
|
|
except urllib.error.HTTPError as error:
|
|
if error.status != self.TOO_MANY_REQUEST_STATUS:
|
|
raise error
|
|
status = error.status
|
|
headers = dict(error.headers)
|
|
|
|
self._update_rate_limit(headers)
|
|
if status == self.TOO_MANY_REQUEST_STATUS:
|
|
self._logger.warning('Warning: too many request -> retrying')
|
|
status, headers, body = _request()
|
|
self._update_rate_limit(headers)
|
|
if status != expected_code:
|
|
raise RuntimeError(f'Unexpected code {status} (expected: {expected_code}) -> {body}')
|
|
return headers, json.loads(body.decode()) if body else None
|
|
except urllib.error.HTTPError as error:
|
|
raise RuntimeError(
|
|
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
|
except urllib.error.URLError as error:
|
|
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
|
|
except TimeoutError as error:
|
|
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
|
|
|
|
def download_attachment(self, attachment: Attachment, request_timeout: float,
|
|
expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
|
|
request = urllib.request.Request(attachment.url)
|
|
request.add_header('User-Agent', f'BreadTube (v{self._version})')
|
|
request.add_header('Authorization', f'Bot {self._bot_token}')
|
|
try:
|
|
with urllib.request.urlopen(request, timeout=request_timeout) as response:
|
|
if response.status != expected_code:
|
|
raise RuntimeError(
|
|
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
|
|
return dict(response.getheaders()), response.read()
|
|
except urllib.error.HTTPError as error:
|
|
raise RuntimeError(
|
|
f'HTTP error downloading attachment ({attachment}):'
|
|
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
|
except urllib.error.URLError as error:
|
|
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
|
|
|
|
def create_text_channel(
|
|
self, guild_id: int, params: Api.Guild.CreateTextChannelParams, request_timeout: float) -> TextChannel:
|
|
_, channel_info = self._send_request(
|
|
*Api.Guild.create_channel(guild_id=guild_id), request_timeout=request_timeout,
|
|
data=json.dumps(params, cls=ApiEncoder).encode(), expected_code=201)
|
|
if not isinstance(channel_info, dict):
|
|
raise RuntimeError(f'Error creating channel with params (no info): {params}')
|
|
return TextChannel.from_dict(channel_info)
|
|
|
|
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams, request_timeout: float,
|
|
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
|
|
_, message_info = self._send_request(
|
|
*Api.Message.create(channel_id=channel.id), request_timeout=request_timeout,
|
|
data=json.dumps(params, cls=ApiEncoder).encode(), upload_files=upload_files)
|
|
if not isinstance(message_info, dict):
|
|
raise RuntimeError(f'Error creating message with params (no info): {params}')
|
|
return Message.from_dict(message_info)
|
|
|
|
def delete_message(self, message: Message, request_timeout: float):
|
|
_, _ = self._send_request(
|
|
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), request_timeout=request_timeout,
|
|
expected_code=204)
|
|
|
|
def get_current_user(self, request_timeout: float) -> User:
|
|
_, user_info = self._send_request(*Api.User.get_current(), request_timeout=request_timeout)
|
|
if not isinstance(user_info, dict):
|
|
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
|
|
return User.from_dict(user_info)
|
|
|
|
def list_channels(self, guild_id: int, request_timeout: float) -> tuple[list[ChannelCategory], list[TextChannel]]:
|
|
_, channels_info = self._send_request(*Api.Guild.list_guilds(guild_id), request_timeout=request_timeout)
|
|
categories: list[ChannelCategory] = []
|
|
text_channels: list[TextChannel] = []
|
|
if channels_info is not None:
|
|
for channel_info in channels_info:
|
|
channel_type = ChannelType(channel_info['type'])
|
|
match channel_type:
|
|
case ChannelType.GUILD_CATEGORY:
|
|
categories.append(ChannelCategory.from_dict(channel_info))
|
|
case ChannelType.GUILD_TEXT:
|
|
text_channels.append(TextChannel.from_dict(channel_info))
|
|
return categories, text_channels
|
|
|
|
def list_roles(self, guild_id: int, request_timeout: float) -> list[Role]:
|
|
_, roles_info = self._send_request(*Api.Guild.list_roles(guild_id), request_timeout=request_timeout)
|
|
if not isinstance(roles_info, list):
|
|
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
|
|
return [Role.from_dict(r) for r in roles_info]
|
|
|
|
def list_text_channel_messages(
|
|
self, channel: TextChannel, request_timeout: float, limit: int | None = None, before_id: int | None = None,
|
|
after_id: int | None = None) -> list[Message]:
|
|
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
|
|
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
|
|
params: Api.Message.ListMessageParams = {}
|
|
if limit is not None:
|
|
params['limit'] = limit
|
|
if before_id is not None:
|
|
params['before'] = before_id
|
|
if after_id is not None:
|
|
params['after'] = after_id
|
|
_, messages_info = self._send_request(
|
|
*Api.Message.list_by_channel(channel.id), request_timeout=request_timeout,
|
|
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
|
|
return [Message.from_dict(m) for m in messages_info or []]
|