analysis_config.cc 45.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <sstream>
16
#include <string>
17
#include <tuple>
18

19 20
#include "paddle/fluid/inference/api/paddle_analysis_config.h"
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
21
#include "paddle/fluid/inference/utils/table_printer.h"
22
#include "paddle/fluid/platform/cpu_info.h"
23
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
24
#include "paddle/fluid/platform/enforce.h"
25
#include "paddle/utils/string/split.h"
26

27 28 29 30
#ifdef PADDLE_WITH_TENSORRT
#include "paddle/fluid/inference/tensorrt/helper.h"
#endif

31
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
32 33 34
DECLARE_uint64(initial_gpu_memory_in_mb);
#endif

35
namespace paddle {
W
wanghuancoder 已提交
36 37
struct MkldnnQuantizerConfig;

38
extern const std::vector<std::string> kTRTSubgraphPasses;
D
denglin-github 已提交
39
extern const std::vector<std::string> kDlnneSubgraphPasses;
石晓伟 已提交
40
extern const std::vector<std::string> kLiteSubgraphPasses;
41

42
PassStrategy *AnalysisConfig::pass_builder() const {
43 44 45 46
  if (!pass_builder_.get()) {
    if (use_gpu_) {
      LOG(INFO) << "Create GPU IR passes";
      pass_builder_.reset(new GpuPassStrategy);
47 48
    } else if (use_xpu_) {
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
49 50
    } else if (use_npu_) {
      pass_builder_.reset(new NpuPassStrategy);
J
jianghaicheng 已提交
51 52 53
    } else if (use_ipu_) {
      LOG(INFO) << "Create IPU IR passes";
      pass_builder_.reset(new IpuPassStrategy);
54 55 56 57 58 59 60 61 62 63 64 65
    } else {
      LOG(INFO) << "Create CPU IR passes";
      pass_builder_.reset(new CpuPassStrategy);
    }
  } else if (pass_builder_->use_gpu() ^ use_gpu()) {
    LOG(WARNING) << "The use_gpu flag is not compatible between Config and "
                    "PassBuilder, the flags are "
                 << use_gpu() << " " << pass_builder_->use_gpu();
    LOG(WARNING) << "Please make them compatible, still use the existing "
                    "PassBuilder.";
  }

66 67 68
  return pass_builder_.get();
}

69
AnalysisConfig::AnalysisConfig(const std::string &model_dir) {
70
  model_dir_ = model_dir;
Y
Yan Chunwei 已提交
71 72

  Update();
73
}
74 75
AnalysisConfig::AnalysisConfig(const std::string &prog_file,
                               const std::string &params_file) {
76 77
  prog_file_ = prog_file;
  params_file_ = params_file;
Y
Yan Chunwei 已提交
78 79

  Update();
80
}
81 82
void AnalysisConfig::SetModel(const std::string &prog_file_path,
                              const std::string &params_file_path) {
83 84
  prog_file_ = prog_file_path;
  params_file_ = params_file_path;
Y
Yan Chunwei 已提交
85 86

  Update();
87
}
88

89
void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
90 91
                                  int device_id,
                                  Precision precision_mode) {
92
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
93 94
  use_gpu_ = true;
  memory_pool_init_size_mb_ = memory_pool_init_size_mb;
95
  FLAGS_initial_gpu_memory_in_mb = memory_pool_init_size_mb_;
96
  gpu_device_id_ = device_id;
97 98 99 100 101
  mixed_precision_mode_ = precision_mode;
  if (precision_mode == Precision::kFloat32) {
    // default
  } else if (precision_mode == Precision::kHalf ||
             precision_mode == Precision::kBf16) {
102
    enable_gpu_mixed_ = true;
103 104 105 106 107 108
  } else {
    LOG(ERROR)
        << "The Paddle-GPU inference currently only supports "
           "float32/float16/bfloat16 precision. Please check the parameters "
           "you specified in EnableUseGpu or enable_use_gpu function.";
  }
109
#else
110
  LOG(ERROR) << "Please use PaddlePaddle with GPU version.";
111
  use_gpu_ = false;
112
#endif
Y
Yan Chunwei 已提交
113 114 115

  Update();
}
116

117
void AnalysisConfig::SetExecStream(void *stream) {
W
Wilber 已提交
118 119 120
  PADDLE_ENFORCE_NOT_NULL(
      stream,
      platform::errors::InvalidArgument("`stream` should not be nullptr"));
121 122 123 124 125 126
  exec_stream_ = stream;
  use_external_stream_ = true;
  Update();
}

void *AnalysisConfig::GetExecStream() const {
W
Wilber 已提交
127 128 129
  PADDLE_ENFORCE_NOT_NULL(
      exec_stream_,
      platform::errors::InvalidArgument("`stream` should not be nullptr"));
130 131 132 133 134 135 136
  return exec_stream_;
}

bool AnalysisConfig::external_stream_enabled() const {
  return use_external_stream_;
}

137
void AnalysisConfig::DisableGpu() {
Y
Yan Chunwei 已提交
138 139 140
  use_gpu_ = false;

  Update();
141 142
}

143 144 145 146 147 148
void AnalysisConfig::DisableFCPadding() {
  use_fc_padding_ = false;

  Update();
}

W
Wilber 已提交
149 150 151 152
void AnalysisConfig::EnableXpu(int l3_workspace_size,
                               bool locked,
                               bool autotune,
                               const std::string &autotune_file,
W
Wilber 已提交
153
                               const std::string &precision,
154 155
                               bool adaptive_seqlen,
                               bool enable_multi_stream) {
156 157
  use_xpu_ = true;
  xpu_l3_workspace_size_ = l3_workspace_size;
W
Wilber 已提交
158 159 160 161 162
  xpu_locked_ = locked;
  xpu_autotune_ = autotune;
  xpu_autotune_file_ = autotune_file;
  xpu_precision_ = precision;
  xpu_adaptive_seqlen_ = adaptive_seqlen;
163
  xpu_enable_multi_stream_ = enable_multi_stream;
164 165 166
  Update();
}

167
void AnalysisConfig::SetXpuDeviceId(int device_id) {
W
Wilber 已提交
168 169
  PADDLE_ENFORCE_EQ(use_xpu_,
                    true,
170 171 172 173 174 175
                    platform::errors::PreconditionNotMet(
                        "Should call EnableXpu before SetXpuDeviceId."));
  xpu_device_id_ = device_id;
  Update();
}

W
Wilber 已提交
176
void AnalysisConfig::EnableNpu(int device_id) {
S
shentanyue 已提交
177
#if defined(PADDLE_WITH_ASCEND_CL)
W
Wilber 已提交
178 179
  use_npu_ = true;
  npu_device_id_ = device_id;
S
shentanyue 已提交
180 181 182 183
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
  use_custom_device_ = true;
  custom_device_id_ = device_id;
  custom_device_type_ = "npu";
W
Wilber 已提交
184 185 186 187 188 189
#else
  LOG(ERROR) << "Please compile with npu to EnableNpu()";
  use_npu_ = false;
#endif
  Update();
}
190

191 192 193 194 195 196 197 198 199 200 201 202 203
void AnalysisConfig::EnableCustomDevice(const std::string &device_type,
                                        int device_id) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  use_custom_device_ = true;
  custom_device_id_ = device_id;
  custom_device_type_ = device_type;
#else
  LOG(ERROR) << "Please compile with CustomDevice to EnableCustomDevice()";
  use_custom_device_ = false;
#endif
  Update();
}

W
Wilber 已提交
204 205
void AnalysisConfig::EnableIpu(int ipu_device_num,
                               int ipu_micro_batch_size,
206 207
                               bool ipu_enable_pipelining,
                               int ipu_batches_per_step) {
J
jianghaicheng 已提交
208 209 210
  enable_ir_optim_ = true;

  use_ipu_ = true;
211 212
  ipu_device_num_ = ipu_device_num;
  ipu_micro_batch_size_ = ipu_micro_batch_size;
J
jianghaicheng 已提交
213 214
  ipu_enable_pipelining_ = ipu_enable_pipelining;
  ipu_batches_per_step_ = ipu_batches_per_step;
215 216 217 218

  Update();
}

W
Wilber 已提交
219 220
void AnalysisConfig::SetIpuConfig(bool ipu_enable_fp16,
                                  int ipu_replica_num,
221
                                  float ipu_available_memory_proportion,
222 223
                                  bool ipu_enable_half_partial,
                                  bool ipu_enable_model_runtime_executor) {
224 225 226 227
  ipu_enable_fp16_ = ipu_enable_fp16;
  ipu_replica_num_ = ipu_replica_num;
  ipu_available_memory_proportion_ = ipu_available_memory_proportion;
  ipu_enable_half_partial_ = ipu_enable_half_partial;
228
  ipu_enable_model_runtime_executor_ = ipu_enable_model_runtime_executor;
J
jianghaicheng 已提交
229 230 231

  Update();
}
W
Wilber 已提交
232

233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302
void AnalysisConfig::SetIpuCustomInfo(
    const std::vector<std::vector<std::string>> &ipu_custom_ops_info,
    const std::map<std::string, bool> &ipu_custom_patterns) {
  ipu_custom_ops_info_ = ipu_custom_ops_info;
  for (auto iter = ipu_custom_patterns.begin();
       iter != ipu_custom_patterns.end();
       iter++) {
    if (iter->second == true) {
      ipu_custom_patterns_.push_back(
          std::vector<std::string>{iter->first, "True"});
    } else if (iter->second == false) {
      ipu_custom_patterns_.push_back(
          std::vector<std::string>{iter->first, "False"});
    }
  }

  Update();
}

void AnalysisConfig::LoadIpuConfig(const std::string &config_path) {
  std::ifstream fin(config_path, std::ios::in);
  PADDLE_ENFORCE_EQ(
      static_cast<bool>(fin.is_open()),
      true,
      platform::errors::NotFound(
          "Cannot open file %s, please confirm whether the file is normal.",
          config_path));
  std::string line;
  while (std::getline(fin, line)) {
    // remove all space
    line.erase(std::remove(line.begin(), line.end(), ' '), line.end());

    std::string key;
    std::string value;
    std::istringstream stream(line);
    // Split string to key and value based on the first `,`
    std::getline(stream, key, ',');
    std::getline(stream, value);

    auto string2bool = [](std::string s) {
      std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) {
        return ::tolower(c);
      });
      return s == "true" || s == "1";
    };

    // ipu_custom_ops_info:
    // [[paddle_op_name, popart_op_name, domain, version], [paddle_op_name,
    // popart_op_name, domain, version]...]
    // ipu_custom_patterns:
    // [[paddle_op_name, enable_pattern], [paddle_op_name, enable_pattern]...]
    auto string2vector = [](std::string s) {
      std::vector<std::vector<std::string>> custom_info;
      s.erase(0, 1);
      s.pop_back();

      std::string one;
      std::istringstream s_stream(s);
      while (std::getline(s_stream, one, ']')) {
        if (!one.empty()) {
          // remove `[`
          one.erase(0, 1);
          custom_info.push_back(paddle::string::Split(one, ','));
        }
      }
      return custom_info;
    };

    if (ipu_config_mapper_.find(key) == ipu_config_mapper_.end()) {
      PADDLE_THROW(platform::errors::InvalidArgument(
303
          "invalid key %s in IPU config: ", key));
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
    }
    switch (ipu_config_mapper_.at(key)) {
      case ipu_config_code::ipu_device_num:
        ipu_device_num_ = std::stoi(value);
        break;
      case ipu_config_code::ipu_micro_batch_size:
        ipu_micro_batch_size_ = std::stoi(value);
        break;
      case ipu_config_code::ipu_enable_pipelining:
        ipu_enable_pipelining_ = string2bool(value);
        break;
      case ipu_config_code::ipu_batches_per_step:
        ipu_batches_per_step_ = std::stoi(value);
        break;
      case ipu_config_code::ipu_enable_fp16:
        ipu_enable_fp16_ = string2bool(value);
        break;
      case ipu_config_code::ipu_replica_num:
        ipu_replica_num_ = std::stoi(value);
        break;
      case ipu_config_code::ipu_available_memory_proportion:
        ipu_available_memory_proportion_ = std::stof(value);
        break;
      case ipu_config_code::ipu_enable_half_partial:
        ipu_enable_half_partial_ = string2bool(value);
        break;
      case ipu_config_code::ipu_custom_ops_info:
        ipu_custom_ops_info_ = string2vector(value);
        break;
      case ipu_config_code::ipu_custom_patterns:
        ipu_custom_patterns_ = string2vector(value);
        break;
336 337 338
      case ipu_config_code::ipu_enable_model_runtime_executor:
        ipu_enable_model_runtime_executor_ = string2bool(value);
        break;
339 340
      default:
        PADDLE_THROW(platform::errors::InvalidArgument(
341
            "invalid key %s in IPU config", key));
342 343 344 345 346 347 348
        break;
    }
  }

  Update();
}

349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
void AnalysisConfig::EnableONNXRuntime() {
#ifdef PADDLE_WITH_ONNXRUNTIME
  use_onnxruntime_ = true;
#else
  LOG(ERROR) << "Please compile with onnxruntime to EnableONNXRuntime()";
  use_onnxruntime_ = false;
#endif

  Update();
}

void AnalysisConfig::DisableONNXRuntime() {
  use_onnxruntime_ = false;
  Update();
}

void AnalysisConfig::EnableORTOptimization() {
#ifdef PADDLE_WITH_ONNXRUNTIME
  enable_ort_optimization_ = true;
#else
  LOG(ERROR) << "Please compile with onnxruntime to EnableORTOptimization()";
  enable_ort_optimization_ = false;
#endif

  Update();
}

376
AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
377 378 379 380 381 382
#define CP_MEMBER(member__) member__ = other.member__;

  // Model related.
  CP_MEMBER(model_dir_);
  CP_MEMBER(model_from_memory_);  // the memory model reuses prog_file_ and
                                  // params_file_ fields.
383

384
  CP_MEMBER(opt_cache_dir_);
W
Wilber 已提交
385 386
  CP_MEMBER(prog_file_);
  CP_MEMBER(params_file_);
387

388
  CP_MEMBER(use_fc_padding_);
389
  // GPU related.
390
  CP_MEMBER(use_gpu_);
391 392
  CP_MEMBER(use_external_stream_);
  CP_MEMBER(exec_stream_);
393
  CP_MEMBER(use_cudnn_);
394
  CP_MEMBER(gpu_device_id_);
395
  CP_MEMBER(memory_pool_init_size_mb_);
Y
Yan Chunwei 已提交
396

397
  // Mixed precision related.
398
  CP_MEMBER(mixed_black_list_);
399
  CP_MEMBER(enable_gpu_mixed_);
400
  CP_MEMBER(mixed_precision_mode_);
401

Y
Yan Chunwei 已提交
402
  CP_MEMBER(enable_memory_optim_);
S
Sylwester Fraczek 已提交
403
  // TensorRT related.
404 405 406 407
  CP_MEMBER(use_tensorrt_);
  CP_MEMBER(tensorrt_workspace_size_);
  CP_MEMBER(tensorrt_max_batchsize_);
  CP_MEMBER(tensorrt_min_subgraph_size_);
N
nhzlx 已提交
408
  CP_MEMBER(tensorrt_precision_mode_);
409
  CP_MEMBER(trt_disabled_ops_);
410 411
  CP_MEMBER(trt_use_dla_);
  CP_MEMBER(trt_dla_core_);
N
nhzlx 已提交
412
  CP_MEMBER(trt_use_static_engine_);
413
  CP_MEMBER(trt_use_calib_mode_);
414
  CP_MEMBER(trt_use_varseqlen_);
415
  CP_MEMBER(trt_with_interleaved_);
416 417
  CP_MEMBER(tensorrt_transformer_posid_);
  CP_MEMBER(tensorrt_transformer_maskid_);
418 419 420 421
  CP_MEMBER(trt_tuned_dynamic_shape_);
  CP_MEMBER(trt_allow_build_at_runtime_);
  CP_MEMBER(collect_shape_range_info_);
  CP_MEMBER(shape_range_info_path_);
422
  CP_MEMBER(trt_use_inspector_);
423
  CP_MEMBER(trt_engine_memory_sharing_);
D
denglin-github 已提交
424 425 426
  // Dlnne related
  CP_MEMBER(use_dlnne_);
  CP_MEMBER(dlnne_min_subgraph_size_);
D
denglin-github 已提交
427 428 429 430 431 432 433
  CP_MEMBER(dlnne_max_batchsize_);
  CP_MEMBER(dlnne_use_static_batch_);
  CP_MEMBER(dlnne_weight_share_mode_);
  CP_MEMBER(dlnne_use_calib_mode_);
  CP_MEMBER(dlnne_precision_mode_);
  CP_MEMBER(dlnne_disable_nodes_by_outputs_);
  CP_MEMBER(dlnne_input_shape_dict_);
S
Sylwester Fraczek 已提交
434
  // MKLDNN related.
435 436
  CP_MEMBER(use_mkldnn_);
  CP_MEMBER(mkldnn_enabled_op_types_);
437
  CP_MEMBER(mkldnn_cache_capacity_);
438 439 440
  // Bfloat16 related.
  CP_MEMBER(use_mkldnn_bfloat16_);
  CP_MEMBER(bfloat16_enabled_op_types_);
441
  // Quantization related.
B
baoachun 已提交
442 443 444
  CP_MEMBER(use_mkldnn_int8_);
  CP_MEMBER(quantize_enabled_op_types_);
  CP_MEMBER(quantize_excluded_op_ids_);
445 446
  CP_MEMBER(use_mkldnn_quantizer_);
  CP_MEMBER(mkldnn_quantizer_config_);
447 448 449
  CP_MEMBER(min_input_shape_);
  CP_MEMBER(max_input_shape_);
  CP_MEMBER(optim_input_shape_);
450
  CP_MEMBER(disable_trt_plugin_fp16_);
451

石晓伟 已提交
452 453 454 455
  CP_MEMBER(use_lite_);
  CP_MEMBER(lite_precision_mode_);
  CP_MEMBER(lite_passes_filter_);
  CP_MEMBER(lite_ops_filter_);
456 457
  CP_MEMBER(lite_zero_copy_);

W
Wilber 已提交
458
  // XPU related.
459
  CP_MEMBER(use_xpu_);
W
Wilber 已提交
460
  CP_MEMBER(xpu_device_id_);
461
  CP_MEMBER(xpu_l3_workspace_size_);
W
Wilber 已提交
462 463 464 465 466
  CP_MEMBER(xpu_locked_);
  CP_MEMBER(xpu_autotune_);
  CP_MEMBER(xpu_autotune_file_);
  CP_MEMBER(xpu_precision_);
  CP_MEMBER(xpu_adaptive_seqlen_);
467
  CP_MEMBER(xpu_enable_multi_stream_);
石晓伟 已提交
468

469 470 471
  // Lite OpenCL Related
  CP_MEMBER(use_opencl_);

W
Wilber 已提交
472 473 474
  // NPU related.
  CP_MEMBER(use_npu_);
  CP_MEMBER(npu_device_id_);
475
  CP_MEMBER(nnadapter_config_);
W
Wilber 已提交
476

477 478 479
  // profile related.
  CP_MEMBER(with_profile_);

480 481 482
  // cinn compiler related.
  CP_MEMBER(use_cinn_compiler_);

483 484 485
  // glog related.
  CP_MEMBER(with_glog_info_);

486 487 488 489 490 491 492 493 494 495
  // Ir related.
  CP_MEMBER(enable_ir_optim_);
  CP_MEMBER(use_feed_fetch_ops_);
  CP_MEMBER(ir_debug_);
  CP_MEMBER(specify_input_name_);

  CP_MEMBER(cpu_math_library_num_threads_);

  CP_MEMBER(serialized_info_cache_);

496 497
  CP_MEMBER(thread_local_stream_);

J
jianghaicheng 已提交
498 499 500
  // ipu related
  CP_MEMBER(use_ipu_);
  CP_MEMBER(ipu_device_num_);
501
  CP_MEMBER(ipu_micro_batch_size_);
J
jianghaicheng 已提交
502 503
  CP_MEMBER(ipu_enable_pipelining_);
  CP_MEMBER(ipu_batches_per_step_);
504 505 506 507
  CP_MEMBER(ipu_enable_fp16_);
  CP_MEMBER(ipu_replica_num_);
  CP_MEMBER(ipu_available_memory_proportion_);
  CP_MEMBER(ipu_enable_half_partial_);
508
  CP_MEMBER(ipu_enable_model_runtime_executor_);
509 510
  CP_MEMBER(ipu_custom_ops_info_);
  CP_MEMBER(ipu_custom_patterns_);
J
jianghaicheng 已提交
511

512 513 514
  // fleet exe related
  CP_MEMBER(dist_config_);

515 516 517 518 519
  // custom device related.
  CP_MEMBER(use_custom_device_);
  CP_MEMBER(custom_device_type_);
  CP_MEMBER(custom_device_id_);

520 521 522 523
  // JITLayer relate
  CP_MEMBER(apply_optim_);
  CP_MEMBER(skip_load_params_);

524
  if (use_gpu_) {
W
Wilber 已提交
525 526
    PADDLE_ENFORCE_EQ(use_xpu_,
                      false,
527 528
                      platform::errors::InvalidArgument(
                          "Only one choice can be made between CPU and XPU."));
529 530
    pass_builder_.reset(new GpuPassStrategy(
        *static_cast<GpuPassStrategy *>(other.pass_builder())));
J
jianghaicheng 已提交
531 532 533
  } else if (use_ipu_) {
    pass_builder_.reset(new IpuPassStrategy(
        *static_cast<IpuPassStrategy *>(other.pass_builder())));
534 535 536
  } else if (use_xpu_) {
    pass_builder_.reset(new XpuPassStrategy(
        *static_cast<XpuPassStrategy *>(other.pass_builder())));
W
Wilber 已提交
537 538 539
  } else if (use_npu_) {
    pass_builder_.reset(new NpuPassStrategy(
        *static_cast<NpuPassStrategy *>(other.pass_builder())));
540 541 542 543 544
  } else {
    pass_builder_.reset(new CpuPassStrategy(
        *static_cast<CpuPassStrategy *>(other.pass_builder())));
  }

545
#undef CP_MEMBER
Y
Yan Chunwei 已提交
546

W
Wilber 已提交
547
  Update();
548
  if (use_tensorrt_ || use_cinn_compiler_) {
W
Wilber 已提交
549 550 551
    // Update() will reset all the passes, when some tensorRT pass is deleted in
    // other.pass_builder(), it will set again, so we just remove the
    // deleted_pass.
552
    pass_builder_->ClearPasses();
W
Wilber 已提交
553
    auto other_passes = other.pass_builder()->AllPasses();
554 555
    for (auto pass : other_passes) {
      pass_builder_->AppendPass(pass);
W
Wilber 已提交
556
    }
557
  }
D
denglin-github 已提交
558 559 560 561 562 563 564 565
  if (use_dlnne_) {
    auto all_passes = kDlnneSubgraphPasses;
    auto other_passes = other.pass_builder()->AllPasses();
    // We should sort them, because the user may call the SwitchIrDebug
    // interface, which will change the pass.
    std::sort(all_passes.begin(), all_passes.end());
    std::sort(other_passes.begin(), other_passes.end());
    std::vector<std::string> deleted_passes;
W
Wilber 已提交
566 567 568 569
    std::set_difference(all_passes.begin(),
                        all_passes.end(),
                        other_passes.begin(),
                        other_passes.end(),
D
denglin-github 已提交
570 571 572 573 574
                        std::inserter(deleted_passes, deleted_passes.begin()));
    for (auto ps : deleted_passes) {
      pass_builder_->DeletePass(ps);
    }
  }
W
Wilber 已提交
575 576 577 578

  for (auto &delete_pass : other.pass_builder()->GetAllDeletedPasses()) {
    pass_builder_->DeletePass(delete_pass);
  }
579 580
}

581
void AnalysisConfig::EnableCUDNN() {
582
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
583 584 585 586 587 588 589 590 591
  use_cudnn_ = use_gpu_;
#else
  LOG(ERROR) << "Please compile with CUDA first to use cuDNN";
  use_cudnn_ = false;
#endif

  Update();
}

592
void AnalysisConfig::EnableMKLDNN() {
593 594 595 596 597 598
#ifdef PADDLE_WITH_MKLDNN
  use_mkldnn_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MKLDNN";
  use_mkldnn_ = false;
#endif
Y
Yan Chunwei 已提交
599 600

  Update();
601 602
}

603 604 605 606 607 608 609 610 611
void AnalysisConfig::SetMkldnnCacheCapacity(int capacity) {
#ifdef PADDLE_WITH_MKLDNN
  mkldnn_cache_capacity_ = capacity;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to set MKLDNN Thread Id";
  mkldnn_cache_capacity_ = 0;
#endif
}

612 613 614 615 616 617 618 619 620 621 622 623 624
void AnalysisConfig::EnableMkldnnQuantizer() {
#ifdef PADDLE_WITH_MKLDNN
  if (!mkldnn_quantizer_config_)
    mkldnn_quantizer_config_.reset(new MkldnnQuantizerConfig());
  use_mkldnn_quantizer_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnQuantizer";
  use_mkldnn_quantizer_ = false;
#endif

  Update();
}

625 626
void AnalysisConfig::EnableMkldnnBfloat16() {
#ifdef PADDLE_WITH_MKLDNN
627 628
  if (platform::MayIUse(platform::cpu_isa_t::avx512_core)) {
    use_mkldnn_bfloat16_ = true;
629 630 631 632
    LOG(INFO) << "Hardware support for BFLOAT16"
              << (platform::MayIUse(platform::cpu_isa_t::avx512_bf16)
                      ? " is enabled"
                      : " is disabled. Simulation will be used");
633 634 635 636
  } else {
    LOG(INFO) << "CPU does not support BFLOAT16 calculations";
    use_mkldnn_bfloat16_ = false;
  }
637 638 639 640 641 642 643 644
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnBfloat16";
  use_mkldnn_bfloat16_ = false;
#endif

  Update();
}

P
Paulina Gacek 已提交
645 646 647 648 649 650 651 652 653 654
void AnalysisConfig::DisableMkldnnFcPasses() {
#ifdef PADDLE_WITH_MKLDNN
  disable_mkldnn_fc_passes_ = true;
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use DisableMkldnnFcPasses";
  disable_mkldnn_fc_passes_ = false;
#endif
  Update();
}

B
baoachun 已提交
655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683
void AnalysisConfig::EnableMkldnnInt8(
    const std::unordered_set<std::string> &op_list) {
#ifdef PADDLE_WITH_MKLDNN
  use_mkldnn_int8_ = true;
  use_fc_padding_ = false;
  if (!op_list.empty()) {
    for (auto &type : op_list) {
      if (!quantize_enabled_op_types_.count(type)) {
        LOG(ERROR) << "There are unsupported operators in the configured "
                      "quantization operator list. The unsupported operator "
                      "is: "
                   << type;
        use_mkldnn_int8_ = false;
        break;
      }
    }
    if (use_mkldnn_int8_) {
      quantize_enabled_op_types_.clear();
      quantize_enabled_op_types_.insert(op_list.begin(), op_list.end());
    }
  }
#else
  LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnInt8";
  use_mkldnn_int8_ = false;
#endif

  Update();
}

684
MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
685
  PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
686 687
                          platform::errors::PreconditionNotMet(
                              "MkldnnQuantizer was not enabled yet."));
688
  return mkldnn_quantizer_config_.get();
689 690
}

691
void AnalysisConfig::EnableTensorRtEngine(
692
    int64_t workspace_size,
W
Wilber 已提交
693 694 695 696
    int max_batch_size,
    int min_subgraph_size,
    AnalysisConfig::Precision precision_mode,
    bool use_static,
697
    bool use_calib_mode) {
698
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Y
Yan Chunwei 已提交
699
  if (!use_gpu()) {
700
    LOG(ERROR) << "To use TensorRT engine, please call EnableUseGpu() first";
Y
Yan Chunwei 已提交
701 702 703
    return;
  }

704 705 706
  use_tensorrt_ = true;
  tensorrt_workspace_size_ = workspace_size;
  tensorrt_max_batchsize_ = max_batch_size;
N
nhzlx 已提交
707
  tensorrt_min_subgraph_size_ = min_subgraph_size;
N
nhzlx 已提交
708
  tensorrt_precision_mode_ = precision_mode;
N
nhzlx 已提交
709
  trt_use_static_engine_ = use_static;
710
  trt_use_calib_mode_ = use_calib_mode;
Y
Yan Chunwei 已提交
711

712
  Update();
Y
Yan Chunwei 已提交
713 714 715 716
#else
  LOG(ERROR)
      << "To use TensorRT engine, please compile inference lib with GPU first.";
#endif
717 718
}

719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742
void AnalysisConfig::EnableTensorRTMemoryOptim(bool engine_memory_sharing,
                                               int sharing_identifier) {
  PADDLE_ENFORCE_EQ(
      use_tensorrt_,
      true,
      platform::errors::InvalidArgument(
          "To enable TensorRT memory optim, please call "
          "EnableTensorRtEngine or enable_tensorrt_engine first."));
  PADDLE_ENFORCE_GE(sharing_identifier,
                    0,
                    platform::errors::InvalidArgument(
                        "The value of sharing_identifier must be greater "
                        "than or equal to 0."));
  if (!engine_memory_sharing) {
    PADDLE_ENFORCE_EQ(sharing_identifier,
                      0,
                      platform::errors::InvalidArgument(
                          "The value of sharing_identifier must be equal to 0 "
                          "when engine_memory_sharing is false."));
  }
  trt_engine_memory_sharing_ = engine_memory_sharing;
  trt_engine_memory_sharing_identifier_ = sharing_identifier;
}

D
denglin-github 已提交
743 744 745 746 747 748 749 750 751
void AnalysisConfig::EnableDlnne(
    int min_subgraph_size,
    int max_batch_size,
    bool use_static_batch,
    std::string weight_share_mode,
    std::unordered_set<std::string> disable_nodes_by_ouputs,
    std::map<std::string, std::vector<int64_t>> dlnne_input_shape_dict,
    bool use_calib_mode,
    AnalysisConfig::Precision precision_mode) {
D
denglin-github 已提交
752 753
  use_dlnne_ = true;
  dlnne_min_subgraph_size_ = min_subgraph_size;
D
denglin-github 已提交
754 755 756 757 758 759 760
  dlnne_max_batchsize_ = max_batch_size;
  dlnne_use_static_batch_ = use_static_batch;
  dlnne_weight_share_mode_ = weight_share_mode;
  dlnne_disable_nodes_by_outputs_ = disable_nodes_by_ouputs;
  dlnne_input_shape_dict_ = dlnne_input_shape_dict;
  dlnne_use_calib_mode_ = use_calib_mode;
  dlnne_precision_mode_ = precision_mode;
D
denglin-github 已提交
761 762 763
  Update();
}

764 765 766 767 768 769 770 771 772 773 774
void AnalysisConfig::SetTRTDynamicShapeInfo(
    std::map<std::string, std::vector<int>> min_input_shape,
    std::map<std::string, std::vector<int>> max_input_shape,
    std::map<std::string, std::vector<int>> optim_input_shape,
    bool disable_trt_plugin_fp16) {
  min_input_shape_ = min_input_shape;
  max_input_shape_ = max_input_shape;
  optim_input_shape_ = optim_input_shape;
  disable_trt_plugin_fp16_ = disable_trt_plugin_fp16;
}

775 776 777 778 779
void AnalysisConfig::EnableTensorRtDLA(int dla_core) {
  trt_use_dla_ = true;
  trt_dla_core_ = dla_core;
}

780 781
void AnalysisConfig::EnableTensorRtInspector() { trt_use_inspector_ = true; }

782 783 784 785 786
void AnalysisConfig::Exp_DisableTensorRtOPs(
    const std::vector<std::string> &ops) {
  trt_disabled_ops_.insert(trt_disabled_ops_.end(), ops.begin(), ops.end());
}

787
void AnalysisConfig::EnableVarseqlen() { trt_use_varseqlen_ = true; }
788

Y
Yan Chunwei 已提交
789
// TODO(Superjomn) refactor this, buggy.
790
void AnalysisConfig::Update() {
791
  auto &&info = SerializeInfoCache();
792 793
  if (info == serialized_info_cache_) return;

Y
Yan Chunwei 已提交
794
  // Transfer pass_builder and copy the existing compatible passes.
W
Wilber 已提交
795 796
  if (!pass_builder_ || ((use_gpu() ^ pass_builder_->use_gpu())) ||
      ((use_xpu() ^ pass_builder_->use_xpu())) ||
J
jianghaicheng 已提交
797
      ((use_npu() ^ pass_builder_->use_npu())) ||
798 799
      ((use_ipu() ^ pass_builder_->use_ipu())) ||
      ((use_custom_device() ^ pass_builder_->use_custom_device()))) {
Y
Yan Chunwei 已提交
800 801
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy);
J
jianghaicheng 已提交
802 803
    } else if (use_ipu()) {
      pass_builder_.reset(new IpuPassStrategy);
804 805
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
806 807
          use_gpu(),
          false,
808 809 810
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy);
W
Wilber 已提交
811 812
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
813 814
          use_gpu(),
          false,
W
Wilber 已提交
815 816 817
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy);
818 819
    } else if (use_custom_device()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
820 821
          use_gpu(),
          false,
822 823 824
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and CustomDevice."));
      pass_builder_.reset(new CustomDevicePassStrategy);
Y
Yan Chunwei 已提交
825 826 827
    } else {
      pass_builder_.reset(new CpuPassStrategy);
    }
828

829
  } else {
Y
Yan Chunwei 已提交
830 831 832
    if (use_gpu()) {
      pass_builder_.reset(new GpuPassStrategy(
          *static_cast<GpuPassStrategy *>(pass_builder_.get())));
J
jianghaicheng 已提交
833 834 835 836
    } else if (use_ipu()) {
      VLOG(1) << "IpuPassStrategy has been used.";
      pass_builder_.reset(new IpuPassStrategy(
          *static_cast<IpuPassStrategy *>(pass_builder_.get())));
837 838
    } else if (use_xpu()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
839 840
          use_gpu(),
          false,
841 842 843 844
          platform::errors::InvalidArgument(
              "Only one choice can be made between CPU and XPU."));
      pass_builder_.reset(new XpuPassStrategy(
          *static_cast<XpuPassStrategy *>(pass_builder_.get())));
W
Wilber 已提交
845 846
    } else if (use_npu()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
847 848
          use_gpu(),
          false,
W
Wilber 已提交
849 850 851 852
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and NPU."));
      pass_builder_.reset(new NpuPassStrategy(
          *static_cast<NpuPassStrategy *>(pass_builder_.get())));
853 854
    } else if (use_custom_device()) {
      PADDLE_ENFORCE_EQ(
W
Wilber 已提交
855 856
          use_gpu(),
          false,
857 858 859 860
          platform::errors::InvalidArgument(
              "Only one choice can be made between GPU and CustomDevice."));
      pass_builder_.reset(new CustomDevicePassStrategy(
          *static_cast<CustomDevicePassStrategy *>(pass_builder_.get())));
Y
Yan Chunwei 已提交
861 862 863 864
    } else {
      pass_builder_.reset(new CpuPassStrategy(
          *static_cast<CpuPassStrategy *>(pass_builder_.get())));
    }
865 866 867
  }

  if (use_tensorrt_) {
868 869
    pass_builder()->ClearPasses();
    for (const auto &pass : kTRTSubgraphPasses) {
870
      if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
871
          (pass == "conv_bn_fuse_pass")) {
872 873
        continue;
      }
874
      pass_builder()->AppendPass(pass);
875 876
    }
  }
877

878 879 880 881 882 883 884 885
  // TODO(wilber): An ugly method to update pass, need to be fixed.
  if (use_cinn_compiler_) {
    pass_builder()->ClearPasses();
    for (const auto &pass : kCINNCompilerPasses) {
      pass_builder()->AppendPass(pass);
    }
  }

D
denglin-github 已提交
886 887 888 889 890 891 892
  if (use_dlnne_) {
    pass_builder()->ClearPasses();
    for (const auto &pass : kDlnneSubgraphPasses) {
      pass_builder()->AppendPass(pass);
    }
  }

893
  if (use_gpu() && use_cudnn_) {
894
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
895 896 897 898 899 900 901 902
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableCUDNN() only works when IR optimization is enabled.";
    } else {
      pass_builder()->EnableCUDNN();
    }
#endif
  }

903
  if (use_mkldnn_) {
W
Wojciech Uss 已提交
904
#ifdef PADDLE_WITH_MKLDNN
905 906 907
    if (!enable_ir_optim_) {
      LOG(ERROR)
          << "EnableMKLDNN() only works when IR optimization is enabled.";
W
Wojciech Uss 已提交
908 909
    } else {
      pass_builder()->EnableMKLDNN();
910 911 912 913
    }
#endif
  }

914 915 916 917 918
  // Quantization passes must come after all other optimization passes
  if (use_mkldnn_quantizer_) {
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableMkldnnQuantizer() only works when IR optimization "
                    "is enabled.";
919 920
    }
#ifdef PADDLE_WITH_MKLDNN
921
    pass_builder()->EnableMkldnnQuantizer();
922 923 924
#endif
  }

925 926 927 928 929 930
  if (use_mkldnn_bfloat16_) {
#ifdef PADDLE_WITH_MKLDNN
    pass_builder()->EnableMkldnnBfloat16();
#endif
  }

B
baoachun 已提交
931 932 933 934 935 936 937 938 939 940 941 942 943 944
  if (use_mkldnn_int8_) {
#ifdef PADDLE_WITH_MKLDNN
    if (!enable_ir_optim_) {
      LOG(ERROR) << "EnableMkldnnInt8() only works when IR optimization "
                    "is enabled.";
    } else if (!use_mkldnn_) {
      LOG(ERROR) << "EnableMkldnnInt8() only works when MKLDNN "
                    "is enabled.";
    } else {
      pass_builder()->EnableMkldnnInt8();
    }
#endif
  }

P
Paulina Gacek 已提交
945 946 947 948 949 950
  if (disable_mkldnn_fc_passes_) {
#ifdef PADDLE_WITH_MKLDNN
    pass_builder()->DisableMkldnnFcPasses();
#endif
  }

951 952
  // TODO(inference): When we enable memory_optimize and mkldnn, PaddleSeg model
  // fail.
Y
Yan Chunwei 已提交
953
  if (enable_memory_optim_) {
954 955 956 957 958 959 960 961 962 963 964
#ifdef PADDLE_WITH_MKLDNN
    if (use_mkldnn_) {
      enable_memory_optim_ = false;
      LOG_FIRST_N(WARNING, 1)
          << "It is detected that mkldnn and memory_optimize_pass are enabled "
             "at the same time, but they are not supported yet. Currently, "
             "memory_optimize_pass is explicitly disabled";
    } else {
      pass_builder()->AppendAnalysisPass("memory_optimize_pass");
    }
#else
965
    pass_builder()->AppendAnalysisPass("memory_optimize_pass");
966
#endif
Y
Yan Chunwei 已提交
967 968
  }

石晓伟 已提交
969 970 971 972 973 974 975
  if (use_lite_) {
#ifndef PADDLE_WITH_LITE
    LOG(WARNING) << "You tried to enable the lite subgraph "
                    "but did not have the option -DWITH_LITE compiled.";
#endif
    pass_builder()->ClearPasses();
    for (const auto &pass : kLiteSubgraphPasses) {
W
Wilber 已提交
976 977
      if (std::find(lite_passes_filter_.begin(),
                    lite_passes_filter_.end(),
石晓伟 已提交
978 979 980 981 982 983
                    pass) == lite_passes_filter_.end()) {
        pass_builder()->AppendPass(pass);
      }
    }
  }

984
  if (use_xpu_) {
985
#if (defined LITE_SUBGRAPH_WITH_XPU) || (defined PADDLE_WITH_XPU)
W
Wilber 已提交
986 987
    PADDLE_ENFORCE_EQ(use_gpu_,
                      false,
988 989 990
                      platform::errors::Unavailable(
                          "Currently, XPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
991 992 993 994 995
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an XPU device, but Paddle was not compiled "
        "with XPU-runtime."));
#endif
996 997
  }

W
Wilber 已提交
998
  if (use_npu_) {
999
#if defined(PADDLE_WITH_ASCEND_CL) || defined(LITE_SUBGRAPH_WITH_NPU)
W
Wilber 已提交
1000 1001
    PADDLE_ENFORCE_EQ(use_gpu_,
                      false,
W
Wilber 已提交
1002 1003 1004 1005 1006 1007 1008 1009 1010
                      platform::errors::Unavailable(
                          "Currently, NPU and GPU cannot be enabled in the "
                          "same analysis configuration."));
#else
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to use an NPU device, but Paddle was not compiled "
        "with NPU-runtime."));
#endif
  }
J
jianghaicheng 已提交
1011 1012 1013 1014 1015 1016 1017
  if (use_ipu_) {
#ifndef PADDLE_WITH_IPU
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to enable the ipu "
        "but did not have the option -DWITH_IPU compiled."));
#endif
  }
1018 1019 1020 1021 1022 1023 1024
  if (use_custom_device_) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
    PADDLE_THROW(platform::errors::Unavailable(
        "You tried to enable the custom device "
        "but did not have the option -DWITH_CUSTOM_DEVICE compiled."));
#endif
  }
1025 1026
}

1027
std::string AnalysisConfig::SerializeInfoCache() {
1028
  std::stringstream ss;
Y
Yan Chunwei 已提交
1029 1030 1031 1032
  ss << model_dir_;
  ss << prog_file_;
  ss << params_file_;

1033
  ss << use_gpu_;
1034
  ss << enable_gpu_mixed_;
1035 1036
  ss << use_external_stream_;
  ss << exec_stream_;
1037
  ss << use_fc_padding_;
1038 1039
  ss << gpu_device_id_;
  ss << xpu_device_id_;
1040 1041 1042 1043 1044
  ss << memory_pool_init_size_mb_;

  ss << use_tensorrt_;
  ss << tensorrt_workspace_size_;
  ss << tensorrt_max_batchsize_;
Y
Yan Chunwei 已提交
1045 1046
  ss << tensorrt_min_subgraph_size_;

D
denglin-github 已提交
1047 1048 1049
  ss << use_dlnne_;
  ss << dlnne_min_subgraph_size_;

1050 1051 1052
  for (auto &op : trt_disabled_ops_) ss << op.c_str();
  ss << ";";

1053 1054 1055
  ss << trt_use_dla_;
  ss << trt_dla_core_;

Y
Yan Chunwei 已提交
1056
  ss << enable_memory_optim_;
1057
  ss << trt_engine_memory_sharing_;
1058 1059

  ss << use_mkldnn_;
1060
  ss << mkldnn_cache_capacity_;
Y
Yan Chunwei 已提交
1061 1062 1063
  for (auto &item : mkldnn_enabled_op_types_) ss << item;
  ss << ";";

1064
  ss << use_mkldnn_quantizer_;
1065
  ss << use_mkldnn_bfloat16_;
1066
  for (auto &item : bfloat16_enabled_op_types_) ss << item;
B
baoachun 已提交
1067 1068 1069
  ss << use_mkldnn_int8_;
  for (auto &item : quantize_enabled_op_types_) ss << item;
  for (auto &item : quantize_excluded_op_ids_) ss << item;
1070
  ss << ";";
Y
Yan Chunwei 已提交
1071 1072
  ss << model_from_memory_;

1073 1074
  ss << with_profile_;

1075 1076
  ss << with_glog_info_;

1077 1078 1079 1080
  ss << enable_ir_optim_;
  ss << use_feed_fetch_ops_;
  ss << ir_debug_;

Y
Yan Chunwei 已提交
1081 1082
  ss << specify_input_name_;
  ss << cpu_math_library_num_threads_;
石晓伟 已提交
1083 1084

  ss << use_lite_;
1085 1086
  ss << use_xpu_;
  ss << xpu_l3_workspace_size_;
W
Wilber 已提交
1087 1088 1089 1090 1091
  ss << xpu_locked_;
  ss << xpu_autotune_;
  ss << xpu_autotune_file_;
  ss << xpu_precision_;
  ss << xpu_adaptive_seqlen_;
1092
  ss << xpu_enable_multi_stream_;
1093

W
Wilber 已提交
1094 1095 1096
  ss << use_npu_;
  ss << npu_device_id_;

1097 1098
  ss << thread_local_stream_;

J
jianghaicheng 已提交
1099 1100
  ss << use_ipu_;
  ss << ipu_device_num_;
1101
  ss << ipu_micro_batch_size_;
J
jianghaicheng 已提交
1102 1103
  ss << ipu_enable_pipelining_;
  ss << ipu_batches_per_step_;
1104 1105 1106 1107
  ss << ipu_enable_fp16_;
  ss << ipu_replica_num_;
  ss << ipu_available_memory_proportion_;
  ss << ipu_enable_half_partial_;
1108
  ss << ipu_enable_model_runtime_executor_;
1109 1110 1111 1112 1113 1114
  for (auto custom_op : ipu_custom_ops_info_)
    for (auto attr : custom_op) ss << attr;
  ss << ";";
  for (auto pattern : ipu_custom_patterns_)
    for (auto attr : pattern) ss << attr;
  ss << ";";
1115
  for (auto &op : mixed_black_list_) ss << op.c_str();
1116 1117 1118
  return ss.str();
}

1119
void AnalysisConfig::SetCpuMathLibraryNumThreads(
1120 1121
    int cpu_math_library_num_threads) {
  cpu_math_library_num_threads_ = cpu_math_library_num_threads;
Y
Yan Chunwei 已提交
1122 1123

  Update();
1124 1125
}

1126
float AnalysisConfig::fraction_of_gpu_memory_for_pool() const {
1127
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1128 1129
  // Get the GPU memory details and calculate the fraction of memory for the
  // GPU memory pool.
1130
  size_t gpu_total, gpu_available;
1131
  platform::SetDeviceId(gpu_device_id_);
1132 1133
  platform::GpuMemoryUsage(&gpu_available, &gpu_total);
  double total_gpu_memory = gpu_total / 1024. / 1024.;
1134 1135
  float fraction_of_gpu_memory =
      static_cast<double>(memory_pool_init_size_mb()) / total_gpu_memory;
1136 1137 1138 1139
  VLOG(3) << "total_gpu_memory is " << total_gpu_memory
          << "M, gpu_available is " << gpu_available / 1024. / 1024.
          << "M, memory_pool_init_size is " << memory_pool_init_size_mb()
          << "M.";
1140 1141 1142 1143
  return fraction_of_gpu_memory;
#else
  return 0.;
#endif
1144 1145
}

1146 1147
void AnalysisConfig::EnableMemoryOptim(bool x) {
  enable_memory_optim_ = x;
Y
Yan Chunwei 已提交
1148 1149 1150
  Update();
}

1151
bool AnalysisConfig::enable_memory_optim() const {
Y
Yan Chunwei 已提交
1152 1153 1154
  return enable_memory_optim_;
}

1155 1156 1157 1158
bool AnalysisConfig::trt_engine_memory_sharing() const {
  return trt_engine_memory_sharing_;
}

1159 1160 1161 1162
void AnalysisConfig::SetModelBuffer(const char *prog_buffer,
                                    size_t prog_buffer_size,
                                    const char *param_buffer,
                                    size_t param_buffer_size) {
1163 1164
  prog_file_ = std::string(prog_buffer, prog_buffer + prog_buffer_size);
  params_file_ = std::string(param_buffer, param_buffer + param_buffer_size);
T
Tao Luo 已提交
1165
  model_from_memory_ = true;
T
Tao Luo 已提交
1166 1167
}

1168
NativeConfig AnalysisConfig::ToNativeConfig() const {
Y
Yan Chunwei 已提交
1169 1170 1171 1172 1173
  NativeConfig config;
  config.model_dir = model_dir_;
  config.prog_file = prog_file_;
  config.param_file = params_file_;
  config.use_gpu = use_gpu_;
1174
  config.device = gpu_device_id_;
Y
Yan Chunwei 已提交
1175 1176 1177 1178 1179
  config.fraction_of_gpu_memory = fraction_of_gpu_memory_for_pool();
  config.specify_input_name = specify_input_name_;
  return config;
}

Y
Yan Chunwei 已提交
1180 1181 1182 1183
void AnalysisConfig::SwitchIrDebug(int x) {
  ir_debug_ = x;
  Update();
}
1184 1185 1186 1187 1188 1189

void AnalysisConfig::EnableProfile() {
  with_profile_ = true;
  Update();
}

1190 1191 1192 1193 1194
void AnalysisConfig::DisableGlogInfo() {
  with_glog_info_ = false;
  Update();
}

石晓伟 已提交
1195
void AnalysisConfig::EnableLiteEngine(
W
Wilber 已提交
1196 1197
    AnalysisConfig::Precision precision_mode,
    bool zero_copy,
石晓伟 已提交
1198 1199 1200 1201 1202 1203
    const std::vector<std::string> &passes_filter,
    const std::vector<std::string> &ops_filter) {
  use_lite_ = true;
  lite_precision_mode_ = precision_mode;
  lite_passes_filter_ = passes_filter;
  lite_ops_filter_ = ops_filter;
1204
  lite_zero_copy_ = zero_copy;
石晓伟 已提交
1205 1206 1207
  Update();
}

1208 1209 1210 1211 1212
void AnalysisConfig::EnableOpenCL() {
  use_opencl_ = true;
  Update();
}

1213 1214 1215 1216 1217 1218 1219
void AnalysisConfig::PartiallyRelease() {
  prog_file_.clear();
  prog_file_.shrink_to_fit();
  params_file_.clear();
  params_file_.shrink_to_fit();
}

1220 1221
void AnalysisConfig::EnableGpuMultiStream() { thread_local_stream_ = true; }

1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232
std::string AnalysisConfig::Summary() {
  const std::vector<std::string> header{"Option", "Value"};
  paddle::inference::TablePrinter os(header);

  if (!model_dir_.empty()) {
    os.InsertRow({"model_dir", model_dir_});
  }
  if (!(prog_file_.empty() && params_file_.empty())) {
    os.InsertRow({"model_file", prog_file_});
    os.InsertRow({"params_file", params_file_});
  }
1233

1234 1235 1236 1237 1238 1239 1240 1241
  if (model_from_memory_) {
    os.InsertRow({"model_from_memory", params_file_});
  }
  os.InsetDivider();

  // cpu info
  os.InsertRow(
      {"cpu_math_thread", std::to_string(cpu_math_library_num_threads_)});
1242
  os.InsertRow({"enable_mkldnn", use_mkldnn_ ? "true" : "false"});
1243 1244 1245 1246 1247 1248 1249 1250
  os.InsertRow(
      {"mkldnn_cache_capacity", std::to_string(mkldnn_cache_capacity_)});
  os.InsetDivider();

  // gpu info
  os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
  if (use_gpu_) {
    os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
1251
    os.InsertRow({"enable_gpu_mixed_", std::to_string(enable_gpu_mixed_)});
1252 1253
    os.InsertRow({"memory_pool_init_size",
                  std::to_string(memory_pool_init_size_mb_) + "MB"});
1254 1255
    os.InsertRow(
        {"use_external_stream", use_external_stream_ ? "true" : "false"});
1256 1257 1258 1259 1260
    os.InsertRow(
        {"thread_local_stream", thread_local_stream_ ? "true" : "false"});

    os.InsertRow({"use_tensorrt", use_tensorrt_ ? "true" : "false"});
    if (use_tensorrt_) {
1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287
#ifdef PADDLE_WITH_TENSORRT
      auto Precision2String =
          [](paddle::AnalysisConfig::Precision prec) -> std::string {
        if (prec == Precision::kFloat32)
          return "fp32";
        else if (prec == Precision::kHalf)
          return "fp16";
        else if (prec == Precision::kInt8)
          return "int8";
        else
          return "None";
      };
      auto version2string =
          [](const std::tuple<int, int, int> &ver) -> std::string {
        std::ostringstream os;
        int major = std::get<0>(ver);
        int minor = std::get<1>(ver);
        int patch = std::get<2>(ver);
        os << major << "." << minor << "." << patch;
        return os.str();
      };
      os.InsertRow(
          {"trt_compile_version",
           version2string(inference::tensorrt::GetTrtCompileVersion())});
      os.InsertRow(
          {"trt_runtime_version",
           version2string(inference::tensorrt::GetTrtRuntimeVersion())});
1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303
      os.InsertRow({"tensorrt_precision_mode",
                    Precision2String(tensorrt_precision_mode_)});
      os.InsertRow({"tensorrt_workspace_size",
                    std::to_string(tensorrt_workspace_size_)});
      os.InsertRow(
          {"tensorrt_max_batch_size", std::to_string(tensorrt_max_batchsize_)});
      os.InsertRow({"tensorrt_min_subgraph_size",
                    std::to_string(tensorrt_min_subgraph_size_)});
      os.InsertRow({"tensorrt_use_static_engine",
                    trt_use_static_engine_ ? "true" : "false"});
      os.InsertRow(
          {"tensorrt_use_calib_mode", trt_use_calib_mode_ ? "true" : "false"});

      // dynamic_shape
      os.InsertRow({"tensorrt_enable_dynamic_shape",
                    min_input_shape_.empty() ? "false" : "true"});
W
Wilber 已提交
1304 1305 1306
      os.InsertRow(
          {"tensorrt_tuned_dynamic_shape",
           trt_tuned_dynamic_shape_ ? shape_range_info_path_ : "false"});
1307

1308 1309
      os.InsertRow(
          {"tensorrt_use_varseqlen", trt_use_varseqlen_ ? "true" : "false"});
1310 1311
      os.InsertRow({"tensorrt_with_interleaved",
                    trt_with_interleaved_ ? "true" : "false"});
1312 1313 1314
      os.InsertRow({"tensorrt_transformer_posid", tensorrt_transformer_posid_});
      os.InsertRow(
          {"tensorrt_transformer_maskid", tensorrt_transformer_maskid_});
1315 1316 1317 1318
      os.InsertRow({"tensorrt_use_dla", trt_use_dla_ ? "true" : "false"});
      if (trt_use_dla_) {
        os.InsertRow({"tensorrt_dla_core", std::to_string(trt_dla_core_)});
      }
1319 1320
      os.InsertRow({"trt_engine_memory_sharing",
                    trt_engine_memory_sharing_ ? "true" : "false"});
1321
#endif
1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338
    }
  }
  os.InsetDivider();

  // xpu info
  os.InsertRow({"use_xpu", use_xpu_ ? "true" : "false"});
  if (use_xpu_) {
    os.InsertRow({"xpu_device_id", std::to_string(xpu_device_id_)});
    os.InsertRow(
        {"xpu_l3_workspace_size", std::to_string(xpu_l3_workspace_size_)});
  }
  os.InsetDivider();

  if (use_lite_) {
    os.InsertRow({"use_lite", use_lite_ ? "true" : "false"});
  }

1339 1340 1341
  // cinn compiler
  os.InsertRow({"use_cinn_compiler", use_cinn_compiler_ ? "true" : "false"});

1342 1343 1344 1345 1346 1347
  // ir info
  os.InsertRow({"ir_optim", enable_ir_optim_ ? "true" : "false"});
  os.InsertRow({"ir_debug", ir_debug_ ? "true" : "false"});
  os.InsertRow({"memory_optim", enable_memory_optim_ ? "true" : "false"});
  os.InsertRow({"enable_profile", with_profile_ ? "true" : "false"});
  os.InsertRow({"enable_log", with_glog_info_ ? "true" : "false"});
1348 1349
  os.InsertRow({"collect_shape_range_info",
                collect_shape_range_info_ ? shape_range_info_path_ : "false"});
1350 1351 1352 1353

  return os.PrintTable();
}

1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374
LiteNNAdapterConfig &LiteNNAdapterConfig::SetDeviceNames(
    const std::vector<std::string> &names) {
  nnadapter_device_names = names;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetContextProperties(
    const std::string &properties) {
  nnadapter_context_properties = properties;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheDir(
    const std::string &dir) {
  nnadapter_model_cache_dir = dir;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetModelCacheBuffers(
    const std::string &model_cache_token,
    const std::vector<char> &model_cache_buffer) {
W
Wilber 已提交
1375 1376
  PADDLE_ENFORCE_EQ(model_cache_token.empty(),
                    false,
1377 1378
                    platform::errors::InvalidArgument(
                        "model_cache_token should not be empty."));
W
Wilber 已提交
1379 1380
  PADDLE_ENFORCE_EQ(model_cache_buffer.empty(),
                    false,
1381 1382 1383
                    platform::errors::InvalidArgument(
                        "model_cache_buffer should not be empty."));
  PADDLE_ENFORCE_EQ(nnadapter_model_cache_buffers.count(model_cache_token),
1384 1385 1386
                    false,
                    platform::errors::InvalidArgument(
                        "model_cache_token has already been set."));
1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411

  nnadapter_model_cache_buffers[model_cache_token] = model_cache_buffer;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigPath(
    const std::string &path) {
  nnadapter_subgraph_partition_config_path = path;
  return *this;
}

LiteNNAdapterConfig &LiteNNAdapterConfig::SetSubgraphPartitionConfigBuffer(
    const std::string &buffer) {
  nnadapter_subgraph_partition_config_buffer = buffer;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Enable() {
  use_nnadapter = true;
  return *this;
}
LiteNNAdapterConfig &LiteNNAdapterConfig::Disable() {
  use_nnadapter = false;
  return *this;
}

1412 1413 1414 1415 1416 1417 1418
void AnalysisConfig::CollectShapeRangeInfo(
    const std::string &shape_range_info_path) {
  LOG(INFO) << "In CollectShapeInfo mode, we will disable optimizations and "
               "collect the shape information of "
            << "all intermediate tensors in the compute graph and calculate "
               "the min_shape, max_shape and opt_shape.";
  collect_shape_range_info_ = true;
W
Wilber 已提交
1419 1420
  PADDLE_ENFORCE_EQ(shape_range_info_path.empty(),
                    false,
1421 1422 1423 1424 1425 1426
                    platform::errors::InvalidArgument(
                        "The shape_range_info_path should not be empty, please "
                        "re-check the argument."));
  shape_range_info_path_ = shape_range_info_path;
}

1427
const std::string &AnalysisConfig::shape_range_info_path() const {
1428 1429 1430
  return shape_range_info_path_;
}

1431
bool AnalysisConfig::shape_range_info_collected() const {
1432 1433 1434 1435 1436 1437 1438 1439 1440 1441
  return collect_shape_range_info_;
}

void AnalysisConfig::EnableTunedTensorRtDynamicShape(
    const std::string &shape_range_info_path, bool allow_build_at_runtime) {
  shape_range_info_path_ = shape_range_info_path;
  trt_allow_build_at_runtime_ = allow_build_at_runtime;
  trt_tuned_dynamic_shape_ = true;
}

1442
bool AnalysisConfig::tuned_tensorrt_dynamic_shape() const {
1443 1444 1445
  return trt_tuned_dynamic_shape_;
}

1446
bool AnalysisConfig::trt_allow_build_at_runtime() const {
1447 1448
  return trt_allow_build_at_runtime_;
}
1449

1450
void AnalysisConfig::Exp_DisableMixedPrecisionOps(
1451 1452 1453 1454
    const std::unordered_set<std::string> &black_list) {
  mixed_black_list_ = black_list;
}

1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469
void AnalysisConfig::Exp_EnableCINNCompiler() {
#ifdef PADDLE_WITH_CINN
  use_cinn_compiler_ = true;
  Update();
#else
  PADDLE_THROW(platform::errors::Unavailable(
      "You tried to use CINN compiler, but Paddle was not compiled "
      "with CINN."));
#endif
}

bool AnalysisConfig::cinn_compiler_enabled() const {
  return use_cinn_compiler_;
}

1470
}  // namespace paddle