Refresh older sub first + save subscriptions

This commit is contained in:
BreadTube 2025-09-30 23:01:30 +09:00 committed by Corentin
commit 693564bb04
4 changed files with 128 additions and 85 deletions

View file

@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
from enum import Enum from enum import Enum
import http.client
import json
import logging import logging
import operator import operator
from pathlib import Path from pathlib import Path
import http.client
import re import re
import time import time
import tomllib import tomllib
@ -12,7 +13,7 @@ from typing import Any, TYPE_CHECKING
import traceback import traceback
from .config import Config from .config import Config
from .discord_manager import DiscordManager from .discord_manager import ApiEncoder, DiscordManager
from .logger import create_logger from .logger import create_logger
from .objects import (ChannelCategory, FileMime, Message, MessageReference, MessageReferenceType, Overwrite, from .objects import (ChannelCategory, FileMime, Message, MessageReference, MessageReferenceType, Overwrite,
OverwriteType, Permissions, Role, TextChannel) OverwriteType, Permissions, Role, TextChannel)
@ -24,19 +25,20 @@ if TYPE_CHECKING:
class Bot: class Bot:
DEFAULT_MESSAGE_LIST_LIMIT = 50 DEFAULT_MESSAGE_LIST_LIMIT: int = 50
DISCORD_NAME_REGEX = r'([^a-z])' DISCORD_NAME_REGEX: str = r'([^a-z])'
INIT_MESSAGE = ('Bot initialized.\nThis is the current configuration used.\n' INIT_MESSAGE: str = ('Bot initialized.\nThis is the current configuration used.\n'
'You can upload a new one to update the configuration.') 'You can upload a new one to update the configuration.')
MAX_DOWNLOAD_SIZE: int = 50_000 MAX_DOWNLOAD_SIZE: int = 50_000
SUBS_LIST_MIN_SIZE = 50 SUBS_LIST_MIN_SIZE: int = 50
SUBS_LIST_SHORTS_RATIO = 5 SUBS_LIST_SHORTS_RATIO: int = 5
SUBS_LIST_VIDEO_RATIO = 2 SUBS_LIST_VIDEO_RATIO: int = 2
SUBS_SAVE_PATH: Path = Path('/tmp/breadtube-bot_subs.json')
class Task(Enum): class Task(Enum):
DELETE_MESSAGES = 1 DELETE_MESSAGES = 1
SCAN_BOT_CHANNEL = 2 SCAN_BOT_CHANNEL = 2
INIT_SUBS = 3 REFRESH_SUBS = 3
@staticmethod @staticmethod
def _get_code_version() -> str: def _get_code_version() -> str:
@ -90,12 +92,15 @@ class Bot:
raise RuntimeError("Couldn't initialize bot channel/role/permission") raise RuntimeError("Couldn't initialize bot channel/role/permission")
self.bot_channel: TextChannel = bot_channel self.bot_channel: TextChannel = bot_channel
self._yt_subscriptions: Subscriptions = {} self.yt_manager = YoutubeManager(api_key=yt_api_key, logger=self.logger)
self._yt_subscriptions: Subscriptions = {
name: SubscriptionInfo.from_dict(info) for name, info in json.loads(
self.SUBS_SAVE_PATH.read_text(encoding='utf-8')).items()} if self.SUBS_SAVE_PATH.exists() else {}
self._scan_bot_channel() self._scan_bot_channel()
self.tasks.append(( self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None)) self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
self.tasks = list(filter(lambda t: t[0] != Bot.Task.REFRESH_SUBS, self.tasks))
self.yt_manager = YoutubeManager(api_key=yt_api_key, logger=self.logger) self.tasks.append((Bot.Task.REFRESH_SUBS, time.time() + 1, None))
self.logger.info('Bot initialized') self.logger.info('Bot initialized')
def init_bot_channel(self) -> TextChannel | None: def init_bot_channel(self) -> TextChannel | None:
@ -140,6 +145,7 @@ class Bot:
return messages return messages
def _scan_bot_channel(self): # noqa: PLR0915 def _scan_bot_channel(self): # noqa: PLR0915
self.logger.info('Starting scanning bot channel')
messages = self._get_all_channel_messages(self.bot_channel) messages = self._get_all_channel_messages(self.bot_channel)
init_message_found = False init_message_found = False
new_config: Config | None = None new_config: Config | None = None
@ -181,7 +187,8 @@ class Bot:
SubscriptionHelper.update_subscriptions( SubscriptionHelper.update_subscriptions(
new=subscriptions, previous=self._yt_subscriptions) new=subscriptions, previous=self._yt_subscriptions)
self._yt_subscriptions = subscriptions self._yt_subscriptions = subscriptions
self.tasks.append((Bot.Task.INIT_SUBS, time.time() + 1, None)) self.tasks = list(filter(lambda t: t[0] != Bot.Task.REFRESH_SUBS, self.tasks))
self.tasks.append((Bot.Task.REFRESH_SUBS, time.time() + 1, None))
except RuntimeError as error: except RuntimeError as error:
self.logger.error('Invalid init subscriptions file: %s', error) self.logger.error('Invalid init subscriptions file: %s', error)
has_error = True has_error = True
@ -255,7 +262,8 @@ class Bot:
self.logger.info('Loading subscriptions') self.logger.info('Loading subscriptions')
SubscriptionHelper.update_subscriptions(new=new_subscriptions, previous=self._yt_subscriptions) SubscriptionHelper.update_subscriptions(new=new_subscriptions, previous=self._yt_subscriptions)
self._yt_subscriptions = new_subscriptions self._yt_subscriptions = new_subscriptions
self.tasks.append((Bot.Task.INIT_SUBS, time.time() + 1, None)) self.tasks = list(filter(lambda t: t[0] != Bot.Task.REFRESH_SUBS, self.tasks))
self.tasks.append((Bot.Task.REFRESH_SUBS, time.time() + 1, None))
# New init message is needed, previous need to be deleted # New init message is needed, previous need to be deleted
if (new_config is not None or new_subscriptions is not None) and self.init_message is not None: if (new_config is not None or new_subscriptions is not None) and self.init_message is not None:
@ -283,6 +291,7 @@ class Bot:
Bot.Task.DELETE_MESSAGES, Bot.Task.DELETE_MESSAGES,
time.time() + self.config.bot_message_duration, time.time() + self.config.bot_message_duration,
list(delayed_delete.values()))) list(delayed_delete.values())))
self.logger.info('Bot channel scanned')
def _get_subscription_channel(self, subscription: SubscriptionInfo, channel_dict: dict[str, TextChannel], def _get_subscription_channel(self, subscription: SubscriptionInfo, channel_dict: dict[str, TextChannel],
category_ranges: list[tuple[int, int, ChannelCategory]]) -> TextChannel: category_ranges: list[tuple[int, int, ChannelCategory]]) -> TextChannel:
@ -340,43 +349,23 @@ class Bot:
def _video_message_content(video: SearchResultItem) -> str: def _video_message_content(video: SearchResultItem) -> str:
return f'https://www.youtube.com/video/{video.id.videoId}' return f'https://www.youtube.com/video/{video.id.videoId}'
def _init_subs(self): def _refresh_sub(self, subscription: SubscriptionInfo, channel_dict: dict[str, TextChannel],
self.logger.info('Initialize all subs') category_ranges: list[tuple[int, int, ChannelCategory]]):
categories, text_channel = self.discord_manager.list_channels(
self.guild_id, request_timeout=self.config.request_timeout)
self.guild_text_channels = text_channel
self.guild_categories = categories
channel_dict: dict[str, TextChannel] = {c.name or '': c for c in self.guild_text_channels}
unmanaged_categories: set[str] = set(self.config.unmanaged_categories.split(','))
category_ranges: list[tuple[int, int, ChannelCategory]] = []
for category in self.guild_categories:
if category.name in unmanaged_categories:
self.logger.debug('Skipping unmanaged category: %s', category.name)
continue
range_info = (category.name or '').split('-')
if len(range_info) != 2: # noqa: PLR2004
self.logger.warning('Cannot compute range for category: %s', category.name)
continue
category_ranges.append((ord(range_info[0].lower()), ord(range_info[1].lower()), category))
category_ranges = sorted(category_ranges, key=operator.itemgetter(0))
for sub_info in self._yt_subscriptions.values():
try: try:
sub_channel = self._get_subscription_channel(sub_info, channel_dict, category_ranges) sub_channel = self._get_subscription_channel(subscription, channel_dict, category_ranges)
except RuntimeError as error: except RuntimeError as error:
self.logger.error(error) self.logger.error(error)
continue return
if sub_info.channel_info is None: if subscription.channel_info is None:
_, channel_info = self.yt_manager.request_channel_info( _, channel_info = self.yt_manager.request_channel_info(
sub_info.channel_id, request_timeout=self.config.request_timeout) subscription.channel_id, request_timeout=self.config.request_timeout)
if not channel_info.items: if not channel_info.items:
raise RuntimeError('No channel info return from YouTube API for channel: %s', sub_channel.name) raise RuntimeError('No channel info return from YouTube API for channel: %s', sub_channel.name)
sub_info.channel_info = channel_info.items[0].snippet subscription.channel_info = channel_info.items[0].snippet
self._refresh_subscription(sub_info) self._refresh_subscription(subscription)
sub_init_message = f'https://www.youtube.com/{sub_info.channel_info.customUrl}' sub_init_message = f'https://www.youtube.com/{subscription.channel_info.customUrl}'
sub_messages = self._get_all_channel_messages(sub_channel) sub_messages = self._get_all_channel_messages(sub_channel)
if not sub_messages or sub_messages[-1].content != sub_init_message: if not sub_messages or sub_messages[-1].content != sub_init_message:
self.logger.debug('Clearing sub channel: %s', sub_channel.name) self.logger.debug('Clearing sub channel: %s', sub_channel.name)
@ -386,7 +375,7 @@ class Bot:
sub_channel, {'content': sub_init_message}, request_timeout=self.config.request_timeout) sub_channel, {'content': sub_init_message}, request_timeout=self.config.request_timeout)
else: else:
messages = list(reversed(sub_messages[:-1][:self.config.youtube_channel_video_count])) messages = list(reversed(sub_messages[:-1][:self.config.youtube_channel_video_count]))
yt_videos = list(reversed(sub_info.video_list[:self.config.youtube_channel_video_count])) yt_videos = list(reversed(subscription.video_list[:self.config.youtube_channel_video_count]))
immediate_delete: dict[int, Message] = { immediate_delete: dict[int, Message] = {
m.id: m for m in sub_messages[self.config.youtube_channel_video_count:-1]} m.id: m for m in sub_messages[self.config.youtube_channel_video_count:-1]}
last_matching_index = 0 last_matching_index = 0
@ -422,7 +411,37 @@ class Bot:
sub_channel, {'content': self._video_message_content(video)}, sub_channel, {'content': self._video_message_content(video)},
request_timeout=self.config.request_timeout) request_timeout=self.config.request_timeout)
sub_info.last_update = time.time() subscription.last_update = time.time()
def _refresh_subs(self):
self.logger.info('Start refreshing subs')
categories, text_channel = self.discord_manager.list_channels(
self.guild_id, request_timeout=self.config.request_timeout)
self.guild_text_channels = text_channel
self.guild_categories = categories
channel_dict: dict[str, TextChannel] = {c.name or '': c for c in self.guild_text_channels}
unmanaged_categories: set[str] = set(self.config.unmanaged_categories.split(','))
category_ranges: list[tuple[int, int, ChannelCategory]] = []
for category in self.guild_categories:
if category.name in unmanaged_categories:
self.logger.debug('Skipping unmanaged category: %s', category.name)
continue
range_info = (category.name or '').split('-')
if len(range_info) != 2: # noqa: PLR2004
self.logger.warning('Cannot compute range for category: %s', category.name)
continue
category_ranges.append((ord(range_info[0].lower()), ord(range_info[1].lower()), category))
category_ranges = sorted(category_ranges, key=operator.itemgetter(0))
sorted_subs = sorted(self._yt_subscriptions.values(), key=lambda s: s.last_update)
for sub_info in sorted_subs:
try:
self._refresh_sub(sub_info, channel_dict, category_ranges)
except TimeoutError as error:
self.logger.error('Timeout error refreshing subcription: %s', error)
break
self.logger.info('Subs refreshed')
def run(self): def run(self):
while self.tasks: while self.tasks:
@ -455,13 +474,17 @@ class Bot:
except Exception as error: except Exception as error:
self.logger.error('Error scanning bot channel: %s -> %s', self.logger.error('Error scanning bot channel: %s -> %s',
error, traceback.format_exc().replace('\n', ' | ')) error, traceback.format_exc().replace('\n', ' | '))
self.tasks = list(filter(lambda t: t[0] != Bot.Task.SCAN_BOT_CHANNEL, self.tasks))
self.tasks.append(( self.tasks.append((
self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None)) self.Task.SCAN_BOT_CHANNEL, time.time() + self.config.bot_channel_scan_interval, None))
case Bot.Task.INIT_SUBS: case Bot.Task.REFRESH_SUBS:
try: try:
self._init_subs() self._refresh_subs()
except Exception as error: except Exception as error:
self.logger.error('Error initializing subscriptions : %s -> %s', self.logger.error('Error initializing subscriptions : %s -> %s',
error, traceback.format_exc().replace('\n', ' | ')) error, traceback.format_exc().replace('\n', ' | '))
self.SUBS_SAVE_PATH.write_text(
json.dumps(self._yt_subscriptions, cls=ApiEncoder, ensure_ascii=False), encoding='utf-8')
self.tasks = list(filter(lambda t: t[0] != Bot.Task.REFRESH_SUBS, self.tasks))
self.tasks.append(( self.tasks.append((
self.Task.INIT_SUBS, time.time() + self.config.youtube_channel_refresh_interval, None)) self.Task.REFRESH_SUBS, time.time() + self.config.youtube_channel_refresh_interval, None))

View file

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import asdict, dataclass, is_dataclass from dataclasses import asdict, dataclass, is_dataclass
from datetime import datetime
from enum import Enum from enum import Enum
import json import json
import random import random
@ -22,6 +23,8 @@ class ApiEncoder(json.JSONEncoder):
return asdict(o) # type: ignore return asdict(o) # type: ignore
if isinstance(o, Enum): if isinstance(o, Enum):
return o.value return o.value
if isinstance(o, datetime):
return o.isoformat()
return super().default(o) return super().default(o)

View file

@ -17,7 +17,7 @@ if TYPE_CHECKING:
class YoutubeManager: class YoutubeManager:
DEFAULT_DAILY_REQUESTS = 10_000 DEFAULT_DAILY_POINTS = 10_000
SHORTS_CHECK_STATUS = 303 SHORTS_CHECK_STATUS = 303
@dataclass @dataclass
@ -28,12 +28,12 @@ class YoutubeManager:
def __init__(self, api_key: str, logger: logging.Logger): def __init__(self, api_key: str, logger: logging.Logger):
self._api_key = api_key self._api_key = api_key
self._logger = logger self._logger = logger
self.rate_limit = self.RateLimit(remaining=self.DEFAULT_DAILY_REQUESTS, next_reset=time.time() + 24 * 3600) self.rate_limit = self.RateLimit(remaining=self.DEFAULT_DAILY_POINTS, next_reset=time.time() + 24 * 3600)
def _request(self, url: str, request_timeout: float, expected_status: int = 200) -> tuple[HTTPHeaders, dict]: def _request(self, url: str, request_timeout: float, expected_status: int = 200) -> tuple[HTTPHeaders, dict]:
if time.time() >= self.rate_limit.next_reset: if time.time() >= self.rate_limit.next_reset:
self.rate_limit.next_reset = time.time() + 24 * 3600 self.rate_limit.next_reset = time.time() + 24 * 3600
self.rate_limit.remaining = self.DEFAULT_DAILY_REQUESTS self.rate_limit.remaining = self.DEFAULT_DAILY_POINTS
elif self.rate_limit.remaining <= 0: elif self.rate_limit.remaining <= 0:
sleep_time = time.time() - self.rate_limit.next_reset sleep_time = time.time() - self.rate_limit.next_reset
self._logger.debug('No more remaining in Youtube RateLimit : sleeping for %.03fs', sleep_time) self._logger.debug('No more remaining in Youtube RateLimit : sleeping for %.03fs', sleep_time)
@ -60,7 +60,9 @@ class YoutubeManager:
def is_shorts(self, connection: http.client.HTTPConnection, video_id: str) -> bool: def is_shorts(self, connection: http.client.HTTPConnection, video_id: str) -> bool:
try: try:
connection.request('GET', f'/shorts/{video_id}') endpoint = f'/shorts/{video_id}'
self._logger.debug('YoutubeManager: Checking for shorts: %s', endpoint)
connection.request('GET', endpoint)
response = connection.getresponse() response = connection.getresponse()
response.read() response.read()
return response.status != self.SHORTS_CHECK_STATUS return response.status != self.SHORTS_CHECK_STATUS
@ -71,6 +73,7 @@ class YoutubeManager:
HTTPHeaders, ChannelResult]: HTTPHeaders, ChannelResult]:
url = ('https://www.googleapis.com/youtube/v3/channels?part=snippet' url = ('https://www.googleapis.com/youtube/v3/channels?part=snippet'
f'&id={channel_id}&key={self._api_key}') f'&id={channel_id}&key={self._api_key}')
self._logger.debug('YoutubeManager: request channel info for channel %s', channel_id)
headers, info = self._request(url=url, request_timeout=request_timeout) headers, info = self._request(url=url, request_timeout=request_timeout)
return headers, ChannelResult.from_dict(info) return headers, ChannelResult.from_dict(info)
@ -78,5 +81,6 @@ class YoutubeManager:
HTTPHeaders, SearchResult]: HTTPHeaders, SearchResult]:
url = (f'https://www.googleapis.com/youtube/v3/search?part=snippet&channelId={channel_id}' url = (f'https://www.googleapis.com/youtube/v3/search?part=snippet&channelId={channel_id}'
f'&maxResults={max_results}&order=date&type=video&key={self._api_key}') f'&maxResults={max_results}&order=date&type=video&key={self._api_key}')
self._logger.debug('YoutubeManager: request channel videos for channel %s', channel_id)
headers, info = self._request(url=url, request_timeout=request_timeout) headers, info = self._request(url=url, request_timeout=request_timeout)
return headers, SearchResult.from_dict(info) return headers, SearchResult.from_dict(info)

View file

@ -1,3 +1,5 @@
from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from .youtube_objects import ChannelSnippet, SearchResultItem from .youtube_objects import ChannelSnippet, SearchResultItem
@ -12,6 +14,17 @@ class SubscriptionInfo:
shorts_list: list[SearchResultItem] = field(default_factory=list) shorts_list: list[SearchResultItem] = field(default_factory=list)
video_list: list[SearchResultItem] = field(default_factory=list) video_list: list[SearchResultItem] = field(default_factory=list)
@staticmethod
def from_dict(info: dict) -> SubscriptionInfo:
channel_info: dict | None = info.get('channel_info')
return SubscriptionInfo(
name=info['name'],
channel_id=info['channel_id'],
last_update=info['last_update'],
channel_info=ChannelSnippet.from_dict(channel_info) if channel_info is not None else None,
shorts_list=[SearchResultItem.from_dict(s) for s in info['shorts_list']],
video_list=[SearchResultItem.from_dict(s) for s in info['video_list']])
Subscriptions = dict[str, SubscriptionInfo] Subscriptions = dict[str, SubscriptionInfo]