Simplify and fix config scanning

This commit is contained in:
BreadTube 2025-09-26 04:00:55 +09:00 committed by Corentin
commit 486cb82773
3 changed files with 82 additions and 84 deletions

View file

@ -144,3 +144,8 @@ class Api:
before: int # Get messages before this message ID before: int # Get messages before this message ID
after: int # Get messages after this message ID after: int # Get messages after this message ID
limit: int # Max number of messages to return (1-100), default=50 limit: int # Max number of messages to return (1-100), default=50
class User:
@staticmethod
def get_current() -> tuple[ApiAction, str]:
return ApiAction.GET, '/users/@me'

View file

@ -12,13 +12,14 @@ import tomllib
from typing import Any from typing import Any
import urllib.error import urllib.error
import urllib.request import urllib.request
import traceback
from .api import Api, ApiAction, ApiVersion from .api import Api, ApiAction, ApiVersion
from .config import Config from .config import Config
from .logger import create_logger from .logger import create_logger
from .objects import ( from .objects import (
Attachment, ChannelCategory, ChannelType, FileMime, Message, MessageReference, MessageReferenceType, Overwrite, Attachment, ChannelCategory, ChannelType, FileMime, Message, MessageReference, MessageReferenceType, Overwrite,
OverwriteType, Permissions, Role, TextChannel) OverwriteType, Permissions, Role, TextChannel, User)
HTTPHeaders = dict[str, str] HTTPHeaders = dict[str, str]
@ -66,6 +67,8 @@ class DiscordManager:
self.version = self._get_code_version() self.version = self._get_code_version()
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = [] self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
self.logger.info('Retrieving bot user')
self.bot_user = self.get_current_user()
self.logger.info('Retrieving guild roles before init') self.logger.info('Retrieving guild roles before init')
self.guild_roles: list = self.list_roles() self.guild_roles: list = self.list_roles()
self.bot_channel: TextChannel | None = None self.bot_channel: TextChannel | None = None
@ -92,6 +95,11 @@ class DiscordManager:
self.rate_limit.remaining = int(headers['x-ratelimit-remaining']) self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
self.rate_limit.next_reset = float(headers['x-ratelimit-reset']) self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
self.logger.debug('Updated rate limit: %s', self.rate_limit) self.logger.debug('Updated rate limit: %s', self.rate_limit)
if self.rate_limit.remaining <= 0:
sleep_time = self.rate_limit.next_reset - time.time()
if sleep_time > 0:
self.logger.debug('Rate limit: sleeping %.03f second', sleep_time)
time.sleep(sleep_time)
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10, def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None, data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
@ -130,12 +138,16 @@ class DiscordManager:
raise RuntimeError( raise RuntimeError(
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}') f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
body = response.read() body = response.read()
return dict(response.getheaders()), json.loads(body.decode()) if body else None headers = dict(response.getheaders())
self._update_rate_limit(headers)
return headers, json.loads(body.decode()) if body else None
except urllib.error.HTTPError as error: except urllib.error.HTTPError as error:
raise RuntimeError( raise RuntimeError(
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
except urllib.error.URLError as error: except urllib.error.URLError as error:
raise RuntimeError(f'URL error calling API ({url}): {error}') from error raise RuntimeError(f'URL error calling API ({url}): {error}') from error
except TimeoutError as error:
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]: def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
request = urllib.request.Request(attachment.url) request = urllib.request.Request(attachment.url)
@ -199,106 +211,87 @@ class DiscordManager:
self.logger.error('Cannot scan bot channel: bot channel is None') self.logger.error('Cannot scan bot channel: bot channel is None')
return [] return []
messages_id_delete_task: set[int] = set()
for task_type, _, task_params in self.tasks:
if task_type == self.Task.DELETE_MESSAGES:
messages_id_delete_task.update(message.id for message in task_params)
last_message_id: int | None = None last_message_id: int | None = None
messages: list[Message] = []
while True: while True:
messages = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id) message_batch = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
if len(messages) < self.DEFAULT_MESSAGE_LIST_LIMIT: messages.extend([m for m in message_batch if m.id not in messages_id_delete_task])
if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT:
break break
last_message_id = messages[-1].id last_message_id = message_batch[-1].id
messages = sorted(messages, key=lambda x: x.timestamp)
self.init_message = None self.init_message = None
new_config: Config | None = None new_config: Config | None = None
messages_to_delete: list[Message] = [] delayed_delete: dict[int, Message] = {}
immediate_delete: dict[int, Message] = {}
for message in messages: for message in messages:
# Skip message to be deleted if message.id in delayed_delete:
skip = True
for task_type, _, task_params in self.tasks:
if task_type == self.Task.DELETE_MESSAGES and (any(m.id == message.id for m in task_params) or any(
m.id == message.id for m in messages_to_delete)):
self.logger.debug('Skipping message already marked to be deleted') self.logger.debug('Skipping message already marked to be deleted')
break
else:
skip = False
if skip:
continue continue
delete_message = True if self.init_message is None and new_config is None and len(message.attachments) == 1:
for attachment in message.attachments: attachment = message.attachments[0]
if attachment.size < self.MAX_CONFIG_SIZE: if attachment.size < self.MAX_CONFIG_SIZE:
try: try:
_, content = self._download_attachment(attachment) _, content = self._download_attachment(attachment)
if content.startswith(b'config'): if content.startswith(b'config'):
try: try:
config = Config.from_str(content.decode()) config = Config.from_str(content.decode())
if config != self.config: if message.author.id == self.bot_user.id: # keep using current config
new_config = config self.logger.debug('Found previous init message')
elif message.content == self.INIT_MESSAGE:
if self.init_message is not None:
self.logger.debug('Deleting duplicated init message')
try:
self.delete_message(self.init_message)
except RuntimeError as error:
self.logger.error('Error deleting init_message while scanning: %s', error)
self.init_message = message self.init_message = message
delete_message = False if config != self.config: # First scan qill need to load config
break self.config = config
continue
if config != self.config: # New config to update to
new_config = config
self.logger.debug('Marking new config message for immediate deletion: %s', message)
immediate_delete[message.id] = message
continue
except RuntimeError as error: except RuntimeError as error:
self.logger.info('Invalid config file: %s', error) self.logger.info('Invalid config file: %s', error)
messages_to_delete.extend([ bot_message = self.create_message(self.bot_channel, {
self.create_message(self.bot_channel, {
'content': str(error), 'content': str(error),
'message_reference': MessageReference( 'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT, type=MessageReferenceType.DEFAULT,
message_id=message.id, message_id=message.id,
channel_id=self.bot_channel.id, channel_id=self.bot_channel.id,
guild_id=None, guild_id=None,
fail_if_not_exists=None)}), fail_if_not_exists=None)})
message]) delayed_delete[bot_message.id] = bot_message
delete_message = False delayed_delete[message.id] = message
break continue
except Exception as error: except Exception as error:
self.logger.error('Error downloading attachment: %s', error) self.logger.error('Error downloading attachment: %s', error)
messages_to_delete.extend([ self.logger.debug('Marking message for immediate deletion: %s', message)
self.create_message(self.bot_channel, { immediate_delete[message.id] = message
'content': str(error),
'message_reference': MessageReference(
type=MessageReferenceType.DEFAULT,
message_id=message.id,
channel_id=self.bot_channel.id,
guild_id=None,
fail_if_not_exists=None)}),
message])
delete_message = False
break
if delete_message:
if any(m.id == message.id for m in messages_to_delete):
self.logger.warning(
'Warning wrongly trying to delete message id %d while marked to be deleted', message.id)
else:
self.logger.debug('Deleting message: %s', message)
try:
self.delete_message(message)
except RuntimeError as error:
self.logger.error('Error deleting after scanned: %s', error)
if new_config is not None: if new_config is not None:
self.logger.info('Loading new config: %s', new_config) self.logger.info('Loading config: %s', new_config)
self.config = new_config self.config = new_config
if self.init_message is not None:
self.delete_message(self.init_message)
self.init_message = None
if self.init_message is None: if self.init_message is None:
assert self.config is not None
self.init_message = self.create_message( self.init_message = self.create_message(
self.bot_channel, {'content': self.INIT_MESSAGE}, self.bot_channel, {'content': self.INIT_MESSAGE},
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())]) upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
if messages_to_delete: for message in immediate_delete.values():
try:
self.delete_message(message)
except RuntimeError as error:
self.logger.error('Error deleting after bot channel scan (immediate): %s', error)
if delayed_delete:
self.tasks.append(( self.tasks.append((
DiscordManager.Task.DELETE_MESSAGES, DiscordManager.Task.DELETE_MESSAGES,
time.time() + self.config.bot_message_duration, time.time() + self.config.bot_message_duration,
messages_to_delete)) list(delayed_delete.values())))
def run(self): def run(self):
while True: while True:
@ -307,7 +300,7 @@ class DiscordManager:
task_type, task_time, task_params = self.tasks.pop() task_type, task_time, task_params = self.tasks.pop()
sleep_time = task_time - time.time() sleep_time = task_time - time.time()
self.logger.debug( self.logger.debug(
'Next task %s at %.03f (%s), sleeping %.03fs', task_type, task_time, task_params, sleep_time) 'Next task %s at %.03f (sleeping for %.03fs) : %s', task_type, task_time, sleep_time, task_params)
if sleep_time > 0: if sleep_time > 0:
time.sleep(sleep_time) time.sleep(sleep_time)
match task_type: match task_type:
@ -315,7 +308,8 @@ class DiscordManager:
try: try:
self._scan_bot_channel() self._scan_bot_channel()
except Exception as error: except Exception as error:
self.logger.error('Error scanning bot channel: %s', error) self.logger.error('Error scanning bot channel: %s -> %s',
error, traceback.format_exc().replace('\n', ' | '))
self.tasks.append(( self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None)) self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
case DiscordManager.Task.DELETE_MESSAGES: case DiscordManager.Task.DELETE_MESSAGES:
@ -330,41 +324,39 @@ class DiscordManager:
try: try:
self.delete_message(message) self.delete_message(message)
except Exception as error: except Exception as error:
self.logger.error('Error deleting message %s: %s', message, error) self.logger.error('Error deleting message %s: %s -> %s',
if self.rate_limit.remaining <= 1: message, error, traceback.format_exc().replace('\n', ' | '))
sleep_time = self.rate_limit.next_reset - time.time()
if sleep_time > 0:
self.logger.debug('Rate limit: sleeping %.03f second')
time.sleep(sleep_time)
time.sleep(1) time.sleep(1)
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel: def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
headers, channel_info = self._send_request( _, channel_info = self._send_request(
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(), *Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
expected_code=201) expected_code=201)
self._update_rate_limit(headers)
if not isinstance(channel_info, dict): if not isinstance(channel_info, dict):
raise RuntimeError(f'Error creating channel with params (no info): {params}') raise RuntimeError(f'Error creating channel with params (no info): {params}')
return TextChannel.from_dict(channel_info) return TextChannel.from_dict(channel_info)
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams, def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message: upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
headers, message_info = self._send_request( _, message_info = self._send_request(
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(), *Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
upload_files=upload_files) upload_files=upload_files)
self._update_rate_limit(headers)
if not isinstance(message_info, dict): if not isinstance(message_info, dict):
raise RuntimeError(f'Error creating message with params (no info): {params}') raise RuntimeError(f'Error creating message with params (no info): {params}')
return Message.from_dict(message_info) return Message.from_dict(message_info)
def delete_message(self, message: Message): def delete_message(self, message: Message):
headers, _ = self._send_request( _, _ = self._send_request(
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204) *Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
self._update_rate_limit(headers)
def get_current_user(self) -> User:
_, user_info = self._send_request(*Api.User.get_current())
if not isinstance(user_info, dict):
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
return User.from_dict(user_info)
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]: def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
headers, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id)) _, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
self._update_rate_limit(headers)
categories: list[ChannelCategory] = [] categories: list[ChannelCategory] = []
text_channels: list[TextChannel] = [] text_channels: list[TextChannel] = []
if channels_info is not None: if channels_info is not None:
@ -378,8 +370,7 @@ class DiscordManager:
return categories, text_channels return categories, text_channels
def list_roles(self) -> list[Role]: def list_roles(self) -> list[Role]:
headers, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id)) _, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
self._update_rate_limit(headers)
if not isinstance(roles_info, list): if not isinstance(roles_info, list):
raise RuntimeError(f'Error listing roles (not a list): {roles_info}') raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
return [Role.from_dict(r) for r in roles_info] return [Role.from_dict(r) for r in roles_info]

View file

@ -251,6 +251,7 @@ class User: # TODO : complete attributes
username: str username: str
discriminator: str discriminator: str
global_name: str | None global_name: str | None
bot: bool
@staticmethod @staticmethod
def from_dict(info: dict) -> User: def from_dict(info: dict) -> User:
@ -258,7 +259,8 @@ class User: # TODO : complete attributes
id=int(info['id']), id=int(info['id']),
username=info['username'], username=info['username'],
discriminator=info['discriminator'], discriminator=info['discriminator'],
global_name=info.get('global_name')) global_name=info.get('global_name'),
bot=info.get('bot', False))
class AttachmentFlags(IntFlag): class AttachmentFlags(IntFlag):
@ -421,7 +423,7 @@ class Message: # TODO : complete attributes
id=int(info['id']), id=int(info['id']),
channel_id=int(info['channel_id']), channel_id=int(info['channel_id']),
author=(User.from_dict(info['author']) if info.get('webhook_id') is None else User( author=(User.from_dict(info['author']) if info.get('webhook_id') is None else User(
id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)), id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None, bot=False)),
content=info['content'], content=info['content'],
timestamp=datetime.fromisoformat(info['timestamp']), timestamp=datetime.fromisoformat(info['timestamp']),
edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None, edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None,