imperative.cc 24.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/pybind/imperative.h"
16

17
#include <Python.h>
18 19 20 21
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
22
#include <memory>
J
Jiabin Yang 已提交
23
#include <string>
24 25
#include <unordered_map>
#include <utility>
J
Jiabin Yang 已提交
26 27
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
28
#include "paddle/fluid/imperative/layer.h"
J
Jiabin Yang 已提交
29
#include "paddle/fluid/imperative/nccl_context.h"
30
#include "paddle/fluid/imperative/profiler.h"
31
#include "paddle/fluid/imperative/tracer.h"
M
minqiyang 已提交
32
#include "paddle/fluid/imperative/type_defs.h"
33
#include "paddle/fluid/pybind/op_function.h"
34
#include "paddle/fluid/pybind/pybind_boost_headers.h"
L
Leo Chen 已提交
35
#include "paddle/fluid/pybind/tensor_py.h"
36

37 38 39
namespace paddle {
namespace pybind {

40 41
namespace py = ::pybind11;

42 43 44 45
class Layer : public imperative::Layer {
 public:
  using imperative::Layer::Layer;  // Inherit constructors

46 47 48 49
  std::vector<std::shared_ptr<imperative::VarBase>> Forward(
      const std::vector<std::shared_ptr<imperative::VarBase>> &inputs)
      override {
    PYBIND11_OVERLOAD(std::vector<std::shared_ptr<imperative::VarBase>>, Layer,
J
Jiabin Yang 已提交
50
                      Forward, inputs);  // NOLINT
51 52 53
  }
};

L
Leo Chen 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
static const platform::Place PyObjectToPlace(const py::object &place_obj) {
  if (py::isinstance<platform::CPUPlace>(place_obj)) {
    return place_obj.cast<platform::CPUPlace>();
  } else if (py::isinstance<platform::CUDAPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPlace>();
  } else if (py::isinstance<platform::CUDAPinnedPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPinnedPlace>();
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace"));
  }
}

static void InitTensorForVarBase(imperative::VarBase *self,
                                 const py::array &array,
                                 const platform::Place place,
                                 bool persistable = false,
                                 bool zero_copy = false,
                                 std::string name = "") {
  if (name == "") {
    name = imperative::GetCurrentTracer()->GenerateUniqueName("generated_var");
  }
  new (self) imperative::VarBase(name);
77
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
L
Leo Chen 已提交
78 79 80 81 82 83 84 85 86
  if (platform::is_cpu_place(place)) {
    SetTensorFromPyArray<platform::CPUPlace>(
        tensor, array, boost::get<platform::CPUPlace>(place), zero_copy);
  } else if (platform::is_gpu_place(place)) {
    SetTensorFromPyArray<platform::CUDAPlace>(
        tensor, array, boost::get<platform::CUDAPlace>(place), zero_copy);
  } else if (platform::is_cuda_pinned_place(place)) {
    SetTensorFromPyArray<platform::CUDAPinnedPlace>(
        tensor, array, boost::get<platform::CUDAPinnedPlace>(place), zero_copy);
87
  } else {
L
Leo Chen 已提交
88 89
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace"));
J
Jiabin Yang 已提交
90
  }
L
Leo Chen 已提交
91
  self->SetPersistable(persistable);
92 93 94 95 96 97 98 99 100
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
                                           const py::kwargs &kwargs) {
  PADDLE_ENFORCE_EQ(
      kwargs.contains("value"), true,
      platform::errors::InvalidArgument("Missing argument: value"));
L
Leo Chen 已提交
101 102 103 104 105 106 107 108 109 110 111 112 113

  auto persistable = kwargs.contains("persistable")
                         ? kwargs["persistable"].cast<bool>()
                         : false;
  auto array = kwargs.contains("value") ? kwargs["value"].cast<py::array>()
                                        : py::array();
  auto zero_copy =
      kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast<bool>() : false;
  auto name = kwargs.contains("name") ? kwargs["name"].cast<std::string>() : "";
  auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
  auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"])
                                        : default_place;
  InitTensorForVarBase(self, array, place, persistable, zero_copy, name);
114
}
115

116 117 118
template <typename P>
static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
                                        const py::array &array, const P &place,
L
Leo Chen 已提交
119 120 121 122 123 124 125 126
                                        bool persistable = false,
                                        bool zero_copy = false,
                                        std::string name = "") {
  // 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name
  if (name == "") {
    name = imperative::GetCurrentTracer()->GenerateUniqueName("generated_var");
  }
  new (self) imperative::VarBase(name);
127 128 129 130 131 132 133 134
  self->SetPersistable(persistable);
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
  SetTensorFromPyArray<P>(tensor, array, place, zero_copy);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
L
Leo Chen 已提交
135 136 137
                                               const py::array &array) {
  auto place = imperative::GetCurrentTracer()->ExpectedPlace();
  InitTensorForVarBase(self, array, place);
138
}
139

140 141 142 143 144
static std::string GetTypeName(const imperative::VarBase &var) {
  if (var.Type() == framework::proto::VarType::RAW) {
    return "RAW";
  } else if (!var.Var().IsInitialized()) {
    return "nullptr";
145
  } else {
146
    return framework::ToTypeName(var.Var().Type());
147 148
  }
}
L
Leo Chen 已提交
149

150
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172

template <typename T>
static T PyObjectCast(PyObject *obj) {
  try {
    return py::cast<T>(py::handle(obj));
  } catch (py::cast_error &) {
    PADDLE_THROW("Python object is not type of %s", typeid(T).name());
  }
}

// NOTE(zjl): py::handle is a very light wrapper of PyObject *.
// Unlike py::object, py::handle does not change reference count of PyObject *.
static std::vector<std::shared_ptr<imperative::VarBase>>
GetVarBaseListFromPyHandle(const py::handle &handle) {
  PyObject *py_obj = handle.ptr();  // get underlying PyObject
  // Python None is not nullptr in C++!
  if (!py_obj || py_obj == Py_None) {
    return {};
  }

  std::vector<std::shared_ptr<imperative::VarBase>> result;

173
  if (PyList_Check(py_obj)) {  // List of VarBase
174 175 176
    size_t len = PyList_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
177 178 179
      PyObject *py_ivar = PyList_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
180 181 182
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
183
  } else if (PyTuple_Check(py_obj)) {  // Tuple of VarBase
184 185 186
    size_t len = PyTuple_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
187 188 189
      PyObject *py_ivar = PyTuple_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
190 191 192
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
193 194 195
  } else {  // VarBase
    result.emplace_back(
        PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_obj));
196 197 198 199 200
  }

  return result;
}

J
Jiabin Yang 已提交
201 202 203
static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
    const PyNameVarBaseMap &map) {
  imperative::NameVarBaseMap result;
204 205 206 207 208 209
  for (auto &pair : map) {
    auto var_vec = GetVarBaseListFromPyHandle(pair.second);
    if (!var_vec.empty()) {
      result.emplace(pair.first, std::move(var_vec));
    }
  }
J
Jiabin Yang 已提交
210 211 212

  PADDLE_ENFORCE_EQ(PyErr_Occurred() == nullptr, true,
                    py::str(py::handle(PyErr_Occurred())));
213 214 215
  return result;
}

216
// Bind Methods
J
Jiabin Yang 已提交
217
void BindImperative(py::module *m_ptr) {
218 219
  auto &m = *m_ptr;

220 221
  BindOpFunctions(&m);

222
  py::class_<imperative::detail::BackwardStrategy> backward_strategy(
223 224
      m, "BackwardStrategy", R"DOC(

J
Jiabin Yang 已提交
225
    BackwardStrategy is a descriptor of how to run the backward process.
226

J
Jiabin Yang 已提交
227 228
    **Note**:
        **This API is only avaliable in** `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ **Mode**
229

J
Jiabin Yang 已提交
230 231
    Attribute:
        **sort_sum_gradient**:
232

J
Jiabin Yang 已提交
233
        If framework will sum the gradient by the reverse order of trace. eg. x_var ( :ref:`api_guide_Variable` ) will be the input of multiple OP such as :ref:`api_fluid_layers_scale` , this attr will decide if framework will sum gradient of `x_var` by the reverse order.
L
lujun 已提交
234

J
Jiabin Yang 已提交
235
        By Default: False
L
lujun 已提交
236

J
Jiabin Yang 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle.fluid as fluid

                x = np.ones([2, 2], np.float32)
                with fluid.dygraph.guard():
                    x_var = fluid.dygraph.to_variable(x)
                    sums_inputs = []
                    # x_var will be multi-scales' input here
                    for _ in range(10):
                        sums_inputs.append(fluid.layers.scale(x_var))
                    ret2 = fluid.layers.sums(sums_inputs)
                    loss2 = fluid.layers.reduce_sum(ret2)
                    backward_strategy = fluid.dygraph.BackwardStrategy()
                    backward_strategy.sort_sum_gradient = True
                    loss2.backward(backward_strategy)
255
      )DOC");
256 257 258 259 260 261 262 263 264 265 266 267 268 269 270
  backward_strategy.def(py::init())
      .def_property("sort_sum_gradient",
                    [](const imperative::detail::BackwardStrategy &self) {
                      return self.sorted_sum_gradient_;
                    },
                    [](imperative::detail::BackwardStrategy &self,
                       bool sorted_sum_gradient) {
                      self.sorted_sum_gradient_ = sorted_sum_gradient;
                    });

  m.def("start_imperative_gperf_profiler",
        []() { imperative::StartProfile(); });

  m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); });

Z
Zeng Jinle 已提交
271 272 273
  m.def("_is_dygraph_debug_enabled",
        []() { return imperative::IsDebugEnabled(); });
  m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); });
274 275 276 277
  m.def("_switch_tracer",
        [](const std::shared_ptr<imperative::Tracer> &tracer) {
          imperative::SetCurrentTracer(tracer);
        });
Z
Zeng Jinle 已提交
278

279
  py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
J
Jiabin Yang 已提交
280 281
      m, "VarBase",
      R"DOC()DOC")
Z
Zeng Jinle 已提交
282
      .def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
J
Jiabin Yang 已提交
283
      .def("__init__",
284 285 286 287 288 289 290 291 292 293 294
           [](imperative::VarBase &self, framework::proto::VarType::Type dtype,
              const std::vector<int> &dims, const py::handle &name,
              framework::proto::VarType::Type type, bool persistable) {
             std::string act_name = "";
             if (!name.ptr() || name.ptr() == Py_None) {
               act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
                   "generated_var");
             } else {
               act_name = name.cast<std::string>();
             }
             new (&self) imperative::VarBase(act_name);
J
Jiabin Yang 已提交
295 296 297 298 299 300 301 302 303
             self.SetPersistable(persistable);
             self.SetType(type);
             self.SetDataType(dtype);
             if (type == framework::proto::VarType::LOD_TENSOR) {
               auto *tensor =
                   self.MutableVar()->GetMutable<framework::LoDTensor>();
               tensor->Resize(framework::make_ddim(dims));
             }
           })
304 305
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
L
Leo Chen 已提交
306
           py::arg("zero_copy") = false, py::arg("name") = "")
307 308
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
L
Leo Chen 已提交
309
           py::arg("zero_copy") = false, py::arg("name") = "")
310 311
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPinnedPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
L
Leo Chen 已提交
312 313
           py::arg("zero_copy") = false, py::arg("name") = "")
      .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420
      .def("__init__", &InitVarBaseFromNumpyWithKwargs)
      .def("numpy",
           [](imperative::VarBase &self) -> py::array {
             const auto &tensor =
                 self.MutableVar()->Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(
                 tensor.IsInitialized(), true,
                 platform::errors::InvalidArgument(
                     "%s is Empty, Please check if it has no data in",
                     self.Name()));
             return TensorToPyArray(tensor, true);
           },
           R"DOC(
        **Notes**:
            **This API is ONLY avaliable in Dygraph mode**

        Returns a numpy array shows the value of current :ref:`api_guide_Variable_en`

        Returns:
            ndarray: The numpy value of current Variable.

        Returns type:
            ndarray: dtype is same as current Variable

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                from paddle.fluid.dygraph.base import to_variable
                from paddle.fluid.dygraph import FC
                import numpy as np

                data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
                with fluid.dygraph.guard():
                    fc = FC("fc", 64, num_flatten_dims=2)
                    data = to_variable(data)
                    x = fc(data)
                    print(x.numpy())

       )DOC")
      .def("detach",
           [](const imperative::VarBase &self) {
             const auto &tensor = self.Var().Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true,
                               platform::errors::InvalidArgument(
                                   "%s has not been initialized", self.Name()));
             return self.NewVarBase(tensor.place(), false);
           },
           py::return_value_policy::copy, R"DOC(
        **Notes**:
            **This API is ONLY avaliable in Dygraph mode**

        Returns a new Variable, detached from the current graph.

        Returns:
             ( :ref:`api_guide_Variable_en` | dtype is same as current Variable): The detached Variable.


        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                from paddle.fluid.dygraph.base import to_variable
                from paddle.fluid.dygraph import FC
                import numpy as np

                data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
                with fluid.dygraph.guard():
                    fc = FC("fc", 64, num_flatten_dims=2)
                    data = to_variable(data)
                    x = fc(data)
                    y = x.detach()

       )DOC")
      .def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(

        **Notes**:
        **1. This API is ONLY avaliable in Dygraph mode**

        **2. Use it only Variable has gradient, normally we use this for Parameters since other temporal Variable will be deleted by Python's GC**

        Clear  (set to ``0`` ) the Gradient of Current Variable

        Returns:  None

        Examples:
             .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                x = np.ones([2, 2], np.float32)
                with fluid.dygraph.guard():
                    inputs2 = []
                    for _ in range(10):
                         tmp = fluid.dygraph.base.to_variable(x)
                         tmp.stop_gradient=False
                         inputs2.append(tmp)
                    ret2 = fluid.layers.sums(inputs2)
                    loss2 = fluid.layers.reduce_sum(ret2)
                    backward_strategy = fluid.dygraph.BackwardStrategy()
                    backward_strategy.sort_sum_gradient = True
                    loss2.backward(backward_strategy)
                    print(loss2.gradient())
                    loss2.clear_gradient()
                    print("After clear {}".format(loss2.gradient()))
      )DOC")
L
Leo Chen 已提交
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
      .def("_run_backward",
           [](imperative::VarBase &self,
              const imperative::detail::BackwardStrategy &bckst,
              const imperative::Tracer &tracer) {
             // TODO(jiabin): when we impl more backward execution we can select
             // them

             imperative::Engine *engine = tracer.GetDefaultEngine();
             VLOG(3) << "Start backward";
             engine->Init(&self, bckst);
             engine->Execute();
             VLOG(3) << "Finish backward";
           },
           py::call_guard<py::gil_scoped_release>())
      .def("_grad_name", &imperative::VarBase::GradVarName)
      .def("_grad_value",
           [](imperative::VarBase &self) {
             return self.MutableGradVar()->Get<framework::LoDTensor>();
           },
           py::return_value_policy::reference)
441
      .def("_grad_ivar",
J
Jiabin Yang 已提交
442 443
           [](const imperative::VarBase &self) {
             auto &grad_var = self.GradVarBase();
444 445 446 447 448 449 450 451 452 453 454
             if (grad_var && grad_var->Var().IsInitialized()) {
               auto *tensor =
                   grad_var->MutableVar()->IsType<framework::LoDTensor>()
                       ? grad_var->MutableVar()
                             ->GetMutable<framework::LoDTensor>()
                       : grad_var->MutableVar()
                             ->GetMutable<framework::SelectedRows>()
                             ->mutable_value();
               if (tensor->IsInitialized()) {
                 return grad_var;
               }
J
Jiabin Yang 已提交
455
             }
456
             return std::shared_ptr<imperative::VarBase>(nullptr);
J
Jiabin Yang 已提交
457 458
           },
           py::return_value_policy::copy)
459 460
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CPUPlace &place,
J
Jiabin Yang 已提交
461 462
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
463 464
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CUDAPlace &place,
J
Jiabin Yang 已提交
465 466 467
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
      .def("value", [](imperative::VarBase &self) { return self.MutableVar(); },
468 469 470
           py::return_value_policy::reference)
      .def_property("name", &imperative::VarBase::Name,
                    &imperative::VarBase::SetName)
L
Leo Chen 已提交
471 472 473 474 475
      .def_property("stop_gradient",
                    &imperative::VarBase::OverridedStopGradient,
                    &imperative::VarBase::SetOverridedStopGradient)
      .def_property("persistable", &imperative::VarBase::Persistable,
                    &imperative::VarBase::SetPersistable)
J
Jiabin Yang 已提交
476 477 478 479
      .def_property_readonly(
          "shape",
          [](imperative::VarBase &self) {
            if (self.Var().IsType<framework::LoDTensor>()) {
480
              return framework::vectorize<int>(
J
Jiabin Yang 已提交
481
                  self.Var().Get<framework::LoDTensor>().dims());
482 483 484
            } else if (self.Var().IsType<framework::SelectedRows>()) {
              return framework::vectorize<int>(
                  self.Var().Get<framework::SelectedRows>().value().dims());
J
Jiabin Yang 已提交
485 486 487 488 489 490 491
            } else {
              VLOG(2) << "It is meaningless to get shape of variable type "
                      << GetTypeName(self);
              return std::vector<int>();
            }
          })
      .def_property_readonly("type", &imperative::VarBase::Type)
L
Leo Chen 已提交
492
      .def_property_readonly("dtype", &imperative::VarBase::DataType);
493 494 495

  py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer");
  layer.def(py::init<>())
496 497 498 499 500
      .def("forward",
           [](imperative::Layer &self,
              const std::vector<std::shared_ptr<imperative::VarBase>> &inputs) {
             return self.Forward(inputs);
           });
501

502 503 504 505 506
  py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "")
      .def("create_program_desc",
           &imperative::jit::ProgramDescTracer::CreateProgramDesc)
      .def("reset", &imperative::jit::ProgramDescTracer::Reset);

507 508 509
  py::class_<imperative::Tracer, std::shared_ptr<imperative::Tracer>>(
      m, "Tracer",
      R"DOC()DOC")
510
      .def("__init__",
J
Jiabin Yang 已提交
511
           [](imperative::Tracer &self) { new (&self) imperative::Tracer(); })
512 513 514
      .def_property("_enable_program_desc_tracing",
                    &imperative::Tracer::IsProgramDescTracingEnabled,
                    &imperative::Tracer::SetEnableProgramDescTracing)
515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
      .def_property("_train_mode", &imperative::Tracer::NoGrad,
                    &imperative::Tracer::SetNoGrad)
      .def_property(
          "_expected_place",
          [](const imperative::Tracer &self) -> py::object {
            return py::cast(self.ExpectedPlace());
          },
          [](imperative::Tracer &self, const py::object &obj) {
            if (py::isinstance<platform::CUDAPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPlace *>();
              self.SetExpectedPlace<platform::CUDAPlace>(*p);
            } else if (py::isinstance<platform::CPUPlace>(obj)) {
              auto p = obj.cast<platform::CPUPlace *>();
              self.SetExpectedPlace<platform::CPUPlace>(*p);
            } else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPinnedPlace *>();
              self.SetExpectedPlace<platform::CUDAPinnedPlace>(*p);
            } else {
              PADDLE_THROW(
                  "Incompatible Place Type: supports CUDAPlace, CPUPlace, "
                  "CUDAPinnedPlace, "
                  "but got Unknown Type!");
            }
          })
539 540 541
      .def("_get_program_desc_tracer",
           &imperative::Tracer::GetProgramDescTracer,
           py::return_value_policy::reference)
M
minqiyang 已提交
542
      .def("trace",
J
Jiabin Yang 已提交
543 544 545 546 547 548
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CUDAPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
549 550
             {
               py::gil_scoped_release release;
J
Jiabin Yang 已提交
551 552
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
553
             }
M
minqiyang 已提交
554
           })
J
Jiabin Yang 已提交
555 556 557 558 559 560 561 562 563 564 565 566 567
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           });
568 569

  // define parallel context
570 571 572
  py::class_<imperative::ParallelStrategy> parallel_strategy(
      m, "ParallelStrategy", "");
  parallel_strategy.def(py::init())
573 574
      .def_property(
          "nranks",
575 576
          [](const imperative::ParallelStrategy &self) { return self.nranks_; },
          [](imperative::ParallelStrategy &self, int nranks) {
577 578 579
            self.nranks_ = nranks;
          })
      .def_property("local_rank",
580
                    [](const imperative::ParallelStrategy &self) {
581 582
                      return self.local_rank_;
                    },
583
                    [](imperative::ParallelStrategy &self, int local_rank) {
584 585 586 587
                      self.local_rank_ = local_rank;
                    })
      .def_property(
          "trainer_endpoints",
588
          [](const imperative::ParallelStrategy &self) {
589 590
            return self.trainer_endpoints_;
          },
591
          [](imperative::ParallelStrategy &self, std::vector<std::string> eps) {
592 593 594
            self.trainer_endpoints_ = eps;
          })
      .def_property("current_endpoint",
595
                    [](const imperative::ParallelStrategy &self) {
596 597
                      return self.current_endpoint_;
                    },
598 599
                    [](imperative::ParallelStrategy &self,
                       const std::string &ep) { self.current_endpoint_ = ep; });
600
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
601 602
  py::class_<imperative::NCCLParallelContext> nccl_ctx(m,
                                                       "NCCLParallelContext");
603 604

  nccl_ctx
605 606 607
      .def(py::init<const imperative::ParallelStrategy &,
                    const platform::CUDAPlace &>())
      .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); });
608
#endif
609 610 611 612
}

}  // namespace pybind
}  // namespace paddle