提交 6f6fc75b 编写于 作者: L liubuyu

bug fix

上级 c8f69f5d
......@@ -512,7 +512,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list);
select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
select_status = SetMatchedKernelInfo(kernel_node, aicpu_kernel_info_list);
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
}
// The kernel info not finded both in the aicpu kernel list & aicore kernel list
......
......@@ -33,8 +33,7 @@ const BaseRef InsertPadForNMSWithMask::DefinePattern() const {
return VectorRef({prim::kPrimNMSWithMask, Xs});
}
AnfNodePtr INsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type, const TypeId &origin_type,
AnfNodePtr InsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const TypeId &origin_type,
const std::vector<size_t> &origin_shape) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_pad_inputs;
......@@ -43,25 +42,6 @@ AnfNodePtr INsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &in
new_pad_inputs.push_back(input);
CNodePtr pad = func_graph->NewCNode(new_pad_inputs);
MS_EXCEPTION_IF_NULL(pad);
// set kernel build info
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({format});
builder.SetOutputsFormat({format});
builder.SetInputsDeviceType({input_type});
builder.SetOutputsDeviceType({output_type});
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetProcessor(kernel::Processor::AICORE);
if (kernel::OpLib::FindOp(prim::kPrimPad->name(), kernel::kTBE) != nullptr) {
builder.SetKernelType(KernelType::TBE_KERNEL);
} else {
builder.SetKernelType(KernelType::AICPU_KERNEL);
}
if (pad->kernel_info() == nullptr) {
auto kernel_info = std::make_shared<device::KernelInfo>();
pad->set_kernel_info(kernel_info);
}
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), pad.get());
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, pad.get());
return pad;
}
......@@ -81,14 +61,12 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph
for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) {
auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx);
auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx);
auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode, input_idx);
auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx);
if (!(origin_shape.size() == 2 && origin_shape[1] == 5)) {
return nullptr;
}
origin_shape[1] = 8;
auto device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_idx);
auto pad = INsertPadToGraph(func_graph, cur_input, format, origin_type, device_type, origin_type, origin_shape);
auto pad = InsertPadToGraph(func_graph, cur_input, origin_type, origin_shape);
MS_EXCEPTION_IF_NULL(pad);
pad->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr("paddings", MakeValue(std::vector<std::vector<int>>{{0, 0}, {0, 3}}), pad);
......
......@@ -90,7 +90,7 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
new_transdata_node =
NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0),
new_transpose_node);
new_transdata_node);
new_replace_node = new_transdata_node;
}
FuncGraphManagerPtr manager = func_graph->manager();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册