From 542ba21432aae51d63bed27b9feee43da86613ca Mon Sep 17 00:00:00 2001 From: whs Date: Wed, 27 Oct 2021 10:23:59 +0800 Subject: [PATCH] Fix inverse in fake quant (#36762) --- paddle/fluid/operators/fake_quantize_op.cu | 4 ++-- paddle/fluid/operators/fake_quantize_op.h | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/fake_quantize_op.cu b/paddle/fluid/operators/fake_quantize_op.cu index 583ff157a0d..8f2235c7e3d 100644 --- a/paddle/fluid/operators/fake_quantize_op.cu +++ b/paddle/fluid/operators/fake_quantize_op.cu @@ -216,14 +216,14 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, int tid = threadIdx.x; T s = scale[0]; + T inv_s = inverse(s); T bin_cnt_t = static_cast(bin_cnt); for (int i = bid; i < n; i += blockDim.x * gridDim.x) { T x = in[i]; x = x > s ? s : x; x = x < -s ? -s : x; - x = (bin_cnt_t / s) * x; - + x = bin_cnt_t * inv_s * x; x = static_cast(round(static_cast(x))); out[i] = (x * s) / bin_cnt_t; } diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 11a2d2de8bc..21e7079ff62 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -28,8 +28,9 @@ namespace operators { template inline HOSTDEVICE T inverse(T s) { - T eps = 1e-6; - return s <= 1e-30 ? 1.0 / (s + eps) : 1.0 / s; + T eps = static_cast(1e-6); + T one = static_cast(1.0); + return s <= static_cast(1e-30) ? one / (s + eps) : one / s; } template -- GitLab