breadtube-bot/breadtube_bot/manager.py
2025-10-04 16:51:08 +09:00

401 lines
20 KiB
Python

from __future__ import annotations
from dataclasses import asdict, dataclass, is_dataclass
from enum import Enum
import logging
import operator
from pathlib import Path
import json
import random
import time
import tomllib
from typing import Any
import urllib.error
import urllib.request
import traceback
from breadtube_bot.youtube_api import YoutubeManager
from .api import Api, ApiAction, ApiVersion
from .config import Config
from .logger import create_logger
from .objects import (
Attachment, ChannelCategory, ChannelType, FileMime, HTTPHeaders, Message, MessageReference, MessageReferenceType,
Overwrite, OverwriteType, Permissions, Role, TextChannel, User)
class ApiEncoder(json.JSONEncoder):
def default(self, o):
if is_dataclass(o):
return asdict(o) # type: ignore
if isinstance(o, Enum):
return o.value
return super().default(o)
class DiscordManager:
MAX_CONFIG_SIZE: int = 50_000
DEFAULT_MESSAGE_LIST_LIMIT = 50
INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n'
'You can upload a new one to update the configuration.')
@dataclass
class RateLimit:
remaining: int
next_reset: float
@dataclass
class YoutTubeChannel:
name: str
channel_id: str
last_update: float
class Task(Enum):
SCAN_BOT_CHANNEL = 1
DELETE_MESSAGES = 2
@staticmethod
def _get_code_version() -> str:
pyproject_path = Path(__file__).parents[1] / 'pyproject.toml'
if not pyproject_path.exists():
raise RuntimeError('Cannot current bot version')
return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version']
def __init__(self, bot_token: str, guild_id: int, yt_api_key: str, config: Config | None = None,
log_level: int = logging.INFO) -> None:
self.config = config or Config()
self.guild_id = guild_id
self._bot_token = bot_token
self.logger = create_logger('breadtube', log_level, stdout=True)
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
self.version = self._get_code_version()
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
self.logger.info('Retrieving bot user')
self.bot_user = self.get_current_user()
self.logger.info('Retrieving guild roles before init')
self.guild_roles: list = self.list_roles()
self.bot_channel: TextChannel | None = None
self.init_message: Message | None = None
for _ in range(self.config.bot_channel_init_retries):
while not self.init_bot_channel():
time.sleep(10)
break
else:
self.logger.info('Bot init OK')
break
raise RuntimeError("Couldn't initialize bot channel/role/permission")
self._scan_bot_channel()
self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
self.yt_manager = YoutubeManager(api_key=yt_api_key, logger=self.logger)
self.logger.info('Bot initialized')
def _update_rate_limit(self, headers: HTTPHeaders):
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
if header_key not in headers:
self.logger.info('Warning: no "%s" found in headers', header_key)
return
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
self.logger.debug('Updated rate limit: %s', self.rate_limit)
if self.rate_limit.remaining <= 0:
sleep_time = self.rate_limit.next_reset - time.time()
if sleep_time > 0:
self.logger.debug('Rate limit: sleeping %.03f second', sleep_time)
time.sleep(sleep_time)
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]:
min_api_version = 9
if api_version.value < min_api_version:
self.logger.warning(
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
api_version, min_api_version)
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
boundary: str = ''
if upload_files:
boundary = f'{random.randbytes(16).hex()}'
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
'Content-Type: application/json\r\n\r\n'.encode() + data
+ f'\r\n--{boundary}'.encode()) if data else b''
for file_index, (name, mime, content) in enumerate(upload_files):
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
f' filename="{name}"\r\nContent-Length: {len(content)}'
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
data += f'\r\n--{boundary}--'.encode()
request = urllib.request.Request(url, data=data)
request.method = api_action.value
request.add_header('User-Agent', f'BreadTube (v{self.version})')
request.add_header('Accept', 'application/json')
if upload_files:
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
else:
request.add_header('Content-Type', 'application/json')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
body = response.read()
headers = dict(response.getheaders())
self._update_rate_limit(headers)
return headers, json.loads(body.decode()) if body else None
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
except TimeoutError as error:
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
request = urllib.request.Request(attachment.url)
request.add_header('User-Agent', f'BreadTube (v{self.version})')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
return dict(response.getheaders()), response.read()
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error downloading attachment ({attachment}):'
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
def init_bot_channel(self) -> bool:
_, text_channel = self.list_channels()
breadtube_role: Role | None = None
everyone_role: Role | None = None
for role in self.guild_roles:
if role.name == self.config.bot_role:
breadtube_role = role
elif role.name == '@everyone':
everyone_role = role
if breadtube_role is None:
self.logger.info('No BreadTube role found')
return False
if everyone_role is None:
self.logger.info('No everyone role found')
return False
for channel in text_channel:
if channel.name == self.config.bot_channel:
self.bot_channel = channel
self.logger.info('Found breadtube bot channel')
for perm in self.bot_channel.permission_overwrites:
if perm.id == breadtube_role.id:
if not perm.allow | Permissions.VIEW_CHANNEL:
self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing')
return False
self.logger.info('BreadTube channel permission OK')
break
break
else:
self.bot_channel = self.create_text_channel({
'name': self.config.bot_channel,
'permission_overwrites': [
Overwrite(everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
deny=Permissions.VIEW_CHANNEL),
Overwrite(breadtube_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL,
deny=Permissions.NONE)]
})
self.logger.info('Created breadtube bot channel')
return True
def _scan_bot_channel(self):
if self.bot_channel is None:
self.logger.error('Cannot scan bot channel: bot channel is None')
return []
messages_id_delete_task: set[int] = set()
for task_type, _, task_params in self.tasks:
if task_type == self.Task.DELETE_MESSAGES:
messages_id_delete_task.update(message.id for message in task_params)
last_message_id: int | None = None
messages: list[Message] = []
while True:
message_batch = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
messages.extend([m for m in message_batch if m.id not in messages_id_delete_task])
if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT:
break
last_message_id = message_batch[-1].id
self.init_message = None
new_config: Config | None = None
delayed_delete: dict[int, Message] = {}
immediate_delete: dict[int, Message] = {}
for message in messages:
if message.id in delayed_delete:
self.logger.debug('Skipping message already marked to be deleted')
continue
if self.init_message is None and new_config is None and len(message.attachments) == 1:
attachment = message.attachments[0]
if attachment.size < self.MAX_CONFIG_SIZE:
try:
_, content = self._download_attachment(attachment)
if content.startswith(b'config'):
try:
config = Config.from_str(content.decode())
if message.author.id == self.bot_user.id: # keep using current config
self.logger.debug('Found previous init message')
self.init_message = message
if config != self.config: # First scan qill need to load config
self.config = config
continue
if config != self.config: # New config to update to
new_config = config
self.logger.debug('Marking new config message for immediate deletion: %s', message)
immediate_delete[message.id] = message
continue
except RuntimeError as error:
self.logger.info('Invalid config file: %s', error)
bot_message = self.create_message(self.bot_channel, {
'content': str(error),
'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT,
message_id=message.id,
channel_id=self.bot_channel.id,
guild_id=None,
fail_if_not_exists=None)})
delayed_delete[bot_message.id] = bot_message
delayed_delete[message.id] = message
continue
except Exception as error:
self.logger.error('Error downloading attachment: %s', error)
self.logger.debug('Marking message for immediate deletion: %s', message)
immediate_delete[message.id] = message
if new_config is not None:
self.logger.info('Loading config: %s', new_config)
self.config = new_config
if self.init_message is None:
assert self.config is not None
self.init_message = self.create_message(
self.bot_channel, {'content': self.INIT_MESSAGE},
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
for message in immediate_delete.values():
try:
self.delete_message(message)
except RuntimeError as error:
self.logger.error('Error deleting after bot channel scan (immediate): %s', error)
if delayed_delete:
self.tasks.append((
DiscordManager.Task.DELETE_MESSAGES,
time.time() + self.config.bot_message_duration,
list(delayed_delete.values())))
def run(self):
while True:
if self.tasks:
self.tasks = sorted(self.tasks, key=operator.itemgetter(1), reverse=True)
task_type, task_time, task_params = self.tasks.pop()
sleep_time = task_time - time.time()
self.logger.debug(
'Next task %s at %.03f (sleeping for %.03fs) : %s', task_type, task_time, sleep_time, task_params)
if sleep_time > 0:
time.sleep(sleep_time)
match task_type:
case DiscordManager.Task.SCAN_BOT_CHANNEL:
try:
self._scan_bot_channel()
except Exception as error:
self.logger.error('Error scanning bot channel: %s -> %s',
error, traceback.format_exc().replace('\n', ' | '))
self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
case DiscordManager.Task.DELETE_MESSAGES:
if not isinstance(task_params, list):
self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params)
elif not task_params:
self.logger.error('Empty params for DELETE_MESSAGES: %s', task_params)
elif any(not isinstance(v, Message) for v in task_params):
self.logger.error('All params not int for DELETE_MESSAGES: %s', task_params)
else:
for message in task_params:
try:
self.delete_message(message)
except Exception as error:
self.logger.error('Error deleting message %s: %s -> %s',
message, error, traceback.format_exc().replace('\n', ' | '))
time.sleep(1)
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
_, channel_info = self._send_request(
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
expected_code=201)
if not isinstance(channel_info, dict):
raise RuntimeError(f'Error creating channel with params (no info): {params}')
return TextChannel.from_dict(channel_info)
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
_, message_info = self._send_request(
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
upload_files=upload_files)
if not isinstance(message_info, dict):
raise RuntimeError(f'Error creating message with params (no info): {params}')
return Message.from_dict(message_info)
def delete_message(self, message: Message):
_, _ = self._send_request(
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
def get_current_user(self) -> User:
_, user_info = self._send_request(*Api.User.get_current())
if not isinstance(user_info, dict):
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
return User.from_dict(user_info)
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
_, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
categories: list[ChannelCategory] = []
text_channels: list[TextChannel] = []
if channels_info is not None:
for channel_info in channels_info:
channel_type = ChannelType(channel_info['type'])
match channel_type:
case ChannelType.GUILD_CATEGORY:
categories.append(ChannelCategory.from_dict(channel_info))
case ChannelType.GUILD_TEXT:
text_channels.append(TextChannel.from_dict(channel_info))
return categories, text_channels
def list_roles(self) -> list[Role]:
_, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
if not isinstance(roles_info, list):
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
return [Role.from_dict(r) for r in roles_info]
def list_text_channel_messages(self, channel: TextChannel, limit: int | None = None, before_id: int | None = None,
after_id: int | None = None) -> list[Message]:
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
params: Api.Message.ListMessageParams = {}
if limit is not None:
params['limit'] = limit
if before_id is not None:
params['before'] = before_id
if after_id is not None:
params['after'] = after_id
headers, messages_info = self._send_request(
*Api.Message.list_by_channel(channel.id),
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
self._update_rate_limit(headers)
return [Message.from_dict(m) for m in messages_info or []]