提交 1f8b00df 编写于 作者: Z zhoufeng

Fix empty graph dump ir

Signed-off-by: Nzhoufeng <zhoufeng54@huawei.com>
上级 a342a615
......@@ -694,7 +694,7 @@ void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &
MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size();
}
if (shapes.empty()) {
MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes";
node->set_abstract(std::make_shared<abstract::AbstractNone>());
} else if (shapes.size() == 1) {
// single output handle
std::vector<int> shape_int;
......@@ -1012,6 +1012,9 @@ std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CN
auto get_switch_kernel_graph = [switch_node](size_t input_index) -> KernelGraphPtr {
auto partial = switch_node->input(input_index);
MS_EXCEPTION_IF_NULL(partial);
if (IsValueNode<KernelGraph>(partial)) {
return GetValueNode<KernelGraphPtr>(partial);
}
auto partial_cnode = partial->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto graph_node = partial_cnode->input(1);
......
......@@ -386,8 +386,7 @@ void AscendControlParser::RecurseSwitch(NotNull<KernelGraphPtr> kg, NotNull<CNod
origin_switch_inputs[kCNodeSwitchCond]};
for (size_t i = kCNodeSwitchCond + 1; i < kCNodeSwitchLength; ++i) {
// 3.1 branch kernel graph and args
KernelGraphPtr branch_fg;
std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
// 3.2 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
new_switch_inputs.push_back(branch_label);
......@@ -432,8 +431,7 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
origin_switch_inputs[kCNodeSwitchCond]};
for (size_t i = 0; i < branch_partial.size(); ++i) {
// 3.1 branch kernel graph and args
KernelGraphPtr branch_fg;
std::tie(std::ignore, branch_fg) = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
KernelGraphPtr branch_fg = ParsePartial(NOT_NULL(origin_switch_inputs[i]));
// 3.2 recurse sub graph
CNodePtr branch_label = ProcessKernelGraph(NOT_NULL(branch_fg), cur_node, back_label, memo);
new_switch_inputs.push_back(branch_label);
......@@ -444,8 +442,11 @@ void AscendControlParser::RecurseSwitchLayer(NotNull<KernelGraphPtr> kg, NotNull
MS_LOG(INFO) << "Succeed processing switch layer " << cur_node->DebugString();
}
std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
KernelGraphPtr AscendControlParser::ParsePartial(NotNull<AnfNodePtr> node) {
if (!node.get()->isa<CNode>()) {
if (IsValueNode<KernelGraph>(node)) {
return GetValueNode<KernelGraphPtr>(node);
}
MS_LOG(EXCEPTION) << "Switch branches must be partial, node: " << node->DebugString();
}
// 2.1 branch kernel graph and args
......@@ -460,7 +461,7 @@ std::tuple<CNodePtr, KernelGraphPtr> AscendControlParser::ParsePartial(NotNull<A
MS_LOG(EXCEPTION) << "Index out of range:" << partial_inputs.size() << ".";
}
auto branch_kg = GetValueNode<KernelGraphPtr>(partial_inputs[kCNodePartialFunc]);
return {partial_cnode, branch_kg};
return branch_kg;
}
void AscendControlParser::InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph,
......
......@@ -52,7 +52,7 @@ class AscendControlParser {
static void LinkParentGraph(NotNull<KernelGraphPtr> kg, const CNodePtr &from_graph_call_node,
const CNodePtr &last_label);
static std::tuple<CNodePtr, KernelGraphPtr> ParsePartial(NotNull<AnfNodePtr> node);
static KernelGraphPtr ParsePartial(NotNull<AnfNodePtr> node);
static void InsertMultipleAssignToGraph(NotNull<KernelGraphPtr> from_graph, NotNull<KernelGraphPtr> to_graph,
NotNull<AnfNodePtr> from, NotNull<AnfNodePtr> to);
......
......@@ -247,6 +247,9 @@ static void UpdateRealInput(NotNull<KernelGraphPtr> graph, bool split_flag) {
MS_EXCEPTION_IF_NULL(switch_cnode);
auto partial = switch_cnode->input(input_index);
MS_EXCEPTION_IF_NULL(partial);
if (IsValueNode<KernelGraph>(partial)) {
return {};
}
auto partial_cnode = partial->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto ret = std::vector<AnfNodePtr>(partial_cnode->inputs().begin() + 2, partial_cnode->inputs().end());
......
......@@ -357,18 +357,16 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
} else {
kernel_info->SetFeatureMapFlag(true);
}
// if output is a tuple tensor,now can use for loop to handle tuple tensor
output_tensor_num = AnfAlgo::GetOutputTensorNum(parameter);
}
new_parameter->set_kernel_info(kernel_info);
// create kernel_build_info for new parameter
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
// create init data type,
std::vector<TypeId> init_data_type = {};
for (size_t i = 0; i < output_tensor_num; i++) {
TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, i);
init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type);
}
TypeId infer_data_type = AnfAlgo::GetOutputInferDataType(new_parameter, 0);
init_data_type.push_back(AnfAlgo::IsParameterWeight(new_parameter) ? kTypeUnknown : infer_data_type);
// set the format of parameter to DEFAULT_FORMAT
kernel_build_info_builder->SetOutputsFormat(std::vector<std::string>(output_tensor_num, kOpFormat_DEFAULT));
// set parameter initaial device data type
......
......@@ -590,7 +590,8 @@ TEST_F(AnfRuntimeAlgorithmTest, SetOutputInferTypeAndShape) {
std::vector<TypeId> none_types = {};
std::vector<std::vector<size_t>> none_shapes = {};
EXPECT_THROW(AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, nullptr), std::runtime_error);
EXPECT_THROW(AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, add.get()), std::runtime_error);
AnfAlgo::SetOutputInferTypeAndShape(none_types, none_shapes, add.get());
EXPECT_EQ((*add->abstract()), abstract::AbstractNone());
// set single input
std::vector<TypeId> single_types = {kFloat32->type_id()};
std::vector<std::vector<size_t>> single_shapes = {{2, 32, 224, 224}};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册