pybind.cc 19.1 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

7
http://www.apache.org/licenses/LICENSE-2.0
8 9 10 11 12 13 14

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/pybind/protobuf.h"
Q
qijun 已提交
16

Q
QI JUN 已提交
17
#include <mutex>  // for call_once
18
#include <unordered_map>
Y
Yi Wang 已提交
19
#include "paddle/fluid/framework/backward.h"
20
#include "paddle/fluid/framework/channel.h"
Y
Yi Wang 已提交
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/init.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/prune.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/cond_op.h"
#include "paddle/fluid/operators/net_op.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/const_value.h"
#include "paddle/fluid/pybind/exception.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/pybind/tensor_py.h"
39
#include "paddle/fluid/string/to_string.h"
40

D
Dong Zhihong 已提交
41
#ifdef PADDLE_WITH_CUDA
Y
Yi Wang 已提交
42 43 44
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/platform/cuda_profiler.h"
#include "paddle/fluid/platform/gpu_info.h"
D
Dong Zhihong 已提交
45 46
#endif

Q
Qiao Longfei 已提交
47 48 49
// disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);

50
namespace paddle {
51
namespace pybind {
52 53 54
static size_t UniqueIntegerGenerator(const std::string &prefix) {
  static std::unordered_map<std::string, std::atomic<size_t>> generators;
  return generators[prefix].fetch_add(1);
55 56
}

57
bool IsCompiledWithCUDA() {
58
#ifndef PADDLE_WITH_CUDA
Q
qijun 已提交
59 60 61 62 63 64
  return false;
#else
  return true;
#endif
}

65 66
PYBIND11_PLUGIN(core) {
  py::module m("core", "C++ core of PaddlePaddle");
67

68 69 70 71
  // using framework in this function. Since it is inside a function, it will
  // not cause namespace pollution.
  using namespace paddle::framework;  // NOLINT

Y
Yu Yang 已提交
72 73
  BindException(m);

74 75 76
  py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
      .def_buffer(
          [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
Y
Yu Yang 已提交
77
      .def("get_dims",
78
           [](const Tensor &self) { return vectorize(self.dims()); })
Y
Yu Yang 已提交
79
      .def("set_dims",
Q
qijun 已提交
80
           [](Tensor &self, const std::vector<int64_t> &dim) {
Y
Yu Yang 已提交
81
             self.Resize(make_ddim(dim));
Y
Yu Yang 已提交
82
           })
D
dzhwinter 已提交
83 84 85 86
      .def("set_layout",
           [](Tensor &self, const std::string &layout) {
             self.set_layout(StringToDataLayout(layout));
           })
Y
Yu Yang 已提交
87
      .def("alloc_float",
D
dzhwinter 已提交
88
           [](Tensor &self, paddle::platform::CUDAPlace &place) {
Q
qijun 已提交
89
             self.mutable_data<float>(place);
Y
Yu Yang 已提交
90
           })
Q
qijun 已提交
91
      .def("alloc_float",
Y
Yu Yang 已提交
92
           [](Tensor &self, paddle::platform::CPUPlace &place) {
Q
qijun 已提交
93
             self.mutable_data<float>(place);
Y
Yu Yang 已提交
94 95
           })
      .def("alloc_int",
Y
Yu Yang 已提交
96
           [](Tensor &self, paddle::platform::CPUPlace &place) {
Q
qijun 已提交
97
             self.mutable_data<int>(place);
Y
Yu Yang 已提交
98
           })
Q
qijun 已提交
99
      .def("alloc_int",
D
dzhwinter 已提交
100
           [](Tensor &self, paddle::platform::CUDAPlace &place) {
Q
qijun 已提交
101
             self.mutable_data<int>(place);
Q
qijun 已提交
102
           })
Y
Yu Yang 已提交
103 104
      .def("set", PyCPUTensorSetFromArray<float>)
      .def("set", PyCPUTensorSetFromArray<int>)
105
      .def("set", PyCPUTensorSetFromArray<double>)
106
      .def("set", PyCPUTensorSetFromArray<int64_t>)
Y
Yu Yang 已提交
107
      .def("set", PyCPUTensorSetFromArray<bool>)
108
#ifdef PADDLE_WITH_CUDA
Y
Yu Yang 已提交
109 110
      .def("set", PyCUDATensorSetFromArray<float>)
      .def("set", PyCUDATensorSetFromArray<int>)
111
      .def("set", PyCUDATensorSetFromArray<double>)
112
      .def("set", PyCUDATensorSetFromArray<int64_t>)
Y
Yu Yang 已提交
113
      .def("set", PyCUDATensorSetFromArray<bool>)
Q
qijun 已提交
114
#endif
115
      .def("shape", [](Tensor &self) { return vectorize(self.dims()); })
116 117 118 119 120
      .def("set_float_element", TensorSetElement<float>)
      .def("get_float_element", TensorGetElement<float>)
      .def("set_double_element", TensorSetElement<double>)
      .def("get_double_element", TensorGetElement<double>)
      .def("dtype", [](Tensor &self) { return ToDataType(self.type()); });
Y
Yu Yang 已提交
121

122
  py::class_<LoDTensor, Tensor>(m, "LoDTensor")
123 124
      .def_buffer(
          [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
125 126 127
      .def(
          "__init__",
          [](LoDTensor &instance, const std::vector<std::vector<size_t>> &lod) {
D
dzhwinter 已提交
128 129 130 131
            LoD new_lod;
            new_lod.reserve(lod.size());
            std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
            new (&instance) LoDTensor(new_lod);
132
          })
Y
Yu Yang 已提交
133
      .def("__init__", [](LoDTensor &instance) { new (&instance) LoDTensor(); })
D
dangqingqing 已提交
134
      .def("set_lod",
135
           [](LoDTensor &self, const std::vector<std::vector<size_t>> &lod) {
Y
Yu Yang 已提交
136
             LoD new_lod;
137 138 139
             new_lod.reserve(lod.size());
             std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
             self.set_lod(new_lod);
D
dangqingqing 已提交
140
           })
141
      .def("lod", [](LoDTensor &self) -> std::vector<std::vector<size_t>> {
D
dzhwinter 已提交
142 143 144 145 146
        auto lod = self.lod();
        std::vector<std::vector<size_t>> new_lod;
        new_lod.reserve(lod.size());
        std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
        return new_lod;
D
dangqingqing 已提交
147 148
      });

Q
qijun 已提交
149 150 151 152 153 154 155 156 157 158 159 160 161
  py::class_<SelectedRows>(m, "SelectedRows")
      .def("__init__",
           [](SelectedRows &instance) { new (&instance) SelectedRows(); })
      .def("__init__",
           [](SelectedRows &instance, const std::vector<int64_t> rows,
              const int64_t &height) {
             new (&instance) SelectedRows(rows, height);
           })
      .def("get_tensor",
           [](SelectedRows &self) { return self.mutable_value(); },
           py::return_value_policy::reference)
      .def("set_height", &SelectedRows::set_height)
      .def("height", &SelectedRows::height)
Q
qijun 已提交
162 163 164 165 166 167 168 169 170
      .def("set_rows",
           [](SelectedRows &self, std::vector<int64_t> rows) {
#ifndef PADDLE_WITH_CUDA
             self.set_rows(rows);
#else
        Vector<int64_t> new_rows(rows);
        self.set_rows(new_rows);
#endif
           })
171 172 173 174 175 176 177 178 179 180 181
      .def("rows", [](SelectedRows &self) {
#ifndef PADDLE_WITH_CUDA
        return self.rows();
#else
         auto rows = self.rows();
         std::vector<int64_t> new_rows;
         new_rows.reserve(rows.size());
         std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows));
         return new_rows;
#endif
      });
Q
qijun 已提交
182

183
  py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
184 185 186

All parameter, weight, gradient are variables in Paddle.
)DOC")
187
      .def("is_int", [](const Variable &var) { return var.IsType<int>(); })
188
      .def("set_int",
189 190
           [](Variable &var, int val) -> void { *var.GetMutable<int>() = val; })
      .def("get_int", [](const Variable &var) -> int { return var.Get<int>(); })
191 192 193 194 195 196 197
      .def("is_float", [](const Variable &var) { return var.IsType<float>(); })
      .def("set_float",
           [](Variable &var, float val) -> void {
             *var.GetMutable<float>() = val;
           })
      .def("get_float",
           [](const Variable &var) -> float { return var.Get<float>(); })
Y
Yu Yang 已提交
198
      .def("get_tensor",
199 200
           [](Variable &self) -> LoDTensor * {
             return self.GetMutable<LoDTensor>();
D
dangqingqing 已提交
201 202
           },
           py::return_value_policy::reference)
Y
Yu Yang 已提交
203 204 205
      .def("get_lod_rank_table",
           [](Variable &self) { return self.GetMutable<LoDRankTable>(); },
           py::return_value_policy::reference)
Q
qijun 已提交
206 207 208 209 210
      .def("get_selected_rows",
           [](Variable &self) -> SelectedRows * {
             return self.GetMutable<SelectedRows>();
           },
           py::return_value_policy::reference)
Y
Yu Yang 已提交
211 212 213
      .def("get_lod_tensor_array",
           [](Variable &self) { return self.GetMutable<LoDTensorArray>(); },
           py::return_value_policy::reference)
D
Dong Zhihong 已提交
214 215 216 217 218 219 220
#ifdef PADDLE_WITH_CUDA
      .def("get_communicator",
           [](Variable &self) -> platform::Communicator * {
             return self.GetMutable<platform::Communicator>();
           },
           py::return_value_policy::reference)
#endif
Y
Yan Chunwei 已提交
221
      .def("get_net",
D
dongzhihong 已提交
222 223
           [](Variable &self) -> operators::NetOp * {
             return self.GetMutable<operators::NetOp>();
Y
Yan Chunwei 已提交
224
           },
Y
Yu Yang 已提交
225
           py::return_value_policy::reference);
226

227
  py::class_<Scope>(m, "Scope", "")
D
dongzhihong 已提交
228
      .def("var",
229
           [](Scope &self, const std::string &name) -> Variable * {
D
dongzhihong 已提交
230
             return self.Var(name);
Y
Yu Yang 已提交
231
           },
232
           py::return_value_policy::reference)
233
      .def("find_var", &Scope::FindVar, py::return_value_policy::reference)
Y
Yu Yang 已提交
234
      .def(py::init<>())
235
      .def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); },
236
           py::return_value_policy::reference)
Y
Yu Yang 已提交
237
      .def("drop_kids", &Scope::DropKids);
238

Y
Yu Yang 已提交
239 240
  //! @note: Be careful! PyBind will return std::string as an unicode, not
  //! Python str. If you want a str object, you should cast them in Python.
Y
Yu Yang 已提交
241 242
  m.def("get_all_op_protos", []() -> std::vector<py::bytes> {
    std::vector<py::bytes> ret_values;
243 244 245 246 247 248 249 250 251 252
    for (auto &iter : OpInfoMap::Instance().map()) {
      auto &info = iter.second;
      if (info.HasOpProtoAndChecker()) {
        std::string str;
        PADDLE_ENFORCE(
            info.Proto().SerializeToString(&str),
            "Serialize OpProto Error. This could be a bug of Paddle.");
        ret_values.emplace_back(str);
      }
    }
Y
Yu Yang 已提交
253 254
    return ret_values;
  });
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
  m.def(
      "get_grad_op_desc", [](const OpDesc &op_desc,
                             const std::unordered_set<std::string> &no_grad_set,
                             const std::vector<BlockDesc *> &grad_sub_block) {
        std::unordered_map<std::string, std::string> grad_to_var;
        std::vector<std::unique_ptr<OpDesc>> grad_op_descs =
            framework::OpInfoMap::Instance()
                .Get(op_desc.Type())
                .GradOpMaker()(op_desc, no_grad_set, &grad_to_var,
                               grad_sub_block);
        std::vector<OpDesc *> grad_op_desc_ptrs(grad_op_descs.size());
        std::transform(grad_op_descs.begin(), grad_op_descs.end(),
                       grad_op_desc_ptrs.begin(),
                       [](std::unique_ptr<OpDesc> &p) { return p.release(); });
        return std::make_pair(grad_op_desc_ptrs, grad_to_var);
      });
Y
Yu Yang 已提交
271
  m.def("prune", [](const ProgramDesc &origin,
272
                    const std::vector<std::array<size_t, 2>> &targets) {
Y
Yu Yang 已提交
273
    ProgramDesc prog_with_targets(origin);
274
    for (const auto &t : targets) {
275
      prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget();
276
    }
277
    proto::ProgramDesc pruned_desc;
278
    Prune(*prog_with_targets.Proto(), &pruned_desc);
Y
Yu Yang 已提交
279
    return new ProgramDesc(pruned_desc);
280
  });
Y
Yu Yang 已提交
281
  m.def("inference_optimize", [](ProgramDesc &origin) {
282
    proto::ProgramDesc pruned_desc;
283
    InferenceOptimize(*(origin.Proto()), &pruned_desc);
Y
Yu Yang 已提交
284
    return new ProgramDesc(pruned_desc);
285
  });
F
fengjiayi 已提交
286 287
  m.def("empty_var_name", []() { return framework::kEmptyVarName; });
  m.def("grad_var_suffix", []() { return framework::kGradVarSuffix; });
288 289 290
  m.def_submodule(
       "var_names",
       "The module will return special predefined variable name in Paddle")
Y
Yi Wang 已提交
291 292
      .def("empty", []() { return kEmptyVarName; })
      .def("temp", []() { return kTempVarName; });
Q
qijun 已提交
293
  // clang-format off
Y
Yu Yang 已提交
294
  py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
Q
qijun 已提交
295 296
      .def_static("create",
                  [](paddle::platform::CPUPlace& place)
Q
qijun 已提交
297
                      -> paddle::platform::DeviceContext* {
Q
qijun 已提交
298 299 300
                    return new paddle::platform::CPUDeviceContext();
                  })
      .def_static("create",
D
dzhwinter 已提交
301
                  [](paddle::platform::CUDAPlace& place)
Q
qijun 已提交
302
                      -> paddle::platform::DeviceContext* {
303
#ifndef PADDLE_WITH_CUDA
D
dzhwinter 已提交
304
                    PADDLE_THROW("CUDAPlace is not supported in CPU device.");
Q
qijun 已提交
305
#else
Q
qijun 已提交
306
                    return new paddle::platform::CUDADeviceContext(place);
Q
qijun 已提交
307
#endif
Q
qijun 已提交
308
                  });
D
Dong Zhihong 已提交
309
// clang-format on
Q
qijun 已提交
310

D
Dong Zhihong 已提交
311 312 313
#ifdef PADDLE_WITH_CUDA
  py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
#endif
D
dzhwinter 已提交
314
  py::class_<platform::CUDAPlace>(m, "CUDAPlace")
315
      .def(py::init<int>())
D
dzhwinter 已提交
316
      .def("__str__", string::to_string<const platform::CUDAPlace &>);
Q
qijun 已提交
317

318 319 320
  py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
      .def(py::init<>())
      .def("__str__", string::to_string<const platform::CPUPlace &>);
Y
Yu Yang 已提交
321

Y
Yu Yang 已提交
322 323 324 325 326 327 328
  py::class_<platform::Place>(m, "Place")
      .def(py::init<>())
      .def("set_place",
           [](platform::Place &self, const platform::CPUPlace &cpu_place) {
             self = cpu_place;
           })
      .def("set_place",
D
dzhwinter 已提交
329
           [](platform::Place &self, const platform::CUDAPlace &gpu_place) {
Y
Yu Yang 已提交
330 331 332
             self = gpu_place;
           });

Y
Yu Yang 已提交
333 334 335
  py::class_<OperatorBase>(m, "Operator")
      .def_static("create",
                  [](py::bytes protobin) {
336
                    proto::OpDesc desc;
Y
Yu Yang 已提交
337 338 339 340 341
                    PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
                                   "Cannot parse user input to OpDesc");
                    PADDLE_ENFORCE(desc.IsInitialized(),
                                   "User OpDesc is not initialized, reason %s",
                                   desc.InitializationErrorString());
342
                    return OpRegistry::CreateOp(desc);
Y
Yu Yang 已提交
343 344 345 346 347 348
                  })
      .def("backward",
           [](const OperatorBase &forwardOp,
              const std::unordered_set<std::string> &no_grad_vars) {
             return Backward(forwardOp, no_grad_vars).release();
           })
349
      .def("run",
350
           [](OperatorBase &self, const Scope &scope,
D
dzhwinter 已提交
351 352 353
              const platform::CPUPlace &place) { self.Run(scope, place); })
      .def("run",
           [](OperatorBase &self, const Scope &scope,
D
dzhwinter 已提交
354
              const platform::CUDAPlace &place) { self.Run(scope, place); })
Y
Yu Yang 已提交
355 356 357 358 359 360 361
      .def("type",
           [](const OperatorBase &op) -> std::string { return op.Type(); })
      .def("outputs",
           [](const OperatorBase &op)
               -> std::map<std::string, std::vector<std::string>> {
                 return op.Outputs();
               })
Q
qijun 已提交
362 363
      .def("output_vars",
           [](const OperatorBase &op) { return op.OutputVars(true); })
Y
Yu Yang 已提交
364
      .def("inputs", [](const OperatorBase &op) { return op.Inputs(); })
Q
qijun 已提交
365
      .def("input_vars", [](const OperatorBase &op) { return op.InputVars(); })
Y
Yu Yang 已提交
366 367 368 369
      .def("__str__", &OperatorBase::DebugString)
      .def("no_intermediate_outputs",
           [](const OperatorBase &op) { return op.OutputVars(false); })
      .def("support_gpu", &OperatorBase::SupportGPU);
Y
Yu Yang 已提交
370

Y
Yu Yang 已提交
371 372 373 374 375 376 377
  py::class_<operators::NetOp, OperatorBase>(m, "Net")
      .def_static("create",
                  []() -> operators::NetOp * {
                    auto *retv = new operators::NetOp;
                    retv->SetType("plain_net");
                    return retv;
                  })
378 379
      .def("append_op", [](operators::NetOp &self,
                           const OperatorBase &op) { self.AppendOp(op); })
D
dongzhihong 已提交
380 381 382 383
      .def("complete_add_op", &operators::NetOp::CompleteAddOp)
      .def("complete_add_op", [](std::shared_ptr<operators::NetOp> &self) {
        self->CompleteAddOp();
      });
Y
Yan Chunwei 已提交
384

Z
cond op  
zchen0211 已提交
385 386 387 388
  // cond_op
  py::class_<operators::CondOp, OperatorBase>(m, "CondOp")
      .def_static("create",
                  [](py::bytes protobin) -> operators::CondOp * {
389
                    proto::OpDesc desc;
Z
cond op  
zchen0211 已提交
390 391 392 393 394
                    PADDLE_ENFORCE(desc.ParsePartialFromString(protobin),
                                   "Cannot parse user input to OpDesc");
                    PADDLE_ENFORCE(desc.IsInitialized(),
                                   "User OpDesc is not initialized, reason %s",
                                   desc.InitializationErrorString());
395
                    auto cond_op = OpRegistry::CreateOp(desc);
Z
cond op  
zchen0211 已提交
396 397 398 399 400 401 402 403 404 405 406
                    return static_cast<operators::CondOp *>(cond_op.release());
                  })
      .def("set_truenet",
           [](operators::CondOp &self, const operators::NetOp &net) -> void {
             self.set_truenet(net.Clone());
           })
      .def("set_falsenet",
           [](operators::CondOp &self, const operators::NetOp &net) -> void {
             self.set_falsenet(net.Clone());
           });

F
fengjiayi 已提交
407
  py::class_<framework::Executor>(m, "Executor")
D
dzhwinter 已提交
408
      .def(py::init<const platform::Place &>())
409 410 411
      .def("run",
           (void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) &
               Executor::Run);
F
fengjiayi 已提交
412

413
  m.def("unique_integer", UniqueIntegerGenerator);
D
dzhwinter 已提交
414
  m.def("init_gflags", framework::InitGflags);
Y
Yang Yu 已提交
415
  m.def("init_glog", framework::InitGLOG);
D
dzhwinter 已提交
416
  m.def("init_devices", &framework::InitDevices);
417

418
  m.def("is_compiled_with_cuda", IsCompiledWithCUDA);
419

420
  m.def("set_feed_variable", framework::SetFeedVariable);
Q
qijun 已提交
421
  m.def("get_fetch_variable", framework::GetFetchVariable);
Q
qijun 已提交
422

F
fengjiayi 已提交
423 424 425 426
  BindProgramDesc(m);
  BindBlockDesc(m);
  BindVarDsec(m);
  BindOpDesc(m);
Q
qiaolongfei 已提交
427
  BindConstValue(m);
Y
Yu Yang 已提交
428

Y
Yu Yang 已提交
429 430 431 432 433 434 435 436 437
  py::class_<framework::LoDRankTable>(m, "LodRankTable")
      .def("items", [](framework::LoDRankTable &table) {
        std::vector<std::pair<size_t, size_t>> res;
        for (auto &item : table.items()) {
          res.push_back({item.index, item.length});
        }
        return res;
      });

Y
Yu Yang 已提交
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454
  py::class_<LoDTensorArray>(m, "LoDTensorArray")
      .def("__getitem__",
           [](LoDTensorArray &self, size_t i) { return &self.at(i); },
           py::return_value_policy::reference)
      .def("__len__", [](LoDTensorArray &self) { return self.size(); })
      .def("__setitem__",
           [](LoDTensorArray &self, size_t i, const LoDTensor &t) {
             PADDLE_ENFORCE_LT(i, self.size());
             self[i].ShareDataWith(t);
             self[i].set_lod(t.lod());
           })
      .def("append", [](LoDTensorArray &self, const LoDTensor &t) {
        self.emplace_back();
        self.back().ShareDataWith(t);
        self.back().set_lod(t.lod());
      });

Y
Yu Yang 已提交
455
  m.def("op_support_gpu", OpSupportGPU);
D
Dong Zhihong 已提交
456
#ifdef PADDLE_WITH_CUDA
D
Dong Zhihong 已提交
457
  m.def("get_cuda_device_count", platform::GetCUDADeviceCount);
D
dangqingqing 已提交
458 459 460 461

  m.def("nvprof_init", platform::CudaProfilerInit);
  m.def("nvprof_start", platform::CudaProfilerStart);
  m.def("nvprof_stop", platform::CudaProfilerStop);
D
Dong Zhihong 已提交
462
#endif
Y
Yu Yang 已提交
463

464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481
  py::enum_<platform::ProfilerState>(m, "ProfilerState", py::arithmetic())
      .value("kDisabled", platform::ProfilerState::kDisabled)
      .value("kCPU", platform::ProfilerState::kCPU)
      .value("kCUDA", platform::ProfilerState::kCUDA)
      .export_values();

  py::enum_<platform::EventSortingKey>(m, "EventSortingKey", py::arithmetic())
      .value("kDefault", platform::EventSortingKey::kDefault)
      .value("kCalls", platform::EventSortingKey::kCalls)
      .value("kTotal", platform::EventSortingKey::kTotal)
      .value("kMin", platform::EventSortingKey::kMin)
      .value("kMax", platform::EventSortingKey::kMax)
      .value("kAve", platform::EventSortingKey::kAve)
      .export_values();

  m.def("enable_profiler", platform::EnableProfiler);
  m.def("disable_profiler", platform::DisableProfiler);
  m.def("reset_profiler", platform::ResetProfiler);
482
  return m.ptr();
L
Luo Tao 已提交
483
}
484
}  // namespace pybind
485
}  // namespace paddle