distributed_training.md 18.8 KB
Newer Older
1
# Getting Started with Parallel Distributed Training
L
leiyuning 已提交
2 3 4

<!-- TOC -->

5
- [Getting Started with Parallel Distributed Training](#getting-started-with-parallel-distributed-training)
L
leiyuning 已提交
6 7
    - [Overview](#overview)
    - [Preparations](#preparations)
8
        - [Downloading the Dataset](#downloading-the-dataset)
L
leiyuning 已提交
9
        - [Configuring Distributed Environment Variables](#configuring-distributed-environment-variables)
10 11
        - [Calling the Collective Communication Library](#calling-the-collective-communication-library)
    - [Loading the Dataset in Data Parallel Mode](#loading-the-dataset-in-data-parallel-mode)
L
leiyuning 已提交
12 13 14 15 16
    - [Defining the Network](#defining-the-network)
    - [Defining the Loss Function and Optimizer](#defining-the-loss-function-and-optimizer)
        - [Defining the Loss Function](#defining-the-loss-function)
        - [Defining the Optimizer](#defining-the-optimizer)
    - [Training the Network](#training-the-network)
17
    - [Running the Script](#running-the-script)
L
leiyuning 已提交
18 19 20

<!-- /TOC -->

21
<a href="https://gitee.com/mindspore/docs/blob/master/tutorials/source_en/advanced_use/distributed_training.md" target="_blank"><img src="../_static/logo_source.png"></a>
22

L
leiyuning 已提交
23
## Overview
24
In deep learning, the increasing number of datasets and parameters prolongs the training time and requires more hardware resources, becoming a training bottleneck. Parallel distributed training is an important optimization method for training, which can reduce requirements on hardware, such as memory and computing performance. Based on different parallel principles and modes, parallelism is generally classified into the following types:
L
leiyuning 已提交
25

26 27 28
- Data parallelism: splits data into many batches and then allocates the batches to each worker for model computation.
- Model parallelism: splits a model. MindSpore supports the intra-layer model parallelism. Parameters are split and then allocated to each worker for training.
- Hybrid parallelism: contains data parallelism and model parallelism.
L
leiyuning 已提交
29

30 31 32 33
MindSpore also provides the parallel distributed training function. It supports the following modes:
- `DATA_PARALLEL`: data parallelism.
- `AUTO_PARALLEL`: automatic parallelism, which integrates data parallelism, model parallelism, and hybrid parallelism. A cost model can be automatically created to select one parallel mode for users. Creating a cost model refers to modeling the training time based on the memory-based computation and communication overheads of the Ascend 910 chip, and designing efficient algorithms to develop a parallel strategy with a relatively short training time.
- `HYBRID_PARALLEL`: On MindSpore, users manually split parameters to implement intra-layer model parallelism.
L
leiyuning 已提交
34

35 36 37
This tutorial describes how to train the ResNet-50 network in data parallel and automatic parallel modes on MindSpore.
> The example in this tutorial applies to hardware platforms based on the Ascend 910 AI processor, whereas does not support CPU and GPU scenarios.
> Download address of the complete sample code: <https://gitee.com/mindspore/docs/blob/master/tutorials/tutorial_code/distributed_training/resnet50_distributed_training.py>
L
leiyuning 已提交
38 39 40

## Preparations

41 42 43 44 45 46 47 48
### Downloading the Dataset

This sample uses the `CIFAR-10` dataset, which consists of color images of 32 x 32 pixels in 10 classes, with 6000 images per class. There are 50,000 images in the training set and 10,000 images in the test set.

> `CIFAR-10` dataset download address: <https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz>

Download the dataset and decompress it to a local path. The folder generated after the decompression is `cifar-10-batches-bin`.

L
leiyuning 已提交
49 50
### Configuring Distributed Environment Variables

51
When distributed training is performed in the bare-metal environment (compared with the cloud environment where the Ascend 910 AI processor is deployed on the local host), you need to configure the networking information file for the current multi-device environment. If the HUAWEI CLOUD environment is used, skip this section because the cloud service has been configured.
L
leiyuning 已提交
52

53
The following uses the Ascend 910 AI processor as an example. The JSON configuration file for an environment with eight devices is as follows. In this example, the configuration file is named `rank_table_8pcs.json`. For details about how to configure the 2-device environment, see the `rank_table_2pcs.json` file in the sample code.
L
leiyuning 已提交
54 55 56 57 58 59 60 61 62

```json
{
    "board_id": "0x0000",
    "chip_info": "910",
    "deploy_mode": "lab",
    "group_count": "1",
    "group_list": [
        {
63
            "device_num": "8",
L
leiyuning 已提交
64 65
            "server_num": "1",
            "group_name": "",
66
            "instance_count": "8",
L
leiyuning 已提交
67
            "instance_list": [
68 69 70 71 72 73 74 75 76
                {"devices": [{"device_id": "0","device_ip": "192.1.27.6"}],"rank_id": "0","server_id": "10.155.111.140"},
                {"devices": [{"device_id": "1","device_ip": "192.2.27.6"}],"rank_id": "1","server_id": "10.155.111.140"},
                {"devices": [{"device_id": "2","device_ip": "192.3.27.6"}],"rank_id": "2","server_id": "10.155.111.140"},
                {"devices": [{"device_id": "3","device_ip": "192.4.27.6"}],"rank_id": "3","server_id": "10.155.111.140"},
                {"devices": [{"device_id": "4","device_ip": "192.1.27.7"}],"rank_id": "4","server_id": "10.155.111.140"},
                {"devices": [{"device_id": "5","device_ip": "192.2.27.7"}],"rank_id": "5","server_id": "10.155.111.140"},
                {"devices": [{"device_id": "6","device_ip": "192.3.27.7"}],"rank_id": "6","server_id": "10.155.111.140"},
                {"devices": [{"device_id": "7","device_ip": "192.4.27.7"}],"rank_id": "7","server_id": "10.155.111.140"},
                ]
L
leiyuning 已提交
77 78 79
        }
    ],
    "para_plane_nic_location": "device",
80 81
    "para_plane_nic_name": ["eth0","eth1","eth2","eth3","eth4","eth5","eth6","eth7"],
    "para_plane_nic_num": "8",
L
leiyuning 已提交
82 83 84 85 86 87
    "status": "completed"
}

```
The following parameters need to be modified based on the actual training environment:

88 89 90 91 92 93 94 95
- `board_id`: current running environment. Set this parameter to `0x0000` for x86, and to `0x0020` for ARM.
- `server_num`: number of hosts.
- `server_id`: IP address of the local host.
- `device_num`, `para_plane_nic_num`, and `instance_count`: number of devices.
- `rank_id`: logical sequence number of a device, which starts from 0.
- `device_id`: physical sequence number of a device, that is, the actual sequence number of the device on the corresponding host.
- `device_ip`: IP address of the integrated NIC. You can run the `cat /etc/hccn.conf` command on the current host. The key value of `address_x` is the IP address of the NIC.
- `para_plane_nic_name`: name of the corresponding NIC.
L
leiyuning 已提交
96 97


98
### Calling the Collective Communication Library
L
leiyuning 已提交
99

100 101 102 103 104
The Huawei Collective Communication Library (HCCL) is used for the communication of MindSpore parallel distributed training and can be found in the Ascend 310 AI processor software package. In addition, `mindspore.communication.management` encapsulates the collective communication API provided by the HCCL to help users configure distributed information.
> HCCL implements multi-device multi-node communication based on the Ascend AI processor. The common restrictions on using the distributed service are as follows. For details, see the HCCL documentation.
> - In a single-node system, a cluster of 1, 2, 4, or 8 devices is supported. In a multi-node system, a cluster of 8 x N devices is supported.
> - Each host has four devices numbered 0 to 3 and four devices numbered 4 to 7 deployed on two different networks. During training of 2 or 4 devices, the devices must be connected and clusters cannot be created across networks.
> - The server hardware architecture and operating system require the symmetrical multi-processing (SMP) mode.
L
leiyuning 已提交
105

106
The sample code for calling the HCCL as follows:
L
leiyuning 已提交
107 108 109 110 111 112 113

```python
import os
from mindspore import context
from mindspore.communication.management import init

if __name__ == "__main__":
Z
zjun 已提交
114
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=int(os.environ["DEVICE_ID"]))
L
leiyuning 已提交
115 116 117 118
    init()
    ...   
```

119 120 121 122
In the preceding code:  
- `mode=context.GRAPH_MODE`: sets the running mode to graph mode for distributed training. (The PyNative mode does not support parallel running.)
- `device_id`: physical sequence number of a device, that is, the actual sequence number of the device on the corresponding host.
- `init()`: enables HCCL communication and completes the distributed training initialization.
L
leiyuning 已提交
123

124 125 126
## Loading the Dataset in Data Parallel Mode

During distributed training, data is imported in data parallel mode. The following takes the CIFAR-10 dataset as an example to describe how to import the CIFAR-10 dataset in data parallel mode. `data_path` indicates the dataset path, which is also the path of the `cifar-10-batches-bin` folder.
L
leiyuning 已提交
127 128 129 130 131 132 133 134 135


```python
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore.communication.management import get_rank, get_group_size

136
def create_dataset(data_path, repeat_num=1, batch_size=32, rank_id=0, rank_size=1):
L
leiyuning 已提交
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
    resize_height = 224
    resize_width = 224
    rescale = 1.0 / 255.0
    shift = 0.0
    
    # get rank_id and rank_size
    rank_id = get_rank()
    rank_size = get_group_size()
    data_set = ds.Cifar10Dataset(data_path, num_shards=rank_size, shard_id=rank_id)
    
    # define map operations
    random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4))
    random_horizontal_op = vision.RandomHorizontalFlip()
    resize_op = vision.Resize((resize_height, resize_width))
    rescale_op = vision.Rescale(rescale, shift)
    normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023))
    changeswap_op = vision.HWC2CHW()
    type_cast_op = C.TypeCast(mstype.int32)

    c_trans = [random_crop_op, random_horizontal_op]
    c_trans += [resize_op, rescale_op, normalize_op, changeswap_op]

    # apply map operations on images
    data_set = data_set.map(input_columns="label", operations=type_cast_op)
    data_set = data_set.map(input_columns="image", operations=c_trans)

    # apply shuffle operations
    data_set = data_set.shuffle(buffer_size=10)

    # apply batch operations
    data_set = data_set.batch(batch_size=batch_size, drop_remainder=True)

Y
Yanjun Peng 已提交
169 170 171
    # apply repeat operations
    data_set = data_set.repeat(repeat_num)

L
leiyuning 已提交
172 173
    return data_set
```
174 175 176
Different from the single-node system, the multi-node system needs to transfer the `num_shards` and `shard_id` parameters to the dataset API. The two parameters correspond to the number of devices and logical sequence numbers of devices, respectively. You are advised to obtain the parameters through the HCCL API.  
- `get_rank`: obtains the ID of the current device in the cluster.
- `get_group_size`: obtains the number of devices.
L
leiyuning 已提交
177 178 179

## Defining the Network

180
In data parallel and automatic parallel modes, the network definition method is the same as that in a single-node system. The reference code is as follows: <https://gitee.com/mindspore/docs/blob/master/tutorials/tutorial_code/resnet/resnet.py>
L
leiyuning 已提交
181 182 183 184 185

## Defining the Loss Function and Optimizer

### Defining the Loss Function

186 187 188
Automatic parallelism splits models using the operator granularity and obtains the optimal parallel strategy through algorithm search. Therefore, to achieve a better parallel training effect, you are advised to use small operators to implement the loss function.

In the Loss function, the `SoftmaxCrossEntropyWithLogits` is expanded into multiple small operators for implementation according to a mathematical formula. The sample code is as follows:
L
leiyuning 已提交
189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231

```python
from mindspore.ops import operations as P
from mindspore import Tensor
import mindspore.ops.functional as F
import mindspore.common.dtype as mstype
import mindspore.nn as nn

class SoftmaxCrossEntropyExpand(nn.Cell):
    def __init__(self, sparse=False):
        super(SoftmaxCrossEntropyExpand, self).__init__()
        self.exp = P.Exp()
        self.sum = P.ReduceSum(keep_dims=True)
        self.onehot = P.OneHot()
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)
        self.div = P.Div()
        self.log = P.Log()
        self.sum_cross_entropy = P.ReduceSum(keep_dims=False)
        self.mul = P.Mul()
        self.mul2 = P.Mul()
        self.mean = P.ReduceMean(keep_dims=False)
        self.sparse = sparse
        self.max = P.ReduceMax(keep_dims=True)
        self.sub = P.Sub()
        
    def construct(self, logit, label):
        logit_max = self.max(logit, -1)
        exp = self.exp(self.sub(logit, logit_max))
        exp_sum = self.sum(exp, -1)
        softmax_result = self.div(exp, exp_sum)
        if self.sparse:
            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
        softmax_result_log = self.log(softmax_result)
        loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
        loss = self.mul2(F.scalar_to_array(-1.0), loss)
        loss = self.mean(loss, -1)

        return loss
```

### Defining the Optimizer

232
The `Momentum` optimizer is used as the parameter update tool. The definition is the same as that in the single-node system. For details, see the implementation in the sample code.
L
leiyuning 已提交
233 234 235

## Training the Network

236
`context.set_auto_parallel_context()` is an API for users to set parallel training parameters and must be called before the initialization of `Model`. If no parameters are specified, MindSpore will automatically set parameters to the empirical values based on the parallel mode. For example, in data parallel mode, `parameter_broadcast` is enabled by default. The related parameters are as follows:
L
leiyuning 已提交
237

238 239 240
- `parallel_mode`: parallel distributed mode. The default value is `ParallelMode.STAND_ALONE`. The options are `ParallelMode.DATA_PARALLEL` and `ParallelMode.AUTO_PARALLEL`.
- `parameter_broadcast`: whether to broadcast initialized parameters. The default value is `True` in `DATA_PARALLEL` and `HYBRID_PARALLEL` mode.
- `mirror_mean`: During backward computation, the framework collects gradients of parameters in data parallel mode across multiple hosts, obtains the global gradient value, and transfers the global gradient value to the optimizer for update. The default value is `False`, which indicates that the `allreduce_sum` operation is applied. The value `True` indicates that the `allreduce_mean` operation is applied.
L
leiyuning 已提交
241

242
> You are advised to set `device_num` and `global_rank` to their default values. The framework calls the HCCL API to obtain the values.
L
leiyuning 已提交
243

244 245 246
If multiple network cases exist in the script, call `context.reset_auto_parallel_context()` to restore all parameters to default values before executing the next case.

In the following sample code, the automatic parallel mode is specified. To switch to the data parallel mode, you only need to change `parallel_mode` to `DATA_PARALLEL`.
L
leiyuning 已提交
247 248

```python
249
from mindspore import context
L
leiyuning 已提交
250 251 252 253 254
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import LossMonitor
from mindspore.train.model import Model, ParallelMode
from resnet import resnet50

255 256 257 258
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(device_id=device_id) # set device_id

L
leiyuning 已提交
259 260 261 262 263 264 265 266
def test_train_cifar(num_classes=10, epoch_size=10):
    context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
    loss_cb = LossMonitor()
    dataset = create_dataset(epoch_size)
    net = resnet50(32, num_classes)
    loss = SoftmaxCrossEntropyExpand(sparse=True)
    opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
    model = Model(net, loss_fn=loss, optimizer=opt)
267
    model.train(epoch_size, dataset, callbacks=[loss_cb], dataset_sink_mode=True)
L
leiyuning 已提交
268
```
269 270 271
In the preceding code:  
- `dataset_sink_mode=True`: uses the dataset sink mode. That is, the training computing is sunk to the hardware platform for execution.
- `LossMonitor`: returns the loss value through the callback function to monitor the loss function.
L
leiyuning 已提交
272

273 274
## Running the Script
After the script required for training is edited, run the corresponding command to call the script.
L
leiyuning 已提交
275

276
Currently, MindSpore distributed execution uses the single-device single-process running mode. That is, one process runs on each device, and the number of total processes is the same as the number of devices that are being used. For device 0, the corresponding process is executed in the foreground. For other devices, the corresponding processes are executed in the background. You need to create a directory for each process to store log information and operator compilation information. The following takes the distributed training script for eight devices as an example to describe how to run the script:
L
leiyuning 已提交
277 278

```bash
279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
#!/bin/bash

DATA_PATH=$1
export DATA_PATH=${DATA_PATH}
RANK_SIZE=$2

EXEC_PATH=$(pwd)

test_dist_8pcs()
{
    export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table_8pcs.json
    export RANK_SIZE=8
}

test_dist_2pcs()
{
    export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table_2pcs.json
    export RANK_SIZE=2
}

test_dist_${RANK_SIZE}pcs

for((i=1;i<${RANK_SIZE};i++))
do
    rm -rf device$i
    mkdir device$i
    cp ./resnet50_distributed_training.py ./resnet.py ./device$i
    cd ./device$i
    export DEVICE_ID=$i
    export RANK_ID=$i
    echo "start training for device $i"
    env > env$i.log
    pytest -s -v ./resnet50_distributed_training.py > train.log$i 2>&1 &
    cd ../
done
rm -rf device0
mkdir device0
cp ./resnet50_distributed_training.py ./resnet.py ./device0
cd ./device0
export DEVICE_ID=0
export RANK_ID=0
echo "start training for device 0"
env > env0.log
pytest -s -v ./resnet50_distributed_training.py > train.log0 2>&1
if [ $? -eq 0 ];then
    echo "training success"
else
    echo "training failed"
    exit 2
fi
cd ../
L
leiyuning 已提交
330 331
```

332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358
The variables `DATA_PATH` and `RANK_SIZE` need to be transferred to the script, which indicate the path of the dataset and the number of devices, respectively.

The necessary environment variables are as follows:  
- `MINDSPORE_HCCL_CONFIG_PATH`: path for storing the networking information file.
- `DEVICE_ID`: actual sequence number of the current device on the corresponding host.
- `RANK_ID`: logical sequence number of the current device.
For details about other environment variables, see configuration items in the installation guide.

The running time is about 5 minutes, which is mainly occupied by operator compilation. The actual training time is within 20 seconds. You can use `ps -ef | grep pytest` to monitor task processes.

Log files are saved in the device directory. The env.log file records environment variable information. The train.log file records the loss function information. The following is an example:

```
resnet50_distributed_training.py::test_train_feed ===============ds_num 195
global_step: 194, loss: 1.997
global_step: 389, loss: 1.655
global_step: 584, loss: 1.723
global_step: 779, loss: 1.807
global_step: 974, loss: 1.417
global_step: 1169, loss: 1.195
global_step: 1364, loss: 1.238
global_step: 1559, loss: 1.456
global_step: 1754, loss: 0.987
global_step: 1949, loss: 1.035
end training
PASSED
```