diff --git a/mindspore/ccsrc/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/parallel/ops_info/ops_utils.h index bdae87858d92d44c966a21df5fe016fa525fd40f..e0b62eb23376002f9a3c76e674416aae3ddb0491 100644 --- a/mindspore/ccsrc/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/parallel/ops_info/ops_utils.h @@ -61,6 +61,8 @@ constexpr char CROSS_BATCH[] = "cross_batch"; constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin"; constexpr char STEP_PARALLEL_END[] = "step_parallel_end"; constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot"; +constexpr char REQUIRES_GRAD[] = "requires_grad"; +constexpr char PARAM_NAME[] = "name"; constexpr char RELU_TYPE[] = "relu"; constexpr char RELU6_TYPE[] = "relu6"; diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index f0be47642e5664acb1a29ea20cbd5799f98f997f..b16108a279bce77b702da372605dfd4a83730f4c 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -387,8 +387,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & operator_info->set_outputs_dtype(cnode->Type()); operator_info->set_cnode(cnode); // key of strategy map - std::string instance_name = prim->instance_name(); - std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name; + std::string strategy_key_name = NodeParameterName(cnode); bool load_strategy_from_ckpt = StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); // If no strategy has been configured for this operator, then candidate strategies are generated for diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 62fb96c2979a0680b2b88c52785cef3f334036aa..21a515ff854917a954b6952ee6137e1553587c4a 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -1423,11 +1423,9 @@ void ExtractInformation(const std::vector &all_nodes) { } // load strategy checkpoint // key of strategy map - std::string instance_name = prim->instance_name(); - std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name; + std::string strategy_key_name = NodeParameterName(cnode); bool load_strategy_from_ckpt = StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); - if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() << " is empty, using batch parallel"; @@ -2038,17 +2036,20 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector node_inputs{node->inputs()}; for (auto input : node_inputs) { if (input->isa()) { auto input_parameter = input->cast(); if (input_parameter->has_default()) { - return py::cast(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad")); + if (py::cast(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), REQUIRES_GRAD))) { + return py::cast( + parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), PARAM_NAME)); + } } } } - return false; + return ""; } void CheckpointStrategy(const FuncGraphPtr &func_graph) { @@ -2060,21 +2061,20 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); - if ((cnode == nullptr) || !IsValueNode(cnode->input(0)) || !NodeWithParameter(cnode)) { + if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { + continue; + } + std::string param_name = NodeParameterName(cnode); + if (param_name.empty()) { continue; } PrimitivePtr prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(prim); OperatorInfoPtr operator_info = cnode->operator_info(); if (operator_info) { - if (prim->instance_name().empty()) { - MS_LOG(EXCEPTION) << "Node with parameter to checkpoint strategy needs instance name"; - } - std::string instance_name = prim->instance_name(); StrategyPtr strategyPtr = operator_info->strategy(); MS_EXCEPTION_IF_NULL(node->scope()); - std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name; - stra_map[node_name] = strategyPtr; + stra_map[param_name] = strategyPtr; } } if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { diff --git a/mindspore/ccsrc/parallel/step_parallel.h b/mindspore/ccsrc/parallel/step_parallel.h index c26f65ec656425107737eb39bdb5a709b680606a..93c3ed798cb5935ee0302f1dff07875a30922123 100644 --- a/mindspore/ccsrc/parallel/step_parallel.h +++ b/mindspore/ccsrc/parallel/step_parallel.h @@ -135,7 +135,7 @@ void ReshapeInit(const std::vector &all_nodes); void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager); -bool NodeWithParameter(const CNodePtr &node); +std::string NodeParameterName(const CNodePtr &node); void CheckpointStrategy(const FuncGraphPtr &func_graph); diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h index 0cf6229fa363978c86aff242937e459f1c411153..a758a9e7bb4739f43d2a0e3d824192d4db1a9e3b 100644 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h +++ b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h @@ -25,7 +25,6 @@ namespace mindspore { namespace parallel { - using StrategyMap = std::unordered_map; class StrategyCheckpoint { public: diff --git a/tests/ut/python/parallel/test_strategy_checkpoint.py b/tests/ut/python/parallel/test_strategy_checkpoint.py index 89b6dd1dbbb1f496725d8874f85726e6343434f7..d95b13f435f1d0a51f2e15c39a4355254fcef5d1 100644 --- a/tests/ut/python/parallel/test_strategy_checkpoint.py +++ b/tests/ut/python/parallel/test_strategy_checkpoint.py @@ -59,6 +59,7 @@ def test_six_matmul_save(): self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") + self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") def construct(self, x1, x6): out = self.matmul1(x1, self.weight1) @@ -66,6 +67,7 @@ def test_six_matmul_save(): out = self.matmul3(out, self.weight3) out = self.matmul4(out, self.weight4) out = self.matmul5(out, self.weight5) + out = out + self.weight6 out = self.matmul6(out, x6) return out @@ -118,12 +120,14 @@ def test_six_matmul_load(): self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") + self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") def construct(self, x1, x6, x7): out = self.matmul1(x1, self.weight1) out = self.matmul3(out, self.weight3) out = self.matmul4(out, self.weight4) out = self.matmul5(out, self.weight5) + out = out + self.weight6 out = self.matmul6(out, x6) out = self.matmul7(out, x7) return out @@ -179,6 +183,7 @@ def test_six_matmul_save_auto(): self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") + self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") def construct(self, x1, x6): out = self.matmul1(x1, self.weight1) @@ -186,6 +191,7 @@ def test_six_matmul_save_auto(): out = self.matmul3(out, self.weight3) out = self.matmul4(out, self.weight4) out = self.matmul5(out, self.weight5) + out = out + self.weight6 out = self.matmul6(out, x6) return out @@ -232,12 +238,14 @@ def test_six_matmul_load_auto(): self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") + self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6") def construct(self, x1, x6, x7): out = self.matmul1(x1, self.weight1) out = self.matmul3(out, self.weight3) out = self.matmul4(out, self.weight4) out = self.matmul5(out, self.weight5) + out = out + self.weight6 out = self.matmul6(out, x6) out = self.matmul7(out, x7) return out