From ec88b6cc5aa403ebeca600adbe5f2be99e40f064 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Tue, 19 Mar 2019 22:42:26 +0800 Subject: [PATCH] add channel wise quantization in ir pass. --- paddle/fluid/operators/fake_dequantize_op.h | 46 +++- paddle/fluid/operators/fake_quantize_op.cc | 8 +- paddle/fluid/operators/fake_quantize_op.h | 8 +- .../slim/quantization/quantization_pass.py | 223 +++++++++++++++--- .../slim/tests/test_quantization_pass.py | 28 ++- .../unittests/test_fake_dequantize_op.py | 45 +++- .../tests/unittests/test_fake_quantize_op.py | 2 +- 7 files changed, 290 insertions(+), 70 deletions(-) diff --git a/paddle/fluid/operators/fake_dequantize_op.h b/paddle/fluid/operators/fake_dequantize_op.h index d05f20385..1a504bf03 100644 --- a/paddle/fluid/operators/fake_dequantize_op.h +++ b/paddle/fluid/operators/fake_dequantize_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" @@ -54,10 +55,6 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { auto scales = ctx.MultiInput("Scales"); auto* out = ctx.Output("Out"); - PADDLE_ENFORCE_EQ(scales[0]->numel(), in->dims()[0], - "The number of first scale values must be the same with " - "first dimension value of Input(X)."); - auto quant_bits = ctx.Attr>("quant_bits"); int max_range = std::pow(2, quant_bits[0] - 1) - 1; @@ -65,15 +62,38 @@ class FakeChannelWiseDequantizeMaxAbsKernel : public framework::OpKernel { out->mutable_data(dev_ctx.GetPlace()); auto dequant = DequantizeFunctor(); - for (int64_t i = 0; i < in->dims()[0]; i++) { - framework::Tensor one_channel_in = in->Slice(i, i + 1); - framework::Tensor one_channel_out = out->Slice(i, i + 1); - framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); - dequant(dev_ctx, &one_channel_in, &one_channel_scale, - static_cast(max_range), &one_channel_out); - } - - if (scales.size() == 2) { + if (scales.size() == 1) { + PADDLE_ENFORCE_EQ( + scales[0]->numel(), in->dims()[0], + "The number of first scale values must be the same with " + "first dimension value of Input(X) when the `Scales` has only one " + "element."); + for (int64_t i = 0; i < in->dims()[0]; i++) { + framework::Tensor one_channel_in = in->Slice(i, i + 1); + framework::Tensor one_channel_out = out->Slice(i, i + 1); + framework::Tensor one_channel_scale = scales[0]->Slice(i, i + 1); + dequant(dev_ctx, &one_channel_in, &one_channel_scale, + static_cast(max_range), &one_channel_out); + } + } else if (scales.size() == 2) { + PADDLE_ENFORCE_EQ( + scales[0]->numel(), in->dims()[1], + "The number of first scale values must be the same with " + "second dimension value of Input(X) when the `Scales` has two " + "elements."); + for (int64_t i = 0; i < in->dims()[0]; i++) { + framework::Tensor one_batch_in = in->Slice(i, i + 1).Resize( + framework::slice_ddim(in->dims(), 1, in->dims().size())); + framework::Tensor one_batch_out = out->Slice(i, i + 1).Resize( + framework::slice_ddim(out->dims(), 1, out->dims().size())); + for (int64_t j = 0; j < in->dims()[1]; j++) { + framework::Tensor one_channel_in = one_batch_in.Slice(j, j + 1); + framework::Tensor one_channel_out = one_batch_out.Slice(j, j + 1); + framework::Tensor one_channel_scale = scales[0]->Slice(j, j + 1); + dequant(dev_ctx, &one_channel_in, &one_channel_scale, + static_cast(max_range), &one_channel_out); + } + } PADDLE_ENFORCE_EQ( scales[1]->numel(), 1, "The second scale tensor should only have one value at now."); diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index d51d51b49..f9f28f60c 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -169,10 +169,10 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ctx->HasOutput("Out"), "Output(Out) of FakeChannelWiseQuantizeOp should not be null."); PADDLE_ENFORCE( - ctx->HasOutput("OutScales"), - "Output(Scales) of FakeChannelWiseQuantizeOp should not be null."); + ctx->HasOutput("OutScale"), + "Output(Scale) of FakeChannelWiseQuantizeOp should not be null."); ctx->SetOutputDim("Out", ctx->GetInputDim("X")); - ctx->SetOutputDim("OutScales", {ctx->GetInputDim("X")[0]}); + ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]}); ctx->ShareLoD("X", /*->*/ "Out"); } @@ -192,7 +192,7 @@ class FakeChannelWiseQuantizeAbsMaxOpMaker AddOutput("Out", "(Tensor) Output of quantized low level tensor, " "but also saved as float data type."); - AddOutput("OutScales", "(Tensor) Current channel wise scale"); + AddOutput("OutScale", "(Tensor) Current channel wise scale"); AddAttr("bit_length", "(int, default 8)") .SetDefault(8) .AddCustomChecker([](const int& bit_length) { diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index ec667e89e..2616cd996 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -78,8 +78,8 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { auto* in = context.Input("X"); auto* out = context.Output("Out"); - auto* out_scales = context.Output("OutScales"); - T* out_scales_data = out_scales->mutable_data(context.GetPlace()); + auto* out_scale = context.Output("OutScale"); + T* out_scale_data = out_scale->mutable_data(context.GetPlace()); out->mutable_data(context.GetPlace()); int bit_length = context.Attr("bit_length"); @@ -91,13 +91,13 @@ class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel { framework::Tensor one_channel = in->Slice(i, i + 1); const T* one_channel_data = one_channel.data(); find_abs_max(dev_ctx, one_channel_data, one_channel.numel(), - &out_scales_data[i]); + &out_scale_data[i]); } auto clip_quant = ClipAndFakeQuantFunctor(); for (int64_t i = 0; i < in->dims()[0]; i++) { framework::Tensor one_channel_in = in->Slice(i, i + 1); framework::Tensor one_channel_out = out->Slice(i, i + 1); - framework::Tensor one_channel_scale = out_scales->Slice(i, i + 1); + framework::Tensor one_channel_scale = out_scale->Slice(i, i + 1); clip_quant(dev_ctx, one_channel_in, one_channel_scale, bin_cnt, &one_channel_out); } diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 919db4c78..03ffd2795 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -22,6 +22,7 @@ from ....framework import IrGraph from ....framework import IrNode from ....framework import Program from ....initializer import Constant +from ....initializer import NumpyArrayInitializer from .... import unique_name __all__ = [ @@ -54,14 +55,15 @@ class QuantizationTransformPass(object): the bias is not quantized. activation_bits (int): quantization bit number for activation. activation_quantize_type (str): quantization type for activation, - now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode, - the quantization scale will be calculated dynamically each step - in both training and testing period. If use 'range_abs_max', - a static quantization scale will be calculated during training - and used in inference. + now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'. + If use 'abs_max' mode, the quantization scale will be calculated + dynamically each step in both training and testing period. If use + 'range_abs_max', a static quantization scale will be calculated + during training and used in inference. weight_quantize_type (str): quantization type for weights, - support 'abs_max'. The 'range_abs_max' usually is not used for - weight, since weights are fixed once the model is well trained. + support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max' + usually is not used for weight, since weights are fixed once the + model is well trained. window_size (int): the window size for 'range_abs_max' quantization. Examples: @@ -84,7 +86,11 @@ class QuantizationTransformPass(object): self._weight_bits = weight_bits self._activation_bits = activation_bits - quant_type = ['abs_max', 'range_abs_max', 'moving_average_abs_max'] + quant_type = [ + 'abs_max', 'channel_wise_abs_max', 'range_abs_max', + 'moving_average_abs_max' + ] + assert activation_quantize_type != 'channel_wise_abs_max', "The activation quantization type does not support 'channel_wise_abs_max'." if activation_quantize_type not in quant_type: raise ValueError( "Unknown activation_quantize_type : '%s'. It can only be ", @@ -93,7 +99,7 @@ class QuantizationTransformPass(object): if weight_quantize_type not in quant_type: raise ValueError( "Unknown weight_quantize_type: '%s'. It can only be ", - "'abs_max' or 'range_abs_max' or 'moving_average_abs_max'.", + "'abs_max' or 'channel_wise_abs_max' or 'range_abs_max' or 'moving_average_abs_max'.", str(weight_quantize_type)) self._activation_quantize_type = activation_quantize_type @@ -103,6 +109,7 @@ class QuantizationTransformPass(object): self._need_initialized = collections.OrderedDict() self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._quantizable_grad_ops = [ '%s_grad' % (op) for op in self._quantizable_ops ] @@ -135,10 +142,26 @@ class QuantizationTransformPass(object): else self._activation_bits quant_type = self._weight_quantize_type if var_node.name() \ in persistable_vars else self._activation_quantize_type - quant_var_node, scale_var_node = self._insert_quant_op( - graph, var_node, quant_bits, quant_type) - dequant_var_node = self._insert_dequant_op( - graph, quant_var_node, scale_var_node, quant_bits) + if quant_type == 'channel_wise_abs_max': + assert var_node.name( + ) in persistable_vars, "'channel_wise_abs_max' can only be applied on weights." + if op.name() in self._conv_ops: + quant_var_node, scale_var_node = self._insert_channel_quant_op( + graph, var_node, quant_bits) + dequant_var_node = self._insert_channel_dequant_op( + graph, quant_var_node, [scale_var_node], + [quant_bits]) + else: + quant_var_node, scale_var_node = self._insert_quant_op( + graph, var_node, quant_bits, 'abs_max') + dequant_var_node = self._insert_dequant_op( + graph, quant_var_node, scale_var_node, + quant_bits) + else: + quant_var_node, scale_var_node = self._insert_quant_op( + graph, var_node, quant_bits, quant_type) + dequant_var_node = self._insert_dequant_op( + graph, quant_var_node, scale_var_node, quant_bits) dequantized_vars[var_node.name()] = dequant_var_node graph.update_input_link(var_node, dequant_var_node, op) @@ -244,7 +267,7 @@ class QuantizationTransformPass(object): scale_var_node = graph.create_var_node( name=self._quantized_scale_name(var_node.name()), var_type=var_node.type(), - shape=var_node.shape(), + shape=[1], var_dtype=var_node.dtype()) quant_op_node = graph.create_op_node( op_type='fake_quantize_abs_max', @@ -384,6 +407,36 @@ class QuantizationTransformPass(object): return quant_var_node, scale_out_node + def _insert_channel_quant_op(self, graph, var_node, quant_bits): + """ + Insert fake_channel_wise_quantize_abs_max op in the graph. + """ + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + quant_var_node = graph.create_var_node( + name=self._quantized_var_name(var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) + scale_var_node = graph.create_var_node( + name=self._quantized_scale_name(var_node.name()), + var_type=var_node.type(), + shape=[var_node.shape()[0]], + var_dtype=var_node.dtype()) + quant_op_node = graph.create_op_node( + op_type='fake_channel_wise_quantize_abs_max', + attrs={ + 'bit_length': quant_bits, + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }, + inputs={'X': var_node}, + outputs={'Out': quant_var_node, + 'OutScale': scale_var_node}) + graph.link_to(var_node, quant_op_node) + graph.link_to(quant_op_node, quant_var_node) + graph.link_to(quant_op_node, scale_var_node) + return quant_var_node, scale_var_node + def _insert_dequant_op(self, graph, var_node, scale_var_node, quant_bits): """ Insert fake_dequantize_op in the graph. @@ -410,6 +463,33 @@ class QuantizationTransformPass(object): graph.link_to(dequant_op_node, dequant_var_node) return dequant_var_node + def _insert_channel_dequant_op(self, graph, var_node, scale_var_nodes, + quant_bits): + """ + Insert fake_channel_wise_dequantize_max_abs in the graph. + """ + assert var_node.is_var(), '{} is not a var'.format(var_node.name()) + + dequant_var_node = graph.create_var_node( + name=self._dequantized_var_name(var_node.name()), + var_type=var_node.type(), + shape=var_node.shape(), + var_dtype=var_node.dtype()) + dequant_op_node = graph.create_op_node( + op_type='fake_channel_wise_dequantize_max_abs', + attrs={ + 'quant_bits': quant_bits, + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }, + inputs={'X': var_node, + 'Scales': scale_var_nodes}, + outputs={'Out': dequant_var_node}) + graph.link_to(var_node, dequant_op_node) + for scale_n in scale_var_nodes: + graph.link_to(scale_n, dequant_op_node) + graph.link_to(dequant_op_node, dequant_var_node) + return dequant_var_node + def _quantized_var_name(self, var_name): """ Return quantized variable name for the input `var_name`. @@ -442,7 +522,7 @@ class QuantizationFreezePass(object): place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors. weight_bits (int): quantization bit number for weights. activation_bits (int): quantization bit number for activation. - weight_quantize_type (str): quantization type for weights, support 'abs_max'. + weight_quantize_type (str): quantization type for weights, support 'abs_max' and 'channel_wise_abs_max'. The 'range_abs_max' usually is not used for weight, since weights are fixed once the model is well trained. """ @@ -463,11 +543,15 @@ class QuantizationFreezePass(object): self._activation_bits = activation_bits self._weight_quantize_type = weight_quantize_type self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._fake_quant_op_names = [ 'fake_quantize_abs_max', 'fake_quantize_range_abs_max', - 'fake_quantize_moving_average_abs_max' + 'fake_quantize_moving_average_abs_max', + 'fake_channel_wise_quantize_abs_max' + ] + self._fake_dequant_op_names = [ + 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' ] - self._fake_dequant_op_names = ['fake_dequantize_max_abs'] self._op_input_rename_map = collections.OrderedDict() self._op_output_rename_map = collections.OrderedDict() self._var_scale_map = collections.OrderedDict() @@ -489,20 +573,29 @@ class QuantizationFreezePass(object): if self._weight_quantize_type == 'abs_max': param = self._load_var(input_arg_name) scale_v = np.max(np.abs(param)) + elif self._weight_quantize_type == 'channel_wise_abs_max': + param = self._load_var(input_arg_name) + if len(param.shape) == 4: # conv2d or depthwise_conv2d + print('DEBUG**************************: %s' % + input_arg_name) + scale_v = [] + for i in range(param.shape[0]): + scale_v.append(np.max(np.abs(param[i]))) + else: + scale_v = np.max(np.abs(param)) else: scale_v = self._load_var( op_node.output('OutScale')[0])[0] self._var_scale_map[input_arg_name] = scale_v - else: - scale_v = graph.var_node(op_node.output('OutScale')[0]) - self._var_scale_map[input_arg_name] = scale_v - if input_arg_name in persistable_vars: self._remove_fake_quant_and_dequant_op(graph, op_node) # quantize weight and restore param_v = self._load_var(input_arg_name) quantized_param_v = self._quant(param_v, scale_v, self._weight_bits) self._restore_var(input_arg_name, quantized_param_v) + else: + scale_v = graph.var_node(op_node.output('OutScale')[0]) + self._var_scale_map[input_arg_name] = scale_v ops = graph.all_op_nodes() for op_node in ops: @@ -514,7 +607,10 @@ class QuantizationFreezePass(object): for op_node in ops: op_name = op_node.name() if op_name in self._quantizable_ops: - self._insert_post_dequant_op(graph, op_node) + if self._weight_quantize_type == 'channel_wise_abs_max' and op_name in self._conv_ops: + self._insert_post_channel_dequant_op(graph, op_node) + else: + self._insert_post_dequant_op(graph, op_node) for op_node in ops: # insert dequant_op after fc/conv, need to rename inputs of the followed ops @@ -538,9 +634,73 @@ class QuantizationFreezePass(object): self._op_input_rename_map[k] = self._op_input_rename_map[v] graph.safe_remove_nodes(op_node) + def _insert_post_channel_dequant_op(self, graph, op_node): + persistable_vars = [p.name() for p in graph.all_persistable_nodes()] + for var_node in op_node.inputs: + name = var_node.name() + if name in self._op_input_rename_map: + old_in = graph.var_node(name) + new_in = graph.var_node(self._op_input_rename_map[name]) + new_in.clear_outputs() + graph.update_input_link(old_in, new_in, op_node) + original_var_name = self._original_var_name(name) + scale_v = self._var_scale_map[original_var_name] + if original_var_name in persistable_vars: + assert isinstance( + scale_v, + list), 'The scale of parameter %s is not a list.' % ( + original_var_name) + channel_scale = np.array(scale_v) + else: + assert isinstance(scale_v, IrNode) + scale_var_node = self._var_scale_map[original_var_name] + + if len(op_node.outputs) != 1: + raise ValueError("Only support one output, but op %s has" + " more than one output." % (op_node.name())) + + output_var_node = op_node.outputs[0] + weight_scale_node = graph.create_persistable_node( + name=unique_name.generate('channel_scale'), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[channel_scale.shape[0]], + var_dtype=output_var_node.dtype()) + init_program = Program() + weight_scale_var = init_program.global_block().create_var( + name=weight_scale_node.name(), + shape=weight_scale_node.shape(), + dtype=weight_scale_node.dtype(), + type=weight_scale_node.type(), + lod_level=weight_scale_node.var().lod_level(), + persistable=weight_scale_node.persistable()) + initializer = NumpyArrayInitializer(value=channel_scale) + initializer(weight_scale_var, init_program.global_block()) + exe = Executor(self._place) + exe.run(program=init_program, scope=self._scope) + dequant_var_node = graph.create_var_node( + name=self._dequantized_var_name(output_var_node.name()), + var_type=output_var_node.type(), + shape=output_var_node.shape(), + var_dtype=output_var_node.dtype()) + dequant_op_node = graph.create_op_node( + op_type='fake_channel_wise_dequantize_max_abs', + attrs={ + 'quant_bits': [self._weight_bits, self._activation_bits], + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + }, + inputs={ + 'X': output_var_node, + 'Scales': [weight_scale_node, scale_var_node] + }, + outputs={'Out': dequant_var_node}) + graph.link_to(output_var_node, dequant_op_node) + graph.link_to(scale_var_node, dequant_op_node) + graph.link_to(weight_scale_node, dequant_op_node) + graph.link_to(dequant_op_node, dequant_var_node) + self._op_output_rename_map[output_var_node.name()] = dequant_var_node + return dequant_var_node + def _insert_post_dequant_op(self, graph, op_node): - max_range = None - scale_var_node = None persistable_vars = [p.name() for p in graph.all_persistable_nodes()] for var_node in op_node.inputs: name = var_node.name() @@ -637,7 +797,12 @@ class QuantizationFreezePass(object): or isinstance(v, np.float64) def _quant(self, x, scale, num_bits): - return np.round(x / scale * ((1 << (num_bits - 1)) - 1)) + if isinstance(scale, list): + for i, s in enumerate(scale): + x[i] = np.round(x[i] / s * ((1 << (num_bits - 1)) - 1)) + return x + else: + return np.round(x / scale * ((1 << (num_bits - 1)) - 1)) class ConvertToInt8Pass(object): @@ -731,9 +896,13 @@ class TransformForMobilePass(object): def __init__(self): self._fake_quant_op_names = [ - 'fake_quantize_abs_max', 'fake_quantize_range_abs_max' + 'fake_quantize_abs_max', 'fake_quantize_range_abs_max', + 'fake_quantize_moving_average_abs_max', + 'fake_channel_wise_quantize_abs_max' + ] + self._fake_dequant_op_names = [ + 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' ] - self._fake_dequant_op_names = ['fake_dequantize_max_abs'] def apply(self, graph): """ diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py index 0b4b2a285..cda7ecbd8 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py @@ -243,7 +243,12 @@ class TestQuantizationFreezePass(unittest.TestCase): with fluid.scope_guard(scope): exe.run(startup) transform_pass = QuantizationTransformPass( - scope=scope, place=place, activation_quantize_type=quant_type) + scope=scope, + place=place, + activation_quantize_type=quant_type, + weight_quantize_type='channel_wise_abs_max') + #transform_pass = QuantizationTransformPass( + # scope=scope, place=place, activation_quantize_type=quant_type) transform_pass.apply(main_graph) transform_pass.apply(test_graph) dev_name = '_gpu_' if use_cuda else '_cpu_' @@ -296,7 +301,11 @@ class TestQuantizationFreezePass(unittest.TestCase): fetch_list=[loss, w_var]) # Freeze graph for inference, but the weight of fc/conv is still float type. - freeze_pass = QuantizationFreezePass(scope=scope, place=place) + freeze_pass = QuantizationFreezePass( + scope=scope, + place=place, + weight_quantize_type='channel_wise_abs_max') + #freeze_pass = QuantizationFreezePass(scope=scope, place=place) freeze_pass.apply(test_graph) if not for_ci: marked_nodes = set() @@ -375,29 +384,32 @@ class TestQuantizationFreezePass(unittest.TestCase): if fluid.core.is_compiled_with_cuda(): with fluid.unique_name.guard(): self.freeze_graph( - True, seed=1, quant_type='abs_max', for_ci=True) + True, seed=1, quant_type='abs_max', for_ci=False) def test_freeze_graph_cpu_dynamic(self): with fluid.unique_name.guard(): - self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=True) + self.freeze_graph(False, seed=2, quant_type='abs_max', for_ci=False) def test_freeze_graph_cuda_static(self): if fluid.core.is_compiled_with_cuda(): with fluid.unique_name.guard(): self.freeze_graph( - True, seed=1, quant_type='range_abs_max', for_ci=True) + True, seed=1, quant_type='range_abs_max', for_ci=False) self.freeze_graph( True, seed=1, quant_type='moving_average_abs_max', - for_ci=True) + for_ci=False) def test_freeze_graph_cpu_static(self): with fluid.unique_name.guard(): self.freeze_graph( - False, seed=2, quant_type='range_abs_max', for_ci=True) + False, seed=2, quant_type='range_abs_max', for_ci=False) self.freeze_graph( - False, seed=2, quant_type='moving_average_abs_max', for_ci=True) + False, + seed=2, + quant_type='moving_average_abs_max', + for_ci=False) if __name__ == '__main__': diff --git a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py index 32cb23cbf..0812b02b4 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_dequantize_op.py @@ -31,15 +31,27 @@ def dequantize_max_abs(x, scale, max_range): return y -def channel_wise_quantize_max_abs(x, quant_bit=8): +def channel_wise_quantize_max_abs(x, quant_bit=8, use_second_dim=False): scales = [] - for i in range(x.shape[0]): - scales.append(np.max(np.abs(x[i])).astype("float32")) - - y = x.copy() - max_range = math.pow(2, quant_bit - 1) - 1 - for i, scale in enumerate(scales): - y[i] = np.round(y[i] / scale * max_range) + if not use_second_dim: + for i in range(x.shape[0]): + scales.append(np.max(np.abs(x[i])).astype("float32")) + y = x.copy() + max_range = math.pow(2, quant_bit - 1) - 1 + for i, scale in enumerate(scales): + y[i] = np.round(x[i] / scale * max_range) + else: + for i in range(x.shape[0]): + s = [] + for j in range(x.shape[1]): + s.append(np.max(np.abs(x[i][j])).astype("float32")) + scales.append(s) + scales = np.amax(np.array(scales), axis=0) + y = x.copy() + max_range = math.pow(2, quant_bit - 1) - 1 + for i in range(x.shape[0]): + for j, scale in enumerate(scales): + y[i][j] = np.round(x[i][j] / scale * max_range) return y, scales @@ -47,10 +59,16 @@ def channel_wise_dequantize_max_abs(x, scales, quant_bits, activation_scale=None): - y = x.copy() - for i in range(x.shape[0]): - y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * y[i] - if activation_scale is not None: + if activation_scale is None: + y = x.copy() + for i in range(x.shape[0]): + y[i] = (scales[i] / (math.pow(2, quant_bits[0] - 1) - 1)) * x[i] + else: + y = x.copy() + for i in range(x.shape[0]): + for j in range(x.shape[1]): + y[i][j] = (scales[j] / + (math.pow(2, quant_bits[0] - 1) - 1)) * x[i][j] y *= activation_scale / (math.pow(2, quant_bits[1] - 1) - 1) return y @@ -65,7 +83,8 @@ class TestFakeChannelWiseDequantizeMaxAbsOpTwoScales(OpTest): self.set_args() self.op_type = "fake_channel_wise_dequantize_max_abs" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) - yq, scales = channel_wise_quantize_max_abs(x, self.quant_bits[0]) + yq, scales = channel_wise_quantize_max_abs( + x, self.quant_bits[0], use_second_dim=True) ydq = channel_wise_dequantize_max_abs(yq, scales, self.quant_bits, self.activation_scale) diff --git a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py index cf8f01edb..07038b044 100644 --- a/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py +++ b/python/paddle/fluid/tests/unittests/test_fake_quantize_op.py @@ -53,7 +53,7 @@ class TestFakeChannelWiseQuantizeOp(OpTest): self.outputs = { 'Out': outputs, - 'OutScales': np.array(scales).astype("float32"), + 'OutScale': np.array(scales).astype("float32"), } def test_check_output(self): -- GitLab