提交 215f88f3 编写于 作者: M Megvii Engine Team

fix(dnn/argmxx): fix argmxx on inf

GitOrigin-RevId: 740f67b73a242b9254699a4a09835d4c3c11ca02
上级 03d0cc02
......@@ -56,7 +56,7 @@ struct ArgmxxOp {
ArgmxxOp(stype_ *src, dt_int32 *dst, uint32_t A, uint32_t B, uint32_t C):
src(src), dst(dst), A(A), B(B), C(C),
INIT(wtype(is_max ? DTypeTrait<stype_>::min() :
DTypeTrait<stype_>::max(), -1))
DTypeTrait<stype_>::max(), 0))
{
}
MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx)
......
......@@ -45,7 +45,7 @@ void exec_forward(_megdnn_tensor_in src,
reduce::get_ABC(src.layout, A, B, C, param.axis);
for (size_t a = 0; a < A; ++a) for (size_t c = 0; c < C; ++c) {
float best_val = traits<is_max>::init;
size_t best_arg = -1;
size_t best_arg = 0;
for (size_t b = 0; b < B; ++b) {
float curr_val = float(src.ptr<T>()[(a*B+b)*C+c]);
if (traits<is_max>::better_than(curr_val, best_val)) {
......
......@@ -527,3 +527,20 @@ def test_nms_is_same():
assert op3 != op4
def test_argmxx_on_inf():
def run_argmax():
x = F.zeros((100, 100))
x[:] = -float("inf")
idxs = F.argmax(x, axis=0)
return idxs
def run_argmin():
x = F.zeros((100, 100))
x[:] = float("inf")
idxs = F.argmin(x, axis=0)
return idxs
assert all(run_argmax() >= 0)
assert all(run_argmin() >= 0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册