analysis_predictor.cc 22.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Y
Yan Chunwei 已提交
15
#include "paddle/fluid/inference/api/analysis_predictor.h"
16 17
#include <glog/logging.h>
#include <algorithm>
N
nhzlx 已提交
18
#include <fstream>
19
#include <memory>
20 21
#include <string>
#include <vector>
22
#include "paddle/fluid/framework/feed_fetch_method.h"
23
#include "paddle/fluid/framework/feed_fetch_type.h"
Y
Yan Chunwei 已提交
24
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
25
#include "paddle/fluid/framework/ir/pass.h"
26
#include "paddle/fluid/framework/naive_executor.h"
27
#include "paddle/fluid/framework/scope.h"
28
#include "paddle/fluid/inference/api/helper.h"
29
#include "paddle/fluid/inference/api/paddle_inference_api.h"
L
luotao1 已提交
30
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
31 32
#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
N
nhzlx 已提交
33
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
34
#endif
N
nhzlx 已提交
35
#include "paddle/fluid/inference/analysis/helper.h"
36
#include "paddle/fluid/inference/utils/singleton.h"
37
#include "paddle/fluid/memory/memcpy.h"
38
#include "paddle/fluid/platform/cpu_helper.h"
39
#include "paddle/fluid/platform/gpu_info.h"
T
tensor-tang 已提交
40 41 42
#include "paddle/fluid/platform/profiler.h"

DECLARE_bool(profile);
43 44 45

namespace paddle {

46
using contrib::AnalysisConfig;
N
nhzlx 已提交
47 48
using inference::Singleton;
using inference::tensorrt::TRTInt8Calibrator;
N
nhzlx 已提交
49 50
using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorEngineManager;
51

52 53 54 55 56 57 58 59 60 61 62
namespace {
bool IsPersistable(const framework::VarDesc *var) {
  if (var->Persistable() &&
      var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
      var->GetType() != framework::proto::VarType::FETCH_LIST) {
    return true;
  }
  return false;
}
}  // namespace

Y
Yan Chunwei 已提交
63
bool AnalysisPredictor::Init(
64 65
    const std::shared_ptr<framework::Scope> &parent_scope,
    const std::shared_ptr<framework::ProgramDesc> &program) {
M
minqiyang 已提交
66
  VLOG(3) << "Predictor::init()";
T
tensor-tang 已提交
67 68 69
  if (FLAGS_profile) {
    LOG(WARNING) << "Profiler is actived, might affect the performance";
    LOG(INFO) << "You can turn off by set gflags '-profile false'";
70 71
    auto tracking_device = config_.use_gpu() ? platform::ProfilerState::kAll
                                             : platform::ProfilerState::kCPU;
T
tensor-tang 已提交
72 73 74
    platform::EnableProfiler(tracking_device);
  }

75
  // no matter with or without MKLDNN
L
luotao1 已提交
76
  paddle::platform::SetNumThreads(config_.cpu_math_library_num_threads());
77

78 79 80 81 82 83 84 85 86 87 88 89 90
  if (!PrepareScope(parent_scope)) {
    return false;
  }
  if (!CreateExecutor()) {
    return false;
  }
  if (!PrepareProgram(program)) {
    return false;
  }

  // Prepare executor, create local variables.
  if (!PrepareExecutor()) {
    return true;
Y
Yan Chunwei 已提交
91
  }
92 93 94 95 96 97 98 99 100

  // Get the feed_target_names and fetch_target_names
  PrepareFeedFetch();

  return true;
}

bool AnalysisPredictor::PrepareScope(
    const std::shared_ptr<framework::Scope> &parent_scope) {
Y
Yan Chunwei 已提交
101
  if (parent_scope) {
102 103 104
    PADDLE_ENFORCE_NOT_NULL(
        parent_scope,
        "Both program and parent_scope should be set in Clone mode.");
Y
Yan Chunwei 已提交
105
    scope_ = parent_scope;
106
    status_is_cloned_ = true;
Y
Yan Chunwei 已提交
107 108 109
  } else {
    paddle::framework::InitDevices(false);
    scope_.reset(new paddle::framework::Scope());
110
    status_is_cloned_ = false;
Y
Yan Chunwei 已提交
111
  }
112 113 114 115 116
  sub_scope_ = &scope_->NewScope();
  return true;
}
bool AnalysisPredictor::PrepareProgram(
    const std::shared_ptr<framework::ProgramDesc> &program) {
117 118
  if (!program) {
    if (!LoadProgramDesc()) return false;
119 120 121 122

    // Optimize the program, and load parameters and modify them in the
    // scope_.
    // This will change the scope_ address.
123
    if (config_.ir_optim()) {
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
      status_ir_optim_enabled_ = true;
      OptimizeInferenceProgram();
    } else {
      // If the parent_scope is passed, we assert that the persistable variables
      // are already created, so just create the no persistable variables.

      // If not cloned, the parameters should be loaded
      // OptimizeInferenceProgram.
      // So in both cases, just the local variables are needed to load, not the
      // parematers.
      executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);

      // Load parameters
      LOG(INFO) << "load parameters ";
      LoadParameters();
    }
Y
Yan Chunwei 已提交
140
  } else {
141 142
    // If the program is passed from external, no need to optimize it, this
    // logic is used in the clone scenario.
143 144
    inference_program_ = program;
  }
M
Michal Gallus 已提交
145

146 147 148 149 150
  executor_->CreateVariables(*inference_program_, 0, false, sub_scope_);

  return true;
}
bool AnalysisPredictor::CreateExecutor() {
151
  if (config_.use_gpu_) {
152
    status_use_gpu_ = true;
153
    place_ = paddle::platform::CUDAPlace(config_.device_id_);
154 155 156 157 158 159 160 161
  } else {
    place_ = paddle::platform::CPUPlace();
  }
  executor_.reset(new paddle::framework::NaiveExecutor(place_));
  return true;
}
bool AnalysisPredictor::PrepareExecutor() {
  executor_->Prepare(sub_scope_, *inference_program_, 0,
162
                     config_.use_feed_fetch_ops_);
163

164
  PADDLE_ENFORCE_NOT_NULL(sub_scope_);
Y
Yan Chunwei 已提交
165

166 167 168
  return true;
}

L
luotao1 已提交
169
void AnalysisPredictor::SetMkldnnThreadID(int tid) {
L
luotao1 已提交
170 171 172 173 174 175 176
#ifdef PADDLE_WITH_MKLDNN
  platform::set_cur_thread_id(tid);
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
#endif
}

177 178 179
bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
                            std::vector<PaddleTensor> *output_data,
                            int batch_size) {
M
minqiyang 已提交
180
  VLOG(3) << "Predictor::predict";
181 182 183 184 185 186
  inference::Timer timer;
  timer.tic();
  // set feed variable
  framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get();
  if (!SetFeed(inputs, scope)) {
    LOG(ERROR) << "fail to set feed";
Y
Yan Chunwei 已提交
187
    return false;
188
  }
M
Michal Gallus 已提交
189

190 191 192
  // Run the inference program
  // if share variables, we need not create variables
  executor_->Run();
193

194 195 196 197
  // get fetch variable
  if (!GetFetch(output_data, scope)) {
    LOG(ERROR) << "fail to get fetches";
    return false;
T
tensor-tang 已提交
198
  }
M
minqiyang 已提交
199
  VLOG(3) << "predict cost: " << timer.toc() << "ms";
Y
Yan Chunwei 已提交
200

Y
Yan Chunwei 已提交
201 202 203 204 205 206 207
  // All the containers in the scope will be hold in inference, but the
  // operators assume that the container will be reset after each batch.
  // Here is a bugfix, collect all the container variables, and reset then to a
  // bool; the next time, the operator will call MutableData and construct a new
  // container again, so that the container will be empty for each batch.
  tensor_array_batch_cleaner_.CollectNoTensorVars(sub_scope_);
  tensor_array_batch_cleaner_.ResetNoTensorVars();
208 209
  return true;
}
210

211 212
bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
                                framework::Scope *scope) {
M
minqiyang 已提交
213
  VLOG(3) << "Predictor::set_feed";
214 215 216 217 218 219 220 221 222 223 224 225 226 227
  if (inputs.size() != feeds_.size()) {
    LOG(ERROR) << "wrong feed input size, need " << feeds_.size() << " but get "
               << inputs.size();
    return false;
  }

  // Cache the inputs memory for better concurrency performance.
  feed_tensors_.resize(inputs.size());

  for (size_t i = 0; i < inputs.size(); ++i) {
    auto &input = feed_tensors_[i];
    framework::DDim ddim = framework::make_ddim(inputs[i].shape);
    void *input_ptr;
    if (inputs[i].dtype == PaddleDType::INT64) {
228
      input_ptr = input.mutable_data<int64_t>(ddim, place_);
229
    } else if (inputs[i].dtype == PaddleDType::FLOAT32) {
230
      input_ptr = input.mutable_data<float>(ddim, place_);
231 232 233 234 235
    } else {
      LOG(ERROR) << "unsupported feed type " << inputs[i].dtype;
      return false;
    }

236 237 238 239 240 241
    if (platform::is_cpu_place(place_)) {
      // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
      std::memcpy(static_cast<void *>(input_ptr), inputs[i].data.data(),
                  inputs[i].data.length());
    } else {
#ifdef PADDLE_WITH_CUDA
Q
qingqing01 已提交
242 243 244 245
      platform::DeviceContextPool &pool =
          platform::DeviceContextPool::Instance();
      auto *dev_ctx =
          static_cast<const platform::CUDADeviceContext *>(pool.Get(place_));
246 247 248
      auto dst_gpu_place = boost::get<platform::CUDAPlace>(place_);
      memory::Copy(dst_gpu_place, static_cast<void *>(input_ptr),
                   platform::CPUPlace(), inputs[i].data.data(),
Q
qingqing01 已提交
249
                   inputs[i].data.length(), dev_ctx->stream());
250 251 252 253
#else
      PADDLE_THROW("Not compile with CUDA, should not reach here.");
#endif
    }
254 255 256 257 258 259 260
    // TODO(Superjomn) Low performance, need optimization for heavy LoD copy.
    framework::LoD lod;
    for (auto &level : inputs[i].lod) {
      lod.emplace_back(level);
    }
    input.set_lod(lod);
    int idx = -1;
261
    if (config_.specify_input_name_) {
T
tensor-tang 已提交
262 263
      auto name = inputs[i].name;
      if (feed_names_.find(name) == feed_names_.end()) {
T
tensor-tang 已提交
264 265
        LOG(ERROR) << "feed names from program do not have name: [" << name
                   << "] from specified input";
T
tensor-tang 已提交
266 267
      }
      idx = feed_names_[name];
268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
    } else {
      idx = boost::get<int>(feeds_[i]->GetAttr("col"));
    }
    framework::SetFeedVariable(scope, input, "feed", idx);
  }
  return true;
}

template <typename T>
void AnalysisPredictor::GetFetchOne(const framework::LoDTensor &fetch,
                                    PaddleTensor *output) {
  // set shape.
  auto shape = framework::vectorize(fetch.dims());
  output->shape.assign(shape.begin(), shape.end());
  // set data.
  const T *data = fetch.data<T>();
  int num_elems = inference::VecReduceToInt(shape);
  output->data.Resize(num_elems * sizeof(T));
  // The fetched tensor output by fetch op, should always in CPU memory, so just
  // copy.
  memcpy(output->data.data(), data, num_elems * sizeof(T));
  // set lod
  output->lod.clear();
  for (auto &level : fetch.lod()) {
    output->lod.emplace_back(level.begin(), level.end());
  }
}

bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
                                 framework::Scope *scope) {
M
minqiyang 已提交
298
  VLOG(3) << "Predictor::get_fetch";
299 300 301 302 303 304 305 306
  outputs->resize(fetchs_.size());
  for (size_t i = 0; i < fetchs_.size(); ++i) {
    int idx = boost::get<int>(fetchs_[i]->GetAttr("col"));
    PADDLE_ENFORCE((size_t)idx == i);
    framework::LoDTensor &fetch =
        framework::GetFetchVariable(*scope, "fetch", idx);
    auto type = fetch.type();
    auto output = &(outputs->at(i));
307
    output->name = fetchs_[idx]->Input("X")[0];
Y
Yu Yang 已提交
308
    if (type == framework::proto::VarType::FP32) {
309 310
      GetFetchOne<float>(fetch, output);
      output->dtype = PaddleDType::FLOAT32;
Y
Yu Yang 已提交
311
    } else if (type == framework::proto::VarType::INT64) {
312 313 314 315 316 317
      GetFetchOne<int64_t>(fetch, output);
      output->dtype = PaddleDType::INT64;
    } else {
      LOG(ERROR) << "unknown type, only support float32 and int64 now.";
    }
  }
Y
Yan Chunwei 已提交
318 319
  return true;
}
320

321
// NOTE All the members in AnalysisConfig should be copied to Argument.
Y
Yan Chunwei 已提交
322
void AnalysisPredictor::OptimizeInferenceProgram() {
323 324
  status_program_optimized_ = true;

325 326
  argument_.SetUseGPU(config_.use_gpu());
  argument_.SetGPUDeviceId(config_.gpu_device_id());
T
Tao Luo 已提交
327
  argument_.SetModelFromMemory(config_.model_from_memory_);
Y
Yan Chunwei 已提交
328
  // Analyze inference_program
329 330
  if (!config_.model_dir().empty()) {
    argument_.SetModelDir(config_.model_dir());
N
nhzlx 已提交
331
    argument_.SetModelPath(config_.model_dir());
T
Tao Luo 已提交
332 333
  } else {
    PADDLE_ENFORCE(
334
        !config_.params_file().empty(),
T
Tao Luo 已提交
335
        "Either model_dir or (param_file, prog_file) should be set.");
336
    PADDLE_ENFORCE(!config_.prog_file().empty());
N
nhzlx 已提交
337
    std::string dir = inference::analysis::GetDirRoot(config_.prog_file());
N
nhzlx 已提交
338 339

    argument_.SetModelPath(dir);
340 341
    argument_.SetModelProgramPath(config_.prog_file());
    argument_.SetModelParamsPath(config_.params_file());
Y
Yan Chunwei 已提交
342
  }
343

344
  if (config_.use_gpu() && config_.tensorrt_engine_enabled()) {
345 346 347
    argument_.SetUseTensorRT(true);
    argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_);
    argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
348
    argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
N
nhzlx 已提交
349
    argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
W
Wojciech Uss 已提交
350
  }
351

352 353 354 355
  if (config_.use_mkldnn_) {
    argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
  }

356
  auto passes = config_.pass_builder()->AllPasses();
357
  if (!config_.ir_optim()) passes.clear();
358 359 360 361 362 363 364
  argument_.SetIrAnalysisPasses(passes);
  argument_.SetScopeNotOwned(const_cast<framework::Scope *>(scope_.get()));
  Analyzer().Run(&argument_);

  PADDLE_ENFORCE(argument_.scope_valid());
  VLOG(5) << "to prepare executor";
  ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program);
Y
Yan Chunwei 已提交
365
  inference_program_.reset(
366
      new framework::ProgramDesc(argument_.ir_analyzed_program()));
367
  LOG(INFO) << "== optimize end ==";
Y
Yan Chunwei 已提交
368
}
369 370

template <>
371 372
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
    AnalysisConfig, PaddleEngineKind::kAnalysis>(const AnalysisConfig &config) {
M
minqiyang 已提交
373
  VLOG(3) << "create AnalysisConfig";
374
  if (config.use_gpu()) {
375
    // 1. GPU memeroy
376 377 378
    PADDLE_ENFORCE_GT(config.memory_pool_init_size_mb(), 0.f);
    PADDLE_ENFORCE_GE(config.gpu_device_id(), 0, "Invalid device id %d",
                      config.gpu_device_id());
379
    std::vector<std::string> flags;
380 381 382 383 384 385 386 387 388 389 390

    float fraction_of_gpu_memory = config.fraction_of_gpu_memory_for_pool();
    if (fraction_of_gpu_memory > 0.95f) {
      LOG(ERROR)
          << "Allocate too much memory for the GPU memory pool, assigned "
          << config.memory_pool_init_size_mb() << " MB";
      LOG(ERROR)
          << "Try to shink the value by setting AnalysisConfig::EnableGpu(...)";
    }

    if (fraction_of_gpu_memory >= 0.0f || fraction_of_gpu_memory <= 0.95f) {
391 392
      flags.push_back("dummpy");
      std::string flag = "--fraction_of_gpu_memory_to_use=" +
393
                         std::to_string(fraction_of_gpu_memory);
394
      flags.push_back(flag);
M
minqiyang 已提交
395
      VLOG(3) << "set flag: " << flag;
396 397 398 399 400
      framework::InitGflags(flags);
    }
  }

  std::unique_ptr<PaddlePredictor> predictor(new AnalysisPredictor(config));
401
  if (!dynamic_cast<AnalysisPredictor *>(predictor.get())->Init(nullptr)) {
402 403
    return nullptr;
  }
404
  return std::move(predictor);
405 406
}

407
void AnalysisPredictor::PrepareFeedFetch() {
408 409
  PADDLE_ENFORCE_NOT_NULL(sub_scope_);
  CreateFeedFetchVar(sub_scope_);
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
  for (auto *op : inference_program_->Block(0).AllOps()) {
    if (op->Type() == "feed") {
      int idx = boost::get<int>(op->GetAttr("col"));
      if (feeds_.size() <= static_cast<size_t>(idx)) {
        feeds_.resize(idx + 1);
      }
      feeds_[idx] = op;
      feed_names_[op->Output("Out")[0]] = idx;
    } else if (op->Type() == "fetch") {
      int idx = boost::get<int>(op->GetAttr("col"));
      if (fetchs_.size() <= static_cast<size_t>(idx)) {
        fetchs_.resize(idx + 1);
      }
      fetchs_[idx] = op;
    }
  }
}

428 429 430 431 432 433 434 435
void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) {
  PADDLE_ENFORCE_NOT_NULL(scope);
  auto *var = scope->Var("feed");
  var->GetMutable<framework::FeedFetchList>();
  var = scope->Var("fetch");
  var->GetMutable<framework::FeedFetchList>();
}

436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
    const std::string &name) {
  PADDLE_ENFORCE(executor_->scope()->FindVar(name), "no name called %s", name);
  std::unique_ptr<ZeroCopyTensor> res(
      new ZeroCopyTensor(static_cast<void *>(executor_->scope())));
  res->input_or_output_ = true;
  res->SetName(name);
  return res;
}

std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
    const std::string &name) {
  PADDLE_ENFORCE(executor_->scope()->FindVar(name), "no name called %s", name);
  std::unique_ptr<ZeroCopyTensor> res(
      new ZeroCopyTensor(static_cast<void *>(executor_->scope())));
  res->input_or_output_ = false;
  res->SetName(name);
  return res;
}

bool AnalysisPredictor::ZeroCopyRun() {
  executor_->Run();
Y
Yan Chunwei 已提交
458
  // Fix TensorArray reuse not cleaned bug.
Y
Yan Chunwei 已提交
459
  tensor_array_batch_cleaner_.CollectTensorArrays(sub_scope_);
Y
Yan Chunwei 已提交
460
  tensor_array_batch_cleaner_.ResetTensorArray();
461 462 463 464 465
  return true;
}

bool AnalysisPredictor::LoadProgramDesc() {
  // Initialize the inference program
466
  std::string filename;
467 468 469
  if (!config_.model_dir().empty()) {
    filename = config_.model_dir() + "/__model__";
  } else if (!config_.prog_file().empty() && !config_.params_file().empty()) {
470 471 472
    // All parameters are saved in a single file.
    // The file names should be consistent with that used
    // in Python API `fluid.io.save_inference_model`.
473
    filename = config_.prog_file();
474
  } else {
475
    if (config_.model_dir().empty() && config_.prog_file().empty()) {
476 477 478 479
      LOG(ERROR)
          << "Either model_dir or (prog_file, param_file) should be set.";
      return false;
    }
480
    LOG(ERROR) << string::Sprintf(
481 482
        "not valid model path '%s' or program path '%s'.", config_.model_dir(),
        config_.params_file());
483 484
    return false;
  }
485 486 487

  // Create ProgramDesc
  framework::proto::ProgramDesc proto;
T
Tao Luo 已提交
488
  if (!config_.model_from_memory()) {
T
Tao Luo 已提交
489 490 491
    std::string pb_content;
    // Read binary
    std::ifstream fin(filename, std::ios::in | std::ios::binary);
T
Tao Luo 已提交
492 493
    PADDLE_ENFORCE(static_cast<bool>(fin.is_open()), "Cannot open file %s",
                   filename);
T
Tao Luo 已提交
494 495 496 497 498 499 500 501
    fin.seekg(0, std::ios::end);
    pb_content.resize(fin.tellg());
    fin.seekg(0, std::ios::beg);
    fin.read(&(pb_content.at(0)), pb_content.size());
    fin.close();

    proto.ParseFromString(pb_content);
  } else {
502
    proto.ParseFromString(config_.prog_file());
T
Tao Luo 已提交
503
  }
504 505 506 507 508 509 510
  inference_program_.reset(new framework::ProgramDesc(proto));
  return true;
}

bool AnalysisPredictor::LoadParameters() {
  PADDLE_ENFORCE_NOT_NULL(inference_program_.get(),
                          "The inference program should be loaded first.");
T
Tao Luo 已提交
511

512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531
  const auto &global_block = inference_program_->MutableBlock(0);

  // create a temporary program to load parameters.

  std::unique_ptr<framework::ProgramDesc> load_program(
      new framework::ProgramDesc());
  framework::BlockDesc *load_block = load_program->MutableBlock(0);
  std::vector<std::string> params;

  for (auto *var : global_block->AllVars()) {
    if (IsPersistable(var)) {
      VLOG(3) << "persistable variable's name: " << var->Name();

      framework::VarDesc *new_var = load_block->Var(var->Name());
      new_var->SetShape(var->GetShape());
      new_var->SetDataType(var->GetDataType());
      new_var->SetType(var->GetType());
      new_var->SetLoDLevel(var->GetLoDLevel());
      new_var->SetPersistable(true);

532
      if (!config_.params_file().empty()) {
533 534 535 536 537 538
        params.push_back(new_var->Name());
      } else {
        // append_op
        framework::OpDesc *op = load_block->AppendOp();
        op->SetType("load");
        op->SetOutput("Out", {new_var->Name()});
539
        op->SetAttr("file_path", {config_.model_dir() + "/" + new_var->Name()});
540 541 542 543 544
        op->CheckAttrs();
      }
    }
  }

545
  if (!config_.params_file().empty()) {
546 547 548 549 550 551
    // sort paramlist to have consistent ordering
    std::sort(params.begin(), params.end());
    // append just the load_combine op
    framework::OpDesc *op = load_block->AppendOp();
    op->SetType("load_combine");
    op->SetOutput("Out", params);
552
    op->SetAttr("file_path", {config_.params_file()});
553 554 555 556
    op->CheckAttrs();
  }

  // Use NaiveExecutor to Load parameters.
S
superjomn 已提交
557
  framework::NaiveExecutor e(place_);
558 559 560 561
  e.Prepare(scope_.get(), *load_program, 0, false);
  e.Run();
  VLOG(3) << "get " << scope_->LocalVarNames().size() << " vars after load";

562 563
  return true;
}
564

N
nhzlx 已提交
565
#if PADDLE_WITH_TENSORRT
N
nhzlx 已提交
566 567 568 569 570 571 572 573
bool AnalysisPredictor::SaveTrtCalibToDisk() {
  PADDLE_ENFORCE(config_.tensorrt_engine_enabled(),
                 "This func can be invoked only in trt mode");
  auto &block = inference_program_->Block(0);
  for (auto &op_desc : block.AllOps()) {
    if (op_desc->Type() == "tensorrt_engine") {
      std::string engine_name =
          boost::get<std::string>(op_desc->GetAttr("engine_key"));
N
nhzlx 已提交
574
      if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_name)) {
N
nhzlx 已提交
575 576 577 578
        LOG(ERROR) << "You should run the predictor(with trt) on the real data "
                      "to generate calibration info";
        return false;
      }
N
nhzlx 已提交
579 580
      TRTCalibratorEngine *calib_engine =
          Singleton<TRTCalibratorEngineManager>::Global().Get(engine_name);
N
nhzlx 已提交
581
      LOG(INFO) << "Wait for calib threads done.";
N
nhzlx 已提交
582
      calib_engine->calib_->waitAndSetDone();
N
nhzlx 已提交
583
      LOG(INFO) << "Finish wait.";
N
nhzlx 已提交
584 585 586
      calib_engine->thr_->join();
      std::string calibration_table_data =
          calib_engine->calib_->getCalibrationTableAsString();
N
nhzlx 已提交
587

N
nhzlx 已提交
588
      if (calibration_table_data.empty()) {
N
nhzlx 已提交
589 590 591
        LOG(ERROR) << "the calibration table is empty.";
        return false;
      }
N
nhzlx 已提交
592 593 594 595 596 597 598 599 600

      std::string calibration_table_data_path =
          inference::analysis::GetTrtCalibPath(argument_.model_path(),
                                               engine_name);

      std::ofstream ofile(calibration_table_data_path, std::ios::out);
      LOG(INFO) << "Write Paddle-TRT INT8 calibration table data to file "
                << calibration_table_data_path;
      ofile << calibration_table_data;
N
nhzlx 已提交
601 602 603 604
      ofile.close();
    }
  }
  // Free all calibrator resources.
N
nhzlx 已提交
605
  Singleton<TRTCalibratorEngineManager>::Global().DeleteALL();
N
nhzlx 已提交
606 607
  return true;
}
N
nhzlx 已提交
608
#endif
N
nhzlx 已提交
609

610
AnalysisPredictor::~AnalysisPredictor() {
N
nhzlx 已提交
611
#if PADDLE_WITH_TENSORRT
N
nhzlx 已提交
612
  if (config_.tensorrt_engine_enabled() &&
N
nhzlx 已提交
613 614
      config_.tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
      Singleton<TRTCalibratorEngineManager>::Global().Has()) {
N
nhzlx 已提交
615 616
    SaveTrtCalibToDisk();
  }
N
nhzlx 已提交
617
#endif
618 619 620 621 622 623 624 625 626
  if (FLAGS_profile) {
    platform::DisableProfiler(platform::EventSortingKey::kTotal,
                              "./profile.log");
  }
  if (sub_scope_) {
    scope_->DeleteScope(sub_scope_);
  }
}

627 628 629 630 631 632
std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone() {
  auto *x = new AnalysisPredictor(config_);
  x->Init(scope_, inference_program_);
  return std::unique_ptr<PaddlePredictor>(x);
}

Y
Yan Chunwei 已提交
633 634
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<contrib::AnalysisConfig>(
635
    const contrib::AnalysisConfig &config) {
Y
Yan Chunwei 已提交
636 637 638 639
  return CreatePaddlePredictor<contrib::AnalysisConfig,
                               PaddleEngineKind::kAnalysis>(config);
}

640
}  // namespace paddle
641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662

#if PADDLE_WITH_TENSORRT
USE_TRT_CONVERTER(elementwise_add_weight);
USE_TRT_CONVERTER(elementwise_add_tensor);
USE_TRT_CONVERTER(elementwise_sub_tensor);
USE_TRT_CONVERTER(elementwise_div_tensor);
USE_TRT_CONVERTER(elementwise_mul_tensor);
USE_TRT_CONVERTER(elementwise_max_tensor);
USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor);
USE_TRT_CONVERTER(mul);
USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu);
USE_TRT_CONVERTER(sigmoid);
USE_TRT_CONVERTER(tanh);
USE_TRT_CONVERTER(fc);
USE_TRT_CONVERTER(pool2d);
USE_TRT_CONVERTER(softmax);
USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad);
663
USE_TRT_CONVERTER(split);
664 665
USE_TRT_CONVERTER(prelu);
USE_TRT_CONVERTER(conv2d_transpose);
H
hjchen2 已提交
666
USE_TRT_CONVERTER(leaky_relu);
667
#endif