未验证 提交 797f7cbd 编写于 作者: W Wojciech Uss 提交者: GitHub

QAT: support for new models (#23928) (#24121)

* QAT: support range-based quantization and scales from attribute

* added support for channelwise

test=release/2.0
上级 e0d0b129
...@@ -45,16 +45,14 @@ class Qat2Int8MkldnnPass(object): ...@@ -45,16 +45,14 @@ class Qat2Int8MkldnnPass(object):
self._place = _place self._place = _place
self._core = _core self._core = _core
self._debug = _debug self._debug = _debug
self._quantize_types = [ self._fake_quantize_types = [
'fake_quantize_moving_average_abs_max', 'fake_quantize_moving_average_abs_max',
'fake_quantize_range_abs_max', 'fake_quantize_range_abs_max',
'fake_quantize_dequantize_moving_average_abs_max' 'fake_quantize_dequantize_moving_average_abs_max'
] ]
self._fake_quantize_types = [ self._fake_dequantize_types = [
'fake_quantize_moving_average_abs_max', 'fake_dequantize_max_abs', 'fake_channel_wise_dequantize_max_abs'
'fake_quantize_dequantize_moving_average_abs_max'
] ]
self._fake_dequantize_types = ['fake_dequantize_max_abs']
self._quantized_ops = _quantized_ops self._quantized_ops = _quantized_ops
self._scale_immutable_ops = [ self._scale_immutable_ops = [
'transpose2', 'reshape2', 'pool2d', 'scale' 'transpose2', 'reshape2', 'pool2d', 'scale'
...@@ -74,7 +72,9 @@ class Qat2Int8MkldnnPass(object): ...@@ -74,7 +72,9 @@ class Qat2Int8MkldnnPass(object):
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
graph = self._gather_scales(graph) graph = self._gather_weight_scales_from_fake(graph)
graph = self._gather_output_scales_from_attr(graph)
graph = self._gather_input_scales_from_fake(graph)
graph = self._remove_fake_ops(graph) graph = self._remove_fake_ops(graph)
graph = self._dequantize_weights(graph) graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph) graph = self._optimize_fp32_graph(graph)
...@@ -83,6 +83,7 @@ class Qat2Int8MkldnnPass(object): ...@@ -83,6 +83,7 @@ class Qat2Int8MkldnnPass(object):
graph = self._propagate_scales(graph) graph = self._propagate_scales(graph)
graph = self._set_dummy_out_scales(graph) graph = self._set_dummy_out_scales(graph)
graph = self._quantize_fp32_graph(graph) graph = self._quantize_fp32_graph(graph)
graph = self._optimize_int8_graph(graph)
graph = self._cleanup(graph) graph = self._cleanup(graph)
return graph return graph
...@@ -90,9 +91,6 @@ class Qat2Int8MkldnnPass(object): ...@@ -90,9 +91,6 @@ class Qat2Int8MkldnnPass(object):
assert isinstance(graph, assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.' IrGraph), 'graph must be the instance of IrGraph.'
graph = self._gather_scales(graph)
graph = self._remove_fake_ops(graph)
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph) graph = self._optimize_fp32_graph(graph)
graph = self._cleanup(graph) graph = self._cleanup(graph)
return graph return graph
...@@ -108,29 +106,61 @@ class Qat2Int8MkldnnPass(object): ...@@ -108,29 +106,61 @@ class Qat2Int8MkldnnPass(object):
def _is_fc_quantized(self): def _is_fc_quantized(self):
return 'fc' in self._quantized_ops return 'fc' in self._quantized_ops
def _gather_scales(self, graph): def _gather_input_scales_from_fake(self, graph):
def _add_scale_for_vars(var_names, use_unsigned_int, lod_tensor):
scales = self._var_quant_scales
for var_name in var_names:
scales[var_name] = (use_unsigned_int, lod_tensor)
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._quantize_types: if op.name() in self._fake_quantize_types:
bit_length = op.op().attr("bit_length") bit_length = op.op().attr("bit_length")
assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format( assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format(
bit_length) bit_length)
input_name = op.input("X")[0] input_name = op.input("X")[0]
scale_name = op.input("InScale")[0] scale_name = op.input("InScale")[0]
output_name = op.output("Out")[0]
# Gather new weights scale after folding batchnorm in convolution # Gather new weights scale after folding batchnorm in convolution
scale = np.array(1.0 / self._load_param( scale = np.array(1.0 / self._load_param(
self._scope, scale_name)[0]).astype(np.float64) self._scope, scale_name)[0]).astype(np.float64)
lod_tensor = self._convert_scale2tensor(scale) lod_tensor = self._convert_scale2tensor(scale)
use_unsigned_int = False use_unsigned_int = False
self._var_quant_scales[input_name] = (use_unsigned_int, _add_scale_for_vars([input_name, output_name], use_unsigned_int,
lod_tensor) lod_tensor)
self._var_quant_scales[scale_name.replace(".scale", "")] = (
use_unsigned_int, lod_tensor) return graph
def _gather_weight_scales_from_fake(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types: if op.name() in self._fake_dequantize_types:
input_name = op.input("X")[0] input_name = op.input("X")[0]
_max_range = op.op().attr("max_range") if op.op().has_attr("max_range"):
self._weight_scales[input_name] = _max_range _max_range = np.array(op.op().attr("max_range")).astype(
np.float64)
self._weight_scales[input_name] = _max_range
else:
scale_name = op.input("Scales")[0]
scale = np.array(
self._s8_max * self._s8_max / self._load_param(
self._scope, scale_name)).astype(np.float64)
self._weight_scales[input_name] = scale
return graph
def _gather_output_scales_from_attr(self, graph):
for op in graph.all_op_nodes():
if op.op().has_attr("out_threshold"):
attr_scale = op.op().attr("out_threshold")
if attr_scale == 0.0: continue
scale = np.array(1.0 / attr_scale).astype(np.float64)
scale_lod_tensor = self._convert_scale2tensor(scale)
use_unsigned_int = False
for output_name in op.op().outputs():
for out_var_name in op.op().output(output_name):
self._var_quant_scales[out_var_name] = (
use_unsigned_int, scale_lod_tensor)
return graph return graph
def _propagate_scales(self, graph): def _propagate_scales(self, graph):
...@@ -274,29 +304,24 @@ class Qat2Int8MkldnnPass(object): ...@@ -274,29 +304,24 @@ class Qat2Int8MkldnnPass(object):
def _dequantize_weights(self, graph): def _dequantize_weights(self, graph):
for op in graph.all_op_nodes(): for op in graph.all_op_nodes():
if op.name() in self._conv_ops: if op.name() in self._conv_ops:
self._dequantize_conv_weights(graph, op) self._dequantize_op_weights(graph, op, "Filter", "Output")
elif self._is_fc_quantized() and op.name() in self._mul_ops: elif self._is_fc_quantized() and op.name() in self._mul_ops:
self._dequantize_mul_weights(graph, op) self._dequantize_op_weights(graph, op, "Y", "Out")
return graph return graph
def _dequantize_conv_weights(self, graph, op_node): def _dequantize_op_weights(self, graph, op_node, weight_name, output_name):
weight_name = op_node.input("Filter")[0] weight_var_name = op_node.input(weight_name)[0]
output_name = op_node.output("Output")[0] output_var_name = op_node.output(output_name)[0]
# Convert int8 range weights to fp32 range weights # Convert int8 range weights to fp32 range weights
scales = self._weight_scales[output_name] scales = self._weight_scales[output_var_name]
weight = self._load_param(self._scope, weight_name) weight = self._load_param(self._scope, weight_var_name)
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales) assert scales.size == 1 or scales.size == len(
w_fp32 = w_fp32.reshape(weight.shape) weight
self._restore_var(weight_name, w_fp32) ), "The size of weight scales vector ({}) does not match the number of output channels ({}) in the weights tensor {}.".format(
scales.size, len(weight), weight_var_name)
def _dequantize_mul_weights(self, graph, op_node): w_fp32 = np.divide(np.multiply(weight, self._s8_max).T, scales.T).T
weight_name = op_node.input("Y")[0] w_fp32 = w_fp32.reshape(weight.shape).astype(np.float32)
output_name = op_node.output("Out")[0] self._restore_var(weight_var_name, w_fp32)
scales = self._weight_scales[output_name]
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
def _restore_var(self, name, array): def _restore_var(self, name, array):
tensor = self._scope.find_var(name).get_tensor() tensor = self._scope.find_var(name).get_tensor()
...@@ -366,11 +391,14 @@ class Qat2Int8MkldnnPass(object): ...@@ -366,11 +391,14 @@ class Qat2Int8MkldnnPass(object):
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
return graph return graph
def _cleanup(self, graph): def _optimize_int8_graph(self, graph):
# remove dropout ops # remove dropout ops
graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass') graph = self._apply_pass(graph, 'simplify_with_basic_ops_pass')
# make some MKL-DNN ops working inplace # make some MKL-DNN ops working inplace
graph = self._apply_pass(graph, 'mkldnn_inplace_pass') graph = self._apply_pass(graph, 'mkldnn_inplace_pass')
return graph
def _cleanup(self, graph):
graph = self._remove_unused_var_nodes(graph) graph = self._remove_unused_var_nodes(graph)
graph = self._set_op_role_forward(graph) graph = self._set_op_role_forward(graph)
return graph return graph
......
...@@ -68,7 +68,7 @@ function(inference_qat2_int8_image_classification_test target qat_model_dir fp32 ...@@ -68,7 +68,7 @@ function(inference_qat2_int8_image_classification_test target qat_model_dir fp32
--batch_size 10 --batch_size 10
--batch_num 2 --batch_num 2
--acc_diff_threshold 0.1 --acc_diff_threshold 0.1
--quantized_ops ${quantized_ops}) --quantized_ops ${quantized_ops})
endfunction() endfunction()
# set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20 # set batch_size 10 for UT only (avoid OOM). For whole dataset, use batch_size 20
...@@ -215,13 +215,28 @@ if(LINUX AND WITH_MKLDNN) ...@@ -215,13 +215,28 @@ if(LINUX AND WITH_MKLDNN)
set(QAT2_IC_QUANTIZED_OPS "conv2d,pool2d") set(QAT2_IC_QUANTIZED_OPS "conv2d,pool2d")
# QAT2 ResNet50 # QAT2 ResNet50 with input/output scales in `fake_quantize_moving_average_abs_max` operators,
# with weight scales in `fake_dequantize_max_abs` operators
set(QAT2_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_perf") set(QAT2_RESNET50_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_perf")
set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50") set(FP32_RESNET50_MODEL_DIR "${INT8_INSTALL_DIR}/resnet50")
set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz") set(QAT2_RESNET50_MODEL_ARCHIVE "ResNet50_qat_perf.tar.gz")
download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE}) download_qat_model(${QAT2_RESNET50_MODEL_DIR} ${QAT2_RESNET50_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS}) inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
# QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_dequantize_max_abs` operators
set(QAT2_RESNET50_RANGE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_range")
set(QAT2_RESNET50_RANGE_MODEL_ARCHIVE "ResNet50_qat_range.tar.gz")
download_qat_model(${QAT2_RESNET50_RANGE_MODEL_DIR} ${QAT2_RESNET50_RANGE_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_range_mkldnn ${QAT2_RESNET50_RANGE_MODEL_DIR}/ResNet50_qat_range ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
# QAT2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_channel_wise_dequantize_max_abs` operators
set(QAT2_RESNET50_CHANNELWISE_MODEL_DIR "${QAT_INSTALL_DIR}/ResNet50_qat_channelwise")
set(QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz")
download_qat_model(${QAT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QAT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE})
inference_qat2_int8_image_classification_test(test_qat2_int8_resnet50_channelwise_mkldnn ${QAT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise ${FP32_RESNET50_MODEL_DIR}/model ${IMAGENET_DATA_PATH} ${QAT2_IC_QUANTIZED_OPS})
# QAT2 MobileNetV1 # QAT2 MobileNetV1
set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf") set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_INSTALL_DIR}/MobileNet_qat_perf")
set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1") set(FP32_MOBILENETV1_MODEL_DIR "${INT8_INSTALL_DIR}/mobilenetv1")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册