Simplify and fix config scanning
This commit is contained in:
parent
157e8c1b17
commit
486cb82773
3 changed files with 82 additions and 84 deletions
|
|
@ -144,3 +144,8 @@ class Api:
|
||||||
before: int # Get messages before this message ID
|
before: int # Get messages before this message ID
|
||||||
after: int # Get messages after this message ID
|
after: int # Get messages after this message ID
|
||||||
limit: int # Max number of messages to return (1-100), default=50
|
limit: int # Max number of messages to return (1-100), default=50
|
||||||
|
|
||||||
|
class User:
|
||||||
|
@staticmethod
|
||||||
|
def get_current() -> tuple[ApiAction, str]:
|
||||||
|
return ApiAction.GET, '/users/@me'
|
||||||
|
|
|
||||||
|
|
@ -12,13 +12,14 @@ import tomllib
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import urllib.error
|
import urllib.error
|
||||||
import urllib.request
|
import urllib.request
|
||||||
|
import traceback
|
||||||
|
|
||||||
from .api import Api, ApiAction, ApiVersion
|
from .api import Api, ApiAction, ApiVersion
|
||||||
from .config import Config
|
from .config import Config
|
||||||
from .logger import create_logger
|
from .logger import create_logger
|
||||||
from .objects import (
|
from .objects import (
|
||||||
Attachment, ChannelCategory, ChannelType, FileMime, Message, MessageReference, MessageReferenceType, Overwrite,
|
Attachment, ChannelCategory, ChannelType, FileMime, Message, MessageReference, MessageReferenceType, Overwrite,
|
||||||
OverwriteType, Permissions, Role, TextChannel)
|
OverwriteType, Permissions, Role, TextChannel, User)
|
||||||
|
|
||||||
|
|
||||||
HTTPHeaders = dict[str, str]
|
HTTPHeaders = dict[str, str]
|
||||||
|
|
@ -66,6 +67,8 @@ class DiscordManager:
|
||||||
self.version = self._get_code_version()
|
self.version = self._get_code_version()
|
||||||
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
|
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
|
||||||
|
|
||||||
|
self.logger.info('Retrieving bot user')
|
||||||
|
self.bot_user = self.get_current_user()
|
||||||
self.logger.info('Retrieving guild roles before init')
|
self.logger.info('Retrieving guild roles before init')
|
||||||
self.guild_roles: list = self.list_roles()
|
self.guild_roles: list = self.list_roles()
|
||||||
self.bot_channel: TextChannel | None = None
|
self.bot_channel: TextChannel | None = None
|
||||||
|
|
@ -92,6 +95,11 @@ class DiscordManager:
|
||||||
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
|
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
|
||||||
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
|
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
|
||||||
self.logger.debug('Updated rate limit: %s', self.rate_limit)
|
self.logger.debug('Updated rate limit: %s', self.rate_limit)
|
||||||
|
if self.rate_limit.remaining <= 0:
|
||||||
|
sleep_time = self.rate_limit.next_reset - time.time()
|
||||||
|
if sleep_time > 0:
|
||||||
|
self.logger.debug('Rate limit: sleeping %.03f second', sleep_time)
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
|
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
|
||||||
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
|
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
|
||||||
|
|
@ -130,12 +138,16 @@ class DiscordManager:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
|
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
|
||||||
body = response.read()
|
body = response.read()
|
||||||
return dict(response.getheaders()), json.loads(body.decode()) if body else None
|
headers = dict(response.getheaders())
|
||||||
|
self._update_rate_limit(headers)
|
||||||
|
return headers, json.loads(body.decode()) if body else None
|
||||||
except urllib.error.HTTPError as error:
|
except urllib.error.HTTPError as error:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
||||||
except urllib.error.URLError as error:
|
except urllib.error.URLError as error:
|
||||||
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
|
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
|
||||||
|
except TimeoutError as error:
|
||||||
|
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
|
||||||
|
|
||||||
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
|
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
|
||||||
request = urllib.request.Request(attachment.url)
|
request = urllib.request.Request(attachment.url)
|
||||||
|
|
@ -199,106 +211,87 @@ class DiscordManager:
|
||||||
self.logger.error('Cannot scan bot channel: bot channel is None')
|
self.logger.error('Cannot scan bot channel: bot channel is None')
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
messages_id_delete_task: set[int] = set()
|
||||||
|
for task_type, _, task_params in self.tasks:
|
||||||
|
if task_type == self.Task.DELETE_MESSAGES:
|
||||||
|
messages_id_delete_task.update(message.id for message in task_params)
|
||||||
|
|
||||||
last_message_id: int | None = None
|
last_message_id: int | None = None
|
||||||
|
messages: list[Message] = []
|
||||||
while True:
|
while True:
|
||||||
messages = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
|
message_batch = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
|
||||||
if len(messages) < self.DEFAULT_MESSAGE_LIST_LIMIT:
|
messages.extend([m for m in message_batch if m.id not in messages_id_delete_task])
|
||||||
|
if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT:
|
||||||
break
|
break
|
||||||
last_message_id = messages[-1].id
|
last_message_id = message_batch[-1].id
|
||||||
messages = sorted(messages, key=lambda x: x.timestamp)
|
|
||||||
|
|
||||||
self.init_message = None
|
self.init_message = None
|
||||||
new_config: Config | None = None
|
new_config: Config | None = None
|
||||||
messages_to_delete: list[Message] = []
|
delayed_delete: dict[int, Message] = {}
|
||||||
|
immediate_delete: dict[int, Message] = {}
|
||||||
for message in messages:
|
for message in messages:
|
||||||
# Skip message to be deleted
|
if message.id in delayed_delete:
|
||||||
skip = True
|
self.logger.debug('Skipping message already marked to be deleted')
|
||||||
for task_type, _, task_params in self.tasks:
|
|
||||||
if task_type == self.Task.DELETE_MESSAGES and (any(m.id == message.id for m in task_params) or any(
|
|
||||||
m.id == message.id for m in messages_to_delete)):
|
|
||||||
self.logger.debug('Skipping message already marked to be deleted')
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
skip = False
|
|
||||||
if skip:
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
delete_message = True
|
if self.init_message is None and new_config is None and len(message.attachments) == 1:
|
||||||
for attachment in message.attachments:
|
attachment = message.attachments[0]
|
||||||
if attachment.size < self.MAX_CONFIG_SIZE:
|
if attachment.size < self.MAX_CONFIG_SIZE:
|
||||||
try:
|
try:
|
||||||
_, content = self._download_attachment(attachment)
|
_, content = self._download_attachment(attachment)
|
||||||
if content.startswith(b'config'):
|
if content.startswith(b'config'):
|
||||||
try:
|
try:
|
||||||
config = Config.from_str(content.decode())
|
config = Config.from_str(content.decode())
|
||||||
if config != self.config:
|
if message.author.id == self.bot_user.id: # keep using current config
|
||||||
new_config = config
|
self.logger.debug('Found previous init message')
|
||||||
elif message.content == self.INIT_MESSAGE:
|
|
||||||
if self.init_message is not None:
|
|
||||||
self.logger.debug('Deleting duplicated init message')
|
|
||||||
try:
|
|
||||||
self.delete_message(self.init_message)
|
|
||||||
except RuntimeError as error:
|
|
||||||
self.logger.error('Error deleting init_message while scanning: %s', error)
|
|
||||||
self.init_message = message
|
self.init_message = message
|
||||||
delete_message = False
|
if config != self.config: # First scan qill need to load config
|
||||||
break
|
self.config = config
|
||||||
|
continue
|
||||||
|
if config != self.config: # New config to update to
|
||||||
|
new_config = config
|
||||||
|
self.logger.debug('Marking new config message for immediate deletion: %s', message)
|
||||||
|
immediate_delete[message.id] = message
|
||||||
|
continue
|
||||||
except RuntimeError as error:
|
except RuntimeError as error:
|
||||||
self.logger.info('Invalid config file: %s', error)
|
self.logger.info('Invalid config file: %s', error)
|
||||||
messages_to_delete.extend([
|
bot_message = self.create_message(self.bot_channel, {
|
||||||
self.create_message(self.bot_channel, {
|
|
||||||
'content': str(error),
|
'content': str(error),
|
||||||
'message_reference': MessageReference(
|
'message_reference': MessageReference(
|
||||||
type=MessageReferenceType.DEFAULT,
|
type=MessageReferenceType.DEFAULT,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
channel_id=self.bot_channel.id,
|
channel_id=self.bot_channel.id,
|
||||||
guild_id=None,
|
guild_id=None,
|
||||||
fail_if_not_exists=None)}),
|
fail_if_not_exists=None)})
|
||||||
message])
|
delayed_delete[bot_message.id] = bot_message
|
||||||
delete_message = False
|
delayed_delete[message.id] = message
|
||||||
break
|
continue
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
self.logger.error('Error downloading attachment: %s', error)
|
self.logger.error('Error downloading attachment: %s', error)
|
||||||
messages_to_delete.extend([
|
self.logger.debug('Marking message for immediate deletion: %s', message)
|
||||||
self.create_message(self.bot_channel, {
|
immediate_delete[message.id] = message
|
||||||
'content': str(error),
|
|
||||||
'message_reference': MessageReference(
|
|
||||||
type=MessageReferenceType.DEFAULT,
|
|
||||||
message_id=message.id,
|
|
||||||
channel_id=self.bot_channel.id,
|
|
||||||
guild_id=None,
|
|
||||||
fail_if_not_exists=None)}),
|
|
||||||
message])
|
|
||||||
delete_message = False
|
|
||||||
break
|
|
||||||
if delete_message:
|
|
||||||
if any(m.id == message.id for m in messages_to_delete):
|
|
||||||
self.logger.warning(
|
|
||||||
'Warning wrongly trying to delete message id %d while marked to be deleted', message.id)
|
|
||||||
else:
|
|
||||||
self.logger.debug('Deleting message: %s', message)
|
|
||||||
try:
|
|
||||||
self.delete_message(message)
|
|
||||||
except RuntimeError as error:
|
|
||||||
self.logger.error('Error deleting after scanned: %s', error)
|
|
||||||
|
|
||||||
if new_config is not None:
|
if new_config is not None:
|
||||||
self.logger.info('Loading new config: %s', new_config)
|
self.logger.info('Loading config: %s', new_config)
|
||||||
self.config = new_config
|
self.config = new_config
|
||||||
if self.init_message is not None:
|
|
||||||
self.delete_message(self.init_message)
|
|
||||||
self.init_message = None
|
|
||||||
|
|
||||||
if self.init_message is None:
|
if self.init_message is None:
|
||||||
|
assert self.config is not None
|
||||||
self.init_message = self.create_message(
|
self.init_message = self.create_message(
|
||||||
self.bot_channel, {'content': self.INIT_MESSAGE},
|
self.bot_channel, {'content': self.INIT_MESSAGE},
|
||||||
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
|
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
|
||||||
|
|
||||||
if messages_to_delete:
|
for message in immediate_delete.values():
|
||||||
|
try:
|
||||||
|
self.delete_message(message)
|
||||||
|
except RuntimeError as error:
|
||||||
|
self.logger.error('Error deleting after bot channel scan (immediate): %s', error)
|
||||||
|
|
||||||
|
if delayed_delete:
|
||||||
self.tasks.append((
|
self.tasks.append((
|
||||||
DiscordManager.Task.DELETE_MESSAGES,
|
DiscordManager.Task.DELETE_MESSAGES,
|
||||||
time.time() + self.config.bot_message_duration,
|
time.time() + self.config.bot_message_duration,
|
||||||
messages_to_delete))
|
list(delayed_delete.values())))
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -307,7 +300,7 @@ class DiscordManager:
|
||||||
task_type, task_time, task_params = self.tasks.pop()
|
task_type, task_time, task_params = self.tasks.pop()
|
||||||
sleep_time = task_time - time.time()
|
sleep_time = task_time - time.time()
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
'Next task %s at %.03f (%s), sleeping %.03fs', task_type, task_time, task_params, sleep_time)
|
'Next task %s at %.03f (sleeping for %.03fs) : %s', task_type, task_time, sleep_time, task_params)
|
||||||
if sleep_time > 0:
|
if sleep_time > 0:
|
||||||
time.sleep(sleep_time)
|
time.sleep(sleep_time)
|
||||||
match task_type:
|
match task_type:
|
||||||
|
|
@ -315,7 +308,8 @@ class DiscordManager:
|
||||||
try:
|
try:
|
||||||
self._scan_bot_channel()
|
self._scan_bot_channel()
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
self.logger.error('Error scanning bot channel: %s', error)
|
self.logger.error('Error scanning bot channel: %s -> %s',
|
||||||
|
error, traceback.format_exc().replace('\n', ' | '))
|
||||||
self.tasks.append((
|
self.tasks.append((
|
||||||
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
|
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
|
||||||
case DiscordManager.Task.DELETE_MESSAGES:
|
case DiscordManager.Task.DELETE_MESSAGES:
|
||||||
|
|
@ -330,41 +324,39 @@ class DiscordManager:
|
||||||
try:
|
try:
|
||||||
self.delete_message(message)
|
self.delete_message(message)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
self.logger.error('Error deleting message %s: %s', message, error)
|
self.logger.error('Error deleting message %s: %s -> %s',
|
||||||
if self.rate_limit.remaining <= 1:
|
message, error, traceback.format_exc().replace('\n', ' | '))
|
||||||
sleep_time = self.rate_limit.next_reset - time.time()
|
|
||||||
if sleep_time > 0:
|
|
||||||
self.logger.debug('Rate limit: sleeping %.03f second')
|
|
||||||
time.sleep(sleep_time)
|
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
|
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
|
||||||
headers, channel_info = self._send_request(
|
_, channel_info = self._send_request(
|
||||||
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
||||||
expected_code=201)
|
expected_code=201)
|
||||||
self._update_rate_limit(headers)
|
|
||||||
if not isinstance(channel_info, dict):
|
if not isinstance(channel_info, dict):
|
||||||
raise RuntimeError(f'Error creating channel with params (no info): {params}')
|
raise RuntimeError(f'Error creating channel with params (no info): {params}')
|
||||||
return TextChannel.from_dict(channel_info)
|
return TextChannel.from_dict(channel_info)
|
||||||
|
|
||||||
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
|
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
|
||||||
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
|
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
|
||||||
headers, message_info = self._send_request(
|
_, message_info = self._send_request(
|
||||||
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
||||||
upload_files=upload_files)
|
upload_files=upload_files)
|
||||||
self._update_rate_limit(headers)
|
|
||||||
if not isinstance(message_info, dict):
|
if not isinstance(message_info, dict):
|
||||||
raise RuntimeError(f'Error creating message with params (no info): {params}')
|
raise RuntimeError(f'Error creating message with params (no info): {params}')
|
||||||
return Message.from_dict(message_info)
|
return Message.from_dict(message_info)
|
||||||
|
|
||||||
def delete_message(self, message: Message):
|
def delete_message(self, message: Message):
|
||||||
headers, _ = self._send_request(
|
_, _ = self._send_request(
|
||||||
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
|
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
|
||||||
self._update_rate_limit(headers)
|
|
||||||
|
def get_current_user(self) -> User:
|
||||||
|
_, user_info = self._send_request(*Api.User.get_current())
|
||||||
|
if not isinstance(user_info, dict):
|
||||||
|
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
|
||||||
|
return User.from_dict(user_info)
|
||||||
|
|
||||||
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
|
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
|
||||||
headers, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
|
_, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
|
||||||
self._update_rate_limit(headers)
|
|
||||||
categories: list[ChannelCategory] = []
|
categories: list[ChannelCategory] = []
|
||||||
text_channels: list[TextChannel] = []
|
text_channels: list[TextChannel] = []
|
||||||
if channels_info is not None:
|
if channels_info is not None:
|
||||||
|
|
@ -378,8 +370,7 @@ class DiscordManager:
|
||||||
return categories, text_channels
|
return categories, text_channels
|
||||||
|
|
||||||
def list_roles(self) -> list[Role]:
|
def list_roles(self) -> list[Role]:
|
||||||
headers, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
|
_, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
|
||||||
self._update_rate_limit(headers)
|
|
||||||
if not isinstance(roles_info, list):
|
if not isinstance(roles_info, list):
|
||||||
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
|
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
|
||||||
return [Role.from_dict(r) for r in roles_info]
|
return [Role.from_dict(r) for r in roles_info]
|
||||||
|
|
|
||||||
|
|
@ -251,6 +251,7 @@ class User: # TODO : complete attributes
|
||||||
username: str
|
username: str
|
||||||
discriminator: str
|
discriminator: str
|
||||||
global_name: str | None
|
global_name: str | None
|
||||||
|
bot: bool
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_dict(info: dict) -> User:
|
def from_dict(info: dict) -> User:
|
||||||
|
|
@ -258,7 +259,8 @@ class User: # TODO : complete attributes
|
||||||
id=int(info['id']),
|
id=int(info['id']),
|
||||||
username=info['username'],
|
username=info['username'],
|
||||||
discriminator=info['discriminator'],
|
discriminator=info['discriminator'],
|
||||||
global_name=info.get('global_name'))
|
global_name=info.get('global_name'),
|
||||||
|
bot=info.get('bot', False))
|
||||||
|
|
||||||
|
|
||||||
class AttachmentFlags(IntFlag):
|
class AttachmentFlags(IntFlag):
|
||||||
|
|
@ -421,7 +423,7 @@ class Message: # TODO : complete attributes
|
||||||
id=int(info['id']),
|
id=int(info['id']),
|
||||||
channel_id=int(info['channel_id']),
|
channel_id=int(info['channel_id']),
|
||||||
author=(User.from_dict(info['author']) if info.get('webhook_id') is None else User(
|
author=(User.from_dict(info['author']) if info.get('webhook_id') is None else User(
|
||||||
id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)),
|
id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None, bot=False)),
|
||||||
content=info['content'],
|
content=info['content'],
|
||||||
timestamp=datetime.fromisoformat(info['timestamp']),
|
timestamp=datetime.fromisoformat(info['timestamp']),
|
||||||
edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None,
|
edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue