diff --git a/.gitignore b/.gitignore index b34f717..f260bf9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .vscode/ venv/ *.secret +__pycache__/ \ No newline at end of file diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/modules/diffuseapi.py b/modules/diffuseapi.py new file mode 100644 index 0000000..e34b8da --- /dev/null +++ b/modules/diffuseapi.py @@ -0,0 +1,121 @@ +import aiohttp + + +class DiffuseAPI(): + def __init__(self, url, styles, nsfw_enabled=True, num_steps=28): + + self.nsfw_enabled = nsfw_enabled + self.num_steps = num_steps + self.url = url + self.styles = styles + self.seed = -1 + self.width = 512 + self.height = 1024 + self.cfg_scale = 12 + + def set_steps(self, steps): + try: + new_steps = int(steps) + if 0 > new_steps < 50: + self.num_steps = new_steps + return True + return False + except: + return False + + def set_seed(self, seed): + try: + new_seed = int(seed) + self.seed = new_seed + return True + except: + return False + + def set_cfg_scale(self, scale): + try: + new_scale = int(scale) + if 0 > new_scale < 30: + self.cfg_scale = new_scale + return True + return False + except: + return False + + def set_styles(self, styles): + if type(styles) == list: + self.styles = styles + return True + return False + + def set_orientation(self, orientation): + new_orientation = str(orientation) + match new_orientation: + case "portrait": + self.width = 512 + self.height = 1024 + return True + case "landscape": + self.width = 1024 + self.height = 512 + return True + case "square": + self.width = 512 + self.height = 512 + return True + case _: + return False + + def get_orientation(self): + if self.width == 512 and self.height == 512: + return ("square", self.width, self.height) + elif self.width == 1024: + return ("landscape", self.width, self.height) + return ("portrait", self.width, self.height) + + def set_nsfw_filter(self, filter_state): + if type(filter_state) == bool: + self.nsfw_enabled = filter_state + return True + return False + + def get_nsfw_filter(self): + return self.nsfw_enabled + + async def generate_image(self, prompt, neg_prompt=""): + payload = { + "prompt": prompt, + "styles": self.styles, + "steps": self.num_steps, + "seed": self.seed, + "n_iter": 1, + "height": self.height, + "width": self.width, + "negative_prompts": neg_prompt, + "cfg_scale": self.cfg_scale + } + + settings = { + "filter_nsfw": not self.nsfw_enabled + } + + override_payload = { + "override_settings": settings + } + + payload.update(override_payload) + + sess = aiohttp.ClientSession(self.url) + alive = await sess.head('/') + + if alive.status != 200: + alive.close() + await sess.close() + return None + request = await sess.post("/sdapi/v1/txt2img", json=payload) + try: + req_json = await request.json() + request.close() + await sess.close() + return req_json["images"][0] + except: + return None diff --git a/modules/diffuseapi.py_ref b/modules/diffuseapi.py_ref new file mode 100644 index 0000000..0fa1c52 --- /dev/null +++ b/modules/diffuseapi.py_ref @@ -0,0 +1,41 @@ +import aiohttp +import asyncio + +class DiffuseAPI: + def __init__(self, url, nsfw_enabled, styles, steps): + self.url = url + self.nsfw_enabled = nsfw_enabled + self.styles = styles + self.steps = steps + self.seed = -1 + + def _generate_payload(self): + payload = { + "prompt": "", + "negative_prompt": "", + "steps": self.steps, + "seed": self.seed, + "styles": self.styles, + "height": 1024, + "width": 512 + } + settings = { + "filter_nsfw": not self.nsfw_enabled, + "samples_save": True, + } + + override_payload = { + "override_settings": settings + } + payload.update(override_payload) + return payload + + async def generate_image(self, prompt, negative_prompt=""): + async with aiohttp.ClientSession(self.url) as session: + payload = self._generate_payload() + payload.update({"prompt": prompt, "negative_prompt": negative_prompt}) + print(payload) + async with session.post('/sdapi/v1/txt2img', json=payload) as image_handler: + image_data = await image_handler.json() + + print(image_data) diff --git a/weeeabot.py b/weeeabot.py index 7b04ed9..b339e73 100644 --- a/weeeabot.py +++ b/weeeabot.py @@ -3,6 +3,7 @@ from discord.ext import commands import aiohttp from base64 import b64decode from sys import exit +from modules import diffuseapi try: with open("token.secret", 'r') as f: @@ -14,6 +15,8 @@ except FileNotFoundError: intents = discord.Intents.default() intents.message_content = True +api = diffuseapi.DiffuseAPI("http://localhost:7860", ["default"], True, 28) + # client = discord.Client(intents=intents) class Settings: @@ -156,4 +159,19 @@ async def test(ctx): main_settings.styles = ["Bot"] await ctx.reply("Set to main api") +@bot.command() +async def test2(ctx): + image_data = await api.generate_image(prompt="cute bot doing bot things") + if image_data is None: + return + with open('/tmp/image.png', 'wb') as f: + f.write(b64decode(image_data)) + + embed = discord.Embed() + upload_file = discord.File("/tmp/image.png", filename="image.png") + embed.set_image(url="attachment://image.png") + # embed.title = prompt + + await ctx.reply("", file=upload_file, embed=embed) + bot.run(discord_client_token) \ No newline at end of file