未验证 提交 dc2a5e54 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #37 from heavengate/refine_metric

refine metric
...@@ -39,11 +39,6 @@ class BmnMetric(Metric): ...@@ -39,11 +39,6 @@ class BmnMetric(Metric):
elif self.mode == 'infer': elif self.mode == 'infer':
self.get_infer_dataset_dict() self.get_infer_dataset_dict()
def add_metric_op(self, preds, label):
pred_bm, pred_start, pred_en = preds
video_index = label[-1]
return [pred_bm, pred_start, pred_en, video_index] #return list
def update(self, pred_bm, pred_start, pred_end, fid): def update(self, pred_bm, pred_start, pred_end, fid):
# generate proposals # generate proposals
pred_start = pred_start[0] pred_start = pred_start[0]
......
...@@ -48,9 +48,16 @@ class Metric(object): ...@@ -48,9 +48,16 @@ class Metric(object):
format(self.__class__.__name__)) format(self.__class__.__name__))
@abc.abstractmethod @abc.abstractmethod
def update(self, *args, **kwargs): def update(self, *args):
""" """
Update states for metric Update states for metric
Inputs of :code:`update` is the outputs of :code:`Metric.add_metric_op`,
if :code:`add_metric_op` is not defined, the inputs of :code:`update`
will be flatten arguments of **output** of mode and **label** from data:
:code:`update(output1, output2, ..., label1, label2,...)`
see :code:`Metric.add_metric_op`
""" """
raise NotImplementedError("function 'update' not implemented in {}.". raise NotImplementedError("function 'update' not implemented in {}.".
format(self.__class__.__name__)) format(self.__class__.__name__))
...@@ -72,11 +79,26 @@ class Metric(object): ...@@ -72,11 +79,26 @@ class Metric(object):
raise NotImplementedError("function 'name' not implemented in {}.". raise NotImplementedError("function 'name' not implemented in {}.".
format(self.__class__.__name__)) format(self.__class__.__name__))
def add_metric_op(self, pred, label): def add_metric_op(self, *args):
""" """
Add process op for metric in program This API is advanced usage to accelerate metric calculating, calulations
from outputs of model to the states which should be updated by Metric can
be defined here, where Paddle OPs is also supported. Outputs of this API
will be the inputs of "Metric.update".
If :code:`add_metric_op` is defined, it will be called with **outputs**
of model and **labels** from data as arguments, all outputs and labels
will be concatenated and flatten and each filed as a separate argument
as follows:
:code:`add_metric_op(output1, output2, ..., label1, label2,...)`
If :code:`add_metric_op` is not defined, default behaviour is to pass
input to output, so output format will be:
:code:`return output1, output2, ..., label1, label2,...`
see :code:`Metric.update`
""" """
return pred, label return args
class Accuracy(Metric): class Accuracy(Metric):
...@@ -91,12 +113,12 @@ class Accuracy(Metric): ...@@ -91,12 +113,12 @@ class Accuracy(Metric):
self._init_name(name) self._init_name(name)
self.reset() self.reset()
def add_metric_op(self, pred, label, *args, **kwargs): def add_metric_op(self, pred, label, *args):
pred = fluid.layers.argsort(pred[0], descending=True)[1][:, :self.maxk] pred = fluid.layers.argsort(pred, descending=True)[1][:, :self.maxk]
correct = pred == label[0] correct = pred == label
return correct return correct
def update(self, correct, *args, **kwargs): def update(self, correct, *args):
accs = [] accs = []
for i, k in enumerate(self.topk): for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum() num_corrects = correct[:, :k].sum()
......
...@@ -67,7 +67,7 @@ def to_list(value): ...@@ -67,7 +67,7 @@ def to_list(value):
if value is None: if value is None:
return value return value
if isinstance(value, (list, tuple)): if isinstance(value, (list, tuple)):
return value return list(value)
return [value] return [value]
...@@ -473,7 +473,7 @@ class StaticGraphAdapter(object): ...@@ -473,7 +473,7 @@ class StaticGraphAdapter(object):
if mode != 'test': if mode != 'test':
for metric in self.model._metrics: for metric in self.model._metrics:
metrics.append( metrics.append(
to_list(metric.add_metric_op(outputs, labels))) to_list(metric.add_metric_op(*(outputs + labels))))
if mode == 'train' and self.model._optimizer: if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses) self._loss_endpoint = fluid.layers.sum(losses)
...@@ -593,7 +593,7 @@ class DynamicGraphAdapter(object): ...@@ -593,7 +593,7 @@ class DynamicGraphAdapter(object):
metrics = [] metrics = []
for metric in self.model._metrics: for metric in self.model._metrics:
metric_outs = metric.add_metric_op( metric_outs = metric.add_metric_op(
to_list(outputs), to_list(labels)) *(to_list(outputs) + to_list(labels)))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m) metrics.append(m)
...@@ -632,7 +632,8 @@ class DynamicGraphAdapter(object): ...@@ -632,7 +632,8 @@ class DynamicGraphAdapter(object):
self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples self._merge_count[self.mode + '_batch'] = samples
metric_outs = metric.add_metric_op(to_list(outputs), labels) metric_outs = metric.add_metric_op(
*(to_list(outputs) + to_list(labels)))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m) metrics.append(m)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册