breadtube-bot/breadtube_bot/manager.py
2025-10-04 16:51:08 +09:00

160 lines
6.9 KiB
Python

from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
import json
import tomllib
import urllib.error
import urllib.request
from .api import Api, ApiAction, ApiVersion
from .objects import (
ChannelCategory, ChannelFlags, ChannelType, Message, Overwrite, OverwriteType, Permissions, TextChannel, User)
HTTPHeaders = dict[str, str]
@dataclass
class _RateLimit:
remaining: int
next_reset: float
class DiscordManager:
@staticmethod
def _get_code_version() -> str:
pyproject_path = Path(__file__).parents[1] / 'pyproject.toml'
if not pyproject_path.exists():
raise RuntimeError('Cannot current bot version')
return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version']
def __init__(self, bot_token: str, guild_id: int) -> None:
self.guild_id = guild_id
self._bot_token = bot_token
self.rate_limit: _RateLimit = _RateLimit(remaining=1, next_reset=0)
self.version = self._get_code_version()
def _update_rate_limit(self, headers: HTTPHeaders):
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
if header_key not in headers:
print(f'Warning: no "{header_key}" found in headers')
return
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
expected_code: int = 200) -> tuple[
HTTPHeaders, dict]:
timeout = 3
min_api_version = 9
if api_action == ApiAction.POST:
raise NotImplementedError
if api_version.value < min_api_version:
print(f'Warning: using deprecated API version {api_version} (minimum non deprecated is {min_api_version})')
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
request = urllib.request.Request(url)
request.add_header('User-Agent', f'BreadTube (v{self.version})')
request.add_header('Accept', 'application/json')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
return dict(response.getheaders()), json.loads(response.read().decode())
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
@staticmethod
def _parse_overwrite(info: dict) -> Overwrite:
return Overwrite(
id=int(info['id']),
type=OverwriteType(info['type']),
allow=Permissions(int(info['allow'])),
deny=Permissions(int(info['deny']))
)
def _parse_channel_category(self, info: dict) -> ChannelCategory:
parent_id: str | None = info.get('parent_id')
return ChannelCategory(
id=int(info['id']),
guild_id=int(info['guild_id']),
position=int(info['position']),
permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']],
name=info.get('name'),
parent_id=int(parent_id) if parent_id is not None else None,
flags=ChannelFlags(info['flags']),
)
def _parse_text_channel(self, info: dict) -> TextChannel:
parent_id: str | None = info.get('parent_id')
last_message_id: str | None = info.get('last_message_id')
last_pin_timestamp: str | None = info.get('last_pin_timestamp')
return TextChannel(
id=int(info['id']),
guild_id=int(info['guild_id']),
position=int(info['position']),
permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']],
name=info.get('name'),
topic=info.get('topic'),
nsfw=info['nsfw'],
last_message_id=int(last_message_id) if last_message_id is not None else None,
rate_limit_per_user=int(info['rate_limit_per_user']),
parent_id=int(parent_id) if parent_id is not None else None,
last_pin_timestamp=(datetime.fromisoformat(last_pin_timestamp) if last_pin_timestamp is not None else None),
flags=ChannelFlags(info['flags']),
)
@staticmethod
def _parse_user(info: dict) -> User:
return User(
id=int(info['id']),
username=info['username'],
discriminator=info['discriminator'],
global_name=info.get('global_name')
)
def _parse_message(self, info: dict) -> Message:
edited_timestamp: str | None = info.get('edited_timestamp')
return Message(
id=int(info['id']),
channel_id=int(info['channel_id']),
author=(self._parse_user(info['author']) if info.get('webhook_id') is None else User(
id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)),
content=info['content'],
timestamp=datetime.fromisoformat(info['timestamp']),
edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None
)
def delete_message(self, message: Message):
try:
headers, _ = self._send_request(
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
self._update_rate_limit(headers)
print(f'Message {message.id} deleted')
except RuntimeError as error:
print(error)
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
headers, channels = self._send_request(*Api.Guild.list_guilds(self.guild_id))
self._update_rate_limit(headers)
categories: list[ChannelCategory] = []
text_channels: list[TextChannel] = []
for channel in channels:
channel_type = ChannelType(channel['type'])
match channel_type:
case ChannelType.GUILD_CATEGORY:
categories.append(self._parse_channel_category(channel))
case ChannelType.GUILD_TEXT:
text_channels.append(self._parse_text_channel(channel))
return categories, text_channels
def list_text_channel_messages(self, channel: TextChannel) -> list:
headers, messages = self._send_request(*Api.Message.list_by_channel(channel.id))
self._update_rate_limit(headers)
return [self._parse_message(m) for m in messages]