未验证 提交 eb65877c 编写于 作者: S Sylwester Fraczek 提交者: GitHub

fix dimensions error for mobilenetv1_KL_quant (#26776)

* fix dimensions error for mobilenetv1_KL_quant

fixes AssertionError: The size of weight scales vector (1000) does not match the number of output channels (1024) in the weights tensor fc7_weights.

add mul test

* remove comment

* add third case unit test
上级 24ec5175
......@@ -299,11 +299,14 @@ class Quant2Int8MkldnnPass(object):
# Convert int8 range weights to fp32 range weights
scales = self._weight_scales[output_var_name]
weight = self._load_param(self._scope, weight_var_name)
assert scales.size == 1 or scales.size == len(
weight
), "The size of weight scales vector ({}) does not match the number of output channels ({}) in the weights tensor {}.".format(
scales.size, len(weight), weight_var_name)
w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T
if scales.size == 1 or scales.size == weight.shape[0]:
w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T
elif len(weight.shape) > 1 and scales.size == weight.shape[1]:
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
else:
raise ValueError(
"The size of weight scales vector ({}) does not match the dimensions ({}) of the weights tensor {}."
.format(scales.size, weight.shape, weight_var_name))
w_fp32 = w_fp32.reshape(weight.shape).astype(np.float32)
self._restore_var(weight_var_name, w_fp32)
......
......@@ -43,7 +43,7 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
self.conv_output = np.ndarray(self.conv_output_size).astype(self.dtype)
self.conv_output2 = np.ndarray(self.conv_output2_size).astype(
self.dtype)
self.quantized_ops = 'conv2d'
self.quantized_ops = 'conv2d,mul'
self.variables = {
"input": self.input,
"filter": self.filter,
......@@ -51,6 +51,22 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
"conv_output": self.conv_output,
"conv_output2": self.conv_output2,
}
self.mul_input_size = [1, 3]
self.mul_weights_size = [3, 5]
self.mul_output_size = [1, 5]
self.mul_input = np.random.random(self.mul_input_size).astype(
self.dtype)
self.mul_weights = np.ones(self.mul_weights_size, self.dtype)
self.mul_weights_bad = np.ones([1, 1], self.dtype)
self.mul_output = np.ndarray(self.mul_output_size).astype(self.dtype)
self.mul_output_scale = np.linspace(1, 5, num=5).astype(self.dtype)
self.variables_mul = {
"mul_input": self.mul_input,
"mul_weights": self.mul_weights,
"mul_output": self.mul_output,
"mul_weights_bad": self.mul_weights_bad
}
def prepare_program(self, program):
block = program.global_block()
......@@ -92,6 +108,23 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
'fuse_brelu': True
})
def prepare_program_mul(self, program):
block = program.global_block()
for name in self.variables_mul:
block.create_var(
name=name,
dtype="float32",
shape=self.variables_mul[name].shape)
mul_op1 = block.append_op(
type="mul",
inputs={
"X": block.var('mul_input'),
"Y": block.var('mul_weights')
},
outputs={"Out": block.var('mul_output')},
attrs={'use_mkldnn': self.use_mkldnn})
def remove_fuse_activation_attribute(self, graph):
for op in graph.all_op_nodes():
op.op().remove_attr("fuse_activation")
......@@ -103,11 +136,13 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
def check_graph_after_pass(self, graph):
for op in graph.all_op_nodes():
self.assertTrue(op.op().has_attr("fuse_activation"))
if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"):
self.assertTrue(op.op().attr("fuse_activation") == "relu")
if op.op().has_attr("fuse_brelu") and op.op().attr("fuse_brelu"):
self.assertTrue(op.op().attr("fuse_activation") == "relu6")
if op.op().type() == "conv2d":
self.assertTrue(op.op().has_attr("fuse_activation"))
if op.op().has_attr("fuse_relu") and op.op().attr("fuse_relu"):
self.assertTrue(op.op().attr("fuse_activation") == "relu")
if op.op().has_attr("fuse_brelu") and op.op().attr(
"fuse_brelu"):
self.assertTrue(op.op().attr("fuse_activation") == "relu6")
def test_quant_update_activation(self):
program = fluid.Program()
......@@ -125,6 +160,39 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
graph = quant2_int8_mkldnn_pass._update_activations(graph)
self.check_graph_after_pass(graph)
def test_dequantize_op_weights(self):
program = fluid.Program()
with fluid.program_guard(program):
self.prepare_program_mul(program)
graph = IrGraph(core.Graph(program.desc), for_test=True)
for op in graph.all_op_nodes():
if op.op().type() == "mul":
op_node = op
break
qpass = Quant2Int8MkldnnPass(
self.quantized_ops,
_scope=self.scope,
_place=self.place,
_core=core,
_debug=False)
qpass._weight_scales["mul_output"] = self.mul_output_scale
param = self.scope.var("mul_weights").get_tensor()
param.set(self.variables_mul["mul_weights"], self.place)
qpass._dequantize_op_weights(graph, op_node, "Y", "Out")
assert np.allclose(
self.scope.find_var("mul_weights").get_tensor(),
[[127, 63.5, 42.3333, 31.75, 25.4],
[127, 63.5, 42.3333, 31.75, 25.4],
[127, 63.5, 42.3333, 31.75, 25.4]])
param = self.scope.var("mul_weights").get_tensor()
param.set(self.variables_mul["mul_weights_bad"], self.place)
with self.assertRaises(ValueError):
qpass._dequantize_op_weights(graph, op_node, "Y", "Out")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册