From 8fabb1c32fec435f08a18151422ac9566f563cde Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Mon, 12 Oct 2020 10:10:25 +0800 Subject: [PATCH] Add test attribute in channelwise_quant op, test=develop (#27742) * Add test attribute in channelwise_quant op, test=develop --- paddle/fluid/operators/fake_quantize_op.cc | 4 ++++ paddle/fluid/operators/fake_quantize_op.h | 9 +++++--- .../slim/quantization/quantization_pass.py | 22 ++++++++++++++----- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index e9b4c7dacf..04fa8db9a5 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -404,6 +404,10 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker "the received is %d", bit_length)); }); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); AddComment(R"DOC( The scale of FakeChannelWiseQuantize operator is a vector. In detail, each channel of the input X has a scale value. diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 2f5afbe0ee..94a75f930b 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -146,16 +146,19 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { auto* out = context.Output("Out"); auto* out_scale = context.Output("OutScale"); - T* out_scale_data = out_scale->mutable_data(context.GetPlace()); out->mutable_data(context.GetPlace()); int bit_length = context.Attr("bit_length"); int bin_cnt = std::pow(2, bit_length - 1) - 1; int quant_axis = context.Attr("quant_axis"); + bool is_test = context.Attr("is_test"); auto& dev_ctx = context.template device_context(); - FindChannelAbsMaxFunctor()(dev_ctx, *in, quant_axis, - out_scale_data); + if (!is_test) { + T* out_scale_data = out_scale->mutable_data(context.GetPlace()); + FindChannelAbsMaxFunctor()(dev_ctx, *in, quant_axis, + out_scale_data); + } ChannelClipAndFakeQuantFunctor()( dev_ctx, *in, *out_scale, bin_cnt, quant_axis, out); } diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index b5a8d90194..eba881a263 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -758,6 +758,7 @@ class QuantizationTransformPass(object): attrs={ 'bit_length': quant_bits, 'quant_axis': quant_axis, + 'is_test': self._is_test, 'op_role': core.op_proto_and_checker_maker.OpRole.Forward }, inputs={'X': var_node}, @@ -1125,7 +1126,7 @@ class QuantizationFreezePass(object): self._restore_var(input_arg_name, quantized_param_v) self._remove_fake_quant_and_dequant_op(graph, op_node) -# Remove all fake dequant op + # Remove all fake dequant op ops = graph.all_op_nodes() for op_node in ops: op_name = op_node.name() @@ -1331,16 +1332,25 @@ class QuantizationFreezePass(object): def _quant(self, x, scale, num_bits, quant_axis): assert quant_axis in [0, 1], 'quant_axis should be 0 or 1 for now.' + bnt = (1 << (num_bits - 1)) - 1 + + def _clip(x, scale): + x[x > scale] = scale + x[x < -scale] = -scale + return x + if isinstance(scale, list): for i, s in enumerate(scale): if quant_axis == 0: - x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1)) + x[i] = _clip(x[i], s) + x[i] = np.round(x[i] / s * bnt) else: - x[:, i] = np.round(x[:, i] / s * ( - (1 << (num_bits - 1)) - 1)) - return x + x[:, i] = _clip(x[:, i], s) + x[:, i] = np.round(x[:, i] / s * bnt) else: - return np.round(x / scale * ((1 << (num_bits - 1)) - 1)) + x = _clip(x, scale) + x = np.round(x / scale * bnt) + return x class ConvertToInt8Pass(object): -- GitLab