From 90ebce9eadde0dca2731cc8fa5717797260e412c Mon Sep 17 00:00:00 2001 From: bingyanghuang <33643817+bingyanghuang@users.noreply.github.com> Date: Mon, 10 Jun 2019 16:39:46 +0800 Subject: [PATCH] QAT int8 MKL-DNN transformation pass (#17819) --- .../contrib/slim/quantization/__init__.py | 3 + .../quantization/quantization_mkldnn_pass.py | 229 ++++++++++++++++++ .../tests/test_quantization_mkldnn_pass.py | 193 +++++++++++++++ 3 files changed, 425 insertions(+) create mode 100644 python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py create mode 100644 python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py diff --git a/python/paddle/fluid/contrib/slim/quantization/__init__.py b/python/paddle/fluid/contrib/slim/quantization/__init__.py index 445cbc776a..659265895a 100644 --- a/python/paddle/fluid/contrib/slim/quantization/__init__.py +++ b/python/paddle/fluid/contrib/slim/quantization/__init__.py @@ -20,6 +20,9 @@ from . import quantization_strategy from .quantization_strategy import * from . import mkldnn_post_training_strategy from .mkldnn_post_training_strategy import * +from . import quantization_mkldnn_pass +from .quantization_mkldnn_pass import * __all__ = quantization_pass.__all__ + quantization_strategy.__all__ __all__ += mkldnn_post_training_strategy.__all__ +__all__ += quantization_mkldnn_pass.__all__ diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py new file mode 100644 index 0000000000..2fc9dfac8e --- /dev/null +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py @@ -0,0 +1,229 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from .... import core +from ....framework import IrGraph +from ....framework import IrNode + +__all__ = ['TransformForMkldnnPass'] + + +class TransformForMkldnnPass(object): + """ + Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8 + IrGraph. Following transformations did in this pass: + 1. Convert int8 range weights with float32 data type, which are generated by + the QuantizationFreezePass, to float32 range weights with float32 data type + by using the corresponding scales. This conversion is because MKL-DNN INT8 + conv2d kernel now only supports float32 weights input, will do weights + quantization inside the conv2d kernel. + 2. Create the new conv2d op with the converted weights and link its output + to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32 + _output" as true + 3. Transform fake_quantize_xx op to quantize op + 4. Remove fake_dequantize_abs_max op + """ + + def __init__(self, scope=None, place=None): + """ + Args: + scope(fluid.Scope): scope is used to initialize the new parameters. + place(fluid.CPUPlace): place is used to initialize the new parameters. + + + Examples: + .. code-block:: python + # The original graph will be rewrite. + import paddle.fluid as fluid + from paddle.fluid.contrib.slim.quantization \ + import TransformForMkldnnPass + from paddle.fluid.framework import IrGraph + from paddle.fluid import core + + graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False) + place = fluid.CPUPlace() + mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(), + place) + mkldnn_pass.apply(graph) + """ + + self._scope = scope + self._place = place + + self.quantize_type = [ + 'fake_quantize_moving_average_abs_max', + 'fake_quantize_range_abs_max' + ] + self.dequantize_type = ['fake_dequantize_max_abs'] + + self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] + self._conv_ops = ['conv2d', 'depthwise_conv2d'] + + self.InScale = {} + self.max_range = {} + self.conv_new_output = {} + self.s8_max = 127 + # Temporary code for keeping the mul op as fake quantization + #TODO Intel: Remove the following code when mul int8 mkldnn + # kernel enabled + self.mul_input_id = [] + self.mul_output_id = [] + + def apply(self, graph): + """ + Quantize the graph for running MKL-DNN INT8 inference. According + to activation quantization type, the graph will transform fake + quantize ops to quantize ops and remove the fake dequantize ops. + + Args: + graph(IrGraph): the applied graph. + """ + + assert isinstance(graph, + IrGraph), 'graph must be the instance of IrGraph.' + ops = graph.all_op_nodes() + + persistable_vars = [p.name() for p in graph.all_persistable_nodes()] + # Collect the InScales and max_range to calculate the new scales for MKL-DNN + # INT8 conv2d + for op_node in ops: + if op_node.name() in self.dequantize_type: + input_name = op_node.input("X")[0] + scale_name = op_node.input("Scale")[0] + self.InScale[input_name] = self._load_param(self._scope, + scale_name)[0] + self.max_range[input_name] = op_node.op().attr("max_range") + self.conv_new_output[input_name] = op_node.output("Out")[0] + # Temporary graph transform on keeping the mul op + # TODO Intel: Remove following code + elif op_node.name() in ['mul']: + input_node = graph._find_node_by_name(op_node.inputs, + op_node.input('X')[0]) + output_node = graph._find_node_by_name(op_node.outputs, + op_node.output('Out')[0]) + self.mul_input_id.append(input_node.id()) + self.mul_output_id.append(output_node.id()) + + for op_node in ops: + if op_node.name() in self._conv_ops: + self._transform_to_conv_mkldnn(graph, op_node) + elif op_node.name() in self.quantize_type: + self._transform_to_quantize_mkldnn(graph, op_node) + elif op_node.name() in self.dequantize_type: + self._remove_fake_dequantize_op(graph, op_node) + self._remove_unused_var_nodes(graph) + return graph + + def _transform_to_conv_mkldnn(self, graph, op_node): + weight_name = op_node.input("Filter")[0] + output_name = op_node.output("Output")[0] + # Convert int8 range weights to fp32 range weights + weight = self._load_param(self._scope, weight_name) + w_fp32 = np.divide( + np.multiply(weight, 127), self.max_range[output_name]) + w_fp32 = w_fp32.reshape(weight.shape) + self._restore_var(weight_name, w_fp32) + input_var_node = graph._find_node_by_name(op_node.inputs, + op_node.input("Input")[0]) + weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name) + + # Set fake_dequantize_abs_max's output as new output of conv2d + output_var_node = graph._find_node_by_name( + graph.all_var_nodes(), self.conv_new_output[output_name]) + attrs = { + name: op_node.op().attr(name) + for name in op_node.op().attr_names() + } + + conv_op_node = graph.create_op_node( + op_type='conv2d', + attrs=attrs, + inputs={'Input': input_var_node, + 'Filter': weight_var_node}, + outputs={'Output': output_var_node}) + + # Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d + scale_in = self.s8_max / self.InScale[output_name] + scale_w = [] + scale_w.append(self.max_range[output_name] / self.s8_max) + + conv_op_node.set_attr("Scale_weights", scale_w) + conv_op_node.set_attr("Scale_in", scale_in) + conv_op_node.set_attr("Scale_out", 1.0) + conv_op_node.set_attr("use_mkldnn", 1) + conv_op_node.set_attr("force_fp32_output", 1) + graph.link_to(input_var_node, conv_op_node) + graph.link_to(weight_var_node, conv_op_node) + graph.link_to(conv_op_node, output_var_node) + graph.safe_remove_nodes(op_node) + + def _transform_to_quantize_mkldnn(self, graph, op_node): + """ + Transform fake_quantize_xx op to quantize mkldnn op in the graph. + """ + input_var_node = graph._find_node_by_name(op_node.inputs, + op_node.input("X")[0]) + output_var_node = graph._find_node_by_name(op_node.outputs, + op_node.output("Out")[0]) + if output_var_node.id() in self.mul_input_id: + return + else: + scale_in = self.s8_max / self._load_param( + self._scope, op_node.input("InScale")[0])[0] + quant_op_node = graph.create_op_node( + op_type='quantize', + attrs={ + 'data_format': 'MKLDNNLAYOUT', + 'use_mkldnn': 1, + 'Scale': scale_in, + 'is_negative_input': 1 + }, + inputs={'Input': input_var_node}, + outputs={'Output': output_var_node}) + graph.link_to(input_var_node, quant_op_node) + graph.link_to(quant_op_node, output_var_node) + graph.safe_remove_nodes(op_node) + + def _remove_fake_dequantize_op(self, graph, op_node): + input_var_node = graph._find_node_by_name(op_node.inputs, + op_node.input("X")[0]) + if input_var_node.id() in self.mul_output_id: + return + else: + graph.safe_remove_nodes(op_node) + + def _load_param(self, scope, param_name): + return np.array(scope.find_var(param_name).get_tensor()) + + def _restore_var(self, name, array): + tensor = self._scope.find_var(name).get_tensor() + tensor.set(array, self._place) + + def _remove_unused_var_nodes(self, graph): + all_used_vars = set() + ops = graph.all_op_nodes() + for op_node in ops: + for input_node in op_node.inputs: + all_used_vars.add(input_node) + for output_node in op_node.outputs: + all_used_vars.add(output_node) + + all_used_vars = {n.node for n in all_used_vars} + all_unused_vars = { + n + for n in filter(lambda node: node.node not in all_used_vars, + graph.all_var_nodes()) + } + graph.safe_remove_nodes(all_unused_vars) diff --git a/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py new file mode 100644 index 0000000000..90cc28b3aa --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_quantization_mkldnn_pass.py @@ -0,0 +1,193 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import os +import unittest +import random +import numpy as np +import paddle.fluid as fluid +import six +import paddle +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass +from paddle.fluid import core + +os.environ["CPU_NUM"] = "1" + + +def conv_net(img, label): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=img, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu") + conv_pool_1 = fluid.layers.batch_norm(conv_pool_1) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu") + prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax') + loss = fluid.layers.cross_entropy(input=prediction, label=label) + avg_loss = fluid.layers.mean(loss) + return avg_loss + + +class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): + def setUp(self): + self.quantizable_op_and_inputs = { + 'conv2d': ['Input', 'Filter'], + 'depthwise_conv2d': ['Input', 'Filter'], + # Mul int8 op is under internal test + # TODO Update this when mul op is merged + #'mul': ['X', 'Y'] + } + + def check_program(self, program): + for block in program.blocks: + for op in block.ops: + if op.type in self.quantizable_op_and_inputs: + for arg_name in op.output_arg_names: + # Check quantizable op's output is linked to + # fake_dequantize's output + self.assertTrue(arg_name.endswith('.dequantized')) + + def isinteger(self, x): + return np.equal(np.mod(x, 1), 0) + + def build_program(self, main, startup, is_test, seed): + main.random_seed = seed + startup.random_seed = seed + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + img = fluid.layers.data( + name='image', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data( + name='label', shape=[1], dtype='int64') + loss = conv_net(img, label) + if not is_test: + opt = fluid.optimizer.Adam(learning_rate=0.001) + opt.minimize(loss) + return [img, label], loss + + def mkldnn_based_freeze_graph(self, + use_cuda, + seed, + activation_quant_type, + weight_quant_type='abs_max', + for_ci=False): + random.seed(0) + np.random.seed(0) + + main = fluid.Program() + startup = fluid.Program() + test_program = fluid.Program() + feeds, loss = self.build_program(main, startup, False, seed) + self.build_program(test_program, startup, True, seed) + test_program = test_program.clone(for_test=True) + main_graph = IrGraph(core.Graph(main.desc), for_test=False) + test_graph = IrGraph(core.Graph(test_program.desc), for_test=True) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.Scope() + with fluid.scope_guard(scope): + exe.run(startup) + # Apply the QAT QuantizationTransformPass + transform_pass = QuantizationTransformPass( + scope=scope, + place=place, + activation_quantize_type=activation_quant_type, + weight_quantize_type=weight_quant_type) + transform_pass.apply(main_graph) + transform_pass.apply(test_graph) + + build_strategy = fluid.BuildStrategy() + build_strategy.memory_optimize = False + build_strategy.enable_inplace = False + binary = fluid.CompiledProgram(main_graph.graph).with_data_parallel( + loss_name=loss.name, build_strategy=build_strategy) + quantized_test_program = test_graph.to_program() + iters = 5 + batch_size = 8 + + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + feeder = fluid.DataFeeder(feed_list=feeds, place=place) + + # Training the model to get the weights value + with fluid.scope_guard(scope): + for _ in range(iters): + data = next(train_reader()) + loss_v = exe.run(binary, + feed=feeder.feed(data), + fetch_list=[loss]) + + # Freeze graph for inference, but the weight of fc/conv is still float type. + freeze_pass = QuantizationFreezePass( + scope=scope, place=place, weight_quantize_type=weight_quant_type) + freeze_pass.apply(test_graph) + + # Transform quantized graph for MKL-DNN INT8 inference + mkldnn_int8_pass = TransformForMkldnnPass(scope=scope, place=place) + mkldnn_int8_pass.apply(test_graph) + dev_name = '_cpu_' + if not for_ci: + marked_nodes = set() + for op in test_graph.all_op_nodes(): + if op.name().find('quantize') > -1: + marked_nodes.add(op) + test_graph.draw('.', 'test_mkldnn' + dev_name + + activation_quant_type + '_' + weight_quant_type, + marked_nodes) + mkldnn_program = test_graph.to_program() + w_mkldnn = np.array(scope.find_var('conv2d_1.w_0').get_tensor()) + # Check if weights are still integer + self.assertFalse(self.isinteger(np.sum(w_mkldnn))) + + # Check if the conv2d output is rightly linked to fake_dequantize's + # output + self.check_program(mkldnn_program) + if not for_ci: + print('{}: {}'.format('w_mkldnn' + dev_name + activation_quant_type + + '_' + weight_quant_type, np.sum(w_mkldnn))) + + def test_mkldnn_graph_cpu_static(self): + with fluid.unique_name.guard(): + self.mkldnn_based_freeze_graph( + False, + seed=2, + activation_quant_type='range_abs_max', + weight_quant_type='abs_max', + for_ci=True) + self.mkldnn_based_freeze_graph( + False, + seed=2, + activation_quant_type='moving_average_abs_max', + weight_quant_type='abs_max', + for_ci=True) + + +if __name__ == '__main__': + unittest.main() -- GitLab