提交 0e7e82d1 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1397 add pad for nms_with_mask

Merge pull request !1397 from liubuyu/master
......@@ -145,6 +145,8 @@ const PrimitivePtr kPrimReshape = std::make_shared<Primitive>("Reshape");
const PrimitivePtr kPrimTile = std::make_shared<Primitive>("Tile");
const PrimitivePtr kPrimAddN = std::make_shared<Primitive>("AddN");
const PrimitivePtr KPrimTransData = std::make_shared<Primitive>("TransData");
const PrimitivePtr kPrimNMSWithMask = std::make_shared<Primitive>("NMSWithMask");
const PrimitivePtr kPrimPad = std::make_shared<Primitive>("Pad");
// Maths
const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
......
......@@ -151,6 +151,8 @@ extern const PrimitivePtr kPrimReshape;
extern const PrimitivePtr kPrimTile;
extern const PrimitivePtr kPrimAddN;
extern const PrimitivePtr KPrimTransData;
extern const PrimitivePtr kPrimNMSWithMask;
extern const PrimitivePtr kPrimPad;
// Maths
extern const PrimitivePtr kPrimTensorAdd;
......
......@@ -78,6 +78,7 @@
#include "pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h"
#include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h"
#include "pre_activate/ascend/enhancer/add_memcpy_async.h"
#include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h"
#include "pre_activate/ascend/format_type/insert_cast_for_runop.h"
#include "pre_activate/ascend/format_type/insert_transdata_for_runop.h"
#include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h"
......@@ -227,6 +228,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptr<session::KernelGrap
ir_fusion_pm->AddPass(std::make_shared<FusedBatchNormMixPrecisionFusion>());
}
ir_fusion_pm->AddPass(std::make_shared<AddMemcpyAsync>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
if (context_ptr->ir_fusion_flag()) {
AddAscendBackendOptionalIRFusion(ir_fusion_pm.get());
}
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pre_activate/ascend/enhancer/insert_pad_for_nms_with_mask.h"
#include <vector>
#include <string>
#include <memory>
#include "pre_activate/ascend/ascend_helper.h"
#include "pre_activate/common/helper.h"
#include "session/anf_runtime_algorithm.h"
#include "utils/utils.h"
#include "device/kernel_info.h"
#include "kernel//oplib/oplib.h"
#include "operator/ops.h"
namespace mindspore {
namespace opt {
const BaseRef InsertPadForNMSWithMask::DefinePattern() const {
VarPtr Xs = std::make_shared<SeqVar>();
return VectorRef({prim::kPrimNMSWithMask, Xs});
}
AnfNodePtr INsertPadToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format,
const TypeId &input_type, const TypeId &output_type, const TypeId &origin_type,
const std::vector<size_t> &origin_shape) {
MS_EXCEPTION_IF_NULL(func_graph);
std::vector<AnfNodePtr> new_pad_inputs;
auto prim = std::make_shared<Primitive>(prim::kPrimPad->name());
new_pad_inputs.push_back(NewValueNode(prim));
new_pad_inputs.push_back(input);
CNodePtr pad = func_graph->NewCNode(new_pad_inputs);
MS_EXCEPTION_IF_NULL(pad);
// set kernel build info
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetInputsFormat({format});
builder.SetOutputsFormat({format});
builder.SetInputsDeviceType({input_type});
builder.SetOutputsDeviceType({output_type});
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetProcessor(kernel::Processor::AICORE);
if (kernel::OpLib::FindOp(prim::kPrimPad->name(), kernel::kTBE) != nullptr) {
builder.SetKernelType(KernelType::TBE_KERNEL);
} else {
builder.SetKernelType(KernelType::AICPU_KERNEL);
}
if (pad->kernel_info() == nullptr) {
auto kernel_info = std::make_shared<device::KernelInfo>();
pad->set_kernel_info(kernel_info);
}
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), pad.get());
AnfAlgo::SetOutputInferTypeAndShape({origin_type}, {origin_shape}, pad.get());
return pad;
}
const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
size_t input_num = AnfAlgo::GetInputTensorNum(node);
if (input_num == 0) {
return nullptr;
}
std::vector<AnfNodePtr> new_inputs = {AnfAlgo::GetCNodePrimitiveNode(cnode)};
for (size_t input_idx = 0; input_idx < AnfAlgo::GetInputTensorNum(cnode); input_idx++) {
auto cur_input = AnfAlgo::GetInputNode(cnode, input_idx);
auto origin_type = AnfAlgo::GetPrevNodeOutputInferDataType(cnode, input_idx);
auto format = AnfAlgo::GetPrevNodeOutputFormat(cnode, input_idx);
auto origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, input_idx);
if (!(origin_shape.size() == 2 && origin_shape[1] == 5)) {
return nullptr;
}
origin_shape[1] = 8;
auto device_type = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode, input_idx);
auto pad = INsertPadToGraph(func_graph, cur_input, format, origin_type, device_type, origin_type, origin_shape);
MS_EXCEPTION_IF_NULL(pad);
pad->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr("paddings", MakeValue(std::vector<std::vector<int>>{{0, 0}, {0, 3}}), pad);
new_inputs.push_back(pad);
}
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
CNodePtr new_node = nullptr;
if (kernel_graph == nullptr) {
new_node = std::make_shared<CNode>(*cnode);
} else {
new_node = kernel_graph->NewCNode(cnode);
}
MS_EXCEPTION_IF_NULL(new_node);
new_node->set_inputs(new_inputs);
return new_node;
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H
#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/pass.h"
namespace mindspore {
namespace opt {
class InsertPadForNMSWithMask : public PatternProcessPass {
public:
explicit InsertPadForNMSWithMask(bool multigraph = true)
: PatternProcessPass("insert_pad_for_nms_with_mask", multigraph) {}
~InsertPadForNMSWithMask() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_ENHANCER_INSERT_PAD_FOR_NMS_WITH_MASK_H
......@@ -148,6 +148,7 @@ constexpr auto kReturnOpName = "return";
constexpr auto kLarsV2OpName = "LarsV2";
constexpr auto kLarsV2UpdateOpName = "LarsV2Update";
constexpr auto kSquareSumAllOpName = "SquareSumAll";
constexpr auto kNMSWithMaskOpName = "NMSWithMask";
// attr key name
constexpr auto kAttrInputNames = "input_names";
......
......@@ -2021,11 +2021,6 @@ class NMSWithMask(PrimitiveWithInfer):
cls_name = self.name
validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, cls_name)
validator.check_integer("bboxes.shape()[0]", bboxes_shape[0], 0, Rel.GT, cls_name)
if not self.is_ge:
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 8, Rel.EQ, cls_name)
num = bboxes_shape[0]
return ((num, 5), (num,), (num,))
validator.check_integer("bboxes.shape()[1]", bboxes_shape[1], 5, Rel.EQ, cls_name)
num = bboxes_shape[0]
return (bboxes_shape, (num,), (num,))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册