提交 f615ba2f 编写于 作者: L luotao1

update the multi-thread unit-tests

上级 35cff5e0
...@@ -255,8 +255,8 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -255,8 +255,8 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
} }
} }
// Test with a really complicate model. // Test with a really complicate model.
void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false, void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
int num_threads = FLAGS_num_threads) { int num_threads) {
AnalysisConfig config; AnalysisConfig config;
config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__"; config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__";
config.param_file = FLAGS_infer_ditu_rnn_model + "/param"; config.param_file = FLAGS_infer_ditu_rnn_model + "/param";
...@@ -300,7 +300,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false, ...@@ -300,7 +300,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
// because AttentionLSTM's hard code nodeid will be damanged. // because AttentionLSTM's hard code nodeid will be damanged.
for (int tid = 0; tid < num_threads; ++tid) { for (int tid = 0; tid < num_threads; ++tid) {
predictors.emplace_back( predictors.emplace_back(
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>( CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
config)); config));
} }
for (int tid = 0; tid < num_threads; ++tid) { for (int tid = 0; tid < num_threads; ++tid) {
...@@ -326,7 +326,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false, ...@@ -326,7 +326,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
} }
LOG(INFO) << "====================================="; LOG(INFO) << "=====================================";
if (use_analysis_and_activate_ir) { if (use_analysis && activate_ir) {
AnalysisPredictor *analysis_predictor = AnalysisPredictor *analysis_predictor =
dynamic_cast<AnalysisPredictor *>(predictor.get()); dynamic_cast<AnalysisPredictor *>(predictor.get());
auto &fuse_statis = analysis_predictor->analysis_argument() auto &fuse_statis = analysis_predictor->analysis_argument()
...@@ -353,15 +353,26 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false, ...@@ -353,15 +353,26 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
} }
} }
// basic unit-test of DituRNN, easy for profiling independently. // Inference with analysis and IR, easy for profiling independently.
TEST(Analyzer, DituRNN) { TestDituRNNPrediction(false, FLAGS_num_threads); } TEST(Analyzer, DituRNN) {
TestDituRNNPrediction(true, true, FLAGS_num_threads);
}
// advance unit-test of DituRNN, test use_analysis_and_activate_ir and // Other unit-tests of DituRNN, test different options of use_analysis,
// multi-threads. // activate_ir and multi-threads.
TEST(Analyzer, DituRNN_multi_thread) { TEST(Analyzer, DituRNN_tests) {
TestDituRNNPrediction(true, 1); int num_threads[2] = {1, 4};
TestDituRNNPrediction(false, 4); for (auto i : num_threads) {
TestDituRNNPrediction(true, 4); // Directly infer with the original model.
TestDituRNNPrediction(false, false, i);
// Inference with the original model with the analysis turned on, the
// analysis
// module will transform the program to a data flow graph.
TestDituRNNPrediction(true, false, i);
// Inference with analysis and IR. The IR module will fuse some large
// kernels.
TestDituRNNPrediction(true, true, i);
}
} }
} // namespace analysis } // namespace analysis
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册