提交 1a373fbb 编写于 作者: L luotao1

add result check for multi-thread UT

上级 2dc23ffa
......@@ -234,6 +234,26 @@ const float ditu_rnn_target_data[] = {
10.7286, 12.0595, 10.6672, 0, 0, 0, 0, 0,
93.5771, 3.84641, 0, 0, 0, 0, 0, 0,
169.426, 0, 0, 0, 0, 0, 0, 0};
void CompareResult(const std::vector<PaddleTensor> &outputs,
const std::vector<PaddleTensor> &base_outputs) {
PADDLE_ENFORCE_GT(outputs.size(), 0);
PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size());
for (size_t i = 0; i < outputs.size(); i++) {
auto &out = outputs[i];
auto &base_out = base_outputs[i];
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
[](int a, int b) { return a * b; });
size_t size1 = std::accumulate(base_out.shape.begin(), base_out.shape.end(),
1, [](int a, int b) { return a * b; });
PADDLE_ENFORCE_EQ(size, size1);
PADDLE_ENFORCE_GT(size, 0);
float *data = static_cast<float *>(out.data.data());
float *base_data = static_cast<float *>(base_out.data.data());
for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(data[i], base_data[i], 1e-3);
}
}
}
// Test with a really complicate model.
void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
int num_threads = FLAGS_num_threads) {
......@@ -266,7 +286,8 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
for (int i = 0; i < num_times; i++) {
predictor->Run(input_slots, &outputs);
}
print_time(batch_size, num_times, 1, 0, timer.toc() / num_times);
PrintTime(batch_size, num_times, 1, 0, timer.toc() / num_times);
CompareResult(outputs, base_outputs);
} else {
std::vector<std::thread> threads;
std::vector<std::unique_ptr<PaddlePredictor>> predictors;
......@@ -279,13 +300,19 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
}
for (int tid = 0; tid < num_threads; ++tid) {
threads.emplace_back([&, tid]() {
// Each thread should have local input_slots and outputs.
std::vector<PaddleTensor> input_slots;
DataRecord data(FLAGS_infer_ditu_rnn_data, batch_size);
PrepareInputs(&input_slots, &data, batch_size);
std::vector<PaddleTensor> outputs;
Timer timer;
timer.tic();
for (int i = 0; i < num_times; i++) {
predictors[tid]->Run(input_slots, &outputs);
}
print_time(batch_size, num_times, num_threads, tid,
timer.toc() / num_times);
PrintTime(batch_size, num_times, num_threads, tid,
timer.toc() / num_times);
CompareResult(outputs, base_outputs);
});
}
for (int i = 0; i < num_threads; ++i) {
......@@ -294,27 +321,6 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
}
LOG(INFO) << "=====================================";
if (num_threads == 1) {
PADDLE_ENFORCE_GT(outputs.size(), 0);
PADDLE_ENFORCE_EQ(outputs.size(), base_outputs.size());
for (size_t i = 0; i < outputs.size(); i++) {
auto &out = outputs[i];
auto &base_out = base_outputs[i];
size_t size = std::accumulate(out.shape.begin(), out.shape.end(), 1,
[](int a, int b) { return a * b; });
size_t size1 =
std::accumulate(base_out.shape.begin(), base_out.shape.end(), 1,
[](int a, int b) { return a * b; });
PADDLE_ENFORCE_EQ(size, size1);
PADDLE_ENFORCE_GT(size, 0);
float *data = static_cast<float *>(out.data.data());
float *base_data = static_cast<float *>(base_out.data.data());
for (size_t i = 0; i < size; i++) {
EXPECT_NEAR(data[i], base_data[i], 1e-3);
}
}
}
if (use_analysis_and_activate_ir) {
AnalysisPredictor *analysis_predictor =
dynamic_cast<AnalysisPredictor *>(predictor.get());
......@@ -342,13 +348,13 @@ void TestDituRNNPrediction(bool use_analysis_and_activate_ir = false,
}
}
TEST(Analyzer, DituRNN) {
// default FLAGS_num_threads = 1
TestDituRNNPrediction(false, FLAGS_num_threads);
TestDituRNNPrediction(true, FLAGS_num_threads);
}
// basic unit-test of DituRNN, easy for profiling independently.
TEST(Analyzer, DituRNN) { TestDituRNNPrediction(false, FLAGS_num_threads); }
// advance unit-test of DituRNN, test use_analysis_and_activate_ir and
// multi-threads.
TEST(Analyzer, DituRNN_multi_thread) {
TestDituRNNPrediction(true, 1);
TestDituRNNPrediction(false, 4);
TestDituRNNPrediction(true, 4);
}
......
......@@ -122,8 +122,8 @@ std::string DescribeTensor(const PaddleTensor &tensor) {
return os.str();
}
void print_time(int batch_size, int repeat, int num_threads, int tid,
double latency) {
void PrintTime(int batch_size, int repeat, int num_threads, int tid,
double latency) {
LOG(INFO) << "batch_size: " << batch_size << ", repeat: " << repeat
<< ", threads: " << num_threads << ", thread id: " << tid
<< ", latency: " << latency << "ms";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册