analysis_predictor.cc 20.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Y
Yan Chunwei 已提交
15
#include "paddle/fluid/inference/api/analysis_predictor.h"
16 17
#include <glog/logging.h>
#include <algorithm>
18
#include <memory>
19 20
#include <string>
#include <vector>
21
#include "paddle/fluid/framework/feed_fetch_method.h"
22
#include "paddle/fluid/framework/feed_fetch_type.h"
Y
Yan Chunwei 已提交
23
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
24
#include "paddle/fluid/framework/ir/pass.h"
25
#include "paddle/fluid/framework/naive_executor.h"
26
#include "paddle/fluid/framework/scope.h"
27
#include "paddle/fluid/inference/api/helper.h"
28
#include "paddle/fluid/inference/api/paddle_inference_api.h"
L
luotao1 已提交
29
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
30 31 32
#if PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#endif
33
#include "paddle/fluid/inference/utils/singleton.h"
34
#include "paddle/fluid/memory/memcpy.h"
35
#include "paddle/fluid/platform/cpu_helper.h"
36
#include "paddle/fluid/platform/gpu_info.h"
T
tensor-tang 已提交
37 38 39
#include "paddle/fluid/platform/profiler.h"

DECLARE_bool(profile);
40 41 42

namespace paddle {

43 44
using contrib::AnalysisConfig;

45 46 47 48 49 50 51 52 53 54 55
namespace {
bool IsPersistable(const framework::VarDesc *var) {
  if (var->Persistable() &&
      var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
      var->GetType() != framework::proto::VarType::FETCH_LIST) {
    return true;
  }
  return false;
}
}  // namespace

Y
Yan Chunwei 已提交
56
bool AnalysisPredictor::Init(
57 58
    const std::shared_ptr<framework::Scope> &parent_scope,
    const std::shared_ptr<framework::ProgramDesc> &program) {
M
minqiyang 已提交
59
  VLOG(3) << "Predictor::init()";
T
tensor-tang 已提交
60 61 62
  if (FLAGS_profile) {
    LOG(WARNING) << "Profiler is actived, might affect the performance";
    LOG(INFO) << "You can turn off by set gflags '-profile false'";
63 64
    auto tracking_device = config_.use_gpu() ? platform::ProfilerState::kAll
                                             : platform::ProfilerState::kCPU;
T
tensor-tang 已提交
65 66 67
    platform::EnableProfiler(tracking_device);
  }

68
  // no matter with or without MKLDNN
L
luotao1 已提交
69
  paddle::platform::SetNumThreads(config_.cpu_math_library_num_threads());
70

71 72 73 74 75 76 77 78 79 80 81 82 83
  if (!PrepareScope(parent_scope)) {
    return false;
  }
  if (!CreateExecutor()) {
    return false;
  }
  if (!PrepareProgram(program)) {
    return false;
  }

  // Prepare executor, create local variables.
  if (!PrepareExecutor()) {
    return true;
Y
Yan Chunwei 已提交
84
  }
85 86 87 88 89 90 91 92 93

  // Get the feed_target_names and fetch_target_names
  PrepareFeedFetch();

  return true;
}

bool AnalysisPredictor::PrepareScope(
    const std::shared_ptr<framework::Scope> &parent_scope) {
Y
Yan Chunwei 已提交
94
  if (parent_scope) {
95 96 97
    PADDLE_ENFORCE_NOT_NULL(
        parent_scope,
        "Both program and parent_scope should be set in Clone mode.");
Y
Yan Chunwei 已提交
98
    scope_ = parent_scope;
99
    status_is_cloned_ = true;
Y
Yan Chunwei 已提交
100 101 102
  } else {
    paddle::framework::InitDevices(false);
    scope_.reset(new paddle::framework::Scope());
103
    status_is_cloned_ = false;
Y
Yan Chunwei 已提交
104
  }
105 106 107 108 109
  sub_scope_ = &scope_->NewScope();
  return true;
}
bool AnalysisPredictor::PrepareProgram(
    const std::shared_ptr<framework::ProgramDesc> &program) {
110 111
  if (!program) {
    if (!LoadProgramDesc()) return false;
112 113 114 115

    // Optimize the program, and load parameters and modify them in the
    // scope_.
    // This will change the scope_ address.
116
    if (config_.ir_optim()) {
117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132
      status_ir_optim_enabled_ = true;
      OptimizeInferenceProgram();
    } else {
      // If the parent_scope is passed, we assert that the persistable variables
      // are already created, so just create the no persistable variables.

      // If not cloned, the parameters should be loaded
      // OptimizeInferenceProgram.
      // So in both cases, just the local variables are needed to load, not the
      // parematers.
      executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);

      // Load parameters
      LOG(INFO) << "load parameters ";
      LoadParameters();
    }
Y
Yan Chunwei 已提交
133
  } else {
134 135
    // If the program is passed from external, no need to optimize it, this
    // logic is used in the clone scenario.
136 137
    inference_program_ = program;
  }
M
Michal Gallus 已提交
138

139 140 141 142 143
  executor_->CreateVariables(*inference_program_, 0, false, sub_scope_);

  return true;
}
bool AnalysisPredictor::CreateExecutor() {
144
  if (config_.use_gpu_) {
145
    status_use_gpu_ = true;
146
    place_ = paddle::platform::CUDAPlace(config_.device_id_);
147 148 149 150 151 152 153 154
  } else {
    place_ = paddle::platform::CPUPlace();
  }
  executor_.reset(new paddle::framework::NaiveExecutor(place_));
  return true;
}
bool AnalysisPredictor::PrepareExecutor() {
  executor_->Prepare(sub_scope_, *inference_program_, 0,
155
                     config_.use_feed_fetch_ops_);
156

157
  PADDLE_ENFORCE_NOT_NULL(sub_scope_);
Y
Yan Chunwei 已提交
158

159 160 161
  return true;
}

L
luotao1 已提交
162
void AnalysisPredictor::SetMkldnnThreadID(int tid) {
L
luotao1 已提交
163 164 165 166 167 168 169
#ifdef PADDLE_WITH_MKLDNN
  platform::set_cur_thread_id(tid);
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
#endif
}

170 171 172
bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
                            std::vector<PaddleTensor> *output_data,
                            int batch_size) {
M
minqiyang 已提交
173
  VLOG(3) << "Predictor::predict";
174 175 176 177 178 179
  inference::Timer timer;
  timer.tic();
  // set feed variable
  framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get();
  if (!SetFeed(inputs, scope)) {
    LOG(ERROR) << "fail to set feed";
Y
Yan Chunwei 已提交
180
    return false;
181
  }
M
Michal Gallus 已提交
182

183 184 185
  // Run the inference program
  // if share variables, we need not create variables
  executor_->Run();
186

187 188 189 190
  // get fetch variable
  if (!GetFetch(output_data, scope)) {
    LOG(ERROR) << "fail to get fetches";
    return false;
T
tensor-tang 已提交
191
  }
M
minqiyang 已提交
192
  VLOG(3) << "predict cost: " << timer.toc() << "ms";
Y
Yan Chunwei 已提交
193

Y
Yan Chunwei 已提交
194 195 196 197 198 199 200
  // All the containers in the scope will be hold in inference, but the
  // operators assume that the container will be reset after each batch.
  // Here is a bugfix, collect all the container variables, and reset then to a
  // bool; the next time, the operator will call MutableData and construct a new
  // container again, so that the container will be empty for each batch.
  tensor_array_batch_cleaner_.CollectNoTensorVars(sub_scope_);
  tensor_array_batch_cleaner_.ResetNoTensorVars();
201 202
  return true;
}
203

204 205
bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
                                framework::Scope *scope) {
M
minqiyang 已提交
206
  VLOG(3) << "Predictor::set_feed";
207 208 209 210 211 212 213 214 215 216 217 218 219 220
  if (inputs.size() != feeds_.size()) {
    LOG(ERROR) << "wrong feed input size, need " << feeds_.size() << " but get "
               << inputs.size();
    return false;
  }

  // Cache the inputs memory for better concurrency performance.
  feed_tensors_.resize(inputs.size());

  for (size_t i = 0; i < inputs.size(); ++i) {
    auto &input = feed_tensors_[i];
    framework::DDim ddim = framework::make_ddim(inputs[i].shape);
    void *input_ptr;
    if (inputs[i].dtype == PaddleDType::INT64) {
221
      input_ptr = input.mutable_data<int64_t>(ddim, place_);
222
    } else if (inputs[i].dtype == PaddleDType::FLOAT32) {
223
      input_ptr = input.mutable_data<float>(ddim, place_);
224 225 226 227 228
    } else {
      LOG(ERROR) << "unsupported feed type " << inputs[i].dtype;
      return false;
    }

229 230 231 232 233 234
    if (platform::is_cpu_place(place_)) {
      // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
      std::memcpy(static_cast<void *>(input_ptr), inputs[i].data.data(),
                  inputs[i].data.length());
    } else {
#ifdef PADDLE_WITH_CUDA
Q
qingqing01 已提交
235 236 237 238
      platform::DeviceContextPool &pool =
          platform::DeviceContextPool::Instance();
      auto *dev_ctx =
          static_cast<const platform::CUDADeviceContext *>(pool.Get(place_));
239 240 241
      auto dst_gpu_place = boost::get<platform::CUDAPlace>(place_);
      memory::Copy(dst_gpu_place, static_cast<void *>(input_ptr),
                   platform::CPUPlace(), inputs[i].data.data(),
Q
qingqing01 已提交
242
                   inputs[i].data.length(), dev_ctx->stream());
243 244 245 246
#else
      PADDLE_THROW("Not compile with CUDA, should not reach here.");
#endif
    }
247 248 249 250 251 252 253
    // TODO(Superjomn) Low performance, need optimization for heavy LoD copy.
    framework::LoD lod;
    for (auto &level : inputs[i].lod) {
      lod.emplace_back(level);
    }
    input.set_lod(lod);
    int idx = -1;
254
    if (config_.specify_input_name_) {
T
tensor-tang 已提交
255 256
      auto name = inputs[i].name;
      if (feed_names_.find(name) == feed_names_.end()) {
T
tensor-tang 已提交
257 258
        LOG(ERROR) << "feed names from program do not have name: [" << name
                   << "] from specified input";
T
tensor-tang 已提交
259 260
      }
      idx = feed_names_[name];
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290
    } else {
      idx = boost::get<int>(feeds_[i]->GetAttr("col"));
    }
    framework::SetFeedVariable(scope, input, "feed", idx);
  }
  return true;
}

template <typename T>
void AnalysisPredictor::GetFetchOne(const framework::LoDTensor &fetch,
                                    PaddleTensor *output) {
  // set shape.
  auto shape = framework::vectorize(fetch.dims());
  output->shape.assign(shape.begin(), shape.end());
  // set data.
  const T *data = fetch.data<T>();
  int num_elems = inference::VecReduceToInt(shape);
  output->data.Resize(num_elems * sizeof(T));
  // The fetched tensor output by fetch op, should always in CPU memory, so just
  // copy.
  memcpy(output->data.data(), data, num_elems * sizeof(T));
  // set lod
  output->lod.clear();
  for (auto &level : fetch.lod()) {
    output->lod.emplace_back(level.begin(), level.end());
  }
}

bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
                                 framework::Scope *scope) {
M
minqiyang 已提交
291
  VLOG(3) << "Predictor::get_fetch";
292 293 294 295 296 297 298 299
  outputs->resize(fetchs_.size());
  for (size_t i = 0; i < fetchs_.size(); ++i) {
    int idx = boost::get<int>(fetchs_[i]->GetAttr("col"));
    PADDLE_ENFORCE((size_t)idx == i);
    framework::LoDTensor &fetch =
        framework::GetFetchVariable(*scope, "fetch", idx);
    auto type = fetch.type();
    auto output = &(outputs->at(i));
300
    output->name = fetchs_[idx]->Input("X")[0];
Y
Yu Yang 已提交
301
    if (type == framework::proto::VarType::FP32) {
302 303
      GetFetchOne<float>(fetch, output);
      output->dtype = PaddleDType::FLOAT32;
Y
Yu Yang 已提交
304
    } else if (type == framework::proto::VarType::INT64) {
305 306 307 308 309 310
      GetFetchOne<int64_t>(fetch, output);
      output->dtype = PaddleDType::INT64;
    } else {
      LOG(ERROR) << "unknown type, only support float32 and int64 now.";
    }
  }
Y
Yan Chunwei 已提交
311 312
  return true;
}
313

314
// NOTE All the members in AnalysisConfig should be copied to Argument.
Y
Yan Chunwei 已提交
315
void AnalysisPredictor::OptimizeInferenceProgram() {
316 317
  status_program_optimized_ = true;

318 319
  argument_.SetUseGPU(config_.use_gpu());
  argument_.SetGPUDeviceId(config_.gpu_device_id());
T
Tao Luo 已提交
320
  argument_.SetModelFromMemory(config_.model_from_memory_);
Y
Yan Chunwei 已提交
321
  // Analyze inference_program
322 323
  if (!config_.model_dir().empty()) {
    argument_.SetModelDir(config_.model_dir());
T
Tao Luo 已提交
324 325
  } else {
    PADDLE_ENFORCE(
326
        !config_.params_file().empty(),
T
Tao Luo 已提交
327
        "Either model_dir or (param_file, prog_file) should be set.");
328 329 330
    PADDLE_ENFORCE(!config_.prog_file().empty());
    argument_.SetModelProgramPath(config_.prog_file());
    argument_.SetModelParamsPath(config_.params_file());
Y
Yan Chunwei 已提交
331
  }
332

333
  if (config_.use_gpu() && config_.tensorrt_engine_enabled()) {
334 335 336
    argument_.SetUseTensorRT(true);
    argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_);
    argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
337
    argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
W
Wojciech Uss 已提交
338
  }
339

340 341 342 343
  if (config_.use_mkldnn_) {
    argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
  }

344
  auto passes = config_.pass_builder()->AllPasses();
345
  if (!config_.ir_optim()) passes.clear();
346 347 348 349 350 351 352
  argument_.SetIrAnalysisPasses(passes);
  argument_.SetScopeNotOwned(const_cast<framework::Scope *>(scope_.get()));
  Analyzer().Run(&argument_);

  PADDLE_ENFORCE(argument_.scope_valid());
  VLOG(5) << "to prepare executor";
  ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program);
Y
Yan Chunwei 已提交
353
  inference_program_.reset(
354
      new framework::ProgramDesc(argument_.ir_analyzed_program()));
355
  LOG(INFO) << "== optimize end ==";
Y
Yan Chunwei 已提交
356
}
357 358

template <>
359 360
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<
    AnalysisConfig, PaddleEngineKind::kAnalysis>(const AnalysisConfig &config) {
M
minqiyang 已提交
361
  VLOG(3) << "create AnalysisConfig";
362
  if (config.use_gpu()) {
363
    // 1. GPU memeroy
364 365 366
    PADDLE_ENFORCE_GT(config.memory_pool_init_size_mb(), 0.f);
    PADDLE_ENFORCE_GE(config.gpu_device_id(), 0, "Invalid device id %d",
                      config.gpu_device_id());
367
    std::vector<std::string> flags;
368 369 370 371 372 373 374 375 376 377 378

    float fraction_of_gpu_memory = config.fraction_of_gpu_memory_for_pool();
    if (fraction_of_gpu_memory > 0.95f) {
      LOG(ERROR)
          << "Allocate too much memory for the GPU memory pool, assigned "
          << config.memory_pool_init_size_mb() << " MB";
      LOG(ERROR)
          << "Try to shink the value by setting AnalysisConfig::EnableGpu(...)";
    }

    if (fraction_of_gpu_memory >= 0.0f || fraction_of_gpu_memory <= 0.95f) {
379 380
      flags.push_back("dummpy");
      std::string flag = "--fraction_of_gpu_memory_to_use=" +
381
                         std::to_string(fraction_of_gpu_memory);
382
      flags.push_back(flag);
M
minqiyang 已提交
383
      VLOG(3) << "set flag: " << flag;
384 385 386 387 388
      framework::InitGflags(flags);
    }
  }

  std::unique_ptr<PaddlePredictor> predictor(new AnalysisPredictor(config));
389
  if (!dynamic_cast<AnalysisPredictor *>(predictor.get())->Init(nullptr)) {
390 391
    return nullptr;
  }
392
  return std::move(predictor);
393 394
}

395
void AnalysisPredictor::PrepareFeedFetch() {
396 397
  PADDLE_ENFORCE_NOT_NULL(sub_scope_);
  CreateFeedFetchVar(sub_scope_);
398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
  for (auto *op : inference_program_->Block(0).AllOps()) {
    if (op->Type() == "feed") {
      int idx = boost::get<int>(op->GetAttr("col"));
      if (feeds_.size() <= static_cast<size_t>(idx)) {
        feeds_.resize(idx + 1);
      }
      feeds_[idx] = op;
      feed_names_[op->Output("Out")[0]] = idx;
    } else if (op->Type() == "fetch") {
      int idx = boost::get<int>(op->GetAttr("col"));
      if (fetchs_.size() <= static_cast<size_t>(idx)) {
        fetchs_.resize(idx + 1);
      }
      fetchs_[idx] = op;
    }
  }
}

416 417 418 419 420 421 422 423
void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) {
  PADDLE_ENFORCE_NOT_NULL(scope);
  auto *var = scope->Var("feed");
  var->GetMutable<framework::FeedFetchList>();
  var = scope->Var("fetch");
  var->GetMutable<framework::FeedFetchList>();
}

424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445
std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
    const std::string &name) {
  PADDLE_ENFORCE(executor_->scope()->FindVar(name), "no name called %s", name);
  std::unique_ptr<ZeroCopyTensor> res(
      new ZeroCopyTensor(static_cast<void *>(executor_->scope())));
  res->input_or_output_ = true;
  res->SetName(name);
  return res;
}

std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
    const std::string &name) {
  PADDLE_ENFORCE(executor_->scope()->FindVar(name), "no name called %s", name);
  std::unique_ptr<ZeroCopyTensor> res(
      new ZeroCopyTensor(static_cast<void *>(executor_->scope())));
  res->input_or_output_ = false;
  res->SetName(name);
  return res;
}

bool AnalysisPredictor::ZeroCopyRun() {
  executor_->Run();
Y
Yan Chunwei 已提交
446
  // Fix TensorArray reuse not cleaned bug.
Y
Yan Chunwei 已提交
447
  tensor_array_batch_cleaner_.CollectTensorArrays(sub_scope_);
Y
Yan Chunwei 已提交
448
  tensor_array_batch_cleaner_.ResetTensorArray();
449 450 451 452 453
  return true;
}

bool AnalysisPredictor::LoadProgramDesc() {
  // Initialize the inference program
454
  std::string filename;
455 456 457
  if (!config_.model_dir().empty()) {
    filename = config_.model_dir() + "/__model__";
  } else if (!config_.prog_file().empty() && !config_.params_file().empty()) {
458 459 460
    // All parameters are saved in a single file.
    // The file names should be consistent with that used
    // in Python API `fluid.io.save_inference_model`.
461
    filename = config_.prog_file();
462
  } else {
463
    if (config_.model_dir().empty() && config_.prog_file().empty()) {
464 465 466 467
      LOG(ERROR)
          << "Either model_dir or (prog_file, param_file) should be set.";
      return false;
    }
468
    LOG(ERROR) << string::Sprintf(
469 470
        "not valid model path '%s' or program path '%s'.", config_.model_dir(),
        config_.params_file());
471 472
    return false;
  }
473 474 475

  // Create ProgramDesc
  framework::proto::ProgramDesc proto;
T
Tao Luo 已提交
476
  if (!config_.model_from_memory()) {
T
Tao Luo 已提交
477 478 479
    std::string pb_content;
    // Read binary
    std::ifstream fin(filename, std::ios::in | std::ios::binary);
T
Tao Luo 已提交
480 481
    PADDLE_ENFORCE(static_cast<bool>(fin.is_open()), "Cannot open file %s",
                   filename);
T
Tao Luo 已提交
482 483 484 485 486 487 488 489
    fin.seekg(0, std::ios::end);
    pb_content.resize(fin.tellg());
    fin.seekg(0, std::ios::beg);
    fin.read(&(pb_content.at(0)), pb_content.size());
    fin.close();

    proto.ParseFromString(pb_content);
  } else {
490
    proto.ParseFromString(config_.prog_file());
T
Tao Luo 已提交
491
  }
492 493 494 495 496 497 498
  inference_program_.reset(new framework::ProgramDesc(proto));
  return true;
}

bool AnalysisPredictor::LoadParameters() {
  PADDLE_ENFORCE_NOT_NULL(inference_program_.get(),
                          "The inference program should be loaded first.");
T
Tao Luo 已提交
499

500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
  const auto &global_block = inference_program_->MutableBlock(0);

  // create a temporary program to load parameters.

  std::unique_ptr<framework::ProgramDesc> load_program(
      new framework::ProgramDesc());
  framework::BlockDesc *load_block = load_program->MutableBlock(0);
  std::vector<std::string> params;

  for (auto *var : global_block->AllVars()) {
    if (IsPersistable(var)) {
      VLOG(3) << "persistable variable's name: " << var->Name();

      framework::VarDesc *new_var = load_block->Var(var->Name());
      new_var->SetShape(var->GetShape());
      new_var->SetDataType(var->GetDataType());
      new_var->SetType(var->GetType());
      new_var->SetLoDLevel(var->GetLoDLevel());
      new_var->SetPersistable(true);

520
      if (!config_.params_file().empty()) {
521 522 523 524 525 526
        params.push_back(new_var->Name());
      } else {
        // append_op
        framework::OpDesc *op = load_block->AppendOp();
        op->SetType("load");
        op->SetOutput("Out", {new_var->Name()});
527
        op->SetAttr("file_path", {config_.model_dir() + "/" + new_var->Name()});
528 529 530 531 532
        op->CheckAttrs();
      }
    }
  }

533
  if (!config_.params_file().empty()) {
534 535 536 537 538 539
    // sort paramlist to have consistent ordering
    std::sort(params.begin(), params.end());
    // append just the load_combine op
    framework::OpDesc *op = load_block->AppendOp();
    op->SetType("load_combine");
    op->SetOutput("Out", params);
540
    op->SetAttr("file_path", {config_.params_file()});
541 542 543 544
    op->CheckAttrs();
  }

  // Use NaiveExecutor to Load parameters.
S
superjomn 已提交
545
  framework::NaiveExecutor e(place_);
546 547 548 549
  e.Prepare(scope_.get(), *load_program, 0, false);
  e.Run();
  VLOG(3) << "get " << scope_->LocalVarNames().size() << " vars after load";

550 551
  return true;
}
552 553 554 555 556 557 558 559 560 561 562

AnalysisPredictor::~AnalysisPredictor() {
  if (FLAGS_profile) {
    platform::DisableProfiler(platform::EventSortingKey::kTotal,
                              "./profile.log");
  }
  if (sub_scope_) {
    scope_->DeleteScope(sub_scope_);
  }
}

563
std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone() {
Y
Yan Chunwei 已提交
564
  std::lock_guard<std::mutex> lk(clone_mutex_);
565 566 567 568 569
  auto *x = new AnalysisPredictor(config_);
  x->Init(scope_, inference_program_);
  return std::unique_ptr<PaddlePredictor>(x);
}

Y
Yan Chunwei 已提交
570 571
template <>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<contrib::AnalysisConfig>(
572
    const contrib::AnalysisConfig &config) {
Y
Yan Chunwei 已提交
573 574 575 576
  return CreatePaddlePredictor<contrib::AnalysisConfig,
                               PaddleEngineKind::kAnalysis>(config);
}

577
}  // namespace paddle
578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599

#if PADDLE_WITH_TENSORRT
USE_TRT_CONVERTER(elementwise_add_weight);
USE_TRT_CONVERTER(elementwise_add_tensor);
USE_TRT_CONVERTER(elementwise_sub_tensor);
USE_TRT_CONVERTER(elementwise_div_tensor);
USE_TRT_CONVERTER(elementwise_mul_tensor);
USE_TRT_CONVERTER(elementwise_max_tensor);
USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor);
USE_TRT_CONVERTER(mul);
USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu);
USE_TRT_CONVERTER(sigmoid);
USE_TRT_CONVERTER(tanh);
USE_TRT_CONVERTER(fc);
USE_TRT_CONVERTER(pool2d);
USE_TRT_CONVERTER(softmax);
USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad);
600
USE_TRT_CONVERTER(split);
601 602
USE_TRT_CONVERTER(prelu);
USE_TRT_CONVERTER(conv2d_transpose);
H
hjchen2 已提交
603
USE_TRT_CONVERTER(leaky_relu);
604
#endif