pd_config.cc 15.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include <limits>
#include <map>
#include <memory>
#include <string>
#include <vector>
21

22
#include "paddle/fluid/inference/capi/c_api_internal.h"
F
flame 已提交
23
#include "paddle/fluid/inference/capi/paddle_c_api.h"
24
#include "paddle/fluid/platform/enforce.h"
25

F
flame 已提交
26
using paddle::ConvertToACPrecision;
27 28 29 30 31 32 33 34 35 36 37
using paddle::ConvertToPaddleDType;
using paddle::ConvertToPDDataType;

extern "C" {

PD_AnalysisConfig* PD_NewAnalysisConfig() { return new PD_AnalysisConfig; }  //

void PD_DeleteAnalysisConfig(PD_AnalysisConfig* config) {
  if (config) {
    delete config;
    config = nullptr;
38
    VLOG(3) << "PD_AnalysisConfig delete successfully. ";
39 40 41
  }
}

42 43
void PD_SetModel(PD_AnalysisConfig* config,
                 const char* model_dir,
44 45
                 const char* params_path) {
  LOG(INFO) << model_dir;
46 47 48 49
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
50 51 52 53 54 55 56 57 58
  LOG(INFO) << std::string(model_dir);
  if (!params_path) {
    config->config.SetModel(std::string(model_dir));
  } else {
    config->config.SetModel(std::string(model_dir), std::string(params_path));
  }
}

void PD_SetProgFile(PD_AnalysisConfig* config, const char* x) {
59 60 61 62
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
63 64 65 66
  config->config.SetProgFile(std::string(x));
}

void PD_SetParamsFile(PD_AnalysisConfig* config, const char* x) {
67 68 69 70
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
71 72 73 74
  config->config.SetParamsFile(std::string(x));
}

void PD_SetOptimCacheDir(PD_AnalysisConfig* config, const char* opt_cache_dir) {
75 76 77 78
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
79 80 81 82
  config->config.SetOptimCacheDir(std::string(opt_cache_dir));
}

const char* PD_ModelDir(const PD_AnalysisConfig* config) {
83 84 85 86
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
87 88 89 90
  return config->config.model_dir().c_str();
}

const char* PD_ProgFile(const PD_AnalysisConfig* config) {
91 92 93 94
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
95 96 97 98
  return config->config.prog_file().c_str();
}

const char* PD_ParamsFile(const PD_AnalysisConfig* config) {
99 100 101 102
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
103 104 105
  return config->config.params_file().c_str();
}

106 107
void PD_EnableUseGpu(PD_AnalysisConfig* config,
                     int memory_pool_init_size_mb,
F
flame 已提交
108
                     int device_id) {
109 110 111 112
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
F
flame 已提交
113 114
  config->config.EnableUseGpu(static_cast<uint64_t>(memory_pool_init_size_mb),
                              device_id);
115 116
}

117 118 119 120 121 122 123 124
void PD_EnableXpu(PD_AnalysisConfig* config, int l3_workspace_size) {
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
  config->config.EnableXpu(l3_workspace_size);
}

125
void PD_DisableGpu(PD_AnalysisConfig* config) {
126 127 128 129
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
130 131 132 133
  config->config.DisableGpu();
}

bool PD_UseGpu(const PD_AnalysisConfig* config) {
134 135 136 137
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
138 139 140
  return config->config.use_gpu();
}

141 142 143 144 145 146 147 148
bool PD_UseXpu(const PD_AnalysisConfig* config) {
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
  return config->config.use_xpu();
}

149
int PD_GpuDeviceId(const PD_AnalysisConfig* config) {
150 151 152 153
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
154 155 156
  return config->config.gpu_device_id();
}

157 158 159 160 161 162 163 164
int PD_XpuDeviceId(const PD_AnalysisConfig* config) {
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
  return config->config.xpu_device_id();
}

165
int PD_MemoryPoolInitSizeMb(const PD_AnalysisConfig* config) {
166 167 168 169
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
170 171 172 173
  return config->config.memory_pool_init_size_mb();
}

float PD_FractionOfGpuMemoryForPool(const PD_AnalysisConfig* config) {
174 175 176 177
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
178 179 180 181
  return config->config.fraction_of_gpu_memory_for_pool();
}

void PD_EnableCUDNN(PD_AnalysisConfig* config) {
182 183 184 185
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
186 187 188 189
  config->config.EnableCUDNN();
}

bool PD_CudnnEnabled(const PD_AnalysisConfig* config) {
190 191 192 193
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
194 195 196 197
  return config->config.cudnn_enabled();
}

void PD_SwitchIrOptim(PD_AnalysisConfig* config, bool x) {
198 199 200 201
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
202 203 204 205
  config->config.SwitchIrOptim(x);
}

bool PD_IrOptim(const PD_AnalysisConfig* config) {
206 207 208 209
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
210 211 212 213
  return config->config.ir_optim();
}

void PD_SwitchUseFeedFetchOps(PD_AnalysisConfig* config, bool x) {
214 215 216 217
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
218 219 220 221
  config->config.SwitchUseFeedFetchOps(x);
}

bool PD_UseFeedFetchOpsEnabled(const PD_AnalysisConfig* config) {
222 223 224 225
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
226 227 228 229
  return config->config.use_feed_fetch_ops_enabled();
}

void PD_SwitchSpecifyInputNames(PD_AnalysisConfig* config, bool x) {
230 231 232 233
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
234 235 236 237
  config->config.SwitchSpecifyInputNames(x);
}

bool PD_SpecifyInputName(const PD_AnalysisConfig* config) {
238 239 240 241
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
242 243 244
  return config->config.specify_input_name();
}

245
void PD_EnableTensorRtEngine(PD_AnalysisConfig* config,
246
                             int64_t workspace_size,
247 248 249 250
                             int max_batch_size,
                             int min_subgraph_size,
                             Precision precision,
                             bool use_static,
251
                             bool use_calib_mode) {
252 253 254 255
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
256 257 258 259 260 261
  config->config.EnableTensorRtEngine(workspace_size,
                                      max_batch_size,
                                      min_subgraph_size,
                                      paddle::ConvertToACPrecision(precision),
                                      use_static,
                                      use_calib_mode);
262 263 264
}

bool PD_TensorrtEngineEnabled(const PD_AnalysisConfig* config) {
265 266 267 268
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
269 270 271
  return config->config.tensorrt_engine_enabled();
}

D
denglin-github 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
void PD_EnableDlnne(PD_AnalysisConfig* config, int min_subgraph_size) {
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
  config->config.EnableDlnne(min_subgraph_size);
}

bool PD_DlnneEnabled(const PD_AnalysisConfig* config) {
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
  return config->config.dlnne_enabled();
}

288
void PD_SwitchIrDebug(PD_AnalysisConfig* config, bool x) {
289 290 291 292
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
293 294 295 296
  config->config.SwitchIrDebug(x);
}

void PD_EnableMKLDNN(PD_AnalysisConfig* config) {
297 298 299 300
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
301 302 303 304
  config->config.EnableMKLDNN();
}

void PD_SetMkldnnCacheCapacity(PD_AnalysisConfig* config, int capacity) {
305 306 307 308
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
309 310 311 312
  config->config.SetMkldnnCacheCapacity(capacity);
}

bool PD_MkldnnEnabled(const PD_AnalysisConfig* config) {
313 314 315 316
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
317 318 319 320 321
  return config->config.mkldnn_enabled();
}

void PD_SetCpuMathLibraryNumThreads(PD_AnalysisConfig* config,
                                    int cpu_math_library_num_threads) {
322 323 324 325
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
326 327 328 329
  config->config.SetCpuMathLibraryNumThreads(cpu_math_library_num_threads);
}

int PD_CpuMathLibraryNumThreads(const PD_AnalysisConfig* config) {
330 331 332 333
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
334 335 336 337
  return config->config.cpu_math_library_num_threads();
}

void PD_EnableMkldnnQuantizer(PD_AnalysisConfig* config) {
338 339 340 341
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
342 343 344 345
  config->config.EnableMkldnnQuantizer();
}

bool PD_MkldnnQuantizerEnabled(const PD_AnalysisConfig* config) {
346 347 348 349
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
350 351 352
  return config->config.mkldnn_quantizer_enabled();
}

353
void PD_EnableMkldnnBfloat16(PD_AnalysisConfig* config) {
354 355 356 357
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
358 359 360 361
  config->config.EnableMkldnnBfloat16();
}

bool PD_MkldnnBfloat16Enabled(const PD_AnalysisConfig* config) {
362 363 364 365
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
366 367 368
  return config->config.mkldnn_bfloat16_enabled();
}

369 370 371 372
void PD_SetModelBuffer(PD_AnalysisConfig* config,
                       const char* prog_buffer,
                       size_t prog_buffer_size,
                       const char* params_buffer,
373
                       size_t params_buffer_size) {
374 375 376 377
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
378 379
  config->config.SetModelBuffer(
      prog_buffer, prog_buffer_size, params_buffer, params_buffer_size);
380 381 382
}

bool PD_ModelFromMemory(const PD_AnalysisConfig* config) {
383 384 385 386
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
387 388 389 390
  return config->config.model_from_memory();
}

void PD_EnableMemoryOptim(PD_AnalysisConfig* config) {
391 392 393 394
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
395 396 397 398
  config->config.EnableMemoryOptim();
}

bool PD_MemoryOptimEnabled(const PD_AnalysisConfig* config) {
399 400 401 402
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
403 404 405 406
  return config->config.enable_memory_optim();
}

void PD_EnableProfile(PD_AnalysisConfig* config) {
407 408 409 410
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
411 412 413 414
  config->config.EnableProfile();
}

bool PD_ProfileEnabled(const PD_AnalysisConfig* config) {
415 416 417 418
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
419 420 421 422
  return config->config.profile_enabled();
}

void PD_SetInValid(PD_AnalysisConfig* config) {
423 424 425 426
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
427 428 429 430
  config->config.SetInValid();
}

bool PD_IsValid(const PD_AnalysisConfig* config) {
431 432 433 434
  PADDLE_ENFORCE_NOT_NULL(
      config,
      paddle::platform::errors::InvalidArgument(
          "The pointer of analysis configuration shouldn't be nullptr"));
435 436
  return config->config.is_valid();
}
F
flame 已提交
437 438 439 440 441 442 443 444

void PD_DisableGlogInfo(PD_AnalysisConfig* config) {
  config->config.DisableGlogInfo();
}

void PD_DeletePass(PD_AnalysisConfig* config, char* pass_name) {
  return config->config.pass_builder()->DeletePass(std::string(pass_name));
}
445
}  // extern "C"