From 0ebd4400d5d360d055858e06f708698365c3bba8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 3 Nov 2022 14:21:40 +0800 Subject: [PATCH] fix(dnn): fix the modulo of int GitOrigin-RevId: 6f7280246b7fcc10a972fa7ebfd856a968b153fd --- dnn/src/common/elemwise/kern_defs.cuh | 2 +- dnn/test/common/elemwise.cpp | 2 +- dnn/test/naive/elemwise_multi_type.cpp | 20 ++++++++++++++++++++ src/opr/test/basic_arith/elemwise.cpp | 2 +- 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index 87788c0d8..53953f227 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -229,7 +229,7 @@ DEF_KERN(dt_bool, EQ, x == y); DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y)); DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); -DEF_KERN_INT(MOD, x % y); +DEF_KERN_INT(MOD, ((y + x % y) % y)); // consistent with python modulo DEF_KERN_FLOAT(MOD, fmodf(x, y)); DEF_KERN_INT(SHL, x << y); diff --git a/dnn/test/common/elemwise.cpp b/dnn/test/common/elemwise.cpp index 31f9bcf02..602954c6a 100644 --- a/dnn/test/common/elemwise.cpp +++ b/dnn/test/common/elemwise.cpp @@ -878,8 +878,8 @@ DEF_TEST(all_modes) { } while (0) if (trait.allow_int) { - run(dtype::Int8{}); run(dtype::Int32{}); + run(dtype::Int8{}); } if (trait.allow_float) { DNN_FLOAT16_SELECT( diff --git a/dnn/test/naive/elemwise_multi_type.cpp b/dnn/test/naive/elemwise_multi_type.cpp index cf81591ba..dae114823 100644 --- a/dnn/test/naive/elemwise_multi_type.cpp +++ b/dnn/test/naive/elemwise_multi_type.cpp @@ -280,4 +280,24 @@ TEST_F(NAIVE, ELEMWISE_QUANTIZED_MODE_TERNARY) { } } +TEST_F(NAIVE, ELELMWISE_INT_MODULO) { + Checker checker(handle(), /* check_dispatch */ false); + Elemwise::Param param; + param.mode = Elemwise::Param::Mode::MOD; + + checker.set_param(param).exect( + Testcase{ + TensorValue( + {10}, dtype::Int32(), + {10, 24, -6, -20, 10, -90, 45, 3, -1, 0}), + TensorValue( + {10}, dtype::Int32(), {3, 7, 5, -3, -6, 11, 7, -1, 8, -1}), + {}}, + Testcase{ + {}, + {}, + TensorValue( + {10}, dtype::Int32(), {1, 3, 4, -2, -2, 9, 3, 0, 7, 0})}); +} + // vim: syntax=cpp.doxygen diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index b4bab90d5..e9342d3d2 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -25,7 +25,7 @@ float do_mod(float a, float b) { } int do_mod(int a, int b) { - return a % b; + return (a % b + b) % b; } float do_floor_div(float a, float b) { -- GitLab