未验证 提交 1fa863da 编写于 作者: C cc 提交者: GitHub

Support dygraph quant model (#29927)

* Avoid the scale to be infinity in quant2_int8_mkldnn_pass, test=develop
* support quantized model for paddle2.0 dygraph, test=develop
上级 46c46954
......@@ -49,11 +49,14 @@ class Quant2Int8MkldnnPass(object):
self._fake_quantize_types = [
'fake_quantize_moving_average_abs_max',
'fake_quantize_range_abs_max',
'fake_quantize_dequantize_moving_average_abs_max'
]
self._fake_dequantize_types = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
]
self._fake_quantize_dequantize_types = [
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max'
]
self._ops_to_quantize = _ops_to_quantize
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
[-1])
......@@ -137,8 +140,12 @@ class Quant2Int8MkldnnPass(object):
for var_name in var_names:
scales[var_name] = (use_unsigned_int, lod_tensor)
# fake_quantize_dequantize_abs_max doesn't have scale value
fake_ops = ['fake_quantize_dequantize_moving_average_abs_max']
fake_ops.extend(self._fake_quantize_types)
for op in graph.all_op_nodes():
if op.name() in self._fake_quantize_types:
if op.name() in fake_ops:
bit_length = op.op().attr("bit_length")
assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format(
bit_length)
......@@ -164,14 +171,14 @@ class Quant2Int8MkldnnPass(object):
if op.op().has_attr("max_range"):
_max_range = np.array(op.op().attr("max_range")).astype(
np.float64)
self._weight_scales[input_name] = _max_range
self._weight_scales[input_name] = np.array(
self._s8_max * self._s8_max /
_max_range).astype(np.float64)
else:
scale_name = op.input("Scales")[0]
scales = np.array(
self._s8_max * self._s8_max / self._load_param(
self._scope, scale_name)).astype(np.float64)
scales[scales == np.Inf] = 0.0
self._weight_scales[input_name] = scales
self._weight_scales[input_name] = np.array(
self._load_param(self._scope, scale_name)).astype(
np.float64)
return graph
......@@ -243,9 +250,9 @@ class Quant2Int8MkldnnPass(object):
for op in graph.all_op_nodes():
if op.name() in self._fake_quantize_types:
self._remove_fake_quantize(graph, op)
for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types:
elif op.name() in self._fake_dequantize_types:
self._remove_fake_dequantize(graph, op)
elif op.name() in self._fake_quantize_dequantize_types:
self._remove_fake_dequantize(graph, op)
return graph
......@@ -290,10 +297,15 @@ class Quant2Int8MkldnnPass(object):
])
def _dequantize_weights(self, graph):
def _is_int8_weights(op_node, weight_name):
weight_var_name = op_node.input(weight_name)[0]
weight = self._load_param(self._scope, weight_var_name)
return np.all(np.mod(weight, 1) == 0)
for op in graph.all_op_nodes():
if op.name() in self._conv_ops:
if op.name() in self._conv_ops and _is_int8_weights(op, "Filter"):
self._dequantize_op_weights(graph, op, "Filter", "Output")
elif op.name() in self._mul_ops:
elif op.name() in self._mul_ops and _is_int8_weights(op, "Y"):
self._dequantize_op_weights(graph, op, "Y", "Out")
return graph
......@@ -304,9 +316,9 @@ class Quant2Int8MkldnnPass(object):
scales = self._weight_scales[output_var_name]
weight = self._load_param(self._scope, weight_var_name)
if scales.size == 1 or scales.size == weight.shape[0]:
w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T
w_fp32 = np.multiply(np.divide(weight, self._s8_max).T, scales.T).T
elif len(weight.shape) > 1 and scales.size == weight.shape[1]:
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
w_fp32 = np.multiply(np.divide(weight, self._s8_max), scales)
else:
raise ValueError(
"The size of weight scales vector ({}) does not match the dimensions ({}) of the weights tensor {}."
......
......@@ -187,9 +187,9 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
assert np.allclose(
self.scope.find_var("mul_weights").get_tensor(),
[[127, 63.5, 42.3333, 31.75, 25.4],
[127, 63.5, 42.3333, 31.75, 25.4],
[127, 63.5, 42.3333, 31.75, 25.4]])
[[1. / 127., 2. / 127., 3. / 127., 4. / 127., 5. / 127.],
[1. / 127., 2. / 127., 3. / 127., 4. / 127., 5. / 127.],
[1. / 127., 2. / 127., 3. / 127., 4. / 127., 5. / 127.]])
param = self.scope.var("mul_weights").get_tensor()
param.set(self.variables_mul["mul_weights_bad"], self.place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册