Bot config and channel init
This commit is contained in:
parent
8ca93c1bab
commit
72edbe6599
7 changed files with 499 additions and 109 deletions
|
|
@ -1,26 +1,39 @@
|
|||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from dataclasses import asdict, dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
import tomllib
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
|
||||
from .api import Api, ApiAction, ApiVersion
|
||||
from .config import Config
|
||||
from .logger import create_logger
|
||||
from .objects import (
|
||||
ChannelCategory, ChannelFlags, ChannelType, Message, Overwrite, OverwriteType, Permissions, TextChannel, User)
|
||||
ChannelCategory, ChannelType, FileMime, Message, Overwrite, OverwriteType, Permissions, Role, TextChannel)
|
||||
|
||||
|
||||
HTTPHeaders = dict[str, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _RateLimit:
|
||||
remaining: int
|
||||
next_reset: float
|
||||
class ApiEncoder(json.JSONEncoder):
|
||||
def default(self, o):
|
||||
if is_dataclass(o):
|
||||
return asdict(o) # type: ignore
|
||||
if isinstance(o, Enum):
|
||||
return o.value
|
||||
return super().default(o)
|
||||
|
||||
|
||||
class DiscordManager:
|
||||
@dataclass
|
||||
class RateLimit:
|
||||
remaining: int
|
||||
next_reset: float
|
||||
|
||||
@staticmethod
|
||||
def _get_code_version() -> str:
|
||||
pyproject_path = Path(__file__).parents[1] / 'pyproject.toml'
|
||||
|
|
@ -28,133 +41,176 @@ class DiscordManager:
|
|||
raise RuntimeError('Cannot current bot version')
|
||||
return tomllib.loads(pyproject_path.read_text(encoding='utf-8'))['project']['version']
|
||||
|
||||
def __init__(self, bot_token: str, guild_id: int) -> None:
|
||||
def __init__(self, bot_token: str, guild_id: int, config: Config | None = None,
|
||||
log_level: int = logging.INFO) -> None:
|
||||
self.config = config or Config()
|
||||
self.guild_id = guild_id
|
||||
self._bot_token = bot_token
|
||||
self.logger = create_logger('breadtube', log_level, stdout=True)
|
||||
|
||||
self.rate_limit: _RateLimit = _RateLimit(remaining=1, next_reset=0)
|
||||
self.rate_limit = self.RateLimit(remaining=1, next_reset=0)
|
||||
self.version = self._get_code_version()
|
||||
|
||||
self.guild_roles: list = self.list_roles()
|
||||
for _ in range(self.config.bot_channel_init_retries):
|
||||
while not self.init_bot_channel():
|
||||
time.sleep(10)
|
||||
break
|
||||
else:
|
||||
self.logger.info('Bot init OK')
|
||||
break
|
||||
raise RuntimeError("Couldn't initialize bot channel/role/permission")
|
||||
self.logger.info('Bot initialized')
|
||||
|
||||
def _update_rate_limit(self, headers: HTTPHeaders):
|
||||
for header_key in ['x-ratelimit-remaining', 'x-ratelimit-reset']:
|
||||
if header_key not in headers:
|
||||
print(f'Warning: no "{header_key}" found in headers')
|
||||
self.logger.info('Warning: no "%s" found in headers', header_key)
|
||||
return
|
||||
self.rate_limit.remaining = int(headers['x-ratelimit-remaining'])
|
||||
self.rate_limit.next_reset = float(headers['x-ratelimit-reset'])
|
||||
|
||||
def _send_request(self, api_action: ApiAction, endpoint: str, api_version: ApiVersion = ApiVersion.V10,
|
||||
expected_code: int = 200) -> tuple[
|
||||
HTTPHeaders, dict]:
|
||||
data: bytes | None = None, upload_files: list[tuple[str, FileMime, bytes]] | None = None,
|
||||
expected_code: int = 200) -> tuple[HTTPHeaders, dict | list | None]:
|
||||
timeout = 3
|
||||
min_api_version = 9
|
||||
|
||||
if api_action == ApiAction.POST:
|
||||
raise NotImplementedError
|
||||
if api_version.value < min_api_version:
|
||||
print(f'Warning: using deprecated API version {api_version} (minimum non deprecated is {min_api_version})')
|
||||
self.logger.warning(
|
||||
'Warning: using deprecated API version %d (minimum non deprecated is %d)',
|
||||
api_version, min_api_version)
|
||||
url = f'https://discord.com/api/v{api_version.value}{endpoint}'
|
||||
request = urllib.request.Request(url)
|
||||
|
||||
boundary: str = ''
|
||||
if upload_files:
|
||||
boundary = f'{random.randbytes(16).hex()}'
|
||||
data = (f'--{boundary}\r\nContent-Disposition: form-data; name="payload_json"\r\n'
|
||||
'Content-Type: application/json\r\n\r\n'.encode() + data
|
||||
+ f'\r\n--{boundary}'.encode()) if data else b''
|
||||
for file_index, (name, mime, content) in enumerate(upload_files):
|
||||
data += (f'\r\n--{boundary}\r\nContent-Disposition: form-data; name="files[{file_index}]";'
|
||||
f' filename="{name}"\r\nContent-Type: {mime.value}\r\n\r\n').encode() + content
|
||||
data += f'\r\n--{boundary}--'.encode()
|
||||
request = urllib.request.Request(url, data=data)
|
||||
request.method = api_action.value
|
||||
request.add_header('User-Agent', f'BreadTube (v{self.version})')
|
||||
request.add_header('Accept', 'application/json')
|
||||
if upload_files:
|
||||
request.add_header('Content-Type', f'multipart/form-data; boundary={boundary}')
|
||||
else:
|
||||
request.add_header('Content-Type', 'application/json')
|
||||
request.add_header('Authorization', f'Bot {self._bot_token}')
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=timeout) as response:
|
||||
if response.status != expected_code:
|
||||
raise RuntimeError(
|
||||
f'Unexpected code {response.status} (expected: {expected_code}) -> {response.read().decode()}')
|
||||
return dict(response.getheaders()), json.loads(response.read().decode())
|
||||
body = response.read()
|
||||
return dict(response.getheaders()), json.loads(body.decode()) if body else None
|
||||
except urllib.error.HTTPError as error:
|
||||
raise RuntimeError(
|
||||
f'HTTP error calling API ({url}): {error}:\nHeaders:\n{error.headers}Body:\n{error.read()}') from error
|
||||
except urllib.error.URLError as error:
|
||||
raise RuntimeError(f'URL error calling API ({url}): {error}') from error
|
||||
|
||||
@staticmethod
|
||||
def _parse_overwrite(info: dict) -> Overwrite:
|
||||
return Overwrite(
|
||||
id=int(info['id']),
|
||||
type=OverwriteType(info['type']),
|
||||
allow=Permissions(int(info['allow'])),
|
||||
deny=Permissions(int(info['deny']))
|
||||
)
|
||||
def init_bot_channel(self) -> bool:
|
||||
_, text_channel = self.list_channels()
|
||||
breadtube_role: Role | None = None
|
||||
everyone_role: Role | None = None
|
||||
for role in self.guild_roles:
|
||||
if role.name == self.config.bot_role:
|
||||
breadtube_role = role
|
||||
elif role.name == '@everyone':
|
||||
everyone_role = role
|
||||
if breadtube_role is None:
|
||||
self.logger.info('No BreadTube role found')
|
||||
return False
|
||||
if everyone_role is None:
|
||||
self.logger.info('No everyone role found')
|
||||
return False
|
||||
|
||||
def _parse_channel_category(self, info: dict) -> ChannelCategory:
|
||||
parent_id: str | None = info.get('parent_id')
|
||||
return ChannelCategory(
|
||||
id=int(info['id']),
|
||||
guild_id=int(info['guild_id']),
|
||||
position=int(info['position']),
|
||||
permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']],
|
||||
name=info.get('name'),
|
||||
parent_id=int(parent_id) if parent_id is not None else None,
|
||||
flags=ChannelFlags(info['flags']),
|
||||
)
|
||||
breadtube_channel: TextChannel | None = None
|
||||
for channel in text_channel:
|
||||
if channel.name == self.config.bot_channel:
|
||||
breadtube_channel = channel
|
||||
self.logger.info('Found breadtube bot channel')
|
||||
for perm in breadtube_channel.permission_overwrites:
|
||||
if perm.id == breadtube_role.id:
|
||||
if not perm.allow | Permissions.VIEW_CHANNEL:
|
||||
self.logger.info('BreadTube bot cannot view BreadTube channel: permission missing')
|
||||
return False
|
||||
self.logger.info('BreadTube channel permission OK')
|
||||
break
|
||||
messages = self.list_text_channel_messages(breadtube_channel)
|
||||
for message in messages:
|
||||
self.logger.debug('Deleting message: %s', message)
|
||||
self.delete_message(message)
|
||||
break
|
||||
else:
|
||||
breadtube_channel = self.create_text_channel({
|
||||
'name': self.config.bot_channel,
|
||||
'permission_overwrites': [
|
||||
Overwrite(everyone_role.id, OverwriteType.ROLE, allow=Permissions.NONE,
|
||||
deny=Permissions.VIEW_CHANNEL),
|
||||
Overwrite(breadtube_role.id, OverwriteType.ROLE, allow=Permissions.VIEW_CHANNEL,
|
||||
deny=Permissions.NONE)]
|
||||
})
|
||||
self.logger.info('Created breadtube bot channel')
|
||||
|
||||
def _parse_text_channel(self, info: dict) -> TextChannel:
|
||||
parent_id: str | None = info.get('parent_id')
|
||||
last_message_id: str | None = info.get('last_message_id')
|
||||
last_pin_timestamp: str | None = info.get('last_pin_timestamp')
|
||||
return TextChannel(
|
||||
id=int(info['id']),
|
||||
guild_id=int(info['guild_id']),
|
||||
position=int(info['position']),
|
||||
permission_overwrites=[self._parse_overwrite(o) for o in info['permission_overwrites']],
|
||||
name=info.get('name'),
|
||||
topic=info.get('topic'),
|
||||
nsfw=info['nsfw'],
|
||||
last_message_id=int(last_message_id) if last_message_id is not None else None,
|
||||
rate_limit_per_user=int(info['rate_limit_per_user']),
|
||||
parent_id=int(parent_id) if parent_id is not None else None,
|
||||
last_pin_timestamp=(datetime.fromisoformat(last_pin_timestamp) if last_pin_timestamp is not None else None),
|
||||
flags=ChannelFlags(info['flags']),
|
||||
)
|
||||
self.create_message(
|
||||
breadtube_channel,
|
||||
{'content': 'This is the current configuration used, upload a new one to update the configuration'},
|
||||
upload_files=[('config.txt', FileMime.TEXT_PLAIN, self.config.to_str().encode())])
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _parse_user(info: dict) -> User:
|
||||
return User(
|
||||
id=int(info['id']),
|
||||
username=info['username'],
|
||||
discriminator=info['discriminator'],
|
||||
global_name=info.get('global_name')
|
||||
)
|
||||
def create_text_channel(self, params: Api.Guild.CreateTextChannelParams) -> TextChannel:
|
||||
headers, channel_info = self._send_request(
|
||||
*Api.Guild.create_channel(guild_id=self.guild_id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
||||
expected_code=201)
|
||||
self._update_rate_limit(headers)
|
||||
if not isinstance(channel_info, dict):
|
||||
raise RuntimeError(f'Error creating channel with params (no info): {params}')
|
||||
return TextChannel.from_dict(channel_info)
|
||||
|
||||
def _parse_message(self, info: dict) -> Message:
|
||||
edited_timestamp: str | None = info.get('edited_timestamp')
|
||||
return Message(
|
||||
id=int(info['id']),
|
||||
channel_id=int(info['channel_id']),
|
||||
author=(self._parse_user(info['author']) if info.get('webhook_id') is None else User(
|
||||
id=info['webhook_id'], username='webhook', discriminator='webhook', global_name=None)),
|
||||
content=info['content'],
|
||||
timestamp=datetime.fromisoformat(info['timestamp']),
|
||||
edited_timestamp=datetime.fromisoformat(edited_timestamp) if edited_timestamp is not None else None
|
||||
)
|
||||
def create_message(self, channel: TextChannel, params: Api.Message.CreateParams,
|
||||
upload_files: list[tuple[str, FileMime, bytes]] | None = None) -> Message:
|
||||
headers, message_info = self._send_request(
|
||||
*Api.Message.create(channel_id=channel.id), data=json.dumps(params, cls=ApiEncoder).encode(),
|
||||
upload_files=upload_files)
|
||||
self._update_rate_limit(headers)
|
||||
if not isinstance(message_info, dict):
|
||||
raise RuntimeError(f'Error creating message with params (no info): {params}')
|
||||
return Message.from_dict(message_info)
|
||||
|
||||
def delete_message(self, message: Message):
|
||||
try:
|
||||
headers, _ = self._send_request(
|
||||
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
|
||||
self._update_rate_limit(headers)
|
||||
print(f'Message {message.id} deleted')
|
||||
except RuntimeError as error:
|
||||
print(error)
|
||||
headers, _ = self._send_request(
|
||||
*Api.Message.delete(channel_id=message.channel_id, message_id=message.id), expected_code=204)
|
||||
self._update_rate_limit(headers)
|
||||
|
||||
def list_channels(self) -> tuple[list[ChannelCategory], list[TextChannel]]:
|
||||
headers, channels = self._send_request(*Api.Guild.list_guilds(self.guild_id))
|
||||
headers, channels_info = self._send_request(*Api.Guild.list_guilds(self.guild_id))
|
||||
self._update_rate_limit(headers)
|
||||
categories: list[ChannelCategory] = []
|
||||
text_channels: list[TextChannel] = []
|
||||
for channel in channels:
|
||||
channel_type = ChannelType(channel['type'])
|
||||
match channel_type:
|
||||
case ChannelType.GUILD_CATEGORY:
|
||||
categories.append(self._parse_channel_category(channel))
|
||||
case ChannelType.GUILD_TEXT:
|
||||
text_channels.append(self._parse_text_channel(channel))
|
||||
if channels_info is not None:
|
||||
for channel_info in channels_info:
|
||||
channel_type = ChannelType(channel_info['type'])
|
||||
match channel_type:
|
||||
case ChannelType.GUILD_CATEGORY:
|
||||
categories.append(ChannelCategory.from_dict(channel_info))
|
||||
case ChannelType.GUILD_TEXT:
|
||||
text_channels.append(TextChannel.from_dict(channel_info))
|
||||
return categories, text_channels
|
||||
|
||||
def list_text_channel_messages(self, channel: TextChannel) -> list:
|
||||
def list_roles(self) -> list[Role]:
|
||||
headers, roles_info = self._send_request(*Api.Guild.list_roles(self.guild_id))
|
||||
self._update_rate_limit(headers)
|
||||
if not isinstance(roles_info, list):
|
||||
raise RuntimeError(f'Error listing roles (not a list): {roles_info}')
|
||||
return [Role.from_dict(r) for r in roles_info]
|
||||
|
||||
def list_text_channel_messages(self, channel: TextChannel) -> list[Message]:
|
||||
headers, messages = self._send_request(*Api.Message.list_by_channel(channel.id))
|
||||
self._update_rate_limit(headers)
|
||||
return [self._parse_message(m) for m in messages]
|
||||
return [Message.from_dict(m) for m in messages or []]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue