From 037d843288ca5cca890f56568168556ae6bafe92 Mon Sep 17 00:00:00 2001 From: Ziyan Date: Mon, 11 May 2020 15:32:43 +0800 Subject: [PATCH] fix distributd training tutroial and code --- .../advanced_use/distributed_training.md | 25 +- .../distributed_training/resnet.py | 329 ++++++++++++++++++ .../resnet50_distributed_training.py | 6 +- .../tutorial_code/distributed_training/run.sh | 7 +- 4 files changed, 358 insertions(+), 9 deletions(-) create mode 100644 tutorials/tutorial_code/distributed_training/resnet.py diff --git a/tutorials/source_zh_cn/advanced_use/distributed_training.md b/tutorials/source_zh_cn/advanced_use/distributed_training.md index 33effbcb..3fd78a51 100644 --- a/tutorials/source_zh_cn/advanced_use/distributed_training.md +++ b/tutorials/source_zh_cn/advanced_use/distributed_training.md @@ -27,6 +27,7 @@ 当前MindSpore也提供分布式并行训练的功能。它支持了多种模式包括: - `DATA_PARALLEL`:数据并行模式。 - `AUTO_PARALLEL`:自动并行模式,融合了数据并行、模型并行及混合并行的1种分布式并行模式,可以自动建立代价模型,为用户选择1种并行模式。其中,代价模型指围绕Ascend 910芯片基于内存的计算开销和通信开销对训练时间建模,并设计高效的算法找到训练时间较短的并行策略。 +- `HYBRID_PARALLEL`: 在MindSpore中指用户手动切分参数实现层内模型并行的场景。 本篇教程我们主要讲解如何在MindSpore上通过数据并行及自动并行模式训练ResNet-50网络。 > 本例面向Ascend 910 AI处理器硬件平台,暂不支持CPU和GPU场景。 @@ -34,6 +35,14 @@ ## 准备环节 +### 下载数据集 + +本样例采用`CIFAR-10`数据集,由10类32*32的彩色图片组成,每类包含6000张图片。其中训练集共50000张图片,测试集共10000张图片。 + +> `CIFAR-10`数据集下载链接:。 + +将数据集下载并解压到本地路径下,这里将数据集解压存放到工作区的`./dataset`路径下。 + ### 配置分布式环境变量 在裸机环境(对比云上环境,即本地有Ascend 910 AI 处理器)进行分布式训练时,需要配置当前多卡环境的组网信息文件。如果使用华为云环境,因为云服务本身已经做好了配置,可以跳过本小节。 @@ -109,7 +118,7 @@ if __name__ == "__main__": ## 数据并行模式加载数据集 -分布式训练时,数据是以数据并行的方式导入的。下面我们以CIFAR-10数据集为例,介绍以数据并行方式导入CIFAR-10数据集的方法,`data_path`是指数据集的路径。 +分布式训练时,数据是以数据并行的方式导入的。下面我们以CIFAR-10数据集为例,介绍以数据并行方式导入CIFAR-10数据集的方法,`data_path`是指数据集的路径,在样例代码中采用工作区下`dataset/cifar-10-batches-bin`文件夹的路径。 ```python @@ -119,7 +128,7 @@ import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.transforms.vision.c_transforms as vision from mindspore.communication.management import get_rank, get_group_size -def create_dataset(repeat_num=1, batch_size=32, rank_id=0, rank_size=1): +def create_dataset(data_path=data_path, repeat_num=1, batch_size=32, rank_id=0, rank_size=1): resize_height = 224 resize_width = 224 rescale = 1.0 / 255.0 @@ -227,6 +236,8 @@ class SoftmaxCrossEntropyExpand(nn.Cell): > `device_num`和`global_rank`建议采用默认值,框架内会调用HCCL接口获取。 +如脚本中存在多个网络用例,请在执行下个用例前调用`context.reset_auto_parallel_context()`将所有参数还原到默认值。 + 在下面的样例中我们指定并行模式为自动并行,用户如需切换为数据并行模式,只需将`parallel_mode`改为`DATA_PARALLEL`。 ```python @@ -263,15 +274,18 @@ def test_train_cifar(num_classes=10, epoch_size=10): ```bash #!/bin/bash -export RANK_TABLE_FILE=./rank_table.json +EXEC_PATH=$(pwd) +export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table.json export RANK_SIZE=8 + for((i=0;i<$RANK_SIZE;i++)) do rm -rf device$i mkdir device$i - cp ./resnet50_distributed_training.py ./device$i + cp ./resnet50_distributed_training.py ./resnet.py ./device$i cd ./device$i export DEVICE_ID=$i + export RANK_ID=$i echo "start training for device $i" env > env$i.log pytest -s -v ./resnet50_distributed_training.py > train.log$i 2>&1 & @@ -280,8 +294,9 @@ done ``` 其中必要的环境变量有, -- `RANK_TABLE_FILE`:组网信息文件的路径。 +- `MINDSPORE_HCCL_CONFIG_PATH`:组网信息文件的路径。 - `DEVICE_ID`:当前网卡在机器上的实际序号。 +- `RANK_ID`: 当前网卡的逻辑序号。 其余环境变量请参考安装教程中的配置项。 运行时间大约在5分钟内,主要时间是用于算子的编译,实际训练时间在20秒内。用户可以通过`ps -ef | grep pytest`来监控任务进程。 diff --git a/tutorials/tutorial_code/distributed_training/resnet.py b/tutorials/tutorial_code/distributed_training/resnet.py new file mode 100644 index 00000000..d0155563 --- /dev/null +++ b/tutorials/tutorial_code/distributed_training/resnet.py @@ -0,0 +1,329 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +'''resnet +The sample can be run on Ascend 910 AI processor. +''' +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from mindspore.common import dtype as mstype +def weight_variable(shape): + """weight_variable""" + return initializer('XavierUniform', shape=shape, dtype=mstype.float32) + + +def weight_variable_uniform(shape): + """weight_variable_uniform""" + return initializer('Uniform', shape=shape, dtype=mstype.float32) + + +def weight_variable_0(shape): + """weight_variable_0""" + zeros = np.zeros(shape).astype(np.float32) + return Tensor(zeros) + + +def weight_variable_1(shape): + """weight_variable_1""" + ones = np.ones(shape).astype(np.float32) + return Tensor(ones) + + +def conv3x3(in_channels, out_channels, stride=1, padding=0): + """3x3 convolution """ + weight_shape = (out_channels, in_channels, 3, 3) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def conv1x1(in_channels, out_channels, stride=1, padding=0): + """1x1 convolution""" + weight_shape = (out_channels, in_channels, 1, 1) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def conv7x7(in_channels, out_channels, stride=1, padding=0): + """1x1 convolution""" + weight_shape = (out_channels, in_channels, 7, 7) + weight = weight_variable(weight_shape) + return nn.Conv2d(in_channels, out_channels, + kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + + +def bn_with_initialize(out_channels): + """bn_with_initialize""" + shape = (out_channels) + mean = weight_variable_0(shape) + var = weight_variable_1(shape) + beta = weight_variable_0(shape) + gamma = weight_variable_uniform(shape) + bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma, + beta_init=beta, moving_mean_init=mean, moving_var_init=var) + return bn + + +def bn_with_initialize_last(out_channels): + """bn_with_initialize_last""" + shape = (out_channels) + mean = weight_variable_0(shape) + var = weight_variable_1(shape) + beta = weight_variable_0(shape) + gamma = weight_variable_uniform(shape) + bn = nn.BatchNorm2d(out_channels, momentum=0.99, eps=0.00001, gamma_init=gamma, + beta_init=beta, moving_mean_init=mean, moving_var_init=var) + return bn + + +def fc_with_initialize(input_channels, out_channels): + """fc_with_initialize""" + weight_shape = (out_channels, input_channels) + weight = weight_variable(weight_shape) + bias_shape = (out_channels) + bias = weight_variable_uniform(bias_shape) + return nn.Dense(input_channels, out_channels, weight, bias) + + +class ResidualBlock(nn.Cell): + """ResidualBlock""" + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + stride=1): + """init block""" + super(ResidualBlock, self).__init__() + + out_chls = out_channels // self.expansion + self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0) + self.bn1 = bn_with_initialize(out_chls) + + self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0) + self.bn2 = bn_with_initialize(out_chls) + + self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) + self.bn3 = bn_with_initialize_last(out_channels) + + self.relu = P.ReLU() + self.add = P.TensorAdd() + + def construct(self, x): + """construct""" + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class ResidualBlockWithDown(nn.Cell): + """ResidualBlockWithDown""" + expansion = 4 + + def __init__(self, + in_channels, + out_channels, + stride=1, + down_sample=False): + """init block with down""" + super(ResidualBlockWithDown, self).__init__() + + out_chls = out_channels // self.expansion + self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0) + self.bn1 = bn_with_initialize(out_chls) + + self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=0) + self.bn2 = bn_with_initialize(out_chls) + + self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0) + self.bn3 = bn_with_initialize_last(out_channels) + + self.relu = P.ReLU() + self.down_sample = down_sample + + self.conv_down_sample = conv1x1(in_channels, out_channels, stride=stride, padding=0) + self.bn_down_sample = bn_with_initialize(out_channels) + self.add = P.TensorAdd() + + def construct(self, x): + """construct""" + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + identity = self.conv_down_sample(identity) + identity = self.bn_down_sample(identity) + + out = self.add(out, identity) + out = self.relu(out) + + return out + + +class MakeLayer0(nn.Cell): + """MakeLayer0""" + + def __init__(self, block, in_channels, out_channels, stride): + """init""" + super(MakeLayer0, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=1, down_sample=True) + self.b = block(out_channels, out_channels, stride=stride) + self.c = block(out_channels, out_channels, stride=1) + + def construct(self, x): + """construct""" + x = self.a(x) + x = self.b(x) + x = self.c(x) + + return x + + +class MakeLayer1(nn.Cell): + """MakeLayer1""" + + def __init__(self, block, in_channels, out_channels, stride): + """init""" + super(MakeLayer1, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + self.d = block(out_channels, out_channels, stride=1) + + def construct(self, x): + """construct""" + x = self.a(x) + x = self.b(x) + x = self.c(x) + x = self.d(x) + + return x + + +class MakeLayer2(nn.Cell): + """MakeLayer2""" + + def __init__(self, block, in_channels, out_channels, stride): + """init""" + super(MakeLayer2, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + self.d = block(out_channels, out_channels, stride=1) + self.e = block(out_channels, out_channels, stride=1) + self.f = block(out_channels, out_channels, stride=1) + + def construct(self, x): + """construct""" + x = self.a(x) + x = self.b(x) + x = self.c(x) + x = self.d(x) + x = self.e(x) + x = self.f(x) + + return x + + +class MakeLayer3(nn.Cell): + """MakeLayer3""" + + def __init__(self, block, in_channels, out_channels, stride): + """init""" + super(MakeLayer3, self).__init__() + self.a = ResidualBlockWithDown(in_channels, out_channels, stride=stride, down_sample=True) + self.b = block(out_channels, out_channels, stride=1) + self.c = block(out_channels, out_channels, stride=1) + + def construct(self, x): + """construct""" + x = self.a(x) + x = self.b(x) + x = self.c(x) + + return x + + +class ResNet(nn.Cell): + """ResNet""" + + def __init__(self, block, num_classes=100, batch_size=32): + """init""" + super(ResNet, self).__init__() + self.batch_size = batch_size + self.num_classes = num_classes + + self.conv1 = conv7x7(3, 64, stride=2, padding=0) + + self.bn1 = bn_with_initialize(64) + self.relu = P.ReLU() + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same") + + self.layer1 = MakeLayer0(block, in_channels=64, out_channels=256, stride=1) + self.layer2 = MakeLayer1(block, in_channels=256, out_channels=512, stride=2) + self.layer3 = MakeLayer2(block, in_channels=512, out_channels=1024, stride=2) + self.layer4 = MakeLayer3(block, in_channels=1024, out_channels=2048, stride=2) + + self.pool = P.ReduceMean(keep_dims=True) + self.squeeze = P.Squeeze(axis=(2, 3)) + self.fc = fc_with_initialize(512 * block.expansion, num_classes) + + def construct(self, x): + """construct""" + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = self.pool(x, (2, 3)) + x = self.squeeze(x) + x = self.fc(x) + return x + + +def resnet50(batch_size, num_classes): + """create resnet50""" + return ResNet(ResidualBlock, num_classes, batch_size) diff --git a/tutorials/tutorial_code/distributed_training/resnet50_distributed_training.py b/tutorials/tutorial_code/distributed_training/resnet50_distributed_training.py index fa7b3231..5d84a7af 100644 --- a/tutorials/tutorial_code/distributed_training/resnet50_distributed_training.py +++ b/tutorials/tutorial_code/distributed_training/resnet50_distributed_training.py @@ -43,7 +43,9 @@ init() rank_id = get_rank() rank_size = get_group_size() -def create_dataset(repeat_num=1, batch_size=32, rank_id=0, rank_size=1): +EXEC_PATH=os.getcwd() + +def create_dataset(data_path, repeat_num=1, batch_size=32, rank_id=0, rank_size=1): resize_height = 224 resize_width = 224 rescale = 1.0 / 255.0 @@ -120,7 +122,7 @@ class SoftmaxCrossEntropyExpand(nn.Cell): def test_train_cifar(num_classes=10, epoch_size=10): context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) loss_cb = LossMonitor() - dataset = create_dataset(epoch_size) + dataset = create_dataset(os.path.join(EXEC_PATH, '../dataset/cifar-10-batches-bin/'), epoch_size) net = resnet50(32, num_classes) loss = SoftmaxCrossEntropyExpand(sparse=True) opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, 0.9) diff --git a/tutorials/tutorial_code/distributed_training/run.sh b/tutorials/tutorial_code/distributed_training/run.sh index 611982ef..11376ebb 100644 --- a/tutorials/tutorial_code/distributed_training/run.sh +++ b/tutorials/tutorial_code/distributed_training/run.sh @@ -1,14 +1,17 @@ #!/bin/bash -export RANK_TABLE_FILE=./rank_table.json +EXEC_PATH=$(pwd) +export MINDSPORE_HCCL_CONFIG_PATH=${EXEC_PATH}/rank_table.json export RANK_SIZE=8 + for((i=0;i<$RANK_SIZE;i++)) do rm -rf device$i mkdir device$i - cp ./resnet50_distributed_training.py ./device$i + cp ./resnet50_distributed_training.py ./resnet.py ./device$i cd ./device$i export DEVICE_ID=$i + export RANK_ID=$i echo "start training for device $i" env > env$i.log pytest -s -v ./resnet50_distributed_training.py > train.log$i 2>&1 & -- GitLab