提交 b18178f1 编写于 作者: D dengkaipeng

refine metric

上级 308447ba
......@@ -72,11 +72,15 @@ class Metric(object):
raise NotImplementedError("function 'name' not implemented in {}.".
format(self.__class__.__name__))
def add_metric_op(self, pred, label):
def add_metric_op(self, *args):
"""
Add process op for metric in program
If :code:`add_metric_op` is defined, it will be called with outputs
of model and labels from data as parameter, all outputs and labels
will be concatenated and flatten to a list like follows:
[output1, output2, ..., label1, label2,...]
"""
return pred, label
return args
class Accuracy(Metric):
......@@ -91,12 +95,12 @@ class Accuracy(Metric):
self._init_name(name)
self.reset()
def add_metric_op(self, pred, label, *args, **kwargs):
pred = fluid.layers.argsort(pred[0], descending=True)[1][:, :self.maxk]
correct = pred == label[0]
def add_metric_op(self, pred, label, *args):
pred = fluid.layers.argsort(pred, descending=True)[1][:, :self.maxk]
correct = pred == label
return correct
def update(self, correct, *args, **kwargs):
def update(self, correct, *args):
accs = []
for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum()
......
......@@ -473,7 +473,7 @@ class StaticGraphAdapter(object):
if mode != 'test':
for metric in self.model._metrics:
metrics.append(
to_list(metric.add_metric_op(outputs, labels)))
to_list(metric.add_metric_op(*(outputs + labels))))
if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses)
......@@ -593,7 +593,7 @@ class DynamicGraphAdapter(object):
metrics = []
for metric in self.model._metrics:
metric_outs = metric.add_metric_op(
to_list(outputs), to_list(labels))
*(to_list(outputs) + to_list(labels)))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
......@@ -632,7 +632,8 @@ class DynamicGraphAdapter(object):
self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples
metric_outs = metric.add_metric_op(to_list(outputs), labels)
metric_outs = metric.add_metric_op(
*(to_list(outputs) + to_list(labels)))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册