未验证 提交 0ae8a2d6 编写于 作者: L Leo Chen 提交者: GitHub

Fix the underflow of fp16 fake quantize operators (#43088)

Co-authored-by: NRyan Jeng <rjeng@nvidia.com>
上级 4700a08e
...@@ -217,16 +217,18 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale, ...@@ -217,16 +217,18 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x; int tid = threadIdx.x;
T s = scale[0]; using ComputeDataType = typename QuantizeDataType<T>::type;
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt); ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i]; ComputeDataType x = static_cast<ComputeDataType>(in[i]);
T v = x > s ? s : x; ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v; v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>( out[i] = static_cast<T>(round(v));
round(static_cast<typename QuantizeDataType<T>::type>(v)));
} }
} }
...@@ -237,18 +239,19 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale, ...@@ -237,18 +239,19 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
int bid = threadIdx.x + blockIdx.x * blockDim.x; int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x; int tid = threadIdx.x;
T s = scale[0]; using ComputeDataType = typename QuantizeDataType<T>::type;
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt); ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) { for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i]; ComputeDataType x = static_cast<ComputeDataType>(in[i]);
x = x > s ? s : x; x = x > s ? s : x;
x = x < -s ? -s : x; x = x < -s ? -s : x;
x = bin_cnt_t * inv_s * x; x = bin_cnt_t * inv_s * x;
x = static_cast<T>( x = round(x);
round(static_cast<typename QuantizeDataType<T>::type>(x))); out[i] = static_cast<T>((x * s) / bin_cnt_t);
out[i] = (x * s) / bin_cnt_t;
} }
} }
...@@ -302,17 +305,18 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale, ...@@ -302,17 +305,18 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
const T* in_c = in + blockIdx.x * channel_size; const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size; T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x]; using ComputeDataType = typename QuantizeDataType<T>::type;
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt); ComputeDataType s = static_cast<ComputeDataType>(scale[blockIdx.x]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int64_t i = tid; i < channel_size; i += blockDim.x) { for (int64_t i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i]; ComputeDataType x = static_cast<ComputeDataType>(in_c[i]);
T v = x > s ? s : x; ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v; v = bin_cnt_t * inv_s * v;
out_c[i] = static_cast<T>( out_c[i] = static_cast<T>(round(v));
round(static_cast<typename QuantizeDataType<T>::type>(v)));
} }
} }
...@@ -322,16 +326,17 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN( ...@@ -322,16 +326,17 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
const T* in, const T* scale, const int bin_cnt, const int64_t n, const T* in, const T* scale, const int bin_cnt, const int64_t n,
const int nScale, const int quant_stride, T* out) { const int nScale, const int quant_stride, T* out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
T bin_cnt_t = static_cast<T>(bin_cnt); using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) { for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) {
T s = scale[(i / quant_stride) % nScale]; ComputeDataType s =
T inv_s = inverse(s); static_cast<ComputeDataType>(scale[(i / quant_stride) % nScale]);
T x = in[i]; ComputeDataType inv_s = inverse(s);
T v = x > s ? s : x; ComputeDataType x = static_cast<ComputeDataType>(in[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v; v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v; v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>( out[i] = static_cast<T>(round(v));
round(static_cast<typename QuantizeDataType<T>::type>(v)));
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册