analysis_predictor.cc 93.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

Y
Yan Chunwei 已提交
15
#include "paddle/fluid/inference/api/analysis_predictor.h"
16

17
#include <glog/logging.h>
18

19
#include <algorithm>
N
nhzlx 已提交
20
#include <fstream>
21
#include <memory>
22
#include <set>
23
#include <string>
24
#include <utility>
25
#include <vector>
26

W
Wilber 已提交
27
#include "paddle/fluid//platform/device/gpu/gpu_types.h"
28
#include "paddle/fluid/framework/feed_fetch_method.h"
29
#include "paddle/fluid/framework/feed_fetch_type.h"
30
#include "paddle/fluid/framework/generator.h"
Y
Yan Chunwei 已提交
31
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
32
#include "paddle/fluid/framework/ir/pass.h"
33
#include "paddle/fluid/framework/naive_executor.h"
34
#include "paddle/fluid/framework/op_proto_maker.h"
35
#include "paddle/fluid/framework/operator.h"
36
#include "paddle/fluid/framework/scope.h"
J
JingZhuangzhuang 已提交
37
#include "paddle/fluid/framework/transfer_scope_cache.h"
Y
Yan Chunwei 已提交
38
#include "paddle/fluid/framework/var_type_traits.h"
39
#include "paddle/fluid/framework/version.h"
40
#include "paddle/fluid/inference/analysis/helper.h"
41
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
Y
Yan Chunwei 已提交
42
#include "paddle/fluid/inference/analysis/passes/memory_optimize_pass.h"
43
#include "paddle/fluid/inference/api/helper.h"
44
#include "paddle/fluid/inference/api/infer_context.h"
45
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
46
#include "paddle/fluid/inference/api/paddle_inference_api.h"
L
luotao1 已提交
47
#include "paddle/fluid/inference/api/paddle_inference_pass.h"
W
Wilber 已提交
48
#include "paddle/fluid/inference/api/resource_manager.h"
49
#include "paddle/fluid/inference/utils/io_utils.h"
50
#include "paddle/fluid/inference/utils/model_utils.h"
51
#include "paddle/fluid/inference/utils/singleton.h"
52
#include "paddle/fluid/memory/memcpy.h"
53
#include "paddle/fluid/platform/cpu_helper.h"
54
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
55
#include "paddle/fluid/platform/device_context.h"
56
#include "paddle/fluid/platform/place.h"
T
tensor-tang 已提交
57
#include "paddle/fluid/platform/profiler.h"
58
#include "paddle/phi/api/ext/op_meta_info.h"
59 60
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"
W
Wilber 已提交
61
#include "paddle/phi/common/place.h"
W
Wilber 已提交
62
#include "paddle/phi/core/enforce.h"
63 64
#include "paddle/utils/string/split.h"

65
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
66 67 68 69
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#endif
T
tensor-tang 已提交
70

71 72 73 74
#ifdef PADDLE_WITH_MKLML
#include "paddle/fluid/platform/dynload/mklml.h"
#endif

75 76 77 78
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/inference/api/mkldnn_quantizer.h"
#endif

79 80 81 82
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "paddle/fluid/inference/api/onnxruntime_predictor.h"
#endif

83
#ifdef PADDLE_WITH_TENSORRT
Y
Yan Chunwei 已提交
84
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
85
#include "paddle/fluid/inference/tensorrt/helper.h"
86
#include "paddle/fluid/inference/tensorrt/trt_int8_calibrator.h"
Y
Yan Chunwei 已提交
87 88
#endif

89 90 91 92
#ifdef PADDLE_WITH_IPU
#include "paddle/fluid/platform/device/ipu/paddle_ipu_handler.h"
#endif

93 94
namespace paddle {

N
nhzlx 已提交
95
using inference::Singleton;
96
#ifdef PADDLE_WITH_TENSORRT
N
nhzlx 已提交
97 98
using inference::tensorrt::TRTCalibratorEngine;
using inference::tensorrt::TRTCalibratorEngineManager;
99
using inference::tensorrt::TRTInt8Calibrator;
N
nhzlx 已提交
100
#endif
101

102 103
int AnalysisPredictor::clone_num_ = 1;

104 105 106 107
namespace {
bool IsPersistable(const framework::VarDesc *var) {
  if (var->Persistable() &&
      var->GetType() != framework::proto::VarType::FEED_MINIBATCH &&
108 109
      var->GetType() != framework::proto::VarType::FETCH_LIST &&
      var->GetType() != framework::proto::VarType::RAW) {
110 111 112 113
    return true;
  }
  return false;
}
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132

phi::DataType ConvertPrecision(AnalysisConfig::Precision precision) {
  switch (precision) {
    case AnalysisConfig::Precision::kFloat32:
      return phi::DataType::FLOAT32;
    case AnalysisConfig::Precision::kHalf:
      return phi::DataType::FLOAT16;
    case AnalysisConfig::Precision::kBf16:
      return phi::DataType::BFLOAT16;
    case AnalysisConfig::Precision::kInt8:
      return phi::DataType::INT8;
    default:
      PADDLE_THROW(paddle::platform::errors::InvalidArgument(
          "Paddle Inference not support precision. We now only support "
          "Float32, Half, Bfloat16 and Int8"));
      return phi::DataType::FLOAT32;
  }
}

133
phi::Backend ConvertBackend(paddle_infer::PlaceType backend) {
134
  switch (backend) {
135
    case paddle_infer::PlaceType::kGPU:
136 137
      // NOTE: phi also support phi::Backend::GPUDNN.
      return phi::Backend::GPU;
138
    case paddle_infer::PlaceType::kNPU:
139
      return phi::Backend::NPU;
140
    case paddle_infer::PlaceType::kXPU:
141
      return phi::Backend::XPU;
142
    case paddle_infer::PlaceType::kCPU:
143
      return phi::Backend::CPU;
144 145
    case paddle_infer::PlaceType::kIPU:
      return phi::Backend::IPU;
146 147 148 149 150 151 152
    default:
      PADDLE_THROW(paddle::platform::errors::InvalidArgument(
          "Paddle Inference not support backend, we now only support GPU, XPU, "
          "NPU and CPU."));
      return phi::Backend::CPU;
  }
}
153 154
}  // namespace

C
ccrrong 已提交
155
bool PaddleTensorToLoDTensor(const PaddleTensor &pt,
156
                             phi::DenseTensor *t,
157
                             const platform::Place &place) {
158
  framework::DDim ddim = phi::make_ddim(pt.shape);
159 160 161 162 163 164 165
  void *input_ptr;
  if (pt.dtype == PaddleDType::INT64) {
    input_ptr = t->mutable_data<int64_t>(ddim, place);
  } else if (pt.dtype == PaddleDType::FLOAT32) {
    input_ptr = t->mutable_data<float>(ddim, place);
  } else if (pt.dtype == PaddleDType::INT32) {
    input_ptr = t->mutable_data<int32_t>(ddim, place);
166 167
  } else if (pt.dtype == PaddleDType::FLOAT16) {
    input_ptr = t->mutable_data<float16>(ddim, place);
168 169 170 171
  } else {
    LOG(ERROR) << "unsupported feed type " << pt.dtype;
    return false;
  }
172 173 174
  // NOTE(Aurelius84): Some kernels support zero shape input
  // without memory holder, we should skip enforce logic.
  bool has_zero_dim = (phi::product(ddim) == 0);
175 176 177
  VLOG(3) << "Found zero dim: " << has_zero_dim
          << " from input with ddim: " << ddim;
  if (!has_zero_dim) {
178 179 180 181 182 183 184 185
    PADDLE_ENFORCE_NOT_NULL(
        input_ptr,
        paddle::platform::errors::Fatal(
            "Cannot convert to LoDTensor because LoDTensor creation failed."));
    PADDLE_ENFORCE_NOT_NULL(
        pt.data.data(),
        paddle::platform::errors::InvalidArgument(
            "The data contained in the input PaddleTensor is illegal."));
186 187 188 189 190
    PADDLE_ENFORCE_EQ(
        pt.data.length(),
        t->numel() * paddle::experimental::SizeOf(t->dtype()),
        paddle::platform::errors::InvalidArgument(
            "The data contained in the input PaddleTensor had wrong length."));
191
  }
192 193 194

  if (platform::is_cpu_place(place)) {
    // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
195 196 197 198
    if (input_ptr != nullptr) {
      std::memcpy(
          static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
    }
J
jianghaicheng 已提交
199 200
  } else if (platform::is_ipu_place(place)) {
#ifdef PADDLE_WITH_IPU
C
ccrrong 已提交
201 202
    std::memcpy(
        static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
J
jianghaicheng 已提交
203 204 205 206
#else
    PADDLE_THROW(paddle::platform::errors::Fatal(
        "Not compile with WITH_IPU, should not reach here."));
#endif
207
  } else if (platform::is_gpu_place(place)) {
C
ccrrong 已提交
208 209
    PADDLE_ENFORCE_EQ(platform::is_xpu_place(place),
                      false,
210 211
                      platform::errors::InvalidArgument(
                          "Only one choice can be made between CPU and XPU."));
212
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
213
    platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
L
Leo Chen 已提交
214
    auto *dev_ctx = static_cast<const phi::GPUContext *>(pool.Get(place));
215
    auto dst_gpu_place = place;
C
ccrrong 已提交
216 217 218 219 220
    memory::Copy(dst_gpu_place,
                 static_cast<void *>(input_ptr),
                 platform::CPUPlace(),
                 pt.data.data(),
                 pt.data.length(),
221 222 223 224 225
                 dev_ctx->stream());
#else
    PADDLE_THROW(paddle::platform::errors::Fatal(
        "Not compile with CUDA, should not reach here."));
#endif
226 227
  } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
228
    auto dst_xpu_place = place;
C
ccrrong 已提交
229 230 231 232 233
    memory::Copy(dst_xpu_place,
                 static_cast<void *>(input_ptr),
                 platform::CPUPlace(),
                 pt.data.data(),
                 pt.data.length());
234 235 236 237 238 239 240
#else
    PADDLE_THROW(paddle::platform::errors::Fatal(
        "Not compile with XPU, should not reach here."));
#endif
  } else {
    PADDLE_THROW(paddle::platform::errors::InvalidArgument(
        "The analysis predictor supports CPU, GPU and XPU now."));
241 242 243 244 245 246 247 248 249 250
  }
  // TODO(Superjomn) Low performance, need optimization for heavy LoD copy.
  framework::LoD lod;
  for (auto &level : pt.lod) {
    lod.emplace_back(level);
  }
  t->set_lod(lod);
  return true;
}

Y
Yan Chunwei 已提交
251
bool AnalysisPredictor::Init(
252 253
    const std::shared_ptr<framework::Scope> &parent_scope,
    const std::shared_ptr<framework::ProgramDesc> &program) {
M
minqiyang 已提交
254
  VLOG(3) << "Predictor::init()";
255 256
  if (config_.with_profile_) {
    LOG(WARNING) << "Profiler is activated, which might affect the performance";
257 258
    auto tracking_device = config_.use_gpu() ? platform::ProfilerState::kAll
                                             : platform::ProfilerState::kCPU;
T
tensor-tang 已提交
259
    platform::EnableProfiler(tracking_device);
260
  } else {
261 262
    VLOG(2) << "Profiler is deactivated, and no profiling report will be "
               "generated.";
T
tensor-tang 已提交
263 264
  }

265
  // no matter with or without MKLDNN
L
luotao1 已提交
266
  paddle::platform::SetNumThreads(config_.cpu_math_library_num_threads());
267

268 269 270
  if (!PrepareScope(parent_scope)) {
    return false;
  }
271 272 273

  InitPlace();

274 275 276 277 278 279 280
  if (!CreateExecutor()) {
    return false;
  }
  if (!PrepareProgram(program)) {
    return false;
  }

281 282 283
  // Get the feed_target_names and fetch_target_names
  PrepareFeedFetch();

284 285 286
  // Prepare executor, create local variables.
  if (!PrepareExecutor()) {
    return true;
Y
Yan Chunwei 已提交
287
  }
288

289 290 291 292 293 294 295 296 297 298 299 300 301
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  // TODO(inference): Now only gpu with external stream support private
  // device_context.
  if (config_.use_gpu_ && config_.use_external_stream_) {
    private_context_ = true;
  }
  if (private_context_) {
    if (!status_is_cloned_) {
      predictor_stream_ = config_.GetExecStream();
    }
    // NOTE: If the external_stream equals to global_device_contexts's stream,
    // then fallback.
    auto global_stream =
L
Leo Chen 已提交
302
        static_cast<phi::GPUContext *>(
303 304 305 306 307 308
            platform::DeviceContextPool::Instance().Get(place_))
            ->stream();
    if (predictor_stream_ != global_stream) {
      InitResourceManager(predictor_stream_);
      InitDeviceContexts();
    }
Y
Yan Chunwei 已提交
309
  }
310
#endif
311
  inference::DisplayMemoryInfo(place_, "Init predictor");
312 313
  return true;
}
314

315
void AnalysisPredictor::InitPlace() {
316
  if (config_.use_gpu()) {
C
ccrrong 已提交
317 318
    PADDLE_ENFORCE_EQ(config_.use_xpu(),
                      false,
319 320
                      platform::errors::InvalidArgument(
                          "Only one choice can be made between CPU and XPU."));
321
    place_ = paddle::platform::CUDAPlace(config_.gpu_device_id());
322
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
323
    if (config_.thread_local_stream_enabled()) {
W
Wilber 已提交
324 325
      LOG_FIRST_N(WARNING, 1) << "We will remove this interface in the future. "
                                 "Please use config.SetExecStream instead.";
326 327
    }
#endif
328
  } else if (config_.use_xpu()) {
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
    if (config_.lite_engine_enabled()) {
#ifdef LITE_SUBGRAPH_WITH_XPU
      // Currently, Paddle-Lite's XPU user interface only supports the transfer
      // of Host data pointers. If it is currently used as a subgraph, execution
      // efficiency will be sacrificed, so it is temporarily set to cpu place.
      // And, the current lite engine of xpu must execute all parts of the
      // model.
      place_ = paddle::platform::CPUPlace();
#else
      PADDLE_THROW(platform::errors::Unavailable(
          "You tried to use an XPU lite engine, but Paddle was not compiled "
          "with it."));
#endif  // LITE_SUBGRAPH_WITH_XPU
    } else {
#ifdef PADDLE_WITH_XPU
      place_ = paddle::platform::XPUPlace(config_.xpu_device_id());
#else
      PADDLE_THROW(platform::errors::Unavailable(
          "You tried to use XPU forward propagation (inference without lite "
          "engine), but Paddle was not compiled "
          "with WITH_XPU."));
#endif  // PADDLE_WITH_XPU
    }
W
Wilber 已提交
352 353 354 355 356 357 358 359
  } else if (config_.use_npu()) {
#ifdef PADDLE_WITH_ASCEND_CL
    place_ = paddle::platform::NPUPlace(config_.npu_device_id());
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use NPU forward propagation, but Paddle was not compiled "
        "with WITH_ASCEND_CL."));
#endif
360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
  } else if (config_.NNAdapter().use_nnadapter) {
    if (config_.lite_engine_enabled()) {
      place_ = paddle::platform::CPUPlace();
#ifndef LITE_SUBGRAPH_WITH_NNADAPTER
      PADDLE_THROW(
          platform::errors::Unavailable("You tried to use an NNAdapter lite "
                                        "engine, but Paddle was not compiled "
                                        "with it."));
#endif  // LITE_SUBGRAPH_WITH_NNADAPTER
    } else {
      PADDLE_THROW(
          platform::errors::Unavailable("You tried to use NNadapter forward "
                                        "propagation (inference without lite "
                                        "engine), but Paddle was not compiled "
                                        "with LITE_WITH_NNADAPTER."));
    }
J
jianghaicheng 已提交
376 377 378 379 380 381 382
  } else if (config_.use_ipu()) {
#ifdef PADDLE_WITH_IPU
    place_ = paddle::platform::IPUPlace();
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use IPU forward propagation, but Paddle was not compiled "
        "with WITH_IPU."));
383 384 385 386 387 388 389 390 391
#endif
  } else if (config_.use_custom_device()) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
    place_ = paddle::platform::CustomPlace(config_.custom_device_type());
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use CustomDevice forward propagation, but Paddle was not "
        "compiled "
        "with WITH_CUSTOM_DEVICE."));
J
jianghaicheng 已提交
392
#endif
393 394 395
  } else {
    place_ = paddle::platform::CPUPlace();
  }
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412
}

void AnalysisPredictor::InitResourceManager(void *stream) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  predictor_stream_ =
      ResourceManager::Instance().InitGPUResource(place_, stream);
#endif
}

void AnalysisPredictor::InitDeviceContexts() {
// Init GPUContext.
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (place_.GetType() == phi::AllocationType::GPU) {
    device_contexts_.emplace(
        place_, std::async(std::launch::deferred, [=] {
          auto *gpu_resource =
              ResourceManager::Instance().GetGPUResource(predictor_stream_);
W
Wilber 已提交
413
          auto *gpu_context = new InferGPUContext(place_);
414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434
          gpu_context->SetAllocator(
              memory::allocation::AllocatorFacade::Instance()
                  .GetAllocator(place_, gpu_resource->GetStream())
                  .get());
          gpu_context->SetPinnedAllocator(
              memory::allocation::AllocatorFacade::Instance()
                  .GetAllocator(paddle::platform::CUDAPinnedPlace())
                  .get());
          gpu_context->SetHostAllocator(
              memory::allocation::AllocatorFacade::Instance()
                  .GetAllocator(platform::CPUPlace())
                  .get());
          gpu_context->SetZeroAllocator(
              memory::allocation::AllocatorFacade::Instance()
                  .GetZeroAllocator(place_)
                  .get());
          gpu_context->SetGenerator(
              framework::DefaultCUDAGenerator(place_.GetDeviceId()).get());
          gpu_context->SetHostGenerator(framework::DefaultCPUGenerator().get());

          gpu_context->SetStream(gpu_resource->GetStream());
435
          gpu_context->SetBlasHandle(gpu_resource->GetBlasHandleCreator());
436
          gpu_context->SetBlasTensorCoreHandle(
437 438 439 440 441 442 443 444
              gpu_resource->GetBlasTensorCoreHandleCreator());
          gpu_context->SetBlasTF32Handle(
              gpu_resource->GetBlasTF32TensorCoreHandleCreator());
          gpu_context->SetDnnHandle(gpu_resource->GetDnnHandleCreator());
          gpu_context->SetSolverHandle(
              gpu_resource->GetSolverDnHandleCreator());
          gpu_context->SetSparseHandle(gpu_resource->GetSparseHandleCreator());
          gpu_context->SetEigenDevice(gpu_resource->GetGpuEigenDeviceCreator());
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
          gpu_context->SetComputeCapability(
              gpu_resource->GetGpuComputeCapability());
          gpu_context->SetMaxThreadsPerBlock(
              gpu_resource->GetGpuMaxThreadsPerBlock());
          gpu_context->SetMaxThreadsPerMultiProcessor(
              gpu_resource->GetGpuMaxThreadsPerMp());
          gpu_context->SetMaxGridDimSize(gpu_resource->GetGpuMaxGridDimSize());
          gpu_context->SetMultiProcessors(
              gpu_resource->GetGPUMultiProcessors());
          gpu_context->SetDriverVersion(gpu_resource->GetGpuDriverVersion());
          gpu_context->SetRuntimeVersion(gpu_resource->GetGpuRuntimeVersion());
          VLOG(1) << "thread id is " << std::this_thread::get_id()
                  << ", stream id is "
                  << reinterpret_cast<void *>(gpu_resource->GetStream())
                  << ", allotor ptr is "
                  << reinterpret_cast<void *>(
                         memory::allocation::AllocatorFacade::Instance()
                             .GetAllocator(place_, gpu_resource->GetStream())
                             .get());
          return std::unique_ptr<phi::DeviceContext>(gpu_context);
        }));
  }
#endif
  // TODO(Inference): Support other backends.
}

void *AnalysisPredictor::GetExecStream() const {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (place_.GetType() == phi::AllocationType::GPU) {
    if (private_context_) {
      return predictor_stream_;
    } else {
      paddle::platform::DeviceContextPool &pool =
          paddle::platform::DeviceContextPool::Instance();
      return reinterpret_cast<const phi::GPUContext *>(pool.Get(place_))
          ->stream();
    }
  } else {
    return nullptr;
  }
  return nullptr;
#else
  // TODO(inference): Support other backends.
  return nullptr;
#endif
}

const void *AnalysisPredictor::GetDeviceContexts() const {
  if (private_context_) {
    return &device_contexts_;
  } else {
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
    const auto &dev_ctxs = pool.device_contexts();
    return &dev_ctxs;
  }
}

bool AnalysisPredictor::PrepareScope(
    const std::shared_ptr<framework::Scope> &parent_scope) {
  if (parent_scope) {
    PADDLE_ENFORCE_NOT_NULL(
        parent_scope,
        platform::errors::PreconditionNotMet(
            "Both program and parent_scope should be set in Clone mode."));
    scope_ = parent_scope;
    status_is_cloned_ = true;
  } else {
    paddle::framework::InitDevices();
    paddle::framework::InitDefaultKernelSignatureMap();
    // TODO(wilber): we need to release memory occupied by weights.
    scope_.reset(new paddle::framework::Scope());
    status_is_cloned_ = false;
  }
  sub_scope_ = &scope_->NewScope();
  return true;
}

bool AnalysisPredictor::PrepareProgram(
    const std::shared_ptr<framework::ProgramDesc> &program) {
  if (!program) {
    if (!LoadProgramDesc()) return false;
    // If not cloned, the parameters should be loaded.
    // If config_.ir_optim() is True, parameters is loaded in
    // OptimizeInferenceProgram(), but other persistable variables
    // (like RAW type var) are not created in scope.
    // If config_.ir_optim() is False, parameters is loaded in LoadParameters(),
    // still need to create other persistable variables.
    // So in both case, create persistable variables at first.
    executor_->CreateVariables(*inference_program_, 0, true, sub_scope_);

    // if enable_ir_optim_ is false,
    // the analysis pass(op fuse, graph analysis, trt subgraph, mkldnn etc) will
    // not be executed.
539 540
    model_precision_ =
        paddle::inference::GetModelPrecision(*inference_program_);
541 542 543 544 545
    OptimizeInferenceProgram();
  } else {
    // If the program is passed from external, no need to optimize it, this
    // logic is used in the clone scenario.
    inference_program_ = program;
546 547 548 549 550
    if (config_.apply_optim_) {
      VLOG(3)
          << "apply_optim is enabled, will call OptimizeInferenceProgram().";
      OptimizeInferenceProgram();
    }
551 552 553 554 555 556 557 558
  }

  executor_->CreateVariables(*inference_program_, 0, false, sub_scope_);

  return true;
}

bool AnalysisPredictor::CreateExecutor() {
559 560 561
  executor_.reset(new paddle::framework::NaiveExecutor(place_));
  return true;
}
W
wenbin 已提交
562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580

static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
  // here is prepare data optimization related bad cases:
  // let's assume an op behind conditional_block and if conditional_block
  // chooses branch 1, the op need to call prepare data. else the op don't need
  // to call prepare data. In running, if predictor chooses branch 2, then
  // optimization takes effect, later issue is followed if predictor chooses
  // branch 1, because the op lost chance to prepare data.
  std::vector<std::string> op_type = {"conditional_block_infer",
                                      "select_input"};
  for (const auto &type : op_type) {
    if (op->Type() == type) {
      return true;
    }
  }
  return false;
}

static void DisablePrepareDataOpt(
C
ccrrong 已提交
581 582
    std::shared_ptr<framework::ProgramDesc> inference_program,
    int block,
W
wenbin 已提交
583 584 585 586 587 588 589 590 591
    bool pre_disable_opt) {
  bool disable_opt = false;
  auto &infer_block = inference_program->Block(block);
  for (auto *op : infer_block.AllOps()) {
    if (disable_opt || pre_disable_opt) {
      op->SetAttr("inference_force_prepare_data", true);
    }
    if (op->HasAttr("sub_block")) {
      int blockID = op->GetBlockAttrId("sub_block");
C
ccrrong 已提交
592 593
      DisablePrepareDataOpt(
          inference_program, blockID, disable_opt || pre_disable_opt);
W
wenbin 已提交
594 595
    }
    // disable prepare data if unfriendly op is found
W
wenbin 已提交
596 597 598
    if (!disable_opt) {
      disable_opt = IsPrepareDataOptTargetOp(op);
    }
W
wenbin 已提交
599 600 601
  }
}

602
bool AnalysisPredictor::PrepareExecutor() {
603
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
604 605 606 607 608
  if (config_.dist_config().use_dist_model()) {
    VLOG(3) << "use_dist_model is enabled, will init FleetExecutor.";
    return PrepareFleetExecutor();
  }
#endif
W
wenbin 已提交
609 610
  DisablePrepareDataOpt(inference_program_, 0, false);

C
ccrrong 已提交
611 612
  executor_->Prepare(
      sub_scope_, *inference_program_, 0, config_.use_feed_fetch_ops_);
613

614 615 616
  PADDLE_ENFORCE_NOT_NULL(sub_scope_,
                          platform::errors::PreconditionNotMet(
                              "The sub_scope should not be nullptr."));
Y
Yan Chunwei 已提交
617

618 619 620
  return true;
}

621
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657
bool AnalysisPredictor::PrepareFleetExecutor() {
  VLOG(3) << "AnalysisPredictor::PrepareFleetExecutor()";
  if (config_.dist_config().nranks() > 1 && !CommInit()) {
    return false;
  }
  task_node_.reset(new distributed::TaskNode(inference_program_.get(),
                                             config_.dist_config().rank()));
  // With auto cut, there is no concept of pp, no need to add dependency.
  task_node_->SetType("Compute");
  task_node_->Init(config_.use_feed_fetch_ops_enabled());
  executor_desc_ = distributed::FleetExecutorDesc();
  executor_desc_.set_cur_rank(config_.dist_config().rank());
  std::unordered_map<int64_t, int64_t> id_to_rank;
  for (int i = 0; i < config_.dist_config().nranks(); ++i) {
    distributed::RankInfo *rank_info = executor_desc_.add_cluster_info();
    rank_info->set_rank(i);
    rank_info->set_ip_port(config_.dist_config().trainer_endpoints()[i]);
    id_to_rank.insert({i, i});
  }
  fleet_exe_.reset(new distributed::FleetExecutor(executor_desc_));
  // NOTE: Vars of feed fetch ops are not persistable,
  // which will result in that those vars will be created in
  // the subscope (microscope) in fleet executor. This will
  // cause that the GetInputTensor/GetOutputTensor funct
  // in analysis predictor cannot find those vars in the scope
  // returned by the DistModel, since DistModel only return the
  // root scope. So, those vars must  to be created in the root
  // scope instead of in the microscope
  std::vector<std::string> feed_fetch_vars;
  for (auto pair : idx2feeds_) {
    feed_fetch_vars.emplace_back(pair.second);
  }
  for (auto pair : idx2fetches_) {
    feed_fetch_vars.emplace_back(pair.second);
  }
  fleet_exe_->Init(config_.dist_config().carrier_id(),
C
ccrrong 已提交
658 659 660 661 662 663 664
                   *(inference_program_.get()),
                   scope_.get(),
                   place_,
                   1,
                   {task_node_.get()},
                   id_to_rank,
                   feed_fetch_vars);
665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700
  return true;
}

bool AnalysisPredictor::CommInit() {
  std::map<int64_t, std::vector<int64_t>> ring_id_to_ranks{};
  std::map<int64_t, std::vector<int64_t>> rank_to_ring_ids{};
  if (!LoadConverterConfig(&ring_id_to_ranks, &rank_to_ring_ids)) {
    VLOG(3) << "Load converter config failed, DistModel init failed.";
    return false;
  }
  std::unique_ptr<framework::ProgramDesc> comm_init_program(
      new framework::ProgramDesc());
  framework::BlockDesc *comm_init_block = comm_init_program->MutableBlock(0);
  std::vector<int64_t> &ring_ids =
      rank_to_ring_ids[config_.dist_config().rank()];
  int64_t order = 0;
  std::string var_name_base = "comm_init_";
  for (int64_t ring_id : ring_ids) {
    VLOG(3) << "Init comm for ring id: " << ring_id;
    int64_t ranks_in_group = ring_id_to_ranks[ring_id].size();
    int64_t rank_in_group = 0;
    std::vector<int64_t> &ranks = ring_id_to_ranks[ring_id];
    for (int64_t rank : ranks) {
      if (config_.dist_config().rank() == rank) {
        break;
      }
      rank_in_group += 1;
    }
    std::vector<std::string> peer_endpoints;
    for (int64_t rank : ranks) {
      if (config_.dist_config().rank() == rank) {
        continue;
      }
      peer_endpoints.emplace_back(
          config_.dist_config().trainer_endpoints()[rank]);
    }
C
ccrrong 已提交
701 702 703 704 705 706
    InsertCommOp(var_name_base + std::to_string(order),
                 ranks_in_group,
                 rank_in_group,
                 peer_endpoints,
                 comm_init_block,
                 ring_id);
707 708 709 710 711 712 713 714 715 716 717
    order += 1;
  }
  framework::NaiveExecutor e(place_);
  e.CreateVariables(*comm_init_program, 0, true, scope_.get());
  e.Prepare(scope_.get(), *comm_init_program, 0, false);
  e.Run();
  VLOG(3) << "Comm init successful.";
  return true;
}

void AnalysisPredictor::InsertCommOp(
C
ccrrong 已提交
718 719 720 721 722
    std::string tmp_var_name,
    int nranks,
    int rank,
    const std::vector<std::string> &peer_endpoints,
    framework::BlockDesc *block,
723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778
    int ring_id) {
  /*
   * tmp_var_name: the var name for var comm_id
   * nranks: number of total ranks
   * rank: the rank of local rank in the comm group
   * peer_endpoints: peer's endpoints
   * block: the block where to insert the comm ops
   * ring_id: the ring_id to be inited
   */
  const std::string &endpoint = config_.dist_config().current_endpoint();
  std::stringstream ss;
  ss << "Init comm with tmp var: " << tmp_var_name
     << ". The ring id is: " << ring_id << ". The group has: " << nranks
     << " ranks. Current rank in the group is: " << rank
     << ". The endpoint is: " << endpoint << ". Peer endpoints are: ";
  for (auto ep : peer_endpoints) {
    ss << ep << ", ";
  }
  VLOG(3) << ss.str();
  if (config_.use_gpu()) {
    framework::VarDesc *new_var = block->Var(tmp_var_name);
    new_var->SetType(framework::proto::VarType::RAW);
    new_var->SetPersistable(true);
    framework::OpDesc *gen_nccl_id_op = block->AppendOp();
    gen_nccl_id_op->SetType("c_gen_nccl_id");
    gen_nccl_id_op->SetOutput("Out", {tmp_var_name});
    gen_nccl_id_op->SetAttr("rank", rank);
    gen_nccl_id_op->SetAttr("endpoint",
                            config_.dist_config().current_endpoint());
    gen_nccl_id_op->SetAttr("other_endpoints", peer_endpoints);
    gen_nccl_id_op->SetAttr("ring_id", ring_id);
    gen_nccl_id_op->SetAttr("op_role",
                            static_cast<int>(framework::OpRole::kForward));
    gen_nccl_id_op->CheckAttrs();
    framework::OpDesc *comm_init_op = block->AppendOp();
    comm_init_op->SetType("c_comm_init");
    comm_init_op->SetInput("X", {tmp_var_name});
    comm_init_op->SetAttr("rank", rank);
    comm_init_op->SetAttr("nranks", nranks);
    comm_init_op->SetAttr("ring_id", ring_id);
    comm_init_op->SetAttr("op_role",
                          static_cast<int>(framework::OpRole::kForward));
    comm_init_op->CheckAttrs();
  } else {
    LOG(WARNING) << "DistModelInf doesn't init comm.";
    // TODO(fleet exe dev): comm init for more devices
  }
}

bool AnalysisPredictor::LoadConverterConfig(
    std::map<int64_t, std::vector<int64_t>> *ring_id_to_ranks,
    std::map<int64_t, std::vector<int64_t>> *rank_to_ring_ids) {
  VLOG(3) << "Going to load converter config from: "
          << config_.dist_config().comm_init_config() << "\n";
  std::ifstream fin(config_.dist_config().comm_init_config(), std::ios::in);
  PADDLE_ENFORCE_EQ(
C
ccrrong 已提交
779 780
      static_cast<bool>(fin.is_open()),
      true,
781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852
      platform::errors::NotFound(
          "Cannot open file %s, please confirm whether the file is normal.",
          config_.dist_config().comm_init_config()));
  std::string line;
  bool ring_to_rank{true};
  // Reading config from file, the config file should like these format
  //  [ring_id -> ranks]
  //  0,0,1,2,3
  //  1,0,1
  //  2,2,3
  //  21,0,1
  //  22,1,2
  //  23,2,3
  //  [rank -> ring_ids]
  //  0,0,1,21
  //  1,0,1,21,22
  //  2,0,2,22,23
  //  3,0,2,23
  while (std::getline(fin, line)) {
    std::vector<std::string> one_line = paddle::string::Split(line, ',');
    if (one_line.size() == 1) {
      // start a new section of the config
      if (line == "[ring_id -> ranks]") {
        ring_to_rank = true;
      } else if (line == "[rank -> ring_ids]") {
        ring_to_rank = false;
      }
    } else {
      // parse key - values pairs in one section
      int64_t key = std::stoll(one_line[0]);
      for (size_t i = 1; i < one_line.size(); ++i) {
        int64_t val = std::stoll(one_line[i]);
        if (ring_to_rank) {
          if (ring_id_to_ranks->find(key) == ring_id_to_ranks->end()) {
            ring_id_to_ranks->insert({key, std::vector<int64_t>()});
          }
          ring_id_to_ranks->at(key).emplace_back(val);
        } else {
          if (rank_to_ring_ids->find(key) == rank_to_ring_ids->end()) {
            rank_to_ring_ids->insert({key, std::vector<int64_t>()});
          }
          rank_to_ring_ids->at(key).emplace_back(val);
        }
        // NOTE: add more configuration sections here
      }
    }
  }
  std::stringstream ss;
  ss << "Loaded the following converter config:\n";
  ss << "ring_id_to_ranks:\n";
  for (auto pair : *ring_id_to_ranks) {
    int64_t key = pair.first;
    ss << "\t" << key << "\t->\t";
    for (auto value : pair.second) {
      ss << value << "\t";
    }
    ss << "\n";
  }
  ss << "rank_to_ring_ids:\n";
  for (auto pair : *rank_to_ring_ids) {
    int64_t key = pair.first;
    ss << "\t" << key << "\t->\t";
    for (auto value : pair.second) {
      ss << value << "\t";
    }
    ss << "\n";
  }
  VLOG(3) << ss.str();
  return true;
}
#endif

853 854
void AnalysisPredictor::MkldnnPreSet(const std::vector<PaddleTensor> &inputs) {
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
855 856 857 858 859 860 861 862 863 864 865 866
  std::vector<std::vector<int>> inputs_shape;
  for (size_t i = 0; i < inputs.size(); ++i) {
    inputs_shape.emplace_back(inputs[i].shape);
  }
  MkldnnPreSet(inputs_shape);
#endif
}

void AnalysisPredictor::MkldnnPreSet(
    const std::vector<std::vector<int>> &inputs_shape) {
#ifdef PADDLE_WITH_MKLDNN
  VLOG(2) << "AnalysisPredictor::ZeroCopyRun get_cur_mkldnn_session_id="
867
          << platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id();
868 869 870
  // In cache clearing mode.
  if (config_.mkldnn_cache_capacity_ > 0) {
    VLOG(2) << "In mkldnn cache clear mode.";
871 872 873
    platform::MKLDNNDeviceContext::tls().set_cur_mkldnn_session_id(
        platform::MKLDNNDeviceContextThreadLocals::
            kMKLDNNSessionID_CacheClearing);
874 875
    // Set current_input_shape for caching dynamic shape.
    std::stringstream ss;
W
Wilber 已提交
876 877 878
    for (size_t i = 0; i < inputs_shape.size(); ++i) {
      for (size_t j = 0; j < inputs_shape[i].size(); ++j) {
        ss << inputs_shape[i][j] << "-";
879 880 881
      }
    }
    VLOG(2) << "Set input shape=" << ss.str();
882
    platform::MKLDNNDeviceContext::tls().set_cur_input_shape_str(ss.str());
883
  }
884 885 886
  platform::MKLDNNDeviceContext::tls().set_cur_input_shape_cache_capacity(
      config_.mkldnn_cache_capacity_);

887 888 889 890 891 892
#endif
}

void AnalysisPredictor::MkldnnPostReset() {
#ifdef PADDLE_WITH_MKLDNN
  // In cache clearing mode.
893 894 895 896
  if (config_.mkldnn_cache_capacity_ > 0 &&
      static_cast<platform::MKLDNNDeviceContext *>(
          (&platform::DeviceContextPool::Instance())->Get(platform::CPUPlace()))
              ->GetCachedObjectsNumber() > 0) {
897 898 899 900 901 902 903 904
    if (VLOG_IS_ON(2)) {
      auto shape_blob_size = static_cast<platform::MKLDNNDeviceContext *>(
                                 (&platform::DeviceContextPool::Instance())
                                     ->Get(platform::CPUPlace()))
                                 ->GetShapeBlobSize();
      CHECK_LE(shape_blob_size,
               static_cast<size_t>(config_.mkldnn_cache_capacity_));
    }
905 906 907
    // We cannot reset to the default cache settings
    // as there maybe CopyToCPU method used and oneDNN
    // primitives are used there so cache would grow
908 909 910 911
  }
#endif
}

912 913 914
bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
                            std::vector<PaddleTensor> *output_data,
                            int batch_size) {
915
  paddle::platform::SetNumThreads(config_.cpu_math_library_num_threads());
916 917 918
#ifdef PADDLE_WITH_MKLDNN
  if (config_.use_mkldnn_) MkldnnPreSet(inputs);
#endif
M
minqiyang 已提交
919
  VLOG(3) << "Predictor::predict";
920 921 922 923
  inference::Timer timer;
  timer.tic();
  // set feed variable
  framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get();
C
ccrrong 已提交
924 925 926
  PADDLE_ENFORCE_NOT_NULL(
      scope,
      platform::errors::PreconditionNotMet("The scope should not be nullptr."));
927 928
  if (!SetFeed(inputs, scope)) {
    LOG(ERROR) << "fail to set feed";
Y
Yan Chunwei 已提交
929
    return false;
930
  }
M
Michal Gallus 已提交
931

932 933 934 935 936 937 938 939 940
#ifdef PADDLE_WITH_TENSORRT
  if (config_.tensorrt_engine_enabled()) {
    inference::tensorrt::TensorRTEngine::predictor_id_per_thread =
        predictor_id_;
    VLOG(3) << "thread_local var predictor_id in TensorRTEngine is set to: "
            << inference::tensorrt::TensorRTEngine::predictor_id_per_thread;
  }
#endif

941 942 943
  // Run the inference program
  // if share variables, we need not create variables
  executor_->Run();
944

945 946 947 948
  // get fetch variable
  if (!GetFetch(output_data, scope)) {
    LOG(ERROR) << "fail to get fetches";
    return false;
T
tensor-tang 已提交
949
  }
Y
Yan Chunwei 已提交
950

M
minqiyang 已提交
951
  VLOG(3) << "predict cost: " << timer.toc() << "ms";
Y
Yan Chunwei 已提交
952

Y
Yan Chunwei 已提交
953 954 955 956 957
  // All the containers in the scope will be hold in inference, but the
  // operators assume that the container will be reset after each batch.
  // Here is a bugfix, collect all the container variables, and reset then to a
  // bool; the next time, the operator will call MutableData and construct a new
  // container again, so that the container will be empty for each batch.
958 959 960
  if (sub_scope_) {
    tensor_array_batch_cleaner_.CollectNoTensorVars(sub_scope_);
  }
Y
Yan Chunwei 已提交
961
  tensor_array_batch_cleaner_.ResetNoTensorVars();
962 963 964 965

  // recover the cpu_math_library_num_threads to 1, in order to avoid thread
  // conflict when integrating it into deployment service.
  paddle::platform::SetNumThreads(1);
966 967
#ifdef PADDLE_WITH_MKLDNN
  if (config_.use_mkldnn_) MkldnnPostReset();
T
Tao Luo 已提交
968
#endif
969
#if defined(PADDLE_WITH_MKLML)
T
Tao Luo 已提交
970 971 972 973
  // Frees unused memory allocated by the Intel® MKL Memory Allocator to
  // avoid memory leak. See:
  // https://software.intel.com/en-us/mkl-developer-reference-c-mkl-free-buffers
  platform::dynload::MKL_Free_Buffers();
974
#endif
975 976
  return true;
}
977

978 979
bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
                                framework::Scope *scope) {
M
minqiyang 已提交
980
  VLOG(3) << "Predictor::set_feed";
981 982 983 984 985 986 987 988 989 990
  if (inputs.size() != feeds_.size()) {
    LOG(ERROR) << "wrong feed input size, need " << feeds_.size() << " but get "
               << inputs.size();
    return false;
  }

  // Cache the inputs memory for better concurrency performance.
  feed_tensors_.resize(inputs.size());

  for (size_t i = 0; i < inputs.size(); ++i) {
991
    phi::DenseTensor *input = &feed_tensors_[i];
992
    if (!PaddleTensorToLoDTensor(inputs[i], input, place_)) {
993 994 995
      return false;
    }
    int idx = -1;
996
    if (config_.specify_input_name_) {
T
tensor-tang 已提交
997 998
      auto name = inputs[i].name;
      if (feed_names_.find(name) == feed_names_.end()) {
T
tensor-tang 已提交
999 1000
        LOG(ERROR) << "feed names from program do not have name: [" << name
                   << "] from specified input";
T
tensor-tang 已提交
1001 1002
      }
      idx = feed_names_[name];
1003
    } else {
R
Ruibiao Chen 已提交
1004
      idx = PADDLE_GET_CONST(int, feeds_[i]->GetAttr("col"));
1005
    }
1006
    framework::SetFeedVariable(scope, *input, "feed", idx);
1007 1008 1009 1010 1011
  }
  return true;
}

template <typename T>
1012
void AnalysisPredictor::GetFetchOne(const phi::DenseTensor &fetch,
1013 1014
                                    PaddleTensor *output) {
  // set shape.
1015
  auto shape = phi::vectorize(fetch.dims());
1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032
  output->shape.assign(shape.begin(), shape.end());
  // set data.
  const T *data = fetch.data<T>();
  int num_elems = inference::VecReduceToInt(shape);
  output->data.Resize(num_elems * sizeof(T));
  // The fetched tensor output by fetch op, should always in CPU memory, so just
  // copy.
  memcpy(output->data.data(), data, num_elems * sizeof(T));
  // set lod
  output->lod.clear();
  for (auto &level : fetch.lod()) {
    output->lod.emplace_back(level.begin(), level.end());
  }
}

bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
                                 framework::Scope *scope) {
M
minqiyang 已提交
1033
  VLOG(3) << "Predictor::get_fetch";
Y
Yan Chunwei 已提交
1034 1035
  outputs->resize(fetches_.size());
  for (size_t i = 0; i < fetches_.size(); ++i) {
R
Ruibiao Chen 已提交
1036
    int idx = PADDLE_GET_CONST(int, fetches_[i]->GetAttr("col"));
1037
    PADDLE_ENFORCE_EQ(
C
ccrrong 已提交
1038 1039
        static_cast<size_t>(idx),
        i,
1040
        platform::errors::InvalidArgument(
C
ccrrong 已提交
1041 1042
            "Fetch op's col attr(%d) should be equal to the index(%d)",
            idx,
1043
            i));
1044
    framework::FetchType &fetch_var =
1045
        framework::GetFetchVariable(*scope, "fetch", idx);
1046
    auto &fetch = PADDLE_GET(phi::DenseTensor, fetch_var);
1047
    auto type = framework::TransToProtoVarType(fetch.dtype());
1048
    auto output = &(outputs->at(i));
Y
Yan Chunwei 已提交
1049
    output->name = fetches_[idx]->Input("X")[0];
Y
Yu Yang 已提交
1050
    if (type == framework::proto::VarType::FP32) {
1051 1052
      GetFetchOne<float>(fetch, output);
      output->dtype = PaddleDType::FLOAT32;
Y
Yu Yang 已提交
1053
    } else if (type == framework::proto::VarType::INT64) {
1054 1055
      GetFetchOne<int64_t>(fetch, output);
      output->dtype = PaddleDType::INT64;
1056 1057 1058
    } else if (type == framework::proto::VarType::INT32) {
      GetFetchOne<int32_t>(fetch, output);
      output->dtype = PaddleDType::INT32;
1059 1060 1061
    } else if (type == framework::proto::VarType::FP16) {
      GetFetchOne<float16>(fetch, output);
      output->dtype = PaddleDType::FLOAT16;
1062
    } else {
1063 1064
      LOG(ERROR) << "unknown type, only support float32, float16, int64 and "
                    "int32 now.";
1065 1066
    }
  }
Y
Yan Chunwei 已提交
1067 1068
  return true;
}
1069

1070
void AnalysisPredictor::PrepareArgument() {
1071
  argument_.SetUseGPU(config_.use_gpu());
1072
  argument_.SetUseFcPadding(config_.use_fc_padding());
1073
  argument_.SetGPUDeviceId(config_.gpu_device_id());
1074
  argument_.SetEnableAnalysisOptim(config_.enable_ir_optim_);
1075
  argument_.SetEnableMemoryOptim(config_.enable_memory_optim());
T
Tao Luo 已提交
1076
  argument_.SetModelFromMemory(config_.model_from_memory_);
Y
Yan Chunwei 已提交
1077
  // Analyze inference_program
1078
  argument_.SetPredictorID(predictor_id_);
1079
  argument_.SetOptimCacheDir(config_.opt_cache_dir_);
1080 1081
  if (!config_.model_dir().empty()) {
    argument_.SetModelDir(config_.model_dir());
T
Tao Luo 已提交
1082
  } else {
C
ccrrong 已提交
1083 1084
    PADDLE_ENFORCE_EQ(config_.prog_file().empty(),
                      false,
1085 1086
                      platform::errors::PreconditionNotMet(
                          "Either model_dir or prog_file should be set."));
N
nhzlx 已提交
1087

1088 1089
    argument_.SetModelProgramPath(config_.prog_file());
    argument_.SetModelParamsPath(config_.params_file());
Y
Yan Chunwei 已提交
1090
  }
1091 1092
  // For JITLayer
  argument_.SetSkipLoadParams(config_.skip_load_params_);
1093

1094
  argument_.SetTensorRtPrecisionMode(config_.tensorrt_precision_mode_);
1095
  argument_.SetTensorRtUseOSS(config_.trt_use_varseqlen_);
1096
  argument_.SetTensorRtWithInterleaved(config_.trt_with_interleaved_);
1097 1098
  argument_.SetTensorRtTransformerPosid(config_.tensorrt_transformer_posid_);
  argument_.SetTensorRtTransformerMaskid(config_.tensorrt_transformer_maskid_);
1099 1100 1101 1102 1103
  argument_.SetMinInputShape(config_.min_input_shape_);
  argument_.SetMaxInputShape(config_.max_input_shape_);
  argument_.SetOptimInputShape(config_.optim_input_shape_);
  argument_.SetTensorRtTunedDynamicShape(
      config_.tuned_tensorrt_dynamic_shape());
1104
  if (config_.use_gpu() && config_.tensorrt_engine_enabled()) {
Y
Yan Chunwei 已提交
1105
    LOG(INFO) << "TensorRT subgraph engine is enabled";
1106 1107 1108
    argument_.SetUseTensorRT(true);
    argument_.SetTensorRtWorkspaceSize(config_.tensorrt_workspace_size_);
    argument_.SetTensorRtMaxBatchSize(config_.tensorrt_max_batchsize_);
1109
    argument_.SetTensorRtMinSubgraphSize(config_.tensorrt_min_subgraph_size_);
1110
    argument_.SetTensorRtDisabledOPs(config_.trt_disabled_ops_);
1111 1112
    argument_.SetTensorRtUseDLA(config_.trt_use_dla_);
    argument_.SetTensorRtDLACore(config_.trt_dla_core_);
N
nhzlx 已提交
1113
    argument_.SetTensorRtUseStaticEngine(config_.trt_use_static_engine_);
1114
    argument_.SetTensorRtUseCalibMode(config_.trt_use_calib_mode_);
1115
    argument_.SetCloseTrtPluginFp16(config_.disable_trt_plugin_fp16_);
1116 1117 1118
    argument_.SetTensorRtShapeRangeInfoPath(config_.shape_range_info_path());
    argument_.SetTensorRtAllowBuildAtRuntime(
        config_.trt_allow_build_at_runtime());
1119
    argument_.SetTensorRtUseInspector(config_.trt_use_inspector_);
1120
    argument_.SetTrtEngineMemorySharing(config_.trt_engine_memory_sharing());
W
Wojciech Uss 已提交
1121
  }
1122

D
denglin-github 已提交
1123 1124 1125 1126
  if (config_.dlnne_enabled()) {
    LOG(INFO) << "Dlnne subgraph is enabled";
    argument_.SetUseDlnne(true);
    argument_.SetDlnneMinSubgraphSize(config_.dlnne_min_subgraph_size_);
D
denglin-github 已提交
1127 1128 1129 1130 1131 1132 1133 1134
    argument_.SetDlnneMaxBatchSize(config_.dlnne_max_batchsize_);
    argument_.SetDlnneUseStaticBatch(config_.dlnne_use_static_batch_);
    argument_.SetDlnneWeightShareMode(config_.dlnne_weight_share_mode_);
    argument_.SetDlnneDisableNodesByOutputs(
        config_.dlnne_disable_nodes_by_outputs_);
    argument_.SetDlnneInputShapeDict(config_.dlnne_input_shape_dict_);
    argument_.SetDlnneUseCalibMode(config_.dlnne_use_calib_mode_);
    argument_.SetDlnnePrecisionMode(config_.dlnne_precision_mode_);
D
denglin-github 已提交
1135 1136
  }

石晓伟 已提交
1137
  if (config_.lite_engine_enabled()) {
W
Wilber 已提交
1138 1139
    argument_.SetCpuMathLibraryNumThreads(
        config_.cpu_math_library_num_threads());
石晓伟 已提交
1140 1141 1142
    argument_.SetLitePrecisionMode(config_.lite_precision_mode_);
    argument_.SetLitePassesFilter(config_.lite_passes_filter_);
    argument_.SetLiteOpsFilter(config_.lite_ops_filter_);
1143 1144 1145
    argument_.SetLiteZeroCopy(config_.lite_zero_copy_);
    argument_.SetUseXpu(config_.use_xpu_);
    argument_.SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_);
W
Wilber 已提交
1146 1147 1148 1149 1150
    argument_.SetXpuLocked(config_.xpu_locked_);
    argument_.SetXpuAutotune(config_.xpu_autotune_);
    argument_.SetXpuAutotuneFile(config_.xpu_autotune_file_);
    argument_.SetXpuPrecision(config_.xpu_precision_);
    argument_.SetXpuAdaptiveSeqlen(config_.xpu_adaptive_seqlen_);
1151
    argument_.SetXpuDeviceId(config_.xpu_device_id_);
1152
    argument_.SetXpuEnableMultiStream(config_.xpu_enable_multi_stream_);
1153
    argument_.SetUseOpenCL(config_.use_opencl_);
1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173
    // NNAdapter related
    argument_.SetUseNNAdapter(config_.NNAdapter().use_nnadapter);
    argument_.SetNNAdapterDeviceNames(
        config_.NNAdapter().nnadapter_device_names);
    argument_.SetNNAdapterContextProperties(
        config_.NNAdapter().nnadapter_context_properties);
    argument_.SetNNAdapterModelCacheDir(
        config_.NNAdapter().nnadapter_model_cache_dir);
    argument_.SetNNAdapterSubgraphPartitionConfigBuffer(
        config_.NNAdapter().nnadapter_subgraph_partition_config_buffer);
    argument_.SetNNAdapterSubgraphPartitionConfigPath(
        config_.NNAdapter().nnadapter_subgraph_partition_config_path);
    std::vector<std::string> buffer_keys;
    std::vector<std::vector<char>> buffer_vals;
    for (auto it : config_.NNAdapter().nnadapter_model_cache_buffers) {
      buffer_keys.emplace_back(it.first);
      buffer_vals.emplace_back(it.second);
    }
    argument_.SetNNAdapterModelCacheToken(buffer_keys);
    argument_.SetNNAdapterModelCacheBuffer(buffer_vals);
石晓伟 已提交
1174 1175 1176
    LOG(INFO) << "Lite subgraph engine is enabled";
  }

1177
#ifdef PADDLE_WITH_IPU
J
jianghaicheng 已提交
1178 1179
  argument_.SetUseIpu(config_.use_ipu_);
  argument_.SetIpuDeviceNum(config_.ipu_device_num());
1180
  argument_.SetIpuMicroBatchSize(config_.ipu_micro_batch_size_);
J
jianghaicheng 已提交
1181 1182
  argument_.SetIpuEnablePipelining(config_.ipu_enable_pipelining_);
  argument_.SetIpuBatchesPerStep(config_.ipu_batches_per_step_);
1183 1184 1185 1186 1187
  argument_.SetIpuEnableFp16(config_.ipu_enable_fp16_);
  argument_.SetIpuReplicaNum(config_.ipu_replica_num_);
  argument_.SetIpuAvailableMemoryProportion(
      config_.ipu_available_memory_proportion_);
  argument_.SetIpuEnableHalfPartial(config_.ipu_enable_half_partial_);
1188 1189
  argument_.SetIpuEnableModelRuntimeExecutor(
      config_.ipu_enable_model_runtime_executor_);
1190 1191
  argument_.SetIpuCustomOpsInfo(config_.ipu_custom_ops_info_);
  argument_.SetIpuCustomPatterns(config_.ipu_custom_patterns_);
1192
#endif
J
jianghaicheng 已提交
1193

1194 1195 1196
  argument_.SetUseNpu(config_.use_npu_);
  argument_.SetNPUDeviceId(config_.npu_device_id());

1197
  if (config_.use_mkldnn_) {
Y
Yan Chunwei 已提交
1198
    LOG(INFO) << "MKLDNN is enabled";
1199 1200 1201
    argument_.SetMKLDNNEnabledOpTypes(config_.mkldnn_enabled_op_types_);
  }

1202 1203 1204 1205 1206 1207 1208 1209
#ifdef PADDLE_WITH_MKLDNN
  if (config_.mkldnn_quantizer_enabled()) {
    LOG(INFO) << "Quantization is enabled";
    argument_.SetQuantizeEnabledOpTypes(
        config_.mkldnn_quantizer_config()->enabled_op_types());
    argument_.SetQuantizeExcludedOpIds(
        config_.mkldnn_quantizer_config()->excluded_op_ids());
  }
1210 1211 1212 1213
  if (config_.use_mkldnn_bfloat16_) {
    LOG(INFO) << "Bfloat16 is enabled";
    argument_.SetBfloat16EnabledOpTypes(config_.bfloat16_enabled_op_types_);
  }
B
baoachun 已提交
1214 1215 1216 1217 1218 1219 1220

  if (config_.use_mkldnn_int8_) {
    LOG(INFO) << "Int8 is enabled";
    argument_.SetQuantizeEnabledOpTypes(config_.quantize_enabled_op_types_);
    argument_.SetQuantizeExcludedOpIds(config_.quantize_excluded_op_ids_);
    argument_.SetQuantVarScales({});
  }
1221 1222
#endif

1223
  auto passes = config_.pass_builder()->AllPasses();
1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257
  if (model_precision_ != phi::DataType::FLOAT32) {
    LOG(INFO) << "Model is mixed precision type with " << model_precision_
              << ", we will use a new PassStrategy. Note that only the GPU "
                 "backend is supported for now.";
    passes.clear();
    if (config_.tensorrt_engine_enabled()) {
      for (const auto &pass : kTrtLowerPrecisionPasses) {
        passes.push_back(pass);
      }
    } else if (config_.use_gpu()) {
      for (const auto &pass : kGpuLowerPrecisionPasses) {
        passes.push_back(pass);
      }
    }

    const auto &deleted_passes = config_.pass_builder()->GetAllDeletedPasses();
    for (const auto &it : deleted_passes) {
      auto iterator = std::find(passes.begin(), passes.end(), it);
      if (iterator != passes.end()) {
        passes.erase(iterator);
      }
    }

    if (config_.ir_debug_) {
      auto it = std::begin(passes);
      while (it != std::end(passes)) {
        if (*it != "graph_viz_pass") {
          it = passes.insert(it + 1, "graph_viz_pass");
        } else {
          ++it;
        }
      }
    }
  }
Y
Yan Chunwei 已提交
1258 1259 1260 1261
  if (!config_.ir_optim()) {
    passes.clear();
    LOG(INFO) << "ir_optim is turned off, no IR pass will be executed";
  }
1262
  argument_.SetDisableLogs(config_.glog_info_disabled());
1263
  argument_.SetIrAnalysisPasses(passes);
Y
Yan Chunwei 已提交
1264
  argument_.SetAnalysisPasses(config_.pass_builder()->AnalysisPasses());
1265
  argument_.SetScopeNotOwned(scope_.get());
1266

1267
  // mixed precison.
1268
  argument_.SetModelPrecision(static_cast<int>(model_precision_));
1269
  argument_.SetMixedBlackList(config_.mixed_black_list_);
1270 1271 1272 1273 1274
}

// NOTE All the members in AnalysisConfig should be copied to Argument.
void AnalysisPredictor::OptimizeInferenceProgram() {
  PrepareArgument();
1275 1276 1277 1278 1279 1280 1281 1282 1283 1284

#ifdef PADDLE_WITH_TENSORRT
  if (config_.tensorrt_engine_enabled()) {
    inference::tensorrt::TensorRTEngine::predictor_id_per_thread =
        predictor_id_;
    VLOG(3) << "thread_local var predictor_id in TensorRTEngine is set to: "
            << inference::tensorrt::TensorRTEngine::predictor_id_per_thread;
  }
#endif

1285 1286
  Analyzer().Run(&argument_);

1287
  PADDLE_ENFORCE_EQ(
C
ccrrong 已提交
1288 1289
      argument_.scope_valid(),
      true,
1290
      platform::errors::InvalidArgument("The argument scope should be valid."));
1291 1292
  VLOG(5) << "to prepare executor";
  ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program);
Y
Yan Chunwei 已提交
1293
  inference_program_.reset(
1294 1295 1296 1297
      new framework::ProgramDesc(argument_.ir_analyzed_program()),
      [](framework::ProgramDesc *prog) {
// Note, please do NOT use any member variables, because member variables may
// have been destructed in multiple threads.
1298
#ifdef PADDLE_WITH_TENSORRT
W
Wilber 已提交
1299 1300 1301 1302
        auto &block = prog->Block(0);
        for (auto &op_desc : block.AllOps()) {
          if (op_desc->Type() == "tensorrt_engine") {
            std::string engine_key =
R
Ruibiao Chen 已提交
1303
                PADDLE_GET_CONST(std::string, op_desc->GetAttr("engine_key"));
W
Wilber 已提交
1304
            int engine_predictor_id =
R
Ruibiao Chen 已提交
1305
                PADDLE_GET_CONST(int, op_desc->GetAttr("predictor_id"));
W
Wilber 已提交
1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316
            std::string engine_name =
                engine_key + std::to_string(engine_predictor_id);
            if (paddle::inference::Singleton<
                    inference::tensorrt::TRTEngineManager>::Global()
                    .Has(engine_name)) {
              paddle::inference::Singleton<
                  inference::tensorrt::TRTEngineManager>::Global()
                  .DeleteKey(engine_name);
            }
          }
        }
1317 1318 1319
#endif
        delete prog;
      });
1320 1321 1322 1323
  // The config and argument take a lot of storage,
  // when the predictor settings are complete, we release these stores.
  argument_.PartiallyRelease();
  config_.PartiallyRelease();
1324
  LOG(INFO) << "======= optimize end =======";
Y
Yan Chunwei 已提交
1325
}
1326 1327

template <>
1328 1329 1330
std::unique_ptr<PaddlePredictor>
CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
    const AnalysisConfig &config) {
W
Wilber 已提交
1331 1332
  // TODO(NHZlX): Should add the link to the doc of
  // paddle_infer::CreatePredictor<paddle_infer::Config>
P
Pei Yang 已提交
1333 1334 1335 1336
  if (config.glog_info_disabled()) {
    FLAGS_logtostderr = 1;
    FLAGS_minloglevel = 2;  // GLOG_ERROR
  }
M
minqiyang 已提交
1337
  VLOG(3) << "create AnalysisConfig";
1338
  PADDLE_ENFORCE_EQ(
C
ccrrong 已提交
1339 1340
      config.is_valid(),
      true,
1341 1342
      platform::errors::InvalidArgument(
          "Note: Each config can only be used for one predictor."));
1343

1344 1345 1346 1347
  // Register custom operators compiled by the user.
  // This function can only be executed once per process.
  static std::once_flag custom_operators_registered;
  std::call_once(custom_operators_registered,
1348
                 []() { inference::RegisterAllCustomOperator(); });
1349

1350
  if (config.use_gpu()) {
1351 1352 1353 1354 1355 1356
    static std::once_flag gflags_initialized;
    static bool process_level_allocator_enabled;

    std::call_once(gflags_initialized, [&]() {
      std::vector<std::string> gflags;
      PADDLE_ENFORCE_GE(
C
ccrrong 已提交
1357 1358
          config.memory_pool_init_size_mb(),
          0.f,
1359 1360 1361
          platform::errors::InvalidArgument(
              "The size of memory pool should be greater than 0."));
      PADDLE_ENFORCE_GE(
C
ccrrong 已提交
1362 1363
          config.gpu_device_id(),
          0,
1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376
          platform::errors::InvalidArgument(
              "Invalid device id (%d). The device id should be greater than 0.",
              config.gpu_device_id()));
      gflags.push_back("dummy");

      float fraction_of_gpu_memory = config.fraction_of_gpu_memory_for_pool();
      if (fraction_of_gpu_memory > 0.95f) {
        LOG(ERROR)
            << "Allocate too much memory for the GPU memory pool, assigned "
            << config.memory_pool_init_size_mb() << " MB";
        LOG(ERROR) << "Try to shink the value by setting "
                      "AnalysisConfig::EnableGpu(...)";
      }
1377

1378 1379 1380 1381 1382 1383 1384
      if (fraction_of_gpu_memory >= 0.0f || fraction_of_gpu_memory <= 0.95f) {
        std::string flag = "--fraction_of_gpu_memory_to_use=" +
                           std::to_string(fraction_of_gpu_memory);
        VLOG(3) << "set flag: " << flag;
        gflags.push_back(flag);
      }

1385 1386 1387 1388 1389 1390 1391 1392 1393
      // TODO(Shixiaowei02): Add a mandatory scheme to use the thread local
      // allocator when multi-stream is enabled.
      if (config.thread_local_stream_enabled()) {
        gflags.push_back("--allocator_strategy=thread_local");
        process_level_allocator_enabled = false;
      } else {
        process_level_allocator_enabled = true;
      }

W
Wilber 已提交
1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404
      // support set flags from enviorment.
      const platform::ExportedFlagInfoMap &env_map =
          platform::GetExportedFlagInfoMap();
      std::ostringstream os;
      os << "--tryfromenv=";
      for (auto &pair : env_map) {
        os << pair.second.name << ",";
      }
      auto tryfromenv_str = os.str();
      gflags.push_back(os.str().substr(0, tryfromenv_str.size() - 1));

1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419
      if (framework::InitGflags(gflags)) {
        VLOG(3) << "The following gpu analysis configurations only take effect "
                   "for the first predictor: ";
        for (size_t i = 1; i < gflags.size(); ++i) {
          VLOG(3) << gflags[i];
        }
      } else {
        LOG(WARNING) << "The one-time configuration of analysis predictor "
                        "failed, which may be due to native predictor called "
                        "first and its configurations taken effect.";
      }
    });

    if (config.thread_local_stream_enabled() &&
        process_level_allocator_enabled) {
1420 1421 1422 1423 1424 1425
      PADDLE_THROW(platform::errors::Fatal(
          "When binding threads and streams, the use of "
          "process-level allocators will result in undefined result "
          "errors due to memory asynchronous operations."
          "The thread and stream binding configuration of all "
          "predictors should be the same in a single process."));
1426 1427 1428 1429
    }
  }

  std::unique_ptr<PaddlePredictor> predictor(new AnalysisPredictor(config));
1430 1431
  // Each config can only be used for one predictor.
  config.SetInValid();
1432 1433
  auto predictor_p = dynamic_cast<AnalysisPredictor *>(predictor.get());

1434 1435 1436 1437
#ifdef PADDLE_WITH_TENSORRT
  paddle::framework::ir::patterns::KeyCounter::Instance().CleanCounter();
#endif

1438 1439 1440 1441 1442
  if (!predictor_p->Init(nullptr)) {
    return nullptr;
  }

  if (config.mkldnn_quantizer_enabled() && !predictor_p->MkldnnQuantize()) {
1443 1444
    return nullptr;
  }
1445

G
Gabor Buella 已提交
1446
  return predictor;
1447 1448
}

1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460
bool AnalysisPredictor::MkldnnQuantize() {
#if PADDLE_WITH_MKLDNN
  if (!mkldnn_quantizer_)
    mkldnn_quantizer_ = new AnalysisPredictor::MkldnnQuantizer(
        *this, config_.mkldnn_quantizer_config());
  return mkldnn_quantizer_->Quantize();
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer";
  return false;
#endif
}

1461
void AnalysisPredictor::PrepareFeedFetch() {
1462 1463 1464
  PADDLE_ENFORCE_NOT_NULL(sub_scope_,
                          platform::errors::InvalidArgument(
                              "The sub_scope should not be nullptr."));
1465
  CreateFeedFetchVar(sub_scope_);
1466 1467
  for (auto *op : inference_program_->Block(0).AllOps()) {
    if (op->Type() == "feed") {
R
Ruibiao Chen 已提交
1468
      int idx = PADDLE_GET_CONST(int, op->GetAttr("col"));
1469 1470 1471 1472 1473
      if (feeds_.size() <= static_cast<size_t>(idx)) {
        feeds_.resize(idx + 1);
      }
      feeds_[idx] = op;
      feed_names_[op->Output("Out")[0]] = idx;
N
nhzlx 已提交
1474
      idx2feeds_[idx] = op->Output("Out")[0];
1475
    } else if (op->Type() == "fetch") {
R
Ruibiao Chen 已提交
1476
      int idx = PADDLE_GET_CONST(int, op->GetAttr("col"));
Y
Yan Chunwei 已提交
1477 1478
      if (fetches_.size() <= static_cast<size_t>(idx)) {
        fetches_.resize(idx + 1);
1479
      }
Y
Yan Chunwei 已提交
1480
      fetches_[idx] = op;
N
nhzlx 已提交
1481
      idx2fetches_[idx] = op->Input("X")[0];
1482 1483 1484 1485
    }
  }
}

1486
void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) {
C
ccrrong 已提交
1487 1488 1489
  PADDLE_ENFORCE_NOT_NULL(
      scope,
      platform::errors::InvalidArgument("The scope should not be nullptr."));
1490
  auto *var = scope->Var("feed");
1491
  var->GetMutable<framework::FeedList>();
1492
  var = scope->Var("fetch");
1493
  var->GetMutable<framework::FetchList>();
1494 1495
}

N
nhzlx 已提交
1496 1497 1498 1499 1500 1501 1502 1503
std::vector<std::string> AnalysisPredictor::GetInputNames() {
  std::vector<std::string> input_names;
  for (auto &item : idx2feeds_) {
    input_names.push_back(item.second);
  }
  return input_names;
}

1504 1505 1506 1507 1508 1509
std::map<std::string, std::vector<int64_t>>
AnalysisPredictor::GetInputTensorShape() {
  std::map<std::string, std::vector<int64_t>> input_shapes;
  std::vector<std::string> names = GetInputNames();
  for (std::string name : names) {
    auto *var = inference_program_->Block(0).FindVar(name);
C
ccrrong 已提交
1510 1511 1512
    PADDLE_ENFORCE_NOT_NULL(
        var,
        platform::errors::PreconditionNotMet("Input %s does not exist.", name));
1513 1514 1515 1516 1517
    input_shapes[name] = var->GetShape();
  }
  return input_shapes;
}

1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548
std::map<std::string, paddle_infer::DataType>
AnalysisPredictor::GetInputTypes() {
  std::map<std::string, paddle_infer::DataType> input_type;
  std::vector<std::string> names = GetInputNames();
  for (const auto &name : names) {
    auto *var = inference_program_->Block(0).FindVar(name);
    PADDLE_ENFORCE_NOT_NULL(
        var,
        platform::errors::PreconditionNotMet(
            "Input %s does not exist inference_program_.", name));
    auto dtype = var->GetDataType();
    if (dtype == paddle::framework::proto::VarType::FP32) {
      input_type[name] = paddle_infer::DataType::FLOAT32;
    } else if (dtype == paddle::framework::proto::VarType::FP16) {
      input_type[name] = paddle_infer::DataType::FLOAT16;
    } else if (dtype == paddle::framework::proto::VarType::INT64) {
      input_type[name] = paddle_infer::DataType::INT64;
    } else if (dtype == paddle::framework::proto::VarType::INT32) {
      input_type[name] = paddle_infer::DataType::INT32;
    } else if (dtype == paddle::framework::proto::VarType::UINT8) {
      input_type[name] = paddle_infer::DataType::UINT8;
    } else if (dtype == paddle::framework::proto::VarType::INT8) {
      input_type[name] = paddle_infer::DataType::INT8;
    } else {
      PADDLE_THROW(paddle::platform::errors::Unimplemented(
          "Unsupported data type `%s` when get input dtype ", dtype));
    }
  }
  return input_type;
}

N
nhzlx 已提交
1549 1550 1551 1552 1553 1554 1555 1556
std::vector<std::string> AnalysisPredictor::GetOutputNames() {
  std::vector<std::string> output_names;
  for (auto &item : idx2fetches_) {
    output_names.push_back(item.second);
  }
  return output_names;
}

1557 1558
std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetInputTensor(
    const std::string &name) {
1559
  framework::Scope *scope;
1560
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
1561 1562 1563
  if (config_.dist_config().use_dist_model()) {
    scope = scope_.get();
  } else {
1564
    scope = executor_->GetScope();
1565 1566
  }
#else
1567
  scope = executor_->GetScope();
1568
#endif
1569
  PADDLE_ENFORCE_NOT_NULL(
1570
      scope->FindVar(name),
1571
      platform::errors::PreconditionNotMet(
1572
          "The variable named %s is not found in the scope of the executor.",
1573
          name));
1574 1575
  std::unique_ptr<ZeroCopyTensor> res(new ZeroCopyTensor(
      static_cast<void *>(scope), this->GetDeviceContexts()));
1576 1577
  res->input_or_output_ = true;
  res->SetName(name);
N
nhzlx 已提交
1578 1579
  if (platform::is_cpu_place(place_)) {
    res->SetPlace(PaddlePlace::kCPU);
J
jianghaicheng 已提交
1580 1581 1582 1583
  } else if (platform::is_ipu_place(place_)) {
    // Currently, IPUPlace's tensor copy between cpu and ipu has been set in
    // IpuBackend.
    res->SetPlace(PaddlePlace::kCPU);
1584
  } else if (platform::is_xpu_place(place_)) {
1585 1586 1587 1588 1589 1590 1591 1592
    if (config_.lite_engine_enabled()) {
      // Currently, Paddle-Lite's XPU user interface only supports the transfer
      // of host data pointers. If it is currently used as a subgraph, execution
      // efficiency will be sacrificed, so it is temporarily set to cpu place.
      // And, the current lite engine of xpu must execute all parts of the
      // model.
      res->SetPlace(PaddlePlace::kCPU);
    } else {
1593
      auto xpu_place = place_;
1594 1595
      res->SetPlace(PaddlePlace::kXPU, xpu_place.GetDeviceId());
    }
W
Wilber 已提交
1596
  } else if (platform::is_npu_place(place_)) {
1597
    auto npu_place = place_;
W
Wilber 已提交
1598
    res->SetPlace(PaddlePlace::kNPU, npu_place.GetDeviceId());
1599 1600 1601 1602 1603 1604
  } else if (platform::is_custom_place(place_)) {
    auto custom_place = place_;
    auto paddleplace = static_cast<PaddlePlace>(
        static_cast<size_t>(PaddlePlace::kCUSTOM) +
        phi::GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType()));
    res->SetPlace(paddleplace, custom_place.GetDeviceId());
N
nhzlx 已提交
1605
  } else {
1606
    auto gpu_place = place_;
N
nhzlx 已提交
1607 1608
    res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
  }
1609 1610 1611 1612 1613
  return res;
}

std::unique_ptr<ZeroCopyTensor> AnalysisPredictor::GetOutputTensor(
    const std::string &name) {
1614
  framework::Scope *scope;
1615
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
1616 1617 1618
  if (config_.dist_config().use_dist_model()) {
    scope = scope_.get();
  } else {
1619
    scope = executor_->GetScope();
1620 1621
  }
#else
1622
  scope = executor_->GetScope();
1623
#endif
1624
  PADDLE_ENFORCE_NOT_NULL(
1625
      scope->FindVar(name),
1626
      platform::errors::PreconditionNotMet(
1627
          "The variable named %s is not found in the scope of the executor.",
1628
          name));
1629 1630
  std::unique_ptr<ZeroCopyTensor> res(new ZeroCopyTensor(
      static_cast<void *>(scope), this->GetDeviceContexts()));
1631 1632
  res->input_or_output_ = false;
  res->SetName(name);
N
nhzlx 已提交
1633 1634
  if (platform::is_cpu_place(place_)) {
    res->SetPlace(PaddlePlace::kCPU);
J
jianghaicheng 已提交
1635 1636 1637 1638
  } else if (platform::is_ipu_place(place_)) {
    // Currently, IPUPlace's tensor copy between cpu and ipu has been set in
    // IpuBackend.
    res->SetPlace(PaddlePlace::kCPU);
1639
  } else if (platform::is_xpu_place(place_)) {
1640 1641 1642 1643 1644 1645 1646 1647
    if (config_.lite_engine_enabled()) {
      // Currently, Paddle-Lite's XPU user interface only supports the transfer
      // of host data pointers. If it is currently used as a subgraph, execution
      // efficiency will be sacrificed, so it is temporarily set to cpu place.
      // And, the current lite engine of xpu must execute all parts of the
      // model.
      res->SetPlace(PaddlePlace::kCPU);
    } else {
1648
      auto xpu_place = place_;
1649 1650
      res->SetPlace(PaddlePlace::kXPU, xpu_place.GetDeviceId());
    }
W
Wilber 已提交
1651
  } else if (platform::is_npu_place(place_)) {
1652
    auto npu_place = place_;
W
Wilber 已提交
1653
    res->SetPlace(PaddlePlace::kNPU, npu_place.GetDeviceId());
1654 1655 1656 1657 1658 1659
  } else if (platform::is_custom_place(place_)) {
    auto custom_place = place_;
    auto paddleplace = static_cast<PaddlePlace>(
        static_cast<size_t>(PaddlePlace::kCUSTOM) +
        phi::GetOrRegisterGlobalDeviceTypeId(place_.GetDeviceType()));
    res->SetPlace(paddleplace, custom_place.GetDeviceId());
N
nhzlx 已提交
1660
  } else {
1661
    auto gpu_place = place_;
N
nhzlx 已提交
1662 1663
    res->SetPlace(PaddlePlace::kGPU, gpu_place.GetDeviceId());
  }
1664 1665 1666 1667
  return res;
}

bool AnalysisPredictor::ZeroCopyRun() {
1668
  inference::DisplayMemoryInfo(place_, "before run");
1669
#if defined(PADDLE_WITH_DISTRIBUTE) && defined(PADDLE_WITH_PSCORE)
1670 1671 1672 1673 1674 1675 1676 1677 1678 1679
  if (config_.dist_config().use_dist_model()) {
    VLOG(3) << "ZeroCopyRun will use the fleet executor.";
    inference::Timer timer;
    timer.tic();
    fleet_exe_->Run(config_.dist_config().carrier_id());
    VLOG(3) << "Fleet executor inf runs once use: "
            << std::to_string(timer.toc()) << "ms";
    return true;
  }
#endif
1680 1681 1682
  if (private_context_) {
    paddle::platform::DeviceContextPool::SetDeviceContexts(&device_contexts_);
  }
1683
  paddle::platform::SetNumThreads(config_.cpu_math_library_num_threads());
W
Wilber 已提交
1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694
#ifdef PADDLE_WITH_MKLDNN
  if (config_.use_mkldnn_) {
    std::vector<std::vector<int>> shape_vector;
    auto names = GetInputNames();
    for (size_t i = 0; i < names.size(); ++i) {
      auto in_tensor = GetInputTensor(names[i]);
      shape_vector.emplace_back(in_tensor->shape());
    }
    MkldnnPreSet(shape_vector);
  }
#endif
1695 1696 1697 1698 1699 1700 1701 1702 1703 1704

#ifdef PADDLE_WITH_TENSORRT
  if (config_.tensorrt_engine_enabled()) {
    inference::tensorrt::TensorRTEngine::predictor_id_per_thread =
        predictor_id_;
    VLOG(3) << "thread_local var predictor_id in TensorRTEngine is set to: "
            << inference::tensorrt::TensorRTEngine::predictor_id_per_thread;
  }
#endif

1705
  executor_->Run();
1706
  inference::DisplayMemoryInfo(place_, "after run");
1707 1708 1709 1710 1711

  if (config_.shape_range_info_collected()) {
    CollectShapeRangeInfo();
  }

Y
Yan Chunwei 已提交
1712
  // Fix TensorArray reuse not cleaned bug.
Y
Yan Chunwei 已提交
1713
  tensor_array_batch_cleaner_.CollectTensorArrays(sub_scope_);
Y
Yan Chunwei 已提交
1714
  tensor_array_batch_cleaner_.ResetTensorArray();
1715 1716 1717 1718

  // recover the cpu_math_library_num_threads to 1, in order to avoid thread
  // conflict when integrating it into deployment service.
  paddle::platform::SetNumThreads(1);
1719 1720 1721
  if (private_context_) {
    paddle::platform::DeviceContextPool::SetDeviceContexts(nullptr);
  }
W
Wilber 已提交
1722 1723 1724
#ifdef PADDLE_WITH_MKLDNN
  if (config_.use_mkldnn_) MkldnnPostReset();
#endif
1725
#if defined(PADDLE_WITH_MKLML)
T
Tao Luo 已提交
1726 1727 1728 1729 1730
  // Frees unused memory allocated by the Intel® MKL Memory Allocator to
  // avoid memory leak. See:
  // https://software.intel.com/en-us/mkl-developer-reference-c-mkl-free-buffers
  platform::dynload::MKL_Free_Buffers();
#endif
1731 1732 1733
  return true;
}

W
Wilber 已提交
1734 1735
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
bool AnalysisPredictor::ExpRunWithExternalStream(const gpuStream_t stream) {
W
Wilber 已提交
1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760
  if (!private_context_) {
    PADDLE_THROW(platform::errors::Fatal(
        "Please use config.SetExecStream to init gpu resources, and then we "
        "will bind gpu resources to execution stream."));
  }

  if (stream != predictor_stream_) {
#ifdef PADDLE_WITH_HIP
    hipStreamSynchronize(static_cast<gpuStream_t>(predictor_stream_));
#else
    cudaStreamSynchronize(static_cast<gpuStream_t>(predictor_stream_));
#endif
    ResourceManager::Instance().GpuResourceReBindStream(predictor_stream_,
                                                        stream);
    predictor_stream_ = stream;

    auto *dev_ctxs = reinterpret_cast<const std::map<
        phi::Place,
        std::shared_future<std::unique_ptr<phi::DeviceContext>>> *>(
        this->GetDeviceContexts());
    auto *dev_ctx =
        static_cast<InferGPUContext *>(dev_ctxs->at(place_).get().get());
    dev_ctx->SetStream(stream);
  }

W
Wilber 已提交
1761 1762 1763 1764
  return ZeroCopyRun();
}
#endif

1765 1766 1767 1768 1769 1770
void AnalysisPredictor::CollectShapeRangeInfo() {
  // if use gpu, sync first.
  if (config_.use_gpu()) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    paddle::platform::DeviceContextPool &pool =
        paddle::platform::DeviceContextPool::Instance();
1771
    auto gpu_place = place_;
L
Leo Chen 已提交
1772
    auto *dev_ctx = static_cast<const phi::GPUContext *>(pool.Get(gpu_place));
1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783
#ifdef PADDLE_WITH_HIP
    hipStreamSynchronize(dev_ctx->stream());
#else
    cudaStreamSynchronize(dev_ctx->stream());
#endif
#endif
  }

  std::vector<std::string> var_names = sub_scope_->LocalVarNames();
  for (const auto &name : var_names) {
    auto *var = sub_scope_->GetVar(name);
1784
    if (!var->IsType<phi::DenseTensor>()) {
1785 1786
      continue;
    }
1787 1788
    auto tensor = var->Get<phi::DenseTensor>();
    framework::DDim dim = tensor.dims();
1789 1790 1791
    std::vector<int32_t> shape(dim.size());
    for (size_t i = 0; i < shape.size(); ++i) shape[i] = dim[i];
    shape_info_[name].emplace_back(shape);
1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819

    // We need collect value range for shape tensor for Paddle-TRT's use.
    // To be noticed, this method to identify all shape tensors is based on
    // assumption that all shape tensors in the model have numbers <= 7.
    // This is a simple method to identify all shape tensors with some
    // mistakes, but it doesn't matter.
    auto is_shape_tensor = tensor.numel() <= 7 && tensor.numel() >= 1;
    if (tensor.dtype() == paddle::experimental::DataType::INT32 &&
        is_shape_tensor) {
      std::vector<int> int32_host(tensor.numel());
      if (tensor.place() == platform::CPUPlace()) {
        paddle::memory::Copy(platform::CPUPlace(),
                             int32_host.data(),
                             platform::CPUPlace(),
                             tensor.data<int>(),
                             tensor.numel() * sizeof(int));
      } else if (tensor.place() == platform::CUDAPlace()) {
#if defined(PADDLE_WITH_CUDA)
        paddle::memory::Copy(platform::CPUPlace(),
                             int32_host.data(),
                             platform::CUDAPlace(),
                             tensor.data<int>(),
                             tensor.numel() * sizeof(int),
                             nullptr);
#endif
      }
      shape_tensor_value_[name].emplace_back(int32_host);
    }
1820 1821 1822 1823 1824 1825 1826
  }
}

void AnalysisPredictor::StatisticShapeRangeInfo() {
  std::map<std::string, std::vector<int32_t>> min_shapes;
  std::map<std::string, std::vector<int32_t>> max_shapes;
  std::map<std::string, std::vector<int32_t>> opt_shapes;
1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865
  std::map<std::string, std::vector<int32_t>> min_values;
  std::map<std::string, std::vector<int32_t>> max_values;
  std::map<std::string, std::vector<int32_t>> opt_values;

  auto extract_min_max_opt =
      [](std::map<std::string, std::vector<int32_t>> &min_data,
         decltype(min_data) max_data,
         decltype(min_data) opt_data,
         decltype(shape_info_) shape_data) {
        for (auto it : shape_data) {
          auto name = it.first;
          auto shapes = it.second;

          std::vector<int32_t> min_shape(shapes[0].begin(), shapes[0].end());
          std::vector<int32_t> max_shape(shapes[0].begin(), shapes[0].end());
          std::vector<int32_t> opt_shape(shapes[0].begin(), shapes[0].end());

          auto ShapeMaxFreq =
              [](const std::map<int32_t, int32_t> &m) -> int32_t {
            std::vector<std::pair<int32_t, int32_t>> counter;
            for (auto &it : m) counter.push_back(it);
            std::sort(counter.begin(),
                      counter.end(),
                      [](std::pair<int32_t, int32_t> &a,
                         std::pair<int32_t, int32_t> &b) {
                        return a.second > b.second;
                      });
            return counter[0].first;
          };

          for (size_t d = 0; d < shapes[0].size(); ++d) {
            std::map<int32_t, int32_t> counter;
            for (size_t i = 0; i < shapes.size(); ++i) {
              counter[shapes[i][d]] += 1;
              if (shapes[i][d] < min_shape[d]) min_shape[d] = shapes[i][d];
              if (shapes[i][d] > max_shape[d]) max_shape[d] = shapes[i][d];
            }
            opt_shape[d] = ShapeMaxFreq(counter);
          }
1866

1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881
          min_data[name] = min_shape;
          max_data[name] = max_shape;
          opt_data[name] = opt_shape;
        }
      };
  extract_min_max_opt(min_shapes, max_shapes, opt_shapes, shape_info_);
  extract_min_max_opt(min_values, max_values, opt_values, shape_tensor_value_);

  inference::SerializeShapeRangeInfo(config_.shape_range_info_path(),
                                     min_shapes,
                                     max_shapes,
                                     opt_shapes,
                                     min_values,
                                     max_values,
                                     opt_values);
1882 1883
}

1884 1885
bool AnalysisPredictor::LoadProgramDesc() {
  // Initialize the inference program
1886
  std::string filename;
1887 1888
  if (!config_.model_dir().empty()) {
    filename = config_.model_dir() + "/__model__";
1889
  } else if (!config_.prog_file().empty()) {
1890 1891 1892
    // All parameters are saved in a single file.
    // The file names should be consistent with that used
    // in Python API `fluid.io.save_inference_model`.
1893
    filename = config_.prog_file();
1894
  } else {
1895
    if (config_.model_dir().empty() && config_.prog_file().empty()) {
1896 1897 1898 1899
      LOG(ERROR)
          << "Either model_dir or (prog_file, param_file) should be set.";
      return false;
    }
1900
    LOG(ERROR) << string::Sprintf(
C
ccrrong 已提交
1901 1902
        "not valid model path '%s' or program path '%s'.",
        config_.model_dir(),
1903
        config_.params_file());
1904 1905
    return false;
  }
1906 1907 1908

  // Create ProgramDesc
  framework::proto::ProgramDesc proto;
T
Tao Luo 已提交
1909
  if (!config_.model_from_memory()) {
T
Tao Luo 已提交
1910 1911 1912
    std::string pb_content;
    // Read binary
    std::ifstream fin(filename, std::ios::in | std::ios::binary);
1913
    PADDLE_ENFORCE_EQ(
C
ccrrong 已提交
1914 1915
        static_cast<bool>(fin.is_open()),
        true,
1916 1917 1918
        platform::errors::NotFound(
            "Cannot open file %s, please confirm whether the file is normal.",
            filename));
T
Tao Luo 已提交
1919 1920 1921 1922 1923 1924 1925 1926
    fin.seekg(0, std::ios::end);
    pb_content.resize(fin.tellg());
    fin.seekg(0, std::ios::beg);
    fin.read(&(pb_content.at(0)), pb_content.size());
    fin.close();

    proto.ParseFromString(pb_content);
  } else {
1927
    proto.ParseFromString(config_.prog_file());
T
Tao Luo 已提交
1928
  }
1929 1930 1931 1932 1933 1934
  inference_program_.reset(new framework::ProgramDesc(proto));
  return true;
}

bool AnalysisPredictor::LoadParameters() {
  PADDLE_ENFORCE_NOT_NULL(inference_program_.get(),
1935 1936
                          platform::errors::PreconditionNotMet(
                              "The inference program should be loaded first."));
T
Tao Luo 已提交
1937

1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957
  const auto &global_block = inference_program_->MutableBlock(0);

  // create a temporary program to load parameters.

  std::unique_ptr<framework::ProgramDesc> load_program(
      new framework::ProgramDesc());
  framework::BlockDesc *load_block = load_program->MutableBlock(0);
  std::vector<std::string> params;

  for (auto *var : global_block->AllVars()) {
    if (IsPersistable(var)) {
      VLOG(3) << "persistable variable's name: " << var->Name();

      framework::VarDesc *new_var = load_block->Var(var->Name());
      new_var->SetShape(var->GetShape());
      new_var->SetDataType(var->GetDataType());
      new_var->SetType(var->GetType());
      new_var->SetLoDLevel(var->GetLoDLevel());
      new_var->SetPersistable(true);

1958
      if (!config_.params_file().empty()) {
1959 1960 1961 1962 1963 1964
        params.push_back(new_var->Name());
      } else {
        // append_op
        framework::OpDesc *op = load_block->AppendOp();
        op->SetType("load");
        op->SetOutput("Out", {new_var->Name()});
1965
        op->SetAttr("file_path", {config_.model_dir() + "/" + new_var->Name()});
1966 1967 1968 1969 1970
        op->CheckAttrs();
      }
    }
  }

1971
  if (!config_.params_file().empty()) {
1972 1973 1974 1975 1976 1977
    // sort paramlist to have consistent ordering
    std::sort(params.begin(), params.end());
    // append just the load_combine op
    framework::OpDesc *op = load_block->AppendOp();
    op->SetType("load_combine");
    op->SetOutput("Out", params);
1978
    op->SetAttr("file_path", {config_.params_file()});
1979 1980 1981 1982
    op->CheckAttrs();
  }

  // Use NaiveExecutor to Load parameters.
S
superjomn 已提交
1983
  framework::NaiveExecutor e(place_);
1984 1985 1986 1987
  e.Prepare(scope_.get(), *load_program, 0, false);
  e.Run();
  VLOG(3) << "get " << scope_->LocalVarNames().size() << " vars after load";

1988 1989
  return true;
}
1990

1991 1992 1993 1994 1995
uint64_t AnalysisPredictor::TryShrinkMemory() {
  ClearIntermediateTensor();
  return paddle::memory::Release(place_);
}

1996 1997 1998 1999 2000 2001 2002 2003
void AnalysisPredictor::ClearIntermediateTensor() {
  PADDLE_ENFORCE_NOT_NULL(inference_program_.get(),
                          platform::errors::PreconditionNotMet(
                              "The inference program should be loaded first."));
  const auto &global_block = inference_program_->MutableBlock(0);
  for (auto *var : global_block->AllVars()) {
    if (!IsPersistable(var)) {
      const std::string name = var->Name();
2004
      auto *variable = executor_->GetScope()->FindVar(name);
2005
      if (variable != nullptr && variable->IsType<phi::DenseTensor>() &&
2006 2007
          name != "feed" && name != "fetch") {
        VLOG(3) << "Clear Intermediate Tensor: " << name;
2008
        auto *t = variable->GetMutable<phi::DenseTensor>();
2009 2010 2011 2012 2013 2014
        t->clear();
      }
    }
  }
}

2015
#ifdef PADDLE_WITH_TENSORRT
N
nhzlx 已提交
2016
bool AnalysisPredictor::SaveTrtCalibToDisk() {
C
ccrrong 已提交
2017 2018
  PADDLE_ENFORCE_EQ(config_.tensorrt_engine_enabled(),
                    true,
2019 2020
                    platform::errors::PreconditionNotMet(
                        "This func can be invoked only in trt mode"));
N
nhzlx 已提交
2021 2022 2023
  auto &block = inference_program_->Block(0);
  for (auto &op_desc : block.AllOps()) {
    if (op_desc->Type() == "tensorrt_engine") {
R
Ruibiao Chen 已提交
2024
      std::string engine_name = PADDLE_GET_CONST(
2025
          std::string, op_desc->GetAttr("calibration_engine_key"));
N
nhzlx 已提交
2026
      if (!Singleton<TRTCalibratorEngineManager>::Global().Has(engine_name)) {
N
nhzlx 已提交
2027 2028 2029 2030
        LOG(ERROR) << "You should run the predictor(with trt) on the real data "
                      "to generate calibration info";
        return false;
      }
N
nhzlx 已提交
2031 2032
      TRTCalibratorEngine *calib_engine =
          Singleton<TRTCalibratorEngineManager>::Global().Get(engine_name);
N
nhzlx 已提交
2033
      LOG(INFO) << "Wait for calib threads done.";
N
nhzlx 已提交
2034
      calib_engine->calib_->waitAndSetDone();
N
nhzlx 已提交
2035 2036
      LOG(INFO) << "Generating TRT Calibration table data, this may cost a lot "
                   "of time...";
N
nhzlx 已提交
2037 2038 2039
      calib_engine->thr_->join();
      std::string calibration_table_data =
          calib_engine->calib_->getCalibrationTableAsString();
N
nhzlx 已提交
2040

N
nhzlx 已提交
2041
      if (calibration_table_data.empty()) {
N
nhzlx 已提交
2042 2043 2044
        LOG(ERROR) << "the calibration table is empty.";
        return false;
      }
N
nhzlx 已提交
2045

N
nhzlx 已提交
2046 2047 2048 2049 2050
      std::string model_opt_cache_dir =
          argument_.Has("model_dir")
              ? argument_.model_dir()
              : inference::analysis::GetDirRoot(argument_.model_program_path());

N
nhzlx 已提交
2051
      std::string calibration_table_data_path =
N
nhzlx 已提交
2052 2053 2054 2055
          inference::analysis::GetTrtCalibPath(
              inference::analysis::GetOrCreateModelOptCacheDir(
                  model_opt_cache_dir),
              engine_name);
N
nhzlx 已提交
2056 2057 2058 2059 2060

      std::ofstream ofile(calibration_table_data_path, std::ios::out);
      LOG(INFO) << "Write Paddle-TRT INT8 calibration table data to file "
                << calibration_table_data_path;
      ofile << calibration_table_data;
N
nhzlx 已提交
2061 2062 2063 2064
      ofile.close();
    }
  }
  // Free all calibrator resources.
N
nhzlx 已提交
2065
  Singleton<TRTCalibratorEngineManager>::Global().DeleteALL();
N
nhzlx 已提交
2066 2067
  return true;
}
N
nhzlx 已提交
2068
#endif
N
nhzlx 已提交
2069

2070
AnalysisPredictor::~AnalysisPredictor() {
2071
#ifdef PADDLE_WITH_TENSORRT
N
nhzlx 已提交
2072
  if (config_.tensorrt_engine_enabled() &&
N
nhzlx 已提交
2073 2074
      config_.tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
      Singleton<TRTCalibratorEngineManager>::Global().Has()) {
N
nhzlx 已提交
2075 2076
    SaveTrtCalibToDisk();
  }
N
nhzlx 已提交
2077
#endif
2078
  if (config_.with_profile_) {
2079 2080 2081 2082
    platform::DisableProfiler(platform::EventSortingKey::kTotal,
                              "./profile.log");
  }
  if (sub_scope_) {
J
JingZhuangzhuang 已提交
2083 2084 2085 2086 2087 2088 2089 2090 2091
    if (framework::global_transfer_scope_key().find(sub_scope_) !=
        framework::global_transfer_scope_key().end()) {
      auto scope_key_set = framework::global_transfer_scope_key()[sub_scope_];
      for (auto iter = scope_key_set.begin(); iter != scope_key_set.end();
           iter++) {
        framework::global_transfer_data_cache().erase(*iter);
      }
      framework::global_transfer_scope_key().erase(sub_scope_);
    }
2092 2093
    scope_->DeleteScope(sub_scope_);
  }
Y
Yan Chunwei 已提交
2094

2095 2096 2097 2098 2099 2100
#if PADDLE_WITH_MKLDNN
  if (mkldnn_quantizer_) {
    delete mkldnn_quantizer_;
    mkldnn_quantizer_ = nullptr;
  }
#endif
2101

2102 2103 2104
  if (config_.shape_range_info_collected()) {
    StatisticShapeRangeInfo();
  }
2105 2106 2107 2108 2109
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (predictor_stream_ != nullptr) {
    ResourceManager::Instance().DestroyGPUResource(predictor_stream_);
  }
#endif
W
Wilber 已提交
2110 2111 2112
  if (place_.GetType() != phi::AllocationType::UNDEFINED) {
    memory::Release(place_);
  }
2113
  device_contexts_.clear();
2114 2115 2116 2117 2118 2119 2120

#ifdef PADDLE_WITH_TENSORRT
  if (config_.trt_engine_memory_sharing()) {
    inference::Singleton<inference::tensorrt::TRTEngineManager>::Global()
        .releaseContextMemory(predictor_id_);
  }
#endif
2121 2122
}

2123
std::unique_ptr<PaddlePredictor> AnalysisPredictor::Clone(void *stream) {
Y
Yan Chunwei 已提交
2124
  std::lock_guard<std::mutex> lk(clone_mutex_);
2125
  auto *x = new AnalysisPredictor(config_);
2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136
  x->status_is_cloned_ = true;
  if (config_.use_external_stream_ && stream == nullptr) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "config has been configured to use external stream, but the Clone "
        "function has not received a valid stream parameter."));
  } else if (!config_.use_external_stream_ && stream != nullptr) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "config has not been configured to use external stream, but the Clone "
        "function has received a stream parameter."));
  }
  x->predictor_stream_ = stream;
2137
  x->Init(scope_, inference_program_);
2138
  x->executor_->ResetTrtOps(++AnalysisPredictor::clone_num_);
2139 2140 2141
  return std::unique_ptr<PaddlePredictor>(x);
}

2142
std::string AnalysisPredictor::GetSerializedProgram() const {
Y
Yan Chunwei 已提交
2143 2144 2145
  return inference_program_->Proto()->SerializeAsString();
}

2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184
// Add SaveOptimModel
void AnalysisPredictor::SaveOptimModel(const std::string &dir) {
  // save model
  std::string model_name = dir + "/model";
  std::ofstream outfile;
  outfile.open(model_name, std::ios::out | std::ios::binary);
  std::string inference_prog_desc = GetSerializedProgram();
  outfile << inference_prog_desc;
  // save params
  framework::ProgramDesc save_program;
  auto *save_block = save_program.MutableBlock(0);

  const framework::ProgramDesc &main_program = program();
  const framework::BlockDesc &global_block = main_program.Block(0);
  std::vector<std::string> save_var_list;
  for (framework::VarDesc *var : global_block.AllVars()) {
    if (IsPersistable(var)) {
      framework::VarDesc *new_var = save_block->Var(var->Name());
      new_var->SetShape(var->GetShape());
      new_var->SetDataType(var->GetDataType());
      new_var->SetType(var->GetType());
      new_var->SetLoDLevel(var->GetLoDLevel());
      new_var->SetPersistable(true);

      save_var_list.push_back(new_var->Name());
    }
  }
  std::sort(save_var_list.begin(), save_var_list.end());
  auto *op = save_block->AppendOp();
  op->SetType("save_combine");
  op->SetInput("X", save_var_list);
  op->SetAttr("file_path", dir + "/params");
  op->CheckAttrs();

  platform::CPUPlace place;
  framework::Executor exe(place);
  exe.Run(save_program, scope(), 0, true, true);
}

2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211
void AnalysisPredictor::RegisterOutputHook(const Exp_OutputHookFunc &hookfunc) {
  if (config_.enable_memory_optim()) {
    LOG(WARNING) << "If you want to run output hook function, you should "
                    "use config.EnableMemoryOptim(false) to turn off memory "
                    "reuse!";
    return;
  }
  static std::once_flag register_hook_flag;
  std::call_once(register_hook_flag, [this] {
    executor_->RegisterOutputHook([this](framework::OperatorBase *op) {
      for (auto &output : op->Outputs()) {
        for (auto &var_name : output.second) {
          auto *var = this->sub_scope_->FindVar(var_name);
          if (!var || !var->IsType<phi::DenseTensor>()) continue;
          auto dense_tensor = var->Get<phi::DenseTensor>();
          if (!dense_tensor.initialized()) continue;
          auto tensor = this->GetOutputTensor(var_name);
          for (auto &hookfunc : this->hookfuncs_) {
            hookfunc(op->Type(), var_name, *tensor);
          }
        }
      }
    });
  });
  hookfuncs_.push_back(hookfunc);
}

Y
Yan Chunwei 已提交
2212
template <>
2213 2214
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor<AnalysisConfig>(
    const AnalysisConfig &config) {
W
Wilber 已提交
2215
  LOG(WARNING) << "Deprecated. Please use CreatePredictor instead.";
2216 2217
  return CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
      config);
Y
Yan Chunwei 已提交
2218 2219
}

2220
}  // namespace paddle
2221

2222
#ifdef PADDLE_WITH_TENSORRT
2223
USE_TRT_CONVERTER(elementwise_add_weight);
S
shentanyue 已提交
2224 2225 2226
USE_TRT_CONVERTER(elementwise_sub_weight);
USE_TRT_CONVERTER(elementwise_mul_weight);
USE_TRT_CONVERTER(elementwise_div_weight);
2227 2228
USE_TRT_CONVERTER(elementwise_min_weight);
USE_TRT_CONVERTER(elementwise_max_weight);
S
shentanyue 已提交
2229
USE_TRT_CONVERTER(elementwise_pow_weight);
W
wenbin 已提交
2230
USE_TRT_CONVERTER(elementwise_floordiv_weight);
2231 2232 2233 2234 2235 2236 2237
USE_TRT_CONVERTER(elementwise_add_tensor);
USE_TRT_CONVERTER(elementwise_sub_tensor);
USE_TRT_CONVERTER(elementwise_div_tensor);
USE_TRT_CONVERTER(elementwise_mul_tensor);
USE_TRT_CONVERTER(elementwise_max_tensor);
USE_TRT_CONVERTER(elementwise_min_tensor);
USE_TRT_CONVERTER(elementwise_pow_tensor);
W
wenbin 已提交
2238
USE_TRT_CONVERTER(elementwise_floordiv_tensor);
2239
USE_TRT_CONVERTER(transpose);
2240
USE_TRT_CONVERTER(transpose2);
2241
USE_TRT_CONVERTER(flatten);
2242
USE_TRT_CONVERTER(flatten_contiguous_range);
2243
USE_TRT_CONVERTER(matmul);
2244
USE_TRT_CONVERTER(matmul_v2);
2245
USE_TRT_CONVERTER(bmm);
2246 2247
USE_TRT_CONVERTER(conv2d);
USE_TRT_CONVERTER(relu);
Z
zhupengyang 已提交
2248 2249
USE_TRT_CONVERTER(exp);
USE_TRT_CONVERTER(log);
2250 2251 2252 2253 2254 2255 2256 2257 2258
USE_TRT_CONVERTER(sigmoid);
USE_TRT_CONVERTER(tanh);
USE_TRT_CONVERTER(fc);
USE_TRT_CONVERTER(pool2d);
USE_TRT_CONVERTER(softmax);
USE_TRT_CONVERTER(batch_norm);
USE_TRT_CONVERTER(concat);
USE_TRT_CONVERTER(dropout);
USE_TRT_CONVERTER(pad);
2259 2260
USE_TRT_CONVERTER(hard_sigmoid);
USE_TRT_CONVERTER(hard_swish);
2261
USE_TRT_CONVERTER(split);
2262
USE_TRT_CONVERTER(fill_any_like);
2263 2264
USE_TRT_CONVERTER(prelu);
USE_TRT_CONVERTER(conv2d_transpose);
H
hjchen2 已提交
2265
USE_TRT_CONVERTER(leaky_relu);
2266
USE_TRT_CONVERTER(shuffle_channel);
2267
USE_TRT_CONVERTER(where);
2268
USE_TRT_CONVERTER(swish);
L
LielinJiang 已提交
2269
USE_TRT_CONVERTER(silu);
2270
USE_TRT_CONVERTER(group_norm);
2271
USE_TRT_CONVERTER(instance_norm);
P
Pei Yang 已提交
2272 2273 2274
USE_TRT_CONVERTER(layer_norm);
USE_TRT_CONVERTER(gelu);
USE_TRT_CONVERTER(multihead_matmul);
2275
USE_TRT_CONVERTER(multihead_matmul_roformer);
2276
USE_TRT_CONVERTER(skip_layernorm);
2277
USE_TRT_CONVERTER(slice);
2278
USE_TRT_CONVERTER(scale);
2279
USE_TRT_CONVERTER(stack);
P
Pei Yang 已提交
2280
USE_TRT_CONVERTER(clip);
2281
USE_TRT_CONVERTER(gather);
2282
USE_TRT_CONVERTER(anchor_generator);
Z
zlsh80826 已提交
2283
USE_TRT_CONVERTER(yolo_box);
2284
USE_TRT_CONVERTER(yolo_box_head);
2285
USE_TRT_CONVERTER(arg_max);
2286
USE_TRT_CONVERTER(roi_align);
2287
USE_TRT_CONVERTER(affine_channel);
Z
zlsh80826 已提交
2288
USE_TRT_CONVERTER(multiclass_nms);
2289
USE_TRT_CONVERTER(multiclass_nms3);
2290
USE_TRT_CONVERTER(nearest_interp);
2291
USE_TRT_CONVERTER(nearest_interp_v2);
2292
USE_TRT_CONVERTER(bilinear_interp_v2);
W
Wangzheee 已提交
2293
USE_TRT_CONVERTER(reshape);
2294
USE_TRT_CONVERTER(reshape2);
2295 2296
USE_TRT_CONVERTER(reduce_sum);
USE_TRT_CONVERTER(gather_nd);
W
wenbin 已提交
2297
USE_TRT_CONVERTER(reduce_mean);
W
wenbin 已提交
2298
USE_TRT_CONVERTER(tile);
W
wenbin 已提交
2299 2300
USE_TRT_CONVERTER(conv3d);
USE_TRT_CONVERTER(conv3d_transpose);
W
wangxinxin08 已提交
2301
USE_TRT_CONVERTER(mish);
W
wangxinxin08 已提交
2302
USE_TRT_CONVERTER(deformable_conv);
F
feng_shuai 已提交
2303
USE_TRT_CONVERTER(pool3d)
2304 2305
#ifdef _WIN32
#else
2306
USE_TRT_CONVERTER(fused_preln_embedding_eltwise_layernorm)
2307 2308
USE_TRT_CONVERTER(fused_embedding_eltwise_layernorm);
#endif
2309
USE_TRT_CONVERTER(preln_skip_layernorm)
2310 2311
USE_TRT_CONVERTER(preln_residual_bias)
USE_TRT_CONVERTER(c_allreduce_sum)
F
feng_shuai 已提交
2312
USE_TRT_CONVERTER(roll)
F
feng_shuai 已提交
2313
USE_TRT_CONVERTER(strided_slice)
Z
zhoutianzi666 已提交
2314 2315
USE_TRT_CONVERTER(rnn)
USE_TRT_CONVERTER(fill_constant_batch_size_like)
2316
USE_TRT_CONVERTER(transformer_input_convert)
C
ccrrong 已提交
2317
USE_TRT_CONVERTER(cast)
2318 2319
USE_TRT_CONVERTER(recover_padding)
USE_TRT_CONVERTER(remove_padding)
C
ccrrong 已提交
2320
USE_TRT_CONVERTER(equal);
2321 2322
USE_TRT_CONVERTER(top_k)
USE_TRT_CONVERTER(top_k_v2)
2323 2324
USE_TRT_CONVERTER(squeeze2)
USE_TRT_CONVERTER(unsqueeze2)
2325 2326
USE_TRT_CONVERTER(sum)
USE_TRT_CONVERTER(shape)
2327
USE_TRT_CONVERTER(fill_constant)
2328
USE_TRT_CONVERTER(fused_token_prune)
2329
USE_TRT_CONVERTER(celu)
W
wenbin 已提交
2330
USE_TRT_CONVERTER(layernorm_shift_partition)
W
wenbin 已提交
2331
USE_TRT_CONVERTER(preln_layernorm_shift_partition)
W
Wang Bojun 已提交
2332
USE_TRT_CONVERTER(merge_layernorm)
W
wenbin 已提交
2333
USE_TRT_CONVERTER(skip_merge_layernorm)
W
weishengying 已提交
2334 2335
USE_TRT_CONVERTER(generic_plugin_creater)
USE_TRT_CONVERTER(custom_plugin_creater)
2336 2337
USE_TRT_CONVERTER(tanh_shrink)
USE_TRT_CONVERTER(logsigmoid)
2338
USE_TRT_CONVERTER(lookup_table)
2339
USE_TRT_CONVERTER(expand_v2)
2340 2341 2342 2343
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(sparse_fc)
USE_TRT_CONVERTER(sparse_multihead_matmul)
#endif
2344
#endif
W
Wilber 已提交
2345 2346 2347 2348 2349 2350

namespace paddle_infer {

Predictor::Predictor(const Config &config) {
  const_cast<Config *>(&config)->SwitchUseFeedFetchOps(false);
  // The second parameter indicates that the discard log is not printed
2351 2352 2353 2354 2355 2356 2357 2358 2359 2360
  if (config.use_onnxruntime()) {
#ifdef PADDLE_WITH_ONNXRUNTIME
    if (config.use_gpu()) {
      LOG(WARNING) << "The current ONNXRuntime backend doesn't support GPU,"
                      "and it falls back to use Paddle Inference.";
    } else if (!paddle::CheckConvertToONNX(config)) {
      LOG(WARNING)
          << "Paddle2ONNX do't support convert the Model, fall back to using "
             "Paddle Inference.";
    } else {
C
ccrrong 已提交
2361 2362 2363 2364
      predictor_ =
          paddle::CreatePaddlePredictor<Config,
                                        paddle::PaddleEngineKind::kONNXRuntime>(
              config);
2365 2366 2367 2368 2369 2370 2371 2372 2373
      return;
    }
#else
    LOG(WARNING)
        << "The onnxruntime backend isn't enabled,"
           " and please re-compile Paddle with WITH_ONNXRUNTIME option,"
           "fall back to using Paddle Inference.";
#endif
  }
C
ccrrong 已提交
2374 2375 2376 2377
  predictor_ =
      paddle::CreatePaddlePredictor<Config,
                                    paddle::PaddleEngineKind::kAnalysis>(
          config);
W
Wilber 已提交
2378 2379 2380 2381 2382
}

std::vector<std::string> Predictor::GetInputNames() {
  return predictor_->GetInputNames();
}
2383 2384 2385 2386

std::map<std::string, DataType> Predictor::GetInputTypes() {
  return predictor_->GetInputTypes();
}
W
Wilber 已提交
2387 2388

std::unique_ptr<Tensor> Predictor::GetInputHandle(const std::string &name) {
2389
  return predictor_->GetInputTensor(name);
W
Wilber 已提交
2390 2391 2392 2393 2394 2395 2396
}

std::vector<std::string> Predictor::GetOutputNames() {
  return predictor_->GetOutputNames();
}

std::unique_ptr<Tensor> Predictor::GetOutputHandle(const std::string &name) {
2397
  return predictor_->GetOutputTensor(name);
W
Wilber 已提交
2398 2399 2400 2401
}

bool Predictor::Run() { return predictor_->ZeroCopyRun(); }

2402 2403
std::unique_ptr<Predictor> Predictor::Clone(void *stream) {
  auto analysis_pred = predictor_->Clone(stream);
W
Wilber 已提交
2404 2405 2406 2407 2408 2409 2410 2411
  std::unique_ptr<Predictor> pred(new Predictor(std::move(analysis_pred)));
  return pred;
}

void Predictor::ClearIntermediateTensor() {
  predictor_->ClearIntermediateTensor();
}

2412 2413
uint64_t Predictor::TryShrinkMemory() { return predictor_->TryShrinkMemory(); }

2414 2415 2416 2417
void Predictor::RegisterOutputHook(const Exp_OutputHookFunc &hookfunc) {
  predictor_->RegisterOutputHook(hookfunc);
}

2418 2419
void *Predictor::GetExecStream() const { return predictor_->GetExecStream(); }

W
Wilber 已提交
2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437
int GetNumBytesOfDataType(DataType dtype) {
  switch (dtype) {
    case DataType::FLOAT32:
      return sizeof(float);
    case DataType::INT64:
      return sizeof(int64_t);
    case DataType::INT32:
      return sizeof(int32_t);
    case DataType::UINT8:
      return sizeof(uint8_t);
    default:
      assert(false);
      return -1;
  }
}

std::string GetVersion() { return paddle::get_version(); }

2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450 2451 2452 2453
std::tuple<int, int, int> GetTrtCompileVersion() {
#ifdef PADDLE_WITH_TENSORRT
  return paddle::inference::tensorrt::GetTrtCompileVersion();
#else
  return std::tuple<int, int, int>{0, 0, 0};
#endif
}

std::tuple<int, int, int> GetTrtRuntimeVersion() {
#ifdef PADDLE_WITH_TENSORRT
  return paddle::inference::tensorrt::GetTrtRuntimeVersion();
#else
  return std::tuple<int, int, int>{0, 0, 0};
#endif
}

W
Wilber 已提交
2454 2455 2456 2457
std::string UpdateDllFlag(const char *name, const char *value) {
  return paddle::UpdateDllFlag(name, value);
}

2458 2459 2460 2461 2462
void ConvertToMixedPrecision(const std::string &model_file,
                             const std::string &params_file,
                             const std::string &mixed_model_file,
                             const std::string &mixed_params_file,
                             PrecisionType mixed_precision,
2463
                             paddle_infer::PlaceType backend,
2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477
                             bool keep_io_types,
                             std::unordered_set<std::string> black_list) {
  auto phi_backend = paddle::ConvertBackend(backend);
  auto phi_precision = paddle::ConvertPrecision(mixed_precision);
  paddle::inference::analysis::ConvertToMixedPrecision(model_file,
                                                       params_file,
                                                       mixed_model_file,
                                                       mixed_params_file,
                                                       phi_precision,
                                                       phi_backend,
                                                       keep_io_types,
                                                       black_list);
}

W
Wilber 已提交
2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488
}  // namespace paddle_infer

namespace paddle_infer {
std::shared_ptr<Predictor> CreatePredictor(const Config &config) {  // NOLINT
  std::shared_ptr<Predictor> predictor(new Predictor(config));
  return predictor;
}

namespace services {
PredictorPool::PredictorPool(const Config &config, size_t size) {
  PADDLE_ENFORCE_GE(
C
ccrrong 已提交
2489 2490
      size,
      1UL,
W
Wilber 已提交
2491 2492 2493 2494 2495 2496 2497 2498
      paddle::platform::errors::InvalidArgument(
          "The predictor pool size should be greater than 1, but it's (%d)",
          size));
  Config copy_config(config);
  main_pred_.reset(new Predictor(config));
  for (size_t i = 0; i < size - 1; i++) {
    if (config.tensorrt_engine_enabled()) {
      Config config_tmp(copy_config);
2499
      preds_.emplace_back(new Predictor(config_tmp));
W
Wilber 已提交
2500
    } else {
2501
      preds_.emplace_back(main_pred_->Clone());
W
Wilber 已提交
2502 2503 2504 2505 2506 2507
    }
  }
}

Predictor *PredictorPool::Retrive(size_t idx) {
  PADDLE_ENFORCE_LT(
C
ccrrong 已提交
2508 2509
      idx,
      preds_.size() + 1,
W
Wilber 已提交
2510
      paddle::platform::errors::InvalidArgument(
C
ccrrong 已提交
2511 2512
          "There are (%d) predictors in the pool, but the idx is (%d)",
          idx,
W
Wilber 已提交
2513 2514 2515 2516 2517 2518 2519
          preds_.size() + 1));
  if (idx == 0) {
    return main_pred_.get();
  }
  return preds_[idx - 1].get();
}
}  // namespace services
W
Wilber 已提交
2520 2521 2522 2523 2524 2525 2526 2527 2528 2529 2530 2531 2532 2533 2534 2535 2536 2537 2538 2539

namespace experimental {

// Note: Can only be used under thread_local semantics.
bool InternalUtils::RunWithExternalStream(paddle_infer::Predictor *p,
                                          cudaStream_t stream) {
#ifdef PADDLE_WITH_CUDA
  auto pred = dynamic_cast<paddle::AnalysisPredictor *>(p->predictor_.get());
  return pred->ExpRunWithExternalStream(stream);
#endif
  return false;
}
bool InternalUtils::RunWithExternalStream(paddle_infer::Predictor *p,
                                          hipStream_t stream) {
#ifdef PADDLE_WITH_HIP
  auto pred = dynamic_cast<paddle::AnalysisPredictor *>(p->predictor_.get());
  return pred->ExpRunWithExternalStream(stream);
#endif
  return false;
}
W
Wilber 已提交
2540

2541 2542 2543 2544 2545 2546
void InternalUtils::UpdateConfigInterleaved(paddle_infer::Config *c,
                                            bool with_interleaved) {
#ifdef PADDLE_WITH_CUDA
  c->trt_with_interleaved_ = with_interleaved;
#endif
}
W
Wilber 已提交
2547

2548 2549 2550 2551 2552 2553 2554 2555 2556 2557 2558 2559 2560 2561
void InternalUtils::SetTransformerPosid(
    paddle_infer::Config *c, const std::string &tensorrt_transformer_posid) {
#ifdef PADDLE_WITH_CUDA
  c->tensorrt_transformer_posid_ = tensorrt_transformer_posid;
#endif
}

void InternalUtils::SetTransformerMaskid(
    paddle_infer::Config *c, const std::string &tensorrt_transformer_maskid) {
#ifdef PADDLE_WITH_CUDA
  c->tensorrt_transformer_maskid_ = tensorrt_transformer_maskid;
#endif
}

W
Wilber 已提交
2562 2563 2564 2565 2566
void InternalUtils::SyncStream(paddle_infer::Predictor *p) {
#ifdef PADDLE_WITH_CUDA
  auto *pred = dynamic_cast<paddle::AnalysisPredictor *>(p->predictor_.get());
  paddle::platform::DeviceContextPool &pool =
      paddle::platform::DeviceContextPool::Instance();
L
Leo Chen 已提交
2567
  auto *dev_ctx = reinterpret_cast<phi::GPUContext *>(pool.Get(pred->place_));
W
Wilber 已提交
2568 2569 2570 2571 2572 2573 2574 2575 2576
  cudaStreamSynchronize(dev_ctx->stream());
#endif
}
void InternalUtils::SyncStream(cudaStream_t stream) {
#ifdef PADDLE_WITH_CUDA
  cudaStreamSynchronize(stream);
#endif
}

W
Wilber 已提交
2577
}  // namespace experimental
W
Wilber 已提交
2578
}  // namespace paddle_infer