diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index 643d2a6cd08e46941f34fd660b1f160e46063c6c..ce1e6a93fcc3f8c07187d4bf51468c9691ee9ee2 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -119,6 +119,15 @@ __device__ __host__ inline float dispatch_powf(float x, float y) { return powf(x, y); } +__device__ __host__ inline int dispatch_floordiv_int(int x, int y) { + if ((x ^ y) < 0) { + const auto quot = x / y; + const auto rem = x % y; + return rem ? quot - 1 : quot; + } + return x / y; +} + #include "src/common/elemwise/each_mode.inl" template @@ -227,7 +236,7 @@ DEF_KERN(dt_bool, LT, x < y); DEF_KERN(dt_bool, LEQ, x <= y); DEF_KERN(dt_bool, EQ, x == y); -DEF_KERN_INT(FLOOR_DIV, x / y); +DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y)); DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); DEF_KERN_INT(MOD, x % y); diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 782c000fdfa2dd158d37d4ac95e7ef57de51a318..f4c1788b51b359cf5a040871a7c816b519339fb2 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -59,7 +59,7 @@ def test_multiply(): def test_div(): np.testing.assert_allclose( - F.div(tensor([3, 4]), 2).numpy(), + F.div(tensor([3.0, 4.0]), 2).numpy(), np.divide(np.array([3, 4], dtype=np.float32), 2), ) @@ -67,6 +67,16 @@ def test_div(): (tensor([3, 4]) / 2).numpy(), np.divide(np.array([3, 4], dtype=np.float32), 2), ) + np.testing.assert_allclose( + F.floor_div(tensor([-5.0, -7.0]), 2).numpy(), + np.floor_divide(np.array([-5.0, -7.0], dtype=np.float32), 2), + ) + + np.testing.assert_allclose( + (tensor([-5, -7]) // 2).numpy(), + np.floor_divide(np.array([-5, -7], dtype=np.int32), 2), + ) + def test_clamp(): """Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index fe964ec5dd17b4aefde8ff923a68531f29e8164e..2ed8d8bd2879d3cb12536bee7f32e36a6bba54bb 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -39,6 +39,19 @@ int do_mod(int a, int b) { return a % b; } +float do_floor_div(float a, float b) { + return std::floor(a / b); +} + +int do_floor_div(int a, int b) { + if ((a ^ b) < 0) { + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + } + return a / b; +} + float do_erfinv(float x) { return erfinvf(x); } diff --git a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl index 20fecb1539df61a46cacf3c25bfd726ac55b1513..e02bc50e8a76fe9ed8f9eb7cd5d29d7cf2e3fc5e 100644 --- a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl @@ -41,7 +41,7 @@ DEF_TRAIT(LT, x < y) #define _ALLOW_INT true DEF_TRAIT(ABS_GRAD, x > 0 ? y : -y) DEF_TRAIT(ADD, x + y) -DEF_TRAIT(FLOOR_DIV, floor(x / y)) +DEF_TRAIT(FLOOR_DIV, do_floor_div(x, y)) DEF_TRAIT(MAX, std::max(x, y)) DEF_TRAIT(MIN, std::min(x, y)) DEF_TRAIT(MOD, do_mod(x, y))