diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index 29c86bf787bea4f5ab1f647295e40ff7deabc84e..4cf26d3c70eafd951d14c26335416ec2c71c001d 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -255,8 +255,8 @@ void CompareResult(const std::vector &outputs, } } // Test with a really complicate model. -void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false, - int num_threads = FLAGS_num_threads) { +void TestDituRNNPrediction(bool use_analysis, bool activate_ir, + int num_threads) { AnalysisConfig config; config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__"; config.param_file = FLAGS_infer_ditu_rnn_model + "/param"; @@ -300,7 +300,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false, // because AttentionLSTM's hard code nodeid will be damanged. for (int tid = 0; tid < num_threads; ++tid) { predictors.emplace_back( - CreatePaddlePredictor( + CreatePaddlePredictor( config)); } for (int tid = 0; tid < num_threads; ++tid) { @@ -326,7 +326,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false, } LOG(INFO) << "====================================="; - if (use_analysis_and_activate_ir) { + if (use_analysis && activate_ir) { AnalysisPredictor *analysis_predictor = dynamic_cast(predictor.get()); auto &fuse_statis = analysis_predictor->analysis_argument() @@ -353,15 +353,26 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false, } } -// basic unit-test of DituRNN, easy for profiling independently. -TEST(Analyzer, DituRNN) { TestDituRNNPrediction(false, FLAGS_num_threads); } +// Inference with analysis and IR, easy for profiling independently. +TEST(Analyzer, DituRNN) { + TestDituRNNPrediction(true, true, FLAGS_num_threads); +} -// advance unit-test of DituRNN, test use_analysis_and_activate_ir and -// multi-threads. -TEST(Analyzer, DituRNN_multi_thread) { - TestDituRNNPrediction(true, 1); - TestDituRNNPrediction(false, 4); - TestDituRNNPrediction(true, 4); +// Other unit-tests of DituRNN, test different options of use_analysis, +// activate_ir and multi-threads. +TEST(Analyzer, DituRNN_tests) { + int num_threads[2] = {1, 4}; + for (auto i : num_threads) { + // Directly infer with the original model. + TestDituRNNPrediction(false, false, i); + // Inference with the original model with the analysis turned on, the + // analysis + // module will transform the program to a data flow graph. + TestDituRNNPrediction(true, false, i); + // Inference with analysis and IR. The IR module will fuse some large + // kernels. + TestDituRNNPrediction(true, true, i); + } } } // namespace analysis