未验证 提交 542ba214 编写于 作者: W whs 提交者: GitHub

Fix inverse in fake quant (#36762)

上级 63f3ae07
...@@ -216,14 +216,14 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, ...@@ -216,14 +216,14 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
int tid = threadIdx.x; int tid = threadIdx.x;
T s = scale[0]; T s = scale[0];
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt); T bin_cnt_t = static_cast<T>(bin_cnt);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i]; T x = in[i];
x = x > s ? s : x; x = x > s ? s : x;
x = x < -s ? -s : x; x = x < -s ? -s : x;
x = (bin_cnt_t / s) * x; x = bin_cnt_t * inv_s * x;
x = static_cast<T>(round(static_cast<float>(x))); x = static_cast<T>(round(static_cast<float>(x)));
out[i] = (x * s) / bin_cnt_t; out[i] = (x * s) / bin_cnt_t;
} }
......
...@@ -28,8 +28,9 @@ namespace operators { ...@@ -28,8 +28,9 @@ namespace operators {
template <typename T> template <typename T>
inline HOSTDEVICE T inverse(T s) { inline HOSTDEVICE T inverse(T s) {
T eps = 1e-6; T eps = static_cast<T>(1e-6);
return s <= 1e-30 ? 1.0 / (s + eps) : 1.0 / s; T one = static_cast<T>(1.0);
return s <= static_cast<T>(1e-30) ? one / (s + eps) : one / s;
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册