diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 9112f1b9b5c2d5fdb5f8abf945c666674fa76d40..9908b3541d612f985e4bd0bbcecbcf488e8ec719 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -104,7 +104,7 @@ class mINP(nn.Layer): keep_mask.astype('float32'), choosen_indices) equal_flag = paddle.logical_and(equal_flag, keep_mask.astype('bool')) - equal_flag = paddle.cast(equal_flag, 'float64') + equal_flag = paddle.cast(equal_flag, 'float32') num_rel = paddle.sum(equal_flag, axis=1) num_rel = paddle.greater_than(num_rel, paddle.to_tensor(0.)) @@ -113,10 +113,10 @@ class mINP(nn.Layer): equal_flag = paddle.index_select(equal_flag, num_rel_index, axis=0) #do accumulative sum - div = paddle.arange(equal_flag.shape[1]).astype("float64") + 2 + div = paddle.arange(equal_flag.shape[1]).astype("float32") + 2 minus = paddle.divide(equal_flag, div) auxilary = paddle.subtract(equal_flag, minus) - hard_index = paddle.argmax(auxilary, axis=1).astype("float64") + hard_index = paddle.argmax(auxilary, axis=1).astype("float32") all_INP = paddle.divide(paddle.sum(equal_flag, axis=1), hard_index) mINP = paddle.mean(all_INP) metric_dict["mINP"] = mINP.numpy()[0]