提交 06af0f75 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!773 Set precision mode and allreduce split strategy

Merge pull request !773 from gengdongjie/r0.2
...@@ -40,9 +40,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): ...@@ -40,9 +40,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
rank_id = int(os.getenv("RANK_ID")) rank_id = int(os.getenv("RANK_ID"))
if device_num == 1: if device_num == 1:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True) ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
else: else:
ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True, ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
num_shards=device_num, shard_id=rank_id) num_shards=device_num, shard_id=rank_id)
resize_height = config.image_height resize_height = config.image_height
...@@ -68,11 +68,8 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): ...@@ -68,11 +68,8 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32):
type_cast_op = C2.TypeCast(mstype.int32) type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="label", operations=type_cast_op) ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op)
ds = ds.map(input_columns="image", operations=trans) ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans)
# apply shuffle operations
ds = ds.shuffle(buffer_size=config.buffer_size)
# apply batch operations # apply batch operations
ds = ds.batch(batch_size, drop_remainder=True) ds = ds.batch(batch_size, drop_remainder=True)
......
...@@ -36,6 +36,7 @@ ulimit -u unlimited ...@@ -36,6 +36,7 @@ ulimit -u unlimited
export DEVICE_NUM=8 export DEVICE_NUM=8
export RANK_SIZE=8 export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$1 export MINDSPORE_HCCL_CONFIG_PATH=$1
export RANK_TABLE_FILE=$1
for((i=0; i<${DEVICE_NUM}; i++)) for((i=0; i<${DEVICE_NUM}; i++))
do do
......
...@@ -61,7 +61,7 @@ if __name__ == '__main__': ...@@ -61,7 +61,7 @@ if __name__ == '__main__':
context.set_context(enable_hccl=True) context.set_context(enable_hccl=True)
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) mirror_mean=True)
auto_parallel_context().set_all_reduce_fusion_split_indices([140]) auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])
init() init()
else: else:
context.set_context(enable_hccl=False) context.set_context(enable_hccl=False)
......
...@@ -359,7 +359,12 @@ void MsContext::GetGeOptions(std::map<std::string, std::string> *ge_options) con ...@@ -359,7 +359,12 @@ void MsContext::GetGeOptions(std::map<std::string, std::string> *ge_options) con
} }
// Enable auto mixed precision according to the context options // Enable auto mixed precision according to the context options
(*ge_options)["ge.exec.auto_mix_precision"] = std::to_string(auto_mixed_precision_flag_); if (auto_mixed_precision_flag_) {
(*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision";
} else {
(*ge_options)["ge.exec.precision_mode"] = "must_keep_origin_dtype";
}
// Disable the global variable acc, only enable it whlie adding training graph in pipeline // Disable the global variable acc, only enable it whlie adding training graph in pipeline
(*ge_options)["ge.exec.variable_acc"] = "0"; (*ge_options)["ge.exec.variable_acc"] = "0";
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册