提交 4161328e 编写于 作者: W wanghaox

add warrper for detection map operator

上级 1924aa14
......@@ -16,6 +16,7 @@ All layers just related to the detection neural network.
"""
from layer_function_generator import generate_layer_fn
from layer_function_generator import autodoc
from ..layer_helper import LayerHelper
import tensor
import ops
......@@ -28,6 +29,7 @@ __all__ = [
'target_assign',
'detection_output',
'ssd_loss',
'detection_map',
]
__auto__ = [
......@@ -132,6 +134,44 @@ def detection_output(scores,
return nmsed_outs
@autodoc()
def detection_map(detect_res,
label,
pos_count=None,
true_pos=None,
false_pos=None,
overlap_threshold=0.3,
evaluate_difficult=True,
ap_type='integral'):
helper = LayerHelper("detection_map", **locals())
map_out = helper.create_tmp_variable(dtype='float32')
accum_pos_count_out = helper.create_tmp_variable(dtype='int32')
accum_true_pos_out = helper.create_tmp_variable(dtype='float32')
accum_false_pos_out = helper.create_tmp_variable(dtype='float32')
helper.append_op(
type="detection_map",
inputs={
'Label': label,
'DetectRes': detect_res,
'PosCount': pos_count,
'TruePos': true_pos,
'FalsePos': false_pos
},
outputs={
'MAP': map_out,
'AccumPosCount': accum_pos_count_out,
'AccumTruePos': accum_true_pos_out,
'AccumFalsePos': accum_false_pos_out
},
attrs={
'overlap_threshold': overlap_threshold,
'evaluate_difficult': evaluate_difficult,
'ap_type': ap_type
})
return map_out, accum_pos_count_out, accum_true_pos_out, accum_false_pos_out
def bipartite_match(dist_matrix, name=None):
"""
**Bipartite matchint operator**
......
......@@ -145,5 +145,43 @@ class TestMultiBoxHead(unittest.TestCase):
return mbox_locs, mbox_confs, box, var
class TestDetectionMAP(unittest.TestCase):
def test_detection_map(self):
program = Program()
with program_guard(program):
detect_res = layers.data(
name='detect_res',
shape=[10, 6],
append_batch_size=False,
dtype='float32')
label = layers.data(
name='label',
shape=[10, 6],
append_batch_size=False,
dtype='float32')
map_out, accum_pos_count_out, accum_true_pos_out, accum_false_pos_out = layers.detection_map(
detect_res=detect_res, label=label)
self.assertIsNotNone(map_out)
self.assertIsNotNone(accum_pos_count_out)
self.assertIsNotNone(accum_true_pos_out)
self.assertIsNotNone(accum_false_pos_out)
self.assertEqual(map_out.shape, (1, ))
map_out, accum_pos_count_out2, accum_true_pos_out2, accum_false_pos_out2 = layers.detection_map(
detect_res=detect_res, label=label)
self.assertIsNotNone(map_out)
self.assertIsNotNone(accum_pos_count_out2)
self.assertIsNotNone(accum_true_pos_out2)
self.assertIsNotNone(accum_false_pos_out2)
self.assertEqual(map_out.shape, (1, ))
self.assertEqual(accum_pos_count_out.shape,
accum_pos_count_out2.shape)
self.assertEqual(accum_true_pos_out.shape,
accum_true_pos_out2.shape)
self.assertEqual(accum_false_pos_out.shape,
accum_false_pos_out2.shape)
print(str(program))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册