diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index 75e7df1882b89c441261ebf72215c8f2d03159be..a1cc80f96bac07e7046d6d01400791206b602f69 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -267,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptrGetUserData(); + auto operator_info = node->user_data(); if (operator_info == nullptr) { return; } diff --git a/mindspore/ccsrc/debug/draw.cc b/mindspore/ccsrc/debug/draw.cc index eee5394dadc9782f4a1778c917c544ec1f5f8bec..6855fd8bd505fab4d376195fc5e63e98f8af267b 100644 --- a/mindspore/ccsrc/debug/draw.cc +++ b/mindspore/ccsrc/debug/draw.cc @@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { if (graph_obj == nullptr || node == nullptr) { return; } - auto distributed_operation_info = node->GetUserData(); + auto distributed_operation_info = node->user_data(); if (distributed_operation_info != nullptr) { auto strategyPtr = distributed_operation_info->strategy(); if (strategyPtr != nullptr) { diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc index f69bda41001c28501bc4226507812934834e731c..927bb2e73b660eda4cb6fd6d94920774958ab722 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc @@ -50,7 +50,7 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr ¶, uint32_t if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(cnode) && cnode->HasUserData()) { + if (IsParallelCareNode(cnode) && cnode->has_user_data()) { (void)cnode_set.emplace(cnode); } else { auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); @@ -98,7 +98,7 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi return cnode_dist; } - auto operator_info = cnode->GetUserData(); + auto operator_info = cnode->user_data(); MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) << " operator_info: " << (operator_info != nullptr); diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc index 93f01943bfb98f1a88373817530fa2f38100e692..05b40e3bb9c324fab116b88342231dc2b570bc96 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc @@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { } auto para_ptr = node_ptr->cast(); MS_EXCEPTION_IF_NULL(para_ptr); - auto layout_ptr = para_ptr->GetUserData(); + auto layout_ptr = para_ptr->user_data(); if (layout_ptr == nullptr) { MS_LOG(ERROR) << "layout_ptr is nullptr!"; return FAILED; diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc index bdcbe9fd3d0c51a443c3ac2a50533d67d3316fe1..d8794240d4b121f0cc2499a7e45024659c96da44 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc @@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { for (auto para : graph_params) { std::string name = std::static_pointer_cast(para)->name(); - auto tensor_layout = para->GetUserData(); + auto tensor_layout = para->user_data(); if (tensor_layout == nullptr) { MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; } else { @@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { if (node->isa()) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - auto distributed_operation_info = cnode->GetUserData(); + auto distributed_operation_info = cnode->user_data(); if (distributed_operation_info != nullptr) { auto strategyPtr = distributed_operation_info->strategy(); if (strategyPtr != nullptr) { diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 08cdde8d22a8d805476ce15adb121851003af979..4a2da2002480a74eee486028a02132ba3e4ca921 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); entire_costgraph->AddOperator(operator_info); - cnode->SetUserData(operator_info); + cnode->set_user_data(operator_info); MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); @@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); entire_costgraph->AddOperator(operator_info); - cnode->SetUserData(operator_info); + cnode->set_user_data(operator_info); MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); @@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() << " does not match the Prim: " << prim->name(); } - cnode->SetUserData(current_op_ptr); + cnode->set_user_data(current_op_ptr); MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); @@ -549,7 +549,7 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { PrimitivePtr prim = GetValueNode(prim_anf_node); size_t edge_count = 0; - auto node_op_info = cnode->GetUserData(); + auto node_op_info = cnode->user_data(); for (size_t i = 1; i < inputs.size(); ++i) { auto prev_cnode = inputs[i]->cast(); @@ -565,7 +565,7 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); while (bool_result) { if (IsAutoParallelCareNode(prev_cnode)) { - auto prev_op_info = prev_cnode->GetUserData(); + auto prev_op_info = prev_cnode->user_data(); std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); // If the edge between these two operators already has been added, then the edge will not be added again. if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { @@ -751,7 +751,7 @@ void AugmentCostGraph(const std::vector &all_nodes) { auto target_cnode = target.first->cast(); auto input_index = target.second; (void)target_without_duplicate.insert(std::to_string(input_index) + - target_cnode->GetUserData()->name()); + target_cnode->user_data()->name()); } if (target_without_duplicate.size() <= 1) { continue; @@ -831,7 +831,7 @@ void AugmentCostGraph(const std::vector &all_nodes) { auto target_cnode = target.first->cast(); auto prim = GetValueNode(target_cnode->input(0)); auto input_index = target.second; - auto target_op_info = target_cnode->GetUserData(); + auto target_op_info = target_cnode->user_data(); std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name(); // If the edge between these two operators already has been added, then the edge will not be added again. @@ -862,7 +862,7 @@ bool FindReshape(const CNodePtr &cnode) { if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { return false; } - if (!IsParallelCareNode(cnode) || !cnode->HasUserData()) { + if (!IsParallelCareNode(cnode) || !cnode->has_user_data()) { return false; } ValueNodePtr prim_anf_node = cnode->input(0)->cast(); @@ -884,7 +884,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ if (!IsValueNode(cnode->input(0))) { return false; } - auto node_op_info = cnode->GetUserData(); + auto node_op_info = cnode->user_data(); if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { *pre_operator_info = node_op_info; *out_index = 0; @@ -900,7 +900,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; } CNodePtr pre_cnode = pre_node->cast(); - auto pre_op_info = pre_cnode->GetUserData(); + auto pre_op_info = pre_cnode->user_data(); if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { *pre_operator_info = pre_op_info; return true; @@ -941,7 +941,7 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - auto op_info = use_apply->GetUserData(); + auto op_info = use_apply->user_data(); if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); *next_operator_info = op_info; @@ -970,7 +970,7 @@ void ReshapeCostCompute(const std::vector &all_nodes) { int32_t out_index = 0; OperatorInfoPtr pre_operator_info; std::vector> pre_stra_costs; - auto operator_info = cnode->GetUserData(); + auto operator_info = cnode->user_data(); if (pre_node->isa()) { auto reshape_info = std::dynamic_pointer_cast(operator_info); reshape_info->SetCostForReshapeWithParameter(); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 9b469809b6194c6a0a076a10f7b21f77b3758855..9d407eba6053aa6a3bda111bb97b960713214242 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { if (!IsParallelCareNode(node)) { return nullptr; } - OperatorInfoPtr distribute_operator = node->GetUserData(); + OperatorInfoPtr distribute_operator = node->user_data(); if (distribute_operator == nullptr) { MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; } @@ -409,7 +409,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { if (prim->name() == GET_NEXT) { return true; } - if ((prim->name() == CAST) && !cnode->HasUserData()) { + if ((prim->name() == CAST) && !cnode->has_user_data()) { return false; } @@ -446,7 +446,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData()) { + if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data()) { Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, pre_node); } else { @@ -459,7 +459,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(next_node); - OperatorInfoPtr op_info = next_node->GetUserData(); + OperatorInfoPtr op_info = next_node->user_data(); MS_EXCEPTION_IF_NULL(op_info); // If the shape of tensor is [] or [1], no need to split it. @@ -584,7 +584,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { // step1:get graph manager distribute_operator - OperatorInfoPtr distribute_operator = node->GetUserData(); + OperatorInfoPtr distribute_operator = node->user_data(); if (distribute_operator == nullptr) { MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; } @@ -622,7 +622,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { (void)prim->SetAttrs(attrs); } if (index == replace_op.size() - 1) { - replace_node->SetUserData(node->GetUserData()); + replace_node->set_user_data(node->user_data()); } replace_node->set_in_forward_flag(true); replace_input[0]->set_scope(scope); @@ -702,7 +702,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { auto pre_cnode = pre_node->cast(); MS_EXCEPTION_IF_NULL(pre_cnode); auto pre_prim = GetValueNode(pre_cnode->input(0)); - if (pre_prim->name() == CAST && !pre_cnode->HasUserData()) { + if (pre_prim->name() == CAST && !pre_cnode->has_user_data()) { pre_node = pre_cnode->input(1); } @@ -1198,7 +1198,7 @@ std::pair FindParallelCareNode(const AnfNodePtr &node) { if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(cnode) && cnode->HasUserData()) { + if (IsParallelCareNode(cnode) && cnode->has_user_data()) { return node_pair; } else if (FindParallelCareNode(node_pair.first).first != nullptr) { return FindParallelCareNode(node_pair.first); @@ -1248,7 +1248,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pairToString() << " shape " << parameter->Shape()->ToString(); CNodePtr cnode = res.first->cast(); MS_EXCEPTION_IF_NULL(cnode); - OperatorInfoPtr distribute_operator = cnode->GetUserData(); + OperatorInfoPtr distribute_operator = cnode->user_data(); if (distribute_operator == nullptr) { MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; } @@ -1271,7 +1271,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::paircast(); MS_EXCEPTION_IF_NULL(parameter_ptr); - parameter_ptr->SetUserData(std::make_shared(tensor_layout)); + parameter_ptr->set_user_data(std::make_shared(tensor_layout)); } void CoverSliceShape(const FuncGraphPtr &root) { @@ -1359,7 +1359,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { if (found_be_cloned_parameter) { // set the shape and tensor layout for cloned parameter - cloned_parameter->SetUserData(cloned_from_parameter->GetUserData()); + cloned_parameter->set_user_data(cloned_from_parameter->user_data()); MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); @@ -1454,7 +1454,7 @@ void ExtractInformation(const std::vector &all_nodes) { (*operator_).set_outputs_dtype(cnode->Type()); (*operator_).set_cnode(cnode); if (prim->name() == RESHAPE) { - cnode->SetUserData(operator_); + cnode->set_user_data(operator_); continue; } // load strategy checkpoint @@ -1489,7 +1489,7 @@ void ExtractInformation(const std::vector &all_nodes) { if (operator_->Init(strategyPtr) == FAILED) { MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; } - cnode->SetUserData(operator_); + cnode->set_user_data(operator_); } else { MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; } @@ -1532,13 +1532,13 @@ std::shared_ptr FindNextLayout(const CNodePtr &cnode) { if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(use_apply) && use_apply->HasUserData()) { + if (IsParallelCareNode(use_apply) && use_apply->has_user_data()) { MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); auto layout = GetInputLayoutFromCNode(node_pair); return std::make_shared(layout); } MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) - << " " << use_apply->HasUserData(); + << " " << use_apply->has_user_data(); auto layout_ptr = FindNextLayout(use_apply); if (layout_ptr) { @@ -1570,7 +1570,7 @@ std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &n if (!IsValueNode(cnode->input(0))) { return nullptr; } - if (IsParallelCareNode(cnode) && cnode->HasUserData()) { + if (IsParallelCareNode(cnode) && cnode->has_user_data()) { auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); if (!layout_ptr) { MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; @@ -1614,7 +1614,7 @@ std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { if (!IsValueNode(cnode->input(0))) { return nullptr; } - if (IsParallelCareNode(cnode) && cnode->HasUserData()) { + if (IsParallelCareNode(cnode) && cnode->has_user_data()) { auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); if (!layout_ptr) { MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; @@ -1654,12 +1654,12 @@ void ReshapeInit(const std::vector &all_nodes) { continue; } ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsParallelCareNode(cnode) || !cnode->HasUserData()) { + if (!IsParallelCareNode(cnode) || !cnode->has_user_data()) { continue; } PrimitivePtr prim = GetValueNode(prim_anf_node); MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->GetUserData(); + OperatorInfoPtr operator_info = cnode->user_data(); if (operator_info == nullptr) { MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; } @@ -1704,7 +1704,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { auto current_prim = GetValueNode(pre_cnode->input(0)); // return -> cast - if (current_prim->name() == CAST && !pre_cnode->HasUserData()) { + if (current_prim->name() == CAST && !pre_cnode->has_user_data()) { pre_cnode = pre_cnode->input(1)->cast(); MS_EXCEPTION_IF_NULL(pre_cnode); current_prim = GetValueNode(pre_cnode->input(0)); @@ -1761,7 +1761,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { return ret; } - OperatorInfoPtr operator_info = loss_cnode->GetUserData(); + OperatorInfoPtr operator_info = loss_cnode->user_data(); MS_EXCEPTION_IF_NULL(operator_info); TensorInfo loss_grad_tensor_info; size_t op_output_size = operator_info->outputs_tensor_info().size(); @@ -1799,7 +1799,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay if (sens_tensor_node->isa()) { auto sens_tensor_param = sens_tensor_node->cast(); MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); - sens_tensor_param->SetUserData(std::make_shared(loss_grad_layout)); + sens_tensor_param->set_user_data(std::make_shared(loss_grad_layout)); } MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; return; @@ -1824,7 +1824,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay cloned_abstract->set_shape(parallel_shape); sens_tensor_node->set_abstract(cloned_abstract); auto sens_tensor_param = sens_tensor_node->cast(); - sens_tensor_param->SetUserData(std::make_shared(loss_grad_layout)); + sens_tensor_param->set_user_data(std::make_shared(loss_grad_layout)); return; } MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; @@ -2131,7 +2131,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { } PrimitivePtr prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->GetUserData(); + OperatorInfoPtr operator_info = cnode->user_data(); if (operator_info) { if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { continue; diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 96acdf4248b29c085603ea1b4a20ee49878b9ca0..82f02a55f88339ce08e8f3578cce3413cc501642 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -158,29 +158,29 @@ class AnfNode : public Base { size_t seen_{0}; template - void SetUserData(const std::string &key, const std::shared_ptr &value) { + void set_user_data(const std::string &key, const std::shared_ptr &value) { user_data_.set(key, value); } template - void SetUserData(const std::shared_ptr &value) { + void set_user_data(const std::shared_ptr &value) { user_data_.set(T::key, value); } template - std::shared_ptr GetUserData(const std::string &key) const { + std::shared_ptr user_data(const std::string &key) const { return user_data_.get(key); } template - std::shared_ptr GetUserData() const { + std::shared_ptr user_data() const { return user_data_.get(T::key); } - bool HasUserData(const std::string &key) const { return user_data_.has(key); } + bool has_user_data(const std::string &key) const { return user_data_.has(key); } template - bool HasUserData() const { + bool has_user_data() const { return user_data_.has(T::key); } diff --git a/tests/ut/cpp/parallel/step_auto_parallel_test.cc b/tests/ut/cpp/parallel/step_auto_parallel_test.cc index cca7efd62ff566f08acfe961c4781922aed0be06..fd455017965452779285f5b9c438aa3aeeeda9c3 100644 --- a/tests/ut/cpp/parallel/step_auto_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_auto_parallel_test.cc @@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) { StrategyPtr strategyPtr; std::shared_ptr matmul_info = NewOperatorInstance(prim, attrs, shape); - node->SetUserData(matmul_info); + node->set_user_data(matmul_info); std::string name_expect = "MatMulInfo00"; std::string name_test = matmul_info->name(); ASSERT_EQ(name_expect, name_test); diff --git a/tests/ut/cpp/parallel/step_parallel_test.cc b/tests/ut/cpp/parallel/step_parallel_test.cc index 3c54f80eda8d298fc18bfbf4ab317bdcd5ad86e4..18898597a71e2f6b6070d8610b61405541ae7b72 100644 --- a/tests/ut/cpp/parallel/step_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_parallel_test.cc @@ -522,8 +522,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) { std::vector shape = {inputs_shape, outputs_shape}; OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); matmul_info->Init(strategyPtr); - node->SetUserData(matmul_info); - OperatorInfoPtr distribute_operator_pre = node->GetUserData(); + node->set_user_data(matmul_info); + OperatorInfoPtr distribute_operator_pre = node->user_data(); TensorLayout tensorlayout_e; std::vector array = {64, 64}; TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre);