imperative.cc 44.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/pybind/imperative.h"
16

17
#include <Python.h>
18 19 20 21
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
22

23
#include <memory>
24
#include <set>
J
Jiabin Yang 已提交
25
#include <string>
26
#include <unordered_map>
27
#include <unordered_set>
28
#include <utility>
J
Jiabin Yang 已提交
29
#include <vector>
30

31
#include "paddle/fluid/imperative/all_reduce.h"
32
#include "paddle/fluid/imperative/amp_auto_cast.h"
J
Jiabin Yang 已提交
33
#include "paddle/fluid/imperative/backward_strategy.h"
34
#include "paddle/fluid/imperative/basic_engine.h"
35
#include "paddle/fluid/imperative/data_loader.h"
36
#include "paddle/fluid/imperative/layer.h"
J
Jiabin Yang 已提交
37
#include "paddle/fluid/imperative/nccl_context.h"
38
#include "paddle/fluid/imperative/partial_grad_engine.h"
39
#include "paddle/fluid/imperative/profiler.h"
40
#include "paddle/fluid/imperative/tracer.h"
M
minqiyang 已提交
41
#include "paddle/fluid/imperative/type_defs.h"
42
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
43
#include "paddle/fluid/pybind/op_function.h"
44
#include "paddle/fluid/pybind/pybind_boost_headers.h"
L
Leo Chen 已提交
45
#include "paddle/fluid/pybind/tensor_py.h"
46

47 48 49
namespace paddle {
namespace pybind {

50 51
namespace py = ::pybind11;

52 53 54 55
class Layer : public imperative::Layer {
 public:
  using imperative::Layer::Layer;  // Inherit constructors

56 57 58 59
  std::vector<std::shared_ptr<imperative::VarBase>> Forward(
      const std::vector<std::shared_ptr<imperative::VarBase>> &inputs)
      override {
    PYBIND11_OVERLOAD(std::vector<std::shared_ptr<imperative::VarBase>>, Layer,
J
Jiabin Yang 已提交
60
                      Forward, inputs);  // NOLINT
61 62 63
  }
};

L
Leo Chen 已提交
64 65 66 67 68
static const platform::Place PyObjectToPlace(const py::object &place_obj) {
  if (py::isinstance<platform::CPUPlace>(place_obj)) {
    return place_obj.cast<platform::CPUPlace>();
  } else if (py::isinstance<platform::CUDAPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPlace>();
69 70
  } else if (py::isinstance<platform::XPUPlace>(place_obj)) {
    return place_obj.cast<platform::XPUPlace>();
L
Leo Chen 已提交
71 72 73 74
  } else if (py::isinstance<platform::CUDAPinnedPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPinnedPlace>();
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
75
        "Place should be one of CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace"));
L
Leo Chen 已提交
76 77 78 79 80 81 82
  }
}

static void InitTensorForVarBase(imperative::VarBase *self,
                                 const py::array &array,
                                 const platform::Place place,
                                 bool persistable = false,
83 84
                                 bool zero_copy = false, std::string name = "",
                                 int stop_gradient = -1) {
L
Leo Chen 已提交
85
  if (name == "") {
86 87
    name =
        imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor");
L
Leo Chen 已提交
88
  }
89 90 91
  VLOG(5) << "Init Tensor as: / name: " << name
          << " / persistable: " << persistable << " / zero_copy: " << zero_copy
          << " / stop_gradient: " << stop_gradient;
L
Leo Chen 已提交
92
  new (self) imperative::VarBase(name);
93
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
L
Leo Chen 已提交
94 95
  if (platform::is_cpu_place(place)) {
    SetTensorFromPyArray<platform::CPUPlace>(
96
        tensor, array, BOOST_GET_CONST(platform::CPUPlace, place), zero_copy);
97 98 99
  } else if (platform::is_xpu_place(place)) {
    SetTensorFromPyArray<platform::XPUPlace>(
        tensor, array, BOOST_GET_CONST(platform::XPUPlace, place), zero_copy);
L
Leo Chen 已提交
100 101
  } else if (platform::is_gpu_place(place)) {
    SetTensorFromPyArray<platform::CUDAPlace>(
102
        tensor, array, BOOST_GET_CONST(platform::CUDAPlace, place), zero_copy);
L
Leo Chen 已提交
103 104
  } else if (platform::is_cuda_pinned_place(place)) {
    SetTensorFromPyArray<platform::CUDAPinnedPlace>(
105 106
        tensor, array, BOOST_GET_CONST(platform::CUDAPinnedPlace, place),
        zero_copy);
107
  } else {
L
Leo Chen 已提交
108
    PADDLE_THROW(platform::errors::InvalidArgument(
109
        "Place should be one of CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace"));
J
Jiabin Yang 已提交
110
  }
111 112 113
  if (stop_gradient != -1) {
    self->SetOverridedStopGradient(stop_gradient);
  }
L
Leo Chen 已提交
114
  self->SetPersistable(persistable);
115 116 117 118 119 120
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
                                           const py::kwargs &kwargs) {
121
  VLOG(4) << "Init VarBase from kwargs: ";
122 123
  PADDLE_ENFORCE_EQ(
      kwargs.contains("value"), true,
124 125
      platform::errors::NotFound(
          "The kwargs used to create Varbase misses argument: value"));
L
Leo Chen 已提交
126 127 128 129 130 131 132 133
  auto persistable = kwargs.contains("persistable")
                         ? kwargs["persistable"].cast<bool>()
                         : false;
  auto array = kwargs.contains("value") ? kwargs["value"].cast<py::array>()
                                        : py::array();
  auto zero_copy =
      kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast<bool>() : false;
  auto name = kwargs.contains("name") ? kwargs["name"].cast<std::string>() : "";
134 135 136
  auto stop_gradient = kwargs.contains("stop_gradient")
                           ? kwargs["stop_gradient"].cast<int>()
                           : -1;
L
Leo Chen 已提交
137 138 139
  auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
  auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"])
                                        : default_place;
140 141
  InitTensorForVarBase(self, array, place, persistable, zero_copy, name,
                       stop_gradient);
142
}
143

144 145 146
template <typename P>
static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
                                        const py::array &array, const P &place,
L
Leo Chen 已提交
147 148
                                        bool persistable = false,
                                        bool zero_copy = false,
149 150 151 152 153
                                        std::string name = "",
                                        int stop_gradient = -1) {
  VLOG(4) << "Init VarBase from Arg: ";
  // 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name , 6:
  // stop_gradient
L
Leo Chen 已提交
154
  if (name == "") {
155 156
    name =
        imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor");
L
Leo Chen 已提交
157
  }
158 159 160
  VLOG(5) << "Init Tensor as: / name: " << name
          << " / persistable: " << persistable << " / zero_copy: " << zero_copy
          << " / stop_gradient: " << stop_gradient;
L
Leo Chen 已提交
161
  new (self) imperative::VarBase(name);
162 163
  self->SetPersistable(persistable);
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
164 165 166
  if (stop_gradient != -1) {
    self->SetOverridedStopGradient(stop_gradient);
  }
167 168 169 170 171 172
  SetTensorFromPyArray<P>(tensor, array, place, zero_copy);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
L
Leo Chen 已提交
173
                                               const py::array &array) {
174
  VLOG(4) << "Init VarBase from numpy: ";
L
Leo Chen 已提交
175 176
  auto place = imperative::GetCurrentTracer()->ExpectedPlace();
  InitTensorForVarBase(self, array, place);
177
}
178

179 180 181 182 183
static void InitVarBaseFromTensorWithArgDefault(
    imperative::VarBase *self, const framework::LoDTensor &tensor) {
  VLOG(4) << "Init VarBase";
  auto place = imperative::GetCurrentTracer()->ExpectedPlace();
  new (self) imperative::VarBase(
184
      imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"));
185 186 187 188 189 190 191 192 193 194 195 196 197 198
  self->SetPersistable(false);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor.type());
  auto *new_tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
  // Same place,share data directly
  if (place == tensor.place()) {
    new_tensor->ShareDataWith(tensor);
    VLOG(4) << "Same place, do ShareDataWith";
  } else {
    framework::TensorCopy(tensor, place, new_tensor);
    VLOG(4) << "Different place, do TensorCopy";
  }
}

199 200 201 202 203
static std::string GetTypeName(const imperative::VarBase &var) {
  if (var.Type() == framework::proto::VarType::RAW) {
    return "RAW";
  } else if (!var.Var().IsInitialized()) {
    return "nullptr";
204
  } else {
205
    return framework::ToTypeName(var.Var().Type());
206 207
  }
}
L
Leo Chen 已提交
208

209
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
210 211 212 213 214 215

template <typename T>
static T PyObjectCast(PyObject *obj) {
  try {
    return py::cast<T>(py::handle(obj));
  } catch (py::cast_error &) {
216 217
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Python object is not type of %s", typeid(T).name()));
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232
  }
}

// NOTE(zjl): py::handle is a very light wrapper of PyObject *.
// Unlike py::object, py::handle does not change reference count of PyObject *.
static std::vector<std::shared_ptr<imperative::VarBase>>
GetVarBaseListFromPyHandle(const py::handle &handle) {
  PyObject *py_obj = handle.ptr();  // get underlying PyObject
  // Python None is not nullptr in C++!
  if (!py_obj || py_obj == Py_None) {
    return {};
  }

  std::vector<std::shared_ptr<imperative::VarBase>> result;

233
  if (PyList_Check(py_obj)) {  // List of VarBase
234 235 236
    size_t len = PyList_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
237 238 239
      PyObject *py_ivar = PyList_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
240 241 242
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
243
  } else if (PyTuple_Check(py_obj)) {  // Tuple of VarBase
244 245 246
    size_t len = PyTuple_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
247 248 249
      PyObject *py_ivar = PyTuple_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
250 251 252
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
253 254 255
  } else {  // VarBase
    result.emplace_back(
        PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_obj));
256 257 258 259 260
  }

  return result;
}

J
Jiabin Yang 已提交
261 262 263
static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
    const PyNameVarBaseMap &map) {
  imperative::NameVarBaseMap result;
264 265 266 267 268 269
  for (auto &pair : map) {
    auto var_vec = GetVarBaseListFromPyHandle(pair.second);
    if (!var_vec.empty()) {
      result.emplace(pair.first, std::move(var_vec));
    }
  }
J
Jiabin Yang 已提交
270

271 272 273
  PADDLE_ENFORCE_EQ(
      PyErr_Occurred(), nullptr,
      platform::errors::InvalidArgument(py::str(py::handle(PyErr_Occurred()))));
274 275 276
  return result;
}

277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341
static bool PyCheckInteger(PyObject *obj) {
#if PY_VERSION_HEX < 0x03000000
  return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj);
#else
  return PyLong_Check(obj) && !PyBool_Check(obj);
#endif
}

// NOTE(zhiqiu): Revised version of PySlice_GetIndices. From:
// https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103
// Original PySlice_GetIndices return wrong result when
// slice_item contains long int, such as arr[:180L].
// NOT sure why this happens !!!
// Besides, PySlice_GetIndices cannot raise error when float in slice item.
// So, I make a revised version of PySlice_GetIndices, named to
// _PySlice_GetIndices. Try to use _PySlice_Unpack which is more robust than
// PySlice_GetIndices in the future.
static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
                               Py_ssize_t *start, Py_ssize_t *stop,
                               Py_ssize_t *step) {
  /* XXX support long ints */
  if (r->step == Py_None) {
    *step = 1;
  } else {
    if (PyCheckInteger(r->step)) {
      *step = PyLong_AsLong(r->step);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->step)->tp_name)));
    }
  }
  if (r->start == Py_None) {
    *start = *step < 0 ? length - 1 : 0;
  } else {
    if (PyCheckInteger(r->start)) {
      *start = PyLong_AsLong(r->start);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->start)->tp_name)));
    }
    if (*start < 0) *start += length;
  }
  if (r->stop == Py_None) {
    *stop = *step < 0 ? -1 : length;
  } else {
    if (PyCheckInteger(r->stop)) {
      *stop = PyLong_AsLong(r->stop);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->stop)->tp_name)));
    }
    if (*stop < 0) *stop += length;
  }
  if (*stop > length) return -1;
  if (*start >= length) return -1;
  if (*step == 0) return -1;
  return 0;
}

S
songyouwei 已提交
342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
                               std::vector<int> *slice_axes,
                               std::vector<int> *slice_starts,
                               std::vector<int> *slice_ends,
                               std::vector<int> *slice_strides,
                               std::vector<int> *decrease_axis,
                               std::vector<int> *infer_flags) {
  // We allow indexing by Integers, Slices, and tuples of those
  // types.
  // Ellipsis and None are not supported yet.
  // wrap to tuple
  PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
  PADDLE_ENFORCE_EQ(
      tensor->IsInitialized(), true,
      platform::errors::InvalidArgument("tensor has not been initialized"));
  const auto &shape = tensor->dims();
  const int rank = shape.size();
  const int size = PyTuple_GET_SIZE(index);
  PADDLE_ENFORCE_EQ(
      size <= rank, true,
      platform::errors::InvalidArgument(
          "too many indices (%d) for tensor of dimension %d", size, rank));
  for (int dim = 0; dim < size; ++dim) {
    PyObject *slice_item = PyTuple_GetItem(index, dim);
366 367 368 369 370 371 372
    PADDLE_ENFORCE_EQ(PyCheckInteger(slice_item) || PySlice_Check(slice_item),
                      true,
                      platform::errors::InvalidArgument(
                          "Currently, VarBase.__getitem__() only allows "
                          "indexing by Integers, Slices, and tuples of "
                          "these types, but received %s in %dth slice item",
                          std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
S
songyouwei 已提交
373 374
    infer_flags->push_back(1);
    int dim_len = shape[dim];
375 376
    if (PyCheckInteger(slice_item)) {
      // integer, PyLong_AsLong supports both int and long
S
songyouwei 已提交
377
      int start = static_cast<int>(PyLong_AsLong(slice_item));
H
hong 已提交
378
      auto s_t = start;
S
songyouwei 已提交
379
      start = start < 0 ? start + dim_len : start;
H
hong 已提交
380 381 382 383 384 385 386 387 388 389
      if (start >= dim_len) {
        std::string str_error_message =
            "The starting index " + std::to_string(s_t) +
            " of slice is out of bounds in tensor " + std::to_string(dim) +
            "-th axis, it shound be in the range of [" +
            std::to_string(-dim_len) + ", " + std::to_string(dim_len) + ")";
        // py::index_error is corresponding to IndexError in Python
        // Used to indicate out of bounds access in __getitem__, __setitem__
        throw py::index_error(str_error_message);
      }
S
songyouwei 已提交
390 391 392 393 394 395
      slice_axes->push_back(dim);
      slice_starts->push_back(start);
      slice_ends->push_back(start + 1);
      slice_strides->push_back(1);
      decrease_axis->push_back(dim);
    } else {
396
      // slice item
S
songyouwei 已提交
397
      Py_ssize_t start, end, step;
398 399 400
      PySliceObject *p = reinterpret_cast<PySliceObject *>(slice_item);
      _PySlice_GetIndices(p, dim_len, &start, &end, &step);

S
songyouwei 已提交
401
      // :: or : or 0:dim_len:1
402 403 404
      if (start == 0 && end == dim_len && step == 1) {
        continue;
      }
S
songyouwei 已提交
405 406 407 408 409 410 411 412 413
      slice_axes->push_back(dim);
      slice_starts->push_back(start);
      slice_ends->push_back(end);
      slice_strides->push_back(step);
    }
  }
  if (!PyTuple_Check(_index)) Py_DecRef(index);
}

414
// Bind Methods
J
Jiabin Yang 已提交
415
void BindImperative(py::module *m_ptr) {
416 417
  auto &m = *m_ptr;

418 419
  BindOpFunctions(&m);

420 421
#ifndef _WIN32
  // Dygraph DataLoader signal handler
422 423 424 425 426 427 428 429 430 431 432 433 434
  m.def("_set_process_pids", [](int64_t key, py::object &obj) {
    PADDLE_ENFORCE_EQ(
        py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj), true,
        platform::errors::InvalidArgument(
            "The subprocess ids set in DataLoader is illegal."
            "Expected data type is tuple or list, but received %s",
            obj.get_type()));
    py::list pids = py::cast<py::list>(obj);
    std::set<pid_t> pids_set = {};
    for (size_t i = 0; i < pids.size(); i++) {
      pids_set.insert(pids[i].cast<pid_t>());
    }
    imperative::SetLoadProcessPIDs(key, pids_set);
435
  });
436 437
  m.def("_erase_process_pids",
        [](int64_t key) { imperative::EraseLoadProcessPIDs(key); });
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509
  m.def("_set_process_signal_handler",
        []() { imperative::SetLoadProcessSignalHandler(); });
  m.def("_throw_error_if_process_failed",
        []() { imperative::ThrowErrorIfLoadProcessFailed(); });

  // Dygraph DataLoader reader process & thread related functions
  m.def(
      "_convert_to_tensor_list",
      [](py::object &obj) -> py::list {
        // 0. input data check
        PADDLE_ENFORCE(
            py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj),
            platform::errors::InvalidArgument(
                "The batch data read into DataLoader is illegal."
                "Expected data type is tuple or list, but received %s",
                obj.get_type()));
        py::list batch = py::cast<py::list>(obj);
        py::list tensors;
        for (size_t i = 0; i < batch.size(); ++i) {
          // 1. cast to python array
          auto array = batch[i].cast<py::array>();
          PADDLE_ENFORCE_NE(
              string::Sprintf("%s", array.dtype()).compare("object"), 0,
              platform::errors::InvalidArgument(
                  "Faild to convert input data to a regular ndarray.\n  * "
                  "Usually this means the input data contains nested "
                  "lists with different lengths.\n  * Check the reader "
                  "function passed to 'set_(sample/sample_list/batch)"
                  "_generator' to locate the data causes this issue."));
          // 2. construcct LoDTensor
          framework::LoDTensor t;
          SetTensorFromPyArray<platform::CPUPlace>(&t, array,
                                                   platform::CPUPlace(), true);
          // 3. allocate shared memory
          void *data_ptr = t.data<void>();
          size_t data_size = t.numel() * framework::SizeOfType(t.type());
          auto shared_writer_holder =
              memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
          // 4. maintain mmap fd set & backup ipc_name
          const std::string &ipc_name = shared_writer_holder->ipc_name();
          memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
          // 5. copy data & reset holder
          memory::Copy(platform::CPUPlace(), shared_writer_holder->ptr(),
                       platform::CPUPlace(), data_ptr, data_size);
          t.ResetHolder(shared_writer_holder);
          // 6. append to result list
          tensors.append(t);
        }
        return tensors;
      },
      py::return_value_policy::take_ownership);

  m.def("_remove_tensor_list_mmap_fds", [](py::list &tensor_list) {
    for (size_t i = 0; i < tensor_list.size(); ++i) {
      auto t = tensor_list[i].cast<framework::LoDTensor>();
      auto *mmap_writer_allocation =
          dynamic_cast<memory::allocation::MemoryMapWriterAllocation *>(
              t.Holder().get());
      PADDLE_ENFORCE_NOT_NULL(
          mmap_writer_allocation,
          platform::errors::NotFound("The shared memory of LoDTensor in "
                                     "DataLoader's child process has been "
                                     "released."));
      memory::allocation::MemoryMapFdSet::Instance().Remove(
          mmap_writer_allocation->ipc_name());
    }
  });

  m.def("_cleanup_mmap_fds",
        []() { memory::allocation::MemoryMapFdSet::Instance().Clear(); });
#endif

510
  py::class_<imperative::detail::BackwardStrategy> backward_strategy(
511 512
      m, "BackwardStrategy", R"DOC(

J
Jiabin Yang 已提交
513
    BackwardStrategy is a descriptor of how to run the backward process.
514

J
Jiabin Yang 已提交
515
    **Note**:
T
tianshuo78520a 已提交
516
        **This API is only available in** `Dygraph <../../user_guides/howto/dygraph/DyGraph.html>`_ **Mode**
517

J
Jiabin Yang 已提交
518 519
    Attribute:
        **sort_sum_gradient**:
520

J
Jiabin Yang 已提交
521
        If framework will sum the gradient by the reverse order of trace. eg. x_var ( :ref:`api_guide_Variable` ) will be the input of multiple OP such as :ref:`api_fluid_layers_scale` , this attr will decide if framework will sum gradient of `x_var` by the reverse order.
L
lujun 已提交
522

J
Jiabin Yang 已提交
523
        By Default: False
L
lujun 已提交
524

J
Jiabin Yang 已提交
525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
        Examples:
            .. code-block:: python

                import numpy as np
                import paddle.fluid as fluid

                x = np.ones([2, 2], np.float32)
                with fluid.dygraph.guard():
                    x_var = fluid.dygraph.to_variable(x)
                    sums_inputs = []
                    # x_var will be multi-scales' input here
                    for _ in range(10):
                        sums_inputs.append(fluid.layers.scale(x_var))
                    ret2 = fluid.layers.sums(sums_inputs)
                    loss2 = fluid.layers.reduce_sum(ret2)
                    backward_strategy = fluid.dygraph.BackwardStrategy()
                    backward_strategy.sort_sum_gradient = True
                    loss2.backward(backward_strategy)
543
      )DOC");
544 545 546 547 548 549 550 551 552 553 554 555 556 557 558
  backward_strategy.def(py::init())
      .def_property("sort_sum_gradient",
                    [](const imperative::detail::BackwardStrategy &self) {
                      return self.sorted_sum_gradient_;
                    },
                    [](imperative::detail::BackwardStrategy &self,
                       bool sorted_sum_gradient) {
                      self.sorted_sum_gradient_ = sorted_sum_gradient;
                    });

  m.def("start_imperative_gperf_profiler",
        []() { imperative::StartProfile(); });

  m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); });

Z
Zeng Jinle 已提交
559 560 561
  m.def("_is_dygraph_debug_enabled",
        []() { return imperative::IsDebugEnabled(); });
  m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); });
562 563 564 565
  m.def("_switch_tracer",
        [](const std::shared_ptr<imperative::Tracer> &tracer) {
          imperative::SetCurrentTracer(tracer);
        });
Z
Zeng Jinle 已提交
566

567
  py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
568
      m, "VarBase", R"DOC()DOC")
Z
Zeng Jinle 已提交
569
      .def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
J
Jiabin Yang 已提交
570
      .def("__init__",
571 572 573
           [](imperative::VarBase &self, framework::proto::VarType::Type dtype,
              const std::vector<int> &dims, const py::handle &name,
              framework::proto::VarType::Type type, bool persistable) {
574
             VLOG(4) << "Init VarBase";
575 576 577
             std::string act_name = "";
             if (!name.ptr() || name.ptr() == Py_None) {
               act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
578
                   "generated_tensor");
579 580 581 582
             } else {
               act_name = name.cast<std::string>();
             }
             new (&self) imperative::VarBase(act_name);
J
Jiabin Yang 已提交
583 584 585 586 587 588 589 590 591
             self.SetPersistable(persistable);
             self.SetType(type);
             self.SetDataType(dtype);
             if (type == framework::proto::VarType::LOD_TENSOR) {
               auto *tensor =
                   self.MutableVar()->GetMutable<framework::LoDTensor>();
               tensor->Resize(framework::make_ddim(dims));
             }
           })
592 593
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
594 595
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
596 597 598 599
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::XPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
600 601
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
602 603
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
604 605
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPinnedPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
606 607
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
L
Leo Chen 已提交
608
      .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
609
      .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
610
      .def("__init__", &InitVarBaseFromNumpyWithKwargs)
611
      .def("__getitem__",
S
songyouwei 已提交
612
           [](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
613
             std::vector<int> slice_axes, slice_starts, slice_ends,
S
songyouwei 已提交
614 615 616 617 618 619
                 slice_strides, decrease_axis, infer_flags;
             auto tensor =
                 self->MutableVar()->GetMutable<framework::LoDTensor>();
             ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
                                &slice_starts, &slice_ends, &slice_strides,
                                &decrease_axis, &infer_flags);
620 621 622 623
             // release gil and do tracing
             py::gil_scoped_release release;
             const auto &tracer = imperative::GetCurrentTracer();
             if (slice_axes.empty()) {
S
songyouwei 已提交
624
               return self;
625
             } else {
S
songyouwei 已提交
626
               imperative::NameVarBaseMap ins = {{"Input", {self}}};
627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648
               framework::AttributeMap attrs = {
                   {"axes", slice_axes},
                   {"starts", slice_starts},
                   {"ends", slice_ends},
                   {"infer_flags", infer_flags},
                   {"decrease_axis", decrease_axis}};
               auto out = std::shared_ptr<imperative::VarBase>(
                   new imperative::VarBase(tracer->GenerateUniqueName()));
               imperative::NameVarBaseMap outs = {{"Out", {out}}};
               std::string op_type = "slice";
               for (auto stride : slice_strides) {
                 if (stride != 1) {
                   op_type = "strided_slice";
                   attrs.insert({"strides", slice_strides});
                   attrs.erase("decrease_axis");
                   break;
                 }
               }
               tracer->TraceOp(op_type, ins, outs, std::move(attrs));
               return out;
             }
           })
649 650 651 652 653 654 655
      .def("numpy",
           [](imperative::VarBase &self) -> py::array {
             const auto &tensor =
                 self.MutableVar()->Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(
                 tensor.IsInitialized(), true,
                 platform::errors::InvalidArgument(
656
                     "Tensor of %s is Empty, please check if it has no data.",
657 658 659 660 661
                     self.Name()));
             return TensorToPyArray(tensor, true);
           },
           R"DOC(
        **Notes**:
T
tianshuo78520a 已提交
662
            **This API is ONLY available in Dygraph mode**
663 664 665 666 667 668 669 670 671 672 673 674 675 676

        Returns a numpy array shows the value of current :ref:`api_guide_Variable_en`

        Returns:
            ndarray: The numpy value of current Variable.

        Returns type:
            ndarray: dtype is same as current Variable

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                from paddle.fluid.dygraph.base import to_variable
677
                from paddle.fluid.dygraph import Linear
678 679 680 681
                import numpy as np

                data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
                with fluid.dygraph.guard():
682
                    linear = Linear(32, 64)
683
                    data = to_variable(data)
684
                    x = linear(data)
685 686 687 688 689 690 691 692 693 694 695 696 697
                    print(x.numpy())

       )DOC")
      .def("detach",
           [](const imperative::VarBase &self) {
             const auto &tensor = self.Var().Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true,
                               platform::errors::InvalidArgument(
                                   "%s has not been initialized", self.Name()));
             return self.NewVarBase(tensor.place(), false);
           },
           py::return_value_policy::copy, R"DOC(
        **Notes**:
T
tianshuo78520a 已提交
698
            **This API is ONLY available in Dygraph mode**
699 700 701 702 703 704 705 706 707 708 709 710

        Returns a new Variable, detached from the current graph.

        Returns:
             ( :ref:`api_guide_Variable_en` | dtype is same as current Variable): The detached Variable.


        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                from paddle.fluid.dygraph.base import to_variable
711
                from paddle.fluid.dygraph import Linear
712 713 714 715
                import numpy as np

                data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
                with fluid.dygraph.guard():
716
                    linear = Linear(32, 64)
717
                    data = to_variable(data)
718
                    x = linear(data)
719 720 721 722 723 724
                    y = x.detach()

       )DOC")
      .def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(

        **Notes**:
T
tianshuo78520a 已提交
725
        **1. This API is ONLY available in Dygraph mode**
726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754

        **2. Use it only Variable has gradient, normally we use this for Parameters since other temporal Variable will be deleted by Python's GC**

        Clear  (set to ``0`` ) the Gradient of Current Variable

        Returns:  None

        Examples:
             .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                x = np.ones([2, 2], np.float32)
                with fluid.dygraph.guard():
                    inputs2 = []
                    for _ in range(10):
                         tmp = fluid.dygraph.base.to_variable(x)
                         tmp.stop_gradient=False
                         inputs2.append(tmp)
                    ret2 = fluid.layers.sums(inputs2)
                    loss2 = fluid.layers.reduce_sum(ret2)
                    backward_strategy = fluid.dygraph.BackwardStrategy()
                    backward_strategy.sort_sum_gradient = True
                    loss2.backward(backward_strategy)
                    print(loss2.gradient())
                    loss2.clear_gradient()
                    print("After clear {}".format(loss2.gradient()))
      )DOC")
L
Leo Chen 已提交
755 756 757
      .def("_run_backward",
           [](imperative::VarBase &self,
              const imperative::detail::BackwardStrategy &bckst,
758
              const imperative::Tracer &tracer, bool retain_graph) {
759 760
             // TODO(jiabin): when we impl more backward execution we can
             // select them
761
             auto *engine = tracer.GetEngine();
762
             engine->Init(&self, bckst, retain_graph);
763
             VLOG(3) << "Start backward";
L
Leo Chen 已提交
764 765 766 767 768 769 770 771 772 773
             engine->Execute();
             VLOG(3) << "Finish backward";
           },
           py::call_guard<py::gil_scoped_release>())
      .def("_grad_name", &imperative::VarBase::GradVarName)
      .def("_grad_value",
           [](imperative::VarBase &self) {
             return self.MutableGradVar()->Get<framework::LoDTensor>();
           },
           py::return_value_policy::reference)
774 775 776 777
      .def("_set_grad_type",
           [](imperative::VarBase &self, framework::proto::VarType::Type type) {
             self.MutableGradVarBase()->SetType(type);
           })
778
      .def("_grad_ivar",
J
Jiabin Yang 已提交
779 780
           [](const imperative::VarBase &self) {
             auto &grad_var = self.GradVarBase();
781 782 783 784 785 786 787 788 789 790 791
             if (grad_var && grad_var->Var().IsInitialized()) {
               auto *tensor =
                   grad_var->MutableVar()->IsType<framework::LoDTensor>()
                       ? grad_var->MutableVar()
                             ->GetMutable<framework::LoDTensor>()
                       : grad_var->MutableVar()
                             ->GetMutable<framework::SelectedRows>()
                             ->mutable_value();
               if (tensor->IsInitialized()) {
                 return grad_var;
               }
J
Jiabin Yang 已提交
792
             }
793
             return std::shared_ptr<imperative::VarBase>(nullptr);
J
Jiabin Yang 已提交
794 795
           },
           py::return_value_policy::copy)
796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825
      .def("_is_sparse",
           [](imperative::VarBase &self) {
             return self.Var().IsType<framework::SelectedRows>();
           })
      .def("_allreduce",
           [](imperative::VarBase &self,
              const imperative::ParallelStrategy &strategy) {
             if (strategy.nranks_ > 1) {
#ifdef PADDLE_WITH_NCCL
#if NCCL_VERSION_CODE >= 2212
               imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
#else
               if (!self.Var().IsType<framework::SelectedRows>()) {
                 imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
               } else {
                 PADDLE_THROW(platform::errors::Unimplemented(
                     "Imperative SelectedRows allreduce is not supported when "
                     "paddle is compiled with NCCL verison lower than v2.2.12. "
                     "You can set is_sparse=False for the Layer containing "
                     "this argument, such as Embedding(is_sparse=False)."));
               }
#endif  // NCCL_VERSION_CODE
#else
               PADDLE_THROW(platform::errors::Unimplemented(
                   "Imperative allreduce is not supported when paddle is "
                   "not compiled with NCCL."));
#endif  // PADDLE_WITH_NCCL
             }
           },
           py::call_guard<py::gil_scoped_release>())
826 827
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CPUPlace &place,
J
Jiabin Yang 已提交
828 829
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
830 831 832 833 834
      .def("_copy_to",
           [](const imperative::VarBase &self,
              const platform::CUDAPinnedPlace &place,
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
835 836 837 838
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::XPUPlace &place,
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
839 840
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CUDAPlace &place,
J
Jiabin Yang 已提交
841 842 843
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
      .def("value", [](imperative::VarBase &self) { return self.MutableVar(); },
844 845 846
           py::return_value_policy::reference)
      .def_property("name", &imperative::VarBase::Name,
                    &imperative::VarBase::SetName)
L
Leo Chen 已提交
847 848 849 850 851
      .def_property("stop_gradient",
                    &imperative::VarBase::OverridedStopGradient,
                    &imperative::VarBase::SetOverridedStopGradient)
      .def_property("persistable", &imperative::VarBase::Persistable,
                    &imperative::VarBase::SetPersistable)
J
Jiabin Yang 已提交
852 853 854 855
      .def_property_readonly(
          "shape",
          [](imperative::VarBase &self) {
            if (self.Var().IsType<framework::LoDTensor>()) {
856
              return framework::vectorize<int>(
J
Jiabin Yang 已提交
857
                  self.Var().Get<framework::LoDTensor>().dims());
858 859 860
            } else if (self.Var().IsType<framework::SelectedRows>()) {
              return framework::vectorize<int>(
                  self.Var().Get<framework::SelectedRows>().value().dims());
J
Jiabin Yang 已提交
861 862 863 864 865 866
            } else {
              VLOG(2) << "It is meaningless to get shape of variable type "
                      << GetTypeName(self);
              return std::vector<int>();
            }
          })
867 868 869
      .def_property_readonly(
          "place", [](imperative::VarBase &self) { return self.Place(); },
          py::return_value_policy::copy)
J
Jiabin Yang 已提交
870
      .def_property_readonly("type", &imperative::VarBase::Type)
L
Leo Chen 已提交
871
      .def_property_readonly("dtype", &imperative::VarBase::DataType);
872 873 874

  py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer");
  layer.def(py::init<>())
875 876 877 878 879
      .def("forward",
           [](imperative::Layer &self,
              const std::vector<std::shared_ptr<imperative::VarBase>> &inputs) {
             return self.Forward(inputs);
           });
880

881 882 883 884 885
  py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "")
      .def("create_program_desc",
           &imperative::jit::ProgramDescTracer::CreateProgramDesc)
      .def("reset", &imperative::jit::ProgramDescTracer::Reset);

886
  py::class_<imperative::Tracer, std::shared_ptr<imperative::Tracer>>(
887
      m, "Tracer", R"DOC()DOC")
888
      .def("__init__",
J
Jiabin Yang 已提交
889
           [](imperative::Tracer &self) { new (&self) imperative::Tracer(); })
890 891 892
      .def_property("_enable_program_desc_tracing",
                    &imperative::Tracer::IsProgramDescTracingEnabled,
                    &imperative::Tracer::SetEnableProgramDescTracing)
893 894
      .def_property("_enable_autocast", &imperative::Tracer::IsAutoCastEnabled,
                    &imperative::Tracer::SetEnableAutoCast)
895 896
      .def_property("_train_mode", &imperative::Tracer::HasGrad,
                    &imperative::Tracer::SetHasGrad)
897 898 899 900 901 902 903 904
      .def_property(
          "_expected_place",
          [](const imperative::Tracer &self) -> py::object {
            return py::cast(self.ExpectedPlace());
          },
          [](imperative::Tracer &self, const py::object &obj) {
            if (py::isinstance<platform::CUDAPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPlace *>();
L
Leo Chen 已提交
905
              self.SetExpectedPlace(*p);
906 907 908
            } else if (py::isinstance<platform::XPUPlace>(obj)) {
              auto p = obj.cast<platform::XPUPlace *>();
              self.SetExpectedPlace(*p);
909 910
            } else if (py::isinstance<platform::CPUPlace>(obj)) {
              auto p = obj.cast<platform::CPUPlace *>();
L
Leo Chen 已提交
911
              self.SetExpectedPlace(*p);
912 913
            } else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPinnedPlace *>();
L
Leo Chen 已提交
914
              self.SetExpectedPlace(*p);
915
            } else {
L
Leo Chen 已提交
916
              PADDLE_THROW(platform::errors::InvalidArgument(
917 918
                  "Incompatible Place Type: supports XPUPlace, CUDAPlace, "
                  "CPUPlace, "
L
Leo Chen 已提交
919 920
                  "and CUDAPinnedPlace, "
                  "but got Unknown Type!"));
921 922
            }
          })
923 924 925
      .def("_get_program_desc_tracer",
           &imperative::Tracer::GetProgramDescTracer,
           py::return_value_policy::reference)
926
      .def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
927
           py::arg("key") = "eager_tmp")
928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947
      .def(
          "_set_amp_op_list",
          [](imperative::Tracer &self,
             std::unordered_set<std::string> &allow_ops,
             std::unordered_set<std::string> &block_ops) {
            // NOTE(zhiqiu): The automatic conversion in pybind11 between c++
            // STL and python set/list/dict involve a copy operation that
            // prevents pass-by-reference semantics, so it is ok to swap.
            // The reaseon why not directly pass
            // std::shared_ptr<std::unordered_set<std::string>>
            // is that pybind11 forbid shared_ptr<T> where T is not custom type.
            imperative::AmpOperators::Instance().GetAllowOps()->swap(allow_ops);
            imperative::AmpOperators::Instance().GetBlockOps()->swap(block_ops);
          })
      .def("_get_amp_op_list",
           [](imperative::Tracer &self) {
             return std::make_tuple(
                 *(imperative::AmpOperators::Instance().GetAllowOps()),
                 *(imperative::AmpOperators::Instance().GetBlockOps()));
           })
948 949 950 951 952 953 954 955 956 957 958 959 960
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::XPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           })
M
minqiyang 已提交
961
      .def("trace",
J
Jiabin Yang 已提交
962 963 964 965 966 967
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CUDAPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
968 969
             {
               py::gil_scoped_release release;
J
Jiabin Yang 已提交
970 971
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
972
             }
M
minqiyang 已提交
973
           })
J
Jiabin Yang 已提交
974 975 976 977 978 979 980 981 982 983 984 985 986
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           });
987 988

  // define parallel context
989 990 991
  py::class_<imperative::ParallelStrategy> parallel_strategy(
      m, "ParallelStrategy", "");
  parallel_strategy.def(py::init())
992 993
      .def_property(
          "nranks",
994 995
          [](const imperative::ParallelStrategy &self) { return self.nranks_; },
          [](imperative::ParallelStrategy &self, int nranks) {
996 997 998
            self.nranks_ = nranks;
          })
      .def_property("local_rank",
999
                    [](const imperative::ParallelStrategy &self) {
1000 1001
                      return self.local_rank_;
                    },
1002
                    [](imperative::ParallelStrategy &self, int local_rank) {
1003 1004 1005 1006
                      self.local_rank_ = local_rank;
                    })
      .def_property(
          "trainer_endpoints",
1007
          [](const imperative::ParallelStrategy &self) {
1008 1009
            return self.trainer_endpoints_;
          },
1010
          [](imperative::ParallelStrategy &self, std::vector<std::string> eps) {
1011 1012 1013
            self.trainer_endpoints_ = eps;
          })
      .def_property("current_endpoint",
1014
                    [](const imperative::ParallelStrategy &self) {
1015 1016
                      return self.current_endpoint_;
                    },
1017 1018
                    [](imperative::ParallelStrategy &self,
                       const std::string &ep) { self.current_endpoint_ = ep; });
1019 1020 1021 1022 1023 1024 1025 1026 1027 1028

  m.def(
      "dygraph_partial_grad",
      [](const std::vector<std::shared_ptr<imperative::VarBase>> &input_targets,
         const std::vector<std::shared_ptr<imperative::VarBase>>
             &output_targets,
         const std::vector<std::shared_ptr<imperative::VarBase>> &output_grads,
         const std::vector<std::shared_ptr<imperative::VarBase>> &no_grad_vars,
         const platform::Place &place,
         const imperative::detail::BackwardStrategy &strategy,
Z
Zeng Jinle 已提交
1029 1030 1031 1032 1033
         bool create_graph, bool retain_graph, bool allow_unused,
         bool only_inputs) {
        imperative::PartialGradEngine engine(
            input_targets, output_targets, output_grads, no_grad_vars, place,
            strategy, create_graph, retain_graph, allow_unused, only_inputs);
1034 1035 1036 1037 1038
        engine.Execute();
        return engine.GetResult();
      },
      py::call_guard<py::gil_scoped_release>());

1039
#if defined(PADDLE_WITH_NCCL)
1040 1041
  py::class_<imperative::NCCLParallelContext> nccl_ctx(m,
                                                       "NCCLParallelContext");
1042 1043

  nccl_ctx
1044 1045 1046
      .def(py::init<const imperative::ParallelStrategy &,
                    const platform::CUDAPlace &>())
      .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); });
1047
#endif
1048 1049 1050 1051
}

}  // namespace pybind
}  // namespace paddle