diff --git a/src/cogs/booru.py b/src/cogs/booru.py index d9aa140..1b7be22 100644 --- a/src/cogs/booru.py +++ b/src/cogs/booru.py @@ -1571,31 +1571,30 @@ class MsG: await ctx.send('**Invalid blacklist**') await ctx.message.add_reaction('\N{CROSS MARK}') - def _remove(self, tags, lst): - removed = [] - skipped = [] + def _remove(self, remove, lst): + removed = set() - if tags: + if remove: if type(lst) is set: - temp = set() - for tag in tags: - if tag not in tags: - temp.add(tag) - else: - removed.append(tag) + for tag in remove: + with suppress(KeyError): + lst.remove(tag) + removed.add(tag) else: - temp = {} - for k, v in lst.items(): - temp[k] = set() - for tag in v: - if tag not in tags: - temp[k].add(tag) - else: - removed.append(tag) - lst.update(temp) + temp = copy.deepcopy(lst) + for k in temp.keys(): + if k in remove: + with suppress(KeyError): + del lst[k] + removed.add(k) + else: + lst[k] = set([tag for tag in lst[k] if tag not in remove]) + lst = temp + removed.update([tag for k, v in lst.items() for tag in v if tag in remove]) + u.dump(self.blacklists, 'cogs/blacklists.pkl') - return removed, skipped + return removed @remove_tags.command( name='global', @@ -1609,7 +1608,7 @@ class MsG: default = set() if lst == 'blacklist' else {} async with ctx.channel.typing(): - removed, skipped = self._remove( + removed = self._remove( tags, self.blacklists['global'].get(lst, default)) @@ -1630,7 +1629,7 @@ class MsG: default = set() if lst == 'blacklist' else {} async with ctx.channel.typing(): - removed, skipped = self._remove( + removed = self._remove( tags, self.blacklists['channel'].get(ctx.channel.id, {}).get(lst, default)) @@ -1650,7 +1649,7 @@ class MsG: default = set() if lst == 'blacklist' else {} async with ctx.channel.typing(): - removed, skipped = self._remove( + removed = self._remove( tags, self.blacklists['user'].get(ctx.author.id, {}).get(lst, default))