未验证 提交 f71c805e 编写于 作者: T Tian Zheng 提交者: GitHub

Add multiclass_nms3 GPU kernel (#52401)

* Add GPU kernel for multiclass_nms3 op

* Make multiclass_nms3 gpu kernel output consistent with cpu kernel

* Fix API incompatibility

* Fix unittests on builds without CUDA

* Fix ROCM build

* Remove fluid headers; Use default atol for unittest

* Change function and variable naming

* Add comments; Reduce redundant code

* Use paddle test framework
上级 d2fa26f6
......@@ -386,6 +386,7 @@ function(op_library TARGET)
list(REMOVE_ITEM hip_srcs "eigh_op.cu")
list(REMOVE_ITEM hip_srcs "lstsq_op.cu")
list(REMOVE_ITEM hip_srcs "multinomial_op.cu")
list(REMOVE_ITEM hip_srcs "multiclass_nms3_op.cu")
hip_library(
${TARGET}
SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${miopen_cu_srcs}
......
......@@ -614,6 +614,13 @@ class MultiClassNMS3Op : public MultiClassNMS2Op {
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: MultiClassNMS2Op(type, inputs, outputs, attrs) {}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(
OperatorWithKernel::IndicateVarDataType(ctx, "Scores"), ctx.GetPlace());
}
};
class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
......
此差异已折叠。
......@@ -20,7 +20,7 @@ from eager_op_test import OpTest
import paddle
from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import _non_static_mode, in_dygraph_mode
from paddle.fluid import _non_static_mode, core, in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
......@@ -355,6 +355,7 @@ def batched_multiclass_nms(
nms_top_k,
keep_top_k,
normalized=True,
gpu_logic=False,
):
batch_size = scores.shape[0]
num_boxes = scores.shape[2]
......@@ -392,6 +393,11 @@ def batched_multiclass_nms(
idx + n * num_boxes,
]
)
if gpu_logic:
sorted_det_out = sorted(
tmp_det_out, key=lambda tup: tup[1], reverse=True
)
else:
sorted_det_out = sorted(
tmp_det_out, key=lambda tup: tup[0], reverse=False
)
......@@ -747,7 +753,7 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op):
background = 0
nms_threshold = 0.3
nms_top_k = 400
keep_top_k = 200
keep_top_k = 200 if not hasattr(self, 'keep_top_k') else self.keep_top_k
score_threshold = self.score_threshold
scores = np.random.random((N * M, C)).astype('float32')
......@@ -768,6 +774,7 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op):
nms_threshold,
nms_top_k,
keep_top_k,
gpu_logic=self.gpu_logic if hasattr(self, 'gpu_logic') else None,
)
det_outs = np.array(det_outs)
......@@ -797,7 +804,8 @@ class TestMulticlassNMS3Op(TestMulticlassNMS2Op):
}
def test_check_output(self):
self.check_output()
place = paddle.CPUPlace()
self.check_output_with_place(place)
class TestMulticlassNMS3OpNoOutput(TestMulticlassNMS3Op):
......@@ -807,6 +815,51 @@ class TestMulticlassNMS3OpNoOutput(TestMulticlassNMS3Op):
self.score_threshold = 2.0
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMulticlassNMS3OpGPU(TestMulticlassNMS2Op):
def test_check_output(self):
place = paddle.CUDAPlace(0)
self.check_output_with_place(place)
def set_argument(self):
self.score_threshold = 0.01
self.gpu_logic = True
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMulticlassNMS3OpGPULessOutput(TestMulticlassNMS3OpGPU):
def set_argument(self):
# Here set 0.08 to make output box size less than keep_top_k
self.score_threshold = 0.08
self.gpu_logic = True
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMulticlassNMS3OpGPUNoOutput(TestMulticlassNMS3OpGPU):
def set_argument(self):
# Here set 2.0 to test the case there is no outputs.
# In practical use, 0.0 < score_threshold < 1.0
self.score_threshold = 2.0
self.gpu_logic = True
@unittest.skipIf(
not core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMulticlassNMS3OpGPUFallback(TestMulticlassNMS3OpGPU):
def set_argument(self):
# Setting keep_top_k < 0 will fall back to CPU kernel
self.score_threshold = 0.01
self.keep_top_k = -1
self.gpu_logic = True
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册