imperative.cc 64.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/pybind/imperative.h"
16

17
#include <Python.h>
18 19 20 21
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
22

23
#include <algorithm>
24
#include <memory>
25
#include <set>
J
Jiabin Yang 已提交
26
#include <string>
27
#include <unordered_map>
28
#include <unordered_set>
29
#include <utility>
J
Jiabin Yang 已提交
30
#include <vector>
31

32
#include "paddle/fluid/imperative/all_reduce.h"
33
#include "paddle/fluid/imperative/amp_auto_cast.h"
34
#include "paddle/fluid/imperative/basic_engine.h"
35
#include "paddle/fluid/imperative/bkcl_context.h"
36
#include "paddle/fluid/imperative/data_loader.h"
37
#include "paddle/fluid/imperative/layer.h"
J
Jiabin Yang 已提交
38
#include "paddle/fluid/imperative/nccl_context.h"
39
#include "paddle/fluid/imperative/partial_grad_engine.h"
40
#include "paddle/fluid/imperative/profiler.h"
41
#include "paddle/fluid/imperative/reducer.h"
42
#include "paddle/fluid/imperative/tracer.h"
M
minqiyang 已提交
43
#include "paddle/fluid/imperative/type_defs.h"
44
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
45
#include "paddle/fluid/pybind/op_function.h"
46
#include "paddle/fluid/pybind/pybind_boost_headers.h"
L
Leo Chen 已提交
47
#include "paddle/fluid/pybind/tensor_py.h"
48

49 50 51
namespace paddle {
namespace pybind {

52 53
namespace py = ::pybind11;

54 55 56 57
class Layer : public imperative::Layer {
 public:
  using imperative::Layer::Layer;  // Inherit constructors

58 59 60 61
  std::vector<std::shared_ptr<imperative::VarBase>> Forward(
      const std::vector<std::shared_ptr<imperative::VarBase>> &inputs)
      override {
    PYBIND11_OVERLOAD(std::vector<std::shared_ptr<imperative::VarBase>>, Layer,
J
Jiabin Yang 已提交
62
                      Forward, inputs);  // NOLINT
63 64 65
  }
};

L
Leo Chen 已提交
66 67 68 69 70
static const platform::Place PyObjectToPlace(const py::object &place_obj) {
  if (py::isinstance<platform::CPUPlace>(place_obj)) {
    return place_obj.cast<platform::CPUPlace>();
  } else if (py::isinstance<platform::CUDAPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPlace>();
71 72
  } else if (py::isinstance<platform::XPUPlace>(place_obj)) {
    return place_obj.cast<platform::XPUPlace>();
L
Leo Chen 已提交
73 74
  } else if (py::isinstance<platform::CUDAPinnedPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPinnedPlace>();
75 76
  } else if (py::isinstance<platform::Place>(place_obj)) {
    return place_obj.cast<platform::Place>();
L
Leo Chen 已提交
77 78
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
79 80
        "Place should be one of "
        "Place/CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace"));
L
Leo Chen 已提交
81 82 83 84 85 86 87
  }
}

static void InitTensorForVarBase(imperative::VarBase *self,
                                 const py::array &array,
                                 const platform::Place place,
                                 bool persistable = false,
88 89
                                 bool zero_copy = false, std::string name = "",
                                 int stop_gradient = -1) {
L
Leo Chen 已提交
90
  if (name == "") {
91 92
    name =
        imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor");
L
Leo Chen 已提交
93
  }
94 95 96
  VLOG(5) << "Init Tensor as: / name: " << name
          << " / persistable: " << persistable << " / zero_copy: " << zero_copy
          << " / stop_gradient: " << stop_gradient;
L
Leo Chen 已提交
97
  new (self) imperative::VarBase(name);
98
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
L
Leo Chen 已提交
99 100
  if (platform::is_cpu_place(place)) {
    SetTensorFromPyArray<platform::CPUPlace>(
101
        tensor, array, BOOST_GET_CONST(platform::CPUPlace, place), zero_copy);
102 103 104
  } else if (platform::is_xpu_place(place)) {
    SetTensorFromPyArray<platform::XPUPlace>(
        tensor, array, BOOST_GET_CONST(platform::XPUPlace, place), zero_copy);
L
Leo Chen 已提交
105 106
  } else if (platform::is_gpu_place(place)) {
    SetTensorFromPyArray<platform::CUDAPlace>(
107
        tensor, array, BOOST_GET_CONST(platform::CUDAPlace, place), zero_copy);
L
Leo Chen 已提交
108 109
  } else if (platform::is_cuda_pinned_place(place)) {
    SetTensorFromPyArray<platform::CUDAPinnedPlace>(
110 111
        tensor, array, BOOST_GET_CONST(platform::CUDAPinnedPlace, place),
        zero_copy);
112
  } else {
L
Leo Chen 已提交
113
    PADDLE_THROW(platform::errors::InvalidArgument(
114
        "Place should be one of CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace"));
J
Jiabin Yang 已提交
115
  }
116 117 118
  if (stop_gradient != -1) {
    self->SetOverridedStopGradient(stop_gradient);
  }
L
Leo Chen 已提交
119
  self->SetPersistable(persistable);
120 121 122 123 124 125
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
                                           const py::kwargs &kwargs) {
126
  VLOG(4) << "Init VarBase from kwargs: ";
127 128
  PADDLE_ENFORCE_EQ(
      kwargs.contains("value"), true,
129 130
      platform::errors::NotFound(
          "The kwargs used to create Varbase misses argument: value"));
L
Leo Chen 已提交
131 132 133 134 135 136 137 138
  auto persistable = kwargs.contains("persistable")
                         ? kwargs["persistable"].cast<bool>()
                         : false;
  auto array = kwargs.contains("value") ? kwargs["value"].cast<py::array>()
                                        : py::array();
  auto zero_copy =
      kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast<bool>() : false;
  auto name = kwargs.contains("name") ? kwargs["name"].cast<std::string>() : "";
139 140 141
  auto stop_gradient = kwargs.contains("stop_gradient")
                           ? kwargs["stop_gradient"].cast<int>()
                           : -1;
L
Leo Chen 已提交
142 143 144
  auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
  auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"])
                                        : default_place;
145 146
  InitTensorForVarBase(self, array, place, persistable, zero_copy, name,
                       stop_gradient);
147
}
148

149 150 151
template <typename P>
static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
                                        const py::array &array, const P &place,
L
Leo Chen 已提交
152 153
                                        bool persistable = false,
                                        bool zero_copy = false,
154 155 156 157 158
                                        std::string name = "",
                                        int stop_gradient = -1) {
  VLOG(4) << "Init VarBase from Arg: ";
  // 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name , 6:
  // stop_gradient
L
Leo Chen 已提交
159
  if (name == "") {
160 161
    name =
        imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor");
L
Leo Chen 已提交
162
  }
163 164
  VLOG(5) << "Init Tensor as: / name: " << name
          << " / persistable: " << persistable << " / zero_copy: " << zero_copy
165
          << " / stop_gradient: " << stop_gradient << " / at " << place;
L
Leo Chen 已提交
166
  new (self) imperative::VarBase(name);
167 168
  self->SetPersistable(persistable);
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
169 170 171
  if (stop_gradient != -1) {
    self->SetOverridedStopGradient(stop_gradient);
  }
172 173 174 175 176 177
  SetTensorFromPyArray<P>(tensor, array, place, zero_copy);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
L
Leo Chen 已提交
178 179
                                               const py::array &array) {
  auto place = imperative::GetCurrentTracer()->ExpectedPlace();
180
  VLOG(4) << "Init VarBase from numpy at " << place;
L
Leo Chen 已提交
181
  InitTensorForVarBase(self, array, place);
182
}
183

184 185 186 187 188
static void InitVarBaseFromTensorWithArgDefault(
    imperative::VarBase *self, const framework::LoDTensor &tensor) {
  VLOG(4) << "Init VarBase";
  auto place = imperative::GetCurrentTracer()->ExpectedPlace();
  new (self) imperative::VarBase(
189
      imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"));
190 191 192 193 194 195 196 197 198 199 200 201 202 203
  self->SetPersistable(false);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor.type());
  auto *new_tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
  // Same place,share data directly
  if (place == tensor.place()) {
    new_tensor->ShareDataWith(tensor);
    VLOG(4) << "Same place, do ShareDataWith";
  } else {
    framework::TensorCopy(tensor, place, new_tensor);
    VLOG(4) << "Different place, do TensorCopy";
  }
}

204 205 206 207 208
static std::string GetTypeName(const imperative::VarBase &var) {
  if (var.Type() == framework::proto::VarType::RAW) {
    return "RAW";
  } else if (!var.Var().IsInitialized()) {
    return "nullptr";
209
  } else {
210
    return framework::ToTypeName(var.Var().Type());
211 212
  }
}
L
Leo Chen 已提交
213

214
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
215 216 217 218 219 220

template <typename T>
static T PyObjectCast(PyObject *obj) {
  try {
    return py::cast<T>(py::handle(obj));
  } catch (py::cast_error &) {
221 222
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Python object is not type of %s", typeid(T).name()));
223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
  }
}

// NOTE(zjl): py::handle is a very light wrapper of PyObject *.
// Unlike py::object, py::handle does not change reference count of PyObject *.
static std::vector<std::shared_ptr<imperative::VarBase>>
GetVarBaseListFromPyHandle(const py::handle &handle) {
  PyObject *py_obj = handle.ptr();  // get underlying PyObject
  // Python None is not nullptr in C++!
  if (!py_obj || py_obj == Py_None) {
    return {};
  }

  std::vector<std::shared_ptr<imperative::VarBase>> result;

238
  if (PyList_Check(py_obj)) {  // List of VarBase
239 240 241
    size_t len = PyList_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
242 243 244
      PyObject *py_ivar = PyList_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
245 246 247
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
248
  } else if (PyTuple_Check(py_obj)) {  // Tuple of VarBase
249 250 251
    size_t len = PyTuple_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
252 253 254
      PyObject *py_ivar = PyTuple_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
255 256 257
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
258 259 260
  } else {  // VarBase
    result.emplace_back(
        PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_obj));
261 262 263 264 265
  }

  return result;
}

J
Jiabin Yang 已提交
266 267 268
static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
    const PyNameVarBaseMap &map) {
  imperative::NameVarBaseMap result;
269 270 271 272 273 274
  for (auto &pair : map) {
    auto var_vec = GetVarBaseListFromPyHandle(pair.second);
    if (!var_vec.empty()) {
      result.emplace(pair.first, std::move(var_vec));
    }
  }
J
Jiabin Yang 已提交
275

276 277 278
  PADDLE_ENFORCE_EQ(
      PyErr_Occurred(), nullptr,
      platform::errors::InvalidArgument(py::str(py::handle(PyErr_Occurred()))));
279 280 281
  return result;
}

282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
static bool PyCheckInteger(PyObject *obj) {
#if PY_VERSION_HEX < 0x03000000
  return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj);
#else
  return PyLong_Check(obj) && !PyBool_Check(obj);
#endif
}

// NOTE(zhiqiu): Revised version of PySlice_GetIndices. From:
// https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103
// Original PySlice_GetIndices return wrong result when
// slice_item contains long int, such as arr[:180L].
// NOT sure why this happens !!!
// Besides, PySlice_GetIndices cannot raise error when float in slice item.
// So, I make a revised version of PySlice_GetIndices, named to
// _PySlice_GetIndices. Try to use _PySlice_Unpack which is more robust than
// PySlice_GetIndices in the future.
static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
                               Py_ssize_t *start, Py_ssize_t *stop,
                               Py_ssize_t *step) {
  /* XXX support long ints */
  if (r->step == Py_None) {
    *step = 1;
  } else {
    if (PyCheckInteger(r->step)) {
      *step = PyLong_AsLong(r->step);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->step)->tp_name)));
    }
  }
  if (r->start == Py_None) {
    *start = *step < 0 ? length - 1 : 0;
  } else {
    if (PyCheckInteger(r->start)) {
      *start = PyLong_AsLong(r->start);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->start)->tp_name)));
    }
    if (*start < 0) *start += length;
327
    *start = std::max(*start, static_cast<Py_ssize_t>(0));
328 329 330 331 332 333 334 335 336 337 338 339 340
  }
  if (r->stop == Py_None) {
    *stop = *step < 0 ? -1 : length;
  } else {
    if (PyCheckInteger(r->stop)) {
      *stop = PyLong_AsLong(r->stop);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->stop)->tp_name)));
    }
    if (*stop < 0) *stop += length;
341
    *stop = std::min(*stop, length);
342 343 344 345 346 347 348
  }
  if (*stop > length) return -1;
  if (*start >= length) return -1;
  if (*step == 0) return -1;
  return 0;
}

S
songyouwei 已提交
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372
static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
                               std::vector<int> *slice_axes,
                               std::vector<int> *slice_starts,
                               std::vector<int> *slice_ends,
                               std::vector<int> *slice_strides,
                               std::vector<int> *decrease_axis,
                               std::vector<int> *infer_flags) {
  // We allow indexing by Integers, Slices, and tuples of those
  // types.
  // Ellipsis and None are not supported yet.
  // wrap to tuple
  PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
  PADDLE_ENFORCE_EQ(
      tensor->IsInitialized(), true,
      platform::errors::InvalidArgument("tensor has not been initialized"));
  const auto &shape = tensor->dims();
  const int rank = shape.size();
  const int size = PyTuple_GET_SIZE(index);
  PADDLE_ENFORCE_EQ(
      size <= rank, true,
      platform::errors::InvalidArgument(
          "too many indices (%d) for tensor of dimension %d", size, rank));
  for (int dim = 0; dim < size; ++dim) {
    PyObject *slice_item = PyTuple_GetItem(index, dim);
373 374 375 376 377 378 379
    PADDLE_ENFORCE_EQ(PyCheckInteger(slice_item) || PySlice_Check(slice_item),
                      true,
                      platform::errors::InvalidArgument(
                          "Currently, VarBase.__getitem__() only allows "
                          "indexing by Integers, Slices, and tuples of "
                          "these types, but received %s in %dth slice item",
                          std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
S
songyouwei 已提交
380 381
    infer_flags->push_back(1);
    int dim_len = shape[dim];
382 383
    if (PyCheckInteger(slice_item)) {
      // integer, PyLong_AsLong supports both int and long
S
songyouwei 已提交
384
      int start = static_cast<int>(PyLong_AsLong(slice_item));
H
hong 已提交
385
      auto s_t = start;
S
songyouwei 已提交
386
      start = start < 0 ? start + dim_len : start;
387
      if (start >= dim_len || start < 0) {
H
hong 已提交
388 389 390 391 392 393 394 395 396
        std::string str_error_message =
            "The starting index " + std::to_string(s_t) +
            " of slice is out of bounds in tensor " + std::to_string(dim) +
            "-th axis, it shound be in the range of [" +
            std::to_string(-dim_len) + ", " + std::to_string(dim_len) + ")";
        // py::index_error is corresponding to IndexError in Python
        // Used to indicate out of bounds access in __getitem__, __setitem__
        throw py::index_error(str_error_message);
      }
S
songyouwei 已提交
397 398 399 400 401 402
      slice_axes->push_back(dim);
      slice_starts->push_back(start);
      slice_ends->push_back(start + 1);
      slice_strides->push_back(1);
      decrease_axis->push_back(dim);
    } else {
403
      // slice item
S
songyouwei 已提交
404
      Py_ssize_t start, end, step;
405 406 407
      PySliceObject *p = reinterpret_cast<PySliceObject *>(slice_item);
      _PySlice_GetIndices(p, dim_len, &start, &end, &step);

S
songyouwei 已提交
408
      // :: or : or 0:dim_len:1
409 410 411
      if (start == 0 && end == dim_len && step == 1) {
        continue;
      }
S
songyouwei 已提交
412 413 414 415 416 417 418 419 420
      slice_axes->push_back(dim);
      slice_starts->push_back(start);
      slice_ends->push_back(end);
      slice_strides->push_back(step);
    }
  }
  if (!PyTuple_Check(_index)) Py_DecRef(index);
}

421
// Bind Methods
J
Jiabin Yang 已提交
422
void BindImperative(py::module *m_ptr) {
423 424
  auto &m = *m_ptr;

425 426
  BindOpFunctions(&m);

427 428
#ifndef _WIN32
  // Dygraph DataLoader signal handler
429 430 431 432 433 434 435 436 437 438 439 440 441
  m.def("_set_process_pids", [](int64_t key, py::object &obj) {
    PADDLE_ENFORCE_EQ(
        py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj), true,
        platform::errors::InvalidArgument(
            "The subprocess ids set in DataLoader is illegal."
            "Expected data type is tuple or list, but received %s",
            obj.get_type()));
    py::list pids = py::cast<py::list>(obj);
    std::set<pid_t> pids_set = {};
    for (size_t i = 0; i < pids.size(); i++) {
      pids_set.insert(pids[i].cast<pid_t>());
    }
    imperative::SetLoadProcessPIDs(key, pids_set);
442
  });
443 444
  m.def("_erase_process_pids",
        [](int64_t key) { imperative::EraseLoadProcessPIDs(key); });
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496
  m.def("_set_process_signal_handler",
        []() { imperative::SetLoadProcessSignalHandler(); });
  m.def("_throw_error_if_process_failed",
        []() { imperative::ThrowErrorIfLoadProcessFailed(); });

  // Dygraph DataLoader reader process & thread related functions
  m.def(
      "_convert_to_tensor_list",
      [](py::object &obj) -> py::list {
        // 0. input data check
        PADDLE_ENFORCE(
            py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj),
            platform::errors::InvalidArgument(
                "The batch data read into DataLoader is illegal."
                "Expected data type is tuple or list, but received %s",
                obj.get_type()));
        py::list batch = py::cast<py::list>(obj);
        py::list tensors;
        for (size_t i = 0; i < batch.size(); ++i) {
          // 1. cast to python array
          auto array = batch[i].cast<py::array>();
          PADDLE_ENFORCE_NE(
              string::Sprintf("%s", array.dtype()).compare("object"), 0,
              platform::errors::InvalidArgument(
                  "Faild to convert input data to a regular ndarray.\n  * "
                  "Usually this means the input data contains nested "
                  "lists with different lengths.\n  * Check the reader "
                  "function passed to 'set_(sample/sample_list/batch)"
                  "_generator' to locate the data causes this issue."));
          // 2. construcct LoDTensor
          framework::LoDTensor t;
          SetTensorFromPyArray<platform::CPUPlace>(&t, array,
                                                   platform::CPUPlace(), true);
          // 3. allocate shared memory
          void *data_ptr = t.data<void>();
          size_t data_size = t.numel() * framework::SizeOfType(t.type());
          auto shared_writer_holder =
              memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
          // 4. maintain mmap fd set & backup ipc_name
          const std::string &ipc_name = shared_writer_holder->ipc_name();
          memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
          // 5. copy data & reset holder
          memory::Copy(platform::CPUPlace(), shared_writer_holder->ptr(),
                       platform::CPUPlace(), data_ptr, data_size);
          t.ResetHolder(shared_writer_holder);
          // 6. append to result list
          tensors.append(t);
        }
        return tensors;
      },
      py::return_value_policy::take_ownership);

K
Kaipeng Deng 已提交
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529
  m.def("_array_to_share_memory_tensor",
        [](py::object &obj) {
          // 1. cast to python array
          auto array = obj.cast<py::array>();
          PADDLE_ENFORCE_NE(
              string::Sprintf("%s", array.dtype()).compare("object"), 0,
              platform::errors::InvalidArgument(
                  "Faild to convert input data to a regular ndarray.\n  * "
                  "Usually this means the input data contains nested "
                  "lists with different lengths.\n  * Check the reader "
                  "function passed to 'set_(sample/sample_list/batch)"
                  "_generator' to locate the data causes this issue."));
          // 2. construcct LoDTensor
          framework::LoDTensor t;
          SetTensorFromPyArray<platform::CPUPlace>(&t, array,
                                                   platform::CPUPlace(), true);
          // 3. allocate shared memory
          void *data_ptr = t.data<void>();
          size_t data_size = t.numel() * framework::SizeOfType(t.type());
          auto shared_writer_holder =
              memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
          // 4. maintain mmap fd set & backup ipc_name
          const std::string &ipc_name = shared_writer_holder->ipc_name();
          memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
          // 5. copy data & reset holder
          memory::Copy(platform::CPUPlace(), shared_writer_holder->ptr(),
                       platform::CPUPlace(), data_ptr, data_size);
          t.ResetHolder(shared_writer_holder);

          return t;
        },
        py::return_value_policy::take_ownership);

530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549
  m.def("_remove_tensor_list_mmap_fds", [](py::list &tensor_list) {
    for (size_t i = 0; i < tensor_list.size(); ++i) {
      auto t = tensor_list[i].cast<framework::LoDTensor>();
      auto *mmap_writer_allocation =
          dynamic_cast<memory::allocation::MemoryMapWriterAllocation *>(
              t.Holder().get());
      PADDLE_ENFORCE_NOT_NULL(
          mmap_writer_allocation,
          platform::errors::NotFound("The shared memory of LoDTensor in "
                                     "DataLoader's child process has been "
                                     "released."));
      memory::allocation::MemoryMapFdSet::Instance().Remove(
          mmap_writer_allocation->ipc_name());
    }
  });

  m.def("_cleanup_mmap_fds",
        []() { memory::allocation::MemoryMapFdSet::Instance().Clear(); });
#endif

550 551 552 553 554
  m.def("start_imperative_gperf_profiler",
        []() { imperative::StartProfile(); });

  m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); });

Z
Zeng Jinle 已提交
555 556 557
  m.def("_is_dygraph_debug_enabled",
        []() { return imperative::IsDebugEnabled(); });
  m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); });
558 559 560 561
  m.def("_switch_tracer",
        [](const std::shared_ptr<imperative::Tracer> &tracer) {
          imperative::SetCurrentTracer(tracer);
        });
Z
Zeng Jinle 已提交
562

563
  py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
564
      m, "VarBase", R"DOC()DOC")
Z
Zeng Jinle 已提交
565
      .def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
566 567 568 569 570 571 572
      .def("__init__",
           [](imperative::VarBase &self) {
             std::string name =
                 imperative::GetCurrentTracer()->GenerateUniqueName(
                     "generated_tensor");
             new (&self) imperative::VarBase(name);
           })
J
Jiabin Yang 已提交
573
      .def("__init__",
574 575 576
           [](imperative::VarBase &self, framework::proto::VarType::Type dtype,
              const std::vector<int> &dims, const py::handle &name,
              framework::proto::VarType::Type type, bool persistable) {
577
             VLOG(4) << "Init VarBase";
578 579 580
             std::string act_name = "";
             if (!name.ptr() || name.ptr() == Py_None) {
               act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
581
                   "generated_tensor");
582 583 584 585
             } else {
               act_name = name.cast<std::string>();
             }
             new (&self) imperative::VarBase(act_name);
J
Jiabin Yang 已提交
586 587 588 589 590 591 592 593 594
             self.SetPersistable(persistable);
             self.SetType(type);
             self.SetDataType(dtype);
             if (type == framework::proto::VarType::LOD_TENSOR) {
               auto *tensor =
                   self.MutableVar()->GetMutable<framework::LoDTensor>();
               tensor->Resize(framework::make_ddim(dims));
             }
           })
595 596
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
597 598
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
599 600 601 602
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::XPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
603 604
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
605 606
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
607 608
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPinnedPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
609 610
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
L
Leo Chen 已提交
611
      .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
612
      .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
613
      .def("__init__", &InitVarBaseFromNumpyWithKwargs)
614 615 616 617 618
      .def("__setitem__",
           [](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
              py::object &value_obj) {
             auto self_tensor =
                 self->MutableVar()->GetMutable<framework::LoDTensor>();
619 620 621 622
             PyObject *index_ptr = !PyTuple_Check(_index.ptr())
                                       ? PyTuple_Pack(1, _index.ptr())
                                       : _index.ptr();
             // 1. Check argumnets
623 624
             // 1.1 Check whether value obj is a tensor.
             bool value_is_tensor = true;
625
             bool parse_index = true;
626 627 628 629 630 631 632
             if (py::isinstance<py::array>(value_obj) ||
                 py::isinstance<py::int_>(value_obj) ||
                 py::isinstance<py::float_>(value_obj)) {
               value_is_tensor = false;
             }

             // 1.2 Check whether _index can be parsed.
633 634 635 636 637 638 639 640 641 642 643 644 645
             const int size = PyTuple_GET_SIZE(index_ptr);
             for (int dim = 0; dim < size; ++dim) {
               PyObject *slice_item = PyTuple_GetItem(index_ptr, dim);
               if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item))) {
                 parse_index = false;
                 break;
               }
             }

             // 2. Call op set_value to speed up if the condition is met,
             // otherwise call TensorToPyArray.
             // TODO(liym27): Try not to call TensorToPyArray because it always
             // copys data to cpu place, which reduces performance.
646
             if (parse_index && value_is_tensor) {
647
               std::vector<int> axes, starts, ends, steps, decrease_axes,
648 649
                   infer_flags;
               ParseIndexingSlice(self_tensor, index_ptr, &axes, &starts, &ends,
650 651 652 653 654 655 656 657
                                  &steps, &decrease_axes, &infer_flags);

               framework::AttributeMap attrs = {
                   {"axes", axes},
                   {"starts", starts},
                   {"ends", ends},
                   {"steps", steps},
                   {"decrease_axes", decrease_axes}};
658 659 660

               imperative::NameVarBaseMap ins = {{"Input", {self}}};
               imperative::NameVarBaseMap outs = {{"Out", {self}}};
661

662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690
               auto value_tensor =
                   value_obj.cast<std::shared_ptr<imperative::VarBase>>();
               ins.insert({"ValueTensor", {value_tensor}});

               const auto &tracer = imperative::GetCurrentTracer();
               {
                 // Release gil and do tracing
                 py::gil_scoped_release release;
                 tracer->TraceOp("set_value", ins, outs, std::move(attrs));
               }
             } else {
               auto self_numpy = TensorToPyArray(*self_tensor);

               if (value_is_tensor) {
                 auto value =
                     value_obj.cast<std::shared_ptr<imperative::VarBase>>();
                 auto value_tensor =
                     value->MutableVar()->GetMutable<framework::LoDTensor>();
                 auto value_numpy = TensorToPyArray(*value_tensor);

                 self_numpy[_index] = value_numpy;
                 SetTensorFromPyArray(self_tensor, self_numpy,
                                      self_tensor->place(), true);
               } else {
                 auto value_numpy = value_obj;
                 self_numpy[_index] = value_numpy;
                 SetTensorFromPyArray(self_tensor, self_numpy,
                                      self_tensor->place(), true);
               }
691
             }
692 693 694 695
             // NOTE(liym27):
             // Increase the version of VarBase self because __setitem__ is an
             // inplace operator for the VarBase self.
             self->BumpInplaceVersion();
696
           })
697
      .def("__getitem__",
S
songyouwei 已提交
698
           [](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
699
             std::vector<int> slice_axes, slice_starts, slice_ends,
S
songyouwei 已提交
700 701 702 703 704 705
                 slice_strides, decrease_axis, infer_flags;
             auto tensor =
                 self->MutableVar()->GetMutable<framework::LoDTensor>();
             ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
                                &slice_starts, &slice_ends, &slice_strides,
                                &decrease_axis, &infer_flags);
706 707 708 709
             // release gil and do tracing
             py::gil_scoped_release release;
             const auto &tracer = imperative::GetCurrentTracer();
             if (slice_axes.empty()) {
S
songyouwei 已提交
710
               return self;
711
             } else {
S
songyouwei 已提交
712
               imperative::NameVarBaseMap ins = {{"Input", {self}}};
713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734
               framework::AttributeMap attrs = {
                   {"axes", slice_axes},
                   {"starts", slice_starts},
                   {"ends", slice_ends},
                   {"infer_flags", infer_flags},
                   {"decrease_axis", decrease_axis}};
               auto out = std::shared_ptr<imperative::VarBase>(
                   new imperative::VarBase(tracer->GenerateUniqueName()));
               imperative::NameVarBaseMap outs = {{"Out", {out}}};
               std::string op_type = "slice";
               for (auto stride : slice_strides) {
                 if (stride != 1) {
                   op_type = "strided_slice";
                   attrs.insert({"strides", slice_strides});
                   attrs.erase("decrease_axis");
                   break;
                 }
               }
               tracer->TraceOp(op_type, ins, outs, std::move(attrs));
               return out;
             }
           })
735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756
      .def("_inplace_version",
           [](imperative::VarBase &self) -> uint32_t {
             const auto &var = self.MutableVar();
             PADDLE_ENFORCE_EQ(
                 var->IsInitialized(), true,
                 platform::errors::InvalidArgument(
                     "Tensor of %s is Empty, please check if it has no data.",
                     self.Name()));
             return var->CurrentInplaceVersion();
           })
      .def("_bump_inplace_version",
           [](std::shared_ptr<imperative::VarBase> &self) {
             // NOTE(liym27): _bump_inplace_version is only used for inplace
             // operation
             self->BumpInplaceVersion();
           },
           R"DOC(
        **Notes**:
            **This API is ONLY available in Dygraph mode.**
            **This is a very low level API. Users should not use it directly. **
         Bump the version whenever the Tensor is modified through an inplace operation.
            )DOC")
757 758 759 760 761 762 763
      .def("numpy",
           [](imperative::VarBase &self) -> py::array {
             const auto &tensor =
                 self.MutableVar()->Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(
                 tensor.IsInitialized(), true,
                 platform::errors::InvalidArgument(
764
                     "Tensor of %s is Empty, please check if it has no data.",
765 766 767 768
                     self.Name()));
             return TensorToPyArray(tensor, true);
           },
           R"DOC(
Z
Zhou Wei 已提交
769 770
        Returns a numpy array shows the value of current Tensor.
        
771
        Returns:
Z
Zhou Wei 已提交
772
            ndarray: The numpy value of current Tensor.
773 774

        Returns type:
Z
Zhou Wei 已提交
775
            ndarray: dtype is same as current Tensor
776 777 778 779

        Examples:
            .. code-block:: python

Z
Zhou Wei 已提交
780
                import paddle
781 782
                import numpy as np
                data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
Z
Zhou Wei 已提交
783 784 785 786
                linear = paddle.nn.Linear(32, 64)
                data = paddle.to_tensor(data)
                x = linear(data)
                print(x.numpy())
787
       )DOC")
788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850
      .def("detach",
           [](const imperative::VarBase
                  &self) -> std::shared_ptr<imperative::VarBase> {
             PADDLE_ENFORCE_EQ(
                 self.Var().IsInitialized(), true,
                 platform::errors::InvalidArgument(
                     "Tensor %s has not been initialized!", self.Name()));

             PADDLE_ENFORCE_EQ(
                 self.Var().IsType<framework::LoDTensor>() ||
                     self.Var().IsType<framework::SelectedRows>(),
                 true,
                 platform::errors::InvalidArgument(
                     "Type of Tensor[%s] must be LoDTensor or SelectedRows!",
                     self.Name()));

             auto detach_var = std::make_shared<imperative::VarBase>(
                 true, "detach_" + self.Name());

             detach_var->SetPersistable(self.Persistable());
             detach_var->SetType(self.Type());
             detach_var->SetDataType(self.DataType());

             if (self.Var().IsType<framework::LoDTensor>()) {
               const auto &origin_tensor =
                   self.Var().Get<framework::LoDTensor>();
               PADDLE_ENFORCE_EQ(
                   origin_tensor.IsInitialized(), true,
                   platform::errors::InvalidArgument(
                       "Tensor %s has not been initialized!", self.Name()));

               auto *detach_tensor =
                   detach_var->MutableVar()->GetMutable<framework::LoDTensor>();
               detach_tensor->ShareDataWith(origin_tensor);
               // NOTE(liym27): Call ShareInplaceVersionCounterWith to share the
               // same TensorInplaceVersion, which is used to check whether
               // inplace
               // operations are correct.
               detach_tensor->ShareInplaceVersionCounterWith(origin_tensor);
             } else {
               const auto &origin_selected_rows =
                   self.Var().Get<framework::SelectedRows>();
               PADDLE_ENFORCE_EQ(
                   origin_selected_rows.value().IsInitialized(), true,
                   platform::errors::InvalidArgument(
                       "Tensor %s has not been initialized!", self.Name()));

               auto *detach_selected_rows =
                   detach_var->MutableVar()
                       ->GetMutable<framework::SelectedRows>();
               detach_selected_rows->set_height(origin_selected_rows.height());
               detach_selected_rows->set_rows(origin_selected_rows.rows());
               detach_selected_rows->mutable_value()->ShareDataWith(
                   origin_selected_rows.value());
               detach_selected_rows->mutable_value()
                   ->ShareInplaceVersionCounterWith(
                       origin_selected_rows.value());
             }
             VLOG(3) << "The detached Tensor(" << detach_var->Name()
                     << ") share data with " << self.Name();
             return detach_var;
           },
           py::return_value_policy::take_ownership, R"DOC(
851

852
        Returns a new Tensor, detached from the current graph.
Z
Zhou Wei 已提交
853 854
        It will share data with origin Tensor and always doesn't have a Tensor copy.
        In addition, the detached Tensor doesn't provide gradient propagation.
855

856
        Returns: The detached Tensor.
857 858 859 860

        Examples:
            .. code-block:: python

861
                import paddle
Z
Zhou Wei 已提交
862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886

                x = paddle.to_tensor(1.0, stop_gradient=False)
                detach_x = x.detach()
                detach_x[:] = 10.0
                print(x)  # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False,
                          #        [10.])
                y = x**2
                y.backward()
                print(x.grad)         # [20.0]
                print(detach_x.grad)  # None, 'stop_gradient=True' by default

                detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
                z = detach_x**3
                z.backward()

                print(x.grad)         # [20.0], detach_x is detached from x's graph, not affect each other
                print(detach_x.grad)  # [300.0], detach_x has its own graph

                # Due to sharing of data with origin Tensor, There are some unsafe operations:
                y = 2 * x
                detach_x[:] = 5.0
                y.backward() 
                # It will raise Error:
                #   one of the variables needed for gradient computation has been modified by an inplace operation.
             
887 888 889
       )DOC")
      .def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(

890
        Only for Tensor that has gradient, normally we use this for Parameters since other temporary Tensor doesen't has gradient.
891

892
        The Gradient of current Tensor will be set to ``0`` .
893 894 895 896 897 898

        Returns:  None

        Examples:
             .. code-block:: python

899
                import paddle
Z
Zhou Wei 已提交
900 901 902 903 904 905 906
                input = paddle.uniform([10, 2])
                linear = paddle.nn.Linear(2, 3)
                out = linear(input)
                out.backward()
                print("Before clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
                linear.weight.clear_gradient()
                print("After clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
907
      )DOC")
Z
Zhou Wei 已提交
908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955
      .def("clone",
           [](std::shared_ptr<imperative::VarBase> &self) {
             const auto &tensor = self->Var().Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(
                 tensor.IsInitialized(), true,
                 platform::errors::InvalidArgument(
                     "%s has not been initialized", self->Name()));
             auto tracer = imperative::GetCurrentTracer();
             auto new_var = std::make_shared<imperative::VarBase>(
                 true, tracer->GenerateUniqueName(self->Name() + "_clone"));
             framework::AttributeMap attrs;
             imperative::NameVarBaseMap ins = {{"X", {self}}};
             imperative::NameVarBaseMap outs = {{"Out", {new_var}}};
             tracer->TraceOp("assign", ins, outs, attrs);
             return new_var;
           },
           py::return_value_policy::copy, R"DOC(

        Returns a new Tensor, which is clone of origin Tensor, and it remains in the current graph.
        It will always have a Tensor copy.
        Tn addition, the cloned Tensor provides gradient propagation.

        Returns: The cloned Tensor.

        Examples:
            .. code-block:: python

              import paddle

              x = paddle.to_tensor(1.0, stop_gradient=False)
              clone_x = x.clone()
              y = clone_x**2
              y.backward()
              print(clone_x.stop_gradient) # False
              print(clone_x.grad)          # [2.0], support gradient propagation
              print(x.stop_gradient)       # False
              print(x.grad)                # [2.0], clone_x support gradient propagation for x

              x = paddle.to_tensor(1.0)
              clone_x = x.clone()
              clone_x.stop_gradient = False
              z = clone_x**3
              z.backward()
              print(clone_x.stop_gradient) # False
              print(clone_x.grad)          # [3.0], support gradient propagation
              print(x.stop_gradient) # True
              print(x.grad)          # None
       )DOC")
L
Leo Chen 已提交
956
      .def("_run_backward",
957 958
           [](imperative::VarBase &self, const imperative::Tracer &tracer,
              bool retain_graph) {
959 960
             // TODO(jiabin): when we impl more backward execution we can
             // select them
961
             auto *engine = tracer.GetEngine();
962
             engine->Init(&self, retain_graph);
963
             VLOG(3) << "Start backward";
L
Leo Chen 已提交
964 965 966 967 968 969 970 971 972 973
             engine->Execute();
             VLOG(3) << "Finish backward";
           },
           py::call_guard<py::gil_scoped_release>())
      .def("_grad_name", &imperative::VarBase::GradVarName)
      .def("_grad_value",
           [](imperative::VarBase &self) {
             return self.MutableGradVar()->Get<framework::LoDTensor>();
           },
           py::return_value_policy::reference)
974 975 976 977
      .def("_set_grad_type",
           [](imperative::VarBase &self, framework::proto::VarType::Type type) {
             self.MutableGradVarBase()->SetType(type);
           })
978
      .def("_grad_ivar",
J
Jiabin Yang 已提交
979 980
           [](const imperative::VarBase &self) {
             auto &grad_var = self.GradVarBase();
981 982 983 984 985 986 987 988 989 990 991
             if (grad_var && grad_var->Var().IsInitialized()) {
               auto *tensor =
                   grad_var->MutableVar()->IsType<framework::LoDTensor>()
                       ? grad_var->MutableVar()
                             ->GetMutable<framework::LoDTensor>()
                       : grad_var->MutableVar()
                             ->GetMutable<framework::SelectedRows>()
                             ->mutable_value();
               if (tensor->IsInitialized()) {
                 return grad_var;
               }
J
Jiabin Yang 已提交
992
             }
993
             return std::shared_ptr<imperative::VarBase>(nullptr);
J
Jiabin Yang 已提交
994 995
           },
           py::return_value_policy::copy)
996 997 998 999 1000 1001 1002 1003
      .def("_is_sparse",
           [](imperative::VarBase &self) {
             return self.Var().IsType<framework::SelectedRows>();
           })
      .def("_allreduce",
           [](imperative::VarBase &self,
              const imperative::ParallelStrategy &strategy) {
             if (strategy.nranks_ > 1) {
1004
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021
#if NCCL_VERSION_CODE >= 2212
               imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
#else
               if (!self.Var().IsType<framework::SelectedRows>()) {
                 imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
               } else {
                 PADDLE_THROW(platform::errors::Unimplemented(
                     "Imperative SelectedRows allreduce is not supported when "
                     "paddle is compiled with NCCL verison lower than v2.2.12. "
                     "You can set is_sparse=False for the Layer containing "
                     "this argument, such as Embedding(is_sparse=False)."));
               }
#endif  // NCCL_VERSION_CODE
#else
               PADDLE_THROW(platform::errors::Unimplemented(
                   "Imperative allreduce is not supported when paddle is "
                   "not compiled with NCCL."));
1022
#endif  // PADDLE_WITH_NCCL or PADDLE_WITH_RCCL
1023 1024 1025
             }
           },
           py::call_guard<py::gil_scoped_release>())
1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053
      .def("cpu",
           [](const std::shared_ptr<imperative::VarBase> &self) {
             if (platform::is_cpu_place(self->Place())) {
               return self;
             } else {
               auto new_var = self->NewVarBase(platform::CPUPlace(), true);
               new_var->SetOverridedStopGradient(self->OverridedStopGradient());
               return new_var;
             }
           },
           R"DOC(
        Returns a copy of this Tensor in CPU memory.

        If this Tensor is already in CPU memory, then no copy is performed and the original Tensor is returned.

        Examples:
            .. code-block:: python

              import paddle
              x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
              print(x.place)    # CUDAPlace(0)
              
              y = x.cpu()
              print(y.place)    # CPUPlace

              )DOC")
      .def("pin_memory",
           [](const std::shared_ptr<imperative::VarBase> &self) {
1054
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087
             PADDLE_THROW(platform::errors::PermissionDenied(
                 "Cannot copy this Tensor to pinned memory in CPU version "
                 "Paddle, "
                 "Please recompile or reinstall Paddle with CUDA support."));
#endif
             if (platform::is_cuda_pinned_place(self->Place())) {
               return self;
             } else {
               auto new_var =
                   self->NewVarBase(platform::CUDAPinnedPlace(), true);
               new_var->SetOverridedStopGradient(self->OverridedStopGradient());
               return new_var;
             }
           },
           R"DOC(
        Returns a copy of this Tensor in pin memory.

        If this Tensor is already in pin memory, then no copy is performed and the original Tensor is returned.

        Examples:
            .. code-block:: python

              import paddle
              x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
              print(x.place)      # CUDAPlace(0)

              y = x.pin_memory()
              print(y.place)      # CUDAPinnedPlace

      )DOC")
      .def("cuda",
           [](const std::shared_ptr<imperative::VarBase> &self, int device_id,
              bool blocking) {
1088
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146
             PADDLE_THROW(platform::errors::PermissionDenied(
                 "Cannot copy this Tensor to GPU in CPU version Paddle, "
                 "Please recompile or reinstall Paddle with CUDA support."));
#else
             int device_count = platform::GetCUDADeviceCount();
             if (device_id == -1) {
               if (platform::is_gpu_place(self->Place())) {
                 return self;
               } else {
                 device_id = 0;
               }
             }
             PADDLE_ENFORCE_GE(
                 device_id, 0,
                 platform::errors::InvalidArgument(
                     "Can not copy Tensor to Invalid CUDAPlace(%d), device id "
                     "must inside [0, %d)",
                     device_id, device_count));
             PADDLE_ENFORCE_LT(
                 device_id, device_count,
                 platform::errors::InvalidArgument(
                     "Can not copy Tensor to Invalid CUDAPlace(%d), device id "
                     "must inside [0, %d)",
                     device_id, device_count));
             platform::CUDAPlace place = platform::CUDAPlace(device_id);
             if (platform::is_same_place(self->Place(), place)) {
               return self;
             } else {
               auto new_var = self->NewVarBase(place, blocking);
               new_var->SetOverridedStopGradient(self->OverridedStopGradient());
               return new_var;
             }
#endif
           },
           py::arg("device_id") = -1, py::arg("blocking") = true, R"DOC(
        Returns a copy of this Tensor in GPU memory.

        If this Tensor is already in GPU memory and device_id is default, 
        then no copy is performed and the original Tensor is returned.
        
        Args:
            device_id(int, optional): The destination GPU device id. Defaults to the current device.
            blocking(bool, optional): If False and the source is in pinned memory, the copy will be 
              asynchronous with respect to the host. Otherwise, the argument has no effect. Default: False.

        Examples:
            .. code-block:: python

              import paddle
              x = paddle.to_tensor(1.0, place=paddle.CPUPlace())
              print(x.place)        # CPUPlace

              y = x.cuda()
              print(y.place)        # CUDAPlace(0)

              y = x.cuda(1)
              print(y.place)        # CUDAPlace(1)
       )DOC")
K
Kaipeng Deng 已提交
1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175
      .def("_share_memory",
           [](const std::shared_ptr<imperative::VarBase> &self) {
#ifndef _WIN32
             PADDLE_ENFORCE_EQ(
                 platform::is_cpu_place(self->Place()), true,
                 platform::errors::InvalidArgument(
                     "Sharing memory only support CPU Tensor currently"));
             // 1. get LoDTensor
             auto *t = self->MutableVar()->GetMutable<framework::LoDTensor>();
             // 2. allocate shared memory
             void *data_ptr = t->data<void>();
             size_t data_size = t->numel() * framework::SizeOfType(t->type());
             auto shared_writer_holder =
                 memory::allocation::AllocateMemoryMapWriterAllocation(
                     data_size);
             // 3. maintain mmap fd set & backup ipc_name
             const std::string &ipc_name = shared_writer_holder->ipc_name();
             memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
             // 4. copy data & reset holder
             memory::Copy(platform::CPUPlace(), shared_writer_holder->ptr(),
                          platform::CPUPlace(), data_ptr, data_size);
             t->ResetHolder(shared_writer_holder);
             return *t;
#else
             PADDLE_THROW(platform::errors::PermissionDenied(
                 "Sharing memory in Windows OS is not supported currently"));
#endif
           },
           py::return_value_policy::reference)
1176
      .def("copy_", &imperative::VarBase::CopyFrom)
1177
      .def("_copy_to",
1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193
           [](const std::shared_ptr<imperative::VarBase> &self,
              const platform::CPUPlace &place, bool blocking) {
             auto new_var = self->NewVarBase(place, blocking);
             // Note(zhiqiu): Since NewVarBase may use GpuCopyAsync to
             // copy data from the tensor of self to the tensor of new varbase,
             // we need to ensure that the varbase self is not destructed until
             // the GpuCopyAsync is completed. Otherwise, the memory may be
             // freed
             // when varbase self is destructed.
             // To do that, we increase the reference count of self by 1 and
             // add a cuda event to wait the GpuCopyAsync's completion.
             if (!blocking) {
               IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
             }
             return new_var;
           },
J
Jiabin Yang 已提交
1194
           py::return_value_policy::copy)
1195
      .def("_copy_to",
1196 1197 1198 1199 1200 1201 1202 1203
           [](const std::shared_ptr<imperative::VarBase> &self,
              const platform::CUDAPinnedPlace &place, bool blocking) {
             auto new_var = self->NewVarBase(place, blocking);
             if (!blocking) {
               IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
             }
             return new_var;
           },
1204
           py::return_value_policy::copy)
1205
      .def("_copy_to",
1206 1207 1208 1209 1210 1211 1212 1213
           [](const std::shared_ptr<imperative::VarBase> &self,
              const platform::XPUPlace &place, bool blocking) {
             auto new_var = self->NewVarBase(place, blocking);
             if (!blocking) {
               IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
             }
             return new_var;
           },
1214
           py::return_value_policy::copy)
1215
      .def("_copy_to",
1216 1217 1218 1219 1220 1221 1222 1223
           [](const std::shared_ptr<imperative::VarBase> &self,
              const platform::CUDAPlace &place, bool blocking) {
             auto new_var = self->NewVarBase(place, blocking);
             if (!blocking) {
               IncreaseVarbaseReferenceCountUntilCopyComplete(self, place);
             }
             return new_var;
           },
J
Jiabin Yang 已提交
1224 1225
           py::return_value_policy::copy)
      .def("value", [](imperative::VarBase &self) { return self.MutableVar(); },
1226 1227 1228
           py::return_value_policy::reference)
      .def_property("name", &imperative::VarBase::Name,
                    &imperative::VarBase::SetName)
L
Leo Chen 已提交
1229 1230 1231 1232 1233
      .def_property("stop_gradient",
                    &imperative::VarBase::OverridedStopGradient,
                    &imperative::VarBase::SetOverridedStopGradient)
      .def_property("persistable", &imperative::VarBase::Persistable,
                    &imperative::VarBase::SetPersistable)
J
Jiabin Yang 已提交
1234 1235 1236 1237
      .def_property_readonly(
          "shape",
          [](imperative::VarBase &self) {
            if (self.Var().IsType<framework::LoDTensor>()) {
1238
              return framework::vectorize<int>(
J
Jiabin Yang 已提交
1239
                  self.Var().Get<framework::LoDTensor>().dims());
1240 1241 1242
            } else if (self.Var().IsType<framework::SelectedRows>()) {
              return framework::vectorize<int>(
                  self.Var().Get<framework::SelectedRows>().value().dims());
J
Jiabin Yang 已提交
1243
            } else {
1244 1245
              VLOG(2) << "It is meaningless to get shape of "
                         "variable type "
J
Jiabin Yang 已提交
1246 1247 1248 1249
                      << GetTypeName(self);
              return std::vector<int>();
            }
          })
1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278
      .def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf,
                             R"DOC(
      Whether a Tensor is leaf Tensor.

      For the Tensor whose stop_gradient is ``True`` , it will be leaf Tensor. 
      
      For the Tensor whose stop_gradient is ``False`` , it will be leaf Tensor too if it is created by user.

      Returns:
          bool: Whether a Tensor is leaf Tensor.

      Examples:
          .. code-block:: python

              import paddle

              x = paddle.to_tensor(1.)
              print(x.is_leaf) # True

              x = paddle.to_tensor(1., stop_gradient=True)
              y = x + 1
              print(x.is_leaf) # True
              print(y.is_leaf) # True

              x = paddle.to_tensor(1., stop_gradient=False)
              y = x + 1
              print(x.is_leaf) # True
              print(y.is_leaf) # False
       )DOC")
1279 1280 1281
      .def_property_readonly(
          "place", [](imperative::VarBase &self) { return self.Place(); },
          py::return_value_policy::copy)
1282 1283 1284 1285 1286 1287
      .def_property_readonly("_place_str",
                             [](imperative::VarBase &self) {
                               std::stringstream ostr;
                               ostr << self.Place();
                               return ostr.str();
                             })
J
Jiabin Yang 已提交
1288
      .def_property_readonly("type", &imperative::VarBase::Type)
L
Leo Chen 已提交
1289
      .def_property_readonly("dtype", &imperative::VarBase::DataType);
1290 1291 1292

  py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer");
  layer.def(py::init<>())
1293 1294 1295 1296 1297
      .def("forward",
           [](imperative::Layer &self,
              const std::vector<std::shared_ptr<imperative::VarBase>> &inputs) {
             return self.Forward(inputs);
           });
1298

1299 1300 1301 1302 1303
  py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "")
      .def("create_program_desc",
           &imperative::jit::ProgramDescTracer::CreateProgramDesc)
      .def("reset", &imperative::jit::ProgramDescTracer::Reset);

1304
  py::class_<imperative::Tracer, std::shared_ptr<imperative::Tracer>>(
1305
      m, "Tracer", R"DOC()DOC")
1306
      .def("__init__",
J
Jiabin Yang 已提交
1307
           [](imperative::Tracer &self) { new (&self) imperative::Tracer(); })
1308 1309 1310
      .def_property("_enable_program_desc_tracing",
                    &imperative::Tracer::IsProgramDescTracingEnabled,
                    &imperative::Tracer::SetEnableProgramDescTracing)
1311 1312
      .def_property("_enable_autocast", &imperative::Tracer::IsAutoCastEnabled,
                    &imperative::Tracer::SetEnableAutoCast)
1313
      .def_property("_has_grad", &imperative::Tracer::HasGrad,
1314
                    &imperative::Tracer::SetHasGrad)
1315 1316 1317 1318 1319 1320 1321 1322
      .def_property(
          "_expected_place",
          [](const imperative::Tracer &self) -> py::object {
            return py::cast(self.ExpectedPlace());
          },
          [](imperative::Tracer &self, const py::object &obj) {
            if (py::isinstance<platform::CUDAPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPlace *>();
L
Leo Chen 已提交
1323
              self.SetExpectedPlace(*p);
1324 1325
              VLOG(4) << "Tracer(" << &self << ")"
                      << " set expected place " << *p;
1326 1327 1328
            } else if (py::isinstance<platform::XPUPlace>(obj)) {
              auto p = obj.cast<platform::XPUPlace *>();
              self.SetExpectedPlace(*p);
1329 1330
              VLOG(4) << "Tracer(" << &self << ")"
                      << " set expected place " << *p;
1331 1332
            } else if (py::isinstance<platform::CPUPlace>(obj)) {
              auto p = obj.cast<platform::CPUPlace *>();
L
Leo Chen 已提交
1333
              self.SetExpectedPlace(*p);
1334 1335
              VLOG(4) << "Tracer(" << &self << ")"
                      << " set expected place " << *p;
1336 1337
            } else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPinnedPlace *>();
L
Leo Chen 已提交
1338
              self.SetExpectedPlace(*p);
1339 1340 1341 1342 1343 1344 1345
              VLOG(4) << "Tracer(" << &self << ")"
                      << " set expected place " << *p;
            } else if (py::isinstance<platform::Place>(obj)) {
              auto p = obj.cast<platform::Place *>();
              self.SetExpectedPlace(*p);
              VLOG(4) << "Tracer(" << &self << ")"
                      << " set expected place " << *p;
1346
            } else {
L
Leo Chen 已提交
1347
              PADDLE_THROW(platform::errors::InvalidArgument(
1348 1349
                  "Incompatible Place Type: supports XPUPlace, CUDAPlace, "
                  "CPUPlace, "
L
Leo Chen 已提交
1350 1351
                  "and CUDAPinnedPlace, "
                  "but got Unknown Type!"));
1352 1353
            }
          })
1354 1355 1356
      .def("_get_program_desc_tracer",
           &imperative::Tracer::GetProgramDescTracer,
           py::return_value_policy::reference)
1357
      .def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
1358
           py::arg("key") = "dygraph_tmp")
1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377
      .def("_set_amp_op_list",
           [](imperative::Tracer &self,
              std::unordered_set<std::string> &allow_ops,
              std::unordered_set<std::string> &block_ops) {
             // NOTE(zhiqiu): The automatic conversion in pybind11 between
             // c++
             // STL and python set/list/dict involve a copy operation that
             // prevents pass-by-reference semantics, so it is ok to swap.
             // The reaseon why not directly pass
             // std::shared_ptr<std::unordered_set<std::string>>
             // is that pybind11 forbid shared_ptr<T> where T is not custom
             // type.
             imperative::AmpOperators::Instance().GetMutableAllowOps()->swap(
                 allow_ops);
             imperative::AmpOperators::Instance().GetMutableBlockOps()->swap(
                 block_ops);
             VLOG(4) << "AMP operators changed, "
                     << imperative::AmpOperators::Instance();
           })
1378 1379 1380
      .def("_get_amp_op_list",
           [](imperative::Tracer &self) {
             return std::make_tuple(
1381 1382
                 *(imperative::AmpOperators::Instance().GetMutableAllowOps()),
                 *(imperative::AmpOperators::Instance().GetMutableBlockOps()));
1383
           })
1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::XPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           })
M
minqiyang 已提交
1397
      .def("trace",
J
Jiabin Yang 已提交
1398 1399 1400 1401 1402 1403
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CUDAPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
1404 1405
             {
               py::gil_scoped_release release;
J
Jiabin Yang 已提交
1406 1407
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
1408
             }
M
minqiyang 已提交
1409
           })
J
Jiabin Yang 已提交
1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           });
1423 1424

  // define parallel context
1425 1426 1427
  py::class_<imperative::ParallelStrategy> parallel_strategy(
      m, "ParallelStrategy", "");
  parallel_strategy.def(py::init())
1428 1429
      .def_property(
          "nranks",
1430 1431
          [](const imperative::ParallelStrategy &self) { return self.nranks_; },
          [](imperative::ParallelStrategy &self, int nranks) {
1432 1433 1434
            self.nranks_ = nranks;
          })
      .def_property("local_rank",
1435
                    [](const imperative::ParallelStrategy &self) {
1436 1437
                      return self.local_rank_;
                    },
1438
                    [](imperative::ParallelStrategy &self, int local_rank) {
1439 1440 1441 1442
                      self.local_rank_ = local_rank;
                    })
      .def_property(
          "trainer_endpoints",
1443
          [](const imperative::ParallelStrategy &self) {
1444 1445
            return self.trainer_endpoints_;
          },
1446
          [](imperative::ParallelStrategy &self, std::vector<std::string> eps) {
1447 1448 1449
            self.trainer_endpoints_ = eps;
          })
      .def_property("current_endpoint",
1450
                    [](const imperative::ParallelStrategy &self) {
1451 1452
                      return self.current_endpoint_;
                    },
1453
                    [](imperative::ParallelStrategy &self,
1454 1455 1456 1457 1458 1459 1460
                       const std::string &ep) { self.current_endpoint_ = ep; })
      .def_property(
          "nrings",
          [](const imperative::ParallelStrategy &self) { return self.nrings_; },
          [](imperative::ParallelStrategy &self, int nrings) {
            self.nrings_ = nrings;
          });
1461 1462 1463 1464 1465 1466 1467 1468

  m.def(
      "dygraph_partial_grad",
      [](const std::vector<std::shared_ptr<imperative::VarBase>> &input_targets,
         const std::vector<std::shared_ptr<imperative::VarBase>>
             &output_targets,
         const std::vector<std::shared_ptr<imperative::VarBase>> &output_grads,
         const std::vector<std::shared_ptr<imperative::VarBase>> &no_grad_vars,
1469 1470
         const platform::Place &place, bool create_graph, bool retain_graph,
         bool allow_unused, bool only_inputs) {
Z
Zeng Jinle 已提交
1471 1472
        imperative::PartialGradEngine engine(
            input_targets, output_targets, output_grads, no_grad_vars, place,
1473
            create_graph, retain_graph, allow_unused, only_inputs);
1474 1475 1476 1477 1478
        engine.Execute();
        return engine.GetResult();
      },
      py::call_guard<py::gil_scoped_release>());

1479 1480
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \
    defined(PADDLE_WITH_XPU_BKCL)
1481 1482 1483 1484 1485 1486
  py::class_<imperative::ParallelContext,
             std::shared_ptr<imperative::ParallelContext>>(m,
                                                           "ParallelContext");

  py::class_<imperative::Reducer, std::shared_ptr<imperative::Reducer>>(
      m, "Reducer", R"DOC()DOC")
S
ShenLiang 已提交
1487 1488 1489 1490 1491
      .def(py::init<const std::vector<std::shared_ptr<imperative::VarBase>> &,
                    const std::vector<std::vector<size_t>> &,
                    const std::vector<bool> &,
                    std::shared_ptr<imperative::ParallelContext>,
                    const std::vector<size_t> &, bool>())
1492
      .def("prepare_for_backward", &imperative::Reducer::PrepareForBackward,
1493
           py::arg("vars"), py::call_guard<py::gil_scoped_release>());
1494 1495 1496 1497

  m.def("assign_group_by_size", &imperative::AssignGroupBySize, py::arg("vars"),
        py::arg("is_sparse_gradient"),
        py::arg("group_size_limits") = std::vector<size_t>{25 * 1024 * 1024},
1498
        py::arg("tensor_indices") = std::vector<int64_t>{},
1499
        py::call_guard<py::gil_scoped_release>());
1500
#endif
1501

1502
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518
  py::class_<imperative::NCCLParallelContext, imperative::ParallelContext,
             std::shared_ptr<imperative::NCCLParallelContext>>(
      m, "NCCLParallelContext")
      .def(py::init<const imperative::ParallelStrategy &,
                    const platform::CUDAPlace &>())
      .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); });
#endif

#if defined(PADDLE_WITH_XPU_BKCL)
  py::class_<imperative::BKCLParallelContext, imperative::ParallelContext,
             std::shared_ptr<imperative::BKCLParallelContext>>(
      m, "BKCLParallelContext")
      .def(py::init<const imperative::ParallelStrategy &,
                    const platform::XPUPlace &>())
      .def("init", [](imperative::BKCLParallelContext &self) { self.Init(); });
#endif
1519 1520 1521 1522
}

}  // namespace pybind
}  // namespace paddle