未验证 提交 dc2a5e54 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #37 from heavengate/refine_metric

refine metric
......@@ -39,11 +39,6 @@ class BmnMetric(Metric):
elif self.mode == 'infer':
self.get_infer_dataset_dict()
def add_metric_op(self, preds, label):
pred_bm, pred_start, pred_en = preds
video_index = label[-1]
return [pred_bm, pred_start, pred_en, video_index] #return list
def update(self, pred_bm, pred_start, pred_end, fid):
# generate proposals
pred_start = pred_start[0]
......
......@@ -48,9 +48,16 @@ class Metric(object):
format(self.__class__.__name__))
@abc.abstractmethod
def update(self, *args, **kwargs):
def update(self, *args):
"""
Update states for metric
Inputs of :code:`update` is the outputs of :code:`Metric.add_metric_op`,
if :code:`add_metric_op` is not defined, the inputs of :code:`update`
will be flatten arguments of **output** of mode and **label** from data:
:code:`update(output1, output2, ..., label1, label2,...)`
see :code:`Metric.add_metric_op`
"""
raise NotImplementedError("function 'update' not implemented in {}.".
format(self.__class__.__name__))
......@@ -72,11 +79,26 @@ class Metric(object):
raise NotImplementedError("function 'name' not implemented in {}.".
format(self.__class__.__name__))
def add_metric_op(self, pred, label):
def add_metric_op(self, *args):
"""
Add process op for metric in program
This API is advanced usage to accelerate metric calculating, calulations
from outputs of model to the states which should be updated by Metric can
be defined here, where Paddle OPs is also supported. Outputs of this API
will be the inputs of "Metric.update".
If :code:`add_metric_op` is defined, it will be called with **outputs**
of model and **labels** from data as arguments, all outputs and labels
will be concatenated and flatten and each filed as a separate argument
as follows:
:code:`add_metric_op(output1, output2, ..., label1, label2,...)`
If :code:`add_metric_op` is not defined, default behaviour is to pass
input to output, so output format will be:
:code:`return output1, output2, ..., label1, label2,...`
see :code:`Metric.update`
"""
return pred, label
return args
class Accuracy(Metric):
......@@ -91,12 +113,12 @@ class Accuracy(Metric):
self._init_name(name)
self.reset()
def add_metric_op(self, pred, label, *args, **kwargs):
pred = fluid.layers.argsort(pred[0], descending=True)[1][:, :self.maxk]
correct = pred == label[0]
def add_metric_op(self, pred, label, *args):
pred = fluid.layers.argsort(pred, descending=True)[1][:, :self.maxk]
correct = pred == label
return correct
def update(self, correct, *args, **kwargs):
def update(self, correct, *args):
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
......
......@@ -67,7 +67,7 @@ def to_list(value):
if value is None:
return value
if isinstance(value, (list, tuple)):
return value
return list(value)
return [value]
......@@ -473,7 +473,7 @@ class StaticGraphAdapter(object):
if mode != 'test':
for metric in self.model._metrics:
metrics.append(
to_list(metric.add_metric_op(outputs, labels)))
to_list(metric.add_metric_op(*(outputs + labels))))
if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses)
......@@ -593,7 +593,7 @@ class DynamicGraphAdapter(object):
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(
to_list(outputs), to_list(labels))
*(to_list(outputs) + to_list(labels)))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
......@@ -632,7 +632,8 @@ class DynamicGraphAdapter(object):
self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples
metric_outs = metric.add_metric_op(to_list(outputs), labels)
metric_outs = metric.add_metric_op(
*(to_list(outputs) + to_list(labels)))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册