提交 4834a6b3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3574 Rename AnfNode::user_data related functions to follow naming rule

Merge pull request !3574 from hewei/rename_user_data_func
...@@ -267,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo ...@@ -267,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
return; return;
} }
auto operator_info = node->GetUserData<parallel::OperatorInfo>(); auto operator_info = node->user_data<parallel::OperatorInfo>();
if (operator_info == nullptr) { if (operator_info == nullptr) {
return; return;
} }
......
...@@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { ...@@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr) { if (graph_obj == nullptr || node == nullptr) {
return; return;
} }
auto distributed_operation_info = node->GetUserData<parallel::OperatorInfo>(); auto distributed_operation_info = node->user_data<parallel::OperatorInfo>();
if (distributed_operation_info != nullptr) { if (distributed_operation_info != nullptr) {
auto strategyPtr = distributed_operation_info->strategy(); auto strategyPtr = distributed_operation_info->strategy();
if (strategyPtr != nullptr) { if (strategyPtr != nullptr) {
......
...@@ -50,7 +50,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr &para, uint32_t ...@@ -50,7 +50,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr &para, uint32_t
if (node_prim->name() == DEPEND && node_pair.second != 1) { if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue; continue;
} }
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
(void)cnode_set.emplace(cnode); (void)cnode_set.emplace(cnode);
} else { } else {
auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
...@@ -98,7 +98,7 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi ...@@ -98,7 +98,7 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi
return cnode_dist; return cnode_dist;
} }
auto operator_info = cnode->GetUserData<OperatorInfo>(); auto operator_info = cnode->user_data<OperatorInfo>();
MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode)
<< " operator_info: " << (operator_info != nullptr); << " operator_info: " << (operator_info != nullptr);
......
...@@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { ...@@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) {
} }
auto para_ptr = node_ptr->cast<ParameterPtr>(); auto para_ptr = node_ptr->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(para_ptr); MS_EXCEPTION_IF_NULL(para_ptr);
auto layout_ptr = para_ptr->GetUserData<TensorLayout>(); auto layout_ptr = para_ptr->user_data<TensorLayout>();
if (layout_ptr == nullptr) { if (layout_ptr == nullptr) {
MS_LOG(ERROR) << "layout_ptr is nullptr!"; MS_LOG(ERROR) << "layout_ptr is nullptr!";
return FAILED; return FAILED;
......
...@@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { ...@@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
for (auto para : graph_params) { for (auto para : graph_params) {
std::string name = std::static_pointer_cast<Parameter>(para)->name(); std::string name = std::static_pointer_cast<Parameter>(para)->name();
auto tensor_layout = para->GetUserData<parallel::TensorLayout>(); auto tensor_layout = para->user_data<parallel::TensorLayout>();
if (tensor_layout == nullptr) { if (tensor_layout == nullptr) {
MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name;
} else { } else {
...@@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { ...@@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) {
if (node->isa<CNode>()) { if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
auto distributed_operation_info = cnode->GetUserData<OperatorInfo>(); auto distributed_operation_info = cnode->user_data<OperatorInfo>();
if (distributed_operation_info != nullptr) { if (distributed_operation_info != nullptr) {
auto strategyPtr = distributed_operation_info->strategy(); auto strategyPtr = distributed_operation_info->strategy();
if (strategyPtr != nullptr) { if (strategyPtr != nullptr) {
......
...@@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node ...@@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
entire_costgraph->AddOperator(operator_info); entire_costgraph->AddOperator(operator_info);
cnode->SetUserData<OperatorInfo>(operator_info); cnode->set_user_data<OperatorInfo>(operator_info);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
...@@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no ...@@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode); std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
entire_costgraph->AddOperator(operator_info); entire_costgraph->AddOperator(operator_info);
cnode->SetUserData<OperatorInfo>(operator_info); cnode->set_user_data<OperatorInfo>(operator_info);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
...@@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no ...@@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
<< " does not match the Prim: " << prim->name(); << " does not match the Prim: " << prim->name();
} }
cnode->SetUserData<OperatorInfo>(current_op_ptr); cnode->set_user_data<OperatorInfo>(current_op_ptr);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
...@@ -549,7 +549,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -549,7 +549,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
size_t edge_count = 0; size_t edge_count = 0;
auto node_op_info = cnode->GetUserData<OperatorInfo>(); auto node_op_info = cnode->user_data<OperatorInfo>();
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
auto prev_cnode = inputs[i]->cast<CNodePtr>(); auto prev_cnode = inputs[i]->cast<CNodePtr>();
...@@ -565,7 +565,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -565,7 +565,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
while (bool_result) { while (bool_result) {
if (IsAutoParallelCareNode(prev_cnode)) { if (IsAutoParallelCareNode(prev_cnode)) {
auto prev_op_info = prev_cnode->GetUserData<OperatorInfo>(); auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name();
// If the edge between these two operators already has been added, then the edge will not be added again. // If the edge between these two operators already has been added, then the edge will not be added again.
if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) {
...@@ -751,7 +751,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -751,7 +751,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
auto target_cnode = target.first->cast<CNodePtr>(); auto target_cnode = target.first->cast<CNodePtr>();
auto input_index = target.second; auto input_index = target.second;
(void)target_without_duplicate.insert(std::to_string(input_index) + (void)target_without_duplicate.insert(std::to_string(input_index) +
target_cnode->GetUserData<OperatorInfo>()->name()); target_cnode->user_data<OperatorInfo>()->name());
} }
if (target_without_duplicate.size() <= 1) { if (target_without_duplicate.size() <= 1) {
continue; continue;
...@@ -831,7 +831,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -831,7 +831,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
auto target_cnode = target.first->cast<CNodePtr>(); auto target_cnode = target.first->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0)); auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0));
auto input_index = target.second; auto input_index = target.second;
auto target_op_info = target_cnode->GetUserData<OperatorInfo>(); auto target_op_info = target_cnode->user_data<OperatorInfo>();
std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name(); std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name();
// If the edge between these two operators already has been added, then the edge will not be added again. // If the edge between these two operators already has been added, then the edge will not be added again.
...@@ -862,7 +862,7 @@ bool FindReshape(const CNodePtr &cnode) { ...@@ -862,7 +862,7 @@ bool FindReshape(const CNodePtr &cnode) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false; return false;
} }
if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) { if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
return false; return false;
} }
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
...@@ -884,7 +884,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ ...@@ -884,7 +884,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
if (!IsValueNode<Primitive>(cnode->input(0))) { if (!IsValueNode<Primitive>(cnode->input(0))) {
return false; return false;
} }
auto node_op_info = cnode->GetUserData<OperatorInfo>(); auto node_op_info = cnode->user_data<OperatorInfo>();
if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) {
*pre_operator_info = node_op_info; *pre_operator_info = node_op_info;
*out_index = 0; *out_index = 0;
...@@ -900,7 +900,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ ...@@ -900,7 +900,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
} }
CNodePtr pre_cnode = pre_node->cast<CNodePtr>(); CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
auto pre_op_info = pre_cnode->GetUserData<OperatorInfo>(); auto pre_op_info = pre_cnode->user_data<OperatorInfo>();
if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) {
*pre_operator_info = pre_op_info; *pre_operator_info = pre_op_info;
return true; return true;
...@@ -941,7 +941,7 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator ...@@ -941,7 +941,7 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator
if (node_prim->name() == DEPEND && node_pair.second != 1) { if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue; continue;
} }
auto op_info = use_apply->GetUserData<OperatorInfo>(); auto op_info = use_apply->user_data<OperatorInfo>();
if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
*next_operator_info = op_info; *next_operator_info = op_info;
...@@ -970,7 +970,7 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -970,7 +970,7 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
int32_t out_index = 0; int32_t out_index = 0;
OperatorInfoPtr pre_operator_info; OperatorInfoPtr pre_operator_info;
std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs; std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
auto operator_info = cnode->GetUserData<OperatorInfo>(); auto operator_info = cnode->user_data<OperatorInfo>();
if (pre_node->isa<Parameter>()) { if (pre_node->isa<Parameter>()) {
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info); auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info->SetCostForReshapeWithParameter(); reshape_info->SetCostForReshapeWithParameter();
......
...@@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { ...@@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
if (!IsParallelCareNode(node)) { if (!IsParallelCareNode(node)) {
return nullptr; return nullptr;
} }
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>(); OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
if (distribute_operator == nullptr) { if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr";
} }
...@@ -409,7 +409,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { ...@@ -409,7 +409,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
if (prim->name() == GET_NEXT) { if (prim->name() == GET_NEXT) {
return true; return true;
} }
if ((prim->name() == CAST) && !cnode->HasUserData<OperatorInfo>()) { if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
return false; return false;
} }
...@@ -446,7 +446,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ ...@@ -446,7 +446,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
if (node_prim->name() == DEPEND && node_pair.second != 1) { if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue; continue;
} }
if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData<OperatorInfo>()) { if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution,
pre_node); pre_node);
} else { } else {
...@@ -459,7 +459,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ ...@@ -459,7 +459,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(next_node); MS_EXCEPTION_IF_NULL(next_node);
OperatorInfoPtr op_info = next_node->GetUserData<OperatorInfo>(); OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>();
MS_EXCEPTION_IF_NULL(op_info); MS_EXCEPTION_IF_NULL(op_info);
// If the shape of tensor is [] or [1], no need to split it. // If the shape of tensor is [] or [1], no need to split it.
...@@ -584,7 +584,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { ...@@ -584,7 +584,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
// step1:get graph manager distribute_operator // step1:get graph manager distribute_operator
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>(); OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
if (distribute_operator == nullptr) { if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr";
} }
...@@ -622,7 +622,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { ...@@ -622,7 +622,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
(void)prim->SetAttrs(attrs); (void)prim->SetAttrs(attrs);
} }
if (index == replace_op.size() - 1) { if (index == replace_op.size() - 1) {
replace_node->SetUserData<OperatorInfo>(node->GetUserData<OperatorInfo>()); replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
} }
replace_node->set_in_forward_flag(true); replace_node->set_in_forward_flag(true);
replace_input[0]->set_scope(scope); replace_input[0]->set_scope(scope);
...@@ -702,7 +702,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { ...@@ -702,7 +702,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) {
auto pre_cnode = pre_node->cast<CNodePtr>(); auto pre_cnode = pre_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode); MS_EXCEPTION_IF_NULL(pre_cnode);
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
if (pre_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) { if (pre_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
pre_node = pre_cnode->input(1); pre_node = pre_cnode->input(1);
} }
...@@ -1198,7 +1198,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) { ...@@ -1198,7 +1198,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) {
if (node_prim->name() == DEPEND && node_pair.second != 1) { if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue; continue;
} }
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
return node_pair; return node_pair;
} else if (FindParallelCareNode(node_pair.first).first != nullptr) { } else if (FindParallelCareNode(node_pair.first).first != nullptr) {
return FindParallelCareNode(node_pair.first); return FindParallelCareNode(node_pair.first);
...@@ -1248,7 +1248,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i ...@@ -1248,7 +1248,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString(); MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
CNodePtr cnode = res.first->cast<CNodePtr>(); CNodePtr cnode = res.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
OperatorInfoPtr distribute_operator = cnode->GetUserData<OperatorInfo>(); OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
if (distribute_operator == nullptr) { if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
} }
...@@ -1271,7 +1271,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i ...@@ -1271,7 +1271,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
TensorLayout tensor_layout = tensorinfo_in.tensor_layout(); TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>(); ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter_ptr); MS_EXCEPTION_IF_NULL(parameter_ptr);
parameter_ptr->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout)); parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
} }
void CoverSliceShape(const FuncGraphPtr &root) { void CoverSliceShape(const FuncGraphPtr &root) {
...@@ -1359,7 +1359,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { ...@@ -1359,7 +1359,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if (found_be_cloned_parameter) { if (found_be_cloned_parameter) {
// set the shape and tensor layout for cloned parameter // set the shape and tensor layout for cloned parameter
cloned_parameter->SetUserData<TensorLayout>(cloned_from_parameter->GetUserData<TensorLayout>()); cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>());
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
...@@ -1454,7 +1454,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -1454,7 +1454,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
(*operator_).set_outputs_dtype(cnode->Type()); (*operator_).set_outputs_dtype(cnode->Type());
(*operator_).set_cnode(cnode); (*operator_).set_cnode(cnode);
if (prim->name() == RESHAPE) { if (prim->name() == RESHAPE) {
cnode->SetUserData<OperatorInfo>(operator_); cnode->set_user_data<OperatorInfo>(operator_);
continue; continue;
} }
// load strategy checkpoint // load strategy checkpoint
...@@ -1489,7 +1489,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -1489,7 +1489,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
if (operator_->Init(strategyPtr) == FAILED) { if (operator_->Init(strategyPtr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
} }
cnode->SetUserData<OperatorInfo>(operator_); cnode->set_user_data<OperatorInfo>(operator_);
} else { } else {
MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr";
} }
...@@ -1532,13 +1532,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) { ...@@ -1532,13 +1532,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
if (node_prim->name() == DEPEND && node_pair.second != 1) { if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue; continue;
} }
if (IsParallelCareNode(use_apply) && use_apply->HasUserData<OperatorInfo>()) { if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
auto layout = GetInputLayoutFromCNode(node_pair); auto layout = GetInputLayoutFromCNode(node_pair);
return std::make_shared<TensorLayout>(layout); return std::make_shared<TensorLayout>(layout);
} }
MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
<< " " << use_apply->HasUserData<OperatorInfo>(); << " " << use_apply->has_user_data<OperatorInfo>();
auto layout_ptr = FindNextLayout(use_apply); auto layout_ptr = FindNextLayout(use_apply);
if (layout_ptr) { if (layout_ptr) {
...@@ -1570,7 +1570,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n ...@@ -1570,7 +1570,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
if (!IsValueNode<Primitive>(cnode->input(0))) { if (!IsValueNode<Primitive>(cnode->input(0))) {
return nullptr; return nullptr;
} }
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
if (!layout_ptr) { if (!layout_ptr) {
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
...@@ -1614,7 +1614,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) { ...@@ -1614,7 +1614,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
if (!IsValueNode<Primitive>(cnode->input(0))) { if (!IsValueNode<Primitive>(cnode->input(0))) {
return nullptr; return nullptr;
} }
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) { if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
if (!layout_ptr) { if (!layout_ptr) {
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
...@@ -1654,12 +1654,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -1654,12 +1654,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
continue; continue;
} }
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) { if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
continue; continue;
} }
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>(); OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
if (operator_info == nullptr) { if (operator_info == nullptr) {
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
} }
...@@ -1704,7 +1704,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { ...@@ -1704,7 +1704,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
// return -> cast // return -> cast
if (current_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) { if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>(); pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode); MS_EXCEPTION_IF_NULL(pre_cnode);
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0)); current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
...@@ -1761,7 +1761,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { ...@@ -1761,7 +1761,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
return ret; return ret;
} }
OperatorInfoPtr operator_info = loss_cnode->GetUserData<OperatorInfo>(); OperatorInfoPtr operator_info = loss_cnode->user_data<OperatorInfo>();
MS_EXCEPTION_IF_NULL(operator_info); MS_EXCEPTION_IF_NULL(operator_info);
TensorInfo loss_grad_tensor_info; TensorInfo loss_grad_tensor_info;
size_t op_output_size = operator_info->outputs_tensor_info().size(); size_t op_output_size = operator_info->outputs_tensor_info().size();
...@@ -1799,7 +1799,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay ...@@ -1799,7 +1799,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
if (sens_tensor_node->isa<Parameter>()) { if (sens_tensor_node->isa<Parameter>()) {
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
} }
MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens";
return; return;
...@@ -1824,7 +1824,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay ...@@ -1824,7 +1824,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
cloned_abstract->set_shape(parallel_shape); cloned_abstract->set_shape(parallel_shape);
sens_tensor_node->set_abstract(cloned_abstract); sens_tensor_node->set_abstract(cloned_abstract);
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>(); auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout)); sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
return; return;
} }
MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now.";
...@@ -2131,7 +2131,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { ...@@ -2131,7 +2131,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
} }
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0)); PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>(); OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
if (operator_info) { if (operator_info) {
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
continue; continue;
......
...@@ -158,29 +158,29 @@ class AnfNode : public Base { ...@@ -158,29 +158,29 @@ class AnfNode : public Base {
size_t seen_{0}; size_t seen_{0};
template <typename T> template <typename T>
void SetUserData(const std::string &key, const std::shared_ptr<T> &value) { void set_user_data(const std::string &key, const std::shared_ptr<T> &value) {
user_data_.set<T>(key, value); user_data_.set<T>(key, value);
} }
template <typename T> template <typename T>
void SetUserData(const std::shared_ptr<T> &value) { void set_user_data(const std::shared_ptr<T> &value) {
user_data_.set<T>(T::key, value); user_data_.set<T>(T::key, value);
} }
template <typename T> template <typename T>
std::shared_ptr<T> GetUserData(const std::string &key) const { std::shared_ptr<T> user_data(const std::string &key) const {
return user_data_.get<T>(key); return user_data_.get<T>(key);
} }
template <typename T> template <typename T>
std::shared_ptr<T> GetUserData() const { std::shared_ptr<T> user_data() const {
return user_data_.get<T>(T::key); return user_data_.get<T>(T::key);
} }
bool HasUserData(const std::string &key) const { return user_data_.has(key); } bool has_user_data(const std::string &key) const { return user_data_.has(key); }
template <typename T> template <typename T>
bool HasUserData() const { bool has_user_data() const {
return user_data_.has(T::key); return user_data_.has(T::key);
} }
......
...@@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) { ...@@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) {
StrategyPtr strategyPtr; StrategyPtr strategyPtr;
std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape); std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape);
node->SetUserData<OperatorInfo>(matmul_info); node->set_user_data<OperatorInfo>(matmul_info);
std::string name_expect = "MatMulInfo00"; std::string name_expect = "MatMulInfo00";
std::string name_test = matmul_info->name(); std::string name_test = matmul_info->name();
ASSERT_EQ(name_expect, name_test); ASSERT_EQ(name_expect, name_test);
......
...@@ -522,8 +522,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) { ...@@ -522,8 +522,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) {
std::vector<Shapes> shape = {inputs_shape, outputs_shape}; std::vector<Shapes> shape = {inputs_shape, outputs_shape};
OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
matmul_info->Init(strategyPtr); matmul_info->Init(strategyPtr);
node->SetUserData<OperatorInfo>(matmul_info); node->set_user_data<OperatorInfo>(matmul_info);
OperatorInfoPtr distribute_operator_pre = node->GetUserData<OperatorInfo>(); OperatorInfoPtr distribute_operator_pre = node->user_data<OperatorInfo>();
TensorLayout tensorlayout_e; TensorLayout tensorlayout_e;
std::vector<int32_t> array = {64, 64}; std::vector<int32_t> array = {64, 64};
TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre); TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册