api_impl_tester.cc 12.3 KB
Newer Older
X
Xin Pan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <glog/logging.h>
#include <gtest/gtest.h>

L
Luo Tao 已提交
18
#include <thread>  // NOLINT
T
tensor-tang 已提交
19

X
Xin Pan 已提交
20
#include "gflags/gflags.h"
21
#include "paddle/fluid/framework/convert_utils.h"
L
Luo Tao 已提交
22
#include "paddle/fluid/inference/api/api_impl.h"
T
tianshuo78520a 已提交
23
#include "test/cpp/inference/test_helper.h"
X
Xin Pan 已提交
24

J
JiabinYang 已提交
25
#ifdef __clang__
26
#define ACC_DIFF 4e-3
J
JiabinYang 已提交
27
#else
28
#define ACC_DIFF 1e-3
J
JiabinYang 已提交
29 30
#endif

31 32
DEFINE_string(word2vec_dirname,
              "",
33 34
              "Directory of the word2vec inference model.");
DEFINE_string(book_dirname, "", "Directory of the book inference model.");
X
Xin Pan 已提交
35 36 37

namespace paddle {

38
PaddleTensor LodTensorToPaddleTensor(phi::DenseTensor* t) {
X
Xin Pan 已提交
39 40
  PaddleTensor pt;

41 42
  if (framework::TransToProtoVarType(t->dtype()) ==
      framework::proto::VarType::INT64) {
43
    pt.data.Reset(t->data(), t->numel() * sizeof(int64_t));
X
Xin Pan 已提交
44
    pt.dtype = PaddleDType::INT64;
45 46
  } else if (framework::TransToProtoVarType(t->dtype()) ==
             framework::proto::VarType::FP32) {
47
    pt.data.Reset(t->data(), t->numel() * sizeof(float));
X
Xin Pan 已提交
48
    pt.dtype = PaddleDType::FLOAT32;
49 50
  } else if (framework::TransToProtoVarType(t->dtype()) ==
             framework::proto::VarType::INT32) {
51
    pt.data.Reset(t->data(), t->numel() * sizeof(int32_t));
52
    pt.dtype = PaddleDType::INT32;
X
Xin Pan 已提交
53
  } else {
54 55
    PADDLE_THROW(platform::errors::Unimplemented(
        "Unsupported tensor date type. Now only supports INT64, FP32, INT32."));
X
Xin Pan 已提交
56
  }
57
  pt.shape = phi::vectorize<int>(t->dims());
X
Xin Pan 已提交
58 59 60
  return pt;
}

Y
Yan Chunwei 已提交
61 62
NativeConfig GetConfig() {
  NativeConfig config;
63
  config.model_dir = FLAGS_word2vec_dirname;
X
Xin Pan 已提交
64
  LOG(INFO) << "dirname  " << config.model_dir;
X
Xin Pan 已提交
65
  config.fraction_of_gpu_memory = 0.15;
X
Xin Pan 已提交
66
  config.device = 0;
67 68
  return config;
}
X
Xin Pan 已提交
69

70
void MainWord2Vec(const paddle::PaddlePlace& place) {
Y
Yan Chunwei 已提交
71 72
  NativeConfig config = GetConfig();
  auto predictor = CreatePaddlePredictor<NativeConfig>(config);
73 74
  config.use_gpu = paddle::gpu_place_used(place);
  config.use_xpu = paddle::xpu_place_used(place);
X
Xin Pan 已提交
75

76
  phi::DenseTensor first_word, second_word, third_word, fourth_word;
X
Xin Pan 已提交
77 78 79 80 81 82 83 84
  framework::LoD lod{{0, 1}};
  int64_t dict_size = 2073;  // The size of dictionary

  SetupLoDTensor(&first_word, lod, static_cast<int64_t>(0), dict_size - 1);
  SetupLoDTensor(&second_word, lod, static_cast<int64_t>(0), dict_size - 1);
  SetupLoDTensor(&third_word, lod, static_cast<int64_t>(0), dict_size - 1);
  SetupLoDTensor(&fourth_word, lod, static_cast<int64_t>(0), dict_size - 1);

85 86 87 88 89 90 91 92 93
  std::vector<PaddleTensor> paddle_tensor_feeds;
  paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&first_word));
  paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&second_word));
  paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&third_word));
  paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&fourth_word));

  std::vector<PaddleTensor> outputs;
  ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
  ASSERT_EQ(outputs.size(), 1UL);
94 95
  size_t len = outputs[0].data.length();
  float* data = static_cast<float*>(outputs[0].data.data());
96
  for (size_t j = 0; j < len / sizeof(float); ++j) {
97 98 99 100
    ASSERT_LT(data[j], 1.0);
    ASSERT_GT(data[j], -1.0);
  }

101
  std::vector<phi::DenseTensor*> cpu_feeds;
102 103 104 105 106
  cpu_feeds.push_back(&first_word);
  cpu_feeds.push_back(&second_word);
  cpu_feeds.push_back(&third_word);
  cpu_feeds.push_back(&fourth_word);

107 108
  framework::FetchType output1;
  std::vector<paddle::framework::FetchType*> cpu_fetchs1;
109 110 111 112
  cpu_fetchs1.push_back(&output1);

  TestInference<platform::CPUPlace>(config.model_dir, cpu_feeds, cpu_fetchs1);

113
  auto output1_tensor = PADDLE_GET(phi::DenseTensor, output1);
114 115
  float* lod_data = output1_tensor.data<float>();
  for (int i = 0; i < output1_tensor.numel(); ++i) {
J
JiabinYang 已提交
116 117
    EXPECT_LT(lod_data[i] - data[i], ACC_DIFF);
    EXPECT_GT(lod_data[i] - data[i], -ACC_DIFF);
118 119 120
  }
}

121
void MainImageClassification(const paddle::PaddlePlace& place) {
122 123
  int batch_size = 2;
  bool repeat = false;
Y
Yan Chunwei 已提交
124
  NativeConfig config = GetConfig();
125 126
  config.use_gpu = paddle::gpu_place_used(place);
  config.use_xpu = paddle::xpu_place_used(place);
127
  config.model_dir =
128
      FLAGS_book_dirname + "/image_classification_resnet.inference.model";
129 130 131 132 133

  const bool is_combined = false;
  std::vector<std::vector<int64_t>> feed_target_shapes =
      GetFeedTargetShapes(config.model_dir, is_combined);

134
  phi::DenseTensor input;
135 136 137
  // Use normilized image pixels as input data,
  // which should be in the range [0.0, 1.0].
  feed_target_shapes[0][0] = batch_size;
138
  framework::DDim input_dims = phi::make_ddim(feed_target_shapes[0]);
139 140
  SetupTensor<float>(
      &input, input_dims, static_cast<float>(0), static_cast<float>(1));
141
  std::vector<phi::DenseTensor*> cpu_feeds;
142 143
  cpu_feeds.push_back(&input);

144 145
  framework::FetchType output1;
  std::vector<framework::FetchType*> cpu_fetchs1;
146 147
  cpu_fetchs1.push_back(&output1);

L
Luo Tao 已提交
148 149
  TestInference<platform::CPUPlace, false, true>(
      config.model_dir, cpu_feeds, cpu_fetchs1, repeat, is_combined);
150

Y
Yan Chunwei 已提交
151
  auto predictor = CreatePaddlePredictor(config);
152 153
  std::vector<PaddleTensor> paddle_tensor_feeds;
  paddle_tensor_feeds.push_back(LodTensorToPaddleTensor(&input));
X
Xin Pan 已提交
154 155

  std::vector<PaddleTensor> outputs;
156
  ASSERT_TRUE(predictor->Run(paddle_tensor_feeds, &outputs));
157
  ASSERT_EQ(outputs.size(), 1UL);
158 159
  size_t len = outputs[0].data.length();
  float* data = static_cast<float*>(outputs[0].data.data());
160
  float* lod_data = PADDLE_GET(phi::DenseTensor, output1).data<float>();
161
  for (size_t j = 0; j < len / sizeof(float); ++j) {
J
JiabinYang 已提交
162
    EXPECT_NEAR(lod_data[j], data[j], ACC_DIFF);
X
Xin Pan 已提交
163 164 165
  }
}

166
void MainThreadsWord2Vec(const paddle::PaddlePlace& place) {
T
tensor-tang 已提交
167
  NativeConfig config = GetConfig();
168 169
  config.use_gpu = paddle::gpu_place_used(place);
  config.use_xpu = paddle::xpu_place_used(place);
T
tensor-tang 已提交
170 171
  auto main_predictor = CreatePaddlePredictor<NativeConfig>(config);

172
  // prepare inputs data and reference results
T
tensor-tang 已提交
173
  constexpr int num_jobs = 3;
174
  std::vector<std::vector<phi::DenseTensor>> jobs(num_jobs);
T
tensor-tang 已提交
175
  std::vector<std::vector<PaddleTensor>> paddle_tensor_feeds(num_jobs);
176
  std::vector<framework::FetchType> refs(num_jobs);
T
tensor-tang 已提交
177 178 179 180 181 182 183 184 185 186 187
  for (size_t i = 0; i < jobs.size(); ++i) {
    // each job has 4 words
    jobs[i].resize(4);
    for (size_t j = 0; j < 4; ++j) {
      framework::LoD lod{{0, 1}};
      int64_t dict_size = 2073;  // The size of dictionary
      SetupLoDTensor(&jobs[i][j], lod, static_cast<int64_t>(0), dict_size - 1);
      paddle_tensor_feeds[i].push_back(LodTensorToPaddleTensor(&jobs[i][j]));
    }

    // get reference result of each job
188
    std::vector<phi::DenseTensor*> ref_feeds;
189
    std::vector<paddle::framework::FetchType*> ref_fetches(1, &refs[i]);
T
tensor-tang 已提交
190 191 192 193 194 195 196 197 198 199
    for (auto& word : jobs[i]) {
      ref_feeds.push_back(&word);
    }
    TestInference<platform::CPUPlace>(config.model_dir, ref_feeds, ref_fetches);
  }

  // create threads and each thread run 1 job
  std::vector<std::thread> threads;
  for (int tid = 0; tid < num_jobs; ++tid) {
    threads.emplace_back([&, tid]() {
Y
Yan Chunwei 已提交
200
      auto predictor = CreatePaddlePredictor(config);
T
tensor-tang 已提交
201 202 203 204 205 206
      auto& local_inputs = paddle_tensor_feeds[tid];
      std::vector<PaddleTensor> local_outputs;
      ASSERT_TRUE(predictor->Run(local_inputs, &local_outputs));

      // check outputs range
      ASSERT_EQ(local_outputs.size(), 1UL);
207 208
      const size_t len = local_outputs[0].data.length();
      float* data = static_cast<float*>(local_outputs[0].data.data());
T
tensor-tang 已提交
209 210 211 212 213 214
      for (size_t j = 0; j < len / sizeof(float); ++j) {
        ASSERT_LT(data[j], 1.0);
        ASSERT_GT(data[j], -1.0);
      }

      // check outputs correctness
215
      auto ref_tensor = PADDLE_GET(phi::DenseTensor, refs[tid]);
216 217 218
      float* ref_data = ref_tensor.data<float>();
      EXPECT_EQ(ref_tensor.numel(), static_cast<int64_t>(len / sizeof(float)));
      for (int i = 0; i < ref_tensor.numel(); ++i) {
S
update  
superjomn 已提交
219
        EXPECT_NEAR(ref_data[i], data[i], 2e-3);
T
tensor-tang 已提交
220
      }
221 222 223 224 225 226 227
    });
  }
  for (int i = 0; i < num_jobs; ++i) {
    threads[i].join();
  }
}

228
void MainThreadsImageClassification(const paddle::PaddlePlace& place) {
229 230 231
  constexpr int num_jobs = 4;  // each job run 1 batch
  constexpr int batch_size = 1;
  NativeConfig config = GetConfig();
232 233
  config.use_gpu = paddle::gpu_place_used(place);
  config.use_xpu = paddle::xpu_place_used(place);
234
  config.model_dir =
235
      FLAGS_book_dirname + "/image_classification_resnet.inference.model";
236 237

  auto main_predictor = CreatePaddlePredictor<NativeConfig>(config);
238
  std::vector<phi::DenseTensor> jobs(num_jobs);
239
  std::vector<std::vector<PaddleTensor>> paddle_tensor_feeds(num_jobs);
240
  std::vector<framework::FetchType> refs(num_jobs);
241 242 243 244 245
  for (size_t i = 0; i < jobs.size(); ++i) {
    // prepare inputs
    std::vector<std::vector<int64_t>> feed_target_shapes =
        GetFeedTargetShapes(config.model_dir, /*is_combined*/ false);
    feed_target_shapes[0][0] = batch_size;
246
    framework::DDim input_dims = phi::make_ddim(feed_target_shapes[0]);
247 248 249 250
    SetupTensor<float>(&jobs[i], input_dims, 0.f, 1.f);
    paddle_tensor_feeds[i].push_back(LodTensorToPaddleTensor(&jobs[i]));

    // get reference result of each job
251
    std::vector<phi::DenseTensor*> ref_feeds(1, &jobs[i]);
252
    std::vector<framework::FetchType*> ref_fetches(1, &refs[i]);
253 254
    TestInference<platform::CPUPlace>(config.model_dir, ref_feeds, ref_fetches);
  }
T
tensor-tang 已提交
255

256 257 258 259
  // create threads and each thread run 1 job
  std::vector<std::thread> threads;
  for (int tid = 0; tid < num_jobs; ++tid) {
    threads.emplace_back([&, tid]() {
Y
Yan Chunwei 已提交
260
      auto predictor = CreatePaddlePredictor(config);
261 262 263 264 265 266
      auto& local_inputs = paddle_tensor_feeds[tid];
      std::vector<PaddleTensor> local_outputs;
      ASSERT_TRUE(predictor->Run(local_inputs, &local_outputs));

      // check outputs correctness
      ASSERT_EQ(local_outputs.size(), 1UL);
267 268
      const size_t len = local_outputs[0].data.length();
      float* data = static_cast<float*>(local_outputs[0].data.data());
269
      auto ref_tensor = PADDLE_GET(phi::DenseTensor, refs[tid]);
270 271 272
      float* ref_data = ref_tensor.data<float>();
      EXPECT_EQ((size_t)ref_tensor.numel(), len / sizeof(float));
      for (int i = 0; i < ref_tensor.numel(); ++i) {
J
JiabinYang 已提交
273
        EXPECT_NEAR(ref_data[i], data[i], ACC_DIFF);
274
      }
T
tensor-tang 已提交
275 276 277 278 279 280 281
    });
  }
  for (int i = 0; i < num_jobs; ++i) {
    threads[i].join();
  }
}

282 283 284
TEST(inference_api_native, word2vec_cpu) {
  MainWord2Vec(paddle::PaddlePlace::kCPU);
}
T
tensor-tang 已提交
285
TEST(inference_api_native, word2vec_cpu_threads) {
286
  MainThreadsWord2Vec(paddle::PaddlePlace::kCPU);
T
tensor-tang 已提交
287 288
}
TEST(inference_api_native, image_classification_cpu) {
289
  MainImageClassification(paddle::PaddlePlace::kCPU);
T
tensor-tang 已提交
290 291
}
TEST(inference_api_native, image_classification_cpu_threads) {
292
  MainThreadsImageClassification(paddle::PaddlePlace::kCPU);
T
tensor-tang 已提交
293 294
}

295 296 297 298 299 300 301 302 303
#ifdef PADDLE_WITH_XPU
TEST(inference_api_native, word2vec_xpu) {
  MainWord2Vec(paddle::PaddlePlace::kXPU);
}
TEST(inference_api_native, image_classification_xpu) {
  MainImageClassification(paddle::PaddlePlace::kXPU);
}
#endif

304
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
305 306 307
TEST(inference_api_native, word2vec_gpu) {
  MainWord2Vec(paddle::PaddlePlace::kGPU);
}
S
superjomn 已提交
308 309
// Turn off temporarily for the unstable result.
// TEST(inference_api_native, word2vec_gpu_threads) {
310
//   MainThreadsWord2Vec(paddle::PaddlePlace::kGPU);
S
superjomn 已提交
311
// }
T
tensor-tang 已提交
312
TEST(inference_api_native, image_classification_gpu) {
313
  MainImageClassification(paddle::PaddlePlace::kGPU);
T
tensor-tang 已提交
314
}
S
superjomn 已提交
315 316
// Turn off temporarily for the unstable result.
// TEST(inference_api_native, image_classification_gpu_threads) {
317
//   MainThreadsImageClassification(paddle::PaddlePlace::kGPU);
S
superjomn 已提交
318
// }
T
tensor-tang 已提交
319 320
#endif

321 322 323 324 325 326 327 328 329 330 331 332
#ifdef PADDLE_WITH_MKLDNN
TEST(inference_api_native, image_classification_cpu_onednn) {
  FLAGS_use_mkldnn = true;
  MainImageClassification(paddle::PaddlePlace::kCPU);
}

TEST(inference_api_native, word2vec_cpu_onednn) {
  FLAGS_use_mkldnn = true;
  MainWord2Vec(paddle::PaddlePlace::kCPU);
}
#endif

333
TEST(PassBuilder, Delete) {
334
  AnalysisConfig config;
335
  config.DisableGpu();
336 337 338 339 340 341
  config.pass_builder()->DeletePass("attention_lstm_fuse_pass");
  const auto& passes = config.pass_builder()->AllPasses();
  auto it = std::find(passes.begin(), passes.end(), "attention_lstm_fuse_pass");
  ASSERT_EQ(it, passes.end());
}

X
Xin Pan 已提交
342
}  // namespace paddle