快速开始

备注

在运行下述示例之前,需要您已经安装了PyTorch-NPU环境,有关环境安装,请参考 安装指南

一般来说,要在代码中使用NPU进行训练推理,需要做以下更改:

  1. 导入torch_npu扩展包 import torch_npu

  2. 将模型,以及模型输入上传到NPU上

1device= torch.device("npu")
2model = model.to(device)
3input = input.to(device)

下面的实例演示了如何使用NPU进行训练和推理任务:

1. 单卡训练

以下代码使用了cifar10数据集在NPU上训练模型(截取自 PyTorch tutorials),请关注高亮的内容。

  1"""
  2Training an image classifier
  3----------------------------
  4
  5We will do the following steps in order:
  6
  71. Load and normalize the CIFAR10 training and test datasets using
  8``torchvision``
  91. Define a Convolutional Neural Network
 102. Define a loss function
 113. Train the network on the training data
 124. Test the network on the test data
 13
 145. Load and normalize CIFAR10
 15^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 16
 17Using ``torchvision``, it’s extremely easy to load CIFAR10.
 18"""
 19import torch
 20# 引入torch-npu包
 21import torch_npu
 22
 23# 定义device
 24device = torch.device('npu:0' if torch.npu.is_available() else 'cpu')
 25print(device)
 26
 27import torchvision
 28import torchvision.transforms as transforms
 29
 30########################################################################
 31# The output of torchvision datasets are PILImage images of range [0, 1].
 32# We transform them to Tensors of normalized range [-1, 1].
 33transform = transforms.Compose(
 34    [transforms.ToTensor(),
 35    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 36
 37batch_size = 4
 38
 39trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
 40                                        download=True, transform=transform)
 41trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
 42                                        shuffle=True, num_workers=2)
 43
 44testset = torchvision.datasets.CIFAR10(root='./data', train=False,
 45                                    download=True, transform=transform)
 46testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
 47                                        shuffle=False, num_workers=2)
 48
 49classes = ('plane', 'car', 'bird', 'cat',
 50        'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
 51
 52########################################################################
 53# 2. Define a Convolutional Neural Network
 54# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 55# Copy the neural network from the Neural Networks section before and modify it to
 56# take 3-channel images (instead of 1-channel images as it was defined).
 57import torch.nn as nn
 58import torch.nn.functional as F
 59
 60
 61class Net(nn.Module):
 62    def __init__(self):
 63        super().__init__()
 64        self.conv1 = nn.Conv2d(3, 6, 5)
 65        self.pool = nn.MaxPool2d(2, 2)
 66        self.conv2 = nn.Conv2d(6, 16, 5)
 67        self.fc1 = nn.Linear(16 * 5 * 5, 120)
 68        self.fc2 = nn.Linear(120, 84)
 69        self.fc3 = nn.Linear(84, 10)
 70
 71    def forward(self, x):
 72        x = self.pool(F.relu(self.conv1(x)))
 73        x = self.pool(F.relu(self.conv2(x)))
 74        x = torch.flatten(x, 1) # flatten all dimensions except batch
 75        x = F.relu(self.fc1(x))
 76        x = F.relu(self.fc2(x))
 77        x = self.fc3(x)
 78        return x
 79
 80net = Net()
 81
 82# 将模型加载到NPU上
 83net.to(device)
 84
 85########################################################################
 86# 3. Define a Loss function and optimizer
 87# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 88# Let's use a Classification Cross-Entropy loss and SGD with momentum.
 89import torch.optim as optim
 90
 91criterion = nn.CrossEntropyLoss()
 92optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
 93
 94########################################################################
 95# 4. Train the network
 96# ^^^^^^^^^^^^^^^^^^^^
 97#
 98# This is when things start to get interesting.
 99# We simply have to loop over our data iterator, and feed the inputs to the
100# network and optimize.
101
102for epoch in range(2):  # loop over the dataset multiple times
103
104    running_loss = 0.0
105    for i, data in enumerate(trainloader, 0):
106        # get the inputs; data is a list of [inputs, labels]
107        # 将input数据发送到NPU上
108        inputs, labels = data[0].to(device), data[1].to(device)
109
110        # zero the parameter gradients
111        optimizer.zero_grad()
112
113        # forward + backward + optimize
114        outputs = net(inputs)
115        loss = criterion(outputs, labels)
116        loss.backward()
117        optimizer.step()
118
119        # print statistics
120        running_loss += loss.item()
121        if i % 2000 == 1999:    # print every 2000 mini-batches
122            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
123            running_loss = 0.0
124
125print('Finished Training')
126
127########################################################################
128# 5. Test the network on the test data
129# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
130#
131# We have trained the network for 2 passes over the training dataset.
132# But we need to check if the network has learnt anything at all.
133#
134# We will check this by predicting the class label that the neural network
135# outputs, and checking it against the ground-truth. If the prediction is
136# correct, we add the sample to the list of correct predictions.
137#
138# Let us look at how the network performs on the whole dataset.
139correct = 0
140total = 0
141# since we're not training, we don't need to calculate the gradients for our outputs
142with torch.no_grad():
143    for data in testloader:
144        # 将input数据发送到NPU上
145        images, labels = data[0].to(device), data[1].to(device)
146        # calculate outputs by running images through the network
147        outputs = net(images)
148        # the class with the highest energy is what we choose as prediction
149        _, predicted = torch.max(outputs.data, 1)
150        total += labels.size(0)
151        correct += (predicted == labels).sum().item()
152
153print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
154########################################################################
155# That looks way better than chance, which is 10% accuracy (randomly picking
156# a class out of 10 classes).
157# Seems like the network learnt something.
158#
159# Hmmm, what are the classes that performed well, and the classes that did
160# not perform well:
161
162# prepare to count predictions for each class
163correct_pred = {classname: 0 for classname in classes}
164total_pred = {classname: 0 for classname in classes}
165
166# again no gradients needed
167with torch.no_grad():
168    for data in testloader:
169        # 将input数据发送到NPU上
170        images, labels = data[0].to(device), data[1].to(device)
171        outputs = net(images)
172        _, predictions = torch.max(outputs, 1)
173        # collect the correct predictions for each class
174        for label, prediction in zip(labels, predictions):
175            if label == prediction:
176                correct_pred[classes[label]] += 1
177            total_pred[classes[label]] += 1
178
179
180# print accuracy for each class
181for classname, correct_count in correct_pred.items():
182    accuracy = 100 * float(correct_count) / total_pred[classname]
183    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

2. 使用DeepSpeed多卡并行训练

以下代码使用了cifar10数据集,使用DeepSpeed训练模型在多张NPU卡上进行模型训练(来自 DeepSpeed Examples),自DeepSpeed v0.12.6之后,代码无需任何修改,即可自动检测NPU并进行训练。

  1import argparse
  2import os
  3
  4import deepspeed
  5import torch
  6import torch.nn as nn
  7import torch.nn.functional as F
  8import torchvision
  9import torchvision.transforms as transforms
 10from deepspeed.accelerator import get_accelerator
 11from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
 12
 13
 14def add_argument():
 15    parser = argparse.ArgumentParser(description="CIFAR")
 16
 17    # For train.
 18    parser.add_argument(
 19        "-e",
 20        "--epochs",
 21        default=30,
 22        type=int,
 23        help="number of total epochs (default: 30)",
 24    )
 25    parser.add_argument(
 26        "--local_rank",
 27        type=int,
 28        default=-1,
 29        help="local rank passed from distributed launcher",
 30    )
 31    parser.add_argument(
 32        "--log-interval",
 33        type=int,
 34        default=2000,
 35        help="output logging information at a given interval",
 36    )
 37
 38    # For mixed precision training.
 39    parser.add_argument(
 40        "--dtype",
 41        default="fp16",
 42        type=str,
 43        choices=["bf16", "fp16", "fp32"],
 44        help="Datatype used for training",
 45    )
 46
 47    # For ZeRO Optimization.
 48    parser.add_argument(
 49        "--stage",
 50        default=0,
 51        type=int,
 52        choices=[0, 1, 2, 3],
 53        help="Datatype used for training",
 54    )
 55
 56    # For MoE (Mixture of Experts).
 57    parser.add_argument(
 58        "--moe",
 59        default=False,
 60        action="store_true",
 61        help="use deepspeed mixture of experts (moe)",
 62    )
 63    parser.add_argument(
 64        "--ep-world-size", default=1, type=int, help="(moe) expert parallel world size"
 65    )
 66    parser.add_argument(
 67        "--num-experts",
 68        type=int,
 69        nargs="+",
 70        default=[
 71            1,
 72        ],
 73        help="number of experts list, MoE related.",
 74    )
 75    parser.add_argument(
 76        "--mlp-type",
 77        type=str,
 78        default="standard",
 79        help="Only applicable when num-experts > 1, accepts [standard, residual]",
 80    )
 81    parser.add_argument(
 82        "--top-k", default=1, type=int, help="(moe) gating top 1 and 2 supported"
 83    )
 84    parser.add_argument(
 85        "--min-capacity",
 86        default=0,
 87        type=int,
 88        help="(moe) minimum capacity of an expert regardless of the capacity_factor",
 89    )
 90    parser.add_argument(
 91        "--noisy-gate-policy",
 92        default=None,
 93        type=str,
 94        help="(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter",
 95    )
 96    parser.add_argument(
 97        "--moe-param-group",
 98        default=False,
 99        action="store_true",
100        help="(moe) create separate moe param groups, required when using ZeRO w. MoE",
101    )
102
103    # Include DeepSpeed configuration arguments.
104    parser = deepspeed.add_config_arguments(parser)
105
106    args = parser.parse_args()
107
108    return args
109
110
111def create_moe_param_groups(model):
112    """Create separate parameter groups for each expert."""
113    parameters = {"params": [p for p in model.parameters()], "name": "parameters"}
114    return split_params_into_different_moe_groups_for_optimizer(parameters)
115
116
117def get_ds_config(args):
118    """Get the DeepSpeed configuration dictionary."""
119    ds_config = {
120        "train_batch_size": 16,
121        "steps_per_print": 2000,
122        "optimizer": {
123            "type": "Adam",
124            "params": {
125                "lr": 0.001,
126                "betas": [0.8, 0.999],
127                "eps": 1e-8,
128                "weight_decay": 3e-7,
129            },
130        },
131        "scheduler": {
132            "type": "WarmupLR",
133            "params": {
134                "warmup_min_lr": 0,
135                "warmup_max_lr": 0.001,
136                "warmup_num_steps": 1000,
137            },
138        },
139        "gradient_clipping": 1.0,
140        "prescale_gradients": False,
141        "bf16": {"enabled": args.dtype == "bf16"},
142        "fp16": {
143            "enabled": args.dtype == "fp16",
144            "fp16_master_weights_and_grads": False,
145            "loss_scale": 0,
146            "loss_scale_window": 500,
147            "hysteresis": 2,
148            "min_loss_scale": 1,
149            "initial_scale_power": 15,
150        },
151        "wall_clock_breakdown": False,
152        "zero_optimization": {
153            "stage": args.stage,
154            "allgather_partitions": True,
155            "reduce_scatter": True,
156            "allgather_bucket_size": 50000000,
157            "reduce_bucket_size": 50000000,
158            "overlap_comm": True,
159            "contiguous_gradients": True,
160            "cpu_offload": False,
161        },
162    }
163    return ds_config
164
165
166class Net(nn.Module):
167    def __init__(self, args):
168        super(Net, self).__init__()
169        self.conv1 = nn.Conv2d(3, 6, 5)
170        self.pool = nn.MaxPool2d(2, 2)
171        self.conv2 = nn.Conv2d(6, 16, 5)
172        self.fc1 = nn.Linear(16 * 5 * 5, 120)
173        self.fc2 = nn.Linear(120, 84)
174        self.moe = args.moe
175        if self.moe:
176            fc3 = nn.Linear(84, 84)
177            self.moe_layer_list = []
178            for n_e in args.num_experts:
179                # Create moe layers based on the number of experts.
180                self.moe_layer_list.append(
181                    deepspeed.moe.layer.MoE(
182                        hidden_size=84,
183                        expert=fc3,
184                        num_experts=n_e,
185                        ep_size=args.ep_world_size,
186                        use_residual=args.mlp_type == "residual",
187                        k=args.top_k,
188                        min_capacity=args.min_capacity,
189                        noisy_gate_policy=args.noisy_gate_policy,
190                    )
191                )
192            self.moe_layer_list = nn.ModuleList(self.moe_layer_list)
193            self.fc4 = nn.Linear(84, 10)
194        else:
195            self.fc3 = nn.Linear(84, 10)
196
197    def forward(self, x):
198        x = self.pool(F.relu(self.conv1(x)))
199        x = self.pool(F.relu(self.conv2(x)))
200        x = x.view(-1, 16 * 5 * 5)
201        x = F.relu(self.fc1(x))
202        x = F.relu(self.fc2(x))
203        if self.moe:
204            for layer in self.moe_layer_list:
205                x, _, _ = layer(x)
206            x = self.fc4(x)
207        else:
208            x = self.fc3(x)
209        return x
210
211
212def test(model_engine, testset, local_device, target_dtype, test_batch_size=4):
213    """Test the network on the test data.
214
215    Args:
216        model_engine (deepspeed.runtime.engine.DeepSpeedEngine): the DeepSpeed engine.
217        testset (torch.utils.data.Dataset): the test dataset.
218        local_device (str): the local device name.
219        target_dtype (torch.dtype): the target datatype for the test data.
220        test_batch_size (int): the test batch size.
221
222    """
223    # The 10 classes for CIFAR10.
224    classes = (
225        "plane",
226        "car",
227        "bird",
228        "cat",
229        "deer",
230        "dog",
231        "frog",
232        "horse",
233        "ship",
234        "truck",
235    )
236
237    # Define the test dataloader.
238    testloader = torch.utils.data.DataLoader(
239        testset, batch_size=test_batch_size, shuffle=False, num_workers=0
240    )
241
242    # For total accuracy.
243    correct, total = 0, 0
244    # For accuracy per class.
245    class_correct = list(0.0 for i in range(10))
246    class_total = list(0.0 for i in range(10))
247
248    # Start testing.
249    model_engine.eval()
250    with torch.no_grad():
251        for data in testloader:
252            images, labels = data
253            if target_dtype != None:
254                images = images.to(target_dtype)
255            outputs = model_engine(images.to(local_device))
256            _, predicted = torch.max(outputs.data, 1)
257            # Count the total accuracy.
258            total += labels.size(0)
259            correct += (predicted == labels.to(local_device)).sum().item()
260
261            # Count the accuracy per class.
262            batch_correct = (predicted == labels.to(local_device)).squeeze()
263            for i in range(test_batch_size):
264                label = labels[i]
265                class_correct[label] += batch_correct[i].item()
266                class_total[label] += 1
267
268    if model_engine.local_rank == 0:
269        print(
270            f"Accuracy of the network on the {total} test images: {100 * correct / total : .0f} %"
271        )
272
273        # For all classes, print the accuracy.
274        for i in range(10):
275            print(
276                f"Accuracy of {classes[i] : >5s} : {100 * class_correct[i] / class_total[i] : 2.0f} %"
277            )
278
279
280def main(args):
281    # Initialize DeepSpeed distributed backend.
282    deepspeed.init_distributed()
283    _local_rank = int(os.environ.get("LOCAL_RANK"))
284    get_accelerator().set_device(_local_rank)
285
286    ########################################################################
287    # Step1. Data Preparation.
288    #
289    # The output of torchvision datasets are PILImage images of range [0, 1].
290    # We transform them to Tensors of normalized range [-1, 1].
291    #
292    # Note:
293    #     If running on Windows and you get a BrokenPipeError, try setting
294    #     the num_worker of torch.utils.data.DataLoader() to 0.
295    ########################################################################
296    transform = transforms.Compose(
297        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
298    )
299
300    if torch.distributed.get_rank() != 0:
301        # Might be downloading cifar data, let rank 0 download first.
302        torch.distributed.barrier()
303
304    # Load or download cifar data.
305    trainset = torchvision.datasets.CIFAR10(
306        root="./data", train=True, download=True, transform=transform
307    )
308    testset = torchvision.datasets.CIFAR10(
309        root="./data", train=False, download=True, transform=transform
310    )
311
312    if torch.distributed.get_rank() == 0:
313        # Cifar data is downloaded, indicate other ranks can proceed.
314        torch.distributed.barrier()
315
316    ########################################################################
317    # Step 2. Define the network with DeepSpeed.
318    #
319    # First, we define a Convolution Neural Network.
320    # Then, we define the DeepSpeed configuration dictionary and use it to
321    # initialize the DeepSpeed engine.
322    ########################################################################
323    net = Net(args)
324
325    # Get list of parameters that require gradients.
326    parameters = filter(lambda p: p.requires_grad, net.parameters())
327
328    # If using MoE, create separate param groups for each expert.
329    if args.moe_param_group:
330        parameters = create_moe_param_groups(net)
331
332    # Initialize DeepSpeed to use the following features.
333    #   1) Distributed model.
334    #   2) Distributed data loader.
335    #   3) DeepSpeed optimizer.
336    ds_config = get_ds_config(args)
337    model_engine, optimizer, trainloader, __ = deepspeed.initialize(
338        args=args,
339        model=net,
340        model_parameters=parameters,
341        training_data=trainset,
342        config=ds_config,
343    )
344
345    # Get the local device name (str) and local rank (int).
346    local_device = get_accelerator().device_name(model_engine.local_rank)
347    local_rank = model_engine.local_rank
348
349    # For float32, target_dtype will be None so no datatype conversion needed.
350    target_dtype = None
351    if model_engine.bfloat16_enabled():
352        target_dtype = torch.bfloat16
353    elif model_engine.fp16_enabled():
354        target_dtype = torch.half
355
356    # Define the Classification Cross-Entropy loss function.
357    criterion = nn.CrossEntropyLoss()
358
359    ########################################################################
360    # Step 3. Train the network.
361    #
362    # This is when things start to get interesting.
363    # We simply have to loop over our data iterator, and feed the inputs to the
364    # network and optimize. (DeepSpeed handles the distributed details for us!)
365    ########################################################################
366
367    for epoch in range(args.epochs):  # loop over the dataset multiple times
368        running_loss = 0.0
369        for i, data in enumerate(trainloader):
370            # Get the inputs. ``data`` is a list of [inputs, labels].
371            inputs, labels = data[0].to(local_device), data[1].to(local_device)
372
373            # Try to convert to target_dtype if needed.
374            if target_dtype != None:
375                inputs = inputs.to(target_dtype)
376
377            outputs = model_engine(inputs)
378            loss = criterion(outputs, labels)
379
380            model_engine.backward(loss)
381            model_engine.step()
382
383            # Print statistics
384            running_loss += loss.item()
385            if local_rank == 0 and i % args.log_interval == (
386                args.log_interval - 1
387            ):  # Print every log_interval mini-batches.
388                print(
389                    f"[{epoch + 1 : d}, {i + 1 : 5d}] loss: {running_loss / args.log_interval : .3f}"
390                )
391                running_loss = 0.0
392    print("Finished Training")
393
394    ########################################################################
395    # Step 4. Test the network on the test data.
396    ########################################################################
397    test(model_engine, testset, local_device, target_dtype)
398
399
400if __name__ == "__main__":
401    args = add_argument()
402    main(args)

3. 使用Transforms进行模型微调

以下代码使用了Transforms对LLM进行微调(来自 transforms examples),自transforms xxx版本以及accelerator 0.21.0版本以后,代码无需任何修改,即可自动检测NPU并进行。

  1#!/usr/bin/env python
  2# coding=utf-8
  3# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
  4#
  5# Licensed under the Apache License, Version 2.0 (the "License");
  6# you may not use this file except in compliance with the License.
  7# You may obtain a copy of the License at
  8#
  9#     http://www.apache.org/licenses/LICENSE-2.0
 10#
 11# Unless required by applicable law or agreed to in writing, software
 12# distributed under the License is distributed on an "AS IS" BASIS,
 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14# See the License for the specific language governing permissions and
 15# limitations under the License.
 16"""
 17Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
 18
 19Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
 20https://huggingface.co/models?filter=text-generation
 21"""
 22# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
 23
 24import logging
 25import math
 26import os
 27import sys
 28from dataclasses import dataclass, field
 29from itertools import chain
 30from typing import Optional
 31
 32import datasets
 33import evaluate
 34import torch
 35from datasets import load_dataset
 36
 37import transformers
 38from transformers import (
 39    CONFIG_MAPPING,
 40    MODEL_FOR_CAUSAL_LM_MAPPING,
 41    AutoConfig,
 42    AutoModelForCausalLM,
 43    AutoTokenizer,
 44    HfArgumentParser,
 45    Trainer,
 46    TrainingArguments,
 47    default_data_collator,
 48    is_torch_xla_available,
 49    set_seed,
 50)
 51from transformers.testing_utils import CaptureLogger
 52from transformers.trainer_utils import get_last_checkpoint
 53from transformers.utils import check_min_version, send_example_telemetry
 54from transformers.utils.versions import require_version
 55
 56
 57# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 58check_min_version("4.48.0.dev0")
 59
 60require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 61
 62logger = logging.getLogger(__name__)
 63
 64
 65MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
 66MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
 67
 68
 69@dataclass
 70class ModelArguments:
 71    """
 72    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
 73    """
 74
 75    model_name_or_path: Optional[str] = field(
 76        default=None,
 77        metadata={
 78            "help": (
 79                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
 80            )
 81        },
 82    )
 83    model_type: Optional[str] = field(
 84        default=None,
 85        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
 86    )
 87    config_overrides: Optional[str] = field(
 88        default=None,
 89        metadata={
 90            "help": (
 91                "Override some existing default config settings when a model is trained from scratch. Example: "
 92                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
 93            )
 94        },
 95    )
 96    config_name: Optional[str] = field(
 97        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 98    )
 99    tokenizer_name: Optional[str] = field(
100        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
101    )
102    cache_dir: Optional[str] = field(
103        default=None,
104        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
105    )
106    use_fast_tokenizer: bool = field(
107        default=True,
108        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
109    )
110    model_revision: str = field(
111        default="main",
112        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
113    )
114    token: str = field(
115        default=None,
116        metadata={
117            "help": (
118                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
119                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
120            )
121        },
122    )
123    trust_remote_code: bool = field(
124        default=False,
125        metadata={
126            "help": (
127                "Whether to trust the execution of code from datasets/models defined on the Hub."
128                " This option should only be set to `True` for repositories you trust and in which you have read the"
129                " code, as it will execute code present on the Hub on your local machine."
130            )
131        },
132    )
133    torch_dtype: Optional[str] = field(
134        default=None,
135        metadata={
136            "help": (
137                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
138                "dtype will be automatically derived from the model's weights."
139            ),
140            "choices": ["auto", "bfloat16", "float16", "float32"],
141        },
142    )
143    low_cpu_mem_usage: bool = field(
144        default=False,
145        metadata={
146            "help": (
147                "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
148                "set True will benefit LLM loading time and RAM consumption."
149            )
150        },
151    )
152
153    def __post_init__(self):
154        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
155            raise ValueError(
156                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
157            )
158
159
160@dataclass
161class DataTrainingArguments:
162    """
163    Arguments pertaining to what data we are going to input our model for training and eval.
164    """
165
166    dataset_name: Optional[str] = field(
167        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
168    )
169    dataset_config_name: Optional[str] = field(
170        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
171    )
172    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
173    validation_file: Optional[str] = field(
174        default=None,
175        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
176    )
177    max_train_samples: Optional[int] = field(
178        default=None,
179        metadata={
180            "help": (
181                "For debugging purposes or quicker training, truncate the number of training examples to this "
182                "value if set."
183            )
184        },
185    )
186    max_eval_samples: Optional[int] = field(
187        default=None,
188        metadata={
189            "help": (
190                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
191                "value if set."
192            )
193        },
194    )
195    streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
196    block_size: Optional[int] = field(
197        default=None,
198        metadata={
199            "help": (
200                "Optional input sequence length after tokenization. "
201                "The training dataset will be truncated in block of this size for training. "
202                "Default to the model max input length for single sentence inputs (take into account special tokens)."
203            )
204        },
205    )
206    overwrite_cache: bool = field(
207        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
208    )
209    validation_split_percentage: Optional[int] = field(
210        default=5,
211        metadata={
212            "help": "The percentage of the train set used as validation set in case there's no validation split"
213        },
214    )
215    preprocessing_num_workers: Optional[int] = field(
216        default=None,
217        metadata={"help": "The number of processes to use for the preprocessing."},
218    )
219    keep_linebreaks: bool = field(
220        default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
221    )
222
223    def __post_init__(self):
224        if self.streaming:
225            require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
226
227        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
228            raise ValueError("Need either a dataset name or a training/validation file.")
229        else:
230            if self.train_file is not None:
231                extension = self.train_file.split(".")[-1]
232                assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
233            if self.validation_file is not None:
234                extension = self.validation_file.split(".")[-1]
235                assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
236
237
238def main():
239    # See all possible arguments in src/transformers/training_args.py
240    # or by passing the --help flag to this script.
241    # We now keep distinct sets of args, for a cleaner separation of concerns.
242
243    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
244    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
245        # If we pass only one argument to the script and it's the path to a json file,
246        # let's parse it to get our arguments.
247        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
248    else:
249        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
250
251    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
252    # information sent is the one passed as arguments along with your Python/PyTorch versions.
253    send_example_telemetry("run_clm", model_args, data_args)
254
255    # Setup logging
256    logging.basicConfig(
257        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
258        datefmt="%m/%d/%Y %H:%M:%S",
259        handlers=[logging.StreamHandler(sys.stdout)],
260    )
261
262    if training_args.should_log:
263        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
264        transformers.utils.logging.set_verbosity_info()
265
266    log_level = training_args.get_process_log_level()
267    logger.setLevel(log_level)
268    datasets.utils.logging.set_verbosity(log_level)
269    transformers.utils.logging.set_verbosity(log_level)
270    transformers.utils.logging.enable_default_handler()
271    transformers.utils.logging.enable_explicit_format()
272
273    # Log on each process the small summary:
274    logger.warning(
275        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
276        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
277    )
278    logger.info(f"Training/evaluation parameters {training_args}")
279
280    # Detecting last checkpoint.
281    last_checkpoint = None
282    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
283        last_checkpoint = get_last_checkpoint(training_args.output_dir)
284        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
285            raise ValueError(
286                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
287                "Use --overwrite_output_dir to overcome."
288            )
289        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
290            logger.info(
291                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
292                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
293            )
294
295    # Set seed before initializing model.
296    set_seed(training_args.seed)
297
298    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
299    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
300    # (the dataset will be downloaded automatically from the datasets Hub).
301    #
302    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
303    # 'text' is found. You can easily tweak this behavior (see below).
304    #
305    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
306    # download the dataset.
307    if data_args.dataset_name is not None:
308        # Downloading and loading a dataset from the hub.
309        raw_datasets = load_dataset(
310            data_args.dataset_name,
311            data_args.dataset_config_name,
312            cache_dir=model_args.cache_dir,
313            token=model_args.token,
314            streaming=data_args.streaming,
315            trust_remote_code=model_args.trust_remote_code,
316        )
317        if "validation" not in raw_datasets.keys():
318            raw_datasets["validation"] = load_dataset(
319                data_args.dataset_name,
320                data_args.dataset_config_name,
321                split=f"train[:{data_args.validation_split_percentage}%]",
322                cache_dir=model_args.cache_dir,
323                token=model_args.token,
324                streaming=data_args.streaming,
325                trust_remote_code=model_args.trust_remote_code,
326            )
327            raw_datasets["train"] = load_dataset(
328                data_args.dataset_name,
329                data_args.dataset_config_name,
330                split=f"train[{data_args.validation_split_percentage}%:]",
331                cache_dir=model_args.cache_dir,
332                token=model_args.token,
333                streaming=data_args.streaming,
334                trust_remote_code=model_args.trust_remote_code,
335            )
336    else:
337        data_files = {}
338        dataset_args = {}
339        if data_args.train_file is not None:
340            data_files["train"] = data_args.train_file
341        if data_args.validation_file is not None:
342            data_files["validation"] = data_args.validation_file
343        extension = (
344            data_args.train_file.split(".")[-1]
345            if data_args.train_file is not None
346            else data_args.validation_file.split(".")[-1]
347        )
348        if extension == "txt":
349            extension = "text"
350            dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
351        raw_datasets = load_dataset(
352            extension,
353            data_files=data_files,
354            cache_dir=model_args.cache_dir,
355            token=model_args.token,
356            **dataset_args,
357        )
358        # If no validation data is there, validation_split_percentage will be used to divide the dataset.
359        if "validation" not in raw_datasets.keys():
360            raw_datasets["validation"] = load_dataset(
361                extension,
362                data_files=data_files,
363                split=f"train[:{data_args.validation_split_percentage}%]",
364                cache_dir=model_args.cache_dir,
365                token=model_args.token,
366                **dataset_args,
367            )
368            raw_datasets["train"] = load_dataset(
369                extension,
370                data_files=data_files,
371                split=f"train[{data_args.validation_split_percentage}%:]",
372                cache_dir=model_args.cache_dir,
373                token=model_args.token,
374                **dataset_args,
375            )
376
377    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
378    # https://huggingface.co/docs/datasets/loading_datasets.
379
380    # Load pretrained model and tokenizer
381    #
382    # Distributed training:
383    # The .from_pretrained methods guarantee that only one local process can concurrently
384    # download model & vocab.
385
386    config_kwargs = {
387        "cache_dir": model_args.cache_dir,
388        "revision": model_args.model_revision,
389        "token": model_args.token,
390        "trust_remote_code": model_args.trust_remote_code,
391    }
392    if model_args.config_name:
393        config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
394    elif model_args.model_name_or_path:
395        config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
396    else:
397        config = CONFIG_MAPPING[model_args.model_type]()
398        logger.warning("You are instantiating a new config instance from scratch.")
399        if model_args.config_overrides is not None:
400            logger.info(f"Overriding config: {model_args.config_overrides}")
401            config.update_from_string(model_args.config_overrides)
402            logger.info(f"New config: {config}")
403
404    tokenizer_kwargs = {
405        "cache_dir": model_args.cache_dir,
406        "use_fast": model_args.use_fast_tokenizer,
407        "revision": model_args.model_revision,
408        "token": model_args.token,
409        "trust_remote_code": model_args.trust_remote_code,
410    }
411    if model_args.tokenizer_name:
412        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
413    elif model_args.model_name_or_path:
414        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
415    else:
416        raise ValueError(
417            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
418            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
419        )
420
421    if model_args.model_name_or_path:
422        torch_dtype = (
423            model_args.torch_dtype
424            if model_args.torch_dtype in ["auto", None]
425            else getattr(torch, model_args.torch_dtype)
426        )
427        model = AutoModelForCausalLM.from_pretrained(
428            model_args.model_name_or_path,
429            from_tf=bool(".ckpt" in model_args.model_name_or_path),
430            config=config,
431            cache_dir=model_args.cache_dir,
432            revision=model_args.model_revision,
433            token=model_args.token,
434            trust_remote_code=model_args.trust_remote_code,
435            torch_dtype=torch_dtype,
436            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
437        )
438    else:
439        model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
440        n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
441        logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
442
443    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
444    # on a small vocab and want a smaller embedding size, remove this test.
445    embedding_size = model.get_input_embeddings().weight.shape[0]
446    if len(tokenizer) > embedding_size:
447        model.resize_token_embeddings(len(tokenizer))
448
449    # Preprocessing the datasets.
450    # First we tokenize all the texts.
451    if training_args.do_train:
452        column_names = list(raw_datasets["train"].features)
453    else:
454        column_names = list(raw_datasets["validation"].features)
455    text_column_name = "text" if "text" in column_names else column_names[0]
456
457    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
458    tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
459
460    def tokenize_function(examples):
461        with CaptureLogger(tok_logger) as cl:
462            output = tokenizer(examples[text_column_name])
463        # clm input could be much much longer than block_size
464        if "Token indices sequence length is longer than the" in cl.out:
465            tok_logger.warning(
466                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
467                " before being passed to the model."
468            )
469        return output
470
471    with training_args.main_process_first(desc="dataset map tokenization"):
472        if not data_args.streaming:
473            tokenized_datasets = raw_datasets.map(
474                tokenize_function,
475                batched=True,
476                num_proc=data_args.preprocessing_num_workers,
477                remove_columns=column_names,
478                load_from_cache_file=not data_args.overwrite_cache,
479                desc="Running tokenizer on dataset",
480            )
481        else:
482            tokenized_datasets = raw_datasets.map(
483                tokenize_function,
484                batched=True,
485                remove_columns=column_names,
486            )
487    if hasattr(config, "max_position_embeddings"):
488        max_pos_embeddings = config.max_position_embeddings
489    else:
490        # Define a default value if the attribute is missing in the config.
491        max_pos_embeddings = 1024
492
493    if data_args.block_size is None:
494        block_size = tokenizer.model_max_length
495        if block_size > max_pos_embeddings:
496            logger.warning(
497                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
498                f"Using block_size={min(1024, max_pos_embeddings)} instead. You can change that default value by passing --block_size xxx."
499            )
500            if max_pos_embeddings > 0:
501                block_size = min(1024, max_pos_embeddings)
502            else:
503                block_size = 1024
504    else:
505        if data_args.block_size > tokenizer.model_max_length:
506            logger.warning(
507                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model "
508                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
509            )
510        block_size = min(data_args.block_size, tokenizer.model_max_length)
511
512    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
513    def group_texts(examples):
514        # Concatenate all texts.
515        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
516        total_length = len(concatenated_examples[list(examples.keys())[0]])
517        # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
518        # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
519        total_length = (total_length // block_size) * block_size
520        # Split by chunks of max_len.
521        result = {
522            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
523            for k, t in concatenated_examples.items()
524        }
525        result["labels"] = result["input_ids"].copy()
526        return result
527
528    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
529    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
530    # to preprocess.
531    #
532    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
533    # https://huggingface.co/docs/datasets/process#map
534
535    with training_args.main_process_first(desc="grouping texts together"):
536        if not data_args.streaming:
537            lm_datasets = tokenized_datasets.map(
538                group_texts,
539                batched=True,
540                num_proc=data_args.preprocessing_num_workers,
541                load_from_cache_file=not data_args.overwrite_cache,
542                desc=f"Grouping texts in chunks of {block_size}",
543            )
544        else:
545            lm_datasets = tokenized_datasets.map(
546                group_texts,
547                batched=True,
548            )
549
550    if training_args.do_train:
551        if "train" not in tokenized_datasets:
552            raise ValueError("--do_train requires a train dataset")
553        train_dataset = lm_datasets["train"]
554        if data_args.max_train_samples is not None:
555            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
556            train_dataset = train_dataset.select(range(max_train_samples))
557
558    if training_args.do_eval:
559        if "validation" not in tokenized_datasets:
560            raise ValueError("--do_eval requires a validation dataset")
561        eval_dataset = lm_datasets["validation"]
562        if data_args.max_eval_samples is not None:
563            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
564            eval_dataset = eval_dataset.select(range(max_eval_samples))
565
566        def preprocess_logits_for_metrics(logits, labels):
567            if isinstance(logits, tuple):
568                # Depending on the model and config, logits may contain extra tensors,
569                # like past_key_values, but logits always come first
570                logits = logits[0]
571            return logits.argmax(dim=-1)
572
573        metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
574
575        def compute_metrics(eval_preds):
576            preds, labels = eval_preds
577            # preds have the same shape as the labels, after the argmax(-1) has been calculated
578            # by preprocess_logits_for_metrics but we need to shift the labels
579            labels = labels[:, 1:].reshape(-1)
580            preds = preds[:, :-1].reshape(-1)
581            return metric.compute(predictions=preds, references=labels)
582
583    # Initialize our Trainer
584    trainer = Trainer(
585        model=model,
586        args=training_args,
587        train_dataset=train_dataset if training_args.do_train else None,
588        eval_dataset=eval_dataset if training_args.do_eval else None,
589        processing_class=tokenizer,
590        # Data collator will default to DataCollatorWithPadding, so we change it.
591        data_collator=default_data_collator,
592        compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
593        preprocess_logits_for_metrics=preprocess_logits_for_metrics
594        if training_args.do_eval and not is_torch_xla_available()
595        else None,
596    )
597
598    # Training
599    if training_args.do_train:
600        checkpoint = None
601        if training_args.resume_from_checkpoint is not None:
602            checkpoint = training_args.resume_from_checkpoint
603        elif last_checkpoint is not None:
604            checkpoint = last_checkpoint
605        train_result = trainer.train(resume_from_checkpoint=checkpoint)
606        trainer.save_model()  # Saves the tokenizer too for easy upload
607
608        metrics = train_result.metrics
609
610        max_train_samples = (
611            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
612        )
613        metrics["train_samples"] = min(max_train_samples, len(train_dataset))
614
615        trainer.log_metrics("train", metrics)
616        trainer.save_metrics("train", metrics)
617        trainer.save_state()
618
619    # Evaluation
620    if training_args.do_eval:
621        logger.info("*** Evaluate ***")
622
623        metrics = trainer.evaluate()
624
625        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
626        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
627        try:
628            perplexity = math.exp(metrics["eval_loss"])
629        except OverflowError:
630            perplexity = float("inf")
631        metrics["perplexity"] = perplexity
632
633        trainer.log_metrics("eval", metrics)
634        trainer.save_metrics("eval", metrics)
635
636    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
637    if data_args.dataset_name is not None:
638        kwargs["dataset_tags"] = data_args.dataset_name
639        if data_args.dataset_config_name is not None:
640            kwargs["dataset_args"] = data_args.dataset_config_name
641            kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
642        else:
643            kwargs["dataset"] = data_args.dataset_name
644
645    if training_args.push_to_hub:
646        trainer.push_to_hub(**kwargs)
647    else:
648        trainer.create_model_card(**kwargs)
649
650
651def _mp_fn(index):
652    # For xla_spawn (TPUs)
653    main()
654
655
656if __name__ == "__main__":
657    main()
1python run_clm.py \
2    --model_name_or_path openai-community/gpt2 \
3    --train_file path_to_train_file \
4    --validation_file path_to_validation_file \
5    --per_device_train_batch_size 8 \
6    --per_device_eval_batch_size 8 \
7    --do_train \
8    --do_eval \
9    --output_dir /tmp/test-clm

4. 使用Diffusers进行模型微调

以下代码使用了Diffusers对文生图模型进行微调(来自 diffusers examples),自diffusers v0.27.0版本以后,代码无需任何修改,即可自动检测NPU并进行。

   1#!/usr/bin/env python
   2# coding=utf-8
   3# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
   4#
   5# Licensed under the Apache License, Version 2.0 (the "License");
   6# you may not use this file except in compliance with the License.
   7# You may obtain a copy of the License at
   8#
   9#     http://www.apache.org/licenses/LICENSE-2.0
  10#
  11# Unless required by applicable law or agreed to in writing, software
  12# distributed under the License is distributed on an "AS IS" BASIS,
  13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14# See the License for the specific language governing permissions and
  15# limitations under the License.
  16
  17import argparse
  18import logging
  19import math
  20import os
  21import random
  22import shutil
  23from contextlib import nullcontext
  24from pathlib import Path
  25
  26import accelerate
  27import datasets
  28import numpy as np
  29import torch
  30import torch.nn.functional as F
  31import torch.utils.checkpoint
  32import transformers
  33from accelerate import Accelerator
  34from accelerate.logging import get_logger
  35from accelerate.state import AcceleratorState
  36from accelerate.utils import ProjectConfiguration, set_seed
  37from datasets import load_dataset
  38from huggingface_hub import create_repo, upload_folder
  39from packaging import version
  40from torchvision import transforms
  41from tqdm.auto import tqdm
  42from transformers import CLIPTextModel, CLIPTokenizer
  43from transformers.utils import ContextManagers
  44
  45import diffusers
  46from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
  47from diffusers.optimization import get_scheduler
  48from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
  49from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
  50from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
  51from diffusers.utils.import_utils import is_xformers_available
  52from diffusers.utils.torch_utils import is_compiled_module
  53
  54
  55if is_wandb_available():
  56    import wandb
  57
  58
  59# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
  60check_min_version("0.32.0.dev0")
  61
  62logger = get_logger(__name__, log_level="INFO")
  63
  64DATASET_NAME_MAPPING = {
  65    "lambdalabs/naruto-blip-captions": ("image", "text"),
  66}
  67
  68
  69def save_model_card(
  70    args,
  71    repo_id: str,
  72    images: list = None,
  73    repo_folder: str = None,
  74):
  75    img_str = ""
  76    if len(images) > 0:
  77        image_grid = make_image_grid(images, 1, len(args.validation_prompts))
  78        image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
  79        img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
  80
  81    model_description = f"""
  82# Text-to-image finetuning - {repo_id}
  83
  84This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
  85{img_str}
  86
  87## Pipeline usage
  88
  89You can use the pipeline like so:
  90
  91```python
  92from diffusers import DiffusionPipeline
  93import torch
  94
  95pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
  96prompt = "{args.validation_prompts[0]}"
  97image = pipeline(prompt).images[0]
  98image.save("my_image.png")
  99```
 100
 101## Training info
 102
 103These are the key hyperparameters used during training:
 104
 105* Epochs: {args.num_train_epochs}
 106* Learning rate: {args.learning_rate}
 107* Batch size: {args.train_batch_size}
 108* Gradient accumulation steps: {args.gradient_accumulation_steps}
 109* Image resolution: {args.resolution}
 110* Mixed-precision: {args.mixed_precision}
 111
 112"""
 113    wandb_info = ""
 114    if is_wandb_available():
 115        wandb_run_url = None
 116        if wandb.run is not None:
 117            wandb_run_url = wandb.run.url
 118
 119    if wandb_run_url is not None:
 120        wandb_info = f"""
 121More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
 122"""
 123
 124    model_description += wandb_info
 125
 126    model_card = load_or_create_model_card(
 127        repo_id_or_path=repo_id,
 128        from_training=True,
 129        license="creativeml-openrail-m",
 130        base_model=args.pretrained_model_name_or_path,
 131        model_description=model_description,
 132        inference=True,
 133    )
 134
 135    tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
 136    model_card = populate_model_card(model_card, tags=tags)
 137
 138    model_card.save(os.path.join(repo_folder, "README.md"))
 139
 140
 141def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
 142    logger.info("Running validation... ")
 143
 144    pipeline = StableDiffusionPipeline.from_pretrained(
 145        args.pretrained_model_name_or_path,
 146        vae=accelerator.unwrap_model(vae),
 147        text_encoder=accelerator.unwrap_model(text_encoder),
 148        tokenizer=tokenizer,
 149        unet=accelerator.unwrap_model(unet),
 150        safety_checker=None,
 151        revision=args.revision,
 152        variant=args.variant,
 153        torch_dtype=weight_dtype,
 154    )
 155    pipeline = pipeline.to(accelerator.device)
 156    pipeline.set_progress_bar_config(disable=True)
 157
 158    if args.enable_xformers_memory_efficient_attention:
 159        pipeline.enable_xformers_memory_efficient_attention()
 160
 161    if args.seed is None:
 162        generator = None
 163    else:
 164        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
 165
 166    images = []
 167    for i in range(len(args.validation_prompts)):
 168        if torch.backends.mps.is_available():
 169            autocast_ctx = nullcontext()
 170        else:
 171            autocast_ctx = torch.autocast(accelerator.device.type)
 172
 173        with autocast_ctx:
 174            image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
 175
 176        images.append(image)
 177
 178    for tracker in accelerator.trackers:
 179        if tracker.name == "tensorboard":
 180            np_images = np.stack([np.asarray(img) for img in images])
 181            tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
 182        elif tracker.name == "wandb":
 183            tracker.log(
 184                {
 185                    "validation": [
 186                        wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
 187                        for i, image in enumerate(images)
 188                    ]
 189                }
 190            )
 191        else:
 192            logger.warning(f"image logging not implemented for {tracker.name}")
 193
 194    del pipeline
 195    torch.cuda.empty_cache()
 196
 197    return images
 198
 199
 200def parse_args():
 201    parser = argparse.ArgumentParser(description="Simple example of a training script.")
 202    parser.add_argument(
 203        "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
 204    )
 205    parser.add_argument(
 206        "--pretrained_model_name_or_path",
 207        type=str,
 208        default=None,
 209        required=True,
 210        help="Path to pretrained model or model identifier from huggingface.co/models.",
 211    )
 212    parser.add_argument(
 213        "--revision",
 214        type=str,
 215        default=None,
 216        required=False,
 217        help="Revision of pretrained model identifier from huggingface.co/models.",
 218    )
 219    parser.add_argument(
 220        "--variant",
 221        type=str,
 222        default=None,
 223        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
 224    )
 225    parser.add_argument(
 226        "--dataset_name",
 227        type=str,
 228        default=None,
 229        help=(
 230            "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
 231            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
 232            " or to a folder containing files that 🤗 Datasets can understand."
 233        ),
 234    )
 235    parser.add_argument(
 236        "--dataset_config_name",
 237        type=str,
 238        default=None,
 239        help="The config of the Dataset, leave as None if there's only one config.",
 240    )
 241    parser.add_argument(
 242        "--train_data_dir",
 243        type=str,
 244        default=None,
 245        help=(
 246            "A folder containing the training data. Folder contents must follow the structure described in"
 247            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
 248            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
 249        ),
 250    )
 251    parser.add_argument(
 252        "--image_column", type=str, default="image", help="The column of the dataset containing an image."
 253    )
 254    parser.add_argument(
 255        "--caption_column",
 256        type=str,
 257        default="text",
 258        help="The column of the dataset containing a caption or a list of captions.",
 259    )
 260    parser.add_argument(
 261        "--max_train_samples",
 262        type=int,
 263        default=None,
 264        help=(
 265            "For debugging purposes or quicker training, truncate the number of training examples to this "
 266            "value if set."
 267        ),
 268    )
 269    parser.add_argument(
 270        "--validation_prompts",
 271        type=str,
 272        default=None,
 273        nargs="+",
 274        help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
 275    )
 276    parser.add_argument(
 277        "--output_dir",
 278        type=str,
 279        default="sd-model-finetuned",
 280        help="The output directory where the model predictions and checkpoints will be written.",
 281    )
 282    parser.add_argument(
 283        "--cache_dir",
 284        type=str,
 285        default=None,
 286        help="The directory where the downloaded models and datasets will be stored.",
 287    )
 288    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
 289    parser.add_argument(
 290        "--resolution",
 291        type=int,
 292        default=512,
 293        help=(
 294            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
 295            " resolution"
 296        ),
 297    )
 298    parser.add_argument(
 299        "--center_crop",
 300        default=False,
 301        action="store_true",
 302        help=(
 303            "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
 304            " cropped. The images will be resized to the resolution first before cropping."
 305        ),
 306    )
 307    parser.add_argument(
 308        "--random_flip",
 309        action="store_true",
 310        help="whether to randomly flip images horizontally",
 311    )
 312    parser.add_argument(
 313        "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
 314    )
 315    parser.add_argument("--num_train_epochs", type=int, default=100)
 316    parser.add_argument(
 317        "--max_train_steps",
 318        type=int,
 319        default=None,
 320        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
 321    )
 322    parser.add_argument(
 323        "--gradient_accumulation_steps",
 324        type=int,
 325        default=1,
 326        help="Number of updates steps to accumulate before performing a backward/update pass.",
 327    )
 328    parser.add_argument(
 329        "--gradient_checkpointing",
 330        action="store_true",
 331        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
 332    )
 333    parser.add_argument(
 334        "--learning_rate",
 335        type=float,
 336        default=1e-4,
 337        help="Initial learning rate (after the potential warmup period) to use.",
 338    )
 339    parser.add_argument(
 340        "--scale_lr",
 341        action="store_true",
 342        default=False,
 343        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
 344    )
 345    parser.add_argument(
 346        "--lr_scheduler",
 347        type=str,
 348        default="constant",
 349        help=(
 350            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
 351            ' "constant", "constant_with_warmup"]'
 352        ),
 353    )
 354    parser.add_argument(
 355        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
 356    )
 357    parser.add_argument(
 358        "--snr_gamma",
 359        type=float,
 360        default=None,
 361        help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
 362        "More details here: https://arxiv.org/abs/2303.09556.",
 363    )
 364    parser.add_argument(
 365        "--dream_training",
 366        action="store_true",
 367        help=(
 368            "Use the DREAM training method, which makes training more efficient and accurate at the ",
 369            "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
 370        ),
 371    )
 372    parser.add_argument(
 373        "--dream_detail_preservation",
 374        type=float,
 375        default=1.0,
 376        help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)",
 377    )
 378    parser.add_argument(
 379        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
 380    )
 381    parser.add_argument(
 382        "--allow_tf32",
 383        action="store_true",
 384        help=(
 385            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
 386            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
 387        ),
 388    )
 389    parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
 390    parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.")
 391    parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.")
 392    parser.add_argument(
 393        "--non_ema_revision",
 394        type=str,
 395        default=None,
 396        required=False,
 397        help=(
 398            "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
 399            " remote repository specified with --pretrained_model_name_or_path."
 400        ),
 401    )
 402    parser.add_argument(
 403        "--dataloader_num_workers",
 404        type=int,
 405        default=0,
 406        help=(
 407            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
 408        ),
 409    )
 410    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
 411    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
 412    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
 413    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
 414    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
 415    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
 416    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
 417    parser.add_argument(
 418        "--prediction_type",
 419        type=str,
 420        default=None,
 421        help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
 422    )
 423    parser.add_argument(
 424        "--hub_model_id",
 425        type=str,
 426        default=None,
 427        help="The name of the repository to keep in sync with the local `output_dir`.",
 428    )
 429    parser.add_argument(
 430        "--logging_dir",
 431        type=str,
 432        default="logs",
 433        help=(
 434            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
 435            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
 436        ),
 437    )
 438    parser.add_argument(
 439        "--mixed_precision",
 440        type=str,
 441        default=None,
 442        choices=["no", "fp16", "bf16"],
 443        help=(
 444            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
 445            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
 446            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
 447        ),
 448    )
 449    parser.add_argument(
 450        "--report_to",
 451        type=str,
 452        default="tensorboard",
 453        help=(
 454            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
 455            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
 456        ),
 457    )
 458    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
 459    parser.add_argument(
 460        "--checkpointing_steps",
 461        type=int,
 462        default=500,
 463        help=(
 464            "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
 465            " training using `--resume_from_checkpoint`."
 466        ),
 467    )
 468    parser.add_argument(
 469        "--checkpoints_total_limit",
 470        type=int,
 471        default=None,
 472        help=("Max number of checkpoints to store."),
 473    )
 474    parser.add_argument(
 475        "--resume_from_checkpoint",
 476        type=str,
 477        default=None,
 478        help=(
 479            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
 480            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
 481        ),
 482    )
 483    parser.add_argument(
 484        "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
 485    )
 486    parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
 487    parser.add_argument(
 488        "--validation_epochs",
 489        type=int,
 490        default=5,
 491        help="Run validation every X epochs.",
 492    )
 493    parser.add_argument(
 494        "--tracker_project_name",
 495        type=str,
 496        default="text2image-fine-tune",
 497        help=(
 498            "The `project_name` argument passed to Accelerator.init_trackers for"
 499            " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
 500        ),
 501    )
 502
 503    args = parser.parse_args()
 504    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
 505    if env_local_rank != -1 and env_local_rank != args.local_rank:
 506        args.local_rank = env_local_rank
 507
 508    # Sanity checks
 509    if args.dataset_name is None and args.train_data_dir is None:
 510        raise ValueError("Need either a dataset name or a training folder.")
 511
 512    # default to using the same revision for the non-ema model if not specified
 513    if args.non_ema_revision is None:
 514        args.non_ema_revision = args.revision
 515
 516    return args
 517
 518
 519def main():
 520    args = parse_args()
 521
 522    if args.report_to == "wandb" and args.hub_token is not None:
 523        raise ValueError(
 524            "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
 525            " Please use `huggingface-cli login` to authenticate with the Hub."
 526        )
 527
 528    if args.non_ema_revision is not None:
 529        deprecate(
 530            "non_ema_revision!=None",
 531            "0.15.0",
 532            message=(
 533                "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
 534                " use `--variant=non_ema` instead."
 535            ),
 536        )
 537    logging_dir = os.path.join(args.output_dir, args.logging_dir)
 538
 539    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
 540
 541    accelerator = Accelerator(
 542        gradient_accumulation_steps=args.gradient_accumulation_steps,
 543        mixed_precision=args.mixed_precision,
 544        log_with=args.report_to,
 545        project_config=accelerator_project_config,
 546    )
 547
 548    # Disable AMP for MPS.
 549    if torch.backends.mps.is_available():
 550        accelerator.native_amp = False
 551
 552    # Make one log on every process with the configuration for debugging.
 553    logging.basicConfig(
 554        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 555        datefmt="%m/%d/%Y %H:%M:%S",
 556        level=logging.INFO,
 557    )
 558    logger.info(accelerator.state, main_process_only=False)
 559    if accelerator.is_local_main_process:
 560        datasets.utils.logging.set_verbosity_warning()
 561        transformers.utils.logging.set_verbosity_warning()
 562        diffusers.utils.logging.set_verbosity_info()
 563    else:
 564        datasets.utils.logging.set_verbosity_error()
 565        transformers.utils.logging.set_verbosity_error()
 566        diffusers.utils.logging.set_verbosity_error()
 567
 568    # If passed along, set the training seed now.
 569    if args.seed is not None:
 570        set_seed(args.seed)
 571
 572    # Handle the repository creation
 573    if accelerator.is_main_process:
 574        if args.output_dir is not None:
 575            os.makedirs(args.output_dir, exist_ok=True)
 576
 577        if args.push_to_hub:
 578            repo_id = create_repo(
 579                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
 580            ).repo_id
 581
 582    # Load scheduler, tokenizer and models.
 583    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
 584    tokenizer = CLIPTokenizer.from_pretrained(
 585        args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
 586    )
 587
 588    def deepspeed_zero_init_disabled_context_manager():
 589        """
 590        returns either a context list that includes one that will disable zero.Init or an empty context list
 591        """
 592        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
 593        if deepspeed_plugin is None:
 594            return []
 595
 596        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
 597
 598    # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
 599    # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
 600    # will try to assign the same optimizer with the same weights to all models during
 601    # `deepspeed.initialize`, which of course doesn't work.
 602    #
 603    # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
 604    # frozen models from being partitioned during `zero.Init` which gets called during
 605    # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
 606    # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
 607    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
 608        text_encoder = CLIPTextModel.from_pretrained(
 609            args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
 610        )
 611        vae = AutoencoderKL.from_pretrained(
 612            args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
 613        )
 614
 615    unet = UNet2DConditionModel.from_pretrained(
 616        args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
 617    )
 618
 619    # Freeze vae and text_encoder and set unet to trainable
 620    vae.requires_grad_(False)
 621    text_encoder.requires_grad_(False)
 622    unet.train()
 623
 624    # Create EMA for the unet.
 625    if args.use_ema:
 626        ema_unet = UNet2DConditionModel.from_pretrained(
 627            args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
 628        )
 629        ema_unet = EMAModel(
 630            ema_unet.parameters(),
 631            model_cls=UNet2DConditionModel,
 632            model_config=ema_unet.config,
 633            foreach=args.foreach_ema,
 634        )
 635
 636    if args.enable_xformers_memory_efficient_attention:
 637        if is_xformers_available():
 638            import xformers
 639
 640            xformers_version = version.parse(xformers.__version__)
 641            if xformers_version == version.parse("0.0.16"):
 642                logger.warning(
 643                    "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
 644                )
 645            unet.enable_xformers_memory_efficient_attention()
 646        else:
 647            raise ValueError("xformers is not available. Make sure it is installed correctly")
 648
 649    # `accelerate` 0.16.0 will have better support for customized saving
 650    if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
 651        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
 652        def save_model_hook(models, weights, output_dir):
 653            if accelerator.is_main_process:
 654                if args.use_ema:
 655                    ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
 656
 657                for i, model in enumerate(models):
 658                    model.save_pretrained(os.path.join(output_dir, "unet"))
 659
 660                    # make sure to pop weight so that corresponding model is not saved again
 661                    weights.pop()
 662
 663        def load_model_hook(models, input_dir):
 664            if args.use_ema:
 665                load_model = EMAModel.from_pretrained(
 666                    os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema
 667                )
 668                ema_unet.load_state_dict(load_model.state_dict())
 669                if args.offload_ema:
 670                    ema_unet.pin_memory()
 671                else:
 672                    ema_unet.to(accelerator.device)
 673                del load_model
 674
 675            for _ in range(len(models)):
 676                # pop models so that they are not loaded again
 677                model = models.pop()
 678
 679                # load diffusers style into model
 680                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
 681                model.register_to_config(**load_model.config)
 682
 683                model.load_state_dict(load_model.state_dict())
 684                del load_model
 685
 686        accelerator.register_save_state_pre_hook(save_model_hook)
 687        accelerator.register_load_state_pre_hook(load_model_hook)
 688
 689    if args.gradient_checkpointing:
 690        unet.enable_gradient_checkpointing()
 691
 692    # Enable TF32 for faster training on Ampere GPUs,
 693    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
 694    if args.allow_tf32:
 695        torch.backends.cuda.matmul.allow_tf32 = True
 696
 697    if args.scale_lr:
 698        args.learning_rate = (
 699            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
 700        )
 701
 702    # Initialize the optimizer
 703    if args.use_8bit_adam:
 704        try:
 705            import bitsandbytes as bnb
 706        except ImportError:
 707            raise ImportError(
 708                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
 709            )
 710
 711        optimizer_cls = bnb.optim.AdamW8bit
 712    else:
 713        optimizer_cls = torch.optim.AdamW
 714
 715    optimizer = optimizer_cls(
 716        unet.parameters(),
 717        lr=args.learning_rate,
 718        betas=(args.adam_beta1, args.adam_beta2),
 719        weight_decay=args.adam_weight_decay,
 720        eps=args.adam_epsilon,
 721    )
 722
 723    # Get the datasets: you can either provide your own training and evaluation files (see below)
 724    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
 725
 726    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
 727    # download the dataset.
 728    if args.dataset_name is not None:
 729        # Downloading and loading a dataset from the hub.
 730        dataset = load_dataset(
 731            args.dataset_name,
 732            args.dataset_config_name,
 733            cache_dir=args.cache_dir,
 734            data_dir=args.train_data_dir,
 735        )
 736    else:
 737        data_files = {}
 738        if args.train_data_dir is not None:
 739            data_files["train"] = os.path.join(args.train_data_dir, "**")
 740        dataset = load_dataset(
 741            "imagefolder",
 742            data_files=data_files,
 743            cache_dir=args.cache_dir,
 744        )
 745        # See more about loading custom images at
 746        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
 747
 748    # Preprocessing the datasets.
 749    # We need to tokenize inputs and targets.
 750    column_names = dataset["train"].column_names
 751
 752    # 6. Get the column names for input/target.
 753    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
 754    if args.image_column is None:
 755        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
 756    else:
 757        image_column = args.image_column
 758        if image_column not in column_names:
 759            raise ValueError(
 760                f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
 761            )
 762    if args.caption_column is None:
 763        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
 764    else:
 765        caption_column = args.caption_column
 766        if caption_column not in column_names:
 767            raise ValueError(
 768                f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
 769            )
 770
 771    # Preprocessing the datasets.
 772    # We need to tokenize input captions and transform the images.
 773    def tokenize_captions(examples, is_train=True):
 774        captions = []
 775        for caption in examples[caption_column]:
 776            if isinstance(caption, str):
 777                captions.append(caption)
 778            elif isinstance(caption, (list, np.ndarray)):
 779                # take a random caption if there are multiple
 780                captions.append(random.choice(caption) if is_train else caption[0])
 781            else:
 782                raise ValueError(
 783                    f"Caption column `{caption_column}` should contain either strings or lists of strings."
 784                )
 785        inputs = tokenizer(
 786            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
 787        )
 788        return inputs.input_ids
 789
 790    # Preprocessing the datasets.
 791    train_transforms = transforms.Compose(
 792        [
 793            transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
 794            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
 795            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
 796            transforms.ToTensor(),
 797            transforms.Normalize([0.5], [0.5]),
 798        ]
 799    )
 800
 801    def preprocess_train(examples):
 802        images = [image.convert("RGB") for image in examples[image_column]]
 803        examples["pixel_values"] = [train_transforms(image) for image in images]
 804        examples["input_ids"] = tokenize_captions(examples)
 805        return examples
 806
 807    with accelerator.main_process_first():
 808        if args.max_train_samples is not None:
 809            dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
 810        # Set the training transforms
 811        train_dataset = dataset["train"].with_transform(preprocess_train)
 812
 813    def collate_fn(examples):
 814        pixel_values = torch.stack([example["pixel_values"] for example in examples])
 815        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
 816        input_ids = torch.stack([example["input_ids"] for example in examples])
 817        return {"pixel_values": pixel_values, "input_ids": input_ids}
 818
 819    # DataLoaders creation:
 820    train_dataloader = torch.utils.data.DataLoader(
 821        train_dataset,
 822        shuffle=True,
 823        collate_fn=collate_fn,
 824        batch_size=args.train_batch_size,
 825        num_workers=args.dataloader_num_workers,
 826    )
 827
 828    # Scheduler and math around the number of training steps.
 829    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
 830    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
 831    if args.max_train_steps is None:
 832        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
 833        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
 834        num_training_steps_for_scheduler = (
 835            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
 836        )
 837    else:
 838        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
 839
 840    lr_scheduler = get_scheduler(
 841        args.lr_scheduler,
 842        optimizer=optimizer,
 843        num_warmup_steps=num_warmup_steps_for_scheduler,
 844        num_training_steps=num_training_steps_for_scheduler,
 845    )
 846
 847    # Prepare everything with our `accelerator`.
 848    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
 849        unet, optimizer, train_dataloader, lr_scheduler
 850    )
 851
 852    if args.use_ema:
 853        if args.offload_ema:
 854            ema_unet.pin_memory()
 855        else:
 856            ema_unet.to(accelerator.device)
 857
 858    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
 859    # as these weights are only used for inference, keeping weights in full precision is not required.
 860    weight_dtype = torch.float32
 861    if accelerator.mixed_precision == "fp16":
 862        weight_dtype = torch.float16
 863        args.mixed_precision = accelerator.mixed_precision
 864    elif accelerator.mixed_precision == "bf16":
 865        weight_dtype = torch.bfloat16
 866        args.mixed_precision = accelerator.mixed_precision
 867
 868    # Move text_encode and vae to gpu and cast to weight_dtype
 869    text_encoder.to(accelerator.device, dtype=weight_dtype)
 870    vae.to(accelerator.device, dtype=weight_dtype)
 871
 872    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
 873    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
 874    if args.max_train_steps is None:
 875        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
 876        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
 877            logger.warning(
 878                f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
 879                f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
 880                f"This inconsistency may result in the learning rate scheduler not functioning properly."
 881            )
 882    # Afterwards we recalculate our number of training epochs
 883    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
 884
 885    # We need to initialize the trackers we use, and also store our configuration.
 886    # The trackers initializes automatically on the main process.
 887    if accelerator.is_main_process:
 888        tracker_config = dict(vars(args))
 889        tracker_config.pop("validation_prompts")
 890        accelerator.init_trackers(args.tracker_project_name, tracker_config)
 891
 892    # Function for unwrapping if model was compiled with `torch.compile`.
 893    def unwrap_model(model):
 894        model = accelerator.unwrap_model(model)
 895        model = model._orig_mod if is_compiled_module(model) else model
 896        return model
 897
 898    # Train!
 899    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
 900
 901    logger.info("***** Running training *****")
 902    logger.info(f"  Num examples = {len(train_dataset)}")
 903    logger.info(f"  Num Epochs = {args.num_train_epochs}")
 904    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
 905    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
 906    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
 907    logger.info(f"  Total optimization steps = {args.max_train_steps}")
 908    global_step = 0
 909    first_epoch = 0
 910
 911    # Potentially load in the weights and states from a previous save
 912    if args.resume_from_checkpoint:
 913        if args.resume_from_checkpoint != "latest":
 914            path = os.path.basename(args.resume_from_checkpoint)
 915        else:
 916            # Get the most recent checkpoint
 917            dirs = os.listdir(args.output_dir)
 918            dirs = [d for d in dirs if d.startswith("checkpoint")]
 919            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
 920            path = dirs[-1] if len(dirs) > 0 else None
 921
 922        if path is None:
 923            accelerator.print(
 924                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
 925            )
 926            args.resume_from_checkpoint = None
 927            initial_global_step = 0
 928        else:
 929            accelerator.print(f"Resuming from checkpoint {path}")
 930            accelerator.load_state(os.path.join(args.output_dir, path))
 931            global_step = int(path.split("-")[1])
 932
 933            initial_global_step = global_step
 934            first_epoch = global_step // num_update_steps_per_epoch
 935
 936    else:
 937        initial_global_step = 0
 938
 939    progress_bar = tqdm(
 940        range(0, args.max_train_steps),
 941        initial=initial_global_step,
 942        desc="Steps",
 943        # Only show the progress bar once on each machine.
 944        disable=not accelerator.is_local_main_process,
 945    )
 946
 947    for epoch in range(first_epoch, args.num_train_epochs):
 948        train_loss = 0.0
 949        for step, batch in enumerate(train_dataloader):
 950            with accelerator.accumulate(unet):
 951                # Convert images to latent space
 952                latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
 953                latents = latents * vae.config.scaling_factor
 954
 955                # Sample noise that we'll add to the latents
 956                noise = torch.randn_like(latents)
 957                if args.noise_offset:
 958                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise
 959                    noise += args.noise_offset * torch.randn(
 960                        (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
 961                    )
 962                if args.input_perturbation:
 963                    new_noise = noise + args.input_perturbation * torch.randn_like(noise)
 964                bsz = latents.shape[0]
 965                # Sample a random timestep for each image
 966                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
 967                timesteps = timesteps.long()
 968
 969                # Add noise to the latents according to the noise magnitude at each timestep
 970                # (this is the forward diffusion process)
 971                if args.input_perturbation:
 972                    noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
 973                else:
 974                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
 975
 976                # Get the text embedding for conditioning
 977                encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
 978
 979                # Get the target for loss depending on the prediction type
 980                if args.prediction_type is not None:
 981                    # set prediction_type of scheduler if defined
 982                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)
 983
 984                if noise_scheduler.config.prediction_type == "epsilon":
 985                    target = noise
 986                elif noise_scheduler.config.prediction_type == "v_prediction":
 987                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
 988                else:
 989                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
 990
 991                if args.dream_training:
 992                    noisy_latents, target = compute_dream_and_update_latents(
 993                        unet,
 994                        noise_scheduler,
 995                        timesteps,
 996                        noise,
 997                        noisy_latents,
 998                        target,
 999                        encoder_hidden_states,
1000                        args.dream_detail_preservation,
1001                    )
1002
1003                # Predict the noise residual and compute loss
1004                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
1005
1006                if args.snr_gamma is None:
1007                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1008                else:
1009                    # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1010                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1011                    # This is discussed in Section 4.2 of the same paper.
1012                    snr = compute_snr(noise_scheduler, timesteps)
1013                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
1014                        dim=1
1015                    )[0]
1016                    if noise_scheduler.config.prediction_type == "epsilon":
1017                        mse_loss_weights = mse_loss_weights / snr
1018                    elif noise_scheduler.config.prediction_type == "v_prediction":
1019                        mse_loss_weights = mse_loss_weights / (snr + 1)
1020
1021                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1022                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1023                    loss = loss.mean()
1024
1025                # Gather the losses across all processes for logging (if we use distributed training).
1026                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1027                train_loss += avg_loss.item() / args.gradient_accumulation_steps
1028
1029                # Backpropagate
1030                accelerator.backward(loss)
1031                if accelerator.sync_gradients:
1032                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
1033                optimizer.step()
1034                lr_scheduler.step()
1035                optimizer.zero_grad()
1036
1037            # Checks if the accelerator has performed an optimization step behind the scenes
1038            if accelerator.sync_gradients:
1039                if args.use_ema:
1040                    if args.offload_ema:
1041                        ema_unet.to(device="cuda", non_blocking=True)
1042                    ema_unet.step(unet.parameters())
1043                    if args.offload_ema:
1044                        ema_unet.to(device="cpu", non_blocking=True)
1045                progress_bar.update(1)
1046                global_step += 1
1047                accelerator.log({"train_loss": train_loss}, step=global_step)
1048                train_loss = 0.0
1049
1050                if global_step % args.checkpointing_steps == 0:
1051                    if accelerator.is_main_process:
1052                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1053                        if args.checkpoints_total_limit is not None:
1054                            checkpoints = os.listdir(args.output_dir)
1055                            checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1056                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1057
1058                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1059                            if len(checkpoints) >= args.checkpoints_total_limit:
1060                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1061                                removing_checkpoints = checkpoints[0:num_to_remove]
1062
1063                                logger.info(
1064                                    f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1065                                )
1066                                logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1067
1068                                for removing_checkpoint in removing_checkpoints:
1069                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1070                                    shutil.rmtree(removing_checkpoint)
1071
1072                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1073                        accelerator.save_state(save_path)
1074                        logger.info(f"Saved state to {save_path}")
1075
1076            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1077            progress_bar.set_postfix(**logs)
1078
1079            if global_step >= args.max_train_steps:
1080                break
1081
1082        if accelerator.is_main_process:
1083            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
1084                if args.use_ema:
1085                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
1086                    ema_unet.store(unet.parameters())
1087                    ema_unet.copy_to(unet.parameters())
1088                log_validation(
1089                    vae,
1090                    text_encoder,
1091                    tokenizer,
1092                    unet,
1093                    args,
1094                    accelerator,
1095                    weight_dtype,
1096                    global_step,
1097                )
1098                if args.use_ema:
1099                    # Switch back to the original UNet parameters.
1100                    ema_unet.restore(unet.parameters())
1101
1102    # Create the pipeline using the trained modules and save it.
1103    accelerator.wait_for_everyone()
1104    if accelerator.is_main_process:
1105        unet = unwrap_model(unet)
1106        if args.use_ema:
1107            ema_unet.copy_to(unet.parameters())
1108
1109        pipeline = StableDiffusionPipeline.from_pretrained(
1110            args.pretrained_model_name_or_path,
1111            text_encoder=text_encoder,
1112            vae=vae,
1113            unet=unet,
1114            revision=args.revision,
1115            variant=args.variant,
1116        )
1117        pipeline.save_pretrained(args.output_dir)
1118
1119        # Run a final round of inference.
1120        images = []
1121        if args.validation_prompts is not None:
1122            logger.info("Running inference for collecting generated images...")
1123            pipeline = pipeline.to(accelerator.device)
1124            pipeline.torch_dtype = weight_dtype
1125            pipeline.set_progress_bar_config(disable=True)
1126
1127            if args.enable_xformers_memory_efficient_attention:
1128                pipeline.enable_xformers_memory_efficient_attention()
1129
1130            if args.seed is None:
1131                generator = None
1132            else:
1133                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
1134
1135            for i in range(len(args.validation_prompts)):
1136                with torch.autocast("cuda"):
1137                    image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
1138                images.append(image)
1139
1140        if args.push_to_hub:
1141            save_model_card(args, repo_id, images, repo_folder=args.output_dir)
1142            upload_folder(
1143                repo_id=repo_id,
1144                folder_path=args.output_dir,
1145                commit_message="End of training",
1146                ignore_patterns=["step_*", "epoch_*"],
1147            )
1148
1149    accelerator.end_training()
1150
1151
1152if __name__ == "__main__":
1153    main()
 1export MODEL_NAME="CompVis/stable-diffusion-v1-4"
 2export DATASET_NAME="lambdalabs/naruto-blip-captions"
 3
 4accelerate launch --mixed_precision="fp16"  train_text_to_image.py \
 5--pretrained_model_name_or_path=$MODEL_NAME \
 6--dataset_name=$DATASET_NAME \
 7--use_ema \
 8--resolution=512 --center_crop --random_flip \
 9--train_batch_size=1 \
10--gradient_accumulation_steps=4 \
11--gradient_checkpointing \
12--max_train_steps=15000 \
13--learning_rate=1e-05 \
14--max_grad_norm=1 \
15--lr_scheduler="constant" --lr_warmup_steps=0 \
16--output_dir="sd-pokemon-model"