未验证 提交 220190d5 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #16543 from Genie-Liu/fix-metrics.Precision

fix fluid.metrics.Precision bug
...@@ -227,7 +227,7 @@ class Precision(MetricBase): ...@@ -227,7 +227,7 @@ class Precision(MetricBase):
metric.reset() metric.reset()
for data in train_reader(): for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels]) loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(preds=preds, labels=labels) metric.update(preds=preds, labels=labels)
numpy_precision = metric.eval() numpy_precision = metric.eval()
""" """
...@@ -241,9 +241,11 @@ class Precision(MetricBase): ...@@ -241,9 +241,11 @@ class Precision(MetricBase):
raise ValueError("The 'preds' must be a numpy ndarray.") raise ValueError("The 'preds' must be a numpy ndarray.")
if not _is_numpy_(labels): if not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray.") raise ValueError("The 'labels' must be a numpy ndarray.")
sample_num = labels[0] sample_num = labels.shape[0]
preds = np.rint(preds).astype("int32")
for i in range(sample_num): for i in range(sample_num):
pred = preds[i].astype("int32") pred = preds[i]
label = labels[i] label = labels[i]
if label == 1: if label == 1:
if pred == label: if pred == label:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册