diff --git a/dbot/db.py b/dbot/db.py index ea7a51e..0430050 100644 --- a/dbot/db.py +++ b/dbot/db.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Union +from typing import Dict, Optional, Union, List, Tuple from sqlite_utils import Database from sqlite_utils.db import NotFoundError @@ -26,12 +26,23 @@ class DiscordDB: if "rss_feed" not in self._db.table_names(): self._db.create_table("rss_feed", {"id": int, "timestamp": float}, pk="id", not_null={"id", "timestamp"}) + if "notification_channel" not in self._db.table_names(): + self._db.create_table( + "notification_channel", {"id": int, "guild_id": int, "channel_id": int, "kind": str}, pk="id", not_null={"id", "guild_id", "channel_id"}, defaults={"kind": "epic"} + ) + self._db["notification_channel"].create_index(("guild_id", "channel_id", "kind"), unique=True) + self._db.create_table("role_mapping", {"id": int, "channel_id": int, "division": int, "role_id": int}, pk="id", not_null={"id", "channel_id", "division", "role_id"}) + self._db["role_mapping"].add_foreign_key("channel_id", "notification_channel", "channel_id") + self._db["role_mapping"].create_index(("channel_id", "division"), unique=True) + self._db.vacuum() self.member = self._db.table("member") self.player = self._db.table("player") self.epic = self._db.table("epic") self.rss_feed = self._db.table("rss_feed") + self.channel = self._db.table("notification_channel") + self.role_mapping = self._db.table("role_mapping") # Player methods @@ -138,14 +149,50 @@ class DiscordDB: return True return False + # RSS Event Methods + def get_rss_feed_timestamp(self, country_id: int) -> float: + """Get latest processed RSS Feed event's timestamp for country + + :param country_id: int Country ID + :return: timestamp + """ try: return self.rss_feed.get(country_id)["timestamp"] except NotFoundError: return 0 - def set_rss_feed_timestamp(self, country_id: int, timestamp: float): + def set_rss_feed_timestamp(self, country_id: int, timestamp: float) -> None: + """Set latest processed RSS Feed event's timestamp for country + + :param country_id: int Country ID + :param timestamp: float UNIX timestamp + """ if self.get_rss_feed_timestamp(country_id): self.rss_feed.update(country_id, {"timestamp": timestamp}) else: self.rss_feed.insert({"id": country_id, "timestamp": timestamp}) + + # RSS Event Methods + + def add_notification_channel(self, guild_id: int, channel_id: int, kind: str) -> bool: + if channel_id in self.get_kind_notification_channel_ids(kind): + return False + self.channel.insert({"guild_id": guild_id, "channel_id": channel_id, "kind": kind}) + return True + + def get_kind_notification_channel_ids(self, kind: str) -> List[int]: + return [row["channel_id"] for row in self.channel.rows_where("kind = ?", [kind])] + + def add_role_mapping_entry(self, channel_id: int, division: int, role_id: int) -> bool: + if division not in (1, 2, 3, 4, 11): + return False + try: + row = next(self.role_mapping.rows_where("channel_id = ? and division = ?", [channel_id, division])) + self.role_mapping.update(row["id"], {"channel_id": channel_id, "division": division, "role_id": role_id}) + except StopIteration: + self.role_mapping.insert({"channel_id": channel_id, "division": division, "role_id": role_id}) + return True + + def get_notification_channel_and_role_ids_for_division(self, division: int) -> List[Tuple[int, int]]: + return [(row["channel_id"], row["role_id"]) for row in self.role_mapping.rows_where("division = ?", (division,))] diff --git a/dbot/discord_bot.py b/dbot/discord_bot.py index 58472b6..fed8e49 100644 --- a/dbot/discord_bot.py +++ b/dbot/discord_bot.py @@ -144,8 +144,7 @@ class MyClient(discord.Client): logger.debug(kind.format.format(**dict(match.groupdict(), **{"current_country": country.name}))) is_latvia = country.id == 71 has_latvia = any("Latvia" in v for v in values.values()) - is_defender = kind.name == "Region ecured" and country.name in values["defender"] - if is_latvia and has_latvia and is_defender: + if is_latvia or has_latvia: text = kind.format.format(**dict(match.groupdict(), **{"current_country": country.name})) title = kind.name else: @@ -164,7 +163,8 @@ class MyClient(discord.Client): embed.set_footer(text=f"{entry_datetime.strftime('%F %T')} (eRepublik time)") logger.debug(f"Message sent: {text}") - await self.get_channel(DEFAULT_CHANNEL_ID).send(embed=embed) + for channel_id in DB.get_kind_notification_channel_ids("events"): + await self.get_channel(channel_id).send(embed=embed) await asyncio.sleep((self.timestamp // 300 + 1) * 300 - self.timestamp) except Exception as e: @@ -178,8 +178,6 @@ class MyClient(discord.Client): async def report_epics(self): await self.wait_until_ready() - roles = [role for role in self.get_guild(300297668553605131).roles if role.name in MENTION_MAPPING.values()] - role_mapping = {role.name: role.mention for role in roles} while not self.is_closed(): try: r = get_battle_page() @@ -205,7 +203,13 @@ class MyClient(discord.Client): f"Round time {s_to_human(self.timestamp - battle['start'])} " f"https://www.erepublik.com/en/military/battlefield/{battle['id']}" ) - await self.get_channel(DEFAULT_CHANNEL_ID).send(f"{role_mapping[MENTION_MAPPING[div['div']]]}", embed=embed) + notified_channels = [] + for channel_id, role_id in DB.get_notification_channel_and_role_ids_for_division(div["div"]): + await self.get_channel(channel_id).send(f"<@&{role_id}> epic", embed=embed) + notified_channels.append(channel_id) + for channel_id in DB.get_kind_notification_channel_ids("epic"): + if channel_id not in notified_channels: + await self.get_channel(channel_id).send(embed=embed) DB.add_epic(div.get("id")) sleep_seconds = r.get("last_updated") + 60 - self.timestamp @@ -234,6 +238,49 @@ async def on_ready(): logger.info("------") +@bot.command() +async def notify(ctx, kind: str): + if ctx.author.guild_permissions.administrator: + guild_id = ctx.guild.id + channel_id = ctx.channel.id + if kind == "epic": + if DB.add_notification_channel(guild_id, channel_id, kind): + await ctx.send("I will notify about epics in this channel!") + await ctx.send( + "If You want for me to also add division mentions write:\n" + "`!set_division d1 @role_to_mention`\n" + "`!set_division d2 @role_to_mention`\n" + "`!set_division d3 @role_to_mention`\n" + "`!set_division d4 @role_to_mention`\n" + "`!set_division air @role_to_mention`" + ) + elif kind == "events": + DB.add_notification_channel(guild_id, channel_id, kind) + await ctx.send("I will notify about eLatvia's events in this channel!") + else: + await ctx.send(f"Unknown {kind=}") + else: + return await ctx.send("This command is only available for server administrators") + + +@bot.command() +async def set_division(ctx, division: str, role_mention): + if not ctx.author.guild_permissions.administrator: + return await ctx.send("This command is only available for server administrators") + if ctx.channel.id not in DB.get_kind_notification_channel_ids("epic"): + return await ctx.send("This command is only available from registered channels!") + div_map = dict(D1=1, D2=3, D3=3, D4=4, Air=11) + + if division.title() not in div_map: + return await ctx.send(f"Unknown {division=}! Available divisions {', '.join(div_map.keys())}") + for role in ctx.guild.roles: + if role.mention == role_mention: + DB.add_role_mapping_entry(ctx.channel.id, div_map[division.title()], role.id) + return await ctx.send(f"Success! For {division.title()} epics I will mention <@&{role.id}>") + else: + await ctx.send(f"Unable to find the role You mentioned...") + + @bot.command() async def exit(ctx): if ctx.author.id == ADMIN_ID: diff --git a/dbot/tests.py b/dbot/tests.py index 1c72c40..4b96bef 100644 --- a/dbot/tests.py +++ b/dbot/tests.py @@ -56,5 +56,5 @@ class TestRegexes(unittest.TestCase): self.assertTrue(event.format) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/logger.py b/logger.py deleted file mode 100644 index dbff743..0000000 --- a/logger.py +++ /dev/null @@ -1,39 +0,0 @@ -import datetime -import json -import logging -import os -import sys -from json import JSONDecodeError -from typing import Union -import time - -APP_NAME = "discord_bot" - -os.chdir(os.path.abspath(os.path.dirname(sys.argv[0]))) - -logger = logging.getLogger(APP_NAME) -formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - -file_logger = logging.FileHandler(f"./logging.log", "w") -file_logger.setLevel(logging.WARNING) -file_logger.setFormatter(formatter) -logger.addHandler(file_logger) - -stream_logger = logging.StreamHandler() -stream_logger.setLevel(logging.INFO) -stream_logger.setFormatter(formatter) -logger.addHandler(stream_logger) - -logger.setLevel(logging.INFO) - - -def main(): - logger.info('Info message') - logger.debug('Debug message') - logger.warning('Warning message') - logger.error('Error message') - logger.critical('Critical message') - - -if __name__ == "__main__": - main()