快速开始
备注
在运行下述示例之前,需要您已经安装了PyTorch-NPU环境,有关环境安装,请参考 安装指南
一般来说,要在代码中使用NPU进行训练推理,需要做以下更改:
导入torch_npu扩展包
import torch_npu
将模型,以及模型输入上传到NPU上
1device= torch.device("npu")
2model = model.to(device)
3input = input.to(device)
下面的实例演示了如何使用NPU进行训练和推理任务:
1. 单卡训练
以下代码使用了cifar10数据集在NPU上训练模型(截取自 PyTorch tutorials),请关注高亮的内容。
1"""
2Training an image classifier
3----------------------------
4
5We will do the following steps in order:
6
71. Load and normalize the CIFAR10 training and test datasets using
8``torchvision``
91. Define a Convolutional Neural Network
102. Define a loss function
113. Train the network on the training data
124. Test the network on the test data
13
145. Load and normalize CIFAR10
15^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
16
17Using ``torchvision``, it’s extremely easy to load CIFAR10.
18"""
19import torch
20# 引入torch-npu包
21import torch_npu
22
23# 定义device
24device = torch.device('npu:0' if torch.npu.is_available() else 'cpu')
25print(device)
26
27import torchvision
28import torchvision.transforms as transforms
29
30########################################################################
31# The output of torchvision datasets are PILImage images of range [0, 1].
32# We transform them to Tensors of normalized range [-1, 1].
33transform = transforms.Compose(
34 [transforms.ToTensor(),
35 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
36
37batch_size = 4
38
39trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
40 download=True, transform=transform)
41trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
42 shuffle=True, num_workers=2)
43
44testset = torchvision.datasets.CIFAR10(root='./data', train=False,
45 download=True, transform=transform)
46testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
47 shuffle=False, num_workers=2)
48
49classes = ('plane', 'car', 'bird', 'cat',
50 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
51
52########################################################################
53# 2. Define a Convolutional Neural Network
54# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
55# Copy the neural network from the Neural Networks section before and modify it to
56# take 3-channel images (instead of 1-channel images as it was defined).
57import torch.nn as nn
58import torch.nn.functional as F
59
60
61class Net(nn.Module):
62 def __init__(self):
63 super().__init__()
64 self.conv1 = nn.Conv2d(3, 6, 5)
65 self.pool = nn.MaxPool2d(2, 2)
66 self.conv2 = nn.Conv2d(6, 16, 5)
67 self.fc1 = nn.Linear(16 * 5 * 5, 120)
68 self.fc2 = nn.Linear(120, 84)
69 self.fc3 = nn.Linear(84, 10)
70
71 def forward(self, x):
72 x = self.pool(F.relu(self.conv1(x)))
73 x = self.pool(F.relu(self.conv2(x)))
74 x = torch.flatten(x, 1) # flatten all dimensions except batch
75 x = F.relu(self.fc1(x))
76 x = F.relu(self.fc2(x))
77 x = self.fc3(x)
78 return x
79
80net = Net()
81
82# 将模型加载到NPU上
83net.to(device)
84
85########################################################################
86# 3. Define a Loss function and optimizer
87# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
88# Let's use a Classification Cross-Entropy loss and SGD with momentum.
89import torch.optim as optim
90
91criterion = nn.CrossEntropyLoss()
92optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
93
94########################################################################
95# 4. Train the network
96# ^^^^^^^^^^^^^^^^^^^^
97#
98# This is when things start to get interesting.
99# We simply have to loop over our data iterator, and feed the inputs to the
100# network and optimize.
101
102for epoch in range(2): # loop over the dataset multiple times
103
104 running_loss = 0.0
105 for i, data in enumerate(trainloader, 0):
106 # get the inputs; data is a list of [inputs, labels]
107 # 将input数据发送到NPU上
108 inputs, labels = data[0].to(device), data[1].to(device)
109
110 # zero the parameter gradients
111 optimizer.zero_grad()
112
113 # forward + backward + optimize
114 outputs = net(inputs)
115 loss = criterion(outputs, labels)
116 loss.backward()
117 optimizer.step()
118
119 # print statistics
120 running_loss += loss.item()
121 if i % 2000 == 1999: # print every 2000 mini-batches
122 print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
123 running_loss = 0.0
124
125print('Finished Training')
126
127########################################################################
128# 5. Test the network on the test data
129# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
130#
131# We have trained the network for 2 passes over the training dataset.
132# But we need to check if the network has learnt anything at all.
133#
134# We will check this by predicting the class label that the neural network
135# outputs, and checking it against the ground-truth. If the prediction is
136# correct, we add the sample to the list of correct predictions.
137#
138# Let us look at how the network performs on the whole dataset.
139correct = 0
140total = 0
141# since we're not training, we don't need to calculate the gradients for our outputs
142with torch.no_grad():
143 for data in testloader:
144 # 将input数据发送到NPU上
145 images, labels = data[0].to(device), data[1].to(device)
146 # calculate outputs by running images through the network
147 outputs = net(images)
148 # the class with the highest energy is what we choose as prediction
149 _, predicted = torch.max(outputs.data, 1)
150 total += labels.size(0)
151 correct += (predicted == labels).sum().item()
152
153print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
154########################################################################
155# That looks way better than chance, which is 10% accuracy (randomly picking
156# a class out of 10 classes).
157# Seems like the network learnt something.
158#
159# Hmmm, what are the classes that performed well, and the classes that did
160# not perform well:
161
162# prepare to count predictions for each class
163correct_pred = {classname: 0 for classname in classes}
164total_pred = {classname: 0 for classname in classes}
165
166# again no gradients needed
167with torch.no_grad():
168 for data in testloader:
169 # 将input数据发送到NPU上
170 images, labels = data[0].to(device), data[1].to(device)
171 outputs = net(images)
172 _, predictions = torch.max(outputs, 1)
173 # collect the correct predictions for each class
174 for label, prediction in zip(labels, predictions):
175 if label == prediction:
176 correct_pred[classes[label]] += 1
177 total_pred[classes[label]] += 1
178
179
180# print accuracy for each class
181for classname, correct_count in correct_pred.items():
182 accuracy = 100 * float(correct_count) / total_pred[classname]
183 print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
2. 使用DeepSpeed多卡并行训练
以下代码使用了cifar10数据集,使用DeepSpeed训练模型在多张NPU卡上进行模型训练(来自 DeepSpeed Examples),自DeepSpeed v0.12.6之后,代码无需任何修改,即可自动检测NPU并进行训练。
1import argparse
2import os
3
4import deepspeed
5import torch
6import torch.nn as nn
7import torch.nn.functional as F
8import torchvision
9import torchvision.transforms as transforms
10from deepspeed.accelerator import get_accelerator
11from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
12
13
14def add_argument():
15 parser = argparse.ArgumentParser(description="CIFAR")
16
17 # For train.
18 parser.add_argument(
19 "-e",
20 "--epochs",
21 default=30,
22 type=int,
23 help="number of total epochs (default: 30)",
24 )
25 parser.add_argument(
26 "--local_rank",
27 type=int,
28 default=-1,
29 help="local rank passed from distributed launcher",
30 )
31 parser.add_argument(
32 "--log-interval",
33 type=int,
34 default=2000,
35 help="output logging information at a given interval",
36 )
37
38 # For mixed precision training.
39 parser.add_argument(
40 "--dtype",
41 default="fp16",
42 type=str,
43 choices=["bf16", "fp16", "fp32"],
44 help="Datatype used for training",
45 )
46
47 # For ZeRO Optimization.
48 parser.add_argument(
49 "--stage",
50 default=0,
51 type=int,
52 choices=[0, 1, 2, 3],
53 help="Datatype used for training",
54 )
55
56 # For MoE (Mixture of Experts).
57 parser.add_argument(
58 "--moe",
59 default=False,
60 action="store_true",
61 help="use deepspeed mixture of experts (moe)",
62 )
63 parser.add_argument(
64 "--ep-world-size", default=1, type=int, help="(moe) expert parallel world size"
65 )
66 parser.add_argument(
67 "--num-experts",
68 type=int,
69 nargs="+",
70 default=[
71 1,
72 ],
73 help="number of experts list, MoE related.",
74 )
75 parser.add_argument(
76 "--mlp-type",
77 type=str,
78 default="standard",
79 help="Only applicable when num-experts > 1, accepts [standard, residual]",
80 )
81 parser.add_argument(
82 "--top-k", default=1, type=int, help="(moe) gating top 1 and 2 supported"
83 )
84 parser.add_argument(
85 "--min-capacity",
86 default=0,
87 type=int,
88 help="(moe) minimum capacity of an expert regardless of the capacity_factor",
89 )
90 parser.add_argument(
91 "--noisy-gate-policy",
92 default=None,
93 type=str,
94 help="(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter",
95 )
96 parser.add_argument(
97 "--moe-param-group",
98 default=False,
99 action="store_true",
100 help="(moe) create separate moe param groups, required when using ZeRO w. MoE",
101 )
102
103 # Include DeepSpeed configuration arguments.
104 parser = deepspeed.add_config_arguments(parser)
105
106 args = parser.parse_args()
107
108 return args
109
110
111def create_moe_param_groups(model):
112 """Create separate parameter groups for each expert."""
113 parameters = {"params": [p for p in model.parameters()], "name": "parameters"}
114 return split_params_into_different_moe_groups_for_optimizer(parameters)
115
116
117def get_ds_config(args):
118 """Get the DeepSpeed configuration dictionary."""
119 ds_config = {
120 "train_batch_size": 16,
121 "steps_per_print": 2000,
122 "optimizer": {
123 "type": "Adam",
124 "params": {
125 "lr": 0.001,
126 "betas": [0.8, 0.999],
127 "eps": 1e-8,
128 "weight_decay": 3e-7,
129 },
130 },
131 "scheduler": {
132 "type": "WarmupLR",
133 "params": {
134 "warmup_min_lr": 0,
135 "warmup_max_lr": 0.001,
136 "warmup_num_steps": 1000,
137 },
138 },
139 "gradient_clipping": 1.0,
140 "prescale_gradients": False,
141 "bf16": {"enabled": args.dtype == "bf16"},
142 "fp16": {
143 "enabled": args.dtype == "fp16",
144 "fp16_master_weights_and_grads": False,
145 "loss_scale": 0,
146 "loss_scale_window": 500,
147 "hysteresis": 2,
148 "min_loss_scale": 1,
149 "initial_scale_power": 15,
150 },
151 "wall_clock_breakdown": False,
152 "zero_optimization": {
153 "stage": args.stage,
154 "allgather_partitions": True,
155 "reduce_scatter": True,
156 "allgather_bucket_size": 50000000,
157 "reduce_bucket_size": 50000000,
158 "overlap_comm": True,
159 "contiguous_gradients": True,
160 "cpu_offload": False,
161 },
162 }
163 return ds_config
164
165
166class Net(nn.Module):
167 def __init__(self, args):
168 super(Net, self).__init__()
169 self.conv1 = nn.Conv2d(3, 6, 5)
170 self.pool = nn.MaxPool2d(2, 2)
171 self.conv2 = nn.Conv2d(6, 16, 5)
172 self.fc1 = nn.Linear(16 * 5 * 5, 120)
173 self.fc2 = nn.Linear(120, 84)
174 self.moe = args.moe
175 if self.moe:
176 fc3 = nn.Linear(84, 84)
177 self.moe_layer_list = []
178 for n_e in args.num_experts:
179 # Create moe layers based on the number of experts.
180 self.moe_layer_list.append(
181 deepspeed.moe.layer.MoE(
182 hidden_size=84,
183 expert=fc3,
184 num_experts=n_e,
185 ep_size=args.ep_world_size,
186 use_residual=args.mlp_type == "residual",
187 k=args.top_k,
188 min_capacity=args.min_capacity,
189 noisy_gate_policy=args.noisy_gate_policy,
190 )
191 )
192 self.moe_layer_list = nn.ModuleList(self.moe_layer_list)
193 self.fc4 = nn.Linear(84, 10)
194 else:
195 self.fc3 = nn.Linear(84, 10)
196
197 def forward(self, x):
198 x = self.pool(F.relu(self.conv1(x)))
199 x = self.pool(F.relu(self.conv2(x)))
200 x = x.view(-1, 16 * 5 * 5)
201 x = F.relu(self.fc1(x))
202 x = F.relu(self.fc2(x))
203 if self.moe:
204 for layer in self.moe_layer_list:
205 x, _, _ = layer(x)
206 x = self.fc4(x)
207 else:
208 x = self.fc3(x)
209 return x
210
211
212def test(model_engine, testset, local_device, target_dtype, test_batch_size=4):
213 """Test the network on the test data.
214
215 Args:
216 model_engine (deepspeed.runtime.engine.DeepSpeedEngine): the DeepSpeed engine.
217 testset (torch.utils.data.Dataset): the test dataset.
218 local_device (str): the local device name.
219 target_dtype (torch.dtype): the target datatype for the test data.
220 test_batch_size (int): the test batch size.
221
222 """
223 # The 10 classes for CIFAR10.
224 classes = (
225 "plane",
226 "car",
227 "bird",
228 "cat",
229 "deer",
230 "dog",
231 "frog",
232 "horse",
233 "ship",
234 "truck",
235 )
236
237 # Define the test dataloader.
238 testloader = torch.utils.data.DataLoader(
239 testset, batch_size=test_batch_size, shuffle=False, num_workers=0
240 )
241
242 # For total accuracy.
243 correct, total = 0, 0
244 # For accuracy per class.
245 class_correct = list(0.0 for i in range(10))
246 class_total = list(0.0 for i in range(10))
247
248 # Start testing.
249 model_engine.eval()
250 with torch.no_grad():
251 for data in testloader:
252 images, labels = data
253 if target_dtype != None:
254 images = images.to(target_dtype)
255 outputs = model_engine(images.to(local_device))
256 _, predicted = torch.max(outputs.data, 1)
257 # Count the total accuracy.
258 total += labels.size(0)
259 correct += (predicted == labels.to(local_device)).sum().item()
260
261 # Count the accuracy per class.
262 batch_correct = (predicted == labels.to(local_device)).squeeze()
263 for i in range(test_batch_size):
264 label = labels[i]
265 class_correct[label] += batch_correct[i].item()
266 class_total[label] += 1
267
268 if model_engine.local_rank == 0:
269 print(
270 f"Accuracy of the network on the {total} test images: {100 * correct / total : .0f} %"
271 )
272
273 # For all classes, print the accuracy.
274 for i in range(10):
275 print(
276 f"Accuracy of {classes[i] : >5s} : {100 * class_correct[i] / class_total[i] : 2.0f} %"
277 )
278
279
280def main(args):
281 # Initialize DeepSpeed distributed backend.
282 deepspeed.init_distributed()
283 _local_rank = int(os.environ.get("LOCAL_RANK"))
284 get_accelerator().set_device(_local_rank)
285
286 ########################################################################
287 # Step1. Data Preparation.
288 #
289 # The output of torchvision datasets are PILImage images of range [0, 1].
290 # We transform them to Tensors of normalized range [-1, 1].
291 #
292 # Note:
293 # If running on Windows and you get a BrokenPipeError, try setting
294 # the num_worker of torch.utils.data.DataLoader() to 0.
295 ########################################################################
296 transform = transforms.Compose(
297 [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
298 )
299
300 if torch.distributed.get_rank() != 0:
301 # Might be downloading cifar data, let rank 0 download first.
302 torch.distributed.barrier()
303
304 # Load or download cifar data.
305 trainset = torchvision.datasets.CIFAR10(
306 root="./data", train=True, download=True, transform=transform
307 )
308 testset = torchvision.datasets.CIFAR10(
309 root="./data", train=False, download=True, transform=transform
310 )
311
312 if torch.distributed.get_rank() == 0:
313 # Cifar data is downloaded, indicate other ranks can proceed.
314 torch.distributed.barrier()
315
316 ########################################################################
317 # Step 2. Define the network with DeepSpeed.
318 #
319 # First, we define a Convolution Neural Network.
320 # Then, we define the DeepSpeed configuration dictionary and use it to
321 # initialize the DeepSpeed engine.
322 ########################################################################
323 net = Net(args)
324
325 # Get list of parameters that require gradients.
326 parameters = filter(lambda p: p.requires_grad, net.parameters())
327
328 # If using MoE, create separate param groups for each expert.
329 if args.moe_param_group:
330 parameters = create_moe_param_groups(net)
331
332 # Initialize DeepSpeed to use the following features.
333 # 1) Distributed model.
334 # 2) Distributed data loader.
335 # 3) DeepSpeed optimizer.
336 ds_config = get_ds_config(args)
337 model_engine, optimizer, trainloader, __ = deepspeed.initialize(
338 args=args,
339 model=net,
340 model_parameters=parameters,
341 training_data=trainset,
342 config=ds_config,
343 )
344
345 # Get the local device name (str) and local rank (int).
346 local_device = get_accelerator().device_name(model_engine.local_rank)
347 local_rank = model_engine.local_rank
348
349 # For float32, target_dtype will be None so no datatype conversion needed.
350 target_dtype = None
351 if model_engine.bfloat16_enabled():
352 target_dtype = torch.bfloat16
353 elif model_engine.fp16_enabled():
354 target_dtype = torch.half
355
356 # Define the Classification Cross-Entropy loss function.
357 criterion = nn.CrossEntropyLoss()
358
359 ########################################################################
360 # Step 3. Train the network.
361 #
362 # This is when things start to get interesting.
363 # We simply have to loop over our data iterator, and feed the inputs to the
364 # network and optimize. (DeepSpeed handles the distributed details for us!)
365 ########################################################################
366
367 for epoch in range(args.epochs): # loop over the dataset multiple times
368 running_loss = 0.0
369 for i, data in enumerate(trainloader):
370 # Get the inputs. ``data`` is a list of [inputs, labels].
371 inputs, labels = data[0].to(local_device), data[1].to(local_device)
372
373 # Try to convert to target_dtype if needed.
374 if target_dtype != None:
375 inputs = inputs.to(target_dtype)
376
377 outputs = model_engine(inputs)
378 loss = criterion(outputs, labels)
379
380 model_engine.backward(loss)
381 model_engine.step()
382
383 # Print statistics
384 running_loss += loss.item()
385 if local_rank == 0 and i % args.log_interval == (
386 args.log_interval - 1
387 ): # Print every log_interval mini-batches.
388 print(
389 f"[{epoch + 1 : d}, {i + 1 : 5d}] loss: {running_loss / args.log_interval : .3f}"
390 )
391 running_loss = 0.0
392 print("Finished Training")
393
394 ########################################################################
395 # Step 4. Test the network on the test data.
396 ########################################################################
397 test(model_engine, testset, local_device, target_dtype)
398
399
400if __name__ == "__main__":
401 args = add_argument()
402 main(args)
3. 使用Transforms进行模型微调
以下代码使用了Transforms对LLM进行微调(来自 transforms examples),自transforms xxx版本以及accelerator 0.21.0版本以后,代码无需任何修改,即可自动检测NPU并进行。
1#!/usr/bin/env python
2# coding=utf-8
3# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""
17Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
18
19Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
20https://huggingface.co/models?filter=text-generation
21"""
22# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
23
24import logging
25import math
26import os
27import sys
28from dataclasses import dataclass, field
29from itertools import chain
30from typing import Optional
31
32import datasets
33import evaluate
34import torch
35from datasets import load_dataset
36
37import transformers
38from transformers import (
39 CONFIG_MAPPING,
40 MODEL_FOR_CAUSAL_LM_MAPPING,
41 AutoConfig,
42 AutoModelForCausalLM,
43 AutoTokenizer,
44 HfArgumentParser,
45 Trainer,
46 TrainingArguments,
47 default_data_collator,
48 is_torch_xla_available,
49 set_seed,
50)
51from transformers.testing_utils import CaptureLogger
52from transformers.trainer_utils import get_last_checkpoint
53from transformers.utils import check_min_version, send_example_telemetry
54from transformers.utils.versions import require_version
55
56
57# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
58check_min_version("4.47.0.dev0")
59
60require_version("datasets>=2.14.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
61
62logger = logging.getLogger(__name__)
63
64
65MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
66MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
67
68
69@dataclass
70class ModelArguments:
71 """
72 Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
73 """
74
75 model_name_or_path: Optional[str] = field(
76 default=None,
77 metadata={
78 "help": (
79 "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
80 )
81 },
82 )
83 model_type: Optional[str] = field(
84 default=None,
85 metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
86 )
87 config_overrides: Optional[str] = field(
88 default=None,
89 metadata={
90 "help": (
91 "Override some existing default config settings when a model is trained from scratch. Example: "
92 "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
93 )
94 },
95 )
96 config_name: Optional[str] = field(
97 default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
98 )
99 tokenizer_name: Optional[str] = field(
100 default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
101 )
102 cache_dir: Optional[str] = field(
103 default=None,
104 metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
105 )
106 use_fast_tokenizer: bool = field(
107 default=True,
108 metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
109 )
110 model_revision: str = field(
111 default="main",
112 metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
113 )
114 token: str = field(
115 default=None,
116 metadata={
117 "help": (
118 "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
119 "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
120 )
121 },
122 )
123 trust_remote_code: bool = field(
124 default=False,
125 metadata={
126 "help": (
127 "Whether to trust the execution of code from datasets/models defined on the Hub."
128 " This option should only be set to `True` for repositories you trust and in which you have read the"
129 " code, as it will execute code present on the Hub on your local machine."
130 )
131 },
132 )
133 torch_dtype: Optional[str] = field(
134 default=None,
135 metadata={
136 "help": (
137 "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
138 "dtype will be automatically derived from the model's weights."
139 ),
140 "choices": ["auto", "bfloat16", "float16", "float32"],
141 },
142 )
143 low_cpu_mem_usage: bool = field(
144 default=False,
145 metadata={
146 "help": (
147 "It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded. "
148 "set True will benefit LLM loading time and RAM consumption."
149 )
150 },
151 )
152
153 def __post_init__(self):
154 if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
155 raise ValueError(
156 "--config_overrides can't be used in combination with --config_name or --model_name_or_path"
157 )
158
159
160@dataclass
161class DataTrainingArguments:
162 """
163 Arguments pertaining to what data we are going to input our model for training and eval.
164 """
165
166 dataset_name: Optional[str] = field(
167 default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
168 )
169 dataset_config_name: Optional[str] = field(
170 default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
171 )
172 train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
173 validation_file: Optional[str] = field(
174 default=None,
175 metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
176 )
177 max_train_samples: Optional[int] = field(
178 default=None,
179 metadata={
180 "help": (
181 "For debugging purposes or quicker training, truncate the number of training examples to this "
182 "value if set."
183 )
184 },
185 )
186 max_eval_samples: Optional[int] = field(
187 default=None,
188 metadata={
189 "help": (
190 "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
191 "value if set."
192 )
193 },
194 )
195 streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
196 block_size: Optional[int] = field(
197 default=None,
198 metadata={
199 "help": (
200 "Optional input sequence length after tokenization. "
201 "The training dataset will be truncated in block of this size for training. "
202 "Default to the model max input length for single sentence inputs (take into account special tokens)."
203 )
204 },
205 )
206 overwrite_cache: bool = field(
207 default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
208 )
209 validation_split_percentage: Optional[int] = field(
210 default=5,
211 metadata={
212 "help": "The percentage of the train set used as validation set in case there's no validation split"
213 },
214 )
215 preprocessing_num_workers: Optional[int] = field(
216 default=None,
217 metadata={"help": "The number of processes to use for the preprocessing."},
218 )
219 keep_linebreaks: bool = field(
220 default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
221 )
222
223 def __post_init__(self):
224 if self.streaming:
225 require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
226
227 if self.dataset_name is None and self.train_file is None and self.validation_file is None:
228 raise ValueError("Need either a dataset name or a training/validation file.")
229 else:
230 if self.train_file is not None:
231 extension = self.train_file.split(".")[-1]
232 assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
233 if self.validation_file is not None:
234 extension = self.validation_file.split(".")[-1]
235 assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
236
237
238def main():
239 # See all possible arguments in src/transformers/training_args.py
240 # or by passing the --help flag to this script.
241 # We now keep distinct sets of args, for a cleaner separation of concerns.
242
243 parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
244 if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
245 # If we pass only one argument to the script and it's the path to a json file,
246 # let's parse it to get our arguments.
247 model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
248 else:
249 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
250
251 # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
252 # information sent is the one passed as arguments along with your Python/PyTorch versions.
253 send_example_telemetry("run_clm", model_args, data_args)
254
255 # Setup logging
256 logging.basicConfig(
257 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
258 datefmt="%m/%d/%Y %H:%M:%S",
259 handlers=[logging.StreamHandler(sys.stdout)],
260 )
261
262 if training_args.should_log:
263 # The default of training_args.log_level is passive, so we set log level at info here to have that default.
264 transformers.utils.logging.set_verbosity_info()
265
266 log_level = training_args.get_process_log_level()
267 logger.setLevel(log_level)
268 datasets.utils.logging.set_verbosity(log_level)
269 transformers.utils.logging.set_verbosity(log_level)
270 transformers.utils.logging.enable_default_handler()
271 transformers.utils.logging.enable_explicit_format()
272
273 # Log on each process the small summary:
274 logger.warning(
275 f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
276 + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
277 )
278 logger.info(f"Training/evaluation parameters {training_args}")
279
280 # Detecting last checkpoint.
281 last_checkpoint = None
282 if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
283 last_checkpoint = get_last_checkpoint(training_args.output_dir)
284 if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
285 raise ValueError(
286 f"Output directory ({training_args.output_dir}) already exists and is not empty. "
287 "Use --overwrite_output_dir to overcome."
288 )
289 elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
290 logger.info(
291 f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
292 "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
293 )
294
295 # Set seed before initializing model.
296 set_seed(training_args.seed)
297
298 # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
299 # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
300 # (the dataset will be downloaded automatically from the datasets Hub).
301 #
302 # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
303 # 'text' is found. You can easily tweak this behavior (see below).
304 #
305 # In distributed training, the load_dataset function guarantee that only one local process can concurrently
306 # download the dataset.
307 if data_args.dataset_name is not None:
308 # Downloading and loading a dataset from the hub.
309 raw_datasets = load_dataset(
310 data_args.dataset_name,
311 data_args.dataset_config_name,
312 cache_dir=model_args.cache_dir,
313 token=model_args.token,
314 streaming=data_args.streaming,
315 trust_remote_code=model_args.trust_remote_code,
316 )
317 if "validation" not in raw_datasets.keys():
318 raw_datasets["validation"] = load_dataset(
319 data_args.dataset_name,
320 data_args.dataset_config_name,
321 split=f"train[:{data_args.validation_split_percentage}%]",
322 cache_dir=model_args.cache_dir,
323 token=model_args.token,
324 streaming=data_args.streaming,
325 trust_remote_code=model_args.trust_remote_code,
326 )
327 raw_datasets["train"] = load_dataset(
328 data_args.dataset_name,
329 data_args.dataset_config_name,
330 split=f"train[{data_args.validation_split_percentage}%:]",
331 cache_dir=model_args.cache_dir,
332 token=model_args.token,
333 streaming=data_args.streaming,
334 trust_remote_code=model_args.trust_remote_code,
335 )
336 else:
337 data_files = {}
338 dataset_args = {}
339 if data_args.train_file is not None:
340 data_files["train"] = data_args.train_file
341 if data_args.validation_file is not None:
342 data_files["validation"] = data_args.validation_file
343 extension = (
344 data_args.train_file.split(".")[-1]
345 if data_args.train_file is not None
346 else data_args.validation_file.split(".")[-1]
347 )
348 if extension == "txt":
349 extension = "text"
350 dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
351 raw_datasets = load_dataset(
352 extension,
353 data_files=data_files,
354 cache_dir=model_args.cache_dir,
355 token=model_args.token,
356 **dataset_args,
357 )
358 # If no validation data is there, validation_split_percentage will be used to divide the dataset.
359 if "validation" not in raw_datasets.keys():
360 raw_datasets["validation"] = load_dataset(
361 extension,
362 data_files=data_files,
363 split=f"train[:{data_args.validation_split_percentage}%]",
364 cache_dir=model_args.cache_dir,
365 token=model_args.token,
366 **dataset_args,
367 )
368 raw_datasets["train"] = load_dataset(
369 extension,
370 data_files=data_files,
371 split=f"train[{data_args.validation_split_percentage}%:]",
372 cache_dir=model_args.cache_dir,
373 token=model_args.token,
374 **dataset_args,
375 )
376
377 # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
378 # https://huggingface.co/docs/datasets/loading_datasets.
379
380 # Load pretrained model and tokenizer
381 #
382 # Distributed training:
383 # The .from_pretrained methods guarantee that only one local process can concurrently
384 # download model & vocab.
385
386 config_kwargs = {
387 "cache_dir": model_args.cache_dir,
388 "revision": model_args.model_revision,
389 "token": model_args.token,
390 "trust_remote_code": model_args.trust_remote_code,
391 }
392 if model_args.config_name:
393 config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
394 elif model_args.model_name_or_path:
395 config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
396 else:
397 config = CONFIG_MAPPING[model_args.model_type]()
398 logger.warning("You are instantiating a new config instance from scratch.")
399 if model_args.config_overrides is not None:
400 logger.info(f"Overriding config: {model_args.config_overrides}")
401 config.update_from_string(model_args.config_overrides)
402 logger.info(f"New config: {config}")
403
404 tokenizer_kwargs = {
405 "cache_dir": model_args.cache_dir,
406 "use_fast": model_args.use_fast_tokenizer,
407 "revision": model_args.model_revision,
408 "token": model_args.token,
409 "trust_remote_code": model_args.trust_remote_code,
410 }
411 if model_args.tokenizer_name:
412 tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
413 elif model_args.model_name_or_path:
414 tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
415 else:
416 raise ValueError(
417 "You are instantiating a new tokenizer from scratch. This is not supported by this script. "
418 "You can do it from another script, save it, and load it from here, using --tokenizer_name."
419 )
420
421 if model_args.model_name_or_path:
422 torch_dtype = (
423 model_args.torch_dtype
424 if model_args.torch_dtype in ["auto", None]
425 else getattr(torch, model_args.torch_dtype)
426 )
427 model = AutoModelForCausalLM.from_pretrained(
428 model_args.model_name_or_path,
429 from_tf=bool(".ckpt" in model_args.model_name_or_path),
430 config=config,
431 cache_dir=model_args.cache_dir,
432 revision=model_args.model_revision,
433 token=model_args.token,
434 trust_remote_code=model_args.trust_remote_code,
435 torch_dtype=torch_dtype,
436 low_cpu_mem_usage=model_args.low_cpu_mem_usage,
437 )
438 else:
439 model = AutoModelForCausalLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
440 n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
441 logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
442
443 # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
444 # on a small vocab and want a smaller embedding size, remove this test.
445 embedding_size = model.get_input_embeddings().weight.shape[0]
446 if len(tokenizer) > embedding_size:
447 model.resize_token_embeddings(len(tokenizer))
448
449 # Preprocessing the datasets.
450 # First we tokenize all the texts.
451 if training_args.do_train:
452 column_names = list(raw_datasets["train"].features)
453 else:
454 column_names = list(raw_datasets["validation"].features)
455 text_column_name = "text" if "text" in column_names else column_names[0]
456
457 # since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
458 tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
459
460 def tokenize_function(examples):
461 with CaptureLogger(tok_logger) as cl:
462 output = tokenizer(examples[text_column_name])
463 # clm input could be much much longer than block_size
464 if "Token indices sequence length is longer than the" in cl.out:
465 tok_logger.warning(
466 "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
467 " before being passed to the model."
468 )
469 return output
470
471 with training_args.main_process_first(desc="dataset map tokenization"):
472 if not data_args.streaming:
473 tokenized_datasets = raw_datasets.map(
474 tokenize_function,
475 batched=True,
476 num_proc=data_args.preprocessing_num_workers,
477 remove_columns=column_names,
478 load_from_cache_file=not data_args.overwrite_cache,
479 desc="Running tokenizer on dataset",
480 )
481 else:
482 tokenized_datasets = raw_datasets.map(
483 tokenize_function,
484 batched=True,
485 remove_columns=column_names,
486 )
487 if hasattr(config, "max_position_embeddings"):
488 max_pos_embeddings = config.max_position_embeddings
489 else:
490 # Define a default value if the attribute is missing in the config.
491 max_pos_embeddings = 1024
492
493 if data_args.block_size is None:
494 block_size = tokenizer.model_max_length
495 if block_size > max_pos_embeddings:
496 logger.warning(
497 f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
498 f"Using block_size={min(1024, max_pos_embeddings)} instead. You can change that default value by passing --block_size xxx."
499 )
500 if max_pos_embeddings > 0:
501 block_size = min(1024, max_pos_embeddings)
502 else:
503 block_size = 1024
504 else:
505 if data_args.block_size > tokenizer.model_max_length:
506 logger.warning(
507 f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model "
508 f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
509 )
510 block_size = min(data_args.block_size, tokenizer.model_max_length)
511
512 # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
513 def group_texts(examples):
514 # Concatenate all texts.
515 concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
516 total_length = len(concatenated_examples[list(examples.keys())[0]])
517 # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
518 # We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
519 total_length = (total_length // block_size) * block_size
520 # Split by chunks of max_len.
521 result = {
522 k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
523 for k, t in concatenated_examples.items()
524 }
525 result["labels"] = result["input_ids"].copy()
526 return result
527
528 # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
529 # for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
530 # to preprocess.
531 #
532 # To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
533 # https://huggingface.co/docs/datasets/process#map
534
535 with training_args.main_process_first(desc="grouping texts together"):
536 if not data_args.streaming:
537 lm_datasets = tokenized_datasets.map(
538 group_texts,
539 batched=True,
540 num_proc=data_args.preprocessing_num_workers,
541 load_from_cache_file=not data_args.overwrite_cache,
542 desc=f"Grouping texts in chunks of {block_size}",
543 )
544 else:
545 lm_datasets = tokenized_datasets.map(
546 group_texts,
547 batched=True,
548 )
549
550 if training_args.do_train:
551 if "train" not in tokenized_datasets:
552 raise ValueError("--do_train requires a train dataset")
553 train_dataset = lm_datasets["train"]
554 if data_args.max_train_samples is not None:
555 max_train_samples = min(len(train_dataset), data_args.max_train_samples)
556 train_dataset = train_dataset.select(range(max_train_samples))
557
558 if training_args.do_eval:
559 if "validation" not in tokenized_datasets:
560 raise ValueError("--do_eval requires a validation dataset")
561 eval_dataset = lm_datasets["validation"]
562 if data_args.max_eval_samples is not None:
563 max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
564 eval_dataset = eval_dataset.select(range(max_eval_samples))
565
566 def preprocess_logits_for_metrics(logits, labels):
567 if isinstance(logits, tuple):
568 # Depending on the model and config, logits may contain extra tensors,
569 # like past_key_values, but logits always come first
570 logits = logits[0]
571 return logits.argmax(dim=-1)
572
573 metric = evaluate.load("accuracy", cache_dir=model_args.cache_dir)
574
575 def compute_metrics(eval_preds):
576 preds, labels = eval_preds
577 # preds have the same shape as the labels, after the argmax(-1) has been calculated
578 # by preprocess_logits_for_metrics but we need to shift the labels
579 labels = labels[:, 1:].reshape(-1)
580 preds = preds[:, :-1].reshape(-1)
581 return metric.compute(predictions=preds, references=labels)
582
583 # Initialize our Trainer
584 trainer = Trainer(
585 model=model,
586 args=training_args,
587 train_dataset=train_dataset if training_args.do_train else None,
588 eval_dataset=eval_dataset if training_args.do_eval else None,
589 processing_class=tokenizer,
590 # Data collator will default to DataCollatorWithPadding, so we change it.
591 data_collator=default_data_collator,
592 compute_metrics=compute_metrics if training_args.do_eval and not is_torch_xla_available() else None,
593 preprocess_logits_for_metrics=preprocess_logits_for_metrics
594 if training_args.do_eval and not is_torch_xla_available()
595 else None,
596 )
597
598 # Training
599 if training_args.do_train:
600 checkpoint = None
601 if training_args.resume_from_checkpoint is not None:
602 checkpoint = training_args.resume_from_checkpoint
603 elif last_checkpoint is not None:
604 checkpoint = last_checkpoint
605 train_result = trainer.train(resume_from_checkpoint=checkpoint)
606 trainer.save_model() # Saves the tokenizer too for easy upload
607
608 metrics = train_result.metrics
609
610 max_train_samples = (
611 data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
612 )
613 metrics["train_samples"] = min(max_train_samples, len(train_dataset))
614
615 trainer.log_metrics("train", metrics)
616 trainer.save_metrics("train", metrics)
617 trainer.save_state()
618
619 # Evaluation
620 if training_args.do_eval:
621 logger.info("*** Evaluate ***")
622
623 metrics = trainer.evaluate()
624
625 max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
626 metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
627 try:
628 perplexity = math.exp(metrics["eval_loss"])
629 except OverflowError:
630 perplexity = float("inf")
631 metrics["perplexity"] = perplexity
632
633 trainer.log_metrics("eval", metrics)
634 trainer.save_metrics("eval", metrics)
635
636 kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
637 if data_args.dataset_name is not None:
638 kwargs["dataset_tags"] = data_args.dataset_name
639 if data_args.dataset_config_name is not None:
640 kwargs["dataset_args"] = data_args.dataset_config_name
641 kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
642 else:
643 kwargs["dataset"] = data_args.dataset_name
644
645 if training_args.push_to_hub:
646 trainer.push_to_hub(**kwargs)
647 else:
648 trainer.create_model_card(**kwargs)
649
650
651def _mp_fn(index):
652 # For xla_spawn (TPUs)
653 main()
654
655
656if __name__ == "__main__":
657 main()
1python run_clm.py \
2 --model_name_or_path openai-community/gpt2 \
3 --train_file path_to_train_file \
4 --validation_file path_to_validation_file \
5 --per_device_train_batch_size 8 \
6 --per_device_eval_batch_size 8 \
7 --do_train \
8 --do_eval \
9 --output_dir /tmp/test-clm
4. 使用Diffusers进行模型微调
以下代码使用了Diffusers对文生图模型进行微调(来自 diffusers examples),自diffusers v0.27.0版本以后,代码无需任何修改,即可自动检测NPU并进行。
1#!/usr/bin/env python
2# coding=utf-8
3# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17import argparse
18import logging
19import math
20import os
21import random
22import shutil
23from contextlib import nullcontext
24from pathlib import Path
25
26import accelerate
27import datasets
28import numpy as np
29import torch
30import torch.nn.functional as F
31import torch.utils.checkpoint
32import transformers
33from accelerate import Accelerator
34from accelerate.logging import get_logger
35from accelerate.state import AcceleratorState
36from accelerate.utils import ProjectConfiguration, set_seed
37from datasets import load_dataset
38from huggingface_hub import create_repo, upload_folder
39from packaging import version
40from torchvision import transforms
41from tqdm.auto import tqdm
42from transformers import CLIPTextModel, CLIPTokenizer
43from transformers.utils import ContextManagers
44
45import diffusers
46from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
47from diffusers.optimization import get_scheduler
48from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
49from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
50from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
51from diffusers.utils.import_utils import is_xformers_available
52from diffusers.utils.torch_utils import is_compiled_module
53
54
55if is_wandb_available():
56 import wandb
57
58
59# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
60check_min_version("0.32.0.dev0")
61
62logger = get_logger(__name__, log_level="INFO")
63
64DATASET_NAME_MAPPING = {
65 "lambdalabs/naruto-blip-captions": ("image", "text"),
66}
67
68
69def save_model_card(
70 args,
71 repo_id: str,
72 images: list = None,
73 repo_folder: str = None,
74):
75 img_str = ""
76 if len(images) > 0:
77 image_grid = make_image_grid(images, 1, len(args.validation_prompts))
78 image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
79 img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"
80
81 model_description = f"""
82# Text-to-image finetuning - {repo_id}
83
84This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
85{img_str}
86
87## Pipeline usage
88
89You can use the pipeline like so:
90
91```python
92from diffusers import DiffusionPipeline
93import torch
94
95pipeline = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype=torch.float16)
96prompt = "{args.validation_prompts[0]}"
97image = pipeline(prompt).images[0]
98image.save("my_image.png")
99```
100
101## Training info
102
103These are the key hyperparameters used during training:
104
105* Epochs: {args.num_train_epochs}
106* Learning rate: {args.learning_rate}
107* Batch size: {args.train_batch_size}
108* Gradient accumulation steps: {args.gradient_accumulation_steps}
109* Image resolution: {args.resolution}
110* Mixed-precision: {args.mixed_precision}
111
112"""
113 wandb_info = ""
114 if is_wandb_available():
115 wandb_run_url = None
116 if wandb.run is not None:
117 wandb_run_url = wandb.run.url
118
119 if wandb_run_url is not None:
120 wandb_info = f"""
121More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
122"""
123
124 model_description += wandb_info
125
126 model_card = load_or_create_model_card(
127 repo_id_or_path=repo_id,
128 from_training=True,
129 license="creativeml-openrail-m",
130 base_model=args.pretrained_model_name_or_path,
131 model_description=model_description,
132 inference=True,
133 )
134
135 tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
136 model_card = populate_model_card(model_card, tags=tags)
137
138 model_card.save(os.path.join(repo_folder, "README.md"))
139
140
141def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
142 logger.info("Running validation... ")
143
144 pipeline = StableDiffusionPipeline.from_pretrained(
145 args.pretrained_model_name_or_path,
146 vae=accelerator.unwrap_model(vae),
147 text_encoder=accelerator.unwrap_model(text_encoder),
148 tokenizer=tokenizer,
149 unet=accelerator.unwrap_model(unet),
150 safety_checker=None,
151 revision=args.revision,
152 variant=args.variant,
153 torch_dtype=weight_dtype,
154 )
155 pipeline = pipeline.to(accelerator.device)
156 pipeline.set_progress_bar_config(disable=True)
157
158 if args.enable_xformers_memory_efficient_attention:
159 pipeline.enable_xformers_memory_efficient_attention()
160
161 if args.seed is None:
162 generator = None
163 else:
164 generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
165
166 images = []
167 for i in range(len(args.validation_prompts)):
168 if torch.backends.mps.is_available():
169 autocast_ctx = nullcontext()
170 else:
171 autocast_ctx = torch.autocast(accelerator.device.type)
172
173 with autocast_ctx:
174 image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
175
176 images.append(image)
177
178 for tracker in accelerator.trackers:
179 if tracker.name == "tensorboard":
180 np_images = np.stack([np.asarray(img) for img in images])
181 tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
182 elif tracker.name == "wandb":
183 tracker.log(
184 {
185 "validation": [
186 wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
187 for i, image in enumerate(images)
188 ]
189 }
190 )
191 else:
192 logger.warning(f"image logging not implemented for {tracker.name}")
193
194 del pipeline
195 torch.cuda.empty_cache()
196
197 return images
198
199
200def parse_args():
201 parser = argparse.ArgumentParser(description="Simple example of a training script.")
202 parser.add_argument(
203 "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
204 )
205 parser.add_argument(
206 "--pretrained_model_name_or_path",
207 type=str,
208 default=None,
209 required=True,
210 help="Path to pretrained model or model identifier from huggingface.co/models.",
211 )
212 parser.add_argument(
213 "--revision",
214 type=str,
215 default=None,
216 required=False,
217 help="Revision of pretrained model identifier from huggingface.co/models.",
218 )
219 parser.add_argument(
220 "--variant",
221 type=str,
222 default=None,
223 help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
224 )
225 parser.add_argument(
226 "--dataset_name",
227 type=str,
228 default=None,
229 help=(
230 "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
231 " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
232 " or to a folder containing files that 🤗 Datasets can understand."
233 ),
234 )
235 parser.add_argument(
236 "--dataset_config_name",
237 type=str,
238 default=None,
239 help="The config of the Dataset, leave as None if there's only one config.",
240 )
241 parser.add_argument(
242 "--train_data_dir",
243 type=str,
244 default=None,
245 help=(
246 "A folder containing the training data. Folder contents must follow the structure described in"
247 " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
248 " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
249 ),
250 )
251 parser.add_argument(
252 "--image_column", type=str, default="image", help="The column of the dataset containing an image."
253 )
254 parser.add_argument(
255 "--caption_column",
256 type=str,
257 default="text",
258 help="The column of the dataset containing a caption or a list of captions.",
259 )
260 parser.add_argument(
261 "--max_train_samples",
262 type=int,
263 default=None,
264 help=(
265 "For debugging purposes or quicker training, truncate the number of training examples to this "
266 "value if set."
267 ),
268 )
269 parser.add_argument(
270 "--validation_prompts",
271 type=str,
272 default=None,
273 nargs="+",
274 help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
275 )
276 parser.add_argument(
277 "--output_dir",
278 type=str,
279 default="sd-model-finetuned",
280 help="The output directory where the model predictions and checkpoints will be written.",
281 )
282 parser.add_argument(
283 "--cache_dir",
284 type=str,
285 default=None,
286 help="The directory where the downloaded models and datasets will be stored.",
287 )
288 parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
289 parser.add_argument(
290 "--resolution",
291 type=int,
292 default=512,
293 help=(
294 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
295 " resolution"
296 ),
297 )
298 parser.add_argument(
299 "--center_crop",
300 default=False,
301 action="store_true",
302 help=(
303 "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
304 " cropped. The images will be resized to the resolution first before cropping."
305 ),
306 )
307 parser.add_argument(
308 "--random_flip",
309 action="store_true",
310 help="whether to randomly flip images horizontally",
311 )
312 parser.add_argument(
313 "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
314 )
315 parser.add_argument("--num_train_epochs", type=int, default=100)
316 parser.add_argument(
317 "--max_train_steps",
318 type=int,
319 default=None,
320 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
321 )
322 parser.add_argument(
323 "--gradient_accumulation_steps",
324 type=int,
325 default=1,
326 help="Number of updates steps to accumulate before performing a backward/update pass.",
327 )
328 parser.add_argument(
329 "--gradient_checkpointing",
330 action="store_true",
331 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
332 )
333 parser.add_argument(
334 "--learning_rate",
335 type=float,
336 default=1e-4,
337 help="Initial learning rate (after the potential warmup period) to use.",
338 )
339 parser.add_argument(
340 "--scale_lr",
341 action="store_true",
342 default=False,
343 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
344 )
345 parser.add_argument(
346 "--lr_scheduler",
347 type=str,
348 default="constant",
349 help=(
350 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
351 ' "constant", "constant_with_warmup"]'
352 ),
353 )
354 parser.add_argument(
355 "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
356 )
357 parser.add_argument(
358 "--snr_gamma",
359 type=float,
360 default=None,
361 help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
362 "More details here: https://arxiv.org/abs/2303.09556.",
363 )
364 parser.add_argument(
365 "--dream_training",
366 action="store_true",
367 help=(
368 "Use the DREAM training method, which makes training more efficient and accurate at the ",
369 "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
370 ),
371 )
372 parser.add_argument(
373 "--dream_detail_preservation",
374 type=float,
375 default=1.0,
376 help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)",
377 )
378 parser.add_argument(
379 "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
380 )
381 parser.add_argument(
382 "--allow_tf32",
383 action="store_true",
384 help=(
385 "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
386 " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
387 ),
388 )
389 parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
390 parser.add_argument("--offload_ema", action="store_true", help="Offload EMA model to CPU during training step.")
391 parser.add_argument("--foreach_ema", action="store_true", help="Use faster foreach implementation of EMAModel.")
392 parser.add_argument(
393 "--non_ema_revision",
394 type=str,
395 default=None,
396 required=False,
397 help=(
398 "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
399 " remote repository specified with --pretrained_model_name_or_path."
400 ),
401 )
402 parser.add_argument(
403 "--dataloader_num_workers",
404 type=int,
405 default=0,
406 help=(
407 "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
408 ),
409 )
410 parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
411 parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
412 parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
413 parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
414 parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
415 parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
416 parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
417 parser.add_argument(
418 "--prediction_type",
419 type=str,
420 default=None,
421 help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
422 )
423 parser.add_argument(
424 "--hub_model_id",
425 type=str,
426 default=None,
427 help="The name of the repository to keep in sync with the local `output_dir`.",
428 )
429 parser.add_argument(
430 "--logging_dir",
431 type=str,
432 default="logs",
433 help=(
434 "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
435 " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
436 ),
437 )
438 parser.add_argument(
439 "--mixed_precision",
440 type=str,
441 default=None,
442 choices=["no", "fp16", "bf16"],
443 help=(
444 "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
445 " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
446 " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
447 ),
448 )
449 parser.add_argument(
450 "--report_to",
451 type=str,
452 default="tensorboard",
453 help=(
454 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
455 ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
456 ),
457 )
458 parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
459 parser.add_argument(
460 "--checkpointing_steps",
461 type=int,
462 default=500,
463 help=(
464 "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
465 " training using `--resume_from_checkpoint`."
466 ),
467 )
468 parser.add_argument(
469 "--checkpoints_total_limit",
470 type=int,
471 default=None,
472 help=("Max number of checkpoints to store."),
473 )
474 parser.add_argument(
475 "--resume_from_checkpoint",
476 type=str,
477 default=None,
478 help=(
479 "Whether training should be resumed from a previous checkpoint. Use a path saved by"
480 ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
481 ),
482 )
483 parser.add_argument(
484 "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
485 )
486 parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
487 parser.add_argument(
488 "--validation_epochs",
489 type=int,
490 default=5,
491 help="Run validation every X epochs.",
492 )
493 parser.add_argument(
494 "--tracker_project_name",
495 type=str,
496 default="text2image-fine-tune",
497 help=(
498 "The `project_name` argument passed to Accelerator.init_trackers for"
499 " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
500 ),
501 )
502
503 args = parser.parse_args()
504 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
505 if env_local_rank != -1 and env_local_rank != args.local_rank:
506 args.local_rank = env_local_rank
507
508 # Sanity checks
509 if args.dataset_name is None and args.train_data_dir is None:
510 raise ValueError("Need either a dataset name or a training folder.")
511
512 # default to using the same revision for the non-ema model if not specified
513 if args.non_ema_revision is None:
514 args.non_ema_revision = args.revision
515
516 return args
517
518
519def main():
520 args = parse_args()
521
522 if args.report_to == "wandb" and args.hub_token is not None:
523 raise ValueError(
524 "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
525 " Please use `huggingface-cli login` to authenticate with the Hub."
526 )
527
528 if args.non_ema_revision is not None:
529 deprecate(
530 "non_ema_revision!=None",
531 "0.15.0",
532 message=(
533 "Downloading 'non_ema' weights from revision branches of the Hub is deprecated. Please make sure to"
534 " use `--variant=non_ema` instead."
535 ),
536 )
537 logging_dir = os.path.join(args.output_dir, args.logging_dir)
538
539 accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
540
541 accelerator = Accelerator(
542 gradient_accumulation_steps=args.gradient_accumulation_steps,
543 mixed_precision=args.mixed_precision,
544 log_with=args.report_to,
545 project_config=accelerator_project_config,
546 )
547
548 # Disable AMP for MPS.
549 if torch.backends.mps.is_available():
550 accelerator.native_amp = False
551
552 # Make one log on every process with the configuration for debugging.
553 logging.basicConfig(
554 format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
555 datefmt="%m/%d/%Y %H:%M:%S",
556 level=logging.INFO,
557 )
558 logger.info(accelerator.state, main_process_only=False)
559 if accelerator.is_local_main_process:
560 datasets.utils.logging.set_verbosity_warning()
561 transformers.utils.logging.set_verbosity_warning()
562 diffusers.utils.logging.set_verbosity_info()
563 else:
564 datasets.utils.logging.set_verbosity_error()
565 transformers.utils.logging.set_verbosity_error()
566 diffusers.utils.logging.set_verbosity_error()
567
568 # If passed along, set the training seed now.
569 if args.seed is not None:
570 set_seed(args.seed)
571
572 # Handle the repository creation
573 if accelerator.is_main_process:
574 if args.output_dir is not None:
575 os.makedirs(args.output_dir, exist_ok=True)
576
577 if args.push_to_hub:
578 repo_id = create_repo(
579 repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
580 ).repo_id
581
582 # Load scheduler, tokenizer and models.
583 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
584 tokenizer = CLIPTokenizer.from_pretrained(
585 args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
586 )
587
588 def deepspeed_zero_init_disabled_context_manager():
589 """
590 returns either a context list that includes one that will disable zero.Init or an empty context list
591 """
592 deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
593 if deepspeed_plugin is None:
594 return []
595
596 return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
597
598 # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
599 # For this to work properly all models must be run through `accelerate.prepare`. But accelerate
600 # will try to assign the same optimizer with the same weights to all models during
601 # `deepspeed.initialize`, which of course doesn't work.
602 #
603 # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
604 # frozen models from being partitioned during `zero.Init` which gets called during
605 # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
606 # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
607 with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
608 text_encoder = CLIPTextModel.from_pretrained(
609 args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
610 )
611 vae = AutoencoderKL.from_pretrained(
612 args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, variant=args.variant
613 )
614
615 unet = UNet2DConditionModel.from_pretrained(
616 args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
617 )
618
619 # Freeze vae and text_encoder and set unet to trainable
620 vae.requires_grad_(False)
621 text_encoder.requires_grad_(False)
622 unet.train()
623
624 # Create EMA for the unet.
625 if args.use_ema:
626 ema_unet = UNet2DConditionModel.from_pretrained(
627 args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
628 )
629 ema_unet = EMAModel(
630 ema_unet.parameters(),
631 model_cls=UNet2DConditionModel,
632 model_config=ema_unet.config,
633 foreach=args.foreach_ema,
634 )
635
636 if args.enable_xformers_memory_efficient_attention:
637 if is_xformers_available():
638 import xformers
639
640 xformers_version = version.parse(xformers.__version__)
641 if xformers_version == version.parse("0.0.16"):
642 logger.warning(
643 "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
644 )
645 unet.enable_xformers_memory_efficient_attention()
646 else:
647 raise ValueError("xformers is not available. Make sure it is installed correctly")
648
649 # `accelerate` 0.16.0 will have better support for customized saving
650 if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
651 # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
652 def save_model_hook(models, weights, output_dir):
653 if accelerator.is_main_process:
654 if args.use_ema:
655 ema_unet.save_pretrained(os.path.join(output_dir, "unet_ema"))
656
657 for i, model in enumerate(models):
658 model.save_pretrained(os.path.join(output_dir, "unet"))
659
660 # make sure to pop weight so that corresponding model is not saved again
661 weights.pop()
662
663 def load_model_hook(models, input_dir):
664 if args.use_ema:
665 load_model = EMAModel.from_pretrained(
666 os.path.join(input_dir, "unet_ema"), UNet2DConditionModel, foreach=args.foreach_ema
667 )
668 ema_unet.load_state_dict(load_model.state_dict())
669 if args.offload_ema:
670 ema_unet.pin_memory()
671 else:
672 ema_unet.to(accelerator.device)
673 del load_model
674
675 for _ in range(len(models)):
676 # pop models so that they are not loaded again
677 model = models.pop()
678
679 # load diffusers style into model
680 load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
681 model.register_to_config(**load_model.config)
682
683 model.load_state_dict(load_model.state_dict())
684 del load_model
685
686 accelerator.register_save_state_pre_hook(save_model_hook)
687 accelerator.register_load_state_pre_hook(load_model_hook)
688
689 if args.gradient_checkpointing:
690 unet.enable_gradient_checkpointing()
691
692 # Enable TF32 for faster training on Ampere GPUs,
693 # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
694 if args.allow_tf32:
695 torch.backends.cuda.matmul.allow_tf32 = True
696
697 if args.scale_lr:
698 args.learning_rate = (
699 args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
700 )
701
702 # Initialize the optimizer
703 if args.use_8bit_adam:
704 try:
705 import bitsandbytes as bnb
706 except ImportError:
707 raise ImportError(
708 "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
709 )
710
711 optimizer_cls = bnb.optim.AdamW8bit
712 else:
713 optimizer_cls = torch.optim.AdamW
714
715 optimizer = optimizer_cls(
716 unet.parameters(),
717 lr=args.learning_rate,
718 betas=(args.adam_beta1, args.adam_beta2),
719 weight_decay=args.adam_weight_decay,
720 eps=args.adam_epsilon,
721 )
722
723 # Get the datasets: you can either provide your own training and evaluation files (see below)
724 # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
725
726 # In distributed training, the load_dataset function guarantees that only one local process can concurrently
727 # download the dataset.
728 if args.dataset_name is not None:
729 # Downloading and loading a dataset from the hub.
730 dataset = load_dataset(
731 args.dataset_name,
732 args.dataset_config_name,
733 cache_dir=args.cache_dir,
734 data_dir=args.train_data_dir,
735 )
736 else:
737 data_files = {}
738 if args.train_data_dir is not None:
739 data_files["train"] = os.path.join(args.train_data_dir, "**")
740 dataset = load_dataset(
741 "imagefolder",
742 data_files=data_files,
743 cache_dir=args.cache_dir,
744 )
745 # See more about loading custom images at
746 # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
747
748 # Preprocessing the datasets.
749 # We need to tokenize inputs and targets.
750 column_names = dataset["train"].column_names
751
752 # 6. Get the column names for input/target.
753 dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
754 if args.image_column is None:
755 image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
756 else:
757 image_column = args.image_column
758 if image_column not in column_names:
759 raise ValueError(
760 f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
761 )
762 if args.caption_column is None:
763 caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
764 else:
765 caption_column = args.caption_column
766 if caption_column not in column_names:
767 raise ValueError(
768 f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
769 )
770
771 # Preprocessing the datasets.
772 # We need to tokenize input captions and transform the images.
773 def tokenize_captions(examples, is_train=True):
774 captions = []
775 for caption in examples[caption_column]:
776 if isinstance(caption, str):
777 captions.append(caption)
778 elif isinstance(caption, (list, np.ndarray)):
779 # take a random caption if there are multiple
780 captions.append(random.choice(caption) if is_train else caption[0])
781 else:
782 raise ValueError(
783 f"Caption column `{caption_column}` should contain either strings or lists of strings."
784 )
785 inputs = tokenizer(
786 captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
787 )
788 return inputs.input_ids
789
790 # Preprocessing the datasets.
791 train_transforms = transforms.Compose(
792 [
793 transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
794 transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
795 transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
796 transforms.ToTensor(),
797 transforms.Normalize([0.5], [0.5]),
798 ]
799 )
800
801 def preprocess_train(examples):
802 images = [image.convert("RGB") for image in examples[image_column]]
803 examples["pixel_values"] = [train_transforms(image) for image in images]
804 examples["input_ids"] = tokenize_captions(examples)
805 return examples
806
807 with accelerator.main_process_first():
808 if args.max_train_samples is not None:
809 dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
810 # Set the training transforms
811 train_dataset = dataset["train"].with_transform(preprocess_train)
812
813 def collate_fn(examples):
814 pixel_values = torch.stack([example["pixel_values"] for example in examples])
815 pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
816 input_ids = torch.stack([example["input_ids"] for example in examples])
817 return {"pixel_values": pixel_values, "input_ids": input_ids}
818
819 # DataLoaders creation:
820 train_dataloader = torch.utils.data.DataLoader(
821 train_dataset,
822 shuffle=True,
823 collate_fn=collate_fn,
824 batch_size=args.train_batch_size,
825 num_workers=args.dataloader_num_workers,
826 )
827
828 # Scheduler and math around the number of training steps.
829 # Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
830 num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
831 if args.max_train_steps is None:
832 len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
833 num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
834 num_training_steps_for_scheduler = (
835 args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
836 )
837 else:
838 num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
839
840 lr_scheduler = get_scheduler(
841 args.lr_scheduler,
842 optimizer=optimizer,
843 num_warmup_steps=num_warmup_steps_for_scheduler,
844 num_training_steps=num_training_steps_for_scheduler,
845 )
846
847 # Prepare everything with our `accelerator`.
848 unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
849 unet, optimizer, train_dataloader, lr_scheduler
850 )
851
852 if args.use_ema:
853 if args.offload_ema:
854 ema_unet.pin_memory()
855 else:
856 ema_unet.to(accelerator.device)
857
858 # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
859 # as these weights are only used for inference, keeping weights in full precision is not required.
860 weight_dtype = torch.float32
861 if accelerator.mixed_precision == "fp16":
862 weight_dtype = torch.float16
863 args.mixed_precision = accelerator.mixed_precision
864 elif accelerator.mixed_precision == "bf16":
865 weight_dtype = torch.bfloat16
866 args.mixed_precision = accelerator.mixed_precision
867
868 # Move text_encode and vae to gpu and cast to weight_dtype
869 text_encoder.to(accelerator.device, dtype=weight_dtype)
870 vae.to(accelerator.device, dtype=weight_dtype)
871
872 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
873 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
874 if args.max_train_steps is None:
875 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
876 if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
877 logger.warning(
878 f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
879 f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
880 f"This inconsistency may result in the learning rate scheduler not functioning properly."
881 )
882 # Afterwards we recalculate our number of training epochs
883 args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
884
885 # We need to initialize the trackers we use, and also store our configuration.
886 # The trackers initializes automatically on the main process.
887 if accelerator.is_main_process:
888 tracker_config = dict(vars(args))
889 tracker_config.pop("validation_prompts")
890 accelerator.init_trackers(args.tracker_project_name, tracker_config)
891
892 # Function for unwrapping if model was compiled with `torch.compile`.
893 def unwrap_model(model):
894 model = accelerator.unwrap_model(model)
895 model = model._orig_mod if is_compiled_module(model) else model
896 return model
897
898 # Train!
899 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
900
901 logger.info("***** Running training *****")
902 logger.info(f" Num examples = {len(train_dataset)}")
903 logger.info(f" Num Epochs = {args.num_train_epochs}")
904 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
905 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
906 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
907 logger.info(f" Total optimization steps = {args.max_train_steps}")
908 global_step = 0
909 first_epoch = 0
910
911 # Potentially load in the weights and states from a previous save
912 if args.resume_from_checkpoint:
913 if args.resume_from_checkpoint != "latest":
914 path = os.path.basename(args.resume_from_checkpoint)
915 else:
916 # Get the most recent checkpoint
917 dirs = os.listdir(args.output_dir)
918 dirs = [d for d in dirs if d.startswith("checkpoint")]
919 dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
920 path = dirs[-1] if len(dirs) > 0 else None
921
922 if path is None:
923 accelerator.print(
924 f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
925 )
926 args.resume_from_checkpoint = None
927 initial_global_step = 0
928 else:
929 accelerator.print(f"Resuming from checkpoint {path}")
930 accelerator.load_state(os.path.join(args.output_dir, path))
931 global_step = int(path.split("-")[1])
932
933 initial_global_step = global_step
934 first_epoch = global_step // num_update_steps_per_epoch
935
936 else:
937 initial_global_step = 0
938
939 progress_bar = tqdm(
940 range(0, args.max_train_steps),
941 initial=initial_global_step,
942 desc="Steps",
943 # Only show the progress bar once on each machine.
944 disable=not accelerator.is_local_main_process,
945 )
946
947 for epoch in range(first_epoch, args.num_train_epochs):
948 train_loss = 0.0
949 for step, batch in enumerate(train_dataloader):
950 with accelerator.accumulate(unet):
951 # Convert images to latent space
952 latents = vae.encode(batch["pixel_values"].to(weight_dtype)).latent_dist.sample()
953 latents = latents * vae.config.scaling_factor
954
955 # Sample noise that we'll add to the latents
956 noise = torch.randn_like(latents)
957 if args.noise_offset:
958 # https://www.crosslabs.org//blog/diffusion-with-offset-noise
959 noise += args.noise_offset * torch.randn(
960 (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
961 )
962 if args.input_perturbation:
963 new_noise = noise + args.input_perturbation * torch.randn_like(noise)
964 bsz = latents.shape[0]
965 # Sample a random timestep for each image
966 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
967 timesteps = timesteps.long()
968
969 # Add noise to the latents according to the noise magnitude at each timestep
970 # (this is the forward diffusion process)
971 if args.input_perturbation:
972 noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
973 else:
974 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
975
976 # Get the text embedding for conditioning
977 encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0]
978
979 # Get the target for loss depending on the prediction type
980 if args.prediction_type is not None:
981 # set prediction_type of scheduler if defined
982 noise_scheduler.register_to_config(prediction_type=args.prediction_type)
983
984 if noise_scheduler.config.prediction_type == "epsilon":
985 target = noise
986 elif noise_scheduler.config.prediction_type == "v_prediction":
987 target = noise_scheduler.get_velocity(latents, noise, timesteps)
988 else:
989 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
990
991 if args.dream_training:
992 noisy_latents, target = compute_dream_and_update_latents(
993 unet,
994 noise_scheduler,
995 timesteps,
996 noise,
997 noisy_latents,
998 target,
999 encoder_hidden_states,
1000 args.dream_detail_preservation,
1001 )
1002
1003 # Predict the noise residual and compute loss
1004 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
1005
1006 if args.snr_gamma is None:
1007 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
1008 else:
1009 # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1010 # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1011 # This is discussed in Section 4.2 of the same paper.
1012 snr = compute_snr(noise_scheduler, timesteps)
1013 mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
1014 dim=1
1015 )[0]
1016 if noise_scheduler.config.prediction_type == "epsilon":
1017 mse_loss_weights = mse_loss_weights / snr
1018 elif noise_scheduler.config.prediction_type == "v_prediction":
1019 mse_loss_weights = mse_loss_weights / (snr + 1)
1020
1021 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1022 loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1023 loss = loss.mean()
1024
1025 # Gather the losses across all processes for logging (if we use distributed training).
1026 avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1027 train_loss += avg_loss.item() / args.gradient_accumulation_steps
1028
1029 # Backpropagate
1030 accelerator.backward(loss)
1031 if accelerator.sync_gradients:
1032 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
1033 optimizer.step()
1034 lr_scheduler.step()
1035 optimizer.zero_grad()
1036
1037 # Checks if the accelerator has performed an optimization step behind the scenes
1038 if accelerator.sync_gradients:
1039 if args.use_ema:
1040 if args.offload_ema:
1041 ema_unet.to(device="cuda", non_blocking=True)
1042 ema_unet.step(unet.parameters())
1043 if args.offload_ema:
1044 ema_unet.to(device="cpu", non_blocking=True)
1045 progress_bar.update(1)
1046 global_step += 1
1047 accelerator.log({"train_loss": train_loss}, step=global_step)
1048 train_loss = 0.0
1049
1050 if global_step % args.checkpointing_steps == 0:
1051 if accelerator.is_main_process:
1052 # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1053 if args.checkpoints_total_limit is not None:
1054 checkpoints = os.listdir(args.output_dir)
1055 checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1056 checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1057
1058 # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1059 if len(checkpoints) >= args.checkpoints_total_limit:
1060 num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1061 removing_checkpoints = checkpoints[0:num_to_remove]
1062
1063 logger.info(
1064 f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1065 )
1066 logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1067
1068 for removing_checkpoint in removing_checkpoints:
1069 removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1070 shutil.rmtree(removing_checkpoint)
1071
1072 save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1073 accelerator.save_state(save_path)
1074 logger.info(f"Saved state to {save_path}")
1075
1076 logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1077 progress_bar.set_postfix(**logs)
1078
1079 if global_step >= args.max_train_steps:
1080 break
1081
1082 if accelerator.is_main_process:
1083 if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
1084 if args.use_ema:
1085 # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
1086 ema_unet.store(unet.parameters())
1087 ema_unet.copy_to(unet.parameters())
1088 log_validation(
1089 vae,
1090 text_encoder,
1091 tokenizer,
1092 unet,
1093 args,
1094 accelerator,
1095 weight_dtype,
1096 global_step,
1097 )
1098 if args.use_ema:
1099 # Switch back to the original UNet parameters.
1100 ema_unet.restore(unet.parameters())
1101
1102 # Create the pipeline using the trained modules and save it.
1103 accelerator.wait_for_everyone()
1104 if accelerator.is_main_process:
1105 unet = unwrap_model(unet)
1106 if args.use_ema:
1107 ema_unet.copy_to(unet.parameters())
1108
1109 pipeline = StableDiffusionPipeline.from_pretrained(
1110 args.pretrained_model_name_or_path,
1111 text_encoder=text_encoder,
1112 vae=vae,
1113 unet=unet,
1114 revision=args.revision,
1115 variant=args.variant,
1116 )
1117 pipeline.save_pretrained(args.output_dir)
1118
1119 # Run a final round of inference.
1120 images = []
1121 if args.validation_prompts is not None:
1122 logger.info("Running inference for collecting generated images...")
1123 pipeline = pipeline.to(accelerator.device)
1124 pipeline.torch_dtype = weight_dtype
1125 pipeline.set_progress_bar_config(disable=True)
1126
1127 if args.enable_xformers_memory_efficient_attention:
1128 pipeline.enable_xformers_memory_efficient_attention()
1129
1130 if args.seed is None:
1131 generator = None
1132 else:
1133 generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
1134
1135 for i in range(len(args.validation_prompts)):
1136 with torch.autocast("cuda"):
1137 image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]
1138 images.append(image)
1139
1140 if args.push_to_hub:
1141 save_model_card(args, repo_id, images, repo_folder=args.output_dir)
1142 upload_folder(
1143 repo_id=repo_id,
1144 folder_path=args.output_dir,
1145 commit_message="End of training",
1146 ignore_patterns=["step_*", "epoch_*"],
1147 )
1148
1149 accelerator.end_training()
1150
1151
1152if __name__ == "__main__":
1153 main()
1export MODEL_NAME="CompVis/stable-diffusion-v1-4"
2export DATASET_NAME="lambdalabs/naruto-blip-captions"
3
4accelerate launch --mixed_precision="fp16" train_text_to_image.py \
5--pretrained_model_name_or_path=$MODEL_NAME \
6--dataset_name=$DATASET_NAME \
7--use_ema \
8--resolution=512 --center_crop --random_flip \
9--train_batch_size=1 \
10--gradient_accumulation_steps=4 \
11--gradient_checkpointing \
12--max_train_steps=15000 \
13--learning_rate=1e-05 \
14--max_grad_norm=1 \
15--lr_scheduler="constant" --lr_warmup_steps=0 \
16--output_dir="sd-pokemon-model"