提交 b18178f1 编写于 作者: D dengkaipeng

refine metric

上级 308447ba
...@@ -72,11 +72,15 @@ class Metric(object): ...@@ -72,11 +72,15 @@ class Metric(object):
raise NotImplementedError("function 'name' not implemented in {}.". raise NotImplementedError("function 'name' not implemented in {}.".
format(self.__class__.__name__)) format(self.__class__.__name__))
def add_metric_op(self, pred, label): def add_metric_op(self, *args):
""" """
Add process op for metric in program Add process op for metric in program
If :code:`add_metric_op` is defined, it will be called with outputs
of model and labels from data as parameter, all outputs and labels
will be concatenated and flatten to a list like follows:
[output1, output2, ..., label1, label2,...]
""" """
return pred, label return args
class Accuracy(Metric): class Accuracy(Metric):
...@@ -91,12 +95,12 @@ class Accuracy(Metric): ...@@ -91,12 +95,12 @@ class Accuracy(Metric):
self._init_name(name) self._init_name(name)
self.reset() self.reset()
def add_metric_op(self, pred, label, *args, **kwargs): def add_metric_op(self, pred, label, *args):
pred = fluid.layers.argsort(pred[0], descending=True)[1][:, :self.maxk] pred = fluid.layers.argsort(pred, descending=True)[1][:, :self.maxk]
correct = pred == label[0] correct = pred == label
return correct return correct
def update(self, correct, *args, **kwargs): def update(self, correct, *args):
accs = [] accs = []
for i, k in enumerate(self.topk): for i, k in enumerate(self.topk):
num_corrects = correct[:, :k].sum() num_corrects = correct[:, :k].sum()
......
...@@ -473,7 +473,7 @@ class StaticGraphAdapter(object): ...@@ -473,7 +473,7 @@ class StaticGraphAdapter(object):
if mode != 'test': if mode != 'test':
for metric in self.model._metrics: for metric in self.model._metrics:
metrics.append( metrics.append(
to_list(metric.add_metric_op(outputs, labels))) to_list(metric.add_metric_op(*(outputs + labels))))
if mode == 'train' and self.model._optimizer: if mode == 'train' and self.model._optimizer:
self._loss_endpoint = fluid.layers.sum(losses) self._loss_endpoint = fluid.layers.sum(losses)
...@@ -593,7 +593,7 @@ class DynamicGraphAdapter(object): ...@@ -593,7 +593,7 @@ class DynamicGraphAdapter(object):
metrics = [] metrics = []
for metric in self.model._metrics: for metric in self.model._metrics:
metric_outs = metric.add_metric_op( metric_outs = metric.add_metric_op(
to_list(outputs), to_list(labels)) *(to_list(outputs) + to_list(labels)))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m) metrics.append(m)
...@@ -632,7 +632,8 @@ class DynamicGraphAdapter(object): ...@@ -632,7 +632,8 @@ class DynamicGraphAdapter(object):
self._merge_count[self.mode + '_total'] += samples self._merge_count[self.mode + '_total'] += samples
self._merge_count[self.mode + '_batch'] = samples self._merge_count[self.mode + '_batch'] = samples
metric_outs = metric.add_metric_op(to_list(outputs), labels) metric_outs = metric.add_metric_op(
*(to_list(outputs) + to_list(labels)))
m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)])
metrics.append(m) metrics.append(m)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册