提交 5d25bf7c 编写于 作者: W WilliamLian

add more transform format insert transdata

上级 b9e59f9d
......@@ -70,7 +70,7 @@ string GetPriorityMatchFormat(const CNodePtr &cnode) {
for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cnode); ++index) {
auto pre_output_format = AnfAlgo::GetPrevNodeOutputFormat(cnode, index);
if (AnfAlgo::IsFeatureMapInput(cnode, index) &&
kNeedTransFormatSet.find(pre_output_format) != kNeedTransFormatSet.end()) {
kHWSpecialFormatSet.find(pre_output_format) != kHWSpecialFormatSet.end()) {
priority_matched_format = !is_init ? pre_output_format : priority_matched_format;
is_init = true;
}
......
......@@ -31,6 +31,7 @@ namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
namespace {
const std::set<std::string> kCommonFormatSet = {kOpFormat_DEFAULT, kOpFormat_ND, kOpFormat_NCHW};
AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node,
const KernelSelectPtr &kernel_select, const std::vector<size_t> &dst_shape) {
std::vector<AnfNodePtr> trans_inputs;
......@@ -110,13 +111,9 @@ AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &
MS_EXCEPTION_IF_NULL(input_node);
AnfAlgo::SetNodeInput(node, input_node, index);
}
if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) {
MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index)
<< "when inserting the transdata node " << node->DebugString();
}
std::vector<size_t> origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index);
std::string dest_format = AnfAlgo::GetInputFormat(node, index);
if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
if (kCommonFormatSet.find(dest_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index)
<< " To DefaultFormat , index: " << index;
return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true);
......@@ -133,7 +130,7 @@ AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const An
MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node "
<< node->DebugString();
}
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0";
return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false);
}
......@@ -154,7 +151,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
}
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) {
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false));
} else {
// No need insert trans op.
......
......@@ -97,7 +97,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
std::string convert_format;
for (const auto &do_mask : do_mask_node_list) {
auto do_mask_data_format = AnfAlgo::GetInputFormat(do_mask, 0);
if (special_format.empty() && kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end()) {
if (special_format.empty() && kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end()) {
special_format = do_mask_data_format;
}
if (format_counter.find(do_mask_data_format) == format_counter.end()) {
......@@ -111,7 +111,7 @@ void RectifyDoMaskKernelInfo::RectifyKernelInfo(const std::vector<CNodePtr> &do_
convert_format = kOpFormat_DEFAULT;
break;
}
if (kNeedTransFormatSet.find(do_mask_data_format) != kNeedTransFormatSet.end() &&
if (kHWSpecialFormatSet.find(do_mask_data_format) != kHWSpecialFormatSet.end() &&
special_format != do_mask_data_format) {
convert_format = kOpFormat_DEFAULT;
break;
......@@ -133,7 +133,7 @@ std::string RectifyDoMaskKernelInfo::GetConvertFormat(const std::map<std::string
if (counter < iter.second) {
convert_format = iter.first;
}
if (counter == iter.second && kNeedTransFormatSet.find(convert_format) == kNeedTransFormatSet.end()) {
if (counter == iter.second && kHWSpecialFormatSet.find(convert_format) == kHWSpecialFormatSet.end()) {
convert_format = iter.first;
}
}
......
......@@ -265,7 +265,7 @@ const std::set<std::string> kOptOperatorSet = {
kApplyRMSPropOpName,
};
const std::set<std::string> kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
const std::set<std::string> kHWSpecialFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0,
kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04,
kOpFormat_FRACTAL_Z_C04};
......
......@@ -58,6 +58,8 @@ trans_data_op_info = TBERegOp("TransData") \
.dtype_format(DataType.F32_HWCN, DataType.F32_FracZ) \
.dtype_format(DataType.F32_HWCN, DataType.F32_C1HWNCoC0) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_HWCN) \
.dtype_format(DataType.F32_Default, DataType.F32_NCHW) \
.dtype_format(DataType.F32_HWCN, DataType.F32_Default) \
.get_op_info()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册