imperative.cc 25.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/pybind/imperative.h"
16

17
#include <Python.h>
18 19 20 21
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
22
#include <memory>
J
Jiabin Yang 已提交
23
#include <string>
24 25
#include <unordered_map>
#include <utility>
J
Jiabin Yang 已提交
26 27
#include <vector>
#include "paddle/fluid/imperative/backward_strategy.h"
28
#include "paddle/fluid/imperative/layer.h"
J
Jiabin Yang 已提交
29
#include "paddle/fluid/imperative/nccl_context.h"
30
#include "paddle/fluid/imperative/profiler.h"
31
#include "paddle/fluid/imperative/tracer.h"
M
minqiyang 已提交
32
#include "paddle/fluid/imperative/type_defs.h"
33 34
#include "paddle/fluid/pybind/pybind_boost_headers.h"

35 36 37
namespace paddle {
namespace pybind {

38 39
namespace py = ::pybind11;

40 41 42 43 44 45
template <typename P>
extern void SetTensorFromPyArray(framework::Tensor *self, const py::object &obj,
                                 const P &place, bool zero_copy);
extern py::array TensorToPyArray(const framework::Tensor &tensor,
                                 bool need_deep_copy = false);

46 47 48 49
class Layer : public imperative::Layer {
 public:
  using imperative::Layer::Layer;  // Inherit constructors

50 51 52 53
  std::vector<std::shared_ptr<imperative::VarBase>> Forward(
      const std::vector<std::shared_ptr<imperative::VarBase>> &inputs)
      override {
    PYBIND11_OVERLOAD(std::vector<std::shared_ptr<imperative::VarBase>>, Layer,
J
Jiabin Yang 已提交
54
                      Forward, inputs);  // NOLINT
55 56 57
  }
};

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
static void InitTensorForVarBase(imperative::VarBase *self, bool persistable,
                                 bool is_default, const py::array &array,
                                 const py::object &obj = py::object(),
                                 bool zero_copy = false) {
  new (self) imperative::VarBase(
      imperative::GetCurrentTracer()->GenerateUniqueName("generated_var_"));
  self->SetPersistable(persistable);
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
  if (is_default) {
    auto place = imperative::GetCurrentTracer()->ExpectedPlace();
    if (platform::is_cpu_place(place)) {
      SetTensorFromPyArray<platform::CPUPlace>(
          tensor, array, boost::get<platform::CPUPlace>(place), zero_copy);
    } else if (platform::is_gpu_place(place)) {
      SetTensorFromPyArray<platform::CUDAPlace>(
          tensor, array, boost::get<platform::CUDAPlace>(place), zero_copy);
    } else if (platform::is_cuda_pinned_place(place)) {
      SetTensorFromPyArray<platform::CUDAPinnedPlace>(
          tensor, array, boost::get<platform::CUDAPinnedPlace>(place),
          zero_copy);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace"));
    }
  } else {
    if (py::isinstance<platform::CPUPlace>(obj)) {
      SetTensorFromPyArray<platform::CPUPlace>(
          tensor, array, obj.cast<platform::CPUPlace>(), zero_copy);
    } else if (py::isinstance<platform::CUDAPlace>(obj)) {
      SetTensorFromPyArray<platform::CUDAPlace>(
          tensor, array, obj.cast<platform::CUDAPlace>(), zero_copy);
    } else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
      SetTensorFromPyArray<platform::CUDAPinnedPlace>(
          tensor, array, obj.cast<platform::CUDAPinnedPlace>(), zero_copy);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Place should be one of CPUPlace/CUDAPlace/CUDAPinnedPlace"));
    }
J
Jiabin Yang 已提交
96
  }
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
                                           const py::kwargs &kwargs) {
  PADDLE_ENFORCE_EQ(
      kwargs.contains("value"), true,
      platform::errors::InvalidArgument("Missing argument: value"));
  if (kwargs.contains("place")) {
    InitTensorForVarBase(self, kwargs.contains("persistable")
                                   ? kwargs["persistable"].cast<bool>()
                                   : false,
                         false, kwargs["value"].cast<py::array>(),
                         kwargs["place"], kwargs["zero_copy"].cast<bool>());
  } else {
    InitTensorForVarBase(self, kwargs.contains("persistable")
                                   ? kwargs["persistable"].cast<bool>()
                                   : false,
                         true, kwargs["value"].cast<py::array>(), py::object(),
                         kwargs["zero_copy"].cast<bool>());
J
Jiabin Yang 已提交
118
  }
119
}
120

121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
template <typename P>
static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
                                        const py::array &array, const P &place,
                                        bool persistable, bool zero_copy) {
  // 0: value, 1: place, 2: name 3: persistable, 4: zero_copy
  new (self) imperative::VarBase(
      imperative::GetCurrentTracer()->GenerateUniqueName("generated_var_"));
  self->SetPersistable(persistable);
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
  SetTensorFromPyArray<P>(tensor, array, place, zero_copy);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
                                               const py::array &array,
                                               bool persistable) {
  InitTensorForVarBase(self, persistable, true, array);
}
140

141 142 143 144 145
static std::string GetTypeName(const imperative::VarBase &var) {
  if (var.Type() == framework::proto::VarType::RAW) {
    return "RAW";
  } else if (!var.Var().IsInitialized()) {
    return "nullptr";
146
  } else {
147
    return framework::ToTypeName(var.Var().Type());
148 149
  }
}
150
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172

template <typename T>
static T PyObjectCast(PyObject *obj) {
  try {
    return py::cast<T>(py::handle(obj));
  } catch (py::cast_error &) {
    PADDLE_THROW("Python object is not type of %s", typeid(T).name());
  }
}

// NOTE(zjl): py::handle is a very light wrapper of PyObject *.
// Unlike py::object, py::handle does not change reference count of PyObject *.
static std::vector<std::shared_ptr<imperative::VarBase>>
GetVarBaseListFromPyHandle(const py::handle &handle) {
  PyObject *py_obj = handle.ptr();  // get underlying PyObject
  // Python None is not nullptr in C++!
  if (!py_obj || py_obj == Py_None) {
    return {};
  }

  std::vector<std::shared_ptr<imperative::VarBase>> result;

173
  if (PyList_Check(py_obj)) {  // List of VarBase
174 175 176
    size_t len = PyList_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
177 178 179
      PyObject *py_ivar = PyList_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
180 181 182
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
183
  } else if (PyTuple_Check(py_obj)) {  // Tuple of VarBase
184 185 186
    size_t len = PyTuple_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
187 188 189
      PyObject *py_ivar = PyTuple_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
190 191 192
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
193 194 195
  } else {  // VarBase
    result.emplace_back(
        PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_obj));
196 197 198 199 200
  }

  return result;
}

J
Jiabin Yang 已提交
201 202 203
static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
    const PyNameVarBaseMap &map) {
  imperative::NameVarBaseMap result;
204 205 206 207 208 209
  for (auto &pair : map) {
    auto var_vec = GetVarBaseListFromPyHandle(pair.second);
    if (!var_vec.empty()) {
      result.emplace(pair.first, std::move(var_vec));
    }
  }
J
Jiabin Yang 已提交
210 211 212

  PADDLE_ENFORCE_EQ(PyErr_Occurred() == nullptr, true,
                    py::str(py::handle(PyErr_Occurred())));
213 214 215
  return result;
}

216
// Bind Methods
J
Jiabin Yang 已提交
217
void BindImperative(py::module *m_ptr) {
218 219 220
  auto &m = *m_ptr;

  py::class_<imperative::detail::BackwardStrategy> backward_strategy(
221 222
      m, "BackwardStrategy", R"DOC(

J
Jiabin Yang 已提交
223
    BackwardStrategy is a descriptor of how to run the backward process.
224

J
Jiabin Yang 已提交
225 226
    **Note**:
        **This API is only avaliable in** `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ **Mode**
227

J
Jiabin Yang 已提交
228 229
    Attribute:
        **sort_sum_gradient**:
230

J
Jiabin Yang 已提交
231
        If framework will sum the gradient by the reverse order of trace. eg. x_var ( :ref:`api_guide_Variable` ) will be the input of multiple OP such as :ref:`api_fluid_layers_scale` , this attr will decide if framework will sum gradient of `x_var` by the reverse order.
L
lujun 已提交
232

J
Jiabin Yang 已提交
233
        By Default: False
L
lujun 已提交
234

J
Jiabin Yang 已提交
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle.fluid as fluid

                x = np.ones([2, 2], np.float32)
                with fluid.dygraph.guard():
                    x_var = fluid.dygraph.to_variable(x)
                    sums_inputs = []
                    # x_var will be multi-scales' input here
                    for _ in range(10):
                        sums_inputs.append(fluid.layers.scale(x_var))
                    ret2 = fluid.layers.sums(sums_inputs)
                    loss2 = fluid.layers.reduce_sum(ret2)
                    backward_strategy = fluid.dygraph.BackwardStrategy()
                    backward_strategy.sort_sum_gradient = True
                    loss2.backward(backward_strategy)
253
      )DOC");
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
  backward_strategy.def(py::init())
      .def_property("sort_sum_gradient",
                    [](const imperative::detail::BackwardStrategy &self) {
                      return self.sorted_sum_gradient_;
                    },
                    [](imperative::detail::BackwardStrategy &self,
                       bool sorted_sum_gradient) {
                      self.sorted_sum_gradient_ = sorted_sum_gradient;
                    });

  m.def("start_imperative_gperf_profiler",
        []() { imperative::StartProfile(); });

  m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); });

Z
Zeng Jinle 已提交
269 270 271
  m.def("_is_dygraph_debug_enabled",
        []() { return imperative::IsDebugEnabled(); });
  m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); });
272 273 274 275
  m.def("_switch_tracer",
        [](const std::shared_ptr<imperative::Tracer> &tracer) {
          imperative::SetCurrentTracer(tracer);
        });
Z
Zeng Jinle 已提交
276

277
  py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
J
Jiabin Yang 已提交
278 279
      m, "VarBase",
      R"DOC()DOC")
Z
Zeng Jinle 已提交
280
      .def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
J
Jiabin Yang 已提交
281
      .def("__init__",
282 283 284 285 286 287 288 289 290 291 292
           [](imperative::VarBase &self, framework::proto::VarType::Type dtype,
              const std::vector<int> &dims, const py::handle &name,
              framework::proto::VarType::Type type, bool persistable) {
             std::string act_name = "";
             if (!name.ptr() || name.ptr() == Py_None) {
               act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
                   "generated_var");
             } else {
               act_name = name.cast<std::string>();
             }
             new (&self) imperative::VarBase(act_name);
J
Jiabin Yang 已提交
293 294 295 296 297 298 299 300 301
             self.SetPersistable(persistable);
             self.SetType(type);
             self.SetDataType(dtype);
             if (type == framework::proto::VarType::LOD_TENSOR) {
               auto *tensor =
                   self.MutableVar()->GetMutable<framework::LoDTensor>();
               tensor->Resize(framework::make_ddim(dims));
             }
           })
302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
           py::arg("zero_copy") = false)
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
           py::arg("zero_copy") = false)
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPinnedPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
           py::arg("zero_copy") = false)
      .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"),
           py::arg("persistable") = false)
      .def("__init__", &InitVarBaseFromNumpyWithKwargs)
      .def("numpy",
           [](imperative::VarBase &self) -> py::array {
             const auto &tensor =
                 self.MutableVar()->Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(
                 tensor.IsInitialized(), true,
                 platform::errors::InvalidArgument(
                     "%s is Empty, Please check if it has no data in",
                     self.Name()));
             return TensorToPyArray(tensor, true);
           },
           R"DOC(
        **Notes**:
            **This API is ONLY avaliable in Dygraph mode**

        Returns a numpy array shows the value of current :ref:`api_guide_Variable_en`

        Returns:
            ndarray: The numpy value of current Variable.

        Returns type:
            ndarray: dtype is same as current Variable

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                from paddle.fluid.dygraph.base import to_variable
                from paddle.fluid.dygraph import FC
                import numpy as np

                data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
                with fluid.dygraph.guard():
                    fc = FC("fc", 64, num_flatten_dims=2)
                    data = to_variable(data)
                    x = fc(data)
                    print(x.numpy())

       )DOC")
      .def("detach",
           [](const imperative::VarBase &self) {
             const auto &tensor = self.Var().Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true,
                               platform::errors::InvalidArgument(
                                   "%s has not been initialized", self.Name()));
             return self.NewVarBase(tensor.place(), false);
           },
           py::return_value_policy::copy, R"DOC(
        **Notes**:
            **This API is ONLY avaliable in Dygraph mode**

        Returns a new Variable, detached from the current graph.

        Returns:
             ( :ref:`api_guide_Variable_en` | dtype is same as current Variable): The detached Variable.


        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                from paddle.fluid.dygraph.base import to_variable
                from paddle.fluid.dygraph import FC
                import numpy as np

                data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
                with fluid.dygraph.guard():
                    fc = FC("fc", 64, num_flatten_dims=2)
                    data = to_variable(data)
                    x = fc(data)
                    y = x.detach()

       )DOC")
387 388
      .def("_run_backward",
           [](imperative::VarBase &self,
J
Jiabin Yang 已提交
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406
              const imperative::detail::BackwardStrategy &bckst,
              const imperative::Tracer &tracer) {
             // TODO(jiabin): when we impl more backward execution we can select
             // them

             imperative::Engine *engine = tracer.GetDefaultEngine();
             VLOG(3) << "Start backward";
             engine->Init(&self, bckst);
             engine->Execute();
             VLOG(3) << "Finish backward";
           },
           py::call_guard<py::gil_scoped_release>())
      .def("_grad_name", &imperative::VarBase::GradVarName)
      .def("_grad_value",
           [](imperative::VarBase &self) {
             return self.MutableGradVar()->Get<framework::LoDTensor>();
           },
           py::return_value_policy::reference)
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439
      .def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(

        **Notes**:
        **1. This API is ONLY avaliable in Dygraph mode**

        **2. Use it only Variable has gradient, normally we use this for Parameters since other temporal Variable will be deleted by Python's GC**

        Clear  (set to ``0`` ) the Gradient of Current Variable

        Returns:  None

        Examples:
             .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                x = np.ones([2, 2], np.float32)
                with fluid.dygraph.guard():
                    inputs2 = []
                    for _ in range(10):
                         tmp = fluid.dygraph.base.to_variable(x)
                         tmp.stop_gradient=False
                         inputs2.append(tmp)
                    ret2 = fluid.layers.sums(inputs2)
                    loss2 = fluid.layers.reduce_sum(ret2)
                    backward_strategy = fluid.dygraph.BackwardStrategy()
                    backward_strategy.sort_sum_gradient = True
                    loss2.backward(backward_strategy)
                    print(loss2.gradient())
                    loss2.clear_gradient()
                    print("After clear {}".format(loss2.gradient()))
      )DOC")
440
      .def("_grad_ivar",
J
Jiabin Yang 已提交
441 442
           [](const imperative::VarBase &self) {
             auto &grad_var = self.GradVarBase();
443 444 445 446 447 448 449 450 451 452 453
             if (grad_var && grad_var->Var().IsInitialized()) {
               auto *tensor =
                   grad_var->MutableVar()->IsType<framework::LoDTensor>()
                       ? grad_var->MutableVar()
                             ->GetMutable<framework::LoDTensor>()
                       : grad_var->MutableVar()
                             ->GetMutable<framework::SelectedRows>()
                             ->mutable_value();
               if (tensor->IsInitialized()) {
                 return grad_var;
               }
J
Jiabin Yang 已提交
454
             }
455
             return std::shared_ptr<imperative::VarBase>(nullptr);
J
Jiabin Yang 已提交
456 457
           },
           py::return_value_policy::copy)
458 459
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CPUPlace &place,
J
Jiabin Yang 已提交
460 461
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
462 463
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CUDAPlace &place,
J
Jiabin Yang 已提交
464 465 466
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
      .def("value", [](imperative::VarBase &self) { return self.MutableVar(); },
467 468 469
           py::return_value_policy::reference)
      .def_property("name", &imperative::VarBase::Name,
                    &imperative::VarBase::SetName)
J
Jiabin Yang 已提交
470 471 472 473
      .def_property_readonly(
          "shape",
          [](imperative::VarBase &self) {
            if (self.Var().IsType<framework::LoDTensor>()) {
474
              return framework::vectorize<int>(
J
Jiabin Yang 已提交
475
                  self.Var().Get<framework::LoDTensor>().dims());
476 477 478
            } else if (self.Var().IsType<framework::SelectedRows>()) {
              return framework::vectorize<int>(
                  self.Var().Get<framework::SelectedRows>().value().dims());
J
Jiabin Yang 已提交
479 480 481 482 483 484 485
            } else {
              VLOG(2) << "It is meaningless to get shape of variable type "
                      << GetTypeName(self);
              return std::vector<int>();
            }
          })
      .def_property_readonly("type", &imperative::VarBase::Type)
486
      .def_property_readonly("dtype", &imperative::VarBase::DataType)
J
Jiabin Yang 已提交
487
      .def_property("persistable", &imperative::VarBase::Persistable,
488
                    &imperative::VarBase::SetPersistable)
489 490 491
      .def_property("stop_gradient",
                    &imperative::VarBase::OverridedStopGradient,
                    &imperative::VarBase::SetOverridedStopGradient);
492 493 494

  py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer");
  layer.def(py::init<>())
495 496 497 498 499
      .def("forward",
           [](imperative::Layer &self,
              const std::vector<std::shared_ptr<imperative::VarBase>> &inputs) {
             return self.Forward(inputs);
           });
500

501 502 503 504 505
  py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "")
      .def("create_program_desc",
           &imperative::jit::ProgramDescTracer::CreateProgramDesc)
      .def("reset", &imperative::jit::ProgramDescTracer::Reset);

506 507 508
  py::class_<imperative::Tracer, std::shared_ptr<imperative::Tracer>>(
      m, "Tracer",
      R"DOC()DOC")
509
      .def("__init__",
J
Jiabin Yang 已提交
510
           [](imperative::Tracer &self) { new (&self) imperative::Tracer(); })
511 512 513
      .def_property("_enable_program_desc_tracing",
                    &imperative::Tracer::IsProgramDescTracingEnabled,
                    &imperative::Tracer::SetEnableProgramDescTracing)
514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537
      .def_property("_train_mode", &imperative::Tracer::NoGrad,
                    &imperative::Tracer::SetNoGrad)
      .def_property(
          "_expected_place",
          [](const imperative::Tracer &self) -> py::object {
            return py::cast(self.ExpectedPlace());
          },
          [](imperative::Tracer &self, const py::object &obj) {
            if (py::isinstance<platform::CUDAPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPlace *>();
              self.SetExpectedPlace<platform::CUDAPlace>(*p);
            } else if (py::isinstance<platform::CPUPlace>(obj)) {
              auto p = obj.cast<platform::CPUPlace *>();
              self.SetExpectedPlace<platform::CPUPlace>(*p);
            } else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPinnedPlace *>();
              self.SetExpectedPlace<platform::CUDAPinnedPlace>(*p);
            } else {
              PADDLE_THROW(
                  "Incompatible Place Type: supports CUDAPlace, CPUPlace, "
                  "CUDAPinnedPlace, "
                  "but got Unknown Type!");
            }
          })
538 539 540
      .def("_get_program_desc_tracer",
           &imperative::Tracer::GetProgramDescTracer,
           py::return_value_policy::reference)
M
minqiyang 已提交
541
      .def("trace",
J
Jiabin Yang 已提交
542 543 544 545 546 547
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CUDAPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
548 549
             {
               py::gil_scoped_release release;
J
Jiabin Yang 已提交
550 551
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
552
             }
M
minqiyang 已提交
553
           })
J
Jiabin Yang 已提交
554 555 556 557 558 559 560 561 562 563 564 565 566
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           });
567 568

  // define parallel context
569 570 571
  py::class_<imperative::ParallelStrategy> parallel_strategy(
      m, "ParallelStrategy", "");
  parallel_strategy.def(py::init())
572 573
      .def_property(
          "nranks",
574 575
          [](const imperative::ParallelStrategy &self) { return self.nranks_; },
          [](imperative::ParallelStrategy &self, int nranks) {
576 577 578
            self.nranks_ = nranks;
          })
      .def_property("local_rank",
579
                    [](const imperative::ParallelStrategy &self) {
580 581
                      return self.local_rank_;
                    },
582
                    [](imperative::ParallelStrategy &self, int local_rank) {
583 584 585 586
                      self.local_rank_ = local_rank;
                    })
      .def_property(
          "trainer_endpoints",
587
          [](const imperative::ParallelStrategy &self) {
588 589
            return self.trainer_endpoints_;
          },
590
          [](imperative::ParallelStrategy &self, std::vector<std::string> eps) {
591 592 593
            self.trainer_endpoints_ = eps;
          })
      .def_property("current_endpoint",
594
                    [](const imperative::ParallelStrategy &self) {
595 596
                      return self.current_endpoint_;
                    },
597 598
                    [](imperative::ParallelStrategy &self,
                       const std::string &ep) { self.current_endpoint_ = ep; });
599
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
600 601
  py::class_<imperative::NCCLParallelContext> nccl_ctx(m,
                                                       "NCCLParallelContext");
602 603

  nccl_ctx
604 605 606
      .def(py::init<const imperative::ParallelStrategy &,
                    const platform::CUDAPlace &>())
      .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); });
607
#endif
608 609 610 611
}

}  // namespace pybind
}  // namespace paddle