提交 993c703b 编写于 作者: 翟飞跃 提交者: Tao Luo

INT8 MKL-DNN v2 integrate to slim (#17634)

* refactor PR 16865

* delete mergetool files

* test=develop

* test=develop

* test=develop

* test=develop

* create dir for int8 model before call SaveOptimModel

* test=develop

* mkldnn int8 only support linux; test=develop

* refine code; test=develop

* remove comment; test=develop

* refine code; test=develop

* fix bug; test=develop

* add exception for mkldnn_post_training_strategy

* reuse int8v2 CAPI dataset; test=develop

* fix accuracy check bug; test=develop

* remove tab

* convert files to unix format

* test=develop

* reduce CI time;test=develop

* reduce CI time and refine code;test=develop

* refine comment; test=develop

* add cmake FLAGS;test=develop

* remove predict_num;test=develop
上级 6a1df469
......@@ -182,11 +182,10 @@ void AnalysisConfig::EnableNgraph() {
#endif
}
std::shared_ptr<MkldnnQuantizerConfig> AnalysisConfig::mkldnn_quantizer_config()
const {
MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
"MkldnnQuantizer was not enabled yet.");
return mkldnn_quantizer_config_;
return mkldnn_quantizer_config_.get();
}
void AnalysisConfig::EnableTensorRtEngine(
......
......@@ -260,7 +260,7 @@ class MkldnnQuantizerTest : public testing::Test {
predictor.reset(new AnalysisPredictor(config));
auto* predictor_p = static_cast<AnalysisPredictor*>(predictor.get());
auto qconfig = std::make_shared<MkldnnQuantizerConfig>();
auto qconfig = new MkldnnQuantizerConfig();
mkldnn_quantizer.reset(
new AnalysisPredictor::MkldnnQuantizer(*predictor_p, qconfig));
......
......@@ -45,9 +45,8 @@ using VarQuantScale =
class AnalysisPredictor::MkldnnQuantizer {
public:
explicit MkldnnQuantizer(
AnalysisPredictor& predictor, // NOLINT
const std::shared_ptr<MkldnnQuantizerConfig>& qconfig)
explicit MkldnnQuantizer(AnalysisPredictor& predictor, // NOLINT
const MkldnnQuantizerConfig* qconfig)
: predictor_(predictor), qconfig_(qconfig) {}
// Execute full quantization procedure.
......@@ -95,7 +94,7 @@ class AnalysisPredictor::MkldnnQuantizer {
private:
AnalysisPredictor& predictor_;
const std::shared_ptr<MkldnnQuantizerConfig> qconfig_;
const MkldnnQuantizerConfig* qconfig_;
// A map: variable name -> scale
VarQuantScale scales_;
......
......@@ -210,7 +210,7 @@ struct AnalysisConfig {
*/
bool mkldnn_quantizer_enabled() const { return use_mkldnn_quantizer_; }
std::shared_ptr<MkldnnQuantizerConfig> mkldnn_quantizer_config() const;
MkldnnQuantizerConfig* mkldnn_quantizer_config() const;
/** Specify the memory buffer of program and parameter
* @param prog_buffer the memory buffer of program.
......
......@@ -17,7 +17,9 @@
#include <cstring>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/inference/api/analysis_predictor.h"
#include "paddle/fluid/inference/api/paddle_inference_api.h"
......@@ -45,6 +47,10 @@ static void BindNativePredictor(py::module *m);
static void BindAnalysisConfig(py::module *m);
static void BindAnalysisPredictor(py::module *m);
#ifdef PADDLE_WITH_MKLDNN
static void BindMkldnnQuantizerConfig(py::module *m);
#endif
void BindInferenceApi(py::module *m) {
BindPaddleDType(m);
BindPaddleBuf(m);
......@@ -55,7 +61,9 @@ void BindInferenceApi(py::module *m) {
BindNativePredictor(m);
BindAnalysisConfig(m);
BindAnalysisPredictor(m);
#ifdef PADDLE_WITH_MKLDNN
BindMkldnnQuantizerConfig(m);
#endif
m->def("create_paddle_predictor",
&paddle::CreatePaddlePredictor<AnalysisConfig>);
m->def("create_paddle_predictor",
......@@ -249,6 +257,11 @@ void BindAnalysisConfig(py::module *m) {
.def("cpu_math_library_num_threads",
&AnalysisConfig::cpu_math_library_num_threads)
.def("to_native_config", &AnalysisConfig::ToNativeConfig)
.def("enable_quantizer", &AnalysisConfig::EnableMkldnnQuantizer)
#ifdef PADDLE_WITH_MKLDNN
.def("quantizer_config", &AnalysisConfig::mkldnn_quantizer_config,
py::return_value_policy::reference)
#endif
.def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp)
.def("set_model_buffer", &AnalysisConfig::SetModelBuffer)
.def("model_from_memory", &AnalysisConfig::model_from_memory)
......@@ -256,6 +269,28 @@ void BindAnalysisConfig(py::module *m) {
py::return_value_policy::reference);
}
#ifdef PADDLE_WITH_MKLDNN
void BindMkldnnQuantizerConfig(py::module *m) {
py::class_<MkldnnQuantizerConfig> quantizer_config(*m,
"MkldnnQuantizerConfig");
quantizer_config.def(py::init<const MkldnnQuantizerConfig &>())
.def(py::init<>())
.def("set_quant_data",
[](MkldnnQuantizerConfig &self,
const std::vector<PaddleTensor> &data) {
auto warmup_data =
std::make_shared<std::vector<PaddleTensor>>(data);
self.SetWarmupData(warmup_data);
return;
})
.def("set_quant_batch_size", &MkldnnQuantizerConfig::SetWarmupBatchSize)
.def(
"set_enabled_op_types",
(void (MkldnnQuantizerConfig::*)(std::unordered_set<std::string> &)) &
MkldnnQuantizerConfig::SetEnabledOpTypes);
}
#endif
void BindAnalysisPredictor(py::module *m) {
py::class_<AnalysisPredictor, PaddlePredictor>(*m, "AnalysisPredictor")
.def(py::init<const AnalysisConfig &>())
......@@ -272,7 +307,9 @@ void BindAnalysisPredictor(py::module *m) {
.def("zero_copy_run", &AnalysisPredictor::ZeroCopyRun)
.def("clone", &AnalysisPredictor::Clone)
.def("scope", &AnalysisPredictor::scope,
py::return_value_policy::reference);
py::return_value_policy::reference)
.def("SaveOptimModel", &AnalysisPredictor::SaveOptimModel,
py::arg("dir"));
}
} // namespace pybind
......
......@@ -467,6 +467,10 @@ class Compressor(object):
for strategy in self.strategies:
strategy.on_compression_begin(context)
if 'MKLDNNPostTrainingQuantStrategy' in [
i.__class__.__name__ for i in self.strategies
]:
return None
start = context.epoch_id
self._eval(context)
for epoch in range(start, self.epoch):
......
......@@ -18,5 +18,8 @@ from . import quantization_pass
from .quantization_pass import *
from . import quantization_strategy
from .quantization_strategy import *
from . import mkldnn_post_training_strategy
from .mkldnn_post_training_strategy import *
__all__ = quantization_pass.__all__ + quantization_strategy.__all__
__all__ += mkldnn_post_training_strategy.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import logging
import six
import numpy as np
from .... import core
from ..core.strategy import Strategy
__all__ = ['MKLDNNPostTrainingQuantStrategy']
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class MKLDNNPostTrainingQuantStrategy(Strategy):
"""
The strategy for MKL-DNN Post Training quantization strategy.
"""
def __init__(self,
int8_model_save_path=None,
fp32_model_path=None,
cpu_math_library_num_threads=1):
"""
Args:
int8_model_save_path(str): int8_model_save_path is used to save an int8 ProgramDesc
with fp32 weights which is used for MKL-DNN int8 inference. For post training quantization,
MKLDNNPostTrainingQuantStrategy only supports converting a fp32 ProgramDesc
with fp32 weights to an int8 ProgramDesc with fp32 weights now. The saved
int8 ProgramDesc with fp32 weights only can be executed with MKL-DNN enabled.
None means it doesn't save int8 ProgramDesc with fp32 weights. default: None.
fp32_model_path(str): fp32_model_path is used to load an original fp32 ProgramDesc with fp32 weights.
None means it doesn't have a fp32 ProgramDesc with fp32 weights. default: None.
cpu_math_library_num_threads(int): The number of cpu math library threads which is used on
MKLDNNPostTrainingQuantStrategy. 1 means it only uses one cpu math library
thread. default: 1
"""
super(MKLDNNPostTrainingQuantStrategy, self).__init__(0, 0)
self.int8_model_save_path = int8_model_save_path
if fp32_model_path is None:
raise Exception("fp32_model_path is None")
self.fp32_model_path = fp32_model_path
self.cpu_math_library_num_threads = cpu_math_library_num_threads
def on_compression_begin(self, context):
"""
Prepare the data and quantify the model
"""
super(MKLDNNPostTrainingQuantStrategy,
self).on_compression_begin(context)
_logger.info('InferQuantStrategy::on_compression_begin')
# Prepare the Analysis Config
infer_config = core.AnalysisConfig("AnalysisConfig")
infer_config.switch_ir_optim(True)
infer_config.disable_gpu()
infer_config.set_model(self.fp32_model_path)
infer_config.enable_mkldnn()
infer_config.set_cpu_math_library_num_threads(
self.cpu_math_library_num_threads)
# Prepare the data for calculating the quantization scales
warmup_reader = context.eval_reader()
if six.PY2:
data = warmup_reader.next()
if six.PY3:
data = warmup_reader.__next__()
# TODO (Intel) Remove limits that MKLDNNPostTrainingQuantStrategy
# only support image classification
num_images = len(data)
images = core.PaddleTensor()
images.name = "x"
images.shape = [num_images, ] + list(data[0][0].shape)
images.dtype = core.PaddleDType.FLOAT32
image_data = [img.tolist() for (img, _) in data]
image_data = np.array(image_data).astype("float32")
image_data = image_data.ravel()
images.data = core.PaddleBuf(image_data.tolist())
labels = core.PaddleTensor()
labels.name = "y"
labels.shape = [num_images, 1]
labels.dtype = core.PaddleDType.INT64
label_data = [label for (_, label) in data]
labels.data = core.PaddleBuf(label_data)
warmup_data = [images, labels]
# Enable the INT8 Quantization
infer_config.enable_quantizer()
infer_config.quantizer_config().set_quant_data(warmup_data)
infer_config.quantizer_config().set_quant_batch_size(num_images)
# Run INT8 MKL-DNN Quantization
predictor = core.create_paddle_predictor(infer_config)
if self.int8_model_save_path:
if not os.path.exists(self.int8_model_save_path):
os.makedirs(self.int8_model_save_path)
predictor.SaveOptimModel(self.int8_model_save_path)
_logger.info(
'Finish MKLDNNPostTrainingQuantStrategy::on_compresseion_begin')
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
function(inference_analysis_python_api_int8_test target model_dir data_dir filename)
py_test(${target} SRCS ${filename}
ENVS FLAGS_OMP_NUM_THREADS=${CPU_NUM_THREADS_ON_CI}
ARGS --infer_model ${model_dir}/model
--infer_data ${data_dir}/data.bin
--int8_model_save_path int8_models/${target}
--warmup_batch_size 100
--batch_size 50)
endfunction()
# NOTE: TODOOOOOOOOOOO
# temporarily disable test_distillation_strategy since it always failed on a specified machine with 4 GPUs
# Need to figure out the root cause and then add it back
list(REMOVE_ITEM TEST_OPS test_distillation_strategy)
# int8 image classification python api test
if(LINUX AND WITH_MKLDNN)
set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2")
set(MKLDNN_INT8_TEST_FILE "test_mkldnn_int8_quantization_strategy.py")
# googlenet int8
set(INT8_GOOGLENET_MODEL_DIR "${INT8_DATA_DIR}/googlenet")
inference_analysis_python_api_int8_test(test_slim_int8_googlenet ${INT8_GOOGLENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
# mobilenet int8
set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenet")
inference_analysis_python_api_int8_test(test_slim_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
# temporarily adding WITH_SLIM_MKLDNN_FULL_TEST FLAG for QA testing the following UTs locally,
# since the following UTs cost too much time on CI test.
if (WITH_SLIM_MKLDNN_FULL_TEST)
# resnet50 int8
set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50")
inference_analysis_python_api_int8_test(test_slim_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
# mobilenetv2 int8
set(INT8_MOBILENETV2_MODEL_DIR "${INT8_DATA_DIR}/mobilenetv2")
inference_analysis_python_api_int8_test(test_slim_int8_mobilenetv2 ${INT8_MOBILENETV2_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
# resnet101 int8
set(INT8_RESNET101_MODEL_DIR "${INT8_DATA_DIR}/resnet101")
inference_analysis_python_api_int8_test(test_slim_int8_resnet101 ${INT8_RESNET101_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
# vgg16 int8
set(INT8_VGG16_MODEL_DIR "${INT8_DATA_DIR}/vgg16")
inference_analysis_python_api_int8_test(test_slim_int8_vgg16 ${INT8_VGG16_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
# vgg19 int8
set(INT8_VGG19_MODEL_DIR "${INT8_DATA_DIR}/vgg19")
inference_analysis_python_api_int8_test(test_slim_int8_vgg19 ${INT8_VGG19_MODEL_DIR} ${INT8_DATA_DIR} ${MKLDNN_INT8_TEST_FILE})
endif()
endif()
# Since test_mkldnn_int8_quantization_strategy only supports testing on Linux
# with MKL-DNN, we remove it here for not repeating test, or not testing on other systems.
list(REMOVE_ITEM TEST_OPS test_mkldnn_int8_quantization_strategy)
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
#int8_model_save_path(str): int8_model_save_path is used to save an int8 ProgramDesc with
# fp32 weights which is used for MKL-DNN int8 inference. For post training quantization,
# MKLDNNPostTrainingQuantStrategy only supports converting a fp32 ProgramDesc
# with fp32 weights to an int8 ProgramDesc with fp32 weights now. The saved
# int8 ProgramDesc with fp32 weights only can be executed with MKL-DNN enabled.
# None means it doesn't save int8 ProgramDesc with fp32 weights. default: None.
#
#fp32_model_path(str): fp32_model_path is used to load an original fp32 ProgramDesc with fp32 weights.
# None means it doesn't have a fp32 ProgramDesc with fp32 weights. default: None.
#
#cpu_math_library_num_threads(int): The number of cpu math library threads which is used on
# MKLDNNPostTrainingQuantStrategy. 1 means it only uses one cpu math library
# thread. default: 1
# Note: Here we set the cpu_math_library_num_threads to 4 which is the maximum number of
# cpu math library threads on CI machine.
#
version: 1.0
strategies:
mkldnn_post_training_strategy:
class: 'MKLDNNPostTrainingQuantStrategy'
int8_model_save_path: 'OUTPUT_PATH'
fp32_model_path: 'MODEL_PATH'
cpu_math_library_num_threads: 4
compressor:
epoch: 0
checkpoint_path: ''
strategies:
- mkldnn_post_training_strategy
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import os
import sys
import argparse
import shutil
import logging
import struct
import six
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.core import Compressor
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=1, help='batch size')
parser.add_argument(
'--infer_model',
type=str,
default='',
help='infer_model is used to load an original fp32 ProgramDesc with fp32 weights'
)
parser.add_argument('--infer_data', type=str, default='', help='data file')
parser.add_argument(
'--int8_model_save_path',
type=str,
default='./output',
help='infer_data is used to save an int8 ProgramDesc with fp32 weights')
parser.add_argument(
'--warmup_batch_size',
type=int,
default=100,
help='batch size for quantization warmup')
parser.add_argument(
'--accuracy_diff_threshold',
type=float,
default=0.01,
help='accepted accuracy drop threshold.')
test_args, args = parser.parse_known_args(namespace=unittest)
return test_args, sys.argv[:1] + args
class TestMKLDNNPostTrainingQuantStrategy(unittest.TestCase):
"""
Test API of Post Training quantization strategy for int8 with MKL-DNN.
"""
def _reader_creator(self, data_file='data.bin', cycle=False):
def reader():
with open(data_file, 'rb') as fp:
num = fp.read(8)
num = struct.unpack('q', num)[0]
imgs_offset = 8
img_ch = 3
img_w = 224
img_h = 224
img_pixel_size = 4
img_size = img_ch * img_h * img_w * img_pixel_size
label_size = 8
labels_offset = imgs_offset + num * img_size
step = 0
while step < num:
fp.seek(imgs_offset + img_size * step)
img = fp.read(img_size)
img = struct.unpack_from('{}f'.format(img_ch * img_w *
img_h), img)
img = np.array(img)
img.shape = (img_ch, img_w, img_h)
fp.seek(labels_offset + label_size * step)
label = fp.read(label_size)
label = struct.unpack('q', label)[0]
yield img, int(label)
step += 1
if cycle and step == num:
step = 0
return reader
def _update_config_file(self, fp32_model_path, output_path):
config_path = './quantization/config_mkldnn_int8.yaml'
new_config_path = './quantization/temp.yaml'
shutil.copy(config_path, new_config_path)
with open(new_config_path, 'r+') as fp:
data = fp.read()
data = data.replace('MODEL_PATH', fp32_model_path)
data = data.replace('OUTPUT_PATH', output_path)
with open(new_config_path, 'w') as fp:
fp.write(data)
return new_config_path
def _predict(self, test_reader=None, model_path=None):
place = fluid.CPUPlace()
exe = fluid.Executor(place)
inference_scope = fluid.executor.global_scope()
with fluid.scope_guard(inference_scope):
if os.path.exists(os.path.join(model_path, '__model__')):
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(model_path, exe)
else:
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(
model_path, exe, 'model', 'params')
dshape = [3, 224, 224]
top1 = 0.0
top5 = 0.0
total_samples = 0
for _, data in enumerate(test_reader()):
if six.PY2:
images = map(lambda x: x[0].reshape(dshape), data)
if six.PY3:
images = list(map(lambda x: x[0].reshape(dshape), data))
images = np.array(images).astype('float32')
labels = np.array([x[1] for x in data]).astype("int64")
labels = labels.reshape([-1, 1])
out = exe.run(inference_program,
feed={
feed_target_names[0]: images,
feed_target_names[1]: labels
},
fetch_list=fetch_targets)
top1 += np.sum(out[1]) * len(data)
top5 += np.sum(out[2]) * len(data)
total_samples += len(data)
return top1 / total_samples, top5 / total_samples
def _warmup(self, reader=None, config_path=''):
com_pass = Compressor(
place=None,
scope=None,
train_program=None,
train_reader=None,
train_feed_list=[],
train_fetch_list=[],
eval_program=None,
eval_reader=reader,
eval_feed_list=[],
eval_fetch_list=[],
teacher_programs=[],
checkpoint_path='',
train_optimizer=None,
distiller_optimizer=None)
com_pass.config(config_path)
com_pass.run()
def test_compression(self):
if not fluid.core.is_compiled_with_mkldnn():
return
int8_model_path = test_case_args.int8_model_save_path
data_path = test_case_args.infer_data
fp32_model_path = test_case_args.infer_model
batch_size = test_case_args.batch_size
warmup_batch_size = test_case_args.warmup_batch_size
accuracy_diff_threshold = test_case_args.accuracy_diff_threshold
_logger.info(
'FP32 & INT8 prediction run: batch_size {0}, warmup batch size {1}.'.
format(batch_size, warmup_batch_size))
#warmup dataset, only use the first batch data
warmup_reader = paddle.batch(
self._reader_creator(data_path, False),
batch_size=warmup_batch_size)
config_path = self._update_config_file(fp32_model_path, int8_model_path)
self._warmup(warmup_reader, config_path)
_logger.info('--- INT8 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, False), batch_size=batch_size)
int8_model_result = self._predict(val_reader, int8_model_path)
_logger.info('--- FP32 prediction start ---')
val_reader = paddle.batch(
self._reader_creator(data_path, False), batch_size=batch_size)
fp32_model_result = self._predict(val_reader, fp32_model_path)
_logger.info('--- comparing outputs ---')
_logger.info('Avg top1 INT8 accuracy: {0:.4f}'.format(int8_model_result[
0]))
_logger.info('Avg top1 FP32 accuracy: {0:.4f}'.format(fp32_model_result[
0]))
_logger.info('Accepted accuracy drop threshold: {0}'.format(
accuracy_diff_threshold))
assert fp32_model_result[0] - int8_model_result[
0] <= accuracy_diff_threshold
if __name__ == '__main__':
global test_case_args
test_case_args, remaining_args = parse_args()
unittest.main(argv=remaining_args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册