Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
docs
提交
037d8432
D
docs
项目概览
MindSpore
/
docs
通知
4
Star
2
Fork
2
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
D
docs
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
037d8432
编写于
5月 11, 2020
作者:
Z
Ziyan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix distributd training tutroial and code
上级
601b3e29
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
358 addition
and
9 deletion
+358
-9
tutorials/source_zh_cn/advanced_use/distributed_training.md
tutorials/source_zh_cn/advanced_use/distributed_training.md
+20
-5
tutorials/tutorial_code/distributed_training/resnet.py
tutorials/tutorial_code/distributed_training/resnet.py
+329
-0
tutorials/tutorial_code/distributed_training/resnet50_distributed_training.py
...ode/distributed_training/resnet50_distributed_training.py
+4
-2
tutorials/tutorial_code/distributed_training/run.sh
tutorials/tutorial_code/distributed_training/run.sh
+5
-2
未找到文件。
tutorials/source_zh_cn/advanced_use/distributed_training.md
浏览文件 @
037d8432
...
...
@@ -27,6 +27,7 @@
当前MindSpore也提供分布式并行训练的功能。它支持了多种模式包括:
-
`DATA_PARALLEL`
:数据并行模式。
-
`AUTO_PARALLEL`
:自动并行模式,融合了数据并行、模型并行及混合并行的1种分布式并行模式,可以自动建立代价模型,为用户选择1种并行模式。其中,代价模型指围绕Ascend 910芯片基于内存的计算开销和通信开销对训练时间建模,并设计高效的算法找到训练时间较短的并行策略。
-
`HYBRID_PARALLEL`
: 在MindSpore中指用户手动切分参数实现层内模型并行的场景。
本篇教程我们主要讲解如何在MindSpore上通过数据并行及自动并行模式训练ResNet-50网络。
> 本例面向Ascend 910 AI处理器硬件平台,暂不支持CPU和GPU场景。
...
...
@@ -34,6 +35,14 @@
## 准备环节
### 下载数据集
本样例采用
`CIFAR-10`
数据集,由10类32
*
32的彩色图片组成,每类包含6000张图片。其中训练集共50000张图片,测试集共10000张图片。
> `CIFAR-10`数据集下载链接:<http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz>。
将数据集下载并解压到本地路径下,这里将数据集解压存放到工作区的
`./dataset`
路径下。
### 配置分布式环境变量
在裸机环境(对比云上环境,即本地有Ascend 910 AI 处理器)进行分布式训练时,需要配置当前多卡环境的组网信息文件。如果使用华为云环境,因为云服务本身已经做好了配置,可以跳过本小节。
...
...
@@ -109,7 +118,7 @@ if __name__ == "__main__":
## 数据并行模式加载数据集
分布式训练时,数据是以数据并行的方式导入的。下面我们以CIFAR-10数据集为例,介绍以数据并行方式导入CIFAR-10数据集的方法,
`data_path`
是指数据集的路径。
分布式训练时,数据是以数据并行的方式导入的。下面我们以CIFAR-10数据集为例,介绍以数据并行方式导入CIFAR-10数据集的方法,
`data_path`
是指数据集的路径
,在样例代码中采用工作区下
`dataset/cifar-10-batches-bin`
文件夹的路径
。
```
python
...
...
@@ -119,7 +128,7 @@ import mindspore.dataset.transforms.c_transforms as C
import
mindspore.dataset.transforms.vision.c_transforms
as
vision
from
mindspore.communication.management
import
get_rank
,
get_group_size
def
create_dataset
(
repeat_num
=
1
,
batch_size
=
32
,
rank_id
=
0
,
rank_size
=
1
):
def
create_dataset
(
data_path
=
data_path
,
repeat_num
=
1
,
batch_size
=
32
,
rank_id
=
0
,
rank_size
=
1
):
resize_height
=
224
resize_width
=
224
rescale
=
1.0
/
255.0
...
...
@@ -227,6 +236,8 @@ class SoftmaxCrossEntropyExpand(nn.Cell):
> `device_num`和`global_rank`建议采用默认值,框架内会调用HCCL接口获取。
如脚本中存在多个网络用例,请在执行下个用例前调用
`context.reset_auto_parallel_context()`
将所有参数还原到默认值。
在下面的样例中我们指定并行模式为自动并行,用户如需切换为数据并行模式,只需将
`parallel_mode`
改为
`DATA_PARALLEL`
。
```
python
...
...
@@ -263,15 +274,18 @@ def test_train_cifar(num_classes=10, epoch_size=10):
```
bash
#!/bin/bash
export
RANK_TABLE_FILE
=
./rank_table.json
EXEC_PATH
=
$(
pwd
)
export
MINDSPORE_HCCL_CONFIG_PATH
=
${
EXEC_PATH
}
/rank_table.json
export
RANK_SIZE
=
8
for
((
i
=
0
;
i<
$RANK_SIZE
;
i++
))
do
rm
-rf
device
$i
mkdir
device
$i
cp
./resnet50_distributed_training.py ./device
$i
cp
./resnet50_distributed_training.py ./
resnet.py ./
device
$i
cd
./device
$i
export
DEVICE_ID
=
$i
export
RANK_ID
=
$i
echo
"start training for device
$i
"
env
>
env
$i
.log
pytest
-s
-v
./resnet50_distributed_training.py
>
train.log
$i
2>&1 &
...
...
@@ -280,8 +294,9 @@ done
```
其中必要的环境变量有,
-
`
RANK_TABLE_FILE
`
:组网信息文件的路径。
-
`
MINDSPORE_HCCL_CONFIG_PATH
`
:组网信息文件的路径。
-
`DEVICE_ID`
:当前网卡在机器上的实际序号。
-
`RANK_ID`
: 当前网卡的逻辑序号。
其余环境变量请参考安装教程中的配置项。
运行时间大约在5分钟内,主要时间是用于算子的编译,实际训练时间在20秒内。用户可以通过
`ps -ef | grep pytest`
来监控任务进程。
...
...
tutorials/tutorial_code/distributed_training/resnet.py
0 → 100644
浏览文件 @
037d8432
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
'''resnet
The sample can be run on Ascend 910 AI processor.
'''
import
numpy
as
np
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.common.initializer
import
initializer
from
mindspore.common
import
dtype
as
mstype
def
weight_variable
(
shape
):
"""weight_variable"""
return
initializer
(
'XavierUniform'
,
shape
=
shape
,
dtype
=
mstype
.
float32
)
def
weight_variable_uniform
(
shape
):
"""weight_variable_uniform"""
return
initializer
(
'Uniform'
,
shape
=
shape
,
dtype
=
mstype
.
float32
)
def
weight_variable_0
(
shape
):
"""weight_variable_0"""
zeros
=
np
.
zeros
(
shape
).
astype
(
np
.
float32
)
return
Tensor
(
zeros
)
def
weight_variable_1
(
shape
):
"""weight_variable_1"""
ones
=
np
.
ones
(
shape
).
astype
(
np
.
float32
)
return
Tensor
(
ones
)
def
conv3x3
(
in_channels
,
out_channels
,
stride
=
1
,
padding
=
0
):
"""3x3 convolution """
weight_shape
=
(
out_channels
,
in_channels
,
3
,
3
)
weight
=
weight_variable
(
weight_shape
)
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
3
,
stride
=
stride
,
padding
=
padding
,
weight_init
=
weight
,
has_bias
=
False
,
pad_mode
=
"same"
)
def
conv1x1
(
in_channels
,
out_channels
,
stride
=
1
,
padding
=
0
):
"""1x1 convolution"""
weight_shape
=
(
out_channels
,
in_channels
,
1
,
1
)
weight
=
weight_variable
(
weight_shape
)
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
1
,
stride
=
stride
,
padding
=
padding
,
weight_init
=
weight
,
has_bias
=
False
,
pad_mode
=
"same"
)
def
conv7x7
(
in_channels
,
out_channels
,
stride
=
1
,
padding
=
0
):
"""1x1 convolution"""
weight_shape
=
(
out_channels
,
in_channels
,
7
,
7
)
weight
=
weight_variable
(
weight_shape
)
return
nn
.
Conv2d
(
in_channels
,
out_channels
,
kernel_size
=
7
,
stride
=
stride
,
padding
=
padding
,
weight_init
=
weight
,
has_bias
=
False
,
pad_mode
=
"same"
)
def
bn_with_initialize
(
out_channels
):
"""bn_with_initialize"""
shape
=
(
out_channels
)
mean
=
weight_variable_0
(
shape
)
var
=
weight_variable_1
(
shape
)
beta
=
weight_variable_0
(
shape
)
gamma
=
weight_variable_uniform
(
shape
)
bn
=
nn
.
BatchNorm2d
(
out_channels
,
momentum
=
0.99
,
eps
=
0.00001
,
gamma_init
=
gamma
,
beta_init
=
beta
,
moving_mean_init
=
mean
,
moving_var_init
=
var
)
return
bn
def
bn_with_initialize_last
(
out_channels
):
"""bn_with_initialize_last"""
shape
=
(
out_channels
)
mean
=
weight_variable_0
(
shape
)
var
=
weight_variable_1
(
shape
)
beta
=
weight_variable_0
(
shape
)
gamma
=
weight_variable_uniform
(
shape
)
bn
=
nn
.
BatchNorm2d
(
out_channels
,
momentum
=
0.99
,
eps
=
0.00001
,
gamma_init
=
gamma
,
beta_init
=
beta
,
moving_mean_init
=
mean
,
moving_var_init
=
var
)
return
bn
def
fc_with_initialize
(
input_channels
,
out_channels
):
"""fc_with_initialize"""
weight_shape
=
(
out_channels
,
input_channels
)
weight
=
weight_variable
(
weight_shape
)
bias_shape
=
(
out_channels
)
bias
=
weight_variable_uniform
(
bias_shape
)
return
nn
.
Dense
(
input_channels
,
out_channels
,
weight
,
bias
)
class
ResidualBlock
(
nn
.
Cell
):
"""ResidualBlock"""
expansion
=
4
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
=
1
):
"""init block"""
super
(
ResidualBlock
,
self
).
__init__
()
out_chls
=
out_channels
//
self
.
expansion
self
.
conv1
=
conv1x1
(
in_channels
,
out_chls
,
stride
=
stride
,
padding
=
0
)
self
.
bn1
=
bn_with_initialize
(
out_chls
)
self
.
conv2
=
conv3x3
(
out_chls
,
out_chls
,
stride
=
1
,
padding
=
0
)
self
.
bn2
=
bn_with_initialize
(
out_chls
)
self
.
conv3
=
conv1x1
(
out_chls
,
out_channels
,
stride
=
1
,
padding
=
0
)
self
.
bn3
=
bn_with_initialize_last
(
out_channels
)
self
.
relu
=
P
.
ReLU
()
self
.
add
=
P
.
TensorAdd
()
def
construct
(
self
,
x
):
"""construct"""
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
out
=
self
.
add
(
out
,
identity
)
out
=
self
.
relu
(
out
)
return
out
class
ResidualBlockWithDown
(
nn
.
Cell
):
"""ResidualBlockWithDown"""
expansion
=
4
def
__init__
(
self
,
in_channels
,
out_channels
,
stride
=
1
,
down_sample
=
False
):
"""init block with down"""
super
(
ResidualBlockWithDown
,
self
).
__init__
()
out_chls
=
out_channels
//
self
.
expansion
self
.
conv1
=
conv1x1
(
in_channels
,
out_chls
,
stride
=
stride
,
padding
=
0
)
self
.
bn1
=
bn_with_initialize
(
out_chls
)
self
.
conv2
=
conv3x3
(
out_chls
,
out_chls
,
stride
=
1
,
padding
=
0
)
self
.
bn2
=
bn_with_initialize
(
out_chls
)
self
.
conv3
=
conv1x1
(
out_chls
,
out_channels
,
stride
=
1
,
padding
=
0
)
self
.
bn3
=
bn_with_initialize_last
(
out_channels
)
self
.
relu
=
P
.
ReLU
()
self
.
down_sample
=
down_sample
self
.
conv_down_sample
=
conv1x1
(
in_channels
,
out_channels
,
stride
=
stride
,
padding
=
0
)
self
.
bn_down_sample
=
bn_with_initialize
(
out_channels
)
self
.
add
=
P
.
TensorAdd
()
def
construct
(
self
,
x
):
"""construct"""
identity
=
x
out
=
self
.
conv1
(
x
)
out
=
self
.
bn1
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv2
(
out
)
out
=
self
.
bn2
(
out
)
out
=
self
.
relu
(
out
)
out
=
self
.
conv3
(
out
)
out
=
self
.
bn3
(
out
)
identity
=
self
.
conv_down_sample
(
identity
)
identity
=
self
.
bn_down_sample
(
identity
)
out
=
self
.
add
(
out
,
identity
)
out
=
self
.
relu
(
out
)
return
out
class
MakeLayer0
(
nn
.
Cell
):
"""MakeLayer0"""
def
__init__
(
self
,
block
,
in_channels
,
out_channels
,
stride
):
"""init"""
super
(
MakeLayer0
,
self
).
__init__
()
self
.
a
=
ResidualBlockWithDown
(
in_channels
,
out_channels
,
stride
=
1
,
down_sample
=
True
)
self
.
b
=
block
(
out_channels
,
out_channels
,
stride
=
stride
)
self
.
c
=
block
(
out_channels
,
out_channels
,
stride
=
1
)
def
construct
(
self
,
x
):
"""construct"""
x
=
self
.
a
(
x
)
x
=
self
.
b
(
x
)
x
=
self
.
c
(
x
)
return
x
class
MakeLayer1
(
nn
.
Cell
):
"""MakeLayer1"""
def
__init__
(
self
,
block
,
in_channels
,
out_channels
,
stride
):
"""init"""
super
(
MakeLayer1
,
self
).
__init__
()
self
.
a
=
ResidualBlockWithDown
(
in_channels
,
out_channels
,
stride
=
stride
,
down_sample
=
True
)
self
.
b
=
block
(
out_channels
,
out_channels
,
stride
=
1
)
self
.
c
=
block
(
out_channels
,
out_channels
,
stride
=
1
)
self
.
d
=
block
(
out_channels
,
out_channels
,
stride
=
1
)
def
construct
(
self
,
x
):
"""construct"""
x
=
self
.
a
(
x
)
x
=
self
.
b
(
x
)
x
=
self
.
c
(
x
)
x
=
self
.
d
(
x
)
return
x
class
MakeLayer2
(
nn
.
Cell
):
"""MakeLayer2"""
def
__init__
(
self
,
block
,
in_channels
,
out_channels
,
stride
):
"""init"""
super
(
MakeLayer2
,
self
).
__init__
()
self
.
a
=
ResidualBlockWithDown
(
in_channels
,
out_channels
,
stride
=
stride
,
down_sample
=
True
)
self
.
b
=
block
(
out_channels
,
out_channels
,
stride
=
1
)
self
.
c
=
block
(
out_channels
,
out_channels
,
stride
=
1
)
self
.
d
=
block
(
out_channels
,
out_channels
,
stride
=
1
)
self
.
e
=
block
(
out_channels
,
out_channels
,
stride
=
1
)
self
.
f
=
block
(
out_channels
,
out_channels
,
stride
=
1
)
def
construct
(
self
,
x
):
"""construct"""
x
=
self
.
a
(
x
)
x
=
self
.
b
(
x
)
x
=
self
.
c
(
x
)
x
=
self
.
d
(
x
)
x
=
self
.
e
(
x
)
x
=
self
.
f
(
x
)
return
x
class
MakeLayer3
(
nn
.
Cell
):
"""MakeLayer3"""
def
__init__
(
self
,
block
,
in_channels
,
out_channels
,
stride
):
"""init"""
super
(
MakeLayer3
,
self
).
__init__
()
self
.
a
=
ResidualBlockWithDown
(
in_channels
,
out_channels
,
stride
=
stride
,
down_sample
=
True
)
self
.
b
=
block
(
out_channels
,
out_channels
,
stride
=
1
)
self
.
c
=
block
(
out_channels
,
out_channels
,
stride
=
1
)
def
construct
(
self
,
x
):
"""construct"""
x
=
self
.
a
(
x
)
x
=
self
.
b
(
x
)
x
=
self
.
c
(
x
)
return
x
class
ResNet
(
nn
.
Cell
):
"""ResNet"""
def
__init__
(
self
,
block
,
num_classes
=
100
,
batch_size
=
32
):
"""init"""
super
(
ResNet
,
self
).
__init__
()
self
.
batch_size
=
batch_size
self
.
num_classes
=
num_classes
self
.
conv1
=
conv7x7
(
3
,
64
,
stride
=
2
,
padding
=
0
)
self
.
bn1
=
bn_with_initialize
(
64
)
self
.
relu
=
P
.
ReLU
()
self
.
maxpool
=
nn
.
MaxPool2d
(
kernel_size
=
3
,
stride
=
2
,
pad_mode
=
"same"
)
self
.
layer1
=
MakeLayer0
(
block
,
in_channels
=
64
,
out_channels
=
256
,
stride
=
1
)
self
.
layer2
=
MakeLayer1
(
block
,
in_channels
=
256
,
out_channels
=
512
,
stride
=
2
)
self
.
layer3
=
MakeLayer2
(
block
,
in_channels
=
512
,
out_channels
=
1024
,
stride
=
2
)
self
.
layer4
=
MakeLayer3
(
block
,
in_channels
=
1024
,
out_channels
=
2048
,
stride
=
2
)
self
.
pool
=
P
.
ReduceMean
(
keep_dims
=
True
)
self
.
squeeze
=
P
.
Squeeze
(
axis
=
(
2
,
3
))
self
.
fc
=
fc_with_initialize
(
512
*
block
.
expansion
,
num_classes
)
def
construct
(
self
,
x
):
"""construct"""
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
self
.
relu
(
x
)
x
=
self
.
maxpool
(
x
)
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
x
=
self
.
pool
(
x
,
(
2
,
3
))
x
=
self
.
squeeze
(
x
)
x
=
self
.
fc
(
x
)
return
x
def
resnet50
(
batch_size
,
num_classes
):
"""create resnet50"""
return
ResNet
(
ResidualBlock
,
num_classes
,
batch_size
)
tutorials/tutorial_code/distributed_training/resnet50_distributed_training.py
浏览文件 @
037d8432
...
...
@@ -43,7 +43,9 @@ init()
rank_id
=
get_rank
()
rank_size
=
get_group_size
()
def
create_dataset
(
repeat_num
=
1
,
batch_size
=
32
,
rank_id
=
0
,
rank_size
=
1
):
EXEC_PATH
=
os
.
getcwd
()
def
create_dataset
(
data_path
,
repeat_num
=
1
,
batch_size
=
32
,
rank_id
=
0
,
rank_size
=
1
):
resize_height
=
224
resize_width
=
224
rescale
=
1.0
/
255.0
...
...
@@ -120,7 +122,7 @@ class SoftmaxCrossEntropyExpand(nn.Cell):
def
test_train_cifar
(
num_classes
=
10
,
epoch_size
=
10
):
context
.
set_auto_parallel_context
(
parallel_mode
=
ParallelMode
.
AUTO_PARALLEL
,
mirror_mean
=
True
)
loss_cb
=
LossMonitor
()
dataset
=
create_dataset
(
epoch_size
)
dataset
=
create_dataset
(
os
.
path
.
join
(
EXEC_PATH
,
'../dataset/cifar-10-batches-bin/'
),
epoch_size
)
net
=
resnet50
(
32
,
num_classes
)
loss
=
SoftmaxCrossEntropyExpand
(
sparse
=
True
)
opt
=
Momentum
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
0.01
,
0.9
)
...
...
tutorials/tutorial_code/distributed_training/run.sh
浏览文件 @
037d8432
#!/bin/bash
export
RANK_TABLE_FILE
=
./rank_table.json
EXEC_PATH
=
$(
pwd
)
export
MINDSPORE_HCCL_CONFIG_PATH
=
${
EXEC_PATH
}
/rank_table.json
export
RANK_SIZE
=
8
for
((
i
=
0
;
i<
$RANK_SIZE
;
i++
))
do
rm
-rf
device
$i
mkdir
device
$i
cp
./resnet50_distributed_training.py ./device
$i
cp
./resnet50_distributed_training.py ./
resnet.py ./
device
$i
cd
./device
$i
export
DEVICE_ID
=
$i
export
RANK_ID
=
$i
echo
"start training for device
$i
"
env
>
env
$i
.log
pytest
-s
-v
./resnet50_distributed_training.py
>
train.log
$i
2>&1 &
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录