From bd77460182d083cb0b3cd8277623181ef3473145 Mon Sep 17 00:00:00 2001 From: Tao Luo Date: Tue, 16 Oct 2018 11:07:02 +0800 Subject: [PATCH] refine mkldnn test in analyzer_tests test=develop --- .../inference/tests/api/analyzer_resnet50_tester.cc | 13 ++++++++++++- .../inference/tests/api/analyzer_vis_tester.cc | 12 ++++++++++-- paddle/fluid/inference/tests/api/tester_helper.h | 6 +++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc index 290fb007d8..050f267fff 100644 --- a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -20,13 +20,16 @@ namespace paddle { namespace inference { namespace analysis { -void SetConfig(AnalysisConfig *cfg) { +void SetConfig(AnalysisConfig *cfg, bool _use_mkldnn = FLAGS__use_mkldnn) { cfg->param_file = FLAGS_infer_model + "/params"; cfg->prog_file = FLAGS_infer_model + "/model"; cfg->use_gpu = false; cfg->device = 0; cfg->enable_ir_optim = true; cfg->specify_input_name = true; +#ifdef PADDLE_WITH_MKLDNN + cfg->_use_mkldnn = _use_mkldnn; +#endif } void SetInput(std::vector> *inputs) { @@ -89,6 +92,14 @@ TEST(Analyzer_resnet50, compare) { std::vector> input_slots_all; SetInput(&input_slots_all); CompareNativeAndAnalysis(cfg, input_slots_all); +#ifdef PADDLE_WITH_MKLDNN + // since default config._use_mkldnn=true in this case, + // we should compare analysis_outputs in config._use_mkldnn=false + // with native_outputs as well. + AnalysisConfig cfg1; + SetConfig(&cfg1, false); + CompareNativeAndAnalysis(cfg1, input_slots_all); +#endif } } // namespace analysis diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc index 305b8bfe15..07398ed26c 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc @@ -50,7 +50,7 @@ Record ProcessALine(const std::string &line) { return record; } -void SetConfig(AnalysisConfig *cfg) { +void SetConfig(AnalysisConfig *cfg, bool _use_mkldnn = FLAGS__use_mkldnn) { cfg->param_file = FLAGS_infer_model + "/__params__"; cfg->prog_file = FLAGS_infer_model + "/__model__"; cfg->use_gpu = false; @@ -60,7 +60,7 @@ void SetConfig(AnalysisConfig *cfg) { // TODO(TJ): fix fusion gru cfg->ir_passes.push_back("fc_gru_fuse_pass"); #ifdef PADDLE_WITH_MKLDNN - cfg->_use_mkldnn = true; + cfg->_use_mkldnn = _use_mkldnn; #endif } @@ -125,6 +125,14 @@ TEST(Analyzer_vis, compare) { std::vector> input_slots_all; SetInput(&input_slots_all); CompareNativeAndAnalysis(cfg, input_slots_all); +#ifdef PADDLE_WITH_MKLDNN + // since default config._use_mkldnn=true in this case, + // we should compare analysis_outputs in config._use_mkldnn=false + // with native_outputs as well. + AnalysisConfig cfg1; + SetConfig(&cfg1, false); + CompareNativeAndAnalysis(cfg1, input_slots_all); +#endif } } // namespace analysis diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 8603d09cbd..fe3ee5bcd7 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -35,6 +35,8 @@ DEFINE_bool(test_all_data, false, "Test the all dataset in data file."); DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads."); DEFINE_bool(use_analysis, true, "Running the inference program in analysis mode."); +DEFINE_bool(_use_mkldnn, true, + "Running the inference program with mkldnn library."); namespace paddle { namespace inference { @@ -165,7 +167,8 @@ void TestPrediction(const AnalysisConfig &config, const std::vector> &inputs, std::vector *outputs, int num_threads, bool use_analysis = FLAGS_use_analysis) { - LOG(INFO) << "use_analysis: " << use_analysis; + LOG(INFO) << "use_analysis: " << use_analysis + << ", use_mkldnn: " << config._use_mkldnn; if (num_threads == 1) { TestOneThreadPrediction(config, inputs, outputs, use_analysis); } else { @@ -177,6 +180,7 @@ void TestPrediction(const AnalysisConfig &config, void CompareNativeAndAnalysis( const AnalysisConfig &config, const std::vector> &inputs) { + LOG(INFO) << "use_mkldnn: " << config._use_mkldnn; std::vector native_outputs, analysis_outputs; TestOneThreadPrediction(config, inputs, &native_outputs, false); TestOneThreadPrediction(config, inputs, &analysis_outputs, true); -- GitLab