未验证 提交 1fa863da 编写于 作者: C cc 提交者: GitHub

Support dygraph quant model (#29927)

* Avoid the scale to be infinity in quant2_int8_mkldnn_pass, test=develop
* support quantized model for paddle2.0 dygraph, test=develop
上级 46c46954
...@@ -49,11 +49,14 @@ class Quant2Int8MkldnnPass(object): ...@@ -49,11 +49,14 @@ class Quant2Int8MkldnnPass(object):
self._fake_quantize_types = [ self._fake_quantize_types = [
'fake_quantize_moving_average_abs_max', 'fake_quantize_moving_average_abs_max',
'fake_quantize_range_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_dequantize_moving_average_abs_max'
] ]
self._fake_dequantize_types = [ self._fake_dequantize_types = [
'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs' 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
] ]
self._fake_quantize_dequantize_types = [
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max'
]
self._ops_to_quantize = _ops_to_quantize self._ops_to_quantize = _ops_to_quantize
self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set( self._op_ids_to_skip = _op_ids_to_skip if _op_ids_to_skip is not None else set(
[-1]) [-1])
...@@ -137,8 +140,12 @@ class Quant2Int8MkldnnPass(object): ...@@ -137,8 +140,12 @@ class Quant2Int8MkldnnPass(object):
for var_name in var_names: for var_name in var_names:
scales[var_name] = (use_unsigned_int, lod_tensor) scales[var_name] = (use_unsigned_int, lod_tensor)
# fake_quantize_dequantize_abs_max doesn't have scale value
fake_ops = ['fake_quantize_dequantize_moving_average_abs_max']
fake_ops.extend(self._fake_quantize_types)
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._fake_quantize_types: if op.name() in fake_ops:
bit_length = op.op().attr("bit_length") bit_length = op.op().attr("bit_length")
assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format( assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format(
bit_length) bit_length)
...@@ -164,14 +171,14 @@ class Quant2Int8MkldnnPass(object): ...@@ -164,14 +171,14 @@ class Quant2Int8MkldnnPass(object):
if op.op().has_attr("max_range"): if op.op().has_attr("max_range"):
_max_range = np.array(op.op().attr("max_range")).astype( _max_range = np.array(op.op().attr("max_range")).astype(
np.float64) np.float64)
self._weight_scales[input_name] = _max_range self._weight_scales[input_name] = np.array(
self._s8_max * self._s8_max /
_max_range).astype(np.float64)
else: else:
scale_name = op.input("Scales")[0] scale_name = op.input("Scales")[0]
scales = np.array( self._weight_scales[input_name] = np.array(
self._s8_max * self._s8_max / self._load_param( self._load_param(self._scope, scale_name)).astype(
self._scope, scale_name)).astype(np.float64) np.float64)
scales[scales == np.Inf] = 0.0
self._weight_scales[input_name] = scales
return graph return graph
...@@ -243,9 +250,9 @@ class Quant2Int8MkldnnPass(object): ...@@ -243,9 +250,9 @@ class Quant2Int8MkldnnPass(object):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._fake_quantize_types: if op.name() in self._fake_quantize_types:
self._remove_fake_quantize(graph, op) self._remove_fake_quantize(graph, op)
elif op.name() in self._fake_dequantize_types:
for op in graph.all_op_nodes(): self._remove_fake_dequantize(graph, op)
if op.name() in self._fake_dequantize_types: elif op.name() in self._fake_quantize_dequantize_types:
self._remove_fake_dequantize(graph, op) self._remove_fake_dequantize(graph, op)
return graph return graph
...@@ -290,10 +297,15 @@ class Quant2Int8MkldnnPass(object): ...@@ -290,10 +297,15 @@ class Quant2Int8MkldnnPass(object):
]) ])
def _dequantize_weights(self, graph): def _dequantize_weights(self, graph):
def _is_int8_weights(op_node, weight_name):
weight_var_name = op_node.input(weight_name)[0]
weight = self._load_param(self._scope, weight_var_name)
return np.all(np.mod(weight, 1) == 0)
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._conv_ops: if op.name() in self._conv_ops and _is_int8_weights(op, "Filter"):
self._dequantize_op_weights(graph, op, "Filter", "Output") self._dequantize_op_weights(graph, op, "Filter", "Output")
elif op.name() in self._mul_ops: elif op.name() in self._mul_ops and _is_int8_weights(op, "Y"):
self._dequantize_op_weights(graph, op, "Y", "Out") self._dequantize_op_weights(graph, op, "Y", "Out")
return graph return graph
...@@ -304,9 +316,9 @@ class Quant2Int8MkldnnPass(object): ...@@ -304,9 +316,9 @@ class Quant2Int8MkldnnPass(object):
scales = self._weight_scales[output_var_name] scales = self._weight_scales[output_var_name]
weight = self._load_param(self._scope, weight_var_name) weight = self._load_param(self._scope, weight_var_name)
if scales.size == 1 or scales.size == weight.shape[0]: if scales.size == 1 or scales.size == weight.shape[0]:
w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T w_fp32 = np.multiply(np.divide(weight, self._s8_max).T, scales.T).T
elif len(weight.shape) > 1 and scales.size == weight.shape[1]: elif len(weight.shape) > 1 and scales.size == weight.shape[1]:
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales) w_fp32 = np.multiply(np.divide(weight, self._s8_max), scales)
else: else:
raise ValueError( raise ValueError(
"The size of weight scales vector ({}) does not match the dimensions ({}) of the weights tensor {}." "The size of weight scales vector ({}) does not match the dimensions ({}) of the weights tensor {}."
......
...@@ -187,9 +187,9 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase): ...@@ -187,9 +187,9 @@ class TestQuant2Int8MkldnnPass(unittest.TestCase):
assert np.allclose( assert np.allclose(
self.scope.find_var("mul_weights").get_tensor(), self.scope.find_var("mul_weights").get_tensor(),
[[127, 63.5, 42.3333, 31.75, 25.4], [[1. / 127., 2. / 127., 3. / 127., 4. / 127., 5. / 127.],
[127, 63.5, 42.3333, 31.75, 25.4], [1. / 127., 2. / 127., 3. / 127., 4. / 127., 5. / 127.],
[127, 63.5, 42.3333, 31.75, 25.4]]) [1. / 127., 2. / 127., 3. / 127., 4. / 127., 5. / 127.]])
param = self.scope.var("mul_weights").get_tensor() param = self.scope.var("mul_weights").get_tensor()
param.set(self.variables_mul["mul_weights_bad"], self.place) param.set(self.variables_mul["mul_weights_bad"], self.place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册