快速开始
备注
在运行下述示例之前,需要您已经安装了PyTorch-NPU环境,有关环境安装,请参考 安装指南
一般来说,要在代码中使用NPU进行训练推理,需要做以下更改:
导入torch_npu扩展包
import torch_npu
将模型,以及模型输入上传到NPU上
1device= torch.device("npu")
2model = model.to(device)
3input = input.to(device)
下面的实例演示了如何使用NPU进行训练和推理任务:
1. 单卡训练
以下代码使用了cifar10数据集在NPU上训练模型(截取自 PyTorch tutorials),请关注高亮的内容。
1"""
2Training an image classifier
3----------------------------
4
5We will do the following steps in order:
6
71. Load and normalize the CIFAR10 training and test datasets using
8``torchvision``
91. Define a Convolutional Neural Network
102. Define a loss function
113. Train the network on the training data
124. Test the network on the test data
13
145. Load and normalize CIFAR10
15^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
16
17Using ``torchvision``, it’s extremely easy to load CIFAR10.
18"""
19import torch
20# 引入torch-npu包
21import torch_npu
22
23# 定义device
24device = torch.device('npu:0' if torch.npu.is_available() else 'cpu')
25print(device)
26
27import torchvision
28import torchvision.transforms as transforms
29
30########################################################################
31# The output of torchvision datasets are PILImage images of range [0, 1].
32# We transform them to Tensors of normalized range [-1, 1].
33transform = transforms.Compose(
34 [transforms.ToTensor(),
35 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
36
37batch_size = 4
38
39trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
40 download=True, transform=transform)
41trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
42 shuffle=True, num_workers=2)
43
44testset = torchvision.datasets.CIFAR10(root='./data', train=False,
45 download=True, transform=transform)
46testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
47 shuffle=False, num_workers=2)
48
49classes = ('plane', 'car', 'bird', 'cat',
50 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
51
52########################################################################
53# 2. Define a Convolutional Neural Network
54# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
55# Copy the neural network from the Neural Networks section before and modify it to
56# take 3-channel images (instead of 1-channel images as it was defined).
57import torch.nn as nn
58import torch.nn.functional as F
59
60
61class Net(nn.Module):
62 def __init__(self):
63 super().__init__()
64 self.conv1 = nn.Conv2d(3, 6, 5)
65 self.pool = nn.MaxPool2d(2, 2)
66 self.conv2 = nn.Conv2d(6, 16, 5)
67 self.fc1 = nn.Linear(16 * 5 * 5, 120)
68 self.fc2 = nn.Linear(120, 84)
69 self.fc3 = nn.Linear(84, 10)
70
71 def forward(self, x):
72 x = self.pool(F.relu(self.conv1(x)))
73 x = self.pool(F.relu(self.conv2(x)))
74 x = torch.flatten(x, 1) # flatten all dimensions except batch
75 x = F.relu(self.fc1(x))
76 x = F.relu(self.fc2(x))
77 x = self.fc3(x)
78 return x
79
80net = Net()
81
82# 将模型加载到NPU上
83net.to(device)
84
85########################################################################
86# 3. Define a Loss function and optimizer
87# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
88# Let's use a Classification Cross-Entropy loss and SGD with momentum.
89import torch.optim as optim
90
91criterion = nn.CrossEntropyLoss()
92optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
93
94########################################################################
95# 4. Train the network
96# ^^^^^^^^^^^^^^^^^^^^
97#
98# This is when things start to get interesting.
99# We simply have to loop over our data iterator, and feed the inputs to the
100# network and optimize.
101
102for epoch in range(2): # loop over the dataset multiple times
103
104 running_loss = 0.0
105 for i, data in enumerate(trainloader, 0):
106 # get the inputs; data is a list of [inputs, labels]
107 # 将input数据发送到NPU上
108 inputs, labels = data[0].to(device), data[1].to(device)
109
110 # zero the parameter gradients
111 optimizer.zero_grad()
112
113 # forward + backward + optimize
114 outputs = net(inputs)
115 loss = criterion(outputs, labels)
116 loss.backward()
117 optimizer.step()
118
119 # print statistics
120 running_loss += loss.item()
121 if i % 2000 == 1999: # print every 2000 mini-batches
122 print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
123 running_loss = 0.0
124
125print('Finished Training')
126
127########################################################################
128# 5. Test the network on the test data
129# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
130#
131# We have trained the network for 2 passes over the training dataset.
132# But we need to check if the network has learnt anything at all.
133#
134# We will check this by predicting the class label that the neural network
135# outputs, and checking it against the ground-truth. If the prediction is
136# correct, we add the sample to the list of correct predictions.
137#
138# Let us look at how the network performs on the whole dataset.
139correct = 0
140total = 0
141# since we're not training, we don't need to calculate the gradients for our outputs
142with torch.no_grad():
143 for data in testloader:
144 # 将input数据发送到NPU上
145 images, labels = data[0].to(device), data[1].to(device)
146 # calculate outputs by running images through the network
147 outputs = net(images)
148 # the class with the highest energy is what we choose as prediction
149 _, predicted = torch.max(outputs.data, 1)
150 total += labels.size(0)
151 correct += (predicted == labels).sum().item()
152
153print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
154########################################################################
155# That looks way better than chance, which is 10% accuracy (randomly picking
156# a class out of 10 classes).
157# Seems like the network learnt something.
158#
159# Hmmm, what are the classes that performed well, and the classes that did
160# not perform well:
161
162# prepare to count predictions for each class
163correct_pred = {classname: 0 for classname in classes}
164total_pred = {classname: 0 for classname in classes}
165
166# again no gradients needed
167with torch.no_grad():
168 for data in testloader:
169 # 将input数据发送到NPU上
170 images, labels = data[0].to(device), data[1].to(device)
171 outputs = net(images)
172 _, predictions = torch.max(outputs, 1)
173 # collect the correct predictions for each class
174 for label, prediction in zip(labels, predictions):
175 if label == prediction:
176 correct_pred[classes[label]] += 1
177 total_pred[classes[label]] += 1
178
179
180# print accuracy for each class
181for classname, correct_count in correct_pred.items():
182 accuracy = 100 * float(correct_count) / total_pred[classname]
183 print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
2. 使用DeepSpeed多卡并行训练
以下代码使用了cifar10数据集,使用DeepSpeed训练模型在多张NPU卡上进行模型训练(来自 DeepSpeed Examples),自DeepSpeed v0.12.6之后,代码无需任何修改,即可自动检测NPU并进行训练。
1import argparse
2import os
3
4import deepspeed
5import torch
6import torch.nn as nn
7import torch.nn.functional as F
8import torchvision
9from torchvision import transforms
10from deepspeed.accelerator import get_accelerator
11from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
12
13
14def add_argument():
15 parser = argparse.ArgumentParser(description="CIFAR")
16
17 # For train.
18 parser.add_argument(
19 "-e",
20 "--epochs",
21 default=30,
22 type=int,
23 help="number of total epochs (default: 30)",
24 )
25 parser.add_argument(
26 "--local_rank",
27 type=int,
28 default=-1,
29 help="local rank passed from distributed launcher",
30 )
31 parser.add_argument(
32 "--log-interval",
33 type=int,
34 default=2000,
35 help="output logging information at a given interval",
36 )
37
38 # For mixed precision training.
39 parser.add_argument(
40 "--dtype",
41 default="fp16",
42 type=str,
43 choices=["bf16", "fp16", "fp32"],
44 help="Datatype used for training",
45 )
46
47 # For ZeRO Optimization.
48 parser.add_argument(
49 "--stage",
50 default=0,
51 type=int,
52 choices=[0, 1, 2, 3],
53 help="Datatype used for training",
54 )
55
56 # For MoE (Mixture of Experts).
57 parser.add_argument(
58 "--moe",
59 default=False,
60 action="store_true",
61 help="use deepspeed mixture of experts (moe)",
62 )
63 parser.add_argument(
64 "--ep-world-size", default=1, type=int, help="(moe) expert parallel world size"
65 )
66 parser.add_argument(
67 "--num-experts",
68 type=int,
69 nargs="+",
70 default=[
71 1,
72 ],
73 help="number of experts list, MoE related.",
74 )
75 parser.add_argument(
76 "--mlp-type",
77 type=str,
78 default="standard",
79 help="Only applicable when num-experts > 1, accepts [standard, residual]",
80 )
81 parser.add_argument(
82 "--top-k", default=1, type=int, help="(moe) gating top 1 and 2 supported"
83 )
84 parser.add_argument(
85 "--min-capacity",
86 default=0,
87 type=int,
88 help="(moe) minimum capacity of an expert regardless of the capacity_factor",
89 )
90 parser.add_argument(
91 "--noisy-gate-policy",
92 default=None,
93 type=str,
94 help="(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter",
95 )
96 parser.add_argument(
97 "--moe-param-group",
98 default=False,
99 action="store_true",
100 help="(moe) create separate moe param groups, required when using ZeRO w. MoE",
101 )
102
103 # Include DeepSpeed configuration arguments.
104 parser = deepspeed.add_config_arguments(parser)
105
106 args = parser.parse_args()
107
108 return args
109
110
111def create_moe_param_groups(model):
112 """Create separate parameter groups for each expert."""
113 parameters = {"params": [p for p in model.parameters()], "name": "parameters"}
114 return split_params_into_different_moe_groups_for_optimizer(parameters)
115
116
117def get_ds_config(args):
118 """Get the DeepSpeed configuration dictionary."""
119 ds_config = {
120 "train_batch_size": 16,
121 "steps_per_print": 2000,
122 "optimizer": {
123 "type": "Adam",
124 "params": {
125 "lr": 0.001,
126 "betas": [0.8, 0.999],
127 "eps": 1e-8,
128 "weight_decay": 3e-7,
129 },
130 },
131 "scheduler": {
132 "type": "WarmupLR",
133 "params": {
134 "warmup_min_lr": 0,
135 "warmup_max_lr": 0.001,
136 "warmup_num_steps": 1000,
137 },
138 },
139 "gradient_clipping": 1.0,
140 "prescale_gradients": False,
141 "bf16": {"enabled": args.dtype == "bf16"},
142 "fp16": {
143 "enabled": args.dtype == "fp16",
144 "fp16_master_weights_and_grads": False,
145 "loss_scale": 0,
146 "loss_scale_window": 500,
147 "hysteresis": 2,
148 "min_loss_scale": 1,
149 "initial_scale_power": 15,
150 },
151 "wall_clock_breakdown": False,
152 "zero_optimization": {
153 "stage": args.stage,
154 "allgather_partitions": True,
155 "reduce_scatter": True,
156 "allgather_bucket_size": 50000000,
157 "reduce_bucket_size": 50000000,
158 "overlap_comm": True,
159 "contiguous_gradients": True,
160 "cpu_offload": False,
161 },
162 }
163 return ds_config
164
165
166class Net(nn.Module):
167 def __init__(self, args):
168 super(Net, self).__init__()
169 self.conv1 = nn.Conv2d(3, 6, 5)
170 self.pool = nn.MaxPool2d(2, 2)
171 self.conv2 = nn.Conv2d(6, 16, 5)
172 self.fc1 = nn.Linear(16 * 5 * 5, 120)
173 self.fc2 = nn.Linear(120, 84)
174 self.moe = args.moe
175 if self.moe:
176 fc3 = nn.Linear(84, 84)
177 self.moe_layer_list = []
178 for n_e in args.num_experts:
179 # Create moe layers based on the number of experts.
180 self.moe_layer_list.append(
181 deepspeed.moe.layer.MoE(
182 hidden_size=84,
183 expert=fc3,
184 num_experts=n_e,
185 ep_size=args.ep_world_size,
186 use_residual=args.mlp_type == "residual",
187 k=args.top_k,
188 min_capacity=args.min_capacity,
189 noisy_gate_policy=args.noisy_gate_policy,
190 )
191 )
192 self.moe_layer_list = nn.ModuleList(self.moe_layer_list)
193 self.fc4 = nn.Linear(84, 10)
194 else:
195 self.fc3 = nn.Linear(84, 10)
196
197 def forward(self, x):
198 x = self.pool(F.relu(self.conv1(x)))
199 x = self.pool(F.relu(self.conv2(x)))
200 x = x.view(-1, 16 * 5 * 5)
201 x = F.relu(self.fc1(x))
202 x = F.relu(self.fc2(x))
203 if self.moe:
204 for layer in self.moe_layer_list:
205 x, _, _ = layer(x)
206 x = self.fc4(x)
207 else:
208 x = self.fc3(x)
209 return x
210
211
212def test(model_engine, testset, local_device, target_dtype, test_batch_size=4):
213 """Test the network on the test data.
214
215 Args:
216 model_engine (deepspeed.runtime.engine.DeepSpeedEngine): the DeepSpeed engine.
217 testset (torch.utils.data.Dataset): the test dataset.
218 local_device (str): the local device name.
219 target_dtype (torch.dtype): the target datatype for the test data.
220 test_batch_size (int): the test batch size.
221
222 """
223 # The 10 classes for CIFAR10.
224 classes = (
225 "plane",
226 "car",
227 "bird",
228 "cat",
229 "deer",
230 "dog",
231 "frog",
232 "horse",
233 "ship",
234 "truck",
235 )
236
237 # Define the test dataloader.
238 testloader = torch.utils.data.DataLoader(
239 testset, batch_size=test_batch_size, shuffle=False, num_workers=0
240 )
241
242 # For total accuracy.
243 correct, total = 0, 0
244 # For accuracy per class.
245 class_correct = list(0.0 for i in range(10))
246 class_total = list(0.0 for i in range(10))
247
248 # Start testing.
249 model_engine.eval()
250 with torch.no_grad():
251 for data in testloader:
252 images, labels = data
253 if target_dtype != None:
254 images = images.to(target_dtype)
255 outputs = model_engine(images.to(local_device))
256 _, predicted = torch.max(outputs.data, 1)
257 # Count the total accuracy.
258 total += labels.size(0)
259 correct += (predicted == labels.to(local_device)).sum().item()
260
261 # Count the accuracy per class.
262 batch_correct = (predicted == labels.to(local_device)).squeeze()
263 for i in range(test_batch_size):
264 label = labels[i]
265 class_correct[label] += batch_correct[i].item()
266 class_total[label] += 1
267
268 if model_engine.local_rank == 0:
269 print(
270 f"Accuracy of the network on the {total} test images: {100 * correct / total : .0f} %"
271 )
272
273 # For all classes, print the accuracy.
274 for i in range(10):
275 print(
276 f"Accuracy of {classes[i] : >5s} : {100 * class_correct[i] / class_total[i] : 2.0f} %"
277 )
278
279
280def main(args):
281 # Initialize DeepSpeed distributed backend.
282 deepspeed.init_distributed()
283 _local_rank = int(os.environ.get("LOCAL_RANK"))
284 get_accelerator().set_device(_local_rank)
285
286 ########################################################################
287 # Step1. Data Preparation.
288 #
289 # The output of torchvision datasets are PILImage images of range [0, 1].
290 # We transform them to Tensors of normalized range [-1, 1].
291 #
292 # Note:
293 # If running on Windows and you get a BrokenPipeError, try setting
294 # the num_worker of torch.utils.data.DataLoader() to 0.
295 ########################################################################
296 transform = transforms.Compose(
297 [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
298 )
299
300 if torch.distributed.get_rank() != 0:
301 # Might be downloading cifar data, let rank 0 download first.
302 torch.distributed.barrier()
303
304 # Load or download cifar data.
305 trainset = torchvision.datasets.CIFAR10(
306 root="./data", train=True, download=True, transform=transform
307 )
308 testset = torchvision.datasets.CIFAR10(
309 root="./data", train=False, download=True, transform=transform
310 )
311
312 if torch.distributed.get_rank() == 0:
313 # Cifar data is downloaded, indicate other ranks can proceed.
314 torch.distributed.barrier()
315
316 ########################################################################
317 # Step 2. Define the network with DeepSpeed.
318 #
319 # First, we define a Convolution Neural Network.
320 # Then, we define the DeepSpeed configuration dictionary and use it to
321 # initialize the DeepSpeed engine.
322 ########################################################################
323 net = Net(args)
324
325 # Get list of parameters that require gradients.
326 parameters = filter(lambda p: p.requires_grad, net.parameters())
327
328 # If using MoE, create separate param groups for each expert.
329 if args.moe_param_group:
330 parameters = create_moe_param_groups(net)
331
332 # Initialize DeepSpeed to use the following features.
333 # 1) Distributed model.
334 # 2) Distributed data loader.
335 # 3) DeepSpeed optimizer.
336 ds_config = get_ds_config(args)
337 model_engine, optimizer, trainloader, __ = deepspeed.initialize(
338 args=args,
339 model=net,
340 model_parameters=parameters,
341 training_data=trainset,
342 config=ds_config,
343 )
344
345 # Get the local device name (str) and local rank (int).
346 local_device = get_accelerator().device_name(model_engine.local_rank)
347 local_rank = model_engine.local_rank
348
349 # For float32, target_dtype will be None so no datatype conversion needed.
350 target_dtype = None
351 if model_engine.bfloat16_enabled():
352 target_dtype = torch.bfloat16
353 elif model_engine.fp16_enabled():
354 target_dtype = torch.half
355
356 # Define the Classification Cross-Entropy loss function.
357 criterion = nn.CrossEntropyLoss()
358
359 ########################################################################
360 # Step 3. Train the network.
361 #
362 # This is when things start to get interesting.
363 # We simply have to loop over our data iterator, and feed the inputs to the
364 # network and optimize. (DeepSpeed handles the distributed details for us!)
365 ########################################################################
366
367 for epoch in range(args.epochs): # loop over the dataset multiple times
368 running_loss = 0.0
369 for i, data in enumerate(trainloader):
370 # Get the inputs. ``data`` is a list of [inputs, labels].
371 inputs, labels = data[0].to(local_device), data[1].to(local_device)
372
373 # Try to convert to target_dtype if needed.
374 if target_dtype != None:
375 inputs = inputs.to(target_dtype)
376
377 outputs = model_engine(inputs)
378 loss = criterion(outputs, labels)
379
380 model_engine.backward(loss)
381 model_engine.step()
382
383 # Print statistics
384 running_loss += loss.item()
385 if local_rank == 0 and i % args.log_interval == (
386 args.log_interval - 1
387 ): # Print every log_interval mini-batches.
388 print(
389 f"[{epoch + 1 : d}, {i + 1 : 5d}] loss: {running_loss / args.log_interval : .3f}"
390 )
391 running_loss = 0.0
392 print("Finished Training")
393
394 ########################################################################
395 # Step 4. Test the network on the test data.
396 ########################################################################
397 test(model_engine, testset, local_device, target_dtype)
398
399
400if __name__ == "__main__":
401 args = add_argument()
402 main(args)
3. 使用Transforms进行模型微调
以下代码使用了Transforms对LLM进行微调(来自 transforms examples),自transforms xxx版本以及accelerator 0.21.0版本以后,代码无需任何修改,即可自动检测NPU并进行。
1#!/usr/bin/env python
2# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15"""
16Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
17
18Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
19https://huggingface.co/models?filter=text-generation
20"""
21# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
22
23import logging
24import math
25import os
26import sys
27from dataclasses import dataclass, field
28from itertools import chain
29from typing import Optional
30
31import datasets
32import evaluate
33import torch
34from datasets import load_dataset
35
36import transformers
37from transformers import (
38 CONFIG_MAPPING,
39 MODEL_FOR_CAUSAL_LM_MAPPING,
40 AutoConfig,
41 AutoModelForCausalLM,
42 AutoTokenizer,
43 HfArgumentParser,
44 Trainer,
45 TrainingArguments,
46 default_data_collator,
47 is_torch_xla_available,
48 set_seed,
49)
50from transformers.testing_utils import CaptureLogger
51from transformers.trainer_utils import get_last_checkpoint
52from transformers.utils import check_min_version, send_example_telemetry
53from transformers.utils.versions import require_version
54
55
56# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
57check_min_version("4.52.0.dev0")
58
59require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
60
61logger = logging.getLogger(__name__)
62
63
64MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
65MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
66
67
68@dataclass
69class ModelArguments:
70 """
71 Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
72 """
73
74 model_name_or_path: Optional[str] = field(
75 default=None,
76 metadata={
77 "help": (
78 "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
79 )
80 },
81 )
82 model_type: Optional[str] = field(
83 default=None,
84 metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
85 )
86 config_overrides: Optional[str] = field(
87 default=None,
88 metadata={
89 "help": (
90 "Override some existing default config settings when a model is trained from scratch. Example: "
91 "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
92 )
93 },
94 )
95 config_name: Optional[str] = field(
96 default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
97 )
98 tokenizer_name: Optional[str] = field(
99 default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
100 )
101 cache_dir: Optional[str] = field(
102 default=None,
103 metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
104 )
105 use_fast_tokenizer: bool = field(
106 default=True,
107 metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
108 )
109 model_revision: str = field(
110 default="main",
111 metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
112 )
113 token: str = field(
114 default=None,
115 metadata={
116 "help": (
117 "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
118 "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
119 )
120 },
121 )
122 trust_remote_code: bool = field(
123 default=False,
124 metadata={
125 "help": (
126 "Whether to trust the execution of code from datasets/models defined on the Hub."
127 " This option should only be set to `True` for repositories you trust and in which you have read the"
128 " code, as it will execute code present on the Hub on your local machine."
129 )
130 },
131 )
132 torch_dtype: Optional[str] = field(
133 default=None,
134 metadata={
135 "help": (
136 "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
137 "dtype will be automatically derived from the model's weights."
138 ),
139 "choices": ["auto", "bfloat16", "float16", "float32"],
140 },
141 )
142 low_cpu_mem_usage: bool = field(
143 default=False,
144 metadata={
145 "help": (
146 "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
147 "set True will benefit LLM loading time and RAM consumption."
148 )
149 },
150 )
151
152 def __post_init__(self):
153 if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
154 raise ValueError(
155 "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
156 )
157
158
159@dataclass
160class DataTrainingArguments:
161 """
162 Arguments pertaining to what data we are going to input our model for training and eval.
163 """
164
165 dataset_name: Optional[str] = field(
166 default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
167 )
168 dataset_config_name: Optional[str] = field(
169 default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
170 )
171 train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
172 validation_file: Optional[str] = field(
173 default=None,
174 metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
175 )
176 max_train_samples: Optional[int] = field(
177 default=None,
178 metadata={
179 "help": (
180 "For debugging purposes or quicker training, truncate the number of training examples to this "
181 "value if set."
182 )
183 },
184 )
185 max_eval_samples: Optional[int] = field(
186 default=None,
187 metadata={
188 "help": (
189 "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
190 "value if set."
191 )
192 },
193 )
194 streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
195 block_size: Optional[int] = field(
196 default=None,
197 metadata={
198 "help": (
199 "Optional input sequence length after tokenization. "
200 "The training dataset will be truncated in block of this size for training. "
201 "Default to the model max input length for single sentence inputs (take into account special tokens)."
202 )
203 },
204 )
205 overwrite_cache: bool = field(
206 default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
207 )
208 validation_split_percentage: Optional[int] = field(
209 default=5,
210 metadata={
211 "help": "The percentage of the train set used as validation set in case there's no validation split"
212 },
213 )
214 preprocessing_num_workers: Optional[int] = field(
215 default=None,
216 metadata={"help": "The number of processes to use for the preprocessing."},
217 )
218 keep_linebreaks: bool = field(
219 default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
220 )
221
222 def __post_init__(self):
223 if self.streaming:
224 require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
225
226 if self.dataset_name is None and self.train_file is None and self.validation_file is None:
227 raise ValueError("Need either a dataset name or a training/validation file.")
228 else:
229 if self.train_file is not None:
230 extension = self.train_file.split(".")[-1]
231 assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
232 if self.validation_file is not None:
233 extension = self.validation_file.split(".")[-1]
234 assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
235
236
237def main():
238 # See all possible arguments in src/transformers/training_args.py
239 # or by passing the --help flag to this script.
240 # We now keep distinct sets of args, for a cleaner separation of concerns.
241
242 parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
243 if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
244 # If we pass only one argument to the script and it's the path to a json file,
245 # let's parse it to get our arguments.
246 model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
247 else:
248 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
249
250 # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
251 # information sent is the one passed as arguments along with your Python/PyTorch versions.
252 send_example_telemetry("run_clm", model_args, data_args)
253
254 # Setup logging
255 logging.basicConfig(
256 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
257 datefmt="%m/%d/%Y %H:%M:%S",
258 handlers=[logging.StreamHandler(sys.stdout)],
259 )
260
261 if training_args.should_log:
262 # The default of training_args.log_level is passive, so we set log level at info here to have that default.
263 transformers.utils.logging.set_verbosity_info()
264
265 log_level = training_args.get_process_log_level()
266 logger.setLevel(log_level)
267 datasets.utils.logging.set_verbosity(log_level)
268 transformers.utils.logging.set_verbosity(log_level)
269 transformers.utils.logging.enable_default_handler()
270 transformers.utils.logging.enable_explicit_format()
271
272 # Log on each process the small summary:
273 logger.warning(
274 f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
275 + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
276 )
277 logger.info(f"Training/evaluation parameters {training_args}")
278
279 # Detecting last checkpoint.
280 last_checkpoint = None
281 if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
282 last_checkpoint = get_last_checkpoint(training_args.output_dir)
283 if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
284 raise ValueError(
285 f"Output directory ({training_args.output_dir}) already exists and is not empty. "
286 "Use --overwrite_output_dir to overcome."
287 )
288 elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
289 logger.info(
290 f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
291 "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
292 )
293
294 # Set seed before initializing model.
295 set_seed(training_args.seed)
296
297 # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
298 # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
299 # (the dataset will be downloaded automatically from the datasets Hub).
300 #
301 # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
302 # 'text' is found. You can easily tweak this behavior (see below).
303 #
304 # In distributed training, the load_dataset function guarantee that only one local process can concurrently
305 # download the dataset.
306 if data_args.dataset_name is not None:
307 # Downloading and loading a dataset from the hub.
308 raw_datasets = load_dataset(
309 data_args.dataset_name,
310 data_args.dataset_config_name,
311 cache_dir=model_args.cache_dir,
312 token=model_args.token,
313 streaming=data_args.streaming,
314 trust_remote_code=model_args.trust_remote_code,
315 )
316 if "validation" not in raw_datasets.keys():
317 raw_datasets["validation"] = load_dataset(
318 data_args.dataset_name,
319 data_args.dataset_config_name,
320 split=f"train[:{data_args.validation_split_percentage}%]",
321 cache_dir=model_args.cache_dir,
322 token=model_args.token,
323 streaming=data_args.streaming,
324 trust_remote_code=model_args.trust_remote_code,
325 )
326 raw_datasets["train"] = load_dataset(
327 data_args.dataset_name,
328 data_args.dataset_config_name,
329 split=f"train[{data_args.validation_split_percentage}%:]",
330 cache_dir=model_args.cache_dir,
331 token=model_args.token,
332 streaming=data_args.streaming,
333 trust_remote_code=model_args.trust_remote_code,
334 )
335 else:
336 data_files = {}
337 dataset_args = {}
338 if data_args.train_file is not None:
339 data_files["train"] = data_args.train_file
340 if data_args.validation_file is not None:
341 data_files["validation"] = data_args.validation_file
342 extension = (
343 data_args.train_file.split(".")[-1]
344 if data_args.train_file is not None
345 else data_args.validation_file.split(".")[-1]
346 )
347 if extension == "txt":
348 extension = "text"
349 dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
350 raw_datasets = load_dataset(
351 extension,
352 data_files=data_files,
353 cache_dir=model_args.cache_dir,
354 token=model_args.token,
355 **dataset_args,
356 )
357 # If no validation data is there, validation_split_percentage will be used to divide the dataset.
358 if "validation" not in raw_datasets.keys():
359 raw_datasets["validation"] = load_dataset(
360 extension,
361 data_files=data_files,
362 split=f"train[:{data_args.validation_split_percentage}%]",
363 cache_dir=model_args.cache_dir,
364 token=model_args.token,
365 **dataset_args,
366 )
367 raw_datasets["train"] = load_dataset(
368 extension,
369 data_files=data_files,
370 split=f"train[{data_args.validation_split_percentage}%:]",
371 cache_dir=model_args.cache_dir,
372 token=model_args.token,
373 **dataset_args,
374 )
375
376 # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
377 # https://huggingface.co/docs/datasets/loading_datasets.
378
379 # Load pretrained model and tokenizer
380 #
381 # Distributed training:
382 # The .from_pretrained methods guarantee that only one local process can concurrently
383 # download model & vocab.
384
385 config_kwargs = {
386 "cache_dir": model_args.cache_dir,
387 "revision": model_args.model_revision,
388 "token": model_args.token,
389 "trust_remote_code": model_args.trust_remote_code,
390 }
391 if model_args.config_name:
392 config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
393 elif model_args.model_name_or_path:
394 config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
395 else:
396 config = CONFIG_MAPPING[model_args.model_type]()
397 logger.warning("You are instantiating a new config instance from scratch.")
398 if model_args.config_overrides is not None:
399 logger.info(f"Overriding config: {model_args.config_overrides}")
400 config.update_from_string(model_args.config_overrides)
401 logger.info(f"New config: {config}")
402
403 tokenizer_kwargs = {
404 "cache_dir": model_args.cache_dir,
405 "use_fast": model_args.use_fast_tokenizer,
406 "revision": model_args.model_revision,
407 "token": model_args.token,
408 "trust_remote_code": model_args.trust_remote_code,
409 }
410 if model_args.tokenizer_name:
411 tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
412 elif model_args.model_name_or_path:
413 tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
414 else:
415 raise ValueError(
416 "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
417 "You can do it from another script, save it, and load it from here, using --tokenizer_name."
418 )
419
420 if model_args.model_name_or_path:
421 torch_dtype = (
422 model_args.torch_dtype
423 if model_args.torch_dtype in ["auto", None]
424 else getattr(torch, model_args.torch_dtype)
425 )
426 model = AutoModelForCausalLM.from_pretrained(
427 model_args.model_name_or_path,
428 from_tf=bool(".ckpt" in model_args.model_name_or_path),
429 config=config,
430 cache_dir=model_args.cache_dir,
431 revision=model_args.model_revision,
432 token=model_args.token,
433 trust_remote_code=model_args.trust_remote_code,
434 torch_dtype=torch_dtype,
435 low_cpu_mem_usage=model_args.low_cpu_mem_usage,
436 )
437 else:
438 model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
439 n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
440 logger.info(f"Training new model from scratch - Total size={n_params / 2**20:.2f}M params")
441
442 # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
443 # on a small vocab and want a smaller embedding size, remove this test.
444 embedding_size = model.get_input_embeddings().weight.shape[0]
445 if len(tokenizer) > embedding_size:
446 model.resize_token_embeddings(len(tokenizer))
447
448 # Preprocessing the datasets.
449 # First we tokenize all the texts.
450 if training_args.do_train:
451 column_names = list(raw_datasets["train"].features)
452 else:
453 column_names = list(raw_datasets["validation"].features)
454 text_column_name = "text" if "text" in column_names else column_names[0]
455
456 # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
457 tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
458
459 def tokenize_function(examples):
460 with CaptureLogger(tok_logger) as cl:
461 output = tokenizer(examples[text_column_name])
462 # clm input could be much much longer than block_size
463 if "Token indices sequence length is longer than the" in cl.out:
464 tok_logger.warning(
465 "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
466 " before being passed to the model."
467 )
468 return output
469
470 with training_args.main_process_first(desc="dataset map tokenization"):
471 if not data_args.streaming:
472 tokenized_datasets = raw_datasets.map(
473 tokenize_function,
474 batched=True,
475 num_proc=data_args.preprocessing_num_workers,
476 remove_columns=column_names,
477 load_from_cache_file=not data_args.overwrite_cache,
478 desc="Running tokenizer on dataset",
479 )
480 else:
481 tokenized_datasets = raw_datasets.map(
482 tokenize_function,
483 batched=True,
484 remove_columns=column_names,
485 )
486 if hasattr(config, "max_position_embeddings"):
487 max_pos_embeddings = config.max_position_embeddings
488 else:
489 # Define a default value if the attribute is missing in the config.
490 max_pos_embeddings = 1024
491
492 if data_args.block_size is None:
493 block_size = tokenizer.model_max_length
494 if block_size > max_pos_embeddings:
495 logger.warning(
496 f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
497 f"Using block_size={min(1024, max_pos_embeddings)} instead. You can change that default value by passing --block_size xxx."
498 )
499 if max_pos_embeddings > 0:
500 block_size = min(1024, max_pos_embeddings)
501 else:
502 block_size = 1024
503 else:
504 if data_args.block_size > tokenizer.model_max_length:
505 logger.warning(
506 f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model "
507 f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
508 )
509 block_size = min(data_args.block_size, tokenizer.model_max_length)
510
511 # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
512 def group_texts(examples):
513 # Concatenate all texts.
514 concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
515 total_length = len(concatenated_examples[list(examples.keys())[0]])
516 # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
517 # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
518 total_length = (total_length // block_size) * block_size
519 # Split by chunks of max_len.
520 result = {
521 k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
522 for k, t in concatenated_examples.items()
523 }
524 result["labels"] = result["input_ids"].copy()
525 return result
526
527 # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
528 # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
529 # to preprocess.
530 #
531 # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
532 # https://huggingface.co/docs/datasets/process#map
533
534 with training_args.main_process_first(desc="grouping texts together"):
535 if not data_args.streaming:
536 lm_datasets = tokenized_datasets.map(
537 group_texts,
538 batched=True,
539 num_proc=data_args.preprocessing_num_workers,
540 load_from_cache_file=not data_args.overwrite_cache,
541 desc=f"Grouping texts in chunks of {block_size}",
542 )
543 else:
544 lm_datasets = tokenized_datasets.map(
545 group_texts,
546 batched=True,
547 )
548
549 if training_args.do_train:
550 if "train" not in tokenized_datasets:
551 raise ValueError("--do_train requires a train dataset")
552 train_dataset = lm_datasets["train"]
553 if data_args.max_train_samples is not None:
554 max_train_samples = min(len(train_dataset), data_args.max_train_samples)
555 train_dataset = train_dataset.select(range(max_train_samples))
556
557 if training_args.do_eval:
558 if "validation" not in tokenized_datasets:
559 raise ValueError("--do_eval requires a validation dataset")
560 eval_dataset = lm_datasets["validation"]
561 if data_args.max_eval_samples is not None:
562 max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
563 eval_dataset = eval_dataset.select(range(max_eval_samples))
564
565 def preprocess_logits_for_metrics(logits, labels):
566 if isinstance(logits, tuple):
567 # Depending on the model and config, logits may contain extra tensors,
568 # like past_key_values, but logits always come first
569 logits = logits[0]
570 return logits.argmax(dim=-1)
571
572 metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
573
574 def compute_metrics(eval_preds):
575 preds, labels = eval_preds
576 # preds have the same shape as the labels, after the argmax(-1) has been calculated
577 # by preprocess_logits_for_metrics but we need to shift the labels
578 labels = labels[:, 1:].reshape(-1)
579 preds = preds[:, :-1].reshape(-1)
580 return metric.compute(predictions=preds, references=labels)
581
582 # Initialize our Trainer
583 trainer = Trainer(
584 model=model,
585 args=training_args,
586 train_dataset=train_dataset if training_args.do_train else None,
587 eval_dataset=eval_dataset if training_args.do_eval else None,
588 processing_class=tokenizer,
589 # Data collator will default to DataCollatorWithPadding, so we change it.
590 data_collator=default_data_collator,
591 compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
592 preprocess_logits_for_metrics=preprocess_logits_for_metrics
593 if training_args.do_eval and not is_torch_xla_available()
594 else None,
595 )
596
597 # Training
598 if training_args.do_train:
599 checkpoint = None
600 if training_args.resume_from_checkpoint is not None:
601 checkpoint = training_args.resume_from_checkpoint
602 elif last_checkpoint is not None:
603 checkpoint = last_checkpoint
604 train_result = trainer.train(resume_from_checkpoint=checkpoint)
605 trainer.save_model() # Saves the tokenizer too for easy upload
606
607 metrics = train_result.metrics
608
609 max_train_samples = (
610 data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
611 )
612 metrics["train_samples"] = min(max_train_samples, len(train_dataset))
613
614 trainer.log_metrics("train", metrics)
615 trainer.save_metrics("train", metrics)
616 trainer.save_state()
617
618 # Evaluation
619 if training_args.do_eval:
620 logger.info("*** Evaluate ***")
621
622 metrics = trainer.evaluate()
623
624 max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
625 metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
626 try:
627 perplexity = math.exp(metrics["eval_loss"])
628 except OverflowError:
629 perplexity = float("inf")
630 metrics["perplexity"] = perplexity
631
632 trainer.log_metrics("eval", metrics)
633 trainer.save_metrics("eval", metrics)
634
635 kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
636 if data_args.dataset_name is not None:
637 kwargs["dataset_tags"] = data_args.dataset_name
638 if data_args.dataset_config_name is not None:
639 kwargs["dataset_args"] = data_args.dataset_config_name
640 kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
641 else:
642 kwargs["dataset"] = data_args.dataset_name
643
644 if training_args.push_to_hub:
645 trainer.push_to_hub(**kwargs)
646 else:
647 trainer.create_model_card(**kwargs)
648
649
650def _mp_fn(index):
651 # For xla_spawn (TPUs)
652 main()
653
654
655if __name__ == "__main__":
656 main()
1python run_clm.py \
2 --model_name_or_path openai-community/gpt2 \
3 --train_file path_to_train_file \
4 --validation_file path_to_validation_file \
5 --per_device_train_batch_size 8 \
6 --per_device_eval_batch_size 8 \
7 --do_train \
8 --do_eval \
9 --output_dir /tmp/test-clm
4. 使用Diffusers进行模型微调
以下代码使用了Diffusers对文生图模型进行微调(来自 diffusers examples),自diffusers v0.27.0版本以后,代码无需任何修改,即可自动检测NPU并进行。
1#!/usr/bin/env python
2# coding=utf-8
3# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import argparse
18import logging
19import math
20import os
21import random
22import shutil
23from contextlib import nullcontext
24from pathlib import Path
25
26import accelerate
27import datasets
28import numpy as np
29import torch
30import torch.nn.functional as F
31import torch.utils.checkpoint
32import transformers
33from accelerate import Accelerator
34from accelerate.logging import get_logger
35from accelerate.state import AcceleratorState
36from accelerate.utils import ProjectConfiguration, set_seed
37from datasets import load_dataset
38from huggingface_hub import create_repo, upload_folder
39from packaging import version
40from torchvision import transforms
41from tqdm.auto import tqdm
42from transformers import CLIPTextModel, CLIPTokenizer
43from transformers.utils import ContextManagers
44
45import diffusers
46from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
47from diffusers.optimization import get_scheduler
48from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
49from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
50from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
51from diffusers.utils.import_utils import is_xformers_available
52from diffusers.utils.torch_utils import is_compiled_module
53
54
55if is_wandb_available():
56 import wandb
57
58
59# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
60check_min_version("0.34.0.dev0")
61
62logger = get_logger(__name__, log_level="INFO")
63
64DATASET_NAME_MAPPING = {
65 "lambdalabs/naruto-blip-captions": ("image", "text"),
66}
67
68
69def save_model_card(
70 args,
71 repo_id: str,
72 images: list = None,
73 repo_folder: str = None,
74):
75 img_str = ""
76 if len(images) > 0:
77 image_grid = make_image_grid(images, 1, len(args.validation_prompts))
78 image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
79 img_str += "\n"
80
81 model_description = f"""
82# Text-to-image finetuning - {repo_id}
83
84This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
85{img_str}
86
87## Pipeline usage
88
89You can use the pipeline like so:
90
91```python
92from diffusers import DiffusionPipeline
93import torch
94
95pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
96prompt = "{args.validation_prompts[0]}"
97image = pipeline(prompt).images[0]
98image.save("my_image.png")
99```
100
101## Training info
102
103These are the key hyperparameters used during training:
104
105* Epochs: {args.num_train_epochs}
106* Learning rate: {args.learning_rate}
107* Batch size: {args.train_batch_size}
108* Gradient accumulation steps: {args.gradient_accumulation_steps}
109* Image resolution: {args.resolution}
110* Mixed-precision: {args.mixed_precision}
111
112"""
113 wandb_info = ""
114 if is_wandb_available():
115 wandb_run_url = None
116 if wandb.run is not None:
117 wandb_run_url = wandb.run.url
118
119 if wandb_run_url is not None:
120 wandb_info = f"""
121More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
122"""
123
124 model_description += wandb_info
125
126 model_card = load_or_create_model_card(
127 repo_id_or_path=repo_id,
128 from_training=True,
129 license="creativeml-openrail-m",
130 base_model=args.pretrained_model_name_or_path,
131 model_description=model_description,
132 inference=True,
133 )
134
135 tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
136 model_card = populate_model_card(model_card, tags=tags)
137
138 model_card.save(os.path.join(repo_folder, "README.md"))
139
140
141def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
142 logger.info("Running validation... ")
143
144 pipeline = StableDiffusionPipeline.from_pretrained(
145 args.pretrained_model_name_or_path,
146 vae=accelerator.unwrap_model(vae),
147 text_encoder=accelerator.unwrap_model(text_encoder),
148 tokenizer=tokenizer,
149 unet=accelerator.unwrap_model(unet),
150 safety_checker=None,
151 revision=args.revision,
152 variant=args.variant,
153 torch_dtype=weight_dtype,
154 )
155 pipeline = pipeline.to(accelerator.device)
156 pipeline.set_progress_bar_config(disable=True)
157
158 if args.enable_xformers_memory_efficient_attention:
159 pipeline.enable_xformers_memory_efficient_attention()
160
161 if args.seed is None:
162 generator = None
163 else:
164 generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
165
166 images = []
167 for i in range(len(args.validation_prompts)):
168 if torch.backends.mps.is_available():
169 autocast_ctx = nullcontext()
170 else:
171 autocast_ctx = torch.autocast(accelerator.device.type)
172
173 with autocast_ctx:
174 image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
175
176 images.append(image)
177
178 for tracker in accelerator.trackers:
179 if tracker.name == "tensorboard":
180 np_images = np.stack([np.asarray(img) for img in images])
181 tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
182 elif tracker.name == "wandb":
183 tracker.log(
184 {
185 "validation": [
186 wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
187 for i, image in enumerate(images)
188 ]
189 }
190 )
191 else:
192 logger.warning(f"image logging not implemented for {tracker.name}")
193
194 del pipeline
195 torch.cuda.empty_cache()
196
197 return images
198
199
200def parse_args():
201 parser = argparse.ArgumentParser(description="Simple example of a training script.")
202 parser.add_argument(
203 "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
204 )
205 parser.add_argument(
206 "--pretrained_model_name_or_path",
207 type=str,
208 default=None,
209 required=True,
210 help="Path to pretrained model or model identifier from huggingface.co/models.",
211 )
212 parser.add_argument(
213 "--revision",
214 type=str,
215 default=None,
216 required=False,
217 help="Revision of pretrained model identifier from huggingface.co/models.",
218 )
219 parser.add_argument(
220 "--variant",
221 type=str,
222 default=None,
223 help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
224 )
225 parser.add_argument(
226 "--dataset_name",
227 type=str,
228 default=None,
229 help=(
230 "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
231 " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
232 " or to a folder containing files that 🤗 Datasets can understand."
233 ),
234 )
235 parser.add_argument(
236 "--dataset_config_name",
237 type=str,
238 default=None,
239 help="The config of the Dataset, leave as None if there's only one config.",
240 )
241 parser.add_argument(
242 "--train_data_dir",
243 type=str,
244 default=None,
245 help=(
246 "A folder containing the training data. Folder contents must follow the structure described in"
247 " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
248 " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
249 ),
250 )
251 parser.add_argument(
252 "--image_column", type=str, default="image", help="The column of the dataset containing an image."
253 )
254 parser.add_argument(
255 "--caption_column",
256 type=str,
257 default="text",
258 help="The column of the dataset containing a caption or a list of captions.",
259 )
260 parser.add_argument(
261 "--max_train_samples",
262 type=int,
263 default=None,
264 help=(
265 "For debugging purposes or quicker training, truncate the number of training examples to this "
266 "value if set."
267 ),
268 )
269 parser.add_argument(
270 "--validation_prompts",
271 type=str,
272 default=None,
273 nargs="+",
274 help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
275 )
276 parser.add_argument(
277 "--output_dir",
278 type=str,
279 default="sd-model-finetuned",
280 help="The output directory where the model predictions and checkpoints will be written.",
281 )
282 parser.add_argument(
283 "--cache_dir",
284 type=str,
285 default=None,
286 help="The directory where the downloaded models and datasets will be stored.",
287 )
288 parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
289 parser.add_argument(
290 "--resolution",
291 type=int,
292 default=512,
293 help=(
294 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
295 " resolution"
296 ),
297 )
298 parser.add_argument(
299 "--center_crop",
300 default=False,
301 action="store_true",
302 help=(
303 "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
304 " cropped. The images will be resized to the resolution first before cropping."
305 ),
306 )
307 parser.add_argument(
308 "--random_flip",
309 action="store_true",
310 help="whether to randomly flip images horizontally",
311 )
312 parser.add_argument(
313 "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
314 )
315 parser.add_argument("--num_train_epochs", type=int, default=100)
316 parser.add_argument(
317 "--max_train_steps",
318 type=int,
319 default=None,
320 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
321 )
322 parser.add_argument(
323 "--gradient_accumulation_steps",
324 type=int,
325 default=1,
326 help="Number of updates steps to accumulate before performing a backward/update pass.",
327 )
328 parser.add_argument(
329 "--gradient_checkpointing",
330 action="store_true",
331 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
332 )
333 parser.add_argument(
334 "--learning_rate",
335 type=float,
336 default=1e-4,
337 help="Initial learning rate (after the potential warmup period) to use.",
338 )
339 parser.add_argument(
340 "--scale_lr",
341 action="store_true",
342 default=False,
343 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
344 )
345 parser.add_argument(
346 "--lr_scheduler",
347 type=str,
348 default="constant",
349 help=(
350 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
351 ' "constant", "constant_with_warmup"]'
352 ),
353 )
354 parser.add_argument(
355 "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
356 )
357 parser.add_argument(
358 "--snr_gamma",
359 type=float,
360 default=None,
361 help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
362 "More details here: https://arxiv.org/abs/2303.09556.",
363 )
364 parser.add_argument(
365 "--dream_training",
366 action="store_true",
367 help=(
368 "Use the DREAM training method, which makes training more efficient and accurate at the "
369 "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210"
370 ),
371 )
372 parser.add_argument(
373 "--dream_detail_preservation",
374 type=float,
375 default=1.0,
376 help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)",
377 )
378 parser.add_argument(
379 "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
380 )
381 parser.add_argument(
382 "--allow_tf32",
383 action="store_true",
384 help=(
385 "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
386 " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
387 ),
388 )
389 parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
390 parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.")
391 parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.")
392 parser.add_argument(
393 "--non_ema_revision",
394 type=str,
395 default=None,
396 required=False,
397 help=(
398 "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
399 " remote repository specified with --pretrained_model_name_or_path."
400 ),
401 )
402 parser.add_argument(
403 "--dataloader_num_workers",
404 type=int,
405 default=0,
406 help=(
407 "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
408 ),
409 )
410 parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
411 parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
412 parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
413 parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
414 parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
415 parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
416 parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
417 parser.add_argument(
418 "--prediction_type",
419 type=str,
420 default=None,
421 help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
422 )
423 parser.add_argument(
424 "--hub_model_id",
425 type=str,
426 default=None,
427 help="The name of the repository to keep in sync with the local `output_dir`.",
428 )
429 parser.add_argument(
430 "--logging_dir",
431 type=str,
432 default="logs",
433 help=(
434 "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
435 " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
436 ),
437 )
438 parser.add_argument(
439 "--mixed_precision",
440 type=str,
441 default=None,
442 choices=["no", "fp16", "bf16"],
443 help=(
444 "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
445 " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
446 " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
447 ),
448 )
449 parser.add_argument(
450 "--report_to",
451 type=str,
452 default="tensorboard",
453 help=(
454 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
455 ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
456 ),
457 )
458 parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
459 parser.add_argument(
460 "--checkpointing_steps",
461 type=int,
462 default=500,
463 help=(
464 "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
465 " training using `--resume_from_checkpoint`."
466 ),
467 )
468 parser.add_argument(
469 "--checkpoints_total_limit",
470 type=int,
471 default=None,
472 help=("Max number of checkpoints to store."),
473 )
474 parser.add_argument(
475 "--resume_from_checkpoint",
476 type=str,
477 default=None,
478 help=(
479 "Whether training should be resumed from a previous checkpoint. Use a path saved by"
480 ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
481 ),
482 )
483 parser.add_argument(
484 "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
485 )
486 parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
487 parser.add_argument(
488 "--validation_epochs",
489 type=int,
490 default=5,
491 help="Run validation every X epochs.",
492 )
493 parser.add_argument(
494 "--tracker_project_name",
495 type=str,
496 default="text2image-fine-tune",
497 help=(
498 "The `project_name` argument passed to Accelerator.init_trackers for"
499 " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
500 ),
501 )
502 parser.add_argument(
503 "--image_interpolation_mode",
504 type=str,
505 default="lanczos",
506 choices=[
507 f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
508 ],
509 help="The image interpolation method to use for resizing images.",
510 )
511
512 args = parser.parse_args()
513 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
514 if env_local_rank != -1 and env_local_rank != args.local_rank:
515 args.local_rank = env_local_rank
516
517 # Sanity checks
518 if args.dataset_name is None and args.train_data_dir is None:
519 raise ValueError("Need either a dataset name or a training folder.")
520
521 # default to using the same revision for the non-ema model if not specified
522 if args.non_ema_revision is None:
523 args.non_ema_revision = args.revision
524
525 return args
526
527
528def main():
529 args = parse_args()
530
531 if args.report_to == "wandb" and args.hub_token is not None:
532 raise ValueError(
533 "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
534 " Please use `huggingface-cli login` to authenticate with the Hub."
535 )
536
537 if args.non_ema_revision is not None:
538 deprecate(
539 "non_ema_revision!=None",
540 "0.15.0",
541 message=(
542 "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
543 " use `--variant=non_ema` instead."
544 ),
545 )
546 logging_dir = os.path.join(args.output_dir, args.logging_dir)
547
548 accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
549
550 accelerator = Accelerator(
551 gradient_accumulation_steps=args.gradient_accumulation_steps,
552 mixed_precision=args.mixed_precision,
553 log_with=args.report_to,
554 project_config=accelerator_project_config,
555 )
556
557 # Disable AMP for MPS.
558 if torch.backends.mps.is_available():
559 accelerator.native_amp = False
560
561 # Make one log on every process with the configuration for debugging.
562 logging.basicConfig(
563 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
564 datefmt="%m/%d/%Y %H:%M:%S",
565 level=logging.INFO,
566 )
567 logger.info(accelerator.state, main_process_only=False)
568 if accelerator.is_local_main_process:
569 datasets.utils.logging.set_verbosity_warning()
570 transformers.utils.logging.set_verbosity_warning()
571 diffusers.utils.logging.set_verbosity_info()
572 else:
573 datasets.utils.logging.set_verbosity_error()
574 transformers.utils.logging.set_verbosity_error()
575 diffusers.utils.logging.set_verbosity_error()
576
577 # If passed along, set the training seed now.
578 if args.seed is not None:
579 set_seed(args.seed)
580
581 # Handle the repository creation
582 if accelerator.is_main_process:
583 if args.output_dir is not None:
584 os.makedirs(args.output_dir, exist_ok=True)
585
586 if args.push_to_hub:
587 repo_id = create_repo(
588 repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
589 ).repo_id
590
591 # Load scheduler, tokenizer and models.
592 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
593 tokenizer = CLIPTokenizer.from_pretrained(
594 args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
595 )
596
597 def deepspeed_zero_init_disabled_context_manager():
598 """
599 returns either a context list that includes one that will disable zero.Init or an empty context list
600 """
601 deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
602 if deepspeed_plugin is None:
603 return []
604
605 return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
606
607 # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
608 # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
609 # will try to assign the same optimizer with the same weights to all models during
610 # `deepspeed.initialize`, which of course doesn't work.
611 #
612 # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
613 # frozen models from being partitioned during `zero.Init` which gets called during
614 # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
615 # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
616 with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
617 text_encoder = CLIPTextModel.from_pretrained(
618 args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
619 )
620 vae = AutoencoderKL.from_pretrained(
621 args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
622 )
623
624 unet = UNet2DConditionModel.from_pretrained(
625 args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
626 )
627
628 # Freeze vae and text_encoder and set unet to trainable
629 vae.requires_grad_(False)
630 text_encoder.requires_grad_(False)
631 unet.train()
632
633 # Create EMA for the unet.
634 if args.use_ema:
635 ema_unet = UNet2DConditionModel.from_pretrained(
636 args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
637 )
638 ema_unet = EMAModel(
639 ema_unet.parameters(),
640 model_cls=UNet2DConditionModel,
641 model_config=ema_unet.config,
642 foreach=args.foreach_ema,
643 )
644
645 if args.enable_xformers_memory_efficient_attention:
646 if is_xformers_available():
647 import xformers
648
649 xformers_version = version.parse(xformers.__version__)
650 if xformers_version == version.parse("0.0.16"):
651 logger.warning(
652 "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
653 )
654 unet.enable_xformers_memory_efficient_attention()
655 else:
656 raise ValueError("xformers is not available. Make sure it is installed correctly")
657
658 # `accelerate` 0.16.0 will have better support for customized saving
659 if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
660 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
661 def save_model_hook(models, weights, output_dir):
662 if accelerator.is_main_process:
663 if args.use_ema:
664 ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
665
666 for i, model in enumerate(models):
667 model.save_pretrained(os.path.join(output_dir, "unet"))
668
669 # make sure to pop weight so that corresponding model is not saved again
670 weights.pop()
671
672 def load_model_hook(models, input_dir):
673 if args.use_ema:
674 load_model = EMAModel.from_pretrained(
675 os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema
676 )
677 ema_unet.load_state_dict(load_model.state_dict())
678 if args.offload_ema:
679 ema_unet.pin_memory()
680 else:
681 ema_unet.to(accelerator.device)
682 del load_model
683
684 for _ in range(len(models)):
685 # pop models so that they are not loaded again
686 model = models.pop()
687
688 # load diffusers style into model
689 load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
690 model.register_to_config(**load_model.config)
691
692 model.load_state_dict(load_model.state_dict())
693 del load_model
694
695 accelerator.register_save_state_pre_hook(save_model_hook)
696 accelerator.register_load_state_pre_hook(load_model_hook)
697
698 if args.gradient_checkpointing:
699 unet.enable_gradient_checkpointing()
700
701 # Enable TF32 for faster training on Ampere GPUs,
702 # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
703 if args.allow_tf32:
704 torch.backends.cuda.matmul.allow_tf32 = True
705
706 if args.scale_lr:
707 args.learning_rate = (
708 args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
709 )
710
711 # Initialize the optimizer
712 if args.use_8bit_adam:
713 try:
714 import bitsandbytes as bnb
715 except ImportError:
716 raise ImportError(
717 "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
718 )
719
720 optimizer_cls = bnb.optim.AdamW8bit
721 else:
722 optimizer_cls = torch.optim.AdamW
723
724 optimizer = optimizer_cls(
725 unet.parameters(),
726 lr=args.learning_rate,
727 betas=(args.adam_beta1, args.adam_beta2),
728 weight_decay=args.adam_weight_decay,
729 eps=args.adam_epsilon,
730 )
731
732 # Get the datasets: you can either provide your own training and evaluation files (see below)
733 # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
734
735 # In distributed training, the load_dataset function guarantees that only one local process can concurrently
736 # download the dataset.
737 if args.dataset_name is not None:
738 # Downloading and loading a dataset from the hub.
739 dataset = load_dataset(
740 args.dataset_name,
741 args.dataset_config_name,
742 cache_dir=args.cache_dir,
743 data_dir=args.train_data_dir,
744 )
745 else:
746 data_files = {}
747 if args.train_data_dir is not None:
748 data_files["train"] = os.path.join(args.train_data_dir, "**")
749 dataset = load_dataset(
750 "imagefolder",
751 data_files=data_files,
752 cache_dir=args.cache_dir,
753 )
754 # See more about loading custom images at
755 # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
756
757 # Preprocessing the datasets.
758 # We need to tokenize inputs and targets.
759 column_names = dataset["train"].column_names
760
761 # 6. Get the column names for input/target.
762 dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
763 if args.image_column is None:
764 image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
765 else:
766 image_column = args.image_column
767 if image_column not in column_names:
768 raise ValueError(
769 f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
770 )
771 if args.caption_column is None:
772 caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
773 else:
774 caption_column = args.caption_column
775 if caption_column not in column_names:
776 raise ValueError(
777 f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
778 )
779
780 # Preprocessing the datasets.
781 # We need to tokenize input captions and transform the images.
782 def tokenize_captions(examples, is_train=True):
783 captions = []
784 for caption in examples[caption_column]:
785 if isinstance(caption, str):
786 captions.append(caption)
787 elif isinstance(caption, (list, np.ndarray)):
788 # take a random caption if there are multiple
789 captions.append(random.choice(caption) if is_train else caption[0])
790 else:
791 raise ValueError(
792 f"Caption column `{caption_column}` should contain either strings or lists of strings."
793 )
794 inputs = tokenizer(
795 captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
796 )
797 return inputs.input_ids
798
799 # Get the specified interpolation method from the args
800 interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
801
802 # Raise an error if the interpolation method is invalid
803 if interpolation is None:
804 raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
805
806 # Data preprocessing transformations
807 train_transforms = transforms.Compose(
808 [
809 transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
810 transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
811 transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
812 transforms.ToTensor(),
813 transforms.Normalize([0.5], [0.5]),
814 ]
815 )
816
817 def preprocess_train(examples):
818 images = [image.convert("RGB") for image in examples[image_column]]
819 examples["pixel_values"] = [train_transforms(image) for image in images]
820 examples["input_ids"] = tokenize_captions(examples)
821 return examples
822
823 with accelerator.main_process_first():
824 if args.max_train_samples is not None:
825 dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
826 # Set the training transforms
827 train_dataset = dataset["train"].with_transform(preprocess_train)
828
829 def collate_fn(examples):
830 pixel_values = torch.stack([example["pixel_values"] for example in examples])
831 pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
832 input_ids = torch.stack([example["input_ids"] for example in examples])
833 return {"pixel_values": pixel_values, "input_ids": input_ids}
834
835 # DataLoaders creation:
836 train_dataloader = torch.utils.data.DataLoader(
837 train_dataset,
838 shuffle=True,
839 collate_fn=collate_fn,
840 batch_size=args.train_batch_size,
841 num_workers=args.dataloader_num_workers,
842 )
843
844 # Scheduler and math around the number of training steps.
845 # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
846 num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
847 if args.max_train_steps is None:
848 len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
849 num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
850 num_training_steps_for_scheduler = (
851 args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
852 )
853 else:
854 num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
855
856 lr_scheduler = get_scheduler(
857 args.lr_scheduler,
858 optimizer=optimizer,
859 num_warmup_steps=num_warmup_steps_for_scheduler,
860 num_training_steps=num_training_steps_for_scheduler,
861 )
862
863 # Prepare everything with our `accelerator`.
864 unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
865 unet, optimizer, train_dataloader, lr_scheduler
866 )
867
868 if args.use_ema:
869 if args.offload_ema:
870 ema_unet.pin_memory()
871 else:
872 ema_unet.to(accelerator.device)
873
874 # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
875 # as these weights are only used for inference, keeping weights in full precision is not required.
876 weight_dtype = torch.float32
877 if accelerator.mixed_precision == "fp16":
878 weight_dtype = torch.float16
879 args.mixed_precision = accelerator.mixed_precision
880 elif accelerator.mixed_precision == "bf16":
881 weight_dtype = torch.bfloat16
882 args.mixed_precision = accelerator.mixed_precision
883
884 # Move text_encode and vae to gpu and cast to weight_dtype
885 text_encoder.to(accelerator.device, dtype=weight_dtype)
886 vae.to(accelerator.device, dtype=weight_dtype)
887
888 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
889 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
890 if args.max_train_steps is None:
891 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
892 if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
893 logger.warning(
894 f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
895 f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
896 f"This inconsistency may result in the learning rate scheduler not functioning properly."
897 )
898 # Afterwards we recalculate our number of training epochs
899 args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
900
901 # We need to initialize the trackers we use, and also store our configuration.
902 # The trackers initializes automatically on the main process.
903 if accelerator.is_main_process:
904 tracker_config = dict(vars(args))
905 tracker_config.pop("validation_prompts")
906 accelerator.init_trackers(args.tracker_project_name, tracker_config)
907
908 # Function for unwrapping if model was compiled with `torch.compile`.
909 def unwrap_model(model):
910 model = accelerator.unwrap_model(model)
911 model = model._orig_mod if is_compiled_module(model) else model
912 return model
913
914 # Train!
915 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
916
917 logger.info("***** Running training *****")
918 logger.info(f" Num examples = {len(train_dataset)}")
919 logger.info(f" Num Epochs = {args.num_train_epochs}")
920 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
921 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
922 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
923 logger.info(f" Total optimization steps = {args.max_train_steps}")
924 global_step = 0
925 first_epoch = 0
926
927 # Potentially load in the weights and states from a previous save
928 if args.resume_from_checkpoint:
929 if args.resume_from_checkpoint != "latest":
930 path = os.path.basename(args.resume_from_checkpoint)
931 else:
932 # Get the most recent checkpoint
933 dirs = os.listdir(args.output_dir)
934 dirs = [d for d in dirs if d.startswith("checkpoint")]
935 dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
936 path = dirs[-1] if len(dirs) > 0 else None
937
938 if path is None:
939 accelerator.print(
940 f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
941 )
942 args.resume_from_checkpoint = None
943 initial_global_step = 0
944 else:
945 accelerator.print(f"Resuming from checkpoint {path}")
946 accelerator.load_state(os.path.join(args.output_dir, path))
947 global_step = int(path.split("-")[1])
948
949 initial_global_step = global_step
950 first_epoch = global_step // num_update_steps_per_epoch
951
952 else:
953 initial_global_step = 0
954
955 progress_bar = tqdm(
956 range(0, args.max_train_steps),
957 initial=initial_global_step,
958 desc="Steps",
959 # Only show the progress bar once on each machine.
960 disable=not accelerator.is_local_main_process,
961 )
962
963 for epoch in range(first_epoch, args.num_train_epochs):
964 train_loss = 0.0
965 for step, batch in enumerate(train_dataloader):
966 with accelerator.accumulate(unet):
967 # Convert images to latent space
968 latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
969 latents = latents * vae.config.scaling_factor
970
971 # Sample noise that we'll add to the latents
972 noise = torch.randn_like(latents)
973 if args.noise_offset:
974 # https://www.crosslabs.org//blog/diffusion-with-offset-noise
975 noise += args.noise_offset * torch.randn(
976 (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
977 )
978 if args.input_perturbation:
979 new_noise = noise + args.input_perturbation * torch.randn_like(noise)
980 bsz = latents.shape[0]
981 # Sample a random timestep for each image
982 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
983 timesteps = timesteps.long()
984
985 # Add noise to the latents according to the noise magnitude at each timestep
986 # (this is the forward diffusion process)
987 if args.input_perturbation:
988 noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
989 else:
990 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
991
992 # Get the text embedding for conditioning
993 encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
994
995 # Get the target for loss depending on the prediction type
996 if args.prediction_type is not None:
997 # set prediction_type of scheduler if defined
998 noise_scheduler.register_to_config(prediction_type=args.prediction_type)
999
1000 if noise_scheduler.config.prediction_type == "epsilon":
1001 target = noise
1002 elif noise_scheduler.config.prediction_type == "v_prediction":
1003 target = noise_scheduler.get_velocity(latents, noise, timesteps)
1004 else:
1005 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1006
1007 if args.dream_training:
1008 noisy_latents, target = compute_dream_and_update_latents(
1009 unet,
1010 noise_scheduler,
1011 timesteps,
1012 noise,
1013 noisy_latents,
1014 target,
1015 encoder_hidden_states,
1016 args.dream_detail_preservation,
1017 )
1018
1019 # Predict the noise residual and compute loss
1020 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
1021
1022 if args.snr_gamma is None:
1023 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1024 else:
1025 # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1026 # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1027 # This is discussed in Section 4.2 of the same paper.
1028 snr = compute_snr(noise_scheduler, timesteps)
1029 mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
1030 dim=1
1031 )[0]
1032 if noise_scheduler.config.prediction_type == "epsilon":
1033 mse_loss_weights = mse_loss_weights / snr
1034 elif noise_scheduler.config.prediction_type == "v_prediction":
1035 mse_loss_weights = mse_loss_weights / (snr + 1)
1036
1037 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1038 loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1039 loss = loss.mean()
1040
1041 # Gather the losses across all processes for logging (if we use distributed training).
1042 avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1043 train_loss += avg_loss.item() / args.gradient_accumulation_steps
1044
1045 # Backpropagate
1046 accelerator.backward(loss)
1047 if accelerator.sync_gradients:
1048 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
1049 optimizer.step()
1050 lr_scheduler.step()
1051 optimizer.zero_grad()
1052
1053 # Checks if the accelerator has performed an optimization step behind the scenes
1054 if accelerator.sync_gradients:
1055 if args.use_ema:
1056 if args.offload_ema:
1057 ema_unet.to(device="cuda", non_blocking=True)
1058 ema_unet.step(unet.parameters())
1059 if args.offload_ema:
1060 ema_unet.to(device="cpu", non_blocking=True)
1061 progress_bar.update(1)
1062 global_step += 1
1063 accelerator.log({"train_loss": train_loss}, step=global_step)
1064 train_loss = 0.0
1065
1066 if global_step % args.checkpointing_steps == 0:
1067 if accelerator.is_main_process:
1068 # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1069 if args.checkpoints_total_limit is not None:
1070 checkpoints = os.listdir(args.output_dir)
1071 checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1072 checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1073
1074 # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1075 if len(checkpoints) >= args.checkpoints_total_limit:
1076 num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1077 removing_checkpoints = checkpoints[0:num_to_remove]
1078
1079 logger.info(
1080 f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1081 )
1082 logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1083
1084 for removing_checkpoint in removing_checkpoints:
1085 removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1086 shutil.rmtree(removing_checkpoint)
1087
1088 save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1089 accelerator.save_state(save_path)
1090 logger.info(f"Saved state to {save_path}")
1091
1092 logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1093 progress_bar.set_postfix(**logs)
1094
1095 if global_step >= args.max_train_steps:
1096 break
1097
1098 if accelerator.is_main_process:
1099 if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
1100 if args.use_ema:
1101 # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
1102 ema_unet.store(unet.parameters())
1103 ema_unet.copy_to(unet.parameters())
1104 log_validation(
1105 vae,
1106 text_encoder,
1107 tokenizer,
1108 unet,
1109 args,
1110 accelerator,
1111 weight_dtype,
1112 global_step,
1113 )
1114 if args.use_ema:
1115 # Switch back to the original UNet parameters.
1116 ema_unet.restore(unet.parameters())
1117
1118 # Create the pipeline using the trained modules and save it.
1119 accelerator.wait_for_everyone()
1120 if accelerator.is_main_process:
1121 unet = unwrap_model(unet)
1122 if args.use_ema:
1123 ema_unet.copy_to(unet.parameters())
1124
1125 pipeline = StableDiffusionPipeline.from_pretrained(
1126 args.pretrained_model_name_or_path,
1127 text_encoder=text_encoder,
1128 vae=vae,
1129 unet=unet,
1130 revision=args.revision,
1131 variant=args.variant,
1132 )
1133 pipeline.save_pretrained(args.output_dir)
1134
1135 # Run a final round of inference.
1136 images = []
1137 if args.validation_prompts is not None:
1138 logger.info("Running inference for collecting generated images...")
1139 pipeline = pipeline.to(accelerator.device)
1140 pipeline.torch_dtype = weight_dtype
1141 pipeline.set_progress_bar_config(disable=True)
1142
1143 if args.enable_xformers_memory_efficient_attention:
1144 pipeline.enable_xformers_memory_efficient_attention()
1145
1146 if args.seed is None:
1147 generator = None
1148 else:
1149 generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
1150
1151 for i in range(len(args.validation_prompts)):
1152 with torch.autocast("cuda"):
1153 image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
1154 images.append(image)
1155
1156 if args.push_to_hub:
1157 save_model_card(args, repo_id, images, repo_folder=args.output_dir)
1158 upload_folder(
1159 repo_id=repo_id,
1160 folder_path=args.output_dir,
1161 commit_message="End of training",
1162 ignore_patterns=["step_*", "epoch_*"],
1163 )
1164
1165 accelerator.end_training()
1166
1167
1168if __name__ == "__main__":
1169 main()
1export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2export DATASET_NAME="lambdalabs/naruto-blip-captions"
3
4accelerate launch --mixed_precision="fp16" train_text_to_image.py \
5--pretrained_model_name_or_path=$MODEL_NAME \
6--dataset_name=$DATASET_NAME \
7--use_ema \
8--resolution=512 --center_crop --random_flip \
9--train_batch_size=1 \
10--gradient_accumulation_steps=4 \
11--gradient_checkpointing \
12--max_train_steps=15000 \
13--learning_rate=1e-05 \
14--max_grad_norm=1 \
15--lr_scheduler="constant" --lr_warmup_steps=0 \
16--output_dir="sd-pokemon-model"