提交 32d17598 编写于 作者: D Dang Qingqing

Refine detection mAP in metrics.py.

evaluator.py throws warnings and tell users to use DetectionMAP in metrics.py, but it is wrong in metrics.py.
So refine this API in metrics.py.

test=develop
上级 765085d2
......@@ -316,7 +316,7 @@ class DetectionMAP(Evaluator):
gt_label (Variable): The ground truth label index, which is a LoDTensor
with shape [N, 1].
gt_box (Variable): The ground truth bounding box (bbox), which is a
LoDTensor with shape [N, 6]. The layout is [xmin, ymin, xmax, ymax].
LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax].
gt_difficult (Variable|None): Whether this ground truth is a difficult
bounding bbox, which can be a LoDTensor [N, 1] or not set. If None,
it means all the ground truth labels are not difficult bbox.
......
......@@ -478,67 +478,6 @@ class EditDistance(MetricBase):
return avg_distance, avg_instance_error
class DetectionMAP(MetricBase):
"""
Calculate the detection mean average precision (mAP).
mAP is the metric to measure the accuracy of object detectors
like Faster R-CNN, SSD, etc.
It is the average of the maximum precisions at different recall values.
Please get more information from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
The general steps are as follows:
1. calculate the true positive and false positive according to the input
of detection and labels.
2. calculate mAP value, support two versions: '11 point' and 'integral'.
Examples:
.. code-block:: python
pred = fluid.layers.fc(input=data, size=1000, act="tanh")
batch_map = layers.detection_map(
input,
label,
class_num,
background_label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
ap_version=ap_version)
metric = fluid.metrics.DetectionMAP()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, batch_map])
batch_size = data[0]
metric.update(value=batch_map, weight=batch_size)
numpy_map = metric.eval()
"""
def __init__(self, name=None):
super(DetectionMAP, self).__init__(name)
# the current map value
self.value = .0
self.weight = .0
def update(self, value, weight):
if not _is_number_or_matrix_(value):
raise ValueError(
"The 'value' must be a number(int, float) or a numpy ndarray.")
if not _is_number_(weight):
raise ValueError("The 'weight' must be a number(int, float).")
self.value += value
self.weight += weight
def eval(self):
if self.weight == 0:
raise ValueError(
"There is no data in DetectionMAP Metrics. "
"Please check layers.detection_map output has added to DetectionMAP."
)
return self.value / self.weight
class Auc(MetricBase):
"""
Auc metric adapts to the binary classification.
......@@ -616,3 +555,185 @@ class Auc(MetricBase):
idx -= 1
return auc / tot_pos / tot_neg if tot_pos > 0.0 and tot_neg > 0.0 else 0.0
class DetectionMAP(object):
"""
Calculate the detection mean average precision (mAP).
The general steps are as follows:
1. calculate the true positive and false positive according to the input
of detection and labels.
2. calculate mAP value, support two versions: '11 point' and 'integral'.
Please get more information from the following articles:
https://sanchom.wordpress.com/tag/average-precision/
https://arxiv.org/abs/1512.02325
Args:
input (Variable): The detection results, which is a LoDTensor with shape
[M, 6]. The layout is [label, confidence, xmin, ymin, xmax, ymax].
gt_label (Variable): The ground truth label index, which is a LoDTensor
with shape [N, 1].
gt_box (Variable): The ground truth bounding box (bbox), which is a
LoDTensor with shape [N, 4]. The layout is [xmin, ymin, xmax, ymax].
gt_difficult (Variable|None): Whether this ground truth is a difficult
bounding bbox, which can be a LoDTensor [N, 1] or not set. If None,
it means all the ground truth labels are not difficult bbox.
class_num (int): The class number.
background_label (int): The index of background label, the background
label will be ignored. If set to -1, then all categories will be
considered, 0 by defalut.
overlap_threshold (float): The threshold for deciding true/false
positive, 0.5 by defalut.
evaluate_difficult (bool): Whether to consider difficult ground truth
for evaluation, True by defalut. This argument does not work when
gt_difficult is None.
ap_version (string): The average precision calculation ways, it must be
'integral' or '11point'. Please check
https://sanchom.wordpress.com/tag/average-precision/ for details.
- 11point: the 11-point interpolated average precision.
- integral: the natural integral of the precision-recall curve.
Examples:
.. code-block:: python
exe = fluid.executor(place)
map_evaluator = fluid.Evaluator.DetectionMAP(input,
gt_label, gt_box, gt_difficult)
cur_map, accum_map = map_evaluator.get_map_var()
fetch = [cost, cur_map, accum_map]
for epoch in PASS_NUM:
map_evaluator.reset(exe)
for data in batches:
loss, cur_map_v, accum_map_v = exe.run(fetch_list=fetch)
In the above example:
'cur_map_v' is the mAP of current mini-batch.
'accum_map_v' is the accumulative mAP of one pass.
"""
def __init__(self,
input,
gt_label,
gt_box,
gt_difficult=None,
class_num=None,
background_label=0,
overlap_threshold=0.5,
evaluate_difficult=True,
ap_version='integral'):
from . import layers
from .layer_helper import LayerHelper
from .initializer import Constant
self.helper = LayerHelper('map_eval')
gt_label = layers.cast(x=gt_label, dtype=gt_box.dtype)
if gt_difficult:
gt_difficult = layers.cast(x=gt_difficult, dtype=gt_box.dtype)
label = layers.concat([gt_label, gt_difficult, gt_box], axis=1)
else:
label = layers.concat([gt_label, gt_box], axis=1)
# calculate mean average precision (mAP) of current mini-batch
map = layers.detection_map(
input,
label,
class_num,
background_label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
ap_version=ap_version)
states = []
states.append(
self._create_state(
dtype='int32', shape=None, suffix='accum_pos_count'))
states.append(
self._create_state(
dtype='float32', shape=None, suffix='accum_true_pos'))
states.append(
self._create_state(
dtype='float32', shape=None, suffix='accum_false_pos'))
var = self._create_state(dtype='int32', shape=[1], suffix='has_state')
self.helper.set_variable_initializer(
var, initializer=Constant(value=int(0)))
self.has_state = var
# calculate accumulative mAP
accum_map = layers.detection_map(
input,
label,
class_num,
background_label,
overlap_threshold=overlap_threshold,
evaluate_difficult=evaluate_difficult,
has_state=self.has_state,
input_states=states,
out_states=states,
ap_version=ap_version)
layers.fill_constant(
shape=self.has_state.shape,
value=1,
dtype=self.has_state.dtype,
out=self.has_state)
self.cur_map = map
self.accum_map = accum_map
def _create_state(self, suffix, dtype, shape):
"""
Create state variable.
Args:
suffix(str): the state suffix.
dtype(str|core.VarDesc.VarType): the state data type
shape(tuple|list): the shape of state
Returns: State variable
"""
from . import unique_name
state = self.helper.create_variable(
name="_".join([unique_name.generate(self.helper.name), suffix]),
persistable=True,
dtype=dtype,
shape=shape)
return state
def get_map_var(self):
"""
Returns: mAP variable of current mini-batch and
accumulative mAP variable cross mini-batches.
"""
return self.cur_map, self.accum_map
def reset(self, executor, reset_program=None):
"""
reset metric states at the begin of each pass/user specified batch.
Args:
executor(Executor|ParallelExecutor): a executor for executing
the reset_program.
reset_program(Program|None): a single Program for reset process.
If None, will create a Program.
"""
from .framework import Program, Variable, program_guard
from . import layers
def _clone_var_(block, var):
assert isinstance(var, Variable)
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level,
persistable=var.persistable)
if reset_program is None:
reset_program = Program()
with program_guard(main_program=reset_program):
var = _clone_var_(reset_program.current_block(), self.has_state)
layers.fill_constant(
shape=var.shape, value=0, dtype=var.dtype, out=var)
executor.run(reset_program)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册