diff --git a/src/main/cogs/booru.py b/src/main/cogs/booru.py index 75f66d8..f7e6893 100644 --- a/src/main/cogs/booru.py +++ b/src/main/cogs/booru.py @@ -134,8 +134,14 @@ class MsG: if len(pool_request) > 1: for pool in pool_request: pools.append(pool['name']) - match = await ctx.send('**Multiple pools found for `{}`.** Type the number of the correct match.\n```\n{}```\n`0` or `cancel`'.format(' '.join(query), '\n'.join(['{} {}'.format(c, elem) for c, elem in enumerate(pools, 1)]))) - selection = await self.bot.wait_for('message', check=on_message, timeout=60) + match = await ctx.send('**Multiple pools found for `{}`.** Type the number of the correct match\n```\n{}```'.format(' '.join(query), '\n'.join(['{} {}'.format(c, elem) for c, elem in enumerate(pools, 1)]))) + + await ctx.message.add_reaction('🛑') + done, pending = await asyncio.wait([self.bot.wait_for('reaction_add', check=on_reaction, timeout=60), + self.bot.wait_for('reaction_remove', check=on_reaction, timeout=60), self.bot.wait_for('message', check=on_message, timeout=60)], return_when=asyncio.FIRST_COMPLETED) + for future in done: + selection = future.result() + await match.delete() tempool = [pool for pool in pool_request if pool['name'] == pools[int(selection.content) - 1]][0] await selection.delete() @@ -147,9 +153,8 @@ class MsG: await ctx.send(f'**{tempool["name"]}**\nhttps://e621.net/pool/show/{tempool["id"]}') await ctx.message.add_reaction('✅') - except exc.Abort: - await ctx.send('**Search aborted**', delete_after=10) - await ctx.message.add_reaction('\N{CROSS MARK}') + except exc.Abort as e: + await e.message.edit(content='**Search aborted**', delete_after=10) @commands.command(name='getimage', aliases=['geti', 'gi']) @checks.del_ctx() @@ -479,45 +484,56 @@ class MsG: return args async def _get_pool(self, ctx, *, destination, booru='e621', query=[]): - def on_message(msg): - if (msg.content.isdigit() and int(msg.content) == 0) or msg.content.lower() == 'cancel' and msg.author is ctx.author and msg.channel is ctx.channel: - raise exc.Abort - elif msg.content.isdigit(): - if int(msg.content) <= len(pools) and int(msg.content) > 0 and msg.author is ctx.author and msg.channel is ctx.channel: - return True + def on_reaction(reaction, user): + if reaction.emoji == '🛑' and reaction.message.id == ctx.message.id and user is ctx.author: + raise exc.Abort(match) return False + def on_message(msg): + return msg.content.isdigit() and int(msg.content) <= len(pools) and int(msg.content) > 0 and msg.author is ctx.author and msg.channel is ctx.channel + posts = {} pool = {} - pools = [] - pool_request = await u.fetch('https://{}.net/pool/index.json'.format(booru), params={'query': ' '.join(query)}, json=True) - if len(pool_request) > 1: - for pool in pool_request: - pools.append(pool['name']) - match = await ctx.send('**Multiple pools found for `{}`.** Type the number of the correct match.\n```\n{}```\n`0` or `cancel`'.format(' '.join(query), '\n'.join(['{} {}'.format(c, elem) for c, elem in enumerate(pools, 1)]))) - selection = await self.bot.wait_for('message', check=on_message, timeout=60) - await match.delete() - tempool = [pool for pool in pool_request if pool['name'] - == pools[int(selection.content) - 1]][0] - await selection.delete() - pool = {'name': tempool['name'], 'id': tempool['id']} + try: + pools = [] + pool_request = await u.fetch('https://{}.net/pool/index.json'.format(booru), params={'query': ' '.join(query)}, json=True) + if len(pool_request) > 1: + for pool in pool_request: + pools.append(pool['name']) + match = await ctx.send('**Multiple pools found for `{}`.** Type the number of the correct match.\n```\n{}```'.format(' '.join(query), '\n'.join(['{} {}'.format(c, elem) for c, elem in enumerate(pools, 1)]))) - await destination.trigger_typing() - elif pool_request: - tempool = pool_request[0] - pool = {'name': pool_request[0]['name'], 'id': pool_request[0]['id']} - else: - raise exc.NotFound + await ctx.message.add_reaction('🛑') + done, pending = await asyncio.wait([self.bot.wait_for('reaction_add', check=on_reaction, timeout=60), + self.bot.wait_for('reaction_remove', check=on_reaction, timeout=60), self.bot.wait_for('message', check=on_message, timeout=60)], return_when=asyncio.FIRST_COMPLETED) + for future in done: + selection = future.result() - page = 1 - while len(posts) < tempool['post_count']: - posts_request = await u.fetch('https://{}.net/pool/show.json'.format(booru), params={'id': tempool['id'], 'page': page}, json=True) - for post in posts_request['posts']: - posts[post['id']] = {'artist': ', '.join(post['artist']), 'url': post['file_url']} - page += 1 + await match.delete() + tempool = [pool for pool in pool_request if pool['name'] + == pools[int(selection.content) - 1]][0] + await selection.delete() + pool = {'name': tempool['name'], 'id': tempool['id']} - return pool, posts + await destination.trigger_typing() + elif pool_request: + tempool = pool_request[0] + pool = {'name': pool_request[0]['name'], 'id': pool_request[0]['id']} + else: + raise exc.NotFound + + page = 1 + while len(posts) < tempool['post_count']: + posts_request = await u.fetch('https://{}.net/pool/show.json'.format(booru), params={'id': tempool['id'], 'page': page}, json=True) + for post in posts_request['posts']: + posts[post['id']] = {'artist': ', '.join(post['artist']), 'url': post['file_url']} + page += 1 + + return pool, posts + + except exc.Abort as e: + await e.message.edit(content='**Search aborted**') + raise exc.Continue # Messy code that checks image limit and tags in blacklists async def _get_posts(self, ctx, *, booru='e621', tags=[], limit=1, previous={}):