From 83b22941af652044770c90442f3ed67e7fe08a3f Mon Sep 17 00:00:00 2001 From: Genieliu Date: Fri, 29 Mar 2019 12:49:33 +0800 Subject: [PATCH] 1 metric update should be in the for loop. 2 sample_num is the shape's first element 3 preds cast should use round, astype("int32") method will cast all the elements(which is less than 1) to 0. --- python/paddle/fluid/metrics.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index fd07ff0ba3d..dfe8259139d 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -227,7 +227,7 @@ class Precision(MetricBase): metric.reset() for data in train_reader(): loss, preds, labels = exe.run(fetch_list=[cost, preds, labels]) - metric.update(preds=preds, labels=labels) + metric.update(preds=preds, labels=labels) numpy_precision = metric.eval() """ @@ -241,9 +241,10 @@ class Precision(MetricBase): raise ValueError("The 'preds' must be a numpy ndarray.") if not _is_numpy_(labels): raise ValueError("The 'labels' must be a numpy ndarray.") - sample_num = labels[0] + sample_num = labels.shape[0] + preds = np.rint(preds).astype("int32") for i in range(sample_num): - pred = preds[i].astype("int32") + pred = preds[i] label = labels[i] if label == 1: if pred == label: -- GitLab