快速开始

备注

在运行下述示例之前,需要您已经安装了PyTorch-NPU环境,有关环境安装,请参考 安装指南

一般来说,要在代码中使用NPU进行训练推理,需要做以下更改:

  1. 导入torch_npu扩展包 import torch_npu

  2. 将模型,以及模型输入上传到NPU上

1device= torch.device("npu")
2model = model.to(device)
3input = input.to(device)

下面的实例演示了如何使用NPU进行训练和推理任务:

1. 单卡训练

以下代码使用了cifar10数据集在NPU上训练模型(截取自 PyTorch tutorials),请关注高亮的内容。

  1"""
  2Training an image classifier
  3----------------------------
  4
  5We will do the following steps in order:
  6
  71. Load and normalize the CIFAR10 training and test datasets using
  8``torchvision``
  91. Define a Convolutional Neural Network
 102. Define a loss function
 113. Train the network on the training data
 124. Test the network on the test data
 13
 145. Load and normalize CIFAR10
 15^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 16
 17Using ``torchvision``, it’s extremely easy to load CIFAR10.
 18"""
 19import torch
 20# 引入torch-npu包
 21import torch_npu
 22
 23# 定义device
 24device = torch.device('npu:0' if torch.npu.is_available() else 'cpu')
 25print(device)
 26
 27import torchvision
 28import torchvision.transforms as transforms
 29
 30########################################################################
 31# The output of torchvision datasets are PILImage images of range [0, 1].
 32# We transform them to Tensors of normalized range [-1, 1].
 33transform = transforms.Compose(
 34    [transforms.ToTensor(),
 35    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
 36
 37batch_size = 4
 38
 39trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
 40                                        download=True, transform=transform)
 41trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
 42                                        shuffle=True, num_workers=2)
 43
 44testset = torchvision.datasets.CIFAR10(root='./data', train=False,
 45                                    download=True, transform=transform)
 46testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
 47                                        shuffle=False, num_workers=2)
 48
 49classes = ('plane', 'car', 'bird', 'cat',
 50        'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
 51
 52########################################################################
 53# 2. Define a Convolutional Neural Network
 54# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 55# Copy the neural network from the Neural Networks section before and modify it to
 56# take 3-channel images (instead of 1-channel images as it was defined).
 57import torch.nn as nn
 58import torch.nn.functional as F
 59
 60
 61class Net(nn.Module):
 62    def __init__(self):
 63        super().__init__()
 64        self.conv1 = nn.Conv2d(3, 6, 5)
 65        self.pool = nn.MaxPool2d(2, 2)
 66        self.conv2 = nn.Conv2d(6, 16, 5)
 67        self.fc1 = nn.Linear(16 * 5 * 5, 120)
 68        self.fc2 = nn.Linear(120, 84)
 69        self.fc3 = nn.Linear(84, 10)
 70
 71    def forward(self, x):
 72        x = self.pool(F.relu(self.conv1(x)))
 73        x = self.pool(F.relu(self.conv2(x)))
 74        x = torch.flatten(x, 1) # flatten all dimensions except batch
 75        x = F.relu(self.fc1(x))
 76        x = F.relu(self.fc2(x))
 77        x = self.fc3(x)
 78        return x
 79
 80net = Net()
 81
 82# 将模型加载到NPU上
 83net.to(device)
 84
 85########################################################################
 86# 3. Define a Loss function and optimizer
 87# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 88# Let's use a Classification Cross-Entropy loss and SGD with momentum.
 89import torch.optim as optim
 90
 91criterion = nn.CrossEntropyLoss()
 92optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
 93
 94########################################################################
 95# 4. Train the network
 96# ^^^^^^^^^^^^^^^^^^^^
 97#
 98# This is when things start to get interesting.
 99# We simply have to loop over our data iterator, and feed the inputs to the
100# network and optimize.
101
102for epoch in range(2):  # loop over the dataset multiple times
103
104    running_loss = 0.0
105    for i, data in enumerate(trainloader, 0):
106        # get the inputs; data is a list of [inputs, labels]
107        # 将input数据发送到NPU上
108        inputs, labels = data[0].to(device), data[1].to(device)
109
110        # zero the parameter gradients
111        optimizer.zero_grad()
112
113        # forward + backward + optimize
114        outputs = net(inputs)
115        loss = criterion(outputs, labels)
116        loss.backward()
117        optimizer.step()
118
119        # print statistics
120        running_loss += loss.item()
121        if i % 2000 == 1999:    # print every 2000 mini-batches
122            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
123            running_loss = 0.0
124
125print('Finished Training')
126
127########################################################################
128# 5. Test the network on the test data
129# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
130#
131# We have trained the network for 2 passes over the training dataset.
132# But we need to check if the network has learnt anything at all.
133#
134# We will check this by predicting the class label that the neural network
135# outputs, and checking it against the ground-truth. If the prediction is
136# correct, we add the sample to the list of correct predictions.
137#
138# Let us look at how the network performs on the whole dataset.
139correct = 0
140total = 0
141# since we're not training, we don't need to calculate the gradients for our outputs
142with torch.no_grad():
143    for data in testloader:
144        # 将input数据发送到NPU上
145        images, labels = data[0].to(device), data[1].to(device)
146        # calculate outputs by running images through the network
147        outputs = net(images)
148        # the class with the highest energy is what we choose as prediction
149        _, predicted = torch.max(outputs.data, 1)
150        total += labels.size(0)
151        correct += (predicted == labels).sum().item()
152
153print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
154########################################################################
155# That looks way better than chance, which is 10% accuracy (randomly picking
156# a class out of 10 classes).
157# Seems like the network learnt something.
158#
159# Hmmm, what are the classes that performed well, and the classes that did
160# not perform well:
161
162# prepare to count predictions for each class
163correct_pred = {classname: 0 for classname in classes}
164total_pred = {classname: 0 for classname in classes}
165
166# again no gradients needed
167with torch.no_grad():
168    for data in testloader:
169        # 将input数据发送到NPU上
170        images, labels = data[0].to(device), data[1].to(device)
171        outputs = net(images)
172        _, predictions = torch.max(outputs, 1)
173        # collect the correct predictions for each class
174        for label, prediction in zip(labels, predictions):
175            if label == prediction:
176                correct_pred[classes[label]] += 1
177            total_pred[classes[label]] += 1
178
179
180# print accuracy for each class
181for classname, correct_count in correct_pred.items():
182    accuracy = 100 * float(correct_count) / total_pred[classname]
183    print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')

2. 使用DeepSpeed多卡并行训练

以下代码使用了cifar10数据集,使用DeepSpeed训练模型在多张NPU卡上进行模型训练(来自 DeepSpeed Examples),自DeepSpeed v0.12.6之后,代码无需任何修改,即可自动检测NPU并进行训练。

  1import argparse
  2import os
  3
  4import deepspeed
  5import torch
  6import torch.nn as nn
  7import torch.nn.functional as F
  8import torchvision
  9from torchvision import transforms
 10from deepspeed.accelerator import get_accelerator
 11from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
 12
 13
 14def add_argument():
 15    parser = argparse.ArgumentParser(description="CIFAR")
 16
 17    # For train.
 18    parser.add_argument(
 19        "-e",
 20        "--epochs",
 21        default=30,
 22        type=int,
 23        help="number of total epochs (default: 30)",
 24    )
 25    parser.add_argument(
 26        "--local_rank",
 27        type=int,
 28        default=-1,
 29        help="local rank passed from distributed launcher",
 30    )
 31    parser.add_argument(
 32        "--log-interval",
 33        type=int,
 34        default=2000,
 35        help="output logging information at a given interval",
 36    )
 37
 38    # For mixed precision training.
 39    parser.add_argument(
 40        "--dtype",
 41        default="fp16",
 42        type=str,
 43        choices=["bf16", "fp16", "fp32"],
 44        help="Datatype used for training",
 45    )
 46
 47    # For ZeRO Optimization.
 48    parser.add_argument(
 49        "--stage",
 50        default=0,
 51        type=int,
 52        choices=[0, 1, 2, 3],
 53        help="Datatype used for training",
 54    )
 55
 56    # For MoE (Mixture of Experts).
 57    parser.add_argument(
 58        "--moe",
 59        default=False,
 60        action="store_true",
 61        help="use deepspeed mixture of experts (moe)",
 62    )
 63    parser.add_argument(
 64        "--ep-world-size", default=1, type=int, help="(moe) expert parallel world size"
 65    )
 66    parser.add_argument(
 67        "--num-experts",
 68        type=int,
 69        nargs="+",
 70        default=[
 71            1,
 72        ],
 73        help="number of experts list, MoE related.",
 74    )
 75    parser.add_argument(
 76        "--mlp-type",
 77        type=str,
 78        default="standard",
 79        help="Only applicable when num-experts > 1, accepts [standard, residual]",
 80    )
 81    parser.add_argument(
 82        "--top-k", default=1, type=int, help="(moe) gating top 1 and 2 supported"
 83    )
 84    parser.add_argument(
 85        "--min-capacity",
 86        default=0,
 87        type=int,
 88        help="(moe) minimum capacity of an expert regardless of the capacity_factor",
 89    )
 90    parser.add_argument(
 91        "--noisy-gate-policy",
 92        default=None,
 93        type=str,
 94        help="(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter",
 95    )
 96    parser.add_argument(
 97        "--moe-param-group",
 98        default=False,
 99        action="store_true",
100        help="(moe) create separate moe param groups, required when using ZeRO w. MoE",
101    )
102
103    # Include DeepSpeed configuration arguments.
104    parser = deepspeed.add_config_arguments(parser)
105
106    args = parser.parse_args()
107
108    return args
109
110
111def create_moe_param_groups(model):
112    """Create separate parameter groups for each expert."""
113    parameters = {"params": [p for p in model.parameters()], "name": "parameters"}
114    return split_params_into_different_moe_groups_for_optimizer(parameters)
115
116
117def get_ds_config(args):
118    """Get the DeepSpeed configuration dictionary."""
119    ds_config = {
120        "train_batch_size": 16,
121        "steps_per_print": 2000,
122        "optimizer": {
123            "type": "Adam",
124            "params": {
125                "lr": 0.001,
126                "betas": [0.8, 0.999],
127                "eps": 1e-8,
128                "weight_decay": 3e-7,
129            },
130        },
131        "scheduler": {
132            "type": "WarmupLR",
133            "params": {
134                "warmup_min_lr": 0,
135                "warmup_max_lr": 0.001,
136                "warmup_num_steps": 1000,
137            },
138        },
139        "gradient_clipping": 1.0,
140        "prescale_gradients": False,
141        "bf16": {"enabled": args.dtype == "bf16"},
142        "fp16": {
143            "enabled": args.dtype == "fp16",
144            "fp16_master_weights_and_grads": False,
145            "loss_scale": 0,
146            "loss_scale_window": 500,
147            "hysteresis": 2,
148            "min_loss_scale": 1,
149            "initial_scale_power": 15,
150        },
151        "wall_clock_breakdown": False,
152        "zero_optimization": {
153            "stage": args.stage,
154            "allgather_partitions": True,
155            "reduce_scatter": True,
156            "allgather_bucket_size": 50000000,
157            "reduce_bucket_size": 50000000,
158            "overlap_comm": True,
159            "contiguous_gradients": True,
160            "cpu_offload": False,
161        },
162    }
163    return ds_config
164
165
166class Net(nn.Module):
167    def __init__(self, args):
168        super(Net, self).__init__()
169        self.conv1 = nn.Conv2d(3, 6, 5)
170        self.pool = nn.MaxPool2d(2, 2)
171        self.conv2 = nn.Conv2d(6, 16, 5)
172        self.fc1 = nn.Linear(16 * 5 * 5, 120)
173        self.fc2 = nn.Linear(120, 84)
174        self.moe = args.moe
175        if self.moe:
176            fc3 = nn.Linear(84, 84)
177            self.moe_layer_list = []
178            for n_e in args.num_experts:
179                # Create moe layers based on the number of experts.
180                self.moe_layer_list.append(
181                    deepspeed.moe.layer.MoE(
182                        hidden_size=84,
183                        expert=fc3,
184                        num_experts=n_e,
185                        ep_size=args.ep_world_size,
186                        use_residual=args.mlp_type == "residual",
187                        k=args.top_k,
188                        min_capacity=args.min_capacity,
189                        noisy_gate_policy=args.noisy_gate_policy,
190                    )
191                )
192            self.moe_layer_list = nn.ModuleList(self.moe_layer_list)
193            self.fc4 = nn.Linear(84, 10)
194        else:
195            self.fc3 = nn.Linear(84, 10)
196
197    def forward(self, x):
198        x = self.pool(F.relu(self.conv1(x)))
199        x = self.pool(F.relu(self.conv2(x)))
200        x = x.view(-1, 16 * 5 * 5)
201        x = F.relu(self.fc1(x))
202        x = F.relu(self.fc2(x))
203        if self.moe:
204            for layer in self.moe_layer_list:
205                x, _, _ = layer(x)
206            x = self.fc4(x)
207        else:
208            x = self.fc3(x)
209        return x
210
211
212def test(model_engine, testset, local_device, target_dtype, test_batch_size=4):
213    """Test the network on the test data.
214
215    Args:
216        model_engine (deepspeed.runtime.engine.DeepSpeedEngine): the DeepSpeed engine.
217        testset (torch.utils.data.Dataset): the test dataset.
218        local_device (str): the local device name.
219        target_dtype (torch.dtype): the target datatype for the test data.
220        test_batch_size (int): the test batch size.
221
222    """
223    # The 10 classes for CIFAR10.
224    classes = (
225        "plane",
226        "car",
227        "bird",
228        "cat",
229        "deer",
230        "dog",
231        "frog",
232        "horse",
233        "ship",
234        "truck",
235    )
236
237    # Define the test dataloader.
238    testloader = torch.utils.data.DataLoader(
239        testset, batch_size=test_batch_size, shuffle=False, num_workers=0
240    )
241
242    # For total accuracy.
243    correct, total = 0, 0
244    # For accuracy per class.
245    class_correct = list(0.0 for i in range(10))
246    class_total = list(0.0 for i in range(10))
247
248    # Start testing.
249    model_engine.eval()
250    with torch.no_grad():
251        for data in testloader:
252            images, labels = data
253            if target_dtype != None:
254                images = images.to(target_dtype)
255            outputs = model_engine(images.to(local_device))
256            _, predicted = torch.max(outputs.data, 1)
257            # Count the total accuracy.
258            total += labels.size(0)
259            correct += (predicted == labels.to(local_device)).sum().item()
260
261            # Count the accuracy per class.
262            batch_correct = (predicted == labels.to(local_device)).squeeze()
263            for i in range(test_batch_size):
264                label = labels[i]
265                class_correct[label] += batch_correct[i].item()
266                class_total[label] += 1
267
268    if model_engine.local_rank == 0:
269        print(
270            f"Accuracy of the network on the {total} test images: {100 * correct / total : .0f} %"
271        )
272
273        # For all classes, print the accuracy.
274        for i in range(10):
275            print(
276                f"Accuracy of {classes[i] : >5s} : {100 * class_correct[i] / class_total[i] : 2.0f} %"
277            )
278
279
280def main(args):
281    # Initialize DeepSpeed distributed backend.
282    deepspeed.init_distributed()
283    _local_rank = int(os.environ.get("LOCAL_RANK"))
284    get_accelerator().set_device(_local_rank)
285
286    ########################################################################
287    # Step1. Data Preparation.
288    #
289    # The output of torchvision datasets are PILImage images of range [0, 1].
290    # We transform them to Tensors of normalized range [-1, 1].
291    #
292    # Note:
293    #     If running on Windows and you get a BrokenPipeError, try setting
294    #     the num_worker of torch.utils.data.DataLoader() to 0.
295    ########################################################################
296    transform = transforms.Compose(
297        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
298    )
299
300    if torch.distributed.get_rank() != 0:
301        # Might be downloading cifar data, let rank 0 download first.
302        torch.distributed.barrier()
303
304    # Load or download cifar data.
305    trainset = torchvision.datasets.CIFAR10(
306        root="./data", train=True, download=True, transform=transform
307    )
308    testset = torchvision.datasets.CIFAR10(
309        root="./data", train=False, download=True, transform=transform
310    )
311
312    if torch.distributed.get_rank() == 0:
313        # Cifar data is downloaded, indicate other ranks can proceed.
314        torch.distributed.barrier()
315
316    ########################################################################
317    # Step 2. Define the network with DeepSpeed.
318    #
319    # First, we define a Convolution Neural Network.
320    # Then, we define the DeepSpeed configuration dictionary and use it to
321    # initialize the DeepSpeed engine.
322    ########################################################################
323    net = Net(args)
324
325    # Get list of parameters that require gradients.
326    parameters = filter(lambda p: p.requires_grad, net.parameters())
327
328    # If using MoE, create separate param groups for each expert.
329    if args.moe_param_group:
330        parameters = create_moe_param_groups(net)
331
332    # Initialize DeepSpeed to use the following features.
333    #   1) Distributed model.
334    #   2) Distributed data loader.
335    #   3) DeepSpeed optimizer.
336    ds_config = get_ds_config(args)
337    model_engine, optimizer, trainloader, __ = deepspeed.initialize(
338        args=args,
339        model=net,
340        model_parameters=parameters,
341        training_data=trainset,
342        config=ds_config,
343    )
344
345    # Get the local device name (str) and local rank (int).
346    local_device = get_accelerator().device_name(model_engine.local_rank)
347    local_rank = model_engine.local_rank
348
349    # For float32, target_dtype will be None so no datatype conversion needed.
350    target_dtype = None
351    if model_engine.bfloat16_enabled():
352        target_dtype = torch.bfloat16
353    elif model_engine.fp16_enabled():
354        target_dtype = torch.half
355
356    # Define the Classification Cross-Entropy loss function.
357    criterion = nn.CrossEntropyLoss()
358
359    ########################################################################
360    # Step 3. Train the network.
361    #
362    # This is when things start to get interesting.
363    # We simply have to loop over our data iterator, and feed the inputs to the
364    # network and optimize. (DeepSpeed handles the distributed details for us!)
365    ########################################################################
366
367    for epoch in range(args.epochs):  # loop over the dataset multiple times
368        running_loss = 0.0
369        for i, data in enumerate(trainloader):
370            # Get the inputs. ``data`` is a list of [inputs, labels].
371            inputs, labels = data[0].to(local_device), data[1].to(local_device)
372
373            # Try to convert to target_dtype if needed.
374            if target_dtype != None:
375                inputs = inputs.to(target_dtype)
376
377            outputs = model_engine(inputs)
378            loss = criterion(outputs, labels)
379
380            model_engine.backward(loss)
381            model_engine.step()
382
383            # Print statistics
384            running_loss += loss.item()
385            if local_rank == 0 and i % args.log_interval == (
386                args.log_interval - 1
387            ):  # Print every log_interval mini-batches.
388                print(
389                    f"[{epoch + 1 : d}, {i + 1 : 5d}] loss: {running_loss / args.log_interval : .3f}"
390                )
391                running_loss = 0.0
392    print("Finished Training")
393
394    ########################################################################
395    # Step 4. Test the network on the test data.
396    ########################################################################
397    test(model_engine, testset, local_device, target_dtype)
398
399
400if __name__ == "__main__":
401    args = add_argument()
402    main(args)

3. 使用Transforms进行模型微调

以下代码使用了Transforms对LLM进行微调(来自 transforms examples),自transforms xxx版本以及accelerator 0.21.0版本以后,代码无需任何修改,即可自动检测NPU并进行。

  1#!/usr/bin/env python
  2# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
  3#
  4# Licensed under the Apache License, Version 2.0 (the "License");
  5# you may not use this file except in compliance with the License.
  6# You may obtain a copy of the License at
  7#
  8#     http://www.apache.org/licenses/LICENSE-2.0
  9#
 10# Unless required by applicable law or agreed to in writing, software
 11# distributed under the License is distributed on an "AS IS" BASIS,
 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 13# See the License for the specific language governing permissions and
 14# limitations under the License.
 15"""
 16Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
 17
 18Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
 19https://huggingface.co/models?filter=text-generation
 20"""
 21# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
 22
 23import logging
 24import math
 25import os
 26import sys
 27from dataclasses import dataclass, field
 28from itertools import chain
 29from typing import Optional
 30
 31import datasets
 32import evaluate
 33import torch
 34from datasets import load_dataset
 35
 36import transformers
 37from transformers import (
 38    CONFIG_MAPPING,
 39    MODEL_FOR_CAUSAL_LM_MAPPING,
 40    AutoConfig,
 41    AutoModelForCausalLM,
 42    AutoTokenizer,
 43    HfArgumentParser,
 44    Trainer,
 45    TrainingArguments,
 46    default_data_collator,
 47    is_torch_xla_available,
 48    set_seed,
 49)
 50from transformers.testing_utils import CaptureLogger
 51from transformers.trainer_utils import get_last_checkpoint
 52from transformers.utils import check_min_version, send_example_telemetry
 53from transformers.utils.versions import require_version
 54
 55
 56# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
 57check_min_version("4.52.0.dev0")
 58
 59require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
 60
 61logger = logging.getLogger(__name__)
 62
 63
 64MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
 65MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
 66
 67
 68@dataclass
 69class ModelArguments:
 70    """
 71    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
 72    """
 73
 74    model_name_or_path: Optional[str] = field(
 75        default=None,
 76        metadata={
 77            "help": (
 78                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
 79            )
 80        },
 81    )
 82    model_type: Optional[str] = field(
 83        default=None,
 84        metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
 85    )
 86    config_overrides: Optional[str] = field(
 87        default=None,
 88        metadata={
 89            "help": (
 90                "Override some existing default config settings when a model is trained from scratch. Example: "
 91                "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
 92            )
 93        },
 94    )
 95    config_name: Optional[str] = field(
 96        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
 97    )
 98    tokenizer_name: Optional[str] = field(
 99        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
100    )
101    cache_dir: Optional[str] = field(
102        default=None,
103        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
104    )
105    use_fast_tokenizer: bool = field(
106        default=True,
107        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
108    )
109    model_revision: str = field(
110        default="main",
111        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
112    )
113    token: str = field(
114        default=None,
115        metadata={
116            "help": (
117                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
118                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
119            )
120        },
121    )
122    trust_remote_code: bool = field(
123        default=False,
124        metadata={
125            "help": (
126                "Whether to trust the execution of code from datasets/models defined on the Hub."
127                " This option should only be set to `True` for repositories you trust and in which you have read the"
128                " code, as it will execute code present on the Hub on your local machine."
129            )
130        },
131    )
132    torch_dtype: Optional[str] = field(
133        default=None,
134        metadata={
135            "help": (
136                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
137                "dtype will be automatically derived from the model's weights."
138            ),
139            "choices": ["auto", "bfloat16", "float16", "float32"],
140        },
141    )
142    low_cpu_mem_usage: bool = field(
143        default=False,
144        metadata={
145            "help": (
146                "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
147                "set True will benefit LLM loading time and RAM consumption."
148            )
149        },
150    )
151
152    def __post_init__(self):
153        if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
154            raise ValueError(
155                "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
156            )
157
158
159@dataclass
160class DataTrainingArguments:
161    """
162    Arguments pertaining to what data we are going to input our model for training and eval.
163    """
164
165    dataset_name: Optional[str] = field(
166        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
167    )
168    dataset_config_name: Optional[str] = field(
169        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
170    )
171    train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
172    validation_file: Optional[str] = field(
173        default=None,
174        metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
175    )
176    max_train_samples: Optional[int] = field(
177        default=None,
178        metadata={
179            "help": (
180                "For debugging purposes or quicker training, truncate the number of training examples to this "
181                "value if set."
182            )
183        },
184    )
185    max_eval_samples: Optional[int] = field(
186        default=None,
187        metadata={
188            "help": (
189                "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
190                "value if set."
191            )
192        },
193    )
194    streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
195    block_size: Optional[int] = field(
196        default=None,
197        metadata={
198            "help": (
199                "Optional input sequence length after tokenization. "
200                "The training dataset will be truncated in block of this size for training. "
201                "Default to the model max input length for single sentence inputs (take into account special tokens)."
202            )
203        },
204    )
205    overwrite_cache: bool = field(
206        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
207    )
208    validation_split_percentage: Optional[int] = field(
209        default=5,
210        metadata={
211            "help": "The percentage of the train set used as validation set in case there's no validation split"
212        },
213    )
214    preprocessing_num_workers: Optional[int] = field(
215        default=None,
216        metadata={"help": "The number of processes to use for the preprocessing."},
217    )
218    keep_linebreaks: bool = field(
219        default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
220    )
221
222    def __post_init__(self):
223        if self.streaming:
224            require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
225
226        if self.dataset_name is None and self.train_file is None and self.validation_file is None:
227            raise ValueError("Need either a dataset name or a training/validation file.")
228        else:
229            if self.train_file is not None:
230                extension = self.train_file.split(".")[-1]
231                assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
232            if self.validation_file is not None:
233                extension = self.validation_file.split(".")[-1]
234                assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
235
236
237def main():
238    # See all possible arguments in src/transformers/training_args.py
239    # or by passing the --help flag to this script.
240    # We now keep distinct sets of args, for a cleaner separation of concerns.
241
242    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
243    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
244        # If we pass only one argument to the script and it's the path to a json file,
245        # let's parse it to get our arguments.
246        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
247    else:
248        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
249
250    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
251    # information sent is the one passed as arguments along with your Python/PyTorch versions.
252    send_example_telemetry("run_clm", model_args, data_args)
253
254    # Setup logging
255    logging.basicConfig(
256        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
257        datefmt="%m/%d/%Y %H:%M:%S",
258        handlers=[logging.StreamHandler(sys.stdout)],
259    )
260
261    if training_args.should_log:
262        # The default of training_args.log_level is passive, so we set log level at info here to have that default.
263        transformers.utils.logging.set_verbosity_info()
264
265    log_level = training_args.get_process_log_level()
266    logger.setLevel(log_level)
267    datasets.utils.logging.set_verbosity(log_level)
268    transformers.utils.logging.set_verbosity(log_level)
269    transformers.utils.logging.enable_default_handler()
270    transformers.utils.logging.enable_explicit_format()
271
272    # Log on each process the small summary:
273    logger.warning(
274        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
275        + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
276    )
277    logger.info(f"Training/evaluation parameters {training_args}")
278
279    # Detecting last checkpoint.
280    last_checkpoint = None
281    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
282        last_checkpoint = get_last_checkpoint(training_args.output_dir)
283        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
284            raise ValueError(
285                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
286                "Use --overwrite_output_dir to overcome."
287            )
288        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
289            logger.info(
290                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
291                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
292            )
293
294    # Set seed before initializing model.
295    set_seed(training_args.seed)
296
297    # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
298    # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
299    # (the dataset will be downloaded automatically from the datasets Hub).
300    #
301    # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
302    # 'text' is found. You can easily tweak this behavior (see below).
303    #
304    # In distributed training, the load_dataset function guarantee that only one local process can concurrently
305    # download the dataset.
306    if data_args.dataset_name is not None:
307        # Downloading and loading a dataset from the hub.
308        raw_datasets = load_dataset(
309            data_args.dataset_name,
310            data_args.dataset_config_name,
311            cache_dir=model_args.cache_dir,
312            token=model_args.token,
313            streaming=data_args.streaming,
314            trust_remote_code=model_args.trust_remote_code,
315        )
316        if "validation" not in raw_datasets.keys():
317            raw_datasets["validation"] = load_dataset(
318                data_args.dataset_name,
319                data_args.dataset_config_name,
320                split=f"train[:{data_args.validation_split_percentage}%]",
321                cache_dir=model_args.cache_dir,
322                token=model_args.token,
323                streaming=data_args.streaming,
324                trust_remote_code=model_args.trust_remote_code,
325            )
326            raw_datasets["train"] = load_dataset(
327                data_args.dataset_name,
328                data_args.dataset_config_name,
329                split=f"train[{data_args.validation_split_percentage}%:]",
330                cache_dir=model_args.cache_dir,
331                token=model_args.token,
332                streaming=data_args.streaming,
333                trust_remote_code=model_args.trust_remote_code,
334            )
335    else:
336        data_files = {}
337        dataset_args = {}
338        if data_args.train_file is not None:
339            data_files["train"] = data_args.train_file
340        if data_args.validation_file is not None:
341            data_files["validation"] = data_args.validation_file
342        extension = (
343            data_args.train_file.split(".")[-1]
344            if data_args.train_file is not None
345            else data_args.validation_file.split(".")[-1]
346        )
347        if extension == "txt":
348            extension = "text"
349            dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
350        raw_datasets = load_dataset(
351            extension,
352            data_files=data_files,
353            cache_dir=model_args.cache_dir,
354            token=model_args.token,
355            **dataset_args,
356        )
357        # If no validation data is there, validation_split_percentage will be used to divide the dataset.
358        if "validation" not in raw_datasets.keys():
359            raw_datasets["validation"] = load_dataset(
360                extension,
361                data_files=data_files,
362                split=f"train[:{data_args.validation_split_percentage}%]",
363                cache_dir=model_args.cache_dir,
364                token=model_args.token,
365                **dataset_args,
366            )
367            raw_datasets["train"] = load_dataset(
368                extension,
369                data_files=data_files,
370                split=f"train[{data_args.validation_split_percentage}%:]",
371                cache_dir=model_args.cache_dir,
372                token=model_args.token,
373                **dataset_args,
374            )
375
376    # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
377    # https://huggingface.co/docs/datasets/loading_datasets.
378
379    # Load pretrained model and tokenizer
380    #
381    # Distributed training:
382    # The .from_pretrained methods guarantee that only one local process can concurrently
383    # download model & vocab.
384
385    config_kwargs = {
386        "cache_dir": model_args.cache_dir,
387        "revision": model_args.model_revision,
388        "token": model_args.token,
389        "trust_remote_code": model_args.trust_remote_code,
390    }
391    if model_args.config_name:
392        config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
393    elif model_args.model_name_or_path:
394        config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
395    else:
396        config = CONFIG_MAPPING[model_args.model_type]()
397        logger.warning("You are instantiating a new config instance from scratch.")
398        if model_args.config_overrides is not None:
399            logger.info(f"Overriding config: {model_args.config_overrides}")
400            config.update_from_string(model_args.config_overrides)
401            logger.info(f"New config: {config}")
402
403    tokenizer_kwargs = {
404        "cache_dir": model_args.cache_dir,
405        "use_fast": model_args.use_fast_tokenizer,
406        "revision": model_args.model_revision,
407        "token": model_args.token,
408        "trust_remote_code": model_args.trust_remote_code,
409    }
410    if model_args.tokenizer_name:
411        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
412    elif model_args.model_name_or_path:
413        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
414    else:
415        raise ValueError(
416            "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
417            "You can do it from another script, save it, and load it from here, using --tokenizer_name."
418        )
419
420    if model_args.model_name_or_path:
421        torch_dtype = (
422            model_args.torch_dtype
423            if model_args.torch_dtype in ["auto", None]
424            else getattr(torch, model_args.torch_dtype)
425        )
426        model = AutoModelForCausalLM.from_pretrained(
427            model_args.model_name_or_path,
428            from_tf=bool(".ckpt" in model_args.model_name_or_path),
429            config=config,
430            cache_dir=model_args.cache_dir,
431            revision=model_args.model_revision,
432            token=model_args.token,
433            trust_remote_code=model_args.trust_remote_code,
434            torch_dtype=torch_dtype,
435            low_cpu_mem_usage=model_args.low_cpu_mem_usage,
436        )
437    else:
438        model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
439        n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
440        logger.info(f"Training new model from scratch - Total size={n_params / 2**20:.2f}M params")
441
442    # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
443    # on a small vocab and want a smaller embedding size, remove this test.
444    embedding_size = model.get_input_embeddings().weight.shape[0]
445    if len(tokenizer) > embedding_size:
446        model.resize_token_embeddings(len(tokenizer))
447
448    # Preprocessing the datasets.
449    # First we tokenize all the texts.
450    if training_args.do_train:
451        column_names = list(raw_datasets["train"].features)
452    else:
453        column_names = list(raw_datasets["validation"].features)
454    text_column_name = "text" if "text" in column_names else column_names[0]
455
456    # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
457    tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
458
459    def tokenize_function(examples):
460        with CaptureLogger(tok_logger) as cl:
461            output = tokenizer(examples[text_column_name])
462        # clm input could be much much longer than block_size
463        if "Token indices sequence length is longer than the" in cl.out:
464            tok_logger.warning(
465                "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
466                " before being passed to the model."
467            )
468        return output
469
470    with training_args.main_process_first(desc="dataset map tokenization"):
471        if not data_args.streaming:
472            tokenized_datasets = raw_datasets.map(
473                tokenize_function,
474                batched=True,
475                num_proc=data_args.preprocessing_num_workers,
476                remove_columns=column_names,
477                load_from_cache_file=not data_args.overwrite_cache,
478                desc="Running tokenizer on dataset",
479            )
480        else:
481            tokenized_datasets = raw_datasets.map(
482                tokenize_function,
483                batched=True,
484                remove_columns=column_names,
485            )
486    if hasattr(config, "max_position_embeddings"):
487        max_pos_embeddings = config.max_position_embeddings
488    else:
489        # Define a default value if the attribute is missing in the config.
490        max_pos_embeddings = 1024
491
492    if data_args.block_size is None:
493        block_size = tokenizer.model_max_length
494        if block_size > max_pos_embeddings:
495            logger.warning(
496                f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
497                f"Using block_size={min(1024, max_pos_embeddings)} instead. You can change that default value by passing --block_size xxx."
498            )
499            if max_pos_embeddings > 0:
500                block_size = min(1024, max_pos_embeddings)
501            else:
502                block_size = 1024
503    else:
504        if data_args.block_size > tokenizer.model_max_length:
505            logger.warning(
506                f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model "
507                f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
508            )
509        block_size = min(data_args.block_size, tokenizer.model_max_length)
510
511    # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
512    def group_texts(examples):
513        # Concatenate all texts.
514        concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
515        total_length = len(concatenated_examples[list(examples.keys())[0]])
516        # We drop the small remainder, and if the total_length < block_size  we exclude this batch and return an empty dict.
517        # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
518        total_length = (total_length // block_size) * block_size
519        # Split by chunks of max_len.
520        result = {
521            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
522            for k, t in concatenated_examples.items()
523        }
524        result["labels"] = result["input_ids"].copy()
525        return result
526
527    # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
528    # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
529    # to preprocess.
530    #
531    # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
532    # https://huggingface.co/docs/datasets/process#map
533
534    with training_args.main_process_first(desc="grouping texts together"):
535        if not data_args.streaming:
536            lm_datasets = tokenized_datasets.map(
537                group_texts,
538                batched=True,
539                num_proc=data_args.preprocessing_num_workers,
540                load_from_cache_file=not data_args.overwrite_cache,
541                desc=f"Grouping texts in chunks of {block_size}",
542            )
543        else:
544            lm_datasets = tokenized_datasets.map(
545                group_texts,
546                batched=True,
547            )
548
549    if training_args.do_train:
550        if "train" not in tokenized_datasets:
551            raise ValueError("--do_train requires a train dataset")
552        train_dataset = lm_datasets["train"]
553        if data_args.max_train_samples is not None:
554            max_train_samples = min(len(train_dataset), data_args.max_train_samples)
555            train_dataset = train_dataset.select(range(max_train_samples))
556
557    if training_args.do_eval:
558        if "validation" not in tokenized_datasets:
559            raise ValueError("--do_eval requires a validation dataset")
560        eval_dataset = lm_datasets["validation"]
561        if data_args.max_eval_samples is not None:
562            max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
563            eval_dataset = eval_dataset.select(range(max_eval_samples))
564
565        def preprocess_logits_for_metrics(logits, labels):
566            if isinstance(logits, tuple):
567                # Depending on the model and config, logits may contain extra tensors,
568                # like past_key_values, but logits always come first
569                logits = logits[0]
570            return logits.argmax(dim=-1)
571
572        metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
573
574        def compute_metrics(eval_preds):
575            preds, labels = eval_preds
576            # preds have the same shape as the labels, after the argmax(-1) has been calculated
577            # by preprocess_logits_for_metrics but we need to shift the labels
578            labels = labels[:, 1:].reshape(-1)
579            preds = preds[:, :-1].reshape(-1)
580            return metric.compute(predictions=preds, references=labels)
581
582    # Initialize our Trainer
583    trainer = Trainer(
584        model=model,
585        args=training_args,
586        train_dataset=train_dataset if training_args.do_train else None,
587        eval_dataset=eval_dataset if training_args.do_eval else None,
588        processing_class=tokenizer,
589        # Data collator will default to DataCollatorWithPadding, so we change it.
590        data_collator=default_data_collator,
591        compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
592        preprocess_logits_for_metrics=preprocess_logits_for_metrics
593        if training_args.do_eval and not is_torch_xla_available()
594        else None,
595    )
596
597    # Training
598    if training_args.do_train:
599        checkpoint = None
600        if training_args.resume_from_checkpoint is not None:
601            checkpoint = training_args.resume_from_checkpoint
602        elif last_checkpoint is not None:
603            checkpoint = last_checkpoint
604        train_result = trainer.train(resume_from_checkpoint=checkpoint)
605        trainer.save_model()  # Saves the tokenizer too for easy upload
606
607        metrics = train_result.metrics
608
609        max_train_samples = (
610            data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
611        )
612        metrics["train_samples"] = min(max_train_samples, len(train_dataset))
613
614        trainer.log_metrics("train", metrics)
615        trainer.save_metrics("train", metrics)
616        trainer.save_state()
617
618    # Evaluation
619    if training_args.do_eval:
620        logger.info("*** Evaluate ***")
621
622        metrics = trainer.evaluate()
623
624        max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
625        metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
626        try:
627            perplexity = math.exp(metrics["eval_loss"])
628        except OverflowError:
629            perplexity = float("inf")
630        metrics["perplexity"] = perplexity
631
632        trainer.log_metrics("eval", metrics)
633        trainer.save_metrics("eval", metrics)
634
635    kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
636    if data_args.dataset_name is not None:
637        kwargs["dataset_tags"] = data_args.dataset_name
638        if data_args.dataset_config_name is not None:
639            kwargs["dataset_args"] = data_args.dataset_config_name
640            kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
641        else:
642            kwargs["dataset"] = data_args.dataset_name
643
644    if training_args.push_to_hub:
645        trainer.push_to_hub(**kwargs)
646    else:
647        trainer.create_model_card(**kwargs)
648
649
650def _mp_fn(index):
651    # For xla_spawn (TPUs)
652    main()
653
654
655if __name__ == "__main__":
656    main()
1python run_clm.py \
2    --model_name_or_path openai-community/gpt2 \
3    --train_file path_to_train_file \
4    --validation_file path_to_validation_file \
5    --per_device_train_batch_size 8 \
6    --per_device_eval_batch_size 8 \
7    --do_train \
8    --do_eval \
9    --output_dir /tmp/test-clm

4. 使用Diffusers进行模型微调

以下代码使用了Diffusers对文生图模型进行微调(来自 diffusers examples),自diffusers v0.27.0版本以后,代码无需任何修改,即可自动检测NPU并进行。

   1#!/usr/bin/env python
   2# coding=utf-8
   3# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
   4#
   5# Licensed under the Apache License, Version 2.0 (the "License");
   6# you may not use this file except in compliance with the License.
   7# You may obtain a copy of the License at
   8#
   9#     http://www.apache.org/licenses/LICENSE-2.0
  10#
  11# Unless required by applicable law or agreed to in writing, software
  12# distributed under the License is distributed on an "AS IS" BASIS,
  13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14# See the License for the specific language governing permissions and
  15# limitations under the License.
  16
  17import argparse
  18import logging
  19import math
  20import os
  21import random
  22import shutil
  23from contextlib import nullcontext
  24from pathlib import Path
  25
  26import accelerate
  27import datasets
  28import numpy as np
  29import torch
  30import torch.nn.functional as F
  31import torch.utils.checkpoint
  32import transformers
  33from accelerate import Accelerator
  34from accelerate.logging import get_logger
  35from accelerate.state import AcceleratorState
  36from accelerate.utils import ProjectConfiguration, set_seed
  37from datasets import load_dataset
  38from huggingface_hub import create_repo, upload_folder
  39from packaging import version
  40from torchvision import transforms
  41from tqdm.auto import tqdm
  42from transformers import CLIPTextModel, CLIPTokenizer
  43from transformers.utils import ContextManagers
  44
  45import diffusers
  46from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
  47from diffusers.optimization import get_scheduler
  48from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
  49from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
  50from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
  51from diffusers.utils.import_utils import is_xformers_available
  52from diffusers.utils.torch_utils import is_compiled_module
  53
  54
  55if is_wandb_available():
  56    import wandb
  57
  58
  59# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
  60check_min_version("0.34.0.dev0")
  61
  62logger = get_logger(__name__, log_level="INFO")
  63
  64DATASET_NAME_MAPPING = {
  65    "lambdalabs/naruto-blip-captions": ("image", "text"),
  66}
  67
  68
  69def save_model_card(
  70    args,
  71    repo_id: str,
  72    images: list = None,
  73    repo_folder: str = None,
  74):
  75    img_str = ""
  76    if len(images) > 0:
  77        image_grid = make_image_grid(images, 1, len(args.validation_prompts))
  78        image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
  79        img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
  80
  81    model_description = f"""
  82# Text-to-image finetuning - {repo_id}
  83
  84This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
  85{img_str}
  86
  87## Pipeline usage
  88
  89You can use the pipeline like so:
  90
  91```python
  92from diffusers import DiffusionPipeline
  93import torch
  94
  95pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
  96prompt = "{args.validation_prompts[0]}"
  97image = pipeline(prompt).images[0]
  98image.save("my_image.png")
  99```
 100
 101## Training info
 102
 103These are the key hyperparameters used during training:
 104
 105* Epochs: {args.num_train_epochs}
 106* Learning rate: {args.learning_rate}
 107* Batch size: {args.train_batch_size}
 108* Gradient accumulation steps: {args.gradient_accumulation_steps}
 109* Image resolution: {args.resolution}
 110* Mixed-precision: {args.mixed_precision}
 111
 112"""
 113    wandb_info = ""
 114    if is_wandb_available():
 115        wandb_run_url = None
 116        if wandb.run is not None:
 117            wandb_run_url = wandb.run.url
 118
 119    if wandb_run_url is not None:
 120        wandb_info = f"""
 121More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
 122"""
 123
 124    model_description += wandb_info
 125
 126    model_card = load_or_create_model_card(
 127        repo_id_or_path=repo_id,
 128        from_training=True,
 129        license="creativeml-openrail-m",
 130        base_model=args.pretrained_model_name_or_path,
 131        model_description=model_description,
 132        inference=True,
 133    )
 134
 135    tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
 136    model_card = populate_model_card(model_card, tags=tags)
 137
 138    model_card.save(os.path.join(repo_folder, "README.md"))
 139
 140
 141def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
 142    logger.info("Running validation... ")
 143
 144    pipeline = StableDiffusionPipeline.from_pretrained(
 145        args.pretrained_model_name_or_path,
 146        vae=accelerator.unwrap_model(vae),
 147        text_encoder=accelerator.unwrap_model(text_encoder),
 148        tokenizer=tokenizer,
 149        unet=accelerator.unwrap_model(unet),
 150        safety_checker=None,
 151        revision=args.revision,
 152        variant=args.variant,
 153        torch_dtype=weight_dtype,
 154    )
 155    pipeline = pipeline.to(accelerator.device)
 156    pipeline.set_progress_bar_config(disable=True)
 157
 158    if args.enable_xformers_memory_efficient_attention:
 159        pipeline.enable_xformers_memory_efficient_attention()
 160
 161    if args.seed is None:
 162        generator = None
 163    else:
 164        generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
 165
 166    images = []
 167    for i in range(len(args.validation_prompts)):
 168        if torch.backends.mps.is_available():
 169            autocast_ctx = nullcontext()
 170        else:
 171            autocast_ctx = torch.autocast(accelerator.device.type)
 172
 173        with autocast_ctx:
 174            image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
 175
 176        images.append(image)
 177
 178    for tracker in accelerator.trackers:
 179        if tracker.name == "tensorboard":
 180            np_images = np.stack([np.asarray(img) for img in images])
 181            tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
 182        elif tracker.name == "wandb":
 183            tracker.log(
 184                {
 185                    "validation": [
 186                        wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
 187                        for i, image in enumerate(images)
 188                    ]
 189                }
 190            )
 191        else:
 192            logger.warning(f"image logging not implemented for {tracker.name}")
 193
 194    del pipeline
 195    torch.cuda.empty_cache()
 196
 197    return images
 198
 199
 200def parse_args():
 201    parser = argparse.ArgumentParser(description="Simple example of a training script.")
 202    parser.add_argument(
 203        "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
 204    )
 205    parser.add_argument(
 206        "--pretrained_model_name_or_path",
 207        type=str,
 208        default=None,
 209        required=True,
 210        help="Path to pretrained model or model identifier from huggingface.co/models.",
 211    )
 212    parser.add_argument(
 213        "--revision",
 214        type=str,
 215        default=None,
 216        required=False,
 217        help="Revision of pretrained model identifier from huggingface.co/models.",
 218    )
 219    parser.add_argument(
 220        "--variant",
 221        type=str,
 222        default=None,
 223        help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
 224    )
 225    parser.add_argument(
 226        "--dataset_name",
 227        type=str,
 228        default=None,
 229        help=(
 230            "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
 231            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
 232            " or to a folder containing files that 🤗 Datasets can understand."
 233        ),
 234    )
 235    parser.add_argument(
 236        "--dataset_config_name",
 237        type=str,
 238        default=None,
 239        help="The config of the Dataset, leave as None if there's only one config.",
 240    )
 241    parser.add_argument(
 242        "--train_data_dir",
 243        type=str,
 244        default=None,
 245        help=(
 246            "A folder containing the training data. Folder contents must follow the structure described in"
 247            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
 248            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
 249        ),
 250    )
 251    parser.add_argument(
 252        "--image_column", type=str, default="image", help="The column of the dataset containing an image."
 253    )
 254    parser.add_argument(
 255        "--caption_column",
 256        type=str,
 257        default="text",
 258        help="The column of the dataset containing a caption or a list of captions.",
 259    )
 260    parser.add_argument(
 261        "--max_train_samples",
 262        type=int,
 263        default=None,
 264        help=(
 265            "For debugging purposes or quicker training, truncate the number of training examples to this "
 266            "value if set."
 267        ),
 268    )
 269    parser.add_argument(
 270        "--validation_prompts",
 271        type=str,
 272        default=None,
 273        nargs="+",
 274        help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
 275    )
 276    parser.add_argument(
 277        "--output_dir",
 278        type=str,
 279        default="sd-model-finetuned",
 280        help="The output directory where the model predictions and checkpoints will be written.",
 281    )
 282    parser.add_argument(
 283        "--cache_dir",
 284        type=str,
 285        default=None,
 286        help="The directory where the downloaded models and datasets will be stored.",
 287    )
 288    parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
 289    parser.add_argument(
 290        "--resolution",
 291        type=int,
 292        default=512,
 293        help=(
 294            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
 295            " resolution"
 296        ),
 297    )
 298    parser.add_argument(
 299        "--center_crop",
 300        default=False,
 301        action="store_true",
 302        help=(
 303            "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
 304            " cropped. The images will be resized to the resolution first before cropping."
 305        ),
 306    )
 307    parser.add_argument(
 308        "--random_flip",
 309        action="store_true",
 310        help="whether to randomly flip images horizontally",
 311    )
 312    parser.add_argument(
 313        "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
 314    )
 315    parser.add_argument("--num_train_epochs", type=int, default=100)
 316    parser.add_argument(
 317        "--max_train_steps",
 318        type=int,
 319        default=None,
 320        help="Total number of training steps to perform.  If provided, overrides num_train_epochs.",
 321    )
 322    parser.add_argument(
 323        "--gradient_accumulation_steps",
 324        type=int,
 325        default=1,
 326        help="Number of updates steps to accumulate before performing a backward/update pass.",
 327    )
 328    parser.add_argument(
 329        "--gradient_checkpointing",
 330        action="store_true",
 331        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
 332    )
 333    parser.add_argument(
 334        "--learning_rate",
 335        type=float,
 336        default=1e-4,
 337        help="Initial learning rate (after the potential warmup period) to use.",
 338    )
 339    parser.add_argument(
 340        "--scale_lr",
 341        action="store_true",
 342        default=False,
 343        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
 344    )
 345    parser.add_argument(
 346        "--lr_scheduler",
 347        type=str,
 348        default="constant",
 349        help=(
 350            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
 351            ' "constant", "constant_with_warmup"]'
 352        ),
 353    )
 354    parser.add_argument(
 355        "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
 356    )
 357    parser.add_argument(
 358        "--snr_gamma",
 359        type=float,
 360        default=None,
 361        help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
 362        "More details here: https://arxiv.org/abs/2303.09556.",
 363    )
 364    parser.add_argument(
 365        "--dream_training",
 366        action="store_true",
 367        help=(
 368            "Use the DREAM training method, which makes training more efficient and accurate at the "
 369            "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210"
 370        ),
 371    )
 372    parser.add_argument(
 373        "--dream_detail_preservation",
 374        type=float,
 375        default=1.0,
 376        help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)",
 377    )
 378    parser.add_argument(
 379        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
 380    )
 381    parser.add_argument(
 382        "--allow_tf32",
 383        action="store_true",
 384        help=(
 385            "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
 386            " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
 387        ),
 388    )
 389    parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
 390    parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.")
 391    parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.")
 392    parser.add_argument(
 393        "--non_ema_revision",
 394        type=str,
 395        default=None,
 396        required=False,
 397        help=(
 398            "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
 399            " remote repository specified with --pretrained_model_name_or_path."
 400        ),
 401    )
 402    parser.add_argument(
 403        "--dataloader_num_workers",
 404        type=int,
 405        default=0,
 406        help=(
 407            "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
 408        ),
 409    )
 410    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
 411    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
 412    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
 413    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
 414    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
 415    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
 416    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
 417    parser.add_argument(
 418        "--prediction_type",
 419        type=str,
 420        default=None,
 421        help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
 422    )
 423    parser.add_argument(
 424        "--hub_model_id",
 425        type=str,
 426        default=None,
 427        help="The name of the repository to keep in sync with the local `output_dir`.",
 428    )
 429    parser.add_argument(
 430        "--logging_dir",
 431        type=str,
 432        default="logs",
 433        help=(
 434            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
 435            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
 436        ),
 437    )
 438    parser.add_argument(
 439        "--mixed_precision",
 440        type=str,
 441        default=None,
 442        choices=["no", "fp16", "bf16"],
 443        help=(
 444            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
 445            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
 446            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
 447        ),
 448    )
 449    parser.add_argument(
 450        "--report_to",
 451        type=str,
 452        default="tensorboard",
 453        help=(
 454            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
 455            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
 456        ),
 457    )
 458    parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
 459    parser.add_argument(
 460        "--checkpointing_steps",
 461        type=int,
 462        default=500,
 463        help=(
 464            "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
 465            " training using `--resume_from_checkpoint`."
 466        ),
 467    )
 468    parser.add_argument(
 469        "--checkpoints_total_limit",
 470        type=int,
 471        default=None,
 472        help=("Max number of checkpoints to store."),
 473    )
 474    parser.add_argument(
 475        "--resume_from_checkpoint",
 476        type=str,
 477        default=None,
 478        help=(
 479            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
 480            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
 481        ),
 482    )
 483    parser.add_argument(
 484        "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
 485    )
 486    parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
 487    parser.add_argument(
 488        "--validation_epochs",
 489        type=int,
 490        default=5,
 491        help="Run validation every X epochs.",
 492    )
 493    parser.add_argument(
 494        "--tracker_project_name",
 495        type=str,
 496        default="text2image-fine-tune",
 497        help=(
 498            "The `project_name` argument passed to Accelerator.init_trackers for"
 499            " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
 500        ),
 501    )
 502    parser.add_argument(
 503        "--image_interpolation_mode",
 504        type=str,
 505        default="lanczos",
 506        choices=[
 507            f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
 508        ],
 509        help="The image interpolation method to use for resizing images.",
 510    )
 511
 512    args = parser.parse_args()
 513    env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
 514    if env_local_rank != -1 and env_local_rank != args.local_rank:
 515        args.local_rank = env_local_rank
 516
 517    # Sanity checks
 518    if args.dataset_name is None and args.train_data_dir is None:
 519        raise ValueError("Need either a dataset name or a training folder.")
 520
 521    # default to using the same revision for the non-ema model if not specified
 522    if args.non_ema_revision is None:
 523        args.non_ema_revision = args.revision
 524
 525    return args
 526
 527
 528def main():
 529    args = parse_args()
 530
 531    if args.report_to == "wandb" and args.hub_token is not None:
 532        raise ValueError(
 533            "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
 534            " Please use `huggingface-cli login` to authenticate with the Hub."
 535        )
 536
 537    if args.non_ema_revision is not None:
 538        deprecate(
 539            "non_ema_revision!=None",
 540            "0.15.0",
 541            message=(
 542                "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
 543                " use `--variant=non_ema` instead."
 544            ),
 545        )
 546    logging_dir = os.path.join(args.output_dir, args.logging_dir)
 547
 548    accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
 549
 550    accelerator = Accelerator(
 551        gradient_accumulation_steps=args.gradient_accumulation_steps,
 552        mixed_precision=args.mixed_precision,
 553        log_with=args.report_to,
 554        project_config=accelerator_project_config,
 555    )
 556
 557    # Disable AMP for MPS.
 558    if torch.backends.mps.is_available():
 559        accelerator.native_amp = False
 560
 561    # Make one log on every process with the configuration for debugging.
 562    logging.basicConfig(
 563        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 564        datefmt="%m/%d/%Y %H:%M:%S",
 565        level=logging.INFO,
 566    )
 567    logger.info(accelerator.state, main_process_only=False)
 568    if accelerator.is_local_main_process:
 569        datasets.utils.logging.set_verbosity_warning()
 570        transformers.utils.logging.set_verbosity_warning()
 571        diffusers.utils.logging.set_verbosity_info()
 572    else:
 573        datasets.utils.logging.set_verbosity_error()
 574        transformers.utils.logging.set_verbosity_error()
 575        diffusers.utils.logging.set_verbosity_error()
 576
 577    # If passed along, set the training seed now.
 578    if args.seed is not None:
 579        set_seed(args.seed)
 580
 581    # Handle the repository creation
 582    if accelerator.is_main_process:
 583        if args.output_dir is not None:
 584            os.makedirs(args.output_dir, exist_ok=True)
 585
 586        if args.push_to_hub:
 587            repo_id = create_repo(
 588                repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
 589            ).repo_id
 590
 591    # Load scheduler, tokenizer and models.
 592    noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
 593    tokenizer = CLIPTokenizer.from_pretrained(
 594        args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
 595    )
 596
 597    def deepspeed_zero_init_disabled_context_manager():
 598        """
 599        returns either a context list that includes one that will disable zero.Init or an empty context list
 600        """
 601        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
 602        if deepspeed_plugin is None:
 603            return []
 604
 605        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
 606
 607    # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
 608    # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
 609    # will try to assign the same optimizer with the same weights to all models during
 610    # `deepspeed.initialize`, which of course doesn't work.
 611    #
 612    # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
 613    # frozen models from being partitioned during `zero.Init` which gets called during
 614    # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
 615    # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
 616    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
 617        text_encoder = CLIPTextModel.from_pretrained(
 618            args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
 619        )
 620        vae = AutoencoderKL.from_pretrained(
 621            args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
 622        )
 623
 624    unet = UNet2DConditionModel.from_pretrained(
 625        args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
 626    )
 627
 628    # Freeze vae and text_encoder and set unet to trainable
 629    vae.requires_grad_(False)
 630    text_encoder.requires_grad_(False)
 631    unet.train()
 632
 633    # Create EMA for the unet.
 634    if args.use_ema:
 635        ema_unet = UNet2DConditionModel.from_pretrained(
 636            args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
 637        )
 638        ema_unet = EMAModel(
 639            ema_unet.parameters(),
 640            model_cls=UNet2DConditionModel,
 641            model_config=ema_unet.config,
 642            foreach=args.foreach_ema,
 643        )
 644
 645    if args.enable_xformers_memory_efficient_attention:
 646        if is_xformers_available():
 647            import xformers
 648
 649            xformers_version = version.parse(xformers.__version__)
 650            if xformers_version == version.parse("0.0.16"):
 651                logger.warning(
 652                    "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
 653                )
 654            unet.enable_xformers_memory_efficient_attention()
 655        else:
 656            raise ValueError("xformers is not available. Make sure it is installed correctly")
 657
 658    # `accelerate` 0.16.0 will have better support for customized saving
 659    if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
 660        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
 661        def save_model_hook(models, weights, output_dir):
 662            if accelerator.is_main_process:
 663                if args.use_ema:
 664                    ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
 665
 666                for i, model in enumerate(models):
 667                    model.save_pretrained(os.path.join(output_dir, "unet"))
 668
 669                    # make sure to pop weight so that corresponding model is not saved again
 670                    weights.pop()
 671
 672        def load_model_hook(models, input_dir):
 673            if args.use_ema:
 674                load_model = EMAModel.from_pretrained(
 675                    os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema
 676                )
 677                ema_unet.load_state_dict(load_model.state_dict())
 678                if args.offload_ema:
 679                    ema_unet.pin_memory()
 680                else:
 681                    ema_unet.to(accelerator.device)
 682                del load_model
 683
 684            for _ in range(len(models)):
 685                # pop models so that they are not loaded again
 686                model = models.pop()
 687
 688                # load diffusers style into model
 689                load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
 690                model.register_to_config(**load_model.config)
 691
 692                model.load_state_dict(load_model.state_dict())
 693                del load_model
 694
 695        accelerator.register_save_state_pre_hook(save_model_hook)
 696        accelerator.register_load_state_pre_hook(load_model_hook)
 697
 698    if args.gradient_checkpointing:
 699        unet.enable_gradient_checkpointing()
 700
 701    # Enable TF32 for faster training on Ampere GPUs,
 702    # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
 703    if args.allow_tf32:
 704        torch.backends.cuda.matmul.allow_tf32 = True
 705
 706    if args.scale_lr:
 707        args.learning_rate = (
 708            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
 709        )
 710
 711    # Initialize the optimizer
 712    if args.use_8bit_adam:
 713        try:
 714            import bitsandbytes as bnb
 715        except ImportError:
 716            raise ImportError(
 717                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
 718            )
 719
 720        optimizer_cls = bnb.optim.AdamW8bit
 721    else:
 722        optimizer_cls = torch.optim.AdamW
 723
 724    optimizer = optimizer_cls(
 725        unet.parameters(),
 726        lr=args.learning_rate,
 727        betas=(args.adam_beta1, args.adam_beta2),
 728        weight_decay=args.adam_weight_decay,
 729        eps=args.adam_epsilon,
 730    )
 731
 732    # Get the datasets: you can either provide your own training and evaluation files (see below)
 733    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
 734
 735    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
 736    # download the dataset.
 737    if args.dataset_name is not None:
 738        # Downloading and loading a dataset from the hub.
 739        dataset = load_dataset(
 740            args.dataset_name,
 741            args.dataset_config_name,
 742            cache_dir=args.cache_dir,
 743            data_dir=args.train_data_dir,
 744        )
 745    else:
 746        data_files = {}
 747        if args.train_data_dir is not None:
 748            data_files["train"] = os.path.join(args.train_data_dir, "**")
 749        dataset = load_dataset(
 750            "imagefolder",
 751            data_files=data_files,
 752            cache_dir=args.cache_dir,
 753        )
 754        # See more about loading custom images at
 755        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
 756
 757    # Preprocessing the datasets.
 758    # We need to tokenize inputs and targets.
 759    column_names = dataset["train"].column_names
 760
 761    # 6. Get the column names for input/target.
 762    dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
 763    if args.image_column is None:
 764        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
 765    else:
 766        image_column = args.image_column
 767        if image_column not in column_names:
 768            raise ValueError(
 769                f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
 770            )
 771    if args.caption_column is None:
 772        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
 773    else:
 774        caption_column = args.caption_column
 775        if caption_column not in column_names:
 776            raise ValueError(
 777                f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
 778            )
 779
 780    # Preprocessing the datasets.
 781    # We need to tokenize input captions and transform the images.
 782    def tokenize_captions(examples, is_train=True):
 783        captions = []
 784        for caption in examples[caption_column]:
 785            if isinstance(caption, str):
 786                captions.append(caption)
 787            elif isinstance(caption, (list, np.ndarray)):
 788                # take a random caption if there are multiple
 789                captions.append(random.choice(caption) if is_train else caption[0])
 790            else:
 791                raise ValueError(
 792                    f"Caption column `{caption_column}` should contain either strings or lists of strings."
 793                )
 794        inputs = tokenizer(
 795            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
 796        )
 797        return inputs.input_ids
 798
 799    # Get the specified interpolation method from the args
 800    interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
 801
 802    # Raise an error if the interpolation method is invalid
 803    if interpolation is None:
 804        raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
 805
 806    # Data preprocessing transformations
 807    train_transforms = transforms.Compose(
 808        [
 809            transforms.Resize(args.resolution, interpolation=interpolation),  # Use dynamic interpolation method
 810            transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
 811            transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
 812            transforms.ToTensor(),
 813            transforms.Normalize([0.5], [0.5]),
 814        ]
 815    )
 816
 817    def preprocess_train(examples):
 818        images = [image.convert("RGB") for image in examples[image_column]]
 819        examples["pixel_values"] = [train_transforms(image) for image in images]
 820        examples["input_ids"] = tokenize_captions(examples)
 821        return examples
 822
 823    with accelerator.main_process_first():
 824        if args.max_train_samples is not None:
 825            dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
 826        # Set the training transforms
 827        train_dataset = dataset["train"].with_transform(preprocess_train)
 828
 829    def collate_fn(examples):
 830        pixel_values = torch.stack([example["pixel_values"] for example in examples])
 831        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
 832        input_ids = torch.stack([example["input_ids"] for example in examples])
 833        return {"pixel_values": pixel_values, "input_ids": input_ids}
 834
 835    # DataLoaders creation:
 836    train_dataloader = torch.utils.data.DataLoader(
 837        train_dataset,
 838        shuffle=True,
 839        collate_fn=collate_fn,
 840        batch_size=args.train_batch_size,
 841        num_workers=args.dataloader_num_workers,
 842    )
 843
 844    # Scheduler and math around the number of training steps.
 845    # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
 846    num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
 847    if args.max_train_steps is None:
 848        len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
 849        num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
 850        num_training_steps_for_scheduler = (
 851            args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
 852        )
 853    else:
 854        num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
 855
 856    lr_scheduler = get_scheduler(
 857        args.lr_scheduler,
 858        optimizer=optimizer,
 859        num_warmup_steps=num_warmup_steps_for_scheduler,
 860        num_training_steps=num_training_steps_for_scheduler,
 861    )
 862
 863    # Prepare everything with our `accelerator`.
 864    unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
 865        unet, optimizer, train_dataloader, lr_scheduler
 866    )
 867
 868    if args.use_ema:
 869        if args.offload_ema:
 870            ema_unet.pin_memory()
 871        else:
 872            ema_unet.to(accelerator.device)
 873
 874    # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
 875    # as these weights are only used for inference, keeping weights in full precision is not required.
 876    weight_dtype = torch.float32
 877    if accelerator.mixed_precision == "fp16":
 878        weight_dtype = torch.float16
 879        args.mixed_precision = accelerator.mixed_precision
 880    elif accelerator.mixed_precision == "bf16":
 881        weight_dtype = torch.bfloat16
 882        args.mixed_precision = accelerator.mixed_precision
 883
 884    # Move text_encode and vae to gpu and cast to weight_dtype
 885    text_encoder.to(accelerator.device, dtype=weight_dtype)
 886    vae.to(accelerator.device, dtype=weight_dtype)
 887
 888    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
 889    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
 890    if args.max_train_steps is None:
 891        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
 892        if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
 893            logger.warning(
 894                f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
 895                f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
 896                f"This inconsistency may result in the learning rate scheduler not functioning properly."
 897            )
 898    # Afterwards we recalculate our number of training epochs
 899    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
 900
 901    # We need to initialize the trackers we use, and also store our configuration.
 902    # The trackers initializes automatically on the main process.
 903    if accelerator.is_main_process:
 904        tracker_config = dict(vars(args))
 905        tracker_config.pop("validation_prompts")
 906        accelerator.init_trackers(args.tracker_project_name, tracker_config)
 907
 908    # Function for unwrapping if model was compiled with `torch.compile`.
 909    def unwrap_model(model):
 910        model = accelerator.unwrap_model(model)
 911        model = model._orig_mod if is_compiled_module(model) else model
 912        return model
 913
 914    # Train!
 915    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
 916
 917    logger.info("***** Running training *****")
 918    logger.info(f"  Num examples = {len(train_dataset)}")
 919    logger.info(f"  Num Epochs = {args.num_train_epochs}")
 920    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
 921    logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
 922    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
 923    logger.info(f"  Total optimization steps = {args.max_train_steps}")
 924    global_step = 0
 925    first_epoch = 0
 926
 927    # Potentially load in the weights and states from a previous save
 928    if args.resume_from_checkpoint:
 929        if args.resume_from_checkpoint != "latest":
 930            path = os.path.basename(args.resume_from_checkpoint)
 931        else:
 932            # Get the most recent checkpoint
 933            dirs = os.listdir(args.output_dir)
 934            dirs = [d for d in dirs if d.startswith("checkpoint")]
 935            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
 936            path = dirs[-1] if len(dirs) > 0 else None
 937
 938        if path is None:
 939            accelerator.print(
 940                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
 941            )
 942            args.resume_from_checkpoint = None
 943            initial_global_step = 0
 944        else:
 945            accelerator.print(f"Resuming from checkpoint {path}")
 946            accelerator.load_state(os.path.join(args.output_dir, path))
 947            global_step = int(path.split("-")[1])
 948
 949            initial_global_step = global_step
 950            first_epoch = global_step // num_update_steps_per_epoch
 951
 952    else:
 953        initial_global_step = 0
 954
 955    progress_bar = tqdm(
 956        range(0, args.max_train_steps),
 957        initial=initial_global_step,
 958        desc="Steps",
 959        # Only show the progress bar once on each machine.
 960        disable=not accelerator.is_local_main_process,
 961    )
 962
 963    for epoch in range(first_epoch, args.num_train_epochs):
 964        train_loss = 0.0
 965        for step, batch in enumerate(train_dataloader):
 966            with accelerator.accumulate(unet):
 967                # Convert images to latent space
 968                latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
 969                latents = latents * vae.config.scaling_factor
 970
 971                # Sample noise that we'll add to the latents
 972                noise = torch.randn_like(latents)
 973                if args.noise_offset:
 974                    # https://www.crosslabs.org//blog/diffusion-with-offset-noise
 975                    noise += args.noise_offset * torch.randn(
 976                        (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
 977                    )
 978                if args.input_perturbation:
 979                    new_noise = noise + args.input_perturbation * torch.randn_like(noise)
 980                bsz = latents.shape[0]
 981                # Sample a random timestep for each image
 982                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
 983                timesteps = timesteps.long()
 984
 985                # Add noise to the latents according to the noise magnitude at each timestep
 986                # (this is the forward diffusion process)
 987                if args.input_perturbation:
 988                    noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
 989                else:
 990                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
 991
 992                # Get the text embedding for conditioning
 993                encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
 994
 995                # Get the target for loss depending on the prediction type
 996                if args.prediction_type is not None:
 997                    # set prediction_type of scheduler if defined
 998                    noise_scheduler.register_to_config(prediction_type=args.prediction_type)
 999
1000                if noise_scheduler.config.prediction_type == "epsilon":
1001                    target = noise
1002                elif noise_scheduler.config.prediction_type == "v_prediction":
1003                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
1004                else:
1005                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1006
1007                if args.dream_training:
1008                    noisy_latents, target = compute_dream_and_update_latents(
1009                        unet,
1010                        noise_scheduler,
1011                        timesteps,
1012                        noise,
1013                        noisy_latents,
1014                        target,
1015                        encoder_hidden_states,
1016                        args.dream_detail_preservation,
1017                    )
1018
1019                # Predict the noise residual and compute loss
1020                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
1021
1022                if args.snr_gamma is None:
1023                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1024                else:
1025                    # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1026                    # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1027                    # This is discussed in Section 4.2 of the same paper.
1028                    snr = compute_snr(noise_scheduler, timesteps)
1029                    mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
1030                        dim=1
1031                    )[0]
1032                    if noise_scheduler.config.prediction_type == "epsilon":
1033                        mse_loss_weights = mse_loss_weights / snr
1034                    elif noise_scheduler.config.prediction_type == "v_prediction":
1035                        mse_loss_weights = mse_loss_weights / (snr + 1)
1036
1037                    loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1038                    loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1039                    loss = loss.mean()
1040
1041                # Gather the losses across all processes for logging (if we use distributed training).
1042                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1043                train_loss += avg_loss.item() / args.gradient_accumulation_steps
1044
1045                # Backpropagate
1046                accelerator.backward(loss)
1047                if accelerator.sync_gradients:
1048                    accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
1049                optimizer.step()
1050                lr_scheduler.step()
1051                optimizer.zero_grad()
1052
1053            # Checks if the accelerator has performed an optimization step behind the scenes
1054            if accelerator.sync_gradients:
1055                if args.use_ema:
1056                    if args.offload_ema:
1057                        ema_unet.to(device="cuda", non_blocking=True)
1058                    ema_unet.step(unet.parameters())
1059                    if args.offload_ema:
1060                        ema_unet.to(device="cpu", non_blocking=True)
1061                progress_bar.update(1)
1062                global_step += 1
1063                accelerator.log({"train_loss": train_loss}, step=global_step)
1064                train_loss = 0.0
1065
1066                if global_step % args.checkpointing_steps == 0:
1067                    if accelerator.is_main_process:
1068                        # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1069                        if args.checkpoints_total_limit is not None:
1070                            checkpoints = os.listdir(args.output_dir)
1071                            checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1072                            checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1073
1074                            # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1075                            if len(checkpoints) >= args.checkpoints_total_limit:
1076                                num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1077                                removing_checkpoints = checkpoints[0:num_to_remove]
1078
1079                                logger.info(
1080                                    f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1081                                )
1082                                logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1083
1084                                for removing_checkpoint in removing_checkpoints:
1085                                    removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1086                                    shutil.rmtree(removing_checkpoint)
1087
1088                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1089                        accelerator.save_state(save_path)
1090                        logger.info(f"Saved state to {save_path}")
1091
1092            logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1093            progress_bar.set_postfix(**logs)
1094
1095            if global_step >= args.max_train_steps:
1096                break
1097
1098        if accelerator.is_main_process:
1099            if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
1100                if args.use_ema:
1101                    # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
1102                    ema_unet.store(unet.parameters())
1103                    ema_unet.copy_to(unet.parameters())
1104                log_validation(
1105                    vae,
1106                    text_encoder,
1107                    tokenizer,
1108                    unet,
1109                    args,
1110                    accelerator,
1111                    weight_dtype,
1112                    global_step,
1113                )
1114                if args.use_ema:
1115                    # Switch back to the original UNet parameters.
1116                    ema_unet.restore(unet.parameters())
1117
1118    # Create the pipeline using the trained modules and save it.
1119    accelerator.wait_for_everyone()
1120    if accelerator.is_main_process:
1121        unet = unwrap_model(unet)
1122        if args.use_ema:
1123            ema_unet.copy_to(unet.parameters())
1124
1125        pipeline = StableDiffusionPipeline.from_pretrained(
1126            args.pretrained_model_name_or_path,
1127            text_encoder=text_encoder,
1128            vae=vae,
1129            unet=unet,
1130            revision=args.revision,
1131            variant=args.variant,
1132        )
1133        pipeline.save_pretrained(args.output_dir)
1134
1135        # Run a final round of inference.
1136        images = []
1137        if args.validation_prompts is not None:
1138            logger.info("Running inference for collecting generated images...")
1139            pipeline = pipeline.to(accelerator.device)
1140            pipeline.torch_dtype = weight_dtype
1141            pipeline.set_progress_bar_config(disable=True)
1142
1143            if args.enable_xformers_memory_efficient_attention:
1144                pipeline.enable_xformers_memory_efficient_attention()
1145
1146            if args.seed is None:
1147                generator = None
1148            else:
1149                generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
1150
1151            for i in range(len(args.validation_prompts)):
1152                with torch.autocast("cuda"):
1153                    image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
1154                images.append(image)
1155
1156        if args.push_to_hub:
1157            save_model_card(args, repo_id, images, repo_folder=args.output_dir)
1158            upload_folder(
1159                repo_id=repo_id,
1160                folder_path=args.output_dir,
1161                commit_message="End of training",
1162                ignore_patterns=["step_*", "epoch_*"],
1163            )
1164
1165    accelerator.end_training()
1166
1167
1168if __name__ == "__main__":
1169    main()
 1export MODEL_NAME="CompVis/stable-diffusion-v1-4"
 2export DATASET_NAME="lambdalabs/naruto-blip-captions"
 3
 4accelerate launch --mixed_precision="fp16"  train_text_to_image.py \
 5--pretrained_model_name_or_path=$MODEL_NAME \
 6--dataset_name=$DATASET_NAME \
 7--use_ema \
 8--resolution=512 --center_crop --random_flip \
 9--train_batch_size=1 \
10--gradient_accumulation_steps=4 \
11--gradient_checkpointing \
12--max_train_steps=15000 \
13--learning_rate=1e-05 \
14--max_grad_norm=1 \
15--lr_scheduler="constant" --lr_warmup_steps=0 \
16--output_dir="sd-pokemon-model"