一、基础环境
环境搭建参考ComfyUI搭建文生图,并开启ComfyUI的Dev Mode。
ComfyUI API
二、本地化运行脚本编写
代码语言:javascript复制# -- utf-8 ---
# https://www.bilibili.com/read/cv33202530/
# https://www.wehelpwin.com/article/5317
import json
import websocket
import uuid
import urllib.request
import urllib.parse
import random
# 显示图片
def show_gif(fname):
import base64
from IPython import display
with open(fname, 'rb') as fd:
b64 = base64.b64encode(fd.read()).decode('ascii')
return display.HTML(f'<img src="data:image/gif;base64,{b64}" />')
# 向服务器队列发送提示词
def queue_prompt(textPrompt):
p = {"prompt": textPrompt, "client_id": client_id}
data = json.dumps(p).encode('utf-8')
req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
return json.loads(urllib.request.urlopen(req).read())
# 获取生成图片
def get_image(fileName, subFolder, folder_type):
data = {"filename": fileName, "subfolder": subFolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
return response.read()
# 获取历史记录
def get_history(prompt_id):
with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
return json.loads(response.read())
# 获取图片,监听WebSocket消息
def get_images(ws, prompt):
prompt_id = queue_prompt(prompt)['prompt_id']
print('prompt: {}'.format(prompt))
print('prompt_id:{}'.format(prompt_id))
output_images = {}
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message['type'] == 'executing':
data = message['data']
if data['node'] is None and data['prompt_id'] == prompt_id:
print('执行完成')
break
else:
continue
history = get_history(prompt_id)[prompt_id]
print(history)
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
# 图片分支
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
# 视频分支
if 'videos' in node_output:
videos_output = []
for video in node_output['videos']:
video_data = get_image(video['filename'], video['subfolder'], video['type'])
videos_output.append(video_data)
output_images[node_id] = videos_output
print('获取图片完成:{}'.format(output_images))
return output_images
# 解析comfyUI 工作流并获取图片
def parse_worflow(ws, prompt, seed, workflowfile):
workflowfile = workflowfile
print('workflowfile:{}'.format(workflowfile))
with open(workflowfile, 'r', encoding="utf-8") as workflow_api_txt2gif_file:
prompt_data = json.load(workflow_api_txt2gif_file)
# 设置文本提示
prompt_data["6"]["inputs"]["text"] = prompt
return get_images(ws, prompt_data)
# 生成图像并显示
def generate_clip(prompt, seed, workflowfile, idx):
print('seed:' str(seed))
ws = websocket.WebSocket()
ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
images = parse_worflow(ws, prompt, seed, workflowfile)
for node_id in images:
for image_data in images[node_id]:
from datetime import datetime
# 获取当前时间,并格式化为 YYYYMMDDHHMMSS 的格式
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
# 使用格式化的时间戳在文件名中
GIF_LOCATION = "{}/{}_{}_{}.png".format('/mnt/d/aigc_result', idx, seed, timestamp)
print('GIF_LOCATION:' GIF_LOCATION)
with open(GIF_LOCATION, "wb") as binary_file:
# 写入二进制文件
binary_file.write(image_data)
show_gif(GIF_LOCATION)
print("{} DONE!!!".format(GIF_LOCATION))
if __name__ == "__main__":
# 设置工作目录和项目相关的路径
WORKING_DIR = 'output'
SageMaker_ComfyUI = WORKING_DIR
workflowfile = '/mnt/d/code/aigc/workflow_api.json'
COMFYUI_ENDPOINT = 'localhost:8188'
server_address = COMFYUI_ENDPOINT
client_id = str(uuid.uuid4())
seed = 15465856
prompt = 'Leopards hunt on the grassland'
generate_clip(prompt, seed, workflowfile, 1)
三、生产部署
代码语言:python代码运行次数:0复制1、comfyui源代码不变
2、新创建一个类似 mian_v2.py 采用flask 或 fast api方式变现代码(参考server.py内容 ),并引用comfyui的模块的方法,如:
# -*- coding: utf-8 -*-
"""
To start this server, run
$ waitress-serve --port=1230 --call agc_server:create_app
or specify the init sd model:
$ COMMANDLINE_ARGS="--ckpt models/Stable-diffusion/novelaifinal-pruned.ckpt" waitress-serve --port=1230 --call agc_server:create_app
"""
import datetime
import logging
import os
from flask import Flask, request,jsonify
from flask_cors import CORS
import json
# from sdutils import parse_image_data, encode_pil_to_base64
import cv2 as cv
from PIL import Image
import shutil
import numpy as np
from process_base import *
app = Flask(__name__)
CORS(app, supports_credentials=True)
# print(f"aisticker_server: {datetime.datetime.now()}")
# logging.basicConfig(
# filename=None,
# level=logging.INFO,
# format='%(asctime)s.%(msecs)03d:%(levelname)s:%(message)s',
# datefmt = '%m/%d/%Y %H:%M:%S')
# logging.info('aisticker_server level: info')
# logging.debug('aisticker_server level: debug')
def update_params_lora(params, loras):
if len(loras) < 1:
return params
start_index = 50000
start_node, _ = get_prompt_item(params["prompt"], "CheckpointLoaderSimple")
for i in range(len(loras)):
lora = loras[i]
lora["model"] = [start_node, 0]
lora["clip"] = [start_node, 1]
item = {"inputs": lora, "class_type": "LoraLoader"}
params["prompt"][str(start_index)] = item
start_node = str(start_index)
start_index = 1
next_node, _ = get_prompt_items(params["prompt"], "CLIPTextEncode")
for key in next_node:
params["prompt"][key]["inputs"]["clip"][0] = start_node
next_node, _ = get_prompt_items(params["prompt"], "KSampler")
for key in next_node:
params["prompt"][key]["inputs"]["model"][0] = start_node
return params
class AISticker:
def __init__(self):
self.server, self.q = startup()
self.e = execution.PromptExecutor(self.server)
def forward(self, params):
params = check_prompt_seed(params)
demo_prompt_process(self.server, params)
result = demo_prompt_worker(self.q, self.server, self.e)
_, result_image = get_prompt_item_with_title(params["prompt"], "output_images")
result_key = result_image["inputs"]["images"][0]
_, result_rm_bg = get_prompt_item_with_title(params["prompt"], "remove_bg_images")
remove_bg_key = result_rm_bg["inputs"]["images"][0]
return result[result_key][0][0], result[remove_bg_key][0][0]
# def base64_to_str(img_list):
# for i in range(len(img_list)):
# img_list[i] = str(img_list[i], "utf-8")
# return img_list
sticker = AISticker()
@app.route('/sdapi/v1/aisticker', methods=['POST'])
def algo_aisticker():
params = request.json
print(params)
out_params = {"errno" : 0, "outputs": ""}
# images = params.get('images', "")
json_params = params.get('params', [])
prompt_params = {"prompt": json_params[0]}
lora_params = json_params[1]
logging.info(params)
params = update_params_lora(prompt_params, lora_params["sd_loras"])
logging.info(params)
outputs, remove_bg_outputs = sticker.forward(params)
logging.info(outputs,remove_bg_outputs)
# save_image(outputs)
reselt = save_image(remove_bg_outputs)
out_params["outputs"] = reselt
return jsonify(out_params), 200
# to_base64(output)
def _parse_command_line():
from argparse import ArgumentParser, RawDescriptionHelpFormatter
parser = ArgumentParser(epilog="""
测试novelai
==============
""",
formatter_class=RawDescriptionHelpFormatter)
parser.add_argument("-p", "--port", default=1230, type=int, help="Specify the port")
parser.add_argument("--no-half-vae", dest="no_half_vae", action="store_true", default=False)
parser.add_argument("--ckpt", type=str)
parser.add_argument('--xformers', dest='xformers', action='store_true', default=False)
parser.add_argument('--lora-dir', type=str)
parser.add_argument('--package-version', type=int, default=0, help="{0, 1, 2} 0 for official server, 1 for debug server, 2 for webui")
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--raise-all", dest="raise_all", action='store_true', default=False)
return parser.parse_args()
def create_app():
return app
if __name__ == '__main__':
args = _parse_command_line()
print(args)
app.run(port=args.port)