diff --git a/example/resnet50_cifar10/dataset.py b/example/resnet50_cifar10/dataset.py index 9ed16f08b5555d2f6d73c28388a28282fb1dca8e..1889da95b65ff79277fb2cdd5c988933520bd4e0 100755 --- a/example/resnet50_cifar10/dataset.py +++ b/example/resnet50_cifar10/dataset.py @@ -40,9 +40,9 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): rank_id = int(os.getenv("RANK_ID")) if device_num == 1: - ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True) + ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True) else: - ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=4, shuffle=True, + ds = de.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True, num_shards=device_num, shard_id=rank_id) resize_height = config.image_height @@ -68,11 +68,8 @@ def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32): type_cast_op = C2.TypeCast(mstype.int32) - ds = ds.map(input_columns="label", operations=type_cast_op) - ds = ds.map(input_columns="image", operations=trans) - - # apply shuffle operations - ds = ds.shuffle(buffer_size=config.buffer_size) + ds = ds.map(input_columns="label", num_parallel_workers=8, operations=type_cast_op) + ds = ds.map(input_columns="image", num_parallel_workers=8, operations=trans) # apply batch operations ds = ds.batch(batch_size, drop_remainder=True) diff --git a/example/resnet50_cifar10/run_distribute_train.sh b/example/resnet50_cifar10/run_distribute_train.sh index e78e2bf104598249383fd6e9abb8d0e28a4bd713..5165f58cab237c4733da84f60b27980e0ca67c8c 100755 --- a/example/resnet50_cifar10/run_distribute_train.sh +++ b/example/resnet50_cifar10/run_distribute_train.sh @@ -36,6 +36,7 @@ ulimit -u unlimited export DEVICE_NUM=8 export RANK_SIZE=8 export MINDSPORE_HCCL_CONFIG_PATH=$1 +export RANK_TABLE_FILE=$1 for((i=0; i<${DEVICE_NUM}; i++)) do diff --git a/example/resnet50_cifar10/train.py b/example/resnet50_cifar10/train.py index b18c3778de013e77bac8a0f0386d00b1930ee5d5..0a0299b2bdd7d10d6de27cc3bee07de8c81a0194 100755 --- a/example/resnet50_cifar10/train.py +++ b/example/resnet50_cifar10/train.py @@ -61,7 +61,7 @@ if __name__ == '__main__': context.set_context(enable_hccl=True) context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) - auto_parallel_context().set_all_reduce_fusion_split_indices([140]) + auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) init() else: context.set_context(enable_hccl=False) diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index 46c28dec888da86a335cb33a637473b5f541238b..9cd6d8e05b3bc58726906ab72ae1ce6a60d14c53 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -359,7 +359,12 @@ void MsContext::GetGeOptions(std::map *ge_options) con } // Enable auto mixed precision according to the context options - (*ge_options)["ge.exec.auto_mix_precision"] = std::to_string(auto_mixed_precision_flag_); + if (auto_mixed_precision_flag_) { + (*ge_options)["ge.exec.precision_mode"] = "allow_mix_precision"; + } else { + (*ge_options)["ge.exec.precision_mode"] = "must_keep_origin_dtype"; + } + // Disable the global variable acc, only enable it whlie adding training graph in pipeline (*ge_options)["ge.exec.variable_acc"] = "0"; #endif