diff --git a/python/paddle/fluid/tests/transpiler/CMakeLists.txt b/python/paddle/fluid/tests/transpiler/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..79bec8c4ad34d682895250bc29b1fddb3a569bd4 --- /dev/null +++ b/python/paddle/fluid/tests/transpiler/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py") +string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}") + +foreach(src ${TEST_OPS}) + py_test(${src} SRCS ${src}.py) +endforeach() diff --git a/python/paddle/fluid/tests/transpiler/test_quantize_transpiler.py b/python/paddle/fluid/tests/transpiler/test_quantize_transpiler.py new file mode 100644 index 0000000000000000000000000000000000000000..5245b5ea09918d830fcdeb541c65810b139a790e --- /dev/null +++ b/python/paddle/fluid/tests/transpiler/test_quantize_transpiler.py @@ -0,0 +1,254 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import numpy as np +import unittest +import paddle +import paddle.fluid as fluid +from paddle.fluid.transpiler.quantize_transpiler import _original_var_name + + +def linear_fc(num): + data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = data + for _ in xrange(num): + hidden = fluid.layers.fc(hidden, size=128, act='relu') + loss = fluid.layers.cross_entropy(input=hidden, label=label) + loss = fluid.layers.mean(loss) + return loss + + +def residual_block(num): + def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + act='relu', + bias_attr=False): + tmp = fluid.layers.conv2d( + input=input, + filter_size=filter_size, + num_filters=ch_out, + stride=stride, + padding=padding, + act=None, + bias_attr=bias_attr) + return fluid.layers.batch_norm(input=tmp, act=act) + + data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + hidden = data + for _ in xrange(num): + conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True) + short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None) + hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') + fc = fluid.layers.fc(input=hidden, size=10) + loss = fluid.layers.cross_entropy(input=fc, label=label) + loss = fluid.layers.mean(loss) + return loss + + +def conv_net(img, label): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu") + conv_pool_1 = fluid.layers.batch_norm(conv_pool_1) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu") + prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + return avg_loss + + +class TestQuantizeTranspiler(unittest.TestCase): + def setUp(self): + # since quant_op and dequant_op is not ready, use cos and sin for test + self.weight_quant_op_type = 'fake_quantize_abs_max' + self.dequant_op_type = 'fake_dequantize_max_abs' + self.quantizable_op_and_inputs = { + 'conv2d': ['Input', 'Filter'], + 'depthwise_conv2d': ['Input', 'Filter'], + 'mul': ['X', 'Y'] + } + self.quantizable_op_grad_and_inputs = { + 'conv2d_grad': ['Input', 'Filter'], + 'depthwise_conv2d_grad': ['Input', 'Filter'], + 'mul_grad': ['X', 'Y'] + } + + def check_program(self, program): + quantized_ops = {} + + persistable_vars = [ + v.name + for v in filter(lambda var: var.persistable, program.list_vars()) + ] + + for block in program.blocks: + for idx, op in enumerate(block.ops): + # check forward + if op.type in self.quantizable_op_and_inputs: + for i, arg_name in enumerate(op.input_arg_names): + quant_op_type = self.weight_quant_op_type if \ + _original_var_name(arg_name) \ + in persistable_vars else self.act_quant_op_type + self.assertTrue( + arg_name.endswith('.quantized.dequantized')) + if arg_name not in quantized_ops: + self.assertEqual(block.ops[idx - 2 * i - 1].type, + self.dequant_op_type) + self.assertEqual(block.ops[idx - 2 * i - 2].type, + quant_op_type) + quantized_ops[arg_name] = block.ops[idx - 2 * i - 2] + else: + op_idx = block.ops.index(quantized_ops[arg_name]) + self.assertLess(op_idx, idx) + + # check backward + if op.type in self.quantizable_op_grad_and_inputs: + for pname in self.quantizable_op_grad_and_inputs[op.type]: + arg_name = op.input(pname)[0] + self.assertTrue( + arg_name.endswith('.quantized.dequantized')) + self.assertTrue(arg_name in quantized_ops) + + def linear_fc_quant(self, quant_type): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = linear_fc(3) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + t = fluid.QuantizeTranspiler(activation_quantize_type=quant_type) + t.training_transpile(main) + self.check_program(main) + + def test_linear_fc_quant_abs_max(self): + self.act_quant_op_type = 'fake_quantize_abs_max' + self.linear_fc_quant('abs_max') + + def test_linear_fc_quant_range_abs_max(self): + self.act_quant_op_type = 'fake_quantize_range_abs_max' + self.linear_fc_quant('range_abs_max') + + def residual_block_quant(self, quant_type): + main = fluid.Program() + startup = fluid.Program() + with fluid.program_guard(main, startup): + loss = residual_block(2) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + t = fluid.QuantizeTranspiler(activation_quantize_type=quant_type) + t.training_transpile(main) + self.check_program(main) + + def test_residual_block_abs_max(self): + self.act_quant_op_type = 'fake_quantize_abs_max' + self.residual_block_quant('abs_max') + + def test_residual_block_range_abs_max(self): + self.act_quant_op_type = 'fake_quantize_range_abs_max' + self.residual_block_quant('range_abs_max') + + def freeze_program(self, use_cuda): + main = fluid.Program() + startup = fluid.Program() + quant_transpiler = fluid.QuantizeTranspiler() + with fluid.program_guard(main, startup): + img = fluid.layers.data( + name='image', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + loss = conv_net(img, label) + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + quant_transpiler.training_transpile(main) + + test_program = main.clone() + with fluid.program_guard(test_program): + test_program = fluid.io.get_inference_program(loss) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + iter = 5 + batch_size = 8 + class_num = 10 + exe.run(startup) + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + feeder = fluid.DataFeeder(feed_list=[img, label], place=place) + + for _ in range(iter): + data = train_reader().next() + loss_v = exe.run(program=main, + feed=feeder.feed(data), + fetch_list=[loss]) + test_data = test_reader().next() + + f_var = fluid.framework.get_var('conv2d_1.tmp_0', test_program) + w_var = fluid.framework.get_var('conv2d_1.w_0.quantized', test_program) + # Testing during training + test_loss1, f_v1, w_quant = exe.run(program=test_program, + feed=feeder.feed(test_data), + fetch_list=[loss, f_var, w_var]) + + # Freeze program for inference, but the weight of fc/conv is still float type. + quant_transpiler.freeze_program(test_program, place) + fv2 = fluid.framework.get_var('conv2d_1.tmp_0.dequantized', + test_program) + test_loss2, f_v2 = exe.run(program=test_program, + feed=feeder.feed(test_data), + fetch_list=[loss, fv2]) + self.assertAlmostEqual(test_loss1, test_loss2, delta=1e-5) + self.assertAlmostEqual(f_v1.all(), f_v2.all(), delta=1e-5) + w_freeze = np.array(fluid.global_scope().find_var('conv2d_1.w_0') + .get_tensor()) + self.assertEqual(np.sum(w_freeze), np.sum(w_quant)) + + # Convert parameter to 8-bit. + quant_transpiler.convert_to_int8(test_program, place) + # Save the 8-bit parameter and model file. + fluid.io.save_inference_model('model_8bit', ['image', 'label'], [loss], + exe, test_program) + # Test whether the 8-bit parameter and model file can be loaded successfully. + [infer, feed, fetch] = fluid.io.load_inference_model('model_8bit', exe) + # Check the loaded 8-bit weight. + w_8bit = np.array(fluid.global_scope().find_var('conv2d_1.w_0.int8') + .get_tensor()) + + self.assertEqual(w_8bit.dtype, np.int8) + self.assertEqual(np.sum(w_8bit), np.sum(w_freeze)) + + def test_freeze_program_cuda(self): + self.freeze_program(True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/transpiler/quantize_transpiler.py b/python/paddle/fluid/transpiler/quantize_transpiler.py new file mode 100644 index 0000000000000000000000000000000000000000..ec8193760bdb95eebb6891d1ff1791e139091511 --- /dev/null +++ b/python/paddle/fluid/transpiler/quantize_transpiler.py @@ -0,0 +1,545 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import numpy as np + +from paddle.fluid.framework import default_main_program, default_startup_program, program_guard +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid import unique_name +from paddle.fluid.initializer import Constant +from paddle.fluid.param_attr import ParamAttr +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.layers.nn import autoincreased_step_counter +from .. import core +from ..framework import Variable +from ..executor import global_scope +from inference_transpiler import InferenceTranspiler + +_QUANTIZABLE_OP_TYPES = ['conv2d', 'depthwise_conv2d', 'mul'] + + +def _quantized_var_name(var_name): + """ + Return quantized variable name for the input `var_name`. + """ + return "%s.quantized" % (var_name) + + +def _dequantized_var_name(var_name): + """ + Return dequantized variable name for the input `var_name`. + """ + return "%s.dequantized" % (var_name) + + +def _quantized_scale_name(var_name): + """ + Return quantized variable name for the input `var_name`. + """ + return "%s.scale" % (var_name) + + +def _original_var_name(var_name): + """ + Return the original variable name. + """ + if var_name.endswith('.quantized.dequantized'): + return var_name[:-len('.quantized.dequantized')] + if var_name.endswith('.quantized'): + return var_name[:-len('.quantized')] + if var_name.endswith('.dequantized'): + return var_name[:-len('.dequantized')] + if var_name.endswith('.scale'): + return var_name[:-len('.scale')] + else: + return var_name + + +def _is_float(v): + return isinstance(v, float) or isinstance(v, np.float32) + + +def quant(x, scale, num_bits): + y = np.round(x / scale * ((1 << (num_bits - 1)) - 1)) + return y + + +class QuantizeTranspiler(object): + def __init__(self, + weight_bits=8, + activation_bits=8, + activation_quantize_type='abs_max', + weight_quantize_type='abs_max', + window_size=10000): + """ + Convert and rewrite the fluid Program according to weight and + activation quantization type. + + Args: + weight_bits (int): quantization bit number for weights, + the bias is not quantized. + activation_bits (int): quantization bit number for activation. + activation_quantize_type (str): quantization type for activation, + now support 'abs_max', 'range_abs_max'. If use 'abs_max' mode, + the quantization scale will be calculated dynamically each step + in both training and testing period. If use 'range_abs_max', + a static quantization scale will be calculated during training + and used in inference. + weight_quantize_type (str): quantization type for weights, + support 'abs_max'. The 'range_abs_max' usually is not used for + weight, since weights are fixed once the model is well trained. + window_size (int): the window size for 'range_abs_max' quantization. + + Examples: + + .. code-block:: python + + # the original program will be rewrite, if you don't want to + # change it, please clone at first. + # quantize_program = program.clone() + t = fluid.QuantizeTranspiler() + t.transpile(quantize_program) + + """ + self.weight_bits = weight_bits + self.activation_bits = activation_bits + quant_type = ['abs_max', 'range_abs_max'] + if weight_quantize_type not in quant_type: + raise ValueError( + "Unknown weight_quantize_type: '%s'. It can only be ", + "'abs_max' or 'range_abs_max'.", str(weight_quantize_type)) + if activation_quantize_type not in quant_type: + raise ValueError( + "Unknown activation_quantize_type : '%s'. It can only be ", + "'abs_max' or 'range_abs_max'.", str(activation_quantize_type)) + + self.weight_quantize_type = weight_quantize_type + self.activation_quantize_type = activation_quantize_type + + self.window_size = window_size + self.helper = LayerHelper(self.__class__.__name__) + self.fake_quant_op_types = [ + 'fake_quantize_abs_max', 'fake_quantize_range_abs_max' + ] + self.fake_dequant_op_types = ['fake_dequantize_max_abs'] + self.is_test = None + self.global_step = None + + def training_transpile(self, program=None, startup_program=None): + """Rewrites a training input program in place for simulated + quantization. Insert fake quantization and de-quantization ops into + program to simulate the error introduced by quantization. And change + the graident ops' input by using the faked quantization weights and + activation. Since the program is transformed in place, the graph + connection will change. + + Args: + program (Program): the input program to be transpile. + """ + self.is_test = False + program = default_main_program() if program is None else program + startup_program = default_startup_program() if startup_program is \ + None else startup_program + + # marked the variable which has been quantized and dequantized. + dequanted_vars = [ + collections.OrderedDict() for _ in range(len(program.blocks)) + ] + grad_op_types = ['%s_grad' % (type) for type in _QUANTIZABLE_OP_TYPES] + + params = [p.name for p in program.global_block().iter_parameters()] + + def _transpile_forward(block, op): + idx = block.ops.index(op) + block_id = block.idx + # insert quant op and dequant op + for name in op.input_arg_names: + if name in dequanted_vars[block_id]: + dequant_var = dequanted_vars[block_id][name] + else: + var = block.var(name) + quant_bits = self.weight_bits if var.name in params \ + else self.activation_bits + quant_type = self.weight_quantize_type if var.name \ + in params else self.activation_quantize_type + + quant_var, scale_var = self._insert_quant_op( + block, idx, var, quant_bits, quant_type) + dequant_var = self._insert_dequant_op( + block, idx + 1, quant_var, scale_var, quant_bits) + dequanted_vars[block_id][name] = dequant_var + # rename the forward op inputs + op.rename_input(name, dequant_var.name) + + def _transpile_backward(block, op): + block_id = block.idx + no_dequanted_input_vars = True + for name in op.input_arg_names: + if name in dequanted_vars[block_id]: + dequant_var = dequanted_vars[block_id][name] + op.rename_input(name, dequant_var.name) + no_dequanted_input_vars = False + if no_dequanted_input_vars: + raise ValueError("There is no dequanted inputs for op %s." % + (op.type)) + + with program_guard(program, startup_program): + self._create_globael_step() + for block in program.blocks: + ops = list(block.ops) + block_id = block.idx + for op in ops: + # rewrite the forward ProgramDes + if op.type in _QUANTIZABLE_OP_TYPES: + _transpile_forward(block, op) + # rename the backward op inputs + if op.type in grad_op_types: + _transpile_backward(block, op) + + def _create_globael_step(self): + if self.weight_quantize_type == 'range_abs_max' or \ + self.activation_quantize_type == 'range_abs_max': + self.global_step = autoincreased_step_counter() + + def freeze_program(self, program, place, fuse_bn=False, scope=None): + """Freeze input training program for inference. + + Args: + program (Program): the input program to be transpile. + """ + + self.is_test = True + scope = global_scope() if scope is None else scope + program = default_main_program() if program is None else program + + if fuse_bn: + bn_fuse_transpiler = BNFuseTranspiler() + bn_fuse_transpiler.transpile(program, place) + + persistable_vars = [ + v.name + for v in filter(lambda var: var.persistable, program.list_vars()) + ] + op_in_rename_map = [ + collections.OrderedDict() for _ in range(len(program.blocks)) + ] + op_out_rename_map = [ + collections.OrderedDict() for _ in range(len(program.blocks)) + ] + var_scale_map = [ + collections.OrderedDict() for _ in range(len(program.blocks)) + ] + + def _remove_fake_quant_and_dequant_op(block, op): + idx = block.ops.index(op) + block_id = block.idx + k = op.output('Out')[0] + v = op.input('X')[0] + if v not in op_in_rename_map[block_id]: + op_in_rename_map[block_id][k] = v + else: + op_in_rename_map[block_id][k] = op_in_rename_map[block_id][v] + block._remove_op(idx) + + def _insert_post_dequant_op(block, op): + idx = block.ops.index(op) + block_id = block.idx + max_range = None + scale_var = None + for name in op.input_arg_names: + if name in op_in_rename_map[block_id]: + op.rename_input(name, op_in_rename_map[block_id][name]) + + scale_v = var_scale_map[block_id][_original_var_name(name)] + if _original_var_name(name) in persistable_vars: + param_range = (1 << (self.weight_bits - 1)) - 1 + act_range = (1 << (self.activation_bits - 1)) - 1 + assert _is_float(scale_v) + max_range = param_range * act_range / scale_v + else: + assert isinstance(scale_v, Variable) + scale_var = var_scale_map[block_id][_original_var_name( + name)] + + if len(op.output_arg_names) != 1: + raise ValueError("Only support one output, but op %s has" + " more than one output." % (op.type)) + out_var = block.var(op.output_arg_names[0]) + dequant_var = block.create_var( + name=_dequantized_var_name(out_var.name), + type=out_var.type, + shape=out_var.shape, + dtype=out_var.dtype) + # insert fake_dequantize_op + dequant_op = block._insert_op( + idx + 1, + type="fake_dequantize_max_abs", + attrs={'max_range': float(max_range)}, + inputs={"X": out_var, + 'Scale': scale_var}, + outputs={"Out": dequant_var}) + op_out_rename_map[block_id][out_var.name] = dequant_var.name + return dequant_var + + def _load_var(name): + return np.array(scope.find_var(name).get_tensor()) + + def _restore_var(name, arr): + t = scope.find_var(name).get_tensor() + t.set(arr, place) + + for block in program.blocks: + ops = list(block.ops) + block_id = block.idx + for op in ops: + op_type = op.type + + # insert dequant_op after fc/conv, need to rename + # input of the followed ops + for name in op.input_arg_names: + if name in op_out_rename_map[block_id]: + op.rename_input(name, op_out_rename_map[block_id][name]) + + if op_type in self.fake_quant_op_types: + in_arg_name = op.input('X')[0] + if in_arg_name in persistable_vars: + if self.weight_quantize_type == 'abs_max': + param = _load_var(in_arg_name) + scale_v = np.max(np.abs(param)) + else: + scale_v = _load_var(op.output('OutScale')[0]) + var_scale_map[block_id][in_arg_name] = scale_v + else: + scale_v = block.var(op.output('OutScale')[0]) + var_scale_map[block_id][in_arg_name] = scale_v + + if in_arg_name in persistable_vars: + _remove_fake_quant_and_dequant_op(block, op) + # quantize weight and restore + param_t = _load_var(in_arg_name) + param_q_t = quant(param_t, scale_v, self.weight_bits) + _restore_var(in_arg_name, param_q_t) + + if op_type in self.fake_dequant_op_types: + _remove_fake_quant_and_dequant_op(block, op) + + if op_type in _QUANTIZABLE_OP_TYPES: + dequant_var = _insert_post_dequant_op(block, op) + + # remove the unused var in ProgramDesc + self._remove_unused_var(program) + #program = program.clone() + + def convert_to_int8(self, program, place, scope=None): + scope = global_scope() if scope is None else scope + program = default_main_program() if program is None else program + + def _load_var(name): + return np.array(scope.find_var(name).get_tensor()) + + global_block = program.global_block() + + def convert_to_int8(var): + int8_var_name = var.name + ".int8" + int8_var = global_block.create_parameter( + name=int8_var_name.encode('ascii'), + type=var.type, + dtype=core.VarDesc.VarType.INT8, + shape=var.shape) + + tensor = _load_var(var.name) + + scope.var(int8_var_name) + int8_tensor = scope.find_var(int8_var_name).get_tensor() + int8_tensor.set(tensor.astype(np.int8), place) + return int8_var + + input_map = {} + for block in program.blocks: + for op in list(block.ops): + if op.type in _QUANTIZABLE_OP_TYPES: + for name in op.input_arg_names: + var = block.var(name) + if var.persistable: + if name not in input_map: + int8_var = convert_to_int8(var) + input_map[name] = int8_var.name + op.rename_input(name, input_map[name]) + self._remove_unused_var(program) + + def _remove_unused_var(self, program): + for block in program.blocks: + args = [] + for op in block.ops: + args += op.input_arg_names + args += op.output_arg_names + args = list(set(args)) + for var in block.vars.keys(): + if var not in args: + block._remove_var(var) + + def _insert_quant_abs_max_op(self, block, idx, var, quant_bits): + """Insert fake_quantize_abs_max op. + """ + quant_var = block.create_var( + name=_quantized_var_name(var.name), + type=var.type, + shape=var.shape, + dtype=var.dtype) + scale = block.create_var( + name=_quantized_scale_name(var.name), + type=var.type, + shape=var.shape, + dtype=var.dtype) + quant_op = block._insert_op( + idx, + type='fake_quantize_abs_max', + attrs={'bit_length': quant_bits}, + inputs={'X': var}, + outputs={'Out': quant_var, + 'OutScale': scale}) + return quant_var, scale + + def _insert_quant_range_abs_max_op(self, block, idx, var, quant_bits): + """Insert fake_quantize_range_abs_max + """ + quant_var = block.create_var( + name=_quantized_var_name(var.name), + type=var.type, + shape=var.shape, + dtype=var.dtype) + scale = self.helper.create_parameter( + attr=ParamAttr( + name=_quantized_scale_name(var.name), + initializer=Constant(0.001), + trainable=False), + shape=[1], + dtype=var.dtype) + scale.stop_gradient = True + + ins = {'X': var, 'InScale': scale} + outs = {'Out': quant_var, 'OutScale': scale} + if not self.is_test: + # A global step counter variable with type int64 + scales = self.helper.create_global_variable( + name=unique_name.generate('scales'), + persistable=True, + dtype=var.dtype, + shape=[self.window_size]) + self.helper.set_variable_initializer( + scales, initializer=Constant(value=0)) + + ins['Iter'] = self.global_step + outs['OutScales'] = scales + + attrs = { + 'window_size': self.window_size, + 'bit_length': quant_bits, + 'is_test': self.is_test + } + + quant_op = block._insert_op( + idx, + type='fake_quantize_range_abs_max', + attrs=attrs, + inputs=ins, + outputs=outs) + + return quant_var, scale + + def _insert_quant_op(self, block, idx, var, quant_bits, quant_type): + """ + Insert fake_quantize_op + """ + if quant_type == 'abs_max': + return self._insert_quant_abs_max_op(block, idx, var, quant_bits) + elif quant_type == 'range_abs_max': + return self._insert_quant_range_abs_max_op(block, idx, var, + quant_bits) + + def _insert_dequant_op(self, block, idx, var, scale, quant_bits): + """ + Insert fake_quantize_op + """ + dequant_var = block.create_var( + name=_dequantized_var_name(var.name), + type=var.type, + shape=var.shape, + dtype=var.dtype) + # insert fake_dequantize_op + max_range = (1 << (quant_bits - 1)) - 1 + dequant_op = block._insert_op( + idx, + type="fake_dequantize_max_abs", + attrs={'max_range': float(max_range)}, + inputs={"X": var, + 'Scale': scale}, + outputs={"Out": dequant_var}) + return dequant_var + + +class BNFuseTranspiler(InferenceTranspiler): + def _fuse_param(self, current_op, bn_op, bias_op, with_bias): + def _update_param(op, param_name, new_param): + var = self.block.vars[param_name] + tensor = self.scope.find_var(param_name).get_tensor() + tensor.set(np.array(new_param), self.place) + + def _load_param(param_name): + return np.array(self.scope.find_var(param_name).get_tensor()) + + bias_bn = _load_param(bn_op.input("Bias")[0]) #Bias + scale_bn = _load_param(bn_op.input("Scale")[0]) #Scale + mean_bn = _load_param(bn_op.input("Mean")[0]) #Mean + var_bn = _load_param(bn_op.input("Variance")[0]) #Variance + + if current_op.type in ['conv2d', 'depthwise_conv2d']: + current_param = _load_param( + _original_var_name(current_op.input("Filter")[0])) + elif current_op.type == 'mul': + current_param = _load_param( + _original_var_name(current_op.input("Y")[0])) + + std_bn = np.float32(np.sqrt(np.add(var_bn, 1e-5))) + tmp = np.float32(np.divide(scale_bn, std_bn)) + + # add bias of batch_norm_op to conv2d + if with_bias: + bias = _load_param(bias_op.input("Y")) + else: + bias = np.zeros(bias_bn.shape) + bias = np.float32( + np.add(np.multiply(np.subtract(bias, mean_bn), tmp), bias_bn)) + + # re-compute weight of conv2d/fc + tmp = tmp.reshape(tmp.shape[0], -1) + dst_param = current_param.reshape((tmp.shape[0], -1)) + dst_param = np.float32(np.multiply(dst_param, tmp)) + dst_param = dst_param.reshape(current_param.shape) + + # update parameters + if current_op.type in ['conv2d', 'depthwise_conv2d']: + _update_param(current_op, + _original_var_name(current_op.input("Filter")[0]), + dst_param) + elif current_op.type == 'mul': + _update_param(current_op, + _original_var_name(current_op.input("Y")[0]), + dst_param) + + _update_param(bias_op, bias_op.input("Y")[0], bias) + + # collect the renamed input + self.input_map[bn_op.output("Y")[0]] = bias_op.output("Out")[0]