From 15efdb52414b6c7562437c79dcc11241d7f9353f Mon Sep 17 00:00:00 2001 From: Dylan Dizon Date: Tue, 6 Nov 2018 02:57:59 -0500 Subject: [PATCH] WIP migration to remove alias dict from blacklisting logic in favor of adding all to a single context-aware blacklist, prompting the user to verify and remove unwanted aliases --- src/cogs/booru.py | 143 ++++++++++++++++++++++++++++++----------- src/misc/exceptions.py | 3 + 2 files changed, 108 insertions(+), 38 deletions(-) diff --git a/src/cogs/booru.py b/src/cogs/booru.py index a90019f..6bc8517 100644 --- a/src/cogs/booru.py +++ b/src/cogs/booru.py @@ -664,7 +664,7 @@ class MsG: # Creates temp blacklist based on context for bl in (self.blacklists['global_blacklist'], self.blacklists['guild_blacklist'].get(guild.id, {}).get(ctx.channel.id, set()), self.blacklists['user_blacklist'].get(ctx.author.id, set())): for tag in bl: - blacklist.update([tag] + list(self.aliases[tag])) + blacklist.add(tag) # Checks for, assigns, and removes first order in tags if possible order = [tag for tag in tags if 'order:' in tag] if order: @@ -1405,6 +1405,19 @@ class MsG: await ctx.send('**Invalid blacklist**') await ctx.message.add_reaction('\N{CROSS MARK}') + @_get_blacklist.command(name='alias', aliases=['aliases']) + async def __get_blacklist_aliases(self, ctx, *args): + guild = ctx.guild if isinstance( + ctx.guild, d.Guild) else ctx.channel + + blacklist = set() + # Creates temp blacklist based on context + for bl in (self.blacklists['global_blacklist'], self.blacklists['guild_blacklist'].get(guild.id, {}).get(ctx.channel.id, set()), self.blacklists['user_blacklist'].get(ctx.author.id, set())): + for tag in bl: + blacklist.update([tag] + list(self.aliases[tag])) + + await ctx.send(f'**Contextual blacklist aliases:**\n```\n{formatter.tostring(blacklist)}```') + @_get_blacklist.command(name='global', aliases=['gl', 'g'], brief='Get current global blacklist', description='Get current global blacklist\n\nThis applies to all booru commands, in accordance with Discord\'s ToS agreement\n\nExample:\n\{p\}bl get global') async def __get_global_blacklist(self, ctx, *args): dest = u.get_kwargs(ctx, args)['destination'] @@ -1464,27 +1477,89 @@ class MsG: await ctx.send('**Invalid blacklist**') await ctx.message.add_reaction('\N{CROSS MARK}') + async def _aliases(self, ctx, tags, blacklist): + def on_reaction(reaction, user): + if user is ctx.author and reaction.message.channel is ctx.message.channel: + if reaction.emoji == '\N{HEAVY MINUS SIGN}': + raise exc.Remove + if reaction.emoji == '\N{THUMBS DOWN SIGN}': + raise exc.Continue + elif reaction.emoji == '\N{THUMBS UP SIGN}': + return True + else: + return False + + def on_message(msg): + if msg.author is ctx.message.author and msg.channel is ctx.message.channel: + if msg.content == '0': + raise exc.Abort + return True + return False + + aliases = set() + + try: + for tag in tags: + aliases.add(tag) + alias_request = await u.fetch('https://e621.net/tag_alias/index.json', params={'aliased_to': tag, 'approved': 'true'}, json=True) + if alias_request: + for dic in alias_request: + aliases.add(dic['name']) + + message = await ctx.send(f'**Also add aliases?**```\n{formatter.tostring(aliases)}```') + await message.add_reaction('\N{THUMBS DOWN SIGN}') + await message.add_reaction('\N{HEAVY MINUS SIGN}') + await message.add_reaction('\N{THUMBS UP SIGN}') + + try: + await self.bot.wait_for('reaction_add', check=on_reaction, timeout=7 * 60) + + except exc.Remove: + await message.edit(content=f'**Also add aliases?**```\n{formatter.tostring(aliases)}```\nType the tag(s) to remove or `0` to abort:') + await message.remove_reaction('\N{HEAVY MINUS SIGN}', self.bot.user) + await message.remove_reaction('\N{HEAVY MINUS SIGN}', ctx.author) + response = await self.bot.wait_for('message', check=on_message, timeout=7 * 60) + + for tag in response.content.split(' '): + if tag in aliases: + aliases.remove(tag) + + await message.edit(content=f'**Also add aliases?**```\n{formatter.tostring(aliases)}```\nConfirm or deny changes') + await self.bot.wait_for('reaction_add', check=on_reaction, timeout=7 * 60) + + blacklist.update(aliases) + + await message.delete() + + return aliases + + except exc.Continue: + await message.delete() + + return tags + except exc.Abort: + await message.delete() + + raise exc.Abort + @_add_tags.command(name='global', aliases=['gl', 'g']) @cmds.is_owner() async def __add_global_tags(self, ctx, *args): kwargs = u.get_kwargs(ctx, args) dest, tags = kwargs['destination'], kwargs['remaining'] - await dest.trigger_typing() + try: + await dest.trigger_typing() - self.blacklists['global_blacklist'].update(tags) - for tag in tags: - alias_request = await u.fetch('https://e621.net/tag_alias/index.json', params={'aliased_to': tag, 'approved': 'true'}, json=True) - if alias_request: - for dic in alias_request: - self.aliases.setdefault(tag, set()).add(dic['name']) - else: - self.aliases.setdefault(tag, set()) - u.dump(self.blacklists, 'cogs/blacklists.pkl') - u.dump(self.aliases, 'cogs/aliases.pkl') + tags = await self._aliases(dest, tags, self.blacklists['global_blacklist']) + + u.dump(self.blacklists, 'cogs/blacklists.pkl') await dest.send('**Added to global blacklist:**\n```\n{}```'.format(formatter.tostring(tags))) + except exc.Abort: + await dest.send('**Aborted**') + @_add_tags.command(name='channel', aliases=['ch', 'c'], brief='@manage_channel@ Add tag(s) to the current channel blacklist (requires manage_channel)', description='Add tag(s) to the current channel blacklist ') @cmds.has_permissions(manage_channels=True) async def __add_channel_tags(self, ctx, *args): @@ -1494,42 +1569,34 @@ class MsG: guild = ctx.guild if isinstance( ctx.guild, d.Guild) else ctx.channel - await dest.trigger_typing() + try: + await dest.trigger_typing() - self.blacklists['guild_blacklist'].setdefault( - guild.id, {}).setdefault(ctx.channel.id, set()).update(tags) - for tag in tags: - alias_request = await u.fetch('https://e621.net/tag_alias/index.json', params={'aliased_to': tag, 'approved': 'true'}, json=True) - if alias_request: - for dic in alias_request: - self.aliases.setdefault(tag, set()).add(dic['name']) - else: - self.aliases.setdefault(tag, set()) - u.dump(self.blacklists, 'cogs/blacklists.pkl') - u.dump(self.aliases, 'cogs/aliases.pkl') + tags = await self._aliases(dest, tags, self.blacklists['guild_blacklist'].setdefault(guild.id, {}).setdefault(ctx.channel.id, set())) - await dest.send('**Added to** {} **blacklist:**\n```\n{}```'.format(ctx.channel.mention, formatter.tostring(tags)), delete_after=5) + u.dump(self.blacklists, 'cogs/blacklists.pkl') + + await dest.send('**Added to** {} **blacklist:**\n```\n{}```'.format(ctx.channel.mention, formatter.tostring(tags))) + + except exc.Abort: + await dest.send('**Aborted**') @_add_tags.command(name='me', aliases=['m']) async def __add_user_tags(self, ctx, *args): kwargs = u.get_kwargs(ctx, args) dest, tags = kwargs['destination'], kwargs['remaining'] - await dest.trigger_typing() + try: + await dest.trigger_typing() - self.blacklists['user_blacklist'].setdefault( - ctx.author.id, set()).update(tags) - for tag in tags: - alias_request = await u.fetch('https://e621.net/tag_alias/index.json', params={'aliased_to': tag, 'approved': 'true'}, json=True) - if alias_request: - for dic in alias_request: - self.aliases.setdefault(tag, set()).add(dic['name']) - else: - self.aliases.setdefault(tag, set()) - u.dump(self.blacklists, 'cogs/blacklists.pkl') - u.dump(self.aliases, 'cogs/aliases.pkl') + tags = await self._aliases(dest, tags, self.blacklists['user_blacklist'].setdefault(ctx.author.id, set())) - await dest.send('{} **added to their blacklist:**\n```\n{}```'.format(ctx.author.mention, formatter.tostring(tags)), delete_after=5) + u.dump(self.blacklists, 'cogs/blacklists.pkl') + + await dest.send('{} **added to their blacklist:**\n```\n{}```'.format(ctx.author.mention, formatter.tostring(tags))) + + except exc.Abort: + await dest.send('**Aborted**') @blacklist.group(name='remove', aliases=['rm', 'r']) async def _remove_tags(self, ctx): diff --git a/src/misc/exceptions.py b/src/misc/exceptions.py index 965ebe9..ac58f01 100644 --- a/src/misc/exceptions.py +++ b/src/misc/exceptions.py @@ -10,6 +10,9 @@ async def send_error(ctx, error): # class NSFW(errext.CheckFailure): # pass +class Remove(Exception): + pass + class SizeError(Exception): pass