From 37338813f09103da3e9ea8db085a31db8830f8e9 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Mon, 6 Jul 2020 21:04:59 +0800 Subject: [PATCH] skip strategy ckpt save for reshape --- mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc | 1 - mindspore/ccsrc/parallel/ops_info/ops_utils.h | 1 + mindspore/ccsrc/parallel/step_parallel.cc | 3 +++ .../strategy_checkpoint/parallel_strategy_checkpoint.cc | 1 + 4 files changed, 5 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index 05be097e6..d5523aaa6 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -41,7 +41,6 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; int32_t RUN_PHASE = DEFAULT_RUN_PHASE; -constexpr char RESHAPEINFO[] = "ReshapeInfo"; void CostGraph::SetDeviceMemoryAndCostParameter() { MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 9cb3c7040..93e14d7f3 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -65,6 +65,7 @@ constexpr char STEP_PARALLEL_END[] = "step_parallel_end"; constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot"; constexpr char REQUIRES_GRAD[] = "requires_grad"; constexpr char PARAM_NAME[] = "name"; +constexpr char RESHAPEINFO[] = "ReshapeInfo"; constexpr char RELU_TYPE[] = "relu"; constexpr char RELU6_TYPE[] = "relu6"; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 7d1200b19..b4ba7dd69 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -2120,6 +2120,9 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(prim); OperatorInfoPtr operator_info = cnode->operator_info(); if (operator_info) { + if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { + continue; + } StrategyPtr strategyPtr = operator_info->strategy(); MS_EXCEPTION_IF_NULL(node->scope()); stra_map[param_name] = strategyPtr; diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index de10f4beb..a83b5eb62 100644 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -93,6 +93,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { parallel_strategy_item->set_node_name(node_stra.first); straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); MS_EXCEPTION_IF_NULL(parallel_strategys); + MS_EXCEPTION_IF_NULL(node_stra.second); parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); for (auto &dims : node_stra.second->GetInputDim()) { straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); -- GitLab