Simplify and fix config scanning

This commit is contained in:
BreadTube 2025-09-26 04:00:55 +09:00 committed by Corentin
commit 486cb82773
3 changed files with 82 additions and 84 deletions

View file

@ -144,3 +144,8 @@ class Api:
before: int # Get messages before this message ID
after: int # Get messages after this message ID
limit: int # Max number of messages to return (1-100), default=50
class User:
@staticmethod
def get_current() -> tuple[ApiAction, str]:
return ApiAction.GET, '/users/@me'

View file

@ -12,13 +12,14 @@ import tomllib
from typing import Any
import urllib.error
import urllib.request
import traceback
from .api import Api, ApiAction, ApiVersion
from .config import Config
from .logger import create_logger
from .objects import (
Attachment, ChannelCategory, ChannelType, FileMime, Message, MessageReference, MessageReferenceType, Overwrite,
OverwriteType, Permissions, Role, TextChannel)
OverwriteType, Permissions, Role, TextChannel, User)
HTTPHeaders = dict[str, str]
@ -66,6 +67,8 @@ class DiscordManager:
self.version = self._get_code_version()
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
self.logger.info('Retrieving bot user')
self.bot_user = self.get_current_user()
self.logger.info('Retrieving guild roles before init')
self.guild_roles: list = self.list_roles()
self.bot_channel: TextChannel | None = None
@ -92,6 +95,11 @@ class DiscordManager:
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
self.logger.debug('Updated rate limit: %s', self.rate_limit)
if self.rate_limit.remaining <= 0:
sleep_time = self.rate_limit.next_reset - time.time()
if sleep_time > 0:
self.logger.debug('Rate limit: sleeping %.03f second', sleep_time)
time.sleep(sleep_time)
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
@ -130,12 +138,16 @@ class DiscordManager:
raise RuntimeError(
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
body = response.read()
return dict(response.getheaders()), json.loads(body.decode()) if body else None
headers = dict(response.getheaders())
self._update_rate_limit(headers)
return headers, json.loads(body.decode()) if body else None
except urllib.error.HTTPError as error:
raise RuntimeError(
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error:
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
except TimeoutError as error:
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
request = urllib.request.Request(attachment.url)
@ -199,106 +211,87 @@ class DiscordManager:
self.logger.error('Cannot scan bot channel: bot channel is None')
return []
messages_id_delete_task: set[int] = set()
for task_type, _, task_params in self.tasks:
if task_type == self.Task.DELETE_MESSAGES:
messages_id_delete_task.update(message.id for message in task_params)
last_message_id: int | None = None
messages: list[Message] = []
while True:
messages = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
if len(messages) < self.DEFAULT_MESSAGE_LIST_LIMIT:
message_batch = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
messages.extend([m for m in message_batch if m.id not in messages_id_delete_task])
if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT:
break
last_message_id = messages[-1].id
messages = sorted(messages, key=lambda x: x.timestamp)
last_message_id = message_batch[-1].id
self.init_message = None
new_config: Config | None = None
messages_to_delete: list[Message] = []
delayed_delete: dict[int, Message] = {}
immediate_delete: dict[int, Message] = {}
for message in messages:
# Skip message to be deleted
skip = True
for task_type, _, task_params in self.tasks:
if task_type == self.Task.DELETE_MESSAGES and (any(m.id == message.id for m in task_params) or any(
m.id == message.id for m in messages_to_delete)):
self.logger.debug('Skipping message already marked to be deleted')
break
else:
skip = False
if skip:
if message.id in delayed_delete:
self.logger.debug('Skipping message already marked to be deleted')
continue
delete_message = True
for attachment in message.attachments:
if self.init_message is None and new_config is None and len(message.attachments) == 1:
attachment = message.attachments[0]
if attachment.size < self.MAX_CONFIG_SIZE:
try:
_, content = self._download_attachment(attachment)
if content.startswith(b'config'):
try:
config = Config.from_str(content.decode())
if config != self.config:
new_config = config
elif message.content == self.INIT_MESSAGE:
if self.init_message is not None:
self.logger.debug('Deleting duplicated init message')
try:
self.delete_message(self.init_message)
except RuntimeError as error:
self.logger.error('Error deleting init_message while scanning: %s', error)
if message.author.id == self.bot_user.id: # keep using current config
self.logger.debug('Found previous init message')
self.init_message = message
delete_message = False
break
if config != self.config: # First scan qill need to load config
self.config = config
continue
if config != self.config: # New config to update to
new_config = config
self.logger.debug('Marking new config message for immediate deletion: %s', message)
immediate_delete[message.id] = message
continue
except RuntimeError as error:
self.logger.info('Invalid config file: %s', error)
messages_to_delete.extend([
self.create_message(self.bot_channel, {
bot_message = self.create_message(self.bot_channel, {
'content': str(error),
'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT,
message_id=message.id,
channel_id=self.bot_channel.id,
guild_id=None,
fail_if_not_exists=None)}),
message])
delete_message = False
break
fail_if_not_exists=None)})
delayed_delete[bot_message.id] = bot_message
delayed_delete[message.id] = message
continue
except Exception as error:
self.logger.error('Error downloading attachment: %s', error)
messages_to_delete.extend([
self.create_message(self.bot_channel, {
'content': str(error),
'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT,
message_id=message.id,
channel_id=self.bot_channel.id,
guild_id=None,
fail_if_not_exists=None)}),
message])
delete_message = False
break
if delete_message:
if any(m.id == message.id for m in messages_to_delete):
self.logger.warning(
'Warning wrongly trying to delete message id %d while marked to be deleted', message.id)
else:
self.logger.debug('Deleting message: %s', message)
try:
self.delete_message(message)
except RuntimeError as error:
self.logger.error('Error deleting after scanned: %s', error)
self.logger.debug('Marking message for immediate deletion: %s', message)
immediate_delete[message.id] = message
if new_config is not None:
self.logger.info('Loading new config: %s', new_config)
self.logger.info('Loading config: %s', new_config)
self.config = new_config
if self.init_message is not None:
self.delete_message(self.init_message)
self.init_message = None
if self.init_message is None:
assert self.config is not None
self.init_message = self.create_message(
self.bot_channel, {'content': self.INIT_MESSAGE},
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
if messages_to_delete:
for message in immediate_delete.values():
try:
self.delete_message(message)
except RuntimeError as error:
self.logger.error('Error deleting after bot channel scan (immediate): %s', error)
if delayed_delete:
self.tasks.append((
DiscordManager.Task.DELETE_MESSAGES,
time.time() + self.config.bot_message_duration,
messages_to_delete))
list(delayed_delete.values())))
def run(self):
while True:
@ -307,7 +300,7 @@ class DiscordManager:
task_type, task_time, task_params = self.tasks.pop()
sleep_time = task_time - time.time()
self.logger.debug(
'Next task %s at %.03f (%s), sleeping %.03fs', task_type, task_time, task_params, sleep_time)
'Next task %s at %.03f (sleeping for %.03fs) : %s', task_type, task_time, sleep_time, task_params)
if sleep_time > 0:
time.sleep(sleep_time)
match task_type:
@ -315,7 +308,8 @@ class DiscordManager:
try:
self._scan_bot_channel()
except Exception as error:
self.logger.error('Error scanning bot channel: %s', error)
self.logger.error('Error scanning bot channel: %s -> %s',
error, traceback.format_exc().replace('\n', ' | '))
self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
case DiscordManager.Task.DELETE_MESSAGES:
@ -330,41 +324,39 @@ class DiscordManager:
try:
self.delete_message(message)
except Exception as error:
self.logger.error('Error deleting message %s: %s', message, error)
if self.rate_limit.remaining <= 1:
sleep_time = self.rate_limit.next_reset - time.time()
if sleep_time > 0:
self.logger.debug('Rate limit: sleeping %.03f second')
time.sleep(sleep_time)
self.logger.error('Error deleting message %s: %s -> %s',
message, error, traceback.format_exc().replace('\n', ' | '))
time.sleep(1)
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
headers, channel_info = self._send_request(
_, channel_info = self._send_request(
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
expected_code=201)
self._update_rate_limit(headers)
if not isinstance(channel_info, dict):
raise RuntimeError(f'Error creating channel with params (no info): {params}')
return TextChannel.from_dict(channel_info)
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
headers, message_info = self._send_request(
_, message_info = self._send_request(
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
upload_files=upload_files)
self._update_rate_limit(headers)
if not isinstance(message_info, dict):
raise RuntimeError(f'Error creating message with params (no info): {params}')
return Message.from_dict(message_info)
def delete_message(self, message: Message):
headers, _ = self._send_request(
_, _ = self._send_request(
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
self._update_rate_limit(headers)
def get_current_user(self) -> User:
_, user_info = self._send_request(*Api.User.get_current())
if not isinstance(user_info, dict):
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
return User.from_dict(user_info)
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
headers, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
self._update_rate_limit(headers)
_, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
categories: list[ChannelCategory] = []
text_channels: list[TextChannel] = []
if channels_info is not None:
@ -378,8 +370,7 @@ class DiscordManager:
return categories, text_channels
def list_roles(self) -> list[Role]:
headers, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
self._update_rate_limit(headers)
_, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
if not isinstance(roles_info, list):
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
return [Role.from_dict(r) for r in roles_info]

View file

@ -251,6 +251,7 @@ class User: # TODO : complete attributes
username: str
discriminator: str
global_name: str | None
bot: bool
@staticmethod
def from_dict(info: dict) -> User:
@ -258,7 +259,8 @@ class User: # TODO : complete attributes
id=int(info['id']),
username=info['username'],
discriminator=info['discriminator'],
global_name=info.get('global_name'))
global_name=info.get('global_name'),
bot=info.get('bot', False))
class AttachmentFlags(IntFlag):
@ -421,7 +423,7 @@ class Message: # TODO : complete attributes
id=int(info['id']),
channel_id=int(info['channel_id']),
author=(User.from_dict(info['author']) if info.get('webhook_id') is None else User(
id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)),
id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None, bot=False)),
content=info['content'],
timestamp=datetime.fromisoformat(info['timestamp']),
edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None,