提交 4286a627 编写于 作者: W Wojciech Uss 提交者: Tao Luo

Add support for new QAT models (#18970)

* Add support for new QAT models

test=develop
Co-Authored-By: NMichał Gallus <michal.gallus@intel.com>
Co-Authored-By: NWojciech Uss <wojciech.uss@intel.com>

* fixed fps results

test=develop

* fix top5 accuracy drop problem

* updated for new QAT models

* skip quantizing average pooling - dirty but working

* add missing pass

* added missing conv+brelu fuse pass

* removed a call to non-existent pass

test=develop

* renamed pass

test=develop

* Adjust finding pooling scale to newest QAT models

* Remove unnecessary code from quantization_mkldnn_pass

* Copy Pooling input scale to output scale in QAT

* Refactor & remove unused code in QAT

* Incorporate fp32 FC into QAT

test=develop

* Enable graph drawing with debug flag

test=develop

* Add tests for QATv2

* Fix paths for QATv2 models

test=develop

* Add option to save transformed int8 qat model

test=develop

* Remove redundant lines from qat mkldnn pass

test=develop

* Delegate disablement of avg pooling to qat

test=develop

* fix CI bug, test=develop

* Follow Wangzhen's Review, test=develop

* Update API.spec

test=develop

* Name False in (is_unsigned, TensorScale) tuple

test=develop
上级 99a9615a
......@@ -219,6 +219,10 @@ PYBIND11_MODULE(core_noavx, m) {
[](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<float>(place);
})
.def("_alloc_double",
[](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<double>(place);
})
.def("_alloc_int",
[](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<int>(place);
......@@ -1154,6 +1158,9 @@ All parameter, weight, gradient are variables in Paddle.
m.def("size_of_dtype", framework::SizeOfType);
using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, LoDTensor>>;
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init())
.def("has", &ir::Pass::Has)
......@@ -1168,6 +1175,20 @@ All parameter, weight, gradient are variables in Paddle.
})
.def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); })
.def("set",
[](ir::Pass &self, const std::string &name,
std::unordered_set<std::string> set) {
self.Set(name, new std::unordered_set<std::string>(set));
})
.def("set",
[](ir::Pass &self, const std::string &name,
std::unordered_set<int> set) {
self.Set(name, new std::unordered_set<int>(set));
})
.def("set",
[](ir::Pass &self, const std::string &name, VarQuantScale scales) {
self.Set(name, new VarQuantScale(scales));
})
.def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
self.Apply(graph.get());
......
......@@ -17,10 +17,10 @@ from .... import core
from ....framework import IrGraph
from ....framework import IrNode
__all__ = ['TransformForMkldnnPass']
__all__ = ['FakeQAT2MkldnnINT8KernelPass', 'FakeQAT2MkldnnINT8PerfPass']
class TransformForMkldnnPass(object):
class FakeQAT2MkldnnINT8KernelPass(object):
"""
Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8
IrGraph. Following transformations did in this pass:
......@@ -36,7 +36,7 @@ class TransformForMkldnnPass(object):
4. Remove fake_dequantize_abs_max op
"""
def __init__(self, scope=None, place=None):
def __init__(self, _scope=None, _place=None):
"""
Args:
scope(fluid.Scope): scope is used to initialize the new parameters.
......@@ -48,33 +48,37 @@ class TransformForMkldnnPass(object):
# The original graph will be rewrite.
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \
import TransformForMkldnnPass
import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace()
mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(),
mkldnn_pass = FakeQAT2MkldnnINT8KernelPass(fluid.global_scope(),
place)
mkldnn_pass.apply(graph)
"""
self._scope = scope
self._place = place
self._scope = _scope
self._place = _place
self.quantize_type = [
self._quantize_type = [
'fake_quantize_moving_average_abs_max',
'fake_quantize_range_abs_max'
]
self.dequantize_type = ['fake_dequantize_max_abs']
self._dequantize_type = ['fake_dequantize_max_abs']
self._quantize_dequantize_type = [
'fake_quantize_dequantize_moving_average_abs_max'
]
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d']
self.InScale = {}
self.max_range = {}
self.new_output = {}
self.s8_max = 127
self._in_scale = {}
self._max_range = {}
self._new_output = {}
self._s8_max = 127
def apply(self, graph):
"""
......@@ -91,37 +95,53 @@ class TransformForMkldnnPass(object):
ops = graph.all_op_nodes()
persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
# Collect the InScales and max_range to calculate the new scales for MKL-DNN
# Collect the _in_scales and _max_range to calculate the new scales for MKL-DNN
# INT8 conv2d and mul
for op_node in ops:
if op_node.name() in self.dequantize_type:
if op_node.name() in self._dequantize_type:
input_name = op_node.input("X")[0]
scale_name = op_node.input("Scale")[0]
self.InScale[input_name] = self._load_param(self._scope,
self._in_scale[input_name] = self._load_param(self._scope,
scale_name)[0]
self._max_range[input_name] = op_node.op().attr("max_range")
self._new_output[input_name] = op_node.output("Out")[0]
if op_node.name() in self._quantize_dequantize_type:
inputs = op_node.op().input_names()
attrs = op_node.op().attr_names()
input_name = op_node.input("X")[0]
scale_name = op_node.input("InScale")[0]
self._in_scale[input_name] = self._load_param(self._scope,
scale_name)[0]
self.max_range[input_name] = op_node.op().attr("max_range")
self.new_output[input_name] = op_node.output("Out")[0]
# self._max_range[input_name] = op_node.op().attr("max_range")
self._new_output[input_name] = op_node.output("Out")[0]
for op_node in ops:
if op_node.name() in self._quantizable_ops:
if op_node.name() in self._conv_ops:
self._transform_to_conv_mkldnn(graph, op_node)
elif op_node.name() in self._pool_ops:
self._transform_to_pool_mkldnn(graph, op_node)
else:
self._transform_to_mul_mkldnn(graph, op_node)
elif op_node.name() in self.quantize_type:
elif op_node.name() in self._quantize_type:
self._transform_to_quantize_mkldnn(graph, op_node)
elif op_node.name() in self.dequantize_type:
elif op_node.name() in self._dequantize_type:
self._remove_fake_dequantize_op(graph, op_node)
self._remove_unused_var_nodes(graph)
return graph
def _transform_to_pool_mkldnn(self, graph, op):
output_name = op.output("Out")[0]
input_name = op.input("X")[0]
def _transform_to_conv_mkldnn(self, graph, op_node):
weight_name = op_node.input("Filter")[0]
output_name = op_node.output("Output")[0]
# Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(
np.multiply(weight, self.s8_max), self.max_range[output_name])
np.multiply(weight, self._s8_max), self._max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs,
......@@ -129,8 +149,8 @@ class TransformForMkldnnPass(object):
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
# Set fake_dequantize_abs_max's output as new output of conv2d
output_var_node = graph._find_node_by_name(graph.all_var_nodes(),
self.new_output[output_name])
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), self._new_output[output_name])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
......@@ -144,9 +164,9 @@ class TransformForMkldnnPass(object):
outputs={'Output': output_var_node})
# Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d
scale_in = self.s8_max / self.InScale[output_name]
scale_in = self._s8_max / self._in_scale[output_name]
scale_w = []
scale_w = [self.max_range[output_name] / self.s8_max]
scale_w = [self._max_range[output_name] / self._s8_max]
conv_op_node.set_attr("Scale_weights", scale_w)
conv_op_node.set_attr("Scale_in", scale_in)
......@@ -165,7 +185,7 @@ class TransformForMkldnnPass(object):
# Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(
np.multiply(weight, self.s8_max), self.max_range[output_name])
np.multiply(weight, self._s8_max), self._max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs,
......@@ -173,8 +193,8 @@ class TransformForMkldnnPass(object):
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
# Set fake_dequantize_abs_max's output as new output of mul
output_var_node = graph._find_node_by_name(graph.all_var_nodes(),
self.new_output[output_name])
output_var_node = graph._find_node_by_name(
graph.all_var_nodes(), self._new_output[output_name])
attrs = {
name: op_node.op().attr(name)
for name in op_node.op().attr_names()
......@@ -188,9 +208,9 @@ class TransformForMkldnnPass(object):
outputs={'Out': output_var_node})
# Based on the QAT's scales to calculate MKL-DNN INT8 mul's scales
scale_in = self.s8_max / self.InScale[output_name]
scale_in = self._s8_max / self._in_scale[output_name]
scale_w = []
scale_w = [self.max_range[output_name] / self.s8_max]
scale_w = [self._max_range[output_name] / self._s8_max]
mul_op_node.set_attr("scale_y", scale_w)
mul_op_node.set_attr("scale_x", scale_in)
......@@ -210,7 +230,7 @@ class TransformForMkldnnPass(object):
op_node.input("X")[0])
output_var_node = graph._find_node_by_name(op_node.outputs,
op_node.output("Out")[0])
scale_in = self.s8_max / self._load_param(
scale_in = self._s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node(
op_type='quantize',
......@@ -254,3 +274,279 @@ class TransformForMkldnnPass(object):
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
class FakeQAT2MkldnnINT8PerfPass(object):
"""
Transform a QAT model IrGraph into MKL-DNN supported INT8 IrGraph.
The pass consists of the following transformations:
1. gather scale values from fake quantize/dequantize operators,
2. extract FP32 inference model graph from the QAT graph, i.e.
a. remove fake quantize/dequantize operators,
b. dequantize conv2d and mul's weights,
3. optimize the FP32 graph using standard FP32 optimization fuses
(e.g. `conv2d`+`bn` -> `conv2d`),
4. quantize the optimized FP32 graph using standard INT8v2 quantization
passes (`cpu_quantize_pass`, `cpu_quantize_squash_pass`).
"""
def __init__(self, _scope=None, _place=None, _core=None, _debug=False):
self._scope = _scope
self._place = _place
self._core = _core
self._debug = _debug
self._quantize_types = [
'fake_quantize_moving_average_abs_max',
'fake_quantize_range_abs_max',
'fake_quantize_dequantize_moving_average_abs_max'
]
self._fake_quantize_types = [
'fake_quantize_moving_average_abs_max',
'fake_quantize_dequantize_moving_average_abs_max'
]
self._fake_dequantize_types = ['fake_dequantize_max_abs']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d']
self._mul_ops = ['mul']
self._fc_ops = ['fc']
self._weight_scales = {}
# Collect the Input and Output sclaes from Fake QAT models
self._var_quant_scales = {}
self._max_range = {}
self._s8_max = 127
def apply(self, graph):
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
graph = self._gather_scales(graph)
graph = self._remove_fake_ops(graph)
graph = self._update_pooling_scales(graph)
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph)
graph = self._compute_weight_scales(graph)
graph = self._quantize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph)
return graph
def _convert_scale2tensor(self, scale):
tensor = core.LoDTensor()
tensor.set(scale, core.CPUPlace())
return tensor
def _gather_scales(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._quantize_types:
bit_length = op.op().attr("bit_length")
assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format(
bit_length)
input_name = op.input("X")[0]
scale_name = op.input("InScale")[0]
# Gather new weights scale after folding batchnorm in convolution
scale = np.array(1.0 / self._load_param(
self._scope, scale_name)[0]).astype(np.float64)
lod_tensor = self._convert_scale2tensor(scale)
use_unsigned_int = False
self._var_quant_scales[input_name] = (use_unsigned_int,
lod_tensor)
if op.name() in self._fake_dequantize_types:
input_name = op.input("X")[0]
_max_range = op.op().attr("max_range")
self._weight_scales[input_name] = _max_range
return graph
def _update_pooling_scales(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._pool_ops:
input_name = op.input("X")[0]
output_name = op.output("Out")[0]
if input_name in self._var_quant_scales:
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
return graph
def _load_param(self, scope, param_name):
return np.array(scope.find_var(param_name).get_tensor())
def _remove_fake_ops(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._fake_quantize_types:
op_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0])
self._remove_fake_quantize(graph, op)
for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types:
op_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
self._remove_fake_dequantize(graph, op)
return graph
def _remove_fake_quantize(self, graph, op):
fake_quant_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
fake_quant_in_scale = graph._find_node_by_name(op.inputs,
op.input("InScale")[0])
fake_quant_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0])
fake_quant_out_scale = graph._find_node_by_name(
op.outputs, op.output("OutScale")[0])
next_ops = fake_quant_out.outputs
for next_op in next_ops:
self._swap_inputs(next_op, fake_quant_out, fake_quant_in)
graph.link_to(fake_quant_in, next_op)
graph.safe_remove_nodes(
{op, fake_quant_in_scale, fake_quant_out, fake_quant_out_scale})
return graph
def _remove_fake_dequantize(self, graph, op):
fake_dequant_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
fake_dequant_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0])
next_ops = fake_dequant_out.outputs
for next_op in next_ops:
self._swap_inputs(next_op, fake_dequant_out, fake_dequant_in)
graph.link_to(fake_dequant_in, next_op)
graph.safe_remove_nodes({op, fake_dequant_out})
return graph
def _swap_inputs(self, op, old_input, new_input):
for input_name in op.op().input_names():
if old_input.name() in op.input(input_name):
op.op().set_input(input_name, [
new_input.name() if x == old_input.name() else x
for x in op.input(input_name)
])
def _dequantize_weights(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._conv_ops:
self._dequantize_conv_weights(graph, op)
elif op.name() in self._mul_ops:
self._dequantize_mul_weights(graph, op)
return graph
def _dequantize_conv_weights(self, graph, op_node):
weight_name = op_node.input("Filter")[0]
output_name = op_node.output("Output")[0]
# Convert int8 range weights to fp32 range weights
scales = self._weight_scales[output_name]
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
def _dequantize_mul_weights(self, graph, op_node):
weight_name = op_node.input("Y")[0]
output_name = op_node.output("Out")[0]
scales = self._weight_scales[output_name]
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
def _restore_var(self, name, array):
tensor = self._scope.find_var(name).get_tensor()
tensor.set(array, self._place)
def _optimize_fp32_graph(self, graph):
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'fc_fuse_pass')
return graph
def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
ir_pass = core.get_pass(pass_name)
inference_program = graph.to_program()
ir_graph = core.Graph(inference_program.desc)
ir_graph.set_not_owned('__param_scope__', self._scope)
if attrs:
assert attr_values and len(attrs) == len(
attr_values
), "Different number of pass attributes and their values."
for attr, value in zip(attrs, attr_values):
ir_pass.set(attr, value)
ir_pass.apply(ir_graph)
graph = IrGraph(ir_graph, for_test=True)
if self._debug:
graph.draw('.', 'qat_fp32_{}'.format(pass_name),
graph.all_op_nodes())
self._remove_unused_var_nodes(graph)
return graph
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
return graph
def _compute_weight_scales(self, graph):
def _compute_var_scales(ops, out_name, w_name, axis):
for op in graph.all_op_nodes():
if op.op().type() in ops:
weight_var_name = op.input(w_name)[0]
weights = np.array(
self._load_param(self._scope, weight_var_name))
scales = 1.0 / np.amax(
np.abs(weights.reshape(weights.shape[0], -1)),
axis=axis)
lod_tensor = self._convert_scale2tensor(
scales.astype(np.float64))
use_unsigned_int = False
self._var_quant_scales[weight_var_name] = (use_unsigned_int,
lod_tensor)
_compute_var_scales(self._conv_ops, "Output", "Filter", axis=1)
_compute_var_scales(self._fc_ops, "Out", "W", axis=0)
return graph
def _find_avg_pooling_ids(self, graph):
ids = []
for op in graph.all_op_nodes():
if op.name() in self._pool_ops:
if op.op().attr("pooling_type") == "avg":
ids.append(op.id())
return set(ids)
def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass')
inference_program = graph.to_program()
ir_graph = self._core.Graph(inference_program.desc)
ir_pass.set('quantize_enabled_op_types', {'conv2d', 'pool2d'})
ir_pass.set('quantize_excluded_op_ids',
self._find_avg_pooling_ids(graph))
ir_pass.apply(ir_graph)
graph = IrGraph(ir_graph, for_test=True)
if self._debug:
graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()),
graph.all_op_nodes())
graph = self._apply_pass(graph, 'cpu_quantize_pass',
['quant_var_scales'],
[self._var_quant_scales])
graph = self._apply_pass(graph, 'cpu_quantize_squash_pass')
return graph
......@@ -32,6 +32,20 @@ function(inference_qat_int8_test target model_dir data_dir test_script use_mkldn
--acc_diff_threshold 0.1)
endfunction()
function(inference_qat2_int8_test target model_dir data_dir test_script use_mkldnn)
py_test(${target} SRCS ${test_script}
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
ARGS --qat_model ${model_dir}/float
--infer_data ${data_dir}/data.bin
--batch_size 25
--batch_num 2
--acc_diff_threshold 0.1
--qat2)
endfunction()
if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas)
endif()
......@@ -142,6 +156,19 @@ if(LINUX AND WITH_MKLDNN)
inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_qat_perf")
if (NOT EXISTS ${QAT2_RESNET50_MODEL_DIR})
inference_download_and_uncompress(${QAT2_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_perf.tar.gz" )
endif()
inference_qat2_int8_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf")
if (NOT EXISTS ${QAT2_MOBILENETV1_MODEL_DIR})
inference_download_and_uncompress(${QAT2_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNet_qat_perf.tar.gz" )
endif()
inference_qat2_int8_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
endif()
# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux
......
......@@ -6,11 +6,11 @@ This document describes how to use [Paddle Slim](https://github.com/PaddlePaddle
You need to install at least PaddlePaddle-1.5 python package `pip install paddlepaddle==1.5`.
## 1. How to generate INT8 MKL-DNN QAT model
You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `TransformForMkldnnPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.5, this pass only supports `conv2d` and `depthwise_conv2d` with channel-wise quantization for weights.
You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `FakeQAT2MkldnnINT8KernelPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.5, this pass only supports `conv2d` and `depthwise_conv2d` with channel-wise quantization for weights.
```python
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
......@@ -18,9 +18,9 @@ You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quanti
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace()
# Convert the IrGraph to MKL-DNN supported INT8 IrGraph by using
# TransformForMkldnnPass
mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(), place)
# Apply TransformForMkldnnPass to IrGraph
# FakeQAT2MkldnnINT8KernelPass
mkldnn_pass = FakeQAT2MkldnnINT8KernelPass(fluid.global_scope(), place)
# Apply FakeQAT2MkldnnINT8KernelPass to IrGraph
mkldnn_pass.apply(graph)
```
......
......@@ -24,7 +24,8 @@ import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass
from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
......@@ -41,8 +42,21 @@ def parse_args():
default=0,
help='Number of the first minibatches to skip in performance statistics.'
)
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument(
'--qat2',
action='store_true',
help='If used, the QAT model is treated as a second generation model for performance optimization.'
)
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--batch_num',
......@@ -164,12 +178,24 @@ class TestQatInt8Comparison(unittest.TestCase):
model_path, exe, 'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes())
if (transform_to_int8):
mkldnn_int8_pass = TransformForMkldnnPass(
scope=inference_scope, place=place)
mkldnn_int8_pass.apply(graph)
if (test_case_args.qat2):
transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass(
_scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph)
else:
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
_scope=inference_scope, _place=place)
graph = mkldnn_int8_pass.apply(graph)
else:
graph = self._prepare_for_fp32_mkldnn(graph)
inference_program = graph.to_program()
dshape = [3, 224, 224]
......@@ -209,7 +235,7 @@ class TestQatInt8Comparison(unittest.TestCase):
samples = len(data)
total_samples += samples
batch_times.append(batch_time)
fps = samples / batch_time
fps = samples / batch_time * 1000
fpses.append(fps)
iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else ''
......@@ -230,6 +256,12 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time))
if test_case_args.save_model:
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(
'transformed_qat_int8_model', feed_target_names,
fetch_targets, exe, inference_program)
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat):
......@@ -265,6 +297,7 @@ class TestQatInt8Comparison(unittest.TestCase):
batch_num = test_case_args.batch_num
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
......@@ -283,7 +316,6 @@ class TestQatInt8Comparison(unittest.TestCase):
batch_num,
skip_batch_num,
transform_to_int8=False)
_logger.info('--- QAT INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
......
......@@ -22,7 +22,7 @@ import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid import core
os.environ["CPU_NUM"] = "1"
......@@ -90,6 +90,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
seed,
activation_quant_type,
weight_quant_type='abs_max',
qat_perf=False,
for_ci=False):
random.seed(0)
np.random.seed(0)
......@@ -148,7 +149,8 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
freeze_pass.apply(test_graph)
# Transform quantized graph for MKL-DNN INT8 inference
mkldnn_int8_pass = TransformForMkldnnPass(scope=scope, place=place)
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
_scope=scope, _place=place)
mkldnn_int8_pass.apply(test_graph)
dev_name = '_cpu_'
if not for_ci:
......
......@@ -2416,6 +2416,20 @@ class IrOpNode(IrNode):
"The node operator description cannot be None."
self.node.op()._rename_input(old_input_name, new_input_name)
def rename_output(self, old_output_name, new_output_name):
"""
Rename the output of this node.
Args:
old_output_name(str): the old output name.
new_output_name(str): the new output name.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
print("op: {}, old: {}, new: {}\n".format(self.node.op().type(
), old_output_name, new_output_name))
self.node.op()._rename_output(old_output_name, new_output_name)
def input(self, name):
"""
Get the argument name list by the parameter name for input.
......@@ -2709,6 +2723,24 @@ class IrGraph(object):
op_node.append_input(new_input_node)
op_node.rename_input(old_input_node.name(), new_input_node.name())
def update_output_link(self, old_output_node, new_output_node, op_node):
"""
Update the output's link of an operator node.
Args:
old_output_node(IrNode): the old output node of the giving op_node.
new_output_node(IrNode): the new output node of the giving op_node.
op_node(IrOpNode): the operator node that is needed to update input's link.
"""
assert old_output_node.node in self.graph.nodes() and new_output_node.node in \
self.graph.nodes() and op_node.node in self.graph.nodes(), \
'The three arguments(old_output_node &new_output_node &op_node) must be in the graph nodes.'
old_output_node.remove_input(op_node)
op_node.remove_output(old_output_node)
new_output_node.append_input(op_node)
op_node.append_output(new_output_node)
op_node.rename_output(old_output_node.name(), new_output_node.name())
def link_to(self, node_in, node_out):
"""
Connect two nodes.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册