tester_helper.h 32.3 KB
Newer Older
L
luotao1 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <gtest/gtest.h>
Y
Yan Chunwei 已提交
18

L
luotao1 已提交
19
#include <algorithm>
L
luotao1 已提交
20
#include <memory>
T
Tao Luo 已提交
21
#include <string>
L
luotao1 已提交
22
#include <thread>  // NOLINT
L
luotao1 已提交
23
#include <unordered_map>
L
luotao1 已提交
24
#include <vector>
Y
Yiqun Liu 已提交
25 26 27
#ifdef WITH_GPERFTOOLS
#include <gperftools/profiler.h>
#endif
L
luotao1 已提交
28
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
29
#include "paddle/fluid/framework/scope.h"
L
luotao1 已提交
30 31 32
#include "paddle/fluid/inference/analysis/analyzer.h"
#include "paddle/fluid/inference/analysis/ut_helper.h"
#include "paddle/fluid/inference/api/analysis_predictor.h"
33
#include "paddle/fluid/inference/api/helper.h"
Y
Yan Chunwei 已提交
34
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
35
#include "paddle/fluid/inference/tests/api/config_printer.h"
T
Tao Luo 已提交
36
#include "paddle/fluid/inference/tests/test_helper.h"
N
nhzlx 已提交
37
#include "paddle/fluid/inference/utils/benchmark.h"
L
luotao1 已提交
38 39
#include "paddle/fluid/platform/profiler.h"

N
nhzlx 已提交
40
DEFINE_string(model_name, "", "model name");
L
luotao1 已提交
41
DEFINE_string(infer_model, "", "model path");
42 43
DEFINE_string(fp32_model, "", "FP32 model path");
DEFINE_string(int8_model, "", "INT8 model path");
L
luotao1 已提交
44
DEFINE_string(infer_data, "", "data file");
T
Tao Luo 已提交
45
DEFINE_string(refer_result, "", "reference result for comparison");
46
DEFINE_int32(batch_size, 1, "batch size");
47
DEFINE_bool(ernie_large, false, "Test ernie large");
48 49
DEFINE_bool(with_accuracy_layer, true,
            "Calculate the accuracy while label is in the input");
50 51
DEFINE_bool(enable_fp32, true, "Enable FP32 type prediction");
DEFINE_bool(enable_int8, true, "Enable INT8 type prediction");
52 53 54
DEFINE_int32(warmup_batch_size, 100, "batch size for quantization warmup");
// setting iterations to 0 means processing the whole dataset
DEFINE_int32(iterations, 0, "number of batches to process");
L
luotao1 已提交
55 56 57
DEFINE_int32(repeat, 1, "Running the inference program repeat times.");
DEFINE_bool(test_all_data, false, "Test the all dataset in data file.");
DEFINE_int32(num_threads, 1, "Running the inference program in multi-threads.");
T
Tao Luo 已提交
58 59
DEFINE_bool(use_analysis, true,
            "Running the inference program in analysis mode.");
N
nhzlx 已提交
60 61
DEFINE_bool(record_benchmark, false,
            "Record benchmark after profiling the model");
L
luotao1 已提交
62
DEFINE_double(accuracy, 1e-3, "Result Accuracy.");
63
DEFINE_double(quantized_accuracy, 1e-2, "Result Quantized Accuracy.");
L
luotao1 已提交
64
DEFINE_bool(zero_copy, false, "Use ZeroCopy to speedup Feed/Fetch.");
65 66 67
DEFINE_bool(warmup, false,
            "Use warmup to calculate elapsed_time more accurately. "
            "To reduce CI time, it sets false in default.");
L
luotao1 已提交
68

69
DECLARE_bool(profile);
L
luotao1 已提交
70
DECLARE_int32(paddle_num_threads);
71

L
luotao1 已提交
72 73 74
namespace paddle {
namespace inference {

75 76
using paddle::framework::proto::VarType;

77 78 79 80 81 82 83 84 85 86 87 88 89
template <typename T>
constexpr paddle::PaddleDType GetPaddleDType();

template <>
constexpr paddle::PaddleDType GetPaddleDType<int64_t>() {
  return paddle::PaddleDType::INT64;
}

template <>
constexpr paddle::PaddleDType GetPaddleDType<float>() {
  return paddle::PaddleDType::FLOAT32;
}

90
void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
91
  const auto *analysis_config =
92
      reinterpret_cast<const AnalysisConfig *>(config);
93
  if (use_analysis) {
94
    LOG(INFO) << *analysis_config;
95 96
    return;
  }
97
  LOG(INFO) << analysis_config->ToNativeConfig();
98
}
Y
Yan Chunwei 已提交
99

100 101 102 103 104 105 106 107
void CheckError(float data_ref, float data) {
  if (std::abs(data_ref) > 1) {
    CHECK_LE(std::abs((data_ref - data) / data_ref), FLAGS_accuracy);
  } else {
    CHECK_LE(std::abs(data_ref - data), FLAGS_accuracy);
  }
}

108
// Compare result between two PaddleTensor
L
luotao1 已提交
109
void CompareResult(const std::vector<PaddleTensor> &outputs,
T
tensor-tang 已提交
110
                   const std::vector<PaddleTensor> &ref_outputs) {
T
Tao Luo 已提交
111
  EXPECT_GT(outputs.size(), 0UL);
T
tensor-tang 已提交
112
  EXPECT_EQ(outputs.size(), ref_outputs.size());
L
luotao1 已提交
113 114
  for (size_t i = 0; i < outputs.size(); i++) {
    auto &out = outputs[i];
T
tensor-tang 已提交
115
    auto &ref_out = ref_outputs[i];
116 117
    size_t size = VecReduceToInt(out.shape);
    size_t ref_size = VecReduceToInt(ref_out.shape);
118
    EXPECT_GT(size, 0UL);
T
tensor-tang 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
    EXPECT_EQ(size, ref_size);
    EXPECT_EQ(out.dtype, ref_out.dtype);
    switch (out.dtype) {
      case PaddleDType::INT64: {
        int64_t *pdata = static_cast<int64_t *>(out.data.data());
        int64_t *pdata_ref = static_cast<int64_t *>(ref_out.data.data());
        for (size_t j = 0; j < size; ++j) {
          EXPECT_EQ(pdata_ref[j], pdata[j]);
        }
        break;
      }
      case PaddleDType::FLOAT32: {
        float *pdata = static_cast<float *>(out.data.data());
        float *pdata_ref = static_cast<float *>(ref_out.data.data());
        for (size_t j = 0; j < size; ++j) {
134
          CheckError(pdata_ref[j], pdata[j]);
T
tensor-tang 已提交
135 136 137
        }
        break;
      }
138 139 140 141 142 143 144 145
      case PaddleDType::INT32: {
        int32_t *pdata = static_cast<int32_t *>(out.data.data());
        int32_t *pdata_ref = static_cast<int32_t *>(ref_out.data.data());
        for (size_t j = 0; j < size; ++j) {
          EXPECT_EQ(pdata_ref[j], pdata[j]);
        }
        break;
      }
146 147 148 149 150 151 152 153
      case PaddleDType::UINT8: {
        uint8_t *pdata = static_cast<uint8_t *>(out.data.data());
        uint8_t *pdata_ref = static_cast<uint8_t *>(ref_out.data.data());
        for (size_t j = 0; j < size; ++j) {
          EXPECT_EQ(pdata_ref[j], pdata[j]);
        }
        break;
      }
L
luotao1 已提交
154 155 156 157
    }
  }
}

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
// Compare result between a PaddleTensor and a ZeroCopyTensor
void CompareResult(const std::vector<PaddleTensor> &outputs,
                   const std::vector<ZeroCopyTensor> &ref_outputs) {
  EXPECT_GT(outputs.size(), 0UL);
  EXPECT_EQ(outputs.size(), ref_outputs.size());
  for (size_t i = 0; i < outputs.size(); i++) {
    auto &out = outputs[i];
    auto &ref_out = ref_outputs[i];
    size_t size = VecReduceToInt(out.shape);
    EXPECT_GT(size, 0UL);
    int ref_size = 0;  // this is the number of elements not memory size
    PaddlePlace place;
    switch (out.dtype) {
      case PaddleDType::INT64: {
        int64_t *pdata = static_cast<int64_t *>(out.data.data());
        int64_t *pdata_ref = ref_out.data<int64_t>(&place, &ref_size);
174
        EXPECT_EQ(size, static_cast<size_t>(ref_size));
175 176 177 178 179 180 181 182
        for (size_t j = 0; j < size; ++j) {
          EXPECT_EQ(pdata_ref[j], pdata[j]);
        }
        break;
      }
      case PaddleDType::FLOAT32: {
        float *pdata = static_cast<float *>(out.data.data());
        float *pdata_ref = ref_out.data<float>(&place, &ref_size);
183
        EXPECT_EQ(size, static_cast<size_t>(ref_size));
184
        for (size_t j = 0; j < size; ++j) {
185
          CheckError(pdata_ref[j], pdata[j]);
186 187 188
        }
        break;
      }
L
luotao1 已提交
189 190 191
      case PaddleDType::INT32: {
        int32_t *pdata = static_cast<int32_t *>(out.data.data());
        int32_t *pdata_ref = ref_out.data<int32_t>(&place, &ref_size);
192
        EXPECT_EQ(size, static_cast<size_t>(ref_size));
L
luotao1 已提交
193 194 195 196 197
        for (size_t j = 0; j < size; ++j) {
          EXPECT_EQ(pdata_ref[j], pdata[j]);
        }
        break;
      }
198 199 200
      case PaddleDType::UINT8: {
        uint8_t *pdata = static_cast<uint8_t *>(out.data.data());
        uint8_t *pdata_ref = ref_out.data<uint8_t>(&place, &ref_size);
201
        EXPECT_EQ(size, static_cast<size_t>(ref_size));
202 203 204 205 206
        for (size_t j = 0; j < size; ++j) {
          EXPECT_EQ(pdata_ref[j], pdata[j]);
        }
        break;
      }
207 208 209 210
    }
  }
}

211
std::unique_ptr<PaddlePredictor> CreateTestPredictor(
212
    const PaddlePredictor::Config *config, bool use_analysis = true) {
213
  const auto *analysis_config =
214
      reinterpret_cast<const AnalysisConfig *>(config);
T
Tao Luo 已提交
215
  if (use_analysis) {
216
    return CreatePaddlePredictor<AnalysisConfig>(*analysis_config);
T
Tao Luo 已提交
217
  }
218 219
  auto native_config = analysis_config->ToNativeConfig();
  return CreatePaddlePredictor<NativeConfig>(native_config);
T
Tao Luo 已提交
220 221
}

222
size_t GetSize(const PaddleTensor &out) { return VecReduceToInt(out.shape); }
T
Tao Luo 已提交
223

224
std::unordered_map<std::string, int> GetFuseStatis(PaddlePredictor *predictor,
T
Tao Luo 已提交
225
                                                   int *num_ops) {
226
  std::unordered_map<std::string, int> res;
227
  auto *analysis_predictor = static_cast<AnalysisPredictor *>(predictor);
228 229 230 231 232 233
  auto *fusion_status =
      analysis_predictor->analysis_argument().fusion_statis_ptr();
  if (!fusion_status) {
    return res;
  }
  for (auto &item : *fusion_status) {
T
Tao Luo 已提交
234 235 236 237
    LOG(INFO) << "fused " << item.first << " " << item.second;
  }
  int num = 0;
  for (auto &node :
238 239
       analysis_predictor->analysis_argument().main_graph().Nodes()) {
    if (node->IsOp()) {
T
Tao Luo 已提交
240 241 242 243
      ++num;
    }
  }
  *num_ops = num;
244
  return *fusion_status;
T
Tao Luo 已提交
245 246
}

T
Tao Luo 已提交
247
void SetFakeImageInput(std::vector<std::vector<PaddleTensor>> *inputs,
248 249
                       const std::string &dirname, bool is_combined = true,
                       std::string model_filename = "model",
T
tensor-tang 已提交
250
                       std::string params_filename = "params",
N
nhzlx 已提交
251 252
                       const std::vector<std::string> *feed_names = nullptr,
                       const int continuous_inuput_index = 0) {
T
Tao Luo 已提交
253
  // Set fake_image_data
254 255 256 257 258
  PADDLE_ENFORCE_EQ(FLAGS_test_all_data, 0,
                    platform::errors::InvalidArgument(
                        "In SetFakeImageInput, expected test_all_data = false, "
                        "but now test_all_data=",
                        FLAGS_test_all_data));
259 260 261 262 263 264 265 266 267 268 269
  std::vector<std::vector<int64_t>> feed_target_shapes = GetFeedTargetShapes(
      dirname, is_combined, model_filename, params_filename);
  std::ostringstream os;
  for (size_t i = 0; i < feed_target_shapes.size(); ++i) {
    os << "feed target " << i << ": {" << feed_target_shapes[i][0];
    for (size_t j = 1; j < feed_target_shapes[i].size(); ++j) {
      os << ", " << feed_target_shapes[i][j];
    }
    os << "}\n";
  }
  LOG(INFO) << os.str();
T
tensor-tang 已提交
270
  if (feed_names) {
271 272 273 274 275 276 277
    PADDLE_ENFORCE_EQ(
        feed_names->size(), feed_target_shapes.size(),
        platform::errors::InvalidArgument(
            "The size of feeds_names and size of "
            "feed_target_shapes must be equal, but now feeds_names "
            "size is %d and feed_target_shapes size is %d",
            feed_names->size(), feed_target_shapes.size()));
T
tensor-tang 已提交
278 279 280 281 282 283 284 285 286 287 288 289 290 291
  }
  std::vector<PaddleTensor> input_slots(feed_target_shapes.size());
  for (size_t i = 0; i < feed_target_shapes.size(); ++i) {
    const auto &feed_shape = feed_target_shapes[i];
    auto &input = input_slots[i];
    std::vector<int> shape({FLAGS_batch_size});
    for (size_t s = 1; s < feed_shape.size(); ++s) {
      shape.push_back(static_cast<int>(feed_shape[s]));
    }
    if (feed_names) {
      input.name = (*feed_names)[i];
    }
    input.shape = shape;
    input.dtype = PaddleDType::FLOAT32;
292
    size_t len = std::accumulate(shape.begin(), shape.end(), size_t{1},
T
tensor-tang 已提交
293 294 295 296 297 298
                                 [](int a, int b) { return a * b; });
    input.data.Resize(len * sizeof(float));
    input.lod.assign({{0, static_cast<size_t>(FLAGS_batch_size)}});
    float *input_data = static_cast<float *>(input.data.data());
    // fill input data, for profile easily, do not use random data here.
    for (size_t j = 0; j < len; ++j) {
N
nhzlx 已提交
299 300
      *(input_data + j) =
          static_cast<float>((j + continuous_inuput_index) % len) / len;
T
tensor-tang 已提交
301
    }
T
Tao Luo 已提交
302 303 304 305
  }
  (*inputs).emplace_back(input_slots);
}

306 307 308 309 310 311 312 313 314 315 316 317
void GetInputPerBatch(const std::vector<std::vector<int64_t>> &in,
                      std::vector<std::vector<int64_t>> *out,
                      std::vector<size_t> *lod, size_t batch_iter,
                      size_t batch_end) {
  lod->clear();
  lod->push_back(0);
  for (auto it = in.begin() + batch_iter; it < in.begin() + batch_end; it++) {
    out->push_back(*it);
    lod->push_back(lod->back() + (*it).size());  // calculate lod
  }
}

L
luotao1 已提交
318 319 320 321 322 323 324 325 326 327 328
void ConvertPaddleTensorToZeroCopyTensor(
    PaddlePredictor *predictor, const std::vector<PaddleTensor> &inputs) {
  for (size_t i = 0; i < inputs.size(); i++) {
    auto input = inputs[i];
    auto tensor = predictor->GetInputTensor(input.name);
    tensor->Reshape(input.shape);
    tensor->SetLoD({input.lod});
    if (input.dtype == PaddleDType::INT64) {
      ZeroCopyTensorAssignData<int64_t>(tensor.get(), input.data);
    } else if (input.dtype == PaddleDType::FLOAT32) {
      ZeroCopyTensorAssignData<float>(tensor.get(), input.data);
L
luotao1 已提交
329 330
    } else if (input.dtype == PaddleDType::INT32) {
      ZeroCopyTensorAssignData<int32_t>(tensor.get(), input.data);
331 332
    } else if (input.dtype == PaddleDType::UINT8) {
      ZeroCopyTensorAssignData<uint8_t>(tensor.get(), input.data);
L
luotao1 已提交
333 334 335 336 337
    } else {
      LOG(ERROR) << "unsupported feed type " << input.dtype;
    }
  }
}
338

L
luotao1 已提交
339 340
void PredictionWarmUp(PaddlePredictor *predictor,
                      const std::vector<std::vector<PaddleTensor>> &inputs,
341
                      std::vector<std::vector<PaddleTensor>> *outputs,
342 343
                      int num_threads, int tid,
                      const VarType::Type data_type = VarType::FP32) {
L
luotao1 已提交
344 345 346 347 348
  int batch_size = FLAGS_batch_size;
  LOG(INFO) << "Running thread " << tid << ", warm up run...";
  if (FLAGS_zero_copy) {
    ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[0]);
  }
349
  outputs->resize(1);
L
luotao1 已提交
350 351 352
  Timer warmup_timer;
  warmup_timer.tic();
  if (!FLAGS_zero_copy) {
353
    predictor->Run(inputs[0], &(*outputs)[0], batch_size);
L
luotao1 已提交
354 355
  } else {
    predictor->ZeroCopyRun();
356
  }
357
  PrintTime(batch_size, 1, num_threads, tid, warmup_timer.toc(), 1, data_type);
L
luotao1 已提交
358 359 360 361
  if (FLAGS_profile) {
    paddle::platform::ResetProfiler();
  }
}
362

L
luotao1 已提交
363 364
void PredictionRun(PaddlePredictor *predictor,
                   const std::vector<std::vector<PaddleTensor>> &inputs,
365
                   std::vector<std::vector<PaddleTensor>> *outputs,
366
                   int num_threads, int tid,
367 368
                   const VarType::Type data_type = VarType::FP32,
                   float *sample_latency = nullptr) {
L
luotao1 已提交
369
  int num_times = FLAGS_repeat;
370
  int iterations = inputs.size();  // process the whole dataset ...
371 372
  if (FLAGS_iterations > 0 &&
      FLAGS_iterations < static_cast<int64_t>(inputs.size()))
373 374 375 376 377
    iterations =
        FLAGS_iterations;  // ... unless the number of iterations is set
  outputs->resize(iterations);
  LOG(INFO) << "Thread " << tid << ", number of threads " << num_threads
            << ", run " << num_times << " times...";
L
luotao1 已提交
378 379
  Timer run_timer;
  double elapsed_time = 0;
Y
Yiqun Liu 已提交
380
#ifdef WITH_GPERFTOOLS
L
luotao1 已提交
381
  ProfilerStart("paddle_inference.prof");
Y
Yiqun Liu 已提交
382
#endif
383
  int predicted_num = 0;
L
luotao1 已提交
384
  if (!FLAGS_zero_copy) {
385
    for (int i = 0; i < iterations; i++) {
386
      run_timer.tic();
L
luotao1 已提交
387
      for (int j = 0; j < num_times; j++) {
388
        predictor->Run(inputs[i], &(*outputs)[i], FLAGS_batch_size);
389
      }
390 391 392 393 394 395
      elapsed_time += run_timer.toc();

      predicted_num += FLAGS_batch_size;
      if (predicted_num % 100 == 0) {
        LOG(INFO) << predicted_num << " samples";
      }
L
luotao1 已提交
396
    }
L
luotao1 已提交
397
  } else {
398
    for (int i = 0; i < iterations; i++) {
L
luotao1 已提交
399 400 401 402 403 404
      ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[i]);
      run_timer.tic();
      for (int j = 0; j < num_times; j++) {
        predictor->ZeroCopyRun();
      }
      elapsed_time += run_timer.toc();
405 406 407 408 409

      predicted_num += FLAGS_batch_size;
      if (predicted_num % 100 == 0) {
        LOG(INFO) << predicted_num << " samples";
      }
L
luotao1 已提交
410 411
    }
  }
412

Y
Yiqun Liu 已提交
413
#ifdef WITH_GPERFTOOLS
L
luotao1 已提交
414
  ProfilerStop();
Y
Yiqun Liu 已提交
415
#endif
N
nhzlx 已提交
416

417 418
  auto batch_latency = elapsed_time / (iterations * num_times);
  PrintTime(FLAGS_batch_size, num_times, num_threads, tid, batch_latency,
419
            iterations, data_type);
420 421 422 423

  if (sample_latency != nullptr)
    *sample_latency = batch_latency / FLAGS_batch_size;

L
luotao1 已提交
424 425 426
  if (FLAGS_record_benchmark) {
    Benchmark benchmark;
    benchmark.SetName(FLAGS_model_name);
427 428
    benchmark.SetBatchSize(FLAGS_batch_size);
    benchmark.SetLatency(batch_latency);
L
luotao1 已提交
429
    benchmark.PersistToFile("benchmark_record.txt");
L
luotao1 已提交
430 431 432
  }
}

L
luotao1 已提交
433 434 435
void TestOneThreadPrediction(
    const PaddlePredictor::Config *config,
    const std::vector<std::vector<PaddleTensor>> &inputs,
436
    std::vector<std::vector<PaddleTensor>> *outputs, bool use_analysis = true,
437 438
    const VarType::Type data_type = VarType::FP32,
    float *sample_latency = nullptr) {
L
luotao1 已提交
439
  auto predictor = CreateTestPredictor(config, use_analysis);
440
  if (FLAGS_warmup) {
441
    PredictionWarmUp(predictor.get(), inputs, outputs, 1, 0, data_type);
442
  }
443 444
  PredictionRun(predictor.get(), inputs, outputs, 1, 0, data_type,
                sample_latency);
L
luotao1 已提交
445 446
}

L
luotao1 已提交
447
void TestMultiThreadPrediction(
448
    const PaddlePredictor::Config *config,
449
    const std::vector<std::vector<PaddleTensor>> &inputs,
450
    std::vector<std::vector<PaddleTensor>> *outputs, int num_threads,
T
Tao Luo 已提交
451
    bool use_analysis = true) {
L
luotao1 已提交
452
  std::vector<std::thread> threads;
L
luotao1 已提交
453 454 455 456 457
  std::vector<std::unique_ptr<PaddlePredictor>> predictors;
  predictors.emplace_back(CreateTestPredictor(config, use_analysis));
  for (int tid = 1; tid < num_threads; tid++) {
    predictors.emplace_back(predictors.front()->Clone());
  }
458

L
luotao1 已提交
459 460 461 462
  for (int tid = 0; tid < num_threads; ++tid) {
    threads.emplace_back([&, tid]() {
      // Each thread should have local inputs and outputs.
      // The inputs of each thread are all the same.
463
      std::vector<std::vector<PaddleTensor>> outputs_tid;
L
luotao1 已提交
464
      auto &predictor = predictors[tid];
465 466 467 468
      if (FLAGS_warmup) {
        PredictionWarmUp(predictor.get(), inputs, &outputs_tid, num_threads,
                         tid);
      }
469
      PredictionRun(predictor.get(), inputs, &outputs_tid, num_threads, tid);
L
luotao1 已提交
470 471 472 473 474 475 476
    });
  }
  for (int i = 0; i < num_threads; ++i) {
    threads[i].join();
  }
}

477
void TestPrediction(const PaddlePredictor::Config *config,
478
                    const std::vector<std::vector<PaddleTensor>> &inputs,
479 480
                    std::vector<std::vector<PaddleTensor>> *outputs,
                    int num_threads, bool use_analysis = FLAGS_use_analysis) {
481
  PrintConfig(config, use_analysis);
L
luotao1 已提交
482
  if (num_threads == 1) {
T
Tao Luo 已提交
483
    TestOneThreadPrediction(config, inputs, outputs, use_analysis);
L
luotao1 已提交
484
  } else {
T
Tao Luo 已提交
485 486
    TestMultiThreadPrediction(config, inputs, outputs, num_threads,
                              use_analysis);
L
luotao1 已提交
487 488 489
  }
}

490 491
void SummarizeAccuracy(float avg_acc_fp32, float avg_acc_int8,
                       int compared_idx) {
492 493 494 495 496 497 498 499 500 501 502 503 504 505
  PADDLE_ENFORCE_LE(
      compared_idx, 2,
      platform::errors::InvalidArgument(
          "The compared_idx should be <= 2. But received compared_idx = %d. "
          "For top1 accuracy, set compared_idx = 1; For top5 accuracy or mean "
          "Average Precision (mAP), set compared_idx = 2.",
          compared_idx));
  PADDLE_ENFORCE_GE(
      compared_idx, 1,
      platform::errors::InvalidArgument(
          "The compared_idx should be >= 1. But received compared_idx = %d. "
          "For top1 accuracy, set compared_idx = 1; For top5 accuracy or mean "
          "Average Precision (mAP), set compared_idx = 2.",
          compared_idx));
506
  std::string prefix = (compared_idx == 1) ? "top1_accuracy " : "mAP ";
507
  LOG(INFO) << "--- Accuracy summary --- ";
508 509 510 511 512 513 514 515
  LOG(INFO) << "Accepted " << prefix
            << "drop threshold: " << FLAGS_quantized_accuracy
            << ". (condition: (FP32_" << prefix << " - INT8_" << prefix
            << ") <= threshold)";
  LOG(INFO) << "FP32: avg " << prefix << std::fixed << std::setw(6)
            << std::setprecision(4) << avg_acc_fp32;
  LOG(INFO) << "INT8: avg " << prefix << std::fixed << std::setw(6)
            << std::setprecision(4) << avg_acc_int8;
516 517
}

518 519 520 521 522 523 524 525
void SummarizePerformance(const char *title, float sample) {
  CHECK_GT(sample, 0.0);
  auto throughput = 1000.0 / sample;
  LOG(INFO) << title << ": avg fps: " << std::fixed << std::setw(6)
            << std::setprecision(4) << throughput << ", avg latency: " << sample
            << " ms";
}

526 527
void SummarizePerformance(float sample_latency_fp32,
                          float sample_latency_int8) {
528 529
  if (FLAGS_enable_fp32) SummarizePerformance("FP32", sample_latency_fp32);
  if (FLAGS_enable_int8) SummarizePerformance("INT8", sample_latency_int8);
530 531
}

532 533
float CompareAccuracyOne(
    const std::vector<std::vector<PaddleTensor>> &output_slots,
534
    int compared_idx) {
535 536 537 538
  PADDLE_ENFORCE_GT(output_slots.size(), 0,
                    platform::errors::InvalidArgument(
                        "The accuracy vector is empty. The accuracy vector "
                        "size should be bigger than 0"));
539

540 541 542 543 544 545 546
  float total_accs{0};

  for (size_t i = 0; i < output_slots.size(); ++i) {
    switch (compared_idx) {
      case 1:
        PADDLE_ENFORCE_GE(
            output_slots[i].size(), 2UL,
547 548 549 550
            platform::errors::InvalidArgument(
                "To achieve top 1 accuracy, output_slots size "
                "must be bigger than or equal to 2, but now the size is %d",
                output_slots[i].size()));
551 552 553
        break;
      case 2:
        PADDLE_ENFORCE_GE(
554 555 556 557 558 559
            output_slots[i].size(), 3UL,
            platform::errors::InvalidArgument(
                "To achieve top 5 accuracy or mean Average "
                "Precision (mAP), output_slots size must be "
                "bigger than or equal to 3, but now the size is %d",
                output_slots[i].size()));
560 561 562 563
        break;
      default:
        throw std::invalid_argument(
            "CompareAccuracy: compared_idx is out of range.");
564 565
    }

566
    if (output_slots[i][compared_idx].lod.size() > 0)
567
      throw std::invalid_argument("CompareAccuracy: output has nonempty LoD.");
568 569

    if (output_slots[i][compared_idx].dtype != paddle::PaddleDType::FLOAT32)
570
      throw std::invalid_argument(
571
          "CompareAccuracy: output is of a wrong type.");
572 573 574

    total_accs +=
        *static_cast<float *>(output_slots[i][compared_idx].data.data());
575
  }
576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596

  return total_accs / output_slots.size();
}

void CompareAccuracy(
    const std::vector<std::vector<PaddleTensor>> &output_slots_quant,
    const std::vector<std::vector<PaddleTensor>> &output_slots_ref,
    int compared_idx) {
  if ((FLAGS_enable_fp32 && FLAGS_enable_int8) &&
      (output_slots_quant.size() == 0 || output_slots_ref.size()) == 0)
    throw std::invalid_argument(
        "CompareAccuracy: output_slots vector is empty.");

  float avg_acc_quant = 0.0;
  float avg_acc_ref = 0.0;

  if (FLAGS_enable_int8)
    avg_acc_quant = CompareAccuracyOne(output_slots_quant, compared_idx);

  if (FLAGS_enable_fp32)
    avg_acc_ref = CompareAccuracyOne(output_slots_ref, compared_idx);
597

598
  SummarizeAccuracy(avg_acc_ref, avg_acc_quant, compared_idx);
599 600 601 602 603 604 605

  if (FLAGS_enable_fp32) CHECK_GT(avg_acc_ref, 0.0);

  if (FLAGS_enable_int8) CHECK_GT(avg_acc_quant, 0.0);

  if (FLAGS_enable_fp32 && FLAGS_enable_int8)
    CHECK_LE(avg_acc_ref - avg_acc_quant, FLAGS_quantized_accuracy);
606 607
}

L
luotao1 已提交
608 609 610 611 612 613 614 615 616
void CompareDeterministic(
    const PaddlePredictor::Config *config,
    const std::vector<std::vector<PaddleTensor>> &inputs) {
  int batch_size = FLAGS_batch_size;
  int num_times = FLAGS_repeat;
  auto predictor = CreateTestPredictor(config, FLAGS_use_analysis);

  std::vector<PaddleTensor> warmup_outputs, outputs;
  // run num_times to Compare Deterministic Result.
617 618 619 620
  for (size_t j = 0; j < inputs.size(); j++) {
    // warmup run
    predictor->Run(inputs[j], &warmup_outputs, batch_size);
    for (int i = 0; i < num_times; i++) {
L
luotao1 已提交
621 622 623 624 625 626
      predictor->Run(inputs[j], &outputs, batch_size);
      CompareResult(outputs, warmup_outputs);
    }
  }
}

T
Tao Luo 已提交
627
void CompareNativeAndAnalysis(
628
    const PaddlePredictor::Config *config,
629
    const std::vector<std::vector<PaddleTensor>> &inputs) {
630
  PrintConfig(config, true);
631
  std::vector<std::vector<PaddleTensor>> native_outputs, analysis_outputs;
632
  TestOneThreadPrediction(config, inputs, &native_outputs, false);
T
Tao Luo 已提交
633
  TestOneThreadPrediction(config, inputs, &analysis_outputs, true);
634 635 636 637 638 639 640 641
  PADDLE_ENFORCE_GT(native_outputs.size(), 0,
                    platform::errors::InvalidArgument(
                        "The native outputs is empty. The native outputs "
                        "vector size must be bigger than 0"));
  PADDLE_ENFORCE_GT(analysis_outputs.size(), 0,
                    platform::errors::InvalidArgument(
                        "The analysis outputs is empty. The analysis outputs "
                        "vector size must be bigger than 0"));
642
  CompareResult(analysis_outputs.back(), native_outputs.back());
T
Tao Luo 已提交
643 644
}

645
void CompareQuantizedAndAnalysis(
646
    const AnalysisConfig *config, const AnalysisConfig *qconfig,
647 648
    const std::vector<std::vector<PaddleTensor>> &inputs,
    const int compared_idx = 1) {
649 650 651 652 653 654
  PADDLE_ENFORCE_EQ(
      inputs[0][0].shape[0], FLAGS_batch_size,
      platform::errors::InvalidArgument(
          "Input data has to be packed batch by batch. The batchsize is set to "
          "%d, but the real input is packed with batchsize = %d",
          FLAGS_batch_size, inputs[0][0].shape[0]));
655 656 657 658 659 660 661
  LOG(INFO) << "FP32 & INT8 prediction run: batch_size " << FLAGS_batch_size
            << ", warmup batch size " << FLAGS_warmup_batch_size << ".";

  LOG(INFO) << "--- FP32 prediction start ---";
  auto *cfg = reinterpret_cast<const PaddlePredictor::Config *>(config);
  PrintConfig(cfg, true);
  std::vector<std::vector<PaddleTensor>> analysis_outputs;
662
  float sample_latency_fp32{-1};
663 664 665 666 667

  if (FLAGS_enable_fp32) {
    TestOneThreadPrediction(cfg, inputs, &analysis_outputs, true, VarType::FP32,
                            &sample_latency_fp32);
  }
668 669 670 671 672

  LOG(INFO) << "--- INT8 prediction start ---";
  auto *qcfg = reinterpret_cast<const PaddlePredictor::Config *>(qconfig);
  PrintConfig(qcfg, true);
  std::vector<std::vector<PaddleTensor>> quantized_outputs;
673
  float sample_latency_int8{-1};
674

675 676 677 678
  if (FLAGS_enable_int8) {
    TestOneThreadPrediction(qcfg, inputs, &quantized_outputs, true,
                            VarType::INT8, &sample_latency_int8);
  }
679
  SummarizePerformance(sample_latency_fp32, sample_latency_int8);
680

681
  CompareAccuracy(quantized_outputs, analysis_outputs, compared_idx);
682 683
}

684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725
void CompareAnalysisAndAnalysis(
    const AnalysisConfig *config1, const AnalysisConfig *config2,
    const std::vector<std::vector<PaddleTensor>> &inputs,
    const bool with_accuracy_layer = FLAGS_with_accuracy_layer,
    const int compared_idx = 1) {
  PADDLE_ENFORCE_EQ(
      inputs[0][0].shape[0], FLAGS_batch_size,
      platform::errors::InvalidArgument(
          "Input data has to be packed batch by batch. The batchsize is set to "
          "%d, but the real input is packed with batchsize = %d",
          FLAGS_batch_size, inputs[0][0].shape[0]));

  LOG(INFO) << "FP32 & INT8 prediction run: batch_size " << FLAGS_batch_size
            << ", warmup batch size " << FLAGS_warmup_batch_size << ".";

  LOG(INFO) << "--- FP32 prediction start ---";
  auto *cfg1 = reinterpret_cast<const PaddlePredictor::Config *>(config1);
  PrintConfig(cfg1, true);
  std::vector<std::vector<PaddleTensor>> analysis_outputs;
  float sample_latency_fp32{-1};

  if (FLAGS_enable_fp32) {
    TestOneThreadPrediction(cfg1, inputs, &analysis_outputs, true,
                            VarType::FP32, &sample_latency_fp32);
  }

  LOG(INFO) << "--- INT8 prediction start ---";
  auto *cfg2 = reinterpret_cast<const PaddlePredictor::Config *>(config2);
  PrintConfig(cfg2, true);
  std::vector<std::vector<PaddleTensor>> int8_outputs;
  float sample_latency_int8{-1};

  if (FLAGS_enable_int8) {
    TestOneThreadPrediction(cfg2, inputs, &int8_outputs, true, VarType::INT8,
                            &sample_latency_int8);
  }
  SummarizePerformance(sample_latency_fp32, sample_latency_int8);
  if (with_accuracy_layer) {
    CompareAccuracy(int8_outputs, analysis_outputs, compared_idx);
  }
}

N
nhzlx 已提交
726 727 728 729 730 731 732 733 734 735
void CompareNativeAndAnalysis(
    PaddlePredictor *native_pred, PaddlePredictor *analysis_pred,
    const std::vector<std::vector<PaddleTensor>> &inputs) {
  int batch_size = FLAGS_batch_size;
  std::vector<PaddleTensor> native_outputs, analysis_outputs;
  native_pred->Run(inputs[0], &native_outputs, batch_size);
  analysis_pred->Run(inputs[0], &analysis_outputs, batch_size);
  CompareResult(analysis_outputs, native_outputs);
}

736
void CompareAnalysisAndZeroCopy(
737
    PaddlePredictor::Config *config, PaddlePredictor::Config *config1,
738 739 740 741 742 743 744 745 746
    const std::vector<std::vector<PaddleTensor>> &inputs,
    const std::vector<std::string> &outputs_name) {
  int batch_size = FLAGS_batch_size;
  // analysis
  std::vector<PaddleTensor> analysis_outputs;
  auto predictor = CreateTestPredictor(config, true);
  predictor->Run(inputs[0], &analysis_outputs, batch_size);
  // analysis + zero_copy
  std::vector<ZeroCopyTensor> zerocopy_outputs;
747 748
  reinterpret_cast<AnalysisConfig *>(config1)->SwitchUseFeedFetchOps(false);
  predictor = CreateTestPredictor(config1, true);
749 750 751 752 753 754
  ConvertPaddleTensorToZeroCopyTensor(predictor.get(), inputs[0]);
  predictor->ZeroCopyRun();
  for (size_t i = 0; i < outputs_name.size(); i++) {
    ZeroCopyTensor zerocopy_output =
        *predictor->GetOutputTensor(outputs_name[i]).get();
    zerocopy_outputs.emplace_back(zerocopy_output);
L
luotao1 已提交
755
    LOG(INFO) << "ZeroCopy output: " << DescribeZeroCopyTensor(zerocopy_output);
756 757 758 759 760
  }
  // compare
  CompareResult(analysis_outputs, zerocopy_outputs);
}

761 762 763 764 765 766 767
void SaveOptimModel(AnalysisConfig *cfg, const std::string &dstPath) {
  auto predictor = CreateTestPredictor(
      reinterpret_cast<const PaddlePredictor::Config *>(cfg),
      FLAGS_use_analysis);
  (static_cast<AnalysisPredictor *>(predictor.get()))->SaveOptimModel(dstPath);
}

L
luotao1 已提交
768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838
template <typename T>
std::string LoDTensorSummary(const framework::LoDTensor &tensor) {
  std::stringstream ss;
  ss << "\n---- tensor ---" << '\n';
  ss << "lod: [";
  for (const auto &level : tensor.lod()) {
    ss << "[ ";
    for (auto i : level) {
      ss << i << ", ";
    }
    ss << "]";
  }
  ss << "]\n";

  ss << "shape: [";
  int size = 1;
  for (int i = 0; i < tensor.dims().size(); i++) {
    int dim = tensor.dims()[i];
    ss << dim << ", ";
    size *= dim;
  }
  ss << "]\n";

  ss << "data: ";
  for (int i = 0; i < std::min(20, size); i++) {
    ss << tensor.data<T>()[i] << " ";
  }
  ss << "\n";

  return ss.str();
}

static bool CompareLoD(const framework::LoD &a, const framework::LoD &b) {
  if (a.size() != b.size()) {
    LOG(ERROR) << string::Sprintf("lod size not match %d != %d", a.size(),
                                  b.size());
    return false;
  }
  for (size_t i = 0; i < a.size(); i++) {
    auto &al = a[i];
    auto &bl = b[i];
    if (al.size() != bl.size()) {
      LOG(ERROR) << string::Sprintf("level size %d != %d", al.size(),
                                    bl.size());
      return false;
    }
  }
  return true;
}

static bool CompareShape(const std::vector<int64_t> &a,
                         const std::vector<int64_t> &b) {
  if (a.size() != b.size()) {
    LOG(ERROR) << string::Sprintf("shape size not match %d != %d", a.size(),
                                  b.size());
    return false;
  }
  for (size_t i = 0; i < a.size(); i++) {
    if (a[i] != b[i]) {
      LOG(ERROR) << string::Sprintf("shape %d-th element not match %d != %d", i,
                                    a[i], b[i]);
      return false;
    }
  }
  return true;
}

static bool CompareTensorData(const framework::LoDTensor &a,
                              const framework::LoDTensor &b) {
  auto a_shape = framework::vectorize(a.dims());
  auto b_shape = framework::vectorize(b.dims());
839
  size_t a_size = std::accumulate(a_shape.begin(), a_shape.end(), size_t{1},
L
luotao1 已提交
840
                                  [](int a, int b) { return a * b; });
841
  size_t b_size = std::accumulate(b_shape.begin(), b_shape.end(), size_t{1},
L
luotao1 已提交
842 843 844 845 846 847 848
                                  [](int a, int b) { return a * b; });
  if (a_size != b_size) {
    LOG(ERROR) << string::Sprintf("tensor data size not match, %d != %d",
                                  a_size, b_size);
  }

  for (size_t i = 0; i < a_size; i++) {
849
    if (a.type() == VarType::FP32) {
L
luotao1 已提交
850 851 852 853 854 855 856 857
      const auto *a_data = a.data<float>();
      const auto *b_data = b.data<float>();
      if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
        LOG(ERROR) << string::Sprintf(
            "tensor data %d-th element not match, %f != %f", i, a_data[i],
            b_data[i]);
        return false;
      }
858
    } else if (a.type() == VarType::INT64) {
L
luotao1 已提交
859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889
      const auto *a_data = a.data<int64_t>();
      const auto *b_data = b.data<int64_t>();
      if (std::abs(a_data[i] - b_data[i]) > 1e-3) {
        LOG(ERROR) << string::Sprintf(
            "tensor data %d-th element not match, %f != %f", i, a_data[i],
            b_data[i]);
        return false;
      }
    }
  }

  return true;
}

static bool CompareTensor(const framework::LoDTensor &a,
                          const framework::LoDTensor &b) {
  if (!CompareLoD(a.lod(), b.lod())) {
    return false;
  }
  if (!CompareShape(framework::vectorize(a.dims()),
                    framework::vectorize(b.dims()))) {
    return false;
  }

  if (!CompareTensorData(a, b)) {
    return false;
  }

  return true;
}

L
luotao1 已提交
890 891
}  // namespace inference
}  // namespace paddle