From 9fd90674551beaf019cdc635f0d221e95fe274c3 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Wed, 1 Apr 2020 18:29:34 +0200 Subject: [PATCH] handle conv2d activations in older QAT models (#23202) --- .../quantization/qat2_int8_mkldnn_pass.py | 18 +++ .../slim/tests/test_qat2_int8_mkldnn_pass.py | 130 ++++++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 python/paddle/fluid/contrib/slim/tests/test_qat2_int8_mkldnn_pass.py diff --git a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py index 021f188311f..44098e56e16 100644 --- a/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/qat2_int8_mkldnn_pass.py @@ -294,6 +294,23 @@ class Qat2Int8MkldnnPass(object): tensor = self._scope.find_var(name).get_tensor() tensor.set(array, self._place) + def _update_activations(self, graph): + for op in graph.all_op_nodes(): + if op.name() in self._conv_ops and not op.op().has_attr( + "fuse_activation"): + activation = "" + if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"): + activation = "relu" + elif op.op().has_attr("fuse_brelu") and op.op().attr( + "fuse_brelu"): + activation = "relu6" + alpha = 6.0 + if op.op().has_attr("fuse_brelu_threshold"): + alpha = op.op().attr("fuse_brelu_threshold") + op.set_attr("fuse_alpha", alpha) + op.set_attr("fuse_activation", activation) + return graph + def _remove_ctrl_vars(self, graph): remove_ctr_vars = set() for node in graph.all_var_nodes(): @@ -303,6 +320,7 @@ class Qat2Int8MkldnnPass(object): return graph def _optimize_fp32_graph(self, graph): + graph = self._update_activations(graph) graph = self._remove_ctrl_vars(graph) graph = self._apply_pass(graph, 'mkldnn_placement_pass', ['mkldnn_enabled_op_types'], [set()]) diff --git a/python/paddle/fluid/contrib/slim/tests/test_qat2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/tests/test_qat2_int8_mkldnn_pass.py new file mode 100644 index 00000000000..16cbfdd99d3 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/test_qat2_int8_mkldnn_pass.py @@ -0,0 +1,130 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import unittest +import numpy as np +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import Qat2Int8MkldnnPass + + +class TestQat2Int8MkldnnPass(unittest.TestCase): + def setUp(self): + self.scope = fluid.Scope() + self.place = fluid.CPUPlace() + self.dtype = np.float32 + self.use_cudnn = False + self.use_mkldnn = True + self.data_format = "ANYLAYOUT" + self.pad = [0, 0] + self.stride = [1, 1] + self.dilations = [1, 1] + self.groups = 1 + self.input_size = [1, 3, 5, 5] + self.filter_size = [16, 3, 3, 3] + self.filter_size2 = [1, 16, 2, 2] + self.conv_output_size = [1, 16, 3, 3] + self.conv_output2_size = [1, 1, 2, 2] + self.input = np.random.random(self.input_size).astype(self.dtype) + self.filter = np.random.random(self.filter_size).astype(self.dtype) + self.filter2 = np.random.random(self.filter_size2).astype(self.dtype) + self.conv_output = np.ndarray(self.conv_output_size).astype(self.dtype) + self.conv_output2 = np.ndarray(self.conv_output2_size).astype( + self.dtype) + self.quantized_ops = 'conv2d' + self.variables = { + "input": self.input, + "filter": self.filter, + "filter2": self.filter2, + "conv_output": self.conv_output, + "conv_output2": self.conv_output2, + } + + def prepare_program(self, program): + block = program.global_block() + for name in self.variables: + block.create_var( + name=name, dtype="float32", shape=self.variables[name].shape) + conv2d_op1 = block.append_op( + type="conv2d", + inputs={ + "Input": block.var('input'), + 'Filter': block.var('filter') + }, + outputs={"Output": block.var('conv_output')}, + attrs={ + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format, + 'fuse_relu': True + }) + conv2d_op2 = block.append_op( + type="conv2d", + inputs={ + "Input": block.var('conv_output'), + 'Filter': block.var('filter2') + }, + outputs={"Output": block.var('conv_output2')}, + attrs={ + 'strides': self.stride, + 'paddings': self.pad, + 'groups': self.groups, + 'dilations': self.dilations, + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format, + 'fuse_brelu': True + }) + + def remove_fuse_activation_attribute(self, graph): + for op in graph.all_op_nodes(): + op.op().remove_attr("fuse_activation") + return graph + + def check_graph_before_pass(self, graph): + for op in graph.all_op_nodes(): + self.assertFalse(op.op().has_attr("fuse_activation")) + + def check_graph_after_pass(self, graph): + for op in graph.all_op_nodes(): + self.assertTrue(op.op().has_attr("fuse_activation")) + if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"): + self.assertTrue(op.op().attr("fuse_activation") == "relu") + if op.op().has_attr("fuse_brelu") and op.op().attr("fuse_brelu"): + self.assertTrue(op.op().attr("fuse_activation") == "relu6") + + def test_qat_update_activation(self): + program = fluid.Program() + with fluid.program_guard(program): + self.prepare_program(program) + graph = IrGraph(core.Graph(program.desc), for_test=True) + graph = self.remove_fuse_activation_attribute(graph) + self.check_graph_before_pass(graph) + qat2_int8_mkldnn_pass = Qat2Int8MkldnnPass( + self.quantized_ops, + _scope=self.scope, + _place=self.place, + _core=core, + _debug=False) + graph = qat2_int8_mkldnn_pass._update_activations(graph) + self.check_graph_after_pass(graph) + + +if __name__ == '__main__': + unittest.main() -- GitLab