diff --git a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc index 290fb007d8ba94a2d121947fe67c6474586ac0e0..050f267fff3b82e60e6ab5c17be0f251161bbfab 100644 --- a/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_resnet50_tester.cc @@ -20,13 +20,16 @@ namespace paddle { namespace inference { namespace analysis { -void SetConfig(AnalysisConfig *cfg) { +void SetConfig(AnalysisConfig *cfg, bool _use_mkldnn = FLAGS__use_mkldnn) { cfg->param_file = FLAGS_infer_model + "/params"; cfg->prog_file = FLAGS_infer_model + "/model"; cfg->use_gpu = false; cfg->device = 0; cfg->enable_ir_optim = true; cfg->specify_input_name = true; +#ifdef PADDLE_WITH_MKLDNN + cfg->_use_mkldnn = _use_mkldnn; +#endif } void SetInput(std::vector> *inputs) { @@ -89,6 +92,14 @@ TEST(Analyzer_resnet50, compare) { std::vector> input_slots_all; SetInput(&input_slots_all); CompareNativeAndAnalysis(cfg, input_slots_all); +#ifdef PADDLE_WITH_MKLDNN + // since default config._use_mkldnn=true in this case, + // we should compare analysis_outputs in config._use_mkldnn=false + // with native_outputs as well. + AnalysisConfig cfg1; + SetConfig(&cfg1, false); + CompareNativeAndAnalysis(cfg1, input_slots_all); +#endif } } // namespace analysis diff --git a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc index 305b8bfe158150d5dfd8bdaee2c0a89afe264de4..07398ed26c745de46309a15675da07686eca3d9e 100644 --- a/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc +++ b/paddle/fluid/inference/tests/api/analyzer_vis_tester.cc @@ -50,7 +50,7 @@ Record ProcessALine(const std::string &line) { return record; } -void SetConfig(AnalysisConfig *cfg) { +void SetConfig(AnalysisConfig *cfg, bool _use_mkldnn = FLAGS__use_mkldnn) { cfg->param_file = FLAGS_infer_model + "/__params__"; cfg->prog_file = FLAGS_infer_model + "/__model__"; cfg->use_gpu = false; @@ -60,7 +60,7 @@ void SetConfig(AnalysisConfig *cfg) { // TODO(TJ): fix fusion gru cfg->ir_passes.push_back("fc_gru_fuse_pass"); #ifdef PADDLE_WITH_MKLDNN - cfg->_use_mkldnn = true; + cfg->_use_mkldnn = _use_mkldnn; #endif } @@ -125,6 +125,14 @@ TEST(Analyzer_vis, compare) { std::vector> input_slots_all; SetInput(&input_slots_all); CompareNativeAndAnalysis(cfg, input_slots_all); +#ifdef PADDLE_WITH_MKLDNN + // since default config._use_mkldnn=true in this case, + // we should compare analysis_outputs in config._use_mkldnn=false + // with native_outputs as well. + AnalysisConfig cfg1; + SetConfig(&cfg1, false); + CompareNativeAndAnalysis(cfg1, input_slots_all); +#endif } } // namespace analysis diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 8603d09cbd09e4cee254faf52c2ccbe8d661994d..fe3ee5bcd712d288cc73d77c9772b7ba237271af 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -35,6 +35,8 @@ DEFINE_bool(test_all_data, false, "Test the all dataset in data file."); DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads."); DEFINE_bool(use_analysis, true, "Running the inference program in analysis mode."); +DEFINE_bool(_use_mkldnn, true, + "Running the inference program with mkldnn library."); namespace paddle { namespace inference { @@ -165,7 +167,8 @@ void TestPrediction(const AnalysisConfig &config, const std::vector> &inputs, std::vector *outputs, int num_threads, bool use_analysis = FLAGS_use_analysis) { - LOG(INFO) << "use_analysis: " << use_analysis; + LOG(INFO) << "use_analysis: " << use_analysis + << ", use_mkldnn: " << config._use_mkldnn; if (num_threads == 1) { TestOneThreadPrediction(config, inputs, outputs, use_analysis); } else { @@ -177,6 +180,7 @@ void TestPrediction(const AnalysisConfig &config, void CompareNativeAndAnalysis( const AnalysisConfig &config, const std::vector> &inputs) { + LOG(INFO) << "use_mkldnn: " << config._use_mkldnn; std::vector native_outputs, analysis_outputs; TestOneThreadPrediction(config, inputs, &native_outputs, false); TestOneThreadPrediction(config, inputs, &analysis_outputs, true);