imperative.cc 95.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/pybind/imperative.h"
16

17
#include <Python.h>
18 19 20 21
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
22

23
#include <algorithm>
24
#include <memory>
25
#include <set>
J
Jiabin Yang 已提交
26
#include <string>
27
#include <unordered_map>
28
#include <unordered_set>
29
#include <utility>
J
Jiabin Yang 已提交
30
#include <vector>
31

32
#include "paddle/fluid/framework/scope_guard.h"
33
#include "paddle/fluid/imperative/all_reduce.h"
34
#include "paddle/fluid/imperative/amp_auto_cast.h"
35
#include "paddle/fluid/imperative/basic_engine.h"
36
#include "paddle/fluid/imperative/bkcl_context.h"
37
#include "paddle/fluid/imperative/data_loader.h"
38
#include "paddle/fluid/imperative/gloo_context.h"
39
#include "paddle/fluid/imperative/hooks.h"
40
#include "paddle/fluid/imperative/layer.h"
J
Jiabin Yang 已提交
41
#include "paddle/fluid/imperative/nccl_context.h"
42
#include "paddle/fluid/imperative/partial_grad_engine.h"
43
#include "paddle/fluid/imperative/profiler.h"
44
#include "paddle/fluid/imperative/py_layer_fwd.h"
45
#include "paddle/fluid/imperative/reducer.h"
46
#include "paddle/fluid/imperative/tracer.h"
M
minqiyang 已提交
47
#include "paddle/fluid/imperative/type_defs.h"
48
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
49
#include "paddle/fluid/operators/utils.h"
50
#include "paddle/fluid/pybind/op_function.h"
51
#include "paddle/fluid/pybind/pybind_boost_headers.h"
L
Leo Chen 已提交
52
#include "paddle/fluid/pybind/tensor_py.h"
53

54 55 56
namespace paddle {
namespace pybind {

57 58
PyTypeObject *g_varbase_pytype = nullptr;

59 60
namespace py = ::pybind11;

61 62 63 64
class Layer : public imperative::Layer {
 public:
  using imperative::Layer::Layer;  // Inherit constructors

65 66 67 68
  std::vector<std::shared_ptr<imperative::VarBase>> Forward(
      const std::vector<std::shared_ptr<imperative::VarBase>> &inputs)
      override {
    PYBIND11_OVERLOAD(std::vector<std::shared_ptr<imperative::VarBase>>, Layer,
J
Jiabin Yang 已提交
69
                      Forward, inputs);  // NOLINT
70 71 72
  }
};

73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
template <typename T>
static T PyObjectCast(PyObject *obj) {
  try {
    return py::cast<T>(py::handle(obj));
  } catch (py::cast_error &) {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Python object is not type of %s", typeid(T).name()));
  }
}

class PyVariableWrapperHook : public imperative::VariableWrapperHook {
 public:
  explicit PyVariableWrapperHook(PyObject *func) : py_func_(func) {
    Py_INCREF(py_func_);
  }

  ~PyVariableWrapperHook() {
    py::gil_scoped_acquire gil;
    Py_DECREF(py_func_);
  }

  std::shared_ptr<imperative::VariableWrapper> operator()(
      const std::shared_ptr<imperative::VariableWrapper> &var) override {
    py::gil_scoped_acquire gil;
    VLOG(3) << "Call PyVariableWrapperHook for var " << var->Name();

    // 1. unpack temp VarBase from VariableWrapper
    std::shared_ptr<imperative::VarBase> tmp_varbase =
        std::make_shared<imperative::VarBase>(var);

    // 2. call hook and return
    PyObject *res = nullptr;
    try {
      res = PyObject_CallFunctionObjArgs(py_func_, py::cast(tmp_varbase).ptr(),
                                         nullptr);
    } catch (platform::EnforceNotMet &e) {
      throw std::move(e);
    } catch (std::exception &e) {
      PADDLE_THROW(platform::errors::Unavailable(
          "Hook function of Tensor raises an exception: %s.", e.what()));
    } catch (...) {
      PADDLE_THROW(platform::errors::Fatal(
          "Hook function of Tensor raises an unknown exception."));
    }

    PADDLE_ENFORCE_NOT_NULL(res,
                            platform::errors::Unavailable(
                                "Hook function of Tensor return a nullptr."));
    if (res == Py_None) {
      return var;
    }

    return PyObjectCast<std::shared_ptr<imperative::VarBase>>(res)->SharedVar();
  }

 private:
  PyObject *py_func_;
};

L
Leo Chen 已提交
132 133 134 135 136
static const platform::Place PyObjectToPlace(const py::object &place_obj) {
  if (py::isinstance<platform::CPUPlace>(place_obj)) {
    return place_obj.cast<platform::CPUPlace>();
  } else if (py::isinstance<platform::CUDAPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPlace>();
137 138
  } else if (py::isinstance<platform::XPUPlace>(place_obj)) {
    return place_obj.cast<platform::XPUPlace>();
L
Leo Chen 已提交
139 140
  } else if (py::isinstance<platform::CUDAPinnedPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPinnedPlace>();
141 142
  } else if (py::isinstance<platform::NPUPlace>(place_obj)) {
    return place_obj.cast<platform::NPUPlace>();
143 144
  } else if (py::isinstance<platform::Place>(place_obj)) {
    return place_obj.cast<platform::Place>();
L
Leo Chen 已提交
145 146
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
147
        "Place should be one of "
148
        "Place/CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace"));
L
Leo Chen 已提交
149 150 151
  }
}

L
Leo Chen 已提交
152 153 154 155 156 157 158 159 160 161
// only initialize varbase, but not its tensor.
static void InitVarBaseOnly(imperative::VarBase *self, const std::string &name,
                            bool persistable = false, int stop_gradient = -1) {
  auto name_ = name == ""
                   ? imperative::GetCurrentTracer()->GenerateUniqueName(
                         "generated_tensor")
                   : name;

  VLOG(5) << "Init Tensor as: / name: " << name_
          << " / persistable: " << persistable
162
          << " / stop_gradient: " << stop_gradient;
L
Leo Chen 已提交
163 164 165 166 167 168 169 170 171 172 173 174 175 176
  new (self) imperative::VarBase(name_);
  if (stop_gradient != -1) {
    self->SetOverridedStopGradient(stop_gradient);
  }
  self->SetPersistable(persistable);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
}

// initialize varbase and its tensor.
static void InitVarBaseAndTensor(
    imperative::VarBase *self, const py::array &array,
    const platform::Place &place, const std::string &name,
    bool persistable = false, bool zero_copy = false, int stop_gradient = -1) {
  InitVarBaseOnly(self, name, persistable, stop_gradient);
177
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
L
Leo Chen 已提交
178
  VLOG(4) << "zero_copy: " << zero_copy;
L
Leo Chen 已提交
179 180
  if (platform::is_cpu_place(place)) {
    SetTensorFromPyArray<platform::CPUPlace>(
181
        tensor, array, BOOST_GET_CONST(platform::CPUPlace, place), zero_copy);
182 183 184
  } else if (platform::is_xpu_place(place)) {
    SetTensorFromPyArray<platform::XPUPlace>(
        tensor, array, BOOST_GET_CONST(platform::XPUPlace, place), zero_copy);
L
Leo Chen 已提交
185 186
  } else if (platform::is_gpu_place(place)) {
    SetTensorFromPyArray<platform::CUDAPlace>(
187
        tensor, array, BOOST_GET_CONST(platform::CUDAPlace, place), zero_copy);
L
Leo Chen 已提交
188 189
  } else if (platform::is_cuda_pinned_place(place)) {
    SetTensorFromPyArray<platform::CUDAPinnedPlace>(
190 191
        tensor, array, BOOST_GET_CONST(platform::CUDAPinnedPlace, place),
        zero_copy);
192 193 194
  } else if (platform::is_npu_place(place)) {
    SetTensorFromPyArray<platform::NPUPlace>(
        tensor, array, BOOST_GET_CONST(platform::NPUPlace, place), zero_copy);
195
  } else {
L
Leo Chen 已提交
196
    PADDLE_THROW(platform::errors::InvalidArgument(
197 198
        "Place should be one of "
        "CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/NPUPlace"));
J
Jiabin Yang 已提交
199
  }
200 201 202 203 204
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
                                           const py::kwargs &kwargs) {
205
  VLOG(4) << "Init VarBase from kwargs: ";
L
Leo Chen 已提交
206 207 208 209 210 211
  auto persistable = kwargs.contains("persistable")
                         ? kwargs["persistable"].cast<bool>()
                         : false;
  auto zero_copy =
      kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast<bool>() : false;
  auto name = kwargs.contains("name") ? kwargs["name"].cast<std::string>() : "";
212 213 214
  auto stop_gradient = kwargs.contains("stop_gradient")
                           ? kwargs["stop_gradient"].cast<int>()
                           : -1;
L
Leo Chen 已提交
215
  auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
L
Leo Chen 已提交
216 217 218 219 220 221 222 223 224 225 226 227

  if (kwargs.contains("value")) {
    auto array = kwargs["value"].cast<py::array>();
    // place is only used when array is given, otherwise, it is meaningless and
    // ignored
    auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"])
                                          : default_place;
    InitVarBaseAndTensor(self, array, place, name, persistable, zero_copy,
                         stop_gradient);
  } else {
    InitVarBaseOnly(self, name, persistable, stop_gradient);
  }
228
}
229

230 231 232
template <typename P>
static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
                                        const py::array &array, const P &place,
L
Leo Chen 已提交
233 234
                                        bool persistable = false,
                                        bool zero_copy = false,
235 236 237 238 239
                                        std::string name = "",
                                        int stop_gradient = -1) {
  VLOG(4) << "Init VarBase from Arg: ";
  // 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name , 6:
  // stop_gradient
L
Leo Chen 已提交
240
  if (name == "") {
241 242
    name =
        imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor");
L
Leo Chen 已提交
243
  }
244 245
  VLOG(5) << "Init Tensor as: / name: " << name
          << " / persistable: " << persistable << " / zero_copy: " << zero_copy
246
          << " / stop_gradient: " << stop_gradient << " / at " << place;
L
Leo Chen 已提交
247
  new (self) imperative::VarBase(name);
248 249
  self->SetPersistable(persistable);
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
250 251 252
  if (stop_gradient != -1) {
    self->SetOverridedStopGradient(stop_gradient);
  }
253 254 255 256 257 258
  SetTensorFromPyArray<P>(tensor, array, place, zero_copy);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
L
Leo Chen 已提交
259 260
                                               const py::array &array) {
  auto place = imperative::GetCurrentTracer()->ExpectedPlace();
261
  VLOG(4) << "Init VarBase from numpy at " << place;
L
Leo Chen 已提交
262
  InitVarBaseAndTensor(self, array, place, "");
263
}
264

265
static void InitVarBaseFromTensorWithArgDefault(
266
    imperative::VarBase *self, const framework::Tensor &tensor) {
267 268 269
  VLOG(4) << "Init VarBase";
  auto place = imperative::GetCurrentTracer()->ExpectedPlace();
  new (self) imperative::VarBase(
270
      imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"));
271 272 273 274 275 276 277 278 279 280 281 282 283 284
  self->SetPersistable(false);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor.type());
  auto *new_tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
  // Same place,share data directly
  if (place == tensor.place()) {
    new_tensor->ShareDataWith(tensor);
    VLOG(4) << "Same place, do ShareDataWith";
  } else {
    framework::TensorCopy(tensor, place, new_tensor);
    VLOG(4) << "Different place, do TensorCopy";
  }
}

285 286 287 288 289
static std::string GetTypeName(const imperative::VarBase &var) {
  if (var.Type() == framework::proto::VarType::RAW) {
    return "RAW";
  } else if (!var.Var().IsInitialized()) {
    return "nullptr";
290
  } else {
291
    return framework::ToTypeName(var.Var().Type());
292 293
  }
}
L
Leo Chen 已提交
294

295
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
296 297 298 299 300 301 302 303 304 305 306 307 308

// NOTE(zjl): py::handle is a very light wrapper of PyObject *.
// Unlike py::object, py::handle does not change reference count of PyObject *.
static std::vector<std::shared_ptr<imperative::VarBase>>
GetVarBaseListFromPyHandle(const py::handle &handle) {
  PyObject *py_obj = handle.ptr();  // get underlying PyObject
  // Python None is not nullptr in C++!
  if (!py_obj || py_obj == Py_None) {
    return {};
  }

  std::vector<std::shared_ptr<imperative::VarBase>> result;

309
  if (PyList_Check(py_obj)) {  // List of VarBase
310 311 312
    size_t len = PyList_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
313 314 315
      PyObject *py_ivar = PyList_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
316 317 318
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
319
  } else if (PyTuple_Check(py_obj)) {  // Tuple of VarBase
320 321 322
    size_t len = PyTuple_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
323 324 325
      PyObject *py_ivar = PyTuple_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
326 327 328
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
329 330 331
  } else {  // VarBase
    result.emplace_back(
        PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_obj));
332 333 334 335
  }

  return result;
}
336 337 338 339 340 341 342 343
static bool IsNumpyType(PyObject *obj) {
  // It is not a good way to judge the type of obj by its type'name. Maybe using
  // `PyArray_IsScalar` will be better. However, this interface cannot be used
  // by including pybind11, and it needs to compile with numpy.
  auto type_name = std::string(Py_TYPE(obj)->tp_name);
  return type_name == "numpy.int64" || type_name == "numpy.longlong" ||
         type_name == "numpy.int32" || type_name == "numpy.int16";
}
344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388

static bool PyCheckTensor(PyObject *obj) {
  return py::isinstance<imperative::VarBase>(obj);
}

// cast numpy type form S to T, this may allocate new memory
template <class T, class S>
static py::array_t<T> CastNumpyType(py::array_t<S> array) {
  if (std::is_same<T, S>::value) {
    return array;
  }
  auto dim = array.ndim();
  std::vector<py::ssize_t> result_shape(dim);
  for (auto i = 0; i < dim; i++) {
    result_shape[i] = array.shape(i);
  }

  py::array_t<T> result(result_shape);

  return py::vectorize([](S s) { return static_cast<T>(s); })(array);
}

template <class T>
static py::array_t<T> CastNumpyArray(const py::object &array) {
  if (py::isinstance<py::array_t<float>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<float>>());
  } else if (py::isinstance<py::array_t<double>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<double>>());
  } else if (py::isinstance<py::array_t<int32_t>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<int32_t>>());
  } else if (py::isinstance<py::array_t<int64_t>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<int64_t>>());
  } else if (py::isinstance<py::array_t<bool>>(array)) {
    return CastNumpyType<T>(array.cast<py::array_t<bool>>());
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Value type error. The assign numpy value allows integer, float, "
        "double and bool, "
        "but received %s.",
        Py_TYPE(array.ptr())->tp_name));
  }
  // can't reach here
  return py::array_t<T>();
}

J
Jiabin Yang 已提交
389 390 391
static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
    const PyNameVarBaseMap &map) {
  imperative::NameVarBaseMap result;
392 393 394 395 396 397
  for (auto &pair : map) {
    auto var_vec = GetVarBaseListFromPyHandle(pair.second);
    if (!var_vec.empty()) {
      result.emplace(pair.first, std::move(var_vec));
    }
  }
J
Jiabin Yang 已提交
398

399 400 401
  PADDLE_ENFORCE_EQ(
      PyErr_Occurred(), nullptr,
      platform::errors::InvalidArgument(py::str(py::handle(PyErr_Occurred()))));
402 403 404
  return result;
}

405 406 407 408 409 410 411 412
static bool PyCheckInteger(PyObject *obj) {
#if PY_VERSION_HEX < 0x03000000
  return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj);
#else
  return PyLong_Check(obj) && !PyBool_Check(obj);
#endif
}

413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433
static Py_ssize_t GetSliceIndexFromTensor(
    const std::shared_ptr<imperative::VarBase> &tensor_index) {
  const auto &tensor = tensor_index->Var().Get<framework::LoDTensor>();
  if (tensor.numel() == 1) {
    if (tensor.type() == framework::proto::VarType::INT32) {
      return static_cast<Py_ssize_t>(operators::GetValue<int32_t>(&tensor));
    } else if (tensor.type() == framework::proto::VarType::INT64) {
      return static_cast<Py_ssize_t>(operators::GetValue<int64_t>(&tensor));
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, the type of tensor in slice indices only allows "
          "int32 and int64, please check the type of index tensor."));
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Currently, tensor in slice indices only allows 1 element, "
        "but received %d.",
        tensor.numel()));
  }
}

434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
// NOTE(zhiqiu): Revised version of PySlice_GetIndices. From:
// https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103
// Original PySlice_GetIndices return wrong result when
// slice_item contains long int, such as arr[:180L].
// NOT sure why this happens !!!
// Besides, PySlice_GetIndices cannot raise error when float in slice item.
// So, I make a revised version of PySlice_GetIndices, named to
// _PySlice_GetIndices. Try to use _PySlice_Unpack which is more robust than
// PySlice_GetIndices in the future.
static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
                               Py_ssize_t *start, Py_ssize_t *stop,
                               Py_ssize_t *step) {
  /* XXX support long ints */
  if (r->step == Py_None) {
    *step = 1;
  } else {
450
    if (PyCheckInteger(r->step) || IsNumpyType(r->step)) {
451
      *step = PyLong_AsLong(r->step);
452 453 454
    } else if (PyCheckTensor(r->step)) {
      *step = GetSliceIndexFromTensor(
          py::cast<std::shared_ptr<imperative::VarBase>>(r->step));
455 456
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
457 458
          "Currently, slice indices only allows None, integers, "
          "tensor(int) and numpy(int) in slice item, but received %s.",
459 460 461 462 463 464
          std::string(Py_TYPE(r->step)->tp_name)));
    }
  }
  if (r->start == Py_None) {
    *start = *step < 0 ? length - 1 : 0;
  } else {
465
    if (PyCheckInteger(r->start) || IsNumpyType(r->start)) {
466
      *start = PyLong_AsLong(r->start);
467 468 469
    } else if (PyCheckTensor(r->start)) {
      *start = GetSliceIndexFromTensor(
          py::cast<std::shared_ptr<imperative::VarBase>>(r->start));
470 471
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
472 473
          "Currently, slice indices only allows None, integers, "
          "tensor(int) and numpy(int) in slice item, but received %s.",
474 475 476
          std::string(Py_TYPE(r->start)->tp_name)));
    }
    if (*start < 0) *start += length;
477
    *start = std::max(*start, static_cast<Py_ssize_t>(0));
478 479 480 481
  }
  if (r->stop == Py_None) {
    *stop = *step < 0 ? -1 : length;
  } else {
482
    if (PyCheckInteger(r->stop) || IsNumpyType(r->stop)) {
483
      *stop = PyLong_AsLong(r->stop);
484 485 486
    } else if (PyCheckTensor(r->stop)) {
      *stop = GetSliceIndexFromTensor(
          py::cast<std::shared_ptr<imperative::VarBase>>(r->stop));
487 488
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
489 490
          "Currently, slice indices only allows None, integers, "
          "tensor(int) and numpy(int) in slice item, but received %s.",
491 492
          std::string(Py_TYPE(r->stop)->tp_name)));
    }
493
    if (0 < *step && *stop < 0) *stop += length;
494
    *stop = std::min(*stop, length);
495 496 497 498 499 500 501
  }
  if (*stop > length) return -1;
  if (*start >= length) return -1;
  if (*step == 0) return -1;
  return 0;
}

Z
zyfncg 已提交
502 503 504 505 506 507 508 509 510
static void ParseIndexingSlice(
    framework::LoDTensor *tensor, PyObject *_index,
    std::vector<int> *slice_axes, std::vector<int> *slice_starts,
    std::vector<int> *slice_ends, std::vector<int> *slice_strides,
    std::vector<int> *decrease_axis, std::vector<int> *none_axes,
    std::vector<int> *infer_flags, std::vector<int> *list_select_idxs,
    bool *list_select_flag) {
  // We allow indexing by Integers, Slices, Ellipsis, None, tuples of those
  // types, and list of Bool and Integers.
S
songyouwei 已提交
511
  // wrap to tuple
512 513

  // NOTE(zhiqiu): PyTuple_Pack increases refcount.
S
songyouwei 已提交
514
  PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
515 516 517 518 519 520
  DEFINE_PADDLE_SCOPE_GUARD([index, _index]() {
    if (!PyTuple_Check(_index)) {
      Py_DECREF(index);
      VLOG(4) << "Call Py_DECREF";
    }
  });
S
songyouwei 已提交
521 522 523 524 525 526
  PADDLE_ENFORCE_EQ(
      tensor->IsInitialized(), true,
      platform::errors::InvalidArgument("tensor has not been initialized"));
  const auto &shape = tensor->dims();
  const int rank = shape.size();
  const int size = PyTuple_GET_SIZE(index);
527 528 529 530 531 532 533 534 535 536 537 538 539 540

  // specified_dims is the number of dimensions which indexed by Interger,
  // Slices.
  int specified_dims = 0;
  for (int dim = 0; dim < size; ++dim) {
    PyObject *slice_item = PyTuple_GetItem(index, dim);
    if (PyCheckInteger(slice_item) || PySlice_Check(slice_item)) {
      specified_dims++;
    }
  }

  for (int i = 0, dim = 0; i < size; ++i) {
    PyObject *slice_item = PyTuple_GetItem(index, i);

S
songyouwei 已提交
541 542
    infer_flags->push_back(1);
    int dim_len = shape[dim];
543
    if (PyCheckInteger(slice_item) || IsNumpyType(slice_item)) {
544
      // integer, PyLong_AsLong supports both int and long
S
songyouwei 已提交
545
      int start = static_cast<int>(PyLong_AsLong(slice_item));
H
hong 已提交
546
      auto s_t = start;
S
songyouwei 已提交
547
      start = start < 0 ? start + dim_len : start;
548
      if (start >= dim_len || start < 0) {
H
hong 已提交
549 550 551 552 553 554 555 556 557
        std::string str_error_message =
            "The starting index " + std::to_string(s_t) +
            " of slice is out of bounds in tensor " + std::to_string(dim) +
            "-th axis, it shound be in the range of [" +
            std::to_string(-dim_len) + ", " + std::to_string(dim_len) + ")";
        // py::index_error is corresponding to IndexError in Python
        // Used to indicate out of bounds access in __getitem__, __setitem__
        throw py::index_error(str_error_message);
      }
S
songyouwei 已提交
558 559 560 561 562
      slice_axes->push_back(dim);
      slice_starts->push_back(start);
      slice_ends->push_back(start + 1);
      slice_strides->push_back(1);
      decrease_axis->push_back(dim);
563 564
      dim++;
    } else if (PySlice_Check(slice_item)) {
565
      // slice item
S
songyouwei 已提交
566
      Py_ssize_t start, end, step;
567 568 569
      PySliceObject *p = reinterpret_cast<PySliceObject *>(slice_item);
      _PySlice_GetIndices(p, dim_len, &start, &end, &step);

S
songyouwei 已提交
570
      // :: or : or 0:dim_len:1
571
      if (start == 0 && end == dim_len && step == 1) {
572
        dim++;
573 574
        continue;
      }
S
songyouwei 已提交
575 576 577 578
      slice_axes->push_back(dim);
      slice_starts->push_back(start);
      slice_ends->push_back(end);
      slice_strides->push_back(step);
579 580 581
      dim++;
    } else if (slice_item == Py_Ellipsis) {
      dim += rank - specified_dims;
582 583
    } else if (slice_item == Py_None) {
      none_axes->push_back(dim);
Z
zyfncg 已提交
584 585
    } else if (PyList_Check(slice_item)) {
      *list_select_flag = true;
Z
zyfncg 已提交
586 587 588 589 590 591
      PADDLE_ENFORCE_EQ(
          size, 1,
          platform::errors::InvalidArgument(
              "When index contains a list, its length is excepted to 1, "
              "but received %d",
              size));
Z
zyfncg 已提交
592 593 594 595 596 597 598 599 600 601 602 603
      bool all_bool = true;
      int list_size = PyList_GET_SIZE(slice_item);
      for (int j = 0; j < list_size; ++j) {
        PyObject *list_item = PyList_GetItem(slice_item, j);
        if (PyCheckInteger(list_item)) {
          all_bool = false;
        } else if (!PyBool_Check(list_item)) {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "Only support int or bool in index list."));
        }
      }
      if (all_bool) {
Z
zyfncg 已提交
604 605 606 607 608 609 610
        PADDLE_ENFORCE_EQ(
            list_size, shape[0],
            platform::errors::InvalidArgument(
                "The dimension of bool index doesn't match indexed array along "
                "dimension 0, the target dimension is %d, but received %d.",
                shape[0], list_size));

Z
zyfncg 已提交
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630
        for (int j = 0; j < list_size; ++j) {
          PyObject *list_item = PyList_GetItem(slice_item, j);
          if (list_item == Py_True) {
            list_select_idxs->push_back(j);
          }
        }
      } else {
        for (int j = 0; j < list_size; ++j) {
          PyObject *list_item = PyList_GetItem(slice_item, j);
          if (PyCheckInteger(list_item)) {
            list_select_idxs->push_back(
                static_cast<int>(PyLong_AsLong(list_item)));
          } else if (list_item == Py_True) {
            list_select_idxs->push_back(1);
          } else {
            list_select_idxs->push_back(0);
          }
        }
      }

631 632
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
633
          "Currently, Tensor.__indices__() only allows indexing "
Z
zyfncg 已提交
634 635 636
          "by Integers, Slices, Ellipsis, None, tuples of these types "
          "and list of Bool and Integers, but received "
          "%s in %dth slice item",
637
          std::string(Py_TYPE(slice_item)->tp_name), i + 1));
S
songyouwei 已提交
638 639
    }
  }
640 641 642 643 644 645 646

  // valid_index is the number of dimensions exclude None index
  const int valid_indexs = size - none_axes->size();
  PADDLE_ENFORCE_EQ(valid_indexs <= rank, true,
                    platform::errors::InvalidArgument(
                        "Too many indices (%d) for tensor of dimension %d.",
                        valid_indexs, rank));
S
songyouwei 已提交
647 648
}

649
template <typename P>
650 651 652
static void VarBaseCopy(std::shared_ptr<imperative::VarBase> &src,  // NOLINT
                        imperative::VarBase &dst,                   // NOLINT
                        const P &dst_device, const bool blocking) {
653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704
  if (dst.SharedVar()->IsEmpty()) {
    VLOG(3) << "deep copy Variable from " << src->Name() << " to "
            << dst.Name();
    dst.SetPersistable(src->Persistable());
    dst.SetDataType(src->DataType());
    dst.SetType(src->Type());
    dst.SetOverridedStopGradient(src->OverridedStopGradient());
    if (!src->SharedVar()->IsEmpty()) {
      if (src->Var().IsType<framework::LoDTensor>()) {
        auto &src_tensor = src->Var().Get<framework::LoDTensor>();
        auto *dst_tensor = dst.MutableVar()->GetMutable<framework::LoDTensor>();
        dst_tensor->set_lod(src_tensor.lod());
        framework::TensorCopy(src_tensor, dst_device, dst_tensor);
        if (blocking) {
          platform::DeviceContextPool::Instance().Get(dst_device)->Wait();
          auto src_device = src_tensor.place();
          if (!(src_device == dst_device)) {
            platform::DeviceContextPool::Instance().Get(src_device)->Wait();
          }
        }
      } else if (src->Var().IsType<framework::SelectedRows>()) {
        auto &src_selected_rows = src->Var().Get<framework::SelectedRows>();
        auto *dst_selected_rows =
            dst.MutableVar()->GetMutable<framework::SelectedRows>();
        dst_selected_rows->set_height(src_selected_rows.height());
        dst_selected_rows->set_rows(src_selected_rows.rows());
        framework::TensorCopy(src_selected_rows.value(), dst_device,
                              dst_selected_rows->mutable_value());
        if (blocking) {
          platform::DeviceContextPool::Instance().Get(dst_device)->Wait();
          auto src_device = src_selected_rows.value().place();
          if (!(src_device == dst_device)) {
            platform::DeviceContextPool::Instance().Get(src_device)->Wait();
          }
        }
      }

      if (!blocking) {
        IncreaseVarbaseReferenceCountUntilCopyComplete(src, dst_device);
      }

    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "The source Tensor(%s) can not copy when it is empty.", src->Name()));
    }
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "The destion Tensor(%s) can not copy when it is not empty.",
        dst.Name()));
  }
}

705
// Bind Methods
J
Jiabin Yang 已提交
706
void BindImperative(py::module *m_ptr) {
707 708
  auto &m = *m_ptr;

709 710
  BindOpFunctions(&m);

711 712
#ifndef _WIN32
  // Dygraph DataLoader signal handler
713 714 715 716 717 718 719 720 721 722 723 724 725
  m.def("_set_process_pids", [](int64_t key, py::object &obj) {
    PADDLE_ENFORCE_EQ(
        py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj), true,
        platform::errors::InvalidArgument(
            "The subprocess ids set in DataLoader is illegal."
            "Expected data type is tuple or list, but received %s",
            obj.get_type()));
    py::list pids = py::cast<py::list>(obj);
    std::set<pid_t> pids_set = {};
    for (size_t i = 0; i < pids.size(); i++) {
      pids_set.insert(pids[i].cast<pid_t>());
    }
    imperative::SetLoadProcessPIDs(key, pids_set);
726
  });
727 728
  m.def("_erase_process_pids",
        [](int64_t key) { imperative::EraseLoadProcessPIDs(key); });
729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780
  m.def("_set_process_signal_handler",
        []() { imperative::SetLoadProcessSignalHandler(); });
  m.def("_throw_error_if_process_failed",
        []() { imperative::ThrowErrorIfLoadProcessFailed(); });

  // Dygraph DataLoader reader process & thread related functions
  m.def(
      "_convert_to_tensor_list",
      [](py::object &obj) -> py::list {
        // 0. input data check
        PADDLE_ENFORCE(
            py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj),
            platform::errors::InvalidArgument(
                "The batch data read into DataLoader is illegal."
                "Expected data type is tuple or list, but received %s",
                obj.get_type()));
        py::list batch = py::cast<py::list>(obj);
        py::list tensors;
        for (size_t i = 0; i < batch.size(); ++i) {
          // 1. cast to python array
          auto array = batch[i].cast<py::array>();
          PADDLE_ENFORCE_NE(
              string::Sprintf("%s", array.dtype()).compare("object"), 0,
              platform::errors::InvalidArgument(
                  "Faild to convert input data to a regular ndarray.\n  * "
                  "Usually this means the input data contains nested "
                  "lists with different lengths.\n  * Check the reader "
                  "function passed to 'set_(sample/sample_list/batch)"
                  "_generator' to locate the data causes this issue."));
          // 2. construcct LoDTensor
          framework::LoDTensor t;
          SetTensorFromPyArray<platform::CPUPlace>(&t, array,
                                                   platform::CPUPlace(), true);
          // 3. allocate shared memory
          void *data_ptr = t.data<void>();
          size_t data_size = t.numel() * framework::SizeOfType(t.type());
          auto shared_writer_holder =
              memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
          // 4. maintain mmap fd set & backup ipc_name
          const std::string &ipc_name = shared_writer_holder->ipc_name();
          memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
          // 5. copy data & reset holder
          memory::Copy(platform::CPUPlace(), shared_writer_holder->ptr(),
                       platform::CPUPlace(), data_ptr, data_size);
          t.ResetHolder(shared_writer_holder);
          // 6. append to result list
          tensors.append(t);
        }
        return tensors;
      },
      py::return_value_policy::take_ownership);

K
Kaipeng Deng 已提交
781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813
  m.def("_array_to_share_memory_tensor",
        [](py::object &obj) {
          // 1. cast to python array
          auto array = obj.cast<py::array>();
          PADDLE_ENFORCE_NE(
              string::Sprintf("%s", array.dtype()).compare("object"), 0,
              platform::errors::InvalidArgument(
                  "Faild to convert input data to a regular ndarray.\n  * "
                  "Usually this means the input data contains nested "
                  "lists with different lengths.\n  * Check the reader "
                  "function passed to 'set_(sample/sample_list/batch)"
                  "_generator' to locate the data causes this issue."));
          // 2. construcct LoDTensor
          framework::LoDTensor t;
          SetTensorFromPyArray<platform::CPUPlace>(&t, array,
                                                   platform::CPUPlace(), true);
          // 3. allocate shared memory
          void *data_ptr = t.data<void>();
          size_t data_size = t.numel() * framework::SizeOfType(t.type());
          auto shared_writer_holder =
              memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
          // 4. maintain mmap fd set & backup ipc_name
          const std::string &ipc_name = shared_writer_holder->ipc_name();
          memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
          // 5. copy data & reset holder
          memory::Copy(platform::CPUPlace(), shared_writer_holder->ptr(),
                       platform::CPUPlace(), data_ptr, data_size);
          t.ResetHolder(shared_writer_holder);

          return t;
        },
        py::return_value_policy::take_ownership);

814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833
  m.def("_remove_tensor_list_mmap_fds", [](py::list &tensor_list) {
    for (size_t i = 0; i < tensor_list.size(); ++i) {
      auto t = tensor_list[i].cast<framework::LoDTensor>();
      auto *mmap_writer_allocation =
          dynamic_cast<memory::allocation::MemoryMapWriterAllocation *>(
              t.Holder().get());
      PADDLE_ENFORCE_NOT_NULL(
          mmap_writer_allocation,
          platform::errors::NotFound("The shared memory of LoDTensor in "
                                     "DataLoader's child process has been "
                                     "released."));
      memory::allocation::MemoryMapFdSet::Instance().Remove(
          mmap_writer_allocation->ipc_name());
    }
  });

  m.def("_cleanup_mmap_fds",
        []() { memory::allocation::MemoryMapFdSet::Instance().Clear(); });
#endif

834 835 836 837 838
  m.def("start_imperative_gperf_profiler",
        []() { imperative::StartProfile(); });

  m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); });

Z
Zeng Jinle 已提交
839 840 841
  m.def("_is_dygraph_debug_enabled",
        []() { return imperative::IsDebugEnabled(); });
  m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); });
842 843 844 845
  m.def("_switch_tracer",
        [](const std::shared_ptr<imperative::Tracer> &tracer) {
          imperative::SetCurrentTracer(tracer);
        });
Z
Zeng Jinle 已提交
846

847 848 849 850
  py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>> varbase(
      m, "VarBase", R"DOC()DOC");
  g_varbase_pytype = (PyTypeObject *)varbase.ptr();  // NOLINT
  varbase.def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
851 852 853 854 855 856 857
      .def("__init__",
           [](imperative::VarBase &self) {
             std::string name =
                 imperative::GetCurrentTracer()->GenerateUniqueName(
                     "generated_tensor");
             new (&self) imperative::VarBase(name);
           })
J
Jiabin Yang 已提交
858
      .def("__init__",
859 860 861
           [](imperative::VarBase &self, framework::proto::VarType::Type dtype,
              const std::vector<int> &dims, const py::handle &name,
              framework::proto::VarType::Type type, bool persistable) {
862
             VLOG(4) << "Init VarBase";
863 864 865
             std::string act_name = "";
             if (!name.ptr() || name.ptr() == Py_None) {
               act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
866
                   "generated_tensor");
867 868 869 870
             } else {
               act_name = name.cast<std::string>();
             }
             new (&self) imperative::VarBase(act_name);
J
Jiabin Yang 已提交
871 872 873 874 875 876 877 878 879
             self.SetPersistable(persistable);
             self.SetType(type);
             self.SetDataType(dtype);
             if (type == framework::proto::VarType::LOD_TENSOR) {
               auto *tensor =
                   self.MutableVar()->GetMutable<framework::LoDTensor>();
               tensor->Resize(framework::make_ddim(dims));
             }
           })
880 881
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
882 883
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
884 885 886 887
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::XPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
888 889
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
890 891
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
892 893
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPinnedPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
894 895
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
896 897 898 899
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::NPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
L
Leo Chen 已提交
900
      .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
901
      .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
902
      .def("__init__", &InitVarBaseFromNumpyWithKwargs)
903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100
      .def(
          "__setitem_varbase__",
          [](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
             py::object &value_obj) {
            VLOG(4) << "Call __setitem_varbase__";

            auto self_tensor =
                self->MutableVar()->GetMutable<framework::LoDTensor>();
            // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
            // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
            PyObject *index_ptr = !PyTuple_Check(_index.ptr())
                                      ? PyTuple_Pack(1, _index.ptr())
                                      : _index.ptr();
            DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
              if (!PyTuple_Check(_index.ptr())) {
                Py_DECREF(index_ptr);
                VLOG(4) << "Call Py_DECREF";
              }
            });

            auto is_tensor = [](py::handle var) {
              if (!var.ptr() || var.ptr() == Py_None) {
                return false;
              }
              try {
                py::cast<std::shared_ptr<imperative::VarBase>>(var);
                return true;
              } catch (py::cast_error &) {
                return false;
              }
            };

            // 1. Check argumnets
            bool parse_index = true;

            // Check whether _index can be parsed.
            const int size = PyTuple_GET_SIZE(index_ptr);
            for (int dim = 0; dim < size; ++dim) {
              PyObject *slice_item = PyTuple_GetItem(index_ptr, dim);
              if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item) ||
                    slice_item == Py_Ellipsis || slice_item == Py_None)) {
                parse_index = false;
                break;
              }
            }

            // 2. Call op set_value to speed up if the condition is met,
            // otherwise call TensorToPyArray.
            // TODO(liym27): Try not to call TensorToPyArray because it always
            // copys data to cpu place, which reduces performance.
            if (parse_index) {
              std::vector<int> axes, starts, ends, steps, decrease_axes,
                  none_axes, infer_flags, list_select_idxs;
              // if index is a list, list_select_flag will be true
              bool list_select_flag = false;
              ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends,
                                 &steps, &decrease_axes, &none_axes,
                                 &infer_flags, &list_select_idxs,
                                 &list_select_flag);

              framework::AttributeMap attrs = {{"axes", axes},
                                               {"starts", starts},
                                               {"ends", ends},
                                               {"steps", steps},
                                               {"decrease_axes", decrease_axes},
                                               {"none_axes", none_axes}};

              imperative::NameVarBaseMap ins = {{"Input", {self}}};
              imperative::NameVarBaseMap outs = {{"Out", {self}}};

              const auto &tracer = imperative::GetCurrentTracer();

              if (tracer->HasGrad()) {
                PADDLE_ENFORCE_EQ(
                    self->IsLeaf() && !self->OverridedStopGradient(), false,
                    platform::errors::InvalidArgument(
                        "Leaf Tensor (%s) that doesn't stop gradient can't use "
                        "inplace strategy.",
                        self->Name()));
              }

              if (PyCheckTensor(value_obj.ptr())) {
                auto value_tensor =
                    value_obj.cast<std::shared_ptr<imperative::VarBase>>();
                ins.insert({"ValueTensor", {value_tensor}});
              } else if (py::isinstance<py::array>(value_obj)) {
                auto value_tensor = std::shared_ptr<imperative::VarBase>(
                    new imperative::VarBase(false,
                                            tracer->GenerateUniqueName()));
                py::object value = value_obj;
                if (self->DataType() == framework::proto::VarType::FP32) {
                  if (!py::isinstance<py::array_t<float>>(value_obj)) {
                    value = CastNumpyArray<float>(value_obj);
                  }
                } else if (self->DataType() ==
                           framework::proto::VarType::FP64) {
                  if (!py::isinstance<py::array_t<double>>(value_obj)) {
                    value = CastNumpyArray<double>(value_obj);
                  }
                } else if (self->DataType() ==
                           framework::proto::VarType::INT32) {
                  if (!py::isinstance<py::array_t<int32_t>>(value_obj)) {
                    value = CastNumpyArray<int32_t>(value_obj);
                  }
                } else if (self->DataType() ==
                           framework::proto::VarType::INT64) {
                  if (!py::isinstance<py::array_t<int64_t>>(value_obj)) {
                    value = CastNumpyArray<int64_t>(value_obj);
                  }
                } else if (self->DataType() ==
                           framework::proto::VarType::BOOL) {
                  if (!py::isinstance<py::array_t<bool>>(value_obj)) {
                    value = CastNumpyArray<bool>(value_obj);
                  }
                } else {
                  PADDLE_THROW(platform::errors::InvalidArgument(
                      "When assign a numpy.np value to a paddle.Tensor, "
                      "the data type of the paddle.Tensor must be bool, "
                      "float32, int32 or int64, "
                      "please check the type of tensor."));
                }

                SetTensorFromPyArray(value_tensor->MutableVar()
                                         ->GetMutable<framework::LoDTensor>(),
                                     value, self->Place(), false);
                ins.insert({"ValueTensor", {value_tensor}});

              } else {
                // convert the value to self data type
                if (py::isinstance<py::float_>(value_obj) ||
                    py::isinstance<py::int_>(value_obj) ||
                    py::isinstance<py::bool_>(value_obj)) {
                  if (self->DataType() == framework::proto::VarType::FP32) {
                    attrs["fp32_values"] =
                        std::vector<float>{value_obj.cast<float>()};
                  } else if (self->DataType() ==
                             framework::proto::VarType::FP64) {
                    attrs["fp64_values"] =
                        std::vector<double>{value_obj.cast<double>()};
                  } else if (self->DataType() ==
                             framework::proto::VarType::INT32) {
                    attrs["int32_values"] =
                        std::vector<int32_t>{value_obj.cast<int32_t>()};
                  } else if (self->DataType() ==
                             framework::proto::VarType::INT64) {
                    attrs["int64_values"] =
                        std::vector<int64_t>{value_obj.cast<int64_t>()};
                  } else if (self->DataType() ==
                             framework::proto::VarType::BOOL) {
                    attrs["bool_values"] =
                        std::vector<int>{value_obj.cast<bool>()};
                  } else {
                    PADDLE_THROW(platform::errors::InvalidArgument(
                        "When assign a value to a paddle.Tensor, "
                        "the data type of the paddle.Tensor must be bool, "
                        "float32, int32 or int64, "
                        "please check the type of tensor."));
                  }
                  attrs["shape"] = std::vector<int64_t>{1};

                } else {
                  PADDLE_THROW(platform::errors::InvalidArgument(
                      "Value type error. The assign value allows "
                      "numpy.ndarray, integer, float or bool, "
                      "but received %s.",
                      Py_TYPE(value_obj.ptr())));
                }
              }

              {
                // Release gil and do tracing
                py::gil_scoped_release release;
                tracer->TraceOp("set_value", ins, outs, std::move(attrs),
                                {{"Input", "Out"}});
              }
            } else {
              auto self_numpy = TensorToPyArray(*self_tensor);
              VLOG(4) << "parse_index is false";
              if (is_tensor(_index)) {
                VLOG(4) << "index is tensor";
                auto index_var =
                    py::cast<std::shared_ptr<imperative::VarBase>>(_index);
                auto index_tensor =
                    index_var->MutableVar()->GetMutable<framework::LoDTensor>();
                auto index_numpy = TensorToPyArray(*index_tensor);
                self_numpy[index_numpy] = value_obj;
              } else {
                VLOG(4) << "index is not tensor";
                self_numpy[_index] = value_obj;
              }
              SetTensorFromPyArray(self_tensor, self_numpy,
                                   self_tensor->place(), false);
            }
            // NOTE(liym27):
            // Increase the version of VarBase self because __setitem__ is an
            // inplace operator for the VarBase self.
            self->BumpInplaceVersion();
          })
1101
      .def("_getitem_index_not_tensor",
S
songyouwei 已提交
1102
           [](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
1103
             VLOG(4) << "Call _getitem_index_not_tensor";
1104
             std::vector<int> slice_axes, slice_starts, slice_ends,
Z
zyfncg 已提交
1105 1106 1107 1108
                 slice_strides, decrease_axis, none_axes, infer_flags,
                 list_select_idxs;
             // if index is a list, list_select_flag will be true
             bool list_select_flag = false;
S
songyouwei 已提交
1109 1110 1111 1112
             auto tensor =
                 self->MutableVar()->GetMutable<framework::LoDTensor>();
             ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
                                &slice_starts, &slice_ends, &slice_strides,
Z
zyfncg 已提交
1113 1114
                                &decrease_axis, &none_axes, &infer_flags,
                                &list_select_idxs, &list_select_flag);
1115 1116 1117
             // release gil and do tracing
             py::gil_scoped_release release;
             const auto &tracer = imperative::GetCurrentTracer();
1118

Z
zyfncg 已提交
1119
             auto out = slice_axes.empty() && !list_select_flag
1120 1121 1122 1123
                            ? self
                            : std::shared_ptr<imperative::VarBase>(
                                  new imperative::VarBase(
                                      tracer->GenerateUniqueName()));
Z
zyfncg 已提交
1124

1125
             if (!slice_axes.empty()) {
S
songyouwei 已提交
1126
               imperative::NameVarBaseMap ins = {{"Input", {self}}};
1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144
               framework::AttributeMap attrs = {
                   {"axes", slice_axes},
                   {"starts", slice_starts},
                   {"ends", slice_ends},
                   {"infer_flags", infer_flags},
                   {"decrease_axis", decrease_axis}};
               imperative::NameVarBaseMap outs = {{"Out", {out}}};
               std::string op_type = "slice";
               for (auto stride : slice_strides) {
                 if (stride != 1) {
                   op_type = "strided_slice";
                   attrs.insert({"strides", slice_strides});
                   attrs.erase("decrease_axis");
                   break;
                 }
               }
               tracer->TraceOp(op_type, ins, outs, std::move(attrs));
             }
1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172
             if (!none_axes.empty()) {
               // Deal with cases when all axes are decreased.
               // After slice, the shape of out is [1], which should have been
               // [], but Paddle doesn't support scalar.
               // In order to ensure the correctness of the final shape of out,
               // one dimension of out needs to be decreased.
               // For example:
               // # x.shape: (2,3,4)
               // out = x[0, 1, 1, None] # out.shape : (1)
               if (static_cast<int>(decrease_axis.size()) ==
                   tensor->dims().size()) {
                 none_axes.pop_back();
               }
               if (!none_axes.empty()) {
                 // Deal with cases that decrease_axes is not empty
                 // For example:
                 // # x.shape: (2,3,4)
                 // out = x[0, 0:2, None] # out.shape : (2, 1, 4)
                 for (auto &axis : none_axes) {
                   int len = 0;
                   for (int da : decrease_axis) {
                     if (da < axis) {
                       len++;
                     }
                   }
                   axis -= len;
                 }

1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195
                 // Deal with cases that there are more than one
                 // prefix none index, For example:
                 // [None, None, :, :, None]
                 // the none_axes int the return of ParseIndexingSlice is:
                 // [0,    0,          2   ]
                 // according to the interface of "unsqueeze2",
                 // we should convert it to:
                 // [0,    0,          4   ]
                 int prefix_zero_cnt = 0;
                 for (const auto &axis : none_axes) {
                   if (axis == 0) {
                     prefix_zero_cnt++;
                   } else {
                     break;
                   }
                 }
                 if (prefix_zero_cnt > 0) {
                   int none_axes_num = static_cast<int>(none_axes.size());
                   for (int i = prefix_zero_cnt; i < none_axes_num; ++i) {
                     none_axes[i] += prefix_zero_cnt;
                   }
                 }

1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209
                 imperative::NameVarBaseMap ins = {{"X", {out}}};
                 framework::AttributeMap attrs = {{"axes", none_axes}};
                 auto new_out = std::shared_ptr<imperative::VarBase>(
                     new imperative::VarBase(tracer->GenerateUniqueName()));
                 auto out_xshape = std::shared_ptr<imperative::VarBase>(
                     new imperative::VarBase(tracer->GenerateUniqueName()));
                 imperative::NameVarBaseMap outs = {{"Out", {new_out}},
                                                    {"XShape", {out_xshape}}};
                 tracer->TraceOp("unsqueeze2", ins, outs, std::move(attrs));

                 return new_out;
               }
             }

Z
zyfncg 已提交
1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225
             // the index is a list
             if (list_select_flag) {
               auto select_index = std::shared_ptr<imperative::VarBase>(
                   new imperative::VarBase(tracer->GenerateUniqueName()));
               auto *idx_tensor = select_index->MutableVar()
                                      ->GetMutable<framework::LoDTensor>();
               auto *dev_ctx = platform::DeviceContextPool::Instance().Get(
                   tracer->ExpectedPlace());
               TensorFromVector(list_select_idxs, *dev_ctx, idx_tensor);

               imperative::NameVarBaseMap ins = {{"X", {self}},
                                                 {"Index", {select_index}}};
               imperative::NameVarBaseMap outs = {{"Out", {out}}};
               tracer->TraceOp("index_select", ins, outs, {{"dim", 0}});
             }

1226
             return out;
1227
           })
1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291
      .def(
          "_getitem_from_offset",
          [](std::shared_ptr<imperative::VarBase> &self, const py::args &args) {
            const auto &tensor = self->Var().Get<framework::LoDTensor>();
            PADDLE_ENFORCE_EQ(
                tensor.IsInitialized(), true,
                platform::errors::InvalidArgument(
                    "Tensor of %s is Empty, please check if it has no data.",
                    self->Name()));

            const auto &tensor_dims = tensor.dims();

            std::vector<size_t> dims(tensor_dims.size());
            std::vector<size_t> strides(tensor_dims.size());

            size_t numel = 1;
            for (int i = tensor_dims.size() - 1; i >= 0; --i) {
              strides[i] = numel;
              dims[i] = static_cast<size_t>(tensor_dims[i]);
              numel *= dims[i];
            }
            size_t offset = 0;
            if (args.empty()) {
              PADDLE_ENFORCE_EQ(
                  numel, 1,
                  platform::errors::InvalidArgument(
                      "only one element tensors can be converted to Python "
                      "scalars when no input coordinates"));
            } else if (args.size() == 1) {
              offset = args[0].cast<size_t>();
              PADDLE_ENFORCE_LT(
                  offset, numel,
                  platform::errors::InvalidArgument(
                      "index %d is out of bounds for size %d", offset, numel));
            } else {
              PADDLE_ENFORCE_EQ(args.size(), dims.size(),
                                platform::errors::InvalidArgument(
                                    "incorrect number of indices for Tensor"));

              for (size_t i = 0; i < args.size(); ++i) {
                size_t index = args[i].cast<size_t>();
                PADDLE_ENFORCE_LT(
                    index, dims[i],
                    platform::errors::InvalidArgument(
                        "index %d is out fo bounds for axis %d with size %d",
                        index, i, dims[i]));
                offset += index * strides[i];
              }
            }
#define TENSOR_TO_PY_SCALAR(T, proto_type)                                   \
  if (tensor.type() == proto_type) {                                         \
    std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(proto_type); \
    T b = TensorGetElement<T>(tensor, offset);                               \
    return py::array(py::dtype(py_dtype_str.c_str()), {}, {},                \
                     static_cast<void *>(&b));                               \
  }

            _ForEachDataType_(TENSOR_TO_PY_SCALAR);
#undef TENSOR_TO_PY_SCALAR
            PADDLE_THROW(platform::errors::Unimplemented(
                "Unsupported tensor data type: %s",
                framework::DataTypeToString(tensor.type())));
          },
          py::return_value_policy::copy)
1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313
      .def("_inplace_version",
           [](imperative::VarBase &self) -> uint32_t {
             const auto &var = self.MutableVar();
             PADDLE_ENFORCE_EQ(
                 var->IsInitialized(), true,
                 platform::errors::InvalidArgument(
                     "Tensor of %s is Empty, please check if it has no data.",
                     self.Name()));
             return var->CurrentInplaceVersion();
           })
      .def("_bump_inplace_version",
           [](std::shared_ptr<imperative::VarBase> &self) {
             // NOTE(liym27): _bump_inplace_version is only used for inplace
             // operation
             self->BumpInplaceVersion();
           },
           R"DOC(
        **Notes**:
            **This API is ONLY available in Dygraph mode.**
            **This is a very low level API. Users should not use it directly. **
         Bump the version whenever the Tensor is modified through an inplace operation.
            )DOC")
1314
      .def("numpy",
1315

1316 1317 1318 1319 1320 1321
           [](imperative::VarBase &self) -> py::array {
             const auto &tensor =
                 self.MutableVar()->Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(
                 tensor.IsInitialized(), true,
                 platform::errors::InvalidArgument(
1322
                     "Tensor of %s is Empty, please check if it has no data.",
1323 1324 1325 1326
                     self.Name()));
             return TensorToPyArray(tensor, true);
           },
           R"DOC(
Z
Zhou Wei 已提交
1327 1328
        Returns a numpy array shows the value of current Tensor.
        
1329
        Returns:
Z
Zhou Wei 已提交
1330
            ndarray: The numpy value of current Tensor.
1331 1332

        Returns type:
Z
Zhou Wei 已提交
1333
            ndarray: dtype is same as current Tensor
1334 1335 1336 1337

        Examples:
            .. code-block:: python

Z
Zhou Wei 已提交
1338
                import paddle
1339 1340
                import numpy as np
                data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
Z
Zhou Wei 已提交
1341 1342 1343 1344
                linear = paddle.nn.Linear(32, 64)
                data = paddle.to_tensor(data)
                x = linear(data)
                print(x.numpy())
1345
       )DOC")
1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408
      .def("detach",
           [](const imperative::VarBase
                  &self) -> std::shared_ptr<imperative::VarBase> {
             PADDLE_ENFORCE_EQ(
                 self.Var().IsInitialized(), true,
                 platform::errors::InvalidArgument(
                     "Tensor %s has not been initialized!", self.Name()));

             PADDLE_ENFORCE_EQ(
                 self.Var().IsType<framework::LoDTensor>() ||
                     self.Var().IsType<framework::SelectedRows>(),
                 true,
                 platform::errors::InvalidArgument(
                     "Type of Tensor[%s] must be LoDTensor or SelectedRows!",
                     self.Name()));

             auto detach_var = std::make_shared<imperative::VarBase>(
                 true, "detach_" + self.Name());

             detach_var->SetPersistable(self.Persistable());
             detach_var->SetType(self.Type());
             detach_var->SetDataType(self.DataType());

             if (self.Var().IsType<framework::LoDTensor>()) {
               const auto &origin_tensor =
                   self.Var().Get<framework::LoDTensor>();
               PADDLE_ENFORCE_EQ(
                   origin_tensor.IsInitialized(), true,
                   platform::errors::InvalidArgument(
                       "Tensor %s has not been initialized!", self.Name()));

               auto *detach_tensor =
                   detach_var->MutableVar()->GetMutable<framework::LoDTensor>();
               detach_tensor->ShareDataWith(origin_tensor);
               // NOTE(liym27): Call ShareInplaceVersionCounterWith to share the
               // same TensorInplaceVersion, which is used to check whether
               // inplace
               // operations are correct.
               detach_tensor->ShareInplaceVersionCounterWith(origin_tensor);
             } else {
               const auto &origin_selected_rows =
                   self.Var().Get<framework::SelectedRows>();
               PADDLE_ENFORCE_EQ(
                   origin_selected_rows.value().IsInitialized(), true,
                   platform::errors::InvalidArgument(
                       "Tensor %s has not been initialized!", self.Name()));

               auto *detach_selected_rows =
                   detach_var->MutableVar()
                       ->GetMutable<framework::SelectedRows>();
               detach_selected_rows->set_height(origin_selected_rows.height());
               detach_selected_rows->set_rows(origin_selected_rows.rows());
               detach_selected_rows->mutable_value()->ShareDataWith(
                   origin_selected_rows.value());
               detach_selected_rows->mutable_value()
                   ->ShareInplaceVersionCounterWith(
                       origin_selected_rows.value());
             }
             VLOG(3) << "The detached Tensor(" << detach_var->Name()
                     << ") share data with " << self.Name();
             return detach_var;
           },
           py::return_value_policy::take_ownership, R"DOC(
1409

1410
        Returns a new Tensor, detached from the current graph.
Z
Zhou Wei 已提交
1411 1412
        It will share data with origin Tensor and always doesn't have a Tensor copy.
        In addition, the detached Tensor doesn't provide gradient propagation.
1413

1414
        Returns: The detached Tensor.
1415 1416 1417 1418

        Examples:
            .. code-block:: python

1419
                import paddle
Z
Zhou Wei 已提交
1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444

                x = paddle.to_tensor(1.0, stop_gradient=False)
                detach_x = x.detach()
                detach_x[:] = 10.0
                print(x)  # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False,
                          #        [10.])
                y = x**2
                y.backward()
                print(x.grad)         # [20.0]
                print(detach_x.grad)  # None, 'stop_gradient=True' by default

                detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
                z = detach_x**3
                z.backward()

                print(x.grad)         # [20.0], detach_x is detached from x's graph, not affect each other
                print(detach_x.grad)  # [300.0], detach_x has its own graph

                # Due to sharing of data with origin Tensor, There are some unsafe operations:
                y = 2 * x
                detach_x[:] = 5.0
                y.backward() 
                # It will raise Error:
                #   one of the variables needed for gradient computation has been modified by an inplace operation.
             
1445 1446 1447
       )DOC")
      .def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(

1448
        Only for Tensor that has gradient, normally we use this for Parameters since other temporary Tensor doesen't has gradient.
1449

1450
        The Gradient of current Tensor will be set to ``0`` .
1451 1452 1453 1454 1455 1456

        Returns:  None

        Examples:
             .. code-block:: python

1457
                import paddle
Z
Zhou Wei 已提交
1458 1459 1460 1461 1462 1463 1464
                input = paddle.uniform([10, 2])
                linear = paddle.nn.Linear(2, 3)
                out = linear(input)
                out.backward()
                print("Before clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
                linear.weight.clear_gradient()
                print("After clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
1465
      )DOC")
Z
Zhou Wei 已提交
1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513
      .def("clone",
           [](std::shared_ptr<imperative::VarBase> &self) {
             const auto &tensor = self->Var().Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(
                 tensor.IsInitialized(), true,
                 platform::errors::InvalidArgument(
                     "%s has not been initialized", self->Name()));
             auto tracer = imperative::GetCurrentTracer();
             auto new_var = std::make_shared<imperative::VarBase>(
                 true, tracer->GenerateUniqueName(self->Name() + "_clone"));
             framework::AttributeMap attrs;
             imperative::NameVarBaseMap ins = {{"X", {self}}};
             imperative::NameVarBaseMap outs = {{"Out", {new_var}}};
             tracer->TraceOp("assign", ins, outs, attrs);
             return new_var;
           },
           py::return_value_policy::copy, R"DOC(

        Returns a new Tensor, which is clone of origin Tensor, and it remains in the current graph.
        It will always have a Tensor copy.
        Tn addition, the cloned Tensor provides gradient propagation.

        Returns: The cloned Tensor.

        Examples:
            .. code-block:: python

              import paddle

              x = paddle.to_tensor(1.0, stop_gradient=False)
              clone_x = x.clone()
              y = clone_x**2
              y.backward()
              print(clone_x.stop_gradient) # False
              print(clone_x.grad)          # [2.0], support gradient propagation
              print(x.stop_gradient)       # False
              print(x.grad)                # [2.0], clone_x support gradient propagation for x

              x = paddle.to_tensor(1.0)
              clone_x = x.clone()
              clone_x.stop_gradient = False
              z = clone_x**3
              z.backward()
              print(clone_x.stop_gradient) # False
              print(clone_x.grad)          # [3.0], support gradient propagation
              print(x.stop_gradient) # True
              print(x.grad)          # None
       )DOC")
L
Leo Chen 已提交
1514 1515 1516 1517 1518 1519
      .def("_grad_name", &imperative::VarBase::GradVarName)
      .def("_grad_value",
           [](imperative::VarBase &self) {
             return self.MutableGradVar()->Get<framework::LoDTensor>();
           },
           py::return_value_policy::reference)
1520 1521 1522 1523
      .def("_set_grad_type",
           [](imperative::VarBase &self, framework::proto::VarType::Type type) {
             self.MutableGradVarBase()->SetType(type);
           })
1524
      .def("_grad_ivar",
J
Jiabin Yang 已提交
1525 1526
           [](const imperative::VarBase &self) {
             auto &grad_var = self.GradVarBase();
1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537
             if (grad_var && grad_var->Var().IsInitialized()) {
               auto *tensor =
                   grad_var->MutableVar()->IsType<framework::LoDTensor>()
                       ? grad_var->MutableVar()
                             ->GetMutable<framework::LoDTensor>()
                       : grad_var->MutableVar()
                             ->GetMutable<framework::SelectedRows>()
                             ->mutable_value();
               if (tensor->IsInitialized()) {
                 return grad_var;
               }
J
Jiabin Yang 已提交
1538
             }
1539
             return std::shared_ptr<imperative::VarBase>(nullptr);
J
Jiabin Yang 已提交
1540 1541
           },
           py::return_value_policy::copy)
C
chentianyu03 已提交
1542 1543 1544 1545
      .def("_set_grad_ivar",
           [](imperative::VarBase &self, imperative::VarBase &grad) {
             self.SetGradVarBase(grad);
           })
1546 1547 1548 1549 1550 1551 1552 1553
      .def("_is_sparse",
           [](imperative::VarBase &self) {
             return self.Var().IsType<framework::SelectedRows>();
           })
      .def("_allreduce",
           [](imperative::VarBase &self,
              const imperative::ParallelStrategy &strategy) {
             if (strategy.nranks_ > 1) {
1554
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571
#if NCCL_VERSION_CODE >= 2212
               imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
#else
               if (!self.Var().IsType<framework::SelectedRows>()) {
                 imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
               } else {
                 PADDLE_THROW(platform::errors::Unimplemented(
                     "Imperative SelectedRows allreduce is not supported when "
                     "paddle is compiled with NCCL verison lower than v2.2.12. "
                     "You can set is_sparse=False for the Layer containing "
                     "this argument, such as Embedding(is_sparse=False)."));
               }
#endif  // NCCL_VERSION_CODE
#else
               PADDLE_THROW(platform::errors::Unimplemented(
                   "Imperative allreduce is not supported when paddle is "
                   "not compiled with NCCL."));
1572
#endif  // PADDLE_WITH_NCCL or PADDLE_WITH_RCCL
1573 1574 1575
             }
           },
           py::call_guard<py::gil_scoped_release>())
1576 1577 1578
      .def("_register_grad_hook",
           [](imperative::VarBase &self, const py::handle &hook) {
             PADDLE_ENFORCE_EQ(
1579
                 !self.OverridedStopGradient() && self.HasGradVar(), true,
1580
                 platform::errors::InvalidArgument(
1581 1582 1583
                     "Cannot register gradient hook on a Tensor that stop "
                     "gradient or without gradient."));
             return self.GradVarBase()->AddVariableWrapperHook(
1584 1585 1586 1587 1588
                 std::make_shared<PyVariableWrapperHook>(hook.ptr()));
           })
      .def("_remove_grad_hook",
           [](imperative::VarBase &self, int64_t hook_id) {
             PADDLE_ENFORCE_EQ(
1589
                 !self.OverridedStopGradient() && self.HasGradVar(), true,
1590
                 platform::errors::InvalidArgument(
1591 1592 1593
                     "Cannot remove gradient hook on a Tensor that stop "
                     "gradient or without gradient."));
             return self.GradVarBase()->RemoveVariableWrapperHook(hook_id);
1594
           })
1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630
      .def("_register_backward_hook",
           [](imperative::VarBase &self, const py::handle &hook) {
             PADDLE_ENFORCE_EQ(
                 self.IsLeaf(), true,
                 platform::errors::InvalidArgument(
                     "Only can register backward hook for leaf Tensor."));
             PADDLE_ENFORCE_EQ(
                 !self.OverridedStopGradient() && self.HasGradVar(), true,
                 platform::errors::InvalidArgument(
                     "Cannot register backward hook on a Tensor that stop "
                     "gradient or without gradient."));
             auto py_func = PyObjectCast<std::function<void()>>(hook.ptr());
             self.GradVarBase()->AddVoidHook(
                 std::make_shared<std::function<void()>>(py_func));
           },
           R"DOC(
             Registers a backward hook for current Tensor.

             This hook will be called every time the gradient of current Tensor has been fully calculated.

             There are two differences with `_register_grad_hook`:
             1. This backward hook will be executed after the gradient accumulation completed across batchs,
                but the hook registered by `_register_grad_hook` will be executed the gradient accumulation
                completed in current batch.
             2. This backward hook function should have the following signature:

                  hook() -> None

                It requires no input and no return value.

             Args:
                 hook(function): A backward hook to be registered for Tensor.gradient

             Returns:
                 None
           )DOC")
1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658
      .def("cpu",
           [](const std::shared_ptr<imperative::VarBase> &self) {
             if (platform::is_cpu_place(self->Place())) {
               return self;
             } else {
               auto new_var = self->NewVarBase(platform::CPUPlace(), true);
               new_var->SetOverridedStopGradient(self->OverridedStopGradient());
               return new_var;
             }
           },
           R"DOC(
        Returns a copy of this Tensor in CPU memory.

        If this Tensor is already in CPU memory, then no copy is performed and the original Tensor is returned.

        Examples:
            .. code-block:: python

              import paddle
              x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
              print(x.place)    # CUDAPlace(0)
              
              y = x.cpu()
              print(y.place)    # CPUPlace

              )DOC")
      .def("pin_memory",
           [](const std::shared_ptr<imperative::VarBase> &self) {
1659
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690
             PADDLE_THROW(platform::errors::PermissionDenied(
                 "Cannot copy this Tensor to pinned memory in CPU version "
                 "Paddle, "
                 "Please recompile or reinstall Paddle with CUDA support."));
#endif
             if (platform::is_cuda_pinned_place(self->Place())) {
               return self;
             } else {
               auto new_var =
                   self->NewVarBase(platform::CUDAPinnedPlace(), true);
               new_var->SetOverridedStopGradient(self->OverridedStopGradient());
               return new_var;
             }
           },
           R"DOC(
        Returns a copy of this Tensor in pin memory.

        If this Tensor is already in pin memory, then no copy is performed and the original Tensor is returned.

        Examples:
            .. code-block:: python

              import paddle
              x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
              print(x.place)      # CUDAPlace(0)

              y = x.pin_memory()
              print(y.place)      # CUDAPinnedPlace

      )DOC")
      .def("cuda",
1691 1692
           [](const std::shared_ptr<imperative::VarBase> &self,
              py::handle &handle, bool blocking) {
1693
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
1694 1695 1696 1697 1698
             PADDLE_THROW(platform::errors::PermissionDenied(
                 "Cannot copy this Tensor to GPU in CPU version Paddle, "
                 "Please recompile or reinstall Paddle with CUDA support."));
#else
             int device_count = platform::GetCUDADeviceCount();
1699 1700
             int device_id = 0;
             if (handle == py::none()) {
1701 1702 1703
               if (platform::is_gpu_place(self->Place())) {
                 return self;
               }
1704 1705 1706 1707 1708 1709 1710
             } else {
               PyObject *py_obj = handle.ptr();
               PADDLE_ENFORCE_EQ(
                   PyCheckInteger(py_obj), true,
                   platform::errors::InvalidArgument(
                       " 'device_id' must be a positive integer"));
               device_id = py::cast<int>(handle);
1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733
             }
             PADDLE_ENFORCE_GE(
                 device_id, 0,
                 platform::errors::InvalidArgument(
                     "Can not copy Tensor to Invalid CUDAPlace(%d), device id "
                     "must inside [0, %d)",
                     device_id, device_count));
             PADDLE_ENFORCE_LT(
                 device_id, device_count,
                 platform::errors::InvalidArgument(
                     "Can not copy Tensor to Invalid CUDAPlace(%d), device id "
                     "must inside [0, %d)",
                     device_id, device_count));
             platform::CUDAPlace place = platform::CUDAPlace(device_id);
             if (platform::is_same_place(self->Place(), place)) {
               return self;
             } else {
               auto new_var = self->NewVarBase(place, blocking);
               new_var->SetOverridedStopGradient(self->OverridedStopGradient());
               return new_var;
             }
#endif
           },
1734
           py::arg("device_id") = py::none(), py::arg("blocking") = true, R"DOC(
1735 1736 1737 1738 1739 1740
        Returns a copy of this Tensor in GPU memory.

        If this Tensor is already in GPU memory and device_id is default, 
        then no copy is performed and the original Tensor is returned.
        
        Args:
1741
            device_id(int, optional): The destination GPU device id. Default: None, means current device.
1742 1743 1744 1745 1746 1747
            blocking(bool, optional): If False and the source is in pinned memory, the copy will be 
              asynchronous with respect to the host. Otherwise, the argument has no effect. Default: False.

        Examples:
            .. code-block:: python

1748
              # required: gpu
1749 1750 1751 1752 1753 1754
              import paddle
              x = paddle.to_tensor(1.0, place=paddle.CPUPlace())
              print(x.place)        # CPUPlace

              y = x.cuda()
              print(y.place)        # CUDAPlace(0)
1755 1756 1757
            
              y = x.cuda(None)
              print(y.place)        # CUDAPlace(0)
1758 1759 1760 1761

              y = x.cuda(1)
              print(y.place)        # CUDAPlace(1)
       )DOC")
K
Kaipeng Deng 已提交
1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790
      .def("_share_memory",
           [](const std::shared_ptr<imperative::VarBase> &self) {
#ifndef _WIN32
             PADDLE_ENFORCE_EQ(
                 platform::is_cpu_place(self->Place()), true,
                 platform::errors::InvalidArgument(
                     "Sharing memory only support CPU Tensor currently"));
             // 1. get LoDTensor
             auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
             // 2. allocate shared memory
             void *data_ptr = t->data<void>();
             size_t data_size = t->numel() * framework::SizeOfType(t->type());
             auto shared_writer_holder =
                 memory::allocation::AllocateMemoryMapWriterAllocation(
                     data_size);
             // 3. maintain mmap fd set & backup ipc_name
             const std::string &ipc_name = shared_writer_holder->ipc_name();
             memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
             // 4. copy data & reset holder
             memory::Copy(platform::CPUPlace(), shared_writer_holder->ptr(),
                          platform::CPUPlace(), data_ptr, data_size);
             t->ResetHolder(shared_writer_holder);
             return *t;
#else
             PADDLE_THROW(platform::errors::PermissionDenied(
                 "Sharing memory in Windows OS is not supported currently"));
#endif
           },
           py::return_value_policy::reference)
1791
      .def("copy_", &imperative::VarBase::CopyFrom)
1792
      .def("_copy_to",
1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808
           [](const std::shared_ptr<imperative::VarBase> &self,
              const platform::CPUPlace &place, bool blocking) {
             auto new_var = self->NewVarBase(place, blocking);
             // Note(zhiqiu): Since NewVarBase may use GpuCopyAsync to
             // copy data from the tensor of self to the tensor of new varbase,
             // we need to ensure that the varbase self is not destructed until
             // the GpuCopyAsync is completed. Otherwise, the memory may be
             // freed
             // when varbase self is destructed.
             // To do that, we increase the reference count of self by 1 and
             // add a cuda event to wait the GpuCopyAsync's completion.
             if (!blocking) {
               IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
             }
             return new_var;
           },
J
Jiabin Yang 已提交
1809
           py::return_value_policy::copy)
1810
      .def("_copy_to",
1811 1812 1813 1814 1815 1816 1817 1818
           [](const std::shared_ptr<imperative::VarBase> &self,
              const platform::CUDAPinnedPlace &place, bool blocking) {
             auto new_var = self->NewVarBase(place, blocking);
             if (!blocking) {
               IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
             }
             return new_var;
           },
1819
           py::return_value_policy::copy)
1820
      .def("_copy_to",
1821 1822 1823 1824 1825 1826 1827 1828
           [](const std::shared_ptr<imperative::VarBase> &self,
              const platform::XPUPlace &place, bool blocking) {
             auto new_var = self->NewVarBase(place, blocking);
             if (!blocking) {
               IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
             }
             return new_var;
           },
1829
           py::return_value_policy::copy)
1830
      .def("_copy_to",
1831 1832 1833 1834 1835 1836 1837 1838
           [](const std::shared_ptr<imperative::VarBase> &self,
              const platform::CUDAPlace &place, bool blocking) {
             auto new_var = self->NewVarBase(place, blocking);
             if (!blocking) {
               IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
             }
             return new_var;
           },
J
Jiabin Yang 已提交
1839
           py::return_value_policy::copy)
1840 1841 1842 1843 1844 1845 1846 1847 1848 1849
      .def("_copy_to",
           [](const std::shared_ptr<imperative::VarBase> &self,
              const platform::NPUPlace &place, bool blocking) {
             auto new_var = self->NewVarBase(place, blocking);
             if (!blocking) {
               IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
             }
             return new_var;
           },
           py::return_value_policy::copy)
C
chentianyu03 已提交
1850 1851 1852 1853 1854 1855 1856 1857 1858 1859
      .def("_copy_to",
           [](const std::shared_ptr<imperative::VarBase> &self,
              const platform::Place &place, bool blocking) {
             auto new_var = self->NewVarBase(place, blocking);
             if (!blocking) {
               IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
             }
             return new_var;
           },
           py::return_value_policy::copy)
J
Jiabin Yang 已提交
1860
      .def("value", [](imperative::VarBase &self) { return self.MutableVar(); },
1861 1862 1863
           py::return_value_policy::reference)
      .def_property("name", &imperative::VarBase::Name,
                    &imperative::VarBase::SetName)
L
Leo Chen 已提交
1864 1865 1866 1867 1868
      .def_property("stop_gradient",
                    &imperative::VarBase::OverridedStopGradient,
                    &imperative::VarBase::SetOverridedStopGradient)
      .def_property("persistable", &imperative::VarBase::Persistable,
                    &imperative::VarBase::SetPersistable)
1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884
      .def_property_readonly(
          "shape",
          [](imperative::VarBase &self) {
            if (self.Var().IsType<framework::LoDTensor>()) {
              return framework::vectorize<int>(
                  self.Var().Get<framework::LoDTensor>().dims());
            } else if (self.Var().IsType<framework::SelectedRows>()) {
              return framework::vectorize<int>(
                  self.Var().Get<framework::SelectedRows>().value().dims());
            } else {
              VLOG(2) << "It is meaningless to get shape of "
                         "variable type "
                      << GetTypeName(self);
              return std::vector<int>();
            }
          })
1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913
      .def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf,
                             R"DOC(
      Whether a Tensor is leaf Tensor.

      For the Tensor whose stop_gradient is ``True`` , it will be leaf Tensor. 
      
      For the Tensor whose stop_gradient is ``False`` , it will be leaf Tensor too if it is created by user.

      Returns:
          bool: Whether a Tensor is leaf Tensor.

      Examples:
          .. code-block:: python

              import paddle

              x = paddle.to_tensor(1.)
              print(x.is_leaf) # True

              x = paddle.to_tensor(1., stop_gradient=True)
              y = x + 1
              print(x.is_leaf) # True
              print(y.is_leaf) # True

              x = paddle.to_tensor(1., stop_gradient=False)
              y = x + 1
              print(x.is_leaf) # True
              print(y.is_leaf) # False
       )DOC")
1914 1915 1916
      .def_property_readonly(
          "place", [](imperative::VarBase &self) { return self.Place(); },
          py::return_value_policy::copy)
1917 1918 1919 1920 1921 1922
      .def_property_readonly("_place_str",
                             [](imperative::VarBase &self) {
                               std::stringstream ostr;
                               ostr << self.Place();
                               return ostr.str();
                             })
J
Jiabin Yang 已提交
1923
      .def_property_readonly("type", &imperative::VarBase::Type)
L
Leo Chen 已提交
1924
      .def_property_readonly("dtype", &imperative::VarBase::DataType);
1925

1926 1927 1928 1929 1930
  // NOTE(zhiqiu): set the metaclass of Layer.
  // See details: https://github.com/pybind/pybind11/pull/679
  // https://github.com/pybind/pybind11/blob/028812ae7eee307dca5f8f69d467af7b92cc41c8/tests/test_methods_and_attributes.cpp#L284
  py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(
      m, "Layer", py::metaclass((PyObject *)&PyType_Type));  // NOLINT
1931
  layer.def(py::init<>())
1932 1933 1934 1935 1936
      .def("forward",
           [](imperative::Layer &self,
              const std::vector<std::shared_ptr<imperative::VarBase>> &inputs) {
             return self.Forward(inputs);
           });
1937

1938 1939 1940 1941 1942
  py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "")
      .def("create_program_desc",
           &imperative::jit::ProgramDescTracer::CreateProgramDesc)
      .def("reset", &imperative::jit::ProgramDescTracer::Reset);

L
Leo Chen 已提交
1943 1944 1945 1946 1947 1948 1949
  py::enum_<paddle::imperative::AmpLevel>(m, "AmpLevel", py::arithmetic())
      .value("O0", paddle::imperative::AmpLevel::O0)
      .value("O1", paddle::imperative::AmpLevel::O1)
      .value("O2", paddle::imperative::AmpLevel::O2)
      .value("O3", paddle::imperative::AmpLevel::O3)
      .export_values();

1950
  py::class_<imperative::Tracer, std::shared_ptr<imperative::Tracer>>(
1951
      m, "Tracer", R"DOC()DOC")
1952
      .def("__init__",
J
Jiabin Yang 已提交
1953
           [](imperative::Tracer &self) { new (&self) imperative::Tracer(); })
1954 1955 1956
      .def_property("_enable_program_desc_tracing",
                    &imperative::Tracer::IsProgramDescTracingEnabled,
                    &imperative::Tracer::SetEnableProgramDescTracing)
L
Leo Chen 已提交
1957 1958
      .def_property("_amp_level", &imperative::Tracer::GetAmpLevel,
                    &imperative::Tracer::SetAmpLevel)
1959
      .def_property("_has_grad", &imperative::Tracer::HasGrad,
1960
                    &imperative::Tracer::SetHasGrad)
1961 1962 1963 1964 1965 1966 1967 1968
      .def_property(
          "_expected_place",
          [](const imperative::Tracer &self) -> py::object {
            return py::cast(self.ExpectedPlace());
          },
          [](imperative::Tracer &self, const py::object &obj) {
            if (py::isinstance<platform::CUDAPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPlace *>();
L
Leo Chen 已提交
1969
              self.SetExpectedPlace(*p);
1970 1971
              VLOG(4) << "Tracer(" << &self << ")"
                      << " set expected place " << *p;
1972 1973 1974
            } else if (py::isinstance<platform::XPUPlace>(obj)) {
              auto p = obj.cast<platform::XPUPlace *>();
              self.SetExpectedPlace(*p);
1975 1976
              VLOG(4) << "Tracer(" << &self << ")"
                      << " set expected place " << *p;
1977 1978
            } else if (py::isinstance<platform::CPUPlace>(obj)) {
              auto p = obj.cast<platform::CPUPlace *>();
L
Leo Chen 已提交
1979
              self.SetExpectedPlace(*p);
1980 1981
              VLOG(4) << "Tracer(" << &self << ")"
                      << " set expected place " << *p;
1982 1983
            } else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPinnedPlace *>();
L
Leo Chen 已提交
1984
              self.SetExpectedPlace(*p);
1985 1986
              VLOG(4) << "Tracer(" << &self << ")"
                      << " set expected place " << *p;
1987 1988 1989 1990 1991
            } else if (py::isinstance<platform::NPUPlace>(obj)) {
              auto p = obj.cast<platform::NPUPlace *>();
              self.SetExpectedPlace(*p);
              VLOG(4) << "Tracer(" << &self << ")"
                      << " set expected place " << *p;
1992 1993 1994 1995 1996
            } else if (py::isinstance<platform::Place>(obj)) {
              auto p = obj.cast<platform::Place *>();
              self.SetExpectedPlace(*p);
              VLOG(4) << "Tracer(" << &self << ")"
                      << " set expected place " << *p;
1997
            } else {
L
Leo Chen 已提交
1998
              PADDLE_THROW(platform::errors::InvalidArgument(
1999
                  "Incompatible Place Type: supports XPUPlace, CUDAPlace, "
2000
                  "CPUPlace, NPUPlace"
L
Leo Chen 已提交
2001 2002
                  "and CUDAPinnedPlace, "
                  "but got Unknown Type!"));
2003 2004
            }
          })
2005 2006 2007
      .def("_get_program_desc_tracer",
           &imperative::Tracer::GetProgramDescTracer,
           py::return_value_policy::reference)
2008
      .def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
2009
           py::arg("key") = "dygraph_tmp")
2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025
      .def("_set_amp_op_list",
           [](imperative::Tracer &self,
              std::unordered_set<std::string> &allow_ops,
              std::unordered_set<std::string> &block_ops) {
             // NOTE(zhiqiu): The automatic conversion in pybind11 between
             // c++
             // STL and python set/list/dict involve a copy operation that
             // prevents pass-by-reference semantics, so it is ok to swap.
             // The reaseon why not directly pass
             // std::shared_ptr<std::unordered_set<std::string>>
             // is that pybind11 forbid shared_ptr<T> where T is not custom
             // type.
             imperative::AmpOperators::Instance().GetMutableAllowOps()->swap(
                 allow_ops);
             imperative::AmpOperators::Instance().GetMutableBlockOps()->swap(
                 block_ops);
2026
             VLOG(5) << "AMP operators changed, "
2027 2028
                     << imperative::AmpOperators::Instance();
           })
2029 2030 2031
      .def("_get_amp_op_list",
           [](imperative::Tracer &self) {
             return std::make_tuple(
2032 2033
                 *(imperative::AmpOperators::Instance().GetMutableAllowOps()),
                 *(imperative::AmpOperators::Instance().GetMutableBlockOps()));
2034
           })
2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::XPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           })
M
minqiyang 已提交
2048
      .def("trace",
J
Jiabin Yang 已提交
2049 2050 2051 2052 2053 2054
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CUDAPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
2055 2056
             {
               py::gil_scoped_release release;
J
Jiabin Yang 已提交
2057 2058
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
2059
             }
M
minqiyang 已提交
2060
           })
2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::NPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           })
J
Jiabin Yang 已提交
2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           });
2087 2088

  // define parallel context
2089 2090 2091
  py::class_<imperative::ParallelStrategy> parallel_strategy(
      m, "ParallelStrategy", "");
  parallel_strategy.def(py::init())
2092 2093
      .def_property(
          "nranks",
2094 2095
          [](const imperative::ParallelStrategy &self) { return self.nranks_; },
          [](imperative::ParallelStrategy &self, int nranks) {
2096 2097 2098
            self.nranks_ = nranks;
          })
      .def_property("local_rank",
2099
                    [](const imperative::ParallelStrategy &self) {
2100 2101
                      return self.local_rank_;
                    },
2102
                    [](imperative::ParallelStrategy &self, int local_rank) {
2103 2104 2105 2106
                      self.local_rank_ = local_rank;
                    })
      .def_property(
          "trainer_endpoints",
2107
          [](const imperative::ParallelStrategy &self) {
2108 2109
            return self.trainer_endpoints_;
          },
2110
          [](imperative::ParallelStrategy &self, std::vector<std::string> eps) {
2111 2112 2113
            self.trainer_endpoints_ = eps;
          })
      .def_property("current_endpoint",
2114
                    [](const imperative::ParallelStrategy &self) {
2115 2116
                      return self.current_endpoint_;
                    },
2117
                    [](imperative::ParallelStrategy &self,
2118 2119 2120 2121 2122 2123 2124
                       const std::string &ep) { self.current_endpoint_ = ep; })
      .def_property(
          "nrings",
          [](const imperative::ParallelStrategy &self) { return self.nrings_; },
          [](imperative::ParallelStrategy &self, int nrings) {
            self.nrings_ = nrings;
          });
2125

2126 2127 2128 2129
  m.def("varbase_copy", &VarBaseCopy<platform::Place>);
  m.def("varbase_copy", &VarBaseCopy<platform::CPUPlace>);
  m.def("varbase_copy", &VarBaseCopy<platform::CUDAPlace>);
  m.def("varbase_copy", &VarBaseCopy<platform::XPUPlace>);
2130
  m.def("varbase_copy", &VarBaseCopy<platform::CUDAPinnedPlace>);
2131
  m.def("varbase_copy", &VarBaseCopy<platform::NPUPlace>);
2132

2133 2134 2135 2136 2137 2138 2139
  m.def(
      "dygraph_partial_grad",
      [](const std::vector<std::shared_ptr<imperative::VarBase>> &input_targets,
         const std::vector<std::shared_ptr<imperative::VarBase>>
             &output_targets,
         const std::vector<std::shared_ptr<imperative::VarBase>> &output_grads,
         const std::vector<std::shared_ptr<imperative::VarBase>> &no_grad_vars,
2140 2141
         const platform::Place &place, bool create_graph, bool retain_graph,
         bool allow_unused, bool only_inputs) {
Z
Zeng Jinle 已提交
2142 2143
        imperative::PartialGradEngine engine(
            input_targets, output_targets, output_grads, no_grad_vars, place,
2144
            create_graph, retain_graph, allow_unused, only_inputs);
2145 2146 2147 2148 2149
        engine.Execute();
        return engine.GetResult();
      },
      py::call_guard<py::gil_scoped_release>());

2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162
  m.def(
      "dygraph_run_backward",
      [](const std::vector<std::shared_ptr<imperative::VarBase>> &tensors,
         const std::vector<std::shared_ptr<imperative::VarBase>> &grad_tensors,
         bool retain_graph, const imperative::Tracer &tracer) {
        auto *engine = tracer.GetEngine();
        engine->Init(tensors, grad_tensors, retain_graph);
        VLOG(3) << "Start backward";
        engine->Execute();
        VLOG(3) << "Finish backward";
      },
      py::call_guard<py::gil_scoped_release>());

2163
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
2164
    defined(PADDLE_WITH_XPU_BKCL) || defined(PADDLE_WITH_GLOO)
2165 2166 2167 2168 2169 2170
  py::class_<imperative::ParallelContext,
             std::shared_ptr<imperative::ParallelContext>>(m,
                                                           "ParallelContext");

  py::class_<imperative::Reducer, std::shared_ptr<imperative::Reducer>>(
      m, "Reducer", R"DOC()DOC")
S
ShenLiang 已提交
2171 2172 2173 2174 2175
      .def(py::init<const std::vector<std::shared_ptr<imperative::VarBase>> &,
                    const std::vector<std::vector<size_t>> &,
                    const std::vector<bool> &,
                    std::shared_ptr<imperative::ParallelContext>,
                    const std::vector<size_t> &, bool>())
2176
      .def("prepare_for_backward", &imperative::Reducer::PrepareForBackward,
2177
           py::arg("vars"), py::call_guard<py::gil_scoped_release>());
2178 2179 2180 2181

  m.def("assign_group_by_size", &imperative::AssignGroupBySize, py::arg("vars"),
        py::arg("is_sparse_gradient"),
        py::arg("group_size_limits") = std::vector<size_t>{25 * 1024 * 1024},
2182
        py::arg("tensor_indices") = std::vector<int64_t>{},
2183
        py::call_guard<py::gil_scoped_release>());
2184
#endif
2185

2186
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
2187 2188 2189 2190 2191
  py::class_<imperative::NCCLParallelContext, imperative::ParallelContext,
             std::shared_ptr<imperative::NCCLParallelContext>>(
      m, "NCCLParallelContext")
      .def(py::init<const imperative::ParallelStrategy &,
                    const platform::CUDAPlace &>())
K
kuizhiqing 已提交
2192 2193 2194 2195
      .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); })
      .def("init_with_ring_id",
           &imperative::NCCLParallelContext::InitWithRingID,
           py::arg("ring_id"));
2196 2197 2198 2199 2200 2201 2202 2203
#endif

#if defined(PADDLE_WITH_XPU_BKCL)
  py::class_<imperative::BKCLParallelContext, imperative::ParallelContext,
             std::shared_ptr<imperative::BKCLParallelContext>>(
      m, "BKCLParallelContext")
      .def(py::init<const imperative::ParallelStrategy &,
                    const platform::XPUPlace &>())
K
kuizhiqing 已提交
2204 2205 2206 2207
      .def("init", [](imperative::BKCLParallelContext &self) { self.Init(); })
      .def("init_with_ring_id",
           &imperative::BKCLParallelContext::InitWithRingID,
           py::arg("ring_id"));
2208
#endif
2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222

#if defined(PADDLE_WITH_GLOO)
  // xiongkun
  py::class_<imperative::GLOOParallelContext, imperative::ParallelContext,
             std::shared_ptr<imperative::GLOOParallelContext>>(
      m, "GLOOParallelContext")
      .def(py::init<const imperative::ParallelStrategy &,
                    const platform::CPUPlace &>())
      .def("init", [](imperative::GLOOParallelContext &self) { self.Init(); })
      .def("init_with_ring_id",
           &imperative::GLOOParallelContext::InitWithRingID,
           py::arg("ring_id"));
#endif

2223 2224 2225 2226 2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243 2244 2245
  m.def("pylayer_apply",
        [](const platform::CPUPlace &place, const py::object &cls,
           const py::args args, const py::kwargs kwargs) {
          return imperative::PyLayerApply(place, cls, args, kwargs);
        });

  m.def("pylayer_apply",
        [](const platform::CUDAPlace &place, const py::object &cls,
           const py::args args, const py::kwargs kwargs) {
          return imperative::PyLayerApply(place, cls, args, kwargs);
        });

  m.def("pylayer_apply",
        [](const platform::XPUPlace &place, const py::object &cls,
           const py::args args, const py::kwargs kwargs) {
          return imperative::PyLayerApply(place, cls, args, kwargs);
        });

  m.def("pylayer_apply",
        [](const platform::CUDAPinnedPlace &place, const py::object &cls,
           const py::args args, const py::kwargs kwargs) {
          return imperative::PyLayerApply(place, cls, args, kwargs);
        });
2246 2247 2248 2249 2250 2251

  m.def("pylayer_apply",
        [](const platform::NPUPlace &place, const py::object &cls,
           const py::args args, const py::kwargs kwargs) {
          return imperative::PyLayerApply(place, cls, args, kwargs);
        });
2252 2253 2254 2255
}

}  // namespace pybind
}  // namespace paddle