如何在Stable Diffusion上Fine Tuning出自己风格的模型

2022-10-24 10:26:45 浏览数 (1)

Stable Diffusion在很多事情上都很出色,但并不是在所有事情上都很棒,并且以特定的样式或外观获得结果通常涉及大量工作“即时工程”。那么,如果您想要生成特定类型的图像,除了花很长时间制作复杂的文本提示(prompt)之外,还有另一种方法是微调(Fine Tuning)图像生成模型本身。

Fine Tuning是一种常见的做法,即把一个已经在广泛而多样的数据集上预训练过的模型,再在你特别感兴趣的数据集上再训练一下。这是深度学习的常见做法,比如在自然语言处理(NLP)的BERT模型上微调实际上就是生成专业模型的主要方式,而在图像处理领域,已被证明是从标准图像分类网络到 GAN 的各种模型都非常有效。在此示例中,我们将展示如何在 宝可梦 数据集上微调 Stable Diffusion 以创建对应的txt2img模型,该模型根据任何文本提示制作自定义 宝可梦。

以下是经过训练的模型可以产生的输出类型的一些示例,以及使用的提示(prompt):

Girl with a pearl earring, Cute Obama creature, Donald Trump, Boris Johnson, Totoro, Hello Kitty

如果您只是关注模型、代码或数据集,请参阅:

  • Lambda Diffusers
  • Captioned Pokémon dataset
  • Model weights in Diffusers format
  • Original model weights
  • Training code

如果您只想生成一些神奇宝贝,请使用此notebook或在Replicate上试用。

硬件

按照今天的标准,运行Stable Diffusion本身的要求并不高,微调模型不需要像预训练那样投入大量的计算资源。对于这个示例,我在Lambda GPU Cloud上使用 2xA6000 GPU,并运行大约 15,000 步的训练,运行大约需要 6 个小时,成本约为 10 美元。训练应该能够在单个或更低规格的 GPU 上运行(只要 VRAM 大于 24GB),但您可能需要调整批量大小和梯度累积步骤以适合您的 GPU。有关如何调整这些参数的更多详细信息,请参阅微调笔记本。

数据!

首先,我们需要一个数据集来训练。Stable Diffusion训练需要每个图像都带有对应的文本标题。如果我们为我们的数据集选择具有统一主题和风格的内容,事情会变得简单。在此,我将使用来自 FastGAN 的 宝可梦 数据集,因为它大小合适(几乎一千张图像),高分辨率,并且有非常一致的风格,但是有一个问题是,图像上没有任何的文字描述(文本标题)!

我们将使用神经网络来为我们完成艰苦的工作,而不是自己费力地为每个图片进行标注。这里用到的是一个名为BLIP的图像标注模型。模型的标注并不完美,但它们相当准确且足以满足我们的目的。

我们已将带标注 宝可梦 数据集上传到 Huggingface 以使其易于重用:lambdalabs/pokemon-blip-captions。

代码语言:javascript复制
from datasets import load_dataset
ds = load_dataset("lambdalabs/pokemon-blip-captions", split="train")
sample = ds[0]
display(sample["image"].resize((256, 256)))
print(sample["text"])
a drawing of a green pokemon with red eyesa drawing of a green pokemon with red eyes

做好准备

现在我们有一个数据集,我们需要原始模型的Stable Diffusion模型,可在此处下载,(名称为:sd-v1-4-full-ema.ckpt)接下来我们需要设置训练的代码和环境。我们将使用原始训练代码的一个分支,该分支已经过修改以使其能更友好地进行微调:justinpinkney/stable-diffusion。

Stable Diffusion 使用基于 yaml 的配置文件以及传递给main.py函数的一些额外命令行参数来启动训练。

我们创建了一个运行此微调示例的基本 yaml 配置文件。如果你想在自己的数据集上运行它应该很容易修改,你需要编辑的主要部分是数据配置,这是自定义 yaml 文件的相关摘录:

代码语言:yaml复制
data:
  target: main.DataModuleFromConfig
  params:
    batch_size: 4
    num_workers: 4
    num_val_workers: 0 # Avoid a weird val dataloader issue
    train:
      target: ldm.data.simple.hf_dataset
      params:
        name: lambdalabs/pokemon-blip-captions
        image_transforms:
        - target: torchvision.transforms.Resize
          params:
            size: 512
            interpolation: 3
        - target: torchvision.transforms.RandomCrop
          params:
            size: 512
        - target: torchvision.transforms.RandomHorizontalFlip
    validation:
      target: ldm.data.simple.TextOnly
      params:
        captions:
        - "A pokemon with green eyes, large wings, and a hat"
        - "A cute bunny rabbit"
        - "Yoda"
        - "An epic landscape photo of a mountain"
        output_size: 512
        n_gpus: 2 # small hack to make sure we see all our samples

这部分配置基本上做了以下事情,它使用该ldm.data.simple.hf_dataset函数创建一个数据集,用于对Huggingface Hub 上名为lambdalabs/pokemon-blip-cpations的dataset进行训练,但也可以是格式正确的本地数据目录。对于validation,我们不使用“真实”数据集,而仅使用一些文本提示来评估我们的模型表现如何以及何时停止训练,我们希望训练足够多以获得良好的输出,但我们不想要它忘记原始模型中的所有“常识”。

Train

设置好配置文件后,您就可以通过运行main.py带有一些额外参数的脚本来进行训练了:

  • -t- 进行训练
  • --base configs/stable-diffusion/pokemon.yaml- 使用我们的自定义配置
  • --gpus 0,1- 使用这些 GPU
  • --scale_lr False- 按原样使用配置中的学习率
  • --num_nodes 1- 在单台机器上运行(可能有多个 GPU)
  • --check_val_every_n_epoch 10- 不要太频繁地检查验证样本
  • --finetune_from models/ldm/stable-diffusion-v1/sd-v1-4-full-ema.ckpt- 从原始Stable Diffusion上进行微调
代码语言:javascript复制
python main.py 
    -t 
    --base configs/stable-diffusion/pokemon.yaml 
    --gpus 0,1 
    --scale_lr False 
    --num_nodes 1 
    --check_val_every_n_epoch 10 
    --finetune_from sd-v1-4-full-ema.ckpt

结果

在训练过程中,结果应该被记录到日志文件夹中,你应该看到每隔一段时间从训练数据集中抽取的样本,所有的验证样本都应该被运行。开始的时候,样本看起来像正常的图像,然后开始有口袋妖怪的风格,随着训练的继续,最终与原始提示相背离。

原始风格渐进到宝可梦风格原始风格渐进到宝可梦风格

如果我们想使用该模型,我们可以像使用其他模型一般,例如使用txt2img.py脚本,只需将我们传递的检查点修改为我们的微调版本而不是原始版本:

代码语言:javascript复制
python scripts/txt2img.py 
    --prompt 'robotic cat with wings' 
    --outdir 'outputs/generated_pokemon' 
    --H 512 --W 512 
    --n_samples 4 
    --config 'configs/stable-diffusion/pokemon.yaml' 
    --ckpt 'logs/2022-09-02T06-46-25_pokemon_pokemon/checkpoints/epoch=000142.ckpt'
代码语言:javascript复制
from PIL import Image
im = Image.open("outputs/generated_pokemon/grid-0000.png").resize((1024, 256))
display(im)
print("robotic cat with wings")
“robotic cat with wings”的输出“robotic cat with wings”的输出

该模型应该与为Stable Diffusion开发的任何现有存储库或用户界面兼容,并且还可以使用简单的脚本移植到 Huggingface Diffusers 库。

如果您只想快速的了解,并nodebook中从头到尾运行此示例,请查看此处。

插入您自己的数据

如果您想使用自己的数据进行训练,那么最简单的方法是以正确的方式将其格式化为huggingface上的数据集,如果您的数据集返回imagetext列,那么您可以重新使用本文中的配置,只需将数据集名称更改为您自己的数据集地址即可.

结论

现在您知道如何在自己的数据集上训练自己的Stable Diffusion模型了!

0 人点赞