

在本示例之前,请确保已经安装了 DeepSpeed 环境。 如果还未安装,可以执行 pip install deepspeed 完成安装。

1. 使用DeepSpeed多卡并行训练

以下代码使用了cifar10数据集,使用DeepSpeed训练模型在多张NPU卡上进行模型训练(来自 DeepSpeed Examples),自DeepSpeed v0.12.6之后,代码无需任何修改,即可自动检测NPU并进行训练。

  1import argparse
  2import os
  4import deepspeed
  5import torch
  6import torch.nn as nn
  7import torch.nn.functional as F
  8import torchvision
  9from torchvision import transforms
 10from deepspeed.accelerator import get_accelerator
 11from deepspeed.moe.utils import split_params_into_different_moe_groups_for_optimizer
 14def add_argument():
 15    parser = argparse.ArgumentParser(description="CIFAR")
 17    # For train.
 18    parser.add_argument(
 19        "-e",
 20        "--epochs",
 21        default=30,
 22        type=int,
 23        help="number of total epochs (default: 30)",
 24    )
 25    parser.add_argument(
 26        "--local_rank",
 27        type=int,
 28        default=-1,
 29        help="local rank passed from distributed launcher",
 30    )
 31    parser.add_argument(
 32        "--log-interval",
 33        type=int,
 34        default=2000,
 35        help="output logging information at a given interval",
 36    )
 38    # For mixed precision training.
 39    parser.add_argument(
 40        "--dtype",
 41        default="fp16",
 42        type=str,
 43        choices=["bf16", "fp16", "fp32"],
 44        help="Datatype used for training",
 45    )
 47    # For ZeRO Optimization.
 48    parser.add_argument(
 49        "--stage",
 50        default=0,
 51        type=int,
 52        choices=[0, 1, 2, 3],
 53        help="Datatype used for training",
 54    )
 56    # For MoE (Mixture of Experts).
 57    parser.add_argument(
 58        "--moe",
 59        default=False,
 60        action="store_true",
 61        help="use deepspeed mixture of experts (moe)",
 62    )
 63    parser.add_argument(
 64        "--ep-world-size", default=1, type=int, help="(moe) expert parallel world size"
 65    )
 66    parser.add_argument(
 67        "--num-experts",
 68        type=int,
 69        nargs="+",
 70        default=[
 71            1,
 72        ],
 73        help="number of experts list, MoE related.",
 74    )
 75    parser.add_argument(
 76        "--mlp-type",
 77        type=str,
 78        default="standard",
 79        help="Only applicable when num-experts > 1, accepts [standard, residual]",
 80    )
 81    parser.add_argument(
 82        "--top-k", default=1, type=int, help="(moe) gating top 1 and 2 supported"
 83    )
 84    parser.add_argument(
 85        "--min-capacity",
 86        default=0,
 87        type=int,
 88        help="(moe) minimum capacity of an expert regardless of the capacity_factor",
 89    )
 90    parser.add_argument(
 91        "--noisy-gate-policy",
 92        default=None,
 93        type=str,
 94        help="(moe) noisy gating (only supported with top-1). Valid values are None, RSample, and Jitter",
 95    )
 96    parser.add_argument(
 97        "--moe-param-group",
 98        default=False,
 99        action="store_true",
100        help="(moe) create separate moe param groups, required when using ZeRO w. MoE",
101    )
103    # Include DeepSpeed configuration arguments.
104    parser = deepspeed.add_config_arguments(parser)
106    args = parser.parse_args()
108    return args
111def create_moe_param_groups(model):
112    """Create separate parameter groups for each expert."""
113    parameters = {"params": [p for p in model.parameters()], "name": "parameters"}
114    return split_params_into_different_moe_groups_for_optimizer(parameters)
117def get_ds_config(args):
118    """Get the DeepSpeed configuration dictionary."""
119    ds_config = {
120        "train_batch_size": 16,
121        "steps_per_print": 2000,
122        "optimizer": {
123            "type": "Adam",
124            "params": {
125                "lr": 0.001,
126                "betas": [0.8, 0.999],
127                "eps": 1e-8,
128                "weight_decay": 3e-7,
129            },
130        },
131        "scheduler": {
132            "type": "WarmupLR",
133            "params": {
134                "warmup_min_lr": 0,
135                "warmup_max_lr": 0.001,
136                "warmup_num_steps": 1000,
137            },
138        },
139        "gradient_clipping": 1.0,
140        "prescale_gradients": False,
141        "bf16": {"enabled": args.dtype == "bf16"},
142        "fp16": {
143            "enabled": args.dtype == "fp16",
144            "fp16_master_weights_and_grads": False,
145            "loss_scale": 0,
146            "loss_scale_window": 500,
147            "hysteresis": 2,
148            "min_loss_scale": 1,
149            "initial_scale_power": 15,
150        },
151        "wall_clock_breakdown": False,
152        "zero_optimization": {
153            "stage": args.stage,
154            "allgather_partitions": True,
155            "reduce_scatter": True,
156            "allgather_bucket_size": 50000000,
157            "reduce_bucket_size": 50000000,
158            "overlap_comm": True,
159            "contiguous_gradients": True,
160            "cpu_offload": False,
161        },
162    }
163    return ds_config
166class Net(nn.Module):
167    def __init__(self, args):
168        super(Net, self).__init__()
169        self.conv1 = nn.Conv2d(3, 6, 5)
170        self.pool = nn.MaxPool2d(2, 2)
171        self.conv2 = nn.Conv2d(6, 16, 5)
172        self.fc1 = nn.Linear(16 * 5 * 5, 120)
173        self.fc2 = nn.Linear(120, 84)
174        self.moe = args.moe
175        if self.moe:
176            fc3 = nn.Linear(84, 84)
177            self.moe_layer_list = []
178            for n_e in args.num_experts:
179                # Create moe layers based on the number of experts.
180                self.moe_layer_list.append(
181                    deepspeed.moe.layer.MoE(
182                        hidden_size=84,
183                        expert=fc3,
184                        num_experts=n_e,
185                        ep_size=args.ep_world_size,
186                        use_residual=args.mlp_type == "residual",
187                        k=args.top_k,
188                        min_capacity=args.min_capacity,
189                        noisy_gate_policy=args.noisy_gate_policy,
190                    )
191                )
192            self.moe_layer_list = nn.ModuleList(self.moe_layer_list)
193            self.fc4 = nn.Linear(84, 10)
194        else:
195            self.fc3 = nn.Linear(84, 10)
197    def forward(self, x):
198        x = self.pool(F.relu(self.conv1(x)))
199        x = self.pool(F.relu(self.conv2(x)))
200        x = x.view(-1, 16 * 5 * 5)
201        x = F.relu(self.fc1(x))
202        x = F.relu(self.fc2(x))
203        if self.moe:
204            for layer in self.moe_layer_list:
205                x, _, _ = layer(x)
206            x = self.fc4(x)
207        else:
208            x = self.fc3(x)
209        return x
212def test(model_engine, testset, local_device, target_dtype, test_batch_size=4):
213    """Test the network on the test data.
215    Args:
216        model_engine (deepspeed.runtime.engine.DeepSpeedEngine): the DeepSpeed engine.
217        testset (torch.utils.data.Dataset): the test dataset.
218        local_device (str): the local device name.
219        target_dtype (torch.dtype): the target datatype for the test data.
220        test_batch_size (int): the test batch size.
222    """
223    # The 10 classes for CIFAR10.
224    classes = (
225        "plane",
226        "car",
227        "bird",
228        "cat",
229        "deer",
230        "dog",
231        "frog",
232        "horse",
233        "ship",
234        "truck",
235    )
237    # Define the test dataloader.
238    testloader = torch.utils.data.DataLoader(
239        testset, batch_size=test_batch_size, shuffle=False, num_workers=0
240    )
242    # For total accuracy.
243    correct, total = 0, 0
244    # For accuracy per class.
245    class_correct = list(0.0 for i in range(10))
246    class_total = list(0.0 for i in range(10))
248    # Start testing.
249    model_engine.eval()
250    with torch.no_grad():
251        for data in testloader:
252            images, labels = data
253            if target_dtype != None:
254                images = images.to(target_dtype)
255            outputs = model_engine(images.to(local_device))
256            _, predicted = torch.max(outputs.data, 1)
257            # Count the total accuracy.
258            total += labels.size(0)
259            correct += (predicted == labels.to(local_device)).sum().item()
261            # Count the accuracy per class.
262            batch_correct = (predicted == labels.to(local_device)).squeeze()
263            for i in range(test_batch_size):
264                label = labels[i]
265                class_correct[label] += batch_correct[i].item()
266                class_total[label] += 1
268    if model_engine.local_rank == 0:
269        print(
270            f"Accuracy of the network on the {total} test images: {100 * correct / total : .0f} %"
271        )
273        # For all classes, print the accuracy.
274        for i in range(10):
275            print(
276                f"Accuracy of {classes[i] : >5s} : {100 * class_correct[i] / class_total[i] : 2.0f} %"
277            )
280def main(args):
281    # Initialize DeepSpeed distributed backend.
282    deepspeed.init_distributed()
283    _local_rank = int(os.environ.get("LOCAL_RANK"))
284    get_accelerator().set_device(_local_rank)
286    ########################################################################
287    # Step1. Data Preparation.
288    #
289    # The output of torchvision datasets are PILImage images of range [0, 1].
290    # We transform them to Tensors of normalized range [-1, 1].
291    #
292    # Note:
293    #     If running on Windows and you get a BrokenPipeError, try setting
294    #     the num_worker of torch.utils.data.DataLoader() to 0.
295    ########################################################################
296    transform = transforms.Compose(
297        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
298    )
300    if torch.distributed.get_rank() != 0:
301        # Might be downloading cifar data, let rank 0 download first.
302        torch.distributed.barrier()
304    # Load or download cifar data.
305    trainset = torchvision.datasets.CIFAR10(
306        root="./data", train=True, download=True, transform=transform
307    )
308    testset = torchvision.datasets.CIFAR10(
309        root="./data", train=False, download=True, transform=transform
310    )
312    if torch.distributed.get_rank() == 0:
313        # Cifar data is downloaded, indicate other ranks can proceed.
314        torch.distributed.barrier()
316    ########################################################################
317    # Step 2. Define the network with DeepSpeed.
318    #
319    # First, we define a Convolution Neural Network.
320    # Then, we define the DeepSpeed configuration dictionary and use it to
321    # initialize the DeepSpeed engine.
322    ########################################################################
323    net = Net(args)
325    # Get list of parameters that require gradients.
326    parameters = filter(lambda p: p.requires_grad, net.parameters())
328    # If using MoE, create separate param groups for each expert.
329    if args.moe_param_group:
330        parameters = create_moe_param_groups(net)
332    # Initialize DeepSpeed to use the following features.
333    #   1) Distributed model.
334    #   2) Distributed data loader.
335    #   3) DeepSpeed optimizer.
336    ds_config = get_ds_config(args)
337    model_engine, optimizer, trainloader, __ = deepspeed.initialize(
338        args=args,
339        model=net,
340        model_parameters=parameters,
341        training_data=trainset,
342        config=ds_config,
343    )
345    # Get the local device name (str) and local rank (int).
346    local_device = get_accelerator().device_name(model_engine.local_rank)
347    local_rank = model_engine.local_rank
349    # For float32, target_dtype will be None so no datatype conversion needed.
350    target_dtype = None
351    if model_engine.bfloat16_enabled():
352        target_dtype = torch.bfloat16
353    elif model_engine.fp16_enabled():
354        target_dtype = torch.half
356    # Define the Classification Cross-Entropy loss function.
357    criterion = nn.CrossEntropyLoss()
359    ########################################################################
360    # Step 3. Train the network.
361    #
362    # This is when things start to get interesting.
363    # We simply have to loop over our data iterator, and feed the inputs to the
364    # network and optimize. (DeepSpeed handles the distributed details for us!)
365    ########################################################################
367    for epoch in range(args.epochs):  # loop over the dataset multiple times
368        running_loss = 0.0
369        for i, data in enumerate(trainloader):
370            # Get the inputs. ``data`` is a list of [inputs, labels].
371            inputs, labels = data[0].to(local_device), data[1].to(local_device)
373            # Try to convert to target_dtype if needed.
374            if target_dtype != None:
375                inputs = inputs.to(target_dtype)
377            outputs = model_engine(inputs)
378            loss = criterion(outputs, labels)
380            model_engine.backward(loss)
381            model_engine.step()
383            # Print statistics
384            running_loss += loss.item()
385            if local_rank == 0 and i % args.log_interval == (
386                args.log_interval - 1
387            ):  # Print every log_interval mini-batches.
388                print(
389                    f"[{epoch + 1 : d}, {i + 1 : 5d}] loss: {running_loss / args.log_interval : .3f}"
390                )
391                running_loss = 0.0
392    print("Finished Training")
394    ########################################################################
395    # Step 4. Test the network on the test data.
396    ########################################################################
397    test(model_engine, testset, local_device, target_dtype)
400if __name__ == "__main__":
401    args = add_argument()
402    main(args)

2. 训练结果查看


 1Finished Training
 2Accuracy of the network on the 10000 test images:  57 %
 3Accuracy of plane :  65 %
 4Accuracy of   car :  67 %
 5Accuracy of  bird :  52 %
 6Accuracy of   cat :  34 %
 7Accuracy of  deer :  52 %
 8Accuracy of   dog :  49 %
 9Accuracy of  frog :  59 %
10Accuracy of horse :  66 %
11Accuracy of  ship :  66 %
12Accuracy of truck :  56 %