提交 06af0f75 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!773 Set precision mode and allreduce split strategy

Merge pull request !773 from gengdongjie/r0.2
......@@ -40,9 +40,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
rank_id = int(os.getenv("RANK_ID"))
if device_num == 1:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True)
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
else:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True,
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id)
resize_height = config.image_height
......@@ -68,11 +68,8 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="label", operations=type_cast_op)
ds = ds.map(input_columns="image", operations=trans)
# apply shuffle operations
ds = ds.shuffle(buffer_size=config.buffer_size)
ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
......
......@@ -36,6 +36,7 @@ ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$1
export RANK_TABLE_FILE=$1
for((i=0; i<${DEVICE_NUM}; i++))
do
......
......@@ -61,7 +61,7 @@ if __name__ == '__main__':
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140])
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])
init()
else:
context.set_context(enable_hccl=False)
......
......@@ -359,7 +359,12 @@ void MsContext::GetGeOptions(std::map<std::string, std::string> *ge_options) con
}
// Enable auto mixed precision according to the context options
(*ge_options)["ge.exec.auto_mix_precision"] = std::to_string(auto_mixed_precision_flag_);
if (auto_mixed_precision_flag_) {
(*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
} else {
(*ge_options)["ge.exec.precision_mode"] = "must_keep_origin_dtype";
}
// Disable the global variable acc, only enable it whlie adding training graph in pipeline
(*ge_options)["ge.exec.variable_acc"] = "0";
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册