From fe21578a4467862c23e83fb71a9bec194acd28da Mon Sep 17 00:00:00 2001 From: Sylwester Fraczek Date: Wed, 20 Mar 2019 19:08:36 +0100 Subject: [PATCH] create test for quantized resnet50 test=develop --- .../fluid/inference/tests/api/CMakeLists.txt | 28 +++ .../tests/api/analyzer_bert_tester.cc | 13 -- ...alyzer_int8_image_classification_tester.cc | 189 ++++++++++++++++++ .../fluid/inference/tests/api/tester_helper.h | 51 +++++ 4 files changed, 268 insertions(+), 13 deletions(-) create mode 100644 paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 2f17a44e0c..3eda73f47b 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -23,6 +23,12 @@ function(inference_analysis_api_test target install_dir filename) ARGS --infer_model=${install_dir}/model --infer_data=${install_dir}/data.txt) endfunction() +function(inference_analysis_api_int8_test target model_dir data_dir filename) + inference_analysis_test(${target} SRCS ${filename} + EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark + ARGS --infer_model=${model_dir}/model --infer_data=${data_dir}/data.bin --batch_size=100) +endfunction() + function(inference_analysis_api_test_with_fake_data target install_dir filename model_name) download_model(${install_dir} ${model_name}) inference_analysis_test(${target} SRCS ${filename} @@ -138,6 +144,28 @@ inference_analysis_api_test_with_fake_data(test_analyzer_resnet50 inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_conv "${INFERENCE_DEMO_INSTALL_DIR}/mobilenet_depthwise_conv" analyzer_resnet50_tester.cc "mobilenet_model.tar.gz" SERIAL) +# int8 image classification tests +if(WITH_MKLDNN) + set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8") + if (NOT EXISTS ${INT8_DATA_DIR}) + inference_download_and_uncompress(${INT8_DATA_DIR} "http://paddle-inference-dist.bj.bcebos.com/int8" "imagenet_val_100.bin.tar.gz") + endif() + + #resnet50 int8 + set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50") + if (NOT EXISTS ${INT8_RESNET50_MODEL_DIR}) + inference_download_and_uncompress(${INT8_RESNET50_MODEL_DIR} "http://paddle-inference-dist.bj.bcebos.com/int8" "resnet50_int8_model.tar.gz" ) + endif() + inference_analysis_api_int8_test(test_analyzer_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL) + + #mobilenet int8 + set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenet") + if (NOT EXISTS ${INT8_MOBILENET_MODEL_DIR}) + inference_download_and_uncompress(${INT8_MOBILENET_MODEL_DIR} "http://paddle-inference-dist.bj.bcebos.com/int8" "mobilenetv1_int8_model.tar.gz" ) + endif() + inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL) +endif() + # bert, max_len=20, embedding_dim=128 set(BERT_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/bert_emb128") download_model_and_data(${BERT_INSTALL_DIR} "bert_emb128_model.tar.gz" "bert_data_len20.txt.tar.gz") diff --git a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc index f646fd6d91..e73358d882 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc @@ -53,19 +53,6 @@ void Split(const std::string &line, char sep, std::vector *v) { } } -template -constexpr paddle::PaddleDType GetPaddleDType(); - -template <> -constexpr paddle::PaddleDType GetPaddleDType() { - return paddle::PaddleDType::INT64; -} - -template <> -constexpr paddle::PaddleDType GetPaddleDType() { - return paddle::PaddleDType::FLOAT32; -} - // Parse tensor from string template bool ParseTensor(const std::string &field, paddle::PaddleTensor *tensor) { diff --git a/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc new file mode 100644 index 0000000000..880aa6044c --- /dev/null +++ b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc @@ -0,0 +1,189 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/inference/api/paddle_analysis_config.h" +#include "paddle/fluid/inference/tests/api/tester_helper.h" + +DEFINE_int32(iterations, 0, "Number of iterations"); + +namespace paddle { +namespace inference { +namespace analysis { + +void SetConfig(AnalysisConfig *cfg) { + cfg->SetModel(FLAGS_infer_model); + cfg->SetProgFile("__model__"); + cfg->DisableGpu(); + cfg->SwitchIrOptim(); + cfg->SwitchSpecifyInputNames(false); + cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); + + cfg->EnableMKLDNN(); +} + +template +class TensorReader { + public: + TensorReader(std::ifstream &file, size_t beginning_offset, + std::vector shape, std::string name) + : file_(file), position(beginning_offset), shape_(shape), name_(name) { + numel = + std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); + } + + PaddleTensor NextBatch() { + PaddleTensor tensor; + tensor.name = name_; + tensor.shape = shape_; + tensor.dtype = GetPaddleDType(); + tensor.data.Resize(numel * sizeof(T)); + + file_.seekg(position); + file_.read(static_cast(tensor.data.data()), numel * sizeof(T)); + position = file_.tellg(); + + if (file_.eof()) LOG(ERROR) << name_ << ": reached end of stream"; + if (file_.fail()) + throw std::runtime_error(name_ + ": failed reading file."); + + return tensor; + } + + protected: + std::ifstream &file_; + size_t position; + std::vector shape_; + std::string name_; + size_t numel; +}; + +std::shared_ptr> GetWarmupData( + const std::vector> &test_data, int num_images) { + int test_data_batch_size = test_data[0][0].shape[0]; + CHECK_LE(static_cast(num_images), + test_data.size() * test_data_batch_size); + + PaddleTensor images; + images.name = "input"; + images.shape = {num_images, 3, 224, 224}; + images.dtype = PaddleDType::FLOAT32; + images.data.Resize(sizeof(float) * num_images * 3 * 224 * 224); + + PaddleTensor labels; + labels.name = "labels"; + labels.shape = {num_images, 1}; + labels.dtype = PaddleDType::INT64; + labels.data.Resize(sizeof(int64_t) * num_images); + + for (int i = 0; i < num_images; i++) { + auto batch = i / test_data_batch_size; + auto element_in_batch = i % test_data_batch_size; + std::copy_n(static_cast(test_data[batch][0].data.data()) + + element_in_batch * 3 * 224 * 224, + 3 * 224 * 224, + static_cast(images.data.data()) + i * 3 * 224 * 224); + + std::copy_n(static_cast(test_data[batch][1].data.data()) + + element_in_batch, + 1, static_cast(labels.data.data()) + i); + } + + auto warmup_data = std::make_shared>(2); + (*warmup_data)[0] = std::move(images); + (*warmup_data)[1] = std::move(labels); + return warmup_data; +} + +void SetInput(std::vector> *inputs, + int32_t batch_size = FLAGS_batch_size) { + std::ifstream file(FLAGS_infer_data, std::ios::binary); + if (!file) { + FAIL() << "Couldn't open file: " << FLAGS_infer_data; + } + + int64_t total_images{0}; + file.read(reinterpret_cast(&total_images), sizeof(total_images)); + LOG(INFO) << "Total images in file: " << total_images; + + std::vector image_batch_shape{batch_size, 3, 224, 224}; + std::vector label_batch_shape{batch_size, 1}; + auto labels_offset_in_file = + static_cast(file.tellg()) + + sizeof(float) * total_images * + std::accumulate(image_batch_shape.begin() + 1, + image_batch_shape.end(), 1, std::multiplies()); + + TensorReader image_reader(file, 0, image_batch_shape, "input"); + TensorReader label_reader(file, labels_offset_in_file, + label_batch_shape, "label"); + + auto iterations = total_images / batch_size; + if (FLAGS_iterations > 0 && FLAGS_iterations < iterations) + iterations = FLAGS_iterations; + for (auto i = 0; i < iterations; i++) { + auto images = image_reader.NextBatch(); + auto labels = label_reader.NextBatch(); + inputs->emplace_back( + std::vector{std::move(images), std::move(labels)}); + } +} + +TEST(Analyzer_int8_resnet50, quantization) { + AnalysisConfig cfg; + SetConfig(&cfg); + + AnalysisConfig q_cfg; + SetConfig(&q_cfg); + + std::vector> input_slots_all; + SetInput(&input_slots_all, 100); + + std::shared_ptr> warmup_data = + GetWarmupData(input_slots_all, 100); + + q_cfg.EnableMkldnnQuantizer(); + q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data); + q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(100); + + CompareQuantizedAndAnalysis( + reinterpret_cast(&cfg), + reinterpret_cast(&q_cfg), + input_slots_all); +} + +TEST(Analyzer_int8_resnet50, profile) { + AnalysisConfig cfg; + SetConfig(&cfg); + + std::vector> input_slots_all; + SetInput(&input_slots_all); + + std::shared_ptr> warmup_data = + GetWarmupData(input_slots_all, 100); + + cfg.EnableMkldnnQuantizer(); + cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data); + cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(100); + + std::vector outputs; + + TestPrediction(reinterpret_cast(&cfg), + input_slots_all, &outputs, FLAGS_num_threads); +} + +} // namespace analysis +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index a4881afe58..33f1d02548 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -50,6 +50,7 @@ DEFINE_bool(use_analysis, true, DEFINE_bool(record_benchmark, false, "Record benchmark after profiling the model"); DEFINE_double(accuracy, 1e-3, "Result Accuracy."); +DEFINE_double(quantized_accuracy, 1e-2, "Result Quantized Accuracy."); DEFINE_bool(zero_copy, false, "Use ZeroCopy to speedup Feed/Fetch."); DECLARE_bool(profile); @@ -58,6 +59,19 @@ DECLARE_int32(paddle_num_threads); namespace paddle { namespace inference { +template +constexpr paddle::PaddleDType GetPaddleDType(); + +template <> +constexpr paddle::PaddleDType GetPaddleDType() { + return paddle::PaddleDType::INT64; +} + +template <> +constexpr paddle::PaddleDType GetPaddleDType() { + return paddle::PaddleDType::FLOAT32; +} + void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) { const auto *analysis_config = reinterpret_cast(config); @@ -392,6 +406,32 @@ void TestPrediction(const PaddlePredictor::Config *config, } } +void CompareTopAccuracy(const std::vector &output_slots1, + const std::vector &output_slots2) { + // first output: avg_cost + if (output_slots1.size() == 0 || output_slots2.size() == 0) + throw std::invalid_argument( + "CompareTopAccuracy: output_slots vector is empty."); + PADDLE_ENFORCE(output_slots1.size() >= 2UL); + PADDLE_ENFORCE(output_slots2.size() >= 2UL); + + // second output: acc_top1 + if (output_slots1[1].lod.size() > 0 || output_slots2[1].lod.size() > 0) + throw std::invalid_argument( + "CompareTopAccuracy: top1 accuracy output has nonempty LoD."); + if (output_slots1[1].dtype != paddle::PaddleDType::FLOAT32 || + output_slots2[1].dtype != paddle::PaddleDType::FLOAT32) + throw std::invalid_argument( + "CompareTopAccuracy: top1 accuracy output is of a wrong type."); + float *top1_quantized = static_cast(output_slots1[1].data.data()); + float *top1_reference = static_cast(output_slots2[1].data.data()); + LOG(INFO) << "top1 INT8 accuracy: " << *top1_quantized; + LOG(INFO) << "top1 FP32 accuracy: " << *top1_reference; + LOG(INFO) << "Accepted accuracy drop threshold: " << FLAGS_quantized_accuracy; + CHECK_LE(std::abs(*top1_quantized - *top1_reference), + FLAGS_quantized_accuracy); +} + void CompareDeterministic( const PaddlePredictor::Config *config, const std::vector> &inputs) { @@ -421,6 +461,17 @@ void CompareNativeAndAnalysis( CompareResult(analysis_outputs, native_outputs); } +void CompareQuantizedAndAnalysis( + const PaddlePredictor::Config *config, + const PaddlePredictor::Config *qconfig, + const std::vector> &inputs) { + PrintConfig(config, true); + std::vector analysis_outputs, quantized_outputs; + TestOneThreadPrediction(config, inputs, &analysis_outputs, true); + TestOneThreadPrediction(qconfig, inputs, &quantized_outputs, true); + CompareTopAccuracy(quantized_outputs, analysis_outputs); +} + void CompareNativeAndAnalysis( PaddlePredictor *native_pred, PaddlePredictor *analysis_pred, const std::vector> &inputs) { -- GitLab