From 915b892a15536d7cdeeb97f6b9d7386693c19129 Mon Sep 17 00:00:00 2001 From: Liufang Sang Date: Wed, 18 Mar 2020 21:24:46 -0500 Subject: [PATCH] Fix div zero in fake quantize op (#22966) * fix div zero test=develop * fix div zero test=develop * add hostdevice function test=develop * add eps when is zero test=develop --- paddle/fluid/operators/fake_quantize_op.cc | 10 ++++-- paddle/fluid/operators/fake_quantize_op.cu | 10 ++++-- paddle/fluid/operators/fake_quantize_op.h | 7 ++++ .../tests/unittests/test_fake_quantize_op.py | 34 +++++++++++++++++++ 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 085356f77d..292a69e82b 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -58,11 +58,12 @@ struct ClipAndFakeQuantFunctor { const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, framework::Tensor* out) { T s = scale.data()[0]; + T inv_s = inverse(s); platform::Transform trans; trans(ctx, in.data(), in.data() + in.numel(), out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); auto out_e = framework::EigenVector::Flatten(*out); - out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round(); + out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); } }; @@ -74,12 +75,14 @@ struct ClipAndFakeQuantDequantFunctor { const framework::Tensor& in, const framework::Tensor& scale, const int bin_cnt, framework::Tensor* out) { T s = scale.data()[0]; + T inv_s = inverse(s); + platform::Transform trans; trans(ctx, in.data(), in.data() + in.numel(), out->mutable_data(ctx.GetPlace()), ClipFunctor(-s, s)); auto out_e = framework::EigenVector::Flatten(*out); out_e.device(*ctx.eigen_device()) = - (s / bin_cnt) * (bin_cnt / s * out_e).round(); + (s / bin_cnt) * (bin_cnt * inv_s * out_e).round(); } }; template struct ClipAndFakeQuantDequantFunctor { } for (int i = 0; i < channel; i++) { T s = scale_data[i]; + T inv_s = inverse(s); framework::Tensor one_channel_out = out->Slice(i, i + 1); auto out_e = framework::EigenVector::Flatten(one_channel_out); - out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round(); + out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round(); } } }; diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index e9a7201bc0..c8182f3a9a 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -120,11 +120,12 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, int tid = threadIdx.x; T s = scale[0]; + T inv_s = inverse(s); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { T x = in[i]; T v = x > s ? s : x; v = v < -s ? -s : v; - v = bin_cnt / s * v; + v = bin_cnt * inv_s * v; out[i] = round(v); } } @@ -139,9 +140,10 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, T s = scale[0]; for (int i = bid; i < n; i += blockDim.x * gridDim.x) { T x = in[i]; + T inv_s = inverse(s); T v = x > s ? s : x; v = v < -s ? -s : v; - v = bin_cnt / s * v; + v = bin_cnt * inv_s * v; out[i] = round(v) * s / bin_cnt; } } @@ -198,11 +200,13 @@ __global__ void ChannelClipAndQuantKernel(const T* in, const T* scale, T* out_c = out + blockIdx.x * channel_size; T s = scale[blockIdx.x]; + T inv_s = inverse(s); + for (int i = tid; i < channel_size; i += blockDim.x) { T x = in_c[i]; T v = x > s ? s : x; v = v < -s ? -s : v; - v = bin_cnt / s * v; + v = bin_cnt * inv_s * v; out_c[i] = round(v); } } diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 285947567e..5c27ee8748 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -20,10 +20,17 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/platform/hostdevice.h" namespace paddle { namespace operators { +template +inline HOSTDEVICE T inverse(T s) { + T eps = 1e-6; + return s <= 1e-30 ? 1.0 / (s + eps) : 1.0 / s; +} + template struct FindAbsMaxFunctor { void operator()(const DeviceContext& ctx, const T* in, const int num, T* out); diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index 7cd27e2c89..6943f3d0ff 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -36,6 +36,40 @@ class TestFakeQuantizeOp(OpTest): self.check_output() +class TestFakeQuantizeOp1(OpTest): + def setUp(self): + self.op_type = "fake_quantize_abs_max" + self.attrs = {'bit_length': 8} + self.inputs = {'X': np.zeros((10, 10)).astype("float32"), } + scale = np.max(np.abs(self.inputs['X'])).astype("float32") + inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale + self.outputs = { + 'Out': np.round(self.inputs['X'] * inv_scale * ( + (1 << (self.attrs['bit_length'] - 1)) - 1)), + 'OutScale': np.array(scale).astype("float32"), + } + + def test_check_output(self): + self.check_output() + + +class TestFakeQuantizeOp2(OpTest): + def setUp(self): + self.op_type = "fake_quantize_abs_max" + self.attrs = {'bit_length': 8} + self.inputs = {'X': np.full((10, 10), 1e-40).astype("float32"), } + scale = np.max(np.abs(self.inputs['X'])).astype("float32") + inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale + self.outputs = { + 'Out': np.round(self.inputs['X'] * inv_scale * ( + (1 << (self.attrs['bit_length'] - 1)) - 1)), + 'OutScale': np.array(scale).astype("float32"), + } + + def test_check_output(self): + self.check_output() + + class TestFakeChannelWiseQuantizeOp(OpTest): def setUp(self): self.op_type = "fake_channel_wise_quantize_abs_max" -- GitLab