imperative.cc 55.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/pybind/imperative.h"
16

17
#include <Python.h>
18 19 20 21
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
22

23
#include <memory>
24
#include <set>
J
Jiabin Yang 已提交
25
#include <string>
26
#include <unordered_map>
27
#include <unordered_set>
28
#include <utility>
J
Jiabin Yang 已提交
29
#include <vector>
30

31
#include "paddle/fluid/imperative/all_reduce.h"
32
#include "paddle/fluid/imperative/amp_auto_cast.h"
33
#include "paddle/fluid/imperative/basic_engine.h"
34
#include "paddle/fluid/imperative/data_loader.h"
35
#include "paddle/fluid/imperative/layer.h"
J
Jiabin Yang 已提交
36
#include "paddle/fluid/imperative/nccl_context.h"
37
#include "paddle/fluid/imperative/partial_grad_engine.h"
38
#include "paddle/fluid/imperative/profiler.h"
39
#include "paddle/fluid/imperative/reducer.h"
40
#include "paddle/fluid/imperative/tracer.h"
M
minqiyang 已提交
41
#include "paddle/fluid/imperative/type_defs.h"
42
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
43
#include "paddle/fluid/pybind/op_function.h"
44
#include "paddle/fluid/pybind/pybind_boost_headers.h"
L
Leo Chen 已提交
45
#include "paddle/fluid/pybind/tensor_py.h"
46

47 48 49
namespace paddle {
namespace pybind {

50 51
namespace py = ::pybind11;

52 53 54 55
class Layer : public imperative::Layer {
 public:
  using imperative::Layer::Layer;  // Inherit constructors

56 57 58 59
  std::vector<std::shared_ptr<imperative::VarBase>> Forward(
      const std::vector<std::shared_ptr<imperative::VarBase>> &inputs)
      override {
    PYBIND11_OVERLOAD(std::vector<std::shared_ptr<imperative::VarBase>>, Layer,
J
Jiabin Yang 已提交
60
                      Forward, inputs);  // NOLINT
61 62 63
  }
};

L
Leo Chen 已提交
64 65 66 67 68
static const platform::Place PyObjectToPlace(const py::object &place_obj) {
  if (py::isinstance<platform::CPUPlace>(place_obj)) {
    return place_obj.cast<platform::CPUPlace>();
  } else if (py::isinstance<platform::CUDAPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPlace>();
69 70
  } else if (py::isinstance<platform::XPUPlace>(place_obj)) {
    return place_obj.cast<platform::XPUPlace>();
L
Leo Chen 已提交
71 72
  } else if (py::isinstance<platform::CUDAPinnedPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPinnedPlace>();
73 74
  } else if (py::isinstance<platform::Place>(place_obj)) {
    return place_obj.cast<platform::Place>();
L
Leo Chen 已提交
75 76
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
77 78
        "Place should be one of "
        "Place/CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace"));
L
Leo Chen 已提交
79 80 81 82 83 84 85
  }
}

static void InitTensorForVarBase(imperative::VarBase *self,
                                 const py::array &array,
                                 const platform::Place place,
                                 bool persistable = false,
86 87
                                 bool zero_copy = false, std::string name = "",
                                 int stop_gradient = -1) {
L
Leo Chen 已提交
88
  if (name == "") {
89 90
    name =
        imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor");
L
Leo Chen 已提交
91
  }
92 93 94
  VLOG(5) << "Init Tensor as: / name: " << name
          << " / persistable: " << persistable << " / zero_copy: " << zero_copy
          << " / stop_gradient: " << stop_gradient;
L
Leo Chen 已提交
95
  new (self) imperative::VarBase(name);
96
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
L
Leo Chen 已提交
97 98
  if (platform::is_cpu_place(place)) {
    SetTensorFromPyArray<platform::CPUPlace>(
99
        tensor, array, BOOST_GET_CONST(platform::CPUPlace, place), zero_copy);
100 101 102
  } else if (platform::is_xpu_place(place)) {
    SetTensorFromPyArray<platform::XPUPlace>(
        tensor, array, BOOST_GET_CONST(platform::XPUPlace, place), zero_copy);
L
Leo Chen 已提交
103 104
  } else if (platform::is_gpu_place(place)) {
    SetTensorFromPyArray<platform::CUDAPlace>(
105
        tensor, array, BOOST_GET_CONST(platform::CUDAPlace, place), zero_copy);
L
Leo Chen 已提交
106 107
  } else if (platform::is_cuda_pinned_place(place)) {
    SetTensorFromPyArray<platform::CUDAPinnedPlace>(
108 109
        tensor, array, BOOST_GET_CONST(platform::CUDAPinnedPlace, place),
        zero_copy);
110
  } else {
L
Leo Chen 已提交
111
    PADDLE_THROW(platform::errors::InvalidArgument(
112
        "Place should be one of CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace"));
J
Jiabin Yang 已提交
113
  }
114 115 116
  if (stop_gradient != -1) {
    self->SetOverridedStopGradient(stop_gradient);
  }
L
Leo Chen 已提交
117
  self->SetPersistable(persistable);
118 119 120 121 122 123
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
                                           const py::kwargs &kwargs) {
124
  VLOG(4) << "Init VarBase from kwargs: ";
125 126
  PADDLE_ENFORCE_EQ(
      kwargs.contains("value"), true,
127 128
      platform::errors::NotFound(
          "The kwargs used to create Varbase misses argument: value"));
L
Leo Chen 已提交
129 130 131 132 133 134 135 136
  auto persistable = kwargs.contains("persistable")
                         ? kwargs["persistable"].cast<bool>()
                         : false;
  auto array = kwargs.contains("value") ? kwargs["value"].cast<py::array>()
                                        : py::array();
  auto zero_copy =
      kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast<bool>() : false;
  auto name = kwargs.contains("name") ? kwargs["name"].cast<std::string>() : "";
137 138 139
  auto stop_gradient = kwargs.contains("stop_gradient")
                           ? kwargs["stop_gradient"].cast<int>()
                           : -1;
L
Leo Chen 已提交
140 141 142
  auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
  auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"])
                                        : default_place;
143 144
  InitTensorForVarBase(self, array, place, persistable, zero_copy, name,
                       stop_gradient);
145
}
146

147 148 149
template <typename P>
static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
                                        const py::array &array, const P &place,
L
Leo Chen 已提交
150 151
                                        bool persistable = false,
                                        bool zero_copy = false,
152 153 154 155 156
                                        std::string name = "",
                                        int stop_gradient = -1) {
  VLOG(4) << "Init VarBase from Arg: ";
  // 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name , 6:
  // stop_gradient
L
Leo Chen 已提交
157
  if (name == "") {
158 159
    name =
        imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor");
L
Leo Chen 已提交
160
  }
161 162 163
  VLOG(5) << "Init Tensor as: / name: " << name
          << " / persistable: " << persistable << " / zero_copy: " << zero_copy
          << " / stop_gradient: " << stop_gradient;
L
Leo Chen 已提交
164
  new (self) imperative::VarBase(name);
165 166
  self->SetPersistable(persistable);
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
167 168 169
  if (stop_gradient != -1) {
    self->SetOverridedStopGradient(stop_gradient);
  }
170 171 172 173 174 175
  SetTensorFromPyArray<P>(tensor, array, place, zero_copy);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
L
Leo Chen 已提交
176
                                               const py::array &array) {
177
  VLOG(4) << "Init VarBase from numpy: ";
L
Leo Chen 已提交
178 179
  auto place = imperative::GetCurrentTracer()->ExpectedPlace();
  InitTensorForVarBase(self, array, place);
180
}
181

182 183 184 185 186
static void InitVarBaseFromTensorWithArgDefault(
    imperative::VarBase *self, const framework::LoDTensor &tensor) {
  VLOG(4) << "Init VarBase";
  auto place = imperative::GetCurrentTracer()->ExpectedPlace();
  new (self) imperative::VarBase(
187
      imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"));
188 189 190 191 192 193 194 195 196 197 198 199 200 201
  self->SetPersistable(false);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor.type());
  auto *new_tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
  // Same place,share data directly
  if (place == tensor.place()) {
    new_tensor->ShareDataWith(tensor);
    VLOG(4) << "Same place, do ShareDataWith";
  } else {
    framework::TensorCopy(tensor, place, new_tensor);
    VLOG(4) << "Different place, do TensorCopy";
  }
}

202 203 204 205 206
static std::string GetTypeName(const imperative::VarBase &var) {
  if (var.Type() == framework::proto::VarType::RAW) {
    return "RAW";
  } else if (!var.Var().IsInitialized()) {
    return "nullptr";
207
  } else {
208
    return framework::ToTypeName(var.Var().Type());
209 210
  }
}
L
Leo Chen 已提交
211

212
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
213 214 215 216 217 218

template <typename T>
static T PyObjectCast(PyObject *obj) {
  try {
    return py::cast<T>(py::handle(obj));
  } catch (py::cast_error &) {
219 220
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Python object is not type of %s", typeid(T).name()));
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
  }
}

// NOTE(zjl): py::handle is a very light wrapper of PyObject *.
// Unlike py::object, py::handle does not change reference count of PyObject *.
static std::vector<std::shared_ptr<imperative::VarBase>>
GetVarBaseListFromPyHandle(const py::handle &handle) {
  PyObject *py_obj = handle.ptr();  // get underlying PyObject
  // Python None is not nullptr in C++!
  if (!py_obj || py_obj == Py_None) {
    return {};
  }

  std::vector<std::shared_ptr<imperative::VarBase>> result;

236
  if (PyList_Check(py_obj)) {  // List of VarBase
237 238 239
    size_t len = PyList_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
240 241 242
      PyObject *py_ivar = PyList_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
243 244 245
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
246
  } else if (PyTuple_Check(py_obj)) {  // Tuple of VarBase
247 248 249
    size_t len = PyTuple_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
250 251 252
      PyObject *py_ivar = PyTuple_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
253 254 255
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
256 257 258
  } else {  // VarBase
    result.emplace_back(
        PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_obj));
259 260 261 262 263
  }

  return result;
}

J
Jiabin Yang 已提交
264 265 266
static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
    const PyNameVarBaseMap &map) {
  imperative::NameVarBaseMap result;
267 268 269 270 271 272
  for (auto &pair : map) {
    auto var_vec = GetVarBaseListFromPyHandle(pair.second);
    if (!var_vec.empty()) {
      result.emplace(pair.first, std::move(var_vec));
    }
  }
J
Jiabin Yang 已提交
273

274 275 276
  PADDLE_ENFORCE_EQ(
      PyErr_Occurred(), nullptr,
      platform::errors::InvalidArgument(py::str(py::handle(PyErr_Occurred()))));
277 278 279
  return result;
}

280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344
static bool PyCheckInteger(PyObject *obj) {
#if PY_VERSION_HEX < 0x03000000
  return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj);
#else
  return PyLong_Check(obj) && !PyBool_Check(obj);
#endif
}

// NOTE(zhiqiu): Revised version of PySlice_GetIndices. From:
// https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103
// Original PySlice_GetIndices return wrong result when
// slice_item contains long int, such as arr[:180L].
// NOT sure why this happens !!!
// Besides, PySlice_GetIndices cannot raise error when float in slice item.
// So, I make a revised version of PySlice_GetIndices, named to
// _PySlice_GetIndices. Try to use _PySlice_Unpack which is more robust than
// PySlice_GetIndices in the future.
static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
                               Py_ssize_t *start, Py_ssize_t *stop,
                               Py_ssize_t *step) {
  /* XXX support long ints */
  if (r->step == Py_None) {
    *step = 1;
  } else {
    if (PyCheckInteger(r->step)) {
      *step = PyLong_AsLong(r->step);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->step)->tp_name)));
    }
  }
  if (r->start == Py_None) {
    *start = *step < 0 ? length - 1 : 0;
  } else {
    if (PyCheckInteger(r->start)) {
      *start = PyLong_AsLong(r->start);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->start)->tp_name)));
    }
    if (*start < 0) *start += length;
  }
  if (r->stop == Py_None) {
    *stop = *step < 0 ? -1 : length;
  } else {
    if (PyCheckInteger(r->stop)) {
      *stop = PyLong_AsLong(r->stop);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->stop)->tp_name)));
    }
    if (*stop < 0) *stop += length;
  }
  if (*stop > length) return -1;
  if (*start >= length) return -1;
  if (*step == 0) return -1;
  return 0;
}

S
songyouwei 已提交
345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
                               std::vector<int> *slice_axes,
                               std::vector<int> *slice_starts,
                               std::vector<int> *slice_ends,
                               std::vector<int> *slice_strides,
                               std::vector<int> *decrease_axis,
                               std::vector<int> *infer_flags) {
  // We allow indexing by Integers, Slices, and tuples of those
  // types.
  // Ellipsis and None are not supported yet.
  // wrap to tuple
  PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
  PADDLE_ENFORCE_EQ(
      tensor->IsInitialized(), true,
      platform::errors::InvalidArgument("tensor has not been initialized"));
  const auto &shape = tensor->dims();
  const int rank = shape.size();
  const int size = PyTuple_GET_SIZE(index);
  PADDLE_ENFORCE_EQ(
      size <= rank, true,
      platform::errors::InvalidArgument(
          "too many indices (%d) for tensor of dimension %d", size, rank));
  for (int dim = 0; dim < size; ++dim) {
    PyObject *slice_item = PyTuple_GetItem(index, dim);
369 370 371 372 373 374 375
    PADDLE_ENFORCE_EQ(PyCheckInteger(slice_item) || PySlice_Check(slice_item),
                      true,
                      platform::errors::InvalidArgument(
                          "Currently, VarBase.__getitem__() only allows "
                          "indexing by Integers, Slices, and tuples of "
                          "these types, but received %s in %dth slice item",
                          std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
S
songyouwei 已提交
376 377
    infer_flags->push_back(1);
    int dim_len = shape[dim];
378 379
    if (PyCheckInteger(slice_item)) {
      // integer, PyLong_AsLong supports both int and long
S
songyouwei 已提交
380
      int start = static_cast<int>(PyLong_AsLong(slice_item));
H
hong 已提交
381
      auto s_t = start;
S
songyouwei 已提交
382
      start = start < 0 ? start + dim_len : start;
H
hong 已提交
383 384 385 386 387 388 389 390 391 392
      if (start >= dim_len) {
        std::string str_error_message =
            "The starting index " + std::to_string(s_t) +
            " of slice is out of bounds in tensor " + std::to_string(dim) +
            "-th axis, it shound be in the range of [" +
            std::to_string(-dim_len) + ", " + std::to_string(dim_len) + ")";
        // py::index_error is corresponding to IndexError in Python
        // Used to indicate out of bounds access in __getitem__, __setitem__
        throw py::index_error(str_error_message);
      }
S
songyouwei 已提交
393 394 395 396 397 398
      slice_axes->push_back(dim);
      slice_starts->push_back(start);
      slice_ends->push_back(start + 1);
      slice_strides->push_back(1);
      decrease_axis->push_back(dim);
    } else {
399
      // slice item
S
songyouwei 已提交
400
      Py_ssize_t start, end, step;
401 402 403
      PySliceObject *p = reinterpret_cast<PySliceObject *>(slice_item);
      _PySlice_GetIndices(p, dim_len, &start, &end, &step);

S
songyouwei 已提交
404
      // :: or : or 0:dim_len:1
405 406 407
      if (start == 0 && end == dim_len && step == 1) {
        continue;
      }
S
songyouwei 已提交
408 409 410 411 412 413 414 415 416
      slice_axes->push_back(dim);
      slice_starts->push_back(start);
      slice_ends->push_back(end);
      slice_strides->push_back(step);
    }
  }
  if (!PyTuple_Check(_index)) Py_DecRef(index);
}

417
// Bind Methods
J
Jiabin Yang 已提交
418
void BindImperative(py::module *m_ptr) {
419 420
  auto &m = *m_ptr;

421 422
  BindOpFunctions(&m);

423 424
#ifndef _WIN32
  // Dygraph DataLoader signal handler
425 426 427 428 429 430 431 432 433 434 435 436 437
  m.def("_set_process_pids", [](int64_t key, py::object &obj) {
    PADDLE_ENFORCE_EQ(
        py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj), true,
        platform::errors::InvalidArgument(
            "The subprocess ids set in DataLoader is illegal."
            "Expected data type is tuple or list, but received %s",
            obj.get_type()));
    py::list pids = py::cast<py::list>(obj);
    std::set<pid_t> pids_set = {};
    for (size_t i = 0; i < pids.size(); i++) {
      pids_set.insert(pids[i].cast<pid_t>());
    }
    imperative::SetLoadProcessPIDs(key, pids_set);
438
  });
439 440
  m.def("_erase_process_pids",
        [](int64_t key) { imperative::EraseLoadProcessPIDs(key); });
441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
  m.def("_set_process_signal_handler",
        []() { imperative::SetLoadProcessSignalHandler(); });
  m.def("_throw_error_if_process_failed",
        []() { imperative::ThrowErrorIfLoadProcessFailed(); });

  // Dygraph DataLoader reader process & thread related functions
  m.def(
      "_convert_to_tensor_list",
      [](py::object &obj) -> py::list {
        // 0. input data check
        PADDLE_ENFORCE(
            py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj),
            platform::errors::InvalidArgument(
                "The batch data read into DataLoader is illegal."
                "Expected data type is tuple or list, but received %s",
                obj.get_type()));
        py::list batch = py::cast<py::list>(obj);
        py::list tensors;
        for (size_t i = 0; i < batch.size(); ++i) {
          // 1. cast to python array
          auto array = batch[i].cast<py::array>();
          PADDLE_ENFORCE_NE(
              string::Sprintf("%s", array.dtype()).compare("object"), 0,
              platform::errors::InvalidArgument(
                  "Faild to convert input data to a regular ndarray.\n  * "
                  "Usually this means the input data contains nested "
                  "lists with different lengths.\n  * Check the reader "
                  "function passed to 'set_(sample/sample_list/batch)"
                  "_generator' to locate the data causes this issue."));
          // 2. construcct LoDTensor
          framework::LoDTensor t;
          SetTensorFromPyArray<platform::CPUPlace>(&t, array,
                                                   platform::CPUPlace(), true);
          // 3. allocate shared memory
          void *data_ptr = t.data<void>();
          size_t data_size = t.numel() * framework::SizeOfType(t.type());
          auto shared_writer_holder =
              memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
          // 4. maintain mmap fd set & backup ipc_name
          const std::string &ipc_name = shared_writer_holder->ipc_name();
          memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
          // 5. copy data & reset holder
          memory::Copy(platform::CPUPlace(), shared_writer_holder->ptr(),
                       platform::CPUPlace(), data_ptr, data_size);
          t.ResetHolder(shared_writer_holder);
          // 6. append to result list
          tensors.append(t);
        }
        return tensors;
      },
      py::return_value_policy::take_ownership);

  m.def("_remove_tensor_list_mmap_fds", [](py::list &tensor_list) {
    for (size_t i = 0; i < tensor_list.size(); ++i) {
      auto t = tensor_list[i].cast<framework::LoDTensor>();
      auto *mmap_writer_allocation =
          dynamic_cast<memory::allocation::MemoryMapWriterAllocation *>(
              t.Holder().get());
      PADDLE_ENFORCE_NOT_NULL(
          mmap_writer_allocation,
          platform::errors::NotFound("The shared memory of LoDTensor in "
                                     "DataLoader's child process has been "
                                     "released."));
      memory::allocation::MemoryMapFdSet::Instance().Remove(
          mmap_writer_allocation->ipc_name());
    }
  });

  m.def("_cleanup_mmap_fds",
        []() { memory::allocation::MemoryMapFdSet::Instance().Clear(); });
#endif

513 514 515 516 517
  m.def("start_imperative_gperf_profiler",
        []() { imperative::StartProfile(); });

  m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); });

Z
Zeng Jinle 已提交
518 519 520
  m.def("_is_dygraph_debug_enabled",
        []() { return imperative::IsDebugEnabled(); });
  m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); });
521 522 523 524
  m.def("_switch_tracer",
        [](const std::shared_ptr<imperative::Tracer> &tracer) {
          imperative::SetCurrentTracer(tracer);
        });
Z
Zeng Jinle 已提交
525

526
  py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
527
      m, "VarBase", R"DOC()DOC")
Z
Zeng Jinle 已提交
528
      .def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
529 530 531 532 533 534 535
      .def("__init__",
           [](imperative::VarBase &self) {
             std::string name =
                 imperative::GetCurrentTracer()->GenerateUniqueName(
                     "generated_tensor");
             new (&self) imperative::VarBase(name);
           })
J
Jiabin Yang 已提交
536
      .def("__init__",
537 538 539
           [](imperative::VarBase &self, framework::proto::VarType::Type dtype,
              const std::vector<int> &dims, const py::handle &name,
              framework::proto::VarType::Type type, bool persistable) {
540
             VLOG(4) << "Init VarBase";
541 542 543
             std::string act_name = "";
             if (!name.ptr() || name.ptr() == Py_None) {
               act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
544
                   "generated_tensor");
545 546 547 548
             } else {
               act_name = name.cast<std::string>();
             }
             new (&self) imperative::VarBase(act_name);
J
Jiabin Yang 已提交
549 550 551 552 553 554 555 556 557
             self.SetPersistable(persistable);
             self.SetType(type);
             self.SetDataType(dtype);
             if (type == framework::proto::VarType::LOD_TENSOR) {
               auto *tensor =
                   self.MutableVar()->GetMutable<framework::LoDTensor>();
               tensor->Resize(framework::make_ddim(dims));
             }
           })
558 559
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
560 561
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
562 563 564 565
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::XPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
566 567
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
568 569
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
570 571
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPinnedPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
572 573
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
L
Leo Chen 已提交
574
      .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
575
      .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
576
      .def("__init__", &InitVarBaseFromNumpyWithKwargs)
577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602
      .def("__setitem__",
           [](std::shared_ptr<imperative::VarBase> &self, py::handle _index,
              py::object &value_obj) {
             auto self_tensor =
                 self->MutableVar()->GetMutable<framework::LoDTensor>();
             auto self_numpy = TensorToPyArray(*self_tensor);

             if (py::isinstance<py::array>(value_obj) ||
                 py::isinstance<py::int_>(value_obj) ||
                 py::isinstance<py::float_>(value_obj)) {
               auto value_numpy = value_obj;
               self_numpy[_index] = value_numpy;
               SetTensorFromPyArray(self_tensor, self_numpy,
                                    self_tensor->place(), true);

             } else {
               auto value =
                   value_obj.cast<std::shared_ptr<imperative::VarBase>>();
               auto value_tensor =
                   value->MutableVar()->GetMutable<framework::LoDTensor>();
               auto value_numpy = TensorToPyArray(*value_tensor);

               self_numpy[_index] = value_numpy;
               SetTensorFromPyArray(self_tensor, self_numpy,
                                    self_tensor->place(), true);
             }
603 604 605 606
             // NOTE(liym27):
             // Increase the version of VarBase self because __setitem__ is an
             // inplace operator for the VarBase self.
             self->BumpInplaceVersion();
607
           })
608
      .def("__getitem__",
S
songyouwei 已提交
609
           [](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
610
             std::vector<int> slice_axes, slice_starts, slice_ends,
S
songyouwei 已提交
611 612 613 614 615 616
                 slice_strides, decrease_axis, infer_flags;
             auto tensor =
                 self->MutableVar()->GetMutable<framework::LoDTensor>();
             ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
                                &slice_starts, &slice_ends, &slice_strides,
                                &decrease_axis, &infer_flags);
617 618 619 620
             // release gil and do tracing
             py::gil_scoped_release release;
             const auto &tracer = imperative::GetCurrentTracer();
             if (slice_axes.empty()) {
S
songyouwei 已提交
621
               return self;
622
             } else {
S
songyouwei 已提交
623
               imperative::NameVarBaseMap ins = {{"Input", {self}}};
624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645
               framework::AttributeMap attrs = {
                   {"axes", slice_axes},
                   {"starts", slice_starts},
                   {"ends", slice_ends},
                   {"infer_flags", infer_flags},
                   {"decrease_axis", decrease_axis}};
               auto out = std::shared_ptr<imperative::VarBase>(
                   new imperative::VarBase(tracer->GenerateUniqueName()));
               imperative::NameVarBaseMap outs = {{"Out", {out}}};
               std::string op_type = "slice";
               for (auto stride : slice_strides) {
                 if (stride != 1) {
                   op_type = "strided_slice";
                   attrs.insert({"strides", slice_strides});
                   attrs.erase("decrease_axis");
                   break;
                 }
               }
               tracer->TraceOp(op_type, ins, outs, std::move(attrs));
               return out;
             }
           })
646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667
      .def("_inplace_version",
           [](imperative::VarBase &self) -> uint32_t {
             const auto &var = self.MutableVar();
             PADDLE_ENFORCE_EQ(
                 var->IsInitialized(), true,
                 platform::errors::InvalidArgument(
                     "Tensor of %s is Empty, please check if it has no data.",
                     self.Name()));
             return var->CurrentInplaceVersion();
           })
      .def("_bump_inplace_version",
           [](std::shared_ptr<imperative::VarBase> &self) {
             // NOTE(liym27): _bump_inplace_version is only used for inplace
             // operation
             self->BumpInplaceVersion();
           },
           R"DOC(
        **Notes**:
            **This API is ONLY available in Dygraph mode.**
            **This is a very low level API. Users should not use it directly. **
         Bump the version whenever the Tensor is modified through an inplace operation.
            )DOC")
668 669 670 671 672 673 674
      .def("numpy",
           [](imperative::VarBase &self) -> py::array {
             const auto &tensor =
                 self.MutableVar()->Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(
                 tensor.IsInitialized(), true,
                 platform::errors::InvalidArgument(
675
                     "Tensor of %s is Empty, please check if it has no data.",
676 677 678 679
                     self.Name()));
             return TensorToPyArray(tensor, true);
           },
           R"DOC(
Z
Zhou Wei 已提交
680 681
        Returns a numpy array shows the value of current Tensor.
        
682
        Returns:
Z
Zhou Wei 已提交
683
            ndarray: The numpy value of current Tensor.
684 685

        Returns type:
Z
Zhou Wei 已提交
686
            ndarray: dtype is same as current Tensor
687 688 689 690

        Examples:
            .. code-block:: python

Z
Zhou Wei 已提交
691
                import paddle
692 693
                import numpy as np
                data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
Z
Zhou Wei 已提交
694 695 696 697
                linear = paddle.nn.Linear(32, 64)
                data = paddle.to_tensor(data)
                x = linear(data)
                print(x.numpy())
698
       )DOC")
699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736
      .def(
          "detach",
          [](const imperative::VarBase &self)
              -> std::shared_ptr<imperative::VarBase> {
                PADDLE_ENFORCE_EQ(
                    self.Var().IsInitialized(), true,
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized!", self.Name()));

                PADDLE_ENFORCE_EQ(
                    self.Var().IsType<framework::LoDTensor>() ||
                        self.Var().IsType<framework::SelectedRows>(),
                    true,
                    platform::errors::InvalidArgument(
                        "Type of Tensor[%s] must be LoDTensor or SelectedRows!",
                        self.Name()));

                auto detach_var = std::make_shared<imperative::VarBase>(
                    true, "detach_" + self.Name());

                detach_var->SetPersistable(self.Persistable());
                detach_var->SetType(self.Type());
                detach_var->SetDataType(self.DataType());

                // NOTE(liym27):
                // Call Variable::SharePlaceholderWith but not
                // Tensor::ShareDataWith or Tensor::ShareBufferWith, because
                // `detach_var` should share the same TensorInplaceVersion with
                // `self`, and only SharePlaceholderWith can also share the same
                // TensorInplaceVersion, which is used to check whether inplace
                // operations are correct.
                detach_var->MutableVar()->SharePlaceholderWith(self.Var());

                VLOG(3) << "The detached Tensor(" << detach_var->Name()
                        << ") share data with " << self.Name();
                return detach_var;
              },
          py::return_value_policy::take_ownership, R"DOC(
737

738
        Returns a new Tensor, detached from the current graph.
Z
Zhou Wei 已提交
739 740
        It will share data with origin Tensor and always doesn't have a Tensor copy.
        In addition, the detached Tensor doesn't provide gradient propagation.
741

742
        Returns: The detached Tensor.
743 744 745 746

        Examples:
            .. code-block:: python

747
                import paddle
Z
Zhou Wei 已提交
748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772

                x = paddle.to_tensor(1.0, stop_gradient=False)
                detach_x = x.detach()
                detach_x[:] = 10.0
                print(x)  # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False,
                          #        [10.])
                y = x**2
                y.backward()
                print(x.grad)         # [20.0]
                print(detach_x.grad)  # None, 'stop_gradient=True' by default

                detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
                z = detach_x**3
                z.backward()

                print(x.grad)         # [20.0], detach_x is detached from x's graph, not affect each other
                print(detach_x.grad)  # [300.0], detach_x has its own graph

                # Due to sharing of data with origin Tensor, There are some unsafe operations:
                y = 2 * x
                detach_x[:] = 5.0
                y.backward() 
                # It will raise Error:
                #   one of the variables needed for gradient computation has been modified by an inplace operation.
             
773 774 775
       )DOC")
      .def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(

776
        Only for Tensor that has gradient, normally we use this for Parameters since other temporary Tensor doesen't has gradient.
777

778
        The Gradient of current Tensor will be set to ``0`` .
779 780 781 782 783 784

        Returns:  None

        Examples:
             .. code-block:: python

785
                import paddle
Z
Zhou Wei 已提交
786 787 788 789 790 791 792
                input = paddle.uniform([10, 2])
                linear = paddle.nn.Linear(2, 3)
                out = linear(input)
                out.backward()
                print("Before clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
                linear.weight.clear_gradient()
                print("After clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
793
      )DOC")
Z
Zhou Wei 已提交
794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841
      .def("clone",
           [](std::shared_ptr<imperative::VarBase> &self) {
             const auto &tensor = self->Var().Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(
                 tensor.IsInitialized(), true,
                 platform::errors::InvalidArgument(
                     "%s has not been initialized", self->Name()));
             auto tracer = imperative::GetCurrentTracer();
             auto new_var = std::make_shared<imperative::VarBase>(
                 true, tracer->GenerateUniqueName(self->Name() + "_clone"));
             framework::AttributeMap attrs;
             imperative::NameVarBaseMap ins = {{"X", {self}}};
             imperative::NameVarBaseMap outs = {{"Out", {new_var}}};
             tracer->TraceOp("assign", ins, outs, attrs);
             return new_var;
           },
           py::return_value_policy::copy, R"DOC(

        Returns a new Tensor, which is clone of origin Tensor, and it remains in the current graph.
        It will always have a Tensor copy.
        Tn addition, the cloned Tensor provides gradient propagation.

        Returns: The cloned Tensor.

        Examples:
            .. code-block:: python

              import paddle

              x = paddle.to_tensor(1.0, stop_gradient=False)
              clone_x = x.clone()
              y = clone_x**2
              y.backward()
              print(clone_x.stop_gradient) # False
              print(clone_x.grad)          # [2.0], support gradient propagation
              print(x.stop_gradient)       # False
              print(x.grad)                # [2.0], clone_x support gradient propagation for x

              x = paddle.to_tensor(1.0)
              clone_x = x.clone()
              clone_x.stop_gradient = False
              z = clone_x**3
              z.backward()
              print(clone_x.stop_gradient) # False
              print(clone_x.grad)          # [3.0], support gradient propagation
              print(x.stop_gradient) # True
              print(x.grad)          # None
       )DOC")
L
Leo Chen 已提交
842
      .def("_run_backward",
843 844
           [](imperative::VarBase &self, const imperative::Tracer &tracer,
              bool retain_graph) {
845 846
             // TODO(jiabin): when we impl more backward execution we can
             // select them
847
             auto *engine = tracer.GetEngine();
848
             engine->Init(&self, retain_graph);
849
             VLOG(3) << "Start backward";
L
Leo Chen 已提交
850 851 852 853 854 855 856 857 858 859
             engine->Execute();
             VLOG(3) << "Finish backward";
           },
           py::call_guard<py::gil_scoped_release>())
      .def("_grad_name", &imperative::VarBase::GradVarName)
      .def("_grad_value",
           [](imperative::VarBase &self) {
             return self.MutableGradVar()->Get<framework::LoDTensor>();
           },
           py::return_value_policy::reference)
860 861 862 863
      .def("_set_grad_type",
           [](imperative::VarBase &self, framework::proto::VarType::Type type) {
             self.MutableGradVarBase()->SetType(type);
           })
864
      .def("_grad_ivar",
J
Jiabin Yang 已提交
865 866
           [](const imperative::VarBase &self) {
             auto &grad_var = self.GradVarBase();
867 868 869 870 871 872 873 874 875 876 877
             if (grad_var && grad_var->Var().IsInitialized()) {
               auto *tensor =
                   grad_var->MutableVar()->IsType<framework::LoDTensor>()
                       ? grad_var->MutableVar()
                             ->GetMutable<framework::LoDTensor>()
                       : grad_var->MutableVar()
                             ->GetMutable<framework::SelectedRows>()
                             ->mutable_value();
               if (tensor->IsInitialized()) {
                 return grad_var;
               }
J
Jiabin Yang 已提交
878
             }
879
             return std::shared_ptr<imperative::VarBase>(nullptr);
J
Jiabin Yang 已提交
880 881
           },
           py::return_value_policy::copy)
882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911
      .def("_is_sparse",
           [](imperative::VarBase &self) {
             return self.Var().IsType<framework::SelectedRows>();
           })
      .def("_allreduce",
           [](imperative::VarBase &self,
              const imperative::ParallelStrategy &strategy) {
             if (strategy.nranks_ > 1) {
#ifdef PADDLE_WITH_NCCL
#if NCCL_VERSION_CODE >= 2212
               imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
#else
               if (!self.Var().IsType<framework::SelectedRows>()) {
                 imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
               } else {
                 PADDLE_THROW(platform::errors::Unimplemented(
                     "Imperative SelectedRows allreduce is not supported when "
                     "paddle is compiled with NCCL verison lower than v2.2.12. "
                     "You can set is_sparse=False for the Layer containing "
                     "this argument, such as Embedding(is_sparse=False)."));
               }
#endif  // NCCL_VERSION_CODE
#else
               PADDLE_THROW(platform::errors::Unimplemented(
                   "Imperative allreduce is not supported when paddle is "
                   "not compiled with NCCL."));
#endif  // PADDLE_WITH_NCCL
             }
           },
           py::call_guard<py::gil_scoped_release>())
912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032
      .def("cpu",
           [](const std::shared_ptr<imperative::VarBase> &self) {
             if (platform::is_cpu_place(self->Place())) {
               return self;
             } else {
               auto new_var = self->NewVarBase(platform::CPUPlace(), true);
               new_var->SetOverridedStopGradient(self->OverridedStopGradient());
               return new_var;
             }
           },
           R"DOC(
        Returns a copy of this Tensor in CPU memory.

        If this Tensor is already in CPU memory, then no copy is performed and the original Tensor is returned.

        Examples:
            .. code-block:: python

              import paddle
              x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
              print(x.place)    # CUDAPlace(0)
              
              y = x.cpu()
              print(y.place)    # CPUPlace

              )DOC")
      .def("pin_memory",
           [](const std::shared_ptr<imperative::VarBase> &self) {
#ifndef PADDLE_WITH_CUDA
             PADDLE_THROW(platform::errors::PermissionDenied(
                 "Cannot copy this Tensor to pinned memory in CPU version "
                 "Paddle, "
                 "Please recompile or reinstall Paddle with CUDA support."));
#endif
             if (platform::is_cuda_pinned_place(self->Place())) {
               return self;
             } else {
               auto new_var =
                   self->NewVarBase(platform::CUDAPinnedPlace(), true);
               new_var->SetOverridedStopGradient(self->OverridedStopGradient());
               return new_var;
             }
           },
           R"DOC(
        Returns a copy of this Tensor in pin memory.

        If this Tensor is already in pin memory, then no copy is performed and the original Tensor is returned.

        Examples:
            .. code-block:: python

              import paddle
              x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
              print(x.place)      # CUDAPlace(0)

              y = x.pin_memory()
              print(y.place)      # CUDAPinnedPlace

      )DOC")
      .def("cuda",
           [](const std::shared_ptr<imperative::VarBase> &self, int device_id,
              bool blocking) {
#ifndef PADDLE_WITH_CUDA
             PADDLE_THROW(platform::errors::PermissionDenied(
                 "Cannot copy this Tensor to GPU in CPU version Paddle, "
                 "Please recompile or reinstall Paddle with CUDA support."));
#else
             int device_count = platform::GetCUDADeviceCount();
             if (device_id == -1) {
               if (platform::is_gpu_place(self->Place())) {
                 return self;
               } else {
                 device_id = 0;
               }
             }
             PADDLE_ENFORCE_GE(
                 device_id, 0,
                 platform::errors::InvalidArgument(
                     "Can not copy Tensor to Invalid CUDAPlace(%d), device id "
                     "must inside [0, %d)",
                     device_id, device_count));
             PADDLE_ENFORCE_LT(
                 device_id, device_count,
                 platform::errors::InvalidArgument(
                     "Can not copy Tensor to Invalid CUDAPlace(%d), device id "
                     "must inside [0, %d)",
                     device_id, device_count));
             platform::CUDAPlace place = platform::CUDAPlace(device_id);
             if (platform::is_same_place(self->Place(), place)) {
               return self;
             } else {
               auto new_var = self->NewVarBase(place, blocking);
               new_var->SetOverridedStopGradient(self->OverridedStopGradient());
               return new_var;
             }
#endif
           },
           py::arg("device_id") = -1, py::arg("blocking") = true, R"DOC(
        Returns a copy of this Tensor in GPU memory.

        If this Tensor is already in GPU memory and device_id is default, 
        then no copy is performed and the original Tensor is returned.
        
        Args:
            device_id(int, optional): The destination GPU device id. Defaults to the current device.
            blocking(bool, optional): If False and the source is in pinned memory, the copy will be 
              asynchronous with respect to the host. Otherwise, the argument has no effect. Default: False.

        Examples:
            .. code-block:: python

              import paddle
              x = paddle.to_tensor(1.0, place=paddle.CPUPlace())
              print(x.place)        # CPUPlace

              y = x.cuda()
              print(y.place)        # CUDAPlace(0)

              y = x.cuda(1)
              print(y.place)        # CUDAPlace(1)
       )DOC")
1033
      .def("copy_", &imperative::VarBase::CopyFrom)
1034 1035
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CPUPlace &place,
J
Jiabin Yang 已提交
1036 1037
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
1038 1039 1040 1041 1042
      .def("_copy_to",
           [](const imperative::VarBase &self,
              const platform::CUDAPinnedPlace &place,
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
1043 1044 1045 1046
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::XPUPlace &place,
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
1047 1048
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CUDAPlace &place,
J
Jiabin Yang 已提交
1049 1050 1051
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
      .def("value", [](imperative::VarBase &self) { return self.MutableVar(); },
1052 1053 1054
           py::return_value_policy::reference)
      .def_property("name", &imperative::VarBase::Name,
                    &imperative::VarBase::SetName)
L
Leo Chen 已提交
1055 1056 1057 1058 1059
      .def_property("stop_gradient",
                    &imperative::VarBase::OverridedStopGradient,
                    &imperative::VarBase::SetOverridedStopGradient)
      .def_property("persistable", &imperative::VarBase::Persistable,
                    &imperative::VarBase::SetPersistable)
J
Jiabin Yang 已提交
1060 1061 1062 1063
      .def_property_readonly(
          "shape",
          [](imperative::VarBase &self) {
            if (self.Var().IsType<framework::LoDTensor>()) {
1064
              return framework::vectorize<int>(
J
Jiabin Yang 已提交
1065
                  self.Var().Get<framework::LoDTensor>().dims());
1066 1067 1068
            } else if (self.Var().IsType<framework::SelectedRows>()) {
              return framework::vectorize<int>(
                  self.Var().Get<framework::SelectedRows>().value().dims());
J
Jiabin Yang 已提交
1069
            } else {
1070 1071
              VLOG(2) << "It is meaningless to get shape of "
                         "variable type "
J
Jiabin Yang 已提交
1072 1073 1074 1075
                      << GetTypeName(self);
              return std::vector<int>();
            }
          })
1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104
      .def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf,
                             R"DOC(
      Whether a Tensor is leaf Tensor.

      For the Tensor whose stop_gradient is ``True`` , it will be leaf Tensor. 
      
      For the Tensor whose stop_gradient is ``False`` , it will be leaf Tensor too if it is created by user.

      Returns:
          bool: Whether a Tensor is leaf Tensor.

      Examples:
          .. code-block:: python

              import paddle

              x = paddle.to_tensor(1.)
              print(x.is_leaf) # True

              x = paddle.to_tensor(1., stop_gradient=True)
              y = x + 1
              print(x.is_leaf) # True
              print(y.is_leaf) # True

              x = paddle.to_tensor(1., stop_gradient=False)
              y = x + 1
              print(x.is_leaf) # True
              print(y.is_leaf) # False
       )DOC")
1105 1106 1107
      .def_property_readonly(
          "place", [](imperative::VarBase &self) { return self.Place(); },
          py::return_value_policy::copy)
1108 1109 1110 1111 1112 1113
      .def_property_readonly("_place_str",
                             [](imperative::VarBase &self) {
                               std::stringstream ostr;
                               ostr << self.Place();
                               return ostr.str();
                             })
J
Jiabin Yang 已提交
1114
      .def_property_readonly("type", &imperative::VarBase::Type)
L
Leo Chen 已提交
1115
      .def_property_readonly("dtype", &imperative::VarBase::DataType);
1116 1117 1118

  py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer");
  layer.def(py::init<>())
1119 1120 1121 1122 1123
      .def("forward",
           [](imperative::Layer &self,
              const std::vector<std::shared_ptr<imperative::VarBase>> &inputs) {
             return self.Forward(inputs);
           });
1124

1125 1126 1127 1128 1129
  py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "")
      .def("create_program_desc",
           &imperative::jit::ProgramDescTracer::CreateProgramDesc)
      .def("reset", &imperative::jit::ProgramDescTracer::Reset);

1130
  py::class_<imperative::Tracer, std::shared_ptr<imperative::Tracer>>(
1131
      m, "Tracer", R"DOC()DOC")
1132
      .def("__init__",
J
Jiabin Yang 已提交
1133
           [](imperative::Tracer &self) { new (&self) imperative::Tracer(); })
1134 1135 1136
      .def_property("_enable_program_desc_tracing",
                    &imperative::Tracer::IsProgramDescTracingEnabled,
                    &imperative::Tracer::SetEnableProgramDescTracing)
1137 1138
      .def_property("_enable_autocast", &imperative::Tracer::IsAutoCastEnabled,
                    &imperative::Tracer::SetEnableAutoCast)
1139
      .def_property("_has_grad", &imperative::Tracer::HasGrad,
1140
                    &imperative::Tracer::SetHasGrad)
1141 1142 1143 1144 1145 1146 1147 1148
      .def_property(
          "_expected_place",
          [](const imperative::Tracer &self) -> py::object {
            return py::cast(self.ExpectedPlace());
          },
          [](imperative::Tracer &self, const py::object &obj) {
            if (py::isinstance<platform::CUDAPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPlace *>();
L
Leo Chen 已提交
1149
              self.SetExpectedPlace(*p);
1150 1151 1152
            } else if (py::isinstance<platform::XPUPlace>(obj)) {
              auto p = obj.cast<platform::XPUPlace *>();
              self.SetExpectedPlace(*p);
1153 1154
            } else if (py::isinstance<platform::CPUPlace>(obj)) {
              auto p = obj.cast<platform::CPUPlace *>();
L
Leo Chen 已提交
1155
              self.SetExpectedPlace(*p);
1156 1157
            } else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPinnedPlace *>();
L
Leo Chen 已提交
1158
              self.SetExpectedPlace(*p);
1159
            } else {
L
Leo Chen 已提交
1160
              PADDLE_THROW(platform::errors::InvalidArgument(
1161 1162
                  "Incompatible Place Type: supports XPUPlace, CUDAPlace, "
                  "CPUPlace, "
L
Leo Chen 已提交
1163 1164
                  "and CUDAPinnedPlace, "
                  "but got Unknown Type!"));
1165 1166
            }
          })
1167 1168 1169
      .def("_get_program_desc_tracer",
           &imperative::Tracer::GetProgramDescTracer,
           py::return_value_policy::reference)
1170
      .def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
1171
           py::arg("key") = "dygraph_tmp")
1172 1173 1174 1175 1176
      .def(
          "_set_amp_op_list",
          [](imperative::Tracer &self,
             std::unordered_set<std::string> &allow_ops,
             std::unordered_set<std::string> &block_ops) {
1177 1178
            // NOTE(zhiqiu): The automatic conversion in pybind11 between
            // c++
1179 1180 1181 1182
            // STL and python set/list/dict involve a copy operation that
            // prevents pass-by-reference semantics, so it is ok to swap.
            // The reaseon why not directly pass
            // std::shared_ptr<std::unordered_set<std::string>>
1183 1184
            // is that pybind11 forbid shared_ptr<T> where T is not custom
            // type.
1185 1186 1187 1188 1189 1190 1191 1192 1193
            imperative::AmpOperators::Instance().GetAllowOps()->swap(allow_ops);
            imperative::AmpOperators::Instance().GetBlockOps()->swap(block_ops);
          })
      .def("_get_amp_op_list",
           [](imperative::Tracer &self) {
             return std::make_tuple(
                 *(imperative::AmpOperators::Instance().GetAllowOps()),
                 *(imperative::AmpOperators::Instance().GetBlockOps()));
           })
1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::XPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           })
M
minqiyang 已提交
1207
      .def("trace",
J
Jiabin Yang 已提交
1208 1209 1210 1211 1212 1213
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CUDAPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
1214 1215
             {
               py::gil_scoped_release release;
J
Jiabin Yang 已提交
1216 1217
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
1218
             }
M
minqiyang 已提交
1219
           })
J
Jiabin Yang 已提交
1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           });
1233 1234

  // define parallel context
1235 1236 1237
  py::class_<imperative::ParallelStrategy> parallel_strategy(
      m, "ParallelStrategy", "");
  parallel_strategy.def(py::init())
1238 1239
      .def_property(
          "nranks",
1240 1241
          [](const imperative::ParallelStrategy &self) { return self.nranks_; },
          [](imperative::ParallelStrategy &self, int nranks) {
1242 1243 1244
            self.nranks_ = nranks;
          })
      .def_property("local_rank",
1245
                    [](const imperative::ParallelStrategy &self) {
1246 1247
                      return self.local_rank_;
                    },
1248
                    [](imperative::ParallelStrategy &self, int local_rank) {
1249 1250 1251 1252
                      self.local_rank_ = local_rank;
                    })
      .def_property(
          "trainer_endpoints",
1253
          [](const imperative::ParallelStrategy &self) {
1254 1255
            return self.trainer_endpoints_;
          },
1256
          [](imperative::ParallelStrategy &self, std::vector<std::string> eps) {
1257 1258 1259
            self.trainer_endpoints_ = eps;
          })
      .def_property("current_endpoint",
1260
                    [](const imperative::ParallelStrategy &self) {
1261 1262
                      return self.current_endpoint_;
                    },
1263
                    [](imperative::ParallelStrategy &self,
1264 1265 1266 1267 1268 1269 1270
                       const std::string &ep) { self.current_endpoint_ = ep; })
      .def_property(
          "nrings",
          [](const imperative::ParallelStrategy &self) { return self.nrings_; },
          [](imperative::ParallelStrategy &self, int nrings) {
            self.nrings_ = nrings;
          });
1271 1272 1273 1274 1275 1276 1277 1278

  m.def(
      "dygraph_partial_grad",
      [](const std::vector<std::shared_ptr<imperative::VarBase>> &input_targets,
         const std::vector<std::shared_ptr<imperative::VarBase>>
             &output_targets,
         const std::vector<std::shared_ptr<imperative::VarBase>> &output_grads,
         const std::vector<std::shared_ptr<imperative::VarBase>> &no_grad_vars,
1279 1280
         const platform::Place &place, bool create_graph, bool retain_graph,
         bool allow_unused, bool only_inputs) {
Z
Zeng Jinle 已提交
1281 1282
        imperative::PartialGradEngine engine(
            input_targets, output_targets, output_grads, no_grad_vars, place,
1283
            create_graph, retain_graph, allow_unused, only_inputs);
1284 1285 1286 1287 1288
        engine.Execute();
        return engine.GetResult();
      },
      py::call_guard<py::gil_scoped_release>());

1289
#if defined(PADDLE_WITH_NCCL)
1290 1291 1292 1293 1294 1295
  py::class_<imperative::ParallelContext,
             std::shared_ptr<imperative::ParallelContext>>(m,
                                                           "ParallelContext");
  py::class_<imperative::NCCLParallelContext, imperative::ParallelContext,
             std::shared_ptr<imperative::NCCLParallelContext>>(
      m, "NCCLParallelContext")
1296 1297 1298
      .def(py::init<const imperative::ParallelStrategy &,
                    const platform::CUDAPlace &>())
      .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); });
1299 1300 1301 1302 1303 1304 1305

  py::class_<imperative::Reducer, std::shared_ptr<imperative::Reducer>>(
      m, "Reducer", R"DOC()DOC")
      .def(py::init(
          [](const std::vector<std::shared_ptr<imperative::VarBase>> &vars,
             const std::vector<std::vector<size_t>> &group_indices,
             const std::vector<bool> &is_sparse_gradient,
1306 1307
             std::shared_ptr<imperative::ParallelContext> parallel_ctx,
             const std::vector<size_t> &group_size_limits) {
1308
            return imperative::Reducer::SetInstance(
1309 1310
                vars, group_indices, is_sparse_gradient, parallel_ctx,
                group_size_limits);
1311 1312 1313 1314 1315 1316 1317
          }))
      .def("prepare_for_backward", &imperative::Reducer::PrepareForBackward,
           py::call_guard<py::gil_scoped_release>());

  m.def("assign_group_by_size", &imperative::AssignGroupBySize, py::arg("vars"),
        py::arg("is_sparse_gradient"),
        py::arg("group_size_limits") = std::vector<size_t>{25 * 1024 * 1024},
1318
        py::arg("tensor_indices") = std::vector<int64_t>{},
1319
        py::call_guard<py::gil_scoped_release>());
1320
#endif
1321 1322 1323 1324
}

}  // namespace pybind
}  // namespace paddle