eager_method.cc 92.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// disable numpy compile error
12 13 14 15 16 17

#if defined(_MSC_VER)
#include <BaseTsd.h>
typedef SSIZE_T ssize_t;
#endif

18
#include <Python.h>
19 20 21 22
// Avoid a problem with copysign defined in pyconfig.h on Windows.
#ifdef copysign
#undef copysign
#endif
23 24

#include <string>
25
#include <unordered_map>
26 27
#include <vector>

28
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
29
#include "paddle/fluid/eager/api/all.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
31
#include "paddle/fluid/eager/autograd_meta.h"
32 33
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
34
#include "paddle/fluid/eager/utils.h"
35
#include "paddle/fluid/framework/convert_utils.h"
36
#include "paddle/fluid/framework/string_array.h"
37 38 39 40 41 42
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
J
Jiabin Yang 已提交
43
#include "paddle/fluid/pybind/slice_utils.h"
44
#include "paddle/fluid/pybind/uva_utils.h"
45 46 47 48
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
49 50
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
W
wanghuancoder 已提交
51
#include "pybind11/detail/internals.h"
52 53
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
W
wanghuancoder 已提交
54
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
J
Jiabin Yang 已提交
55
#include "paddle/fluid/eager/amp_utils.h"
56
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
J
Jiabin Yang 已提交
57
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
W
wanghuancoder 已提交
58
#include "paddle/fluid/framework/python_headers.h"
W
wanghuancoder 已提交
59
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
W
wanghuancoder 已提交
60
#include "paddle/fluid/pybind/tensor_py.h"
W
wanghuancoder 已提交
61
#include "paddle/phi/api/lib/data_transform.h"
W
wanghuancoder 已提交
62
#include "paddle/phi/core/ddim.h"
63
#include "paddle/phi/core/flags.h"
64
#include "paddle/phi/core/tensor_utils.h"
65
#include "paddle/phi/kernels/funcs/math_function.h"
L
LiYuRio 已提交
66 67 68
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
J
Jiabin Yang 已提交
69

70
PHI_DECLARE_bool(set_to_1d);
W
wanghuancoder 已提交
71
DECLARE_bool(use_stride_kernel);
72

73 74 75
namespace paddle {
namespace pybind {

76 77
extern void InitTensorWithNumpyValue(TensorObject* self,
                                     const pybind11::object& array,
78
                                     const paddle::platform::Place& place,
79
                                     bool zero_copy);
80

81
extern PyTypeObject* p_tensor_type;
82

J
Jiabin Yang 已提交
83
Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) {
84
  if (PyObject_TypeCheck(obj, p_tensor_type)) {
J
Jiabin Yang 已提交
85
    VLOG(6) << "Call GetSliceIndexFromTensor in Eager";
86
    paddle::Tensor tensor = CastPyArg2Tensor(obj, 0);
J
Jiabin Yang 已提交
87
    PADDLE_ENFORCE_EQ(
88 89
        tensor.initialized(),
        true,
J
Jiabin Yang 已提交
90 91 92 93 94 95 96 97
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in slice, however we got "
            "uninitialized tensor %s, please check your code.",
            tensor.name()));
    return GetSliceIndexFromTensor((*static_cast<phi::DenseTensor*>(
        CastPyArg2Tensor(obj, 0).impl().get())));
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
98
        "We should only get paddle::Tensor or VarBase in this "
J
Jiabin Yang 已提交
99 100 101 102
        "method, when you reach this means we got another type index."));
  }
}

103 104
static PyObject* tensor_method_numpy(TensorObject* self,
                                     PyObject* args,
105 106
                                     PyObject* kwargs) {
  EAGER_TRY
W
wanghuancoder 已提交
107 108 109 110 111 112 113 114 115
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl()) {
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
116 117 118 119 120
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_FLOAT_),
        1,
        py_dims,
        py_strides,
        nullptr,
W
wanghuancoder 已提交
121 122 123 124 125
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }
126 127
  auto tensor_dims = self->tensor.shape();
  auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type());
128
  auto sizeof_dtype = phi::SizeOf(self->tensor.type());
129 130
  Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
  Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
131
  size_t py_rank = tensor_dims.size();
132
  size_t numel = 1;
133
  if (py_rank == 0) {
134
    Py_ssize_t args_num = PyTuple_Size(args);
135 136
    // true by default
    bool set_to_1d = FLAGS_set_to_1d;
137 138 139 140 141 142 143
    if (args_num == (Py_ssize_t)1) {
      PyObject* obj = PyTuple_GET_ITEM(args, 0);
      if (obj == Py_False) {
        set_to_1d = false;
      }
    }
    if (set_to_1d) {
144
      // 0D Tensor hack process to 1D numpy, will remove in release 2.6
145 146 147 148 149
      VLOG(0)
          << "Warning:: 0D Tensor cannot be used as 'Tensor.numpy()[0]' . In "
             "order to avoid this problem, "
             "0D Tensor will be changed to 1D numpy currently, but it's not "
             "correct and will be "
150 151
             "removed in release 2.6. For Tensor contain only one element, "
             "Please "
152
             "modify "
153
             " 'Tensor.numpy()[0]' to 'float(Tensor)' as soon as "
154
             "possible, "
155
             "otherwise 'Tensor.numpy()[0]' will raise error in release 2.6.";
156 157 158 159
      py_rank = 1;
      py_dims[0] = 1;
      py_strides[0] = sizeof_dtype * numel;
    }
W
wanghuancoder 已提交
160 161 162 163 164 165 166 167
  } else if (self->tensor.is_dense_tensor()) {
    auto tensor_stride = self->tensor.strides();

    for (int i = tensor_dims.size() - 1; i >= 0; --i) {
      py_dims[i] = static_cast<size_t>(tensor_dims[i]);
      py_strides[i] = sizeof_dtype * tensor_stride[i];
      numel *= py_dims[i];
    }
168 169 170 171 172 173
  } else {
    for (int i = tensor_dims.size() - 1; i >= 0; --i) {
      py_dims[i] = static_cast<size_t>(tensor_dims[i]);
      py_strides[i] = sizeof_dtype * numel;
      numel *= py_dims[i];
    }
174
  }
W
wanghuancoder 已提交
175 176

  if (!self->tensor.impl()->initialized()) {
W
wanghuancoder 已提交
177 178 179 180 181 182 183 184 185 186 187
    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
        api.PyArray_DescrFromType_(numpy_dtype),
        py_rank,
        py_dims,
        py_strides,
        nullptr,
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);

188
    if (tensor_dims.empty()) {
189 190 191
      py_dims[0] = 0;
      py_strides[0] = 0;
      PyObject* array = api.PyArray_NewFromDescr_(
192 193 194 195 196 197
          api.PyArray_Type_,
          api.PyArray_DescrFromType_(numpy_dtype),
          1,
          py_dims,
          py_strides,
          nullptr,
198 199 200 201 202
          pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
              pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
          nullptr);
      return array;
    }
W
wanghuancoder 已提交
203 204 205
    return array;
  }

W
wanghuancoder 已提交
206 207 208
  phi::DenseTensor cpu_tensor;
  platform::CPUPlace cpu_place;

209
  if (self->tensor.is_cpu() || self->tensor.is_gpu_pinned()) {
W
wanghuancoder 已提交
210
    eager_gil_scoped_release guard;
211
    platform::CPUPlace place;
212 213 214 215
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
216 217
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
218 219 220 221 222
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
223
      // deep copy
W
wanghuancoder 已提交
224 225 226 227 228
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           place,
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
229 230 231 232
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
W
wanghuancoder 已提交
233 234 235 236 237
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
238
      // deep copy
W
wanghuancoder 已提交
239 240 241 242 243
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           place,
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
244 245
    }

246
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
247
  } else if (self->tensor.is_gpu()) {
W
wanghuancoder 已提交
248
    eager_gil_scoped_release guard;
249 250 251 252 253
#if defined(PADDLE_WITH_CUDA)
    gpuMemcpyKind kind = cudaMemcpyDeviceToHost;
#elif defined(PADDLE_WITH_HIP)
    gpuMemcpyKind kind = hipMemcpyDeviceToHost;
#endif
254 255 256 257
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
258 259
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
260 261 262 263 264 265 266 267 268
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
                                      dense_tensor->Holder()->ptr(),
                                      dense_tensor->Holder()->size(),
                                      kind);
269 270 271 272
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
W
wanghuancoder 已提交
273 274 275 276 277 278 279 280 281
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
                                      dense_tensor->Holder()->ptr(),
                                      dense_tensor->Holder()->size(),
                                      kind);
282
    }
283
#endif
C
Chen Weihang 已提交
284 285 286 287 288 289 290
#if defined(PADDLE_WITH_XPU)
  } else if (self->tensor.is_xpu()) {
    platform::CPUPlace place;
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
291 292
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
293 294 295 296 297 298 299 300 301 302
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           dense_tensor->place(),
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
C
Chen Weihang 已提交
303 304 305 306
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
W
wanghuancoder 已提交
307 308 309 310 311 312 313 314 315 316
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           dense_tensor->place(),
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
C
Chen Weihang 已提交
317 318
    }
#endif
319 320
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  } else if (self->tensor.is_custom_device()) {
W
wanghuancoder 已提交
321
    eager_gil_scoped_release guard;
322 323 324 325
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
326 327
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
328 329 330 331 332
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
333
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
W
wanghuancoder 已提交
334 335 336
          ->MemoryCopyD2H(cpu_tensor.Holder()->ptr(),
                          dense_tensor->Holder()->ptr(),
                          dense_tensor->Holder()->size());
337 338 339 340
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
C
co63oc 已提交
341
      // TODO(qili93): temporary for ascend npu performance to be removed along
342
      // with npu_identity op
343
      paddle::Tensor temp_tensor(std::make_shared<phi::DenseTensor>());
344 345 346 347 348
      if (dense_tensor->storage_properties_initialized()) {
        temp_tensor = npu_identity_ad_func(self->tensor, -1);
        dense_tensor =
            std::dynamic_pointer_cast<phi::DenseTensor>(temp_tensor.impl());
      }
W
wanghuancoder 已提交
349 350 351 352 353
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
354
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
W
wanghuancoder 已提交
355 356 357
          ->MemoryCopyD2H(cpu_tensor.Holder()->ptr(),
                          dense_tensor->Holder()->ptr(),
                          dense_tensor->Holder()->size());
358 359
    }
#endif
360 361 362
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Tensor.numpy() only support cpu tensor."));
363
    RETURN_PY_NONE
364 365
  }

W
wanghuancoder 已提交
366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385
  void* array_buffer = cpu_tensor.Holder()->ptr();
  size_t array_offset = cpu_tensor.offset();

  PyObject* base = ToPyObject(paddle::Tensor(
      std::make_shared<phi::DenseTensor>(std::move(cpu_tensor))));

  PyObject* array = api.PyArray_NewFromDescr_(
      api.PyArray_Type_,
      api.PyArray_DescrFromType_(numpy_dtype),
      py_rank,
      py_dims,
      py_strides,
      reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(array_buffer) +
                              array_offset),
      pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
          pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
      nullptr);

  api.PyArray_SetBaseObject_(array, base);

386 387 388 389
  return array;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jack Zhou 已提交
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
static PyObject* tensor_method_numpy_for_string_tensor(TensorObject* self,
                                                       PyObject* args,
                                                       PyObject* kwargs) {
  EAGER_TRY
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl() || !self->tensor.impl()->initialized()) {
    VLOG(6) << "The StringTensor is uninitialized. Return the empty string "
               "numpy array.";
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
405 406 407 408 409
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_UNICODE_),
        1,
        py_dims,
        py_strides,
        nullptr,
J
Jack Zhou 已提交
410 411 412 413 414 415 416 417 418 419 420 421 422
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }

  if (self->tensor.is_cpu()) {
    VLOG(6) << "Getting StringTensor's numpy value";
    auto string_tensor =
        std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
    const auto* st_ptr = string_tensor->data();
    auto numel = self->tensor.numel();
    auto tensor_dims = self->tensor.shape();
W
wanghuancoder 已提交
423 424
    // Get the max unicode length of StringTensor to create numpy unicode
    // string array.
J
Jack Zhou 已提交
425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
    auto* longest_pstring = std::max_element(
        st_ptr, st_ptr + numel, [](const auto& a, const auto& b) {
          auto a_unicode_len =
              phi::strings::GetUnicodeStrLen(a.data(), a.size());
          auto b_unicode_len =
              phi::strings::GetUnicodeStrLen(b.data(), b.size());
          return a_unicode_len < b_unicode_len;
        });
    size_t max_unicode_length = phi::strings::GetUnicodeStrLen(
        longest_pstring->data(), longest_pstring->size());
    max_unicode_length = (max_unicode_length == 0) ? 1 : max_unicode_length;
    VLOG(6) << "The max unicode length is " << max_unicode_length;
    auto sp = std::make_unique<uint32_t[]>(max_unicode_length * numel);
    auto py_array_data = sp.get();
    memset(py_array_data, 0, max_unicode_length * numel * sizeof(uint32_t));
    for (int64_t i = 0; i < numel; ++i) {
      auto curr_unicode_len =
          phi::strings::GetUnicodeStrLen(st_ptr[i].data(), st_ptr[i].size());
      phi::strings::GetUnicodeStr(st_ptr[i].data(),
                                  py_array_data + i * max_unicode_length,
                                  curr_unicode_len);
    }
    py::array array(py::dtype("U" + std::to_string(max_unicode_length)),
448 449 450
                    tensor_dims,
                    {},
                    py_array_data);
J
Jack Zhou 已提交
451 452 453 454
    return array.release().ptr();
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "StringTensor.numpy() only support cpu tensor."));
455
    RETURN_PY_NONE
J
Jack Zhou 已提交
456 457 458 459
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

460 461 462 463
static PyObject* tensor_method__is_initialized(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
464
  return ToPyObject(self->tensor.initialized());
465 466 467
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
468 469 470 471 472 473 474 475 476 477 478 479 480 481
static PyObject* tensor_method__is_dense_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto dense_tensor =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  if (dense_tensor) {
    return ToPyObject(dense_tensor->IsInitialized());
  } else {
    return ToPyObject(false);
  }

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

482
static void IncreaseTensorReferenceCountUntilCopyComplete(
483
    const paddle::Tensor& tensor, const platform::Place& place) {
484 485 486 487 488 489 490 491
  auto place_ = platform::is_gpu_place(place) ? place : tensor.place();

  auto tracer = egr::Controller::Instance().GetCurrentTracer();
  auto gc = tracer->MutableGarbageCollectorIfNotExists(place_);

  // Note(dev): This is an empty callback, the only way is to "reference"
  // inner memory Holder, so it will not be destructed until the kernels
  // launched at current stream of given place is finished, such as
C
co63oc 已提交
492
  // CUDAPinned Mem -> CUDA by cudaMemcpyAsync.
493 494 495 496 497 498 499
  auto callback = [tensor, place_]() {
    VLOG(3) << "Run callback of Tensor:" << tensor.name() << " at place "
            << place_;
  };
  gc->DirectClearCallback(callback);
}

500 501
static PyObject* tensor_method__copy_to(TensorObject* self,
                                        PyObject* args,
502 503
                                        PyObject* kwargs) {
  EAGER_TRY
504 505
  auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
506
  paddle::Tensor cp_tensor;
W
wanghuancoder 已提交
507 508 509 510 511 512 513 514 515 516
  {
    eager_gil_scoped_release guard;
    cp_tensor = self->tensor.copy_to(place, blocking);
    if (!blocking) {
      IncreaseTensorReferenceCountUntilCopyComplete(self->tensor, place);
    }
    egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
    egr::EagerUtils::autograd_meta(&cp_tensor)
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
517
  }
518 519 520 521
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

522 523
static PyObject* tensor_method_cpu(TensorObject* self,
                                   PyObject* args,
524 525
                                   PyObject* kwargs) {
  EAGER_TRY
526
  paddle::Tensor cp_tensor;
W
wanghuancoder 已提交
527 528 529 530 531 532 533 534
  {
    eager_gil_scoped_release guard;
    cp_tensor = self->tensor.copy_to(phi::CPUPlace(), true);
    egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
    egr::EagerUtils::autograd_meta(&cp_tensor)
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  }
535 536 537 538
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

539 540 541 542
static PyObject* tensor_method_reconstruct_from_(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
543
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
544
  std::string orig_name = self->tensor.name();
545 546
  VLOG(6) << "Start Reconstructing Tensor from" << src_tensor.name() << " to "
          << orig_name;
547
  self->tensor = src_tensor;
548 549

  // Recover source name
550
  self->tensor.set_name(orig_name);
551 552

  VLOG(6) << "Finished Reconstructing Tensor from" << src_tensor.name()
553
          << " to " << self->tensor.name();
554 555
  RETURN_PY_NONE

556 557 558
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

559 560
static PyObject* tensor_method_copy_(TensorObject* self,
                                     PyObject* args,
561 562
                                     PyObject* kwargs) {
  EAGER_TRY
563
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
564
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
565
  VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to "
566
          << self->tensor.name();
567
  if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
568
    eager_gil_scoped_release guard;
569
    egr::EagerUtils::autograd_meta(&(self->tensor))
570 571
        ->SetStopGradient(
            egr::EagerUtils::autograd_meta(&(src_tensor))->StopGradient());
572
    egr::EagerUtils::autograd_meta(&(self->tensor))
573 574
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(src_tensor))->Persistable());
575
    if (src_tensor.initialized()) {
C
Chen Weihang 已提交
576
      self->tensor.copy_(src_tensor, src_tensor.place(), blocking);
577 578 579
    }
  } else {
    if (src_tensor.initialized()) {
W
wanghuancoder 已提交
580
      eager_gil_scoped_release guard;
C
Chen Weihang 已提交
581
      self->tensor.copy_(src_tensor, self->tensor.place(), blocking);
582
    }
583 584
  }

585
  VLOG(6) << "Finish Copy Tensor " << src_tensor.name() << " to "
586
          << self->tensor.name();
587 588
  RETURN_PY_NONE

589 590 591
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

592 593 594 595
static PyObject* tensor_method_clone(TensorObject* self,
                                     PyObject* args,
                                     PyObject* kwargs) {
  EAGER_TRY
596
  paddle::Tensor out;
W
wanghuancoder 已提交
597 598 599 600 601 602 603 604 605
  {
    eager_gil_scoped_release guard;
    PADDLE_ENFORCE_EQ(
        self->tensor.initialized(),
        true,
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in clone, however we got "
            "uninitialized tensor %s, please check your code.",
            self->tensor.name()));
606

W
wanghuancoder 已提交
607 608
    out = assign_ad_func(self->tensor);
  }
609 610 611 612
  return ToPyObject(out);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

613 614
static PyObject* tensor_retain_grads(TensorObject* self,
                                     PyObject* args,
615
                                     PyObject* kwargs) {
616
  EAGER_TRY
617
  if (egr::Controller::Instance().HasGrad()) {
W
wanghuancoder 已提交
618
    eager_gil_scoped_release guard;
619
    auto meta = egr::EagerUtils::autograd_meta(&(self->tensor));
620
    if (!meta->GetMutableGradNode()) {
621
      VLOG(6) << "Make grad node of tensor: " << self->tensor.name()
622
              << "become accumulation node";
623
      meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
624
    }
625
    egr::egr_utils_api::RetainGradForTensor(self->tensor);
626
  }
627 628
  RETURN_PY_NONE

629 630 631
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

632 633
static PyObject* tensor_clear_gradient(TensorObject* self,
                                       PyObject* args,
634
                                       PyObject* kwargs) {
635
  EAGER_TRY
636
  VLOG(4) << "ClearGradient " << self->tensor.name();
637

638 639 640
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
J
Jiabin Yang 已提交
641
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
642 643
  }

644
  paddle::Tensor* grad;
645
  bool is_leaf = egr::EagerUtils::IsLeafTensor(self->tensor);
J
Jiabin Yang 已提交
646
  if (is_leaf) {
647 648 649 650 651 652
    grad = egr::EagerUtils::mutable_grad(self->tensor);
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
653
  } else {
654
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
655
    grad = meta->MutableGrad();
656 657
  }

658
  if (grad->impl()) {
W
wanghuancoder 已提交
659
    eager_gil_scoped_release guard;
660 661 662 663 664 665 666 667 668 669
    if (grad->is_selected_rows()) {
      auto selected_rows =
          std::dynamic_pointer_cast<phi::SelectedRows>(grad->impl());
      if (selected_rows->mutable_value()->IsInitialized()) {
        selected_rows->mutable_rows()->clear();
        selected_rows->mutable_value()->clear();
      }
    } else if (grad->is_dense_tensor()) {
      if (grad->initialized()) {
        if (set_to_zero) {
670 671 672 673
          auto* grad_t = static_cast<phi::DenseTensor*>(grad->impl().get());
          auto* dev_ctx =
              platform::DeviceContextPool::Instance().Get(grad_t->place());
          phi::funcs::set_constant(*dev_ctx, grad_t, 0.0);
J
Jiabin Yang 已提交
674 675 676 677 678
          if (is_leaf) {
            std::static_pointer_cast<egr::GradNodeAccumulation>(
                egr::EagerUtils::grad_node(self->tensor))
                ->SetFakeEmpty(true);
          }
679 680 681 682 683 684 685
        } else {
          VLOG(4) << "Gradient of " << self->tensor.name()
                  << " is initialized, will be released.";
          auto dense_tensor =
              std::dynamic_pointer_cast<phi::DenseTensor>(grad->impl());
          dense_tensor->MoveMemoryHolder();
        }
686 687
      }
    }
688
  }
689

690 691
  RETURN_PY_NONE

692 693 694
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

695 696
static PyObject* tensor__zero_grads(TensorObject* self,
                                    PyObject* args,
697
                                    PyObject* kwargs) {
698
  EAGER_TRY
699
  VLOG(4) << "ZeroGrads " << self->tensor.name();
700

701
  if (egr::EagerUtils::IsLeafTensor(self->tensor)) {
W
wanghuancoder 已提交
702
    eager_gil_scoped_release guard;
703
    // Add RetainGrad as PostHook to AccumulationNode
704
    paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
705 706 707 708 709 710
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
    if (grad->initialized()) {
711 712 713 714 715 716 717
      if (grad->is_dense_tensor()) {
        auto* t = static_cast<phi::DenseTensor*>(grad->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl());
      }
718
    }
719
  } else {
W
wanghuancoder 已提交
720
    eager_gil_scoped_release guard;
721
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
722
    if (meta->MutableGrad()->initialized()) {
723 724 725 726 727 728 729 730 731
      if (meta->MutableGrad()->is_dense_tensor()) {
        auto* t =
            static_cast<phi::DenseTensor*>(meta->MutableGrad()->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        meta->MutableGrad()->set_impl(
            paddle::experimental::zeros_like(*(meta->MutableGrad())).impl());
      }
732
    }
733 734
  }

735 736
  RETURN_PY_NONE

737 738 739
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

740 741
static PyObject* tensor__share_buffer_to(TensorObject* self,
                                         PyObject* args,
742 743
                                         PyObject* kwargs) {
  EAGER_TRY
744
  paddle::Tensor* dst_ptr =
745
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
746 747
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
748 749 750
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
751
                        self->tensor.name()));
752
  auto* src_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
753 754 755
  if (!dst_ptr->defined()) {
    dst_ptr->set_impl(std::make_shared<phi::DenseTensor>());
  }
756
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
B
Baibaifan 已提交
757
  dst_tensor->ShareBufferWith(*src_tensor);
758
  dst_tensor->ShareDataTypeWith(*src_tensor);
759 760
  RETURN_PY_NONE

761 762 763
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

764 765 766 767
static PyObject* tensor__is_shared_buffer_with(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
768
  paddle::Tensor* dst_ptr =
769
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
770 771
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
772 773 774
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
775
                        self->tensor.name()));
776
  bool res = false;
777
  if (!self->tensor.defined() || !dst_ptr->defined()) {
778 779
    return ToPyObject(res);
  }
780 781
  auto* self_ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
782 783 784 785 786
  res = dst_tensor->IsSharedBufferWith(*self_ptr);
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

787 788 789 790
static PyObject* tensor__share_underline_tensor_to(TensorObject* self,
                                                   PyObject* args,
                                                   PyObject* kwargs) {
  EAGER_TRY
791
  paddle::Tensor* src_ptr =
792
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
793 794
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
795 796 797
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
798 799
                        self->tensor.name()));
  src_ptr->set_impl(self->tensor.impl());
800 801
  RETURN_PY_NONE

802 803 804
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

805 806 807 808
static PyObject* tensor__is_shared_underline_tensor_with(TensorObject* self,
                                                         PyObject* args,
                                                         PyObject* kwargs) {
  EAGER_TRY
809
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
810 811
  PADDLE_ENFORCE_EQ(src_tensor.initialized(),
                    true,
812 813 814 815 816
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
                        src_tensor.name()));
  bool res = false;
817
  if (!self->tensor.defined() || !src_tensor.defined()) {
818 819
    return ToPyObject(res);
  }
820
  res = (self->tensor.impl().get() == src_tensor.impl().get());
821 822 823 824
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

825 826
static PyObject* tensor_method_detach(TensorObject* self,
                                      PyObject* args,
827 828
                                      PyObject* kwargs) {
  EAGER_TRY
829
  PADDLE_ENFORCE_EQ(
830
      self->tensor.defined(),
831
      true,
832
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
833
                                        self->tensor.name()));
834

835
  PyObject* obj = p_tensor_type->tp_alloc(p_tensor_type, 0);
836
  if (obj) {
837
    auto v = reinterpret_cast<TensorObject*>(obj);
838
    new (&(v->tensor)) paddle::Tensor();
839 840 841 842
    v->tensor.set_impl(self->tensor.impl());
    v->tensor.set_name(egr::Controller::Instance().GenerateUniqueName());
    auto autograd_meta_src = egr::EagerUtils::autograd_meta(&(self->tensor));
    auto autograd_meta = egr::EagerUtils::autograd_meta(&(v->tensor));
843 844 845 846 847 848 849 850 851 852
    autograd_meta->SetPersistable(autograd_meta_src->Persistable());
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "tp_alloc return null, can not new a PyObject."));
  }

  return obj;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871
static PyObject* tensor_method_detach_(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE_EQ(
      self->tensor.defined(),
      true,
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
                                        self->tensor.name()));

  auto autograd_meta = std::make_shared<egr::AutogradMeta>();
  autograd_meta->SetPersistable(
      egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  self->tensor.set_autograd_meta(autograd_meta);

  return reinterpret_cast<PyObject*>(self);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

872 873 874 875
static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
876
  if (!self->tensor.defined()) {
877 878 879
    // The original `get_tensor` method of Variable will create a empty tensor
    phi::DenseTensor empty_tensor;
    return ToPyObject(&empty_tensor);
880
  }
881
  if (self->tensor.is_dense_tensor()) {
882
    auto* tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
883 884
    VLOG(6) << "tensor: " << tensor->IsInitialized();
    return ToPyObject(tensor);
L
LiYuRio 已提交
885 886
  } else if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE
887 888
    auto* tensor =
        static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
889
    VLOG(6) << "dist tensor: " << tensor->defined();
L
LiYuRio 已提交
890 891 892 893
    return ToPyObject(tensor);
#else
    RETURN_PY_NONE
#endif
894
  } else {
895
    RETURN_PY_NONE
896 897 898 899
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

900 901 902 903 904
static PyObject* tensor_method_get_underline_selected_rows(TensorObject* self,
                                                           PyObject* args,
                                                           PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
905
    RETURN_PY_NONE
906 907 908 909 910 911
  }
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    return ToPyObject(selected_rows);
  } else {
912
    RETURN_PY_NONE
913 914 915 916
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

917 918 919 920 921 922 923 924 925 926 927 928 929 930
static PyObject* tensor_method__get_tensor_from_selected_rows(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows."));

  auto* selected_rows =
      static_cast<phi::SelectedRows*>(self->tensor.impl().get());

  PADDLE_ENFORCE(
      selected_rows->initialized(),
      paddle::platform::errors::Fatal("SelectedRows must be initialized."));

931 932
  auto* dense_tensor =
      static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
L
Leo Chen 已提交
933
  VLOG(4) << "dense_tensor: " << dense_tensor->IsInitialized();
934

935
  auto t = paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
936 937 938 939 940 941 942
  t.set_impl(std::make_shared<phi::DenseTensor>(*dense_tensor));

  return ToPyObject(t);

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jiabin Yang 已提交
943 944 945
static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
946
  EAGER_TRY
J
Jiabin Yang 已提交
947 948 949
  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  VLOG(4) << "Call _getitem_index_not_tensor";
  std::vector<int> slice_axes, slice_starts, slice_ends, slice_strides,
W
wanghuancoder 已提交
950 951
      decrease_axis, none_axes, infer_flags;
  std::vector<int64_t> list_select_idxs;
J
Jiabin Yang 已提交
952 953
  // if index is a list, list_select_flag will be true
  bool list_select_flag = false;
954 955
  // Note(0x45f): Using defined() instead of initialized()
  // to support slice tensor which shape like [0, 0, 0].
J
Jiabin Yang 已提交
956
  PADDLE_ENFORCE_EQ(
957
      self->tensor.defined(),
958
      true,
J
Jiabin Yang 已提交
959 960 961 962 963
      platform::errors::InvalidArgument(
          "tensor %s has not been initialized, we can only slice initialized "
          "tensor please init it first with numpy or other tensor.",
          self->tensor.name()));
  auto tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
964 965 966 967 968 969 970 971 972 973 974
  ParseIndexingSlice(tensor,
                     _index,
                     &slice_axes,
                     &slice_starts,
                     &slice_ends,
                     &slice_strides,
                     &decrease_axis,
                     &none_axes,
                     &infer_flags,
                     &list_select_idxs,
                     &list_select_flag);
J
Jiabin Yang 已提交
975

976 977 978 979
  auto out =
      slice_axes.empty() && !list_select_flag
          ? self->tensor
          : paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
J
Jiabin Yang 已提交
980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995

  if (!slice_axes.empty()) {
    framework::AttributeMap attrs = {{"axes", slice_axes},
                                     {"starts", slice_starts},
                                     {"ends", slice_ends},
                                     {"infer_flags", infer_flags},
                                     {"decrease_axis", decrease_axis}};
    std::string op_type = "slice";
    for (auto stride : slice_strides) {
      if (stride != 1) {
        op_type = "strided_slice";
        attrs.insert({"strides", slice_strides});
        attrs.erase("decrease_axis");
        break;
      }
    }
996 997 998 999 1000 1001
    std::vector<int64_t> slice_axes_tmp(slice_axes.begin(), slice_axes.end());
    std::vector<int64_t> infer_flags_tmp(infer_flags.begin(),
                                         infer_flags.end());
    std::vector<int64_t> decrease_axis_tmp(decrease_axis.begin(),
                                           decrease_axis.end());

J
Jiabin Yang 已提交
1002
    if (op_type == "slice") {
W
wanghuancoder 已提交
1003
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
1004 1005 1006 1007 1008 1009
      out = slice_ad_func(self->tensor,
                          slice_axes_tmp,
                          slice_starts,
                          slice_ends,
                          infer_flags_tmp,
                          decrease_axis_tmp);
J
Jiabin Yang 已提交
1010
    } else if (op_type == "strided_slice") {
W
wanghuancoder 已提交
1011
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
1012
      out = strided_slice_ad_func(
1013
          self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
J
Jiabin Yang 已提交
1014 1015 1016 1017 1018 1019 1020 1021 1022
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Slice is only support slice and strided_slice, but we got %s which "
          "is impossible, please check your code first or contact us by "
          "issue. ",
          op_type));
    }
  }

1023
  bool set_to_1d = FLAGS_set_to_1d;
1024 1025 1026 1027 1028 1029

  if (set_to_1d) {
    // NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
    // with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
    // otherwise the output shape will be not correct.
    if (static_cast<int>(decrease_axis.size()) == tensor->dims().size()) {
J
JYChen 已提交
1030
      VLOG(1)
1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042
          << "Warning: In Tensor '__getitem__', if the number of scalar "
             "elements "
             "in the index is equal to the rank of the Tensor, the output "
             "should "
             "be 0-D. In order to be consistent with the behavior of previous "
             "versions, it will be processed to 1-D. But it is not correct and "
             "will be "
             "removed in release 2.6. "
             "If 1-D is still wanted, please modify the index element from "
             "scalar to slice "
             "(e.g. 'x[i]' => 'x[i:i+1]'). ";
      if (!none_axes.empty()) {
1043 1044 1045
        none_axes.pop_back();
      }
    }
1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
  }
  if (!none_axes.empty()) {
    paddle::Tensor new_out;
    {
      eager_gil_scoped_release guard;
      // Deal with cases that decrease_axes is not empty
      // For example:
      // # x.shape: (2,3,4)
      // out = x[0, 0:2, None] # out.shape : (2, 1, 4)
      for (auto& axis : none_axes) {
        int len = 0;
        for (int da : decrease_axis) {
          if (da < axis) {
            len++;
J
Jiabin Yang 已提交
1060 1061
          }
        }
1062
        axis -= len;
J
Jiabin Yang 已提交
1063
      }
1064
      new_out = unsqueeze_ad_func(out, none_axes);
J
Jiabin Yang 已提交
1065
    }
1066
    return ToPyObject(new_out);
J
Jiabin Yang 已提交
1067 1068 1069 1070
  }

  // the index is a list
  if (list_select_flag) {
W
wanghuancoder 已提交
1071
    eager_gil_scoped_release guard;
W
wanghuancoder 已提交
1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084
    if (FLAGS_use_stride_kernel && list_select_idxs.size() == 1) {
      out = index_select_strided_ad_func(self->tensor, list_select_idxs[0], 0);
    } else {
      auto select_index =
          paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
      auto idx_tensor = std::make_shared<phi::DenseTensor>();
      select_index.set_impl(idx_tensor);
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(
          egr::Controller::Instance().GetExpectedPlace());
      paddle::framework::TensorFromVector(
          list_select_idxs, *dev_ctx, idx_tensor.get());
      out = index_select_ad_func(self->tensor, select_index, 0);
    }
J
Jiabin Yang 已提交
1085 1086 1087
  }

  return ToPyObject(out);
1088 1089 1090
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1091 1092
static PyObject* tensor__getitem_from_offset(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
1093 1094
                                             PyObject* kwargs) {
  EAGER_TRY
1095 1096 1097 1098 1099 1100 1101 1102
  phi::DenseTensor* ptr = nullptr;
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    ptr = static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
  } else {
    ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  }
1103 1104 1105
  PADDLE_ENFORCE_NOT_NULL(ptr,
                          platform::errors::InvalidArgument(
                              "%s is not a DenseTensor.", self->tensor.name()));
W
wanghuancoder 已提交
1106 1107
  const auto& tensor = *ptr;
  PADDLE_ENFORCE_EQ(
1108 1109
      tensor.IsInitialized(),
      true,
W
wanghuancoder 已提交
1110 1111 1112 1113 1114 1115 1116
      platform::errors::InvalidArgument(
          "Tensor of %s is Empty, please check if it has no data.",
          self->tensor.name()));

  const auto& tensor_dims = tensor.dims();

  std::vector<size_t> dims(tensor_dims.size());
W
wanghuancoder 已提交
1117
  std::vector<size_t> stride = phi::vectorize<size_t>(tensor.strides());
W
wanghuancoder 已提交
1118 1119 1120 1121 1122 1123 1124 1125

  size_t numel = 1;
  for (int i = tensor_dims.size() - 1; i >= 0; --i) {
    dims[i] = static_cast<size_t>(tensor_dims[i]);
    numel *= dims[i];
  }
  size_t offset = 0;
  if (PyTuple_Size(args) == 0) {
1126 1127
    PADDLE_ENFORCE_EQ(numel,
                      1,
W
wanghuancoder 已提交
1128 1129 1130 1131 1132 1133
                      platform::errors::InvalidArgument(
                          "only one element tensors can be converted to Python "
                          "scalars when no input coordinates"));
  } else if (PyTuple_Size(args) == 1) {
    offset = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
    PADDLE_ENFORCE_LT(
1134 1135
        offset,
        numel,
W
wanghuancoder 已提交
1136 1137 1138
        platform::errors::InvalidArgument(
            "index %d is out of bounds for size %d", offset, numel));
  } else {
1139 1140
    PADDLE_ENFORCE_EQ(PyTuple_Size(args),
                      dims.size(),
W
wanghuancoder 已提交
1141 1142 1143 1144 1145 1146
                      platform::errors::InvalidArgument(
                          "incorrect number of indices for Tensor"));

    for (Py_ssize_t i = 0; i < PyTuple_Size(args); ++i) {
      size_t index = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, i), i);
      PADDLE_ENFORCE_LT(
1147 1148
          index,
          dims[i],
W
wanghuancoder 已提交
1149
          platform::errors::InvalidArgument(
1150 1151 1152
              "index %d is out fo bounds for axis %d with size %d",
              index,
              i,
W
wanghuancoder 已提交
1153
              dims[i]));
W
wanghuancoder 已提交
1154
      offset += index * stride[i];
W
wanghuancoder 已提交
1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181
    }
  }
#define PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(_) \
  _(bool, DataType::BOOL)                     \
  _(int8_t, DataType::INT8)                   \
  _(uint8_t, DataType::UINT8)                 \
  _(int16_t, DataType::INT16)                 \
  _(uint16_t, DataType::UINT16)               \
  _(int32_t, DataType::INT32)                 \
  _(uint32_t, DataType::UINT32)               \
  _(int64_t, DataType::INT64)                 \
  _(uint64_t, DataType::UINT64)               \
  _(bfloat16, DataType::BFLOAT16)             \
  _(float16, DataType::FLOAT16)               \
  _(float, DataType::FLOAT32)                 \
  _(double, DataType::FLOAT64)                \
  _(complex64, DataType::COMPLEX64)           \
  _(complex128, DataType::COMPLEX128)

#define TENSOR_TO_PY_SCALAR(T, proto_type)                                   \
  if (tensor.dtype() == proto_type) {                                        \
    auto numpy_dtype = TensorDtype2NumpyDtype(proto_type);                   \
    T b = paddle::pybind::TensorGetElement<T>(tensor, offset);               \
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];                  \
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];               \
    auto& api = pybind11::detail::npy_api::get();                            \
    PyObject* array = api.PyArray_NewFromDescr_(                             \
1182 1183
        api.PyArray_Type_,                                                   \
        api.PyArray_DescrFromType_(numpy_dtype),                             \
1184
        0,                                                                   \
1185 1186 1187
        py_dims,                                                             \
        py_strides,                                                          \
        nullptr,                                                             \
W
wanghuancoder 已提交
1188 1189 1190 1191 1192
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |                      \
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,                 \
        nullptr);                                                            \
    std::memcpy(                                                             \
        reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data), \
1193 1194
        static_cast<void*>(&b),                                              \
        sizeof(b));                                                          \
W
wanghuancoder 已提交
1195 1196 1197 1198 1199 1200 1201 1202 1203 1204
    return array;                                                            \
  }

  PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(TENSOR_TO_PY_SCALAR);
#undef TENSOR_TO_PY_SCALAR
  PADDLE_THROW(platform::errors::Unimplemented(
      "Unsupported tensor data type: %s", tensor.dtype()));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245
static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Call __setitem_eager_tensor";

  auto self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());

  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  PyObject* value_obj = PyTuple_GET_ITEM(args, 1);
  // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
  // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
  PyObject* index_ptr =
      !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
  DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
    if (!PyTuple_Check(_index)) {
      Py_DECREF(index_ptr);
      VLOG(4) << "Call Py_DECREF";
    }
  });

  // 1. Check argumnets
  bool parse_index = true;

  // Check whether _index can be parsed.
  const int size = PyTuple_GET_SIZE(index_ptr);
  for (int dim = 0; dim < size; ++dim) {
    PyObject* slice_item = PyTuple_GetItem(index_ptr, dim);
    if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item) ||
          slice_item == Py_Ellipsis || slice_item == Py_None)) {
      parse_index = false;
      break;
    }
  }

  // 2. Call op set_value to speed up if the condition is met,
  // otherwise call TensorToPyArray.
  // TODO(liym27): Try not to call TensorToPyArray because it always
  // copys data to cpu place, which reduces performance.
  if (parse_index) {
    std::vector<int> axes, starts, ends, steps, decrease_axes, none_axes,
W
wanghuancoder 已提交
1246 1247
        infer_flags;
    std::vector<int64_t> list_select_idxs;
W
wanghuancoder 已提交
1248 1249
    // if index is a list, list_select_flag will be true
    bool list_select_flag = false;
1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260
    ParseIndexingSlice(self_tensor,
                       index_ptr,
                       &axes,
                       &starts,
                       &ends,
                       &steps,
                       &decrease_axes,
                       &none_axes,
                       &infer_flags,
                       &list_select_idxs,
                       &list_select_flag);
W
wanghuancoder 已提交
1261 1262 1263 1264 1265 1266 1267 1268 1269 1270

    framework::AttributeMap attrs = {{"axes", axes},
                                     {"starts", starts},
                                     {"ends", ends},
                                     {"steps", steps},
                                     {"decrease_axes", decrease_axes},
                                     {"none_axes", none_axes}};

    if (egr::Controller::Instance().HasGrad()) {
      PADDLE_ENFORCE_EQ(
1271
          egr::EagerUtils::IsLeafTensor(self->tensor) &&
W
wanghuancoder 已提交
1272
              !egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient(),
1273 1274 1275 1276 1277
          false,
          platform::errors::InvalidArgument(
              "Leaf Tensor (%s) that doesn't stop gradient can't use "
              "inplace strategy.",
              self->tensor.name()));
W
wanghuancoder 已提交
1278 1279
    }

1280
    paddle::Tensor value_tensor;
W
wanghuancoder 已提交
1281 1282 1283 1284

    if (PyCheckTensor(value_obj)) {
      value_tensor = reinterpret_cast<TensorObject*>(value_obj)->tensor;
    } else if (py::isinstance<py::array>(value_obj)) {
1285
      paddle::Tensor value_tensor_tmp(
W
wanghuancoder 已提交
1286 1287 1288 1289
          std::make_shared<phi::DenseTensor>(),
          egr::Controller::Instance().GenerateUniqueName());
      py::object value_obj_tmp(py::handle(value_obj), true);
      py::object value = value_obj_tmp;
1290
      if (self->tensor.dtype() == phi::DataType::FLOAT32) {
W
wanghuancoder 已提交
1291 1292 1293
        if (!py::isinstance<py::array_t<float>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<float>(value_obj_tmp);
        }
1294
      } else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
W
wanghuancoder 已提交
1295 1296 1297
        if (!py::isinstance<py::array_t<double>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<double>(value_obj_tmp);
        }
1298
      } else if (self->tensor.dtype() == phi::DataType::INT32) {
W
wanghuancoder 已提交
1299 1300 1301
        if (!py::isinstance<py::array_t<int32_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int32_t>(value_obj_tmp);
        }
1302
      } else if (self->tensor.dtype() == phi::DataType::INT64) {
W
wanghuancoder 已提交
1303 1304 1305
        if (!py::isinstance<py::array_t<int64_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int64_t>(value_obj_tmp);
        }
1306
      } else if (self->tensor.dtype() == phi::DataType::BOOL) {
W
wanghuancoder 已提交
1307 1308 1309
        if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp);
        }
1310 1311 1312 1313 1314 1315 1316 1317 1318 1319
      } else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
        if (!py::isinstance<py::array_t<std::complex<float>>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<std::complex<float>>(
              value_obj_tmp);
        }
      } else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
        if (!py::isinstance<py::array_t<std::complex<double>>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<std::complex<double>>(
              value_obj_tmp);
        }
W
wanghuancoder 已提交
1320 1321 1322 1323
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "When assign a numpy.np value to a paddle.Tensor, "
            "the data type of the paddle.Tensor must be bool, "
1324
            "float32, float64, complex64, complex128, int32 or int64, "
W
wanghuancoder 已提交
1325 1326 1327
            "please check the type of tensor."));
      }

W
wanghuancoder 已提交
1328 1329 1330 1331 1332
      SetTensorFromPyArray(
          static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
          value,
          self->tensor.place(),
          false);
W
wanghuancoder 已提交
1333 1334 1335 1336 1337 1338 1339

      value_tensor = value_tensor_tmp;
    } else {
      py::object value_obj_tmp(py::handle(value_obj), true);
      // convert the value to self data type
      if (py::isinstance<py::float_>(value_obj_tmp) ||
          py::isinstance<py::int_>(value_obj_tmp) ||
1340 1341
          py::isinstance<py::bool_>(value_obj_tmp) ||
          PyComplex_Check(value_obj)) {
1342
        if (self->tensor.dtype() == phi::DataType::FLOAT32) {
1343 1344
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<float>()};
1345
        } else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
1346 1347
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<double>()};
1348
        } else if (self->tensor.dtype() == phi::DataType::INT32) {
1349 1350
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<int32_t>()};
1351
        } else if (self->tensor.dtype() == phi::DataType::INT64) {
1352 1353
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<int64_t>()};
1354
        } else if (self->tensor.dtype() == phi::DataType::BOOL) {
1355 1356
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<bool>()};
1357
        } else if (self->tensor.dtype() == phi::DataType::FLOAT16) {
1358 1359 1360 1361 1362 1363 1364 1365
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<float>()};
        } else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<std::complex<float>>()};
        } else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<std::complex<double>>()};
W
wanghuancoder 已提交
1366 1367 1368 1369
        } else {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "When assign a value to a paddle.Tensor, "
              "the data type of the paddle.Tensor must be bool, "
1370 1371
              "float32, float64, complex64, complex128, int32, int64 or "
              "float16, "
W
wanghuancoder 已提交
1372 1373 1374 1375 1376 1377 1378
              "please check the type of tensor."));
        }
        attrs["shape"] = std::vector<int64_t>{1};

      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Value type error. The assign value allows "
1379
            "numpy.ndarray, integer, float, complex  or bool, "
W
wanghuancoder 已提交
1380 1381 1382 1383 1384 1385 1386
            "but received %s.",
            Py_TYPE(value_obj)));
      }
    }
    {
      // Release gil and do tracing
      py::gil_scoped_release release;
1387
      // use inplace set_value_ operator
J
Jiabin Yang 已提交
1388 1389
      if (value_tensor.initialized() &&
          (self->tensor.dtype() != value_tensor.dtype())) {
1390
        paddle::small_vector<std::vector<paddle::Tensor>,
J
Jiabin Yang 已提交
1391 1392 1393 1394 1395 1396 1397
                             egr::kSlotSmallVectorSize>
            tmps = {{self->tensor}, {value_tensor}};
        auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
        self->tensor = egr::EagerAmpAutoCast(
            self->tensor.name(), self->tensor, amp_dtype, "set_value");
        value_tensor = egr::EagerAmpAutoCast(
            value_tensor.name(), value_tensor, amp_dtype, "set_value");
1398 1399 1400
        if (self->tensor.dtype() != value_tensor.dtype()) {
          value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
        }
J
Jiabin Yang 已提交
1401
      }
1402 1403
      self->tensor = set_value__dygraph_function(
          self->tensor, value_tensor, {}, {}, {}, attrs);
1404 1405 1406 1407 1408 1409 1410 1411 1412
    }
    if (PyCheckTensor(value_obj)) {
      // pass the stop_gradient from value to tensor.
      // pass stop gradient should be done after CheckInplace in
      // set_value__dygraph_function.
      if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
        egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
      }
W
wanghuancoder 已提交
1413 1414
    }
  } else {
1415
    auto self_numpy = TensorToPyArray(*self_tensor, true);
W
wanghuancoder 已提交
1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426
    VLOG(4) << "parse_index is false";
    if (PyCheckTensor(_index)) {
      VLOG(4) << "index is tensor";
      auto index_tensor = static_cast<phi::DenseTensor*>(
          reinterpret_cast<TensorObject*>(_index)->tensor.impl().get());
      auto index_numpy = TensorToPyArray(*index_tensor);
      self_numpy[index_numpy] = py::object(py::handle(value_obj), true);
    } else {
      VLOG(4) << "index is not tensor";
      self_numpy[_index] = py::object(py::handle(value_obj), true);
    }
1427
    if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
1428
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1429 1430 1431 1432
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CUDAPlace(0)),
                           false);
W
wanghuancoder 已提交
1433
#else
1434 1435 1436 1437
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CPUPlace()),
                           false);
W
wanghuancoder 已提交
1438 1439
#endif
    } else {
1440 1441
      SetTensorFromPyArray(
          self_tensor, self_numpy, self->tensor.place(), false);
W
wanghuancoder 已提交
1442 1443
    }
  }
1444 1445
  RETURN_PY_NONE

W
wanghuancoder 已提交
1446 1447 1448
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1449 1450
static PyObject* tensor_register_grad_hook(TensorObject* self,
                                           PyObject* args,
1451 1452 1453
                                           PyObject* kwargs) {
  EAGER_TRY
  int64_t hook_id;
1454
  if (egr::EagerUtils::IsLeafTensor(self->tensor)) {
1455
    VLOG(6) << "Register hook for leaf tensor: " << self->tensor.name();
1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467

    auto autograd_meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);

    if (autograd_meta && !autograd_meta->StopGradient()) {
      if (!autograd_meta->GetMutableGradNode()) {
        VLOG(6) << "Detected NULL grad_node, Leaf tensor should have had "
                   "grad_node with type: GradNodeAccumulation.";
        autograd_meta->SetGradNode(
            std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
      }
    }

1468 1469 1470 1471 1472 1473 1474 1475 1476
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();
    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    auto accumulation_grad_node =
        std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
    hook_id = accumulation_grad_node->RegisterGradientHook(
1477 1478
        rank_info.first,
        rank_info.second,
1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490
        std::make_shared<PyTensorHook>(hook_func));

  } else {
    VLOG(6) << "Register hook for non leaf tensor: " << self->tensor.name();
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();

    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    hook_id = grad_node->RegisterGradientHook(
1491 1492
        rank_info.first,
        rank_info.second,
1493 1494 1495 1496 1497 1498
        std::make_shared<PyTensorHook>(hook_func));
  }
  return ToPyObject(hook_id);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1499 1500
static PyObject* tensor_remove_grad_hook(TensorObject* self,
                                         PyObject* args,
1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512
                                         PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Remove the registered hook for tensor: " << self->tensor.name();
  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);

  int64_t hook_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);

  return ToPyObject(grad_node->RemoveGradientHook(hook_id));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1513 1514
static PyObject* tensor_register_reduce_hook(TensorObject* self,
                                             PyObject* args,
1515 1516 1517 1518 1519 1520
                                             PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Register reduce hook for tensor: " << self->tensor.name();

  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);
1521
  PADDLE_ENFORCE_EQ(egr::EagerUtils::IsLeafTensor(self->tensor),
1522
                    true,
1523 1524 1525 1526
                    platform::errors::InvalidArgument(
                        "Only can register backward hook for leaf Tensor."));
  PADDLE_ENFORCE_EQ(
      !egr::EagerUtils::unsafe_autograd_meta(self->tensor)->StopGradient(),
1527 1528 1529 1530
      true,
      platform::errors::InvalidArgument(
          "Cannot register backward hook on a Tensor that stop "
          "gradient."));
1531 1532 1533 1534 1535 1536 1537 1538 1539 1540
  PADDLE_ENFORCE(
      grad_node.get() != nullptr,
      paddle::platform::errors::Fatal("Detected NULL grad_node,"
                                      "Leaf tensor should have had grad_node "
                                      "with type: GradNodeAccumulation."));
  PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

  auto accumulation_grad_node =
      std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
  accumulation_grad_node->RegisterReduceHook(
1541
      std::make_shared<PyVoidHook>(hook_func));
1542

1543 1544
  RETURN_PY_NONE

1545 1546 1547
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1548 1549
static PyObject* tensor__set_grad_type(TensorObject* self,
                                       PyObject* args,
J
Jiabin Yang 已提交
1550
                                       PyObject* kwargs) {
1551 1552 1553
  EAGER_TRY
  auto var_type = pybind::CastPyArg2ProtoType(PyTuple_GET_ITEM(args, 0), 0);
  auto grad_tensor =
1554
      egr::EagerUtils::autograd_meta(&self->tensor)->MutableGrad();
1555
  if (var_type == framework::proto::VarType::LOD_TENSOR) {
1556
    grad_tensor->set_impl(std::make_shared<phi::DenseTensor>());
1557
  } else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
1558
    grad_tensor->set_impl(std::make_shared<phi::SelectedRows>());
1559
  }
1560 1561
  RETURN_PY_NONE

1562 1563 1564
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1565 1566
static PyObject* tensor__clear(TensorObject* self,
                               PyObject* args,
J
Jiabin Yang 已提交
1567 1568 1569
                               PyObject* kwargs) {
  EAGER_TRY
  self->tensor.reset();
1570 1571
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1572 1573 1574
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1575 1576 1577 1578 1579 1580 1581 1582 1583
static PyObject* tensor__clear_dataptr(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  self->tensor.set_impl(nullptr);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1584 1585
static PyObject* tensor__copy_gradient_from(TensorObject* self,
                                            PyObject* args,
J
Jiabin Yang 已提交
1586 1587 1588
                                            PyObject* kwargs) {
  EAGER_TRY
  auto src = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
1589
  if (self->tensor.initialized()) {
1590 1591
    PADDLE_ENFORCE_EQ(self->tensor.dtype(),
                      src.dtype(),
J
Jiabin Yang 已提交
1592 1593
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different data type with Tensor %s",
1594 1595
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1596 1597 1598 1599 1600
    PADDLE_ENFORCE_EQ(self->tensor.impl()->type_info().id(),
                      src.impl()->type_info().id(),
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different type with Tensor %s, Tensor "
                          "ShareGradientDataWith cannot be performed!",
1601 1602
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1603 1604 1605 1606
  }
  VLOG(6) << "Tensor copy gradient from: " << src.name();
  auto* p_grad = egr::EagerUtils::mutable_grad(self->tensor);
  if (p_grad) {
1607 1608
    PADDLE_ENFORCE_EQ(src.initialized(),
                      true,
J
Jiabin Yang 已提交
1609 1610 1611 1612
                      platform::errors::InvalidArgument(
                          "Tensor %s has not been initialized", src.name()));
    p_grad->set_impl(src.impl());
  }
1613 1614
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1615 1616
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
1617

1618 1619 1620
static PyObject* tensor__use_gpudnn(TensorObject* self,
                                    PyObject* args,
                                    PyObject* kwargs) {
1621 1622 1623
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.defined() && self->tensor.is_dense_tensor(),
                 paddle::platform::errors::Fatal(
1624
                     "function _use_gpudnn is only effective for DenseTensor"));
1625

1626
  bool use_gpudnn = pybind::CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
1627

1628
  // Set the same use_gpudnn attribute, return directly
1629 1630 1631 1632
  phi::DenseTensor* dense_tensor =
      static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  phi::DenseTensorMeta* dense_tensor_meta =
      phi::DenseTensorUtils::GetMutableMeta(dense_tensor);
1633
  if (use_gpudnn == dense_tensor_meta->use_gpudnn) {
1634 1635 1636
    return ToPyObject(self->tensor);
  }

1637
  // Share all other members of Tensor except use_gpudnn
1638
  phi::DenseTensorMeta target_dense_meta = *dense_tensor_meta;
1639
  target_dense_meta.use_gpudnn = use_gpudnn;
1640 1641 1642 1643
  phi::DenseTensor target_dense_tensor;
  target_dense_tensor.ShareDataWith(*dense_tensor);
  target_dense_tensor.set_meta(target_dense_meta);
  // Construct returned tensor
1644
  paddle::Tensor target_tensor(
1645 1646 1647 1648
      std::make_shared<phi::DenseTensor>(target_dense_tensor),
      self->tensor.name());
  target_tensor.set_autograd_meta(self->tensor.mutable_autograd_meta());
  VLOG(4) << "Tensor: " << target_tensor.name()
1649
          << " set use_gpudnn = " << use_gpudnn;
1650 1651 1652 1653 1654

  return ToPyObject(target_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1655 1656
static PyObject* tensor_method_set_vocab(TensorObject* self,
                                         PyObject* args,
1657 1658
                                         PyObject* kwargs) {
  EAGER_TRY
1659
  using Vocab = paddle::framework::Vocab;
1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671
  auto vocab = CastPyArg2Vocab(PyTuple_GET_ITEM(args, 0), 0);
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Vocab>() = vocab;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_set_string_list(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
1672
  using Strings = paddle::framework::Strings;
1673
  auto strings = CastPyArg2VectorOfString(PyTuple_GET_ITEM(args, 0), 0);
1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Strings>() = strings;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_map_tensor(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE_EQ(
1686 1687
      egr::IsVariableCompatTensor(self->tensor),
      true,
1688 1689
      paddle::platform::errors::Fatal(
          "this method is only effective for VariableCompatTensor"));
1690
  using Vocab = paddle::framework::Vocab;
1691 1692 1693 1694 1695 1696
  auto* var_tensor =
      static_cast<const egr::VariableCompatTensor*>(self->tensor.impl().get());
  return ToPyObject(var_tensor->Get<Vocab>());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717
static PyObject* tensor_method_get_non_zero_nums(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
    return ToPyObject(sparse_coo_tensor->nnz());
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
    return ToPyObject(sparse_csr_tensor->nnz());
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1718 1719 1720 1721 1722 1723 1724 1725 1726
static PyObject* tensor_method_get_non_zero_indices(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_coo_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCooTensor"));
  auto sparse_coo_tensor =
      std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
1727
  paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744
      sparse_coo_tensor->non_zero_indices()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_elements(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
1745
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1746 1747 1748 1749 1750
        sparse_coo_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1751
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766
        sparse_csr_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_crows(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1767
  paddle::Tensor tensor(
1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_crows()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_cols(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1782
  paddle::Tensor tensor(
1783 1784 1785 1786 1787
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_cols()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1788 1789
static PyObject* tensor_method_is_dense(TensorObject* self,
                                        PyObject* args,
1790 1791 1792 1793 1794 1795 1796 1797 1798
                                        PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
  return ToPyObject(self->tensor.is_dense_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

L
LiYuRio 已提交
1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809
static PyObject* tensor_method_is_dist(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
  return ToPyObject(self->tensor.is_dist_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1810 1811
static PyObject* tensor_method_is_sparse(TensorObject* self,
                                         PyObject* args,
1812 1813
                                         PyObject* kwargs) {
  EAGER_TRY
1814 1815 1816
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1817 1818 1819 1820 1821
  return ToPyObject(self->tensor.is_sparse_coo_tensor() ||
                    self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1822 1823
static PyObject* tensor_method_is_sparse_coo(TensorObject* self,
                                             PyObject* args,
1824 1825
                                             PyObject* kwargs) {
  EAGER_TRY
1826 1827 1828
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1829 1830 1831 1832
  return ToPyObject(self->tensor.is_sparse_coo_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1833 1834
static PyObject* tensor_method_is_sparse_csr(TensorObject* self,
                                             PyObject* args,
1835 1836
                                             PyObject* kwargs) {
  EAGER_TRY
1837 1838 1839
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1840 1841 1842 1843
  return ToPyObject(self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1844 1845
static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
                                             PyObject* args,
1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858
                                             PyObject* kwargs) {
  EAGER_TRY
  auto csr_tensor = self->tensor.to_sparse_csr();
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetStopGradient(
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient());
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetPersistable(
          egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  return ToPyObject(csr_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1859 1860 1861 1862 1863 1864 1865 1866 1867
static PyObject* tensor_method_is_same_shape(TensorObject* self,
                                             PyObject* args,
                                             PyObject* kwargs) {
  EAGER_TRY
  auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
  return ToPyObject(self->tensor.shape() == other.shape());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1868 1869
static PyObject* tensor__inplace_version(TensorObject* self,
                                         PyObject* args,
1870 1871 1872 1873 1874 1875 1876 1877
                                         PyObject* kwargs) {
  EAGER_TRY
  uint32_t inplace_version = self->tensor.current_inplace_version();

  return ToPyObject(inplace_version);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1878 1879
static PyObject* tensor_method_element_size(TensorObject* self,
                                            PyObject* args,
1880 1881
                                            PyObject* kwargs) {
  EAGER_TRY
1882
  uint32_t element_size = phi::SizeOf(self->tensor.dtype());
1883 1884 1885 1886 1887

  return ToPyObject(element_size);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1888 1889 1890 1891 1892
static PyObject* tensor__bump_inplace_version(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  self->tensor.bump_inplace_version();
1893
  RETURN_PY_NONE
1894 1895 1896
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1897 1898 1899 1900
static PyObject* tensor_method_is_selected_rows(TensorObject* self,
                                                PyObject* args,
                                                PyObject* kwargs) {
  EAGER_TRY
1901 1902 1903
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1904 1905 1906 1907
  return ToPyObject(self->tensor.is_selected_rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1908 1909
static PyObject* tensor_method_get_rows(TensorObject* self,
                                        PyObject* args,
1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920
                                        PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows"));
  auto selected_rows =
      std::dynamic_pointer_cast<phi::SelectedRows>(self->tensor.impl());
  return ToPyObject(selected_rows->rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1921 1922 1923 1924 1925 1926 1927 1928 1929 1930
static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
  }

1931
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1932 1933 1934 1935
  if (grad && grad->defined() && grad->is_dense_tensor() &&
      grad->initialized()) {
    grad->reset_inplace_version(set_to_zero);
  }
1936 1937
  RETURN_PY_NONE

1938 1939 1940
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1941 1942
static PyObject* tensor_method__share_memory(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
1943 1944 1945
                                             PyObject* kwargs) {
  EAGER_TRY
#ifndef _WIN32
1946 1947
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
W
wanghuancoder 已提交
1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963
                    platform::errors::InvalidArgument(
                        "Sharing memory only support CPU Tensor currently"));
  // 1. get LoDTensor
  auto* t =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl()).get();
  // 2. allocate shared memory
  void* data_ptr = t->data();
  size_t data_size =
      t->numel() *
      framework::SizeOfType(framework::TransToProtoVarType(t->dtype()));
  auto shared_writer_holder =
      memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
  // 3. maintain mmap fd set & backup ipc_name
  const std::string& ipc_name = shared_writer_holder->ipc_name();
  memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
  // 4. copy data & reset holder
1964 1965 1966 1967 1968
  memory::Copy(platform::CPUPlace(),
               shared_writer_holder->ptr(),
               platform::CPUPlace(),
               data_ptr,
               data_size);
W
wanghuancoder 已提交
1969 1970 1971 1972 1973
  t->ResetHolder(shared_writer_holder);
  return ToPyObject(t);
#else
  PADDLE_THROW(platform::errors::PermissionDenied(
      "Sharing memory in Windows OS is not supported currently"));
1974 1975
  RETURN_PY_NONE

W
wanghuancoder 已提交
1976 1977 1978 1979
#endif
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1980 1981
static PyObject* tensor__offset(TensorObject* self,
                                PyObject* args,
1982 1983 1984 1985
                                PyObject* kwargs) {
  EAGER_TRY
  auto t = std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  PADDLE_ENFORCE_EQ(
1986 1987
      t->IsInitialized(),
      true,
1988 1989 1990 1991 1992 1993 1994
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
                                        self->tensor.name()));

  return ToPyObject(t->offset());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1995 1996
static PyObject* tensor__grad_name(TensorObject* self,
                                   PyObject* args,
1997 1998
                                   PyObject* kwargs) {
  EAGER_TRY
1999
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
2000 2001
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
2002 2003 2004 2005 2006 2007 2008
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));
  return ToPyObject(grad->name());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2009 2010
static PyObject* tensor__grad_value(TensorObject* self,
                                    PyObject* args,
2011 2012
                                    PyObject* kwargs) {
  EAGER_TRY
2013
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
2014 2015
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
2016 2017 2018 2019 2020
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));

  if (!grad->defined()) {
2021
    RETURN_PY_NONE
2022 2023
  }
  if (grad->is_dense_tensor()) {
2024
    auto* grad_tensor = static_cast<phi::DenseTensor*>(grad->impl().get());
2025 2026 2027 2028
    return ToPyObject(grad_tensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::Fatal(
        "this method is only supported for DenseTensor"));
2029
    RETURN_PY_NONE
2030 2031 2032 2033
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2034 2035
static PyObject* tensor__unset_fake_empty(TensorObject* self,
                                          PyObject* args,
2036 2037
                                          PyObject* kwargs) {
  EAGER_TRY
2038
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
2039 2040
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
2041 2042 2043 2044
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));

2045
  bool is_leaf = egr::EagerUtils::IsLeafTensor(self->tensor);
2046 2047 2048 2049 2050 2051 2052 2053 2054
  if (is_leaf) {
    std::static_pointer_cast<egr::GradNodeAccumulation>(
        egr::EagerUtils::grad_node(self->tensor))
        ->SetFakeEmpty(false);
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2055 2056 2057 2058 2059
static PyObject* tensor_data_ptr(TensorObject* self,
                                 PyObject* args,
                                 PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.initialized() && self->tensor.is_dense_tensor()) {
S
sneaxiy 已提交
2060 2061 2062 2063
    return ToPyObject(
        (int64_t)std::dynamic_pointer_cast<phi::DenseTensor>(  // NOLINT
            self->tensor.impl())
            ->data());
2064 2065 2066 2067 2068
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083
static PyObject* tensor__grad_ivar(TensorObject* self,
                                   PyObject* args,
                                   PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Get grad for tensor: " << self->tensor.name();
  auto meta = egr::EagerUtils::nullable_autograd_meta(self->tensor);
  VLOG(6) << meta << " initialized: " << meta->Grad().initialized();
  if (meta && meta->Grad().initialized()) {
    return ToPyObject(meta->Grad());
  } else {
    RETURN_PY_NONE
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139
static PyObject* tensor_method_strides(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  std::vector<int64_t> value;
  if (!self->tensor.defined() || !self->tensor.is_dense_tensor()) {
    return ToPyObject(value);
  }
  auto stride = self->tensor.strides();
  size_t rank = static_cast<size_t>(stride.size());
  value.resize(rank);
  for (size_t i = 0; i < rank; i++) {
    value[i] = stride[i];
  }
  return ToPyObject(value);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_contiguous(TensorObject* self,
                                   PyObject* args,
                                   PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.is_dense_tensor()) {
    auto dense_tensor =
        std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
    if (dense_tensor->meta().is_contiguous()) {
      Py_INCREF(self);
      return reinterpret_cast<PyObject*>(self);
    } else {
      eager_gil_scoped_release guard;
      return ToPyObject(
          paddle::Tensor(std::make_shared<phi::DenseTensor>(std::move(
              paddle::experimental::Trans2Contiguous(*(dense_tensor.get()))))));
    }

  } else {
    Py_INCREF(self);
    return reinterpret_cast<PyObject*>(self);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_is_contiguous(TensorObject* self,
                                      PyObject* args,
                                      PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.is_dense_tensor()) {
    auto dense_tensor =
        std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
    return ToPyObject(dense_tensor->meta().is_contiguous());
  } else {
    return ToPyObject(true);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2140
#if defined(PADDLE_WITH_CUDA)
2141 2142
static PyObject* tensor_method__uva(TensorObject* self,
                                    PyObject* args,
2143 2144 2145
                                    PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Running in tensor_method__uva.";
2146 2147
  PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(),
                    true,
W
Weilong Wu 已提交
2148 2149 2150
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "DenseTensor currently."));
2151 2152
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
2153 2154 2155 2156
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "CPU Tensor currently."));
  int device_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
2157
  auto* self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
2158 2159
  tensor_uva(self_tensor, device_id);

2160 2161
  RETURN_PY_NONE

2162 2163 2164
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
#endif
J
Jack Zhou 已提交
2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176
static PyObject* tensor_method__is_string_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto string_tensor =
      std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
  if (string_tensor) {
    return ToPyObject(string_tensor->initialized());
  } else {
    return ToPyObject(false);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
2177

2178
PyMethodDef variable_methods[] = {
2179
    {"numpy",
2180
     (PyCFunction)(void (*)())tensor_method_numpy,
2181 2182
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2183
    {"_is_initialized",
2184
     (PyCFunction)(void (*)())tensor_method__is_initialized,
2185 2186
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2187
    {"_is_dense_tensor_hold_allocation",
2188 2189
     (PyCFunction)(void (*)(
         void))tensor_method__is_dense_tensor_hold_allocation,
2190 2191 2192
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_copy_to",
2193
     (PyCFunction)(void (*)())tensor_method__copy_to,
2194 2195 2196
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"copy_",
2197
     (PyCFunction)(void (*)())tensor_method_copy_,
2198 2199
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2200
    {"clone",
2201
     (PyCFunction)(void (*)())tensor_method_clone,
2202 2203
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2204
    {"reconstruct_from_",
2205
     (PyCFunction)(void (*)())tensor_method_reconstruct_from_,
2206 2207 2208
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"retain_grads",
2209
     (PyCFunction)(void (*)())tensor_retain_grads,
2210 2211 2212
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"clear_gradient",
2213
     (PyCFunction)(void (*)())tensor_clear_gradient,
2214 2215 2216
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_dense",
2217
     (PyCFunction)(void (*)())tensor_method_is_dense,
2218 2219
     METH_VARARGS | METH_KEYWORDS,
     NULL},
L
LiYuRio 已提交
2220
    {"is_dist",
2221
     (PyCFunction)(void (*)())tensor_method_is_dist,
L
LiYuRio 已提交
2222 2223
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2224
    {"_zero_grads",
2225
     (PyCFunction)(void (*)())tensor__zero_grads,
2226 2227 2228
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_share_buffer_to",
2229
     (PyCFunction)(void (*)())tensor__share_buffer_to,
2230 2231
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2232
    {"_is_shared_buffer_with",
2233
     (PyCFunction)(void (*)())tensor__is_shared_buffer_with,
2234 2235
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2236
    {"_share_underline_tensor_to",
2237
     (PyCFunction)(void (*)())tensor__share_underline_tensor_to,
2238 2239
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2240
    {"_is_shared_underline_tensor_with",
2241
     (PyCFunction)(void (*)())tensor__is_shared_underline_tensor_with,
2242 2243 2244
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"detach",
2245
     (PyCFunction)(void (*)())tensor_method_detach,
2246 2247
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2248 2249 2250 2251
    {"detach_",
     (PyCFunction)(void (*)(void))tensor_method_detach_,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2252
    {"get_tensor",
2253
     (PyCFunction)(void (*)())tensor_method_get_underline_tensor,
2254 2255
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2256
    {"get_selected_rows",
2257
     (PyCFunction)(void (*)())tensor_method_get_underline_selected_rows,
2258 2259
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2260
    {"_get_tensor_from_selected_rows",
2261
     (PyCFunction)(void (*)())tensor_method__get_tensor_from_selected_rows,
2262 2263
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jiabin Yang 已提交
2264
    {"_getitem_index_not_tensor",
2265
     (PyCFunction)(void (*)())tensor__getitem_index_not_tensor,
2266 2267
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2268
    {"_getitem_from_offset",
2269
     (PyCFunction)(void (*)())tensor__getitem_from_offset,
2270 2271
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2272
    {"__setitem_eager_tensor__",
2273
     (PyCFunction)(void (*)())tensor_method__setitem_eager_tensor,
2274 2275
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2276
    {"_register_grad_hook",
2277
     (PyCFunction)(void (*)())tensor_register_grad_hook,
2278 2279 2280
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_remove_grad_hook",
2281
     (PyCFunction)(void (*)())tensor_remove_grad_hook,
2282 2283
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2284
    {"_register_backward_hook",
2285
     (PyCFunction)(void (*)())tensor_register_reduce_hook,
2286 2287 2288
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_set_grad_type",
2289
     (PyCFunction)(void (*)())tensor__set_grad_type,
2290 2291 2292
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_clear",
2293
     (PyCFunction)(void (*)())tensor__clear,
2294 2295
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2296
    {"_clear_dataptr",
2297
     (PyCFunction)(void (*)())tensor__clear_dataptr,
2298 2299
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jiabin Yang 已提交
2300
    {"_copy_gradient_from",
2301
     (PyCFunction)(void (*)())tensor__copy_gradient_from,
2302 2303
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2304
    {"_tensor_use_gpudnn",
2305
     (PyCFunction)(void (*)())tensor__use_gpudnn,
2306 2307
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2308 2309
    /** the methods to adapt old dygraph, will be removed in the future **/
    {"set_string_list",
2310
     (PyCFunction)(void (*)())tensor_method_set_string_list,
2311 2312 2313
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"set_vocab",
2314
     (PyCFunction)(void (*)())tensor_method_set_vocab,
2315 2316
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2317
    {"get_map_tensor",
2318
     (PyCFunction)(void (*)())tensor_method_get_map_tensor,
2319 2320
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2321
    /***the method of sparse tensor****/
2322
    {"nnz",
2323
     (PyCFunction)(void (*)())tensor_method_get_non_zero_nums,
2324 2325
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2326
    {"indices",
2327
     (PyCFunction)(void (*)())tensor_method_get_non_zero_indices,
2328 2329 2330
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"values",
2331
     (PyCFunction)(void (*)())tensor_method_get_non_zero_elements,
2332 2333 2334
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"crows",
2335
     (PyCFunction)(void (*)())tensor_method_get_non_zero_crows,
2336 2337 2338
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"cols",
2339
     (PyCFunction)(void (*)())tensor_method_get_non_zero_cols,
2340 2341 2342
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse",
2343
     (PyCFunction)(void (*)())tensor_method_is_sparse,
2344 2345 2346
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse_coo",
2347
     (PyCFunction)(void (*)())tensor_method_is_sparse_coo,
2348 2349 2350
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse_csr",
2351
     (PyCFunction)(void (*)())tensor_method_is_sparse_csr,
2352 2353
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2354
    {"is_same_shape",
2355
     (PyCFunction)(void (*)())tensor_method_is_same_shape,
2356 2357
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2358
    {"to_sparse_csr",
2359
     (PyCFunction)(void (*)())tensor_method_to_sparse_csr,
2360 2361 2362
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"element_size",
2363
     (PyCFunction)(void (*)())tensor_method_element_size,
2364 2365
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2366
    /***the method of sparse tensor****/
2367
    {"_inplace_version",
2368
     (PyCFunction)(void (*)())tensor__inplace_version,
2369 2370
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2371
    {"_bump_inplace_version",
2372
     (PyCFunction)(void (*)())tensor__bump_inplace_version,
2373 2374
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2375
    {"is_selected_rows",
2376
     (PyCFunction)(void (*)())tensor_method_is_selected_rows,
2377 2378 2379
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"rows",
2380
     (PyCFunction)(void (*)())tensor_method_get_rows,
2381 2382
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2383
    {"_reset_grad_inplace_version",
2384
     (PyCFunction)(void (*)())tensor__reset_grad_inplace_version,
2385 2386 2387
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_share_memory",
2388
     (PyCFunction)(void (*)())tensor_method__share_memory,
2389 2390 2391
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_offset",
2392
     (PyCFunction)(void (*)())tensor__offset,
2393 2394 2395
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_grad_name",
2396
     (PyCFunction)(void (*)())tensor__grad_name,
2397 2398 2399
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_grad_value",
2400
     (PyCFunction)(void (*)())tensor__grad_value,
2401 2402 2403
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_unset_fake_empty",
2404
     (PyCFunction)(void (*)())tensor__unset_fake_empty,
2405 2406
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2407
    {"data_ptr",
2408
     (PyCFunction)(void (*)())tensor_data_ptr,
2409 2410
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2411
    {"_grad_ivar",
2412
     (PyCFunction)(void (*)())tensor__grad_ivar,
W
wanghuancoder 已提交
2413 2414
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426
    {"contiguous",
     (PyCFunction)(void (*)(void))tensor_contiguous,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_contiguous",
     (PyCFunction)(void (*)(void))tensor_is_contiguous,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"get_strides",
     (PyCFunction)(void (*)(void))tensor_method_strides,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2427
#if defined(PADDLE_WITH_CUDA)
2428
    {"_tensor_uva",
2429
     (PyCFunction)(void (*)())tensor_method__uva,
2430 2431
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2432
#endif
2433 2434
    {NULL, NULL, 0, NULL}};

J
Jack Zhou 已提交
2435 2436 2437
// variable_methods for core.eager.StringTensor
PyMethodDef string_tensor_variable_methods[] = {
    {"numpy",
2438
     (PyCFunction)(void (*)())tensor_method_numpy_for_string_tensor,
2439 2440
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2441
    {"_is_initialized",
2442
     (PyCFunction)(void (*)())tensor_method__is_initialized,
2443 2444
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2445
    {"_is_string_tensor_hold_allocation",
2446 2447
     (PyCFunction)(void (*)(
         void))tensor_method__is_string_tensor_hold_allocation,
2448 2449
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2450 2451 2452
    // TODO(zhoushunjie): Need to add _copy_to, copy_ for StringTensor.
    {NULL, NULL, 0, NULL}};

2453 2454
}  // namespace pybind
}  // namespace paddle