pd_config.cc 14.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <vector>
21

22
#include "paddle/fluid/inference/capi/c_api_internal.h"
F
flame 已提交
23
#include "paddle/fluid/inference/capi/paddle_c_api.h"
24
#include "paddle/fluid/platform/enforce.h"
25

F
flame 已提交
26
using paddle::ConvertToACPrecision;
27 28 29 30 31 32 33 34 35 36 37
using paddle::ConvertToPaddleDType;
using paddle::ConvertToPDDataType;

extern "C" {

PD_AnalysisConfig* PD_NewAnalysisConfig() { return new PD_AnalysisConfig; }  //

void PD_DeleteAnalysisConfig(PD_AnalysisConfig* config) {
  if (config) {
    delete config;
    config = nullptr;
38
    VLOG(3) << "PD_AnalysisConfig delete successfully. ";
39 40 41 42 43 44
  }
}

void PD_SetModel(PD_AnalysisConfig* config, const char* model_dir,
                 const char* params_path) {
  LOG(INFO) << model_dir;
45 46 47 48
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
49 50 51 52 53 54 55 56 57
  LOG(INFO) << std::string(model_dir);
  if (!params_path) {
    config->config.SetModel(std::string(model_dir));
  } else {
    config->config.SetModel(std::string(model_dir), std::string(params_path));
  }
}

void PD_SetProgFile(PD_AnalysisConfig* config, const char* x) {
58 59 60 61
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
62 63 64 65
  config->config.SetProgFile(std::string(x));
}

void PD_SetParamsFile(PD_AnalysisConfig* config, const char* x) {
66 67 68 69
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
70 71 72 73
  config->config.SetParamsFile(std::string(x));
}

void PD_SetOptimCacheDir(PD_AnalysisConfig* config, const char* opt_cache_dir) {
74 75 76 77
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
78 79 80 81
  config->config.SetOptimCacheDir(std::string(opt_cache_dir));
}

const char* PD_ModelDir(const PD_AnalysisConfig* config) {
82 83 84 85
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
86 87 88 89
  return config->config.model_dir().c_str();
}

const char* PD_ProgFile(const PD_AnalysisConfig* config) {
90 91 92 93
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
94 95 96 97
  return config->config.prog_file().c_str();
}

const char* PD_ParamsFile(const PD_AnalysisConfig* config) {
98 99 100 101
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
102 103 104
  return config->config.params_file().c_str();
}

F
flame 已提交
105 106
void PD_EnableUseGpu(PD_AnalysisConfig* config, int memory_pool_init_size_mb,
                     int device_id) {
107 108 109 110
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
F
flame 已提交
111 112
  config->config.EnableUseGpu(static_cast<uint64_t>(memory_pool_init_size_mb),
                              device_id);
113 114
}

115 116 117 118 119 120 121 122
void PD_EnableXpu(PD_AnalysisConfig* config, int l3_workspace_size) {
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
  config->config.EnableXpu(l3_workspace_size);
}

123
void PD_DisableGpu(PD_AnalysisConfig* config) {
124 125 126 127
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
128 129 130 131
  config->config.DisableGpu();
}

bool PD_UseGpu(const PD_AnalysisConfig* config) {
132 133 134 135
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
136 137 138
  return config->config.use_gpu();
}

139 140 141 142 143 144 145 146
bool PD_UseXpu(const PD_AnalysisConfig* config) {
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
  return config->config.use_xpu();
}

147
int PD_GpuDeviceId(const PD_AnalysisConfig* config) {
148 149 150 151
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
152 153 154
  return config->config.gpu_device_id();
}

155 156 157 158 159 160 161 162
int PD_XpuDeviceId(const PD_AnalysisConfig* config) {
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
  return config->config.xpu_device_id();
}

163
int PD_MemoryPoolInitSizeMb(const PD_AnalysisConfig* config) {
164 165 166 167
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
168 169 170 171
  return config->config.memory_pool_init_size_mb();
}

float PD_FractionOfGpuMemoryForPool(const PD_AnalysisConfig* config) {
172 173 174 175
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
176 177 178 179
  return config->config.fraction_of_gpu_memory_for_pool();
}

void PD_EnableCUDNN(PD_AnalysisConfig* config) {
180 181 182 183
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
184 185 186 187
  config->config.EnableCUDNN();
}

bool PD_CudnnEnabled(const PD_AnalysisConfig* config) {
188 189 190 191
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
192 193 194 195
  return config->config.cudnn_enabled();
}

void PD_SwitchIrOptim(PD_AnalysisConfig* config, bool x) {
196 197 198 199
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
200 201 202 203
  config->config.SwitchIrOptim(x);
}

bool PD_IrOptim(const PD_AnalysisConfig* config) {
204 205 206 207
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
208 209 210 211
  return config->config.ir_optim();
}

void PD_SwitchUseFeedFetchOps(PD_AnalysisConfig* config, bool x) {
212 213 214 215
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
216 217 218 219
  config->config.SwitchUseFeedFetchOps(x);
}

bool PD_UseFeedFetchOpsEnabled(const PD_AnalysisConfig* config) {
220 221 222 223
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
224 225 226 227
  return config->config.use_feed_fetch_ops_enabled();
}

void PD_SwitchSpecifyInputNames(PD_AnalysisConfig* config, bool x) {
228 229 230 231
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
232 233 234 235
  config->config.SwitchSpecifyInputNames(x);
}

bool PD_SpecifyInputName(const PD_AnalysisConfig* config) {
236 237 238 239
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
240 241 242 243 244 245 246
  return config->config.specify_input_name();
}

void PD_EnableTensorRtEngine(PD_AnalysisConfig* config, int workspace_size,
                             int max_batch_size, int min_subgraph_size,
                             Precision precision, bool use_static,
                             bool use_calib_mode) {
247 248 249 250
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
251 252 253 254 255 256
  config->config.EnableTensorRtEngine(
      workspace_size, max_batch_size, min_subgraph_size,
      paddle::ConvertToACPrecision(precision), use_static, use_calib_mode);
}

bool PD_TensorrtEngineEnabled(const PD_AnalysisConfig* config) {
257 258 259 260
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
261 262 263
  return config->config.tensorrt_engine_enabled();
}

D
denglin-github 已提交
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
void PD_EnableDlnne(PD_AnalysisConfig* config, int min_subgraph_size) {
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
  config->config.EnableDlnne(min_subgraph_size);
}

bool PD_DlnneEnabled(const PD_AnalysisConfig* config) {
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
  return config->config.dlnne_enabled();
}

280
void PD_SwitchIrDebug(PD_AnalysisConfig* config, bool x) {
281 282 283 284
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
285 286 287 288
  config->config.SwitchIrDebug(x);
}

void PD_EnableMKLDNN(PD_AnalysisConfig* config) {
289 290 291 292
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
293 294 295 296
  config->config.EnableMKLDNN();
}

void PD_SetMkldnnCacheCapacity(PD_AnalysisConfig* config, int capacity) {
297 298 299 300
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
301 302 303 304
  config->config.SetMkldnnCacheCapacity(capacity);
}

bool PD_MkldnnEnabled(const PD_AnalysisConfig* config) {
305 306 307 308
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
309 310 311 312 313
  return config->config.mkldnn_enabled();
}

void PD_SetCpuMathLibraryNumThreads(PD_AnalysisConfig* config,
                                    int cpu_math_library_num_threads) {
314 315 316 317
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
318 319 320 321
  config->config.SetCpuMathLibraryNumThreads(cpu_math_library_num_threads);
}

int PD_CpuMathLibraryNumThreads(const PD_AnalysisConfig* config) {
322 323 324 325
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
326 327 328 329
  return config->config.cpu_math_library_num_threads();
}

void PD_EnableMkldnnQuantizer(PD_AnalysisConfig* config) {
330 331 332 333
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
334 335 336 337
  config->config.EnableMkldnnQuantizer();
}

bool PD_MkldnnQuantizerEnabled(const PD_AnalysisConfig* config) {
338 339 340 341
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
342 343 344
  return config->config.mkldnn_quantizer_enabled();
}

345
void PD_EnableMkldnnBfloat16(PD_AnalysisConfig* config) {
346 347 348 349
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
350 351 352 353
  config->config.EnableMkldnnBfloat16();
}

bool PD_MkldnnBfloat16Enabled(const PD_AnalysisConfig* config) {
354 355 356 357
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
358 359 360
  return config->config.mkldnn_bfloat16_enabled();
}

361 362 363
void PD_SetModelBuffer(PD_AnalysisConfig* config, const char* prog_buffer,
                       size_t prog_buffer_size, const char* params_buffer,
                       size_t params_buffer_size) {
364 365 366 367
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
368 369 370 371 372
  config->config.SetModelBuffer(prog_buffer, prog_buffer_size, params_buffer,
                                params_buffer_size);
}

bool PD_ModelFromMemory(const PD_AnalysisConfig* config) {
373 374 375 376
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
377 378 379 380
  return config->config.model_from_memory();
}

void PD_EnableMemoryOptim(PD_AnalysisConfig* config) {
381 382 383 384
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
385 386 387 388
  config->config.EnableMemoryOptim();
}

bool PD_MemoryOptimEnabled(const PD_AnalysisConfig* config) {
389 390 391 392
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
393 394 395 396
  return config->config.enable_memory_optim();
}

void PD_EnableProfile(PD_AnalysisConfig* config) {
397 398 399 400
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
401 402 403 404
  config->config.EnableProfile();
}

bool PD_ProfileEnabled(const PD_AnalysisConfig* config) {
405 406 407 408
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
409 410 411 412
  return config->config.profile_enabled();
}

void PD_SetInValid(PD_AnalysisConfig* config) {
413 414 415 416
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
417 418 419 420
  config->config.SetInValid();
}

bool PD_IsValid(const PD_AnalysisConfig* config) {
421 422 423 424
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
425 426
  return config->config.is_valid();
}
F
flame 已提交
427 428 429 430 431 432 433 434

void PD_DisableGlogInfo(PD_AnalysisConfig* config) {
  config->config.DisableGlogInfo();
}

void PD_DeletePass(PD_AnalysisConfig* config, char* pass_name) {
  return config->config.pass_builder()->DeletePass(std::string(pass_name));
}
435
}  // extern "C"