diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index 05be097e6a8dc7d292c54da7c31329483d48c821..d5523aaa62e3c56814efcb2694d1a5544101f9b1 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -41,7 +41,6 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; int32_t RUN_PHASE = DEFAULT_RUN_PHASE; -constexpr char RESHAPEINFO[] = "ReshapeInfo"; void CostGraph::SetDeviceMemoryAndCostParameter() { MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index 9cb3c7040af835c9c0628c71055ccd0ba0f0843e..93e14d7f34870073e1ea93f1e5c27d2c29da3787 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -65,6 +65,7 @@ constexpr char STEP_PARALLEL_END[] = "step_parallel_end"; constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot"; constexpr char REQUIRES_GRAD[] = "requires_grad"; constexpr char PARAM_NAME[] = "name"; +constexpr char RESHAPEINFO[] = "ReshapeInfo"; constexpr char RELU_TYPE[] = "relu"; constexpr char RELU6_TYPE[] = "relu6"; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 7d1200b1904e055e1cad3954431ba63e75d55c7f..b4ba7dd6959f5eaf82178336a6cff6d079c676b1 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -2120,6 +2120,9 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(prim); OperatorInfoPtr operator_info = cnode->operator_info(); if (operator_info) { + if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { + continue; + } StrategyPtr strategyPtr = operator_info->strategy(); MS_EXCEPTION_IF_NULL(node->scope()); stra_map[param_name] = strategyPtr; diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index de10f4beb405c9717c6f503c33e4bdd670f637ca..a83b5eb627cfd537ea1893f1234cca825534959c 100644 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -93,6 +93,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { parallel_strategy_item->set_node_name(node_stra.first); straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); MS_EXCEPTION_IF_NULL(parallel_strategys); + MS_EXCEPTION_IF_NULL(node_stra.second); parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); for (auto &dims : node_stra.second->GetInputDim()) { straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();