From ca7c0a9afbe59ccc031ad625556d96daaef7517c Mon Sep 17 00:00:00 2001 From: dengkaipeng Date: Mon, 9 Mar 2020 07:08:38 +0000 Subject: [PATCH] refine accuracy --- metrics.py | 34 ++++++++++++++++++------- mnist.py | 18 +++++--------- model.py | 73 ++++++++++++++++++++++++++++++++++++------------------ 3 files changed, 80 insertions(+), 45 deletions(-) diff --git a/metrics.py b/metrics.py index b9cb49c..f3772d7 100644 --- a/metrics.py +++ b/metrics.py @@ -17,6 +17,7 @@ from __future__ import absolute_import import six import abc import numpy as np +import paddle.fluid as fluid import logging FORMAT = '%(asctime)s-%(levelname)s: %(message)s' @@ -59,6 +60,12 @@ class Metric(object): """ raise NotImplementedError("function 'accumulate' not implemented in {}.".format(self.__class__.__name__)) + def add_metric_op(self, pred, label): + """ + Add process op for metric in program + """ + return pred, label + class Accuracy(Metric): """ @@ -71,19 +78,28 @@ class Accuracy(Metric): self.maxk = max(topk) self.reset() - def update(self, pred, label, *args, **kwargs): - pred = np.argsort(pred[0])[:, ::-1][:, :self.maxk] - corr = (pred == np.repeat(label[0], self.maxk, 1)) - self.correct = np.append(self.correct, corr, axis=0) + def add_metric_op(self, pred, label, *args, **kwargs): + pred = fluid.layers.argsort(pred[0], descending=True)[1][:, :self.maxk] + correct = pred == label[0] + return correct + + def update(self, correct, *args, **kwargs): + accs = [] + for i, k in enumerate(self.topk): + num_corrects = correct[:, :k].sum() + num_samples = len(correct) + accs.append(float(num_corrects) / num_samples) + self.total[i] += num_corrects + self.count[i] += num_samples + return accs def reset(self): - self.correct = np.empty((0, self.maxk), dtype="int32") + self.total = [0.] * len(self.topk) + self.count = [0] * len(self.topk) def accumulate(self): res = [] - num_samples = self.correct.shape[0] - for k in self.topk: - correct_k = self.correct[:, :k].sum() - res.append(round(100.0 * correct_k / num_samples, 2)) + for t, c in zip(self.total, self.count): + res.append(float(t) / c) return res diff --git a/mnist.py b/mnist.py index 91fa18d..c09fe30 100644 --- a/mnist.py +++ b/mnist.py @@ -150,20 +150,16 @@ def main(): for e in range(FLAGS.epoch): train_loss = 0.0 - train_acc = 0.0 val_loss = 0.0 - val_acc = 0.0 print("======== train epoch {} ========".format(e)) for idx, batch in enumerate(train_loader()): - outputs, losses = model.train(batch[0], batch[1], device='gpu', + losses, metrics = model.train(batch[0], batch[1], device='gpu', device_ids=device_ids) - acc = accuracy(outputs[0], batch[1])[0] train_loss += np.sum(losses) - train_acc += acc if idx % 10 == 0: - print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format( - idx, train_loss / (idx + 1), train_acc / (idx + 1))) + print("{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%".format( + idx, train_loss / (idx + 1), metrics[0][0], metrics[0][1])) for metric in model._metrics: res = metric.accumulate() print("train epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}".format(e, res[0], res[1])) @@ -171,15 +167,13 @@ def main(): print("======== eval epoch {} ========".format(e)) for idx, batch in enumerate(val_loader()): - outputs, losses = model.eval(batch[0], batch[1], device='gpu', + losses, metrics = model.eval(batch[0], batch[1], device='gpu', device_ids=device_ids) - acc = accuracy(outputs[0], batch[1])[0] val_loss += np.sum(losses) - val_acc += acc if idx % 10 == 0: - print("{:04d}: loss {:0.3f} top1: {:0.3f}%".format( - idx, val_loss / (idx + 1), val_acc / (idx + 1))) + print("{:04d}: loss {:0.3f} top1: {:0.3f}% top2: {:0.3f}%".format( + idx, val_loss / (idx + 1), metrics[0][0], metrics[0][1])) for metric in model._metrics: res = metric.accumulate() print("eval epoch {:03d}: top1: {:0.3f}%, top2: {:0.3f}".format(e, res[0], res[1])) diff --git a/model.py b/model.py index 3949e31..0c30e55 100644 --- a/model.py +++ b/model.py @@ -45,6 +45,26 @@ def to_numpy(var): return np.array(t) +def flatten_list(l): + assert isinstance(l, list), "not a list" + outl = [] + splits = [] + for sl in l: + assert isinstance(sl, list), "sub content not a list" + splits.append(len(sl)) + outl += sl + return outl, splits + + +def restore_flatten_list(l, splits): + outl = [] + for split in splits: + assert len(l) >= split, "list length invalid" + sl, l = l[:split], l[split:] + outl.append(sl) + return outl + + def extract_args(func): if hasattr(inspect, 'getfullargspec'): return inspect.getfullargspec(func)[0] @@ -278,28 +298,26 @@ class StaticGraphAdapter(object): feed[v.name] = labels[idx] endpoints = self._endpoints[self.mode] - fetch_list = endpoints['output'] + endpoints['label'] + endpoints['loss'] - num_output = len(endpoints['output']) - num_label = len(endpoints['label']) + if self.mode == 'test': + fetch_list = endpoints['output'] + else: + metric_list, metric_splits = flatten_list(endpoints['metric']) + fetch_list = endpoints['loss'] + metric_list + num_loss = len(endpoints['loss']) rets = self._executor.run( compiled_prog, feed=feed, fetch_list=fetch_list, return_numpy=False) # LoDTensor cannot be fetch as numpy directly rets = [np.array(v) for v in rets] - outputs = rets[:num_output] - labels = rets[num_output:num_output+num_label] - losses = rets[num_output+num_label:] if self.mode == 'test': - return outputs - elif self.mode == 'eval': - for metric in self.model._metrics: - metric.update(outputs, labels) - return outputs, losses - else: # train - for metric in self.model._metrics: - metric.update(outputs, labels) - return outputs, losses + return rets[:] + losses = rets[:num_loss] + metric_states = restore_flatten_list(rets[num_loss:], metric_splits) + metrics = [] + for metric, state in zip(self.model._metrics, metric_states): + metrics.append(metric.update(*state)) + return losses, metrics def _make_program(self, inputs): prog = self._orig_prog.clone() @@ -314,6 +332,9 @@ class StaticGraphAdapter(object): label_vars = self._infer_label_vars(outputs) self._label_vars[self.mode] = label_vars losses = self.model._loss_function(outputs, label_vars) + metrics = [] + for metric in self.model._metrics: + metrics.append(to_list(metric.add_metric_op(outputs, label_vars))) if self.mode == 'train': self._loss_endpoint = fluid.layers.sum(losses) self.model._optimizer.minimize(self._loss_endpoint) @@ -322,8 +343,8 @@ class StaticGraphAdapter(object): self._progs[self.mode] = prog self._endpoints[self.mode] = { "output": outputs, - "label": label_vars, "loss": losses, + "metric": metrics, } def _infer_input_vars(self, inputs): @@ -421,16 +442,18 @@ class DynamicGraphAdapter(object): self.mode = 'train' inputs = to_list(inputs) labels = to_list(labels) - outputs = self.model.forward(*[to_variable(x) for x in inputs]) + outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs])) losses = self.model._loss_function(outputs, labels) final_loss = fluid.layers.sum(losses) final_loss.backward() self.model._optimizer.minimize(final_loss) self.model.clear_gradients() + metrics = [] for metric in self.model._metrics: - metric.update([to_numpy(o) for o in to_list(outputs)], labels) - return [to_numpy(o) for o in to_list(outputs)], \ - [to_numpy(l) for l in losses] + metric_outs = metric.add_metric_op(outputs, [to_variable(l) for l in labels]) + m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) + metrics.append(m) + return [to_numpy(l) for l in losses], metrics def eval(self, inputs, labels, device='CPU', device_ids=None): assert self.model._loss_function, \ @@ -439,12 +462,14 @@ class DynamicGraphAdapter(object): self.mode = 'eval' inputs = to_list(inputs) labels = to_list(labels) - outputs = self.model.forward(*[to_variable(x) for x in inputs]) + outputs = to_list(self.model.forward(*[to_variable(x) for x in inputs])) losses = self.model._loss_function(outputs, labels) + metrics = [] for metric in self.model._metrics: - metric.update([to_numpy(o) for o in to_list(outputs)], labels) - return [to_numpy(o) for o in to_list(outputs)], \ - [to_numpy(l) for l in losses] + metric_outs = metric.add_metric_op(outputs, [to_variable(l) for l in labels]) + m = metric.update(*[to_numpy(m) for m in to_list(metric_outs)]) + metrics.append(m) + return [to_numpy(l) for l in losses], metrics def test(self, inputs, device='CPU', device_ids=None): super(Model, self.model).eval() -- GitLab