eager_method.cc 84.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// disable numpy compile error
12 13 14 15 16 17

#if defined(_MSC_VER)
#include <BaseTsd.h>
typedef SSIZE_T ssize_t;
#endif

18
#include <Python.h>
19 20 21 22
// Avoid a problem with copysign defined in pyconfig.h on Windows.
#ifdef copysign
#undef copysign
#endif
23 24

#include <string>
25
#include <unordered_map>
26 27
#include <vector>

28
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
29
#include "paddle/fluid/eager/api/all.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
31
#include "paddle/fluid/eager/autograd_meta.h"
32 33
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
34
#include "paddle/fluid/eager/utils.h"
35
#include "paddle/fluid/framework/convert_utils.h"
36
#include "paddle/fluid/framework/string_array.h"
37 38 39 40 41 42
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
J
Jiabin Yang 已提交
43
#include "paddle/fluid/pybind/slice_utils.h"
44
#include "paddle/fluid/pybind/uva_utils.h"
45 46 47 48
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
49 50
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
W
wanghuancoder 已提交
51
#include "pybind11/detail/internals.h"
52 53
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
W
wanghuancoder 已提交
54
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
J
Jiabin Yang 已提交
55
#include "paddle/fluid/eager/amp_utils.h"
56
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
J
Jiabin Yang 已提交
57
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
W
wanghuancoder 已提交
58
#include "paddle/fluid/framework/python_headers.h"
W
wanghuancoder 已提交
59
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
W
wanghuancoder 已提交
60
#include "paddle/fluid/pybind/tensor_py.h"
W
wanghuancoder 已提交
61
#include "paddle/phi/core/ddim.h"
62
#include "paddle/phi/core/flags.h"
63
#include "paddle/phi/core/tensor_utils.h"
64
#include "paddle/phi/kernels/funcs/math_function.h"
J
Jiabin Yang 已提交
65

66
PHI_DECLARE_bool(set_to_1d);
67

68 69 70
namespace paddle {
namespace pybind {

71 72
extern void InitTensorWithNumpyValue(TensorObject* self,
                                     const pybind11::object& array,
73
                                     const paddle::platform::Place& place,
74
                                     bool zero_copy);
75

76
extern PyTypeObject* p_tensor_type;
77

J
Jiabin Yang 已提交
78 79 80
Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) {
  if (PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type))) {
    VLOG(6) << "Call GetSliceIndexFromTensor in Eager";
81
    paddle::Tensor tensor = CastPyArg2Tensor(obj, 0);
J
Jiabin Yang 已提交
82
    PADDLE_ENFORCE_EQ(
83 84
        tensor.initialized(),
        true,
J
Jiabin Yang 已提交
85 86 87 88 89 90 91 92
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in slice, however we got "
            "uninitialized tensor %s, please check your code.",
            tensor.name()));
    return GetSliceIndexFromTensor((*static_cast<phi::DenseTensor*>(
        CastPyArg2Tensor(obj, 0).impl().get())));
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
93
        "We should only get paddle::Tensor or VarBase in this "
J
Jiabin Yang 已提交
94 95 96 97
        "method, when you reach this means we got another type index."));
  }
}

98 99
static PyObject* tensor_method_numpy(TensorObject* self,
                                     PyObject* args,
100 101
                                     PyObject* kwargs) {
  EAGER_TRY
W
wanghuancoder 已提交
102 103 104 105 106 107 108 109 110
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl()) {
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
111 112 113 114 115
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_FLOAT_),
        1,
        py_dims,
        py_strides,
        nullptr,
W
wanghuancoder 已提交
116 117 118 119 120
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }
121 122
  auto tensor_dims = self->tensor.shape();
  auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type());
123
  auto sizeof_dtype = phi::SizeOf(self->tensor.type());
124 125
  Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
  Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
126
  size_t py_rank = tensor_dims.size();
127
  size_t numel = 1;
128
  if (py_rank == 0) {
129
    Py_ssize_t args_num = PyTuple_Size(args);
130 131
    // true by default
    bool set_to_1d = FLAGS_set_to_1d;
132 133 134 135 136 137 138 139 140 141 142 143 144
    if (args_num == (Py_ssize_t)1) {
      PyObject* obj = PyTuple_GET_ITEM(args, 0);
      if (obj == Py_False) {
        set_to_1d = false;
      }
    }
    if (set_to_1d) {
      // 0D Tensor hack process to 1D numpy, will remove in future
      VLOG(0)
          << "Warning:: 0D Tensor cannot be used as 'Tensor.numpy()[0]' . In "
             "order to avoid this problem, "
             "0D Tensor will be changed to 1D numpy currently, but it's not "
             "correct and will be "
145 146
             "removed in future. For Tensor contain only one element, Please "
             "modify "
147
             " 'Tensor.numpy()[0]' to 'float(Tensor)' as soon as "
148
             "possible, "
149
             "otherwise 'Tensor.numpy()[0]' will raise error in future.";
150 151 152 153
      py_rank = 1;
      py_dims[0] = 1;
      py_strides[0] = sizeof_dtype * numel;
    }
154 155 156 157 158 159
  } else {
    for (int i = tensor_dims.size() - 1; i >= 0; --i) {
      py_dims[i] = static_cast<size_t>(tensor_dims[i]);
      py_strides[i] = sizeof_dtype * numel;
      numel *= py_dims[i];
    }
160
  }
W
wanghuancoder 已提交
161

162
  PyObject* array = api.PyArray_NewFromDescr_(
163 164
      api.PyArray_Type_,
      api.PyArray_DescrFromType_(numpy_dtype),
165
      py_rank,
166 167 168
      py_dims,
      py_strides,
      nullptr,
169 170 171 172
      pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
          pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
      nullptr);

W
wanghuancoder 已提交
173
  if (!self->tensor.impl()->initialized()) {
174 175 176 177
    if (tensor_dims.size() == 0) {
      py_dims[0] = 0;
      py_strides[0] = 0;
      PyObject* array = api.PyArray_NewFromDescr_(
178 179 180 181 182 183
          api.PyArray_Type_,
          api.PyArray_DescrFromType_(numpy_dtype),
          1,
          py_dims,
          py_strides,
          nullptr,
184 185 186 187 188
          pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
              pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
          nullptr);
      return array;
    }
W
wanghuancoder 已提交
189 190 191
    return array;
  }

192
  if (self->tensor.is_cpu() || self->tensor.is_gpu_pinned()) {
W
wanghuancoder 已提交
193
    eager_gil_scoped_release guard;
194
    platform::CPUPlace place;
195 196 197 198
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
199 200
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
201 202 203 204 205

      // deep copy
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
206 207 208
          place,
          dense_tensor->data(),
          sizeof_dtype * numel);
209 210 211 212 213 214 215 216
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      // deep copy
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
217 218 219
          place,
          dense_tensor->data(),
          sizeof_dtype * numel);
220 221
    }

222
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
223
  } else if (self->tensor.is_gpu()) {
W
wanghuancoder 已提交
224
    eager_gil_scoped_release guard;
225 226 227 228 229
#if defined(PADDLE_WITH_CUDA)
    gpuMemcpyKind kind = cudaMemcpyDeviceToHost;
#elif defined(PADDLE_WITH_HIP)
    gpuMemcpyKind kind = hipMemcpyDeviceToHost;
#endif
230 231 232 233
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
234 235
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
236
      paddle::platform::GpuMemcpySync(
237 238
          pybind11::detail::array_proxy(array)->data,
          dense_tensor->data(),
239
          phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel(),
240
          kind);
241 242 243 244 245
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      paddle::platform::GpuMemcpySync(
246 247
          pybind11::detail::array_proxy(array)->data,
          dense_tensor->data(),
248
          phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel(),
249
          kind);
250
    }
251
#endif
C
Chen Weihang 已提交
252 253 254 255 256 257 258
#if defined(PADDLE_WITH_XPU)
  } else if (self->tensor.is_xpu()) {
    platform::CPUPlace place;
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
259 260
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
C
Chen Weihang 已提交
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
          dense_tensor->place(),
          dense_tensor->data(),
          sizeof_dtype * numel);
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
          dense_tensor->place(),
          dense_tensor->data(),
          sizeof_dtype * numel);
    }
#endif
279 280
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  } else if (self->tensor.is_custom_device()) {
W
wanghuancoder 已提交
281
    eager_gil_scoped_release guard;
282 283 284 285
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
286 287
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
288 289 290 291
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
          ->MemoryCopyD2H(
              pybind11::detail::array_proxy(array)->data,
              dense_tensor->data(),
292
              phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel());
293 294 295 296
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
297 298
      // TODO(qili93): temporary for ascned npu performance to be removed along
      // with npu_identity op
299
      paddle::Tensor temp_tensor(std::make_shared<phi::DenseTensor>());
300 301 302 303 304
      if (dense_tensor->storage_properties_initialized()) {
        temp_tensor = npu_identity_ad_func(self->tensor, -1);
        dense_tensor =
            std::dynamic_pointer_cast<phi::DenseTensor>(temp_tensor.impl());
      }
305 306 307 308
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
          ->MemoryCopyD2H(
              pybind11::detail::array_proxy(array)->data,
              dense_tensor->data(),
309
              phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel());
310 311
    }
#endif
312 313 314
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Tensor.numpy() only support cpu tensor."));
315
    RETURN_PY_NONE
316 317 318 319 320 321
  }

  return array;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jack Zhou 已提交
322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
static PyObject* tensor_method_numpy_for_string_tensor(TensorObject* self,
                                                       PyObject* args,
                                                       PyObject* kwargs) {
  EAGER_TRY
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl() || !self->tensor.impl()->initialized()) {
    VLOG(6) << "The StringTensor is uninitialized. Return the empty string "
               "numpy array.";
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
337 338 339 340 341
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_UNICODE_),
        1,
        py_dims,
        py_strides,
        nullptr,
J
Jack Zhou 已提交
342 343 344 345 346 347 348 349 350 351 352 353 354
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }

  if (self->tensor.is_cpu()) {
    VLOG(6) << "Getting StringTensor's numpy value";
    auto string_tensor =
        std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
    const auto* st_ptr = string_tensor->data();
    auto numel = self->tensor.numel();
    auto tensor_dims = self->tensor.shape();
W
wanghuancoder 已提交
355 356
    // Get the max unicode length of StringTensor to create numpy unicode
    // string array.
J
Jack Zhou 已提交
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379
    auto* longest_pstring = std::max_element(
        st_ptr, st_ptr + numel, [](const auto& a, const auto& b) {
          auto a_unicode_len =
              phi::strings::GetUnicodeStrLen(a.data(), a.size());
          auto b_unicode_len =
              phi::strings::GetUnicodeStrLen(b.data(), b.size());
          return a_unicode_len < b_unicode_len;
        });
    size_t max_unicode_length = phi::strings::GetUnicodeStrLen(
        longest_pstring->data(), longest_pstring->size());
    max_unicode_length = (max_unicode_length == 0) ? 1 : max_unicode_length;
    VLOG(6) << "The max unicode length is " << max_unicode_length;
    auto sp = std::make_unique<uint32_t[]>(max_unicode_length * numel);
    auto py_array_data = sp.get();
    memset(py_array_data, 0, max_unicode_length * numel * sizeof(uint32_t));
    for (int64_t i = 0; i < numel; ++i) {
      auto curr_unicode_len =
          phi::strings::GetUnicodeStrLen(st_ptr[i].data(), st_ptr[i].size());
      phi::strings::GetUnicodeStr(st_ptr[i].data(),
                                  py_array_data + i * max_unicode_length,
                                  curr_unicode_len);
    }
    py::array array(py::dtype("U" + std::to_string(max_unicode_length)),
380 381 382
                    tensor_dims,
                    {},
                    py_array_data);
J
Jack Zhou 已提交
383 384 385 386
    return array.release().ptr();
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "StringTensor.numpy() only support cpu tensor."));
387
    RETURN_PY_NONE
J
Jack Zhou 已提交
388 389 390 391
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

392 393 394 395
static PyObject* tensor_method__is_initialized(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
396
  return ToPyObject(self->tensor.initialized());
397 398 399
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
400 401 402 403 404 405 406 407 408 409 410 411 412 413
static PyObject* tensor_method__is_dense_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto dense_tensor =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  if (dense_tensor) {
    return ToPyObject(dense_tensor->IsInitialized());
  } else {
    return ToPyObject(false);
  }

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

414
static void IncreaseTensorReferenceCountUntilCopyComplete(
415
    const paddle::Tensor& tensor, const platform::Place& place) {
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
  auto place_ = platform::is_gpu_place(place) ? place : tensor.place();

  auto tracer = egr::Controller::Instance().GetCurrentTracer();
  auto gc = tracer->MutableGarbageCollectorIfNotExists(place_);

  // Note(dev): This is an empty callback, the only way is to "reference"
  // inner memory Holder, so it will not be destructed until the kernels
  // launched at current stream of given place is finished, such as
  // CUDAPinned Mem -> CUDA by cudamemcpyAsync.
  auto callback = [tensor, place_]() {
    VLOG(3) << "Run callback of Tensor:" << tensor.name() << " at place "
            << place_;
  };
  gc->DirectClearCallback(callback);
}

432 433
static PyObject* tensor_method__copy_to(TensorObject* self,
                                        PyObject* args,
434 435
                                        PyObject* kwargs) {
  EAGER_TRY
436 437
  auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
438
  paddle::Tensor cp_tensor;
W
wanghuancoder 已提交
439 440 441 442 443 444 445 446 447 448
  {
    eager_gil_scoped_release guard;
    cp_tensor = self->tensor.copy_to(place, blocking);
    if (!blocking) {
      IncreaseTensorReferenceCountUntilCopyComplete(self->tensor, place);
    }
    egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
    egr::EagerUtils::autograd_meta(&cp_tensor)
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
449
  }
450 451 452 453
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

454 455
static PyObject* tensor_method_cpu(TensorObject* self,
                                   PyObject* args,
456 457
                                   PyObject* kwargs) {
  EAGER_TRY
458
  paddle::Tensor cp_tensor;
W
wanghuancoder 已提交
459 460 461 462 463 464 465 466
  {
    eager_gil_scoped_release guard;
    cp_tensor = self->tensor.copy_to(phi::CPUPlace(), true);
    egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
    egr::EagerUtils::autograd_meta(&cp_tensor)
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  }
467 468 469 470
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

471 472 473 474
static PyObject* tensor_method_reconstruct_from_(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
475
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
476
  std::string orig_name = self->tensor.name();
477 478
  VLOG(6) << "Start Reconstructing Tensor from" << src_tensor.name() << " to "
          << orig_name;
479
  self->tensor = src_tensor;
480 481

  // Recover source name
482
  self->tensor.set_name(orig_name);
483 484

  VLOG(6) << "Finished Reconstructing Tensor from" << src_tensor.name()
485
          << " to " << self->tensor.name();
486 487
  RETURN_PY_NONE

488 489 490
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

491 492
static PyObject* tensor_method_copy_(TensorObject* self,
                                     PyObject* args,
493 494
                                     PyObject* kwargs) {
  EAGER_TRY
495
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
496
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
497
  VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to "
498
          << self->tensor.name();
499
  if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
500
    eager_gil_scoped_release guard;
501
    egr::EagerUtils::autograd_meta(&(self->tensor))
502 503
        ->SetStopGradient(
            egr::EagerUtils::autograd_meta(&(src_tensor))->StopGradient());
504
    egr::EagerUtils::autograd_meta(&(self->tensor))
505 506
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(src_tensor))->Persistable());
507
    if (src_tensor.initialized()) {
C
Chen Weihang 已提交
508
      self->tensor.copy_(src_tensor, src_tensor.place(), blocking);
509 510 511
    }
  } else {
    if (src_tensor.initialized()) {
W
wanghuancoder 已提交
512
      eager_gil_scoped_release guard;
C
Chen Weihang 已提交
513
      self->tensor.copy_(src_tensor, self->tensor.place(), blocking);
514
    }
515 516
  }

517
  VLOG(6) << "Finish Copy Tensor " << src_tensor.name() << " to "
518
          << self->tensor.name();
519 520
  RETURN_PY_NONE

521 522 523
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

524 525 526 527
static PyObject* tensor_method_clone(TensorObject* self,
                                     PyObject* args,
                                     PyObject* kwargs) {
  EAGER_TRY
528
  paddle::Tensor out;
W
wanghuancoder 已提交
529 530 531 532 533 534 535 536 537
  {
    eager_gil_scoped_release guard;
    PADDLE_ENFORCE_EQ(
        self->tensor.initialized(),
        true,
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in clone, however we got "
            "uninitialized tensor %s, please check your code.",
            self->tensor.name()));
538

W
wanghuancoder 已提交
539 540
    out = assign_ad_func(self->tensor);
  }
541 542 543 544
  return ToPyObject(out);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

545 546
static PyObject* tensor_retain_grads(TensorObject* self,
                                     PyObject* args,
547
                                     PyObject* kwargs) {
548
  EAGER_TRY
549
  if (egr::Controller::Instance().HasGrad()) {
W
wanghuancoder 已提交
550
    eager_gil_scoped_release guard;
551
    auto meta = egr::EagerUtils::autograd_meta(&(self->tensor));
552
    if (!meta->GetMutableGradNode()) {
553
      VLOG(6) << "Make grad node of tensor: " << self->tensor.name()
554
              << "become accumulation node";
555
      meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
556
    }
557
    egr::egr_utils_api::RetainGradForTensor(self->tensor);
558
  }
559 560
  RETURN_PY_NONE

561 562 563
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

564 565
static PyObject* tensor_clear_gradient(TensorObject* self,
                                       PyObject* args,
566
                                       PyObject* kwargs) {
567
  EAGER_TRY
568
  VLOG(4) << "ClearGradient " << self->tensor.name();
569

570 571 572
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
J
Jiabin Yang 已提交
573
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
574 575
  }

576
  paddle::Tensor* grad;
J
Jiabin Yang 已提交
577 578
  bool is_leaf = egr::egr_utils_api::IsLeafTensor(self->tensor);
  if (is_leaf) {
579 580 581 582 583 584
    grad = egr::EagerUtils::mutable_grad(self->tensor);
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
585
  } else {
586
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
587
    grad = meta->MutableGrad();
588 589
  }

590
  if (grad->impl()) {
W
wanghuancoder 已提交
591
    eager_gil_scoped_release guard;
592 593 594 595 596 597 598 599 600 601
    if (grad->is_selected_rows()) {
      auto selected_rows =
          std::dynamic_pointer_cast<phi::SelectedRows>(grad->impl());
      if (selected_rows->mutable_value()->IsInitialized()) {
        selected_rows->mutable_rows()->clear();
        selected_rows->mutable_value()->clear();
      }
    } else if (grad->is_dense_tensor()) {
      if (grad->initialized()) {
        if (set_to_zero) {
602 603 604 605
          auto* grad_t = static_cast<phi::DenseTensor*>(grad->impl().get());
          auto* dev_ctx =
              platform::DeviceContextPool::Instance().Get(grad_t->place());
          phi::funcs::set_constant(*dev_ctx, grad_t, 0.0);
J
Jiabin Yang 已提交
606 607 608 609 610
          if (is_leaf) {
            std::static_pointer_cast<egr::GradNodeAccumulation>(
                egr::EagerUtils::grad_node(self->tensor))
                ->SetFakeEmpty(true);
          }
611 612 613 614 615 616 617
        } else {
          VLOG(4) << "Gradient of " << self->tensor.name()
                  << " is initialized, will be released.";
          auto dense_tensor =
              std::dynamic_pointer_cast<phi::DenseTensor>(grad->impl());
          dense_tensor->MoveMemoryHolder();
        }
618 619
      }
    }
620
  }
621

622 623
  RETURN_PY_NONE

624 625 626
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

627 628
static PyObject* tensor__zero_grads(TensorObject* self,
                                    PyObject* args,
629
                                    PyObject* kwargs) {
630
  EAGER_TRY
631
  VLOG(4) << "ZeroGrads " << self->tensor.name();
632

633
  if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
W
wanghuancoder 已提交
634
    eager_gil_scoped_release guard;
635
    // Add RetainGrad as PostHook to AccumulationNode
636
    paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
637 638 639 640 641 642
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
    if (grad->initialized()) {
643 644 645 646 647 648 649
      if (grad->is_dense_tensor()) {
        auto* t = static_cast<phi::DenseTensor*>(grad->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl());
      }
650
    }
651
  } else {
W
wanghuancoder 已提交
652
    eager_gil_scoped_release guard;
653
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
654
    if (meta->MutableGrad()->initialized()) {
655 656 657 658 659 660 661 662 663
      if (meta->MutableGrad()->is_dense_tensor()) {
        auto* t =
            static_cast<phi::DenseTensor*>(meta->MutableGrad()->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        meta->MutableGrad()->set_impl(
            paddle::experimental::zeros_like(*(meta->MutableGrad())).impl());
      }
664
    }
665 666
  }

667 668
  RETURN_PY_NONE

669 670 671
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

672 673
static PyObject* tensor__share_buffer_to(TensorObject* self,
                                         PyObject* args,
674 675
                                         PyObject* kwargs) {
  EAGER_TRY
676
  paddle::Tensor* dst_ptr =
677
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
678 679
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
680 681 682
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
683
                        self->tensor.name()));
684
  auto* src_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
685 686 687
  if (!dst_ptr->defined()) {
    dst_ptr->set_impl(std::make_shared<phi::DenseTensor>());
  }
688
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
B
Baibaifan 已提交
689
  dst_tensor->ShareBufferWith(*src_tensor);
690
  dst_tensor->ShareDataTypeWith(*src_tensor);
691 692
  RETURN_PY_NONE

693 694 695
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

696 697 698 699
static PyObject* tensor__is_shared_buffer_with(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
700
  paddle::Tensor* dst_ptr =
701
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
702 703
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
704 705 706
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
707
                        self->tensor.name()));
708
  bool res = false;
709
  if (!self->tensor.defined() || !dst_ptr->defined()) {
710 711
    return ToPyObject(res);
  }
712 713
  auto* self_ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
714 715 716 717 718
  res = dst_tensor->IsSharedBufferWith(*self_ptr);
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

719 720 721 722
static PyObject* tensor__share_underline_tensor_to(TensorObject* self,
                                                   PyObject* args,
                                                   PyObject* kwargs) {
  EAGER_TRY
723
  paddle::Tensor* src_ptr =
724
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
725 726
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
727 728 729
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
730 731
                        self->tensor.name()));
  src_ptr->set_impl(self->tensor.impl());
732 733
  RETURN_PY_NONE

734 735 736
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

737 738 739 740
static PyObject* tensor__is_shared_underline_tensor_with(TensorObject* self,
                                                         PyObject* args,
                                                         PyObject* kwargs) {
  EAGER_TRY
741
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
742 743
  PADDLE_ENFORCE_EQ(src_tensor.initialized(),
                    true,
744 745 746 747 748
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
                        src_tensor.name()));
  bool res = false;
749
  if (!self->tensor.defined() || !src_tensor.defined()) {
750 751
    return ToPyObject(res);
  }
752
  res = (self->tensor.impl().get() == src_tensor.impl().get());
753 754 755 756
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

757 758
static PyObject* tensor_method_detach(TensorObject* self,
                                      PyObject* args,
759 760
                                      PyObject* kwargs) {
  EAGER_TRY
761
  PADDLE_ENFORCE_EQ(
762
      self->tensor.defined(),
763
      true,
764
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
765
                                        self->tensor.name()));
766

767
  PyObject* obj = p_tensor_type->tp_alloc(p_tensor_type, 0);
768
  if (obj) {
769
    auto v = reinterpret_cast<TensorObject*>(obj);
770
    new (&(v->tensor)) paddle::Tensor();
771 772 773 774
    v->tensor.set_impl(self->tensor.impl());
    v->tensor.set_name(egr::Controller::Instance().GenerateUniqueName());
    auto autograd_meta_src = egr::EagerUtils::autograd_meta(&(self->tensor));
    auto autograd_meta = egr::EagerUtils::autograd_meta(&(v->tensor));
775 776 777 778 779 780 781 782 783 784
    autograd_meta->SetPersistable(autograd_meta_src->Persistable());
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "tp_alloc return null, can not new a PyObject."));
  }

  return obj;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

785 786 787 788
static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
789
  if (!self->tensor.defined()) {
790 791 792
    // The original `get_tensor` method of Variable will create a empty tensor
    phi::DenseTensor empty_tensor;
    return ToPyObject(&empty_tensor);
793
  }
794
  if (self->tensor.is_dense_tensor()) {
795
    auto* tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
796 797 798
    VLOG(6) << "tensor: " << tensor->IsInitialized();
    return ToPyObject(tensor);
  } else {
799
    RETURN_PY_NONE
800 801 802 803
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

804 805 806 807 808
static PyObject* tensor_method_get_underline_selected_rows(TensorObject* self,
                                                           PyObject* args,
                                                           PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
809
    RETURN_PY_NONE
810 811 812 813 814 815
  }
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    return ToPyObject(selected_rows);
  } else {
816
    RETURN_PY_NONE
817 818 819 820
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

821 822 823 824 825 826 827 828 829 830 831 832 833 834
static PyObject* tensor_method__get_tensor_from_selected_rows(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows."));

  auto* selected_rows =
      static_cast<phi::SelectedRows*>(self->tensor.impl().get());

  PADDLE_ENFORCE(
      selected_rows->initialized(),
      paddle::platform::errors::Fatal("SelectedRows must be initialized."));

835 836
  auto* dense_tensor =
      static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
L
Leo Chen 已提交
837
  VLOG(4) << "dense_tensor: " << dense_tensor->IsInitialized();
838

839
  auto t = paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
840 841 842 843 844 845 846
  t.set_impl(std::make_shared<phi::DenseTensor>(*dense_tensor));

  return ToPyObject(t);

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jiabin Yang 已提交
847 848 849
static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
850
  EAGER_TRY
J
Jiabin Yang 已提交
851 852 853 854 855 856
  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  VLOG(4) << "Call _getitem_index_not_tensor";
  std::vector<int> slice_axes, slice_starts, slice_ends, slice_strides,
      decrease_axis, none_axes, infer_flags, list_select_idxs;
  // if index is a list, list_select_flag will be true
  bool list_select_flag = false;
857 858
  // Note(0x45f): Using defined() instead of initialized()
  // to support slice tensor which shape like [0, 0, 0].
J
Jiabin Yang 已提交
859
  PADDLE_ENFORCE_EQ(
860
      self->tensor.defined(),
861
      true,
J
Jiabin Yang 已提交
862 863 864 865 866
      platform::errors::InvalidArgument(
          "tensor %s has not been initialized, we can only slice initialized "
          "tensor please init it first with numpy or other tensor.",
          self->tensor.name()));
  auto tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
867 868 869 870 871 872 873 874 875 876 877
  ParseIndexingSlice(tensor,
                     _index,
                     &slice_axes,
                     &slice_starts,
                     &slice_ends,
                     &slice_strides,
                     &decrease_axis,
                     &none_axes,
                     &infer_flags,
                     &list_select_idxs,
                     &list_select_flag);
J
Jiabin Yang 已提交
878

879 880 881 882
  auto out =
      slice_axes.empty() && !list_select_flag
          ? self->tensor
          : paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
J
Jiabin Yang 已提交
883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898

  if (!slice_axes.empty()) {
    framework::AttributeMap attrs = {{"axes", slice_axes},
                                     {"starts", slice_starts},
                                     {"ends", slice_ends},
                                     {"infer_flags", infer_flags},
                                     {"decrease_axis", decrease_axis}};
    std::string op_type = "slice";
    for (auto stride : slice_strides) {
      if (stride != 1) {
        op_type = "strided_slice";
        attrs.insert({"strides", slice_strides});
        attrs.erase("decrease_axis");
        break;
      }
    }
899 900 901 902 903 904
    std::vector<int64_t> slice_axes_tmp(slice_axes.begin(), slice_axes.end());
    std::vector<int64_t> infer_flags_tmp(infer_flags.begin(),
                                         infer_flags.end());
    std::vector<int64_t> decrease_axis_tmp(decrease_axis.begin(),
                                           decrease_axis.end());

J
Jiabin Yang 已提交
905
    if (op_type == "slice") {
W
wanghuancoder 已提交
906
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
907 908 909 910 911 912
      out = slice_ad_func(self->tensor,
                          slice_axes_tmp,
                          slice_starts,
                          slice_ends,
                          infer_flags_tmp,
                          decrease_axis_tmp);
J
Jiabin Yang 已提交
913
    } else if (op_type == "strided_slice") {
W
wanghuancoder 已提交
914
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
915
      out = strided_slice_ad_func(
916
          self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
J
Jiabin Yang 已提交
917 918 919 920 921 922 923 924 925 926 927
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Slice is only support slice and strided_slice, but we got %s which "
          "is impossible, please check your code first or contact us by "
          "issue. ",
          op_type));
    }
  }

  if (!none_axes.empty()) {
    if (!none_axes.empty()) {
928
      paddle::Tensor new_out;
W
wanghuancoder 已提交
929 930 931 932 933 934 935 936 937 938 939 940
      {
        eager_gil_scoped_release guard;
        // Deal with cases that decrease_axes is not empty
        // For example:
        // # x.shape: (2,3,4)
        // out = x[0, 0:2, None] # out.shape : (2, 1, 4)
        for (auto& axis : none_axes) {
          int len = 0;
          for (int da : decrease_axis) {
            if (da < axis) {
              len++;
            }
J
Jiabin Yang 已提交
941
          }
W
wanghuancoder 已提交
942
          axis -= len;
J
Jiabin Yang 已提交
943
        }
W
wanghuancoder 已提交
944
        new_out = unsqueeze_ad_func(out, none_axes);
J
Jiabin Yang 已提交
945 946 947 948 949 950 951
      }
      return ToPyObject(new_out);
    }
  }

  // the index is a list
  if (list_select_flag) {
W
wanghuancoder 已提交
952
    eager_gil_scoped_release guard;
953 954
    auto select_index =
        paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
J
Jiabin Yang 已提交
955
    auto idx_tensor = std::make_shared<phi::DenseTensor>();
W
wanghuancoder 已提交
956
    select_index.set_impl(idx_tensor);
J
Jiabin Yang 已提交
957 958
    auto* dev_ctx = platform::DeviceContextPool::Instance().Get(
        egr::Controller::Instance().GetExpectedPlace());
959 960
    paddle::framework::TensorFromVector(
        list_select_idxs, *dev_ctx, idx_tensor.get());
J
Jiabin Yang 已提交
961
    framework::AttributeMap attrs = {{"dim", 0}};
J
Jiabin Yang 已提交
962
    out = index_select_ad_func(self->tensor, select_index, 0);
J
Jiabin Yang 已提交
963 964 965
  }

  return ToPyObject(out);
966 967 968
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

969 970
static PyObject* tensor__getitem_from_offset(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
971 972
                                             PyObject* kwargs) {
  EAGER_TRY
973 974 975 976 977 978 979 980
  phi::DenseTensor* ptr = nullptr;
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    ptr = static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
  } else {
    ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  }
981 982 983
  PADDLE_ENFORCE_NOT_NULL(ptr,
                          platform::errors::InvalidArgument(
                              "%s is not a DenseTensor.", self->tensor.name()));
W
wanghuancoder 已提交
984 985
  const auto& tensor = *ptr;
  PADDLE_ENFORCE_EQ(
986 987
      tensor.IsInitialized(),
      true,
W
wanghuancoder 已提交
988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004
      platform::errors::InvalidArgument(
          "Tensor of %s is Empty, please check if it has no data.",
          self->tensor.name()));

  const auto& tensor_dims = tensor.dims();

  std::vector<size_t> dims(tensor_dims.size());
  std::vector<size_t> strides(tensor_dims.size());

  size_t numel = 1;
  for (int i = tensor_dims.size() - 1; i >= 0; --i) {
    strides[i] = numel;
    dims[i] = static_cast<size_t>(tensor_dims[i]);
    numel *= dims[i];
  }
  size_t offset = 0;
  if (PyTuple_Size(args) == 0) {
1005 1006
    PADDLE_ENFORCE_EQ(numel,
                      1,
W
wanghuancoder 已提交
1007 1008 1009 1010 1011 1012
                      platform::errors::InvalidArgument(
                          "only one element tensors can be converted to Python "
                          "scalars when no input coordinates"));
  } else if (PyTuple_Size(args) == 1) {
    offset = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
    PADDLE_ENFORCE_LT(
1013 1014
        offset,
        numel,
W
wanghuancoder 已提交
1015 1016 1017
        platform::errors::InvalidArgument(
            "index %d is out of bounds for size %d", offset, numel));
  } else {
1018 1019
    PADDLE_ENFORCE_EQ(PyTuple_Size(args),
                      dims.size(),
W
wanghuancoder 已提交
1020 1021 1022 1023 1024 1025
                      platform::errors::InvalidArgument(
                          "incorrect number of indices for Tensor"));

    for (Py_ssize_t i = 0; i < PyTuple_Size(args); ++i) {
      size_t index = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, i), i);
      PADDLE_ENFORCE_LT(
1026 1027
          index,
          dims[i],
W
wanghuancoder 已提交
1028
          platform::errors::InvalidArgument(
1029 1030 1031
              "index %d is out fo bounds for axis %d with size %d",
              index,
              i,
W
wanghuancoder 已提交
1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062
              dims[i]));
      offset += index * strides[i];
    }
  }
#define PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(_) \
  _(bool, DataType::BOOL)                     \
  _(int8_t, DataType::INT8)                   \
  _(uint8_t, DataType::UINT8)                 \
  _(int16_t, DataType::INT16)                 \
  _(uint16_t, DataType::UINT16)               \
  _(int32_t, DataType::INT32)                 \
  _(uint32_t, DataType::UINT32)               \
  _(int64_t, DataType::INT64)                 \
  _(uint64_t, DataType::UINT64)               \
  _(bfloat16, DataType::BFLOAT16)             \
  _(float16, DataType::FLOAT16)               \
  _(float, DataType::FLOAT32)                 \
  _(double, DataType::FLOAT64)                \
  _(complex64, DataType::COMPLEX64)           \
  _(complex128, DataType::COMPLEX128)

#define TENSOR_TO_PY_SCALAR(T, proto_type)                                   \
  if (tensor.dtype() == proto_type) {                                        \
    auto numpy_dtype = TensorDtype2NumpyDtype(proto_type);                   \
    T b = paddle::pybind::TensorGetElement<T>(tensor, offset);               \
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];                  \
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];               \
    py_dims[0] = 1;                                                          \
    py_strides[0] = 1;                                                       \
    auto& api = pybind11::detail::npy_api::get();                            \
    PyObject* array = api.PyArray_NewFromDescr_(                             \
1063 1064 1065 1066 1067 1068
        api.PyArray_Type_,                                                   \
        api.PyArray_DescrFromType_(numpy_dtype),                             \
        1,                                                                   \
        py_dims,                                                             \
        py_strides,                                                          \
        nullptr,                                                             \
W
wanghuancoder 已提交
1069 1070 1071 1072 1073
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |                      \
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,                 \
        nullptr);                                                            \
    std::memcpy(                                                             \
        reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data), \
1074 1075
        static_cast<void*>(&b),                                              \
        sizeof(b));                                                          \
W
wanghuancoder 已提交
1076 1077 1078 1079 1080 1081 1082 1083 1084 1085
    return array;                                                            \
  }

  PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(TENSOR_TO_PY_SCALAR);
#undef TENSOR_TO_PY_SCALAR
  PADDLE_THROW(platform::errors::Unimplemented(
      "Unsupported tensor data type: %s", tensor.dtype()));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129
static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Call __setitem_eager_tensor";

  auto self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());

  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  PyObject* value_obj = PyTuple_GET_ITEM(args, 1);
  // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
  // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
  PyObject* index_ptr =
      !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
  DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
    if (!PyTuple_Check(_index)) {
      Py_DECREF(index_ptr);
      VLOG(4) << "Call Py_DECREF";
    }
  });

  // 1. Check argumnets
  bool parse_index = true;

  // Check whether _index can be parsed.
  const int size = PyTuple_GET_SIZE(index_ptr);
  for (int dim = 0; dim < size; ++dim) {
    PyObject* slice_item = PyTuple_GetItem(index_ptr, dim);
    if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item) ||
          slice_item == Py_Ellipsis || slice_item == Py_None)) {
      parse_index = false;
      break;
    }
  }

  // 2. Call op set_value to speed up if the condition is met,
  // otherwise call TensorToPyArray.
  // TODO(liym27): Try not to call TensorToPyArray because it always
  // copys data to cpu place, which reduces performance.
  if (parse_index) {
    std::vector<int> axes, starts, ends, steps, decrease_axes, none_axes,
        infer_flags, list_select_idxs;
    // if index is a list, list_select_flag will be true
    bool list_select_flag = false;
1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140
    ParseIndexingSlice(self_tensor,
                       index_ptr,
                       &axes,
                       &starts,
                       &ends,
                       &steps,
                       &decrease_axes,
                       &none_axes,
                       &infer_flags,
                       &list_select_idxs,
                       &list_select_flag);
W
wanghuancoder 已提交
1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152

    framework::AttributeMap attrs = {{"axes", axes},
                                     {"starts", starts},
                                     {"ends", ends},
                                     {"steps", steps},
                                     {"decrease_axes", decrease_axes},
                                     {"none_axes", none_axes}};

    if (egr::Controller::Instance().HasGrad()) {
      PADDLE_ENFORCE_EQ(
          egr::egr_utils_api::IsLeafTensor(self->tensor) &&
              !egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient(),
1153 1154 1155 1156 1157
          false,
          platform::errors::InvalidArgument(
              "Leaf Tensor (%s) that doesn't stop gradient can't use "
              "inplace strategy.",
              self->tensor.name()));
W
wanghuancoder 已提交
1158 1159
    }

1160
    paddle::Tensor value_tensor;
W
wanghuancoder 已提交
1161 1162 1163 1164

    if (PyCheckTensor(value_obj)) {
      value_tensor = reinterpret_cast<TensorObject*>(value_obj)->tensor;
    } else if (py::isinstance<py::array>(value_obj)) {
1165
      paddle::Tensor value_tensor_tmp(
W
wanghuancoder 已提交
1166 1167 1168 1169
          std::make_shared<phi::DenseTensor>(),
          egr::Controller::Instance().GenerateUniqueName());
      py::object value_obj_tmp(py::handle(value_obj), true);
      py::object value = value_obj_tmp;
1170
      if (self->tensor.dtype() == phi::DataType::FLOAT32) {
W
wanghuancoder 已提交
1171 1172 1173
        if (!py::isinstance<py::array_t<float>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<float>(value_obj_tmp);
        }
1174
      } else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
W
wanghuancoder 已提交
1175 1176 1177
        if (!py::isinstance<py::array_t<double>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<double>(value_obj_tmp);
        }
1178
      } else if (self->tensor.dtype() == phi::DataType::INT32) {
W
wanghuancoder 已提交
1179 1180 1181
        if (!py::isinstance<py::array_t<int32_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int32_t>(value_obj_tmp);
        }
1182
      } else if (self->tensor.dtype() == phi::DataType::INT64) {
W
wanghuancoder 已提交
1183 1184 1185
        if (!py::isinstance<py::array_t<int64_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int64_t>(value_obj_tmp);
        }
1186
      } else if (self->tensor.dtype() == phi::DataType::BOOL) {
W
wanghuancoder 已提交
1187 1188 1189
        if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp);
        }
1190 1191 1192 1193 1194 1195 1196 1197 1198 1199
      } else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
        if (!py::isinstance<py::array_t<std::complex<float>>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<std::complex<float>>(
              value_obj_tmp);
        }
      } else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
        if (!py::isinstance<py::array_t<std::complex<double>>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<std::complex<double>>(
              value_obj_tmp);
        }
W
wanghuancoder 已提交
1200 1201 1202 1203
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "When assign a numpy.np value to a paddle.Tensor, "
            "the data type of the paddle.Tensor must be bool, "
1204
            "float32, float64, complex64, complex128, int32 or int64, "
W
wanghuancoder 已提交
1205 1206 1207
            "please check the type of tensor."));
      }

W
wanghuancoder 已提交
1208 1209 1210 1211 1212
      SetTensorFromPyArray(
          static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
          value,
          self->tensor.place(),
          false);
W
wanghuancoder 已提交
1213 1214 1215 1216 1217 1218 1219

      value_tensor = value_tensor_tmp;
    } else {
      py::object value_obj_tmp(py::handle(value_obj), true);
      // convert the value to self data type
      if (py::isinstance<py::float_>(value_obj_tmp) ||
          py::isinstance<py::int_>(value_obj_tmp) ||
1220 1221
          py::isinstance<py::bool_>(value_obj_tmp) ||
          PyComplex_Check(value_obj)) {
1222
        if (self->tensor.dtype() == phi::DataType::FLOAT32) {
1223 1224
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<float>()};
1225
        } else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
1226 1227
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<double>()};
1228
        } else if (self->tensor.dtype() == phi::DataType::INT32) {
1229 1230
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<int32_t>()};
1231
        } else if (self->tensor.dtype() == phi::DataType::INT64) {
1232 1233
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<int64_t>()};
1234
        } else if (self->tensor.dtype() == phi::DataType::BOOL) {
1235 1236
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<bool>()};
1237
        } else if (self->tensor.dtype() == phi::DataType::FLOAT16) {
1238 1239 1240 1241 1242 1243 1244 1245
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<float>()};
        } else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<std::complex<float>>()};
        } else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<std::complex<double>>()};
W
wanghuancoder 已提交
1246 1247 1248 1249
        } else {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "When assign a value to a paddle.Tensor, "
              "the data type of the paddle.Tensor must be bool, "
1250 1251
              "float32, float64, complex64, complex128, int32, int64 or "
              "float16, "
W
wanghuancoder 已提交
1252 1253 1254 1255 1256 1257 1258
              "please check the type of tensor."));
        }
        attrs["shape"] = std::vector<int64_t>{1};

      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Value type error. The assign value allows "
1259
            "numpy.ndarray, integer, float, complex  or bool, "
W
wanghuancoder 已提交
1260 1261 1262 1263 1264 1265 1266
            "but received %s.",
            Py_TYPE(value_obj)));
      }
    }
    {
      // Release gil and do tracing
      py::gil_scoped_release release;
1267
      // use inplace set_value_ operator
J
Jiabin Yang 已提交
1268 1269
      if (value_tensor.initialized() &&
          (self->tensor.dtype() != value_tensor.dtype())) {
1270
        paddle::small_vector<std::vector<paddle::Tensor>,
J
Jiabin Yang 已提交
1271 1272 1273 1274 1275 1276 1277
                             egr::kSlotSmallVectorSize>
            tmps = {{self->tensor}, {value_tensor}};
        auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
        self->tensor = egr::EagerAmpAutoCast(
            self->tensor.name(), self->tensor, amp_dtype, "set_value");
        value_tensor = egr::EagerAmpAutoCast(
            value_tensor.name(), value_tensor, amp_dtype, "set_value");
1278 1279 1280
        if (self->tensor.dtype() != value_tensor.dtype()) {
          value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
        }
J
Jiabin Yang 已提交
1281
      }
1282 1283
      self->tensor = set_value__dygraph_function(
          self->tensor, value_tensor, {}, {}, {}, attrs);
1284 1285 1286 1287 1288 1289 1290 1291 1292
    }
    if (PyCheckTensor(value_obj)) {
      // pass the stop_gradient from value to tensor.
      // pass stop gradient should be done after CheckInplace in
      // set_value__dygraph_function.
      if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
        egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
      }
W
wanghuancoder 已提交
1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306
    }
  } else {
    auto self_numpy = TensorToPyArray(*self_tensor);
    VLOG(4) << "parse_index is false";
    if (PyCheckTensor(_index)) {
      VLOG(4) << "index is tensor";
      auto index_tensor = static_cast<phi::DenseTensor*>(
          reinterpret_cast<TensorObject*>(_index)->tensor.impl().get());
      auto index_numpy = TensorToPyArray(*index_tensor);
      self_numpy[index_numpy] = py::object(py::handle(value_obj), true);
    } else {
      VLOG(4) << "index is not tensor";
      self_numpy[_index] = py::object(py::handle(value_obj), true);
    }
1307
    if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
1308
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1309 1310 1311 1312
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CUDAPlace(0)),
                           false);
W
wanghuancoder 已提交
1313
#else
1314 1315 1316 1317
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CPUPlace()),
                           false);
W
wanghuancoder 已提交
1318 1319
#endif
    } else {
1320 1321
      SetTensorFromPyArray(
          self_tensor, self_numpy, self->tensor.place(), false);
W
wanghuancoder 已提交
1322 1323
    }
  }
1324 1325
  RETURN_PY_NONE

W
wanghuancoder 已提交
1326 1327 1328
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1329 1330
static PyObject* tensor_register_grad_hook(TensorObject* self,
                                           PyObject* args,
1331 1332 1333 1334 1335
                                           PyObject* kwargs) {
  EAGER_TRY
  int64_t hook_id;
  if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
    VLOG(6) << "Register hook for leaf tensor: " << self->tensor.name();
1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347

    auto autograd_meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);

    if (autograd_meta && !autograd_meta->StopGradient()) {
      if (!autograd_meta->GetMutableGradNode()) {
        VLOG(6) << "Detected NULL grad_node, Leaf tensor should have had "
                   "grad_node with type: GradNodeAccumulation.";
        autograd_meta->SetGradNode(
            std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
      }
    }

1348 1349 1350 1351 1352 1353 1354 1355 1356
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();
    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    auto accumulation_grad_node =
        std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
    hook_id = accumulation_grad_node->RegisterGradientHook(
1357 1358
        rank_info.first,
        rank_info.second,
1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370
        std::make_shared<PyTensorHook>(hook_func));

  } else {
    VLOG(6) << "Register hook for non leaf tensor: " << self->tensor.name();
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();

    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    hook_id = grad_node->RegisterGradientHook(
1371 1372
        rank_info.first,
        rank_info.second,
1373 1374 1375 1376 1377 1378
        std::make_shared<PyTensorHook>(hook_func));
  }
  return ToPyObject(hook_id);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1379 1380
static PyObject* tensor_remove_grad_hook(TensorObject* self,
                                         PyObject* args,
1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392
                                         PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Remove the registered hook for tensor: " << self->tensor.name();
  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);

  int64_t hook_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);

  return ToPyObject(grad_node->RemoveGradientHook(hook_id));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1393 1394
static PyObject* tensor_register_reduce_hook(TensorObject* self,
                                             PyObject* args,
1395 1396 1397 1398 1399 1400
                                             PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Register reduce hook for tensor: " << self->tensor.name();

  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);
1401 1402
  PADDLE_ENFORCE_EQ(egr::egr_utils_api::IsLeafTensor(self->tensor),
                    true,
1403 1404 1405 1406
                    platform::errors::InvalidArgument(
                        "Only can register backward hook for leaf Tensor."));
  PADDLE_ENFORCE_EQ(
      !egr::EagerUtils::unsafe_autograd_meta(self->tensor)->StopGradient(),
1407 1408 1409 1410
      true,
      platform::errors::InvalidArgument(
          "Cannot register backward hook on a Tensor that stop "
          "gradient."));
1411 1412 1413 1414 1415 1416 1417 1418 1419 1420
  PADDLE_ENFORCE(
      grad_node.get() != nullptr,
      paddle::platform::errors::Fatal("Detected NULL grad_node,"
                                      "Leaf tensor should have had grad_node "
                                      "with type: GradNodeAccumulation."));
  PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

  auto accumulation_grad_node =
      std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
  accumulation_grad_node->RegisterReduceHook(
1421
      std::make_shared<PyVoidHook>(hook_func));
1422

1423 1424
  RETURN_PY_NONE

1425 1426 1427
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1428 1429
static PyObject* tensor__set_grad_type(TensorObject* self,
                                       PyObject* args,
J
Jiabin Yang 已提交
1430
                                       PyObject* kwargs) {
1431 1432 1433
  EAGER_TRY
  auto var_type = pybind::CastPyArg2ProtoType(PyTuple_GET_ITEM(args, 0), 0);
  auto grad_tensor =
1434
      egr::EagerUtils::autograd_meta(&self->tensor)->MutableGrad();
1435
  if (var_type == framework::proto::VarType::LOD_TENSOR) {
1436
    grad_tensor->set_impl(std::make_shared<phi::DenseTensor>());
1437
  } else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
1438
    grad_tensor->set_impl(std::make_shared<phi::SelectedRows>());
1439
  }
1440 1441
  RETURN_PY_NONE

1442 1443 1444
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1445 1446
static PyObject* tensor__clear(TensorObject* self,
                               PyObject* args,
J
Jiabin Yang 已提交
1447 1448 1449
                               PyObject* kwargs) {
  EAGER_TRY
  self->tensor.reset();
1450 1451
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1452 1453 1454
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1455 1456
static PyObject* tensor__copy_gradient_from(TensorObject* self,
                                            PyObject* args,
J
Jiabin Yang 已提交
1457 1458 1459
                                            PyObject* kwargs) {
  EAGER_TRY
  auto src = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
1460
  if (self->tensor.initialized()) {
1461 1462
    PADDLE_ENFORCE_EQ(self->tensor.dtype(),
                      src.dtype(),
J
Jiabin Yang 已提交
1463 1464
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different data type with Tensor %s",
1465 1466
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1467 1468 1469 1470 1471
    PADDLE_ENFORCE_EQ(self->tensor.impl()->type_info().id(),
                      src.impl()->type_info().id(),
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different type with Tensor %s, Tensor "
                          "ShareGradientDataWith cannot be performed!",
1472 1473
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1474 1475 1476 1477
  }
  VLOG(6) << "Tensor copy gradient from: " << src.name();
  auto* p_grad = egr::EagerUtils::mutable_grad(self->tensor);
  if (p_grad) {
1478 1479
    PADDLE_ENFORCE_EQ(src.initialized(),
                      true,
J
Jiabin Yang 已提交
1480 1481 1482 1483
                      platform::errors::InvalidArgument(
                          "Tensor %s has not been initialized", src.name()));
    p_grad->set_impl(src.impl());
  }
1484 1485
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1486 1487
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
1488

1489 1490 1491
static PyObject* tensor__use_gpudnn(TensorObject* self,
                                    PyObject* args,
                                    PyObject* kwargs) {
1492 1493 1494
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.defined() && self->tensor.is_dense_tensor(),
                 paddle::platform::errors::Fatal(
1495
                     "function _use_gpudnn is only effective for DenseTensor"));
1496

1497
  bool use_gpudnn = pybind::CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
1498

1499
  // Set the same use_gpudnn attribute, return directly
1500 1501 1502 1503
  phi::DenseTensor* dense_tensor =
      static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  phi::DenseTensorMeta* dense_tensor_meta =
      phi::DenseTensorUtils::GetMutableMeta(dense_tensor);
1504
  if (use_gpudnn == dense_tensor_meta->use_gpudnn) {
1505 1506 1507
    return ToPyObject(self->tensor);
  }

1508
  // Share all other members of Tensor except use_gpudnn
1509
  phi::DenseTensorMeta target_dense_meta = *dense_tensor_meta;
1510
  target_dense_meta.use_gpudnn = use_gpudnn;
1511 1512 1513 1514
  phi::DenseTensor target_dense_tensor;
  target_dense_tensor.ShareDataWith(*dense_tensor);
  target_dense_tensor.set_meta(target_dense_meta);
  // Construct returned tensor
1515
  paddle::Tensor target_tensor(
1516 1517 1518 1519
      std::make_shared<phi::DenseTensor>(target_dense_tensor),
      self->tensor.name());
  target_tensor.set_autograd_meta(self->tensor.mutable_autograd_meta());
  VLOG(4) << "Tensor: " << target_tensor.name()
1520
          << " set use_gpudnn = " << use_gpudnn;
1521 1522 1523 1524 1525

  return ToPyObject(target_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1526 1527
static PyObject* tensor_method_set_vocab(TensorObject* self,
                                         PyObject* args,
1528 1529
                                         PyObject* kwargs) {
  EAGER_TRY
1530
  using Vocab = paddle::framework::Vocab;
1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542
  auto vocab = CastPyArg2Vocab(PyTuple_GET_ITEM(args, 0), 0);
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Vocab>() = vocab;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_set_string_list(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
1543
  using Strings = paddle::framework::Strings;
1544
  auto strings = CastPyArg2VectorOfString(PyTuple_GET_ITEM(args, 0), 0);
1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Strings>() = strings;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_map_tensor(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE_EQ(
1557 1558
      egr::IsVariableCompatTensor(self->tensor),
      true,
1559 1560
      paddle::platform::errors::Fatal(
          "this method is only effective for VariableCompatTensor"));
1561
  using Vocab = paddle::framework::Vocab;
1562 1563 1564 1565 1566 1567
  auto* var_tensor =
      static_cast<const egr::VariableCompatTensor*>(self->tensor.impl().get());
  return ToPyObject(var_tensor->Get<Vocab>());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588
static PyObject* tensor_method_get_non_zero_nums(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
    return ToPyObject(sparse_coo_tensor->nnz());
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
    return ToPyObject(sparse_csr_tensor->nnz());
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1589 1590 1591 1592 1593 1594 1595 1596 1597
static PyObject* tensor_method_get_non_zero_indices(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_coo_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCooTensor"));
  auto sparse_coo_tensor =
      std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
1598
  paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615
      sparse_coo_tensor->non_zero_indices()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_elements(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
1616
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1617 1618 1619 1620 1621
        sparse_coo_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1622
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637
        sparse_csr_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_crows(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1638
  paddle::Tensor tensor(
1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_crows()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_cols(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1653
  paddle::Tensor tensor(
1654 1655 1656 1657 1658
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_cols()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1659 1660
static PyObject* tensor_method_is_dense(TensorObject* self,
                                        PyObject* args,
1661 1662 1663 1664 1665 1666 1667 1668 1669
                                        PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
  return ToPyObject(self->tensor.is_dense_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1670 1671
static PyObject* tensor_method_is_sparse(TensorObject* self,
                                         PyObject* args,
1672 1673
                                         PyObject* kwargs) {
  EAGER_TRY
1674 1675 1676
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1677 1678 1679 1680 1681
  return ToPyObject(self->tensor.is_sparse_coo_tensor() ||
                    self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1682 1683
static PyObject* tensor_method_is_sparse_coo(TensorObject* self,
                                             PyObject* args,
1684 1685
                                             PyObject* kwargs) {
  EAGER_TRY
1686 1687 1688
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1689 1690 1691 1692
  return ToPyObject(self->tensor.is_sparse_coo_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1693 1694
static PyObject* tensor_method_is_sparse_csr(TensorObject* self,
                                             PyObject* args,
1695 1696
                                             PyObject* kwargs) {
  EAGER_TRY
1697 1698 1699
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1700 1701 1702 1703
  return ToPyObject(self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1704 1705
static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
                                             PyObject* args,
1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718
                                             PyObject* kwargs) {
  EAGER_TRY
  auto csr_tensor = self->tensor.to_sparse_csr();
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetStopGradient(
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient());
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetPersistable(
          egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  return ToPyObject(csr_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1719 1720 1721 1722 1723 1724 1725 1726 1727
static PyObject* tensor_method_is_same_shape(TensorObject* self,
                                             PyObject* args,
                                             PyObject* kwargs) {
  EAGER_TRY
  auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
  return ToPyObject(self->tensor.shape() == other.shape());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1728 1729
static PyObject* tensor__inplace_version(TensorObject* self,
                                         PyObject* args,
1730 1731 1732 1733 1734 1735 1736 1737
                                         PyObject* kwargs) {
  EAGER_TRY
  uint32_t inplace_version = self->tensor.current_inplace_version();

  return ToPyObject(inplace_version);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1738 1739
static PyObject* tensor_method_element_size(TensorObject* self,
                                            PyObject* args,
1740 1741
                                            PyObject* kwargs) {
  EAGER_TRY
1742
  uint32_t element_size = phi::SizeOf(self->tensor.dtype());
1743 1744 1745 1746 1747

  return ToPyObject(element_size);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1748 1749 1750 1751 1752
static PyObject* tensor__bump_inplace_version(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  self->tensor.bump_inplace_version();
1753
  RETURN_PY_NONE
1754 1755 1756
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1757 1758 1759 1760
static PyObject* tensor_method_is_selected_rows(TensorObject* self,
                                                PyObject* args,
                                                PyObject* kwargs) {
  EAGER_TRY
1761 1762 1763
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1764 1765 1766 1767
  return ToPyObject(self->tensor.is_selected_rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1768 1769
static PyObject* tensor_method_get_rows(TensorObject* self,
                                        PyObject* args,
1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780
                                        PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows"));
  auto selected_rows =
      std::dynamic_pointer_cast<phi::SelectedRows>(self->tensor.impl());
  return ToPyObject(selected_rows->rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1781 1782 1783 1784 1785 1786 1787 1788 1789 1790
static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
  }

1791
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1792 1793 1794 1795
  if (grad && grad->defined() && grad->is_dense_tensor() &&
      grad->initialized()) {
    grad->reset_inplace_version(set_to_zero);
  }
1796 1797
  RETURN_PY_NONE

1798 1799 1800
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1801 1802
static PyObject* tensor_method__share_memory(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
1803 1804 1805
                                             PyObject* kwargs) {
  EAGER_TRY
#ifndef _WIN32
1806 1807
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
W
wanghuancoder 已提交
1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823
                    platform::errors::InvalidArgument(
                        "Sharing memory only support CPU Tensor currently"));
  // 1. get LoDTensor
  auto* t =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl()).get();
  // 2. allocate shared memory
  void* data_ptr = t->data();
  size_t data_size =
      t->numel() *
      framework::SizeOfType(framework::TransToProtoVarType(t->dtype()));
  auto shared_writer_holder =
      memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
  // 3. maintain mmap fd set & backup ipc_name
  const std::string& ipc_name = shared_writer_holder->ipc_name();
  memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
  // 4. copy data & reset holder
1824 1825 1826 1827 1828
  memory::Copy(platform::CPUPlace(),
               shared_writer_holder->ptr(),
               platform::CPUPlace(),
               data_ptr,
               data_size);
W
wanghuancoder 已提交
1829 1830 1831 1832 1833
  t->ResetHolder(shared_writer_holder);
  return ToPyObject(t);
#else
  PADDLE_THROW(platform::errors::PermissionDenied(
      "Sharing memory in Windows OS is not supported currently"));
1834 1835
  RETURN_PY_NONE

W
wanghuancoder 已提交
1836 1837 1838 1839
#endif
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1840 1841
static PyObject* tensor__offset(TensorObject* self,
                                PyObject* args,
1842 1843 1844 1845
                                PyObject* kwargs) {
  EAGER_TRY
  auto t = std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  PADDLE_ENFORCE_EQ(
1846 1847
      t->IsInitialized(),
      true,
1848 1849 1850 1851 1852 1853 1854
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
                                        self->tensor.name()));

  return ToPyObject(t->offset());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1855 1856
static PyObject* tensor__grad_name(TensorObject* self,
                                   PyObject* args,
1857 1858
                                   PyObject* kwargs) {
  EAGER_TRY
1859
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1860 1861
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1862 1863 1864 1865 1866 1867 1868
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));
  return ToPyObject(grad->name());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1869 1870
static PyObject* tensor__grad_value(TensorObject* self,
                                    PyObject* args,
1871 1872
                                    PyObject* kwargs) {
  EAGER_TRY
1873
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1874 1875
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1876 1877 1878 1879 1880
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));

  if (!grad->defined()) {
1881
    RETURN_PY_NONE
1882 1883
  }
  if (grad->is_dense_tensor()) {
1884
    auto* grad_tensor = static_cast<phi::DenseTensor*>(grad->impl().get());
1885 1886 1887 1888
    return ToPyObject(grad_tensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::Fatal(
        "this method is only supported for DenseTensor"));
1889
    RETURN_PY_NONE
1890 1891 1892 1893
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1894 1895
static PyObject* tensor__unset_fake_empty(TensorObject* self,
                                          PyObject* args,
1896 1897
                                          PyObject* kwargs) {
  EAGER_TRY
1898
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1899 1900
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));

  bool is_leaf = egr::egr_utils_api::IsLeafTensor(self->tensor);
  if (is_leaf) {
    std::static_pointer_cast<egr::GradNodeAccumulation>(
        egr::EagerUtils::grad_node(self->tensor))
        ->SetFakeEmpty(false);
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1915 1916 1917 1918 1919
static PyObject* tensor_data_ptr(TensorObject* self,
                                 PyObject* args,
                                 PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.initialized() && self->tensor.is_dense_tensor()) {
S
sneaxiy 已提交
1920 1921 1922 1923
    return ToPyObject(
        (int64_t)std::dynamic_pointer_cast<phi::DenseTensor>(  // NOLINT
            self->tensor.impl())
            ->data());
1924 1925 1926 1927 1928
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943
static PyObject* tensor__grad_ivar(TensorObject* self,
                                   PyObject* args,
                                   PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Get grad for tensor: " << self->tensor.name();
  auto meta = egr::EagerUtils::nullable_autograd_meta(self->tensor);
  VLOG(6) << meta << " initialized: " << meta->Grad().initialized();
  if (meta && meta->Grad().initialized()) {
    return ToPyObject(meta->Grad());
  } else {
    RETURN_PY_NONE
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1944
#if defined(PADDLE_WITH_CUDA)
1945 1946
static PyObject* tensor_method__uva(TensorObject* self,
                                    PyObject* args,
1947 1948 1949
                                    PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Running in tensor_method__uva.";
1950 1951
  PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(),
                    true,
W
Weilong Wu 已提交
1952 1953 1954
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "DenseTensor currently."));
1955 1956
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
1957 1958 1959 1960
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "CPU Tensor currently."));
  int device_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
1961
  auto* self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
1962 1963
  tensor_uva(self_tensor, device_id);

1964 1965
  RETURN_PY_NONE

1966 1967 1968
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
#endif
J
Jack Zhou 已提交
1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980
static PyObject* tensor_method__is_string_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto string_tensor =
      std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
  if (string_tensor) {
    return ToPyObject(string_tensor->initialized());
  } else {
    return ToPyObject(false);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
1981

1982
PyMethodDef variable_methods[] = {
1983 1984 1985 1986
    {"numpy",
     (PyCFunction)(void (*)(void))tensor_method_numpy,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1987
    {"_is_initialized",
1988
     (PyCFunction)(void (*)(void))tensor_method__is_initialized,
1989 1990
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
1991
    {"_is_dense_tensor_hold_allocation",
1992 1993
     (PyCFunction)(void (*)(
         void))tensor_method__is_dense_tensor_hold_allocation,
1994 1995 1996 1997 1998 1999 2000 2001 2002 2003
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_copy_to",
     (PyCFunction)(void (*)(void))tensor_method__copy_to,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"copy_",
     (PyCFunction)(void (*)(void))tensor_method_copy_,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2004 2005 2006 2007
    {"clone",
     (PyCFunction)(void (*)(void))tensor_method_clone,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2008
    {"reconstruct_from_",
2009
     (PyCFunction)(void (*)(void))tensor_method_reconstruct_from_,
2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"retain_grads",
     (PyCFunction)(void (*)(void))tensor_retain_grads,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"clear_gradient",
     (PyCFunction)(void (*)(void))tensor_clear_gradient,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_dense",
     (PyCFunction)(void (*)(void))tensor_method_is_dense,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_zero_grads",
     (PyCFunction)(void (*)(void))tensor__zero_grads,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_share_buffer_to",
     (PyCFunction)(void (*)(void))tensor__share_buffer_to,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2032
    {"_is_shared_buffer_with",
2033
     (PyCFunction)(void (*)(void))tensor__is_shared_buffer_with,
2034 2035
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2036
    {"_share_underline_tensor_to",
2037
     (PyCFunction)(void (*)(void))tensor__share_underline_tensor_to,
2038 2039
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2040
    {"_is_shared_underline_tensor_with",
2041
     (PyCFunction)(void (*)(void))tensor__is_shared_underline_tensor_with,
2042 2043 2044 2045 2046 2047
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"detach",
     (PyCFunction)(void (*)(void))tensor_method_detach,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2048
    {"get_tensor",
2049
     (PyCFunction)(void (*)(void))tensor_method_get_underline_tensor,
2050 2051
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2052 2053
    {"get_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method_get_underline_selected_rows,
2054 2055
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2056 2057 2058 2059
    {"_get_tensor_from_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method__get_tensor_from_selected_rows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jiabin Yang 已提交
2060 2061
    {"_getitem_index_not_tensor",
     (PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor,
2062 2063
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2064 2065
    {"_getitem_from_offset",
     (PyCFunction)(void (*)(void))tensor__getitem_from_offset,
2066 2067
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2068 2069
    {"__setitem_eager_tensor__",
     (PyCFunction)(void (*)(void))tensor_method__setitem_eager_tensor,
2070 2071
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2072 2073
    {"_register_grad_hook",
     (PyCFunction)(void (*)(void))tensor_register_grad_hook,
2074 2075 2076 2077 2078 2079
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_remove_grad_hook",
     (PyCFunction)(void (*)(void))tensor_remove_grad_hook,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2080 2081
    {"_register_backward_hook",
     (PyCFunction)(void (*)(void))tensor_register_reduce_hook,
2082 2083 2084 2085 2086 2087 2088 2089 2090 2091
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_set_grad_type",
     (PyCFunction)(void (*)(void))tensor__set_grad_type,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_clear",
     (PyCFunction)(void (*)(void))tensor__clear,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jiabin Yang 已提交
2092 2093
    {"_copy_gradient_from",
     (PyCFunction)(void (*)(void))tensor__copy_gradient_from,
2094 2095
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2096 2097
    {"_tensor_use_gpudnn",
     (PyCFunction)(void (*)(void))tensor__use_gpudnn,
2098 2099
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2100 2101 2102
    /** the methods to adapt old dygraph, will be removed in the future **/
    {"set_string_list",
     (PyCFunction)(void (*)(void))tensor_method_set_string_list,
2103 2104 2105 2106 2107 2108
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"set_vocab",
     (PyCFunction)(void (*)(void))tensor_method_set_vocab,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2109 2110
    {"get_map_tensor",
     (PyCFunction)(void (*)(void))tensor_method_get_map_tensor,
2111 2112
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2113
    /***the method of sparse tensor****/
2114 2115 2116 2117
    {"nnz",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_nums,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145
    {"indices",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_indices,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"values",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_elements,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"crows",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_crows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"cols",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_cols,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse_coo",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse_coo,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse_csr",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2146 2147 2148 2149
    {"is_same_shape",
     (PyCFunction)(void (*)(void))tensor_method_is_same_shape,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2150 2151 2152 2153 2154 2155 2156 2157
    {"to_sparse_csr",
     (PyCFunction)(void (*)(void))tensor_method_to_sparse_csr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"element_size",
     (PyCFunction)(void (*)(void))tensor_method_element_size,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2158
    /***the method of sparse tensor****/
2159 2160 2161 2162
    {"_inplace_version",
     (PyCFunction)(void (*)(void))tensor__inplace_version,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2163 2164
    {"_bump_inplace_version",
     (PyCFunction)(void (*)(void))tensor__bump_inplace_version,
2165 2166
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2167 2168
    {"is_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method_is_selected_rows,
2169 2170 2171 2172 2173 2174
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"rows",
     (PyCFunction)(void (*)(void))tensor_method_get_rows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2175 2176
    {"_reset_grad_inplace_version",
     (PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version,
2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_share_memory",
     (PyCFunction)(void (*)(void))tensor_method__share_memory,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_offset",
     (PyCFunction)(void (*)(void))tensor__offset,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_grad_name",
     (PyCFunction)(void (*)(void))tensor__grad_name,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_grad_value",
     (PyCFunction)(void (*)(void))tensor__grad_value,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_unset_fake_empty",
     (PyCFunction)(void (*)(void))tensor__unset_fake_empty,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2199 2200 2201 2202
    {"data_ptr",
     (PyCFunction)(void (*)(void))tensor_data_ptr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2203 2204 2205 2206
    {"_grad_ivar",
     (PyCFunction)(void (*)(void))tensor__grad_ivar,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2207
#if defined(PADDLE_WITH_CUDA)
2208 2209 2210 2211
    {"_tensor_uva",
     (PyCFunction)(void (*)(void))tensor_method__uva,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2212
#endif
2213 2214
    {NULL, NULL, 0, NULL}};

J
Jack Zhou 已提交
2215 2216 2217 2218
// variable_methods for core.eager.StringTensor
PyMethodDef string_tensor_variable_methods[] = {
    {"numpy",
     (PyCFunction)(void (*)(void))tensor_method_numpy_for_string_tensor,
2219 2220
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2221 2222
    {"_is_initialized",
     (PyCFunction)(void (*)(void))tensor_method__is_initialized,
2223 2224
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2225
    {"_is_string_tensor_hold_allocation",
2226 2227
     (PyCFunction)(void (*)(
         void))tensor_method__is_string_tensor_hold_allocation,
2228 2229
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2230 2231 2232
    // TODO(zhoushunjie): Need to add _copy_to, copy_ for StringTensor.
    {NULL, NULL, 0, NULL}};

2233 2234
}  // namespace pybind
}  // namespace paddle