提交 54afcb7e 编写于 作者: T tensor-tang

add compare zerocopy test with native result

test=develop
上级 13706013
...@@ -168,15 +168,13 @@ TEST(Analyzer_seq_pool1, compare) { ...@@ -168,15 +168,13 @@ TEST(Analyzer_seq_pool1, compare) {
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all); reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
} }
// Check the fuse status void analysis_fuse_statis(bool use_zerocopy) {
TEST(Analyzer_seq_pool1, fuse_statis) {
AnalysisConfig cfg; AnalysisConfig cfg;
SetConfig(&cfg); SetConfig(&cfg);
cfg.SwitchUseFeedFetchOps(!use_zerocopy);
int num_ops; int num_ops;
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg); auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
auto fuse_statis = GetFuseStatis( auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
ASSERT_TRUE(fuse_statis.count("fc_fuse")); ASSERT_TRUE(fuse_statis.count("fc_fuse"));
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10); ASSERT_EQ(fuse_statis.at("fc_fuse"), 10);
ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse")); ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse"));
...@@ -185,6 +183,9 @@ TEST(Analyzer_seq_pool1, fuse_statis) { ...@@ -185,6 +183,9 @@ TEST(Analyzer_seq_pool1, fuse_statis) {
EXPECT_EQ(num_ops, 195); EXPECT_EQ(num_ops, 195);
} }
// Check the fuse status
TEST(Analyzer_seq_pool1, fuse_statis) { analysis_fuse_statis(false); }
void PrepareZeroCopyInputs( void PrepareZeroCopyInputs(
const std::unique_ptr<PaddlePredictor> &predictor, const std::unique_ptr<PaddlePredictor> &predictor,
std::vector<std::unique_ptr<ZeroCopyTensor>> *inputs) { std::vector<std::unique_ptr<ZeroCopyTensor>> *inputs) {
...@@ -202,7 +203,8 @@ void PrepareZeroCopyInputs( ...@@ -202,7 +203,8 @@ void PrepareZeroCopyInputs(
} }
} }
std::unique_ptr<ZeroCopyTensor> zerocopy_profile(int repeat_times) { // return the output values
std::vector<float> zerocopy_profile(int repeat_times) {
AnalysisConfig config; AnalysisConfig config;
SetConfig(&config); SetConfig(&config);
config.SwitchUseFeedFetchOps(false); config.SwitchUseFeedFetchOps(false);
...@@ -225,23 +227,40 @@ std::unique_ptr<ZeroCopyTensor> zerocopy_profile(int repeat_times) { ...@@ -225,23 +227,40 @@ std::unique_ptr<ZeroCopyTensor> zerocopy_profile(int repeat_times) {
} }
PrintTime(FLAGS_batch_size, repeat_times, 1, 0, timer.toc() / repeat_times, PrintTime(FLAGS_batch_size, repeat_times, 1, 0, timer.toc() / repeat_times,
1); 1);
return output_tensor;
VLOG(3) << "ZeroCopy output: " << DescribeZeroCopyTensor(*output_tensor);
PaddlePlace place;
int output_size{0};
auto *pdata = output_tensor->data<float>(&place, &output_size);
std::vector<float> res(output_size);
for (int i = 0; i < output_size; ++i) {
res[i] = pdata[i];
}
return res;
} }
TEST(Analyzer_seq_pool1, zerocopy_profile) { zerocopy_profile(FLAGS_repeat); } TEST(Analyzer_seq_pool1, zerocopy_profile) { zerocopy_profile(FLAGS_repeat); }
TEST(Analyzer_seq_pool1, zerocopy_fuse_statis) { TEST(Analyzer_seq_pool1, zerocopy_fuse_statis) { analysis_fuse_statis(true); }
TEST(Analyzer_seq_pool1, zerocopy_compare_native) {
AnalysisConfig config; AnalysisConfig config;
SetConfig(&config); SetConfig(&config);
config.SwitchUseFeedFetchOps(false); config.SwitchUseFeedFetchOps(true);
auto predictor = CreatePaddlePredictor<AnalysisConfig>(config); auto predictor = CreatePaddlePredictor<NativeConfig>(config.ToNativeConfig());
int num_ops; std::vector<PaddleTensor> native_outputs;
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops); std::vector<std::vector<PaddleTensor>> input_slots_all;
ASSERT_TRUE(fuse_statis.count("fc_fuse")); SetInput(&input_slots_all);
ASSERT_EQ(fuse_statis.at("fc_fuse"), 10); ASSERT_TRUE(predictor->Run(input_slots_all[0], &native_outputs));
ASSERT_TRUE(fuse_statis.count("seqpool_concat_fuse")); EXPECT_EQ(native_outputs.size(), 1UL);
EXPECT_EQ(fuse_statis.at("seqpool_concat_fuse"), 2);
ASSERT_EQ(num_ops, 195); auto zerocopy_output = zerocopy_profile(1);
EXPECT_EQ(zerocopy_output.size() * sizeof(float),
native_outputs.front().data.length());
auto *native_data = static_cast<float *>(native_outputs.front().data.data());
for (size_t i = 0; i < zerocopy_output.size(); ++i) {
EXPECT_NEAR(zerocopy_output[i], native_data[i], 1e-3);
}
} }
} // namespace analysis } // namespace analysis
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册