提交 4834a6b3 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3574 Rename AnfNode::user_data related functions to follow naming rule

Merge pull request !3574 from hewei/rename_user_data_func
......@@ -267,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
return;
}
auto operator_info = node->GetUserData<parallel::OperatorInfo>();
auto operator_info = node->user_data<parallel::OperatorInfo>();
if (operator_info == nullptr) {
return;
}
......
......@@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr) {
return;
}
auto distributed_operation_info = node->GetUserData<parallel::OperatorInfo>();
auto distributed_operation_info = node->user_data<parallel::OperatorInfo>();
if (distributed_operation_info != nullptr) {
auto strategyPtr = distributed_operation_info->strategy();
if (strategyPtr != nullptr) {
......
......@@ -50,7 +50,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr &para, uint32_t
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
(void)cnode_set.emplace(cnode);
} else {
auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
......@@ -98,7 +98,7 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi
return cnode_dist;
}
auto operator_info = cnode->GetUserData<OperatorInfo>();
auto operator_info = cnode->user_data<OperatorInfo>();
MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode)
<< " operator_info: " << (operator_info != nullptr);
......
......@@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) {
}
auto para_ptr = node_ptr->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(para_ptr);
auto layout_ptr = para_ptr->GetUserData<TensorLayout>();
auto layout_ptr = para_ptr->user_data<TensorLayout>();
if (layout_ptr == nullptr) {
MS_LOG(ERROR) << "layout_ptr is nullptr!";
return FAILED;
......
......@@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
for (auto para : graph_params) {
std::string name = std::static_pointer_cast<Parameter>(para)->name();
auto tensor_layout = para->GetUserData<parallel::TensorLayout>();
auto tensor_layout = para->user_data<parallel::TensorLayout>();
if (tensor_layout == nullptr) {
MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name;
} else {
......@@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) {
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto distributed_operation_info = cnode->GetUserData<OperatorInfo>();
auto distributed_operation_info = cnode->user_data<OperatorInfo>();
if (distributed_operation_info != nullptr) {
auto strategyPtr = distributed_operation_info->strategy();
if (strategyPtr != nullptr) {
......
......@@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
entire_costgraph->AddOperator(operator_info);
cnode->SetUserData<OperatorInfo>(operator_info);
cnode->set_user_data<OperatorInfo>(operator_info);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
......@@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
entire_costgraph->AddOperator(operator_info);
cnode->SetUserData<OperatorInfo>(operator_info);
cnode->set_user_data<OperatorInfo>(operator_info);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
......@@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
<< " does not match the Prim: " << prim->name();
}
cnode->SetUserData<OperatorInfo>(current_op_ptr);
cnode->set_user_data<OperatorInfo>(current_op_ptr);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
......@@ -549,7 +549,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
size_t edge_count = 0;
auto node_op_info = cnode->GetUserData<OperatorInfo>();
auto node_op_info = cnode->user_data<OperatorInfo>();
for (size_t i = 1; i < inputs.size(); ++i) {
auto prev_cnode = inputs[i]->cast<CNodePtr>();
......@@ -565,7 +565,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
while (bool_result) {
if (IsAutoParallelCareNode(prev_cnode)) {
auto prev_op_info = prev_cnode->GetUserData<OperatorInfo>();
auto prev_op_info = prev_cnode->user_data<OperatorInfo>();
std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name();
// If the edge between these two operators already has been added, then the edge will not be added again.
if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) {
......@@ -751,7 +751,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
auto target_cnode = target.first->cast<CNodePtr>();
auto input_index = target.second;
(void)target_without_duplicate.insert(std::to_string(input_index) +
target_cnode->GetUserData<OperatorInfo>()->name());
target_cnode->user_data<OperatorInfo>()->name());
}
if (target_without_duplicate.size() <= 1) {
continue;
......@@ -831,7 +831,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
auto target_cnode = target.first->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0));
auto input_index = target.second;
auto target_op_info = target_cnode->GetUserData<OperatorInfo>();
auto target_op_info = target_cnode->user_data<OperatorInfo>();
std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name();
// If the edge between these two operators already has been added, then the edge will not be added again.
......@@ -862,7 +862,7 @@ bool FindReshape(const CNodePtr &cnode) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) {
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
......@@ -884,7 +884,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
if (!IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
auto node_op_info = cnode->GetUserData<OperatorInfo>();
auto node_op_info = cnode->user_data<OperatorInfo>();
if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) {
*pre_operator_info = node_op_info;
*out_index = 0;
......@@ -900,7 +900,7 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
}
CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
auto pre_op_info = pre_cnode->GetUserData<OperatorInfo>();
auto pre_op_info = pre_cnode->user_data<OperatorInfo>();
if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) {
*pre_operator_info = pre_op_info;
return true;
......@@ -941,7 +941,7 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
auto op_info = use_apply->GetUserData<OperatorInfo>();
auto op_info = use_apply->user_data<OperatorInfo>();
if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
*next_operator_info = op_info;
......@@ -970,7 +970,7 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
int32_t out_index = 0;
OperatorInfoPtr pre_operator_info;
std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
auto operator_info = cnode->GetUserData<OperatorInfo>();
auto operator_info = cnode->user_data<OperatorInfo>();
if (pre_node->isa<Parameter>()) {
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info->SetCostForReshapeWithParameter();
......
......@@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
if (!IsParallelCareNode(node)) {
return nullptr;
}
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>();
OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr";
}
......@@ -409,7 +409,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
if (prim->name() == GET_NEXT) {
return true;
}
if ((prim->name() == CAST) && !cnode->HasUserData<OperatorInfo>()) {
if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
return false;
}
......@@ -446,7 +446,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData<OperatorInfo>()) {
if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution,
pre_node);
} else {
......@@ -459,7 +459,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(next_node);
OperatorInfoPtr op_info = next_node->GetUserData<OperatorInfo>();
OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>();
MS_EXCEPTION_IF_NULL(op_info);
// If the shape of tensor is [] or [1], no need to split it.
......@@ -584,7 +584,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
// step1:get graph manager distribute_operator
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>();
OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr";
}
......@@ -622,7 +622,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
(void)prim->SetAttrs(attrs);
}
if (index == replace_op.size() - 1) {
replace_node->SetUserData<OperatorInfo>(node->GetUserData<OperatorInfo>());
replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
}
replace_node->set_in_forward_flag(true);
replace_input[0]->set_scope(scope);
......@@ -702,7 +702,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) {
auto pre_cnode = pre_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
if (pre_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) {
if (pre_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
pre_node = pre_cnode->input(1);
}
......@@ -1198,7 +1198,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) {
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
return node_pair;
} else if (FindParallelCareNode(node_pair.first).first != nullptr) {
return FindParallelCareNode(node_pair.first);
......@@ -1248,7 +1248,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
CNodePtr cnode = res.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
OperatorInfoPtr distribute_operator = cnode->GetUserData<OperatorInfo>();
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
}
......@@ -1271,7 +1271,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter_ptr);
parameter_ptr->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
}
void CoverSliceShape(const FuncGraphPtr &root) {
......@@ -1359,7 +1359,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if (found_be_cloned_parameter) {
// set the shape and tensor layout for cloned parameter
cloned_parameter->SetUserData<TensorLayout>(cloned_from_parameter->GetUserData<TensorLayout>());
cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>());
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
......@@ -1454,7 +1454,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
(*operator_).set_outputs_dtype(cnode->Type());
(*operator_).set_cnode(cnode);
if (prim->name() == RESHAPE) {
cnode->SetUserData<OperatorInfo>(operator_);
cnode->set_user_data<OperatorInfo>(operator_);
continue;
}
// load strategy checkpoint
......@@ -1489,7 +1489,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
if (operator_->Init(strategyPtr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
}
cnode->SetUserData<OperatorInfo>(operator_);
cnode->set_user_data<OperatorInfo>(operator_);
} else {
MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr";
}
......@@ -1532,13 +1532,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(use_apply) && use_apply->HasUserData<OperatorInfo>()) {
if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
auto layout = GetInputLayoutFromCNode(node_pair);
return std::make_shared<TensorLayout>(layout);
}
MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
<< " " << use_apply->HasUserData<OperatorInfo>();
<< " " << use_apply->has_user_data<OperatorInfo>();
auto layout_ptr = FindNextLayout(use_apply);
if (layout_ptr) {
......@@ -1570,7 +1570,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
if (!IsValueNode<Primitive>(cnode->input(0))) {
return nullptr;
}
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
if (!layout_ptr) {
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
......@@ -1614,7 +1614,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
if (!IsValueNode<Primitive>(cnode->input(0))) {
return nullptr;
}
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
if (!layout_ptr) {
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
......@@ -1654,12 +1654,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
continue;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) {
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>();
OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
if (operator_info == nullptr) {
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
}
......@@ -1704,7 +1704,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
// return -> cast
if (current_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) {
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
......@@ -1761,7 +1761,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
return ret;
}
OperatorInfoPtr operator_info = loss_cnode->GetUserData<OperatorInfo>();
OperatorInfoPtr operator_info = loss_cnode->user_data<OperatorInfo>();
MS_EXCEPTION_IF_NULL(operator_info);
TensorInfo loss_grad_tensor_info;
size_t op_output_size = operator_info->outputs_tensor_info().size();
......@@ -1799,7 +1799,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
if (sens_tensor_node->isa<Parameter>()) {
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
}
MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens";
return;
......@@ -1824,7 +1824,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
cloned_abstract->set_shape(parallel_shape);
sens_tensor_node->set_abstract(cloned_abstract);
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
return;
}
MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now.";
......@@ -2131,7 +2131,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>();
OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
if (operator_info) {
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
continue;
......
......@@ -158,29 +158,29 @@ class AnfNode : public Base {
size_t seen_{0};
template <typename T>
void SetUserData(const std::string &key, const std::shared_ptr<T> &value) {
void set_user_data(const std::string &key, const std::shared_ptr<T> &value) {
user_data_.set<T>(key, value);
}
template <typename T>
void SetUserData(const std::shared_ptr<T> &value) {
void set_user_data(const std::shared_ptr<T> &value) {
user_data_.set<T>(T::key, value);
}
template <typename T>
std::shared_ptr<T> GetUserData(const std::string &key) const {
std::shared_ptr<T> user_data(const std::string &key) const {
return user_data_.get<T>(key);
}
template <typename T>
std::shared_ptr<T> GetUserData() const {
std::shared_ptr<T> user_data() const {
return user_data_.get<T>(T::key);
}
bool HasUserData(const std::string &key) const { return user_data_.has(key); }
bool has_user_data(const std::string &key) const { return user_data_.has(key); }
template <typename T>
bool HasUserData() const {
bool has_user_data() const {
return user_data_.has(T::key);
}
......
......@@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) {
StrategyPtr strategyPtr;
std::shared_ptr<OperatorInfo> matmul_info = NewOperatorInstance(prim, attrs, shape);
node->SetUserData<OperatorInfo>(matmul_info);
node->set_user_data<OperatorInfo>(matmul_info);
std::string name_expect = "MatMulInfo00";
std::string name_test = matmul_info->name();
ASSERT_EQ(name_expect, name_test);
......
......@@ -522,8 +522,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) {
std::vector<Shapes> shape = {inputs_shape, outputs_shape};
OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape);
matmul_info->Init(strategyPtr);
node->SetUserData<OperatorInfo>(matmul_info);
OperatorInfoPtr distribute_operator_pre = node->GetUserData<OperatorInfo>();
node->set_user_data<OperatorInfo>(matmul_info);
OperatorInfoPtr distribute_operator_pre = node->user_data<OperatorInfo>();
TensorLayout tensorlayout_e;
std::vector<int32_t> array = {64, 64};
TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册