提交 4286a627 编写于 作者: W Wojciech Uss 提交者: Tao Luo

Add support for new QAT models (#18970)

* Add support for new QAT models

test=develop
Co-Authored-By: NMichał Gallus <michal.gallus@intel.com>
Co-Authored-By: NWojciech Uss <wojciech.uss@intel.com>

* fixed fps results

test=develop

* fix top5 accuracy drop problem

* updated for new QAT models

* skip quantizing average pooling - dirty but working

* add missing pass

* added missing conv+brelu fuse pass

* removed a call to non-existent pass

test=develop

* renamed pass

test=develop

* Adjust finding pooling scale to newest QAT models

* Remove unnecessary code from quantization_mkldnn_pass

* Copy Pooling input scale to output scale in QAT

* Refactor & remove unused code in QAT

* Incorporate fp32 FC into QAT

test=develop

* Enable graph drawing with debug flag

test=develop

* Add tests for QATv2

* Fix paths for QATv2 models

test=develop

* Add option to save transformed int8 qat model

test=develop

* Remove redundant lines from qat mkldnn pass

test=develop

* Delegate disablement of avg pooling to qat

test=develop

* fix CI bug, test=develop

* Follow Wangzhen's Review, test=develop

* Update API.spec

test=develop

* Name False in (is_unsigned, TensorScale) tuple

test=develop
上级 99a9615a
...@@ -219,6 +219,10 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -219,6 +219,10 @@ PYBIND11_MODULE(core_noavx, m) {
[](Tensor &self, paddle::platform::CPUPlace &place) { [](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<float>(place); self.mutable_data<float>(place);
}) })
.def("_alloc_double",
[](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<double>(place);
})
.def("_alloc_int", .def("_alloc_int",
[](Tensor &self, paddle::platform::CPUPlace &place) { [](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<int>(place); self.mutable_data<int>(place);
...@@ -1154,6 +1158,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1154,6 +1158,9 @@ All parameter, weight, gradient are variables in Paddle.
m.def("size_of_dtype", framework::SizeOfType); m.def("size_of_dtype", framework::SizeOfType);
using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, LoDTensor>>;
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass"); py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init()) pass.def(py::init())
.def("has", &ir::Pass::Has) .def("has", &ir::Pass::Has)
...@@ -1168,6 +1175,20 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1168,6 +1175,20 @@ All parameter, weight, gradient are variables in Paddle.
}) })
.def("set", [](ir::Pass &self, const std::string &name, .def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); }) int val) { self.Set<const int>(name, new int(val)); })
.def("set",
[](ir::Pass &self, const std::string &name,
std::unordered_set<std::string> set) {
self.Set(name, new std::unordered_set<std::string>(set));
})
.def("set",
[](ir::Pass &self, const std::string &name,
std::unordered_set<int> set) {
self.Set(name, new std::unordered_set<int>(set));
})
.def("set",
[](ir::Pass &self, const std::string &name, VarQuantScale scales) {
self.Set(name, new VarQuantScale(scales));
})
.def("type", &ir::Pass::Type) .def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) { .def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
self.Apply(graph.get()); self.Apply(graph.get());
......
...@@ -17,17 +17,17 @@ from .... import core ...@@ -17,17 +17,17 @@ from .... import core
from ....framework import IrGraph from ....framework import IrGraph
from ....framework import IrNode from ....framework import IrNode
__all__ = ['TransformForMkldnnPass'] __all__ = ['FakeQAT2MkldnnINT8KernelPass', 'FakeQAT2MkldnnINT8PerfPass']
class TransformForMkldnnPass(object): class FakeQAT2MkldnnINT8KernelPass(object):
""" """
Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8 Convert QuantizationFreezePass generated IrGraph to MKL-DNN supported INT8
IrGraph. Following transformations did in this pass: IrGraph. Following transformations did in this pass:
1. Convert int8 range weights with float32 data type, which are generated by 1. Convert int8 range weights with float32 data type, which are generated by
the QuantizationFreezePass, to float32 range weights with float32 data type the QuantizationFreezePass, to float32 range weights with float32 data type
by using the corresponding scales. This conversion is because MKL-DNN INT8 by using the corresponding scales. This conversion is because MKL-DNN INT8
conv2d kernel and mul kernel now only support float32 weights input, hence conv2d kernel and mul kernel now only support float32 weights input, hence
weights quantization will happen inside the conv2d and mul INT8 kernel. weights quantization will happen inside the conv2d and mul INT8 kernel.
2. Create the new conv2d or mul op with the converted weights and link its output 2. Create the new conv2d or mul op with the converted weights and link its output
to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32 to fake_dequantize_abs_max op's output and set conv2d's attribute "force_fp32
...@@ -36,7 +36,7 @@ class TransformForMkldnnPass(object): ...@@ -36,7 +36,7 @@ class TransformForMkldnnPass(object):
4. Remove fake_dequantize_abs_max op 4. Remove fake_dequantize_abs_max op
""" """
def __init__(self, scope=None, place=None): def __init__(self, _scope=None, _place=None):
""" """
Args: Args:
scope(fluid.Scope): scope is used to initialize the new parameters. scope(fluid.Scope): scope is used to initialize the new parameters.
...@@ -48,40 +48,44 @@ class TransformForMkldnnPass(object): ...@@ -48,40 +48,44 @@ class TransformForMkldnnPass(object):
# The original graph will be rewrite. # The original graph will be rewrite.
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \ from paddle.fluid.contrib.slim.quantization \
import TransformForMkldnnPass import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid import core from paddle.fluid import core
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False) graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace() place = fluid.CPUPlace()
mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(), mkldnn_pass = FakeQAT2MkldnnINT8KernelPass(fluid.global_scope(),
place) place)
mkldnn_pass.apply(graph) mkldnn_pass.apply(graph)
""" """
self._scope = scope self._scope = _scope
self._place = place self._place = _place
self.quantize_type = [ self._quantize_type = [
'fake_quantize_moving_average_abs_max', 'fake_quantize_moving_average_abs_max',
'fake_quantize_range_abs_max' 'fake_quantize_range_abs_max'
] ]
self.dequantize_type = ['fake_dequantize_max_abs'] self._dequantize_type = ['fake_dequantize_max_abs']
self._quantize_dequantize_type = [
'fake_quantize_dequantize_moving_average_abs_max'
]
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._conv_ops = ['conv2d', 'depthwise_conv2d'] self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d']
self.InScale = {} self._in_scale = {}
self.max_range = {} self._max_range = {}
self.new_output = {} self._new_output = {}
self.s8_max = 127 self._s8_max = 127
def apply(self, graph): def apply(self, graph):
""" """
Quantize the graph for running MKL-DNN INT8 inference. According Quantize the graph for running MKL-DNN INT8 inference. According
to activation quantization type, the graph will transform fake to activation quantization type, the graph will transform fake
quantize ops to quantize ops and remove the fake dequantize ops. quantize ops to quantize ops and remove the fake dequantize ops.
Args: Args:
graph(IrGraph): the applied graph. graph(IrGraph): the applied graph.
""" """
...@@ -91,37 +95,53 @@ class TransformForMkldnnPass(object): ...@@ -91,37 +95,53 @@ class TransformForMkldnnPass(object):
ops = graph.all_op_nodes() ops = graph.all_op_nodes()
persistable_vars = [p.name() for p in graph.all_persistable_nodes()] persistable_vars = [p.name() for p in graph.all_persistable_nodes()]
# Collect the InScales and max_range to calculate the new scales for MKL-DNN # Collect the _in_scales and _max_range to calculate the new scales for MKL-DNN
# INT8 conv2d and mul # INT8 conv2d and mul
for op_node in ops: for op_node in ops:
if op_node.name() in self.dequantize_type: if op_node.name() in self._dequantize_type:
input_name = op_node.input("X")[0] input_name = op_node.input("X")[0]
scale_name = op_node.input("Scale")[0] scale_name = op_node.input("Scale")[0]
self.InScale[input_name] = self._load_param(self._scope, self._in_scale[input_name] = self._load_param(self._scope,
scale_name)[0] scale_name)[0]
self.max_range[input_name] = op_node.op().attr("max_range") self._max_range[input_name] = op_node.op().attr("max_range")
self.new_output[input_name] = op_node.output("Out")[0] self._new_output[input_name] = op_node.output("Out")[0]
if op_node.name() in self._quantize_dequantize_type:
inputs = op_node.op().input_names()
attrs = op_node.op().attr_names()
input_name = op_node.input("X")[0]
scale_name = op_node.input("InScale")[0]
self._in_scale[input_name] = self._load_param(self._scope,
scale_name)[0]
# self._max_range[input_name] = op_node.op().attr("max_range")
self._new_output[input_name] = op_node.output("Out")[0]
for op_node in ops: for op_node in ops:
if op_node.name() in self._quantizable_ops: if op_node.name() in self._quantizable_ops:
if op_node.name() in self._conv_ops: if op_node.name() in self._conv_ops:
self._transform_to_conv_mkldnn(graph, op_node) self._transform_to_conv_mkldnn(graph, op_node)
elif op_node.name() in self._pool_ops:
self._transform_to_pool_mkldnn(graph, op_node)
else: else:
self._transform_to_mul_mkldnn(graph, op_node) self._transform_to_mul_mkldnn(graph, op_node)
elif op_node.name() in self.quantize_type: elif op_node.name() in self._quantize_type:
self._transform_to_quantize_mkldnn(graph, op_node) self._transform_to_quantize_mkldnn(graph, op_node)
elif op_node.name() in self.dequantize_type: elif op_node.name() in self._dequantize_type:
self._remove_fake_dequantize_op(graph, op_node) self._remove_fake_dequantize_op(graph, op_node)
self._remove_unused_var_nodes(graph) self._remove_unused_var_nodes(graph)
return graph return graph
def _transform_to_pool_mkldnn(self, graph, op):
output_name = op.output("Out")[0]
input_name = op.input("X")[0]
def _transform_to_conv_mkldnn(self, graph, op_node): def _transform_to_conv_mkldnn(self, graph, op_node):
weight_name = op_node.input("Filter")[0] weight_name = op_node.input("Filter")[0]
output_name = op_node.output("Output")[0] output_name = op_node.output("Output")[0]
# Convert int8 range weights to fp32 range weights # Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name) weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide( w_fp32 = np.divide(
np.multiply(weight, self.s8_max), self.max_range[output_name]) np.multiply(weight, self._s8_max), self._max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape) w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32) self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs, input_var_node = graph._find_node_by_name(op_node.inputs,
...@@ -129,8 +149,8 @@ class TransformForMkldnnPass(object): ...@@ -129,8 +149,8 @@ class TransformForMkldnnPass(object):
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name) weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
# Set fake_dequantize_abs_max's output as new output of conv2d # Set fake_dequantize_abs_max's output as new output of conv2d
output_var_node = graph._find_node_by_name(graph.all_var_nodes(), output_var_node = graph._find_node_by_name(
self.new_output[output_name]) graph.all_var_nodes(), self._new_output[output_name])
attrs = { attrs = {
name: op_node.op().attr(name) name: op_node.op().attr(name)
for name in op_node.op().attr_names() for name in op_node.op().attr_names()
...@@ -144,9 +164,9 @@ class TransformForMkldnnPass(object): ...@@ -144,9 +164,9 @@ class TransformForMkldnnPass(object):
outputs={'Output': output_var_node}) outputs={'Output': output_var_node})
# Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d # Based on the QAT's scales to calculate the scales of MKL-DNN INT8 conv2d
scale_in = self.s8_max / self.InScale[output_name] scale_in = self._s8_max / self._in_scale[output_name]
scale_w = [] scale_w = []
scale_w = [self.max_range[output_name] / self.s8_max] scale_w = [self._max_range[output_name] / self._s8_max]
conv_op_node.set_attr("Scale_weights", scale_w) conv_op_node.set_attr("Scale_weights", scale_w)
conv_op_node.set_attr("Scale_in", scale_in) conv_op_node.set_attr("Scale_in", scale_in)
...@@ -165,7 +185,7 @@ class TransformForMkldnnPass(object): ...@@ -165,7 +185,7 @@ class TransformForMkldnnPass(object):
# Convert int8 range weights to fp32 range weights # Convert int8 range weights to fp32 range weights
weight = self._load_param(self._scope, weight_name) weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide( w_fp32 = np.divide(
np.multiply(weight, self.s8_max), self.max_range[output_name]) np.multiply(weight, self._s8_max), self._max_range[output_name])
w_fp32 = w_fp32.reshape(weight.shape) w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32) self._restore_var(weight_name, w_fp32)
input_var_node = graph._find_node_by_name(op_node.inputs, input_var_node = graph._find_node_by_name(op_node.inputs,
...@@ -173,8 +193,8 @@ class TransformForMkldnnPass(object): ...@@ -173,8 +193,8 @@ class TransformForMkldnnPass(object):
weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name) weight_var_node = graph._find_node_by_name(op_node.inputs, weight_name)
# Set fake_dequantize_abs_max's output as new output of mul # Set fake_dequantize_abs_max's output as new output of mul
output_var_node = graph._find_node_by_name(graph.all_var_nodes(), output_var_node = graph._find_node_by_name(
self.new_output[output_name]) graph.all_var_nodes(), self._new_output[output_name])
attrs = { attrs = {
name: op_node.op().attr(name) name: op_node.op().attr(name)
for name in op_node.op().attr_names() for name in op_node.op().attr_names()
...@@ -188,9 +208,9 @@ class TransformForMkldnnPass(object): ...@@ -188,9 +208,9 @@ class TransformForMkldnnPass(object):
outputs={'Out': output_var_node}) outputs={'Out': output_var_node})
# Based on the QAT's scales to calculate MKL-DNN INT8 mul's scales # Based on the QAT's scales to calculate MKL-DNN INT8 mul's scales
scale_in = self.s8_max / self.InScale[output_name] scale_in = self._s8_max / self._in_scale[output_name]
scale_w = [] scale_w = []
scale_w = [self.max_range[output_name] / self.s8_max] scale_w = [self._max_range[output_name] / self._s8_max]
mul_op_node.set_attr("scale_y", scale_w) mul_op_node.set_attr("scale_y", scale_w)
mul_op_node.set_attr("scale_x", scale_in) mul_op_node.set_attr("scale_x", scale_in)
...@@ -210,7 +230,7 @@ class TransformForMkldnnPass(object): ...@@ -210,7 +230,7 @@ class TransformForMkldnnPass(object):
op_node.input("X")[0]) op_node.input("X")[0])
output_var_node = graph._find_node_by_name(op_node.outputs, output_var_node = graph._find_node_by_name(op_node.outputs,
op_node.output("Out")[0]) op_node.output("Out")[0])
scale_in = self.s8_max / self._load_param( scale_in = self._s8_max / self._load_param(
self._scope, op_node.input("InScale")[0])[0] self._scope, op_node.input("InScale")[0])[0]
quant_op_node = graph.create_op_node( quant_op_node = graph.create_op_node(
op_type='quantize', op_type='quantize',
...@@ -254,3 +274,279 @@ class TransformForMkldnnPass(object): ...@@ -254,3 +274,279 @@ class TransformForMkldnnPass(object):
graph.all_var_nodes()) graph.all_var_nodes())
} }
graph.safe_remove_nodes(all_unused_vars) graph.safe_remove_nodes(all_unused_vars)
class FakeQAT2MkldnnINT8PerfPass(object):
"""
Transform a QAT model IrGraph into MKL-DNN supported INT8 IrGraph.
The pass consists of the following transformations:
1. gather scale values from fake quantize/dequantize operators,
2. extract FP32 inference model graph from the QAT graph, i.e.
a. remove fake quantize/dequantize operators,
b. dequantize conv2d and mul's weights,
3. optimize the FP32 graph using standard FP32 optimization fuses
(e.g. `conv2d`+`bn` -> `conv2d`),
4. quantize the optimized FP32 graph using standard INT8v2 quantization
passes (`cpu_quantize_pass`, `cpu_quantize_squash_pass`).
"""
def __init__(self, _scope=None, _place=None, _core=None, _debug=False):
self._scope = _scope
self._place = _place
self._core = _core
self._debug = _debug
self._quantize_types = [
'fake_quantize_moving_average_abs_max',
'fake_quantize_range_abs_max',
'fake_quantize_dequantize_moving_average_abs_max'
]
self._fake_quantize_types = [
'fake_quantize_moving_average_abs_max',
'fake_quantize_dequantize_moving_average_abs_max'
]
self._fake_dequantize_types = ['fake_dequantize_max_abs']
self._conv_ops = ['conv2d', 'depthwise_conv2d']
self._pool_ops = ['pool2d']
self._mul_ops = ['mul']
self._fc_ops = ['fc']
self._weight_scales = {}
# Collect the Input and Output sclaes from Fake QAT models
self._var_quant_scales = {}
self._max_range = {}
self._s8_max = 127
def apply(self, graph):
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
graph = self._gather_scales(graph)
graph = self._remove_fake_ops(graph)
graph = self._update_pooling_scales(graph)
graph = self._dequantize_weights(graph)
graph = self._optimize_fp32_graph(graph)
graph = self._compute_weight_scales(graph)
graph = self._quantize_fp32_graph(graph)
graph = self._remove_unused_var_nodes(graph)
return graph
def _convert_scale2tensor(self, scale):
tensor = core.LoDTensor()
tensor.set(scale, core.CPUPlace())
return tensor
def _gather_scales(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._quantize_types:
bit_length = op.op().attr("bit_length")
assert bit_length == 8, 'Unsupported number quantization bits ({}). Only 8 is supported now.'.format(
bit_length)
input_name = op.input("X")[0]
scale_name = op.input("InScale")[0]
# Gather new weights scale after folding batchnorm in convolution
scale = np.array(1.0 / self._load_param(
self._scope, scale_name)[0]).astype(np.float64)
lod_tensor = self._convert_scale2tensor(scale)
use_unsigned_int = False
self._var_quant_scales[input_name] = (use_unsigned_int,
lod_tensor)
if op.name() in self._fake_dequantize_types:
input_name = op.input("X")[0]
_max_range = op.op().attr("max_range")
self._weight_scales[input_name] = _max_range
return graph
def _update_pooling_scales(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._pool_ops:
input_name = op.input("X")[0]
output_name = op.output("Out")[0]
if input_name in self._var_quant_scales:
self._var_quant_scales[
output_name] = self._var_quant_scales[input_name]
return graph
def _load_param(self, scope, param_name):
return np.array(scope.find_var(param_name).get_tensor())
def _remove_fake_ops(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._fake_quantize_types:
op_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0])
self._remove_fake_quantize(graph, op)
for op in graph.all_op_nodes():
if op.name() in self._fake_dequantize_types:
op_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
self._remove_fake_dequantize(graph, op)
return graph
def _remove_fake_quantize(self, graph, op):
fake_quant_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
fake_quant_in_scale = graph._find_node_by_name(op.inputs,
op.input("InScale")[0])
fake_quant_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0])
fake_quant_out_scale = graph._find_node_by_name(
op.outputs, op.output("OutScale")[0])
next_ops = fake_quant_out.outputs
for next_op in next_ops:
self._swap_inputs(next_op, fake_quant_out, fake_quant_in)
graph.link_to(fake_quant_in, next_op)
graph.safe_remove_nodes(
{op, fake_quant_in_scale, fake_quant_out, fake_quant_out_scale})
return graph
def _remove_fake_dequantize(self, graph, op):
fake_dequant_in = graph._find_node_by_name(op.inputs, op.input("X")[0])
fake_dequant_out = graph._find_node_by_name(op.outputs,
op.output("Out")[0])
next_ops = fake_dequant_out.outputs
for next_op in next_ops:
self._swap_inputs(next_op, fake_dequant_out, fake_dequant_in)
graph.link_to(fake_dequant_in, next_op)
graph.safe_remove_nodes({op, fake_dequant_out})
return graph
def _swap_inputs(self, op, old_input, new_input):
for input_name in op.op().input_names():
if old_input.name() in op.input(input_name):
op.op().set_input(input_name, [
new_input.name() if x == old_input.name() else x
for x in op.input(input_name)
])
def _dequantize_weights(self, graph):
for op in graph.all_op_nodes():
if op.name() in self._conv_ops:
self._dequantize_conv_weights(graph, op)
elif op.name() in self._mul_ops:
self._dequantize_mul_weights(graph, op)
return graph
def _dequantize_conv_weights(self, graph, op_node):
weight_name = op_node.input("Filter")[0]
output_name = op_node.output("Output")[0]
# Convert int8 range weights to fp32 range weights
scales = self._weight_scales[output_name]
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
def _dequantize_mul_weights(self, graph, op_node):
weight_name = op_node.input("Y")[0]
output_name = op_node.output("Out")[0]
scales = self._weight_scales[output_name]
weight = self._load_param(self._scope, weight_name)
w_fp32 = np.divide(np.multiply(weight, self._s8_max), scales)
w_fp32 = w_fp32.reshape(weight.shape)
self._restore_var(weight_name, w_fp32)
def _restore_var(self, name, array):
tensor = self._scope.find_var(name).get_tensor()
tensor.set(array, self._place)
def _optimize_fp32_graph(self, graph):
graph = self._apply_pass(graph, 'mkldnn_placement_pass',
['mkldnn_enabled_op_types'], [set()])
graph = self._apply_pass(graph, 'depthwise_conv_mkldnn_pass')
graph = self._apply_pass(graph, 'conv_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_eltwiseadd_bn_fuse_pass')
graph = self._apply_pass(graph, 'conv_bias_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass')
graph = self._apply_pass(graph, 'fc_fuse_pass')
return graph
def _apply_pass(self, graph, pass_name, attrs=None, attr_values=None):
ir_pass = core.get_pass(pass_name)
inference_program = graph.to_program()
ir_graph = core.Graph(inference_program.desc)
ir_graph.set_not_owned('__param_scope__', self._scope)
if attrs:
assert attr_values and len(attrs) == len(
attr_values
), "Different number of pass attributes and their values."
for attr, value in zip(attrs, attr_values):
ir_pass.set(attr, value)
ir_pass.apply(ir_graph)
graph = IrGraph(ir_graph, for_test=True)
if self._debug:
graph.draw('.', 'qat_fp32_{}'.format(pass_name),
graph.all_op_nodes())
self._remove_unused_var_nodes(graph)
return graph
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_op_nodes()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_used_vars = {n.node for n in all_used_vars}
all_unused_vars = {
n
for n in filter(lambda node: node.node not in all_used_vars,
graph.all_var_nodes())
}
graph.safe_remove_nodes(all_unused_vars)
return graph
def _compute_weight_scales(self, graph):
def _compute_var_scales(ops, out_name, w_name, axis):
for op in graph.all_op_nodes():
if op.op().type() in ops:
weight_var_name = op.input(w_name)[0]
weights = np.array(
self._load_param(self._scope, weight_var_name))
scales = 1.0 / np.amax(
np.abs(weights.reshape(weights.shape[0], -1)),
axis=axis)
lod_tensor = self._convert_scale2tensor(
scales.astype(np.float64))
use_unsigned_int = False
self._var_quant_scales[weight_var_name] = (use_unsigned_int,
lod_tensor)
_compute_var_scales(self._conv_ops, "Output", "Filter", axis=1)
_compute_var_scales(self._fc_ops, "Out", "W", axis=0)
return graph
def _find_avg_pooling_ids(self, graph):
ids = []
for op in graph.all_op_nodes():
if op.name() in self._pool_ops:
if op.op().attr("pooling_type") == "avg":
ids.append(op.id())
return set(ids)
def _quantize_fp32_graph(self, graph):
ir_pass = self._core.get_pass('cpu_quantize_placement_pass')
inference_program = graph.to_program()
ir_graph = self._core.Graph(inference_program.desc)
ir_pass.set('quantize_enabled_op_types', {'conv2d', 'pool2d'})
ir_pass.set('quantize_excluded_op_ids',
self._find_avg_pooling_ids(graph))
ir_pass.apply(ir_graph)
graph = IrGraph(ir_graph, for_test=True)
if self._debug:
graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()),
graph.all_op_nodes())
graph = self._apply_pass(graph, 'cpu_quantize_pass',
['quant_var_scales'],
[self._var_quant_scales])
graph = self._apply_pass(graph, 'cpu_quantize_squash_pass')
return graph
...@@ -32,6 +32,20 @@ function(inference_qat_int8_test target model_dir data_dir test_script use_mkldn ...@@ -32,6 +32,20 @@ function(inference_qat_int8_test target model_dir data_dir test_script use_mkldn
--acc_diff_threshold 0.1) --acc_diff_threshold 0.1)
endfunction() endfunction()
function(inference_qat2_int8_test target model_dir data_dir test_script use_mkldnn)
py_test(${target} SRCS ${test_script}
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
ARGS --qat_model ${model_dir}/float
--infer_data ${data_dir}/data.bin
--batch_size 25
--batch_num 2
--acc_diff_threshold 0.1
--qat2)
endfunction()
if(WIN32) if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_light_nas)
endif() endif()
...@@ -142,6 +156,19 @@ if(LINUX AND WITH_MKLDNN) ...@@ -142,6 +156,19 @@ if(LINUX AND WITH_MKLDNN)
inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" ) inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" )
endif() endif()
inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_qat_perf")
if (NOT EXISTS ${QAT2_RESNET50_MODEL_DIR})
inference_download_and_uncompress(${QAT2_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_perf.tar.gz" )
endif()
inference_qat2_int8_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf")
if (NOT EXISTS ${QAT2_MOBILENETV1_MODEL_DIR})
inference_download_and_uncompress(${QAT2_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNet_qat_perf.tar.gz" )
endif()
inference_qat2_int8_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
endif() endif()
# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux # Since the test for QAT FP32 & INT8 comparison supports only testing on Linux
......
...@@ -6,11 +6,11 @@ This document describes how to use [Paddle Slim](https://github.com/PaddlePaddle ...@@ -6,11 +6,11 @@ This document describes how to use [Paddle Slim](https://github.com/PaddlePaddle
You need to install at least PaddlePaddle-1.5 python package `pip install paddlepaddle==1.5`. You need to install at least PaddlePaddle-1.5 python package `pip install paddlepaddle==1.5`.
## 1. How to generate INT8 MKL-DNN QAT model ## 1. How to generate INT8 MKL-DNN QAT model
You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `TransformForMkldnnPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.5, this pass only supports `conv2d` and `depthwise_conv2d` with channel-wise quantization for weights. You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `FakeQAT2MkldnnINT8KernelPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.5, this pass only supports `conv2d` and `depthwise_conv2d` with channel-wise quantization for weights.
```python ```python
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid import core from paddle.fluid import core
...@@ -18,9 +18,9 @@ You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quanti ...@@ -18,9 +18,9 @@ You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quanti
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False) graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace() place = fluid.CPUPlace()
# Convert the IrGraph to MKL-DNN supported INT8 IrGraph by using # Convert the IrGraph to MKL-DNN supported INT8 IrGraph by using
# TransformForMkldnnPass # FakeQAT2MkldnnINT8KernelPass
mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(), place) mkldnn_pass = FakeQAT2MkldnnINT8KernelPass(fluid.global_scope(), place)
# Apply TransformForMkldnnPass to IrGraph # Apply FakeQAT2MkldnnINT8KernelPass to IrGraph
mkldnn_pass.apply(graph) mkldnn_pass.apply(graph)
``` ```
......
...@@ -24,7 +24,8 @@ import time ...@@ -24,7 +24,8 @@ import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass
from paddle.fluid import core from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
...@@ -41,8 +42,21 @@ def parse_args(): ...@@ -41,8 +42,21 @@ def parse_args():
default=0, default=0,
help='Number of the first minibatches to skip in performance statistics.' help='Number of the first minibatches to skip in performance statistics.'
) )
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
parser.add_argument( parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.') '--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument(
'--qat2',
action='store_true',
help='If used, the QAT model is treated as a second generation model for performance optimization.'
)
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
parser.add_argument('--infer_data', type=str, default='', help='Data file.') parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument( parser.add_argument(
'--batch_num', '--batch_num',
...@@ -164,12 +178,24 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -164,12 +178,24 @@ class TestQatInt8Comparison(unittest.TestCase):
model_path, exe, 'model', 'params') model_path, exe, 'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes())
if (transform_to_int8): if (transform_to_int8):
mkldnn_int8_pass = TransformForMkldnnPass( if (test_case_args.qat2):
scope=inference_scope, place=place) transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass(
mkldnn_int8_pass.apply(graph) _scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph)
else:
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
_scope=inference_scope, _place=place)
graph = mkldnn_int8_pass.apply(graph)
else: else:
graph = self._prepare_for_fp32_mkldnn(graph) graph = self._prepare_for_fp32_mkldnn(graph)
inference_program = graph.to_program() inference_program = graph.to_program()
dshape = [3, 224, 224] dshape = [3, 224, 224]
...@@ -209,7 +235,7 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -209,7 +235,7 @@ class TestQatInt8Comparison(unittest.TestCase):
samples = len(data) samples = len(data)
total_samples += samples total_samples += samples
batch_times.append(batch_time) batch_times.append(batch_time)
fps = samples / batch_time fps = samples / batch_time * 1000
fpses.append(fps) fpses.append(fps)
iters += 1 iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else '' appx = ' (warm-up)' if iters <= skip_batch_num else ''
...@@ -230,6 +256,12 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -230,6 +256,12 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Total inference run time: {:.2f} s'.format( _logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time)) infer_total_time))
if test_case_args.save_model:
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(
'transformed_qat_int8_model', feed_target_names,
fetch_targets, exe, inference_program)
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat): def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat):
...@@ -265,6 +297,7 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -265,6 +297,7 @@ class TestQatInt8Comparison(unittest.TestCase):
batch_num = test_case_args.batch_num batch_num = test_case_args.batch_num
skip_batch_num = test_case_args.skip_batch_num skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
_logger.info('QAT FP32 & INT8 prediction run.') _logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path)) _logger.info('QAT model: {0}'.format(qat_model_path))
...@@ -283,7 +316,6 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -283,7 +316,6 @@ class TestQatInt8Comparison(unittest.TestCase):
batch_num, batch_num,
skip_batch_num, skip_batch_num,
transform_to_int8=False) transform_to_int8=False)
_logger.info('--- QAT INT8 prediction start ---') _logger.info('--- QAT INT8 prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size) self._reader_creator(data_path), batch_size=batch_size)
......
...@@ -22,7 +22,7 @@ import paddle ...@@ -22,7 +22,7 @@ import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid import core from paddle.fluid import core
os.environ["CPU_NUM"] = "1" os.environ["CPU_NUM"] = "1"
...@@ -90,6 +90,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): ...@@ -90,6 +90,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
seed, seed,
activation_quant_type, activation_quant_type,
weight_quant_type='abs_max', weight_quant_type='abs_max',
qat_perf=False,
for_ci=False): for_ci=False):
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -148,7 +149,8 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): ...@@ -148,7 +149,8 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
# Transform quantized graph for MKL-DNN INT8 inference # Transform quantized graph for MKL-DNN INT8 inference
mkldnn_int8_pass = TransformForMkldnnPass(scope=scope, place=place) mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
_scope=scope, _place=place)
mkldnn_int8_pass.apply(test_graph) mkldnn_int8_pass.apply(test_graph)
dev_name = '_cpu_' dev_name = '_cpu_'
if not for_ci: if not for_ci:
......
...@@ -2416,6 +2416,20 @@ class IrOpNode(IrNode): ...@@ -2416,6 +2416,20 @@ class IrOpNode(IrNode):
"The node operator description cannot be None." "The node operator description cannot be None."
self.node.op()._rename_input(old_input_name, new_input_name) self.node.op()._rename_input(old_input_name, new_input_name)
def rename_output(self, old_output_name, new_output_name):
"""
Rename the output of this node.
Args:
old_output_name(str): the old output name.
new_output_name(str): the new output name.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
print("op: {}, old: {}, new: {}\n".format(self.node.op().type(
), old_output_name, new_output_name))
self.node.op()._rename_output(old_output_name, new_output_name)
def input(self, name): def input(self, name):
""" """
Get the argument name list by the parameter name for input. Get the argument name list by the parameter name for input.
...@@ -2709,6 +2723,24 @@ class IrGraph(object): ...@@ -2709,6 +2723,24 @@ class IrGraph(object):
op_node.append_input(new_input_node) op_node.append_input(new_input_node)
op_node.rename_input(old_input_node.name(), new_input_node.name()) op_node.rename_input(old_input_node.name(), new_input_node.name())
def update_output_link(self, old_output_node, new_output_node, op_node):
"""
Update the output's link of an operator node.
Args:
old_output_node(IrNode): the old output node of the giving op_node.
new_output_node(IrNode): the new output node of the giving op_node.
op_node(IrOpNode): the operator node that is needed to update input's link.
"""
assert old_output_node.node in self.graph.nodes() and new_output_node.node in \
self.graph.nodes() and op_node.node in self.graph.nodes(), \
'The three arguments(old_output_node &new_output_node &op_node) must be in the graph nodes.'
old_output_node.remove_input(op_node)
op_node.remove_output(old_output_node)
new_output_node.append_input(op_node)
op_node.append_output(new_output_node)
op_node.rename_output(old_output_node.name(), new_output_node.name())
def link_to(self, node_in, node_out): def link_to(self, node_in, node_out):
""" """
Connect two nodes. Connect two nodes.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册