From c0aa13672edf484b280988f3400636b1a3aff050 Mon Sep 17 00:00:00 2001 From: lidanqing Date: Thu, 28 Nov 2019 07:22:22 +0100 Subject: [PATCH] Fp32 vs int8 qat C++ performance (#21244) * add ut for comparing FP32 and QAT INT8 * add save qat transformed model python script test=develop * updated * added missing file * add "with_label" test=develop * performance benchmark as unit test test=develop * change names of unnecessary thing * Change CMakeList.txt for model downloading and UT test=develop * change names of functions and params for more readable code test=develop * Change PADDLE_ENFORCE messages test=develop * fix indent problems test=develop * indent problems test=develop --- .../fluid/inference/tests/api/CMakeLists.txt | 43 +++++- ...alyzer_int8_image_classification_tester.cc | 29 ++-- .../analyzer_int8_object_detection_tester.cc | 8 +- ...nalyzer_qat_image_classification_tester.cc | 129 ++++++++++++++++++ .../fluid/inference/tests/api/tester_helper.h | 120 +++++++++++++--- .../quantization/quantization_mkldnn_pass.py | 11 ++ .../fluid/contrib/slim/tests/CMakeLists.txt | 13 ++ .../contrib/slim/tests/save_qat_model.py | 87 ++++++++++++ 8 files changed, 399 insertions(+), 41 deletions(-) create mode 100644 paddle/fluid/inference/tests/api/analyzer_qat_image_classification_tester.cc create mode 100644 python/paddle/fluid/contrib/slim/tests/save_qat_model.py diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index b4fc327d679..686433c5608 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -16,6 +16,12 @@ function(download_int8_data install_dir data_file) endif() endfunction() +function(download_qat_data install_dir data_file) + if (NOT EXISTS ${install_dir}/${data_file}) + inference_download_and_uncompress(${install_dir} ${INFERENCE_URL}/int8/QAT_models ${data_file}) + endif() +endfunction() + function(download_model_and_data install_dir model_name data_name) download_data(${install_dir} ${model_name}) download_data(${install_dir} ${data_name}) @@ -31,7 +37,7 @@ function(inference_analysis_api_test target install_dir filename) ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt --refer_result=${install_dir}/result.txt) endfunction() -function(inference_analysis_api_int8_test_build TARGET_NAME filename) +function(inference_analysis_api_test_build TARGET_NAME filename) inference_analysis_test_build(${TARGET_NAME} SRCS ${filename} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark) endfunction() @@ -77,6 +83,18 @@ function(inference_analysis_api_test_with_refer_result target install_dir filena --refer_result=${install_dir}/result.txt) endfunction() +function(inference_analysis_api_qat_test_run TARGET_NAME test_binary fp32_model_dir int8_model_dir data_path) + inference_analysis_test_run(${TARGET_NAME} + COMMAND ${test_binary} + ARGS --fp32_model=${fp32_model_dir} + --int8_model=${int8_model_dir} + --infer_data=${data_path} + --batch_size=50 + --paddle_num_threads=${CPU_NUM_THREADS_ON_CI} + --with_accuracy_layer=false + --iterations=2) +endfunction() + if(NOT APPLE AND WITH_MKLML) # RNN1 set(RNN1_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/rnn1") @@ -196,9 +214,10 @@ download_data(${MOBILENET_MODEL_DIR} "mobilenet_model.tar.gz") inference_analysis_api_test_with_fake_data_run(test_analyzer_mobilenet_depthwise_conv ${IMG_CLASS_TEST_APP} ${MOBILENET_MODEL_DIR} false) -### INT8 tests if(WITH_MKLDNN) + ### INT8 tests + set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") ### Image classification tests @@ -210,7 +229,7 @@ if(WITH_MKLDNN) download_int8_data(${INT8_DATA_DIR} "imagenet_val_100_tail.tar.gz") # build test binary to be used in subsequent tests - inference_analysis_api_int8_test_build(${INT8_IMG_CLASS_TEST_APP} ${INT8_IMG_CLASS_TEST_APP_SRC}) + inference_analysis_api_test_build(${INT8_IMG_CLASS_TEST_APP} ${INT8_IMG_CLASS_TEST_APP_SRC}) # resnet50 int8 set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50") @@ -259,13 +278,29 @@ if(WITH_MKLDNN) download_int8_data(${INT8_DATA_DIR} "pascalvoc_small.tar.gz") # build test binary to be used in subsequent tests - inference_analysis_api_int8_test_build(${INT8_OBJ_DETECT_TEST_APP} ${INT8_OBJ_DETECT_TEST_APP_SRC}) + inference_analysis_api_test_build(${INT8_OBJ_DETECT_TEST_APP} ${INT8_OBJ_DETECT_TEST_APP_SRC}) # mobilenet-ssd int8 set(INT8_MOBILENET_SSD_MODEL_DIR "${INT8_DATA_DIR}/mobilenet-ssd") download_int8_data(${INT8_MOBILENET_SSD_MODEL_DIR} "mobilenet_ssd_int8_model.tar.gz" ) inference_analysis_api_object_dection_int8_test_run(test_analyzer_int8_mobilenet_ssd ${INT8_OBJ_DETECT_TEST_APP} ${INT8_MOBILENET_SSD_MODEL_DIR} ${PASCALVOC_DATA_PATH}) + ### optimized FP32 vs. QAT INT8 tests + + set(QAT_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") + set(QAT_IMG_CLASS_TEST_APP "test_analyzer_qat_image_classification") + set(QAT_IMG_CLASS_TEST_APP_SRC "analyzer_qat_image_classification_tester.cc") + + # build test binary to be used in subsequent tests + inference_analysis_api_test_build(${QAT_IMG_CLASS_TEST_APP} ${QAT_IMG_CLASS_TEST_APP_SRC}) + + # ResNet50 FP32 vs. QAT INT8 + set(QAT2_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_qat_perf") + download_qat_data(${QAT2_RESNET50_MODEL_DIR} "ResNet50_qat_perf.tar.gz") + set(QAT2_INT8_RESNET50_MODEL_DIR "${QAT_DATA_DIR}/ResNet50_qat_perf_int8") + download_qat_data(${QAT2_INT8_RESNET50_MODEL_DIR} "ResNet50_qat_perf_int8.tar.gz") + inference_analysis_api_qat_test_run(test_analyzer_qat_performance_benchmark ${QAT_IMG_CLASS_TEST_APP} ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_INT8_RESNET50_MODEL_DIR}/ResNet50_qat_perf_int8 ${IMAGENET_DATA_PATH}) + endif() # bert, max_len=20, embedding_dim=128 diff --git a/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc index 3e4a8f3ff38..3e337adbc2a 100644 --- a/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc @@ -35,9 +35,9 @@ class TensorReader { public: TensorReader(std::ifstream &file, size_t beginning_offset, std::vector shape, std::string name) - : file_(file), position(beginning_offset), shape_(shape), name_(name) { - numel = std::accumulate(shape_.begin(), shape_.end(), size_t{1}, - std::multiplies()); + : file_(file), position_(beginning_offset), shape_(shape), name_(name) { + numel_ = std::accumulate(shape_.begin(), shape_.end(), size_t{1}, + std::multiplies()); } PaddleTensor NextBatch() { @@ -45,11 +45,11 @@ class TensorReader { tensor.name = name_; tensor.shape = shape_; tensor.dtype = GetPaddleDType(); - tensor.data.Resize(numel * sizeof(T)); + tensor.data.Resize(numel_ * sizeof(T)); - file_.seekg(position); - file_.read(static_cast(tensor.data.data()), numel * sizeof(T)); - position = file_.tellg(); + file_.seekg(position_); + file_.read(static_cast(tensor.data.data()), numel_ * sizeof(T)); + position_ = file_.tellg(); if (file_.eof()) LOG(ERROR) << name_ << ": reached end of stream"; if (file_.fail()) @@ -60,10 +60,10 @@ class TensorReader { protected: std::ifstream &file_; - size_t position; + size_t position_; std::vector shape_; std::string name_; - size_t numel; + size_t numel_; }; std::shared_ptr> GetWarmupData( @@ -71,10 +71,13 @@ std::shared_ptr> GetWarmupData( int num_images = FLAGS_warmup_batch_size) { int test_data_batch_size = test_data[0][0].shape[0]; auto iterations = test_data.size(); - PADDLE_ENFORCE( - static_cast(num_images) <= iterations * test_data_batch_size, - "The requested quantization warmup data size " + - std::to_string(num_images) + " is bigger than all test data size."); + auto all_test_data_size = iterations * test_data_batch_size; + PADDLE_ENFORCE_LE(static_cast(num_images), all_test_data_size, + platform::errors::InvalidArgument( + "The requested quantization warmup data size must be " + "smaller than the test data size. But received warmup " + "size is %d and test data size is %d", + num_images, all_test_data_size)); PaddleTensor images; images.name = "image"; diff --git a/paddle/fluid/inference/tests/api/analyzer_int8_object_detection_tester.cc b/paddle/fluid/inference/tests/api/analyzer_int8_object_detection_tester.cc index b13e454876f..7d9a73d05ca 100644 --- a/paddle/fluid/inference/tests/api/analyzer_int8_object_detection_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_int8_object_detection_tester.cc @@ -50,7 +50,7 @@ template class TensorReader { public: TensorReader(std::ifstream &file, size_t beginning_offset, std::string name) - : file_(file), position(beginning_offset), name_(name) {} + : file_(file), position_(beginning_offset), name_(name) {} PaddleTensor NextBatch(std::vector shape, std::vector lod) { int numel = @@ -64,9 +64,9 @@ class TensorReader { tensor.lod.clear(); tensor.lod.push_back(lod); } - file_.seekg(position); + file_.seekg(position_); file_.read(reinterpret_cast(tensor.data.data()), numel * sizeof(T)); - position = file_.tellg(); + position_ = file_.tellg(); if (file_.eof()) LOG(ERROR) << name_ << ": reached end of stream"; if (file_.fail()) throw std::runtime_error(name_ + ": failed reading file."); @@ -75,7 +75,7 @@ class TensorReader { protected: std::ifstream &file_; - size_t position; + size_t position_; std::string name_; }; diff --git a/paddle/fluid/inference/tests/api/analyzer_qat_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_qat_image_classification_tester.cc new file mode 100644 index 00000000000..fd3210c3384 --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_qat_image_classification_tester.cc @@ -0,0 +1,129 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/inference/api/paddle_analysis_config.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +namespace paddle { +namespace inference { +namespace analysis { + +void SetConfig(AnalysisConfig *cfg, std::string model_path) { + cfg->SetModel(model_path); + cfg->DisableGpu(); + cfg->SwitchIrOptim(false); + cfg->SwitchSpecifyInputNames(); + cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); + cfg->EnableMKLDNN(); +} + +template +class TensorReader { + public: + TensorReader(std::ifstream &file, size_t beginning_offset, + std::vector shape, std::string name) + : file_(file), position_(beginning_offset), shape_(shape), name_(name) { + numel_ = std::accumulate(shape_.begin(), shape_.end(), size_t{1}, + std::multiplies()); + } + + PaddleTensor NextBatch() { + PaddleTensor tensor; + tensor.name = name_; + tensor.shape = shape_; + tensor.dtype = GetPaddleDType(); + tensor.data.Resize(numel_ * sizeof(T)); + + file_.seekg(position_); + file_.read(static_cast(tensor.data.data()), numel_ * sizeof(T)); + position_ = file_.tellg(); + + if (file_.eof()) LOG(ERROR) << name_ << ": reached end of stream"; + if (file_.fail()) + throw std::runtime_error(name_ + ": failed reading file."); + + return tensor; + } + + protected: + std::ifstream &file_; + size_t position_; + std::vector shape_; + std::string name_; + size_t numel_; +}; + +void SetInput(std::vector> *inputs, + bool with_accuracy_layer = FLAGS_with_accuracy_layer, + int32_t batch_size = FLAGS_batch_size) { + std::ifstream file(FLAGS_infer_data, std::ios::binary); + if (!file) { + FAIL() << "Couldn't open file: " << FLAGS_infer_data; + } + + int64_t total_images{0}; + file.read(reinterpret_cast(&total_images), sizeof(total_images)); + LOG(INFO) << "Total images in file: " << total_images; + + std::vector image_batch_shape{batch_size, 3, 224, 224}; + std::vector label_batch_shape{batch_size, 1}; + auto images_offset_in_file = static_cast(file.tellg()); + + TensorReader image_reader(file, images_offset_in_file, + image_batch_shape, "image"); + + auto iterations_max = total_images / batch_size; + auto iterations = iterations_max; + if (FLAGS_iterations > 0 && FLAGS_iterations < iterations_max) { + iterations = FLAGS_iterations; + } + + auto labels_offset_in_file = + images_offset_in_file + sizeof(float) * total_images * 3 * 224 * 224; + + TensorReader label_reader(file, labels_offset_in_file, + label_batch_shape, "label"); + for (auto i = 0; i < iterations; i++) { + auto images = image_reader.NextBatch(); + std::vector tmp_vec; + tmp_vec.push_back(std::move(images)); + if (with_accuracy_layer) { + auto labels = label_reader.NextBatch(); + tmp_vec.push_back(std::move(labels)); + } + inputs->push_back(std::move(tmp_vec)); + } +} + +TEST(Analyzer_qat_image_classification, quantization) { + AnalysisConfig fp32_cfg; + SetConfig(&fp32_cfg, FLAGS_fp32_model); + + AnalysisConfig int8_cfg; + SetConfig(&int8_cfg, FLAGS_int8_model); + + // read data from file and prepare batches with test data + std::vector> input_slots_all; + SetInput(&input_slots_all); + + // 0 is avg_cost, 1 is top1_accuracy, 2 is top5_accuracy or mAP + CompareAnalysisAndAnalysis(&fp32_cfg, &int8_cfg, input_slots_all, + FLAGS_with_accuracy_layer, 1); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index b3c3da54d19..bf06ed0fa92 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -39,9 +39,13 @@ DEFINE_string(model_name, "", "model name"); DEFINE_string(infer_model, "", "model path"); +DEFINE_string(fp32_model, "", "FP32 model path"); +DEFINE_string(int8_model, "", "INT8 model path"); DEFINE_string(infer_data, "", "data file"); DEFINE_string(refer_result, "", "reference result for comparison"); DEFINE_int32(batch_size, 1, "batch size"); +DEFINE_bool(with_accuracy_layer, true, + "Calculate the accuracy while label is in the input"); DEFINE_bool(enable_fp32, true, "Enable FP32 type prediction"); DEFINE_bool(enable_int8, true, "Enable INT8 type prediction"); DEFINE_int32(warmup_batch_size, 100, "batch size for quantization warmup"); @@ -246,7 +250,11 @@ void SetFakeImageInput(std::vector> *inputs, const std::vector *feed_names = nullptr, const int continuous_inuput_index = 0) { // Set fake_image_data - PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, "Only have single batch of data."); + PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0, + platform::errors::InvalidArgument( + "In SetFakeImageInput, expected test_all_data = false, " + "but now test_all_data=", + FLAGS_test_all_data)); std::vector> feed_target_shapes = GetFeedTargetShapes( dirname, is_combined, model_filename, params_filename); std::ostringstream os; @@ -259,7 +267,13 @@ void SetFakeImageInput(std::vector> *inputs, } LOG(INFO) << os.str(); if (feed_names) { - PADDLE_ENFORCE_EQ(feed_names->size(), feed_target_shapes.size()); + PADDLE_ENFORCE_EQ( + feed_names->size(), feed_target_shapes.size(), + platform::errors::InvalidArgument( + "The size of feeds_names and size of " + "feed_target_shapes must be equal, but now feeds_names " + "size is %d and feed_target_shapes size is %d", + feed_names->size(), feed_target_shapes.size())); } std::vector input_slots(feed_target_shapes.size()); for (size_t i = 0; i < feed_target_shapes.size(); ++i) { @@ -474,12 +488,20 @@ void TestPrediction(const PaddlePredictor::Config *config, void SummarizeAccuracy(float avg_acc_fp32, float avg_acc_int8, int compared_idx) { - PADDLE_ENFORCE_LE(compared_idx, 2, - "Compare either top1 accuracy or mAP (top5), the " - "compared_idx is out of range"); - PADDLE_ENFORCE_GE(compared_idx, 1, - "Compare either top1 accuracy or mAP (top5), the " - "compared_idx is out of range"); + PADDLE_ENFORCE_LE( + compared_idx, 2, + platform::errors::InvalidArgument( + "The compared_idx should be <= 2. But received compared_idx = %d. " + "For top1 accuracy, set compared_idx = 1; For top5 accuracy or mean " + "Average Precision (mAP), set compared_idx = 2.", + compared_idx)); + PADDLE_ENFORCE_GE( + compared_idx, 1, + platform::errors::InvalidArgument( + "The compared_idx should be >= 1. But received compared_idx = %d. " + "For top1 accuracy, set compared_idx = 1; For top5 accuracy or mean " + "Average Precision (mAP), set compared_idx = 2.", + compared_idx)); std::string prefix = (compared_idx == 1) ? "top1_accuracy " : "mAP "; LOG(INFO) << "--- Accuracy summary --- "; LOG(INFO) << "Accepted " << prefix @@ -509,9 +531,10 @@ void SummarizePerformance(float sample_latency_fp32, float CompareAccuracyOne( const std::vector> &output_slots, int compared_idx) { - if (output_slots.size() == 0) - throw std::invalid_argument( - "CompareAccuracy: output_slots vector is empty."); + PADDLE_ENFORCE_GT(output_slots.size(), 0, + platform::errors::InvalidArgument( + "The accuracy vector is empty. The accuracy vector " + "size should be bigger than 0")); float total_accs{0}; @@ -520,12 +543,19 @@ float CompareAccuracyOne( case 1: PADDLE_ENFORCE_GE( output_slots[i].size(), 2UL, - "To achieve top 1 accuracy, output_slots_quant[i].size()>=2"); + platform::errors::InvalidArgument( + "To achieve top 1 accuracy, output_slots size " + "must be bigger than or equal to 2, but now the size is %d", + output_slots[i].size())); break; case 2: PADDLE_ENFORCE_GE( - output_slots[i].size(), 2UL, - "To achieve top 1 accuracy, output_slots_ref[i].size()>=2"); + output_slots[i].size(), 3UL, + platform::errors::InvalidArgument( + "To achieve top 5 accuracy or mean Average " + "Precision (mAP), output_slots size must be " + "bigger than or equal to 3, but now the size is %d", + output_slots[i].size())); break; default: throw std::invalid_argument( @@ -543,8 +573,6 @@ float CompareAccuracyOne( *static_cast(output_slots[i][compared_idx].data.data()); } - CHECK_GT(output_slots.size(), 0); - return total_accs / output_slots.size(); } @@ -602,8 +630,14 @@ void CompareNativeAndAnalysis( std::vector> native_outputs, analysis_outputs; TestOneThreadPrediction(config, inputs, &native_outputs, false); TestOneThreadPrediction(config, inputs, &analysis_outputs, true); - PADDLE_ENFORCE_GT(native_outputs.size(), 0, "Native output is empty."); - PADDLE_ENFORCE_GT(analysis_outputs.size(), 0, "Analysis output is empty."); + PADDLE_ENFORCE_GT(native_outputs.size(), 0, + platform::errors::InvalidArgument( + "The native outputs is empty. The native outputs " + "vector size must be bigger than 0")); + PADDLE_ENFORCE_GT(analysis_outputs.size(), 0, + platform::errors::InvalidArgument( + "The analysis outputs is empty. The analysis outputs " + "vector size must be bigger than 0")); CompareResult(analysis_outputs.back(), native_outputs.back()); } @@ -611,8 +645,12 @@ void CompareQuantizedAndAnalysis( const AnalysisConfig *config, const AnalysisConfig *qconfig, const std::vector> &inputs, const int compared_idx = 1) { - PADDLE_ENFORCE_EQ(inputs[0][0].shape[0], FLAGS_batch_size, - "Input data has to be packed batch by batch."); + PADDLE_ENFORCE_EQ( + inputs[0][0].shape[0], FLAGS_batch_size, + platform::errors::InvalidArgument( + "Input data has to be packed batch by batch. The batchsize is set to " + "%d, but the real input is packed with batchsize = %d", + FLAGS_batch_size, inputs[0][0].shape[0])); LOG(INFO) << "FP32 & INT8 prediction run: batch_size " << FLAGS_batch_size << ", warmup batch size " << FLAGS_warmup_batch_size << "."; @@ -642,6 +680,48 @@ void CompareQuantizedAndAnalysis( CompareAccuracy(quantized_outputs, analysis_outputs, compared_idx); } +void CompareAnalysisAndAnalysis( + const AnalysisConfig *config1, const AnalysisConfig *config2, + const std::vector> &inputs, + const bool with_accuracy_layer = FLAGS_with_accuracy_layer, + const int compared_idx = 1) { + PADDLE_ENFORCE_EQ( + inputs[0][0].shape[0], FLAGS_batch_size, + platform::errors::InvalidArgument( + "Input data has to be packed batch by batch. The batchsize is set to " + "%d, but the real input is packed with batchsize = %d", + FLAGS_batch_size, inputs[0][0].shape[0])); + + LOG(INFO) << "FP32 & INT8 prediction run: batch_size " << FLAGS_batch_size + << ", warmup batch size " << FLAGS_warmup_batch_size << "."; + + LOG(INFO) << "--- FP32 prediction start ---"; + auto *cfg1 = reinterpret_cast(config1); + PrintConfig(cfg1, true); + std::vector> analysis_outputs; + float sample_latency_fp32{-1}; + + if (FLAGS_enable_fp32) { + TestOneThreadPrediction(cfg1, inputs, &analysis_outputs, true, + VarType::FP32, &sample_latency_fp32); + } + + LOG(INFO) << "--- INT8 prediction start ---"; + auto *cfg2 = reinterpret_cast(config2); + PrintConfig(cfg2, true); + std::vector> int8_outputs; + float sample_latency_int8{-1}; + + if (FLAGS_enable_int8) { + TestOneThreadPrediction(cfg2, inputs, &int8_outputs, true, VarType::INT8, + &sample_latency_int8); + } + SummarizePerformance(sample_latency_fp32, sample_latency_int8); + if (with_accuracy_layer) { + CompareAccuracy(int8_outputs, analysis_outputs, compared_idx); + } +} + void CompareNativeAndAnalysis( PaddlePredictor *native_pred, PaddlePredictor *analysis_pred, const std::vector> &inputs) { diff --git a/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py index eb8c131357b..0feaa62e2f6 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quantization_mkldnn_pass.py @@ -330,6 +330,17 @@ class FakeQAT2MkldnnINT8PerfPass(object): graph = self._remove_unused_var_nodes(graph) return graph + def apply_fp32(self, graph): + assert isinstance(graph, + IrGraph), 'graph must be the instance of IrGraph.' + + graph = self._gather_scales(graph) + graph = self._remove_fake_ops(graph) + graph = self._dequantize_weights(graph) + graph = self._optimize_fp32_graph(graph) + graph = self._remove_unused_var_nodes(graph) + return graph + def _convert_scale2tensor(self, scale): tensor = core.LoDTensor() tensor.set(scale, core.CPUPlace()) diff --git a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt index c69eceb1ac1..dfc5134ca1f 100644 --- a/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt +++ b/python/paddle/fluid/contrib/slim/tests/CMakeLists.txt @@ -45,6 +45,12 @@ function(inference_qat2_int8_test target model_dir data_dir test_script use_mkld --qat2) endfunction() +function(save_qat_model_test target qat_model_dir fp32_model_save_path int8_model_save_path test_script) + py_test(${target} SRCS ${test_script} + ARGS --qat_model_path ${qat_model_dir} + --fp32_model_save_path ${fp32_model_save_path} + --int8_model_save_path ${int8_model_save_path}) +endfunction() if(WIN32) list(REMOVE_ITEM TEST_OPS test_light_nas) @@ -171,6 +177,13 @@ if(LINUX AND WITH_MKLDNN) endif() inference_qat2_int8_test(test_qat2_int8_mobilenetv1_mkldnn ${QAT2_MOBILENETV1_MODEL_DIR}/MobileNet_qat_perf ${DATASET_DIR} ${MKLDNN_QAT_TEST_FILE_PATH} true) + # Save qat2 fp32 model or qat2 int8 model + + set(QAT2_INT8_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_int8") + set(QAT2_FP32_SAVE_PATH "${QAT_DATA_DIR}/ResNet50_qat2_fp32") + set(SAVE_QAT2_MODEL_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/save_qat_model.py") + save_qat_model_test(save_qat2_model_resnet50 ${QAT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QAT2_FP32_SAVE_PATH} ${QAT2_INT8_SAVE_PATH} ${SAVE_QAT2_MODEL_SCRIPT} true) + endif() # Since the test for QAT FP32 & INT8 comparison supports only testing on Linux diff --git a/python/paddle/fluid/contrib/slim/tests/save_qat_model.py b/python/paddle/fluid/contrib/slim/tests/save_qat_model.py new file mode 100644 index 00000000000..03db63fc103 --- /dev/null +++ b/python/paddle/fluid/contrib/slim/tests/save_qat_model.py @@ -0,0 +1,87 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import unittest +import os +import sys +import argparse +import logging +import struct +import six +import numpy as np +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import FakeQAT2MkldnnINT8PerfPass +from paddle.fluid import core + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--qat_model_path', type=str, default='', help='A path to a QAT model.') + parser.add_argument( + '--fp32_model_save_path', + type=str, + default='', + help='Saved optimized fp32 model') + parser.add_argument( + '--int8_model_save_path', + type=str, + default='', + help='Saved optimized and quantized INT8 model') + + test_args, args = parser.parse_known_args(namespace=unittest) + return test_args, sys.argv[:1] + args + + +def transform_and_save_model(original_path, save_path, save_type): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + inference_scope = fluid.executor.global_scope() + with fluid.scope_guard(inference_scope): + if os.path.exists(os.path.join(original_path, '__model__')): + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(original_path, exe) + else: + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(original_path, exe, + 'model', 'params') + + transform_to_mkldnn_int8_pass = FakeQAT2MkldnnINT8PerfPass( + _scope=inference_scope, _place=place, _core=core) + + graph = IrGraph(core.Graph(inference_program.desc), for_test=True) + if save_type == 'FP32': + graph = transform_to_mkldnn_int8_pass.apply_fp32(graph) + elif save_type == 'INT8': + graph = transform_to_mkldnn_int8_pass.apply(graph) + inference_program = graph.to_program() + with fluid.scope_guard(inference_scope): + fluid.io.save_inference_model(save_path, feed_target_names, + fetch_targets, exe, inference_program) + print("Success! Transformed QAT_{0} model can be found at {1}\n".format( + save_type, save_path)) + + +if __name__ == '__main__': + global test_args + test_args, remaining_args = parse_args() + if test_args.fp32_model_save_path: + transform_and_save_model(test_args.qat_model_path, + test_args.fp32_model_save_path, 'FP32') + if test_args.int8_model_save_path: + transform_and_save_model(test_args.qat_model_path, + test_args.int8_model_save_path, 'INT8') -- GitLab