提交 037d8432 编写于 作者: Z Ziyan

fix distributd training tutroial and code

上级 601b3e29
......@@ -27,6 +27,7 @@
当前MindSpore也提供分布式并行训练的功能。它支持了多种模式包括:
- `DATA_PARALLEL`:数据并行模式。
- `AUTO_PARALLEL`:自动并行模式,融合了数据并行、模型并行及混合并行的1种分布式并行模式,可以自动建立代价模型,为用户选择1种并行模式。其中,代价模型指围绕Ascend 910芯片基于内存的计算开销和通信开销对训练时间建模,并设计高效的算法找到训练时间较短的并行策略。
- `HYBRID_PARALLEL`: 在MindSpore中指用户手动切分参数实现层内模型并行的场景。
本篇教程我们主要讲解如何在MindSpore上通过数据并行及自动并行模式训练ResNet-50网络。
> 本例面向Ascend 910 AI处理器硬件平台,暂不支持CPU和GPU场景。
......@@ -34,6 +35,14 @@
## 准备环节
### 下载数据集
本样例采用`CIFAR-10`数据集,由10类32*32的彩色图片组成,每类包含6000张图片。其中训练集共50000张图片,测试集共10000张图片。
> `CIFAR-10`数据集下载链接:<http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz>。
将数据集下载并解压到本地路径下,这里将数据集解压存放到工作区的`./dataset`路径下。
### 配置分布式环境变量
在裸机环境(对比云上环境,即本地有Ascend 910 AI 处理器)进行分布式训练时,需要配置当前多卡环境的组网信息文件。如果使用华为云环境,因为云服务本身已经做好了配置,可以跳过本小节。
......@@ -109,7 +118,7 @@ if __name__ == "__main__":
## 数据并行模式加载数据集
分布式训练时,数据是以数据并行的方式导入的。下面我们以CIFAR-10数据集为例,介绍以数据并行方式导入CIFAR-10数据集的方法,`data_path`是指数据集的路径。
分布式训练时,数据是以数据并行的方式导入的。下面我们以CIFAR-10数据集为例,介绍以数据并行方式导入CIFAR-10数据集的方法,`data_path`是指数据集的路径,在样例代码中采用工作区下`dataset/cifar-10-batches-bin`文件夹的路径
```python
......@@ -119,7 +128,7 @@ import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.transforms.vision.c_transforms as vision
from mindspore.communication.management import get_rank, get_group_size
def create_dataset(repeat_num=1, batch_size=32, rank_id=0, rank_size=1):
def create_dataset(data_path=data_path, repeat_num=1, batch_size=32, rank_id=0, rank_size=1):
resize_height = 224
resize_width = 224
rescale = 1.0 / 255.0
......@@ -227,6 +236,8 @@ class SoftmaxCrossEntropyExpand(nn.Cell):
> `device_num`和`global_rank`建议采用默认值,框架内会调用HCCL接口获取。
如脚本中存在多个网络用例,请在执行下个用例前调用`context.reset_auto_parallel_context()`将所有参数还原到默认值。
在下面的样例中我们指定并行模式为自动并行,用户如需切换为数据并行模式,只需将`parallel_mode`改为`DATA_PARALLEL`
```python
......@@ -263,15 +274,18 @@ def test_train_cifar(num_classes=10, epoch_size=10):
```bash
#!/bin/bash
export RANK_TABLE_FILE=./rank_table.json
EXEC_PATH=$(pwd)
export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table.json
export RANK_SIZE=8
for((i=0;i<$RANK_SIZE;i++))
do
rm -rf device$i
mkdir device$i
cp ./resnet50_distributed_training.py ./device$i
cp ./resnet50_distributed_training.py ./resnet.py ./device$i
cd ./device$i
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
pytest -s -v ./resnet50_distributed_training.py > train.log$i 2>&1 &
......@@ -280,8 +294,9 @@ done
```
其中必要的环境变量有,
- `RANK_TABLE_FILE`:组网信息文件的路径。
- `MINDSPORE_HCCL_CONFIG_PATH`:组网信息文件的路径。
- `DEVICE_ID`:当前网卡在机器上的实际序号。
- `RANK_ID`: 当前网卡的逻辑序号。
其余环境变量请参考安装教程中的配置项。
运行时间大约在5分钟内,主要时间是用于算子的编译,实际训练时间在20秒内。用户可以通过`ps -ef | grep pytest`来监控任务进程。
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''resnet
The sample can be run on Ascend 910 AI processor.
'''
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common.initializer import initializer
from mindspore.common import dtype as mstype
def weight_variable(shape):
"""weight_variable"""
return initializer('XavierUniform', shape=shape, dtype=mstype.float32)
def weight_variable_uniform(shape):
"""weight_variable_uniform"""
return initializer('Uniform', shape=shape, dtype=mstype.float32)
def weight_variable_0(shape):
"""weight_variable_0"""
zeros = np.zeros(shape).astype(np.float32)
return Tensor(zeros)
def weight_variable_1(shape):
"""weight_variable_1"""
ones = np.ones(shape).astype(np.float32)
return Tensor(ones)
def conv3x3(in_channels, out_channels, stride=1, padding=0):
"""3x3 convolution """
weight_shape = (out_channels, in_channels, 3, 3)
weight = weight_variable(weight_shape)
return nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
def conv1x1(in_channels, out_channels, stride=1, padding=0):
"""1x1 convolution"""
weight_shape = (out_channels, in_channels, 1, 1)
weight = weight_variable(weight_shape)
return nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
def conv7x7(in_channels, out_channels, stride=1, padding=0):
"""1x1 convolution"""
weight_shape = (out_channels, in_channels, 7, 7)
weight = weight_variable(weight_shape)
return nn.Conv2d(in_channels, out_channels,
kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same")
def bn_with_initialize(out_channels):
"""bn_with_initialize"""
shape = (out_channels)
mean = weight_variable_0(shape)
var = weight_variable_1(shape)
beta = weight_variable_0(shape)
gamma = weight_variable_uniform(shape)
bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma,
beta_init=beta, moving_mean_init=mean, moving_var_init=var)
return bn
def bn_with_initialize_last(out_channels):
"""bn_with_initialize_last"""
shape = (out_channels)
mean = weight_variable_0(shape)
var = weight_variable_1(shape)
beta = weight_variable_0(shape)
gamma = weight_variable_uniform(shape)
bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma,
beta_init=beta, moving_mean_init=mean, moving_var_init=var)
return bn
def fc_with_initialize(input_channels, out_channels):
"""fc_with_initialize"""
weight_shape = (out_channels, input_channels)
weight = weight_variable(weight_shape)
bias_shape = (out_channels)
bias = weight_variable_uniform(bias_shape)
return nn.Dense(input_channels, out_channels, weight, bias)
class ResidualBlock(nn.Cell):
"""ResidualBlock"""
expansion = 4
def __init__(self,
in_channels,
out_channels,
stride=1):
"""init block"""
super(ResidualBlock, self).__init__()
out_chls = out_channels // self.expansion
self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0)
self.bn1 = bn_with_initialize(out_chls)
self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0)
self.bn2 = bn_with_initialize(out_chls)
self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
self.bn3 = bn_with_initialize_last(out_channels)
self.relu = P.ReLU()
self.add = P.TensorAdd()
def construct(self, x):
"""construct"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out = self.add(out, identity)
out = self.relu(out)
return out
class ResidualBlockWithDown(nn.Cell):
"""ResidualBlockWithDown"""
expansion = 4
def __init__(self,
in_channels,
out_channels,
stride=1,
down_sample=False):
"""init block with down"""
super(ResidualBlockWithDown, self).__init__()
out_chls = out_channels // self.expansion
self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0)
self.bn1 = bn_with_initialize(out_chls)
self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0)
self.bn2 = bn_with_initialize(out_chls)
self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
self.bn3 = bn_with_initialize_last(out_channels)
self.relu = P.ReLU()
self.down_sample = down_sample
self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0)
self.bn_down_sample = bn_with_initialize(out_channels)
self.add = P.TensorAdd()
def construct(self, x):
"""construct"""
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
identity = self.conv_down_sample(identity)
identity = self.bn_down_sample(identity)
out = self.add(out, identity)
out = self.relu(out)
return out
class MakeLayer0(nn.Cell):
"""MakeLayer0"""
def __init__(self, block, in_channels, out_channels, stride):
"""init"""
super(MakeLayer0, self).__init__()
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True)
self.b = block(out_channels, out_channels, stride=stride)
self.c = block(out_channels, out_channels, stride=1)
def construct(self, x):
"""construct"""
x = self.a(x)
x = self.b(x)
x = self.c(x)
return x
class MakeLayer1(nn.Cell):
"""MakeLayer1"""
def __init__(self, block, in_channels, out_channels, stride):
"""init"""
super(MakeLayer1, self).__init__()
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
self.b = block(out_channels, out_channels, stride=1)
self.c = block(out_channels, out_channels, stride=1)
self.d = block(out_channels, out_channels, stride=1)
def construct(self, x):
"""construct"""
x = self.a(x)
x = self.b(x)
x = self.c(x)
x = self.d(x)
return x
class MakeLayer2(nn.Cell):
"""MakeLayer2"""
def __init__(self, block, in_channels, out_channels, stride):
"""init"""
super(MakeLayer2, self).__init__()
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
self.b = block(out_channels, out_channels, stride=1)
self.c = block(out_channels, out_channels, stride=1)
self.d = block(out_channels, out_channels, stride=1)
self.e = block(out_channels, out_channels, stride=1)
self.f = block(out_channels, out_channels, stride=1)
def construct(self, x):
"""construct"""
x = self.a(x)
x = self.b(x)
x = self.c(x)
x = self.d(x)
x = self.e(x)
x = self.f(x)
return x
class MakeLayer3(nn.Cell):
"""MakeLayer3"""
def __init__(self, block, in_channels, out_channels, stride):
"""init"""
super(MakeLayer3, self).__init__()
self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True)
self.b = block(out_channels, out_channels, stride=1)
self.c = block(out_channels, out_channels, stride=1)
def construct(self, x):
"""construct"""
x = self.a(x)
x = self.b(x)
x = self.c(x)
return x
class ResNet(nn.Cell):
"""ResNet"""
def __init__(self, block, num_classes=100, batch_size=32):
"""init"""
super(ResNet, self).__init__()
self.batch_size = batch_size
self.num_classes = num_classes
self.conv1 = conv7x7(3, 64, stride=2, padding=0)
self.bn1 = bn_with_initialize(64)
self.relu = P.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1)
self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2)
self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2)
self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2)
self.pool = P.ReduceMean(keep_dims=True)
self.squeeze = P.Squeeze(axis=(2, 3))
self.fc = fc_with_initialize(512 * block.expansion, num_classes)
def construct(self, x):
"""construct"""
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.pool(x, (2, 3))
x = self.squeeze(x)
x = self.fc(x)
return x
def resnet50(batch_size, num_classes):
"""create resnet50"""
return ResNet(ResidualBlock, num_classes, batch_size)
......@@ -43,7 +43,9 @@ init()
rank_id = get_rank()
rank_size = get_group_size()
def create_dataset(repeat_num=1, batch_size=32, rank_id=0, rank_size=1):
EXEC_PATH=os.getcwd()
def create_dataset(data_path, repeat_num=1, batch_size=32, rank_id=0, rank_size=1):
resize_height = 224
resize_width = 224
rescale = 1.0 / 255.0
......@@ -120,7 +122,7 @@ class SoftmaxCrossEntropyExpand(nn.Cell):
def test_train_cifar(num_classes=10, epoch_size=10):
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True)
loss_cb = LossMonitor()
dataset = create_dataset(epoch_size)
dataset = create_dataset(os.path.join(EXEC_PATH, '../dataset/cifar-10-batches-bin/'), epoch_size)
net = resnet50(32, num_classes)
loss = SoftmaxCrossEntropyExpand(sparse=True)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9)
......
#!/bin/bash
export RANK_TABLE_FILE=./rank_table.json
EXEC_PATH=$(pwd)
export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table.json
export RANK_SIZE=8
for((i=0;i<$RANK_SIZE;i++))
do
rm -rf device$i
mkdir device$i
cp ./resnet50_distributed_training.py ./device$i
cp ./resnet50_distributed_training.py ./resnet.py ./device$i
cd ./device$i
export DEVICE_ID=$i
export RANK_ID=$i
echo "start training for device $i"
env > env$i.log
pytest -s -v ./resnet50_distributed_training.py > train.log$i 2>&1 &
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册