diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index aee94e12340597e981ac385a01335d2ffa069191..85910c10e7409ac629c8fb5265550a5cfa6bc2c2 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -141,7 +141,6 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { void AnalysisConfig::EnableMKLDNN() { #ifdef PADDLE_WITH_MKLDNN - pass_builder()->EnableMKLDNN(); use_mkldnn_ = true; #else LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN"; @@ -234,16 +233,13 @@ void AnalysisConfig::Update() { } if (use_mkldnn_) { +#ifdef PADDLE_WITH_MKLDNN if (!enable_ir_optim_) { LOG(ERROR) << "EnableMKLDNN() only works when IR optimization is enabled."; + } else { + pass_builder()->EnableMKLDNN(); } -#ifdef PADDLE_WITH_MKLDNN - pass_builder()->EnableMKLDNN(); - use_mkldnn_ = true; -#else - LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN"; - use_mkldnn_ = false; #endif } @@ -255,9 +251,6 @@ void AnalysisConfig::Update() { } #ifdef PADDLE_WITH_MKLDNN pass_builder()->EnableMkldnnQuantizer(); -#else - LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer"; - use_mkldnn_quantizer_ = false; #endif } diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index 258a79fa4e884177490fab79778151ae52537aa0..c89dd41e0a6283e0723e2925f28c0372cda6a2b2 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -27,6 +27,7 @@ #include #include #include "paddle/fluid/inference/api/paddle_inference_api.h" +#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/port.h" #include "paddle/fluid/string/printf.h" @@ -266,17 +267,17 @@ static std::string DescribeZeroCopyTensor(const ZeroCopyTensor &tensor) { } static void PrintTime(int batch_size, int repeat, int num_threads, int tid, - double latency, int epoch = 1) { - LOG(INFO) << "====== batch_size: " << batch_size << ", repeat: " << repeat - << ", threads: " << num_threads << ", thread id: " << tid - << ", latency: " << latency << "ms, fps: " << 1 / (latency / 1000.f) + double batch_latency, int epoch = 1) { + PADDLE_ENFORCE(batch_size > 0, "Non-positive batch size."); + double sample_latency = batch_latency / batch_size; + LOG(INFO) << "====== threads: " << num_threads << ", thread id: " << tid << " ======"; - if (epoch > 1) { - int samples = batch_size * epoch; - LOG(INFO) << "====== sample number: " << samples - << ", average latency of each sample: " << latency / samples - << "ms ======"; - } + LOG(INFO) << "====== batch_size: " << batch_size << ", iterations: " << epoch + << ", repetitions: " << repeat << " ======"; + LOG(INFO) << "====== batch latency: " << batch_latency + << "ms, number of samples: " << batch_size * epoch + << ", sample latency: " << sample_latency + << "ms, fps: " << 1000.f / sample_latency << " ======"; } static bool IsFileExists(const std::string &path) { diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 1d1d39e44096b9f50e5bc9603fa12aba92b0e8e2..87e02a02caebd93d701dfd9e51c35fb974c770ed 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -64,10 +64,12 @@ void PaddlePassBuilder::DeletePass(size_t idx) { passes_.erase(std::begin(passes_) + idx); } -void GpuPassStrategy::EnableMKLDNN() { - LOG(ERROR) << "GPU not support MKLDNN yet"; +void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { + analysis_passes_.push_back(pass); } +void PaddlePassBuilder::ClearPasses() { passes_.clear(); } + // The following passes works for Anakin sub-graph engine. const std::vector kAnakinSubgraphPasses({ "infer_clean_graph_pass", // @@ -102,12 +104,12 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { use_gpu_ = true; } -void GpuPassStrategy::EnableMkldnnQuantizer() { - LOG(ERROR) << "GPU not support MKL-DNN quantization"; +void GpuPassStrategy::EnableMKLDNN() { + LOG(ERROR) << "GPU not support MKLDNN yet"; } -void PaddlePassBuilder::AppendAnalysisPass(const std::string &pass) { - analysis_passes_.push_back(pass); +void GpuPassStrategy::EnableMkldnnQuantizer() { + LOG(ERROR) << "GPU not support MKL-DNN quantization"; } CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { @@ -135,5 +137,39 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { }); use_gpu_ = false; } -void PaddlePassBuilder::ClearPasses() { passes_.clear(); } + +void CpuPassStrategy::EnableMKLDNN() { +// TODO(Superjomn) Consider the way to mix CPU with GPU. +#ifdef PADDLE_WITH_MKLDNN + if (!use_mkldnn_) { + passes_.insert(passes_.begin(), "mkldnn_placement_pass"); + + for (auto &pass : std::vector( + {"depthwise_conv_mkldnn_pass", // + "conv_bn_fuse_pass", // Execute BN passes again to + "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order + "conv_bias_mkldnn_fuse_pass", // + "conv3d_bias_mkldnn_fuse_pass", // + "conv_elementwise_add_mkldnn_fuse_pass", + "conv_relu_mkldnn_fuse_pass"})) { + passes_.push_back(pass); + } + } + use_mkldnn_ = true; +#else + use_mkldnn_ = false; +#endif +} + +void CpuPassStrategy::EnableMkldnnQuantizer() { +#ifdef PADDLE_WITH_MKLDNN + if (!use_mkldnn_quantizer_) { + passes_.push_back("cpu_quantize_placement_pass"); + } + use_mkldnn_quantizer_ = true; +#else + use_mkldnn_quantizer_ = false; +#endif +} + } // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.h b/paddle/fluid/inference/api/paddle_pass_builder.h index 48da8c156f426477011bcc060260c812ad94df23..09ef195d5e66aff0cef17f1594de34c656187a35 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.h +++ b/paddle/fluid/inference/api/paddle_pass_builder.h @@ -109,43 +109,16 @@ class CpuPassStrategy : public PassStrategy { CpuPassStrategy(); explicit CpuPassStrategy(const CpuPassStrategy &other) - : PassStrategy(other.AllPasses()) {} + : PassStrategy(other.AllPasses()) { + use_gpu_ = other.use_gpu_; + use_mkldnn_ = other.use_mkldnn_; + use_mkldnn_quantizer_ = other.use_mkldnn_quantizer_; + } virtual ~CpuPassStrategy() = default; - void EnableMKLDNN() override { -// TODO(Superjomn) Consider the way to mix CPU with GPU. -#ifdef PADDLE_WITH_MKLDNN - if (!use_mkldnn_) { - passes_.insert(passes_.begin(), "mkldnn_placement_pass"); - - for (auto &pass : std::vector( - {"depthwise_conv_mkldnn_pass", // - "conv_bn_fuse_pass", // Execute BN passes again to - "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order - "conv_bias_mkldnn_fuse_pass", // - "conv3d_bias_mkldnn_fuse_pass", // - "conv_relu_mkldnn_fuse_pass", // - "conv_elementwise_add_mkldnn_fuse_pass"})) { - passes_.push_back(pass); - } - } - use_mkldnn_ = true; -#else - use_mkldnn_ = false; -#endif - } - - void EnableMkldnnQuantizer() override { -#ifdef PADDLE_WITH_MKLDNN - if (!use_mkldnn_quantizer_) { - passes_.push_back("cpu_quantize_placement_pass"); - } - use_mkldnn_quantizer_ = true; -#else - use_mkldnn_quantizer_ = false; -#endif - } + void EnableMKLDNN() override; + void EnableMkldnnQuantizer() override; protected: bool use_mkldnn_quantizer_{false}; diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index 6a31185b097bc0ddf93a6e32e61ac0a9f2d04cfd..d3d278822bacee93c19b5c4d0015652c8ef2a6b5 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -26,7 +26,11 @@ endfunction() function(inference_analysis_api_int8_test target model_dir data_dir filename) inference_analysis_test(${target} SRCS ${filename} EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} benchmark - ARGS --infer_model=${model_dir}/model --infer_data=${data_dir}/data.bin --batch_size=100) + ARGS --infer_model=${model_dir}/model + --infer_data=${data_dir}/data.bin + --warmup_batch_size=100 + --batch_size=50 + --iterations=2) endfunction() function(inference_analysis_api_test_with_fake_data target install_dir filename model_name) @@ -146,22 +150,22 @@ inference_analysis_api_test_with_fake_data(test_analyzer_mobilenet_depthwise_con # int8 image classification tests if(WITH_MKLDNN) - set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8") + set(INT8_DATA_DIR "${INFERENCE_DEMO_INSTALL_DIR}/int8v2") if (NOT EXISTS ${INT8_DATA_DIR}) - inference_download_and_uncompress(${INT8_DATA_DIR} "https://paddle-inference-dist.bj.bcebos.com/int8" "imagenet_val_100.tar.gz") + inference_download_and_uncompress(${INT8_DATA_DIR} "${INFERENCE_URL}/int8" "imagenet_val_100_tail.tar.gz") endif() #resnet50 int8 set(INT8_RESNET50_MODEL_DIR "${INT8_DATA_DIR}/resnet50") if (NOT EXISTS ${INT8_RESNET50_MODEL_DIR}) - inference_download_and_uncompress(${INT8_RESNET50_MODEL_DIR} "https://paddle-inference-dist.bj.bcebos.com/int8" "resnet50_int8_model.tar.gz" ) + inference_download_and_uncompress(${INT8_RESNET50_MODEL_DIR} "${INFERENCE_URL}/int8" "resnet50_int8_model.tar.gz" ) endif() inference_analysis_api_int8_test(test_analyzer_int8_resnet50 ${INT8_RESNET50_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL) #mobilenet int8 set(INT8_MOBILENET_MODEL_DIR "${INT8_DATA_DIR}/mobilenet") if (NOT EXISTS ${INT8_MOBILENET_MODEL_DIR}) - inference_download_and_uncompress(${INT8_MOBILENET_MODEL_DIR} "https://paddle-inference-dist.bj.bcebos.com/int8" "mobilenetv1_int8_model.tar.gz" ) + inference_download_and_uncompress(${INT8_MOBILENET_MODEL_DIR} "${INFERENCE_URL}/int8" "mobilenetv1_int8_model.tar.gz" ) endif() inference_analysis_api_int8_test(test_analyzer_int8_mobilenet ${INT8_MOBILENET_MODEL_DIR} ${INT8_DATA_DIR} analyzer_int8_image_classification_tester.cc SERIAL) endif() diff --git a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc index e73358d8827a40786beb05fad931267b0dd88f6b..9b2e74ec16eb3b6e98bfcc8cc546ed74a7966f33 100644 --- a/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_bert_tester.cc @@ -154,7 +154,7 @@ void profile(bool use_mkldnn = false) { config.EnableMKLDNN(); } - std::vector outputs; + std::vector> outputs; std::vector> inputs; LoadInputData(&inputs); TestPrediction(reinterpret_cast(&config), diff --git a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc index 735e4fb563788438ee49ff6308d11f4dbe4962be..e10d239a5d1b30e089a110c6155520e3b035860a 100644 --- a/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_dam_tester.cc @@ -197,7 +197,7 @@ void profile(bool use_mkldnn = false) { cfg.SetMKLDNNOp(op_list); } - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); @@ -206,9 +206,11 @@ void profile(bool use_mkldnn = false) { if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { PADDLE_ENFORCE_GT(outputs.size(), 0); - size_t size = GetSize(outputs[0]); + auto output = outputs.back(); + PADDLE_ENFORCE_GT(output.size(), 0); + size_t size = GetSize(output[0]); PADDLE_ENFORCE_GT(size, 0); - float *result = static_cast(outputs[0].data.data()); + float *result = static_cast(output[0].data.data()); for (size_t i = 0; i < size; i++) { EXPECT_NEAR(result[i], result_data[i], 1e-3); } diff --git a/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc index 5a4f9a31a164a8fca3f80ce2fe2e6065fd04b340..ece094717b8076321c68d7fdd29f07c4da6b0ed4 100644 --- a/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_int8_image_classification_tester.cc @@ -17,8 +17,6 @@ limitations under the License. */ #include "paddle/fluid/inference/api/paddle_analysis_config.h" #include "paddle/fluid/inference/tests/api/tester_helper.h" -DEFINE_int32(iterations, 0, "Number of iterations"); - namespace paddle { namespace inference { namespace analysis { @@ -30,8 +28,13 @@ void SetConfig(AnalysisConfig *cfg) { cfg->SwitchIrOptim(); cfg->SwitchSpecifyInputNames(false); cfg->SetCpuMathLibraryNumThreads(FLAGS_paddle_num_threads); - cfg->EnableMKLDNN(); + cfg->pass_builder()->SetPasses( + {"infer_clean_graph_pass", "mkldnn_placement_pass", + "depthwise_conv_mkldnn_pass", "conv_bn_fuse_pass", + "conv_eltwiseadd_bn_fuse_pass", "conv_bias_mkldnn_fuse_pass", + "conv_elementwise_add_mkldnn_fuse_pass", "conv_relu_mkldnn_fuse_pass", + "fc_fuse_pass", "is_test_pass"}); } template @@ -40,8 +43,8 @@ class TensorReader { TensorReader(std::ifstream &file, size_t beginning_offset, std::vector shape, std::string name) : file_(file), position(beginning_offset), shape_(shape), name_(name) { - numel = - std::accumulate(shape_.begin(), shape_.end(), 1, std::multiplies()); + numel = std::accumulate(shape_.begin(), shape_.end(), size_t{1}, + std::multiplies()); } PaddleTensor NextBatch() { @@ -71,10 +74,14 @@ class TensorReader { }; std::shared_ptr> GetWarmupData( - const std::vector> &test_data, int num_images) { + const std::vector> &test_data, + int num_images = FLAGS_warmup_batch_size) { int test_data_batch_size = test_data[0][0].shape[0]; - CHECK_LE(static_cast(num_images), - test_data.size() * test_data_batch_size); + auto iterations_max = test_data.size(); + PADDLE_ENFORCE( + static_cast(num_images) <= iterations_max * test_data_batch_size, + "The requested quantization warmup data size " + + std::to_string(num_images) + " is bigger than all test data size."); PaddleTensor images; images.name = "input"; @@ -120,20 +127,17 @@ void SetInput(std::vector> *inputs, std::vector image_batch_shape{batch_size, 3, 224, 224}; std::vector label_batch_shape{batch_size, 1}; + auto images_offset_in_file = static_cast(file.tellg()); auto labels_offset_in_file = - static_cast(file.tellg()) + - sizeof(float) * total_images * - std::accumulate(image_batch_shape.begin() + 1, - image_batch_shape.end(), 1, std::multiplies()); + images_offset_in_file + sizeof(float) * total_images * 3 * 224 * 224; - TensorReader image_reader(file, 0, image_batch_shape, "input"); + TensorReader image_reader(file, images_offset_in_file, + image_batch_shape, "input"); TensorReader label_reader(file, labels_offset_in_file, label_batch_shape, "label"); - auto iterations = total_images / batch_size; - if (FLAGS_iterations > 0 && FLAGS_iterations < iterations) - iterations = FLAGS_iterations; - for (auto i = 0; i < iterations; i++) { + auto iterations_max = total_images / batch_size; + for (auto i = 0; i < iterations_max; i++) { auto images = image_reader.NextBatch(); auto labels = label_reader.NextBatch(); inputs->emplace_back( @@ -148,20 +152,21 @@ TEST(Analyzer_int8_resnet50, quantization) { AnalysisConfig q_cfg; SetConfig(&q_cfg); + // read data from file and prepare batches with test data std::vector> input_slots_all; - SetInput(&input_slots_all, 100); + SetInput(&input_slots_all); + // prepare warmup batch from input data read earlier + // warmup batch size can be different than batch size std::shared_ptr> warmup_data = - GetWarmupData(input_slots_all, 100); + GetWarmupData(input_slots_all); + // configure quantizer q_cfg.EnableMkldnnQuantizer(); q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data); - q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(100); + q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(FLAGS_warmup_batch_size); - CompareQuantizedAndAnalysis( - reinterpret_cast(&cfg), - reinterpret_cast(&q_cfg), - input_slots_all); + CompareQuantizedAndAnalysis(&cfg, &q_cfg, input_slots_all); } } // namespace analysis diff --git a/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc b/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc index 347672eaae314aa42096d48a3b044014f2ddbf84..142905dcd8d9964d93d0c5f7444823eef2b84900 100644 --- a/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_lac_tester.cc @@ -124,7 +124,7 @@ void SetInput(std::vector> *inputs) { TEST(Analyzer_LAC, profile) { AnalysisConfig cfg; SetConfig(&cfg); - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); @@ -137,11 +137,13 @@ TEST(Analyzer_LAC, profile) { 24, 25, 25, 25, 38, 30, 31, 14, 15, 44, 24, 25, 25, 25, 25, 25, 44, 24, 25, 25, 25, 36, 42, 43, 44, 14, 15, 44, 14, 15, 44, 14, 15, 44, 38, 39, 14, 15, 44, 22, 23, 23, 23, 23, 23, 23, 23}; - PADDLE_ENFORCE_EQ(outputs.size(), 1UL); - size_t size = GetSize(outputs[0]); + PADDLE_ENFORCE_GT(outputs.size(), 0); + auto output = outputs.back(); + PADDLE_ENFORCE_EQ(output.size(), 1UL); + size_t size = GetSize(output[0]); size_t batch1_size = sizeof(lac_ref_data) / sizeof(int64_t); PADDLE_ENFORCE_GE(size, batch1_size); - int64_t *pdata = static_cast(outputs[0].data.data()); + int64_t *pdata = static_cast(output[0].data.data()); for (size_t i = 0; i < batch1_size; ++i) { EXPECT_EQ(pdata[i], lac_ref_data[i]); } diff --git a/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc b/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc index 089f655c180d784af66af60277bdbf32a6019599..2eb347a44b394a55706d5aa88bee7fe1fcc7838e 100644 --- a/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_mm_dnn_tester.cc @@ -96,7 +96,7 @@ void SetInput(std::vector> *inputs) { void profile(bool use_mkldnn = false) { AnalysisConfig cfg; SetConfig(&cfg); - std::vector outputs; + std::vector> outputs; if (use_mkldnn) { cfg.EnableMKLDNN(); @@ -108,8 +108,9 @@ void profile(bool use_mkldnn = false) { input_slots_all, &outputs, FLAGS_num_threads); if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { - PADDLE_ENFORCE_EQ(outputs.size(), 2UL); - for (auto &output : outputs) { + PADDLE_ENFORCE_GT(outputs.size(), 0); + PADDLE_ENFORCE_EQ(outputs.back().size(), 2UL); + for (auto &output : outputs.back()) { size_t size = GetSize(output); PADDLE_ENFORCE_GT(size, 0); float *result = static_cast(output.data.data()); diff --git a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc index a70aa7a6ac41121a0c8ea397ebc7e24e4b206d12..36e07d5f55600dc7aa96227289f707fb19f92d56 100644 --- a/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_ner_tester.cc @@ -106,7 +106,7 @@ void SetInput(std::vector> *inputs) { void profile(bool memory_load = false) { AnalysisConfig cfg; SetConfig(&cfg, memory_load); - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); @@ -117,10 +117,12 @@ void profile(bool memory_load = false) { // the first inference result const int chinese_ner_result_data[] = {30, 45, 41, 48, 17, 26, 48, 39, 38, 16, 25}; - PADDLE_ENFORCE_EQ(outputs.size(), 1UL); - size_t size = GetSize(outputs[0]); + PADDLE_ENFORCE_GT(outputs.size(), 0); + auto output = outputs.back(); + PADDLE_ENFORCE_EQ(output.size(), 1UL); + size_t size = GetSize(output[0]); PADDLE_ENFORCE_GT(size, 0); - int64_t *result = static_cast(outputs[0].data.data()); + int64_t *result = static_cast(output[0].data.data()); for (size_t i = 0; i < std::min(11UL, size); i++) { EXPECT_EQ(result[i], chinese_ner_result_data[i]); } diff --git a/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc b/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc index 5157bd280d0f3ee327d5cee7799477b5e6fd3f71..9443b08063b8f61d3d6b291a7217d645d8825c54 100644 --- a/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_pyramid_dnn_tester.cc @@ -127,7 +127,7 @@ void SetInput(std::vector> *inputs) { TEST(Analyzer_Pyramid_DNN, profile) { AnalysisConfig cfg; SetConfig(&cfg); - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); @@ -135,10 +135,12 @@ TEST(Analyzer_Pyramid_DNN, profile) { input_slots_all, &outputs, FLAGS_num_threads); if (FLAGS_num_threads == 1 && !FLAGS_test_all_data && !FLAGS_zero_copy) { - PADDLE_ENFORCE_EQ(outputs.size(), 1UL); - size_t size = GetSize(outputs[0]); + PADDLE_ENFORCE_GT(outputs.size(), 0); + auto output = outputs.back(); + PADDLE_ENFORCE_EQ(output.size(), 1UL); + size_t size = GetSize(output[0]); PADDLE_ENFORCE_GT(size, 0); - float *result = static_cast(outputs[0].data.data()); + float *result = static_cast(output[0].data.data()); // output is probability, which is in (0, 1). for (size_t i = 0; i < size; i++) { EXPECT_GT(result[i], 0); diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc index 629981d565f1b6eeabc192287cb9f892df21b8e4..d4330e6cddf8818ace01be2f13a4c18a192c46e1 100644 --- a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -40,7 +40,7 @@ void profile(bool use_mkldnn = false) { if (use_mkldnn) { cfg.EnableMKLDNN(); } - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc index dcf4b38ce8a9230148738cfd0840ca96b0c7cf8c..54fd3a4a4caba52110ab636e6d44ee2a473f0cb0 100644 --- a/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc @@ -229,7 +229,7 @@ TEST(Analyzer_rnn1, profile) { SetConfig(&cfg); cfg.DisableGpu(); cfg.SwitchIrDebug(); - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); @@ -280,7 +280,7 @@ TEST(Analyzer_rnn1, compare_determine) { TEST(Analyzer_rnn1, multi_thread) { AnalysisConfig cfg; SetConfig(&cfg); - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); diff --git a/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc b/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc index 007f9f0b66a7b276f5f2e8500a3001788ad41e79..9ccbf58cbd2bbaab9b1a132c27e50356e1a5df37 100644 --- a/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_rnn2_tester.cc @@ -126,7 +126,7 @@ void SetInput(std::vector> *inputs) { TEST(Analyzer_rnn2, profile) { AnalysisConfig cfg; SetConfig(&cfg); - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); @@ -136,9 +136,11 @@ TEST(Analyzer_rnn2, profile) { if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { // the first inference result PADDLE_ENFORCE_GT(outputs.size(), 0); - size_t size = GetSize(outputs[0]); + auto output = outputs.back(); + PADDLE_ENFORCE_GT(output.size(), 0); + size_t size = GetSize(output[0]); PADDLE_ENFORCE_GT(size, 0); - float *result = static_cast(outputs[0].data.data()); + float *result = static_cast(output[0].data.data()); for (size_t i = 0; i < size; i++) { EXPECT_NEAR(result[i], result_data[i], 1e-3); } diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc index 47c1d7375843e4bad212c1d7d621c9e6d45e5982..9f23b9f037bcaeb758312d011067ae29c82e73cd 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc @@ -110,7 +110,7 @@ void SetInput(std::vector> *inputs) { TEST(Analyzer_seq_conv1, profile) { AnalysisConfig cfg; SetConfig(&cfg); - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); @@ -119,10 +119,12 @@ TEST(Analyzer_seq_conv1, profile) { if (FLAGS_num_threads == 1 && !FLAGS_test_all_data) { // the first inference result - PADDLE_ENFORCE_EQ(outputs.size(), 1UL); - size_t size = GetSize(outputs[0]); + PADDLE_ENFORCE_GT(outputs.size(), 0); + auto output = outputs.back(); + PADDLE_ENFORCE_EQ(output.size(), 1UL); + size_t size = GetSize(output[0]); PADDLE_ENFORCE_GT(size, 0); - float *result = static_cast(outputs[0].data.data()); + float *result = static_cast(output[0].data.data()); // output is probability, which is in (0, 1). for (size_t i = 0; i < size; i++) { EXPECT_GT(result[i], 0); diff --git a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc index 19fa5528da4d11d2eb1a2f932f60a84c3f5468e7..d6f7f468a6c83bd6c4ac087931d0c6b7cac3cc1c 100644 --- a/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_seq_pool1_tester.cc @@ -156,7 +156,7 @@ void profile(bool use_mkldnn = false) { AnalysisConfig cfg; SetConfig(&cfg, use_mkldnn); - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); TestPrediction(reinterpret_cast(&cfg), diff --git a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc index 2003be82019333ca97b9fa8ef83668825fe5710d..54492dbc238bbaf25f86b300fdd6585f74365088 100644 --- a/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc @@ -70,7 +70,7 @@ TEST(Analyzer_Text_Classification, profile) { AnalysisConfig cfg; SetConfig(&cfg); cfg.SwitchIrDebug(); - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); @@ -79,8 +79,9 @@ TEST(Analyzer_Text_Classification, profile) { if (FLAGS_num_threads == 1) { // Get output - LOG(INFO) << "get outputs " << outputs.size(); - for (auto &output : outputs) { + PADDLE_ENFORCE_GT(outputs.size(), 0); + LOG(INFO) << "get outputs " << outputs.back().size(); + for (auto &output : outputs.back()) { LOG(INFO) << "output.shape: " << to_string(output.shape); // no lod ? CHECK_EQ(output.lod.size(), 0UL); diff --git a/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc b/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc index a925da312cde30380b4997b8b76a0d425a71e817..bd4f1b61973fb0de06dcc288e329c94756d5ed47 100644 --- a/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_transformer_tester.cc @@ -186,7 +186,7 @@ void SetInput(std::vector> *inputs) { void profile(bool use_mkldnn = false) { AnalysisConfig cfg; SetConfig(&cfg); - std::vector outputs; + std::vector> outputs; if (use_mkldnn) { cfg.EnableMKLDNN(); } diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc index ca04c1365cbbffcb4a2786cde9ab240cc20aa3d8..fb47048cd0ccc887927cb4b533d45df11ef633eb 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc @@ -87,7 +87,7 @@ void profile(bool use_mkldnn = false) { cfg.EnableMKLDNN(); } // cfg.pass_builder()->TurnOnDebug(); - std::vector outputs; + std::vector> outputs; std::vector> input_slots_all; SetInput(&input_slots_all); @@ -100,7 +100,8 @@ void profile(bool use_mkldnn = false) { auto refer = ProcessALine(line); file.close(); - auto &output = outputs.front(); + PADDLE_ENFORCE_GT(outputs.size(), 0); + auto &output = outputs.back().front(); size_t numel = output.data.length() / PaddleDtypeSize(output.dtype); CHECK_EQ(numel, refer.data.size()); for (size_t i = 0; i < numel; ++i) { diff --git a/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py b/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py index 4d968c83d9c9bf9d947204d73f4460e62039cdda..842865933f2b4741aea034b19952d4c59344ba06 100644 --- a/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py +++ b/paddle/fluid/inference/tests/api/full_ILSVRC2012_val_preprocess.py @@ -1,5 +1,4 @@ # copyright (c) 2019 paddlepaddle authors. all rights reserved. -# # licensed under the apache license, version 2.0 (the "license"); # you may not use this file except in compliance with the license. # you may obtain a copy of the license at @@ -11,6 +10,7 @@ # without warranties or conditions of any kind, either express or implied. # see the license for the specific language governing permissions and # limitations under the license. +import hashlib import unittest import os import numpy as np @@ -21,16 +21,20 @@ import functools import contextlib from PIL import Image, ImageEnhance import math -from paddle.dataset.common import download +from paddle.dataset.common import download, md5file +import tarfile random.seed(0) np.random.seed(0) DATA_DIM = 224 - SIZE_FLOAT32 = 4 SIZE_INT64 = 8 - +FULL_SIZE_BYTES = 30106000008 +FULL_IMAGES = 50000 +DATA_DIR_NAME = 'ILSVRC2012' +IMG_DIR_NAME = 'var' +TARGET_HASH = '8dc592db6dcc8d521e4d5ba9da5ca7d2' img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) @@ -70,19 +74,9 @@ def process_image(img_path, mode, color_jitter, rotate): return img -def download_unzip(): - int8_download = 'int8/download' - - target_name = 'data' - - cache_folder = os.path.expanduser('~/.cache/paddle/dataset/' + - int8_download) - - target_folder = os.path.join(cache_folder, target_name) - +def download_concat(cache_folder, zip_path): data_urls = [] data_md5s = [] - data_urls.append( 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partaa' ) @@ -91,72 +85,138 @@ def download_unzip(): 'https://paddle-inference-dist.bj.bcebos.com/int8/ILSVRC2012_img_val.tar.gz.partab' ) data_md5s.append('1e9f15f64e015e58d6f9ec3210ed18b5') - file_names = [] - + print("Downloading full ImageNet Validation dataset ...") for i in range(0, len(data_urls)): download(data_urls[i], cache_folder, data_md5s[i]) - file_names.append(data_urls[i].split('/')[-1]) - - zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz') - + file_name = os.path.join(cache_folder, data_urls[i].split('/')[-1]) + file_names.append(file_name) + print("Downloaded part {0}\n".format(file_name)) if not os.path.exists(zip_path): - cat_command = 'cat' - for file_name in file_names: - cat_command += ' ' + os.path.join(cache_folder, file_name) - cat_command += ' > ' + zip_path - os.system(cat_command) - print('Data is downloaded at {0}\n').format(zip_path) - - if not os.path.exists(target_folder): - cmd = 'mkdir {0} && tar xf {1} -C {0}'.format(target_folder, zip_path) - os.system(cmd) - print('Data is unzipped at {0}\n'.format(target_folder)) - - data_dir = os.path.join(target_folder, 'ILSVRC2012') - print('ILSVRC2012 full val set at {0}\n'.format(data_dir)) - return data_dir + with open(zip_path, "w+") as outfile: + for fname in file_names: + with open(fname) as infile: + outfile.write(infile.read()) + + +def extract(zip_path, extract_folder): + data_dir = os.path.join(extract_folder, DATA_DIR_NAME) + img_dir = os.path.join(data_dir, IMG_DIR_NAME) + print("Extracting...\n") + + if not (os.path.exists(img_dir) and + len(os.listdir(img_dir)) == FULL_IMAGES): + tar = tarfile.open(zip_path) + tar.extractall(path=extract_folder) + tar.close() + print('Extracted. Full Imagenet Validation dataset is located at {0}\n'. + format(data_dir)) + + +def print_processbar(done, total): + done_filled = done * '=' + empty_filled = (total - done) * ' ' + percentage_done = done * 100 / total + sys.stdout.write("\r[%s%s]%d%%" % + (done_filled, empty_filled, percentage_done)) + sys.stdout.flush() + + +def check_integrity(filename, target_hash): + print('\nThe binary file exists. Checking file integrity...\n') + md = hashlib.md5() + count = 0 + total_parts = 50 + chunk_size = 8192 + onepart = FULL_SIZE_BYTES / chunk_size / total_parts + with open(filename) as ifs: + while True: + buf = ifs.read(8192) + if count % onepart == 0: + done = count / onepart + print_processbar(done, total_parts) + count = count + 1 + if not buf: + break + md.update(buf) + hash1 = md.hexdigest() + if hash1 == target_hash: + return True + else: + return False -def reader(): - data_dir = download_unzip() - file_list = os.path.join(data_dir, 'val_list.txt') - output_file = os.path.join(data_dir, 'int8_full_val.bin') +def convert(file_list, data_dir, output_file): + print('Converting 50000 images to binary file ...\n') with open(file_list) as flist: lines = [line.strip() for line in flist] num_images = len(lines) - if not os.path.exists(output_file): - print( - 'Preprocessing to binary file......\n' - ) - with open(output_file, "w+b") as of: - #save num_images(int64_t) to file - of.seek(0) - num = np.array(int(num_images)).astype('int64') - of.write(num.tobytes()) - for idx, line in enumerate(lines): - img_path, label = line.split() - img_path = os.path.join(data_dir, img_path) - if not os.path.exists(img_path): - continue - - #save image(float32) to file - img = process_image( - img_path, 'val', color_jitter=False, rotate=False) - np_img = np.array(img) - of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 - * idx) - of.write(np_img.astype('float32').tobytes()) - - #save label(int64_t) to file - label_int = (int)(label) - np_label = np.array(label_int) - of.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 - * num_images + idx * SIZE_INT64) - of.write(np_label.astype('int64').tobytes()) - - print('The preprocessed binary file path {}\n'.format(output_file)) + with open(output_file, "w+b") as ofs: + #save num_images(int64_t) to file + ofs.seek(0) + num = np.array(int(num_images)).astype('int64') + ofs.write(num.tobytes()) + per_parts = 1000 + full_parts = FULL_IMAGES / per_parts + print_processbar(0, full_parts) + for idx, line in enumerate(lines): + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + if not os.path.exists(img_path): + continue + + #save image(float32) to file + img = process_image( + img_path, 'val', color_jitter=False, rotate=False) + np_img = np.array(img) + ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * + idx) + ofs.write(np_img.astype('float32').tobytes()) + ofs.flush() + + #save label(int64_t) to file + label_int = (int)(label) + np_label = np.array(label_int) + ofs.seek(SIZE_INT64 + SIZE_FLOAT32 * DATA_DIM * DATA_DIM * 3 * + num_images + idx * SIZE_INT64) + ofs.write(np_label.astype('int64').tobytes()) + ofs.flush() + if (idx + 1) % per_parts == 0: + done = (idx + 1) / per_parts + print_processbar(done, full_parts) + print("Conversion finished.") + + +def run_convert(): + print('Start to download and convert 50000 images to binary file...') + cache_folder = os.path.expanduser('~/.cache/paddle/dataset/int8/download') + extract_folder = os.path.join(cache_folder, 'full_data') + data_dir = os.path.join(extract_folder, DATA_DIR_NAME) + file_list = os.path.join(data_dir, 'val_list.txt') + zip_path = os.path.join(cache_folder, 'full_imagenet_val.tar.gz') + output_file = os.path.join(cache_folder, 'int8_full_val.bin') + retry = 0 + try_limit = 3 + + while not (os.path.exists(output_file) and + os.path.getsize(output_file) == FULL_SIZE_BYTES and + check_integrity(output_file, TARGET_HASH)): + if os.path.exists(output_file): + sys.stderr.write( + "\n\nThe existing binary file is broken. Start to generate new one...\n\n". + format(output_file)) + os.remove(output_file) + if retry < try_limit: + retry = retry + 1 + else: + raise RuntimeError( + "Can not convert the dataset to binary file with try limit {0}". + format(try_limit)) + download_concat(cache_folder, zip_path) + extract(zip_path, extract_folder) + convert(file_list, data_dir, output_file) + print("\nSuccess! The binary file can be found at {0}".format(output_file)) if __name__ == '__main__': - reader() + run_convert() diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 33f1d0254858814be20eee1a6c2faaf00c2e8178..9a0dcc722cf00984b8c0e3ac20f13849e2904102 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -41,7 +41,10 @@ DEFINE_string(model_name, "", "model name"); DEFINE_string(infer_model, "", "model path"); DEFINE_string(infer_data, "", "data file"); DEFINE_string(refer_result, "", "reference result for comparison"); -DEFINE_int32(batch_size, 1, "batch size."); +DEFINE_int32(batch_size, 1, "batch size"); +DEFINE_int32(warmup_batch_size, 100, "batch size for quantization warmup"); +// setting iterations to 0 means processing the whole dataset +DEFINE_int32(iterations, 0, "number of batches to process"); DEFINE_int32(repeat, 1, "Running the inference program repeat times."); DEFINE_bool(test_all_data, false, "Test the all dataset in data file."); DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads."); @@ -239,7 +242,7 @@ void SetFakeImageInput(std::vector> *inputs, } input.shape = shape; input.dtype = PaddleDType::FLOAT32; - size_t len = std::accumulate(shape.begin(), shape.end(), 1, + size_t len = std::accumulate(shape.begin(), shape.end(), size_t{1}, [](int a, int b) { return a * b; }); input.data.Resize(len * sizeof(float)); input.lod.assign({{0, static_cast(FLAGS_batch_size)}}); @@ -286,17 +289,18 @@ void ConvertPaddleTensorToZeroCopyTensor( void PredictionWarmUp(PaddlePredictor *predictor, const std::vector> &inputs, - std::vector *outputs, int num_threads, - int tid) { + std::vector> *outputs, + int num_threads, int tid) { int batch_size = FLAGS_batch_size; LOG(INFO) << "Running thread " << tid << ", warm up run..."; if (FLAGS_zero_copy) { ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[0]); } + outputs->resize(1); Timer warmup_timer; warmup_timer.tic(); if (!FLAGS_zero_copy) { - predictor->Run(inputs[0], outputs, batch_size); + predictor->Run(inputs[0], &(*outputs)[0], batch_size); } else { predictor->ZeroCopyRun(); } @@ -308,11 +312,16 @@ void PredictionWarmUp(PaddlePredictor *predictor, void PredictionRun(PaddlePredictor *predictor, const std::vector> &inputs, - std::vector *outputs, int num_threads, - int tid) { - int batch_size = FLAGS_batch_size; + std::vector> *outputs, + int num_threads, int tid) { int num_times = FLAGS_repeat; - LOG(INFO) << "Thread " << tid << " run " << num_times << " times..."; + int iterations = inputs.size(); // process the whole dataset ... + if (FLAGS_iterations > 0 && FLAGS_iterations < inputs.size()) + iterations = + FLAGS_iterations; // ... unless the number of iterations is set + outputs->resize(iterations); + LOG(INFO) << "Thread " << tid << ", number of threads " << num_threads + << ", run " << num_times << " times..."; Timer run_timer; double elapsed_time = 0; #ifdef WITH_GPERFTOOLS @@ -320,14 +329,14 @@ void PredictionRun(PaddlePredictor *predictor, #endif if (!FLAGS_zero_copy) { run_timer.tic(); - for (size_t i = 0; i < inputs.size(); i++) { + for (size_t i = 0; i < iterations; i++) { for (int j = 0; j < num_times; j++) { - predictor->Run(inputs[i], outputs, batch_size); + predictor->Run(inputs[i], &(*outputs)[i], FLAGS_batch_size); } } elapsed_time = run_timer.toc(); } else { - for (size_t i = 0; i < inputs.size(); i++) { + for (size_t i = 0; i < iterations; i++) { ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[i]); run_timer.tic(); for (int j = 0; j < num_times; j++) { @@ -340,13 +349,14 @@ void PredictionRun(PaddlePredictor *predictor, ProfilerStop(); #endif - PrintTime(batch_size, num_times, num_threads, tid, elapsed_time / num_times, - inputs.size()); + auto batch_latency = elapsed_time / (iterations * num_times); + PrintTime(FLAGS_batch_size, num_times, num_threads, tid, batch_latency, + iterations); if (FLAGS_record_benchmark) { Benchmark benchmark; benchmark.SetName(FLAGS_model_name); - benchmark.SetBatchSize(batch_size); - benchmark.SetLatency(elapsed_time / num_times); + benchmark.SetBatchSize(FLAGS_batch_size); + benchmark.SetLatency(batch_latency); benchmark.PersistToFile("benchmark_record.txt"); } } @@ -354,16 +364,17 @@ void PredictionRun(PaddlePredictor *predictor, void TestOneThreadPrediction( const PaddlePredictor::Config *config, const std::vector> &inputs, - std::vector *outputs, bool use_analysis = true) { + std::vector> *outputs, bool use_analysis = true) { auto predictor = CreateTestPredictor(config, use_analysis); - PredictionWarmUp(predictor.get(), inputs, outputs, 1, 0); - PredictionRun(predictor.get(), inputs, outputs, 1, 0); + PredictionWarmUp(predictor.get(), inputs, outputs, FLAGS_paddle_num_threads, + 0); + PredictionRun(predictor.get(), inputs, outputs, FLAGS_paddle_num_threads, 0); } void TestMultiThreadPrediction( const PaddlePredictor::Config *config, const std::vector> &inputs, - std::vector *outputs, int num_threads, + std::vector> *outputs, int num_threads, bool use_analysis = true) { std::vector threads; std::vector> predictors; @@ -376,7 +387,7 @@ void TestMultiThreadPrediction( threads.emplace_back([&, tid]() { // Each thread should have local inputs and outputs. // The inputs of each thread are all the same. - std::vector outputs_tid; + std::vector> outputs_tid; auto &predictor = predictors[tid]; #ifdef PADDLE_WITH_MKLDNN if (use_analysis) { @@ -384,8 +395,8 @@ void TestMultiThreadPrediction( ->SetMkldnnThreadID(static_cast(tid) + 1); } #endif - PredictionWarmUp(predictor.get(), inputs, outputs, num_threads, tid); - PredictionRun(predictor.get(), inputs, outputs, num_threads, tid); + PredictionWarmUp(predictor.get(), inputs, &outputs_tid, num_threads, tid); + PredictionRun(predictor.get(), inputs, &outputs_tid, num_threads, tid); }); } for (int i = 0; i < num_threads; ++i) { @@ -395,8 +406,8 @@ void TestMultiThreadPrediction( void TestPrediction(const PaddlePredictor::Config *config, const std::vector> &inputs, - std::vector *outputs, int num_threads, - bool use_analysis = FLAGS_use_analysis) { + std::vector> *outputs, + int num_threads, bool use_analysis = FLAGS_use_analysis) { PrintConfig(config, use_analysis); if (num_threads == 1) { TestOneThreadPrediction(config, inputs, outputs, use_analysis); @@ -406,30 +417,41 @@ void TestPrediction(const PaddlePredictor::Config *config, } } -void CompareTopAccuracy(const std::vector &output_slots1, - const std::vector &output_slots2) { - // first output: avg_cost - if (output_slots1.size() == 0 || output_slots2.size() == 0) +void CompareTopAccuracy( + const std::vector> &output_slots_quant, + const std::vector> &output_slots_ref) { + if (output_slots_quant.size() == 0 || output_slots_ref.size() == 0) throw std::invalid_argument( "CompareTopAccuracy: output_slots vector is empty."); - PADDLE_ENFORCE(output_slots1.size() >= 2UL); - PADDLE_ENFORCE(output_slots2.size() >= 2UL); - // second output: acc_top1 - if (output_slots1[1].lod.size() > 0 || output_slots2[1].lod.size() > 0) - throw std::invalid_argument( - "CompareTopAccuracy: top1 accuracy output has nonempty LoD."); - if (output_slots1[1].dtype != paddle::PaddleDType::FLOAT32 || - output_slots2[1].dtype != paddle::PaddleDType::FLOAT32) - throw std::invalid_argument( - "CompareTopAccuracy: top1 accuracy output is of a wrong type."); - float *top1_quantized = static_cast(output_slots1[1].data.data()); - float *top1_reference = static_cast(output_slots2[1].data.data()); - LOG(INFO) << "top1 INT8 accuracy: " << *top1_quantized; - LOG(INFO) << "top1 FP32 accuracy: " << *top1_reference; + float total_accs1_quant{0}; + float total_accs1_ref{0}; + for (size_t i = 0; i < output_slots_quant.size(); ++i) { + PADDLE_ENFORCE(output_slots_quant[i].size() >= 2UL); + PADDLE_ENFORCE(output_slots_ref[i].size() >= 2UL); + // second output: acc_top1 + if (output_slots_quant[i][1].lod.size() > 0 || + output_slots_ref[i][1].lod.size() > 0) + throw std::invalid_argument( + "CompareTopAccuracy: top1 accuracy output has nonempty LoD."); + if (output_slots_quant[i][1].dtype != paddle::PaddleDType::FLOAT32 || + output_slots_ref[i][1].dtype != paddle::PaddleDType::FLOAT32) + throw std::invalid_argument( + "CompareTopAccuracy: top1 accuracy output is of a wrong type."); + total_accs1_quant += + *static_cast(output_slots_quant[i][1].data.data()); + total_accs1_ref += + *static_cast(output_slots_ref[i][1].data.data()); + } + float avg_acc1_quant = total_accs1_quant / output_slots_quant.size(); + float avg_acc1_ref = total_accs1_ref / output_slots_ref.size(); + + LOG(INFO) << "Avg top1 INT8 accuracy: " << std::fixed << std::setw(6) + << std::setprecision(4) << avg_acc1_quant; + LOG(INFO) << "Avg top1 FP32 accuracy: " << std::fixed << std::setw(6) + << std::setprecision(4) << avg_acc1_ref; LOG(INFO) << "Accepted accuracy drop threshold: " << FLAGS_quantized_accuracy; - CHECK_LE(std::abs(*top1_quantized - *top1_reference), - FLAGS_quantized_accuracy); + CHECK_LE(std::abs(avg_acc1_quant - avg_acc1_ref), FLAGS_quantized_accuracy); } void CompareDeterministic( @@ -455,20 +477,35 @@ void CompareNativeAndAnalysis( const PaddlePredictor::Config *config, const std::vector> &inputs) { PrintConfig(config, true); - std::vector native_outputs, analysis_outputs; + std::vector> native_outputs, analysis_outputs; TestOneThreadPrediction(config, inputs, &native_outputs, false); TestOneThreadPrediction(config, inputs, &analysis_outputs, true); - CompareResult(analysis_outputs, native_outputs); + PADDLE_ENFORCE(native_outputs.size() > 0, "Native output is empty."); + PADDLE_ENFORCE(analysis_outputs.size() > 0, "Analysis output is empty."); + CompareResult(analysis_outputs.back(), native_outputs.back()); } void CompareQuantizedAndAnalysis( - const PaddlePredictor::Config *config, - const PaddlePredictor::Config *qconfig, + const AnalysisConfig *config, const AnalysisConfig *qconfig, const std::vector> &inputs) { - PrintConfig(config, true); - std::vector analysis_outputs, quantized_outputs; - TestOneThreadPrediction(config, inputs, &analysis_outputs, true); - TestOneThreadPrediction(qconfig, inputs, &quantized_outputs, true); + PADDLE_ENFORCE_EQ(inputs[0][0].shape[0], FLAGS_batch_size, + "Input data has to be packed batch by batch."); + LOG(INFO) << "FP32 & INT8 prediction run: batch_size " << FLAGS_batch_size + << ", warmup batch size " << FLAGS_warmup_batch_size << "."; + + LOG(INFO) << "--- FP32 prediction start ---"; + auto *cfg = reinterpret_cast(config); + PrintConfig(cfg, true); + std::vector> analysis_outputs; + TestOneThreadPrediction(cfg, inputs, &analysis_outputs, true); + + LOG(INFO) << "--- INT8 prediction start ---"; + auto *qcfg = reinterpret_cast(qconfig); + PrintConfig(qcfg, true); + std::vector> quantized_outputs; + TestOneThreadPrediction(qcfg, inputs, &quantized_outputs, true); + + LOG(INFO) << "--- comparing outputs --- "; CompareTopAccuracy(quantized_outputs, analysis_outputs); } @@ -578,9 +615,9 @@ static bool CompareTensorData(const framework::LoDTensor &a, const framework::LoDTensor &b) { auto a_shape = framework::vectorize(a.dims()); auto b_shape = framework::vectorize(b.dims()); - size_t a_size = std::accumulate(a_shape.begin(), a_shape.end(), 1, + size_t a_size = std::accumulate(a_shape.begin(), a_shape.end(), size_t{1}, [](int a, int b) { return a * b; }); - size_t b_size = std::accumulate(b_shape.begin(), b_shape.end(), 1, + size_t b_size = std::accumulate(b_shape.begin(), b_shape.end(), size_t{1}, [](int a, int b) { return a * b; }); if (a_size != b_size) { LOG(ERROR) << string::Sprintf("tensor data size not match, %d != %d", diff --git a/paddle/fluid/inference/tests/api/trt_models_tester.cc b/paddle/fluid/inference/tests/api/trt_models_tester.cc index cb668a4174134ba3ce9517955ff740ada568e97b..98ce225a0476b38c021b0b81489f69d7953ae456 100644 --- a/paddle/fluid/inference/tests/api/trt_models_tester.cc +++ b/paddle/fluid/inference/tests/api/trt_models_tester.cc @@ -74,7 +74,7 @@ void profile(std::string model_dir, bool use_analysis, bool use_tensorrt) { SetFakeImageInput(&inputs_all, model_dir, false, "__model__", ""); } - std::vector outputs; + std::vector> outputs; if (use_analysis || use_tensorrt) { AnalysisConfig config; config.EnableUseGpu(100, 0);