weeabot/modules/diffuseapi.py
Benjamyn Love 6446cdb86b Push me
And then just touch me
So I can get my
Satisfaction
2022-12-11 19:17:35 +11:00

142 lines
4.2 KiB
Python

import aiohttp
from base64 import b64decode
from sys import exit
from pprint import pprint
class DiffuseAPI():
def __init__(self, url, styles, nsfw_enabled=True, num_steps=28):
self.nsfw_enabled = nsfw_enabled
self.num_steps = num_steps
self.url = url
self.styles = styles
self.seed = -1
self.width = 512
self.height = 1024
self.cfg_scale = 12
def set_steps(self, steps):
try:
new_steps = int(steps)
if 0 > new_steps < 50:
self.num_steps = new_steps
return True
return False
except:
return False
def set_seed(self, seed):
try:
new_seed = int(seed)
self.seed = new_seed
return True
except:
return False
def set_cfg_scale(self, scale):
try:
new_scale = int(scale)
if 0 > new_scale < 30:
self.cfg_scale = new_scale
return True
return False
except:
return False
def set_styles(self, styles):
if type(styles) == list:
self.styles = styles
return True
return False
def set_orientation(self, orientation):
new_orientation = str(orientation)
match new_orientation:
case "portrait":
self.width = 512
self.height = 1024
return True
case "landscape":
self.width = 1024
self.height = 512
return True
case "square":
self.width = 512
self.height = 512
return True
case _:
return False
def get_orientation(self):
if self.width == 512 and self.height == 512:
return ("square", self.width, self.height)
elif self.width == 1024:
return ("landscape", self.width, self.height)
return ("portrait", self.width, self.height)
def set_nsfw_filter(self, filter_state):
if type(filter_state) == bool:
self.nsfw_enabled = filter_state
return True
return False
def get_nsfw_filter(self):
return self.nsfw_enabled
async def generate_image(self, prompt, neg_prompt=""):
payload = {
"prompt": prompt,
"styles": self.styles,
"steps": self.num_steps,
"seed": self.seed,
"n_iter": 1,
"height": self.height,
"width": self.width,
"negative_prompts": neg_prompt,
"cfg_scale": self.cfg_scale
}
settings = {
"filter_nsfw": not self.nsfw_enabled,
"enable_pnginfo": False
}
override_payload = {
"override_settings": settings
}
payload.update(override_payload)
async with aiohttp.ClientSession(self.url) as session:
async with session.head('/') as alive:
if alive.status != 200:
return None
async with session.post("/sdapi/v1/txt2img", json=payload) as image_json:
image_data = await image_json.json()
return image_data["images"][0]
async def generate_upscale(self, image):
payload = {
"resize_mode": 0,
"show_extras_results": True,
"gfpgan_visibility": 0,
"codeformer_visibility": 0,
"codeformer_weight": 0,
"upscaling_resize": 4,
"upscaling_resize_w": 512,
"upscaling_resize_h": 1024,
"upscaling_crop": True,
"upscaler_1": "R-ESRGAN 4x+ Anime6B",
"upscaler_2": "None",
"extras_upscaler_2_visibility": 0,
"upscale_first": False,
"image": image
}
async with aiohttp.ClientSession(self.url) as session:
async with session.head('/') as alive:
if alive.status != 200:
return None
async with session.post("/sdapi/v1/extra-single-image", json=payload) as image_json:
image_data = await image_json.json()
return image_data["image"]