提交 5a654045 编写于 作者: Y yao_yf

use param name as the key of strategy checkpoint

上级 1779479d
...@@ -61,6 +61,8 @@ constexpr char CROSS_BATCH[] = "cross_batch"; ...@@ -61,6 +61,8 @@ constexpr char CROSS_BATCH[] = "cross_batch";
constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin"; constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin";
constexpr char STEP_PARALLEL_END[] = "step_parallel_end"; constexpr char STEP_PARALLEL_END[] = "step_parallel_end";
constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot"; constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot";
constexpr char REQUIRES_GRAD[] = "requires_grad";
constexpr char PARAM_NAME[] = "name";
constexpr char RELU_TYPE[] = "relu"; constexpr char RELU_TYPE[] = "relu";
constexpr char RELU6_TYPE[] = "relu6"; constexpr char RELU6_TYPE[] = "relu6";
......
...@@ -387,8 +387,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr & ...@@ -387,8 +387,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info->set_outputs_dtype(cnode->Type()); operator_info->set_outputs_dtype(cnode->Type());
operator_info->set_cnode(cnode); operator_info->set_cnode(cnode);
// key of strategy map // key of strategy map
std::string instance_name = prim->instance_name(); std::string strategy_key_name = NodeParameterName(cnode);
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
bool load_strategy_from_ckpt = bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end(); StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
// If no strategy has been configured for this operator, then candidate strategies are generated for // If no strategy has been configured for this operator, then candidate strategies are generated for
......
...@@ -1423,11 +1423,9 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -1423,11 +1423,9 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
} }
// load strategy checkpoint // load strategy checkpoint
// key of strategy map // key of strategy map
std::string instance_name = prim->instance_name(); std::string strategy_key_name = NodeParameterName(cnode);
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
bool load_strategy_from_ckpt = bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end(); StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) { if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name() MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel"; << " is empty, using batch parallel";
...@@ -2038,17 +2036,20 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo ...@@ -2038,17 +2036,20 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
} }
} }
bool NodeWithParameter(const CNodePtr &node) { std::string NodeParameterName(const CNodePtr &node) {
std::vector<AnfNodePtr> node_inputs{node->inputs()}; std::vector<AnfNodePtr> node_inputs{node->inputs()};
for (auto input : node_inputs) { for (auto input : node_inputs) {
if (input->isa<Parameter>()) { if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>(); auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) { if (input_parameter->has_default()) {
return py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad")); if (py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), REQUIRES_GRAD))) {
return py::cast<std::string>(
parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), PARAM_NAME));
}
} }
} }
} }
return false; return "";
} }
void CheckpointStrategy(const FuncGraphPtr &func_graph) { void CheckpointStrategy(const FuncGraphPtr &func_graph) {
...@@ -2060,21 +2061,20 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { ...@@ -2060,21 +2061,20 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0)) || !NodeWithParameter(cnode)) { if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
std::string param_name = NodeParameterName(cnode);
if (param_name.empty()) {
continue; continue;
} }
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info(); OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info) { if (operator_info) {
if (prim->instance_name().empty()) {
MS_LOG(EXCEPTION) << "Node with parameter to checkpoint strategy needs instance name";
}
std::string instance_name = prim->instance_name();
StrategyPtr strategyPtr = operator_info->strategy(); StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope()); MS_EXCEPTION_IF_NULL(node->scope());
std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name; stra_map[param_name] = strategyPtr;
stra_map[node_name] = strategyPtr;
} }
} }
if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) { if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) {
......
...@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes); ...@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes, void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager); const FuncGraphManagerPtr &manager);
bool NodeWithParameter(const CNodePtr &node); std::string NodeParameterName(const CNodePtr &node);
void CheckpointStrategy(const FuncGraphPtr &func_graph); void CheckpointStrategy(const FuncGraphPtr &func_graph);
......
...@@ -25,7 +25,6 @@ ...@@ -25,7 +25,6 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
using StrategyMap = std::unordered_map<std::string, StrategyPtr>; using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
class StrategyCheckpoint { class StrategyCheckpoint {
public: public:
......
...@@ -59,6 +59,7 @@ def test_six_matmul_save(): ...@@ -59,6 +59,7 @@ def test_six_matmul_save():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
def construct(self, x1, x6): def construct(self, x1, x6):
out = self.matmul1(x1, self.weight1) out = self.matmul1(x1, self.weight1)
...@@ -66,6 +67,7 @@ def test_six_matmul_save(): ...@@ -66,6 +67,7 @@ def test_six_matmul_save():
out = self.matmul3(out, self.weight3) out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4) out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5) out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6) out = self.matmul6(out, x6)
return out return out
...@@ -118,12 +120,14 @@ def test_six_matmul_load(): ...@@ -118,12 +120,14 @@ def test_six_matmul_load():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
def construct(self, x1, x6, x7): def construct(self, x1, x6, x7):
out = self.matmul1(x1, self.weight1) out = self.matmul1(x1, self.weight1)
out = self.matmul3(out, self.weight3) out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4) out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5) out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6) out = self.matmul6(out, x6)
out = self.matmul7(out, x7) out = self.matmul7(out, x7)
return out return out
...@@ -179,6 +183,7 @@ def test_six_matmul_save_auto(): ...@@ -179,6 +183,7 @@ def test_six_matmul_save_auto():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
def construct(self, x1, x6): def construct(self, x1, x6):
out = self.matmul1(x1, self.weight1) out = self.matmul1(x1, self.weight1)
...@@ -186,6 +191,7 @@ def test_six_matmul_save_auto(): ...@@ -186,6 +191,7 @@ def test_six_matmul_save_auto():
out = self.matmul3(out, self.weight3) out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4) out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5) out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6) out = self.matmul6(out, x6)
return out return out
...@@ -232,12 +238,14 @@ def test_six_matmul_load_auto(): ...@@ -232,12 +238,14 @@ def test_six_matmul_load_auto():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3") self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4") self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5") self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
def construct(self, x1, x6, x7): def construct(self, x1, x6, x7):
out = self.matmul1(x1, self.weight1) out = self.matmul1(x1, self.weight1)
out = self.matmul3(out, self.weight3) out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4) out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5) out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6) out = self.matmul6(out, x6)
out = self.matmul7(out, x7) out = self.matmul7(out, x7)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册