inference_api.cc 38.6 KB
Newer Older
F
flame 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/inference_api.h"
16

17
#include <pybind11/numpy.h>
F
flame 已提交
18
#include <pybind11/stl.h>
19

F
flame 已提交
20
#include <cstring>
21
#include <functional>
F
flame 已提交
22
#include <iostream>
23
#include <iterator>
24
#include <map>
25
#include <memory>
F
flame 已提交
26
#include <string>
27
#include <type_traits>
28
#include <unordered_set>
29
#include <utility>
F
flame 已提交
30
#include <vector>
31

F
flame 已提交
32
#include "paddle/fluid/inference/api/analysis_predictor.h"
33
#include "paddle/fluid/inference/api/helper.h"
34
#include "paddle/fluid/inference/api/paddle_infer_contrib.h"
F
flame 已提交
35
#include "paddle/fluid/inference/api/paddle_inference_api.h"
36
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
37
#include "paddle/fluid/inference/utils/io_utils.h"
F
flame 已提交
38

39 40 41 42
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "paddle/fluid/inference/api/onnxruntime_predictor.h"
#endif

F
flame 已提交
43 44
namespace py = pybind11;

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
namespace pybind11 {
namespace detail {

// Note: use same enum number of float16 in numpy.
// import numpy as np
// print np.dtype(np.float16).num  # 23
constexpr int NPY_FLOAT16_ = 23;
constexpr int NPY_UINT16_ = 4;

// Note: Since float16 is not a builtin type in C++, we register
// paddle::platform::float16 as numpy.float16.
// Ref: https://github.com/pybind/pybind11/issues/1776
template <>
struct npy_format_descriptor<paddle_infer::float16> {
  static py::dtype dtype() {
    handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_);
    return reinterpret_borrow<py::dtype>(ptr);
  }
  static std::string format() {
    // Note: "e" represents float16.
    // Details at:
    // https://docs.python.org/3/library/struct.html#format-characters.
    return "e";
  }
  static constexpr auto name = _("float16");
};

}  // namespace detail
}  // namespace pybind11

F
flame 已提交
75 76
namespace paddle {
namespace pybind {
77 78 79
using paddle::AnalysisPredictor;
using paddle::NativeConfig;
using paddle::NativePaddlePredictor;
F
flame 已提交
80
using paddle::PaddleBuf;
81
using paddle::PaddleDataLayout;
82
using paddle::PaddleDType;
83
using paddle::PaddlePassBuilder;
F
flame 已提交
84 85
using paddle::PaddlePlace;
using paddle::PaddlePredictor;
86 87 88
using paddle::PaddleTensor;
using paddle::PassStrategy;
using paddle::ZeroCopyTensor;
F
flame 已提交
89

90 91
namespace {
void BindPaddleDType(py::module *m);
92
void BindPaddleDataLayout(py::module *m);
93 94 95 96 97 98
void BindPaddleBuf(py::module *m);
void BindPaddleTensor(py::module *m);
void BindPaddlePlace(py::module *m);
void BindPaddlePredictor(py::module *m);
void BindNativeConfig(py::module *m);
void BindNativePredictor(py::module *m);
99
void BindLiteNNAdapterConfig(py::module *m);
100 101
void BindAnalysisConfig(py::module *m);
void BindAnalysisPredictor(py::module *m);
102 103
void BindZeroCopyTensor(py::module *m);
void BindPaddlePassBuilder(py::module *m);
W
Wilber 已提交
104 105 106
void BindPaddleInferPredictor(py::module *m);
void BindPaddleInferTensor(py::module *m);
void BindPredictorPool(py::module *m);
F
flame 已提交
107

108
#ifdef PADDLE_WITH_MKLDNN
109
void BindMkldnnQuantizerConfig(py::module *m);
110
#endif
111 112

template <typename T>
113 114
PaddleBuf PaddleBufCreate(
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {
115
  PaddleBuf buf(data.size() * sizeof(T));
W
Wilber 已提交
116 117
  std::copy_n(static_cast<const T *>(data.data()),
              data.size(),
118 119 120 121 122
              static_cast<T *>(buf.data()));
  return buf;
}

template <typename T>
123 124 125
void PaddleBufReset(
    PaddleBuf &buf,                                                    // NOLINT
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {  // NOLINT
126
  buf.Resize(data.size() * sizeof(T));
W
Wilber 已提交
127 128
  std::copy_n(static_cast<const T *>(data.data()),
              data.size(),
129 130 131 132 133
              static_cast<T *>(buf.data()));
}

template <typename T>
PaddleTensor PaddleTensorCreate(
134 135
    py::array_t<T, py::array::c_style | py::array::forcecast> data,
    const std::string name = "",
W
Wilber 已提交
136 137
    const std::vector<std::vector<size_t>> &lod = {},
    bool copy = true) {
138 139 140 141
  PaddleTensor tensor;

  if (copy) {
    PaddleBuf buf(data.size() * sizeof(T));
W
Wilber 已提交
142 143
    std::copy_n(static_cast<const T *>(data.data()),
                data.size(),
144 145 146 147 148 149
                static_cast<T *>(buf.data()));
    tensor.data = std::move(buf);
  } else {
    tensor.data = PaddleBuf(data.mutable_data(), data.size() * sizeof(T));
  }

150
  tensor.dtype = inference::PaddleTensorGetDType<T>();
151 152 153 154 155 156 157 158
  tensor.name = name;
  tensor.lod = lod;
  tensor.shape.resize(data.ndim());
  std::copy_n(data.shape(), data.ndim(), tensor.shape.begin());

  return tensor;
}

159
py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) {
160
  py::dtype dt;
161
  switch (dtype) {
162 163 164 165 166 167 168 169 170
    case PaddleDType::INT32:
      dt = py::dtype::of<int32_t>();
      break;
    case PaddleDType::INT64:
      dt = py::dtype::of<int64_t>();
      break;
    case PaddleDType::FLOAT32:
      dt = py::dtype::of<float>();
      break;
W
Wilber 已提交
171 172 173
    case PaddleDType::UINT8:
      dt = py::dtype::of<uint8_t>();
      break;
174 175 176
    case PaddleDType::FLOAT16:
      dt = py::dtype::of<paddle_infer::float16>();
      break;
177
    default:
178
      PADDLE_THROW(platform::errors::Unimplemented(
W
Wilber 已提交
179
          "Unsupported data type. Now only supports INT32, INT64, UINT8 and "
180
          "FLOAT32."));
181
  }
182 183 184 185 186 187 188 189 190 191

  return dt;
}

py::array PaddleTensorGetData(PaddleTensor &tensor) {  // NOLINT
  py::dtype dt = PaddleDTypeToNumpyDType(tensor.dtype);
  return py::array(std::move(dt), {tensor.shape}, tensor.data.data());
}

template <typename T>
192 193 194
void ZeroCopyTensorCreate(
    ZeroCopyTensor &tensor,  // NOLINT
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {
195 196 197 198 199 200
  std::vector<int> shape;
  std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape));
  tensor.Reshape(std::move(shape));
  tensor.copy_from_cpu(static_cast<const T *>(data.data()));
}

S
Steffy-zxf 已提交
201 202 203 204 205 206 207 208 209 210 211 212
/// \brief Experimental interface.
/// Create the Strings tensor from data.
/// \param tensor The tensor will be created and
/// the tensor value is same as data.
/// \param data The input text.
void ZeroCopyStringTensorCreate(ZeroCopyTensor &tensor,  // NOLINT
                                const paddle_infer::Strings *data) {
  size_t shape = data->size();
  tensor.ReshapeStrings(shape);
  tensor.copy_strings_from_cpu(data);
}

W
Wilber 已提交
213
template <typename T>
214 215 216
void PaddleInferTensorCreate(
    paddle_infer::Tensor &tensor,  // NOLINT
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {
W
Wilber 已提交
217 218 219 220 221 222
  std::vector<int> shape;
  std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape));
  tensor.Reshape(std::move(shape));
  tensor.CopyFromCpu(static_cast<const T *>(data.data()));
}

223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
paddle_infer::PlaceType ToPaddleInferPlace(
    phi::AllocationType allocation_type) {
  if (allocation_type == phi::AllocationType::CPU) {
    return paddle_infer::PlaceType::kCPU;
  } else if (allocation_type == phi::AllocationType::GPU) {
    return paddle_infer::PlaceType::kGPU;
  } else {
    return paddle_infer::PlaceType::kCPU;
  }
}

void PaddleInferShareExternalData(paddle_infer::Tensor &tensor,  // NOLINT
                                  framework::Tensor input_tensor) {
  std::vector<int> shape;
  for (int i = 0; i < input_tensor.dims().size(); ++i) {
    shape.push_back(input_tensor.dims()[i]);
  }
  if (input_tensor.dtype() == phi::DataType::FLOAT32) {
    tensor.ShareExternalData(
W
Wilber 已提交
242 243
        static_cast<float *>(input_tensor.data()),
        shape,
244 245 246
        ToPaddleInferPlace(input_tensor.place().GetType()));
  } else if (input_tensor.dtype() == phi::DataType::FLOAT16) {
    tensor.ShareExternalData(
W
Wilber 已提交
247 248
        static_cast<paddle::platform::float16 *>(input_tensor.data()),
        shape,
249 250 251 252
        ToPaddleInferPlace(input_tensor.place().GetType()));
  }
}

S
Steffy-zxf 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265
/// \brief Experimental interface.
/// Create the Strings tensor from data.
/// \param tensor The tensor will be created and
/// the tensor value is same as data.
/// \param data The input text.
void PaddleInferStringTensorCreate(paddle_infer::Tensor &tensor,  // NOLINT
                                   const paddle_infer::Strings *data) {
  VLOG(3) << "Create PaddleInferTensor, dtype = Strings ";
  size_t shape = data->size();
  tensor.ReshapeStrings(shape);
  tensor.CopyStringsFromCpu(data);
}

266 267 268 269 270 271 272 273 274 275 276 277 278
size_t PaddleGetDTypeSize(PaddleDType dt) {
  size_t size{0};
  switch (dt) {
    case PaddleDType::INT32:
      size = sizeof(int32_t);
      break;
    case PaddleDType::INT64:
      size = sizeof(int64_t);
      break;
    case PaddleDType::FLOAT32:
      size = sizeof(float);
      break;
    default:
279 280 281
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported data type. Now only supports INT32, INT64 and "
          "FLOAT32."));
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
  }
  return size;
}

py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) {  // NOLINT
  py::dtype dt = PaddleDTypeToNumpyDType(tensor.type());
  auto tensor_shape = tensor.shape();
  py::array::ShapeContainer shape(tensor_shape.begin(), tensor_shape.end());
  py::array array(dt, std::move(shape));

  switch (tensor.type()) {
    case PaddleDType::INT32:
      tensor.copy_to_cpu(static_cast<int32_t *>(array.mutable_data()));
      break;
    case PaddleDType::INT64:
      tensor.copy_to_cpu(static_cast<int64_t *>(array.mutable_data()));
      break;
    case PaddleDType::FLOAT32:
      tensor.copy_to_cpu<float>(static_cast<float *>(array.mutable_data()));
      break;
302 303 304 305
    case PaddleDType::FLOAT16:
      tensor.copy_to_cpu<paddle::platform::float16>(
          static_cast<paddle::platform::float16 *>(array.mutable_data()));
      break;
W
Wilber 已提交
306 307 308
    case PaddleDType::UINT8:
      tensor.copy_to_cpu<uint8_t>(static_cast<uint8_t *>(array.mutable_data()));
      break;
309 310 311
    case PaddleDType::INT8:
      tensor.copy_to_cpu<int8_t>(static_cast<int8_t *>(array.mutable_data()));
      break;
312
    default:
313
      PADDLE_THROW(platform::errors::Unimplemented(
W
Wilber 已提交
314
          "Unsupported data type. Now only supports INT32, INT64, UINT8 and "
315
          "FLOAT32."));
316 317
  }
  return array;
318
}
319

W
Wilber 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) {  // NOLINT
  py::dtype dt = PaddleDTypeToNumpyDType(tensor.type());
  auto tensor_shape = tensor.shape();
  py::array::ShapeContainer shape(tensor_shape.begin(), tensor_shape.end());
  py::array array(dt, std::move(shape));

  switch (tensor.type()) {
    case PaddleDType::INT32:
      tensor.CopyToCpu(static_cast<int32_t *>(array.mutable_data()));
      break;
    case PaddleDType::INT64:
      tensor.CopyToCpu(static_cast<int64_t *>(array.mutable_data()));
      break;
    case PaddleDType::FLOAT32:
      tensor.CopyToCpu<float>(static_cast<float *>(array.mutable_data()));
      break;
336 337 338 339
    case PaddleDType::FLOAT16:
      tensor.CopyToCpu<paddle::platform::float16>(
          static_cast<paddle::platform::float16 *>(array.mutable_data()));
      break;
340 341 342 343 344 345
    case PaddleDType::UINT8:
      tensor.CopyToCpu(static_cast<uint8_t *>(array.mutable_data()));
      break;
    case PaddleDType::INT8:
      tensor.CopyToCpu(static_cast<int8_t *>(array.mutable_data()));
      break;
W
Wilber 已提交
346 347 348 349 350 351 352 353
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported data type. Now only supports INT32, INT64 and "
          "FLOAT32."));
  }
  return array;
}

354 355 356 357 358
py::bytes SerializePDTensorToBytes(PaddleTensor &tensor) {  // NOLINT
  std::stringstream ss;
  paddle::inference::SerializePDTensorToStream(&ss, tensor);
  return static_cast<py::bytes>(ss.str());
}
359

360
void CopyPaddleInferTensor(paddle_infer::Tensor &dst,  // NOLINT
361 362 363 364
                           const paddle_infer::Tensor &src) {
  return paddle_infer::contrib::TensorUtils::CopyTensor(&dst, src);
}

365
}  // namespace
366

F
flame 已提交
367 368
void BindInferenceApi(py::module *m) {
  BindPaddleDType(m);
369
  BindPaddleDataLayout(m);
F
flame 已提交
370 371 372 373 374 375
  BindPaddleBuf(m);
  BindPaddleTensor(m);
  BindPaddlePlace(m);
  BindPaddlePredictor(m);
  BindNativeConfig(m);
  BindNativePredictor(m);
376
  BindLiteNNAdapterConfig(m);
F
flame 已提交
377 378
  BindAnalysisConfig(m);
  BindAnalysisPredictor(m);
W
Wilber 已提交
379
  BindPaddleInferPredictor(m);
380
  BindZeroCopyTensor(m);
W
Wilber 已提交
381
  BindPaddleInferTensor(m);
382
  BindPaddlePassBuilder(m);
W
Wilber 已提交
383
  BindPredictorPool(m);
384 385 386
#ifdef PADDLE_WITH_MKLDNN
  BindMkldnnQuantizerConfig(m);
#endif
F
flame 已提交
387
  m->def("create_paddle_predictor",
W
Wilber 已提交
388 389
         &paddle::CreatePaddlePredictor<AnalysisConfig>,
         py::arg("config"));
F
flame 已提交
390
  m->def("create_paddle_predictor",
W
Wilber 已提交
391 392
         &paddle::CreatePaddlePredictor<NativeConfig>,
         py::arg("config"));
393 394 395 396 397 398 399
  m->def("create_predictor",
         [](const paddle_infer::Config &config)
             -> std::unique_ptr<paddle_infer::Predictor> {
           auto pred = std::unique_ptr<paddle_infer::Predictor>(
               new paddle_infer::Predictor(config));
           return pred;
         });
400
  m->def("copy_tensor", &CopyPaddleInferTensor);
F
flame 已提交
401
  m->def("paddle_dtype_size", &paddle::PaddleDtypeSize);
402
  m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes);
W
Wilber 已提交
403
  m->def("get_version", &paddle_infer::GetVersion);
404 405
  m->def("get_trt_compile_version", &paddle_infer::GetTrtCompileVersion);
  m->def("get_trt_runtime_version", &paddle_infer::GetTrtRuntimeVersion);
W
Wilber 已提交
406
  m->def("get_num_bytes_of_data_type", &paddle_infer::GetNumBytesOfDataType);
F
flame 已提交
407 408
}

409
namespace {
F
flame 已提交
410 411 412
void BindPaddleDType(py::module *m) {
  py::enum_<PaddleDType>(*m, "PaddleDType")
      .value("FLOAT32", PaddleDType::FLOAT32)
413 414
      .value("INT64", PaddleDType::INT64)
      .value("INT32", PaddleDType::INT32);
F
flame 已提交
415 416
}

417 418 419 420 421 422 423 424
void BindPaddleDataLayout(py::module *m) {
  py::enum_<PaddleDataLayout>(*m, "PaddleDataLayout")
      .value("UNK", PaddleDataLayout::kUNK)
      .value("Any", PaddleDataLayout::kAny)
      .value("NHWC", PaddleDataLayout::kNHWC)
      .value("NCHW", PaddleDataLayout::kNCHW);
}

F
flame 已提交
425 426 427 428 429 430
void BindPaddleBuf(py::module *m) {
  py::class_<PaddleBuf>(*m, "PaddleBuf")
      .def(py::init<size_t>())
      .def(py::init([](std::vector<float> &data) {
        auto buf = PaddleBuf(data.size() * sizeof(float));
        std::memcpy(buf.data(), static_cast<void *>(data.data()), buf.length());
G
Gabor Buella 已提交
431
        return buf;
F
flame 已提交
432
      }))
433 434 435
      .def(py::init(&PaddleBufCreate<int32_t>))
      .def(py::init(&PaddleBufCreate<int64_t>))
      .def(py::init(&PaddleBufCreate<float>))
F
flame 已提交
436 437 438 439 440 441
      .def("resize", &PaddleBuf::Resize)
      .def("reset",
           [](PaddleBuf &self, std::vector<float> &data) {
             self.Resize(data.size() * sizeof(float));
             std::memcpy(self.data(), data.data(), self.length());
           })
442 443 444
      .def("reset", &PaddleBufReset<int32_t>)
      .def("reset", &PaddleBufReset<int64_t>)
      .def("reset", &PaddleBufReset<float>)
445
      .def("empty", &PaddleBuf::empty)
446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
      .def("tolist",
           [](PaddleBuf &self, const std::string &dtype) -> py::list {
             py::list l;
             if (dtype == "int32") {
               auto *data = static_cast<int32_t *>(self.data());
               auto size = self.length() / sizeof(int32_t);
               l = py::cast(std::vector<int32_t>(data, data + size));
             } else if (dtype == "int64") {
               auto *data = static_cast<int64_t *>(self.data());
               auto size = self.length() / sizeof(int64_t);
               l = py::cast(std::vector<int64_t>(data, data + size));
             } else if (dtype == "float32") {
               auto *data = static_cast<float *>(self.data());
               auto size = self.length() / sizeof(float);
               l = py::cast(std::vector<float>(data, data + size));
             } else {
462 463 464
               PADDLE_THROW(platform::errors::Unimplemented(
                   "Unsupported data type. Now only supports INT32, INT64 and "
                   "FLOAT32."));
465 466 467
             }
             return l;
           })
F
flame 已提交
468 469 470 471 472 473 474 475 476 477
      .def("float_data",
           [](PaddleBuf &self) -> std::vector<float> {
             auto *data = static_cast<float *>(self.data());
             return {data, data + self.length() / sizeof(*data)};
           })
      .def("int64_data",
           [](PaddleBuf &self) -> std::vector<int64_t> {
             int64_t *data = static_cast<int64_t *>(self.data());
             return {data, data + self.length() / sizeof(*data)};
           })
478 479 480 481
      .def("int32_data",
           [](PaddleBuf &self) -> std::vector<int32_t> {
             int32_t *data = static_cast<int32_t *>(self.data());
             return {data, data + self.length() / sizeof(*data)};
F
flame 已提交
482 483 484 485 486 487 488
           })
      .def("length", &PaddleBuf::length);
}

void BindPaddleTensor(py::module *m) {
  py::class_<PaddleTensor>(*m, "PaddleTensor")
      .def(py::init<>())
W
Wilber 已提交
489 490
      .def(py::init(&PaddleTensorCreate<int32_t>),
           py::arg("data"),
491 492 493
           py::arg("name") = "",
           py::arg("lod") = std::vector<std::vector<size_t>>(),
           py::arg("copy") = true)
W
Wilber 已提交
494 495
      .def(py::init(&PaddleTensorCreate<int64_t>),
           py::arg("data"),
496 497 498
           py::arg("name") = "",
           py::arg("lod") = std::vector<std::vector<size_t>>(),
           py::arg("copy") = true)
W
Wilber 已提交
499 500
      .def(py::init(&PaddleTensorCreate<float>),
           py::arg("data"),
501 502 503 504
           py::arg("name") = "",
           py::arg("lod") = std::vector<std::vector<size_t>>(),
           py::arg("copy") = true)
      .def("as_ndarray", &PaddleTensorGetData)
F
flame 已提交
505 506 507 508 509 510 511 512 513 514 515
      .def_readwrite("name", &PaddleTensor::name)
      .def_readwrite("shape", &PaddleTensor::shape)
      .def_readwrite("data", &PaddleTensor::data)
      .def_readwrite("dtype", &PaddleTensor::dtype)
      .def_readwrite("lod", &PaddleTensor::lod);
}

void BindPaddlePlace(py::module *m) {
  py::enum_<PaddlePlace>(*m, "PaddlePlace")
      .value("UNK", PaddlePlace::kUNK)
      .value("CPU", PaddlePlace::kCPU)
516
      .value("GPU", PaddlePlace::kGPU)
W
Wilber 已提交
517 518
      .value("XPU", PaddlePlace::kXPU)
      .value("NPU", PaddlePlace::kNPU);
F
flame 已提交
519 520 521 522 523 524 525 526 527 528 529 530 531
}

void BindPaddlePredictor(py::module *m) {
  auto paddle_predictor = py::class_<PaddlePredictor>(*m, "PaddlePredictor");
  paddle_predictor
      .def("run",
           [](PaddlePredictor &self, const std::vector<PaddleTensor> &inputs) {
             std::vector<PaddleTensor> outputs;
             self.Run(inputs, &outputs);
             return outputs;
           })
      .def("get_input_tensor", &PaddlePredictor::GetInputTensor)
      .def("get_output_tensor", &PaddlePredictor::GetOutputTensor)
532 533
      .def("get_input_names", &PaddlePredictor::GetInputNames)
      .def("get_output_names", &PaddlePredictor::GetOutputNames)
F
flame 已提交
534
      .def("zero_copy_run", &PaddlePredictor::ZeroCopyRun)
535 536
      .def("clone", &PaddlePredictor::Clone)
      .def("get_serialized_program", &PaddlePredictor::GetSerializedProgram);
F
flame 已提交
537 538 539 540 541 542 543 544 545 546

  auto config = py::class_<PaddlePredictor::Config>(paddle_predictor, "Config");
  config.def(py::init<>())
      .def_readwrite("model_dir", &PaddlePredictor::Config::model_dir);
}

void BindNativeConfig(py::module *m) {
  py::class_<NativeConfig, PaddlePredictor::Config>(*m, "NativeConfig")
      .def(py::init<>())
      .def_readwrite("use_gpu", &NativeConfig::use_gpu)
547
      .def_readwrite("use_xpu", &NativeConfig::use_xpu)
W
Wilber 已提交
548
      .def_readwrite("use_npu", &NativeConfig::use_npu)
F
flame 已提交
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576
      .def_readwrite("device", &NativeConfig::device)
      .def_readwrite("fraction_of_gpu_memory",
                     &NativeConfig::fraction_of_gpu_memory)
      .def_readwrite("prog_file", &NativeConfig::prog_file)
      .def_readwrite("param_file", &NativeConfig::param_file)
      .def_readwrite("specify_input_name", &NativeConfig::specify_input_name)
      .def("set_cpu_math_library_num_threads",
           &NativeConfig::SetCpuMathLibraryNumThreads)
      .def("cpu_math_library_num_threads",
           &NativeConfig::cpu_math_library_num_threads);
}

void BindNativePredictor(py::module *m) {
  py::class_<NativePaddlePredictor, PaddlePredictor>(*m,
                                                     "NativePaddlePredictor")
      .def(py::init<const NativeConfig &>())
      .def("init", &NativePaddlePredictor::Init)
      .def("run",
           [](NativePaddlePredictor &self,
              const std::vector<PaddleTensor> &inputs) {
             std::vector<PaddleTensor> outputs;
             self.Run(inputs, &outputs);
             return outputs;
           })
      .def("get_input_tensor", &NativePaddlePredictor::GetInputTensor)
      .def("get_output_tensor", &NativePaddlePredictor::GetOutputTensor)
      .def("zero_copy_run", &NativePaddlePredictor::ZeroCopyRun)
      .def("clone", &NativePaddlePredictor::Clone)
W
Wilber 已提交
577 578
      .def("scope",
           &NativePaddlePredictor::scope,
F
flame 已提交
579 580 581 582
           py::return_value_policy::reference);
}

void BindAnalysisConfig(py::module *m) {
583 584 585 586 587
  py::class_<AnalysisConfig> analysis_config(*m, "AnalysisConfig");

  py::enum_<AnalysisConfig::Precision>(analysis_config, "Precision")
      .value("Float32", AnalysisConfig::Precision::kFloat32)
      .value("Int8", AnalysisConfig::Precision::kInt8)
Z
Zhaolong Xing 已提交
588
      .value("Half", AnalysisConfig::Precision::kHalf)
589 590
      .export_values();

591 592
  analysis_config.def(py::init<>())
      .def(py::init<const AnalysisConfig &>())
F
flame 已提交
593 594
      .def(py::init<const std::string &>())
      .def(py::init<const std::string &, const std::string &>())
595
      .def("summary", &AnalysisConfig::Summary)
W
Wilber 已提交
596 597 598
      .def("set_model",
           (void(AnalysisConfig::*)(const std::string &)) &
               AnalysisConfig::SetModel)
599 600 601
      .def("set_model",
           (void(AnalysisConfig::*)(const std::string &, const std::string &)) &
               AnalysisConfig::SetModel)
F
flame 已提交
602 603 604 605 606
      .def("set_prog_file", &AnalysisConfig::SetProgFile)
      .def("set_params_file", &AnalysisConfig::SetParamsFile)
      .def("model_dir", &AnalysisConfig::model_dir)
      .def("prog_file", &AnalysisConfig::prog_file)
      .def("params_file", &AnalysisConfig::params_file)
W
Wilber 已提交
607 608 609 610 611 612
      .def("enable_use_gpu",
           &AnalysisConfig::EnableUseGpu,
           py::arg("memory_pool_init_size_mb"),
           py::arg("device_id") = 0)
      .def("enable_xpu",
           &AnalysisConfig::EnableXpu,
W
Wilber 已提交
613
           py::arg("l3_workspace_size") = 16 * 1024 * 1024,
W
Wilber 已提交
614 615 616 617
           py::arg("locked") = false,
           py::arg("autotune") = true,
           py::arg("autotune_file") = "",
           py::arg("precision") = "int16",
W
Wilber 已提交
618
           py::arg("adaptive_seqlen") = false)
W
Wilber 已提交
619 620
      .def("set_xpu_device_id",
           &AnalysisConfig::SetXpuDeviceId,
621
           py::arg("device_id") = 0)
W
Wilber 已提交
622
      .def("enable_npu", &AnalysisConfig::EnableNpu, py::arg("device_id") = 0)
W
Wilber 已提交
623 624 625 626
      .def("enable_ipu",
           &AnalysisConfig::EnableIpu,
           py::arg("ipu_device_num") = 1,
           py::arg("ipu_micro_batch_size") = 1,
627 628
           py::arg("ipu_enable_pipelining") = false,
           py::arg("ipu_batches_per_step") = 1)
W
Wilber 已提交
629 630 631 632
      .def("set_ipu_config",
           &AnalysisConfig::SetIpuConfig,
           py::arg("ipu_enable_fp16") = false,
           py::arg("ipu_replica_num") = 1,
633 634
           py::arg("ipu_available_memory_proportion") = 1.0,
           py::arg("ipu_enable_half_partial") = false)
F
flame 已提交
635
      .def("disable_gpu", &AnalysisConfig::DisableGpu)
636 637 638 639
      .def("enable_onnxruntime", &AnalysisConfig::EnableONNXRuntime)
      .def("disable_onnxruntime", &AnalysisConfig::DisableONNXRuntime)
      .def("onnxruntime_enabled", &AnalysisConfig::use_onnxruntime)
      .def("enable_ort_optimization", &AnalysisConfig::EnableORTOptimization)
F
flame 已提交
640
      .def("use_gpu", &AnalysisConfig::use_gpu)
641
      .def("use_xpu", &AnalysisConfig::use_xpu)
W
Wilber 已提交
642
      .def("use_npu", &AnalysisConfig::use_npu)
F
flame 已提交
643
      .def("gpu_device_id", &AnalysisConfig::gpu_device_id)
644
      .def("xpu_device_id", &AnalysisConfig::xpu_device_id)
W
Wilber 已提交
645
      .def("npu_device_id", &AnalysisConfig::npu_device_id)
F
flame 已提交
646 647 648 649
      .def("memory_pool_init_size_mb",
           &AnalysisConfig::memory_pool_init_size_mb)
      .def("fraction_of_gpu_memory_for_pool",
           &AnalysisConfig::fraction_of_gpu_memory_for_pool)
W
Wilber 已提交
650 651
      .def("switch_ir_optim",
           &AnalysisConfig::SwitchIrOptim,
F
flame 已提交
652 653
           py::arg("x") = true)
      .def("ir_optim", &AnalysisConfig::ir_optim)
W
Wilber 已提交
654 655
      .def("enable_memory_optim",
           &AnalysisConfig::EnableMemoryOptim,
656
           py::arg("x") = true)
657
      .def("enable_profile", &AnalysisConfig::EnableProfile)
658
      .def("disable_glog_info", &AnalysisConfig::DisableGlogInfo)
659
      .def("glog_info_disabled", &AnalysisConfig::glog_info_disabled)
660
      .def("set_optim_cache_dir", &AnalysisConfig::SetOptimCacheDir)
W
Wilber 已提交
661 662
      .def("switch_use_feed_fetch_ops",
           &AnalysisConfig::SwitchUseFeedFetchOps,
F
flame 已提交
663 664 665 666
           py::arg("x") = true)
      .def("use_feed_fetch_ops_enabled",
           &AnalysisConfig::use_feed_fetch_ops_enabled)
      .def("switch_specify_input_names",
W
Wilber 已提交
667 668
           &AnalysisConfig::SwitchSpecifyInputNames,
           py::arg("x") = true)
F
flame 已提交
669
      .def("specify_input_name", &AnalysisConfig::specify_input_name)
W
Wilber 已提交
670 671 672 673
      .def("enable_tensorrt_engine",
           &AnalysisConfig::EnableTensorRtEngine,
           py::arg("workspace_size") = 1 << 20,
           py::arg("max_batch_size") = 1,
674
           py::arg("min_subgraph_size") = 3,
N
nhzlx 已提交
675
           py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
W
Wilber 已提交
676 677
           py::arg("use_static") = false,
           py::arg("use_calib_mode") = true)
678
      .def("tensorrt_precision_mode", &AnalysisConfig::tensorrt_precision_mode)
679 680
      .def("set_trt_dynamic_shape_info",
           &AnalysisConfig::SetTRTDynamicShapeInfo,
681 682 683 684 685
           py::arg("min_input_shape") =
               std::map<std::string, std::vector<int>>({}),
           py::arg("max_input_shape") =
               std::map<std::string, std::vector<int>>({}),
           py::arg("optim_input_shape") =
686 687
               std::map<std::string, std::vector<int>>({}),
           py::arg("disable_trt_plugin_fp16") = false)
688 689
      .def("tensorrt_dynamic_shape_enabled",
           &AnalysisConfig::tensorrt_dynamic_shape_enabled)
690 691 692
      .def("enable_tensorrt_varseqlen", &AnalysisConfig::EnableVarseqlen)
      .def("tensorrt_varseqlen_enabled",
           &AnalysisConfig::tensorrt_varseqlen_enabled)
693 694 695 696 697 698 699 700 701 702
      .def("collect_shape_range_info", &AnalysisConfig::CollectShapeRangeInfo)
      .def("shape_range_info_path", &AnalysisConfig::shape_range_info_path)
      .def("shape_range_info_collected",
           &AnalysisConfig::shape_range_info_collected)
      .def("enable_tuned_tensorrt_dynamic_shape",
           &AnalysisConfig::EnableTunedTensorRtDynamicShape)
      .def("tuned_tensorrt_dynamic_shape",
           &AnalysisConfig::tuned_tensorrt_dynamic_shape)
      .def("trt_allow_build_at_runtime",
           &AnalysisConfig::trt_allow_build_at_runtime)
703
      .def("exp_disable_tensorrt_ops", &AnalysisConfig::Exp_DisableTensorRtOPs)
W
Wilber 已提交
704 705
      .def("enable_tensorrt_dla",
           &AnalysisConfig::EnableTensorRtDLA,
706 707
           py::arg("dla_core") = 0)
      .def("tensorrt_dla_enabled", &AnalysisConfig::tensorrt_dla_enabled)
708 709 710 711
      .def("enable_tensorrt_inspector",
           &AnalysisConfig::EnableTensorRtInspector)
      .def("tensorrt_inspector_enabled",
           &AnalysisConfig::tensorrt_inspector_enabled)
F
flame 已提交
712
      .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
W
Wilber 已提交
713 714
      .def("enable_dlnne",
           &AnalysisConfig::EnableDlnne,
D
denglin-github 已提交
715
           py::arg("min_subgraph_size") = 3)
W
Wilber 已提交
716 717
      .def("enable_lite_engine",
           &AnalysisConfig::EnableLiteEngine,
718
           py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
W
Wilber 已提交
719
           py::arg("zero_copy") = false,
720 721 722
           py::arg("passes_filter") = std::vector<std::string>(),
           py::arg("ops_filter") = std::vector<std::string>())
      .def("lite_engine_enabled", &AnalysisConfig::lite_engine_enabled)
W
Wilber 已提交
723 724
      .def("switch_ir_debug",
           &AnalysisConfig::SwitchIrDebug,
F
flame 已提交
725 726 727 728 729 730 731 732
           py::arg("x") = true)
      .def("enable_mkldnn", &AnalysisConfig::EnableMKLDNN)
      .def("mkldnn_enabled", &AnalysisConfig::mkldnn_enabled)
      .def("set_cpu_math_library_num_threads",
           &AnalysisConfig::SetCpuMathLibraryNumThreads)
      .def("cpu_math_library_num_threads",
           &AnalysisConfig::cpu_math_library_num_threads)
      .def("to_native_config", &AnalysisConfig::ToNativeConfig)
733
      .def("enable_quantizer", &AnalysisConfig::EnableMkldnnQuantizer)
734
      .def("enable_mkldnn_bfloat16", &AnalysisConfig::EnableMkldnnBfloat16)
735
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
736 737
      .def("quantizer_config",
           &AnalysisConfig::mkldnn_quantizer_config,
738
           py::return_value_policy::reference)
W
Wilber 已提交
739 740
      .def("set_mkldnn_cache_capacity",
           &AnalysisConfig::SetMkldnnCacheCapacity,
741
           py::arg("capacity") = 0)
742
      .def("set_bfloat16_op", &AnalysisConfig::SetBfloat16Op)
W
Wilber 已提交
743 744
      .def("enable_mkldnn_int8",
           &AnalysisConfig::EnableMkldnnInt8,
B
baoachun 已提交
745 746 747
           py::arg("mkldnn_int8_enabled_op_types") =
               std::unordered_set<std::string>({}))
      .def("mkldnn_int8_enabled", &AnalysisConfig::mkldnn_int8_enabled)
748
#endif
F
flame 已提交
749 750 751
      .def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp)
      .def("set_model_buffer", &AnalysisConfig::SetModelBuffer)
      .def("model_from_memory", &AnalysisConfig::model_from_memory)
752 753 754 755
      .def("delete_pass",
           [](AnalysisConfig &self, const std::string &pass) {
             self.pass_builder()->DeletePass(pass);
           })
756 757 758 759 760 761
      .def(
          "pass_builder",
          [](AnalysisConfig &self) {
            return dynamic_cast<PaddlePassBuilder *>(self.pass_builder());
          },
          py::return_value_policy::reference)
762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779
      .def("nnadapter", &AnalysisConfig::NNAdapter)
      .def("set_dist_config", &AnalysisConfig::SetDistConfig)
      .def("dist_config", &AnalysisConfig::dist_config);

  py::class_<DistConfig>(*m, "DistConfig")
      .def(py::init<>())
      .def("set_carrier_id", &DistConfig::SetCarrierId)
      .def("set_comm_init_config", &DistConfig::SetCommInitConfig)
      .def("set_endpoints", &DistConfig::SetEndpoints)
      .def("set_ranks", &DistConfig::SetRanks)
      .def("enable_dist_model", &DistConfig::EnableDistModel)
      .def("carrier_id", &DistConfig::carrier_id)
      .def("current_endpoint", &DistConfig::current_endpoint)
      .def("trainer_endpoints", &DistConfig::trainer_endpoints)
      .def("nranks", &DistConfig::nranks)
      .def("rank", &DistConfig::rank)
      .def("comm_init_config", &DistConfig::comm_init_config)
      .def("use_dist_model", &DistConfig::use_dist_model);
780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797
}

void BindLiteNNAdapterConfig(py::module *m) {
  py::class_<LiteNNAdapterConfig> lite_nnadapter_config(*m,
                                                        "LiteNNAdapterConfig");

  lite_nnadapter_config
      .def("set_device_names", &LiteNNAdapterConfig::SetDeviceNames)
      .def("set_context_properties", &LiteNNAdapterConfig::SetContextProperties)
      .def("set_model_cache_dir", &LiteNNAdapterConfig::SetModelCacheDir)
      .def("set_model_cache_buffers",
           &LiteNNAdapterConfig::SetModelCacheBuffers)
      .def("set_subgraph_partition_config_path",
           &LiteNNAdapterConfig::SetSubgraphPartitionConfigPath)
      .def("set_subgraph_partition_config_buffer",
           &LiteNNAdapterConfig::SetSubgraphPartitionConfigBuffer)
      .def("enable", &LiteNNAdapterConfig::Enable)
      .def("disable", &LiteNNAdapterConfig::Disable);
F
flame 已提交
798 799
}

800 801 802 803 804 805 806 807 808 809 810 811 812 813 814
#ifdef PADDLE_WITH_MKLDNN
void BindMkldnnQuantizerConfig(py::module *m) {
  py::class_<MkldnnQuantizerConfig> quantizer_config(*m,
                                                     "MkldnnQuantizerConfig");
  quantizer_config.def(py::init<const MkldnnQuantizerConfig &>())
      .def(py::init<>())
      .def("set_quant_data",
           [](MkldnnQuantizerConfig &self,
              const std::vector<PaddleTensor> &data) {
             auto warmup_data =
                 std::make_shared<std::vector<PaddleTensor>>(data);
             self.SetWarmupData(warmup_data);
             return;
           })
      .def("set_quant_batch_size", &MkldnnQuantizerConfig::SetWarmupBatchSize)
815
      .def("set_enabled_op_types", &MkldnnQuantizerConfig::SetEnabledOpTypes);
816 817 818
}
#endif

F
flame 已提交
819 820 821 822 823 824 825 826 827 828 829 830 831
void BindAnalysisPredictor(py::module *m) {
  py::class_<AnalysisPredictor, PaddlePredictor>(*m, "AnalysisPredictor")
      .def(py::init<const AnalysisConfig &>())
      .def("init", &AnalysisPredictor::Init)
      .def(
          "run",
          [](AnalysisPredictor &self, const std::vector<PaddleTensor> &inputs) {
            std::vector<PaddleTensor> outputs;
            self.Run(inputs, &outputs);
            return outputs;
          })
      .def("get_input_tensor", &AnalysisPredictor::GetInputTensor)
      .def("get_output_tensor", &AnalysisPredictor::GetOutputTensor)
832 833 834
      .def("get_input_names", &AnalysisPredictor::GetInputNames)
      .def("get_output_names", &AnalysisPredictor::GetOutputNames)
      .def("get_input_tensor_shape", &AnalysisPredictor::GetInputTensorShape)
F
flame 已提交
835
      .def("zero_copy_run", &AnalysisPredictor::ZeroCopyRun)
836 837
      .def("clear_intermediate_tensor",
           &AnalysisPredictor::ClearIntermediateTensor)
838
      .def("try_shrink_memory", &AnalysisPredictor::TryShrinkMemory)
839 840 841 842 843
      .def("create_feed_fetch_var", &AnalysisPredictor::CreateFeedFetchVar)
      .def("prepare_feed_fetch", &AnalysisPredictor::PrepareFeedFetch)
      .def("prepare_argument", &AnalysisPredictor::PrepareArgument)
      .def("optimize_inference_program",
           &AnalysisPredictor::OptimizeInferenceProgram)
W
Wilber 已提交
844 845
      .def("analysis_argument",
           &AnalysisPredictor::analysis_argument,
846
           py::return_value_policy::reference)
F
flame 已提交
847
      .def("clone", &AnalysisPredictor::Clone)
W
Wilber 已提交
848 849
      .def("scope",
           &AnalysisPredictor::scope,
850
           py::return_value_policy::reference)
W
Wilber 已提交
851 852
      .def("program",
           &AnalysisPredictor::program,
853 854 855
           py::return_value_policy::reference)
      .def("get_serialized_program", &AnalysisPredictor::GetSerializedProgram)
      .def("mkldnn_quantize", &AnalysisPredictor::MkldnnQuantize)
W
Wilber 已提交
856 857
      .def(
          "SaveOptimModel", &AnalysisPredictor::SaveOptimModel, py::arg("dir"));
F
flame 已提交
858
}
859

W
Wilber 已提交
860 861 862 863 864 865 866
void BindPaddleInferPredictor(py::module *m) {
  py::class_<paddle_infer::Predictor>(*m, "PaddleInferPredictor")
      .def(py::init<const paddle_infer::Config &>())
      .def("get_input_names", &paddle_infer::Predictor::GetInputNames)
      .def("get_output_names", &paddle_infer::Predictor::GetOutputNames)
      .def("get_input_handle", &paddle_infer::Predictor::GetInputHandle)
      .def("get_output_handle", &paddle_infer::Predictor::GetOutputHandle)
W
Wilber 已提交
867 868 869 870 871 872 873
      .def("run",
           [](paddle_infer::Predictor &self) {
#ifdef PADDLE_WITH_ASCEND_CL
             pybind11::gil_scoped_release release;
#endif
             self.Run();
           })
W
Wilber 已提交
874
      .def("clone", &paddle_infer::Predictor::Clone)
875
      .def("try_shrink_memory", &paddle_infer::Predictor::TryShrinkMemory)
W
Wilber 已提交
876 877 878 879
      .def("clear_intermediate_tensor",
           &paddle_infer::Predictor::ClearIntermediateTensor);
}

880 881
void BindZeroCopyTensor(py::module *m) {
  py::class_<ZeroCopyTensor>(*m, "ZeroCopyTensor")
W
Wilber 已提交
882 883 884 885 886 887
      .def(
          "reshape",
          py::overload_cast<const std::vector<int> &>(&ZeroCopyTensor::Reshape))
      .def("reshape",
           py::overload_cast<const std::size_t &>(
               &paddle_infer::Tensor::ReshapeStrings))
888 889 890
      .def("copy_from_cpu", &ZeroCopyTensorCreate<int32_t>)
      .def("copy_from_cpu", &ZeroCopyTensorCreate<int64_t>)
      .def("copy_from_cpu", &ZeroCopyTensorCreate<float>)
891
      .def("copy_from_cpu", &ZeroCopyTensorCreate<paddle_infer::float16>)
S
Steffy-zxf 已提交
892
      .def("copy_from_cpu", &ZeroCopyStringTensorCreate)
893 894 895 896 897 898 899
      .def("copy_to_cpu", &ZeroCopyTensorToNumpy)
      .def("shape", &ZeroCopyTensor::shape)
      .def("set_lod", &ZeroCopyTensor::SetLoD)
      .def("lod", &ZeroCopyTensor::lod)
      .def("type", &ZeroCopyTensor::type);
}

W
Wilber 已提交
900 901
void BindPaddleInferTensor(py::module *m) {
  py::class_<paddle_infer::Tensor>(*m, "PaddleInferTensor")
W
Wilber 已提交
902 903 904 905 906 907
      .def("reshape",
           py::overload_cast<const std::vector<int> &>(
               &paddle_infer::Tensor::Reshape))
      .def("reshape",
           py::overload_cast<const std::size_t &>(
               &paddle_infer::Tensor::ReshapeStrings))
908 909 910 911 912
      .def("copy_from_cpu_bind", &PaddleInferTensorCreate<int32_t>)
      .def("copy_from_cpu_bind", &PaddleInferTensorCreate<int64_t>)
      .def("copy_from_cpu_bind", &PaddleInferTensorCreate<float>)
      .def("copy_from_cpu_bind",
           &PaddleInferTensorCreate<paddle_infer::float16>)
S
Steffy-zxf 已提交
913
      .def("copy_from_cpu_bind", &PaddleInferStringTensorCreate)
914
      .def("share_external_data_bind", &PaddleInferShareExternalData)
W
Wilber 已提交
915 916 917 918 919 920 921 922 923 924
      .def("copy_to_cpu", &PaddleInferTensorToNumpy)
      .def("shape", &paddle_infer::Tensor::shape)
      .def("set_lod", &paddle_infer::Tensor::SetLoD)
      .def("lod", &paddle_infer::Tensor::lod)
      .def("type", &paddle_infer::Tensor::type);
}

void BindPredictorPool(py::module *m) {
  py::class_<paddle_infer::services::PredictorPool>(*m, "PredictorPool")
      .def(py::init<const paddle_infer::Config &, size_t>())
W
Wilber 已提交
925 926
      .def("retrive",
           &paddle_infer::services::PredictorPool::Retrive,
W
Wilber 已提交
927 928 929
           py::return_value_policy::reference);
}

930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948
void BindPaddlePassBuilder(py::module *m) {
  py::class_<PaddlePassBuilder>(*m, "PaddlePassBuilder")
      .def(py::init<const std::vector<std::string> &>())
      .def("set_passes",
           [](PaddlePassBuilder &self, const std::vector<std::string> &passes) {
             self.ClearPasses();
             for (auto pass : passes) {
               self.AppendPass(std::move(pass));
             }
           })
      .def("append_pass", &PaddlePassBuilder::AppendPass)
      .def("insert_pass", &PaddlePassBuilder::InsertPass)
      .def("delete_pass",
           [](PaddlePassBuilder &self, const std::string &pass_type) {
             self.DeletePass(pass_type);
           })
      .def("append_analysis_pass", &PaddlePassBuilder::AppendAnalysisPass)
      .def("turn_on_debug", &PaddlePassBuilder::TurnOnDebug)
      .def("debug_string", &PaddlePassBuilder::DebugString)
W
Wilber 已提交
949 950
      .def("all_passes",
           &PaddlePassBuilder::AllPasses,
951 952 953 954 955 956 957 958
           py::return_value_policy::reference)
      .def("analysis_passes", &PaddlePassBuilder::AnalysisPasses);

  py::class_<PassStrategy, PaddlePassBuilder>(*m, "PassStrategy")
      .def(py::init<const std::vector<std::string> &>())
      .def("enable_cudnn", &PassStrategy::EnableCUDNN)
      .def("enable_mkldnn", &PassStrategy::EnableMKLDNN)
      .def("enable_mkldnn_quantizer", &PassStrategy::EnableMkldnnQuantizer)
959
      .def("enable_mkldnn_bfloat16", &PassStrategy::EnableMkldnnBfloat16)
960 961 962 963 964 965 966
      .def("use_gpu", &PassStrategy::use_gpu);

  py::class_<CpuPassStrategy, PassStrategy>(*m, "CpuPassStrategy")
      .def(py::init<>())
      .def(py::init<const CpuPassStrategy &>())
      .def("enable_cudnn", &CpuPassStrategy::EnableCUDNN)
      .def("enable_mkldnn", &CpuPassStrategy::EnableMKLDNN)
967 968
      .def("enable_mkldnn_quantizer", &CpuPassStrategy::EnableMkldnnQuantizer)
      .def("enable_mkldnn_bfloat16", &CpuPassStrategy::EnableMkldnnBfloat16);
969 970 971 972 973 974

  py::class_<GpuPassStrategy, PassStrategy>(*m, "GpuPassStrategy")
      .def(py::init<>())
      .def(py::init<const GpuPassStrategy &>())
      .def("enable_cudnn", &GpuPassStrategy::EnableCUDNN)
      .def("enable_mkldnn", &GpuPassStrategy::EnableMKLDNN)
975 976
      .def("enable_mkldnn_quantizer", &GpuPassStrategy::EnableMkldnnQuantizer)
      .def("enable_mkldnn_bfloat16", &GpuPassStrategy::EnableMkldnnBfloat16);
977
}
978
}  // namespace
F
flame 已提交
979 980
}  // namespace pybind
}  // namespace paddle