From eb65877ce0b2afc2589ba6ae8fb3018a44828067 Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Mon, 7 Sep 2020 07:54:17 +0200 Subject: [PATCH] fix dimensions error for mobilenetv1_KL_quant (#26776) * fix dimensions error for mobilenetv1_KL_quant fixes AssertionError: The size of weight scales vector (1000) does not match the number of output channels (1024) in the weights tensor fc7_weights. add mul test * remove comment * add third case unit test --- .../quantization/quant2_int8_mkldnn_pass.py | 13 +-- .../tests/test_quant2_int8_mkldnn_pass.py | 80 +++++++++++++++++-- 2 files changed, 82 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 75e1ea43d1..dadc756c43 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -299,11 +299,14 @@ class Quant2Int8MkldnnPass(object): # Convert int8 range weights to fp32 range weights scales = self._weight_scales[output_var_name] weight = self._load_param(self._scope, weight_var_name) - assert scales.size == 1 or scales.size == len( - weight - ), "The size of weight scales vector ({}) does not match the number of output channels ({}) in the weights tensor {}.".format( - scales.size, len(weight), weight_var_name) - w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T + if scales.size == 1 or scales.size == weight.shape[0]: + w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T + elif len(weight.shape) > 1 and scales.size == weight.shape[1]: + w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales) + else: + raise ValueError( + "The size of weight scales vector ({}) does not match the dimensions ({}) of the weights tensor {}." + .format(scales.size, weight.shape, weight_var_name)) w_fp32 = w_fp32.reshape(weight.shape).astype(np.float32) self._restore_var(weight_var_name, w_fp32) diff --git a/python/paddle/fluid/contrib/slim/tests/test_quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/tests/test_quant2_int8_mkldnn_pass.py index fcbb1b66ad..7b51973131 100644 --- a/python/paddle/fluid/contrib/slim/tests/test_quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/tests/test_quant2_int8_mkldnn_pass.py @@ -43,7 +43,7 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase): self.conv_output = np.ndarray(self.conv_output_size).astype(self.dtype) self.conv_output2 = np.ndarray(self.conv_output2_size).astype( self.dtype) - self.quantized_ops = 'conv2d' + self.quantized_ops = 'conv2d,mul' self.variables = { "input": self.input, "filter": self.filter, @@ -51,6 +51,22 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase): "conv_output": self.conv_output, "conv_output2": self.conv_output2, } + self.mul_input_size = [1, 3] + self.mul_weights_size = [3, 5] + self.mul_output_size = [1, 5] + self.mul_input = np.random.random(self.mul_input_size).astype( + self.dtype) + self.mul_weights = np.ones(self.mul_weights_size, self.dtype) + self.mul_weights_bad = np.ones([1, 1], self.dtype) + self.mul_output = np.ndarray(self.mul_output_size).astype(self.dtype) + self.mul_output_scale = np.linspace(1, 5, num=5).astype(self.dtype) + + self.variables_mul = { + "mul_input": self.mul_input, + "mul_weights": self.mul_weights, + "mul_output": self.mul_output, + "mul_weights_bad": self.mul_weights_bad + } def prepare_program(self, program): block = program.global_block() @@ -92,6 +108,23 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase): 'fuse_brelu': True }) + def prepare_program_mul(self, program): + block = program.global_block() + for name in self.variables_mul: + block.create_var( + name=name, + dtype="float32", + shape=self.variables_mul[name].shape) + + mul_op1 = block.append_op( + type="mul", + inputs={ + "X": block.var('mul_input'), + "Y": block.var('mul_weights') + }, + outputs={"Out": block.var('mul_output')}, + attrs={'use_mkldnn': self.use_mkldnn}) + def remove_fuse_activation_attribute(self, graph): for op in graph.all_op_nodes(): op.op().remove_attr("fuse_activation") @@ -103,11 +136,13 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase): def check_graph_after_pass(self, graph): for op in graph.all_op_nodes(): - self.assertTrue(op.op().has_attr("fuse_activation")) - if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"): - self.assertTrue(op.op().attr("fuse_activation") == "relu") - if op.op().has_attr("fuse_brelu") and op.op().attr("fuse_brelu"): - self.assertTrue(op.op().attr("fuse_activation") == "relu6") + if op.op().type() == "conv2d": + self.assertTrue(op.op().has_attr("fuse_activation")) + if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"): + self.assertTrue(op.op().attr("fuse_activation") == "relu") + if op.op().has_attr("fuse_brelu") and op.op().attr( + "fuse_brelu"): + self.assertTrue(op.op().attr("fuse_activation") == "relu6") def test_quant_update_activation(self): program = fluid.Program() @@ -125,6 +160,39 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase): graph = quant2_int8_mkldnn_pass._update_activations(graph) self.check_graph_after_pass(graph) + def test_dequantize_op_weights(self): + program = fluid.Program() + with fluid.program_guard(program): + self.prepare_program_mul(program) + graph = IrGraph(core.Graph(program.desc), for_test=True) + + for op in graph.all_op_nodes(): + if op.op().type() == "mul": + op_node = op + break + + qpass = Quant2Int8MkldnnPass( + self.quantized_ops, + _scope=self.scope, + _place=self.place, + _core=core, + _debug=False) + qpass._weight_scales["mul_output"] = self.mul_output_scale + param = self.scope.var("mul_weights").get_tensor() + param.set(self.variables_mul["mul_weights"], self.place) + qpass._dequantize_op_weights(graph, op_node, "Y", "Out") + + assert np.allclose( + self.scope.find_var("mul_weights").get_tensor(), + [[127, 63.5, 42.3333, 31.75, 25.4], + [127, 63.5, 42.3333, 31.75, 25.4], + [127, 63.5, 42.3333, 31.75, 25.4]]) + + param = self.scope.var("mul_weights").get_tensor() + param.set(self.variables_mul["mul_weights_bad"], self.place) + with self.assertRaises(ValueError): + qpass._dequantize_op_weights(graph, op_node, "Y", "Out") + if __name__ == '__main__': unittest.main() -- GitLab