diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py index 44767220d06a976ef86186b69765c2b3e44b3ac2..0d989903a9aea018913e3ee30e2b80f9341f77c0 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_pass.py @@ -22,7 +22,7 @@ from .... import unique_name __all__ = [ 'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass', - 'TransformForMobilePass' + 'TransformForMobilePass', 'ScaleForTrainingPass', 'ScaleForInferencePass' ] @@ -962,3 +962,158 @@ class TransformForMobilePass(object): graph.safe_remove_nodes(op_node) graph.resolve_hazard() return graph + + +class ScaleForTrainingPass(object): + def __init__(self, scope=None, place=None, moving_rate=0.9): + """ + This pass is used for calculating output scales of some operators. + These output scales may be used by tensorRT or some other inference engines. + + Args: + scope(fluid.Scope): The scope is used to initialize these new parameters. + place(fluid.CPUPlace|fluid.CUDAPlace): The place is used to initialize new parameters. + moving_rate(float): The decay coefficient of moving average. The default value is 0.9. + """ + self._scope = scope + self._place = place + self._moving_rate = moving_rate + self._is_test = None + self._teller_set = [ + "mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", + "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", + "elementwise_add", "elementwise_mul", "dropout", "split", "prelu", + "conv2d_transpose", "leaky_relu" + ] + + def apply(self, graph): + """ + Insert the `moving_average_abs_max_scale` op in order to calculate output scales + of operators in the teller_set. + + Args: + graph(IrGraph): the target graph. + """ + self._is_test = graph.is_test() + ops = graph.all_op_nodes() + for op_node in ops: + name = op_node.name() + if name in self._teller_set: + if len(op_node.output_arg_names()) != 1: + continue + in_node = graph._find_node_by_name( + op_node.outputs, op_node.output_arg_names()[0]) + out_node = graph.create_var_node_from_desc(in_node.var()) + scale_node = graph.create_persistable_node( + name=self._scale_name(in_node.name()), + var_type=core.VarDesc.VarType.LOD_TENSOR, + shape=[1], + var_dtype=in_node.dtype()) + ins = {'X': in_node} + outs = {'Out': out_node, 'OutScale': scale_node} + if not self._is_test: + state_in_node = graph.create_persistable_node( + name=unique_name.generate('scale_state@'), + var_type=core.VarDesc.VarType.LOD_TENSOR, + var_dtype=in_node.dtype(), + shape=[1]) + data_type = 'float64' if in_node.dtype( + ) == core.VarDesc.VarType.FP64 else 'float32' + _init_var_node( + state_in_node, + np.ones( + [1], dtype=data_type), + self._scope, + self._place) + accum_in_node = graph.create_persistable_node( + name=unique_name.generate('scale_accum@'), + var_type=core.VarDesc.VarType.LOD_TENSOR, + var_dtype=in_node.dtype(), + shape=[1]) + _init_var_node( + accum_in_node, + np.ones( + [1], dtype=data_type), + self._scope, + self._place) + state_out_node = graph.create_var_node_from_desc( + state_in_node.var()) + accum_out_node = graph.create_var_node_from_desc( + accum_in_node.var()) + + ins['InState'] = state_in_node + ins['InAccum'] = accum_in_node + outs['OutState'] = state_out_node + outs['OutAccum'] = accum_out_node + + attrs = { + 'moving_rate': self._moving_rate, + 'is_test': self._is_test, + 'op_role': core.op_proto_and_checker_maker.OpRole.Forward + } + scale_op_node = graph.create_op_node( + op_type='moving_average_abs_max_scale', + attrs=attrs, + inputs=ins, + outputs=outs) + graph.link_to(in_node, scale_op_node) + graph.link_to(scale_op_node, out_node) + graph.link_to(scale_op_node, scale_node) + if not self._is_test: + graph.link_to(state_in_node, scale_op_node) + graph.link_to(accum_in_node, scale_op_node) + graph.link_to(scale_op_node, state_out_node) + graph.link_to(scale_op_node, accum_out_node) + graph.resolve_hazard() + return graph + + def _scale_name(self, var_name): + """ + Return the scale name for the var named `var_name`. + """ + return "%s@scale" % (var_name) + + +class ScaleForInferencePass(object): + def __init__(self, scope=None): + """ + This pass is used for setting output scales of some operators. + These output scales may be used by tensorRT or some other inference engines. + + Args: + scope(fluid.Scope): The scope is used to initialize these new parameters. + """ + self._scope = scope + self._teller_set = [ + "mul", "conv2d", "pool2d", "relu", "softmax", "sigmoid", + "depthwise_conv2d", "batch_norm", "concat", "tanh", "pad", + "elementwise_add", "elementwise_mul", "dropout", "split", "prelu", + "conv2d_transpose", "leaky_relu" + ] + + def apply(self, graph): + """ + Get output scales from the scope and set these scales in op_descs + of operators in the teller_set. + + Args: + graph(IrGraph): the target graph. + """ + ops = graph.all_op_nodes() + for op_node in ops: + name = op_node.name() + if name in self._teller_set: + if len(op_node.output_arg_names()) != 1: + continue + scale_name = self._scale_name(op_node.output_arg_names()[0]) + scale_v = np.array( + self._scope.find_var(scale_name).get_tensor())[0] + op_node.op()._set_attr("out_scale", float(scale_v)) + graph.resolve_hazard() + return graph + + def _scale_name(self, var_name): + """ + Return the scale name for the var named `var_name`. + """ + return "%s@scale" % (var_name) diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_scale_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_scale_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..1ed41da0f842b5eac8fd622a96a2fbd68adf98ae --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_scale_pass.py @@ -0,0 +1,190 @@ +# copyright (c) 2018 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import os +import unittest +import random +import numpy as np +import six +import paddle.fluid as fluid +import paddle +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass +from paddle.fluid.contrib.slim.quantization import ScaleForTrainingPass +from paddle.fluid.contrib.slim.quantization import ScaleForInferencePass +from paddle.fluid import core + +os.environ["CUDA_VISIBLE_DEVICES"] = "0" +os.environ["CPU_NUM"] = "1" + + +def residual_block(img, label, num=1): + def conv_bn_layer(input, + ch_out, + filter_size, + stride, + padding, + act='relu', + bias_attr=False): + tmp = fluid.layers.conv2d( + input=input, + filter_size=filter_size, + num_filters=ch_out, + stride=stride, + padding=padding, + act=None, + bias_attr=bias_attr) + return fluid.layers.batch_norm(input=tmp, act=act) + + hidden = img + for _ in six.moves.xrange(num): + conv = conv_bn_layer(hidden, 20, 3, 1, 1, act=None, bias_attr=True) + short = conv_bn_layer(hidden, 20, 1, 1, 0, act=None) + hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu') + fc = fluid.layers.fc(input=hidden, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=fc, label=label) + loss = fluid.layers.mean(loss) + return loss + + +class TestQuantizationScalePass(unittest.TestCase): + def quantization_scale(self, + use_cuda, + seed, + activation_quant_type, + weight_quant_type='abs_max', + for_ci=False): + def build_program(main, startup, is_test): + main.random_seed = seed + startup.random_seed = seed + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + img = fluid.layers.data( + name='image', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + loss = residual_block(img, label, 1) + if not is_test: + opt = fluid.optimizer.Adam(learning_rate=0.0001) + opt.minimize(loss) + return [img, label], loss + + random.seed(0) + np.random.seed(0) + + main = fluid.Program() + startup = fluid.Program() + test_program = fluid.Program() + feeds, loss = build_program(main, startup, False) + build_program(test_program, startup, True) + test_program = test_program.clone(for_test=True) + main_graph = IrGraph(core.Graph(main.desc), for_test=False) + test_graph = IrGraph(core.Graph(test_program.desc), for_test=True) + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.Scope() + with fluid.scope_guard(scope): + exe.run(startup) + transform_pass = QuantizationTransformPass( + scope=scope, + place=place, + activation_quantize_type=activation_quant_type, + weight_quantize_type=weight_quant_type) + transform_pass.apply(main_graph) + transform_pass.apply(test_graph) + scale_training_pass = ScaleForTrainingPass(scope=scope, place=place) + scale_training_pass.apply(main_graph) + dev_name = '_gpu' if use_cuda else '_cpu' + if not for_ci: + marked_nodes = set() + for op in main_graph.all_op_nodes(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + main_graph.draw('.', 'main_scale' + dev_name, marked_nodes) + marked_nodes = set() + for op in test_graph.all_op_nodes(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + test_graph.draw('.', 'test_scale' + dev_name, marked_nodes) + + build_strategy = fluid.BuildStrategy() + build_strategy.memory_optimize = False + build_strategy.enable_inplace = False + binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy) + iters = 5 + batch_size = 8 + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=batch_size) + feeder = fluid.DataFeeder(feed_list=feeds, place=place) + with fluid.scope_guard(scope): + for _ in range(iters): + data = next(train_reader()) + loss_v = exe.run(binary, + feed=feeder.feed(data), + fetch_list=[loss]) + if not for_ci: + print('{}: {}'.format('loss' + dev_name, loss_v)) + + scale_inference_pass = ScaleForInferencePass(scope=scope) + scale_inference_pass.apply(test_graph) + + # Freeze graph for inference, but the weight of fc/conv is still float type. + freeze_pass = QuantizationFreezePass( + scope=scope, place=place, weight_quantize_type=weight_quant_type) + freeze_pass.apply(test_graph) + server_program = test_graph.to_program() + + if not for_ci: + marked_nodes = set() + for op in test_graph.all_op_nodes(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + test_graph.draw('.', 'quant_scale' + dev_name, marked_nodes) + + with open('quant_scale_model' + dev_name + '.txt', 'w') as f: + f.write(str(server_program)) + + with fluid.scope_guard(scope): + fluid.io.save_inference_model('quant_scale_model' + dev_name, + ['image', 'label'], [loss], exe, + server_program) + + def test_quant_scale_cuda(self): + if fluid.core.is_compiled_with_cuda(): + with fluid.unique_name.guard(): + self.quantization_scale( + True, + seed=1, + activation_quant_type='moving_average_abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) + + def test_quant_scale_cpu(self): + with fluid.unique_name.guard(): + self.quantization_scale( + False, + seed=2, + activation_quant_type='moving_average_abs_max', + weight_quant_type='channel_wise_abs_max', + for_ci=True) + + +if __name__ == '__main__': + unittest.main()