From 89186edc5d463d37b34a990ae307e11f667de9f1 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 2 Nov 2021 17:38:13 +0800 Subject: [PATCH] fix(dnn): correct reduce/argmxx/fakequant calculation with nan GitOrigin-RevId: 7e78bdae9106186c5d1a1b8ee2ab337ed69db21b --- dnn/src/common/argmxx_helper.h | 66 +++++++++++++++++++ dnn/src/common/reduce_helper.h | 44 +++++++++++++ dnn/src/cuda/fake_quant/kern.cuh | 12 +++- dnn/src/naive/argmxx/opr_impl.cpp | 8 ++- dnn/src/naive/reduce/opr_impl.cpp | 26 +++++--- dnn/test/cuda/fake_quant.cpp | 52 ++++++++++++++- dnn/test/cuda/reduce.cpp | 14 ++++ .../python/test/unit/functional/test_math.py | 6 +- .../test/unit/quantization/test_fake_quant.py | 5 ++ 9 files changed, 219 insertions(+), 14 deletions(-) diff --git a/dnn/src/common/argmxx_helper.h b/dnn/src/common/argmxx_helper.h index bcd7a2f37..9544cc8d3 100644 --- a/dnn/src/common/argmxx_helper.h +++ b/dnn/src/common/argmxx_helper.h @@ -78,6 +78,72 @@ struct ArgmxxOp { const wtype INIT; }; +template +struct ArgmxxOp { + using stype_ = dt_float32; + struct wtype { + stype_ key; + dt_int32 val; + MEGDNN_HOST MEGDNN_DEVICE wtype() {} + MEGDNN_HOST MEGDNN_DEVICE wtype(stype_ key, dt_int32 val) + : key(key), val(val) {} + MEGDNN_HOST MEGDNN_DEVICE wtype(wtype& rhs) : key(rhs.key), val(rhs.val) {} + MEGDNN_HOST MEGDNN_DEVICE wtype(volatile wtype& rhs) + : key(rhs.key), val(rhs.val) {} + MEGDNN_HOST MEGDNN_DEVICE wtype(const wtype& rhs) + : key(rhs.key), val(rhs.val) {} + MEGDNN_HOST MEGDNN_DEVICE wtype(const volatile wtype& rhs) + : key(rhs.key), val(rhs.val) {} + MEGDNN_HOST MEGDNN_DEVICE volatile wtype& operator=(const wtype& rhs) volatile { + this->key = rhs.key; + this->val = rhs.val; + return *this; + } + }; + MEGDNN_HOST MEGDNN_DEVICE + ArgmxxOp(stype_* src, dt_int32* dst, uint32_t A, uint32_t B, uint32_t C) + : src(src), + dst(dst), + A(A), + B(B), + C(C), + INIT(wtype( + is_max ? DTypeTrait::min() : DTypeTrait::max(), + 0)) {} + MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { + wtype res; + res.key = src[idx]; + res.val = idx / C % B; + return res; + } + MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { + dst[idx] = val.val; + } + static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { +#if defined(__CUDA_ARCH__) + if (isnan(lhs.key)) +#else + if (std::isnan(lhs.key)) +#endif + return lhs; + if (is_max) { + if (lhs.key > rhs.key) + return lhs; + else + return rhs; + } else { + if (lhs.key < rhs.key) + return lhs; + else + return rhs; + } + } + stype_* src; + dt_int32* dst; + uint32_t A, B, C; + const wtype INIT; +}; + } // namespace argmxx } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/common/reduce_helper.h b/dnn/src/common/reduce_helper.h index 14a0f6899..b72340f61 100644 --- a/dnn/src/common/reduce_helper.h +++ b/dnn/src/common/reduce_helper.h @@ -119,6 +119,28 @@ struct MinOp { : INIT(wtype(DTypeTrait::max())), src(src), dst(dst), B(B) {} }; +template +struct MinOp { + typedef dt_float32 wtype; + const wtype INIT; + + src_ctype* src; + dst_ctype* dst; + const size_t B; + + MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } + MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } + static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { +#if defined(__CUDA_ARCH__) + return (isnan(lhs) || lhs < rhs) ? lhs : rhs; +#else + return (std::isnan(lhs) || lhs < rhs) ? lhs : rhs; +#endif + } + MEGDNN_HOST MEGDNN_DEVICE MinOp(src_ctype* src, dst_ctype* dst, size_t B) + : INIT(wtype(DTypeTrait::max())), src(src), dst(dst), B(B) {} +}; + template struct MaxOp { typedef wtype_ wtype; @@ -141,6 +163,28 @@ struct MaxOp { : INIT(wtype(DTypeTrait::min())), src(src), dst(dst), B(B) {} }; +template +struct MaxOp { + typedef dt_float32 wtype; + const wtype INIT; + + src_ctype* src; + dst_ctype* dst; + const size_t B; + + MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { return src[idx]; } + MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { dst[idx] = val; } + static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { +#if defined(__CUDA_ARCH__) + return (isnan(lhs) || lhs > rhs) ? lhs : rhs; +#else + return (std::isnan(lhs) || lhs > rhs) ? lhs : rhs; +#endif + } + MEGDNN_HOST MEGDNN_DEVICE MaxOp(src_ctype* src, dst_ctype* dst, size_t B) + : INIT(wtype(DTypeTrait::min())), src(src), dst(dst), B(B) {} +}; + template struct CheckNonFiniteOp { typedef wtype_ wtype; diff --git a/dnn/src/cuda/fake_quant/kern.cuh b/dnn/src/cuda/fake_quant/kern.cuh index efc952583..bbb60affb 100644 --- a/dnn/src/cuda/fake_quant/kern.cuh +++ b/dnn/src/cuda/fake_quant/kern.cuh @@ -30,6 +30,10 @@ struct FakeQuantKernOp { __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { ctype x = round(input[idx] / scale) + zero_point; + if (isnan(x)) { + output[idx] = NAN; + return; + } x = fmaxf(fminf(x, qmax), qmin); output[idx] = (x - zero_point) * scale; } @@ -54,7 +58,7 @@ struct FakeQuantBwdKernOp { __device__ void operator()(uint32_t idx, ctype scale, ctype zero_point) { ctype x = round(input[idx] / scale) + zero_point; - grad[idx] = x <= qmax && x >= qmin ? diff[idx] : 0.0; + grad[idx] = isnan(x) ? NAN : x <= qmax && x >= qmin ? diff[idx] : 0.0; } #if MEGDNN_CC_HOST @@ -77,6 +81,10 @@ struct FakeQuantKernOpNonContig { __device__ void operator()( uint32_t, ctype& output, ctype input, ctype scale, ctype zero_point) { ctype x = round(input / scale) + zero_point; + if (isnan(x)) { + output = NAN; + return; + } x = fmaxf(fminf(x, qmax), qmin); output = (x - zero_point) * scale; } @@ -96,7 +104,7 @@ struct FakeQuantBwdKernOpNonContig { uint32_t, ctype& grad, ctype diff, ctype input, ctype scale, ctype zero_point) { ctype x = round(input / scale) + zero_point; - grad = x <= qmax && x >= qmin ? diff : 0.0; + grad = isnan(x) ? NAN : x <= qmax && x >= qmin ? diff : 0.0; } #if MEGDNN_CC_HOST diff --git a/dnn/src/naive/argmxx/opr_impl.cpp b/dnn/src/naive/argmxx/opr_impl.cpp index 2dc0c662f..ec2737cd7 100644 --- a/dnn/src/naive/argmxx/opr_impl.cpp +++ b/dnn/src/naive/argmxx/opr_impl.cpp @@ -26,14 +26,18 @@ struct traits; template <> struct traits { static const float init; - static bool better_than(float lhs, float rhs) { return lhs > rhs; } + static bool better_than(float lhs, float rhs) { + return std::isnan(lhs) ? true : lhs > rhs; + } }; const float traits::init = std::numeric_limits::lowest(); template <> struct traits { static const float init; - static float better_than(float lhs, float rhs) { return lhs < rhs; } + static float better_than(float lhs, float rhs) { + return std::isnan(lhs) ? true : lhs < rhs; + } }; const float traits::init = std::numeric_limits::max(); diff --git a/dnn/src/naive/reduce/opr_impl.cpp b/dnn/src/naive/reduce/opr_impl.cpp index 3ec8a3bbe..c1dd5fa76 100644 --- a/dnn/src/naive/reduce/opr_impl.cpp +++ b/dnn/src/naive/reduce/opr_impl.cpp @@ -73,25 +73,35 @@ const ctype Trait::INIT = ctype(1); template struct Trait { - static const ctype INIT; - static ctype apply(ctype x, ctype y) { return x < y ? x : y; } static ctype visit(ctype x) { return x; } static ctype write(ctype x, size_t) { return x; } }; -template -const ctype Trait::INIT = DTypeTrait::max(); + +template <> +struct Trait { + using ctype = dt_float32; + + static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x < y) ? x : y; } + static ctype visit(ctype x) { return x; } + static ctype write(ctype x, size_t) { return x; } +}; template struct Trait { - static const ctype INIT; - static ctype apply(ctype x, ctype y) { return x > y ? x : y; } static ctype visit(ctype x) { return x; } static ctype write(ctype x, size_t) { return x; } }; -template -const ctype Trait::INIT = DTypeTrait::min(); + +template <> +struct Trait { + using ctype = dt_float32; + + static ctype apply(ctype x, ctype y) { return (std::isnan(x) || x > y) ? x : y; } + static ctype visit(ctype x) { return x; } + static ctype write(ctype x, size_t) { return x; } +}; template void reduce_fwd( diff --git a/dnn/test/cuda/fake_quant.cpp b/dnn/test/cuda/fake_quant.cpp index c024cefd9..e55557a09 100644 --- a/dnn/test/cuda/fake_quant.cpp +++ b/dnn/test/cuda/fake_quant.cpp @@ -21,7 +21,9 @@ using namespace fake_quant; TEST_F(CUDA, FAKE_QUANT) { std::vector args = get_args(); auto dtype = dtype::Float32(); - std::unique_ptr rng; + UniformFloatRNG rng(-1.0f, 1.0f); + const auto nan = std::numeric_limits::quiet_NaN(); + UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); for (auto&& arg : args) { auto param = arg.param; @@ -35,6 +37,17 @@ TEST_F(CUDA, FAKE_QUANT) { .set_dtype(2, dtype) .set_dtype(3, dtype) .execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape}); + + checker.set_allow_invalid_check(true); + checker.set_rng(0, &rng1); + checker.set_param(param) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype) + .set_dtype(3, dtype) + .execs(TensorShapeArray{ishape, scale_shape, zeropoint_shape, ishape}); + checker.set_rng(0, &rng); + checker.set_allow_invalid_check(false); } // test noncontiguous layout for (auto&& arg : args) { @@ -53,12 +66,25 @@ TEST_F(CUDA, FAKE_QUANT) { {scale_shape, dtype::Float32()}, {zeropoint_shape, dtype::Float32()}, ilayout}); + + checker.set_allow_invalid_check(true); + checker.set_rng(0, &rng1); + checker.set_param(param).execl( + {ilayout, + {scale_shape, dtype::Float32()}, + {zeropoint_shape, dtype::Float32()}, + ilayout}); + checker.set_rng(0, &rng); + checker.set_allow_invalid_check(false); } } TEST_F(CUDA, FAKE_QUANT_BACKWARD) { std::vector args = get_args(); auto dtype = dtype::Float32(); + UniformFloatRNG rng(-1.0f, 1.0f); + const auto nan = std::numeric_limits::quiet_NaN(); + UniformFloatWithValueRNG rng1 = UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); for (auto&& arg : args) { auto param = arg.param; @@ -74,6 +100,19 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) { .set_dtype(4, dtype) .execs(TensorShapeArray{ ishape, ishape, scale_shape, zeropoint_shape, ishape}); + + checker.set_allow_invalid_check(true); + checker.set_rng(0, &rng1); + checker.set_param(param) + .set_dtype(0, dtype) + .set_dtype(1, dtype) + .set_dtype(2, dtype) + .set_dtype(3, dtype) + .set_dtype(4, dtype) + .execs(TensorShapeArray{ + ishape, ishape, scale_shape, zeropoint_shape, ishape}); + checker.set_rng(0, &rng); + checker.set_allow_invalid_check(false); } // test noncontiguous layout for (auto&& arg : args) { @@ -93,6 +132,17 @@ TEST_F(CUDA, FAKE_QUANT_BACKWARD) { {scale_shape, dtype::Float32()}, {zeropoint_shape, dtype::Float32()}, ilayout}); + + checker.set_allow_invalid_check(true); + checker.set_rng(0, &rng1); + checker.set_param(param).execl( + {ilayout, + ilayout, + {scale_shape, dtype::Float32()}, + {zeropoint_shape, dtype::Float32()}, + ilayout}); + checker.set_rng(0, &rng); + checker.set_allow_invalid_check(false); } } diff --git a/dnn/test/cuda/reduce.cpp b/dnn/test/cuda/reduce.cpp index b1696e93b..bf6b091d9 100644 --- a/dnn/test/cuda/reduce.cpp +++ b/dnn/test/cuda/reduce.cpp @@ -54,6 +54,20 @@ TEST_F(CUDA, REDUCE) { // very large reduce checker.execs({{1, 4194304, 1}, {}}); + // inputs have nan + { + const auto nan = std::numeric_limits::quiet_NaN(); + UniformFloatWithValueRNG rng1 = + UniformFloatWithValueRNG(-1.0f, 1.0f, 0.5f, nan); + checker.set_allow_invalid_check(true).set_rng(0, &rng1); + for (auto mode : {Mode::MIN, Mode::MAX}) { + checker.set_param({mode, 1}); + checker.execs({{2, 64, 32}, {}}); + } + checker.set_allow_invalid_check(false); + } + checker.set_rng(0, &rng); + auto check = [&](Reduce::Mode mode, DType src_dtype, DType dst_dtype, Reduce::DataType data_type) { for (int32_t axis : {0, 1, 2, 3}) { diff --git a/imperative/python/test/unit/functional/test_math.py b/imperative/python/test/unit/functional/test_math.py index e5cc03811..a59e64938 100644 --- a/imperative/python/test/unit/functional/test_math.py +++ b/imperative/python/test/unit/functional/test_math.py @@ -21,7 +21,11 @@ def common_test_reduce(opr, ref_opr): data2_shape = (2, 9, 12) data1 = np.random.random(data1_shape).astype(np.float32) data2 = np.random.random(data2_shape).astype(np.float32) - cases = [{"input": data1}, {"input": data2}] + cases = [ + {"input": data1}, + {"input": data2}, + {"input": np.array([[[1, 2, np.nan, 4], [8, 6, 5, 2], [2, 3, 4, 5]]])}, + ] if opr not in (F.argmin, F.argmax): # test default axis diff --git a/imperative/python/test/unit/quantization/test_fake_quant.py b/imperative/python/test/unit/quantization/test_fake_quant.py index 8ee961ed7..96f6f4fc0 100644 --- a/imperative/python/test/unit/quantization/test_fake_quant.py +++ b/imperative/python/test/unit/quantization/test_fake_quant.py @@ -143,6 +143,11 @@ def test_fakequant(): assert np.allclose(x.grad.numpy(), x1.grad.numpy()) assert make_shape_tuple(x.grad.shape) == make_shape_tuple(x1.grad.shape) + # test nan + x = F.full((1, 32, 3, 3), np.nan) + y = fake_quant_tensor(x, qparams).numpy() + assert np.isnan(y).all() + zero_point = tensor([1.0], dtype=np.float32) scale = tensor([4.0], dtype=np.float32) run(zero_point, scale) -- GitLab