未验证 提交 3fc24e09 编写于 作者: W whs 提交者: GitHub

Fix inverse in fake quant (#36763)

上级 417b22d2
......@@ -216,14 +216,14 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
int tid = threadIdx.x;
T s = scale[0];
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
x = x > s ? s : x;
x = x < -s ? -s : x;
x = (bin_cnt_t / s) * x;
x = bin_cnt_t * inv_s * x;
x = static_cast<T>(round(static_cast<float>(x)));
out[i] = (x * s) / bin_cnt_t;
}
......
......@@ -28,8 +28,9 @@ namespace operators {
template <typename T>
inline HOSTDEVICE T inverse(T s) {
T eps = 1e-6;
return s <= 1e-30 ? 1.0 / (s + eps) : 1.0 / s;
T eps = static_cast<T>(1e-6);
T one = static_cast<T>(1.0);
return s <= static_cast<T>(1e-30) ? one / (s + eps) : one / s;
}
template <typename DeviceContext, typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册