diff --git a/python/paddle/fluid/metrics.py b/python/paddle/fluid/metrics.py index fd07ff0ba3d21721fbbc46099f7dcb6937f93524..c7c82f28e7c441b4aa24ffa81a8695e565d737d8 100644 --- a/python/paddle/fluid/metrics.py +++ b/python/paddle/fluid/metrics.py @@ -227,7 +227,7 @@ class Precision(MetricBase): metric.reset() for data in train_reader(): loss, preds, labels = exe.run(fetch_list=[cost, preds, labels]) - metric.update(preds=preds, labels=labels) + metric.update(preds=preds, labels=labels) numpy_precision = metric.eval() """ @@ -241,9 +241,11 @@ class Precision(MetricBase): raise ValueError("The 'preds' must be a numpy ndarray.") if not _is_numpy_(labels): raise ValueError("The 'labels' must be a numpy ndarray.") - sample_num = labels[0] + sample_num = labels.shape[0] + preds = np.rint(preds).astype("int32") + for i in range(sample_num): - pred = preds[i].astype("int32") + pred = preds[i] label = labels[i] if label == 1: if pred == label: