未验证 提交 587f66ee 编写于 作者: H huangjiyi 提交者: GitHub

Support code generation for op multiclass_nms3 (#54272)

* update

* update eager_gen

* update

* rm intermediate
上级 8d30a5f3
......@@ -1469,9 +1469,12 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
# Get return type list & outputs
returns_type_list = ["" for i in range(num_outputs)]
returns_list = ["" for i in range(num_outputs)]
num_visited_intermediate_outputs = 0
for name, (rtype, pos) in forward_outputs_position_map.items():
if name in intermediate_outputs:
num_visited_intermediate_outputs += 1
continue
pos -= num_visited_intermediate_outputs
returns_list[pos] = f"{name}"
if IsPlainTensorType(rtype):
......
......@@ -607,45 +607,12 @@ class MultiClassNMS2OpMaker : public MultiClassNMSOpMaker {
}
};
class MultiClassNMS3Op : public MultiClassNMS2Op {
public:
MultiClassNMS3Op(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: MultiClassNMS2Op(type, inputs, outputs, attrs) {}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), ctx.GetPlace());
}
};
class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
public:
void Make() override {
MultiClassNMS2OpMaker::Make();
AddInput("RoisNum",
"(Tensor) The number of RoIs in shape (B),"
"B is the number of images")
.AsDispensable();
AddOutput("NmsRoisNum", "(Tensor), The number of NMS RoIs in each image")
.AsDispensable();
}
};
template <typename T, typename DeviceContext>
class MultiClassNMS2Kernel : public MultiClassNMSKernel<T, DeviceContext> {};
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(multiclass_nms3,
MultiClassNMSShapeFunctor,
PD_INFER_META(phi::MultiClassNMSInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(
multiclass_nms,
......@@ -668,11 +635,3 @@ PD_REGISTER_STRUCT_KERNEL(multiclass_nms2,
ops::MultiClassNMS2Kernel,
float,
double) {}
REGISTER_OPERATOR(
multiclass_nms3,
ops::MultiClassNMS3Op,
ops::MultiClassNMS3OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
MultiClassNMSShapeFunctor);
......@@ -666,15 +666,6 @@
func : mish
backward : mish_grad
- op : multiclass_nms3
args : (Tensor bboxes, Tensor scores, Tensor rois_num, float score_threshold, int nms_top_k, int keep_top_k, float nms_threshold=0.3, bool normalized=true, float nms_eta=1.0, int background_label=0)
output : Tensor(out), Tensor(index), Tensor(nms_rois_num)
infer_meta :
func : MultiClassNMSInferMeta
kernel :
func : multiclass_nms3
optional : rois_num
- op : multiply
args : (Tensor x, Tensor y)
output : Tensor
......
......@@ -1702,6 +1702,12 @@
out : Out
drop_empty_grad : [x_grad]
- op : multiclass_nms3
inputs :
{bboxes : BBoxes, scores : Scores, rois_num : RoisNum}
outputs :
{out : Out, index : Index, nms_rois_num : NmsRoisNum}
- op : multinomial
inputs :
{x : X}
......
......@@ -1548,6 +1548,16 @@
func : multi_dot
backward : multi_dot_grad
- op : multiclass_nms3
args : (Tensor bboxes, Tensor scores, Tensor rois_num, float score_threshold, int nms_top_k, int keep_top_k, float nms_threshold=0.3, bool normalized=true, float nms_eta=1.0, int background_label=0)
output : Tensor(out), Tensor(index), Tensor(nms_rois_num)
infer_meta :
func : MultiClassNMSInferMeta
kernel :
func : multiclass_nms3
data_type : scores
optional : rois_num, nms_rois_num
- op : multinomial
args : (Tensor x, Scalar(int) num_samples = 1, bool replacement = false)
output : Tensor(out)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature MultiClassNMS3OpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("multiclass_nms3",
{"BBoxes", "Scores", "RoisNum"},
{"score_threshold",
"nms_top_k",
"keep_top_k",
"nms_threshold",
"normalized",
"nms_eta",
"background_label"},
{"Out", "Index", "NmsRoisNum"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(multiclass_nms3,
phi::MultiClassNMS3OpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册