inference_api.cc 39.3 KB
Newer Older
F
flame 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pybind/inference_api.h"
16

17
#include <pybind11/numpy.h>
F
flame 已提交
18
#include <pybind11/stl.h>
19

F
flame 已提交
20
#include <cstring>
21
#include <functional>
F
flame 已提交
22
#include <iostream>
23
#include <iterator>
24
#include <map>
25
#include <memory>
F
flame 已提交
26
#include <string>
27
#include <type_traits>
28
#include <unordered_set>
29
#include <utility>
F
flame 已提交
30
#include <vector>
31

F
flame 已提交
32
#include "paddle/fluid/inference/api/analysis_predictor.h"
33
#include "paddle/fluid/inference/api/helper.h"
34
#include "paddle/fluid/inference/api/paddle_infer_contrib.h"
F
flame 已提交
35
#include "paddle/fluid/inference/api/paddle_inference_api.h"
36
#include "paddle/fluid/inference/api/paddle_pass_builder.h"
37
#include "paddle/fluid/inference/utils/io_utils.h"
F
flame 已提交
38

39 40 41 42
#ifdef PADDLE_WITH_ONNXRUNTIME
#include "paddle/fluid/inference/api/onnxruntime_predictor.h"
#endif

F
flame 已提交
43 44
namespace py = pybind11;

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
namespace pybind11 {
namespace detail {

// Note: use same enum number of float16 in numpy.
// import numpy as np
// print np.dtype(np.float16).num  # 23
constexpr int NPY_FLOAT16_ = 23;
constexpr int NPY_UINT16_ = 4;

// Note: Since float16 is not a builtin type in C++, we register
// paddle::platform::float16 as numpy.float16.
// Ref: https://github.com/pybind/pybind11/issues/1776
template <>
struct npy_format_descriptor<paddle_infer::float16> {
  static py::dtype dtype() {
    handle ptr = npy_api::get().PyArray_DescrFromType_(NPY_FLOAT16_);
    return reinterpret_borrow<py::dtype>(ptr);
  }
  static std::string format() {
    // Note: "e" represents float16.
    // Details at:
    // https://docs.python.org/3/library/struct.html#format-characters.
    return "e";
  }
  static constexpr auto name = _("float16");
};

}  // namespace detail
}  // namespace pybind11

F
flame 已提交
75 76
namespace paddle {
namespace pybind {
77 78 79
using paddle::AnalysisPredictor;
using paddle::NativeConfig;
using paddle::NativePaddlePredictor;
F
flame 已提交
80
using paddle::PaddleBuf;
81
using paddle::PaddleDataLayout;
82
using paddle::PaddleDType;
83
using paddle::PaddlePassBuilder;
F
flame 已提交
84 85
using paddle::PaddlePlace;
using paddle::PaddlePredictor;
86 87 88
using paddle::PaddleTensor;
using paddle::PassStrategy;
using paddle::ZeroCopyTensor;
F
flame 已提交
89

90 91
namespace {
void BindPaddleDType(py::module *m);
92
void BindPaddleDataLayout(py::module *m);
93 94 95 96 97 98
void BindPaddleBuf(py::module *m);
void BindPaddleTensor(py::module *m);
void BindPaddlePlace(py::module *m);
void BindPaddlePredictor(py::module *m);
void BindNativeConfig(py::module *m);
void BindNativePredictor(py::module *m);
99
void BindLiteNNAdapterConfig(py::module *m);
100 101
void BindAnalysisConfig(py::module *m);
void BindAnalysisPredictor(py::module *m);
102 103
void BindZeroCopyTensor(py::module *m);
void BindPaddlePassBuilder(py::module *m);
W
Wilber 已提交
104 105 106
void BindPaddleInferPredictor(py::module *m);
void BindPaddleInferTensor(py::module *m);
void BindPredictorPool(py::module *m);
F
flame 已提交
107

108
#ifdef PADDLE_WITH_MKLDNN
109
void BindMkldnnQuantizerConfig(py::module *m);
110
#endif
111 112

template <typename T>
113 114
PaddleBuf PaddleBufCreate(
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {
115
  PaddleBuf buf(data.size() * sizeof(T));
W
Wilber 已提交
116 117
  std::copy_n(static_cast<const T *>(data.data()),
              data.size(),
118 119 120 121 122
              static_cast<T *>(buf.data()));
  return buf;
}

template <typename T>
123 124 125
void PaddleBufReset(
    PaddleBuf &buf,                                                    // NOLINT
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {  // NOLINT
126
  buf.Resize(data.size() * sizeof(T));
W
Wilber 已提交
127 128
  std::copy_n(static_cast<const T *>(data.data()),
              data.size(),
129 130 131 132 133
              static_cast<T *>(buf.data()));
}

template <typename T>
PaddleTensor PaddleTensorCreate(
134 135
    py::array_t<T, py::array::c_style | py::array::forcecast> data,
    const std::string name = "",
W
Wilber 已提交
136 137
    const std::vector<std::vector<size_t>> &lod = {},
    bool copy = true) {
138 139 140 141
  PaddleTensor tensor;

  if (copy) {
    PaddleBuf buf(data.size() * sizeof(T));
W
Wilber 已提交
142 143
    std::copy_n(static_cast<const T *>(data.data()),
                data.size(),
144 145 146 147 148 149
                static_cast<T *>(buf.data()));
    tensor.data = std::move(buf);
  } else {
    tensor.data = PaddleBuf(data.mutable_data(), data.size() * sizeof(T));
  }

150
  tensor.dtype = inference::PaddleTensorGetDType<T>();
151 152 153 154 155 156 157 158
  tensor.name = name;
  tensor.lod = lod;
  tensor.shape.resize(data.ndim());
  std::copy_n(data.shape(), data.ndim(), tensor.shape.begin());

  return tensor;
}

159
py::dtype PaddleDTypeToNumpyDType(PaddleDType dtype) {
160
  py::dtype dt;
161
  switch (dtype) {
162 163 164 165 166 167 168 169 170
    case PaddleDType::INT32:
      dt = py::dtype::of<int32_t>();
      break;
    case PaddleDType::INT64:
      dt = py::dtype::of<int64_t>();
      break;
    case PaddleDType::FLOAT32:
      dt = py::dtype::of<float>();
      break;
W
Wilber 已提交
171 172 173
    case PaddleDType::UINT8:
      dt = py::dtype::of<uint8_t>();
      break;
174 175 176
    case PaddleDType::FLOAT16:
      dt = py::dtype::of<paddle_infer::float16>();
      break;
177
    default:
178
      PADDLE_THROW(platform::errors::Unimplemented(
W
Wilber 已提交
179
          "Unsupported data type. Now only supports INT32, INT64, UINT8 and "
180
          "FLOAT32."));
181
  }
182 183 184 185 186 187 188 189 190 191

  return dt;
}

py::array PaddleTensorGetData(PaddleTensor &tensor) {  // NOLINT
  py::dtype dt = PaddleDTypeToNumpyDType(tensor.dtype);
  return py::array(std::move(dt), {tensor.shape}, tensor.data.data());
}

template <typename T>
192 193 194
void ZeroCopyTensorCreate(
    ZeroCopyTensor &tensor,  // NOLINT
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {
195 196 197 198 199 200
  std::vector<int> shape;
  std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape));
  tensor.Reshape(std::move(shape));
  tensor.copy_from_cpu(static_cast<const T *>(data.data()));
}

S
Steffy-zxf 已提交
201 202 203 204 205 206 207 208 209 210 211 212
/// \brief Experimental interface.
/// Create the Strings tensor from data.
/// \param tensor The tensor will be created and
/// the tensor value is same as data.
/// \param data The input text.
void ZeroCopyStringTensorCreate(ZeroCopyTensor &tensor,  // NOLINT
                                const paddle_infer::Strings *data) {
  size_t shape = data->size();
  tensor.ReshapeStrings(shape);
  tensor.copy_strings_from_cpu(data);
}

W
Wilber 已提交
213
template <typename T>
214 215 216
void PaddleInferTensorCreate(
    paddle_infer::Tensor &tensor,  // NOLINT
    py::array_t<T, py::array::c_style | py::array::forcecast> data) {
W
Wilber 已提交
217 218 219 220 221 222
  std::vector<int> shape;
  std::copy_n(data.shape(), data.ndim(), std::back_inserter(shape));
  tensor.Reshape(std::move(shape));
  tensor.CopyFromCpu(static_cast<const T *>(data.data()));
}

223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
paddle_infer::PlaceType ToPaddleInferPlace(
    phi::AllocationType allocation_type) {
  if (allocation_type == phi::AllocationType::CPU) {
    return paddle_infer::PlaceType::kCPU;
  } else if (allocation_type == phi::AllocationType::GPU) {
    return paddle_infer::PlaceType::kGPU;
  } else {
    return paddle_infer::PlaceType::kCPU;
  }
}

void PaddleInferShareExternalData(paddle_infer::Tensor &tensor,  // NOLINT
                                  framework::Tensor input_tensor) {
  std::vector<int> shape;
  for (int i = 0; i < input_tensor.dims().size(); ++i) {
    shape.push_back(input_tensor.dims()[i]);
  }
  if (input_tensor.dtype() == phi::DataType::FLOAT32) {
    tensor.ShareExternalData(
W
Wilber 已提交
242 243
        static_cast<float *>(input_tensor.data()),
        shape,
244 245 246
        ToPaddleInferPlace(input_tensor.place().GetType()));
  } else if (input_tensor.dtype() == phi::DataType::FLOAT16) {
    tensor.ShareExternalData(
W
Wilber 已提交
247 248
        static_cast<paddle::platform::float16 *>(input_tensor.data()),
        shape,
249 250 251 252
        ToPaddleInferPlace(input_tensor.place().GetType()));
  }
}

S
Steffy-zxf 已提交
253 254 255 256 257 258 259 260 261 262 263 264 265
/// \brief Experimental interface.
/// Create the Strings tensor from data.
/// \param tensor The tensor will be created and
/// the tensor value is same as data.
/// \param data The input text.
void PaddleInferStringTensorCreate(paddle_infer::Tensor &tensor,  // NOLINT
                                   const paddle_infer::Strings *data) {
  VLOG(3) << "Create PaddleInferTensor, dtype = Strings ";
  size_t shape = data->size();
  tensor.ReshapeStrings(shape);
  tensor.CopyStringsFromCpu(data);
}

266 267 268 269 270 271 272 273 274 275 276 277 278
size_t PaddleGetDTypeSize(PaddleDType dt) {
  size_t size{0};
  switch (dt) {
    case PaddleDType::INT32:
      size = sizeof(int32_t);
      break;
    case PaddleDType::INT64:
      size = sizeof(int64_t);
      break;
    case PaddleDType::FLOAT32:
      size = sizeof(float);
      break;
    default:
279 280 281
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported data type. Now only supports INT32, INT64 and "
          "FLOAT32."));
282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
  }
  return size;
}

py::array ZeroCopyTensorToNumpy(ZeroCopyTensor &tensor) {  // NOLINT
  py::dtype dt = PaddleDTypeToNumpyDType(tensor.type());
  auto tensor_shape = tensor.shape();
  py::array::ShapeContainer shape(tensor_shape.begin(), tensor_shape.end());
  py::array array(dt, std::move(shape));

  switch (tensor.type()) {
    case PaddleDType::INT32:
      tensor.copy_to_cpu(static_cast<int32_t *>(array.mutable_data()));
      break;
    case PaddleDType::INT64:
      tensor.copy_to_cpu(static_cast<int64_t *>(array.mutable_data()));
      break;
    case PaddleDType::FLOAT32:
      tensor.copy_to_cpu<float>(static_cast<float *>(array.mutable_data()));
      break;
302 303 304 305
    case PaddleDType::FLOAT16:
      tensor.copy_to_cpu<paddle::platform::float16>(
          static_cast<paddle::platform::float16 *>(array.mutable_data()));
      break;
W
Wilber 已提交
306 307 308
    case PaddleDType::UINT8:
      tensor.copy_to_cpu<uint8_t>(static_cast<uint8_t *>(array.mutable_data()));
      break;
309 310 311
    case PaddleDType::INT8:
      tensor.copy_to_cpu<int8_t>(static_cast<int8_t *>(array.mutable_data()));
      break;
312
    default:
313
      PADDLE_THROW(platform::errors::Unimplemented(
W
Wilber 已提交
314
          "Unsupported data type. Now only supports INT32, INT64, UINT8 and "
315
          "FLOAT32."));
316 317
  }
  return array;
318
}
319

W
Wilber 已提交
320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
py::array PaddleInferTensorToNumpy(paddle_infer::Tensor &tensor) {  // NOLINT
  py::dtype dt = PaddleDTypeToNumpyDType(tensor.type());
  auto tensor_shape = tensor.shape();
  py::array::ShapeContainer shape(tensor_shape.begin(), tensor_shape.end());
  py::array array(dt, std::move(shape));

  switch (tensor.type()) {
    case PaddleDType::INT32:
      tensor.CopyToCpu(static_cast<int32_t *>(array.mutable_data()));
      break;
    case PaddleDType::INT64:
      tensor.CopyToCpu(static_cast<int64_t *>(array.mutable_data()));
      break;
    case PaddleDType::FLOAT32:
      tensor.CopyToCpu<float>(static_cast<float *>(array.mutable_data()));
      break;
336 337 338 339
    case PaddleDType::FLOAT16:
      tensor.CopyToCpu<paddle::platform::float16>(
          static_cast<paddle::platform::float16 *>(array.mutable_data()));
      break;
340 341 342 343 344 345
    case PaddleDType::UINT8:
      tensor.CopyToCpu(static_cast<uint8_t *>(array.mutable_data()));
      break;
    case PaddleDType::INT8:
      tensor.CopyToCpu(static_cast<int8_t *>(array.mutable_data()));
      break;
W
Wilber 已提交
346 347 348 349 350 351 352 353
    default:
      PADDLE_THROW(platform::errors::Unimplemented(
          "Unsupported data type. Now only supports INT32, INT64 and "
          "FLOAT32."));
  }
  return array;
}

354 355 356 357 358
py::bytes SerializePDTensorToBytes(PaddleTensor &tensor) {  // NOLINT
  std::stringstream ss;
  paddle::inference::SerializePDTensorToStream(&ss, tensor);
  return static_cast<py::bytes>(ss.str());
}
359

360
void CopyPaddleInferTensor(paddle_infer::Tensor &dst,  // NOLINT
361 362 363 364
                           const paddle_infer::Tensor &src) {
  return paddle_infer::contrib::TensorUtils::CopyTensor(&dst, src);
}

365
}  // namespace
366

F
flame 已提交
367 368
void BindInferenceApi(py::module *m) {
  BindPaddleDType(m);
369
  BindPaddleDataLayout(m);
F
flame 已提交
370 371 372 373 374 375
  BindPaddleBuf(m);
  BindPaddleTensor(m);
  BindPaddlePlace(m);
  BindPaddlePredictor(m);
  BindNativeConfig(m);
  BindNativePredictor(m);
376
  BindLiteNNAdapterConfig(m);
F
flame 已提交
377 378
  BindAnalysisConfig(m);
  BindAnalysisPredictor(m);
W
Wilber 已提交
379
  BindPaddleInferPredictor(m);
380
  BindZeroCopyTensor(m);
W
Wilber 已提交
381
  BindPaddleInferTensor(m);
382
  BindPaddlePassBuilder(m);
W
Wilber 已提交
383
  BindPredictorPool(m);
384 385 386
#ifdef PADDLE_WITH_MKLDNN
  BindMkldnnQuantizerConfig(m);
#endif
F
flame 已提交
387
  m->def("create_paddle_predictor",
W
Wilber 已提交
388 389
         &paddle::CreatePaddlePredictor<AnalysisConfig>,
         py::arg("config"));
F
flame 已提交
390
  m->def("create_paddle_predictor",
W
Wilber 已提交
391 392
         &paddle::CreatePaddlePredictor<NativeConfig>,
         py::arg("config"));
393 394 395 396 397 398 399
  m->def("create_predictor",
         [](const paddle_infer::Config &config)
             -> std::unique_ptr<paddle_infer::Predictor> {
           auto pred = std::unique_ptr<paddle_infer::Predictor>(
               new paddle_infer::Predictor(config));
           return pred;
         });
400
  m->def("copy_tensor", &CopyPaddleInferTensor);
F
flame 已提交
401
  m->def("paddle_dtype_size", &paddle::PaddleDtypeSize);
402
  m->def("paddle_tensor_to_bytes", &SerializePDTensorToBytes);
W
Wilber 已提交
403
  m->def("get_version", &paddle_infer::GetVersion);
404 405
  m->def("get_trt_compile_version", &paddle_infer::GetTrtCompileVersion);
  m->def("get_trt_runtime_version", &paddle_infer::GetTrtRuntimeVersion);
W
Wilber 已提交
406
  m->def("get_num_bytes_of_data_type", &paddle_infer::GetNumBytesOfDataType);
407 408 409 410 411 412 413 414 415 416
  m->def("convert_to_mixed_precision_bind",
         &paddle_infer::ConvertToMixedPrecision,
         py::arg("model_file"),
         py::arg("params_file"),
         py::arg("mixed_model_file"),
         py::arg("mixed_params_file"),
         py::arg("mixed_precision"),
         py::arg("backend"),
         py::arg("keep_io_types") = true,
         py::arg("black_list") = std::unordered_set<std::string>());
F
flame 已提交
417 418
}

419
namespace {
F
flame 已提交
420 421 422
void BindPaddleDType(py::module *m) {
  py::enum_<PaddleDType>(*m, "PaddleDType")
      .value("FLOAT32", PaddleDType::FLOAT32)
423 424
      .value("INT64", PaddleDType::INT64)
      .value("INT32", PaddleDType::INT32);
F
flame 已提交
425 426
}

427 428 429 430 431 432 433 434
void BindPaddleDataLayout(py::module *m) {
  py::enum_<PaddleDataLayout>(*m, "PaddleDataLayout")
      .value("UNK", PaddleDataLayout::kUNK)
      .value("Any", PaddleDataLayout::kAny)
      .value("NHWC", PaddleDataLayout::kNHWC)
      .value("NCHW", PaddleDataLayout::kNCHW);
}

F
flame 已提交
435 436 437 438 439 440
void BindPaddleBuf(py::module *m) {
  py::class_<PaddleBuf>(*m, "PaddleBuf")
      .def(py::init<size_t>())
      .def(py::init([](std::vector<float> &data) {
        auto buf = PaddleBuf(data.size() * sizeof(float));
        std::memcpy(buf.data(), static_cast<void *>(data.data()), buf.length());
G
Gabor Buella 已提交
441
        return buf;
F
flame 已提交
442
      }))
443 444 445
      .def(py::init(&PaddleBufCreate<int32_t>))
      .def(py::init(&PaddleBufCreate<int64_t>))
      .def(py::init(&PaddleBufCreate<float>))
F
flame 已提交
446 447 448 449 450 451
      .def("resize", &PaddleBuf::Resize)
      .def("reset",
           [](PaddleBuf &self, std::vector<float> &data) {
             self.Resize(data.size() * sizeof(float));
             std::memcpy(self.data(), data.data(), self.length());
           })
452 453 454
      .def("reset", &PaddleBufReset<int32_t>)
      .def("reset", &PaddleBufReset<int64_t>)
      .def("reset", &PaddleBufReset<float>)
455
      .def("empty", &PaddleBuf::empty)
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471
      .def("tolist",
           [](PaddleBuf &self, const std::string &dtype) -> py::list {
             py::list l;
             if (dtype == "int32") {
               auto *data = static_cast<int32_t *>(self.data());
               auto size = self.length() / sizeof(int32_t);
               l = py::cast(std::vector<int32_t>(data, data + size));
             } else if (dtype == "int64") {
               auto *data = static_cast<int64_t *>(self.data());
               auto size = self.length() / sizeof(int64_t);
               l = py::cast(std::vector<int64_t>(data, data + size));
             } else if (dtype == "float32") {
               auto *data = static_cast<float *>(self.data());
               auto size = self.length() / sizeof(float);
               l = py::cast(std::vector<float>(data, data + size));
             } else {
472 473 474
               PADDLE_THROW(platform::errors::Unimplemented(
                   "Unsupported data type. Now only supports INT32, INT64 and "
                   "FLOAT32."));
475 476 477
             }
             return l;
           })
F
flame 已提交
478 479 480 481 482 483 484 485 486 487
      .def("float_data",
           [](PaddleBuf &self) -> std::vector<float> {
             auto *data = static_cast<float *>(self.data());
             return {data, data + self.length() / sizeof(*data)};
           })
      .def("int64_data",
           [](PaddleBuf &self) -> std::vector<int64_t> {
             int64_t *data = static_cast<int64_t *>(self.data());
             return {data, data + self.length() / sizeof(*data)};
           })
488 489 490 491
      .def("int32_data",
           [](PaddleBuf &self) -> std::vector<int32_t> {
             int32_t *data = static_cast<int32_t *>(self.data());
             return {data, data + self.length() / sizeof(*data)};
F
flame 已提交
492 493 494 495 496 497 498
           })
      .def("length", &PaddleBuf::length);
}

void BindPaddleTensor(py::module *m) {
  py::class_<PaddleTensor>(*m, "PaddleTensor")
      .def(py::init<>())
W
Wilber 已提交
499 500
      .def(py::init(&PaddleTensorCreate<int32_t>),
           py::arg("data"),
501 502 503
           py::arg("name") = "",
           py::arg("lod") = std::vector<std::vector<size_t>>(),
           py::arg("copy") = true)
W
Wilber 已提交
504 505
      .def(py::init(&PaddleTensorCreate<int64_t>),
           py::arg("data"),
506 507 508
           py::arg("name") = "",
           py::arg("lod") = std::vector<std::vector<size_t>>(),
           py::arg("copy") = true)
W
Wilber 已提交
509 510
      .def(py::init(&PaddleTensorCreate<float>),
           py::arg("data"),
511 512 513 514
           py::arg("name") = "",
           py::arg("lod") = std::vector<std::vector<size_t>>(),
           py::arg("copy") = true)
      .def("as_ndarray", &PaddleTensorGetData)
F
flame 已提交
515 516 517 518 519 520 521 522 523 524 525
      .def_readwrite("name", &PaddleTensor::name)
      .def_readwrite("shape", &PaddleTensor::shape)
      .def_readwrite("data", &PaddleTensor::data)
      .def_readwrite("dtype", &PaddleTensor::dtype)
      .def_readwrite("lod", &PaddleTensor::lod);
}

void BindPaddlePlace(py::module *m) {
  py::enum_<PaddlePlace>(*m, "PaddlePlace")
      .value("UNK", PaddlePlace::kUNK)
      .value("CPU", PaddlePlace::kCPU)
526
      .value("GPU", PaddlePlace::kGPU)
W
Wilber 已提交
527 528
      .value("XPU", PaddlePlace::kXPU)
      .value("NPU", PaddlePlace::kNPU);
F
flame 已提交
529 530 531 532 533 534 535 536 537 538 539 540 541
}

void BindPaddlePredictor(py::module *m) {
  auto paddle_predictor = py::class_<PaddlePredictor>(*m, "PaddlePredictor");
  paddle_predictor
      .def("run",
           [](PaddlePredictor &self, const std::vector<PaddleTensor> &inputs) {
             std::vector<PaddleTensor> outputs;
             self.Run(inputs, &outputs);
             return outputs;
           })
      .def("get_input_tensor", &PaddlePredictor::GetInputTensor)
      .def("get_output_tensor", &PaddlePredictor::GetOutputTensor)
542 543
      .def("get_input_names", &PaddlePredictor::GetInputNames)
      .def("get_output_names", &PaddlePredictor::GetOutputNames)
F
flame 已提交
544
      .def("zero_copy_run", &PaddlePredictor::ZeroCopyRun)
545 546
      .def("clone", &PaddlePredictor::Clone)
      .def("get_serialized_program", &PaddlePredictor::GetSerializedProgram);
F
flame 已提交
547 548 549 550 551 552 553 554 555 556

  auto config = py::class_<PaddlePredictor::Config>(paddle_predictor, "Config");
  config.def(py::init<>())
      .def_readwrite("model_dir", &PaddlePredictor::Config::model_dir);
}

void BindNativeConfig(py::module *m) {
  py::class_<NativeConfig, PaddlePredictor::Config>(*m, "NativeConfig")
      .def(py::init<>())
      .def_readwrite("use_gpu", &NativeConfig::use_gpu)
557
      .def_readwrite("use_xpu", &NativeConfig::use_xpu)
W
Wilber 已提交
558
      .def_readwrite("use_npu", &NativeConfig::use_npu)
F
flame 已提交
559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
      .def_readwrite("device", &NativeConfig::device)
      .def_readwrite("fraction_of_gpu_memory",
                     &NativeConfig::fraction_of_gpu_memory)
      .def_readwrite("prog_file", &NativeConfig::prog_file)
      .def_readwrite("param_file", &NativeConfig::param_file)
      .def_readwrite("specify_input_name", &NativeConfig::specify_input_name)
      .def("set_cpu_math_library_num_threads",
           &NativeConfig::SetCpuMathLibraryNumThreads)
      .def("cpu_math_library_num_threads",
           &NativeConfig::cpu_math_library_num_threads);
}

void BindNativePredictor(py::module *m) {
  py::class_<NativePaddlePredictor, PaddlePredictor>(*m,
                                                     "NativePaddlePredictor")
      .def(py::init<const NativeConfig &>())
      .def("init", &NativePaddlePredictor::Init)
      .def("run",
           [](NativePaddlePredictor &self,
              const std::vector<PaddleTensor> &inputs) {
             std::vector<PaddleTensor> outputs;
             self.Run(inputs, &outputs);
             return outputs;
           })
      .def("get_input_tensor", &NativePaddlePredictor::GetInputTensor)
      .def("get_output_tensor", &NativePaddlePredictor::GetOutputTensor)
      .def("zero_copy_run", &NativePaddlePredictor::ZeroCopyRun)
      .def("clone", &NativePaddlePredictor::Clone)
W
Wilber 已提交
587 588
      .def("scope",
           &NativePaddlePredictor::scope,
F
flame 已提交
589 590 591 592
           py::return_value_policy::reference);
}

void BindAnalysisConfig(py::module *m) {
593 594 595 596 597
  py::class_<AnalysisConfig> analysis_config(*m, "AnalysisConfig");

  py::enum_<AnalysisConfig::Precision>(analysis_config, "Precision")
      .value("Float32", AnalysisConfig::Precision::kFloat32)
      .value("Int8", AnalysisConfig::Precision::kInt8)
Z
Zhaolong Xing 已提交
598
      .value("Half", AnalysisConfig::Precision::kHalf)
599 600 601 602 603 604 605 606
      .value("Bfloat16", AnalysisConfig::Precision::kBf16)
      .export_values();

  py::enum_<AnalysisConfig::Backend>(analysis_config, "Backend")
      .value("CPU", AnalysisConfig::Backend::kCPU)
      .value("GPU", AnalysisConfig::Backend::kGPU)
      .value("NPU", AnalysisConfig::Backend::kNPU)
      .value("XPU", AnalysisConfig::Backend::kXPU)
607 608
      .export_values();

609 610
  analysis_config.def(py::init<>())
      .def(py::init<const AnalysisConfig &>())
F
flame 已提交
611 612
      .def(py::init<const std::string &>())
      .def(py::init<const std::string &, const std::string &>())
613
      .def("summary", &AnalysisConfig::Summary)
W
Wilber 已提交
614 615 616
      .def("set_model",
           (void(AnalysisConfig::*)(const std::string &)) &
               AnalysisConfig::SetModel)
617 618 619
      .def("set_model",
           (void(AnalysisConfig::*)(const std::string &, const std::string &)) &
               AnalysisConfig::SetModel)
F
flame 已提交
620 621 622 623 624
      .def("set_prog_file", &AnalysisConfig::SetProgFile)
      .def("set_params_file", &AnalysisConfig::SetParamsFile)
      .def("model_dir", &AnalysisConfig::model_dir)
      .def("prog_file", &AnalysisConfig::prog_file)
      .def("params_file", &AnalysisConfig::params_file)
W
Wilber 已提交
625 626 627 628 629 630
      .def("enable_use_gpu",
           &AnalysisConfig::EnableUseGpu,
           py::arg("memory_pool_init_size_mb"),
           py::arg("device_id") = 0)
      .def("enable_xpu",
           &AnalysisConfig::EnableXpu,
W
Wilber 已提交
631
           py::arg("l3_workspace_size") = 16 * 1024 * 1024,
W
Wilber 已提交
632 633 634 635
           py::arg("locked") = false,
           py::arg("autotune") = true,
           py::arg("autotune_file") = "",
           py::arg("precision") = "int16",
W
Wilber 已提交
636
           py::arg("adaptive_seqlen") = false)
W
Wilber 已提交
637 638
      .def("set_xpu_device_id",
           &AnalysisConfig::SetXpuDeviceId,
639
           py::arg("device_id") = 0)
W
Wilber 已提交
640
      .def("enable_npu", &AnalysisConfig::EnableNpu, py::arg("device_id") = 0)
W
Wilber 已提交
641 642 643 644
      .def("enable_ipu",
           &AnalysisConfig::EnableIpu,
           py::arg("ipu_device_num") = 1,
           py::arg("ipu_micro_batch_size") = 1,
645 646
           py::arg("ipu_enable_pipelining") = false,
           py::arg("ipu_batches_per_step") = 1)
W
Wilber 已提交
647 648 649 650
      .def("set_ipu_config",
           &AnalysisConfig::SetIpuConfig,
           py::arg("ipu_enable_fp16") = false,
           py::arg("ipu_replica_num") = 1,
651 652
           py::arg("ipu_available_memory_proportion") = 1.0,
           py::arg("ipu_enable_half_partial") = false)
F
flame 已提交
653
      .def("disable_gpu", &AnalysisConfig::DisableGpu)
654 655 656 657
      .def("enable_onnxruntime", &AnalysisConfig::EnableONNXRuntime)
      .def("disable_onnxruntime", &AnalysisConfig::DisableONNXRuntime)
      .def("onnxruntime_enabled", &AnalysisConfig::use_onnxruntime)
      .def("enable_ort_optimization", &AnalysisConfig::EnableORTOptimization)
F
flame 已提交
658
      .def("use_gpu", &AnalysisConfig::use_gpu)
659
      .def("use_xpu", &AnalysisConfig::use_xpu)
W
Wilber 已提交
660
      .def("use_npu", &AnalysisConfig::use_npu)
F
flame 已提交
661
      .def("gpu_device_id", &AnalysisConfig::gpu_device_id)
662
      .def("xpu_device_id", &AnalysisConfig::xpu_device_id)
W
Wilber 已提交
663
      .def("npu_device_id", &AnalysisConfig::npu_device_id)
F
flame 已提交
664 665 666 667
      .def("memory_pool_init_size_mb",
           &AnalysisConfig::memory_pool_init_size_mb)
      .def("fraction_of_gpu_memory_for_pool",
           &AnalysisConfig::fraction_of_gpu_memory_for_pool)
W
Wilber 已提交
668 669
      .def("switch_ir_optim",
           &AnalysisConfig::SwitchIrOptim,
F
flame 已提交
670 671
           py::arg("x") = true)
      .def("ir_optim", &AnalysisConfig::ir_optim)
W
Wilber 已提交
672 673
      .def("enable_memory_optim",
           &AnalysisConfig::EnableMemoryOptim,
674
           py::arg("x") = true)
675
      .def("enable_profile", &AnalysisConfig::EnableProfile)
676
      .def("disable_glog_info", &AnalysisConfig::DisableGlogInfo)
677
      .def("glog_info_disabled", &AnalysisConfig::glog_info_disabled)
678
      .def("set_optim_cache_dir", &AnalysisConfig::SetOptimCacheDir)
W
Wilber 已提交
679 680
      .def("switch_use_feed_fetch_ops",
           &AnalysisConfig::SwitchUseFeedFetchOps,
F
flame 已提交
681 682 683 684
           py::arg("x") = true)
      .def("use_feed_fetch_ops_enabled",
           &AnalysisConfig::use_feed_fetch_ops_enabled)
      .def("switch_specify_input_names",
W
Wilber 已提交
685 686
           &AnalysisConfig::SwitchSpecifyInputNames,
           py::arg("x") = true)
F
flame 已提交
687
      .def("specify_input_name", &AnalysisConfig::specify_input_name)
W
Wilber 已提交
688 689
      .def("enable_tensorrt_engine",
           &AnalysisConfig::EnableTensorRtEngine,
690
           py::arg("workspace_size") = 1 << 30,
W
Wilber 已提交
691
           py::arg("max_batch_size") = 1,
692
           py::arg("min_subgraph_size") = 3,
N
nhzlx 已提交
693
           py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
W
Wilber 已提交
694 695
           py::arg("use_static") = false,
           py::arg("use_calib_mode") = true)
696
      .def("tensorrt_precision_mode", &AnalysisConfig::tensorrt_precision_mode)
697 698
      .def("set_trt_dynamic_shape_info",
           &AnalysisConfig::SetTRTDynamicShapeInfo,
699 700 701 702 703
           py::arg("min_input_shape") =
               std::map<std::string, std::vector<int>>({}),
           py::arg("max_input_shape") =
               std::map<std::string, std::vector<int>>({}),
           py::arg("optim_input_shape") =
704 705
               std::map<std::string, std::vector<int>>({}),
           py::arg("disable_trt_plugin_fp16") = false)
706 707
      .def("tensorrt_dynamic_shape_enabled",
           &AnalysisConfig::tensorrt_dynamic_shape_enabled)
708 709 710
      .def("enable_tensorrt_varseqlen", &AnalysisConfig::EnableVarseqlen)
      .def("tensorrt_varseqlen_enabled",
           &AnalysisConfig::tensorrt_varseqlen_enabled)
711 712 713 714 715 716 717 718 719 720
      .def("collect_shape_range_info", &AnalysisConfig::CollectShapeRangeInfo)
      .def("shape_range_info_path", &AnalysisConfig::shape_range_info_path)
      .def("shape_range_info_collected",
           &AnalysisConfig::shape_range_info_collected)
      .def("enable_tuned_tensorrt_dynamic_shape",
           &AnalysisConfig::EnableTunedTensorRtDynamicShape)
      .def("tuned_tensorrt_dynamic_shape",
           &AnalysisConfig::tuned_tensorrt_dynamic_shape)
      .def("trt_allow_build_at_runtime",
           &AnalysisConfig::trt_allow_build_at_runtime)
721
      .def("exp_disable_tensorrt_ops", &AnalysisConfig::Exp_DisableTensorRtOPs)
W
Wilber 已提交
722 723
      .def("enable_tensorrt_dla",
           &AnalysisConfig::EnableTensorRtDLA,
724 725
           py::arg("dla_core") = 0)
      .def("tensorrt_dla_enabled", &AnalysisConfig::tensorrt_dla_enabled)
726 727 728 729
      .def("enable_tensorrt_inspector",
           &AnalysisConfig::EnableTensorRtInspector)
      .def("tensorrt_inspector_enabled",
           &AnalysisConfig::tensorrt_inspector_enabled)
F
flame 已提交
730
      .def("tensorrt_engine_enabled", &AnalysisConfig::tensorrt_engine_enabled)
W
Wilber 已提交
731 732
      .def("enable_dlnne",
           &AnalysisConfig::EnableDlnne,
D
denglin-github 已提交
733
           py::arg("min_subgraph_size") = 3)
W
Wilber 已提交
734 735
      .def("enable_lite_engine",
           &AnalysisConfig::EnableLiteEngine,
736
           py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32,
W
Wilber 已提交
737
           py::arg("zero_copy") = false,
738 739 740
           py::arg("passes_filter") = std::vector<std::string>(),
           py::arg("ops_filter") = std::vector<std::string>())
      .def("lite_engine_enabled", &AnalysisConfig::lite_engine_enabled)
W
Wilber 已提交
741 742
      .def("switch_ir_debug",
           &AnalysisConfig::SwitchIrDebug,
F
flame 已提交
743 744 745 746 747 748 749 750
           py::arg("x") = true)
      .def("enable_mkldnn", &AnalysisConfig::EnableMKLDNN)
      .def("mkldnn_enabled", &AnalysisConfig::mkldnn_enabled)
      .def("set_cpu_math_library_num_threads",
           &AnalysisConfig::SetCpuMathLibraryNumThreads)
      .def("cpu_math_library_num_threads",
           &AnalysisConfig::cpu_math_library_num_threads)
      .def("to_native_config", &AnalysisConfig::ToNativeConfig)
751
      .def("enable_quantizer", &AnalysisConfig::EnableMkldnnQuantizer)
752
      .def("enable_mkldnn_bfloat16", &AnalysisConfig::EnableMkldnnBfloat16)
753
#ifdef PADDLE_WITH_MKLDNN
W
Wilber 已提交
754 755
      .def("quantizer_config",
           &AnalysisConfig::mkldnn_quantizer_config,
756
           py::return_value_policy::reference)
W
Wilber 已提交
757 758
      .def("set_mkldnn_cache_capacity",
           &AnalysisConfig::SetMkldnnCacheCapacity,
759
           py::arg("capacity") = 0)
760
      .def("set_bfloat16_op", &AnalysisConfig::SetBfloat16Op)
W
Wilber 已提交
761 762
      .def("enable_mkldnn_int8",
           &AnalysisConfig::EnableMkldnnInt8,
B
baoachun 已提交
763 764 765
           py::arg("mkldnn_int8_enabled_op_types") =
               std::unordered_set<std::string>({}))
      .def("mkldnn_int8_enabled", &AnalysisConfig::mkldnn_int8_enabled)
766
#endif
F
flame 已提交
767 768 769
      .def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp)
      .def("set_model_buffer", &AnalysisConfig::SetModelBuffer)
      .def("model_from_memory", &AnalysisConfig::model_from_memory)
770 771 772 773
      .def("delete_pass",
           [](AnalysisConfig &self, const std::string &pass) {
             self.pass_builder()->DeletePass(pass);
           })
774 775 776 777 778 779
      .def(
          "pass_builder",
          [](AnalysisConfig &self) {
            return dynamic_cast<PaddlePassBuilder *>(self.pass_builder());
          },
          py::return_value_policy::reference)
780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797
      .def("nnadapter", &AnalysisConfig::NNAdapter)
      .def("set_dist_config", &AnalysisConfig::SetDistConfig)
      .def("dist_config", &AnalysisConfig::dist_config);

  py::class_<DistConfig>(*m, "DistConfig")
      .def(py::init<>())
      .def("set_carrier_id", &DistConfig::SetCarrierId)
      .def("set_comm_init_config", &DistConfig::SetCommInitConfig)
      .def("set_endpoints", &DistConfig::SetEndpoints)
      .def("set_ranks", &DistConfig::SetRanks)
      .def("enable_dist_model", &DistConfig::EnableDistModel)
      .def("carrier_id", &DistConfig::carrier_id)
      .def("current_endpoint", &DistConfig::current_endpoint)
      .def("trainer_endpoints", &DistConfig::trainer_endpoints)
      .def("nranks", &DistConfig::nranks)
      .def("rank", &DistConfig::rank)
      .def("comm_init_config", &DistConfig::comm_init_config)
      .def("use_dist_model", &DistConfig::use_dist_model);
798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815
}

void BindLiteNNAdapterConfig(py::module *m) {
  py::class_<LiteNNAdapterConfig> lite_nnadapter_config(*m,
                                                        "LiteNNAdapterConfig");

  lite_nnadapter_config
      .def("set_device_names", &LiteNNAdapterConfig::SetDeviceNames)
      .def("set_context_properties", &LiteNNAdapterConfig::SetContextProperties)
      .def("set_model_cache_dir", &LiteNNAdapterConfig::SetModelCacheDir)
      .def("set_model_cache_buffers",
           &LiteNNAdapterConfig::SetModelCacheBuffers)
      .def("set_subgraph_partition_config_path",
           &LiteNNAdapterConfig::SetSubgraphPartitionConfigPath)
      .def("set_subgraph_partition_config_buffer",
           &LiteNNAdapterConfig::SetSubgraphPartitionConfigBuffer)
      .def("enable", &LiteNNAdapterConfig::Enable)
      .def("disable", &LiteNNAdapterConfig::Disable);
F
flame 已提交
816 817
}

818 819 820 821 822 823 824 825 826 827 828 829 830 831 832
#ifdef PADDLE_WITH_MKLDNN
void BindMkldnnQuantizerConfig(py::module *m) {
  py::class_<MkldnnQuantizerConfig> quantizer_config(*m,
                                                     "MkldnnQuantizerConfig");
  quantizer_config.def(py::init<const MkldnnQuantizerConfig &>())
      .def(py::init<>())
      .def("set_quant_data",
           [](MkldnnQuantizerConfig &self,
              const std::vector<PaddleTensor> &data) {
             auto warmup_data =
                 std::make_shared<std::vector<PaddleTensor>>(data);
             self.SetWarmupData(warmup_data);
             return;
           })
      .def("set_quant_batch_size", &MkldnnQuantizerConfig::SetWarmupBatchSize)
833
      .def("set_enabled_op_types", &MkldnnQuantizerConfig::SetEnabledOpTypes);
834 835 836
}
#endif

F
flame 已提交
837 838 839 840 841 842 843 844 845 846 847 848 849
void BindAnalysisPredictor(py::module *m) {
  py::class_<AnalysisPredictor, PaddlePredictor>(*m, "AnalysisPredictor")
      .def(py::init<const AnalysisConfig &>())
      .def("init", &AnalysisPredictor::Init)
      .def(
          "run",
          [](AnalysisPredictor &self, const std::vector<PaddleTensor> &inputs) {
            std::vector<PaddleTensor> outputs;
            self.Run(inputs, &outputs);
            return outputs;
          })
      .def("get_input_tensor", &AnalysisPredictor::GetInputTensor)
      .def("get_output_tensor", &AnalysisPredictor::GetOutputTensor)
850 851 852
      .def("get_input_names", &AnalysisPredictor::GetInputNames)
      .def("get_output_names", &AnalysisPredictor::GetOutputNames)
      .def("get_input_tensor_shape", &AnalysisPredictor::GetInputTensorShape)
F
flame 已提交
853
      .def("zero_copy_run", &AnalysisPredictor::ZeroCopyRun)
854 855
      .def("clear_intermediate_tensor",
           &AnalysisPredictor::ClearIntermediateTensor)
856
      .def("try_shrink_memory", &AnalysisPredictor::TryShrinkMemory)
857 858 859 860 861
      .def("create_feed_fetch_var", &AnalysisPredictor::CreateFeedFetchVar)
      .def("prepare_feed_fetch", &AnalysisPredictor::PrepareFeedFetch)
      .def("prepare_argument", &AnalysisPredictor::PrepareArgument)
      .def("optimize_inference_program",
           &AnalysisPredictor::OptimizeInferenceProgram)
W
Wilber 已提交
862 863
      .def("analysis_argument",
           &AnalysisPredictor::analysis_argument,
864
           py::return_value_policy::reference)
F
flame 已提交
865
      .def("clone", &AnalysisPredictor::Clone)
W
Wilber 已提交
866 867
      .def("scope",
           &AnalysisPredictor::scope,
868
           py::return_value_policy::reference)
W
Wilber 已提交
869 870
      .def("program",
           &AnalysisPredictor::program,
871 872 873
           py::return_value_policy::reference)
      .def("get_serialized_program", &AnalysisPredictor::GetSerializedProgram)
      .def("mkldnn_quantize", &AnalysisPredictor::MkldnnQuantize)
W
Wilber 已提交
874 875
      .def(
          "SaveOptimModel", &AnalysisPredictor::SaveOptimModel, py::arg("dir"));
F
flame 已提交
876
}
877

W
Wilber 已提交
878 879 880 881 882 883 884
void BindPaddleInferPredictor(py::module *m) {
  py::class_<paddle_infer::Predictor>(*m, "PaddleInferPredictor")
      .def(py::init<const paddle_infer::Config &>())
      .def("get_input_names", &paddle_infer::Predictor::GetInputNames)
      .def("get_output_names", &paddle_infer::Predictor::GetOutputNames)
      .def("get_input_handle", &paddle_infer::Predictor::GetInputHandle)
      .def("get_output_handle", &paddle_infer::Predictor::GetOutputHandle)
W
Wilber 已提交
885 886 887 888 889 890 891
      .def("run",
           [](paddle_infer::Predictor &self) {
#ifdef PADDLE_WITH_ASCEND_CL
             pybind11::gil_scoped_release release;
#endif
             self.Run();
           })
W
Wilber 已提交
892
      .def("clone", &paddle_infer::Predictor::Clone)
893
      .def("try_shrink_memory", &paddle_infer::Predictor::TryShrinkMemory)
W
Wilber 已提交
894 895 896 897
      .def("clear_intermediate_tensor",
           &paddle_infer::Predictor::ClearIntermediateTensor);
}

898 899
void BindZeroCopyTensor(py::module *m) {
  py::class_<ZeroCopyTensor>(*m, "ZeroCopyTensor")
W
Wilber 已提交
900 901 902 903 904 905
      .def(
          "reshape",
          py::overload_cast<const std::vector<int> &>(&ZeroCopyTensor::Reshape))
      .def("reshape",
           py::overload_cast<const std::size_t &>(
               &paddle_infer::Tensor::ReshapeStrings))
906 907 908
      .def("copy_from_cpu", &ZeroCopyTensorCreate<int32_t>)
      .def("copy_from_cpu", &ZeroCopyTensorCreate<int64_t>)
      .def("copy_from_cpu", &ZeroCopyTensorCreate<float>)
909
      .def("copy_from_cpu", &ZeroCopyTensorCreate<paddle_infer::float16>)
S
Steffy-zxf 已提交
910
      .def("copy_from_cpu", &ZeroCopyStringTensorCreate)
911 912 913 914 915 916 917
      .def("copy_to_cpu", &ZeroCopyTensorToNumpy)
      .def("shape", &ZeroCopyTensor::shape)
      .def("set_lod", &ZeroCopyTensor::SetLoD)
      .def("lod", &ZeroCopyTensor::lod)
      .def("type", &ZeroCopyTensor::type);
}

W
Wilber 已提交
918 919
void BindPaddleInferTensor(py::module *m) {
  py::class_<paddle_infer::Tensor>(*m, "PaddleInferTensor")
W
Wilber 已提交
920 921 922 923 924 925
      .def("reshape",
           py::overload_cast<const std::vector<int> &>(
               &paddle_infer::Tensor::Reshape))
      .def("reshape",
           py::overload_cast<const std::size_t &>(
               &paddle_infer::Tensor::ReshapeStrings))
926 927 928 929 930
      .def("copy_from_cpu_bind", &PaddleInferTensorCreate<int32_t>)
      .def("copy_from_cpu_bind", &PaddleInferTensorCreate<int64_t>)
      .def("copy_from_cpu_bind", &PaddleInferTensorCreate<float>)
      .def("copy_from_cpu_bind",
           &PaddleInferTensorCreate<paddle_infer::float16>)
S
Steffy-zxf 已提交
931
      .def("copy_from_cpu_bind", &PaddleInferStringTensorCreate)
932
      .def("share_external_data_bind", &PaddleInferShareExternalData)
W
Wilber 已提交
933 934 935 936 937 938 939 940 941 942
      .def("copy_to_cpu", &PaddleInferTensorToNumpy)
      .def("shape", &paddle_infer::Tensor::shape)
      .def("set_lod", &paddle_infer::Tensor::SetLoD)
      .def("lod", &paddle_infer::Tensor::lod)
      .def("type", &paddle_infer::Tensor::type);
}

void BindPredictorPool(py::module *m) {
  py::class_<paddle_infer::services::PredictorPool>(*m, "PredictorPool")
      .def(py::init<const paddle_infer::Config &, size_t>())
W
Wilber 已提交
943 944
      .def("retrive",
           &paddle_infer::services::PredictorPool::Retrive,
W
Wilber 已提交
945 946 947
           py::return_value_policy::reference);
}

948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966
void BindPaddlePassBuilder(py::module *m) {
  py::class_<PaddlePassBuilder>(*m, "PaddlePassBuilder")
      .def(py::init<const std::vector<std::string> &>())
      .def("set_passes",
           [](PaddlePassBuilder &self, const std::vector<std::string> &passes) {
             self.ClearPasses();
             for (auto pass : passes) {
               self.AppendPass(std::move(pass));
             }
           })
      .def("append_pass", &PaddlePassBuilder::AppendPass)
      .def("insert_pass", &PaddlePassBuilder::InsertPass)
      .def("delete_pass",
           [](PaddlePassBuilder &self, const std::string &pass_type) {
             self.DeletePass(pass_type);
           })
      .def("append_analysis_pass", &PaddlePassBuilder::AppendAnalysisPass)
      .def("turn_on_debug", &PaddlePassBuilder::TurnOnDebug)
      .def("debug_string", &PaddlePassBuilder::DebugString)
W
Wilber 已提交
967 968
      .def("all_passes",
           &PaddlePassBuilder::AllPasses,
969 970 971 972 973 974 975 976
           py::return_value_policy::reference)
      .def("analysis_passes", &PaddlePassBuilder::AnalysisPasses);

  py::class_<PassStrategy, PaddlePassBuilder>(*m, "PassStrategy")
      .def(py::init<const std::vector<std::string> &>())
      .def("enable_cudnn", &PassStrategy::EnableCUDNN)
      .def("enable_mkldnn", &PassStrategy::EnableMKLDNN)
      .def("enable_mkldnn_quantizer", &PassStrategy::EnableMkldnnQuantizer)
977
      .def("enable_mkldnn_bfloat16", &PassStrategy::EnableMkldnnBfloat16)
978 979 980 981 982 983 984
      .def("use_gpu", &PassStrategy::use_gpu);

  py::class_<CpuPassStrategy, PassStrategy>(*m, "CpuPassStrategy")
      .def(py::init<>())
      .def(py::init<const CpuPassStrategy &>())
      .def("enable_cudnn", &CpuPassStrategy::EnableCUDNN)
      .def("enable_mkldnn", &CpuPassStrategy::EnableMKLDNN)
985 986
      .def("enable_mkldnn_quantizer", &CpuPassStrategy::EnableMkldnnQuantizer)
      .def("enable_mkldnn_bfloat16", &CpuPassStrategy::EnableMkldnnBfloat16);
987 988 989 990 991 992

  py::class_<GpuPassStrategy, PassStrategy>(*m, "GpuPassStrategy")
      .def(py::init<>())
      .def(py::init<const GpuPassStrategy &>())
      .def("enable_cudnn", &GpuPassStrategy::EnableCUDNN)
      .def("enable_mkldnn", &GpuPassStrategy::EnableMKLDNN)
993 994
      .def("enable_mkldnn_quantizer", &GpuPassStrategy::EnableMkldnnQuantizer)
      .def("enable_mkldnn_bfloat16", &GpuPassStrategy::EnableMkldnnBfloat16);
995
}
996
}  // namespace
F
flame 已提交
997 998
}  // namespace pybind
}  // namespace paddle