提交 8ae61752 编写于 作者: M Megvii Engine Team

feat(opr): let nms support empty IO

GitOrigin-RevId: 4c51b1aedb742fa97a655c74b7ba8d09bf3cdc96
上级 1a1748da
......@@ -600,7 +600,19 @@ def test_hinge_loss():
opr_test(cases, hinge_loss_with_l2_norm)
def test_nms():
@pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_nms(is_symbolic):
def fn(inp, scores):
return F.vision.nms(
inp,
scores=scores,
iou_thresh=0.5,
max_output=None if is_symbolic is None else 4,
)
if is_symbolic is not None:
fn = jit.trace(symbolic=is_symbolic)(fn)
x = np.array(
[
[0, 0, 100, 100],
......@@ -612,8 +624,16 @@ def test_nms():
)
inp = tensor(x)
scores = tensor([0.5, 0.8, 0.9, 0.6], dtype=np.float32)
result = F.vision.nms(inp, scores=scores, iou_thresh=0.5)
np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))
for _ in range(3):
result = fn(inp, scores=scores)
np.testing.assert_equal(result.numpy(), np.array([2, 1, 3], dtype=np.int32))
x = np.array([], dtype=np.float32,).reshape(0, 4)
inp = tensor(x)
scores = tensor([], dtype=np.float32)
for _ in range(3):
result = fn(inp, scores=scores)
np.testing.assert_equal(result.numpy(), np.array([], dtype=np.int32))
@pytest.mark.skipif(
......
......@@ -23,6 +23,7 @@ bool box_iou(Box a, Box b, float thresh) {
} // anonymous namespace
size_t mgb::opr::standalone::nms::cpu_kern_workspace(size_t nr_boxes) {
if (nr_boxes == 0) return 0;
return (((nr_boxes - 1) / sizeof(size_t)) + 1) * sizeof(size_t);
}
......
......@@ -40,11 +40,17 @@ class NMSKeep::CUDAKern final : public Kern {
void init(const NMSKeep* opr, const TensorShape& boxes) {
auto align = opr->comp_node().get_mem_addr_alignment();
size_t nr_boxes = boxes[1];
m_workspace_overlap_mask_bytes =
nr_boxes * DIVUP(nr_boxes, 64) * sizeof(uint64_t);
m_workspace_overlap_mask_bytes_align =
get_aligned_power2(m_workspace_overlap_mask_bytes, align);
m_workspace_rm_mask_bytes = DIVUP(nr_boxes, 64) * sizeof(uint64_t);
if (nr_boxes == 0) {
m_workspace_overlap_mask_bytes = 0;
m_workspace_overlap_mask_bytes_align = 0;
m_workspace_rm_mask_bytes = 0;
} else {
m_workspace_overlap_mask_bytes =
nr_boxes * DIVUP(nr_boxes, 64) * sizeof(uint64_t);
m_workspace_overlap_mask_bytes_align =
get_aligned_power2(m_workspace_overlap_mask_bytes, align);
m_workspace_rm_mask_bytes = DIVUP(nr_boxes, 64) * sizeof(uint64_t);
}
}
public:
......@@ -88,7 +94,10 @@ void NMSKeep::CUDAKern::exec(const NMSKeep* opr, const DeviceTensorND& inp,
auto out_idx_ptr = reinterpret_cast<uint32_t*>(out_idx.ptr<int32_t>()),
out_size_ptr = reinterpret_cast<uint32_t*>(out_size.ptr<int32_t>());
size_t batch = inp.shape(0), nr_boxes = inp.shape(1);
if (nr_boxes == 0) {
MGB_CUDA_CHECK(cudaMemsetAsync(out_size_ptr, 0, batch*sizeof(uint32_t), stream));
return;
}
MGB_CUDA_CHECK(cudaMemsetAsync(dev_overlap_mask, 0,
m_workspace_overlap_mask_bytes, stream));
......@@ -136,6 +145,12 @@ void NMSKeep::CPUKern::exec(const NMSKeep* opr, const DeviceTensorND& inp,
auto out_idx_ptr = reinterpret_cast<uint32_t*>(out_idx.ptr<int32_t>()),
out_size_ptr = reinterpret_cast<uint32_t*>(out_size.ptr<int32_t>());
size_t batch = inp.shape(0), nr_boxes = inp.shape(1);
if (nr_boxes == 0) {
for (size_t i = 0; i < batch; ++i) {
*(out_size_ptr + i) = 0;
}
return;
}
auto param = opr->param();
auto workspace_ptr = workspace.raw_ptr();
......@@ -183,7 +198,8 @@ NMSKeep::NMSKeep(VarNode* boxes, const Param& param,
}
add_input({boxes});
add_output("indices")->dtype(dtype::Int32());
add_output("indices")->dtype(dtype::Int32())
.add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
add_output("sizes")->dtype(dtype::Int32());
cg::add_workspace_output(this); // workspace is also an output var
......@@ -233,6 +249,13 @@ void NMSKeep::scn_do_execute() {
: empty_workspace);
}
NMSKeep::NodeProp* NMSKeep::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
#if MGB_ENABLE_FBS_SERIALIZATION
namespace mgb {
......
......@@ -53,6 +53,8 @@ private:
//! execute the operator
void scn_do_execute() override;
NodeProp* do_make_node_prop() const override;
};
} // namespace standalone
......
......@@ -55,6 +55,25 @@ void run_on_comp_node(const char* cn_name) {
}
}
void run_empty_input_on_comp_node(const char* cn_name) {
auto cn = CompNode::load(cn_name);
auto graph = ComputingGraph::make();
auto host_x = std::make_shared<HostTensorND>(cn, TensorShape{1, 0, 4},
dtype::Float32{});
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
{
auto idx = opr::standalone::NMSKeep::make(x, {0.2, 16});
auto size = idx.node()->owner_opr()->output(1);
HostTensorND host_idx, host_size;
auto func = graph->compile({make_callback_copy(idx, host_idx),
make_callback_copy(size, host_size)});
func->execute().wait();
auto size_ptr = host_size.ptr<int32_t>();
ASSERT_EQ(size_ptr[0], 0);
}
}
}
TEST(TestOprNMS, CPU) {
......@@ -66,6 +85,15 @@ TEST(TestOprNMS, GPU) {
run_on_comp_node("gpu0");
}
TEST(TestOprNMSEmptyIO, CPU) {
run_empty_input_on_comp_node("cpu0");
}
TEST(TestOprNMSEmptyIO, GPU) {
REQUIRE_GPU(1);
run_empty_input_on_comp_node("gpu0");
}
#if MGB_ENABLE_EXCEPTION
TEST(TestOprNMS, InvalidInput) {
HostTensorGenerator<> gen;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册