提交 1e1b4d0f 编写于 作者: D dongshuilong

fix map metric bugs

上级 5b3d4031
......@@ -104,7 +104,7 @@ class mINP(nn.Layer):
keep_mask.astype('float32'), choosen_indices)
equal_flag = paddle.logical_and(equal_flag,
keep_mask.astype('bool'))
equal_flag = paddle.cast(equal_flag, 'float64')
equal_flag = paddle.cast(equal_flag, 'float32')
num_rel = paddle.sum(equal_flag, axis=1)
num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.))
......@@ -113,10 +113,10 @@ class mINP(nn.Layer):
equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0)
#do accumulative sum
div = paddle.arange(equal_flag.shape[1]).astype("float64") + 2
div = paddle.arange(equal_flag.shape[1]).astype("float32") + 2
minus = paddle.divide(equal_flag, div)
auxilary = paddle.subtract(equal_flag, minus)
hard_index = paddle.argmax(auxilary, axis=1).astype("float64")
hard_index = paddle.argmax(auxilary, axis=1).astype("float32")
all_INP = paddle.divide(paddle.sum(equal_flag, axis=1), hard_index)
mINP = paddle.mean(all_INP)
metric_dict["mINP"] = mINP.numpy()[0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册