提交 485ac838 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3162 split tuple output node to maketuple

Merge pull request !3162 from lianliguang/split-tuple-node-to-make-tuple
......@@ -47,8 +47,8 @@ void BackendCommonOptimization(const std::shared_ptr<session::KernelGraph> &kern
common_pm->AddPass(std::make_shared<ConvertConstInputToAttr>());
common_pm->AddPass(std::make_shared<ConstToAttrStridedSliceGradPass>());
common_pm->AddPass(std::make_shared<ConvertConstInputToTensorInput>());
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
common_pm->AddPass(std::make_shared<ConvertTupleOutputToMaketuple>());
common_pm->AddPass(std::make_shared<ConvertTupleInputToDynamicInput>());
optimizer->AddPassManager(common_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
......
......@@ -27,86 +27,33 @@
namespace mindspore {
namespace opt {
namespace {
bool MakeValueNode(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return false;
}
// create kernel_info fo new value node
auto kernel_info = std::make_shared<device::KernelInfo>();
value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
TypeId infer_data_type;
if (AnfAlgo::GetOutputTensorNum(value_node) == 0) {
infer_data_type = kTypeUnknown;
} else {
infer_data_type = AnfAlgo::GetOutputInferDataType(value_node, 0);
}
kernel_build_info_builder->SetOutputsDeviceType(std::vector<TypeId>{infer_data_type});
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), value_node.get());
return true;
}
void ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node,
std::vector<AnfNodePtr> *plant_inputs, std::vector<int> *dyn_input_sizes) {
MS_EXCEPTION_IF_NULL(plant_inputs);
MS_EXCEPTION_IF_NULL(dyn_input_sizes);
MS_EXCEPTION_IF_NULL(graph);
auto output_size = AnfAlgo::GetOutputTensorNum(input_node);
dyn_input_sizes->push_back(output_size);
std::vector<AnfNodePtr> convert_inputs;
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
if (input_node->isa<ValueNode>()) {
auto value_node = input_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
convert_inputs = kernel_graph->SplitTupleValueNodeToNodeList(value_node);
} else {
for (size_t index = 0; index < output_size; ++index) {
auto tuple_get_item = CreatTupleGetItemNode(graph, input_node, index);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, index)},
{AnfAlgo::GetOutputInferShape(input_node, index)}, tuple_get_item.get());
convert_inputs.emplace_back(tuple_get_item);
}
}
(void)std::copy(convert_inputs.begin(), convert_inputs.end(), std::back_inserter(*plant_inputs));
}
void ConvertMakeTupleInputToPlantInputs(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(cnode_ptr);
MS_EXCEPTION_IF_NULL(graph);
auto &ori_args = cnode_ptr->inputs();
if (ori_args.size() < 1) {
return;
}
std::vector<AnfNodePtr> plant_inputs;
std::vector<int> dyn_input_sizes;
plant_inputs.push_back(ori_args[kAnfPrimitiveIndex]);
for (size_t i = 1; i < ori_args.size(); ++i) {
auto input_node = ori_args[i];
if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) {
plant_inputs.push_back(AnfAlgo::GetCNodePrimitiveNode(cnode_ptr));
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(cnode_ptr); ++i) {
auto input_node = AnfAlgo::GetInputNode(cnode_ptr, i);
MS_EXCEPTION_IF_NULL(input_node);
if (input_node->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimMakeTuple)) {
auto input_size = AnfAlgo::GetOutputTensorNum(input_node);
dyn_input_sizes.push_back(input_size);
auto cnode = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto inputs = cnode->inputs();
for (size_t j = 1; j < inputs.size(); ++j) {
MS_EXCEPTION_IF_NULL(inputs[j]);
if (IsValueNode<tensor::Tensor>(inputs[j])) {
auto success = MakeValueNode(inputs[j]);
auto make_tuple = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
for (size_t j = 0; j < AnfAlgo::GetInputTensorNum(make_tuple); ++j) {
auto dyn_input_node = AnfAlgo::GetInputNode(make_tuple, j);
MS_EXCEPTION_IF_NULL(dyn_input_node);
if (IsValueNode<tensor::Tensor>(dyn_input_node)) {
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto success = kernel_graph->NewValueNode(dyn_input_node->cast<ValueNodePtr>());
if (!success) {
MS_LOG(WARNING) << "Make value node failed, " << inputs[j]->DebugString();
MS_LOG(WARNING) << "Make value node failed, " << dyn_input_node->DebugString();
}
}
plant_inputs.push_back(inputs[j]);
plant_inputs.push_back(dyn_input_node);
}
} else if (input_node->Type() != nullptr && AnfAlgo::IsTupleOutput(input_node)) {
ConvertTupleOuputToPlantInputs(graph, input_node, &plant_inputs, &dyn_input_sizes);
} else {
dyn_input_sizes.push_back(-1);
plant_inputs.push_back(input_node);
......@@ -139,9 +86,8 @@ const AnfNodePtr ConvertTupleInputToDynamicInput::Process(const FuncGraphPtr &fu
for (auto &t : todos) {
ConvertMakeTupleInputToPlantInputs(sub_graph, t->cast<CNodePtr>());
}
} else {
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
}
ConvertMakeTupleInputToPlantInputs(func_graph, node->cast<CNodePtr>());
return node;
}
} // namespace opt
......
......@@ -25,6 +25,38 @@
namespace mindspore {
namespace opt {
namespace {
CNodePtr ConvertTupleOuputToPlantInputs(const FuncGraphPtr &graph, const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(graph);
if (!AnfAlgo::IsTupleOutput(input_node)) {
MS_LOG(EXCEPTION) << "Cannot using the function to convert a not tuple output node to maketuple!";
}
if (input_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << input_node->DebugString();
}
std::vector<AnfNodePtr> convert_inputs = {NewValueNode(prim::kPrimMakeTuple)};
auto kernel_graph = graph->cast<KernelGraphPtr>();
MS_EXCEPTION_IF_NULL(kernel_graph);
auto splited_node_list = kernel_graph->SplitTupleOutputNodeToNodeList(input_node);
for (const auto &node : splited_node_list) {
if (AnfAlgo::IsTupleOutput(node)) {
convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, node));
continue;
}
convert_inputs.emplace_back(node);
}
auto make_tuple = graph->NewCNode(convert_inputs);
std::vector<abstract::AbstractBasePtr> abstract_list;
auto make_tuple_input_size = AnfAlgo::GetInputTensorNum(make_tuple);
for (size_t index = 0; index < make_tuple_input_size; ++index) {
auto make_tuple_input = AnfAlgo::GetInputNode(make_tuple, index);
MS_EXCEPTION_IF_NULL(make_tuple_input);
abstract_list.emplace_back(make_tuple_input->abstract());
}
make_tuple->set_abstract(std::make_shared<abstract::AbstractTuple>(abstract_list));
return make_tuple;
}
CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr &cnode_ptr) {
MS_EXCEPTION_IF_NULL(cnode_ptr);
MS_EXCEPTION_IF_NULL(graph);
......@@ -35,19 +67,25 @@ CNodePtr ConvertTupleInputToMakeTuple(const FuncGraphPtr &graph, const CNodePtr
std::vector<TypeId> types;
std::vector<std::vector<size_t>> shapes;
std::vector<AnfNodePtr> make_tuple_inputs_list = {NewValueNode(prim::kPrimMakeTuple)};
for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) {
make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index));
types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index));
shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index));
if (input_node->isa<CNode>()) {
for (size_t tuple_out_index = 0; tuple_out_index < AnfAlgo::GetOutputTensorNum(input_node); ++tuple_out_index) {
make_tuple_inputs_list.emplace_back(CreatTupleGetItemNode(graph, input_node, tuple_out_index));
types.push_back(AnfAlgo::GetOutputInferDataType(input_node, tuple_out_index));
shapes.emplace_back(AnfAlgo::GetOutputInferShape(input_node, tuple_out_index));
}
auto make_tuple = graph->NewCNode(make_tuple_inputs_list);
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get());
convert_inputs.emplace_back(make_tuple);
continue;
}
auto make_tuple = graph->NewCNode(make_tuple_inputs_list);
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, make_tuple.get());
convert_inputs.emplace_back(make_tuple);
convert_inputs.emplace_back(ConvertTupleOuputToPlantInputs(graph, input_node));
} else {
convert_inputs.push_back(input_node);
}
}
return graph->NewCNode(convert_inputs);
auto new_node = graph->NewCNode(convert_inputs);
new_node->set_abstract(cnode_ptr->abstract());
return new_node;
}
} // namespace
......
......@@ -79,31 +79,6 @@ std::vector<AnfNodePtr> GetCallRealOutputs(const AnfNodePtr &call_node) {
return real_inputs;
}
AnfNodePtr MakeValueNode(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return nullptr;
}
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
new_value_node->set_abstract(value_node->abstract());
// create kernel_info fo new value node
auto kernel_info = std::make_shared<device::KernelInfo>();
new_value_node->set_kernel_info(kernel_info);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
std::vector<TypeId> types;
for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(value_node); ++index) {
types.push_back(kTypeUnknown);
}
kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
return new_value_node;
}
bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
if (left == right) {
return true;
......@@ -121,6 +96,18 @@ bool IsSameLabel(const CNodePtr &left, const CNodePtr &right) {
return false;
}
} // namespace
AnfNodePtr KernelGraph::MakeValueNode(const AnfNodePtr &node) {
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return nullptr;
}
ValueNodePtr new_value_node = std::make_shared<ValueNode>(value_node->value());
new_value_node->set_abstract(value_node->abstract());
this->SetKernelInfoForNode(new_value_node);
return new_value_node;
}
std::vector<AnfNodePtr> KernelGraph::outputs() const {
auto graph_output = output();
if (IsPrimitiveCNode(graph_output, prim::kPrimMakeTuple)) {
......@@ -290,28 +277,10 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
MS_EXCEPTION_IF_NULL(cnode);
cnode->set_abstract(std::make_shared<abstract::AbstractNone>());
CreateKernelInfoFromNewParameter(cnode);
auto kernel_info = std::make_shared<device::KernelInfo>();
std::vector<size_t> feature_map_input_indexs;
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
for (size_t index = 1; index < inputs.size(); ++index) {
auto node = inputs[index];
if (AnfAlgo::IsFeatureMapOutput(node)) {
feature_map_input_indexs.push_back(index);
}
}
if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimCast->name()) {
AnfAlgo::SetNodeAttr(kIsBackendCast, MakeValue(false), cnode);
}
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->SetFeatureMapFlag(true);
}
if (AnfAlgo::IsRealKernel(cnode)) {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
}
cnode->set_kernel_info(kernel_info);
SetKernelInfoForNode(cnode);
AnfAlgo::SetGraphId(graph_id_, cnode.get());
return cnode;
}
......@@ -351,6 +320,50 @@ void KernelGraph::CreateKernelInfoFromNewParameter(const CNodePtr &cnode) {
}
}
void KernelGraph::SetKernelInfoForNode(const AnfNodePtr &node) const {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = std::make_shared<device::KernelInfo>();
node->set_kernel_info(kernel_info);
if (node->isa<CNode>()) {
std::vector<size_t> feature_map_input_indexs;
kernel_info->SetFeatureMapFlag(false);
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) {
if (AnfAlgo::IsFeatureMapInput(node, index)) {
kernel_info->SetFeatureMapFlag(true);
feature_map_input_indexs.push_back(index);
}
}
if (AnfAlgo::GetInputTensorNum(node) == 0) {
kernel_info->SetFeatureMapFlag(true);
}
if (AnfAlgo::IsRealKernel(node)) {
// if the node only has the primitive(such as getNext) or the node's input has a feature map input
// then the node's output is a feature map output
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), node);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), node);
}
return;
}
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
std::vector<TypeId> types;
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>{kOpFormat_DEFAULT});
if (node->isa<ValueNode>()) {
kernel_info->SetFeatureMapFlag(false);
types.emplace_back(kTypeUnknown);
}
if (node->isa<Parameter>()) {
auto parameter = node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter);
bool is_weight = AnfAlgo ::IsParameterWeight(parameter);
kernel_info->SetFeatureMapFlag(!is_weight);
types.push_back(is_weight ? kTypeUnknown : AnfAlgo::GetOutputInferDataType(parameter, 0));
}
// set parameter initaial device data type
kernel_build_info_builder->SetOutputsDeviceType(types);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), node.get());
}
CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto new_cnode = std::make_shared<CNode>(*cnode);
......@@ -366,75 +379,97 @@ CNodePtr KernelGraph::NewCNode(const CNodePtr &cnode) {
}
ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
ParameterPtr new_parameter = add_parameter();
auto abstract = parameter == nullptr ? std::make_shared<abstract::AbstractNone>() : parameter->abstract();
auto new_parameter = NewParameter(abstract);
MS_EXCEPTION_IF_NULL(new_parameter);
// create kernel_info form new parameter
auto kernel_info = std::make_shared<device::KernelInfo>();
size_t output_tensor_num = 1;
// if use default parameter = nullptr,it remarks create a new parameter from no parameter
if (parameter == nullptr) {
new_parameter->set_abstract(std::make_shared<abstract::AbstractNone>());
kernel_info->SetFeatureMapFlag(true);
} else {
// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
new_parameter->set_abstract(parameter->abstract());
// if don't use default parameter = nullptr,it remarks create a new parameter from a old parameter
if (parameter != nullptr) {
new_parameter->set_name(parameter->name());
if (AnfAlgo::IsParameterWeight(parameter)) {
new_parameter->set_default_param(parameter->default_param());
kernel_info->SetFeatureMapFlag(false);
} else {
kernel_info->SetFeatureMapFlag(true);
}
}
new_parameter->set_kernel_info(kernel_info);
// create kernel_build_info for new parameter
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// create init data type,
std::vector<TypeId> init_data_type = {};
TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0);
init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type);
// create kernel_info form new parameter
SetKernelInfoForNode(new_parameter);
AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
return new_parameter;
}
// set the format of parameter to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
// set parameter initaial device data type
kernel_build_info_builder->SetOutputsDeviceType(init_data_type);
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_parameter.get());
ParameterPtr KernelGraph::NewParameter(const abstract::AbstractBasePtr &abstract) {
ParameterPtr new_parameter = add_parameter();
new_parameter->set_abstract(abstract);
MS_EXCEPTION_IF_NULL(new_parameter);
// create kernel_info form new parameter
SetKernelInfoForNode(new_parameter);
AnfAlgo::SetGraphId(graph_id_, new_parameter.get());
return new_parameter;
}
std::vector<AnfNodePtr> KernelGraph::SplitTupleParameterToNodeList(const ParameterPtr &parameter) {
MS_EXCEPTION_IF_NULL(parameter);
std::vector<AnfNodePtr> convert_nodes_list;
auto abstract = parameter->abstract();
MS_EXCEPTION_IF_NULL(abstract);
if (!abstract->isa<abstract::AbstractTuple>()) {
MS_LOG(EXCEPTION) << "Multiple output Parameter's output must be a tuple abstract but got " << abstract->ToString();
}
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
for (size_t index = 0; index < tuple_abstract->size(); ++index) {
auto new_parameter = this->NewParameter((*tuple_abstract)[index]);
SetKernelInfoForNode(new_parameter);
convert_nodes_list.emplace_back(new_parameter);
}
auto new_inputs = std::make_shared<std::vector<AnfNodePtr>>();
auto old_inputs = inputs();
for (const auto &input_node : old_inputs) {
if (input_node != parameter) {
new_inputs->emplace_back(input_node);
continue;
}
std::copy(convert_nodes_list.begin(), convert_nodes_list.end(), std::back_inserter(*new_inputs));
}
inputs_ = new_inputs;
return convert_nodes_list;
}
std::vector<AnfNodePtr> KernelGraph::SplitTupleOutputNodeToNodeList(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The function can only split a parameter or valuenode bug got " << node->DebugString();
}
if (node->isa<Parameter>()) {
return SplitTupleParameterToNodeList(node->cast<ParameterPtr>());
}
return SplitTupleValueNodeToNodeList(node->cast<ValueNodePtr>());
}
std::vector<AnfNodePtr> KernelGraph::SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(value_node);
auto node_value = value_node->value();
auto output_size = AnfAlgo::GetOutputTensorNum(value_node);
std::vector<AnfNodePtr> convert_inputs;
if (!node_value->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "Multiple output valuenode's value must be a value tuple but got " << node_value->ToString();
}
auto value_tuple = node_value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
if (value_tuple->size() != output_size) {
MS_LOG(EXCEPTION) << "Value tuple size" << value_tuple->size()
<< " is not mathced with the value node's output size" << output_size;
auto abstract = value_node->abstract();
if (!abstract->isa<abstract::AbstractTuple>()) {
MS_LOG(EXCEPTION) << "Spilted node's output abstract is not type tuple";
}
auto tuple_abstract = abstract->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_abstract);
if (tuple_abstract->size() != value_tuple->size()) {
MS_LOG(EXCEPTION) << "The node output index [" << value_tuple->size() << "]is outof range "
<< tuple_abstract->size();
}
for (size_t index = 0; index < value_tuple->value().size(); ++index) {
auto new_value_node = std::make_shared<ValueNode>(value_tuple->value()[index]);
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(value_node, index)},
{AnfAlgo::GetOutputInferShape(value_node, index)}, new_value_node.get());
new_value_node->set_abstract((*tuple_abstract)[index]);
AddValueNodeToGraph(new_value_node);
auto kernel_info = std::make_shared<device::KernelInfo>();
new_value_node->set_kernel_info(kernel_info);
kernel_info->SetFeatureMapFlag(false);
// create kernel_build_info for new value node
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// set the format of value_node to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
// set value node initial device data type = infer data type
kernel_build_info_builder->SetOutputsDeviceType({kTypeUnknown});
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), new_value_node.get());
SetKernelInfoForNode(new_value_node);
AnfAlgo::SetGraphId(graph_id_, new_value_node.get());
AddValueNodeToGraph(new_value_node);
convert_inputs.emplace_back(new_value_node);
}
if (!RemoveValueNodeFromGraph(value_node)) {
......
......@@ -54,8 +54,10 @@ class KernelGraph : public FuncGraph {
void CreateKernelInfoFromNewParameter(const CNodePtr &cnode);
CNodePtr NewCNode(const CNodePtr &cnode);
ParameterPtr NewParameter(const ParameterPtr &parameter = nullptr);
ParameterPtr NewParameter(const abstract::AbstractBasePtr &abstract);
ValueNodePtr NewValueNode(const ValuePtr &value);
ValueNodePtr NewValueNode(const ValueNodePtr &value_node = nullptr);
std::vector<AnfNodePtr> SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node);
std::vector<AnfNodePtr> SplitTupleOutputNodeToNodeList(const AnfNodePtr &node);
void set_execution_order(const std::vector<CNodePtr> &order) { execution_order_ = order; }
const std::vector<CNodePtr> &execution_order() const { return execution_order_; }
void SetExecOrderByDefault();
......@@ -166,6 +168,10 @@ class KernelGraph : public FuncGraph {
private:
// remove value node form graph
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
void SetKernelInfoForNode(const AnfNodePtr &node) const;
std::vector<AnfNodePtr> SplitTupleValueNodeToNodeList(const ValueNodePtr &value_node);
std::vector<AnfNodePtr> SplitTupleParameterToNodeList(const ParameterPtr &parameter);
AnfNodePtr MakeValueNode(const AnfNodePtr &node);
void VisitNodeDescendants(const AnfNodePtr &node, std::queue<AnfNodePtr> *visit_queue,
std::unordered_set<AnfNodePtr> *visited_nodes);
// update node edge list
......
......@@ -60,7 +60,7 @@ TEST_F(KernelGraphTest, NewParameter) {
auto anf_graph = std::make_shared<FuncGraph>();
auto kernel_graph = std::make_shared<KernelGraph>();
// test nullptr as input
auto new_paramter = kernel_graph->NewParameter(nullptr);
auto new_paramter = kernel_graph->NewParameter();
EXPECT_NE(new_paramter, nullptr);
EXPECT_TRUE(new_paramter->isa<Parameter>());
EXPECT_EQ(AnfAlgo::GetOutputFormat(new_paramter, 0), kOpFormat_DEFAULT);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册