160 lines
6.9 KiB
Python
160 lines
6.9 KiB
Python
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
import json
|
|
import tomllib
|
|
import urllib.error
|
|
import urllib.request
|
|
|
|
from .api import Api, ApiAction, ApiVersion
|
|
from .objects import (
|
|
ChannelCategory, ChannelFlags, ChannelType, Message, Overwrite, OverwriteType, Permissions, TextChannel, User)
|
|
|
|
|
|
HTTPHeaders = dict[str, str]
|
|
|
|
|
|
@dataclass
|
|
class _RateLimit:
|
|
remaining: int
|
|
next_reset: float
|
|
|
|
|
|
class DiscordManager:
|
|
@staticmethod
|
|
def _get_code_version() -> str:
|
|
pyproject_path = Path(__file__).parents[1] / 'pyproject.toml'
|
|
if not pyproject_path.exists():
|
|
raise RuntimeError('Cannot current bot version')
|
|
return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version']
|
|
|
|
def __init__(self, bot_token: str, guild_id: int) -> None:
|
|
self.guild_id = guild_id
|
|
self._bot_token = bot_token
|
|
|
|
self.rate_limit: _RateLimit = _RateLimit(remaining=1, next_reset=0)
|
|
self.version = self._get_code_version()
|
|
|
|
def _update_rate_limit(self, headers: HTTPHeaders):
|
|
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
|
|
if header_key not in headers:
|
|
print(f'Warning: no "{header_key}" found in headers')
|
|
return
|
|
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
|
|
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
|
|
|
|
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
|
|
expected_code: int = 200) -> tuple[
|
|
HTTPHeaders, dict]:
|
|
timeout = 3
|
|
min_api_version = 9
|
|
|
|
if api_action == ApiAction.POST:
|
|
raise NotImplementedError
|
|
if api_version.value < min_api_version:
|
|
print(f'Warning: using deprecated API version {api_version} (minimum non deprecated is {min_api_version})')
|
|
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
|
|
request = urllib.request.Request(url)
|
|
request.add_header('User-Agent', f'BreadTube (v{self.version})')
|
|
request.add_header('Accept', 'application/json')
|
|
request.add_header('Authorization', f'Bot {self._bot_token}')
|
|
try:
|
|
with urllib.request.urlopen(request, timeout=timeout) as response:
|
|
if response.status != expected_code:
|
|
raise RuntimeError(
|
|
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
|
|
return dict(response.getheaders()), json.loads(response.read().decode())
|
|
except urllib.error.HTTPError as error:
|
|
raise RuntimeError(
|
|
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
|
except urllib.error.URLError as error:
|
|
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
|
|
|
|
@staticmethod
|
|
def _parse_overwrite(info: dict) -> Overwrite:
|
|
return Overwrite(
|
|
id=int(info['id']),
|
|
type=OverwriteType(info['type']),
|
|
allow=Permissions(int(info['allow'])),
|
|
deny=Permissions(int(info['deny']))
|
|
)
|
|
|
|
def _parse_channel_category(self, info: dict) -> ChannelCategory:
|
|
parent_id: str | None = info.get('parent_id')
|
|
return ChannelCategory(
|
|
id=int(info['id']),
|
|
guild_id=int(info['guild_id']),
|
|
position=int(info['position']),
|
|
permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']],
|
|
name=info.get('name'),
|
|
parent_id=int(parent_id) if parent_id is not None else None,
|
|
flags=ChannelFlags(info['flags']),
|
|
)
|
|
|
|
def _parse_text_channel(self, info: dict) -> TextChannel:
|
|
parent_id: str | None = info.get('parent_id')
|
|
last_message_id: str | None = info.get('last_message_id')
|
|
last_pin_timestamp: str | None = info.get('last_pin_timestamp')
|
|
return TextChannel(
|
|
id=int(info['id']),
|
|
guild_id=int(info['guild_id']),
|
|
position=int(info['position']),
|
|
permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']],
|
|
name=info.get('name'),
|
|
topic=info.get('topic'),
|
|
nsfw=info['nsfw'],
|
|
last_message_id=int(last_message_id) if last_message_id is not None else None,
|
|
rate_limit_per_user=int(info['rate_limit_per_user']),
|
|
parent_id=int(parent_id) if parent_id is not None else None,
|
|
last_pin_timestamp=(datetime.fromisoformat(last_pin_timestamp) if last_pin_timestamp is not None else None),
|
|
flags=ChannelFlags(info['flags']),
|
|
)
|
|
|
|
@staticmethod
|
|
def _parse_user(info: dict) -> User:
|
|
return User(
|
|
id=int(info['id']),
|
|
username=info['username'],
|
|
discriminator=info['discriminator'],
|
|
global_name=info.get('global_name')
|
|
)
|
|
|
|
def _parse_message(self, info: dict) -> Message:
|
|
edited_timestamp: str | None = info.get('edited_timestamp')
|
|
return Message(
|
|
id=int(info['id']),
|
|
channel_id=int(info['channel_id']),
|
|
author=(self._parse_user(info['author']) if info.get('webhook_id') is None else User(
|
|
id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)),
|
|
content=info['content'],
|
|
timestamp=datetime.fromisoformat(info['timestamp']),
|
|
edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None
|
|
)
|
|
|
|
def delete_message(self, message: Message):
|
|
try:
|
|
headers, _ = self._send_request(
|
|
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
|
|
self._update_rate_limit(headers)
|
|
print(f'Message {message.id} deleted')
|
|
except RuntimeError as error:
|
|
print(error)
|
|
|
|
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
|
|
headers, channels = self._send_request(*Api.Guild.list_guilds(self.guild_id))
|
|
self._update_rate_limit(headers)
|
|
categories: list[ChannelCategory] = []
|
|
text_channels: list[TextChannel] = []
|
|
for channel in channels:
|
|
channel_type = ChannelType(channel['type'])
|
|
match channel_type:
|
|
case ChannelType.GUILD_CATEGORY:
|
|
categories.append(self._parse_channel_category(channel))
|
|
case ChannelType.GUILD_TEXT:
|
|
text_channels.append(self._parse_text_channel(channel))
|
|
return categories, text_channels
|
|
|
|
def list_text_channel_messages(self, channel: TextChannel) -> list:
|
|
headers, messages = self._send_request(*Api.Message.list_by_channel(channel.id))
|
|
self._update_rate_limit(headers)
|
|
return [self._parse_message(m) for m in messages]
|