From d9a46ea47bbb0dcb6eaedc35bd61ca03d75d9449 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 18 Nov 2021 14:40:01 +0800 Subject: [PATCH] fix(dnn): correct behaviour of floor div for int tensor GitOrigin-RevId: 1444f69cce7fea7c3bf34206bce239caad6cc1d0 --- dnn/src/common/elemwise/kern_defs.cuh | 11 ++++++++++- .../python/test/unit/functional/test_elemwise.py | 12 +++++++++++- src/opr/test/basic_arith/elemwise.cpp | 13 +++++++++++++ .../test/basic_arith/elemwise_binary_trait_def.inl | 2 +- 4 files changed, 35 insertions(+), 3 deletions(-) diff --git a/dnn/src/common/elemwise/kern_defs.cuh b/dnn/src/common/elemwise/kern_defs.cuh index 643d2a6cd..ce1e6a93f 100644 --- a/dnn/src/common/elemwise/kern_defs.cuh +++ b/dnn/src/common/elemwise/kern_defs.cuh @@ -119,6 +119,15 @@ __device__ __host__ inline float dispatch_powf(float x, float y) { return powf(x, y); } +__device__ __host__ inline int dispatch_floordiv_int(int x, int y) { + if ((x ^ y) < 0) { + const auto quot = x / y; + const auto rem = x % y; + return rem ? quot - 1 : quot; + } + return x / y; +} + #include "src/common/elemwise/each_mode.inl" template @@ -227,7 +236,7 @@ DEF_KERN(dt_bool, LT, x < y); DEF_KERN(dt_bool, LEQ, x <= y); DEF_KERN(dt_bool, EQ, x == y); -DEF_KERN_INT(FLOOR_DIV, x / y); +DEF_KERN_INT(FLOOR_DIV, dispatch_floordiv_int(x, y)); DEF_KERN_FLOAT(FLOOR_DIV, floorf(x / y)); DEF_KERN_INT(MOD, x % y); diff --git a/imperative/python/test/unit/functional/test_elemwise.py b/imperative/python/test/unit/functional/test_elemwise.py index 782c000fd..f4c1788b5 100644 --- a/imperative/python/test/unit/functional/test_elemwise.py +++ b/imperative/python/test/unit/functional/test_elemwise.py @@ -59,7 +59,7 @@ def test_multiply(): def test_div(): np.testing.assert_allclose( - F.div(tensor([3, 4]), 2).numpy(), + F.div(tensor([3.0, 4.0]), 2).numpy(), np.divide(np.array([3, 4], dtype=np.float32), 2), ) @@ -67,6 +67,16 @@ def test_div(): (tensor([3, 4]) / 2).numpy(), np.divide(np.array([3, 4], dtype=np.float32), 2), ) + np.testing.assert_allclose( + F.floor_div(tensor([-5.0, -7.0]), 2).numpy(), + np.floor_divide(np.array([-5.0, -7.0], dtype=np.float32), 2), + ) + + np.testing.assert_allclose( + (tensor([-5, -7]) // 2).numpy(), + np.floor_divide(np.array([-5, -7], dtype=np.int32), 2), + ) + def test_clamp(): """Fix an issue when `lower` or `upper` is 0, it will be recognized as `False` and diff --git a/src/opr/test/basic_arith/elemwise.cpp b/src/opr/test/basic_arith/elemwise.cpp index fe964ec5d..2ed8d8bd2 100644 --- a/src/opr/test/basic_arith/elemwise.cpp +++ b/src/opr/test/basic_arith/elemwise.cpp @@ -39,6 +39,19 @@ int do_mod(int a, int b) { return a % b; } +float do_floor_div(float a, float b) { + return std::floor(a / b); +} + +int do_floor_div(int a, int b) { + if ((a ^ b) < 0) { + const auto quot = a / b; + const auto rem = a % b; + return rem ? quot - 1 : quot; + } + return a / b; +} + float do_erfinv(float x) { return erfinvf(x); } diff --git a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl index 20fecb153..e02bc50e8 100644 --- a/src/opr/test/basic_arith/elemwise_binary_trait_def.inl +++ b/src/opr/test/basic_arith/elemwise_binary_trait_def.inl @@ -41,7 +41,7 @@ DEF_TRAIT(LT, x < y) #define _ALLOW_INT true DEF_TRAIT(ABS_GRAD, x > 0 ? y : -y) DEF_TRAIT(ADD, x + y) -DEF_TRAIT(FLOOR_DIV, floor(x / y)) +DEF_TRAIT(FLOOR_DIV, do_floor_div(x, y)) DEF_TRAIT(MAX, std::max(x, y)) DEF_TRAIT(MIN, std::min(x, y)) DEF_TRAIT(MOD, do_mod(x, y)) -- GitLab