未验证 提交 915b892a 编写于 作者: L Liufang Sang 提交者: GitHub

Fix div zero in fake quantize op (#22966)

* fix div zero test=develop

* fix div zero test=develop

* add hostdevice function test=develop

* add eps when is zero test=develop
上级 fb7b008a
......@@ -58,11 +58,12 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
T s = scale.data<T>()[0];
T inv_s = inverse(s);
platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round();
out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
}
};
......@@ -74,12 +75,14 @@ struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& in, const framework::Tensor& scale,
const int bin_cnt, framework::Tensor* out) {
T s = scale.data<T>()[0];
T inv_s = inverse(s);
platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) =
(s / bin_cnt) * (bin_cnt / s * out_e).round();
(s / bin_cnt) * (bin_cnt * inv_s * out_e).round();
}
};
template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
......@@ -105,9 +108,10 @@ struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
}
for (int i = 0; i < channel; i++) {
T s = scale_data[i];
T inv_s = inverse(s);
framework::Tensor one_channel_out = out->Slice(i, i + 1);
auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round();
out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
}
}
};
......
......@@ -120,11 +120,12 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
int tid = threadIdx.x;
T s = scale[0];
T inv_s = inverse(s);
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt / s * v;
v = bin_cnt * inv_s * v;
out[i] = round(v);
}
}
......@@ -139,9 +140,10 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
T s = scale[0];
for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
T x = in[i];
T inv_s = inverse(s);
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt / s * v;
v = bin_cnt * inv_s * v;
out[i] = round(v) * s / bin_cnt;
}
}
......@@ -198,11 +200,13 @@ __global__ void ChannelClipAndQuantKernel(const T* in, const T* scale,
T* out_c = out + blockIdx.x * channel_size;
T s = scale[blockIdx.x];
T inv_s = inverse(s);
for (int i = tid; i < channel_size; i += blockDim.x) {
T x = in_c[i];
T v = x > s ? s : x;
v = v < -s ? -s : v;
v = bin_cnt / s * v;
v = bin_cnt * inv_s * v;
out_c[i] = round(v);
}
}
......
......@@ -20,10 +20,17 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
template <typename T>
inline HOSTDEVICE T inverse(T s) {
T eps = 1e-6;
return s <= 1e-30 ? 1.0 / (s + eps) : 1.0 / s;
}
template <typename DeviceContext, typename T>
struct FindAbsMaxFunctor {
void operator()(const DeviceContext& ctx, const T* in, const int num, T* out);
......
......@@ -36,6 +36,40 @@ class TestFakeQuantizeOp(OpTest):
self.check_output()
class TestFakeQuantizeOp1(OpTest):
def setUp(self):
self.op_type = "fake_quantize_abs_max"
self.attrs = {'bit_length': 8}
self.inputs = {'X': np.zeros((10, 10)).astype("float32"), }
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
self.outputs = {
'Out': np.round(self.inputs['X'] * inv_scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScale': np.array(scale).astype("float32"),
}
def test_check_output(self):
self.check_output()
class TestFakeQuantizeOp2(OpTest):
def setUp(self):
self.op_type = "fake_quantize_abs_max"
self.attrs = {'bit_length': 8}
self.inputs = {'X': np.full((10, 10), 1e-40).astype("float32"), }
scale = np.max(np.abs(self.inputs['X'])).astype("float32")
inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
self.outputs = {
'Out': np.round(self.inputs['X'] * inv_scale * (
(1 << (self.attrs['bit_length'] - 1)) - 1)),
'OutScale': np.array(scale).astype("float32"),
}
def test_check_output(self):
self.check_output()
class TestFakeChannelWiseQuantizeOp(OpTest):
def setUp(self):
self.op_type = "fake_channel_wise_quantize_abs_max"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册