未验证 提交 8e439ccf 编写于 作者: Q qingqing01 提交者: GitHub

Fix bug in fake_quantize_op and add more unit testing (#15912)

上级 f4846bf3
...@@ -31,7 +31,7 @@ template <typename T> ...@@ -31,7 +31,7 @@ template <typename T>
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> { struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* in, void operator()(const platform::CPUDeviceContext& ctx, const T* in,
const int num, T* out) { const int num, T* out) {
*out = *(std::max_element(in + 0, in + num, Compare<T>())); *out = std::abs(*(std::max_element(in + 0, in + num, Compare<T>())));
} }
}; };
...@@ -46,10 +46,8 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> { ...@@ -46,10 +46,8 @@ struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
platform::Transform<platform::CPUDeviceContext> trans; platform::Transform<platform::CPUDeviceContext> trans;
trans(ctx, in.data<T>(), in.data<T>() + in.numel(), trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s)); out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
auto in_e = framework::EigenVector<T>::Flatten(in);
auto out_e = framework::EigenVector<T>::Flatten(*out); auto out_e = framework::EigenVector<T>::Flatten(*out);
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * out_e).round();
out_e.device(*ctx.eigen_device()) = (bin_cnt / s * in_e).round();
} }
}; };
......
...@@ -35,7 +35,7 @@ class TestFakeQuantizeOp(OpTest): ...@@ -35,7 +35,7 @@ class TestFakeQuantizeOp(OpTest):
self.check_output() self.check_output()
class TestFakeQuantizeOp(OpTest): class TestFakeQuantizeRangeAbsMaxOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "fake_quantize_range_abs_max" self.op_type = "fake_quantize_range_abs_max"
self.attrs = { self.attrs = {
...@@ -43,8 +43,10 @@ class TestFakeQuantizeOp(OpTest): ...@@ -43,8 +43,10 @@ class TestFakeQuantizeOp(OpTest):
'window_size': int(1), 'window_size': int(1),
'is_test': False 'is_test': False
} }
x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
x = x.astype("float32")
self.inputs = { self.inputs = {
'X': np.random.random((8, 16, 7, 7)).astype("float32"), 'X': x,
'Iter': np.zeros(1).astype("int64"), 'Iter': np.zeros(1).astype("int64"),
'InScale': np.zeros(1).astype("float32") 'InScale': np.zeros(1).astype("float32")
} }
...@@ -62,5 +64,36 @@ class TestFakeQuantizeOp(OpTest): ...@@ -62,5 +64,36 @@ class TestFakeQuantizeOp(OpTest):
self.check_output() self.check_output()
class TestFakeQuantizeRangeAbsMaxOp2(OpTest):
def setUp(self):
self.op_type = "fake_quantize_range_abs_max"
self.attrs = {
'bit_length': int(8),
'window_size': int(1),
'is_test': True
}
x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
x = x.astype("float32")
scale = np.max(np.abs(x)).astype("float32") - 1.0
out_scales = np.zeros(self.attrs['window_size']).astype("float32")
out_scales[0] = scale
self.inputs = {
'X': x,
'Iter': np.zeros(1).astype("int64"),
'InScale': scale.astype("float32")
}
xs = np.clip(x, -scale, scale)
qs = np.round(xs / scale * ((1 << (self.attrs['bit_length'] - 1)) - 1))
self.outputs = {
'Out': qs,
'OutScale': scale.astype("float32"),
'OutScales': out_scales,
}
def test_check_output(self):
self.check_output(no_check_set=set(['OutScale', 'OutScales']))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册