提交 83b22941 编写于 作者: G Genieliu

1 metric update should be in the for loop. 2 sample_num is the shape's first...

1 metric update should be in the for loop. 2 sample_num is the shape's first element 3 preds cast should use round, astype("int32") method will cast all the elements(which is less than 1) to 0.
上级 d4f63d82
...@@ -241,9 +241,10 @@ class Precision(MetricBase): ...@@ -241,9 +241,10 @@ class Precision(MetricBase):
raise ValueError("The 'preds' must be a numpy ndarray.") raise ValueError("The 'preds' must be a numpy ndarray.")
if not _is_numpy_(labels): if not _is_numpy_(labels):
raise ValueError("The 'labels' must be a numpy ndarray.") raise ValueError("The 'labels' must be a numpy ndarray.")
sample_num = labels[0] sample_num = labels.shape[0]
preds = np.rint(preds).astype("int32")
for i in range(sample_num): for i in range(sample_num):
pred = preds[i].astype("int32") pred = preds[i]
label = labels[i] label = labels[i]
if label == 1: if label == 1:
if pred == label: if pred == label:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册