提交 83b22941 编写于 作者: G Genieliu

1 metric update should be in the for loop. 2 sample_num is the shape's first...

1 metric update should be in the for loop. 2 sample_num is the shape's first element 3 preds cast should use round, astype("int32") method will cast all the elements(which is less than 1) to 0.
上级 d4f63d82
......@@ -227,7 +227,7 @@ class Precision(MetricBase):
metric.reset()
for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(preds=preds, labels=labels)
metric.update(preds=preds, labels=labels)
numpy_precision = metric.eval()
"""
......@@ -241,9 +241,10 @@ class Precision(MetricBase):
raise ValueError("The 'preds' must be a numpy ndarray.")
if not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray.")
sample_num = labels[0]
sample_num = labels.shape[0]
preds = np.rint(preds).astype("int32")
for i in range(sample_num):
pred = preds[i].astype("int32")
pred = preds[i]
label = labels[i]
if label == 1:
if pred == label:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册