提交 4286a627 编写于 作者: W Wojciech Uss 提交者: Tao Luo

Add support for new QAT models (#18970)

* Add support for new QAT models

test=develop
Co-Authored-By: NMichał Gallus <michal.gallus@intel.com>
Co-Authored-By: NWojciech Uss <wojciech.uss@intel.com>

* fixed fps results

test=develop

* fix top5 accuracy drop problem

* updated for new QAT models

* skip quantizing average pooling - dirty but working

* add missing pass

* added missing conv+brelu fuse pass

* removed a call to non-existent pass

test=develop

* renamed pass

test=develop

* Adjust finding pooling scale to newest QAT models

* Remove unnecessary code from quantization_mkldnn_pass

* Copy Pooling input scale to output scale in QAT

* Refactor & remove unused code in QAT

* Incorporate fp32 FC into QAT

test=develop

* Enable graph drawing with debug flag

test=develop

* Add tests for QATv2

* Fix paths for QATv2 models

test=develop

* Add option to save transformed int8 qat model

test=develop

* Remove redundant lines from qat mkldnn pass

test=develop

* Delegate disablement of avg pooling to qat

test=develop

* fix CI bug, test=develop

* Follow Wangzhen's Review, test=develop

* Update API.spec

test=develop

* Name False in (is_unsigned, TensorScale) tuple

test=develop
上级 99a9615a
......@@ -219,6 +219,10 @@ PYBIND11_MODULE(core_noavx, m) {
[](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<float>(place);
})
.def("_alloc_double",
[](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<double>(place);
})
.def("_alloc_int",
[](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<int>(place);
......@@ -1154,6 +1158,9 @@ All parameter, weight, gradient are variables in Paddle.
m.def("size_of_dtype", framework::SizeOfType);
using VarQuantScale =
std::unordered_map<std::string, std::pair<bool, LoDTensor>>;
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init())
.def("has", &ir::Pass::Has)
......@@ -1168,6 +1175,20 @@ All parameter, weight, gradient are variables in Paddle.
})
.def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); })
.def("set",
[](ir::Pass &self, const std::string &name,
std::unordered_set<std::string> set) {
self.Set(name, new std::unordered_set<std::string>(set));
})
.def("set",
[](ir::Pass &self, const std::string &name,
std::unordered_set<int> set) {
self.Set(name, new std::unordered_set<int>(set));
})
.def("set",
[](ir::Pass &self, const std::string &name, VarQuantScale scales) {
self.Set(name, new VarQuantScale(scales));
})
.def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
self.Apply(graph.get());
......
......@@ -32,6 +32,20 @@ function(inference_qat_int8_test target model_dir data_dir test_script use_mkldn
--acc_diff_threshold 0.1)
endfunction()
function(inference_qat2_int8_test target model_dir data_dir test_script use_mkldnn)
py_test(${target} SRCS ${test_script}
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
FLAGS_use_mkldnn=${use_mkldnn}
ARGS --qat_model ${model_dir}/float
--infer_data ${data_dir}/data.bin
--batch_size 25
--batch_num 2
--acc_diff_threshold 0.1
--qat2)
endfunction()
if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas)
endif()
......@@ -142,6 +156,19 @@ if(LINUX AND WITH_MKLDNN)
inference_download_and_uncompress(${QAT_VGG19_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "VGG19_qat_model.tar.gz" )
endif()
inference_qat_int8_test(test_qat_int8_vgg19_mkldnn ${QAT_VGG19_MODEL_DIR} ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_qat_perf")
if (NOT EXISTS ${QAT2_RESNET50_MODEL_DIR})
inference_download_and_uncompress(${QAT2_RESNET50_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "ResNet50_qat_perf.tar.gz" )
endif()
inference_qat2_int8_test(test_qat2_int8_resnet50_mkldnn ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
set(QAT2_MOBILENETV1_MODEL_DIR "${QAT_DATA_DIR}/MobileNet_qat_perf")
if (NOT EXISTS ${QAT2_MOBILENETV1_MODEL_DIR})
inference_download_and_uncompress(${QAT2_MOBILENETV1_MODEL_DIR} "${QAT_MODELS_BASE_URL}" "MobileNet_qat_perf.tar.gz" )
endif()
inference_qat2_int8_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true)
endif()
# Since the test for QAT FP32 & INT8 comparison supports only testing on Linux
......
......@@ -6,11 +6,11 @@ This document describes how to use [Paddle Slim](https://github.com/PaddlePaddle
You need to install at least PaddlePaddle-1.5 python package `pip install paddlepaddle==1.5`.
## 1. How to generate INT8 MKL-DNN QAT model
You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `TransformForMkldnnPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.5, this pass only supports `conv2d` and `depthwise_conv2d` with channel-wise quantization for weights.
You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quantization_mkldnn_pass.py). Users firstly use PaddleSlim quantization strategy to get a saved fake QAT model by [QuantizationFreezePass](https://github.com/PaddlePaddle/models/tree/develop/PaddleSlim/quant_low_level_api), then use the `FakeQAT2MkldnnINT8KernelPass` to get the graph which can be run with MKL-DNN INT8 kernel. In Paddle Release 1.5, this pass only supports `conv2d` and `depthwise_conv2d` with channel-wise quantization for weights.
```python
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
......@@ -18,9 +18,9 @@ You can refer to the unit test in [test_quantization_mkldnn_pass.py](test_quanti
graph = IrGraph(core.Graph(fluid.Program().desc), for_test=False)
place = fluid.CPUPlace()
# Convert the IrGraph to MKL-DNN supported INT8 IrGraph by using
# TransformForMkldnnPass
mkldnn_pass = TransformForMkldnnPass(fluid.global_scope(), place)
# Apply TransformForMkldnnPass to IrGraph
# FakeQAT2MkldnnINT8KernelPass
mkldnn_pass = FakeQAT2MkldnnINT8KernelPass(fluid.global_scope(), place)
# Apply FakeQAT2MkldnnINT8KernelPass to IrGraph
mkldnn_pass.apply(graph)
```
......
......@@ -24,7 +24,8 @@ import time
import paddle
import paddle.fluid as fluid
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass
from paddle.fluid import core
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
......@@ -41,8 +42,21 @@ def parse_args():
default=0,
help='Number of the first minibatches to skip in performance statistics.'
)
parser.add_argument(
'--debug',
action='store_true',
help='If used, the graph of QAT model is drawn.')
parser.add_argument(
'--qat_model', type=str, default='', help='A path to a QAT model.')
parser.add_argument(
'--qat2',
action='store_true',
help='If used, the QAT model is treated as a second generation model for performance optimization.'
)
parser.add_argument(
'--save_model',
action='store_true',
help='If used, the QAT model will be saved after all transformations')
parser.add_argument('--infer_data', type=str, default='', help='Data file.')
parser.add_argument(
'--batch_num',
......@@ -164,12 +178,24 @@ class TestQatInt8Comparison(unittest.TestCase):
model_path, exe, 'model', 'params')
graph = IrGraph(core.Graph(inference_program.desc), for_test=True)
if (self._debug):
graph.draw('.', 'qat_orig', graph.all_op_nodes())
if (transform_to_int8):
mkldnn_int8_pass = TransformForMkldnnPass(
scope=inference_scope, place=place)
mkldnn_int8_pass.apply(graph)
if (test_case_args.qat2):
transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass(
_scope=inference_scope,
_place=place,
_core=core,
_debug=self._debug)
graph = transform_to_mkldnn_int8_pass.apply(graph)
else:
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
_scope=inference_scope, _place=place)
graph = mkldnn_int8_pass.apply(graph)
else:
graph = self._prepare_for_fp32_mkldnn(graph)
inference_program = graph.to_program()
dshape = [3, 224, 224]
......@@ -209,7 +235,7 @@ class TestQatInt8Comparison(unittest.TestCase):
samples = len(data)
total_samples += samples
batch_times.append(batch_time)
fps = samples / batch_time
fps = samples / batch_time * 1000
fpses.append(fps)
iters += 1
appx = ' (warm-up)' if iters <= skip_batch_num else ''
......@@ -230,6 +256,12 @@ class TestQatInt8Comparison(unittest.TestCase):
_logger.info('Total inference run time: {:.2f} s'.format(
infer_total_time))
if test_case_args.save_model:
with fluid.scope_guard(inference_scope):
fluid.io.save_inference_model(
'transformed_qat_int8_model', feed_target_names,
fetch_targets, exe, inference_program)
return outputs, acc1_avg, acc5_avg, fps_avg, latency_avg
def _summarize_performance(self, fp32_fps, fp32_lat, int8_fps, int8_lat):
......@@ -265,6 +297,7 @@ class TestQatInt8Comparison(unittest.TestCase):
batch_num = test_case_args.batch_num
skip_batch_num = test_case_args.skip_batch_num
acc_diff_threshold = test_case_args.acc_diff_threshold
self._debug = test_case_args.debug
_logger.info('QAT FP32 & INT8 prediction run.')
_logger.info('QAT model: {0}'.format(qat_model_path))
......@@ -283,7 +316,6 @@ class TestQatInt8Comparison(unittest.TestCase):
batch_num,
skip_batch_num,
transform_to_int8=False)
_logger.info('--- QAT INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path), batch_size=batch_size)
......
......@@ -22,7 +22,7 @@ import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import TransformForMkldnnPass
from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8KernelPass
from paddle.fluid import core
os.environ["CPU_NUM"] = "1"
......@@ -90,6 +90,7 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
seed,
activation_quant_type,
weight_quant_type='abs_max',
qat_perf=False,
for_ci=False):
random.seed(0)
np.random.seed(0)
......@@ -148,7 +149,8 @@ class TestMKLDNNTransformBasedFreezePass(unittest.TestCase):
freeze_pass.apply(test_graph)
# Transform quantized graph for MKL-DNN INT8 inference
mkldnn_int8_pass = TransformForMkldnnPass(scope=scope, place=place)
mkldnn_int8_pass = FakeQAT2MkldnnINT8KernelPass(
_scope=scope, _place=place)
mkldnn_int8_pass.apply(test_graph)
dev_name = '_cpu_'
if not for_ci:
......
......@@ -2416,6 +2416,20 @@ class IrOpNode(IrNode):
"The node operator description cannot be None."
self.node.op()._rename_input(old_input_name, new_input_name)
def rename_output(self, old_output_name, new_output_name):
"""
Rename the output of this node.
Args:
old_output_name(str): the old output name.
new_output_name(str): the new output name.
"""
assert self.node.op() is not None, \
"The node operator description cannot be None."
print("op: {}, old: {}, new: {}\n".format(self.node.op().type(
), old_output_name, new_output_name))
self.node.op()._rename_output(old_output_name, new_output_name)
def input(self, name):
"""
Get the argument name list by the parameter name for input.
......@@ -2709,6 +2723,24 @@ class IrGraph(object):
op_node.append_input(new_input_node)
op_node.rename_input(old_input_node.name(), new_input_node.name())
def update_output_link(self, old_output_node, new_output_node, op_node):
"""
Update the output's link of an operator node.
Args:
old_output_node(IrNode): the old output node of the giving op_node.
new_output_node(IrNode): the new output node of the giving op_node.
op_node(IrOpNode): the operator node that is needed to update input's link.
"""
assert old_output_node.node in self.graph.nodes() and new_output_node.node in \
self.graph.nodes() and op_node.node in self.graph.nodes(), \
'The three arguments(old_output_node &new_output_node &op_node) must be in the graph nodes.'
old_output_node.remove_input(op_node)
op_node.remove_output(old_output_node)
new_output_node.append_input(op_node)
op_node.append_output(new_output_node)
op_node.rename_output(old_output_node.name(), new_output_node.name())
def link_to(self, node_in, node_out):
"""
Connect two nodes.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册