Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
06af0f75
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
06af0f75
编写于
4月 28, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 28, 2020
浏览文件
操作
浏览文件
下载
差异文件
!773 Set precision mode and allreduce split strategy
Merge pull request !773 from gengdongjie/r0.2
上级
c90b66a0
e8621ce1
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
12 addition
and
9 deletion
+12
-9
example/resnet50_cifar10/dataset.py
example/resnet50_cifar10/dataset.py
+4
-7
example/resnet50_cifar10/run_distribute_train.sh
example/resnet50_cifar10/run_distribute_train.sh
+1
-0
example/resnet50_cifar10/train.py
example/resnet50_cifar10/train.py
+1
-1
mindspore/ccsrc/utils/context/ms_context.cc
mindspore/ccsrc/utils/context/ms_context.cc
+6
-1
未找到文件。
example/resnet50_cifar10/dataset.py
浏览文件 @
06af0f75
...
...
@@ -40,9 +40,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
rank_id
=
int
(
os
.
getenv
(
"RANK_ID"
))
if
device_num
==
1
:
ds
=
de
.
Cifar10Dataset
(
dataset_path
,
num_parallel_workers
=
4
,
shuffle
=
True
)
ds
=
de
.
Cifar10Dataset
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
)
else
:
ds
=
de
.
Cifar10Dataset
(
dataset_path
,
num_parallel_workers
=
4
,
shuffle
=
True
,
ds
=
de
.
Cifar10Dataset
(
dataset_path
,
num_parallel_workers
=
8
,
shuffle
=
True
,
num_shards
=
device_num
,
shard_id
=
rank_id
)
resize_height
=
config
.
image_height
...
...
@@ -68,11 +68,8 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
type_cast_op
=
C2
.
TypeCast
(
mstype
.
int32
)
ds
=
ds
.
map
(
input_columns
=
"label"
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"image"
,
operations
=
trans
)
# apply shuffle operations
ds
=
ds
.
shuffle
(
buffer_size
=
config
.
buffer_size
)
ds
=
ds
.
map
(
input_columns
=
"label"
,
num_parallel_workers
=
8
,
operations
=
type_cast_op
)
ds
=
ds
.
map
(
input_columns
=
"image"
,
num_parallel_workers
=
8
,
operations
=
trans
)
# apply batch operations
ds
=
ds
.
batch
(
batch_size
,
drop_remainder
=
True
)
...
...
example/resnet50_cifar10/run_distribute_train.sh
浏览文件 @
06af0f75
...
...
@@ -36,6 +36,7 @@ ulimit -u unlimited
export
DEVICE_NUM
=
8
export
RANK_SIZE
=
8
export
MINDSPORE_HCCL_CONFIG_PATH
=
$1
export
RANK_TABLE_FILE
=
$1
for
((
i
=
0
;
i<
${
DEVICE_NUM
}
;
i++
))
do
...
...
example/resnet50_cifar10/train.py
浏览文件 @
06af0f75
...
...
@@ -61,7 +61,7 @@ if __name__ == '__main__':
context
.
set_context
(
enable_hccl
=
True
)
context
.
set_auto_parallel_context
(
device_num
=
args_opt
.
device_num
,
parallel_mode
=
ParallelMode
.
DATA_PARALLEL
,
mirror_mean
=
True
)
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
1
4
0
])
auto_parallel_context
().
set_all_reduce_fusion_split_indices
([
1
07
,
16
0
])
init
()
else
:
context
.
set_context
(
enable_hccl
=
False
)
...
...
mindspore/ccsrc/utils/context/ms_context.cc
浏览文件 @
06af0f75
...
...
@@ -359,7 +359,12 @@ void MsContext::GetGeOptions(std::map<std::string, std::string> *ge_options) con
}
// Enable auto mixed precision according to the context options
(
*
ge_options
)[
"ge.exec.auto_mix_precision"
]
=
std
::
to_string
(
auto_mixed_precision_flag_
);
if
(
auto_mixed_precision_flag_
)
{
(
*
ge_options
)[
"ge.exec.precision_mode"
]
=
"allow_mix_precision"
;
}
else
{
(
*
ge_options
)[
"ge.exec.precision_mode"
]
=
"must_keep_origin_dtype"
;
}
// Disable the global variable acc, only enable it whlie adding training graph in pipeline
(
*
ge_options
)[
"ge.exec.variable_acc"
]
=
"0"
;
#endif
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录