diff --git a/paddle/fluid/inference/tests/infer_ut/test_resnet50.cc b/paddle/fluid/inference/tests/infer_ut/test_resnet50.cc index a090f1a90189b868026f86a1ee583321ae169d9e..f497acc4b166ca0c4fd867f2670bb17dfc24bae5 100644 --- a/paddle/fluid/inference/tests/infer_ut/test_resnet50.cc +++ b/paddle/fluid/inference/tests/infer_ut/test_resnet50.cc @@ -127,6 +127,49 @@ TEST(test_resnet50, serial_diff_batch_trt_fp32) { std::cout << "finish test" << std::endl; } +TEST(test_resnet50, multi_thread4_trt_fp32_bz2) { + int thread_num = 4; + // init input data + std::map my_input_data_map; + my_input_data_map["inputs"] = PrepareInput(2); + // init output data + std::map infer_output_data, + truth_output_data; + // prepare groudtruth config + paddle_infer::Config config, config_no_ir; + config_no_ir.SetModel(FLAGS_modeldir + "/inference.pdmodel", + FLAGS_modeldir + "/inference.pdiparams"); + config_no_ir.SwitchIrOptim(false); + // prepare inference config + config.SetModel(FLAGS_modeldir + "/inference.pdmodel", + FLAGS_modeldir + "/inference.pdiparams"); + config.EnableUseGpu(100, 0); + config.EnableTensorRtEngine( + 1 << 20, 2, 3, paddle_infer::PrecisionType::kFloat32, false, false); + // get groudtruth by disbale ir + paddle_infer::services::PredictorPool pred_pool_no_ir(config_no_ir, 1); + SingleThreadPrediction(pred_pool_no_ir.Retrive(0), &my_input_data_map, + &truth_output_data, 1); + + // get infer results from multi threads + std::vector threads; + services::PredictorPool pred_pool(config, thread_num); + for (int i = 0; i < thread_num; ++i) { + threads.emplace_back(paddle::test::SingleThreadPrediction, + pred_pool.Retrive(i), &my_input_data_map, + &infer_output_data, 2); + } + + // thread join & check outputs + for (int i = 0; i < thread_num; ++i) { + LOG(INFO) << "join tid : " << i; + threads[i].join(); + CompareRecord(&truth_output_data, &infer_output_data); + } + + std::cout << "finish multi-thread test" << std::endl; +} + } // namespace paddle_infer int main(int argc, char** argv) { diff --git a/paddle/fluid/inference/tests/infer_ut/test_suite.h b/paddle/fluid/inference/tests/infer_ut/test_suite.h index c3c1b36a6e07a3bd9fc38c04047849618c37202b..0e116b01847bfb9c89d52ab49c2b2a7334de9a93 100644 --- a/paddle/fluid/inference/tests/infer_ut/test_suite.h +++ b/paddle/fluid/inference/tests/infer_ut/test_suite.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "gflags/gflags.h" @@ -117,5 +118,5 @@ void CompareRecord(std::map *truth_output_data, } } -} // namespace demo +} // namespace test } // namespace paddle