提交 0be1ce58 编写于 作者: G Genieliu

test=develop

上级 83b22941
......@@ -243,6 +243,7 @@ class Precision(MetricBase):
raise ValueError("The 'labels' must be a numpy ndarray.")
sample_num = labels.shape[0]
preds = np.rint(preds).astype("int32")
for i in range(sample_num):
pred = preds[i]
label = labels[i]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册