inference_api.cc 38.3 KB
Newer Older
F
flame 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/inference_api.h"
16

17
#include <pybind11/numpy.h>
F
flame 已提交
18
#include <pybind11/stl.h>
19

F
flame 已提交
20
#include <cstring>
21
#include <functional>
F
flame 已提交
22
#include <iostream>
23
#include <iterator>
24
#include <map>
25
#include <memory>
F
flame 已提交
26
#include <string>
27
#include <type_traits>
28
#include <unordered_set>
29
#include <utility>
F
flame 已提交
30
#include <vector>
31

F
flame 已提交
32
#include "paddle/fluid/inference/api/analysis_predictor.h"
33
#include "paddle/fluid/inference/api/helper.h"
34
#include "paddle/fluid/inference/api/paddle_infer_contrib.h"
F
flame 已提交
35
#include "paddle/fluid/inference/api/paddle_inference_api.h"
36
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
37
#include "paddle/fluid/inference/utils/io_utils.h"
F
flame 已提交
38

39 40 41 42
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "paddle/fluid/inference/api/onnxruntime_predictor.h"
#endif

F
flame 已提交
43 44
namespace py = pybind11;

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
namespace pybind11 {
namespace detail {

// Note: use same enum number of float16 in numpy.
// import numpy as np
// print np.dtype(np.float16).num  # 23
constexpr int NPY_FLOAT16_ = 23;
constexpr int NPY_UINT16_ = 4;

// Note: Since float16 is not a builtin type in C++, we register
// paddle::platform::float16 as numpy.float16.
// Ref: https://github.com/pybind/pybind11/issues/1776
template <>
struct npy_format_descriptor<paddle_infer::float16> {
  static py::dtype dtype() {
    handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_);
    return reinterpret_borrow<py::dtype>(ptr);
  }
  static std::string format() {
    // Note: "e" represents float16.
    // Details at:
    // https://docs.python.org/3/library/struct.html#format-characters.
    return "e";
  }
  static constexpr auto name = _("float16");
};

}  // namespace detail
}  // namespace pybind11

F
flame 已提交
75 76
namespace paddle {
namespace pybind {
77 78 79
using paddle::AnalysisPredictor;
using paddle::NativeConfig;
using paddle::NativePaddlePredictor;
F
flame 已提交
80
using paddle::PaddleBuf;
81
using paddle::PaddleDataLayout;
82
using paddle::PaddleDType;
83
using paddle::PaddlePassBuilder;
F
flame 已提交
84 85
using paddle::PaddlePlace;
using paddle::PaddlePredictor;
86 87 88
using paddle::PaddleTensor;
using paddle::PassStrategy;
using paddle::ZeroCopyTensor;
F
flame 已提交
89

90 91
namespace {
void BindPaddleDType(py::module *m);
92
void BindPaddleDataLayout(py::module *m);
93 94 95 96 97 98
void BindPaddleBuf(py::module *m);
void BindPaddleTensor(py::module *m);
void BindPaddlePlace(py::module *m);
void BindPaddlePredictor(py::module *m);
void BindNativeConfig(py::module *m);
void BindNativePredictor(py::module *m);
99
void BindLiteNNAdapterConfig(py::module *m);
100 101
void BindAnalysisConfig(py::module *m);
void BindAnalysisPredictor(py::module *m);
102 103
void BindZeroCopyTensor(py::module *m);
void BindPaddlePassBuilder(py::module *m);
W
Wilber 已提交
104 105 106
void BindPaddleInferPredictor(py::module *m);
void BindPaddleInferTensor(py::module *m);
void BindPredictorPool(py::module *m);
F
flame 已提交
107

108
#ifdef PADDLE_WITH_MKLDNN
109
void BindMkldnnQuantizerConfig(py::module *m);
110
#endif
111 112

template <typename T>
113 114
PaddleBuf PaddleBufCreate(
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {
115
  PaddleBuf buf(data.size() * sizeof(T));
116
  std::copy_n(static_cast<const T *>(data.data()), data.size(),
117 118 119 120 121
              static_cast<T *>(buf.data()));
  return buf;
}

template <typename T>
122 123 124
void PaddleBufReset(
    PaddleBuf &buf,                                                    // NOLINT
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {  // NOLINT
125
  buf.Resize(data.size() * sizeof(T));
126
  std::copy_n(static_cast<const T *>(data.data()), data.size(),
127 128 129 130 131
              static_cast<T *>(buf.data()));
}

template <typename T>
PaddleTensor PaddleTensorCreate(
132 133
    py::array_t<T, py::array::c_style | py::array::forcecast> data,
    const std::string name = "",
134 135 136 137 138
    const std::vector<std::vector<size_t>> &lod = {}, bool copy = true) {
  PaddleTensor tensor;

  if (copy) {
    PaddleBuf buf(data.size() * sizeof(T));
139
    std::copy_n(static_cast<const T *>(data.data()), data.size(),
140 141 142 143 144 145
                static_cast<T *>(buf.data()));
    tensor.data = std::move(buf);
  } else {
    tensor.data = PaddleBuf(data.mutable_data(), data.size() * sizeof(T));
  }

146
  tensor.dtype = inference::PaddleTensorGetDType<T>();
147 148 149 150 151 152 153 154
  tensor.name = name;
  tensor.lod = lod;
  tensor.shape.resize(data.ndim());
  std::copy_n(data.shape(), data.ndim(), tensor.shape.begin());

  return tensor;
}

155
py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) {
156
  py::dtype dt;
157
  switch (dtype) {
158 159 160 161 162 163 164 165 166
    case PaddleDType::INT32:
      dt = py::dtype::of<int32_t>();
      break;
    case PaddleDType::INT64:
      dt = py::dtype::of<int64_t>();
      break;
    case PaddleDType::FLOAT32:
      dt = py::dtype::of<float>();
      break;
W
Wilber 已提交
167 168 169
    case PaddleDType::UINT8:
      dt = py::dtype::of<uint8_t>();
      break;
170 171 172
    case PaddleDType::FLOAT16:
      dt = py::dtype::of<paddle_infer::float16>();
      break;
173
    default:
174
      PADDLE_THROW(platform::errors::Unimplemented(
W
Wilber 已提交
175
          "Unsupported data type. Now only supports INT32, INT64, UINT8 and "
176
          "FLOAT32."));
177
  }
178 179 180 181 182 183 184 185 186 187

  return dt;
}

py::array PaddleTensorGetData(PaddleTensor &tensor) {  // NOLINT
  py::dtype dt = PaddleDTypeToNumpyDType(tensor.dtype);
  return py::array(std::move(dt), {tensor.shape}, tensor.data.data());
}

template <typename T>
188 189 190
void ZeroCopyTensorCreate(
    ZeroCopyTensor &tensor,  // NOLINT
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {
191 192 193 194 195 196
  std::vector<int> shape;
  std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape));
  tensor.Reshape(std::move(shape));
  tensor.copy_from_cpu(static_cast<const T *>(data.data()));
}

S
Steffy-zxf 已提交
197 198 199 200 201 202 203 204 205 206 207 208
/// \brief Experimental interface.
/// Create the Strings tensor from data.
/// \param tensor The tensor will be created and
/// the tensor value is same as data.
/// \param data The input text.
void ZeroCopyStringTensorCreate(ZeroCopyTensor &tensor,  // NOLINT
                                const paddle_infer::Strings *data) {
  size_t shape = data->size();
  tensor.ReshapeStrings(shape);
  tensor.copy_strings_from_cpu(data);
}

W
Wilber 已提交
209
template <typename T>
210 211 212
void PaddleInferTensorCreate(
    paddle_infer::Tensor &tensor,  // NOLINT
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {
W
Wilber 已提交
213 214 215 216 217 218
  std::vector<int> shape;
  std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape));
  tensor.Reshape(std::move(shape));
  tensor.CopyFromCpu(static_cast<const T *>(data.data()));
}

219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
paddle_infer::PlaceType ToPaddleInferPlace(
    phi::AllocationType allocation_type) {
  if (allocation_type == phi::AllocationType::CPU) {
    return paddle_infer::PlaceType::kCPU;
  } else if (allocation_type == phi::AllocationType::GPU) {
    return paddle_infer::PlaceType::kGPU;
  } else {
    return paddle_infer::PlaceType::kCPU;
  }
}

void PaddleInferShareExternalData(paddle_infer::Tensor &tensor,  // NOLINT
                                  framework::Tensor input_tensor) {
  std::vector<int> shape;
  for (int i = 0; i < input_tensor.dims().size(); ++i) {
    shape.push_back(input_tensor.dims()[i]);
  }
  if (input_tensor.dtype() == phi::DataType::FLOAT32) {
    tensor.ShareExternalData(
        static_cast<float *>(input_tensor.data()), shape,
        ToPaddleInferPlace(input_tensor.place().GetType()));
  } else if (input_tensor.dtype() == phi::DataType::FLOAT16) {
    tensor.ShareExternalData(
        static_cast<paddle::platform::float16 *>(input_tensor.data()), shape,
        ToPaddleInferPlace(input_tensor.place().GetType()));
  }
}

S
Steffy-zxf 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259
/// \brief Experimental interface.
/// Create the Strings tensor from data.
/// \param tensor The tensor will be created and
/// the tensor value is same as data.
/// \param data The input text.
void PaddleInferStringTensorCreate(paddle_infer::Tensor &tensor,  // NOLINT
                                   const paddle_infer::Strings *data) {
  VLOG(3) << "Create PaddleInferTensor, dtype = Strings ";
  size_t shape = data->size();
  tensor.ReshapeStrings(shape);
  tensor.CopyStringsFromCpu(data);
}

260 261 262 263 264 265 266 267 268 269 270 271 272
size_t PaddleGetDTypeSize(PaddleDType dt) {
  size_t size{0};
  switch (dt) {
    case PaddleDType::INT32:
      size = sizeof(int32_t);
      break;
    case PaddleDType::INT64:
      size = sizeof(int64_t);
      break;
    case PaddleDType::FLOAT32:
      size = sizeof(float);
      break;
    default:
273 274 275
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported data type. Now only supports INT32, INT64 and "
          "FLOAT32."));
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
  }
  return size;
}

py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) {  // NOLINT
  py::dtype dt = PaddleDTypeToNumpyDType(tensor.type());
  auto tensor_shape = tensor.shape();
  py::array::ShapeContainer shape(tensor_shape.begin(), tensor_shape.end());
  py::array array(dt, std::move(shape));

  switch (tensor.type()) {
    case PaddleDType::INT32:
      tensor.copy_to_cpu(static_cast<int32_t *>(array.mutable_data()));
      break;
    case PaddleDType::INT64:
      tensor.copy_to_cpu(static_cast<int64_t *>(array.mutable_data()));
      break;
    case PaddleDType::FLOAT32:
      tensor.copy_to_cpu<float>(static_cast<float *>(array.mutable_data()));
      break;
296 297 298 299
    case PaddleDType::FLOAT16:
      tensor.copy_to_cpu<paddle::platform::float16>(
          static_cast<paddle::platform::float16 *>(array.mutable_data()));
      break;
W
Wilber 已提交
300 301 302
    case PaddleDType::UINT8:
      tensor.copy_to_cpu<uint8_t>(static_cast<uint8_t *>(array.mutable_data()));
      break;
303 304 305
    case PaddleDType::INT8:
      tensor.copy_to_cpu<int8_t>(static_cast<int8_t *>(array.mutable_data()));
      break;
306
    default:
307
      PADDLE_THROW(platform::errors::Unimplemented(
W
Wilber 已提交
308
          "Unsupported data type. Now only supports INT32, INT64, UINT8 and "
309
          "FLOAT32."));
310 311
  }
  return array;
312
}
313

W
Wilber 已提交
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) {  // NOLINT
  py::dtype dt = PaddleDTypeToNumpyDType(tensor.type());
  auto tensor_shape = tensor.shape();
  py::array::ShapeContainer shape(tensor_shape.begin(), tensor_shape.end());
  py::array array(dt, std::move(shape));

  switch (tensor.type()) {
    case PaddleDType::INT32:
      tensor.CopyToCpu(static_cast<int32_t *>(array.mutable_data()));
      break;
    case PaddleDType::INT64:
      tensor.CopyToCpu(static_cast<int64_t *>(array.mutable_data()));
      break;
    case PaddleDType::FLOAT32:
      tensor.CopyToCpu<float>(static_cast<float *>(array.mutable_data()));
      break;
330 331 332 333
    case PaddleDType::FLOAT16:
      tensor.CopyToCpu<paddle::platform::float16>(
          static_cast<paddle::platform::float16 *>(array.mutable_data()));
      break;
334 335 336 337 338 339
    case PaddleDType::UINT8:
      tensor.CopyToCpu(static_cast<uint8_t *>(array.mutable_data()));
      break;
    case PaddleDType::INT8:
      tensor.CopyToCpu(static_cast<int8_t *>(array.mutable_data()));
      break;
W
Wilber 已提交
340 341 342 343 344 345 346 347
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported data type. Now only supports INT32, INT64 and "
          "FLOAT32."));
  }
  return array;
}

348 349 350 351 352
py::bytes SerializePDTensorToBytes(PaddleTensor &tensor) {  // NOLINT
  std::stringstream ss;
  paddle::inference::SerializePDTensorToStream(&ss, tensor);
  return static_cast<py::bytes>(ss.str());
}
353

354
void CopyPaddleInferTensor(paddle_infer::Tensor &dst,  // NOLINT
355 356 357 358
                           const paddle_infer::Tensor &src) {
  return paddle_infer::contrib::TensorUtils::CopyTensor(&dst, src);
}

359
}  // namespace
360

F
flame 已提交
361 362
void BindInferenceApi(py::module *m) {
  BindPaddleDType(m);
363
  BindPaddleDataLayout(m);
F
flame 已提交
364 365 366 367 368 369
  BindPaddleBuf(m);
  BindPaddleTensor(m);
  BindPaddlePlace(m);
  BindPaddlePredictor(m);
  BindNativeConfig(m);
  BindNativePredictor(m);
370
  BindLiteNNAdapterConfig(m);
F
flame 已提交
371 372
  BindAnalysisConfig(m);
  BindAnalysisPredictor(m);
W
Wilber 已提交
373
  BindPaddleInferPredictor(m);
374
  BindZeroCopyTensor(m);
W
Wilber 已提交
375
  BindPaddleInferTensor(m);
376
  BindPaddlePassBuilder(m);
W
Wilber 已提交
377
  BindPredictorPool(m);
378 379 380
#ifdef PADDLE_WITH_MKLDNN
  BindMkldnnQuantizerConfig(m);
#endif
F
flame 已提交
381
  m->def("create_paddle_predictor",
W
Wilber 已提交
382
         &paddle::CreatePaddlePredictor<AnalysisConfig>, py::arg("config"));
F
flame 已提交
383
  m->def("create_paddle_predictor",
W
Wilber 已提交
384
         &paddle::CreatePaddlePredictor<NativeConfig>, py::arg("config"));
385 386 387 388 389 390 391
  m->def("create_predictor",
         [](const paddle_infer::Config &config)
             -> std::unique_ptr<paddle_infer::Predictor> {
           auto pred = std::unique_ptr<paddle_infer::Predictor>(
               new paddle_infer::Predictor(config));
           return pred;
         });
392
  m->def("copy_tensor", &CopyPaddleInferTensor);
F
flame 已提交
393
  m->def("paddle_dtype_size", &paddle::PaddleDtypeSize);
394
  m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes);
W
Wilber 已提交
395
  m->def("get_version", &paddle_infer::GetVersion);
396 397
  m->def("get_trt_compile_version", &paddle_infer::GetTrtCompileVersion);
  m->def("get_trt_runtime_version", &paddle_infer::GetTrtRuntimeVersion);
W
Wilber 已提交
398
  m->def("get_num_bytes_of_data_type", &paddle_infer::GetNumBytesOfDataType);
F
flame 已提交
399 400
}

401
namespace {
F
flame 已提交
402 403 404
void BindPaddleDType(py::module *m) {
  py::enum_<PaddleDType>(*m, "PaddleDType")
      .value("FLOAT32", PaddleDType::FLOAT32)
405 406
      .value("INT64", PaddleDType::INT64)
      .value("INT32", PaddleDType::INT32);
F
flame 已提交
407 408
}

409 410 411 412 413 414 415 416
void BindPaddleDataLayout(py::module *m) {
  py::enum_<PaddleDataLayout>(*m, "PaddleDataLayout")
      .value("UNK", PaddleDataLayout::kUNK)
      .value("Any", PaddleDataLayout::kAny)
      .value("NHWC", PaddleDataLayout::kNHWC)
      .value("NCHW", PaddleDataLayout::kNCHW);
}

F
flame 已提交
417 418 419 420 421 422
void BindPaddleBuf(py::module *m) {
  py::class_<PaddleBuf>(*m, "PaddleBuf")
      .def(py::init<size_t>())
      .def(py::init([](std::vector<float> &data) {
        auto buf = PaddleBuf(data.size() * sizeof(float));
        std::memcpy(buf.data(), static_cast<void *>(data.data()), buf.length());
G
Gabor Buella 已提交
423
        return buf;
F
flame 已提交
424
      }))
425 426 427
      .def(py::init(&PaddleBufCreate<int32_t>))
      .def(py::init(&PaddleBufCreate<int64_t>))
      .def(py::init(&PaddleBufCreate<float>))
F
flame 已提交
428 429 430 431 432 433
      .def("resize", &PaddleBuf::Resize)
      .def("reset",
           [](PaddleBuf &self, std::vector<float> &data) {
             self.Resize(data.size() * sizeof(float));
             std::memcpy(self.data(), data.data(), self.length());
           })
434 435 436
      .def("reset", &PaddleBufReset<int32_t>)
      .def("reset", &PaddleBufReset<int64_t>)
      .def("reset", &PaddleBufReset<float>)
437
      .def("empty", &PaddleBuf::empty)
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453
      .def("tolist",
           [](PaddleBuf &self, const std::string &dtype) -> py::list {
             py::list l;
             if (dtype == "int32") {
               auto *data = static_cast<int32_t *>(self.data());
               auto size = self.length() / sizeof(int32_t);
               l = py::cast(std::vector<int32_t>(data, data + size));
             } else if (dtype == "int64") {
               auto *data = static_cast<int64_t *>(self.data());
               auto size = self.length() / sizeof(int64_t);
               l = py::cast(std::vector<int64_t>(data, data + size));
             } else if (dtype == "float32") {
               auto *data = static_cast<float *>(self.data());
               auto size = self.length() / sizeof(float);
               l = py::cast(std::vector<float>(data, data + size));
             } else {
454 455 456
               PADDLE_THROW(platform::errors::Unimplemented(
                   "Unsupported data type. Now only supports INT32, INT64 and "
                   "FLOAT32."));
457 458 459
             }
             return l;
           })
F
flame 已提交
460 461 462 463 464 465 466 467 468 469
      .def("float_data",
           [](PaddleBuf &self) -> std::vector<float> {
             auto *data = static_cast<float *>(self.data());
             return {data, data + self.length() / sizeof(*data)};
           })
      .def("int64_data",
           [](PaddleBuf &self) -> std::vector<int64_t> {
             int64_t *data = static_cast<int64_t *>(self.data());
             return {data, data + self.length() / sizeof(*data)};
           })
470 471 472 473
      .def("int32_data",
           [](PaddleBuf &self) -> std::vector<int32_t> {
             int32_t *data = static_cast<int32_t *>(self.data());
             return {data, data + self.length() / sizeof(*data)};
F
flame 已提交
474 475 476 477 478 479 480
           })
      .def("length", &PaddleBuf::length);
}

void BindPaddleTensor(py::module *m) {
  py::class_<PaddleTensor>(*m, "PaddleTensor")
      .def(py::init<>())
481 482 483 484 485 486 487 488 489 490 491 492 493
      .def(py::init(&PaddleTensorCreate<int32_t>), py::arg("data"),
           py::arg("name") = "",
           py::arg("lod") = std::vector<std::vector<size_t>>(),
           py::arg("copy") = true)
      .def(py::init(&PaddleTensorCreate<int64_t>), py::arg("data"),
           py::arg("name") = "",
           py::arg("lod") = std::vector<std::vector<size_t>>(),
           py::arg("copy") = true)
      .def(py::init(&PaddleTensorCreate<float>), py::arg("data"),
           py::arg("name") = "",
           py::arg("lod") = std::vector<std::vector<size_t>>(),
           py::arg("copy") = true)
      .def("as_ndarray", &PaddleTensorGetData)
F
flame 已提交
494 495 496 497 498 499 500 501 502 503 504
      .def_readwrite("name", &PaddleTensor::name)
      .def_readwrite("shape", &PaddleTensor::shape)
      .def_readwrite("data", &PaddleTensor::data)
      .def_readwrite("dtype", &PaddleTensor::dtype)
      .def_readwrite("lod", &PaddleTensor::lod);
}

void BindPaddlePlace(py::module *m) {
  py::enum_<PaddlePlace>(*m, "PaddlePlace")
      .value("UNK", PaddlePlace::kUNK)
      .value("CPU", PaddlePlace::kCPU)
505
      .value("GPU", PaddlePlace::kGPU)
W
Wilber 已提交
506 507
      .value("XPU", PaddlePlace::kXPU)
      .value("NPU", PaddlePlace::kNPU);
F
flame 已提交
508 509 510 511 512 513 514 515 516 517 518 519 520
}

void BindPaddlePredictor(py::module *m) {
  auto paddle_predictor = py::class_<PaddlePredictor>(*m, "PaddlePredictor");
  paddle_predictor
      .def("run",
           [](PaddlePredictor &self, const std::vector<PaddleTensor> &inputs) {
             std::vector<PaddleTensor> outputs;
             self.Run(inputs, &outputs);
             return outputs;
           })
      .def("get_input_tensor", &PaddlePredictor::GetInputTensor)
      .def("get_output_tensor", &PaddlePredictor::GetOutputTensor)
521 522
      .def("get_input_names", &PaddlePredictor::GetInputNames)
      .def("get_output_names", &PaddlePredictor::GetOutputNames)
F
flame 已提交
523
      .def("zero_copy_run", &PaddlePredictor::ZeroCopyRun)
524 525
      .def("clone", &PaddlePredictor::Clone)
      .def("get_serialized_program", &PaddlePredictor::GetSerializedProgram);
F
flame 已提交
526 527 528 529 530 531 532 533 534 535

  auto config = py::class_<PaddlePredictor::Config>(paddle_predictor, "Config");
  config.def(py::init<>())
      .def_readwrite("model_dir", &PaddlePredictor::Config::model_dir);
}

void BindNativeConfig(py::module *m) {
  py::class_<NativeConfig, PaddlePredictor::Config>(*m, "NativeConfig")
      .def(py::init<>())
      .def_readwrite("use_gpu", &NativeConfig::use_gpu)
536
      .def_readwrite("use_xpu", &NativeConfig::use_xpu)
W
Wilber 已提交
537
      .def_readwrite("use_npu", &NativeConfig::use_npu)
F
flame 已提交
538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
      .def_readwrite("device", &NativeConfig::device)
      .def_readwrite("fraction_of_gpu_memory",
                     &NativeConfig::fraction_of_gpu_memory)
      .def_readwrite("prog_file", &NativeConfig::prog_file)
      .def_readwrite("param_file", &NativeConfig::param_file)
      .def_readwrite("specify_input_name", &NativeConfig::specify_input_name)
      .def("set_cpu_math_library_num_threads",
           &NativeConfig::SetCpuMathLibraryNumThreads)
      .def("cpu_math_library_num_threads",
           &NativeConfig::cpu_math_library_num_threads);
}

void BindNativePredictor(py::module *m) {
  py::class_<NativePaddlePredictor, PaddlePredictor>(*m,
                                                     "NativePaddlePredictor")
      .def(py::init<const NativeConfig &>())
      .def("init", &NativePaddlePredictor::Init)
      .def("run",
           [](NativePaddlePredictor &self,
              const std::vector<PaddleTensor> &inputs) {
             std::vector<PaddleTensor> outputs;
             self.Run(inputs, &outputs);
             return outputs;
           })
      .def("get_input_tensor", &NativePaddlePredictor::GetInputTensor)
      .def("get_output_tensor", &NativePaddlePredictor::GetOutputTensor)
      .def("zero_copy_run", &NativePaddlePredictor::ZeroCopyRun)
      .def("clone", &NativePaddlePredictor::Clone)
      .def("scope", &NativePaddlePredictor::scope,
           py::return_value_policy::reference);
}

void BindAnalysisConfig(py::module *m) {
571 572 573 574 575
  py::class_<AnalysisConfig> analysis_config(*m, "AnalysisConfig");

  py::enum_<AnalysisConfig::Precision>(analysis_config, "Precision")
      .value("Float32", AnalysisConfig::Precision::kFloat32)
      .value("Int8", AnalysisConfig::Precision::kInt8)
Z
Zhaolong Xing 已提交
576
      .value("Half", AnalysisConfig::Precision::kHalf)
577 578
      .export_values();

579 580
  analysis_config.def(py::init<>())
      .def(py::init<const AnalysisConfig &>())
F
flame 已提交
581 582
      .def(py::init<const std::string &>())
      .def(py::init<const std::string &, const std::string &>())
583
      .def("summary", &AnalysisConfig::Summary)
584
      .def("set_model", (void(AnalysisConfig::*)(const std::string &)) &
F
flame 已提交
585
                            AnalysisConfig::SetModel)
586 587 588
      .def("set_model",
           (void(AnalysisConfig::*)(const std::string &, const std::string &)) &
               AnalysisConfig::SetModel)
F
flame 已提交
589 590 591 592 593 594 595
      .def("set_prog_file", &AnalysisConfig::SetProgFile)
      .def("set_params_file", &AnalysisConfig::SetParamsFile)
      .def("model_dir", &AnalysisConfig::model_dir)
      .def("prog_file", &AnalysisConfig::prog_file)
      .def("params_file", &AnalysisConfig::params_file)
      .def("enable_use_gpu", &AnalysisConfig::EnableUseGpu,
           py::arg("memory_pool_init_size_mb"), py::arg("device_id") = 0)
596 597 598
      .def("exp_enable_use_gpu_fp16", &AnalysisConfig::Exp_EnableUseGpuFp16,
           py::arg("gpu_fp16_disabled_op_types") =
               std::unordered_set<std::string>({}))
599
      .def("enable_xpu", &AnalysisConfig::EnableXpu,
W
Wilber 已提交
600 601 602 603
           py::arg("l3_workspace_size") = 16 * 1024 * 1024,
           py::arg("locked") = false, py::arg("autotune") = true,
           py::arg("autotune_file") = "", py::arg("precision") = "int16",
           py::arg("adaptive_seqlen") = false)
604 605
      .def("set_xpu_device_id", &AnalysisConfig::SetXpuDeviceId,
           py::arg("device_id") = 0)
W
Wilber 已提交
606
      .def("enable_npu", &AnalysisConfig::EnableNpu, py::arg("device_id") = 0)
607 608 609 610 611 612 613 614
      .def("enable_ipu", &AnalysisConfig::EnableIpu,
           py::arg("ipu_device_num") = 1, py::arg("ipu_micro_batch_size") = 1,
           py::arg("ipu_enable_pipelining") = false,
           py::arg("ipu_batches_per_step") = 1)
      .def("set_ipu_config", &AnalysisConfig::SetIpuConfig,
           py::arg("ipu_enable_fp16") = false, py::arg("ipu_replica_num") = 1,
           py::arg("ipu_available_memory_proportion") = 1.0,
           py::arg("ipu_enable_half_partial") = false)
F
flame 已提交
615
      .def("disable_gpu", &AnalysisConfig::DisableGpu)
616 617 618 619
      .def("enable_onnxruntime", &AnalysisConfig::EnableONNXRuntime)
      .def("disable_onnxruntime", &AnalysisConfig::DisableONNXRuntime)
      .def("onnxruntime_enabled", &AnalysisConfig::use_onnxruntime)
      .def("enable_ort_optimization", &AnalysisConfig::EnableORTOptimization)
F
flame 已提交
620
      .def("use_gpu", &AnalysisConfig::use_gpu)
621
      .def("use_xpu", &AnalysisConfig::use_xpu)
W
Wilber 已提交
622
      .def("use_npu", &AnalysisConfig::use_npu)
F
flame 已提交
623
      .def("gpu_device_id", &AnalysisConfig::gpu_device_id)
624
      .def("xpu_device_id", &AnalysisConfig::xpu_device_id)
W
Wilber 已提交
625
      .def("npu_device_id", &AnalysisConfig::npu_device_id)
F
flame 已提交
626 627 628 629 630 631 632
      .def("memory_pool_init_size_mb",
           &AnalysisConfig::memory_pool_init_size_mb)
      .def("fraction_of_gpu_memory_for_pool",
           &AnalysisConfig::fraction_of_gpu_memory_for_pool)
      .def("switch_ir_optim", &AnalysisConfig::SwitchIrOptim,
           py::arg("x") = true)
      .def("ir_optim", &AnalysisConfig::ir_optim)
633 634
      .def("enable_memory_optim", &AnalysisConfig::EnableMemoryOptim,
           py::arg("x") = true)
635
      .def("enable_profile", &AnalysisConfig::EnableProfile)
636
      .def("disable_glog_info", &AnalysisConfig::DisableGlogInfo)
637
      .def("glog_info_disabled", &AnalysisConfig::glog_info_disabled)
638
      .def("set_optim_cache_dir", &AnalysisConfig::SetOptimCacheDir)
F
flame 已提交
639 640 641 642 643 644 645 646 647
      .def("switch_use_feed_fetch_ops", &AnalysisConfig::SwitchUseFeedFetchOps,
           py::arg("x") = true)
      .def("use_feed_fetch_ops_enabled",
           &AnalysisConfig::use_feed_fetch_ops_enabled)
      .def("switch_specify_input_names",
           &AnalysisConfig::SwitchSpecifyInputNames, py::arg("x") = true)
      .def("specify_input_name", &AnalysisConfig::specify_input_name)
      .def("enable_tensorrt_engine", &AnalysisConfig::EnableTensorRtEngine,
           py::arg("workspace_size") = 1 << 20, py::arg("max_batch_size") = 1,
648
           py::arg("min_subgraph_size") = 3,
N
nhzlx 已提交
649
           py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
650
           py::arg("use_static") = false, py::arg("use_calib_mode") = true)
651
      .def("tensorrt_precision_mode", &AnalysisConfig::tensorrt_precision_mode)
652 653
      .def("set_trt_dynamic_shape_info",
           &AnalysisConfig::SetTRTDynamicShapeInfo,
654 655 656 657 658
           py::arg("min_input_shape") =
               std::map<std::string, std::vector<int>>({}),
           py::arg("max_input_shape") =
               std::map<std::string, std::vector<int>>({}),
           py::arg("optim_input_shape") =
659 660
               std::map<std::string, std::vector<int>>({}),
           py::arg("disable_trt_plugin_fp16") = false)
661 662
      .def("tensorrt_dynamic_shape_enabled",
           &AnalysisConfig::tensorrt_dynamic_shape_enabled)
663 664 665
      .def("enable_tensorrt_varseqlen", &AnalysisConfig::EnableVarseqlen)
      .def("tensorrt_varseqlen_enabled",
           &AnalysisConfig::tensorrt_varseqlen_enabled)
666 667 668 669 670 671 672 673 674 675
      .def("collect_shape_range_info", &AnalysisConfig::CollectShapeRangeInfo)
      .def("shape_range_info_path", &AnalysisConfig::shape_range_info_path)
      .def("shape_range_info_collected",
           &AnalysisConfig::shape_range_info_collected)
      .def("enable_tuned_tensorrt_dynamic_shape",
           &AnalysisConfig::EnableTunedTensorRtDynamicShape)
      .def("tuned_tensorrt_dynamic_shape",
           &AnalysisConfig::tuned_tensorrt_dynamic_shape)
      .def("trt_allow_build_at_runtime",
           &AnalysisConfig::trt_allow_build_at_runtime)
676
      .def("exp_disable_tensorrt_ops", &AnalysisConfig::Exp_DisableTensorRtOPs)
677 678 679
      .def("enable_tensorrt_dla", &AnalysisConfig::EnableTensorRtDLA,
           py::arg("dla_core") = 0)
      .def("tensorrt_dla_enabled", &AnalysisConfig::tensorrt_dla_enabled)
680 681 682 683
      .def("enable_tensorrt_inspector",
           &AnalysisConfig::EnableTensorRtInspector)
      .def("tensorrt_inspector_enabled",
           &AnalysisConfig::tensorrt_inspector_enabled)
F
flame 已提交
684
      .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
D
denglin-github 已提交
685 686
      .def("enable_dlnne", &AnalysisConfig::EnableDlnne,
           py::arg("min_subgraph_size") = 3)
687 688
      .def("enable_lite_engine", &AnalysisConfig::EnableLiteEngine,
           py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
W
Wilber 已提交
689
           py::arg("zero_copy") = false,
690 691 692
           py::arg("passes_filter") = std::vector<std::string>(),
           py::arg("ops_filter") = std::vector<std::string>())
      .def("lite_engine_enabled", &AnalysisConfig::lite_engine_enabled)
F
flame 已提交
693 694 695 696 697 698 699 700 701
      .def("switch_ir_debug", &AnalysisConfig::SwitchIrDebug,
           py::arg("x") = true)
      .def("enable_mkldnn", &AnalysisConfig::EnableMKLDNN)
      .def("mkldnn_enabled", &AnalysisConfig::mkldnn_enabled)
      .def("set_cpu_math_library_num_threads",
           &AnalysisConfig::SetCpuMathLibraryNumThreads)
      .def("cpu_math_library_num_threads",
           &AnalysisConfig::cpu_math_library_num_threads)
      .def("to_native_config", &AnalysisConfig::ToNativeConfig)
702
      .def("enable_quantizer", &AnalysisConfig::EnableMkldnnQuantizer)
703
      .def("enable_mkldnn_bfloat16", &AnalysisConfig::EnableMkldnnBfloat16)
704 705 706
#ifdef PADDLE_WITH_MKLDNN
      .def("quantizer_config", &AnalysisConfig::mkldnn_quantizer_config,
           py::return_value_policy::reference)
707 708
      .def("set_mkldnn_cache_capacity", &AnalysisConfig::SetMkldnnCacheCapacity,
           py::arg("capacity") = 0)
709
      .def("set_bfloat16_op", &AnalysisConfig::SetBfloat16Op)
B
baoachun 已提交
710 711 712 713
      .def("enable_mkldnn_int8", &AnalysisConfig::EnableMkldnnInt8,
           py::arg("mkldnn_int8_enabled_op_types") =
               std::unordered_set<std::string>({}))
      .def("mkldnn_int8_enabled", &AnalysisConfig::mkldnn_int8_enabled)
714
#endif
F
flame 已提交
715 716 717
      .def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp)
      .def("set_model_buffer", &AnalysisConfig::SetModelBuffer)
      .def("model_from_memory", &AnalysisConfig::model_from_memory)
718 719 720 721
      .def("delete_pass",
           [](AnalysisConfig &self, const std::string &pass) {
             self.pass_builder()->DeletePass(pass);
           })
722 723 724 725 726 727
      .def(
          "pass_builder",
          [](AnalysisConfig &self) {
            return dynamic_cast<PaddlePassBuilder *>(self.pass_builder());
          },
          py::return_value_policy::reference)
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
      .def("nnadapter", &AnalysisConfig::NNAdapter)
      .def("set_dist_config", &AnalysisConfig::SetDistConfig)
      .def("dist_config", &AnalysisConfig::dist_config);

  py::class_<DistConfig>(*m, "DistConfig")
      .def(py::init<>())
      .def("set_carrier_id", &DistConfig::SetCarrierId)
      .def("set_comm_init_config", &DistConfig::SetCommInitConfig)
      .def("set_endpoints", &DistConfig::SetEndpoints)
      .def("set_ranks", &DistConfig::SetRanks)
      .def("enable_dist_model", &DistConfig::EnableDistModel)
      .def("carrier_id", &DistConfig::carrier_id)
      .def("current_endpoint", &DistConfig::current_endpoint)
      .def("trainer_endpoints", &DistConfig::trainer_endpoints)
      .def("nranks", &DistConfig::nranks)
      .def("rank", &DistConfig::rank)
      .def("comm_init_config", &DistConfig::comm_init_config)
      .def("use_dist_model", &DistConfig::use_dist_model);
746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763
}

void BindLiteNNAdapterConfig(py::module *m) {
  py::class_<LiteNNAdapterConfig> lite_nnadapter_config(*m,
                                                        "LiteNNAdapterConfig");

  lite_nnadapter_config
      .def("set_device_names", &LiteNNAdapterConfig::SetDeviceNames)
      .def("set_context_properties", &LiteNNAdapterConfig::SetContextProperties)
      .def("set_model_cache_dir", &LiteNNAdapterConfig::SetModelCacheDir)
      .def("set_model_cache_buffers",
           &LiteNNAdapterConfig::SetModelCacheBuffers)
      .def("set_subgraph_partition_config_path",
           &LiteNNAdapterConfig::SetSubgraphPartitionConfigPath)
      .def("set_subgraph_partition_config_buffer",
           &LiteNNAdapterConfig::SetSubgraphPartitionConfigBuffer)
      .def("enable", &LiteNNAdapterConfig::Enable)
      .def("disable", &LiteNNAdapterConfig::Disable);
F
flame 已提交
764 765
}

766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
#ifdef PADDLE_WITH_MKLDNN
void BindMkldnnQuantizerConfig(py::module *m) {
  py::class_<MkldnnQuantizerConfig> quantizer_config(*m,
                                                     "MkldnnQuantizerConfig");
  quantizer_config.def(py::init<const MkldnnQuantizerConfig &>())
      .def(py::init<>())
      .def("set_quant_data",
           [](MkldnnQuantizerConfig &self,
              const std::vector<PaddleTensor> &data) {
             auto warmup_data =
                 std::make_shared<std::vector<PaddleTensor>>(data);
             self.SetWarmupData(warmup_data);
             return;
           })
      .def("set_quant_batch_size", &MkldnnQuantizerConfig::SetWarmupBatchSize)
781
      .def("set_enabled_op_types", &MkldnnQuantizerConfig::SetEnabledOpTypes);
782 783 784
}
#endif

F
flame 已提交
785 786 787 788 789 790 791 792 793 794 795 796 797
void BindAnalysisPredictor(py::module *m) {
  py::class_<AnalysisPredictor, PaddlePredictor>(*m, "AnalysisPredictor")
      .def(py::init<const AnalysisConfig &>())
      .def("init", &AnalysisPredictor::Init)
      .def(
          "run",
          [](AnalysisPredictor &self, const std::vector<PaddleTensor> &inputs) {
            std::vector<PaddleTensor> outputs;
            self.Run(inputs, &outputs);
            return outputs;
          })
      .def("get_input_tensor", &AnalysisPredictor::GetInputTensor)
      .def("get_output_tensor", &AnalysisPredictor::GetOutputTensor)
798 799 800
      .def("get_input_names", &AnalysisPredictor::GetInputNames)
      .def("get_output_names", &AnalysisPredictor::GetOutputNames)
      .def("get_input_tensor_shape", &AnalysisPredictor::GetInputTensorShape)
F
flame 已提交
801
      .def("zero_copy_run", &AnalysisPredictor::ZeroCopyRun)
802 803
      .def("clear_intermediate_tensor",
           &AnalysisPredictor::ClearIntermediateTensor)
804
      .def("try_shrink_memory", &AnalysisPredictor::TryShrinkMemory)
805 806 807 808 809 810 811
      .def("create_feed_fetch_var", &AnalysisPredictor::CreateFeedFetchVar)
      .def("prepare_feed_fetch", &AnalysisPredictor::PrepareFeedFetch)
      .def("prepare_argument", &AnalysisPredictor::PrepareArgument)
      .def("optimize_inference_program",
           &AnalysisPredictor::OptimizeInferenceProgram)
      .def("analysis_argument", &AnalysisPredictor::analysis_argument,
           py::return_value_policy::reference)
F
flame 已提交
812 813
      .def("clone", &AnalysisPredictor::Clone)
      .def("scope", &AnalysisPredictor::scope,
814
           py::return_value_policy::reference)
815 816 817 818
      .def("program", &AnalysisPredictor::program,
           py::return_value_policy::reference)
      .def("get_serialized_program", &AnalysisPredictor::GetSerializedProgram)
      .def("mkldnn_quantize", &AnalysisPredictor::MkldnnQuantize)
819 820
      .def("SaveOptimModel", &AnalysisPredictor::SaveOptimModel,
           py::arg("dir"));
F
flame 已提交
821
}
822

W
Wilber 已提交
823 824 825 826 827 828 829
void BindPaddleInferPredictor(py::module *m) {
  py::class_<paddle_infer::Predictor>(*m, "PaddleInferPredictor")
      .def(py::init<const paddle_infer::Config &>())
      .def("get_input_names", &paddle_infer::Predictor::GetInputNames)
      .def("get_output_names", &paddle_infer::Predictor::GetOutputNames)
      .def("get_input_handle", &paddle_infer::Predictor::GetInputHandle)
      .def("get_output_handle", &paddle_infer::Predictor::GetOutputHandle)
W
Wilber 已提交
830 831 832 833 834 835 836
      .def("run",
           [](paddle_infer::Predictor &self) {
#ifdef PADDLE_WITH_ASCEND_CL
             pybind11::gil_scoped_release release;
#endif
             self.Run();
           })
W
Wilber 已提交
837
      .def("clone", &paddle_infer::Predictor::Clone)
838
      .def("try_shrink_memory", &paddle_infer::Predictor::TryShrinkMemory)
W
Wilber 已提交
839 840 841 842
      .def("clear_intermediate_tensor",
           &paddle_infer::Predictor::ClearIntermediateTensor);
}

843 844
void BindZeroCopyTensor(py::module *m) {
  py::class_<ZeroCopyTensor>(*m, "ZeroCopyTensor")
S
Steffy-zxf 已提交
845 846 847 848
      .def("reshape", py::overload_cast<const std::vector<int> &>(
                          &ZeroCopyTensor::Reshape))
      .def("reshape", py::overload_cast<const std::size_t &>(
                          &paddle_infer::Tensor::ReshapeStrings))
849 850 851
      .def("copy_from_cpu", &ZeroCopyTensorCreate<int32_t>)
      .def("copy_from_cpu", &ZeroCopyTensorCreate<int64_t>)
      .def("copy_from_cpu", &ZeroCopyTensorCreate<float>)
852
      .def("copy_from_cpu", &ZeroCopyTensorCreate<paddle_infer::float16>)
S
Steffy-zxf 已提交
853
      .def("copy_from_cpu", &ZeroCopyStringTensorCreate)
854 855 856 857 858 859 860
      .def("copy_to_cpu", &ZeroCopyTensorToNumpy)
      .def("shape", &ZeroCopyTensor::shape)
      .def("set_lod", &ZeroCopyTensor::SetLoD)
      .def("lod", &ZeroCopyTensor::lod)
      .def("type", &ZeroCopyTensor::type);
}

W
Wilber 已提交
861 862
void BindPaddleInferTensor(py::module *m) {
  py::class_<paddle_infer::Tensor>(*m, "PaddleInferTensor")
S
Steffy-zxf 已提交
863 864 865 866
      .def("reshape", py::overload_cast<const std::vector<int> &>(
                          &paddle_infer::Tensor::Reshape))
      .def("reshape", py::overload_cast<const std::size_t &>(
                          &paddle_infer::Tensor::ReshapeStrings))
867 868 869 870 871
      .def("copy_from_cpu_bind", &PaddleInferTensorCreate<int32_t>)
      .def("copy_from_cpu_bind", &PaddleInferTensorCreate<int64_t>)
      .def("copy_from_cpu_bind", &PaddleInferTensorCreate<float>)
      .def("copy_from_cpu_bind",
           &PaddleInferTensorCreate<paddle_infer::float16>)
S
Steffy-zxf 已提交
872
      .def("copy_from_cpu_bind", &PaddleInferStringTensorCreate)
873
      .def("share_external_data_bind", &PaddleInferShareExternalData)
W
Wilber 已提交
874 875 876 877 878 879 880 881 882 883 884 885 886 887
      .def("copy_to_cpu", &PaddleInferTensorToNumpy)
      .def("shape", &paddle_infer::Tensor::shape)
      .def("set_lod", &paddle_infer::Tensor::SetLoD)
      .def("lod", &paddle_infer::Tensor::lod)
      .def("type", &paddle_infer::Tensor::type);
}

void BindPredictorPool(py::module *m) {
  py::class_<paddle_infer::services::PredictorPool>(*m, "PredictorPool")
      .def(py::init<const paddle_infer::Config &, size_t>())
      .def("retrive", &paddle_infer::services::PredictorPool::Retrive,
           py::return_value_policy::reference);
}

888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915
void BindPaddlePassBuilder(py::module *m) {
  py::class_<PaddlePassBuilder>(*m, "PaddlePassBuilder")
      .def(py::init<const std::vector<std::string> &>())
      .def("set_passes",
           [](PaddlePassBuilder &self, const std::vector<std::string> &passes) {
             self.ClearPasses();
             for (auto pass : passes) {
               self.AppendPass(std::move(pass));
             }
           })
      .def("append_pass", &PaddlePassBuilder::AppendPass)
      .def("insert_pass", &PaddlePassBuilder::InsertPass)
      .def("delete_pass",
           [](PaddlePassBuilder &self, const std::string &pass_type) {
             self.DeletePass(pass_type);
           })
      .def("append_analysis_pass", &PaddlePassBuilder::AppendAnalysisPass)
      .def("turn_on_debug", &PaddlePassBuilder::TurnOnDebug)
      .def("debug_string", &PaddlePassBuilder::DebugString)
      .def("all_passes", &PaddlePassBuilder::AllPasses,
           py::return_value_policy::reference)
      .def("analysis_passes", &PaddlePassBuilder::AnalysisPasses);

  py::class_<PassStrategy, PaddlePassBuilder>(*m, "PassStrategy")
      .def(py::init<const std::vector<std::string> &>())
      .def("enable_cudnn", &PassStrategy::EnableCUDNN)
      .def("enable_mkldnn", &PassStrategy::EnableMKLDNN)
      .def("enable_mkldnn_quantizer", &PassStrategy::EnableMkldnnQuantizer)
916
      .def("enable_mkldnn_bfloat16", &PassStrategy::EnableMkldnnBfloat16)
917 918 919 920 921 922 923
      .def("use_gpu", &PassStrategy::use_gpu);

  py::class_<CpuPassStrategy, PassStrategy>(*m, "CpuPassStrategy")
      .def(py::init<>())
      .def(py::init<const CpuPassStrategy &>())
      .def("enable_cudnn", &CpuPassStrategy::EnableCUDNN)
      .def("enable_mkldnn", &CpuPassStrategy::EnableMKLDNN)
924 925
      .def("enable_mkldnn_quantizer", &CpuPassStrategy::EnableMkldnnQuantizer)
      .def("enable_mkldnn_bfloat16", &CpuPassStrategy::EnableMkldnnBfloat16);
926 927 928 929 930 931

  py::class_<GpuPassStrategy, PassStrategy>(*m, "GpuPassStrategy")
      .def(py::init<>())
      .def(py::init<const GpuPassStrategy &>())
      .def("enable_cudnn", &GpuPassStrategy::EnableCUDNN)
      .def("enable_mkldnn", &GpuPassStrategy::EnableMKLDNN)
932 933
      .def("enable_mkldnn_quantizer", &GpuPassStrategy::EnableMkldnnQuantizer)
      .def("enable_mkldnn_bfloat16", &GpuPassStrategy::EnableMkldnnBfloat16);
934
}
935
}  // namespace
F
flame 已提交
936 937
}  // namespace pybind
}  // namespace paddle