提交 f615ba2f 编写于 作者: L luotao1

update the multi-thread unit-tests

上级 35cff5e0
......@@ -255,8 +255,8 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
}
}
// Test with a really complicate model.
void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
int num_threads = FLAGS_num_threads) {
void TestDituRNNPrediction(bool use_analysis, bool activate_ir,
int num_threads) {
AnalysisConfig config;
config.prog_file = FLAGS_infer_ditu_rnn_model + "/__model__";
config.param_file = FLAGS_infer_ditu_rnn_model + "/param";
......@@ -300,7 +300,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
// because AttentionLSTM's hard code nodeid will be damanged.
for (int tid = 0; tid < num_threads; ++tid) {
predictors.emplace_back(
CreatePaddlePredictor<NativeConfig, PaddleEngineKind::kAnalysis>(
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
config));
}
for (int tid = 0; tid < num_threads; ++tid) {
......@@ -326,7 +326,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
}
LOG(INFO) << "=====================================";
if (use_analysis_and_activate_ir) {
if (use_analysis && activate_ir) {
AnalysisPredictor *analysis_predictor =
dynamic_cast<AnalysisPredictor *>(predictor.get());
auto &fuse_statis = analysis_predictor->analysis_argument()
......@@ -353,15 +353,26 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
}
}
// basic unit-test of DituRNN, easy for profiling independently.
TEST(Analyzer, DituRNN) { TestDituRNNPrediction(false, FLAGS_num_threads); }
// Inference with analysis and IR, easy for profiling independently.
TEST(Analyzer, DituRNN) {
TestDituRNNPrediction(true, true, FLAGS_num_threads);
}
// advance unit-test of DituRNN, test use_analysis_and_activate_ir and
// multi-threads.
TEST(Analyzer, DituRNN_multi_thread) {
TestDituRNNPrediction(true, 1);
TestDituRNNPrediction(false, 4);
TestDituRNNPrediction(true, 4);
// Other unit-tests of DituRNN, test different options of use_analysis,
// activate_ir and multi-threads.
TEST(Analyzer, DituRNN_tests) {
int num_threads[2] = {1, 4};
for (auto i : num_threads) {
// Directly infer with the original model.
TestDituRNNPrediction(false, false, i);
// Inference with the original model with the analysis turned on, the
// analysis
// module will transform the program to a data flow graph.
TestDituRNNPrediction(true, false, i);
// Inference with analysis and IR. The IR module will fuse some large
// kernels.
TestDituRNNPrediction(true, true, i);
}
}
} // namespace analysis
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册