diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index ddd698f6e022f0fed9a16a8ffe0b3002d63951f0..dcb8e606ae922103abff5db6872763292fa401ff 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -145,6 +145,8 @@ const PrimitivePtr kPrimReshape = std::make_shared("Reshape"); const PrimitivePtr kPrimTile = std::make_shared("Tile"); const PrimitivePtr kPrimAddN = std::make_shared("AddN"); const PrimitivePtr KPrimTransData = std::make_shared("TransData"); +const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); +const PrimitivePtr kPrimPad = std::make_shared("Pad"); // Maths const PrimitivePtr kPrimTensorAdd = std::make_shared("TensorAdd"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 1125feee7dc976f4d6f5d6bc9da31d69fe697417..8a60055298b02785d84baff3fd767ff05cc3206c 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -151,6 +151,8 @@ extern const PrimitivePtr kPrimReshape; extern const PrimitivePtr kPrimTile; extern const PrimitivePtr kPrimAddN; extern const PrimitivePtr KPrimTransData; +extern const PrimitivePtr kPrimNMSWithMask; +extern const PrimitivePtr kPrimPad; // Maths extern const PrimitivePtr kPrimTensorAdd; diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index f6feb0440f459e5f5614d3d440a735f453b21425..24bf90dd69fb084a9237c04b32a262fe9631f161 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -78,6 +78,7 @@ #include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h" #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" #include "pre_activate/ascend/enhancer/add_memcpy_async.h" +#include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" #include "pre_activate/ascend/format_type/insert_cast_for_runop.h" #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" @@ -227,6 +228,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); } ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); if (context_ptr->ir_fusion_flag()) { AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); } diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.cc b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.cc new file mode 100644 index 0000000000000000000000000000000000000000..20a10e7d22c1667982528b8a610e86941375818d --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.cc @@ -0,0 +1,109 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h" +#include +#include +#include +#include "pre_activate/ascend/ascend_helper.h" +#include "pre_activate/common/helper.h" +#include "session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "device/kernel_info.h" +#include "kernel//oplib/oplib.h" +#include "operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef InsertPadForNMSWithMask::DefinePattern() const { + VarPtr Xs = std::make_shared(); + return VectorRef({prim::kPrimNMSWithMask, Xs}); +} + +AnfNodePtr INsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, + const TypeId &input_type, const TypeId &output_type, const TypeId &origin_type, + const std::vector &origin_shape) { + MS_EXCEPTION_IF_NULL(func_graph); + std::vector new_pad_inputs; + auto prim = std::make_shared(prim::kPrimPad->name()); + new_pad_inputs.push_back(NewValueNode(prim)); + new_pad_inputs.push_back(input); + CNodePtr pad = func_graph->NewCNode(new_pad_inputs); + MS_EXCEPTION_IF_NULL(pad); + // set kernel build info + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetInputsFormat({format}); + builder.SetOutputsFormat({format}); + builder.SetInputsDeviceType({input_type}); + builder.SetOutputsDeviceType({output_type}); + builder.SetFusionType(kernel::FusionType::OPAQUE); + builder.SetProcessor(kernel::Processor::AICORE); + if (kernel::OpLib::FindOp(prim::kPrimPad->name(), kernel::kTBE) != nullptr) { + builder.SetKernelType(KernelType::TBE_KERNEL); + } else { + builder.SetKernelType(KernelType::AICPU_KERNEL); + } + + if (pad->kernel_info() == nullptr) { + auto kernel_info = std::make_shared(); + pad->set_kernel_info(kernel_info); + } + AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), pad.get()); + AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, pad.get()); + return pad; +} + +const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + + size_t input_num = AnfAlgo::GetInputTensorNum(node); + if (input_num == 0) { + return nullptr; + } + std::vector new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)}; + for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) { + auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx); + auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx); + auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode, input_idx); + auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx); + if (!(origin_shape.size() == 2 && origin_shape[1] == 5)) { + return nullptr; + } + origin_shape[1] = 8; + auto device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_idx); + auto pad = INsertPadToGraph(func_graph, cur_input, format, origin_type, device_type, origin_type, origin_shape); + MS_EXCEPTION_IF_NULL(pad); + pad->set_scope(cnode->scope()); + AnfAlgo::SetNodeAttr("paddings", MakeValue(std::vector>{{0, 0}, {0, 3}}), pad); + new_inputs.push_back(pad); + } + auto kernel_graph = func_graph->cast>(); + CNodePtr new_node = nullptr; + if (kernel_graph == nullptr) { + new_node = std::make_shared(*cnode); + } else { + new_node = kernel_graph->NewCNode(cnode); + } + MS_EXCEPTION_IF_NULL(new_node); + new_node->set_inputs(new_inputs); + return new_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h new file mode 100644 index 0000000000000000000000000000000000000000..bfc201ed1164df5d2a498ca48c807034f900f47d --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H + +#include "pre_activate/common/optimizer.h" +#include "pre_activate/common/pass.h" + +namespace mindspore { +namespace opt { +class InsertPadForNMSWithMask : public PatternProcessPass { + public: + explicit InsertPadForNMSWithMask(bool multigraph = true) + : PatternProcessPass("insert_pad_for_nms_with_mask", multigraph) {} + ~InsertPadForNMSWithMask() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index e829a49d793ceff19827279031fe78a779a9538d..ae420cfaec5e9841196df9368db9f5453d4e872c 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -148,6 +148,7 @@ constexpr auto kReturnOpName = "return"; constexpr auto kLarsV2OpName = "LarsV2"; constexpr auto kLarsV2UpdateOpName = "LarsV2Update"; constexpr auto kSquareSumAllOpName = "SquareSumAll"; +constexpr auto kNMSWithMaskOpName = "NMSWithMask"; // attr key name constexpr auto kAttrInputNames = "input_names"; diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index ab64a21498e47a616abb3dac8693fc677c67ce28..69bdd268658735c5c8c9dc77ae6bc3b6ed25d45c 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -2021,11 +2021,6 @@ class NMSWithMask(PrimitiveWithInfer): cls_name = self.name validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name) validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name) - if not self.is_ge: - validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 8, Rel.EQ, cls_name) - num = bboxes_shape[0] - return ((num, 5), (num,), (num,)) - validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name) num = bboxes_shape[0] return (bboxes_shape, (num,), (num,))