提交 5a654045 编写于 作者: Y yao_yf

use param name as the key of strategy checkpoint

上级 1779479d
......@@ -61,6 +61,8 @@ constexpr char CROSS_BATCH[] = "cross_batch";
constexpr char STEP_PARALLEL_BEGIN[] = "step_parallel_begin";
constexpr char STEP_PARALLEL_END[] = "step_parallel_end";
constexpr char STEP_AUTO_PARALLEL_BEGIN[] = "step_auto_parallel_begin.dot";
constexpr char REQUIRES_GRAD[] = "requires_grad";
constexpr char PARAM_NAME[] = "name";
constexpr char RELU_TYPE[] = "relu";
constexpr char RELU6_TYPE[] = "relu6";
......
......@@ -387,8 +387,7 @@ OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &
operator_info->set_outputs_dtype(cnode->Type());
operator_info->set_cnode(cnode);
// key of strategy map
std::string instance_name = prim->instance_name();
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
std::string strategy_key_name = NodeParameterName(cnode);
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map->find(strategy_key_name) != stra_map->end();
// If no strategy has been configured for this operator, then candidate strategies are generated for
......
......@@ -1423,11 +1423,9 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
}
// load strategy checkpoint
// key of strategy map
std::string instance_name = prim->instance_name();
std::string strategy_key_name = cnode->scope()->name() + std::string(CONNSYMBOL) + instance_name;
std::string strategy_key_name = NodeParameterName(cnode);
bool load_strategy_from_ckpt =
StrategyCheckpoint::GetInstance().LoadCheckPointOn() && stra_map.find(strategy_key_name) != stra_map.end();
if (!StrategyFound(attrs) && !load_strategy_from_ckpt) {
MS_LOG(INFO) << "ExtractInformation: the strategy of node " << node->ToString() << " prim " << prim->name()
<< " is empty, using batch parallel";
......@@ -2038,17 +2036,20 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector<AnfNo
}
}
bool NodeWithParameter(const CNodePtr &node) {
std::string NodeParameterName(const CNodePtr &node) {
std::vector<AnfNodePtr> node_inputs{node->inputs()};
for (auto input : node_inputs) {
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) {
return py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), "requires_grad"));
if (py::cast<bool>(parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), REQUIRES_GRAD))) {
return py::cast<std::string>(
parse::python_adapter::GetPyObjAttr(input_parameter->default_param(), PARAM_NAME));
}
}
}
}
return false;
return "";
}
void CheckpointStrategy(const FuncGraphPtr &func_graph) {
......@@ -2060,21 +2061,20 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
for (auto &node : all_nodes) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0)) || !NodeWithParameter(cnode)) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
continue;
}
std::string param_name = NodeParameterName(cnode);
if (param_name.empty()) {
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info) {
if (prim->instance_name().empty()) {
MS_LOG(EXCEPTION) << "Node with parameter to checkpoint strategy needs instance name";
}
std::string instance_name = prim->instance_name();
StrategyPtr strategyPtr = operator_info->strategy();
MS_EXCEPTION_IF_NULL(node->scope());
std::string node_name = node->scope()->name() + std::string(CONNSYMBOL) + instance_name;
stra_map[node_name] = strategyPtr;
stra_map[param_name] = strategyPtr;
}
}
if (StrategyCheckpoint::GetInstance().Save(stra_map) != SUCCESS) {
......
......@@ -135,7 +135,7 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes);
void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePtr> &all_nodes,
const FuncGraphManagerPtr &manager);
bool NodeWithParameter(const CNodePtr &node);
std::string NodeParameterName(const CNodePtr &node);
void CheckpointStrategy(const FuncGraphPtr &func_graph);
......
......@@ -25,7 +25,6 @@
namespace mindspore {
namespace parallel {
using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
class StrategyCheckpoint {
public:
......
......@@ -59,6 +59,7 @@ def test_six_matmul_save():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
def construct(self, x1, x6):
out = self.matmul1(x1, self.weight1)
......@@ -66,6 +67,7 @@ def test_six_matmul_save():
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6)
return out
......@@ -118,12 +120,14 @@ def test_six_matmul_load():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
def construct(self, x1, x6, x7):
out = self.matmul1(x1, self.weight1)
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6)
out = self.matmul7(out, x7)
return out
......@@ -179,6 +183,7 @@ def test_six_matmul_save_auto():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
def construct(self, x1, x6):
out = self.matmul1(x1, self.weight1)
......@@ -186,6 +191,7 @@ def test_six_matmul_save_auto():
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6)
return out
......@@ -232,12 +238,14 @@ def test_six_matmul_load_auto():
self.weight3 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight3")
self.weight4 = Parameter(Tensor(np.ones([128, 64]), dtype=ms.float32), name="weight4")
self.weight5 = Parameter(Tensor(np.ones([64, 128]), dtype=ms.float32), name="weight5")
self.weight6 = Parameter(Tensor(np.ones([32, 128]), dtype=ms.float32), name="weight6")
def construct(self, x1, x6, x7):
out = self.matmul1(x1, self.weight1)
out = self.matmul3(out, self.weight3)
out = self.matmul4(out, self.weight4)
out = self.matmul5(out, self.weight5)
out = out + self.weight6
out = self.matmul6(out, x6)
out = self.matmul7(out, x7)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册