pybind.cc 18.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

7
http://www.apache.org/licenses/LICENSE-2.0
8 9 10 11 12 13 14

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/pybind/protobuf.h"
Q
qijun 已提交
16

Q
QI JUN 已提交
17
#include <mutex>  // for call_once
18
#include <unordered_map>
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/backward.h"
20
#include "paddle/fluid/framework/channel.h"
Y
Yi Wang 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/cond_op.h"
#include "paddle/fluid/operators/net_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/tensor_py.h"
39
#include "paddle/fluid/string/to_string.h"
40

D
Dong Zhihong 已提交
41
#ifdef PADDLE_WITH_CUDA
Y
Yi Wang 已提交
42 43 44
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/cuda_profiler.h"
#include "paddle/fluid/platform/gpu_info.h"
D
Dong Zhihong 已提交
45 46
#endif

Q
Qiao Longfei 已提交
47 48 49
// disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);

50
namespace paddle {
51
namespace pybind {
52
bool IsCompiledWithCUDA() {
53
#ifndef PADDLE_WITH_CUDA
Q
qijun 已提交
54 55 56 57 58 59
  return false;
#else
  return true;
#endif
}

60 61
PYBIND11_PLUGIN(core) {
  py::module m("core", "C++ core of PaddlePaddle");
62

63 64 65 66
  // using framework in this function. Since it is inside a function, it will
  // not cause namespace pollution.
  using namespace paddle::framework;  // NOLINT

Y
Yu Yang 已提交
67 68
  BindException(m);

69 70 71
  py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
      .def_buffer(
          [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
Y
Yu Yang 已提交
72
      .def("get_dims",
73
           [](const Tensor &self) { return vectorize(self.dims()); })
Y
Yu Yang 已提交
74
      .def("set_dims",
Q
qijun 已提交
75
           [](Tensor &self, const std::vector<int64_t> &dim) {
Y
Yu Yang 已提交
76
             self.Resize(make_ddim(dim));
Y
Yu Yang 已提交
77
           })
D
dzhwinter 已提交
78 79 80 81
      .def("set_layout",
           [](Tensor &self, const std::string &layout) {
             self.set_layout(StringToDataLayout(layout));
           })
Y
Yu Yang 已提交
82
      .def("alloc_float",
D
dzhwinter 已提交
83
           [](Tensor &self, paddle::platform::CUDAPlace &place) {
Q
qijun 已提交
84
             self.mutable_data<float>(place);
Y
Yu Yang 已提交
85
           })
Q
qijun 已提交
86
      .def("alloc_float",
Y
Yu Yang 已提交
87
           [](Tensor &self, paddle::platform::CPUPlace &place) {
Q
qijun 已提交
88
             self.mutable_data<float>(place);
Y
Yu Yang 已提交
89 90
           })
      .def("alloc_int",
Y
Yu Yang 已提交
91
           [](Tensor &self, paddle::platform::CPUPlace &place) {
Q
qijun 已提交
92
             self.mutable_data<int>(place);
Y
Yu Yang 已提交
93
           })
Q
qijun 已提交
94
      .def("alloc_int",
D
dzhwinter 已提交
95
           [](Tensor &self, paddle::platform::CUDAPlace &place) {
Q
qijun 已提交
96
             self.mutable_data<int>(place);
Q
qijun 已提交
97
           })
Y
Yu Yang 已提交
98 99
      .def("set", PyCPUTensorSetFromArray<float>)
      .def("set", PyCPUTensorSetFromArray<int>)
100
      .def("set", PyCPUTensorSetFromArray<double>)
101
      .def("set", PyCPUTensorSetFromArray<int64_t>)
Y
Yu Yang 已提交
102
      .def("set", PyCPUTensorSetFromArray<bool>)
103
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
104 105
      .def("set", PyCUDATensorSetFromArray<float>)
      .def("set", PyCUDATensorSetFromArray<int>)
106
      .def("set", PyCUDATensorSetFromArray<double>)
107
      .def("set", PyCUDATensorSetFromArray<int64_t>)
Y
Yu Yang 已提交
108
      .def("set", PyCUDATensorSetFromArray<bool>)
Q
qijun 已提交
109
#endif
110
      .def("shape", [](Tensor &self) { return vectorize(self.dims()); })
111 112 113 114 115
      .def("set_float_element", TensorSetElement<float>)
      .def("get_float_element", TensorGetElement<float>)
      .def("set_double_element", TensorSetElement<double>)
      .def("get_double_element", TensorGetElement<double>)
      .def("dtype", [](Tensor &self) { return ToDataType(self.type()); });
Y
Yu Yang 已提交
116

117
  py::class_<LoDTensor, Tensor>(m, "LoDTensor")
118 119
      .def_buffer(
          [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
120 121 122
      .def(
          "__init__",
          [](LoDTensor &instance, const std::vector<std::vector<size_t>> &lod) {
D
dzhwinter 已提交
123 124 125 126
            LoD new_lod;
            new_lod.reserve(lod.size());
            std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
            new (&instance) LoDTensor(new_lod);
127
          })
Y
Yu Yang 已提交
128
      .def("__init__", [](LoDTensor &instance) { new (&instance) LoDTensor(); })
D
dangqingqing 已提交
129
      .def("set_lod",
130
           [](LoDTensor &self, const std::vector<std::vector<size_t>> &lod) {
Y
Yu Yang 已提交
131
             LoD new_lod;
132 133 134
             new_lod.reserve(lod.size());
             std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
             self.set_lod(new_lod);
D
dangqingqing 已提交
135
           })
136
      .def("lod", [](LoDTensor &self) -> std::vector<std::vector<size_t>> {
D
dzhwinter 已提交
137 138 139 140 141
        auto lod = self.lod();
        std::vector<std::vector<size_t>> new_lod;
        new_lod.reserve(lod.size());
        std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
        return new_lod;
D
dangqingqing 已提交
142 143
      });

Q
qijun 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156
  py::class_<SelectedRows>(m, "SelectedRows")
      .def("__init__",
           [](SelectedRows &instance) { new (&instance) SelectedRows(); })
      .def("__init__",
           [](SelectedRows &instance, const std::vector<int64_t> rows,
              const int64_t &height) {
             new (&instance) SelectedRows(rows, height);
           })
      .def("get_tensor",
           [](SelectedRows &self) { return self.mutable_value(); },
           py::return_value_policy::reference)
      .def("set_height", &SelectedRows::set_height)
      .def("height", &SelectedRows::height)
Q
qijun 已提交
157 158 159 160 161 162 163 164 165
      .def("set_rows",
           [](SelectedRows &self, std::vector<int64_t> rows) {
#ifndef PADDLE_WITH_CUDA
             self.set_rows(rows);
#else
        Vector<int64_t> new_rows(rows);
        self.set_rows(new_rows);
#endif
           })
166 167 168 169 170 171 172 173 174 175 176
      .def("rows", [](SelectedRows &self) {
#ifndef PADDLE_WITH_CUDA
        return self.rows();
#else
         auto rows = self.rows();
         std::vector<int64_t> new_rows;
         new_rows.reserve(rows.size());
         std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows));
         return new_rows;
#endif
      });
Q
qijun 已提交
177

178
  py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
179 180 181

All parameter, weight, gradient are variables in Paddle.
)DOC")
182
      .def("is_int", [](const Variable &var) { return var.IsType<int>(); })
183
      .def("set_int",
184 185
           [](Variable &var, int val) -> void { *var.GetMutable<int>() = val; })
      .def("get_int", [](const Variable &var) -> int { return var.Get<int>(); })
186 187 188 189 190 191 192
      .def("is_float", [](const Variable &var) { return var.IsType<float>(); })
      .def("set_float",
           [](Variable &var, float val) -> void {
             *var.GetMutable<float>() = val;
           })
      .def("get_float",
           [](const Variable &var) -> float { return var.Get<float>(); })
Y
Yu Yang 已提交
193
      .def("get_tensor",
194 195
           [](Variable &self) -> LoDTensor * {
             return self.GetMutable<LoDTensor>();
D
dangqingqing 已提交
196 197
           },
           py::return_value_policy::reference)
Y
Yu Yang 已提交
198 199 200
      .def("get_lod_rank_table",
           [](Variable &self) { return self.GetMutable<LoDRankTable>(); },
           py::return_value_policy::reference)
Q
qijun 已提交
201 202 203 204 205
      .def("get_selected_rows",
           [](Variable &self) -> SelectedRows * {
             return self.GetMutable<SelectedRows>();
           },
           py::return_value_policy::reference)
Y
Yu Yang 已提交
206 207 208
      .def("get_lod_tensor_array",
           [](Variable &self) { return self.GetMutable<LoDTensorArray>(); },
           py::return_value_policy::reference)
D
Dong Zhihong 已提交
209 210 211 212 213 214 215
#ifdef PADDLE_WITH_CUDA
      .def("get_communicator",
           [](Variable &self) -> platform::Communicator * {
             return self.GetMutable<platform::Communicator>();
           },
           py::return_value_policy::reference)
#endif
Y
Yan Chunwei 已提交
216
      .def("get_net",
D
dongzhihong 已提交
217 218
           [](Variable &self) -> operators::NetOp * {
             return self.GetMutable<operators::NetOp>();
Y
Yan Chunwei 已提交
219
           },
Y
Yu Yang 已提交
220
           py::return_value_policy::reference);
221

222
  py::class_<Scope>(m, "Scope", "")
D
dongzhihong 已提交
223
      .def("var",
224
           [](Scope &self, const std::string &name) -> Variable * {
D
dongzhihong 已提交
225
             return self.Var(name);
Y
Yu Yang 已提交
226
           },
227
           py::return_value_policy::reference)
228
      .def("find_var", &Scope::FindVar, py::return_value_policy::reference)
Y
Yu Yang 已提交
229
      .def(py::init<>())
230
      .def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); },
231
           py::return_value_policy::reference)
Y
Yu Yang 已提交
232
      .def("drop_kids", &Scope::DropKids);
233

Y
Yu Yang 已提交
234 235
  //! @note: Be careful! PyBind will return std::string as an unicode, not
  //! Python str. If you want a str object, you should cast them in Python.
Y
Yu Yang 已提交
236 237
  m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
    std::vector<py::bytes> ret_values;
238 239 240 241 242 243 244 245 246 247
    for (auto &iter : OpInfoMap::Instance().map()) {
      auto &info = iter.second;
      if (info.HasOpProtoAndChecker()) {
        std::string str;
        PADDLE_ENFORCE(
            info.Proto().SerializeToString(&str),
            "Serialize OpProto Error. This could be a bug of Paddle.");
        ret_values.emplace_back(str);
      }
    }
Y
Yu Yang 已提交
248 249
    return ret_values;
  });
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
  m.def(
      "get_grad_op_desc", [](const OpDesc &op_desc,
                             const std::unordered_set<std::string> &no_grad_set,
                             const std::vector<BlockDesc *> &grad_sub_block) {
        std::unordered_map<std::string, std::string> grad_to_var;
        std::vector<std::unique_ptr<OpDesc>> grad_op_descs =
            framework::OpInfoMap::Instance()
                .Get(op_desc.Type())
                .GradOpMaker()(op_desc, no_grad_set, &grad_to_var,
                               grad_sub_block);
        std::vector<OpDesc *> grad_op_desc_ptrs(grad_op_descs.size());
        std::transform(grad_op_descs.begin(), grad_op_descs.end(),
                       grad_op_desc_ptrs.begin(),
                       [](std::unique_ptr<OpDesc> &p) { return p.release(); });
        return std::make_pair(grad_op_desc_ptrs, grad_to_var);
      });
Y
Yu Yang 已提交
266
  m.def("prune", [](const ProgramDesc &origin,
267
                    const std::vector<std::array<size_t, 2>> &targets) {
Y
Yu Yang 已提交
268
    ProgramDesc prog_with_targets(origin);
269
    for (const auto &t : targets) {
270
      prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget();
271
    }
272
    proto::ProgramDesc pruned_desc;
273
    Prune(*prog_with_targets.Proto(), &pruned_desc);
Y
Yu Yang 已提交
274
    return new ProgramDesc(pruned_desc);
275
  });
Y
Yu Yang 已提交
276
  m.def("inference_optimize", [](ProgramDesc &origin) {
277
    proto::ProgramDesc pruned_desc;
278
    InferenceOptimize(*(origin.Proto()), &pruned_desc);
Y
Yu Yang 已提交
279
    return new ProgramDesc(pruned_desc);
280
  });
F
fengjiayi 已提交
281 282
  m.def("empty_var_name", []() { return framework::kEmptyVarName; });
  m.def("grad_var_suffix", []() { return framework::kGradVarSuffix; });
283 284 285
  m.def_submodule(
       "var_names",
       "The module will return special predefined variable name in Paddle")
Y
Yi Wang 已提交
286 287
      .def("empty", []() { return kEmptyVarName; })
      .def("temp", []() { return kTempVarName; });
Q
qijun 已提交
288
  // clang-format off
Y
Yu Yang 已提交
289
  py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
Q
qijun 已提交
290 291
      .def_static("create",
                  [](paddle::platform::CPUPlace& place)
Q
qijun 已提交
292
                      -> paddle::platform::DeviceContext* {
Q
qijun 已提交
293 294 295
                    return new paddle::platform::CPUDeviceContext();
                  })
      .def_static("create",
D
dzhwinter 已提交
296
                  [](paddle::platform::CUDAPlace& place)
Q
qijun 已提交
297
                      -> paddle::platform::DeviceContext* {
298
#ifndef PADDLE_WITH_CUDA
D
dzhwinter 已提交
299
                    PADDLE_THROW("CUDAPlace is not supported in CPU device.");
Q
qijun 已提交
300
#else
Q
qijun 已提交
301
                    return new paddle::platform::CUDADeviceContext(place);
Q
qijun 已提交
302
#endif
Q
qijun 已提交
303
                  });
D
Dong Zhihong 已提交
304
// clang-format on
Q
qijun 已提交
305

D
Dong Zhihong 已提交
306 307 308
#ifdef PADDLE_WITH_CUDA
  py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
#endif
D
dzhwinter 已提交
309
  py::class_<platform::CUDAPlace>(m, "CUDAPlace")
310
      .def(py::init<int>())
D
dzhwinter 已提交
311
      .def("__str__", string::to_string<const platform::CUDAPlace &>);
Q
qijun 已提交
312

313 314 315
  py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
      .def(py::init<>())
      .def("__str__", string::to_string<const platform::CPUPlace &>);
Y
Yu Yang 已提交
316

Y
Yu Yang 已提交
317 318 319 320 321 322 323
  py::class_<platform::Place>(m, "Place")
      .def(py::init<>())
      .def("set_place",
           [](platform::Place &self, const platform::CPUPlace &cpu_place) {
             self = cpu_place;
           })
      .def("set_place",
D
dzhwinter 已提交
324
           [](platform::Place &self, const platform::CUDAPlace &gpu_place) {
Y
Yu Yang 已提交
325 326 327
             self = gpu_place;
           });

Y
Yu Yang 已提交
328 329 330
  py::class_<OperatorBase>(m, "Operator")
      .def_static("create",
                  [](py::bytes protobin) {
331
                    proto::OpDesc desc;
Y
Yu Yang 已提交
332 333 334 335 336
                    PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
                                   "Cannot parse user input to OpDesc");
                    PADDLE_ENFORCE(desc.IsInitialized(),
                                   "User OpDesc is not initialized, reason %s",
                                   desc.InitializationErrorString());
337
                    return OpRegistry::CreateOp(desc);
Y
Yu Yang 已提交
338 339 340 341 342 343
                  })
      .def("backward",
           [](const OperatorBase &forwardOp,
              const std::unordered_set<std::string> &no_grad_vars) {
             return Backward(forwardOp, no_grad_vars).release();
           })
344
      .def("run",
345
           [](OperatorBase &self, const Scope &scope,
D
dzhwinter 已提交
346 347 348
              const platform::CPUPlace &place) { self.Run(scope, place); })
      .def("run",
           [](OperatorBase &self, const Scope &scope,
D
dzhwinter 已提交
349
              const platform::CUDAPlace &place) { self.Run(scope, place); })
Y
Yu Yang 已提交
350 351 352 353 354 355 356
      .def("type",
           [](const OperatorBase &op) -> std::string { return op.Type(); })
      .def("outputs",
           [](const OperatorBase &op)
               -> std::map<std::string, std::vector<std::string>> {
                 return op.Outputs();
               })
Q
qijun 已提交
357 358
      .def("output_vars",
           [](const OperatorBase &op) { return op.OutputVars(true); })
Y
Yu Yang 已提交
359
      .def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
Q
qijun 已提交
360
      .def("input_vars", [](const OperatorBase &op) { return op.InputVars(); })
Y
Yu Yang 已提交
361 362 363 364
      .def("__str__", &OperatorBase::DebugString)
      .def("no_intermediate_outputs",
           [](const OperatorBase &op) { return op.OutputVars(false); })
      .def("support_gpu", &OperatorBase::SupportGPU);
Y
Yu Yang 已提交
365

Y
Yu Yang 已提交
366 367 368 369 370 371 372
  py::class_<operators::NetOp, OperatorBase>(m, "Net")
      .def_static("create",
                  []() -> operators::NetOp * {
                    auto *retv = new operators::NetOp;
                    retv->SetType("plain_net");
                    return retv;
                  })
373 374
      .def("append_op", [](operators::NetOp &self,
                           const OperatorBase &op) { self.AppendOp(op); })
D
dongzhihong 已提交
375 376 377 378
      .def("complete_add_op", &operators::NetOp::CompleteAddOp)
      .def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
        self->CompleteAddOp();
      });
Y
Yan Chunwei 已提交
379

Z
cond op  
zchen0211 已提交
380 381 382 383
  // cond_op
  py::class_<operators::CondOp, OperatorBase>(m, "CondOp")
      .def_static("create",
                  [](py::bytes protobin) -> operators::CondOp * {
384
                    proto::OpDesc desc;
Z
cond op  
zchen0211 已提交
385 386 387 388 389
                    PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
                                   "Cannot parse user input to OpDesc");
                    PADDLE_ENFORCE(desc.IsInitialized(),
                                   "User OpDesc is not initialized, reason %s",
                                   desc.InitializationErrorString());
390
                    auto cond_op = OpRegistry::CreateOp(desc);
Z
cond op  
zchen0211 已提交
391 392 393 394 395 396 397 398 399 400 401
                    return static_cast<operators::CondOp *>(cond_op.release());
                  })
      .def("set_truenet",
           [](operators::CondOp &self, const operators::NetOp &net) -> void {
             self.set_truenet(net.Clone());
           })
      .def("set_falsenet",
           [](operators::CondOp &self, const operators::NetOp &net) -> void {
             self.set_falsenet(net.Clone());
           });

F
fengjiayi 已提交
402
  py::class_<framework::Executor>(m, "Executor")
D
dzhwinter 已提交
403
      .def(py::init<const platform::Place &>())
404 405 406
      .def("run",
           (void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) &
               Executor::Run);
F
fengjiayi 已提交
407

D
dzhwinter 已提交
408
  m.def("init_gflags", framework::InitGflags);
Y
Yang Yu 已提交
409
  m.def("init_glog", framework::InitGLOG);
D
dzhwinter 已提交
410
  m.def("init_devices", &framework::InitDevices);
411

412
  m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
413

414
  m.def("set_feed_variable", framework::SetFeedVariable);
Q
qijun 已提交
415
  m.def("get_fetch_variable", framework::GetFetchVariable);
Q
qijun 已提交
416

F
fengjiayi 已提交
417 418 419 420
  BindProgramDesc(m);
  BindBlockDesc(m);
  BindVarDsec(m);
  BindOpDesc(m);
Q
qiaolongfei 已提交
421
  BindConstValue(m);
Y
Yu Yang 已提交
422

Y
Yu Yang 已提交
423 424 425 426 427 428 429 430 431
  py::class_<framework::LoDRankTable>(m, "LodRankTable")
      .def("items", [](framework::LoDRankTable &table) {
        std::vector<std::pair<size_t, size_t>> res;
        for (auto &item : table.items()) {
          res.push_back({item.index, item.length});
        }
        return res;
      });

Y
Yu Yang 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
  py::class_<LoDTensorArray>(m, "LoDTensorArray")
      .def("__getitem__",
           [](LoDTensorArray &self, size_t i) { return &self.at(i); },
           py::return_value_policy::reference)
      .def("__len__", [](LoDTensorArray &self) { return self.size(); })
      .def("__setitem__",
           [](LoDTensorArray &self, size_t i, const LoDTensor &t) {
             PADDLE_ENFORCE_LT(i, self.size());
             self[i].ShareDataWith(t);
             self[i].set_lod(t.lod());
           })
      .def("append", [](LoDTensorArray &self, const LoDTensor &t) {
        self.emplace_back();
        self.back().ShareDataWith(t);
        self.back().set_lod(t.lod());
      });

Y
Yu Yang 已提交
449
  m.def("op_support_gpu", OpSupportGPU);
D
Dong Zhihong 已提交
450
#ifdef PADDLE_WITH_CUDA
D
Dong Zhihong 已提交
451
  m.def("get_cuda_device_count", platform::GetCUDADeviceCount);
D
dangqingqing 已提交
452 453 454 455

  m.def("nvprof_init", platform::CudaProfilerInit);
  m.def("nvprof_start", platform::CudaProfilerStart);
  m.def("nvprof_stop", platform::CudaProfilerStop);
D
Dong Zhihong 已提交
456
#endif
Y
Yu Yang 已提交
457

458 459 460 461
  py::enum_<platform::ProfilerState>(m, "ProfilerState", py::arithmetic())
      .value("kDisabled", platform::ProfilerState::kDisabled)
      .value("kCPU", platform::ProfilerState::kCPU)
      .value("kCUDA", platform::ProfilerState::kCUDA)
462
      .value("kAll", platform::ProfilerState::kAll)
463 464 465 466 467 468 469 470 471 472 473 474 475 476
      .export_values();

  py::enum_<platform::EventSortingKey>(m, "EventSortingKey", py::arithmetic())
      .value("kDefault", platform::EventSortingKey::kDefault)
      .value("kCalls", platform::EventSortingKey::kCalls)
      .value("kTotal", platform::EventSortingKey::kTotal)
      .value("kMin", platform::EventSortingKey::kMin)
      .value("kMax", platform::EventSortingKey::kMax)
      .value("kAve", platform::EventSortingKey::kAve)
      .export_values();

  m.def("enable_profiler", platform::EnableProfiler);
  m.def("disable_profiler", platform::DisableProfiler);
  m.def("reset_profiler", platform::ResetProfiler);
477
  return m.ptr();
L
Luo Tao 已提交
478
}
479
}  // namespace pybind
480
}  // namespace paddle