提交 4286a627 编写于 作者: W Wojciech Uss 提交者: Tao Luo

Add support for new QAT models (#18970)

* Add support for new QAT models

test=develop
Co-Authored-By: NMichał Gallus <michal.gallus@intel.com>
Co-Authored-By: NWojciech Uss <wojciech.uss@intel.com>

* fixed fps results

test=develop

* fix top5 accuracy drop problem

* updated for new QAT models

* skip quantizing average pooling - dirty but working

* add missing pass

* added missing conv+brelu fuse pass

* removed a call to non-existent pass

test=develop

* renamed pass

test=develop

* Adjust finding pooling scale to newest QAT models

* Remove unnecessary code from quantization_mkldnn_pass

* Copy Pooling input scale to output scale in QAT

* Refactor & remove unused code in QAT

* Incorporate fp32 FC into QAT

test=develop

* Enable graph drawing with debug flag

test=develop

* Add tests for QATv2

* Fix paths for QATv2 models

test=develop

* Add option to save transformed int8 qat model

test=develop

* Remove redundant lines from qat mkldnn pass

test=develop

* Delegate disablement of avg pooling to qat

test=develop

* fix CI bug, test=develop

* Follow Wangzhen's Review, test=develop

* Update API.spec

test=develop

* Name False in (is_unsigned, TensorScale) tuple

test=develop
上级 99a9615a
...@@ -219,6 +219,10 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -219,6 +219,10 @@ PYBIND11_MODULE(core_noavx, m) {
[](Tensor &self, paddle::platform::CPUPlace &place) { [](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<float>(place); self.mutable_data<float>(place);
}) })
.def("_alloc_double",
[](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<double>(place);
})
.def("_alloc_int", .def("_alloc_int",
[](Tensor &self, paddle::platform::CPUPlace &place) { [](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<int>(place); self.mutable_data<int>(place);
...@@ -1154,6 +1158,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1154,6 +1158,9 @@ All parameter, weight, gradient are variables in Paddle.
m.def("size_of_dtype", framework::SizeOfType); m.def("size_of_dtype", framework::SizeOfType);
using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, LoDTensor>>;
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass"); py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init()) pass.def(py::init())
.def("has", &ir::Pass::Has) .def("has", &ir::Pass::Has)
...@@ -1168,6 +1175,20 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1168,6 +1175,20 @@ All parameter, weight, gradient are variables in Paddle.
}) })
.def("set", [](ir::Pass &self, const std::string &name, .def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); }) int val) { self.Set<const int>(name, new int(val)); })
.def("set",
[](ir::Pass &self, const std::string &name,
std::unordered_set<std::string> set) {
self.Set(name, new std::unordered_set<std::string>(set));
})
.def("set",
[](ir::Pass &self, const std::string &name,
std::unordered_set<int> set) {
self.Set(name, new std::unordered_set<int>(set));
})
.def("set",
[](ir::Pass &self, const std::string &name, VarQuantScale scales) {
self.Set(name, new VarQuantScale(scales));
})
.def("type", &ir::Pass::Type) .def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) { .def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
self.Apply(graph.get()); self.Apply(graph.get());
......
...@@ -32,6 +32,20 @@ function(inference_qat_int8_test target model_dir data_dir test_script use_mkldn ...@@ -32,6 +32,20 @@ function(inference_qat_int8_test target model_dir data_dir test_script use_mkldn
--acc_diff_threshold 0.1) --acc_diff_threshold 0.1)
endfunction() endfunction()
function(inference_qat2_int8_test target model_dir data_dir test_script use_mkldnn)
py_test(${target} SRCS ${test_script}
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
ARGS --qat_model ${model_dir}/float
--infer_data ${data_dir}/data.bin
--batch_size 25
--batch_num 2
--acc_diff_threshold 0.1
--qat2)
endfunction()
if(WIN32) if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas) list(REMOVE_ITEM TEST_OPS test_light_nas)
endif() endif()
...@@ -142,6 +156,19 @@ if(LINUX AND WITH_MKLDNN) ...@@ -142,6 +156,19 @@ if(LINUX AND WITH_MKLDNN)
inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" ) inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" )
endif() endif()
inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_qat_perf")
if (NOT EXISTS ${QAT2_RESNET50_MODEL_DIR})
inference_download_and_uncompress(${QAT2_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_perf.tar.gz" )
endif()
inference_qat2_int8_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf")
if (NOT EXISTS ${QAT2_MOBILENETV1_MODEL_DIR})
inference_download_and_uncompress(${QAT2_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNet_qat_perf.tar.gz" )
endif()
inference_qat2_int8_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
endif() endif()
# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux # Since the test for QAT FP32 & INT8 comparison supports only testing on Linux
......
...@@ -6,11 +6,11 @@ This document describes how to use [Paddle Slim](https://github.com/PaddlePaddle ...@@ -6,11 +6,11 @@ This document describes how to use [Paddle Slim](https://github.com/PaddlePaddle
You need to install at least PaddlePaddle-1.5 python package `pip install paddlepaddle==1.5`. You need to install at least PaddlePaddle-1.5 python package `pip install paddlepaddle==1.5`.
## 1. How to generate INT8 MKL-DNN QAT model ## 1. How to generate INT8 MKL-DNN QAT model
You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `TransformForMkldnnPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.5, this pass only supports `conv2d` and `depthwise_conv2d` with channel-wise quantization for weights. You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `FakeQAT2MkldnnINT8KernelPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.5, this pass only supports `conv2d` and `depthwise_conv2d` with channel-wise quantization for weights.
```python ```python
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid import core from paddle.fluid import core
...@@ -18,9 +18,9 @@ You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quanti ...@@ -18,9 +18,9 @@ You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quanti
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False) graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace() place = fluid.CPUPlace()
# Convert the IrGraph to MKL-DNN supported INT8 IrGraph by using # Convert the IrGraph to MKL-DNN supported INT8 IrGraph by using
# TransformForMkldnnPass # FakeQAT2MkldnnINT8KernelPass
mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(), place) mkldnn_pass = FakeQAT2MkldnnINT8KernelPass(fluid.global_scope(), place)
# Apply TransformForMkldnnPass to IrGraph # Apply FakeQAT2MkldnnINT8KernelPass to IrGraph
mkldnn_pass.apply(graph) mkldnn_pass.apply(graph)
``` ```
......
...@@ -24,7 +24,8 @@ import time ...@@ -24,7 +24,8 @@ import time
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass
from paddle.fluid import core from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s') logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
...@@ -41,8 +42,21 @@ def parse_args(): ...@@ -41,8 +42,21 @@ def parse_args():
default=0, default=0,
help='Number of the first minibatches to skip in performance statistics.' help='Number of the first minibatches to skip in performance statistics.'
) )
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
parser.add_argument( parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.') '--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument(
'--qat2',
action='store_true',
help='If used, the QAT model is treated as a second generation model for performance optimization.'
)
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
parser.add_argument('--infer_data', type=str, default='', help='Data file.') parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument( parser.add_argument(
'--batch_num', '--batch_num',
...@@ -164,12 +178,24 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -164,12 +178,24 @@ class TestQatInt8Comparison(unittest.TestCase):
model_path, exe, 'model', 'params') model_path, exe, 'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True) graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes())
if (transform_to_int8): if (transform_to_int8):
mkldnn_int8_pass = TransformForMkldnnPass( if (test_case_args.qat2):
scope=inference_scope, place=place) transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass(
mkldnn_int8_pass.apply(graph) _scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph)
else:
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
_scope=inference_scope, _place=place)
graph = mkldnn_int8_pass.apply(graph)
else: else:
graph = self._prepare_for_fp32_mkldnn(graph) graph = self._prepare_for_fp32_mkldnn(graph)
inference_program = graph.to_program() inference_program = graph.to_program()
dshape = [3, 224, 224] dshape = [3, 224, 224]
...@@ -209,7 +235,7 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -209,7 +235,7 @@ class TestQatInt8Comparison(unittest.TestCase):
samples = len(data) samples = len(data)
total_samples += samples total_samples += samples
batch_times.append(batch_time) batch_times.append(batch_time)
fps = samples / batch_time fps = samples / batch_time * 1000
fpses.append(fps) fpses.append(fps)
iters += 1 iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else '' appx = ' (warm-up)' if iters <= skip_batch_num else ''
...@@ -230,6 +256,12 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -230,6 +256,12 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Total inference run time: {:.2f} s'.format( _logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time)) infer_total_time))
if test_case_args.save_model:
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(
'transformed_qat_int8_model', feed_target_names,
fetch_targets, exe, inference_program)
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat): def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat):
...@@ -265,6 +297,7 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -265,6 +297,7 @@ class TestQatInt8Comparison(unittest.TestCase):
batch_num = test_case_args.batch_num batch_num = test_case_args.batch_num
skip_batch_num = test_case_args.skip_batch_num skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
_logger.info('QAT FP32 & INT8 prediction run.') _logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path)) _logger.info('QAT model: {0}'.format(qat_model_path))
...@@ -283,7 +316,6 @@ class TestQatInt8Comparison(unittest.TestCase): ...@@ -283,7 +316,6 @@ class TestQatInt8Comparison(unittest.TestCase):
batch_num, batch_num,
skip_batch_num, skip_batch_num,
transform_to_int8=False) transform_to_int8=False)
_logger.info('--- QAT INT8 prediction start ---') _logger.info('--- QAT INT8 prediction start ---')
val_reader = paddle.batch( val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size) self._reader_creator(data_path), batch_size=batch_size)
......
...@@ -22,7 +22,7 @@ import paddle ...@@ -22,7 +22,7 @@ import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid import core from paddle.fluid import core
os.environ["CPU_NUM"] = "1" os.environ["CPU_NUM"] = "1"
...@@ -90,6 +90,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): ...@@ -90,6 +90,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
seed, seed,
activation_quant_type, activation_quant_type,
weight_quant_type='abs_max', weight_quant_type='abs_max',
qat_perf=False,
for_ci=False): for_ci=False):
random.seed(0) random.seed(0)
np.random.seed(0) np.random.seed(0)
...@@ -148,7 +149,8 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase): ...@@ -148,7 +149,8 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
freeze_pass.apply(test_graph) freeze_pass.apply(test_graph)
# Transform quantized graph for MKL-DNN INT8 inference # Transform quantized graph for MKL-DNN INT8 inference
mkldnn_int8_pass = TransformForMkldnnPass(scope=scope, place=place) mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
_scope=scope, _place=place)
mkldnn_int8_pass.apply(test_graph) mkldnn_int8_pass.apply(test_graph)
dev_name = '_cpu_' dev_name = '_cpu_'
if not for_ci: if not for_ci:
......
...@@ -2416,6 +2416,20 @@ class IrOpNode(IrNode): ...@@ -2416,6 +2416,20 @@ class IrOpNode(IrNode):
"The node operator description cannot be None." "The node operator description cannot be None."
self.node.op()._rename_input(old_input_name, new_input_name) self.node.op()._rename_input(old_input_name, new_input_name)
def rename_output(self, old_output_name, new_output_name):
"""
Rename the output of this node.
Args:
old_output_name(str): the old output name.
new_output_name(str): the new output name.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
print("op: {}, old: {}, new: {}\n".format(self.node.op().type(
), old_output_name, new_output_name))
self.node.op()._rename_output(old_output_name, new_output_name)
def input(self, name): def input(self, name):
""" """
Get the argument name list by the parameter name for input. Get the argument name list by the parameter name for input.
...@@ -2709,6 +2723,24 @@ class IrGraph(object): ...@@ -2709,6 +2723,24 @@ class IrGraph(object):
op_node.append_input(new_input_node) op_node.append_input(new_input_node)
op_node.rename_input(old_input_node.name(), new_input_node.name()) op_node.rename_input(old_input_node.name(), new_input_node.name())
def update_output_link(self, old_output_node, new_output_node, op_node):
"""
Update the output's link of an operator node.
Args:
old_output_node(IrNode): the old output node of the giving op_node.
new_output_node(IrNode): the new output node of the giving op_node.
op_node(IrOpNode): the operator node that is needed to update input's link.
"""
assert old_output_node.node in self.graph.nodes() and new_output_node.node in \
self.graph.nodes() and op_node.node in self.graph.nodes(), \
'The three arguments(old_output_node &new_output_node &op_node) must be in the graph nodes.'
old_output_node.remove_input(op_node)
op_node.remove_output(old_output_node)
new_output_node.append_input(op_node)
op_node.append_output(new_output_node)
op_node.rename_output(old_output_node.name(), new_output_node.name())
def link_to(self, node_in, node_out): def link_to(self, node_in, node_out):
""" """
Connect two nodes. Connect two nodes.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册