diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index c96f8c7813310225545e70b2a35731e16e3df6b1..9aa57849661c56ac1c0fe0267bb78480734fc7dc 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -321,9 +321,11 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ MS_LOG(INFO) << "Dynamic select format response successful, use dynamic format."; for (size_t i = 0; i < inputs_static.size(); i++) { inputs_dyn[i]->set_param_type(inputs_static[i]->param_type()); + inputs_dyn[i]->set_reshape_type(inputs_static[i]->reshape_type()); } for (size_t j = 0; j < outputs_static.size(); j++) { outputs_dyn[j]->set_param_type(outputs_static[j]->param_type()); + outputs_dyn[j]->set_reshape_type(outputs_static[j]->reshape_type()); } op_info_new_ptr->set_inputs_ptr(inputs_dyn); op_info_new_ptr->set_outputs_ptr(outputs_dyn); @@ -335,6 +337,29 @@ void ReplaceByDynamicFormatDtype(const CNodePtr &kernel_node, const std::shared_ op_info_new_ptr->set_fusion_type(op_info_ptr->fusion_type()); } +bool StringToAxisVector(const std::string &reshape_type_str, std::vector *reshape_type_vec) { + for (const auto &c : reshape_type_str) { + switch (c) { + case 'N': + reshape_type_vec->push_back(kernel::N); + break; + case 'C': + reshape_type_vec->push_back(kernel::C); + break; + case 'H': + reshape_type_vec->push_back(kernel::H); + break; + case 'W': + reshape_type_vec->push_back(kernel::W); + break; + default: + MS_LOG(ERROR) << "Unknown axis " << c << "in reshape type."; + return false; + } + } + return true; +} + bool SetKernelBuilderInputInfo(const std::vector> &inputs, size_t real_input_num, size_t builder_idex, const std::vector &dyn_input_sizes, const std::shared_ptr &builder) { @@ -347,6 +372,7 @@ bool SetKernelBuilderInputInfo(const std::vector> &inp MS_EXCEPTION_IF_NULL(inputs[0]); size_t kernel_info_cnt = inputs[0]->dtypes().size(); + std::vector> reshape_types; for (const auto &input : inputs) { MS_EXCEPTION_IF_NULL(input); std::string param_type = input->param_type(); @@ -384,8 +410,14 @@ bool SetKernelBuilderInputInfo(const std::vector> &inp inputs_format.push_back(formats[builder_idex]); } } + std::vector reshape_type; + if (!StringToAxisVector(input->reshape_type(), &reshape_type)) { + return false; + } + reshape_types.push_back(reshape_type); } + builder->SetInputReshapeType(reshape_types); builder->SetInputsDeviceType(inputs_device_type); builder->SetInputsFormat(inputs_format); return true; @@ -403,6 +435,7 @@ bool SetKernelBuilderOutputInfo(const std::vector> &ou MS_EXCEPTION_IF_NULL(outputs[0]); size_t kernel_info_cnt = outputs[0]->dtypes().size(); + std::vector> reshape_types; for (const auto &output : outputs) { MS_EXCEPTION_IF_NULL(output); if (output_idx >= real_output_num) { @@ -436,8 +469,14 @@ bool SetKernelBuilderOutputInfo(const std::vector> &ou outputs_format.push_back(formats[builder_idex]); output_idx++; } + std::vector reshape_type; + if (!StringToAxisVector(output->reshape_type(), &reshape_type)) { + return false; + } + reshape_types.push_back(reshape_type); } + builder->SetOutputReshapeType(reshape_types); builder->SetOutputsFormat(outputs_format); builder->SetOutputsDeviceType(outputs_device_type); return true; @@ -515,7 +554,7 @@ bool IsShapeMatchFormat(const std::vector &shape, const std::string &for const std::set kOpFormatList = {kOpFormat_DEFAULT, kOpFormat_NC1KHKWHWC0, kOpFormat_ND, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_HWCN, kOpFormat_NC1HWC0, kOpFormat_FRAC_Z, kOpFormat_C1HWNCoC0, - kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04}; + kOpFormat_FRAC_NZ, kOpFormat_NC1HWC0_C04, kOpFormat_FRACTAL_Z_C04}; // if format is default, it remarkes support all format if (kOpFormatList.find(format) == kOpFormatList.end()) { @@ -528,13 +567,13 @@ bool IsShapeMatchFormat(const std::vector &shape, const std::string &for if (shape.empty()) { return true; } - if (shape.size() > kShapeSupportFormatMap.size()) { + if (shape.size() > kShape4dDims) { return false; } - if (format == kOpFormat_FRAC_NZ && shape.size() >= 2) { - return true; + if (format == kOpFormat_FRAC_NZ && shape.size() < 2) { + return false; } - return !(kShapeSupportFormatMap[shape.size() - 1].find(format) == kShapeSupportFormatMap[shape.size() - 1].end()); + return true; } bool IsValidKernelInfo(const std::shared_ptr &kernel_node, const kernel::KernelBuildInfo &kernel_build_info) { diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index fbb3e345dfa5d361eb56061504b54a5d535acf1b..9fdb2080b392bf816b7d837743608081e4caf3ed 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -55,12 +55,17 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, trans_inputs.push_back(input); CNodePtr trans_node = func_graph->NewCNode(trans_inputs); MS_EXCEPTION_IF_NULL(trans_node); + std::vector padding_axis; + if (AnfAlgo::IsRealKernel(input)) { + padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); + } else { + padding_axis = AnfAlgo::GetPrevNodeOutputReshapeType(input, 0); + } if (need_padding) { // if need padding we should set the transdata node's shape to the padding shape - AnfAlgo::SetOutputInferTypeAndShape( - {AnfAlgo::GetOutputInferDataType(input, 0)}, - {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), AnfAlgo::GetOutputReshapeType(input, 0))}, - trans_node.get()); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, + {trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input, 0), padding_axis)}, + trans_node.get()); } else { AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, {AnfAlgo::GetOutputInferShape(input, 0)}, trans_node.get()); @@ -194,8 +199,14 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt MS_EXCEPTION_IF_NULL(cnode); input_node = AnfAlgo::GetInputNode(cnode, insert_index); } - bool need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && - op_name == kTransDataOpName); + bool need_padding = false; + if (is_insert_input) { + need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && + op_name == kTransDataOpName); + } else { + need_padding = (trans::IsNeedPadding(origin_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && + op_name == kTransDataOpName); + } if (!need_padding) { // don't need padding insert transdata only trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.cc index cb8670dd00046270ec0423692e61070ab3d0aecf..270b02cb00e59ce102028589fe07e42f37e58c5c 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/batch_norm_grad_split.cc @@ -86,7 +86,6 @@ void CreateOutputsOfReduceGrad(const FuncGraphPtr &graph, const CNodePtr &bn_gra AnfAlgo::CopyNodeAttr(kAttrEpsilon, bn_grad_node, bn_reduce_grad); (*bn_reduce_grad_outputs).push_back(bn_reduce_grad); } - } // namespace const BaseRef BatchNormGradSplit::DefinePattern() const { VarPtr Xs = std::make_shared(); diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc index 62334880899d9fb85a2ca4f68595cc41033a4510..5cc3374ea5eabc80e9467f6e1a0b2bec8f776211 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ b/mindspore/ccsrc/pre_activate/common/helper.cc @@ -344,7 +344,7 @@ bool IsNopNode(const AnfNodePtr &node) { return true; } -bool IsAllNopNode(session::KernelGraph *const graph) { +bool IsAllNopNode(const session::KernelGraph *const graph) { MS_EXCEPTION_IF_NULL(graph); auto execution_order = graph->execution_order(); for (auto &cnode : execution_order) { diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 45588052b09de8ad667e66e0c2306ad52ad0b686..2853f6760fe3da88c8240a8612745d07c70fce21 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -347,6 +347,11 @@ std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_n return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second); } +std::vector AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) { + KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx); + return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second); +} + std::vector AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); abstract::BaseShapePtr base_shape = node->Shape(); diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index a70a63b6786b12c64320767f497d56520e6e18b2..78ebf3121020ae1a422a400d46114a3c8e69390d 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -95,6 +95,8 @@ class AnfRuntimeAlgorithm { static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx); // get output format from prev node,input_index is the input index of current node related to prev node static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx); + // get reshape_type of from the output of input node. + static std::vector GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx); // get output shapes inferred by ME from input nodes. static std::vector GetOutputInferShape(const AnfNodePtr &node, size_t output_idx); // get input shapes inferred by ME from input nodes. diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index b85bc1b0f84a25408493cb48ed727eace75fc8aa..b1d71c18ac256744080a18a5460cad6754e10110 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -204,6 +204,7 @@ constexpr auto kOpFormat_FRAC_Z = "FracZ"; constexpr auto kOpFormat_FRAC_NZ = "FRACTAL_NZ"; constexpr auto kOpFormat_C1HWNCoC0 = "C1HWNCoC0"; constexpr auto kOpFormat_NC1HWC0_C04 = "NC1HWC0_C04"; +constexpr auto kOpFormat_FRACTAL_Z_C04 = "FRACTAL_Z_C04"; const std::set k1DSupportFormat = {kOpFormat_DEFAULT, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_C1HWNCoC0}; @@ -225,8 +226,9 @@ const std::set kOptOperatorSet = { kApplyRMSPropOpName, }; -const std::set kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, - kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; +const std::set kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, + kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0, kOpFormat_NC1HWC0_C04, + kOpFormat_FRACTAL_Z_C04}; static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { if (access(file_name.c_str(), F_OK) != 0) { diff --git a/mindspore/ops/_op_impl/tbe/bn_training_reduce.py b/mindspore/ops/_op_impl/tbe/bn_training_reduce.py index 16d75d06be6d674014c6e415028c9ee98eacf06c..e19d4b65ffd83a7bb6f0d09fc0aa8a251055c6e7 100644 --- a/mindspore/ops/_op_impl/tbe/bn_training_reduce.py +++ b/mindspore/ops/_op_impl/tbe/bn_training_reduce.py @@ -23,7 +23,7 @@ bn_training_reduce_op_info = TBERegOp("BNTrainingReduce") \ .compute_cost(10) \ .kernel_name("bn_training_reduce") \ .partial_flag(True) \ - .input(0, "x", False, "required", "all") \ + .input(0, "x", False, "required", "all", reshape_type="NC") \ .output(0, "sum", False, "required", "all") \ .output(1, "square_sum", False, "required", "all") \ .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD) \ diff --git a/mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py b/mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py index e92054670d5bc708e0489122882e9e929d8820e3..66dc55ab105a9c49cbf039d8789a4cb0e9034ded 100644 --- a/mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py +++ b/mindspore/ops/_op_impl/tbe/bn_training_reduce_grad.py @@ -24,14 +24,14 @@ bn_training_reduce_grad_op_info = TBERegOp("BNTrainingReduceGrad") \ .kernel_name("bn_training_reduce_grad") \ .partial_flag(True) \ .attr("epsilon", "optional", "float", "all") \ - .input(0, "grads", False, "required", "all") \ - .input(1, "x_norm", False, "required", "all") \ + .input(0, "grads", False, "required", "all", reshape_type="NC") \ + .input(1, "x_norm", False, "required", "all", reshape_type="NC") \ .input(2, "diff_scale", False, "required", "all") \ .input(3, "diff_offset", False, "required", "all") \ .input(4, "scale", False, "required", "all") \ .input(5, "batch_mean", False, "required", "all") \ .input(6, "batch_variance", False, "required", "all") \ - .output(0, "y", False, "required", "all") \ + .output(0, "y", False, "required", "all", reshape_type="NC") \ .dtype_format(DataType.F16_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, diff --git a/mindspore/ops/_op_impl/tbe/bn_training_update.py b/mindspore/ops/_op_impl/tbe/bn_training_update.py index 49b572e31ee8ca683bb755c93eaef5ff03c42018..c79c09be4236d4dfa3f445f87128674b462a5c93 100644 --- a/mindspore/ops/_op_impl/tbe/bn_training_update.py +++ b/mindspore/ops/_op_impl/tbe/bn_training_update.py @@ -26,14 +26,14 @@ bn_training_update_op_info = TBERegOp("BNTrainingUpdate") \ .attr("factor", "optional", "float", "all") \ .attr("epsilon", "optional", "float", "all") \ .attr("isRef", "optional", "bool", "all", "true") \ - .input(0, "x", False, "required", "all") \ + .input(0, "x", False, "required", "all", reshape_type="NC") \ .input(1, "sum", False, "required", "all") \ .input(2, "square_sum", False, "required", "all") \ .input(3, "scale", False, "required", "all") \ .input(4, "offset", False, "required", "all") \ .input(5, "mean", False, "required", "all") \ .input(6, "variance", False, "required", "all") \ - .output(0, "y", False, "required", "all") \ + .output(0, "y", False, "required", "all", reshape_type="NC") \ .output(1, "mean", False, "required", "all") \ .output(2, "variance", False, "required", "all") \ .output(3, "batch_mean", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/bn_training_update_grad.py b/mindspore/ops/_op_impl/tbe/bn_training_update_grad.py index 5e693bea42ded764103e7390ad3e107ff424b965..5098923281595765b57082f6ce9f4d6692cb324a 100644 --- a/mindspore/ops/_op_impl/tbe/bn_training_update_grad.py +++ b/mindspore/ops/_op_impl/tbe/bn_training_update_grad.py @@ -24,8 +24,8 @@ bn_training_update_grad_op_info = TBERegOp("BNTrainingUpdateGrad") \ .kernel_name("bn_training_update_grad") \ .partial_flag(True) \ .attr("epsilon", "optional", "float", "all") \ - .input(0, "grads", False, "required", "all") \ - .input(1, "x", False, "required", "all") \ + .input(0, "grads", False, "required", "all", reshape_type="NC") \ + .input(1, "x", False, "required", "all", reshape_type="NC") \ .input(2, "batch_mean", False, "required", "all") \ .input(3, "batch_variance", False, "required", "all") \ .output(0, "diff_scale", False, "required", "all") \