Bot config and channel init

This commit is contained in:
BreadTube 2025-09-23 04:50:23 +09:00 committed by Corentin
commit 72edbe6599
7 changed files with 499 additions and 109 deletions

View file

@ -1,26 +1,39 @@
from dataclasses import dataclass
from datetime import datetime
from dataclasses import asdict, dataclass, is_dataclass
from enum import Enum
import logging
from pathlib import Path
import json
import random
import time
import tomllib
import urllib.error
import urllib.request
from .api import Api, ApiAction, ApiVersion
from .config import Config
from .logger import create_logger
from .objects import (
ChannelCategory, ChannelFlags, ChannelType, Message, Overwrite, OverwriteType, Permissions, TextChannel, User)
ChannelCategory, ChannelType, FileMime, Message, Overwrite, OverwriteType, Permissions, Role, TextChannel)
HTTPHeaders = dict[str, str]
@dataclass
class _RateLimit:
remaining: int
next_reset: float
class ApiEncoder(json.JSONEncoder):
def default(self, o):
if is_dataclass(o):
return asdict(o) # type: ignore
if isinstance(o, Enum):
return o.value
return super().default(o)
class DiscordManager:
@dataclass
class RateLimit:
remaining: int
next_reset: float
@staticmethod
def _get_code_version() -> str:
pyproject_path = Path(__file__).parents[1] / 'pyproject.toml'
@ -28,133 +41,176 @@ class DiscordManager:
raise RuntimeError('Cannot current bot version')
return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version']
def __init__(self, bot_token: str, guild_id: int) -> None:
def __init__(self, bot_token: str, guild_id: int, config: Config | None = None,
log_level: int = logging.INFO) -> None:
self.config = config or Config()
self.guild_id = guild_id
self._bot_token = bot_token
self.logger = create_logger('breadtube', log_level, stdout=True)
self.rate_limit: _RateLimit = _RateLimit(remaining=1, next_reset=0)
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
self.version = self._get_code_version()
self.guild_roles: list = self.list_roles()
for _ in range(self.config.bot_channel_init_retries):
while not self.init_bot_channel():
time.sleep(10)
break
else:
self.logger.info('Bot init OK')
break
raise RuntimeError("Couldn't initialize bot channel/role/permission")
self.logger.info('Bot initialized')
def _update_rate_limit(self, headers: HTTPHeaders):
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
if header_key not in headers:
print(f'Warning: no "{header_key}" found in headers')
self.logger.info('Warning: no "%s" found in headers', header_key)
return
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
expected_code: int = 200) -> tuple[
HTTPHeaders, dict]:
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]:
timeout = 3
min_api_version = 9
if api_action == ApiAction.POST:
raise NotImplementedError
if api_version.value < min_api_version:
print(f'Warning: using deprecated API version {api_version} (minimum non deprecated is {min_api_version})')
self.logger.warning(
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
api_version, min_api_version)
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
request = urllib.request.Request(url)
boundary: str = ''
if upload_files:
boundary = f'{random.randbytes(16).hex()}'
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
'Content-Type: application/json\r\n\r\n'.encode() + data
+ f'\r\n--{boundary}'.encode()) if data else b''
for file_index, (name, mime, content) in enumerate(upload_files):
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
f' filename="{name}"\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
data += f'\r\n--{boundary}--'.encode()
request = urllib.request.Request(url, data=data)
request.method = api_action.value
request.add_header('User-Agent', f'BreadTube (v{self.version})')
request.add_header('Accept', 'application/json')
if upload_files:
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
else:
request.add_header('Content-Type', 'application/json')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
return dict(response.getheaders()), json.loads(response.read().decode())
body = response.read()
return dict(response.getheaders()), json.loads(body.decode()) if body else None
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
@staticmethod
def _parse_overwrite(info: dict) -> Overwrite:
return Overwrite(
id=int(info['id']),
type=OverwriteType(info['type']),
allow=Permissions(int(info['allow'])),
deny=Permissions(int(info['deny']))
)
def init_bot_channel(self) -> bool:
_, text_channel = self.list_channels()
breadtube_role: Role | None = None
everyone_role: Role | None = None
for role in self.guild_roles:
if role.name == self.config.bot_role:
breadtube_role = role
elif role.name == '@everyone':
everyone_role = role
if breadtube_role is None:
self.logger.info('No BreadTube role found')
return False
if everyone_role is None:
self.logger.info('No everyone role found')
return False
def _parse_channel_category(self, info: dict) -> ChannelCategory:
parent_id: str | None = info.get('parent_id')
return ChannelCategory(
id=int(info['id']),
guild_id=int(info['guild_id']),
position=int(info['position']),
permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']],
name=info.get('name'),
parent_id=int(parent_id) if parent_id is not None else None,
flags=ChannelFlags(info['flags']),
)
breadtube_channel: TextChannel | None = None
for channel in text_channel:
if channel.name == self.config.bot_channel:
breadtube_channel = channel
self.logger.info('Found breadtube bot channel')
for perm in breadtube_channel.permission_overwrites:
if perm.id == breadtube_role.id:
if not perm.allow | Permissions.VIEW_CHANNEL:
self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing')
return False
self.logger.info('BreadTube channel permission OK')
break
messages = self.list_text_channel_messages(breadtube_channel)
for message in messages:
self.logger.debug('Deleting message: %s', message)
self.delete_message(message)
break
else:
breadtube_channel = self.create_text_channel({
'name': self.config.bot_channel,
'permission_overwrites': [
Overwrite(everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
deny=Permissions.VIEW_CHANNEL),
Overwrite(breadtube_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL,
deny=Permissions.NONE)]
})
self.logger.info('Created breadtube bot channel')
def _parse_text_channel(self, info: dict) -> TextChannel:
parent_id: str | None = info.get('parent_id')
last_message_id: str | None = info.get('last_message_id')
last_pin_timestamp: str | None = info.get('last_pin_timestamp')
return TextChannel(
id=int(info['id']),
guild_id=int(info['guild_id']),
position=int(info['position']),
permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']],
name=info.get('name'),
topic=info.get('topic'),
nsfw=info['nsfw'],
last_message_id=int(last_message_id) if last_message_id is not None else None,
rate_limit_per_user=int(info['rate_limit_per_user']),
parent_id=int(parent_id) if parent_id is not None else None,
last_pin_timestamp=(datetime.fromisoformat(last_pin_timestamp) if last_pin_timestamp is not None else None),
flags=ChannelFlags(info['flags']),
)
self.create_message(
breadtube_channel,
{'content': 'This is the current configuration used, upload a new one to update the configuration'},
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
return True
@staticmethod
def _parse_user(info: dict) -> User:
return User(
id=int(info['id']),
username=info['username'],
discriminator=info['discriminator'],
global_name=info.get('global_name')
)
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
headers, channel_info = self._send_request(
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
expected_code=201)
self._update_rate_limit(headers)
if not isinstance(channel_info, dict):
raise RuntimeError(f'Error creating channel with params (no info): {params}')
return TextChannel.from_dict(channel_info)
def _parse_message(self, info: dict) -> Message:
edited_timestamp: str | None = info.get('edited_timestamp')
return Message(
id=int(info['id']),
channel_id=int(info['channel_id']),
author=(self._parse_user(info['author']) if info.get('webhook_id') is None else User(
id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)),
content=info['content'],
timestamp=datetime.fromisoformat(info['timestamp']),
edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None
)
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
headers, message_info = self._send_request(
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
upload_files=upload_files)
self._update_rate_limit(headers)
if not isinstance(message_info, dict):
raise RuntimeError(f'Error creating message with params (no info): {params}')
return Message.from_dict(message_info)
def delete_message(self, message: Message):
try:
headers, _ = self._send_request(
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
self._update_rate_limit(headers)
print(f'Message {message.id} deleted')
except RuntimeError as error:
print(error)
headers, _ = self._send_request(
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
self._update_rate_limit(headers)
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
headers, channels = self._send_request(*Api.Guild.list_guilds(self.guild_id))
headers, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
self._update_rate_limit(headers)
categories: list[ChannelCategory] = []
text_channels: list[TextChannel] = []
for channel in channels:
channel_type = ChannelType(channel['type'])
match channel_type:
case ChannelType.GUILD_CATEGORY:
categories.append(self._parse_channel_category(channel))
case ChannelType.GUILD_TEXT:
text_channels.append(self._parse_text_channel(channel))
if channels_info is not None:
for channel_info in channels_info:
channel_type = ChannelType(channel_info['type'])
match channel_type:
case ChannelType.GUILD_CATEGORY:
categories.append(ChannelCategory.from_dict(channel_info))
case ChannelType.GUILD_TEXT:
text_channels.append(TextChannel.from_dict(channel_info))
return categories, text_channels
def list_text_channel_messages(self, channel: TextChannel) -> list:
def list_roles(self) -> list[Role]:
headers, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
self._update_rate_limit(headers)
if not isinstance(roles_info, list):
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
return [Role.from_dict(r) for r in roles_info]
def list_text_channel_messages(self, channel: TextChannel) -> list[Message]:
headers, messages = self._send_request(*Api.Message.list_by_channel(channel.id))
self._update_rate_limit(headers)
return [self._parse_message(m) for m in messages]
return [Message.from_dict(m) for m in messages or []]