快速开始

备注

阅读本篇前,请确保已按照 安装教程 准备好昇腾环境及 Diffusers !

本示例以文生图 Diffusers 库中文生图任务为样例,展示如何进行文生图模型 stable-diffusion-xl-base-1.0 的基于 LoRA 的微调及动态合并 LoRA 的推理。

文生图

模型及数据集下载

  1. 请提前下载 stabilityai/stable-diffusion-xl-base-1.0 模型至自定义路径

  2. 请提前下载 madebyollin/sdxl-vae-fp16-fix 模型至自定义路径

  3. 请提前下载 reach-vb/pokemon-blip-captions 数据集至自定义路径

基于 LoRA 的微调

进入 Diffusers 项目目录,新建并执行以下脚本:

备注

请根据 模型及数据集下载 中模型及数据集的实际缓存路径指定 stable-diffusion-xl-base-1.0 模型缓存路径 MODEL_NAME,sdxl-vae-fp16-fix 模型缓存路径 VAE_NAME 和。

 1export MODEL_NAME="./models_ckpt/stable-diffusion-xl-base-1.0/"
 2export VAE_NAME="./ckpt/sdxl-vae-fp16-fix"
 3export TRAIN_DIR="~/diffusers/data/pokemon-blip-captions/pokemon"
 4
 5python3  ./examples/text_to_image/train_text_to_image_lora_sdxl.py \
 6    --pretrained_model_name_or_path=$MODEL_NAME \
 7    --pretrained_vae_model_name_or_path=$VAE_NAME \
 8    --dataset_name=$DATASET_NAME --caption_column="text" \
 9    --resolution=1024 \
10    --random_flip \
11    --train_batch_size=1 \
12    --num_train_epochs=2 \
13    --checkpointing_steps=500 \
14    --learning_rate=1e-04 \
15    --lr_scheduler="constant" \
16    --lr_warmup_steps=0 \
17    --mixed_precision="no" \
18    --seed=42 \
19    --output_dir="sd-pokemon-model-lora-sdxl" \
20    --validation_prompt="cute dragon creature"

微调过程无报错,并且终端显示 Steps: 100% 的进度条说明微调成功。

动态合并 LoRA 的推理

备注

请根据 模型及数据集下载 中模型实际缓存路径指定 model_path

根据 基于 LoRA 的微调 中指定的 LoRA 模型路径 output_dir 指定 lora_model_path

[可选] 修改 prompt 可使得生成图像改变

 1from diffusers import DiffusionPipeline
 2import torch
 3
 4lora_model_path = "path/to/sd-pokemon-model-lora-sdxl/checkpoint-800/"
 5model_path = "./models_ckpt/stable-diffusion-xl-base-1.0/"
 6pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
 7
 8# 将模型放到 NPU 上
 9pipe.to("npu")
10
11# 加载 LoRA 权重
12pipe.load_lora_weights(lora_model_path)
13# 输入 prompt
14prompt = "Sylveon Pokemon with elegant features, magical design, \
15        light purple aura, extremely detailed and intricate markings, \
16        photo realistic, unreal engine, octane render"
17# 推理
18image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
19
20image.save("pokemon-finetuned-inference-generation.png")

微调过程无报错,并且终端显示 Loading pipeline components...: 100% 的进度条说明微调成功。 查看当前目录下保存的 pokemon-finetuned-inference-generation.png 图像,可根据 prompt 生成内容相关的图像说明推理成功。