提交 338d7c1a 编写于 作者: W WilliamLian

decoupled of insert transdata and deal ref and split transdata

上级 d402b944
...@@ -503,6 +503,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node, ...@@ -503,6 +503,7 @@ KernelSelectStatus SetMatchedKernelInfo(const CNodePtr &kernel_node,
KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list; std::vector<std::shared_ptr<kernel::KernelBuildInfo>> kernel_info_list;
std::vector<std::shared_ptr<kernel::KernelBuildInfo>> aicpu_kernel_info_list;
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
kernel::KernelQuery(kernel_node, &kernel_info_list); kernel::KernelQuery(kernel_node, &kernel_info_list);
auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); auto select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
...@@ -510,7 +511,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { ...@@ -510,7 +511,7 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
if (select_status == kNoMatched) { if (select_status == kNoMatched) {
MS_LOG(WARNING) << "The node [" << kernel_node->DebugString() MS_LOG(WARNING) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info"; << "] cannot find valid TBE kernel info, try to get aicpu kernel info";
kernel::AICPUQuery(kernel_node, &kernel_info_list); kernel::AICPUQuery(kernel_node, &aicpu_kernel_info_list);
select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list); select_status = SetMatchedKernelInfo(kernel_node, kernel_info_list);
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node); AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
} }
...@@ -518,6 +519,15 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) { ...@@ -518,6 +519,15 @@ KernelSelectStatus SelectKernelInfo(const CNodePtr &kernel_node) {
if (select_status == kNoMatched) { if (select_status == kNoMatched) {
std::ostringstream buffer; std::ostringstream buffer;
PrintInputAndOutputInferType(buffer, kernel_node); PrintInputAndOutputInferType(buffer, kernel_node);
MS_LOG(WARNING) << "=========================kernel info list=====================================";
for (size_t index = 0; index < kernel_info_list.size(); ++index) {
MS_LOG(WARNING) << "kernel [" << index << "] :" << kernel_info_list[index]->ToString();
}
for (size_t index = 0; index < aicpu_kernel_info_list.size(); ++index) {
MS_LOG(WARNING) << "kernel [" << (kernel_info_list.size() + index)
<< "] :" << aicpu_kernel_info_list[index]->ToString();
}
MS_LOG(WARNING) << "========================= end ====================================";
MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString() MS_EXCEPTION(TypeError) << "The node [" << kernel_node->DebugString()
<< "] cannot find valid kernel info, not supported the type " << buffer.str(); << "] cannot find valid kernel info, not supported the type " << buffer.str();
} }
......
...@@ -110,9 +110,9 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const { ...@@ -110,9 +110,9 @@ bool KernelBuildInfo::operator==(const KernelBuildInfo &other) const {
return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_); return !(inputs_device_type_ != other.inputs_device_type_ || outputs_device_type_ != other.outputs_device_type_);
} }
bool KernelBuildInfo::IsInputDefaultPadding() const { return output_reshape_type_.empty(); } bool KernelBuildInfo::IsInputDefaultPadding() const { return input_reshape_type_.empty(); }
bool KernelBuildInfo::IsOutputDefaultPadding() const { return input_reshape_type_.empty(); } bool KernelBuildInfo::IsOutputDefaultPadding() const { return output_reshape_type_.empty(); }
void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) { void KernelBuildInfo::KernelBuildInfoBuilder::SetKernelType(const KernelType &kernel_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_); MS_EXCEPTION_IF_NULL(kernel_build_info_);
......
...@@ -56,6 +56,11 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel ...@@ -56,6 +56,11 @@ void KernelQuery(const CNodePtr &kernel_node, std::vector<std::shared_ptr<kernel
TbeMetadataInfo(kernel_node, kernel_info_list); TbeMetadataInfo(kernel_node, kernel_info_list);
if (kernel_info_list->empty()) { if (kernel_info_list->empty()) {
AicpuMetadataInfo(kernel_node, kernel_info_list); AicpuMetadataInfo(kernel_node, kernel_info_list);
if (!kernel_info_list->empty()) {
MS_LOG(INFO) << "Warning The node [" << kernel_node->DebugString()
<< "] cannot find valid TBE kernel info, try to get aicpu kernel info";
AnfAlgo::SetNodeAttr(kAttrIsAICPUKernel, MakeValue(true), kernel_node);
}
} }
if (kernel_info_list->empty()) { if (kernel_info_list->empty()) {
......
...@@ -31,54 +31,6 @@ namespace mindspore { ...@@ -31,54 +31,6 @@ namespace mindspore {
namespace opt { namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace { namespace {
kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format,
const AnfNodePtr &node, const TypeId device_type,
const kernel::KernelBuildInfo &ori_build_info) {
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({input_format});
builder.SetOutputsFormat({output_format});
builder.SetInputsDeviceType({device_type});
builder.SetOutputsDeviceType({device_type});
builder.SetKernelType(ori_build_info.kernel_type());
builder.SetFusionType(ori_build_info.fusion_type());
builder.SetProcessor(ori_build_info.processor());
return builder.Build();
}
CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
const bool need_padding, const std::string &op_name) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(input);
std::vector<AnfNodePtr> trans_inputs;
auto prim = std::make_shared<Primitive>(op_name);
trans_inputs.push_back(NewValueNode(prim));
trans_inputs.push_back(input);
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
MS_EXCEPTION_IF_NULL(trans_node);
std::vector<kernel::Axis> padding_axis;
padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
if (need_padding) {
// if need padding we should set the transdata node's shape to the padding shape
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)},
trans_node.get());
} else {
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
{AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
}
// special handle for ut
if (trans_node->kernel_info() == nullptr) {
auto kernel_info = std::make_shared<device::KernelInfo>();
trans_node->set_kernel_info(kernel_info);
}
MS_EXCEPTION_IF_NULL(kernel_select);
kernel_select->SelectKernel(trans_node);
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);
MS_EXCEPTION_IF_NULL(trans_node);
trans_node->set_scope(input->scope());
return trans_node;
}
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) { const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
std::vector<AnfNodePtr> trans_inputs; std::vector<AnfNodePtr> trans_inputs;
...@@ -94,6 +46,58 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i ...@@ -94,6 +46,58 @@ AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &i
return reshape; return reshape;
} }
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) {
AnfNodePtr trans_node = nullptr;
AnfNodePtr input_node = node;
CNodePtr trans_data = nullptr;
std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0);
std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT;
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0);
std::vector<kernel::Axis> padding_axis = AnfAlgo::GetOutputReshapeType(node, 0);
MS_EXCEPTION_IF_NULL(node);
// if insert transdata for input we need to change the input
if (is_insert_input) {
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode";
}
auto cnode = node->cast<CNodePtr>();
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index);
dst_format = AnfAlgo::GetInputFormat(cnode, insert_index);
input_node = AnfAlgo::GetInputNode(cnode, insert_index);
padding_axis = AnfAlgo::GetInputReshapeType(node, 0);
}
bool need_padding = false;
if (is_insert_input) {
need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
} else {
need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()));
}
if (!need_padding) {
// don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
trans_node = trans_data;
} else if (is_insert_input) {
// if need padding & is input need insert a transdata
// reshape[padding shape] -> transdata[padding shape] -> node
auto padding_shape =
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name());
trans_node = trans_data;
} else {
// if need padding & is output need insert a transdata
// node -> transdata[padding shape] -> reshape[ori_shape]
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name());
auto reshape_node =
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
trans_node = reshape_node;
}
// refresh the transdata's format to ori format & dst format
RefreshKernelBuildInfo(input_format, dst_format, dtype, trans_data, padding_axis);
return trans_node;
}
AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index,
const KernelSelectPtr &kernel_select) { const KernelSelectPtr &kernel_select) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
...@@ -111,13 +115,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr & ...@@ -111,13 +115,11 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
<< "when inserting the transdata node " << node->DebugString(); << "when inserting the transdata node " << node->DebugString();
} }
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
std::string origin_format = kOpFormat_DEFAULT;
std::string dest_format = AnfAlgo::GetInputFormat(node, index); std::string dest_format = AnfAlgo::GetInputFormat(node, index);
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
<< " To DefaultFormat , index: " << index; << " To DefaultFormat , index: " << index;
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName, return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
true);
} }
return input_node; return input_node;
} }
...@@ -131,12 +133,9 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An ...@@ -131,12 +133,9 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
<< node->DebugString(); << node->DebugString();
} }
std::string origin_format = output_format;
std::string dest_format = kOpFormat_DEFAULT;
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName, return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
false);
} }
return node; return node;
} }
...@@ -155,10 +154,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const ...@@ -155,10 +154,8 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
} }
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
std::string dest_format = kOpFormat_DEFAULT;
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format, make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false));
dest_format, kTransDataOpName, false));
} else { } else {
// No need insert trans op. // No need insert trans op.
make_tuple_inputs.push_back(tuple_getitem); make_tuple_inputs.push_back(tuple_getitem);
...@@ -168,62 +165,54 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const ...@@ -168,62 +165,54 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
return make_tuple; return make_tuple;
} }
} // namespace } // namespace
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type,
const KernelSelectPtr &kernel_select, size_t insert_index, const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type) {
const std::string &origin_format, const std::string &dest_format, MS_EXCEPTION_IF_NULL(trans_data);
const std::string &op_name, bool is_insert_input) { MS_EXCEPTION_IF_NULL(trans_data->kernel_info());
AnfNodePtr trans_node = nullptr; auto ori_build_info = trans_data->kernel_info()->select_kernel_build_info();
AnfNodePtr input_node = node; KernelBuildInfoBuilder builder;
AnfNodePtr trans_data = nullptr; builder.SetInputsFormat({input_format});
TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); builder.SetInputReshapeType({reshape_type});
MS_EXCEPTION_IF_NULL(node); builder.SetInputReshapeType({reshape_type});
if (origin_format.empty() || dest_format.empty()) { builder.SetOutputsFormat({output_format});
MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; builder.SetInputsDeviceType({device_type});
} builder.SetOutputsDeviceType({device_type});
// if insert transdata for input we need to change the input builder.SetKernelType(ori_build_info->kernel_type());
if (is_insert_input) { builder.SetFusionType(ori_build_info->fusion_type());
if (!node->isa<CNode>()) { builder.SetProcessor(ori_build_info->processor());
MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), trans_data.get());
} }
auto cnode = node->cast<CNodePtr>();
dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
MS_EXCEPTION_IF_NULL(cnode); const bool need_padding, const std::string &op_name) {
input_node = AnfAlgo::GetInputNode(cnode, insert_index); MS_EXCEPTION_IF_NULL(func_graph);
} MS_EXCEPTION_IF_NULL(input);
bool need_padding = false; std::vector<AnfNodePtr> trans_inputs;
if (is_insert_input) { auto prim = std::make_shared<Primitive>(op_name);
need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && trans_inputs.push_back(NewValueNode(prim));
op_name == kTransDataOpName); trans_inputs.push_back(input);
CNodePtr trans_node = func_graph->NewCNode(trans_inputs);
MS_EXCEPTION_IF_NULL(trans_node);
auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0);
if (need_padding) {
// if need padding we should set the transdata node's shape to the padding shape
AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
{trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)},
trans_node.get());
} else { } else {
need_padding = (trans::IsNeedPadding(origin_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)},
op_name == kTransDataOpName); {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get());
} }
if (!need_padding) { // special handle for ut
// don't need padding insert transdata only if (trans_node->kernel_info() == nullptr) {
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); auto kernel_info = std::make_shared<device::KernelInfo>();
trans_node = trans_data; trans_node->set_kernel_info(kernel_info);
} else if (is_insert_input) {
// if need padding & is input need insert a transdata
// reshape[padding shape] -> transdata[padding shape] -> node
auto padding_shape =
trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0));
auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name);
trans_node = trans_data;
} else {
// if need padding & is output need insert a transdata
// node -> transdata[padding shape] -> reshape[ori_shape]
trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name);
auto reshape_node =
CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0));
trans_node = reshape_node;
} }
// refresh the transdata's format to ori format & dst format MS_EXCEPTION_IF_NULL(kernel_select);
MS_EXCEPTION_IF_NULL(trans_data); kernel_select->SelectKernel(trans_node);
MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), trans_node);
auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(trans_node);
auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info); trans_node->set_scope(input->scope());
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get());
return trans_node; return trans_node;
} }
......
...@@ -58,11 +58,11 @@ class KernelQuery { ...@@ -58,11 +58,11 @@ class KernelQuery {
} }
}; };
using KernelQueryPtr = std::shared_ptr<KernelQuery>; using KernelQueryPtr = std::shared_ptr<KernelQuery>;
void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type,
const AnfNodePtr &trans_data, const std::vector<kernel::Axis> &reshape_type = {});
AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select,
const KernelSelectPtr &kernel_select, size_t insert_index, const bool need_padding, const std::string &op_name);
const std::string &origin_format, const std::string &dest_format,
const std::string &op_name, bool is_insert_input);
AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type, const TypeId &input_type, const TypeId &output_type,
......
...@@ -105,8 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP ...@@ -105,8 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP
// insert trans // insert trans
if (origin_format != cur_format && cur_shape.size() > 1) { if (origin_format != cur_format && cur_shape.size() > 1) {
auto kernel_select = std::make_shared<KernelSelect>(); auto kernel_select = std::make_shared<KernelSelect>();
final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format, final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name());
kTransDataOpName, false); RefreshKernelBuildInfo(cur_format, origin_format, origin_type, final_node);
final_index = 0; final_index = 0;
MS_EXCEPTION_IF_NULL(final_node); MS_EXCEPTION_IF_NULL(final_node);
MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString();
......
...@@ -67,22 +67,30 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n ...@@ -67,22 +67,30 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n
// if output_format=default transdata need split transdata->transpose else transpose->transdata // if output_format=default transdata need split transdata->transpose else transpose->transdata
if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) {
// trans input_format to hwcn // trans input_format to hwcn
new_transdata_node = new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, kTransDataOpName, true); false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0),
new_transdata_node);
// trans hwcn to default_format // trans hwcn to default_format
new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, kOpFormat_HWCN, new_transpose_node =
output_format, prim::kPrimTranspose->name(), false); NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name());
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0),
new_transpose_node);
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node); AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{3, 2, 0, 1}), new_transpose_node);
new_replace_node = new_transpose_node; new_replace_node = new_transpose_node;
} else { } else {
// trans default to hwcn // trans default to hwcn
new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast<CNodePtr>(), 0), kernel_select_,
prim::kPrimTranspose->name(), true); false, prim::kPrimTranspose->name());
AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node); AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector<int>{2, 3, 1, 0}), new_transpose_node);
RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0),
new_transpose_node);
// trans hwcn to output_format // trans hwcn to output_format
new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, kOpFormat_HWCN, new_transdata_node =
output_format, kTransDataOpName, false); NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name());
RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0),
new_transpose_node);
new_replace_node = new_transdata_node; new_replace_node = new_transdata_node;
} }
FuncGraphManagerPtr manager = func_graph->manager(); FuncGraphManagerPtr manager = func_graph->manager();
......
...@@ -196,10 +196,10 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) { ...@@ -196,10 +196,10 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
} }
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->SetFeatureMapFlag(true); kernel_info->SetFeatureMapFlag(true);
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(true), cnode); }
if (AnfAlgo::IsRealCNodeKernel(cnode)) {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
} else {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(false), cnode);
} }
cnode->set_kernel_info(kernel_info); cnode->set_kernel_info(kernel_info);
AnfAlgo::SetGraphId(graph_id_, cnode.get()); AnfAlgo::SetGraphId(graph_id_, cnode.get());
......
...@@ -151,7 +151,7 @@ constexpr auto kSquareSumAllOpName = "SquareSumAll"; ...@@ -151,7 +151,7 @@ constexpr auto kSquareSumAllOpName = "SquareSumAll";
// attr key name // attr key name
constexpr auto kAttrInputNames = "input_names"; constexpr auto kAttrInputNames = "input_names";
constexpr auto kAttrIsAICPUKernel = "is_ai_cpu_kernel"; constexpr auto kAttrIsAICPUKernel = "is_AICPU_kernel";
constexpr auto kIsBackendCast = "is_backed_cast"; constexpr auto kIsBackendCast = "is_backed_cast";
constexpr auto kAttrOutputNames = "output_names"; constexpr auto kAttrOutputNames = "output_names";
constexpr auto kAttrVisited = "visited"; constexpr auto kAttrVisited = "visited";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册