未验证 提交 0ae8a2d6 编写于 作者: L Leo Chen 提交者: GitHub

Fix the underflow of fp16 fake quantize operators (#43088)

Co-authored-by: NRyan Jeng <rjeng@nvidia.com>
上级 4700a08e
......@@ -217,16 +217,18 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
T s = scale[0];
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt);
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
T v = x > s ? s : x;
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(v)));
out[i] = static_cast<T>(round(v));
}
}
......@@ -237,18 +239,19 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
int bid = threadIdx.x + blockIdx.x * blockDim.x;
int tid = threadIdx.x;
T s = scale[0];
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt);
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
x = x > s ? s : x;
x = x < -s ? -s : x;
x = bin_cnt_t * inv_s * x;
x = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(x)));
out[i] = (x * s) / bin_cnt_t;
x = round(x);
out[i] = static_cast<T>((x * s) / bin_cnt_t);
}
}
......@@ -302,17 +305,18 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
const T* in_c = in + blockIdx.x * channel_size;
T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x];
T inv_s = inverse(s);
T bin_cnt_t = static_cast<T>(bin_cnt);
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType s = static_cast<ComputeDataType>(scale[blockIdx.x]);
ComputeDataType inv_s = inverse(s);
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int64_t i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
ComputeDataType x = static_cast<ComputeDataType>(in_c[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v;
out_c[i] = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(v)));
out_c[i] = static_cast<T>(round(v));
}
}
......@@ -322,16 +326,17 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
const T* in, const T* scale, const int bin_cnt, const int64_t n,
const int nScale, const int quant_stride, T* out) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
T bin_cnt_t = static_cast<T>(bin_cnt);
using ComputeDataType = typename QuantizeDataType<T>::type;
ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) {
T s = scale[(i / quant_stride) % nScale];
T inv_s = inverse(s);
T x = in[i];
T v = x > s ? s : x;
ComputeDataType s =
static_cast<ComputeDataType>(scale[(i / quant_stride) % nScale]);
ComputeDataType inv_s = inverse(s);
ComputeDataType x = static_cast<ComputeDataType>(in[i]);
ComputeDataType v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt_t * inv_s * v;
out[i] = static_cast<T>(
round(static_cast<typename QuantizeDataType<T>::type>(v)));
out[i] = static_cast<T>(round(v));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册