Config scan from bot channel implementation

This commit is contained in:
BreadTube 2025-09-23 22:48:35 +09:00 committed by Corentin
commit 157e8c1b17
6 changed files with 453 additions and 34 deletions

View file

@ -1,11 +1,15 @@
from __future__ import annotations
from dataclasses import asdict, dataclass, is_dataclass
from enum import Enum
import logging
import operator
from pathlib import Path
import json
import random
import time
import tomllib
from typing import Any
import urllib.error
import urllib.request
@ -13,7 +17,8 @@ from .api import Api, ApiAction, ApiVersion
from .config import Config
from .logger import create_logger
from .objects import (
ChannelCategory, ChannelType, FileMime, Message, Overwrite, OverwriteType, Permissions, Role, TextChannel)
Attachment, ChannelCategory, ChannelType, FileMime, Message, MessageReference, MessageReferenceType, Overwrite,
OverwriteType, Permissions, Role, TextChannel)
HTTPHeaders = dict[str, str]
@ -29,11 +34,20 @@ class ApiEncoder(json.JSONEncoder):
class DiscordManager:
MAX_CONFIG_SIZE: int = 50_000
DEFAULT_MESSAGE_LIST_LIMIT = 50
INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n'
'You can upload a new one to update the configuration.')
@dataclass
class RateLimit:
remaining: int
next_reset: float
class Task(Enum):
SCAN_BOT_CHANNEL = 1
DELETE_MESSAGES = 2
@staticmethod
def _get_code_version() -> str:
pyproject_path = Path(__file__).parents[1] / 'pyproject.toml'
@ -50,8 +64,12 @@ class DiscordManager:
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
self.version = self._get_code_version()
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
self.logger.info('Retrieving guild roles before init')
self.guild_roles: list = self.list_roles()
self.bot_channel: TextChannel | None = None
self.init_message: Message | None = None
for _ in range(self.config.bot_channel_init_retries):
while not self.init_bot_channel():
time.sleep(10)
@ -60,6 +78,10 @@ class DiscordManager:
self.logger.info('Bot init OK')
break
raise RuntimeError("Couldn't initialize bot channel/role/permission")
self._scan_bot_channel()
self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
self.logger.info('Bot initialized')
def _update_rate_limit(self, headers: HTTPHeaders):
@ -69,11 +91,11 @@ class DiscordManager:
return
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
self.logger.debug('Updated rate limit: %s', self.rate_limit)
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]:
timeout = 3
min_api_version = 9
if api_version.value < min_api_version:
@ -90,7 +112,8 @@ class DiscordManager:
+ f'\r\n--{boundary}'.encode()) if data else b''
for file_index, (name, mime, content) in enumerate(upload_files):
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
f' filename="{name}"\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
f' filename="{name}"\r\nContent-Length: {len(content)}'
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
data += f'\r\n--{boundary}--'.encode()
request = urllib.request.Request(url, data=data)
request.method = api_action.value
@ -102,7 +125,7 @@ class DiscordManager:
request.add_header('Content-Type', 'application/json')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=timeout) as response:
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
@ -114,6 +137,23 @@ class DiscordManager:
except urllib.error.URLError as error:
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
request = urllib.request.Request(attachment.url)
request.add_header('User-Agent', f'BreadTube (v{self.version})')
request.add_header('Authorization', f'Bot {self._bot_token}')
try:
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
if response.status != expected_code:
raise RuntimeError(
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
return dict(response.getheaders()), response.read()
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error downloading attachment ({attachment}):'
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
def init_bot_channel(self) -> bool:
_, text_channel = self.list_channels()
breadtube_role: Role | None = None
@ -130,25 +170,20 @@ class DiscordManager:
self.logger.info('No everyone role found')
return False
breadtube_channel: TextChannel | None = None
for channel in text_channel:
if channel.name == self.config.bot_channel:
breadtube_channel = channel
self.bot_channel = channel
self.logger.info('Found breadtube bot channel')
for perm in breadtube_channel.permission_overwrites:
for perm in self.bot_channel.permission_overwrites:
if perm.id == breadtube_role.id:
if not perm.allow | Permissions.VIEW_CHANNEL:
self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing')
return False
self.logger.info('BreadTube channel permission OK')
break
messages = self.list_text_channel_messages(breadtube_channel)
for message in messages:
self.logger.debug('Deleting message: %s', message)
self.delete_message(message)
break
else:
breadtube_channel = self.create_text_channel({
self.bot_channel = self.create_text_channel({
'name': self.config.bot_channel,
'permission_overwrites': [
Overwrite(everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
@ -157,13 +192,152 @@ class DiscordManager:
deny=Permissions.NONE)]
})
self.logger.info('Created breadtube bot channel')
self.create_message(
breadtube_channel,
{'content': 'This is the current configuration used, upload a new one to update the configuration'},
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
return True
def _scan_bot_channel(self):
if self.bot_channel is None:
self.logger.error('Cannot scan bot channel: bot channel is None')
return []
last_message_id: int | None = None
while True:
messages = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
if len(messages) < self.DEFAULT_MESSAGE_LIST_LIMIT:
break
last_message_id = messages[-1].id
messages = sorted(messages, key=lambda x: x.timestamp)
self.init_message = None
new_config: Config | None = None
messages_to_delete: list[Message] = []
for message in messages:
# Skip message to be deleted
skip = True
for task_type, _, task_params in self.tasks:
if task_type == self.Task.DELETE_MESSAGES and (any(m.id == message.id for m in task_params) or any(
m.id == message.id for m in messages_to_delete)):
self.logger.debug('Skipping message already marked to be deleted')
break
else:
skip = False
if skip:
continue
delete_message = True
for attachment in message.attachments:
if attachment.size < self.MAX_CONFIG_SIZE:
try:
_, content = self._download_attachment(attachment)
if content.startswith(b'config'):
try:
config = Config.from_str(content.decode())
if config != self.config:
new_config = config
elif message.content == self.INIT_MESSAGE:
if self.init_message is not None:
self.logger.debug('Deleting duplicated init message')
try:
self.delete_message(self.init_message)
except RuntimeError as error:
self.logger.error('Error deleting init_message while scanning: %s', error)
self.init_message = message
delete_message = False
break
except RuntimeError as error:
self.logger.info('Invalid config file: %s', error)
messages_to_delete.extend([
self.create_message(self.bot_channel, {
'content': str(error),
'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT,
message_id=message.id,
channel_id=self.bot_channel.id,
guild_id=None,
fail_if_not_exists=None)}),
message])
delete_message = False
break
except Exception as error:
self.logger.error('Error downloading attachment: %s', error)
messages_to_delete.extend([
self.create_message(self.bot_channel, {
'content': str(error),
'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT,
message_id=message.id,
channel_id=self.bot_channel.id,
guild_id=None,
fail_if_not_exists=None)}),
message])
delete_message = False
break
if delete_message:
if any(m.id == message.id for m in messages_to_delete):
self.logger.warning(
'Warning wrongly trying to delete message id %d while marked to be deleted', message.id)
else:
self.logger.debug('Deleting message: %s', message)
try:
self.delete_message(message)
except RuntimeError as error:
self.logger.error('Error deleting after scanned: %s', error)
if new_config is not None:
self.logger.info('Loading new config: %s', new_config)
self.config = new_config
if self.init_message is not None:
self.delete_message(self.init_message)
self.init_message = None
if self.init_message is None:
self.init_message = self.create_message(
self.bot_channel, {'content': self.INIT_MESSAGE},
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
if messages_to_delete:
self.tasks.append((
DiscordManager.Task.DELETE_MESSAGES,
time.time() + self.config.bot_message_duration,
messages_to_delete))
def run(self):
while True:
if self.tasks:
self.tasks = sorted(self.tasks, key=operator.itemgetter(1), reverse=True)
task_type, task_time, task_params = self.tasks.pop()
sleep_time = task_time - time.time()
self.logger.debug(
'Next task %s at %.03f (%s), sleeping %.03fs', task_type, task_time, task_params, sleep_time)
if sleep_time > 0:
time.sleep(sleep_time)
match task_type:
case DiscordManager.Task.SCAN_BOT_CHANNEL:
try:
self._scan_bot_channel()
except Exception as error:
self.logger.error('Error scanning bot channel: %s', error)
self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
case DiscordManager.Task.DELETE_MESSAGES:
if not isinstance(task_params, list):
self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params)
elif not task_params:
self.logger.error('Empty params for DELETE_MESSAGES: %s', task_params)
elif any(not isinstance(v, Message) for v in task_params):
self.logger.error('All params not int for DELETE_MESSAGES: %s', task_params)
else:
for message in task_params:
try:
self.delete_message(message)
except Exception as error:
self.logger.error('Error deleting message %s: %s', message, error)
if self.rate_limit.remaining <= 1:
sleep_time = self.rate_limit.next_reset - time.time()
if sleep_time > 0:
self.logger.debug('Rate limit: sleeping %.03f second')
time.sleep(sleep_time)
time.sleep(1)
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
headers, channel_info = self._send_request(
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
@ -210,7 +384,19 @@ class DiscordManager:
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
return [Role.from_dict(r) for r in roles_info]
def list_text_channel_messages(self, channel: TextChannel) -> list[Message]:
headers, messages = self._send_request(*Api.Message.list_by_channel(channel.id))
def list_text_channel_messages(self, channel: TextChannel, limit: int | None = None, before_id: int | None = None,
after_id: int | None = None) -> list[Message]:
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
params: Api.Message.ListMessageParams = {}
if limit is not None:
params['limit'] = limit
if before_id is not None:
params['before'] = before_id
if after_id is not None:
params['after'] = after_id
headers, messages_info = self._send_request(
*Api.Message.list_by_channel(channel.id),
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
self._update_rate_limit(headers)
return [Message.from_dict(m) for m in messages or []]
return [Message.from_dict(m) for m in messages_info or []]