Code refactored
This commit is contained in:
parent
b80e4f7745
commit
d5b3436aec
3 changed files with 247 additions and 213 deletions
|
|
@ -1,51 +1,30 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, is_dataclass
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import json
|
|
||||||
import random
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import tomllib
|
import tomllib
|
||||||
from typing import Any
|
from typing import Any
|
||||||
import urllib.error
|
|
||||||
import urllib.request
|
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|
||||||
from .api import Api, ApiAction, ApiVersion
|
|
||||||
from .config import Config
|
from .config import Config
|
||||||
|
from .discord_manager import DiscordManager
|
||||||
from .logger import create_logger
|
from .logger import create_logger
|
||||||
from .objects import (
|
from .objects import (ChannelCategory, FileMime, Message, MessageReference, MessageReferenceType, Overwrite,
|
||||||
Attachment, ChannelCategory, ChannelType, FileMime, HTTPHeaders, Message, MessageReference, MessageReferenceType,
|
OverwriteType, Permissions, Role, TextChannel)
|
||||||
Overwrite, OverwriteType, Permissions, Role, TextChannel, User)
|
|
||||||
from .youtube_manager import YoutubeManager
|
from .youtube_manager import YoutubeManager
|
||||||
from .youtube_subscription import SUBSCRIPTION_FILE_COLUMNS, SubscriptionHelper, Subscriptions
|
from .youtube_subscription import SUBSCRIPTION_FILE_COLUMNS, SubscriptionHelper, Subscriptions
|
||||||
|
|
||||||
|
|
||||||
class ApiEncoder(json.JSONEncoder):
|
class Bot:
|
||||||
def default(self, o):
|
|
||||||
if is_dataclass(o):
|
|
||||||
return asdict(o) # type: ignore
|
|
||||||
if isinstance(o, Enum):
|
|
||||||
return o.value
|
|
||||||
return super().default(o)
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordManager:
|
|
||||||
DEFAULT_MESSAGE_LIST_LIMIT = 50
|
DEFAULT_MESSAGE_LIST_LIMIT = 50
|
||||||
INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n'
|
INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n'
|
||||||
'You can upload a new one to update the configuration.')
|
'You can upload a new one to update the configuration.')
|
||||||
MAX_DOWNLOAD_SIZE: int = 50_000
|
MAX_DOWNLOAD_SIZE: int = 50_000
|
||||||
MIN_API_VERSION = 9
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class RateLimit:
|
|
||||||
remaining: int
|
|
||||||
next_reset: float
|
|
||||||
|
|
||||||
class Task(Enum):
|
class Task(Enum):
|
||||||
DELETE_MESSAGES = 1
|
DELETE_MESSAGES = 1
|
||||||
|
|
@ -60,20 +39,20 @@ class DiscordManager:
|
||||||
return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version']
|
return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version']
|
||||||
|
|
||||||
def __init__(self, bot_token: str, guild_id: int, yt_api_key: str, config: Config | None = None,
|
def __init__(self, bot_token: str, guild_id: int, yt_api_key: str, config: Config | None = None,
|
||||||
log_level: int = logging.INFO) -> None:
|
log_level: int = logging.INFO):
|
||||||
self.config = config or Config()
|
self.config = config or Config()
|
||||||
self.guild_id = guild_id
|
self.guild_id = guild_id
|
||||||
self._bot_token = bot_token
|
|
||||||
self.logger = create_logger('breadtube', log_level, stdout=True)
|
self.logger = create_logger('breadtube', log_level, stdout=True)
|
||||||
|
|
||||||
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
|
|
||||||
self.version = self._get_code_version()
|
self.version = self._get_code_version()
|
||||||
self.tasks: list[tuple[DiscordManager.Task, float, Any]] = []
|
self.discord_manager = DiscordManager(bot_token=bot_token, bot_version=self.version, logger=self.logger)
|
||||||
|
self.tasks: list[tuple[Bot.Task, float, Any]] = []
|
||||||
|
|
||||||
self.logger.info('Retrieving bot user')
|
self.logger.info('Retrieving bot user')
|
||||||
self.bot_user = self.get_current_user()
|
self.bot_user = self.discord_manager.get_current_user(request_timeout=self.config.request_timeout)
|
||||||
self.logger.info('Retrieving guild roles before init')
|
self.logger.info('Retrieving guild roles before init')
|
||||||
self.guild_roles: list[Role] = self.list_roles()
|
self.guild_roles: list[Role] = self.discord_manager.list_roles(
|
||||||
|
self.guild_id, request_timeout=self.config.request_timeout)
|
||||||
bot_role: Role | None = None
|
bot_role: Role | None = None
|
||||||
everyone_role: Role | None = None
|
everyone_role: Role | None = None
|
||||||
for role in self.guild_roles:
|
for role in self.guild_roles:
|
||||||
|
|
@ -88,7 +67,8 @@ class DiscordManager:
|
||||||
self.bot_role: Role = bot_role
|
self.bot_role: Role = bot_role
|
||||||
self.everyone_role: Role = everyone_role
|
self.everyone_role: Role = everyone_role
|
||||||
|
|
||||||
categories, text_channel = self.list_channels()
|
categories, text_channel = self.discord_manager.list_channels(
|
||||||
|
self.guild_id, request_timeout=self.config.request_timeout)
|
||||||
self.guild_text_channels: list[TextChannel] = text_channel
|
self.guild_text_channels: list[TextChannel] = text_channel
|
||||||
self.guild_categories: list[ChannelCategory] = categories
|
self.guild_categories: list[ChannelCategory] = categories
|
||||||
self.init_message: Message | None = None
|
self.init_message: Message | None = None
|
||||||
|
|
@ -111,84 +91,6 @@ class DiscordManager:
|
||||||
self.yt_manager = YoutubeManager(api_key=yt_api_key, logger=self.logger)
|
self.yt_manager = YoutubeManager(api_key=yt_api_key, logger=self.logger)
|
||||||
self.logger.info('Bot initialized')
|
self.logger.info('Bot initialized')
|
||||||
|
|
||||||
def _update_rate_limit(self, headers: HTTPHeaders):
|
|
||||||
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
|
|
||||||
if header_key not in headers:
|
|
||||||
self.logger.info('Warning: no "%s" found in headers', header_key)
|
|
||||||
return
|
|
||||||
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
|
|
||||||
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
|
|
||||||
self.logger.debug('Updated rate limit: %s', self.rate_limit)
|
|
||||||
if self.rate_limit.remaining <= 0:
|
|
||||||
sleep_time = self.rate_limit.next_reset - time.time()
|
|
||||||
if sleep_time > 0:
|
|
||||||
self.logger.debug('Rate limit: sleeping %.03f second', sleep_time)
|
|
||||||
time.sleep(sleep_time)
|
|
||||||
|
|
||||||
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
|
|
||||||
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
|
|
||||||
expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]:
|
|
||||||
if api_version.value < self.MIN_API_VERSION:
|
|
||||||
self.logger.warning(
|
|
||||||
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
|
|
||||||
api_version, self.MIN_API_VERSION)
|
|
||||||
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
|
|
||||||
self.logger.debug('Discord API Request: %s %s', api_action.value, url)
|
|
||||||
|
|
||||||
boundary: str = ''
|
|
||||||
if upload_files:
|
|
||||||
boundary = f'{random.randbytes(16).hex()}'
|
|
||||||
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
|
|
||||||
'Content-Type: application/json\r\n\r\n'.encode() + data
|
|
||||||
+ f'\r\n--{boundary}'.encode()) if data else b''
|
|
||||||
for file_index, (name, mime, content) in enumerate(upload_files):
|
|
||||||
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
|
|
||||||
f' filename="{name}"\r\nContent-Length: {len(content)}'
|
|
||||||
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
|
|
||||||
data += f'\r\n--{boundary}--'.encode()
|
|
||||||
request = urllib.request.Request(url, data=data)
|
|
||||||
request.method = api_action.value
|
|
||||||
request.add_header('User-Agent', f'BreadTube (v{self.version})')
|
|
||||||
request.add_header('Accept', 'application/json')
|
|
||||||
if upload_files:
|
|
||||||
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
|
|
||||||
else:
|
|
||||||
request.add_header('Content-Type', 'application/json')
|
|
||||||
request.add_header('Authorization', f'Bot {self._bot_token}')
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
|
|
||||||
if response.status != expected_code:
|
|
||||||
raise RuntimeError(
|
|
||||||
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
|
|
||||||
body = response.read()
|
|
||||||
headers = dict(response.getheaders())
|
|
||||||
self._update_rate_limit(headers)
|
|
||||||
return headers, json.loads(body.decode()) if body else None
|
|
||||||
except urllib.error.HTTPError as error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
|
||||||
except urllib.error.URLError as error:
|
|
||||||
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
|
|
||||||
except TimeoutError as error:
|
|
||||||
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
|
|
||||||
|
|
||||||
def _download_attachment(self, attachment: Attachment, expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
|
|
||||||
request = urllib.request.Request(attachment.url)
|
|
||||||
request.add_header('User-Agent', f'BreadTube (v{self.version})')
|
|
||||||
request.add_header('Authorization', f'Bot {self._bot_token}')
|
|
||||||
try:
|
|
||||||
with urllib.request.urlopen(request, timeout=self.config.request_timeout) as response:
|
|
||||||
if response.status != expected_code:
|
|
||||||
raise RuntimeError(
|
|
||||||
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
|
|
||||||
return dict(response.getheaders()), response.read()
|
|
||||||
except urllib.error.HTTPError as error:
|
|
||||||
raise RuntimeError(
|
|
||||||
f'HTTP error downloading attachment ({attachment}):'
|
|
||||||
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
|
||||||
except urllib.error.URLError as error:
|
|
||||||
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
|
|
||||||
|
|
||||||
def init_bot_channel(self) -> TextChannel | None:
|
def init_bot_channel(self) -> TextChannel | None:
|
||||||
for channel in self.guild_text_channels:
|
for channel in self.guild_text_channels:
|
||||||
if channel.name == self.config.bot_channel:
|
if channel.name == self.config.bot_channel:
|
||||||
|
|
@ -203,14 +105,15 @@ class DiscordManager:
|
||||||
return channel
|
return channel
|
||||||
|
|
||||||
self.logger.info('Creating breadtube bot channel')
|
self.logger.info('Creating breadtube bot channel')
|
||||||
return self.create_text_channel({
|
return self.discord_manager.create_text_channel(
|
||||||
|
self.guild_id, {
|
||||||
'name': self.config.bot_channel,
|
'name': self.config.bot_channel,
|
||||||
'permission_overwrites': [
|
'permission_overwrites': [
|
||||||
Overwrite(self.everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
|
Overwrite(self.everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
|
||||||
deny=Permissions.VIEW_CHANNEL),
|
deny=Permissions.VIEW_CHANNEL),
|
||||||
Overwrite(self.bot_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL,
|
Overwrite(self.bot_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL,
|
||||||
deny=Permissions.NONE)]
|
deny=Permissions.NONE)]},
|
||||||
})
|
request_timeout=self.config.request_timeout)
|
||||||
|
|
||||||
def _get_bot_channel_messages(self) -> list[Message]:
|
def _get_bot_channel_messages(self) -> list[Message]:
|
||||||
messages_id_delete_task: set[int] = set()
|
messages_id_delete_task: set[int] = set()
|
||||||
|
|
@ -221,7 +124,8 @@ class DiscordManager:
|
||||||
last_message_id: int | None = None
|
last_message_id: int | None = None
|
||||||
messages: list[Message] = []
|
messages: list[Message] = []
|
||||||
while True:
|
while True:
|
||||||
message_batch = self.list_text_channel_messages(self.bot_channel, after_id=last_message_id)
|
message_batch = self.discord_manager.list_text_channel_messages(
|
||||||
|
self.bot_channel, request_timeout=self.config.request_timeout, after_id=last_message_id)
|
||||||
messages.extend([m for m in message_batch if m.id not in messages_id_delete_task])
|
messages.extend([m for m in message_batch if m.id not in messages_id_delete_task])
|
||||||
if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT:
|
if len(message_batch) < self.DEFAULT_MESSAGE_LIST_LIMIT:
|
||||||
break
|
break
|
||||||
|
|
@ -255,7 +159,8 @@ class DiscordManager:
|
||||||
has_error = False
|
has_error = False
|
||||||
for attachment in message.attachments:
|
for attachment in message.attachments:
|
||||||
try:
|
try:
|
||||||
_, content = self._download_attachment(attachment)
|
_, content = self.discord_manager.download_attachment(
|
||||||
|
attachment, request_timeout=self.config.request_timeout)
|
||||||
if new_config is None and content.startswith(b'config'):
|
if new_config is None and content.startswith(b'config'):
|
||||||
try:
|
try:
|
||||||
self.config = Config.from_str(content.decode())
|
self.config = Config.from_str(content.decode())
|
||||||
|
|
@ -269,7 +174,7 @@ class DiscordManager:
|
||||||
SubscriptionHelper.update_subscriptions(
|
SubscriptionHelper.update_subscriptions(
|
||||||
new=subscriptions, previous=self._yt_subscriptions)
|
new=subscriptions, previous=self._yt_subscriptions)
|
||||||
self._yt_subscriptions = subscriptions
|
self._yt_subscriptions = subscriptions
|
||||||
self.tasks.append((DiscordManager.Task.INIT_SUBS, time.time() + 1, None))
|
self.tasks.append((Bot.Task.INIT_SUBS, time.time() + 1, None))
|
||||||
except RuntimeError as error:
|
except RuntimeError as error:
|
||||||
self.logger.error('Invalid init subscriptions file: %s', error)
|
self.logger.error('Invalid init subscriptions file: %s', error)
|
||||||
has_error = True
|
has_error = True
|
||||||
|
|
@ -288,7 +193,8 @@ class DiscordManager:
|
||||||
immediate_delete[message.id] = message
|
immediate_delete[message.id] = message
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
_, content = self._download_attachment(attachment)
|
_, content = self.discord_manager.download_attachment(
|
||||||
|
attachment, request_timeout=self.config.request_timeout)
|
||||||
if content.startswith(b'config') and new_config is None:
|
if content.startswith(b'config') and new_config is None:
|
||||||
try:
|
try:
|
||||||
config = Config.from_str(content.decode())
|
config = Config.from_str(content.decode())
|
||||||
|
|
@ -299,14 +205,14 @@ class DiscordManager:
|
||||||
continue
|
continue
|
||||||
except RuntimeError as error:
|
except RuntimeError as error:
|
||||||
self.logger.info('Invalid config file: %s', error)
|
self.logger.info('Invalid config file: %s', error)
|
||||||
bot_message = self.create_message(self.bot_channel, {
|
bot_message = self.discord_manager.create_message(self.bot_channel, {
|
||||||
'content': str(error),
|
'content': str(error),
|
||||||
'message_reference': MessageReference(
|
'message_reference': MessageReference(
|
||||||
type=MessageReferenceType.DEFAULT,
|
type=MessageReferenceType.DEFAULT,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
channel_id=self.bot_channel.id,
|
channel_id=self.bot_channel.id,
|
||||||
guild_id=None,
|
guild_id=None,
|
||||||
fail_if_not_exists=None)})
|
fail_if_not_exists=None)}, request_timeout=self.config.request_timeout)
|
||||||
delayed_delete[bot_message.id] = bot_message
|
delayed_delete[bot_message.id] = bot_message
|
||||||
delayed_delete[message.id] = message
|
delayed_delete[message.id] = message
|
||||||
continue
|
continue
|
||||||
|
|
@ -320,14 +226,14 @@ class DiscordManager:
|
||||||
continue
|
continue
|
||||||
except RuntimeError as error:
|
except RuntimeError as error:
|
||||||
self.logger.info('Invalid subscriptions file: %s', error)
|
self.logger.info('Invalid subscriptions file: %s', error)
|
||||||
bot_message = self.create_message(self.bot_channel, {
|
bot_message = self.discord_manager.create_message(self.bot_channel, {
|
||||||
'content': str(error),
|
'content': str(error),
|
||||||
'message_reference': MessageReference(
|
'message_reference': MessageReference(
|
||||||
type=MessageReferenceType.DEFAULT,
|
type=MessageReferenceType.DEFAULT,
|
||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
channel_id=self.bot_channel.id,
|
channel_id=self.bot_channel.id,
|
||||||
guild_id=None,
|
guild_id=None,
|
||||||
fail_if_not_exists=None)})
|
fail_if_not_exists=None)}, request_timeout=self.config.request_timeout)
|
||||||
delayed_delete[bot_message.id] = bot_message
|
delayed_delete[bot_message.id] = bot_message
|
||||||
delayed_delete[message.id] = message
|
delayed_delete[message.id] = message
|
||||||
continue
|
continue
|
||||||
|
|
@ -342,7 +248,7 @@ class DiscordManager:
|
||||||
self.logger.info('Loading subscriptions: %s', new_subscriptions)
|
self.logger.info('Loading subscriptions: %s', new_subscriptions)
|
||||||
SubscriptionHelper.update_subscriptions(new=new_subscriptions, previous=self._yt_subscriptions)
|
SubscriptionHelper.update_subscriptions(new=new_subscriptions, previous=self._yt_subscriptions)
|
||||||
self._yt_subscriptions = new_subscriptions
|
self._yt_subscriptions = new_subscriptions
|
||||||
self.tasks.append((DiscordManager.Task.INIT_SUBS, time.time() + 1, None))
|
self.tasks.append((Bot.Task.INIT_SUBS, time.time() + 1, None))
|
||||||
|
|
||||||
# New init message is needed, previous need to be deleted
|
# New init message is needed, previous need to be deleted
|
||||||
if (new_config is not None or new_subscriptions is not None) and self.init_message is not None:
|
if (new_config is not None or new_subscriptions is not None) and self.init_message is not None:
|
||||||
|
|
@ -352,8 +258,8 @@ class DiscordManager:
|
||||||
# Init message absent or deleted
|
# Init message absent or deleted
|
||||||
if self.init_message is None:
|
if self.init_message is None:
|
||||||
assert self.config is not None
|
assert self.config is not None
|
||||||
self.init_message = self.create_message(
|
self.init_message = self.discord_manager.create_message(
|
||||||
self.bot_channel, {'content': self.INIT_MESSAGE},
|
self.bot_channel, {'content': self.INIT_MESSAGE}, request_timeout=self.config.request_timeout,
|
||||||
upload_files=[
|
upload_files=[
|
||||||
('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode()),
|
('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode()),
|
||||||
('subscriptions.csv', FileMime.TEXT_CSV, SubscriptionHelper.generate_text(self._yt_subscriptions))
|
('subscriptions.csv', FileMime.TEXT_CSV, SubscriptionHelper.generate_text(self._yt_subscriptions))
|
||||||
|
|
@ -361,18 +267,19 @@ class DiscordManager:
|
||||||
|
|
||||||
for message in immediate_delete.values():
|
for message in immediate_delete.values():
|
||||||
try:
|
try:
|
||||||
self.delete_message(message)
|
self.discord_manager.delete_message(message, request_timeout=self.config.request_timeout)
|
||||||
except RuntimeError as error:
|
except RuntimeError as error:
|
||||||
self.logger.error('Error deleting after bot channel scan (immediate): %s', error)
|
self.logger.error('Error deleting after bot channel scan (immediate): %s', error)
|
||||||
|
|
||||||
if delayed_delete:
|
if delayed_delete:
|
||||||
self.tasks.append((
|
self.tasks.append((
|
||||||
DiscordManager.Task.DELETE_MESSAGES,
|
Bot.Task.DELETE_MESSAGES,
|
||||||
time.time() + self.config.bot_message_duration,
|
time.time() + self.config.bot_message_duration,
|
||||||
list(delayed_delete.values())))
|
list(delayed_delete.values())))
|
||||||
|
|
||||||
def _init_subs(self):
|
def _init_subs(self):
|
||||||
categories, text_channel = self.list_channels()
|
categories, text_channel = self.discord_manager.list_channels(
|
||||||
|
self.guild_id, request_timeout=self.config.request_timeout)
|
||||||
self.guild_text_channels = text_channel
|
self.guild_text_channels = text_channel
|
||||||
self.guild_categories = categories
|
self.guild_categories = categories
|
||||||
|
|
||||||
|
|
@ -404,7 +311,8 @@ class DiscordManager:
|
||||||
break
|
break
|
||||||
if selected_category is None:
|
if selected_category is None:
|
||||||
selected_category = category_ranges[-1][2]
|
selected_category = category_ranges[-1][2]
|
||||||
sub_channel = self.create_text_channel({
|
sub_channel = self.discord_manager.create_text_channel(
|
||||||
|
self.guild_id, {
|
||||||
'name': discord_name,
|
'name': discord_name,
|
||||||
'parent_id': selected_category.id,
|
'parent_id': selected_category.id,
|
||||||
'permission_overwrites': [
|
'permission_overwrites': [
|
||||||
|
|
@ -412,8 +320,8 @@ class DiscordManager:
|
||||||
deny=Permissions.SEND_MESSAGES),
|
deny=Permissions.SEND_MESSAGES),
|
||||||
Overwrite(self.bot_role.id, OverwriteType.ROLE,
|
Overwrite(self.bot_role.id, OverwriteType.ROLE,
|
||||||
allow=Permissions.VIEW_CHANNEL | Permissions.SEND_MESSAGES,
|
allow=Permissions.VIEW_CHANNEL | Permissions.SEND_MESSAGES,
|
||||||
deny=Permissions.NONE)]
|
deny=Permissions.NONE)]},
|
||||||
})
|
request_timeout=self.config.request_timeout)
|
||||||
if sub_info.channel_info is None:
|
if sub_info.channel_info is None:
|
||||||
_, channel_info = self.yt_manager.request_channel_info(
|
_, channel_info = self.yt_manager.request_channel_info(
|
||||||
sub_info.channel_id, request_timeout=self.config.request_timeout)
|
sub_info.channel_id, request_timeout=self.config.request_timeout)
|
||||||
|
|
@ -422,7 +330,8 @@ class DiscordManager:
|
||||||
continue
|
continue
|
||||||
sub_info.channel_info = channel_info.items[0].snippet
|
sub_info.channel_info = channel_info.items[0].snippet
|
||||||
channel_url = f'https://www.youtube.com/{sub_info.channel_info.customUrl}'
|
channel_url = f'https://www.youtube.com/{sub_info.channel_info.customUrl}'
|
||||||
_ = self.create_message(sub_channel, {'content': channel_url})
|
_ = self.discord_manager.create_message(
|
||||||
|
sub_channel, {'content': channel_url}, request_timeout=self.config.request_timeout)
|
||||||
sub_info.last_update = time.time()
|
sub_info.last_update = time.time()
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
|
|
@ -436,7 +345,7 @@ class DiscordManager:
|
||||||
if sleep_time > 0:
|
if sleep_time > 0:
|
||||||
time.sleep(sleep_time)
|
time.sleep(sleep_time)
|
||||||
match task_type:
|
match task_type:
|
||||||
case DiscordManager.Task.DELETE_MESSAGES:
|
case Bot.Task.DELETE_MESSAGES:
|
||||||
if not isinstance(task_params, list):
|
if not isinstance(task_params, list):
|
||||||
self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params)
|
self.logger.error('Wrong task params for DELETE_MESSAGES: %s', task_params)
|
||||||
elif not task_params:
|
elif not task_params:
|
||||||
|
|
@ -446,11 +355,12 @@ class DiscordManager:
|
||||||
else:
|
else:
|
||||||
for message in task_params:
|
for message in task_params:
|
||||||
try:
|
try:
|
||||||
self.delete_message(message)
|
self.discord_manager.delete_message(
|
||||||
|
message, request_timeout=self.config.request_timeout)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
self.logger.error('Error deleting message %s: %s -> %s',
|
self.logger.error('Error deleting message %s: %s -> %s',
|
||||||
message, error, traceback.format_exc().replace('\n', ' | '))
|
message, error, traceback.format_exc().replace('\n', ' | '))
|
||||||
case DiscordManager.Task.SCAN_BOT_CHANNEL:
|
case Bot.Task.SCAN_BOT_CHANNEL:
|
||||||
try:
|
try:
|
||||||
self._scan_bot_channel()
|
self._scan_bot_channel()
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
|
|
@ -458,73 +368,10 @@ class DiscordManager:
|
||||||
error, traceback.format_exc().replace('\n', ' | '))
|
error, traceback.format_exc().replace('\n', ' | '))
|
||||||
self.tasks.append((
|
self.tasks.append((
|
||||||
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
|
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
|
||||||
case DiscordManager.Task.INIT_SUBS:
|
case Bot.Task.INIT_SUBS:
|
||||||
try:
|
try:
|
||||||
self._init_subs()
|
self._init_subs()
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
self.logger.error('Error initializing subscriptions : %s -> %s',
|
self.logger.error('Error initializing subscriptions : %s -> %s',
|
||||||
error, traceback.format_exc().replace('\n', ' | '))
|
error, traceback.format_exc().replace('\n', ' | '))
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
|
|
||||||
_, channel_info = self._send_request(
|
|
||||||
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
|
||||||
expected_code=201)
|
|
||||||
if not isinstance(channel_info, dict):
|
|
||||||
raise RuntimeError(f'Error creating channel with params (no info): {params}')
|
|
||||||
return TextChannel.from_dict(channel_info)
|
|
||||||
|
|
||||||
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
|
|
||||||
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
|
|
||||||
_, message_info = self._send_request(
|
|
||||||
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
|
||||||
upload_files=upload_files)
|
|
||||||
if not isinstance(message_info, dict):
|
|
||||||
raise RuntimeError(f'Error creating message with params (no info): {params}')
|
|
||||||
return Message.from_dict(message_info)
|
|
||||||
|
|
||||||
def delete_message(self, message: Message):
|
|
||||||
_, _ = self._send_request(
|
|
||||||
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
|
|
||||||
|
|
||||||
def get_current_user(self) -> User:
|
|
||||||
_, user_info = self._send_request(*Api.User.get_current())
|
|
||||||
if not isinstance(user_info, dict):
|
|
||||||
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
|
|
||||||
return User.from_dict(user_info)
|
|
||||||
|
|
||||||
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
|
|
||||||
_, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
|
|
||||||
categories: list[ChannelCategory] = []
|
|
||||||
text_channels: list[TextChannel] = []
|
|
||||||
if channels_info is not None:
|
|
||||||
for channel_info in channels_info:
|
|
||||||
channel_type = ChannelType(channel_info['type'])
|
|
||||||
match channel_type:
|
|
||||||
case ChannelType.GUILD_CATEGORY:
|
|
||||||
categories.append(ChannelCategory.from_dict(channel_info))
|
|
||||||
case ChannelType.GUILD_TEXT:
|
|
||||||
text_channels.append(TextChannel.from_dict(channel_info))
|
|
||||||
return categories, text_channels
|
|
||||||
|
|
||||||
def list_roles(self) -> list[Role]:
|
|
||||||
_, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
|
|
||||||
if not isinstance(roles_info, list):
|
|
||||||
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
|
|
||||||
return [Role.from_dict(r) for r in roles_info]
|
|
||||||
|
|
||||||
def list_text_channel_messages(self, channel: TextChannel, limit: int | None = None, before_id: int | None = None,
|
|
||||||
after_id: int | None = None) -> list[Message]:
|
|
||||||
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
|
|
||||||
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
|
|
||||||
params: Api.Message.ListMessageParams = {}
|
|
||||||
if limit is not None:
|
|
||||||
params['limit'] = limit
|
|
||||||
if before_id is not None:
|
|
||||||
params['before'] = before_id
|
|
||||||
if after_id is not None:
|
|
||||||
params['after'] = after_id
|
|
||||||
_, messages_info = self._send_request(
|
|
||||||
*Api.Message.list_by_channel(channel.id),
|
|
||||||
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
|
|
||||||
return [Message.from_dict(m) for m in messages_info or []]
|
|
||||||
187
breadtube_bot/discord_manager.py
Normal file
187
breadtube_bot/discord_manager.py
Normal file
|
|
@ -0,0 +1,187 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import asdict, dataclass, is_dataclass
|
||||||
|
from enum import Enum
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import urllib.error
|
||||||
|
import urllib.request
|
||||||
|
|
||||||
|
from .api import Api, ApiAction, ApiVersion
|
||||||
|
from .objects import Attachment, ChannelCategory, ChannelType, FileMime, HTTPHeaders, Message, Role, TextChannel, User
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
|
class ApiEncoder(json.JSONEncoder):
|
||||||
|
def default(self, o):
|
||||||
|
if is_dataclass(o):
|
||||||
|
return asdict(o) # type: ignore
|
||||||
|
if isinstance(o, Enum):
|
||||||
|
return o.value
|
||||||
|
return super().default(o)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordManager:
|
||||||
|
MIN_API_VERSION = 9
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RateLimit:
|
||||||
|
remaining: int
|
||||||
|
next_reset: float
|
||||||
|
|
||||||
|
def __init__(self, bot_token: str, bot_version: str, logger: logging.Logger):
|
||||||
|
self._bot_token = bot_token
|
||||||
|
self._logger = logger
|
||||||
|
self._version = bot_version
|
||||||
|
|
||||||
|
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
|
||||||
|
|
||||||
|
def _update_rate_limit(self, headers: HTTPHeaders):
|
||||||
|
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
|
||||||
|
if header_key not in headers:
|
||||||
|
self._logger.info('Warning: no "%s" found in headers', header_key)
|
||||||
|
return
|
||||||
|
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
|
||||||
|
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
|
||||||
|
self._logger.debug('Updated rate limit: %s', self.rate_limit)
|
||||||
|
if self.rate_limit.remaining <= 0:
|
||||||
|
sleep_time = self.rate_limit.next_reset - time.time()
|
||||||
|
if sleep_time > 0:
|
||||||
|
self._logger.debug('Rate limit: sleeping %.03f second', sleep_time)
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
def _send_request(
|
||||||
|
self, api_action: ApiAction, endpoint: str, request_timeout: float, data: bytes | None = None,
|
||||||
|
expected_code: int = 200, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
|
||||||
|
api_version: ApiVersion = ApiVersion.V10) -> tuple[HTTPHeaders, dict | list | None]:
|
||||||
|
if api_version.value < self.MIN_API_VERSION:
|
||||||
|
self._logger.warning(
|
||||||
|
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
|
||||||
|
api_version, self.MIN_API_VERSION)
|
||||||
|
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
|
||||||
|
self._logger.debug('Discord API Request: %s %s', api_action.value, url)
|
||||||
|
|
||||||
|
boundary: str = ''
|
||||||
|
if upload_files:
|
||||||
|
boundary = f'{random.randbytes(16).hex()}'
|
||||||
|
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
|
||||||
|
'Content-Type: application/json\r\n\r\n'.encode() + data
|
||||||
|
+ f'\r\n--{boundary}'.encode()) if data else b''
|
||||||
|
for file_index, (name, mime, content) in enumerate(upload_files):
|
||||||
|
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
|
||||||
|
f' filename="{name}"\r\nContent-Length: {len(content)}'
|
||||||
|
f'\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
|
||||||
|
data += f'\r\n--{boundary}--'.encode()
|
||||||
|
request = urllib.request.Request(url, data=data)
|
||||||
|
request.method = api_action.value
|
||||||
|
request.add_header('User-Agent', f'BreadTube (v{self._version})')
|
||||||
|
request.add_header('Accept', 'application/json')
|
||||||
|
if upload_files:
|
||||||
|
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
|
||||||
|
else:
|
||||||
|
request.add_header('Content-Type', 'application/json')
|
||||||
|
request.add_header('Authorization', f'Bot {self._bot_token}')
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(request, timeout=request_timeout) as response:
|
||||||
|
if response.status != expected_code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
|
||||||
|
body = response.read()
|
||||||
|
headers = dict(response.getheaders())
|
||||||
|
self._update_rate_limit(headers)
|
||||||
|
return headers, json.loads(body.decode()) if body else None
|
||||||
|
except urllib.error.HTTPError as error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
||||||
|
except urllib.error.URLError as error:
|
||||||
|
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
|
||||||
|
except TimeoutError as error:
|
||||||
|
raise RuntimeError(f'Timeout calling API ({url}): {error}') from error
|
||||||
|
|
||||||
|
def download_attachment(self, attachment: Attachment, request_timeout: float,
|
||||||
|
expected_code: int = 200) -> tuple[HTTPHeaders, bytes]:
|
||||||
|
request = urllib.request.Request(attachment.url)
|
||||||
|
request.add_header('User-Agent', f'BreadTube (v{self._version})')
|
||||||
|
request.add_header('Authorization', f'Bot {self._bot_token}')
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(request, timeout=request_timeout) as response:
|
||||||
|
if response.status != expected_code:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'Unexpected code while downloading attachment {response.status} (expected: {expected_code})')
|
||||||
|
return dict(response.getheaders()), response.read()
|
||||||
|
except urllib.error.HTTPError as error:
|
||||||
|
raise RuntimeError(
|
||||||
|
f'HTTP error downloading attachment ({attachment}):'
|
||||||
|
f' {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
||||||
|
except urllib.error.URLError as error:
|
||||||
|
raise RuntimeError(f'URL error downloading attachment ({attachment}): {error}') from error
|
||||||
|
|
||||||
|
def create_text_channel(
|
||||||
|
self, guild_id: int, params: Api.Guild.CreateTextChannelParams, request_timeout: float) -> TextChannel:
|
||||||
|
_, channel_info = self._send_request(
|
||||||
|
*Api.Guild.create_channel(guild_id=guild_id), request_timeout=request_timeout,
|
||||||
|
data=json.dumps(params, cls=ApiEncoder).encode(), expected_code=201)
|
||||||
|
if not isinstance(channel_info, dict):
|
||||||
|
raise RuntimeError(f'Error creating channel with params (no info): {params}')
|
||||||
|
return TextChannel.from_dict(channel_info)
|
||||||
|
|
||||||
|
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams, request_timeout: float,
|
||||||
|
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
|
||||||
|
_, message_info = self._send_request(
|
||||||
|
*Api.Message.create(channel_id=channel.id), request_timeout=request_timeout,
|
||||||
|
data=json.dumps(params, cls=ApiEncoder).encode(), upload_files=upload_files)
|
||||||
|
if not isinstance(message_info, dict):
|
||||||
|
raise RuntimeError(f'Error creating message with params (no info): {params}')
|
||||||
|
return Message.from_dict(message_info)
|
||||||
|
|
||||||
|
def delete_message(self, message: Message, request_timeout: float):
|
||||||
|
_, _ = self._send_request(
|
||||||
|
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), request_timeout=request_timeout,
|
||||||
|
expected_code=204)
|
||||||
|
|
||||||
|
def get_current_user(self, request_timeout: float) -> User:
|
||||||
|
_, user_info = self._send_request(*Api.User.get_current(), request_timeout=request_timeout)
|
||||||
|
if not isinstance(user_info, dict):
|
||||||
|
raise RuntimeError(f'Error getting current user (not a dict): {user_info}')
|
||||||
|
return User.from_dict(user_info)
|
||||||
|
|
||||||
|
def list_channels(self, guild_id: int, request_timeout: float) -> tuple[list[ChannelCategory], list[TextChannel]]:
|
||||||
|
_, channels_info = self._send_request(*Api.Guild.list_guilds(guild_id), request_timeout=request_timeout)
|
||||||
|
categories: list[ChannelCategory] = []
|
||||||
|
text_channels: list[TextChannel] = []
|
||||||
|
if channels_info is not None:
|
||||||
|
for channel_info in channels_info:
|
||||||
|
channel_type = ChannelType(channel_info['type'])
|
||||||
|
match channel_type:
|
||||||
|
case ChannelType.GUILD_CATEGORY:
|
||||||
|
categories.append(ChannelCategory.from_dict(channel_info))
|
||||||
|
case ChannelType.GUILD_TEXT:
|
||||||
|
text_channels.append(TextChannel.from_dict(channel_info))
|
||||||
|
return categories, text_channels
|
||||||
|
|
||||||
|
def list_roles(self, guild_id: int, request_timeout: float) -> list[Role]:
|
||||||
|
_, roles_info = self._send_request(*Api.Guild.list_roles(guild_id), request_timeout=request_timeout)
|
||||||
|
if not isinstance(roles_info, list):
|
||||||
|
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
|
||||||
|
return [Role.from_dict(r) for r in roles_info]
|
||||||
|
|
||||||
|
def list_text_channel_messages(
|
||||||
|
self, channel: TextChannel, request_timeout: float, limit: int | None = None, before_id: int | None = None,
|
||||||
|
after_id: int | None = None) -> list[Message]:
|
||||||
|
if limit is not None and not (0 < limit <= 100): # noqa: PLR2004
|
||||||
|
raise RuntimeError('Cannot list messages by channel with limit outside [0, 100] range')
|
||||||
|
params: Api.Message.ListMessageParams = {}
|
||||||
|
if limit is not None:
|
||||||
|
params['limit'] = limit
|
||||||
|
if before_id is not None:
|
||||||
|
params['before'] = before_id
|
||||||
|
if after_id is not None:
|
||||||
|
params['after'] = after_id
|
||||||
|
_, messages_info = self._send_request(
|
||||||
|
*Api.Message.list_by_channel(channel.id), request_timeout=request_timeout,
|
||||||
|
data=json.dumps(params, cls=ApiEncoder).encode() if params else None)
|
||||||
|
return [Message.from_dict(m) for m in messages_info or []]
|
||||||
|
|
@ -2,7 +2,7 @@ from argparse import ArgumentParser
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from breadtube_bot.manager import DiscordManager
|
from breadtube_bot.bot import Bot
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -17,7 +17,7 @@ def main():
|
||||||
|
|
||||||
bot_token = Path('data/discord_bot_token.txt').read_text(encoding='utf-8').strip()
|
bot_token = Path('data/discord_bot_token.txt').read_text(encoding='utf-8').strip()
|
||||||
yt_api_key = Path('data/google_api_key.txt').read_text(encoding='utf-8').strip()
|
yt_api_key = Path('data/google_api_key.txt').read_text(encoding='utf-8').strip()
|
||||||
manager = DiscordManager(bot_token=bot_token, guild_id=guild_id, yt_api_key=yt_api_key,
|
manager = Bot(bot_token=bot_token, guild_id=guild_id, yt_api_key=yt_api_key,
|
||||||
log_level=logging.DEBUG if debug_mode else logging.INFO)
|
log_level=logging.DEBUG if debug_mode else logging.INFO)
|
||||||
try:
|
try:
|
||||||
manager.run()
|
manager.run()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue