提交 37338813 编写于 作者: Y yao_yf

skip strategy ckpt save for reshape

上级 51f8ffab
...@@ -41,7 +41,6 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES; ...@@ -41,7 +41,6 @@ bool FULLY_USE_DEVICES = DEFAULT_FULLY_USE_DEVICES;
bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW; bool ELEMENTWISE_OP_STRA_FOLLOW = DEFAULT_ELEMENTWISE_OP_STRA_FOLLOW;
bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS; bool MULTI_SUBGRAPHS = DEFAULT_IS_MULTI_SUBGRAPHS;
int32_t RUN_PHASE = DEFAULT_RUN_PHASE; int32_t RUN_PHASE = DEFAULT_RUN_PHASE;
constexpr char RESHAPEINFO[] = "ReshapeInfo";
void CostGraph::SetDeviceMemoryAndCostParameter() { void CostGraph::SetDeviceMemoryAndCostParameter() {
MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance()); MS_EXCEPTION_IF_NULL(CostModelContext::GetInstance());
......
...@@ -65,6 +65,7 @@ constexpr char STEP_PARALLEL_END[] = "step_parallel_end"; ...@@ -65,6 +65,7 @@ constexpr char STEP_PARALLEL_END[] = "step_parallel_end";
constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot"; constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot";
constexpr char REQUIRES_GRAD[] = "requires_grad"; constexpr char REQUIRES_GRAD[] = "requires_grad";
constexpr char PARAM_NAME[] = "name"; constexpr char PARAM_NAME[] = "name";
constexpr char RESHAPEINFO[] = "ReshapeInfo";
constexpr char RELU_TYPE[] = "relu"; constexpr char RELU_TYPE[] = "relu";
constexpr char RELU6_TYPE[] = "relu6"; constexpr char RELU6_TYPE[] = "relu6";
......
...@@ -2120,6 +2120,9 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { ...@@ -2120,6 +2120,9 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info(); OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info) { if (operator_info) {
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
continue;
}
StrategyPtr strategyPtr = operator_info->strategy(); StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope()); MS_EXCEPTION_IF_NULL(node->scope());
stra_map[param_name] = strategyPtr; stra_map[param_name] = strategyPtr;
......
...@@ -93,6 +93,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { ...@@ -93,6 +93,7 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) {
parallel_strategy_item->set_node_name(node_stra.first); parallel_strategy_item->set_node_name(node_stra.first);
straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys();
MS_EXCEPTION_IF_NULL(parallel_strategys); MS_EXCEPTION_IF_NULL(parallel_strategys);
MS_EXCEPTION_IF_NULL(node_stra.second);
parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage()));
for (auto &dims : node_stra.second->GetInputDim()) { for (auto &dims : node_stra.second->GetInputDim()) {
straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册