401 lines
20 KiB
Python
401 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
from dataclasses import asdict, dataclass, is_dataclass
|
|
from enum import Enum
|
|
import logging
|
|
import operator
|
|
from pathlib import Path
|
|
import json
|
|
import random
|
|
import time
|
|
import tomllib
|
|
from typing import Any
|
|
import urllib.error
|
|
import urllib.request
|
|
import traceback
|
|
|
|
from breadtube_bot.youtube_api import YoutubeManager
|
|
|
|
from .api import Api, ApiAction, ApiVersion
|
|
from .config import Config
|
|
from .logger import create_logger
|
|
from .objects import (
|
|
Attachment, ChannelCategory, ChannelType, FileMime, HTTPHeaders, Message, MessageReference, MessageReferenceType,
|
|
Overwrite, OverwriteType, Permissions, Role, TextChannel, User)
|
|
|
|
|
|
class ApiEncoder(json.JSONEncoder):
|
|
def default(self, o):
|
|
if is_dataclass(o):
|
|
return asdict(o) # type: ignore
|
|
if isinstance(o, Enum):
|
|
return o.value
|
|
return super().default(o)
|
|
|
|
|
|
class DiscordManager:
|
|
MAX_CONFIG_SIZE: int = 50_000
|
|
DEFAULT_MESSAGE_LIST_LIMIT = 50
|
|
INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n'
|
|
'You can upload a new one to update the configuration.')
|
|
|
|
@dataclass
|
|
class RateLimit:
|
|
remaining: int
|
|
next_reset: float
|
|
|
|
@dataclass
|
|
class YoutTubeChannel:
|
|
name: str
|
|
channel_id: str
|
|
last_update: float
|
|
|
|
class Task(Enum):
|
|
SCAN_BOT_CHANNEL = 1
|
|
DELETE_MESSAGES = 2
|
|
|
|
@staticmethod
|
|
def _get_code_version() -> str:
|
|
pyproject_path = Path(__file__).parents[1] / 'pyproject.toml'
|
|
if not pyproject_path.exists():
|
|
raise RuntimeError('Cannot current bot version')
|
|
return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version']
|
|
|
|
def __init__(self, bot_token: str, guild_id: int, yt_api_key: str, config: Config | None = None,
|
|
log_level: int = logging.INFO) -> None:
|
|
self.config = config or Config()
|
|
self.guild_id = guild_id
|
|
self._bot_token = bot_token
|
|
self.logger = create_logger('breadtube', log_level, stdout=True)
|
|
|
|
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
|
|
self.version = self._get_code_version()
|
|
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
|
|
|
|
self.logger.info('Retrieving bot user')
|
|
self.bot_user = self.get_current_user()
|
|
self.logger.info('Retrieving guild roles before init')
|
|
self.guild_roles: list = self.list_roles()
|
|
self.bot_channel: TextChannel | None = None
|
|
self.init_message: Message | None = None
|
|
|
|
for _ in range(self.config.bot_channel_init_retries):
|
|
while not self.init_bot_channel():
|
|
time.sleep(10)
|
|
break
|
|
else:
|
|
self.logger.info('Bot init OK')
|
|
break
|
|
raise RuntimeError("Couldn't initialize bot channel/role/permission")
|
|
|
|
self._scan_bot_channel()
|
|
self.tasks.append((
|
|
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
|
|
|
|
self.yt_manager = YoutubeManager(api_key=yt_api_key, logger=self.logger)
|
|
self.logger.info('Bot initialized')
|
|
|
|
def _update_rate_limit(self, headers: HTTPHeaders):
|
|
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
|
|
if header_key not in headers:
|
|
self.logger.info('Warning: no "%s" found in headers', header_key)
|
|
return
|
|
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
|
|
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
|
|
self.logger.debug('Updated rate limit: %s', self.rate_limit)
|
|
if self.rate_limit.remaining <= 0:
|
|
sleep_time = self.rate_limit.next_reset - time.time()
|
|
if sleep_time > 0:
|
|
self.logger.debug('Rate limit: sleeping %.03f second', sleep_time)
|
|
time.sleep(sleep_time)
|
|
|
|
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
|
|
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
|
|
expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]:
|
|
min_api_version = 9
|
|
|
|
if api_version.value < min_api_version:
|
|
self.logger.warning(
|
|
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
|
|
api_version, min_api_version)
|
|
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
|
|
|
|
boundary: str = ''
|
|
if upload_files:
|
|
boundary = f'{random.randbytes(16).hex()}'
|
|
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
|
|
'Content-Type: application/json\r\n\r\n'.encode() + data
|
|
+ f'\r\n--{boundary}'.encode()) if data else b''
|
|
for file_index, (name, mime, content) in enumerate(upload_files):
|
|
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
|
|
f' filename="{name}"\r\nContent-Length: {len(content)}'
|
|
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
|
|
data += f'\r\n--{boundary}--'.encode()
|
|
request = urllib.request.Request(url, data=data)
|
|
request.method = api_action.value
|
|
request.add_header('User-Agent', f'BreadTube (v{self.version})')
|
|
request.add_header('Accept', 'application/json')
|
|
if upload_files:
|
|
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
|
|
else:
|
|
request.add_header('Content-Type', 'application/json')
|
|
request.add_header('Authorization', f'Bot {self._bot_token}')
|
|
try:
|
|
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
|
|
if response.status != expected_code:
|
|
raise RuntimeError(
|
|
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
|
|
body = response.read()
|
|
headers = dict(response.getheaders())
|
|
self._update_rate_limit(headers)
|
|
return headers, json.loads(body.decode()) if body else None
|
|
except urllib.error.HTTPError as error:
|
|
raise RuntimeError(
|
|
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
|
except urllib.error.URLError as error:
|
|
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
|
|
except TimeoutError as error:
|
|
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
|
|
|
|
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
|
|
request = urllib.request.Request(attachment.url)
|
|
request.add_header('User-Agent', f'BreadTube (v{self.version})')
|
|
request.add_header('Authorization', f'Bot {self._bot_token}')
|
|
try:
|
|
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
|
|
if response.status != expected_code:
|
|
raise RuntimeError(
|
|
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
|
|
return dict(response.getheaders()), response.read()
|
|
except urllib.error.HTTPError as error:
|
|
raise RuntimeError(
|
|
f'HTTP error downloading attachment ({attachment}):'
|
|
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
|
except urllib.error.URLError as error:
|
|
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
|
|
|
|
def init_bot_channel(self) -> bool:
|
|
_, text_channel = self.list_channels()
|
|
breadtube_role: Role | None = None
|
|
everyone_role: Role | None = None
|
|
for role in self.guild_roles:
|
|
if role.name == self.config.bot_role:
|
|
breadtube_role = role
|
|
elif role.name == '@everyone':
|
|
everyone_role = role
|
|
if breadtube_role is None:
|
|
self.logger.info('No BreadTube role found')
|
|
return False
|
|
if everyone_role is None:
|
|
self.logger.info('No everyone role found')
|
|
return False
|
|
|
|
for channel in text_channel:
|
|
if channel.name == self.config.bot_channel:
|
|
self.bot_channel = channel
|
|
self.logger.info('Found breadtube bot channel')
|
|
for perm in self.bot_channel.permission_overwrites:
|
|
if perm.id == breadtube_role.id:
|
|
if not perm.allow | Permissions.VIEW_CHANNEL:
|
|
self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing')
|
|
return False
|
|
self.logger.info('BreadTube channel permission OK')
|
|
break
|
|
break
|
|
else:
|
|
self.bot_channel = self.create_text_channel({
|
|
'name': self.config.bot_channel,
|
|
'permission_overwrites': [
|
|
Overwrite(everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
|
|
deny=Permissions.VIEW_CHANNEL),
|
|
Overwrite(breadtube_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL,
|
|
deny=Permissions.NONE)]
|
|
})
|
|
self.logger.info('Created breadtube bot channel')
|
|
return True
|
|
|
|
def _scan_bot_channel(self):
|
|
if self.bot_channel is None:
|
|
self.logger.error('Cannot scan bot channel: bot channel is None')
|
|
return []
|
|
|
|
messages_id_delete_task: set[int] = set()
|
|
for task_type, _, task_params in self.tasks:
|
|
if task_type == self.Task.DELETE_MESSAGES:
|
|
messages_id_delete_task.update(message.id for message in task_params)
|
|
|
|
last_message_id: int | None = None
|
|
messages: list[Message] = []
|
|
while True:
|
|
message_batch = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
|
|
messages.extend([m for m in message_batch if m.id not in messages_id_delete_task])
|
|
if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT:
|
|
break
|
|
last_message_id = message_batch[-1].id
|
|
|
|
self.init_message = None
|
|
new_config: Config | None = None
|
|
delayed_delete: dict[int, Message] = {}
|
|
immediate_delete: dict[int, Message] = {}
|
|
for message in messages:
|
|
if message.id in delayed_delete:
|
|
self.logger.debug('Skipping message already marked to be deleted')
|
|
continue
|
|
|
|
if self.init_message is None and new_config is None and len(message.attachments) == 1:
|
|
attachment = message.attachments[0]
|
|
if attachment.size < self.MAX_CONFIG_SIZE:
|
|
try:
|
|
_, content = self._download_attachment(attachment)
|
|
if content.startswith(b'config'):
|
|
try:
|
|
config = Config.from_str(content.decode())
|
|
if message.author.id == self.bot_user.id: # keep using current config
|
|
self.logger.debug('Found previous init message')
|
|
self.init_message = message
|
|
if config != self.config: # First scan qill need to load config
|
|
self.config = config
|
|
continue
|
|
if config != self.config: # New config to update to
|
|
new_config = config
|
|
self.logger.debug('Marking new config message for immediate deletion: %s', message)
|
|
immediate_delete[message.id] = message
|
|
continue
|
|
except RuntimeError as error:
|
|
self.logger.info('Invalid config file: %s', error)
|
|
bot_message = self.create_message(self.bot_channel, {
|
|
'content': str(error),
|
|
'message_reference': MessageReference(
|
|
type=MessageReferenceType.DEFAULT,
|
|
message_id=message.id,
|
|
channel_id=self.bot_channel.id,
|
|
guild_id=None,
|
|
fail_if_not_exists=None)})
|
|
delayed_delete[bot_message.id] = bot_message
|
|
delayed_delete[message.id] = message
|
|
continue
|
|
except Exception as error:
|
|
self.logger.error('Error downloading attachment: %s', error)
|
|
self.logger.debug('Marking message for immediate deletion: %s', message)
|
|
immediate_delete[message.id] = message
|
|
|
|
if new_config is not None:
|
|
self.logger.info('Loading config: %s', new_config)
|
|
self.config = new_config
|
|
|
|
if self.init_message is None:
|
|
assert self.config is not None
|
|
self.init_message = self.create_message(
|
|
self.bot_channel, {'content': self.INIT_MESSAGE},
|
|
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
|
|
|
|
for message in immediate_delete.values():
|
|
try:
|
|
self.delete_message(message)
|
|
except RuntimeError as error:
|
|
self.logger.error('Error deleting after bot channel scan (immediate): %s', error)
|
|
|
|
if delayed_delete:
|
|
self.tasks.append((
|
|
DiscordManager.Task.DELETE_MESSAGES,
|
|
time.time() + self.config.bot_message_duration,
|
|
list(delayed_delete.values())))
|
|
|
|
def run(self):
|
|
while True:
|
|
if self.tasks:
|
|
self.tasks = sorted(self.tasks, key=operator.itemgetter(1), reverse=True)
|
|
task_type, task_time, task_params = self.tasks.pop()
|
|
sleep_time = task_time - time.time()
|
|
self.logger.debug(
|
|
'Next task %s at %.03f (sleeping for %.03fs) : %s', task_type, task_time, sleep_time, task_params)
|
|
if sleep_time > 0:
|
|
time.sleep(sleep_time)
|
|
match task_type:
|
|
case DiscordManager.Task.SCAN_BOT_CHANNEL:
|
|
try:
|
|
self._scan_bot_channel()
|
|
except Exception as error:
|
|
self.logger.error('Error scanning bot channel: %s -> %s',
|
|
error, traceback.format_exc().replace('\n', ' | '))
|
|
self.tasks.append((
|
|
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
|
|
case DiscordManager.Task.DELETE_MESSAGES:
|
|
if not isinstance(task_params, list):
|
|
self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params)
|
|
elif not task_params:
|
|
self.logger.error('Empty params for DELETE_MESSAGES: %s', task_params)
|
|
elif any(not isinstance(v, Message) for v in task_params):
|
|
self.logger.error('All params not int for DELETE_MESSAGES: %s', task_params)
|
|
else:
|
|
for message in task_params:
|
|
try:
|
|
self.delete_message(message)
|
|
except Exception as error:
|
|
self.logger.error('Error deleting message %s: %s -> %s',
|
|
message, error, traceback.format_exc().replace('\n', ' | '))
|
|
time.sleep(1)
|
|
|
|
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
|
|
_, channel_info = self._send_request(
|
|
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
|
expected_code=201)
|
|
if not isinstance(channel_info, dict):
|
|
raise RuntimeError(f'Error creating channel with params (no info): {params}')
|
|
return TextChannel.from_dict(channel_info)
|
|
|
|
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
|
|
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
|
|
_, message_info = self._send_request(
|
|
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
|
upload_files=upload_files)
|
|
if not isinstance(message_info, dict):
|
|
raise RuntimeError(f'Error creating message with params (no info): {params}')
|
|
return Message.from_dict(message_info)
|
|
|
|
def delete_message(self, message: Message):
|
|
_, _ = self._send_request(
|
|
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
|
|
|
|
def get_current_user(self) -> User:
|
|
_, user_info = self._send_request(*Api.User.get_current())
|
|
if not isinstance(user_info, dict):
|
|
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
|
|
return User.from_dict(user_info)
|
|
|
|
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
|
|
_, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
|
|
categories: list[ChannelCategory] = []
|
|
text_channels: list[TextChannel] = []
|
|
if channels_info is not None:
|
|
for channel_info in channels_info:
|
|
channel_type = ChannelType(channel_info['type'])
|
|
match channel_type:
|
|
case ChannelType.GUILD_CATEGORY:
|
|
categories.append(ChannelCategory.from_dict(channel_info))
|
|
case ChannelType.GUILD_TEXT:
|
|
text_channels.append(TextChannel.from_dict(channel_info))
|
|
return categories, text_channels
|
|
|
|
def list_roles(self) -> list[Role]:
|
|
_, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
|
|
if not isinstance(roles_info, list):
|
|
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
|
|
return [Role.from_dict(r) for r in roles_info]
|
|
|
|
def list_text_channel_messages(self, channel: TextChannel, limit: int | None = None, before_id: int | None = None,
|
|
after_id: int | None = None) -> list[Message]:
|
|
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
|
|
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
|
|
params: Api.Message.ListMessageParams = {}
|
|
if limit is not None:
|
|
params['limit'] = limit
|
|
if before_id is not None:
|
|
params['before'] = before_id
|
|
if after_id is not None:
|
|
params['after'] = after_id
|
|
headers, messages_info = self._send_request(
|
|
*Api.Message.list_by_channel(channel.id),
|
|
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
|
|
self._update_rate_limit(headers)
|
|
return [Message.from_dict(m) for m in messages_info or []]
|