API调用ComfyUI模板高效文生图

2024-08-10 15:27:11 浏览数 (1)

一、基础环境

环境搭建参考ComfyUI搭建文生图,并开启ComfyUI的Dev Mode。

ComfyUI API

二、本地化运行脚本编写

代码语言:javascript复制
# -- utf-8 ---
# https://www.bilibili.com/read/cv33202530/
# https://www.wehelpwin.com/article/5317
import json
import websocket
import uuid
import urllib.request
import urllib.parse
import random


# 显示图片
def show_gif(fname):
    import base64
    from IPython import display
    with open(fname, 'rb') as fd:
        b64 = base64.b64encode(fd.read()).decode('ascii')
    return display.HTML(f'<img src="data:image/gif;base64,{b64}" />')

# 向服务器队列发送提示词
def queue_prompt(textPrompt):
    p = {"prompt": textPrompt, "client_id": client_id}
    data = json.dumps(p).encode('utf-8')
    req = urllib.request.Request("http://{}/prompt".format(server_address), data=data)
    return json.loads(urllib.request.urlopen(req).read())  
    
# 获取生成图片
def get_image(fileName, subFolder, folder_type):
    data = {"filename": fileName, "subfolder": subFolder, "type": folder_type}
    url_values = urllib.parse.urlencode(data)
    with urllib.request.urlopen("http://{}/view?{}".format(server_address, url_values)) as response:
        return response.read()

# 获取历史记录
def get_history(prompt_id):
    with urllib.request.urlopen("http://{}/history/{}".format(server_address, prompt_id)) as response:
        return json.loads(response.read())

# 获取图片,监听WebSocket消息
def get_images(ws, prompt):
    prompt_id = queue_prompt(prompt)['prompt_id']
    print('prompt: {}'.format(prompt))
    print('prompt_id:{}'.format(prompt_id))
    output_images = {}
    while True:
        out = ws.recv()
        if isinstance(out, str):
            message = json.loads(out)
            if message['type'] == 'executing': 
                data = message['data']
                if data['node'] is None and data['prompt_id'] == prompt_id:
                    print('执行完成')
                    break
        else:
            continue
    history = get_history(prompt_id)[prompt_id]
    print(history)   
    for o in history['outputs']:
        for node_id in history['outputs']:
            node_output = history['outputs'][node_id]
            # 图片分支
            if 'images' in node_output:
                images_output = []
                for image in node_output['images']:
                    image_data = get_image(image['filename'], image['subfolder'], image['type'])
                    images_output.append(image_data)
                    output_images[node_id] = images_output
            # 视频分支
            if 'videos' in node_output:
                videos_output = []
                for video in node_output['videos']:
                    video_data = get_image(video['filename'], video['subfolder'], video['type'])
                    videos_output.append(video_data)
                    output_images[node_id] = videos_output
    print('获取图片完成:{}'.format(output_images))
    return output_images

# 解析comfyUI 工作流并获取图片
def parse_worflow(ws, prompt, seed, workflowfile):
    workflowfile = workflowfile
    print('workflowfile:{}'.format(workflowfile))
    with open(workflowfile, 'r', encoding="utf-8") as workflow_api_txt2gif_file:
        prompt_data = json.load(workflow_api_txt2gif_file)
        # 设置文本提示
        prompt_data["6"]["inputs"]["text"] = prompt
        return get_images(ws, prompt_data)

# 生成图像并显示
def generate_clip(prompt, seed, workflowfile, idx):
    print('seed:' str(seed))
    ws = websocket.WebSocket()
    ws.connect("ws://{}/ws?clientId={}".format(server_address, client_id))
    images = parse_worflow(ws, prompt, seed, workflowfile)
    for node_id in images:
        for image_data in images[node_id]:
            from datetime import datetime
            # 获取当前时间,并格式化为 YYYYMMDDHHMMSS 的格式
            timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
            # 使用格式化的时间戳在文件名中
            GIF_LOCATION = "{}/{}_{}_{}.png".format('/mnt/d/aigc_result', idx, seed, timestamp)
            print('GIF_LOCATION:' GIF_LOCATION)
            with open(GIF_LOCATION, "wb") as binary_file:
                # 写入二进制文件
                binary_file.write(image_data)
                show_gif(GIF_LOCATION)
                print("{} DONE!!!".format(GIF_LOCATION))
                

if __name__ == "__main__":
    # 设置工作目录和项目相关的路径
    WORKING_DIR = 'output'
    SageMaker_ComfyUI = WORKING_DIR
    workflowfile = '/mnt/d/code/aigc/workflow_api.json'
    COMFYUI_ENDPOINT = 'localhost:8188'
    server_address = COMFYUI_ENDPOINT
    client_id = str(uuid.uuid4())
    seed = 15465856
    prompt = 'Leopards hunt on the grassland'
    generate_clip(prompt, seed, workflowfile, 1)

三、生产部署

代码语言:python代码运行次数:0复制
1、comfyui源代码不变
2、新创建一个类似 mian_v2.py 采用flask 或 fast api方式变现代码(参考server.py内容 ),并引用comfyui的模块的方法,如:

# -*- coding: utf-8 -*-
"""
To start this server, run

$ waitress-serve --port=1230 --call agc_server:create_app

or specify the init sd model:

$ COMMANDLINE_ARGS="--ckpt models/Stable-diffusion/novelaifinal-pruned.ckpt" waitress-serve --port=1230 --call agc_server:create_app
"""
import datetime
import logging
import os
from flask import Flask, request,jsonify
from flask_cors import CORS
import json
# from sdutils import parse_image_data, encode_pil_to_base64
import cv2 as cv
from PIL import Image
import shutil 
import numpy as np
from process_base import *


app = Flask(__name__)
CORS(app, supports_credentials=True)

# print(f"aisticker_server: {datetime.datetime.now()}")
# logging.basicConfig(
#     filename=None,
#     level=logging.INFO,
#     format='%(asctime)s.%(msecs)03d:%(levelname)s:%(message)s',
#     datefmt = '%m/%d/%Y %H:%M:%S')
# logging.info('aisticker_server level: info')
# logging.debug('aisticker_server level: debug')


def update_params_lora(params, loras):
    if len(loras) < 1:
        return params

    start_index = 50000
    start_node, _ = get_prompt_item(params["prompt"], "CheckpointLoaderSimple")

    for i in range(len(loras)):
        lora = loras[i]
        lora["model"] = [start_node, 0]
        lora["clip"] = [start_node, 1]
        item = {"inputs": lora, "class_type": "LoraLoader"}

        params["prompt"][str(start_index)] = item

        start_node = str(start_index)
        start_index  = 1
    
    next_node, _ = get_prompt_items(params["prompt"], "CLIPTextEncode")
    for key in next_node:
        params["prompt"][key]["inputs"]["clip"][0] = start_node

    next_node, _ = get_prompt_items(params["prompt"], "KSampler")
    for key in next_node:
        params["prompt"][key]["inputs"]["model"][0] = start_node

    return params


class AISticker:
    def __init__(self):
        self.server, self.q = startup()
        self.e = execution.PromptExecutor(self.server)

    def forward(self, params):
        params = check_prompt_seed(params)
        demo_prompt_process(self.server, params)
        result = demo_prompt_worker(self.q, self.server, self.e)

        _, result_image = get_prompt_item_with_title(params["prompt"], "output_images")
        result_key = result_image["inputs"]["images"][0]
        _, result_rm_bg = get_prompt_item_with_title(params["prompt"], "remove_bg_images")
        remove_bg_key = result_rm_bg["inputs"]["images"][0]

        return result[result_key][0][0], result[remove_bg_key][0][0]


# def base64_to_str(img_list):
#     for i in range(len(img_list)):
#         img_list[i] = str(img_list[i], "utf-8")
#     return img_list


sticker = AISticker()


@app.route('/sdapi/v1/aisticker', methods=['POST'])
def algo_aisticker():
    params = request.json
    print(params)
    out_params = {"errno" : 0, "outputs": ""}
    
    # images = params.get('images', "")
    json_params = params.get('params', [])
    
    prompt_params = {"prompt": json_params[0]}
    lora_params = json_params[1]
    logging.info(params)
    params = update_params_lora(prompt_params, lora_params["sd_loras"])
    logging.info(params)
    outputs, remove_bg_outputs = sticker.forward(params)
    logging.info(outputs,remove_bg_outputs)
    # save_image(outputs)
    reselt = save_image(remove_bg_outputs)
    out_params["outputs"] = reselt
    return jsonify(out_params), 200 
    # to_base64(output)


def _parse_command_line():
    from argparse import ArgumentParser, RawDescriptionHelpFormatter

    parser = ArgumentParser(epilog="""
测试novelai
==============
""",
                            formatter_class=RawDescriptionHelpFormatter)
    parser.add_argument("-p", "--port", default=1230, type=int, help="Specify the port")
    parser.add_argument("--no-half-vae", dest="no_half_vae", action="store_true", default=False)
    parser.add_argument("--ckpt", type=str)
    parser.add_argument('--xformers', dest='xformers', action='store_true', default=False)
    parser.add_argument('--lora-dir', type=str)
    parser.add_argument('--package-version', type=int, default=0, help="{0, 1, 2} 0 for official server, 1 for debug server, 2 for webui")
    parser.add_argument("--output-dir", type=str, default=None)
    parser.add_argument("--raise-all", dest="raise_all", action='store_true', default=False)
    return parser.parse_args()


def create_app():
    return app


if __name__ == '__main__':
    args = _parse_command_line()
    print(args)
    app.run(port=args.port)

0 人点赞