From 39ed1487144153b5a13cb8943c526b635d65d795 Mon Sep 17 00:00:00 2001 From: luotao1 Date: Tue, 4 Sep 2018 16:09:44 +0800 Subject: [PATCH] fix multi-thread hang temporary --- .../inference/analysis/analyzer_tester.cc | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/paddle/fluid/inference/analysis/analyzer_tester.cc b/paddle/fluid/inference/analysis/analyzer_tester.cc index d36c5bfb75..3aa28479af 100644 --- a/paddle/fluid/inference/analysis/analyzer_tester.cc +++ b/paddle/fluid/inference/analysis/analyzer_tester.cc @@ -260,11 +260,7 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false, LOG(INFO) << "===========profile result==========="; if (num_threads == 1) { - std::vector input_slots; // Prepare inputs. - DataRecord data(FLAGS_infer_ditu_rnn_data, batch_size); - PrepareInputs(&input_slots, &data, batch_size); - Timer timer; timer.tic(); for (int i = 0; i < num_times; i++) { @@ -273,21 +269,20 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false, print_time(batch_size, num_times, 1, 0, timer.toc() / num_times); } else { std::vector threads; - std::vector input_slots; - // Prepare inputs. - PrepareInputs(&input_slots, &data, batch_size); - std::vector outputs; + std::vector> predictors; + // TODO(yanchunwei): Bug here, the analyzer phase can't be parallelled + // because AttentionLSTM's hard code nodeid will be damanged. + for (int tid = 0; tid < num_threads; ++tid) { + predictors.emplace_back( + CreatePaddlePredictor( + config)); + } for (int tid = 0; tid < num_threads; ++tid) { threads.emplace_back([&, tid]() { - auto predictor_tid = - CreatePaddlePredictor( - config); - DataRecord data(FLAGS_infer_ditu_rnn_data, batch_size); - Timer timer; timer.tic(); for (int i = 0; i < num_times; i++) { - predictor_tid->Run(input_slots, &outputs); + predictors[tid]->Run(input_slots, &outputs); } print_time(batch_size, num_times, num_threads, tid, timer.toc() / num_times); @@ -348,8 +343,9 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false, } TEST(Analyzer, DituRNN) { - TestDituRNNPrediction(false, 1); - TestDituRNNPrediction(true, 1); + // default FLAGS_num_threads = 1 + TestDituRNNPrediction(false, FLAGS_num_threads); + TestDituRNNPrediction(true, FLAGS_num_threads); } TEST(Analyzer, DituRNN_multi_thread) { -- GitLab