imperative.cc 41.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/pybind/imperative.h"
16

17
#include <Python.h>
18 19 20 21
#include <pybind11/chrono.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h>
22

23
#include <memory>
24
#include <set>
J
Jiabin Yang 已提交
25
#include <string>
26
#include <unordered_map>
27
#include <unordered_set>
28
#include <utility>
J
Jiabin Yang 已提交
29
#include <vector>
30

31
#include "paddle/fluid/imperative/all_reduce.h"
32
#include "paddle/fluid/imperative/amp_auto_cast.h"
33
#include "paddle/fluid/imperative/basic_engine.h"
34
#include "paddle/fluid/imperative/data_loader.h"
35
#include "paddle/fluid/imperative/layer.h"
J
Jiabin Yang 已提交
36
#include "paddle/fluid/imperative/nccl_context.h"
37
#include "paddle/fluid/imperative/partial_grad_engine.h"
38
#include "paddle/fluid/imperative/profiler.h"
39
#include "paddle/fluid/imperative/tracer.h"
M
minqiyang 已提交
40
#include "paddle/fluid/imperative/type_defs.h"
41
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
42
#include "paddle/fluid/pybind/op_function.h"
43
#include "paddle/fluid/pybind/pybind_boost_headers.h"
L
Leo Chen 已提交
44
#include "paddle/fluid/pybind/tensor_py.h"
45

46 47 48
namespace paddle {
namespace pybind {

49 50
namespace py = ::pybind11;

51 52 53 54
class Layer : public imperative::Layer {
 public:
  using imperative::Layer::Layer;  // Inherit constructors

55 56 57 58
  std::vector<std::shared_ptr<imperative::VarBase>> Forward(
      const std::vector<std::shared_ptr<imperative::VarBase>> &inputs)
      override {
    PYBIND11_OVERLOAD(std::vector<std::shared_ptr<imperative::VarBase>>, Layer,
J
Jiabin Yang 已提交
59
                      Forward, inputs);  // NOLINT
60 61 62
  }
};

L
Leo Chen 已提交
63 64 65 66 67
static const platform::Place PyObjectToPlace(const py::object &place_obj) {
  if (py::isinstance<platform::CPUPlace>(place_obj)) {
    return place_obj.cast<platform::CPUPlace>();
  } else if (py::isinstance<platform::CUDAPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPlace>();
68 69
  } else if (py::isinstance<platform::XPUPlace>(place_obj)) {
    return place_obj.cast<platform::XPUPlace>();
L
Leo Chen 已提交
70 71 72 73
  } else if (py::isinstance<platform::CUDAPinnedPlace>(place_obj)) {
    return place_obj.cast<platform::CUDAPinnedPlace>();
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
74
        "Place should be one of CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace"));
L
Leo Chen 已提交
75 76 77 78 79 80 81
  }
}

static void InitTensorForVarBase(imperative::VarBase *self,
                                 const py::array &array,
                                 const platform::Place place,
                                 bool persistable = false,
82 83
                                 bool zero_copy = false, std::string name = "",
                                 int stop_gradient = -1) {
L
Leo Chen 已提交
84
  if (name == "") {
85 86
    name =
        imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor");
L
Leo Chen 已提交
87
  }
88 89 90
  VLOG(5) << "Init Tensor as: / name: " << name
          << " / persistable: " << persistable << " / zero_copy: " << zero_copy
          << " / stop_gradient: " << stop_gradient;
L
Leo Chen 已提交
91
  new (self) imperative::VarBase(name);
92
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
L
Leo Chen 已提交
93 94
  if (platform::is_cpu_place(place)) {
    SetTensorFromPyArray<platform::CPUPlace>(
95
        tensor, array, BOOST_GET_CONST(platform::CPUPlace, place), zero_copy);
96 97 98
  } else if (platform::is_xpu_place(place)) {
    SetTensorFromPyArray<platform::XPUPlace>(
        tensor, array, BOOST_GET_CONST(platform::XPUPlace, place), zero_copy);
L
Leo Chen 已提交
99 100
  } else if (platform::is_gpu_place(place)) {
    SetTensorFromPyArray<platform::CUDAPlace>(
101
        tensor, array, BOOST_GET_CONST(platform::CUDAPlace, place), zero_copy);
L
Leo Chen 已提交
102 103
  } else if (platform::is_cuda_pinned_place(place)) {
    SetTensorFromPyArray<platform::CUDAPinnedPlace>(
104 105
        tensor, array, BOOST_GET_CONST(platform::CUDAPinnedPlace, place),
        zero_copy);
106
  } else {
L
Leo Chen 已提交
107
    PADDLE_THROW(platform::errors::InvalidArgument(
108
        "Place should be one of CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace"));
J
Jiabin Yang 已提交
109
  }
110 111 112
  if (stop_gradient != -1) {
    self->SetOverridedStopGradient(stop_gradient);
  }
L
Leo Chen 已提交
113
  self->SetPersistable(persistable);
114 115 116 117 118 119
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithKwargs(imperative::VarBase *self,
                                           const py::kwargs &kwargs) {
120
  VLOG(4) << "Init VarBase from kwargs: ";
121 122
  PADDLE_ENFORCE_EQ(
      kwargs.contains("value"), true,
123 124
      platform::errors::NotFound(
          "The kwargs used to create Varbase misses argument: value"));
L
Leo Chen 已提交
125 126 127 128 129 130 131 132
  auto persistable = kwargs.contains("persistable")
                         ? kwargs["persistable"].cast<bool>()
                         : false;
  auto array = kwargs.contains("value") ? kwargs["value"].cast<py::array>()
                                        : py::array();
  auto zero_copy =
      kwargs.contains("zero_copy") ? kwargs["zero_copy"].cast<bool>() : false;
  auto name = kwargs.contains("name") ? kwargs["name"].cast<std::string>() : "";
133 134 135
  auto stop_gradient = kwargs.contains("stop_gradient")
                           ? kwargs["stop_gradient"].cast<int>()
                           : -1;
L
Leo Chen 已提交
136 137 138
  auto default_place = imperative::GetCurrentTracer()->ExpectedPlace();
  auto place = kwargs.contains("place") ? PyObjectToPlace(kwargs["place"])
                                        : default_place;
139 140
  InitTensorForVarBase(self, array, place, persistable, zero_copy, name,
                       stop_gradient);
141
}
142

143 144 145
template <typename P>
static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
                                        const py::array &array, const P &place,
L
Leo Chen 已提交
146 147
                                        bool persistable = false,
                                        bool zero_copy = false,
148 149 150 151 152
                                        std::string name = "",
                                        int stop_gradient = -1) {
  VLOG(4) << "Init VarBase from Arg: ";
  // 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name , 6:
  // stop_gradient
L
Leo Chen 已提交
153
  if (name == "") {
154 155
    name =
        imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor");
L
Leo Chen 已提交
156
  }
157 158 159
  VLOG(5) << "Init Tensor as: / name: " << name
          << " / persistable: " << persistable << " / zero_copy: " << zero_copy
          << " / stop_gradient: " << stop_gradient;
L
Leo Chen 已提交
160
  new (self) imperative::VarBase(name);
161 162
  self->SetPersistable(persistable);
  auto *tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
163 164 165
  if (stop_gradient != -1) {
    self->SetOverridedStopGradient(stop_gradient);
  }
166 167 168 169 170 171
  SetTensorFromPyArray<P>(tensor, array, place, zero_copy);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor->type());
}

static void InitVarBaseFromNumpyWithArgDefault(imperative::VarBase *self,
L
Leo Chen 已提交
172
                                               const py::array &array) {
173
  VLOG(4) << "Init VarBase from numpy: ";
L
Leo Chen 已提交
174 175
  auto place = imperative::GetCurrentTracer()->ExpectedPlace();
  InitTensorForVarBase(self, array, place);
176
}
177

178 179 180 181 182
static void InitVarBaseFromTensorWithArgDefault(
    imperative::VarBase *self, const framework::LoDTensor &tensor) {
  VLOG(4) << "Init VarBase";
  auto place = imperative::GetCurrentTracer()->ExpectedPlace();
  new (self) imperative::VarBase(
183
      imperative::GetCurrentTracer()->GenerateUniqueName("generated_tensor"));
184 185 186 187 188 189 190 191 192 193 194 195 196 197
  self->SetPersistable(false);
  self->SetType(framework::proto::VarType::LOD_TENSOR);
  self->SetDataType(tensor.type());
  auto *new_tensor = self->MutableVar()->GetMutable<framework::LoDTensor>();
  // Same place,share data directly
  if (place == tensor.place()) {
    new_tensor->ShareDataWith(tensor);
    VLOG(4) << "Same place, do ShareDataWith";
  } else {
    framework::TensorCopy(tensor, place, new_tensor);
    VLOG(4) << "Different place, do TensorCopy";
  }
}

198 199 200 201 202
static std::string GetTypeName(const imperative::VarBase &var) {
  if (var.Type() == framework::proto::VarType::RAW) {
    return "RAW";
  } else if (!var.Var().IsInitialized()) {
    return "nullptr";
203
  } else {
204
    return framework::ToTypeName(var.Var().Type());
205 206
  }
}
L
Leo Chen 已提交
207

208
using PyNameVarBaseMap = std::unordered_map<std::string, py::handle>;
209 210 211 212 213 214

template <typename T>
static T PyObjectCast(PyObject *obj) {
  try {
    return py::cast<T>(py::handle(obj));
  } catch (py::cast_error &) {
215 216
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Python object is not type of %s", typeid(T).name()));
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
  }
}

// NOTE(zjl): py::handle is a very light wrapper of PyObject *.
// Unlike py::object, py::handle does not change reference count of PyObject *.
static std::vector<std::shared_ptr<imperative::VarBase>>
GetVarBaseListFromPyHandle(const py::handle &handle) {
  PyObject *py_obj = handle.ptr();  // get underlying PyObject
  // Python None is not nullptr in C++!
  if (!py_obj || py_obj == Py_None) {
    return {};
  }

  std::vector<std::shared_ptr<imperative::VarBase>> result;

232
  if (PyList_Check(py_obj)) {  // List of VarBase
233 234 235
    size_t len = PyList_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
236 237 238
      PyObject *py_ivar = PyList_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
239 240 241
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
242
  } else if (PyTuple_Check(py_obj)) {  // Tuple of VarBase
243 244 245
    size_t len = PyTuple_GET_SIZE(py_obj);
    result.reserve(len);
    for (size_t i = 0; i < len; ++i) {
246 247 248
      PyObject *py_ivar = PyTuple_GET_ITEM(py_obj, i);
      PADDLE_ENFORCE_NOT_NULL(
          py_ivar, platform::errors::InvalidArgument("Python Object is NULL"));
249 250 251
      result.emplace_back(
          PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_ivar));
    }
252 253 254
  } else {  // VarBase
    result.emplace_back(
        PyObjectCast<std::shared_ptr<imperative::VarBase>>(py_obj));
255 256 257 258 259
  }

  return result;
}

J
Jiabin Yang 已提交
260 261 262
static imperative::NameVarBaseMap ConvertToNameVarBaseMap(
    const PyNameVarBaseMap &map) {
  imperative::NameVarBaseMap result;
263 264 265 266 267 268
  for (auto &pair : map) {
    auto var_vec = GetVarBaseListFromPyHandle(pair.second);
    if (!var_vec.empty()) {
      result.emplace(pair.first, std::move(var_vec));
    }
  }
J
Jiabin Yang 已提交
269

270 271 272
  PADDLE_ENFORCE_EQ(
      PyErr_Occurred(), nullptr,
      platform::errors::InvalidArgument(py::str(py::handle(PyErr_Occurred()))));
273 274 275
  return result;
}

276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
static bool PyCheckInteger(PyObject *obj) {
#if PY_VERSION_HEX < 0x03000000
  return (PyLong_Check(obj) || PyInt_Check(obj)) && !PyBool_Check(obj);
#else
  return PyLong_Check(obj) && !PyBool_Check(obj);
#endif
}

// NOTE(zhiqiu): Revised version of PySlice_GetIndices. From:
// https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Objects/sliceobject.c#L103
// Original PySlice_GetIndices return wrong result when
// slice_item contains long int, such as arr[:180L].
// NOT sure why this happens !!!
// Besides, PySlice_GetIndices cannot raise error when float in slice item.
// So, I make a revised version of PySlice_GetIndices, named to
// _PySlice_GetIndices. Try to use _PySlice_Unpack which is more robust than
// PySlice_GetIndices in the future.
static int _PySlice_GetIndices(PySliceObject *r, Py_ssize_t length,
                               Py_ssize_t *start, Py_ssize_t *stop,
                               Py_ssize_t *step) {
  /* XXX support long ints */
  if (r->step == Py_None) {
    *step = 1;
  } else {
    if (PyCheckInteger(r->step)) {
      *step = PyLong_AsLong(r->step);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->step)->tp_name)));
    }
  }
  if (r->start == Py_None) {
    *start = *step < 0 ? length - 1 : 0;
  } else {
    if (PyCheckInteger(r->start)) {
      *start = PyLong_AsLong(r->start);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->start)->tp_name)));
    }
    if (*start < 0) *start += length;
  }
  if (r->stop == Py_None) {
    *stop = *step < 0 ? -1 : length;
  } else {
    if (PyCheckInteger(r->stop)) {
      *stop = PyLong_AsLong(r->stop);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Currently, VarBase.__getitem__() only allows None or integers in "
          "slice item, but received %s.",
          std::string(Py_TYPE(r->stop)->tp_name)));
    }
    if (*stop < 0) *stop += length;
  }
  if (*stop > length) return -1;
  if (*start >= length) return -1;
  if (*step == 0) return -1;
  return 0;
}

S
songyouwei 已提交
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364
static void ParseIndexingSlice(framework::LoDTensor *tensor, PyObject *_index,
                               std::vector<int> *slice_axes,
                               std::vector<int> *slice_starts,
                               std::vector<int> *slice_ends,
                               std::vector<int> *slice_strides,
                               std::vector<int> *decrease_axis,
                               std::vector<int> *infer_flags) {
  // We allow indexing by Integers, Slices, and tuples of those
  // types.
  // Ellipsis and None are not supported yet.
  // wrap to tuple
  PyObject *index = !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
  PADDLE_ENFORCE_EQ(
      tensor->IsInitialized(), true,
      platform::errors::InvalidArgument("tensor has not been initialized"));
  const auto &shape = tensor->dims();
  const int rank = shape.size();
  const int size = PyTuple_GET_SIZE(index);
  PADDLE_ENFORCE_EQ(
      size <= rank, true,
      platform::errors::InvalidArgument(
          "too many indices (%d) for tensor of dimension %d", size, rank));
  for (int dim = 0; dim < size; ++dim) {
    PyObject *slice_item = PyTuple_GetItem(index, dim);
365 366 367 368 369 370 371
    PADDLE_ENFORCE_EQ(PyCheckInteger(slice_item) || PySlice_Check(slice_item),
                      true,
                      platform::errors::InvalidArgument(
                          "Currently, VarBase.__getitem__() only allows "
                          "indexing by Integers, Slices, and tuples of "
                          "these types, but received %s in %dth slice item",
                          std::string(Py_TYPE(slice_item)->tp_name), dim + 1));
S
songyouwei 已提交
372 373
    infer_flags->push_back(1);
    int dim_len = shape[dim];
374 375
    if (PyCheckInteger(slice_item)) {
      // integer, PyLong_AsLong supports both int and long
S
songyouwei 已提交
376
      int start = static_cast<int>(PyLong_AsLong(slice_item));
H
hong 已提交
377
      auto s_t = start;
S
songyouwei 已提交
378
      start = start < 0 ? start + dim_len : start;
H
hong 已提交
379 380 381 382 383 384 385 386 387 388
      if (start >= dim_len) {
        std::string str_error_message =
            "The starting index " + std::to_string(s_t) +
            " of slice is out of bounds in tensor " + std::to_string(dim) +
            "-th axis, it shound be in the range of [" +
            std::to_string(-dim_len) + ", " + std::to_string(dim_len) + ")";
        // py::index_error is corresponding to IndexError in Python
        // Used to indicate out of bounds access in __getitem__, __setitem__
        throw py::index_error(str_error_message);
      }
S
songyouwei 已提交
389 390 391 392 393 394
      slice_axes->push_back(dim);
      slice_starts->push_back(start);
      slice_ends->push_back(start + 1);
      slice_strides->push_back(1);
      decrease_axis->push_back(dim);
    } else {
395
      // slice item
S
songyouwei 已提交
396
      Py_ssize_t start, end, step;
397 398 399
      PySliceObject *p = reinterpret_cast<PySliceObject *>(slice_item);
      _PySlice_GetIndices(p, dim_len, &start, &end, &step);

S
songyouwei 已提交
400
      // :: or : or 0:dim_len:1
401 402 403
      if (start == 0 && end == dim_len && step == 1) {
        continue;
      }
S
songyouwei 已提交
404 405 406 407 408 409 410 411 412
      slice_axes->push_back(dim);
      slice_starts->push_back(start);
      slice_ends->push_back(end);
      slice_strides->push_back(step);
    }
  }
  if (!PyTuple_Check(_index)) Py_DecRef(index);
}

413
// Bind Methods
J
Jiabin Yang 已提交
414
void BindImperative(py::module *m_ptr) {
415 416
  auto &m = *m_ptr;

417 418
  BindOpFunctions(&m);

419 420
#ifndef _WIN32
  // Dygraph DataLoader signal handler
421 422 423 424 425 426 427 428 429 430 431 432 433
  m.def("_set_process_pids", [](int64_t key, py::object &obj) {
    PADDLE_ENFORCE_EQ(
        py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj), true,
        platform::errors::InvalidArgument(
            "The subprocess ids set in DataLoader is illegal."
            "Expected data type is tuple or list, but received %s",
            obj.get_type()));
    py::list pids = py::cast<py::list>(obj);
    std::set<pid_t> pids_set = {};
    for (size_t i = 0; i < pids.size(); i++) {
      pids_set.insert(pids[i].cast<pid_t>());
    }
    imperative::SetLoadProcessPIDs(key, pids_set);
434
  });
435 436
  m.def("_erase_process_pids",
        [](int64_t key) { imperative::EraseLoadProcessPIDs(key); });
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
  m.def("_set_process_signal_handler",
        []() { imperative::SetLoadProcessSignalHandler(); });
  m.def("_throw_error_if_process_failed",
        []() { imperative::ThrowErrorIfLoadProcessFailed(); });

  // Dygraph DataLoader reader process & thread related functions
  m.def(
      "_convert_to_tensor_list",
      [](py::object &obj) -> py::list {
        // 0. input data check
        PADDLE_ENFORCE(
            py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj),
            platform::errors::InvalidArgument(
                "The batch data read into DataLoader is illegal."
                "Expected data type is tuple or list, but received %s",
                obj.get_type()));
        py::list batch = py::cast<py::list>(obj);
        py::list tensors;
        for (size_t i = 0; i < batch.size(); ++i) {
          // 1. cast to python array
          auto array = batch[i].cast<py::array>();
          PADDLE_ENFORCE_NE(
              string::Sprintf("%s", array.dtype()).compare("object"), 0,
              platform::errors::InvalidArgument(
                  "Faild to convert input data to a regular ndarray.\n  * "
                  "Usually this means the input data contains nested "
                  "lists with different lengths.\n  * Check the reader "
                  "function passed to 'set_(sample/sample_list/batch)"
                  "_generator' to locate the data causes this issue."));
          // 2. construcct LoDTensor
          framework::LoDTensor t;
          SetTensorFromPyArray<platform::CPUPlace>(&t, array,
                                                   platform::CPUPlace(), true);
          // 3. allocate shared memory
          void *data_ptr = t.data<void>();
          size_t data_size = t.numel() * framework::SizeOfType(t.type());
          auto shared_writer_holder =
              memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
          // 4. maintain mmap fd set & backup ipc_name
          const std::string &ipc_name = shared_writer_holder->ipc_name();
          memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
          // 5. copy data & reset holder
          memory::Copy(platform::CPUPlace(), shared_writer_holder->ptr(),
                       platform::CPUPlace(), data_ptr, data_size);
          t.ResetHolder(shared_writer_holder);
          // 6. append to result list
          tensors.append(t);
        }
        return tensors;
      },
      py::return_value_policy::take_ownership);

  m.def("_remove_tensor_list_mmap_fds", [](py::list &tensor_list) {
    for (size_t i = 0; i < tensor_list.size(); ++i) {
      auto t = tensor_list[i].cast<framework::LoDTensor>();
      auto *mmap_writer_allocation =
          dynamic_cast<memory::allocation::MemoryMapWriterAllocation *>(
              t.Holder().get());
      PADDLE_ENFORCE_NOT_NULL(
          mmap_writer_allocation,
          platform::errors::NotFound("The shared memory of LoDTensor in "
                                     "DataLoader's child process has been "
                                     "released."));
      memory::allocation::MemoryMapFdSet::Instance().Remove(
          mmap_writer_allocation->ipc_name());
    }
  });

  m.def("_cleanup_mmap_fds",
        []() { memory::allocation::MemoryMapFdSet::Instance().Clear(); });
#endif

509 510 511 512 513
  m.def("start_imperative_gperf_profiler",
        []() { imperative::StartProfile(); });

  m.def("stop_imperative_gperf_profiler", []() { imperative::StopProfile(); });

Z
Zeng Jinle 已提交
514 515 516
  m.def("_is_dygraph_debug_enabled",
        []() { return imperative::IsDebugEnabled(); });
  m.def("_dygraph_debug_level", []() { return imperative::GetDebugLevel(); });
517 518 519 520
  m.def("_switch_tracer",
        [](const std::shared_ptr<imperative::Tracer> &tracer) {
          imperative::SetCurrentTracer(tracer);
        });
Z
Zeng Jinle 已提交
521

522
  py::class_<imperative::VarBase, std::shared_ptr<imperative::VarBase>>(
523
      m, "VarBase", R"DOC()DOC")
Z
Zeng Jinle 已提交
524
      .def_static("_alive_vars", &imperative::VarBase::AliveVarNames)
J
Jiabin Yang 已提交
525
      .def("__init__",
526 527 528
           [](imperative::VarBase &self, framework::proto::VarType::Type dtype,
              const std::vector<int> &dims, const py::handle &name,
              framework::proto::VarType::Type type, bool persistable) {
529
             VLOG(4) << "Init VarBase";
530 531 532
             std::string act_name = "";
             if (!name.ptr() || name.ptr() == Py_None) {
               act_name = imperative::GetCurrentTracer()->GenerateUniqueName(
533
                   "generated_tensor");
534 535 536 537
             } else {
               act_name = name.cast<std::string>();
             }
             new (&self) imperative::VarBase(act_name);
J
Jiabin Yang 已提交
538 539 540 541 542 543 544 545 546
             self.SetPersistable(persistable);
             self.SetType(type);
             self.SetDataType(dtype);
             if (type == framework::proto::VarType::LOD_TENSOR) {
               auto *tensor =
                   self.MutableVar()->GetMutable<framework::LoDTensor>();
               tensor->Resize(framework::make_ddim(dims));
             }
           })
547 548
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
549 550
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
551 552 553 554
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::XPUPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
555 556
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
557 558
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
559 560
      .def("__init__", &InitVarBaseFromNumpyWithArg<platform::CUDAPinnedPlace>,
           py::arg("value"), py::arg("place"), py::arg("persistable") = false,
561 562
           py::arg("zero_copy") = false, py::arg("name") = "",
           py::arg("stop_gradient") = -1)
L
Leo Chen 已提交
563
      .def("__init__", &InitVarBaseFromNumpyWithArgDefault, py::arg("value"))
564
      .def("__init__", &InitVarBaseFromTensorWithArgDefault, py::arg("tensor"))
565
      .def("__init__", &InitVarBaseFromNumpyWithKwargs)
566
      .def("__getitem__",
S
songyouwei 已提交
567
           [](std::shared_ptr<imperative::VarBase> &self, py::handle _index) {
568
             std::vector<int> slice_axes, slice_starts, slice_ends,
S
songyouwei 已提交
569 570 571 572 573 574
                 slice_strides, decrease_axis, infer_flags;
             auto tensor =
                 self->MutableVar()->GetMutable<framework::LoDTensor>();
             ParseIndexingSlice(tensor, _index.ptr(), &slice_axes,
                                &slice_starts, &slice_ends, &slice_strides,
                                &decrease_axis, &infer_flags);
575 576 577 578
             // release gil and do tracing
             py::gil_scoped_release release;
             const auto &tracer = imperative::GetCurrentTracer();
             if (slice_axes.empty()) {
S
songyouwei 已提交
579
               return self;
580
             } else {
S
songyouwei 已提交
581
               imperative::NameVarBaseMap ins = {{"Input", {self}}};
582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603
               framework::AttributeMap attrs = {
                   {"axes", slice_axes},
                   {"starts", slice_starts},
                   {"ends", slice_ends},
                   {"infer_flags", infer_flags},
                   {"decrease_axis", decrease_axis}};
               auto out = std::shared_ptr<imperative::VarBase>(
                   new imperative::VarBase(tracer->GenerateUniqueName()));
               imperative::NameVarBaseMap outs = {{"Out", {out}}};
               std::string op_type = "slice";
               for (auto stride : slice_strides) {
                 if (stride != 1) {
                   op_type = "strided_slice";
                   attrs.insert({"strides", slice_strides});
                   attrs.erase("decrease_axis");
                   break;
                 }
               }
               tracer->TraceOp(op_type, ins, outs, std::move(attrs));
               return out;
             }
           })
604 605 606 607 608 609 610
      .def("numpy",
           [](imperative::VarBase &self) -> py::array {
             const auto &tensor =
                 self.MutableVar()->Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(
                 tensor.IsInitialized(), true,
                 platform::errors::InvalidArgument(
611
                     "Tensor of %s is Empty, please check if it has no data.",
612 613 614 615 616
                     self.Name()));
             return TensorToPyArray(tensor, true);
           },
           R"DOC(
        **Notes**:
T
tianshuo78520a 已提交
617
            **This API is ONLY available in Dygraph mode**
618 619 620 621 622 623 624 625 626 627 628 629 630 631

        Returns a numpy array shows the value of current :ref:`api_guide_Variable_en`

        Returns:
            ndarray: The numpy value of current Variable.

        Returns type:
            ndarray: dtype is same as current Variable

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                from paddle.fluid.dygraph.base import to_variable
632
                from paddle.fluid.dygraph import Linear
633 634 635 636
                import numpy as np

                data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
                with fluid.dygraph.guard():
637
                    linear = Linear(32, 64)
638
                    data = to_variable(data)
639
                    x = linear(data)
640 641 642 643 644 645 646 647 648 649 650 651 652
                    print(x.numpy())

       )DOC")
      .def("detach",
           [](const imperative::VarBase &self) {
             const auto &tensor = self.Var().Get<framework::LoDTensor>();
             PADDLE_ENFORCE_EQ(tensor.IsInitialized(), true,
                               platform::errors::InvalidArgument(
                                   "%s has not been initialized", self.Name()));
             return self.NewVarBase(tensor.place(), false);
           },
           py::return_value_policy::copy, R"DOC(

653
        Returns a new Tensor, detached from the current graph.
654

655
        Returns: The detached Tensor.
656 657 658 659

        Examples:
            .. code-block:: python

660 661
                import paddle
                paddle.disable_static()
662

663 664 665 666
                linear = Linear(32, 64)
                data = paddle.uniform(shape=[30, 10, 32], -1, 1)
                x = linear(data)
                y = x.detach()
667 668 669
       )DOC")
      .def("clear_gradient", &imperative::VarBase::ClearGradient, R"DOC(

670
        Only for Tensor that has gradient, normally we use this for Parameters since other temporary Tensor doesen't has gradient.
671

672
        The Gradient of current Tensor will be set to ``0`` .
673 674 675 676 677 678

        Returns:  None

        Examples:
             .. code-block:: python

679 680 681 682 683 684 685 686 687 688 689 690 691 692
                import paddle
                paddle.disable_static()

                inputs = []
                for _ in range(10):
                    tmp = paddle.ones([2, 2])
                    tmp.stop_gradient=False
                    inputs.append(tmp)
                ret = paddle.sums(inputs2)
                loss = paddle.reduce_sum(ret)
                loss.backward()
                print("Before clear_gradient {}".format(loss.grad))
                loss.clear_gradient()
                print("After clear_gradient {}".format(loss.grad))
693
      )DOC")
L
Leo Chen 已提交
694
      .def("_run_backward",
695 696
           [](imperative::VarBase &self, const imperative::Tracer &tracer,
              bool retain_graph) {
697 698
             // TODO(jiabin): when we impl more backward execution we can
             // select them
699
             auto *engine = tracer.GetEngine();
700
             engine->Init(&self, retain_graph);
701
             VLOG(3) << "Start backward";
L
Leo Chen 已提交
702 703 704 705 706 707 708 709 710 711
             engine->Execute();
             VLOG(3) << "Finish backward";
           },
           py::call_guard<py::gil_scoped_release>())
      .def("_grad_name", &imperative::VarBase::GradVarName)
      .def("_grad_value",
           [](imperative::VarBase &self) {
             return self.MutableGradVar()->Get<framework::LoDTensor>();
           },
           py::return_value_policy::reference)
712 713 714 715
      .def("_set_grad_type",
           [](imperative::VarBase &self, framework::proto::VarType::Type type) {
             self.MutableGradVarBase()->SetType(type);
           })
716
      .def("_grad_ivar",
J
Jiabin Yang 已提交
717 718
           [](const imperative::VarBase &self) {
             auto &grad_var = self.GradVarBase();
719 720 721 722 723 724 725 726 727 728 729
             if (grad_var && grad_var->Var().IsInitialized()) {
               auto *tensor =
                   grad_var->MutableVar()->IsType<framework::LoDTensor>()
                       ? grad_var->MutableVar()
                             ->GetMutable<framework::LoDTensor>()
                       : grad_var->MutableVar()
                             ->GetMutable<framework::SelectedRows>()
                             ->mutable_value();
               if (tensor->IsInitialized()) {
                 return grad_var;
               }
J
Jiabin Yang 已提交
730
             }
731
             return std::shared_ptr<imperative::VarBase>(nullptr);
J
Jiabin Yang 已提交
732 733
           },
           py::return_value_policy::copy)
734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763
      .def("_is_sparse",
           [](imperative::VarBase &self) {
             return self.Var().IsType<framework::SelectedRows>();
           })
      .def("_allreduce",
           [](imperative::VarBase &self,
              const imperative::ParallelStrategy &strategy) {
             if (strategy.nranks_ > 1) {
#ifdef PADDLE_WITH_NCCL
#if NCCL_VERSION_CODE >= 2212
               imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
#else
               if (!self.Var().IsType<framework::SelectedRows>()) {
                 imperative::AllReduce(self.Var(), self.MutableVar(), strategy);
               } else {
                 PADDLE_THROW(platform::errors::Unimplemented(
                     "Imperative SelectedRows allreduce is not supported when "
                     "paddle is compiled with NCCL verison lower than v2.2.12. "
                     "You can set is_sparse=False for the Layer containing "
                     "this argument, such as Embedding(is_sparse=False)."));
               }
#endif  // NCCL_VERSION_CODE
#else
               PADDLE_THROW(platform::errors::Unimplemented(
                   "Imperative allreduce is not supported when paddle is "
                   "not compiled with NCCL."));
#endif  // PADDLE_WITH_NCCL
             }
           },
           py::call_guard<py::gil_scoped_release>())
764 765
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CPUPlace &place,
J
Jiabin Yang 已提交
766 767
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
768 769 770 771 772
      .def("_copy_to",
           [](const imperative::VarBase &self,
              const platform::CUDAPinnedPlace &place,
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
773 774 775 776
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::XPUPlace &place,
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
777 778
      .def("_copy_to",
           [](const imperative::VarBase &self, const platform::CUDAPlace &place,
J
Jiabin Yang 已提交
779 780 781
              bool blocking) { return self.NewVarBase(place, blocking); },
           py::return_value_policy::copy)
      .def("value", [](imperative::VarBase &self) { return self.MutableVar(); },
782 783 784
           py::return_value_policy::reference)
      .def_property("name", &imperative::VarBase::Name,
                    &imperative::VarBase::SetName)
L
Leo Chen 已提交
785 786 787 788 789
      .def_property("stop_gradient",
                    &imperative::VarBase::OverridedStopGradient,
                    &imperative::VarBase::SetOverridedStopGradient)
      .def_property("persistable", &imperative::VarBase::Persistable,
                    &imperative::VarBase::SetPersistable)
J
Jiabin Yang 已提交
790 791 792 793
      .def_property_readonly(
          "shape",
          [](imperative::VarBase &self) {
            if (self.Var().IsType<framework::LoDTensor>()) {
794
              return framework::vectorize<int>(
J
Jiabin Yang 已提交
795
                  self.Var().Get<framework::LoDTensor>().dims());
796 797 798
            } else if (self.Var().IsType<framework::SelectedRows>()) {
              return framework::vectorize<int>(
                  self.Var().Get<framework::SelectedRows>().value().dims());
J
Jiabin Yang 已提交
799 800 801 802 803 804
            } else {
              VLOG(2) << "It is meaningless to get shape of variable type "
                      << GetTypeName(self);
              return std::vector<int>();
            }
          })
805 806 807
      .def_property_readonly(
          "place", [](imperative::VarBase &self) { return self.Place(); },
          py::return_value_policy::copy)
J
Jiabin Yang 已提交
808
      .def_property_readonly("type", &imperative::VarBase::Type)
L
Leo Chen 已提交
809
      .def_property_readonly("dtype", &imperative::VarBase::DataType);
810 811 812

  py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer");
  layer.def(py::init<>())
813 814 815 816 817
      .def("forward",
           [](imperative::Layer &self,
              const std::vector<std::shared_ptr<imperative::VarBase>> &inputs) {
             return self.Forward(inputs);
           });
818

819 820 821 822 823
  py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "")
      .def("create_program_desc",
           &imperative::jit::ProgramDescTracer::CreateProgramDesc)
      .def("reset", &imperative::jit::ProgramDescTracer::Reset);

824
  py::class_<imperative::Tracer, std::shared_ptr<imperative::Tracer>>(
825
      m, "Tracer", R"DOC()DOC")
826
      .def("__init__",
J
Jiabin Yang 已提交
827
           [](imperative::Tracer &self) { new (&self) imperative::Tracer(); })
828 829 830
      .def_property("_enable_program_desc_tracing",
                    &imperative::Tracer::IsProgramDescTracingEnabled,
                    &imperative::Tracer::SetEnableProgramDescTracing)
831 832
      .def_property("_enable_autocast", &imperative::Tracer::IsAutoCastEnabled,
                    &imperative::Tracer::SetEnableAutoCast)
833 834
      .def_property("_train_mode", &imperative::Tracer::HasGrad,
                    &imperative::Tracer::SetHasGrad)
835 836 837 838 839 840 841 842
      .def_property(
          "_expected_place",
          [](const imperative::Tracer &self) -> py::object {
            return py::cast(self.ExpectedPlace());
          },
          [](imperative::Tracer &self, const py::object &obj) {
            if (py::isinstance<platform::CUDAPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPlace *>();
L
Leo Chen 已提交
843
              self.SetExpectedPlace(*p);
844 845 846
            } else if (py::isinstance<platform::XPUPlace>(obj)) {
              auto p = obj.cast<platform::XPUPlace *>();
              self.SetExpectedPlace(*p);
847 848
            } else if (py::isinstance<platform::CPUPlace>(obj)) {
              auto p = obj.cast<platform::CPUPlace *>();
L
Leo Chen 已提交
849
              self.SetExpectedPlace(*p);
850 851
            } else if (py::isinstance<platform::CUDAPinnedPlace>(obj)) {
              auto p = obj.cast<platform::CUDAPinnedPlace *>();
L
Leo Chen 已提交
852
              self.SetExpectedPlace(*p);
853
            } else {
L
Leo Chen 已提交
854
              PADDLE_THROW(platform::errors::InvalidArgument(
855 856
                  "Incompatible Place Type: supports XPUPlace, CUDAPlace, "
                  "CPUPlace, "
L
Leo Chen 已提交
857 858
                  "and CUDAPinnedPlace, "
                  "but got Unknown Type!"));
859 860
            }
          })
861 862 863
      .def("_get_program_desc_tracer",
           &imperative::Tracer::GetProgramDescTracer,
           py::return_value_policy::reference)
864
      .def("_generate_unique_name", &imperative::Tracer::GenerateUniqueName,
865
           py::arg("key") = "eager_tmp")
866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885
      .def(
          "_set_amp_op_list",
          [](imperative::Tracer &self,
             std::unordered_set<std::string> &allow_ops,
             std::unordered_set<std::string> &block_ops) {
            // NOTE(zhiqiu): The automatic conversion in pybind11 between c++
            // STL and python set/list/dict involve a copy operation that
            // prevents pass-by-reference semantics, so it is ok to swap.
            // The reaseon why not directly pass
            // std::shared_ptr<std::unordered_set<std::string>>
            // is that pybind11 forbid shared_ptr<T> where T is not custom type.
            imperative::AmpOperators::Instance().GetAllowOps()->swap(allow_ops);
            imperative::AmpOperators::Instance().GetBlockOps()->swap(block_ops);
          })
      .def("_get_amp_op_list",
           [](imperative::Tracer &self) {
             return std::make_tuple(
                 *(imperative::AmpOperators::Instance().GetAllowOps()),
                 *(imperative::AmpOperators::Instance().GetBlockOps()));
           })
886 887 888 889 890 891 892 893 894 895 896 897 898
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::XPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           })
M
minqiyang 已提交
899
      .def("trace",
J
Jiabin Yang 已提交
900 901 902 903 904 905
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CUDAPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
906 907
             {
               py::gil_scoped_release release;
J
Jiabin Yang 已提交
908 909
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
910
             }
M
minqiyang 已提交
911
           })
J
Jiabin Yang 已提交
912 913 914 915 916 917 918 919 920 921 922 923 924
      .def("trace",
           [](imperative::Tracer &self, const std::string &type,
              const PyNameVarBaseMap &ins, const PyNameVarBaseMap &outs,
              framework::AttributeMap attrs, const platform::CPUPlace &place,
              bool trace_backward) {
             auto ins_map = ConvertToNameVarBaseMap(ins);
             auto outs_map = ConvertToNameVarBaseMap(outs);
             {
               py::gil_scoped_release release;
               self.TraceOp(type, std::move(ins_map), std::move(outs_map),
                            std::move(attrs), place, trace_backward);
             }
           });
925 926

  // define parallel context
927 928 929
  py::class_<imperative::ParallelStrategy> parallel_strategy(
      m, "ParallelStrategy", "");
  parallel_strategy.def(py::init())
930 931
      .def_property(
          "nranks",
932 933
          [](const imperative::ParallelStrategy &self) { return self.nranks_; },
          [](imperative::ParallelStrategy &self, int nranks) {
934 935 936
            self.nranks_ = nranks;
          })
      .def_property("local_rank",
937
                    [](const imperative::ParallelStrategy &self) {
938 939
                      return self.local_rank_;
                    },
940
                    [](imperative::ParallelStrategy &self, int local_rank) {
941 942 943 944
                      self.local_rank_ = local_rank;
                    })
      .def_property(
          "trainer_endpoints",
945
          [](const imperative::ParallelStrategy &self) {
946 947
            return self.trainer_endpoints_;
          },
948
          [](imperative::ParallelStrategy &self, std::vector<std::string> eps) {
949 950 951
            self.trainer_endpoints_ = eps;
          })
      .def_property("current_endpoint",
952
                    [](const imperative::ParallelStrategy &self) {
953 954
                      return self.current_endpoint_;
                    },
955 956
                    [](imperative::ParallelStrategy &self,
                       const std::string &ep) { self.current_endpoint_ = ep; });
957 958 959 960 961 962 963 964

  m.def(
      "dygraph_partial_grad",
      [](const std::vector<std::shared_ptr<imperative::VarBase>> &input_targets,
         const std::vector<std::shared_ptr<imperative::VarBase>>
             &output_targets,
         const std::vector<std::shared_ptr<imperative::VarBase>> &output_grads,
         const std::vector<std::shared_ptr<imperative::VarBase>> &no_grad_vars,
965 966
         const platform::Place &place, bool create_graph, bool retain_graph,
         bool allow_unused, bool only_inputs) {
Z
Zeng Jinle 已提交
967 968
        imperative::PartialGradEngine engine(
            input_targets, output_targets, output_grads, no_grad_vars, place,
969
            create_graph, retain_graph, allow_unused, only_inputs);
970 971 972 973 974
        engine.Execute();
        return engine.GetResult();
      },
      py::call_guard<py::gil_scoped_release>());

975
#if defined(PADDLE_WITH_NCCL)
976 977
  py::class_<imperative::NCCLParallelContext> nccl_ctx(m,
                                                       "NCCLParallelContext");
978 979

  nccl_ctx
980 981 982
      .def(py::init<const imperative::ParallelStrategy &,
                    const platform::CUDAPlace &>())
      .def("init", [](imperative::NCCLParallelContext &self) { self.Init(); });
983
#endif
984 985 986 987
}

}  // namespace pybind
}  // namespace paddle