Merge branch 'discord-bot-addition' into stable-diffuse-api
This commit is contained in:
commit
8a3c228f92
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
||||
.vscode/
|
||||
venv/
|
||||
*.secret
|
||||
__pycache__/
|
||||
0
modules/__init__.py
Normal file
0
modules/__init__.py
Normal file
121
modules/diffuseapi.py
Normal file
121
modules/diffuseapi.py
Normal file
@ -0,0 +1,121 @@
|
||||
import aiohttp
|
||||
|
||||
|
||||
class DiffuseAPI():
|
||||
def __init__(self, url, styles, nsfw_enabled=True, num_steps=28):
|
||||
|
||||
self.nsfw_enabled = nsfw_enabled
|
||||
self.num_steps = num_steps
|
||||
self.url = url
|
||||
self.styles = styles
|
||||
self.seed = -1
|
||||
self.width = 512
|
||||
self.height = 1024
|
||||
self.cfg_scale = 12
|
||||
|
||||
def set_steps(self, steps):
|
||||
try:
|
||||
new_steps = int(steps)
|
||||
if 0 > new_steps < 50:
|
||||
self.num_steps = new_steps
|
||||
return True
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
|
||||
def set_seed(self, seed):
|
||||
try:
|
||||
new_seed = int(seed)
|
||||
self.seed = new_seed
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def set_cfg_scale(self, scale):
|
||||
try:
|
||||
new_scale = int(scale)
|
||||
if 0 > new_scale < 30:
|
||||
self.cfg_scale = new_scale
|
||||
return True
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
|
||||
def set_styles(self, styles):
|
||||
if type(styles) == list:
|
||||
self.styles = styles
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_orientation(self, orientation):
|
||||
new_orientation = str(orientation)
|
||||
match new_orientation:
|
||||
case "portrait":
|
||||
self.width = 512
|
||||
self.height = 1024
|
||||
return True
|
||||
case "landscape":
|
||||
self.width = 1024
|
||||
self.height = 512
|
||||
return True
|
||||
case "square":
|
||||
self.width = 512
|
||||
self.height = 512
|
||||
return True
|
||||
case _:
|
||||
return False
|
||||
|
||||
def get_orientation(self):
|
||||
if self.width == 512 and self.height == 512:
|
||||
return ("square", self.width, self.height)
|
||||
elif self.width == 1024:
|
||||
return ("landscape", self.width, self.height)
|
||||
return ("portrait", self.width, self.height)
|
||||
|
||||
def set_nsfw_filter(self, filter_state):
|
||||
if type(filter_state) == bool:
|
||||
self.nsfw_enabled = filter_state
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_nsfw_filter(self):
|
||||
return self.nsfw_enabled
|
||||
|
||||
async def generate_image(self, prompt, neg_prompt=""):
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"styles": self.styles,
|
||||
"steps": self.num_steps,
|
||||
"seed": self.seed,
|
||||
"n_iter": 1,
|
||||
"height": self.height,
|
||||
"width": self.width,
|
||||
"negative_prompts": neg_prompt,
|
||||
"cfg_scale": self.cfg_scale
|
||||
}
|
||||
|
||||
settings = {
|
||||
"filter_nsfw": not self.nsfw_enabled
|
||||
}
|
||||
|
||||
override_payload = {
|
||||
"override_settings": settings
|
||||
}
|
||||
|
||||
payload.update(override_payload)
|
||||
|
||||
sess = aiohttp.ClientSession(self.url)
|
||||
alive = await sess.head('/')
|
||||
|
||||
if alive.status != 200:
|
||||
alive.close()
|
||||
await sess.close()
|
||||
return None
|
||||
request = await sess.post("/sdapi/v1/txt2img", json=payload)
|
||||
try:
|
||||
req_json = await request.json()
|
||||
request.close()
|
||||
await sess.close()
|
||||
return req_json["images"][0]
|
||||
except:
|
||||
return None
|
||||
41
modules/diffuseapi.py_ref
Normal file
41
modules/diffuseapi.py_ref
Normal file
@ -0,0 +1,41 @@
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
class DiffuseAPI:
|
||||
def __init__(self, url, nsfw_enabled, styles, steps):
|
||||
self.url = url
|
||||
self.nsfw_enabled = nsfw_enabled
|
||||
self.styles = styles
|
||||
self.steps = steps
|
||||
self.seed = -1
|
||||
|
||||
def _generate_payload(self):
|
||||
payload = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "",
|
||||
"steps": self.steps,
|
||||
"seed": self.seed,
|
||||
"styles": self.styles,
|
||||
"height": 1024,
|
||||
"width": 512
|
||||
}
|
||||
settings = {
|
||||
"filter_nsfw": not self.nsfw_enabled,
|
||||
"samples_save": True,
|
||||
}
|
||||
|
||||
override_payload = {
|
||||
"override_settings": settings
|
||||
}
|
||||
payload.update(override_payload)
|
||||
return payload
|
||||
|
||||
async def generate_image(self, prompt, negative_prompt=""):
|
||||
async with aiohttp.ClientSession(self.url) as session:
|
||||
payload = self._generate_payload()
|
||||
payload.update({"prompt": prompt, "negative_prompt": negative_prompt})
|
||||
print(payload)
|
||||
async with session.post('/sdapi/v1/txt2img', json=payload) as image_handler:
|
||||
image_data = await image_handler.json()
|
||||
|
||||
print(image_data)
|
||||
18
weeeabot.py
18
weeeabot.py
@ -3,6 +3,7 @@ from discord.ext import commands
|
||||
import aiohttp
|
||||
from base64 import b64decode
|
||||
from sys import exit
|
||||
from modules import diffuseapi
|
||||
|
||||
try:
|
||||
with open("token.secret", 'r') as f:
|
||||
@ -14,6 +15,8 @@ except FileNotFoundError:
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
|
||||
api = diffuseapi.DiffuseAPI("http://localhost:7860", ["default"], True, 28)
|
||||
|
||||
# client = discord.Client(intents=intents)
|
||||
|
||||
class Settings:
|
||||
@ -156,4 +159,19 @@ async def test(ctx):
|
||||
main_settings.styles = ["Bot"]
|
||||
await ctx.reply("Set to main api")
|
||||
|
||||
@bot.command()
|
||||
async def test2(ctx):
|
||||
image_data = await api.generate_image(prompt="cute bot doing bot things")
|
||||
if image_data is None:
|
||||
return
|
||||
with open('/tmp/image.png', 'wb') as f:
|
||||
f.write(b64decode(image_data))
|
||||
|
||||
embed = discord.Embed()
|
||||
upload_file = discord.File("/tmp/image.png", filename="image.png")
|
||||
embed.set_image(url="attachment://image.png")
|
||||
# embed.title = prompt
|
||||
|
||||
await ctx.reply("", file=upload_file, embed=embed)
|
||||
|
||||
bot.run(discord_client_token)
|
||||
Loading…
x
Reference in New Issue
Block a user