eager_method.cc 86.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// disable numpy compile error
12 13 14 15 16 17

#if defined(_MSC_VER)
#include <BaseTsd.h>
typedef SSIZE_T ssize_t;
#endif

18
#include <Python.h>
19 20 21 22
// Avoid a problem with copysign defined in pyconfig.h on Windows.
#ifdef copysign
#undef copysign
#endif
23 24

#include <string>
25
#include <unordered_map>
26 27
#include <vector>

28
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
29
#include "paddle/fluid/eager/api/all.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
31
#include "paddle/fluid/eager/autograd_meta.h"
32 33
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
34
#include "paddle/fluid/eager/utils.h"
35
#include "paddle/fluid/framework/convert_utils.h"
36
#include "paddle/fluid/framework/string_array.h"
37 38 39 40 41 42
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
J
Jiabin Yang 已提交
43
#include "paddle/fluid/pybind/slice_utils.h"
44
#include "paddle/fluid/pybind/uva_utils.h"
45 46 47 48
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
49 50
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
W
wanghuancoder 已提交
51
#include "pybind11/detail/internals.h"
52 53
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
W
wanghuancoder 已提交
54
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
J
Jiabin Yang 已提交
55
#include "paddle/fluid/eager/amp_utils.h"
56
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
J
Jiabin Yang 已提交
57
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
W
wanghuancoder 已提交
58
#include "paddle/fluid/framework/python_headers.h"
W
wanghuancoder 已提交
59
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
W
wanghuancoder 已提交
60
#include "paddle/fluid/pybind/tensor_py.h"
W
wanghuancoder 已提交
61
#include "paddle/phi/core/ddim.h"
62
#include "paddle/phi/core/flags.h"
63
#include "paddle/phi/core/tensor_utils.h"
64
#include "paddle/phi/kernels/funcs/math_function.h"
L
LiYuRio 已提交
65 66 67
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
J
Jiabin Yang 已提交
68

69
PHI_DECLARE_bool(set_to_1d);
70

71 72 73
namespace paddle {
namespace pybind {

74 75
extern void InitTensorWithNumpyValue(TensorObject* self,
                                     const pybind11::object& array,
76
                                     const paddle::platform::Place& place,
77
                                     bool zero_copy);
78

79
extern PyTypeObject* p_tensor_type;
80

J
Jiabin Yang 已提交
81
Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) {
82
  if (PyObject_TypeCheck(obj, p_tensor_type)) {
J
Jiabin Yang 已提交
83
    VLOG(6) << "Call GetSliceIndexFromTensor in Eager";
84
    paddle::Tensor tensor = CastPyArg2Tensor(obj, 0);
J
Jiabin Yang 已提交
85
    PADDLE_ENFORCE_EQ(
86 87
        tensor.initialized(),
        true,
J
Jiabin Yang 已提交
88 89 90 91 92 93 94 95
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in slice, however we got "
            "uninitialized tensor %s, please check your code.",
            tensor.name()));
    return GetSliceIndexFromTensor((*static_cast<phi::DenseTensor*>(
        CastPyArg2Tensor(obj, 0).impl().get())));
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
96
        "We should only get paddle::Tensor or VarBase in this "
J
Jiabin Yang 已提交
97 98 99 100
        "method, when you reach this means we got another type index."));
  }
}

101 102
static PyObject* tensor_method_numpy(TensorObject* self,
                                     PyObject* args,
103 104
                                     PyObject* kwargs) {
  EAGER_TRY
W
wanghuancoder 已提交
105 106 107 108 109 110 111 112 113
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl()) {
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
114 115 116 117 118
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_FLOAT_),
        1,
        py_dims,
        py_strides,
        nullptr,
W
wanghuancoder 已提交
119 120 121 122 123
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }
124 125
  auto tensor_dims = self->tensor.shape();
  auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type());
126
  auto sizeof_dtype = phi::SizeOf(self->tensor.type());
127 128
  Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
  Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
129
  size_t py_rank = tensor_dims.size();
130
  size_t numel = 1;
131
  if (py_rank == 0) {
132
    Py_ssize_t args_num = PyTuple_Size(args);
133 134
    // true by default
    bool set_to_1d = FLAGS_set_to_1d;
135 136 137 138 139 140 141
    if (args_num == (Py_ssize_t)1) {
      PyObject* obj = PyTuple_GET_ITEM(args, 0);
      if (obj == Py_False) {
        set_to_1d = false;
      }
    }
    if (set_to_1d) {
142
      // 0D Tensor hack process to 1D numpy, will remove in release 2.6
143 144 145 146 147
      VLOG(0)
          << "Warning:: 0D Tensor cannot be used as 'Tensor.numpy()[0]' . In "
             "order to avoid this problem, "
             "0D Tensor will be changed to 1D numpy currently, but it's not "
             "correct and will be "
148 149
             "removed in release 2.6. For Tensor contain only one element, "
             "Please "
150
             "modify "
151
             " 'Tensor.numpy()[0]' to 'float(Tensor)' as soon as "
152
             "possible, "
153
             "otherwise 'Tensor.numpy()[0]' will raise error in release 2.6.";
154 155 156 157
      py_rank = 1;
      py_dims[0] = 1;
      py_strides[0] = sizeof_dtype * numel;
    }
158 159 160 161 162 163
  } else {
    for (int i = tensor_dims.size() - 1; i >= 0; --i) {
      py_dims[i] = static_cast<size_t>(tensor_dims[i]);
      py_strides[i] = sizeof_dtype * numel;
      numel *= py_dims[i];
    }
164
  }
W
wanghuancoder 已提交
165

166
  PyObject* array = api.PyArray_NewFromDescr_(
167 168
      api.PyArray_Type_,
      api.PyArray_DescrFromType_(numpy_dtype),
169
      py_rank,
170 171 172
      py_dims,
      py_strides,
      nullptr,
173 174 175 176
      pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
          pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
      nullptr);

W
wanghuancoder 已提交
177
  if (!self->tensor.impl()->initialized()) {
178 179 180 181
    if (tensor_dims.size() == 0) {
      py_dims[0] = 0;
      py_strides[0] = 0;
      PyObject* array = api.PyArray_NewFromDescr_(
182 183 184 185 186 187
          api.PyArray_Type_,
          api.PyArray_DescrFromType_(numpy_dtype),
          1,
          py_dims,
          py_strides,
          nullptr,
188 189 190 191 192
          pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
              pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
          nullptr);
      return array;
    }
W
wanghuancoder 已提交
193 194 195
    return array;
  }

196
  if (self->tensor.is_cpu() || self->tensor.is_gpu_pinned()) {
W
wanghuancoder 已提交
197
    eager_gil_scoped_release guard;
198
    platform::CPUPlace place;
199 200 201 202
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
203 204
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
205 206 207 208 209

      // deep copy
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
210 211 212
          place,
          dense_tensor->data(),
          sizeof_dtype * numel);
213 214 215 216 217 218 219 220
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      // deep copy
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
221 222 223
          place,
          dense_tensor->data(),
          sizeof_dtype * numel);
224 225
    }

226
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
227
  } else if (self->tensor.is_gpu()) {
W
wanghuancoder 已提交
228
    eager_gil_scoped_release guard;
229 230 231 232 233
#if defined(PADDLE_WITH_CUDA)
    gpuMemcpyKind kind = cudaMemcpyDeviceToHost;
#elif defined(PADDLE_WITH_HIP)
    gpuMemcpyKind kind = hipMemcpyDeviceToHost;
#endif
234 235 236 237
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
238 239
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
240
      paddle::platform::GpuMemcpySync(
241 242
          pybind11::detail::array_proxy(array)->data,
          dense_tensor->data(),
243
          phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel(),
244
          kind);
245 246 247 248 249
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      paddle::platform::GpuMemcpySync(
250 251
          pybind11::detail::array_proxy(array)->data,
          dense_tensor->data(),
252
          phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel(),
253
          kind);
254
    }
255
#endif
C
Chen Weihang 已提交
256 257 258 259 260 261 262
#if defined(PADDLE_WITH_XPU)
  } else if (self->tensor.is_xpu()) {
    platform::CPUPlace place;
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
263 264
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
C
Chen Weihang 已提交
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
          dense_tensor->place(),
          dense_tensor->data(),
          sizeof_dtype * numel);
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
          dense_tensor->place(),
          dense_tensor->data(),
          sizeof_dtype * numel);
    }
#endif
283 284
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  } else if (self->tensor.is_custom_device()) {
W
wanghuancoder 已提交
285
    eager_gil_scoped_release guard;
286 287 288 289
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
290 291
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
292 293 294 295
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
          ->MemoryCopyD2H(
              pybind11::detail::array_proxy(array)->data,
              dense_tensor->data(),
296
              phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel());
297 298 299 300
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
C
co63oc 已提交
301
      // TODO(qili93): temporary for ascend npu performance to be removed along
302
      // with npu_identity op
303
      paddle::Tensor temp_tensor(std::make_shared<phi::DenseTensor>());
304 305 306 307 308
      if (dense_tensor->storage_properties_initialized()) {
        temp_tensor = npu_identity_ad_func(self->tensor, -1);
        dense_tensor =
            std::dynamic_pointer_cast<phi::DenseTensor>(temp_tensor.impl());
      }
309 310 311 312
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
          ->MemoryCopyD2H(
              pybind11::detail::array_proxy(array)->data,
              dense_tensor->data(),
313
              phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel());
314 315
    }
#endif
316 317 318
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Tensor.numpy() only support cpu tensor."));
319
    RETURN_PY_NONE
320 321 322 323 324 325
  }

  return array;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jack Zhou 已提交
326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
static PyObject* tensor_method_numpy_for_string_tensor(TensorObject* self,
                                                       PyObject* args,
                                                       PyObject* kwargs) {
  EAGER_TRY
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl() || !self->tensor.impl()->initialized()) {
    VLOG(6) << "The StringTensor is uninitialized. Return the empty string "
               "numpy array.";
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
341 342 343 344 345
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_UNICODE_),
        1,
        py_dims,
        py_strides,
        nullptr,
J
Jack Zhou 已提交
346 347 348 349 350 351 352 353 354 355 356 357 358
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }

  if (self->tensor.is_cpu()) {
    VLOG(6) << "Getting StringTensor's numpy value";
    auto string_tensor =
        std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
    const auto* st_ptr = string_tensor->data();
    auto numel = self->tensor.numel();
    auto tensor_dims = self->tensor.shape();
W
wanghuancoder 已提交
359 360
    // Get the max unicode length of StringTensor to create numpy unicode
    // string array.
J
Jack Zhou 已提交
361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
    auto* longest_pstring = std::max_element(
        st_ptr, st_ptr + numel, [](const auto& a, const auto& b) {
          auto a_unicode_len =
              phi::strings::GetUnicodeStrLen(a.data(), a.size());
          auto b_unicode_len =
              phi::strings::GetUnicodeStrLen(b.data(), b.size());
          return a_unicode_len < b_unicode_len;
        });
    size_t max_unicode_length = phi::strings::GetUnicodeStrLen(
        longest_pstring->data(), longest_pstring->size());
    max_unicode_length = (max_unicode_length == 0) ? 1 : max_unicode_length;
    VLOG(6) << "The max unicode length is " << max_unicode_length;
    auto sp = std::make_unique<uint32_t[]>(max_unicode_length * numel);
    auto py_array_data = sp.get();
    memset(py_array_data, 0, max_unicode_length * numel * sizeof(uint32_t));
    for (int64_t i = 0; i < numel; ++i) {
      auto curr_unicode_len =
          phi::strings::GetUnicodeStrLen(st_ptr[i].data(), st_ptr[i].size());
      phi::strings::GetUnicodeStr(st_ptr[i].data(),
                                  py_array_data + i * max_unicode_length,
                                  curr_unicode_len);
    }
    py::array array(py::dtype("U" + std::to_string(max_unicode_length)),
384 385 386
                    tensor_dims,
                    {},
                    py_array_data);
J
Jack Zhou 已提交
387 388 389 390
    return array.release().ptr();
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "StringTensor.numpy() only support cpu tensor."));
391
    RETURN_PY_NONE
J
Jack Zhou 已提交
392 393 394 395
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

396 397 398 399
static PyObject* tensor_method__is_initialized(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
400
  return ToPyObject(self->tensor.initialized());
401 402 403
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
404 405 406 407 408 409 410 411 412 413 414 415 416 417
static PyObject* tensor_method__is_dense_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto dense_tensor =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  if (dense_tensor) {
    return ToPyObject(dense_tensor->IsInitialized());
  } else {
    return ToPyObject(false);
  }

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

418
static void IncreaseTensorReferenceCountUntilCopyComplete(
419
    const paddle::Tensor& tensor, const platform::Place& place) {
420 421 422 423 424 425 426 427
  auto place_ = platform::is_gpu_place(place) ? place : tensor.place();

  auto tracer = egr::Controller::Instance().GetCurrentTracer();
  auto gc = tracer->MutableGarbageCollectorIfNotExists(place_);

  // Note(dev): This is an empty callback, the only way is to "reference"
  // inner memory Holder, so it will not be destructed until the kernels
  // launched at current stream of given place is finished, such as
C
co63oc 已提交
428
  // CUDAPinned Mem -> CUDA by cudaMemcpyAsync.
429 430 431 432 433 434 435
  auto callback = [tensor, place_]() {
    VLOG(3) << "Run callback of Tensor:" << tensor.name() << " at place "
            << place_;
  };
  gc->DirectClearCallback(callback);
}

436 437
static PyObject* tensor_method__copy_to(TensorObject* self,
                                        PyObject* args,
438 439
                                        PyObject* kwargs) {
  EAGER_TRY
440 441
  auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
442
  paddle::Tensor cp_tensor;
W
wanghuancoder 已提交
443 444 445 446 447 448 449 450 451 452
  {
    eager_gil_scoped_release guard;
    cp_tensor = self->tensor.copy_to(place, blocking);
    if (!blocking) {
      IncreaseTensorReferenceCountUntilCopyComplete(self->tensor, place);
    }
    egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
    egr::EagerUtils::autograd_meta(&cp_tensor)
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
453
  }
454 455 456 457
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

458 459
static PyObject* tensor_method_cpu(TensorObject* self,
                                   PyObject* args,
460 461
                                   PyObject* kwargs) {
  EAGER_TRY
462
  paddle::Tensor cp_tensor;
W
wanghuancoder 已提交
463 464 465 466 467 468 469 470
  {
    eager_gil_scoped_release guard;
    cp_tensor = self->tensor.copy_to(phi::CPUPlace(), true);
    egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
    egr::EagerUtils::autograd_meta(&cp_tensor)
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  }
471 472 473 474
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

475 476 477 478
static PyObject* tensor_method_reconstruct_from_(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
479
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
480
  std::string orig_name = self->tensor.name();
481 482
  VLOG(6) << "Start Reconstructing Tensor from" << src_tensor.name() << " to "
          << orig_name;
483
  self->tensor = src_tensor;
484 485

  // Recover source name
486
  self->tensor.set_name(orig_name);
487 488

  VLOG(6) << "Finished Reconstructing Tensor from" << src_tensor.name()
489
          << " to " << self->tensor.name();
490 491
  RETURN_PY_NONE

492 493 494
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

495 496
static PyObject* tensor_method_copy_(TensorObject* self,
                                     PyObject* args,
497 498
                                     PyObject* kwargs) {
  EAGER_TRY
499
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
500
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
501
  VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to "
502
          << self->tensor.name();
503
  if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
504
    eager_gil_scoped_release guard;
505
    egr::EagerUtils::autograd_meta(&(self->tensor))
506 507
        ->SetStopGradient(
            egr::EagerUtils::autograd_meta(&(src_tensor))->StopGradient());
508
    egr::EagerUtils::autograd_meta(&(self->tensor))
509 510
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(src_tensor))->Persistable());
511
    if (src_tensor.initialized()) {
C
Chen Weihang 已提交
512
      self->tensor.copy_(src_tensor, src_tensor.place(), blocking);
513 514 515
    }
  } else {
    if (src_tensor.initialized()) {
W
wanghuancoder 已提交
516
      eager_gil_scoped_release guard;
C
Chen Weihang 已提交
517
      self->tensor.copy_(src_tensor, self->tensor.place(), blocking);
518
    }
519 520
  }

521
  VLOG(6) << "Finish Copy Tensor " << src_tensor.name() << " to "
522
          << self->tensor.name();
523 524
  RETURN_PY_NONE

525 526 527
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

528 529 530 531
static PyObject* tensor_method_clone(TensorObject* self,
                                     PyObject* args,
                                     PyObject* kwargs) {
  EAGER_TRY
532
  paddle::Tensor out;
W
wanghuancoder 已提交
533 534 535 536 537 538 539 540 541
  {
    eager_gil_scoped_release guard;
    PADDLE_ENFORCE_EQ(
        self->tensor.initialized(),
        true,
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in clone, however we got "
            "uninitialized tensor %s, please check your code.",
            self->tensor.name()));
542

W
wanghuancoder 已提交
543 544
    out = assign_ad_func(self->tensor);
  }
545 546 547 548
  return ToPyObject(out);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

549 550
static PyObject* tensor_retain_grads(TensorObject* self,
                                     PyObject* args,
551
                                     PyObject* kwargs) {
552
  EAGER_TRY
553
  if (egr::Controller::Instance().HasGrad()) {
W
wanghuancoder 已提交
554
    eager_gil_scoped_release guard;
555
    auto meta = egr::EagerUtils::autograd_meta(&(self->tensor));
556
    if (!meta->GetMutableGradNode()) {
557
      VLOG(6) << "Make grad node of tensor: " << self->tensor.name()
558
              << "become accumulation node";
559
      meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
560
    }
561
    egr::egr_utils_api::RetainGradForTensor(self->tensor);
562
  }
563 564
  RETURN_PY_NONE

565 566 567
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

568 569
static PyObject* tensor_clear_gradient(TensorObject* self,
                                       PyObject* args,
570
                                       PyObject* kwargs) {
571
  EAGER_TRY
572
  VLOG(4) << "ClearGradient " << self->tensor.name();
573

574 575 576
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
J
Jiabin Yang 已提交
577
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
578 579
  }

580
  paddle::Tensor* grad;
581
  bool is_leaf = egr::EagerUtils::IsLeafTensor(self->tensor);
J
Jiabin Yang 已提交
582
  if (is_leaf) {
583 584 585 586 587 588
    grad = egr::EagerUtils::mutable_grad(self->tensor);
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
589
  } else {
590
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
591
    grad = meta->MutableGrad();
592 593
  }

594
  if (grad->impl()) {
W
wanghuancoder 已提交
595
    eager_gil_scoped_release guard;
596 597 598 599 600 601 602 603 604 605
    if (grad->is_selected_rows()) {
      auto selected_rows =
          std::dynamic_pointer_cast<phi::SelectedRows>(grad->impl());
      if (selected_rows->mutable_value()->IsInitialized()) {
        selected_rows->mutable_rows()->clear();
        selected_rows->mutable_value()->clear();
      }
    } else if (grad->is_dense_tensor()) {
      if (grad->initialized()) {
        if (set_to_zero) {
606 607 608 609
          auto* grad_t = static_cast<phi::DenseTensor*>(grad->impl().get());
          auto* dev_ctx =
              platform::DeviceContextPool::Instance().Get(grad_t->place());
          phi::funcs::set_constant(*dev_ctx, grad_t, 0.0);
J
Jiabin Yang 已提交
610 611 612 613 614
          if (is_leaf) {
            std::static_pointer_cast<egr::GradNodeAccumulation>(
                egr::EagerUtils::grad_node(self->tensor))
                ->SetFakeEmpty(true);
          }
615 616 617 618 619 620 621
        } else {
          VLOG(4) << "Gradient of " << self->tensor.name()
                  << " is initialized, will be released.";
          auto dense_tensor =
              std::dynamic_pointer_cast<phi::DenseTensor>(grad->impl());
          dense_tensor->MoveMemoryHolder();
        }
622 623
      }
    }
624
  }
625

626 627
  RETURN_PY_NONE

628 629 630
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

631 632
static PyObject* tensor__zero_grads(TensorObject* self,
                                    PyObject* args,
633
                                    PyObject* kwargs) {
634
  EAGER_TRY
635
  VLOG(4) << "ZeroGrads " << self->tensor.name();
636

637
  if (egr::EagerUtils::IsLeafTensor(self->tensor)) {
W
wanghuancoder 已提交
638
    eager_gil_scoped_release guard;
639
    // Add RetainGrad as PostHook to AccumulationNode
640
    paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
641 642 643 644 645 646
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
    if (grad->initialized()) {
647 648 649 650 651 652 653
      if (grad->is_dense_tensor()) {
        auto* t = static_cast<phi::DenseTensor*>(grad->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl());
      }
654
    }
655
  } else {
W
wanghuancoder 已提交
656
    eager_gil_scoped_release guard;
657
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
658
    if (meta->MutableGrad()->initialized()) {
659 660 661 662 663 664 665 666 667
      if (meta->MutableGrad()->is_dense_tensor()) {
        auto* t =
            static_cast<phi::DenseTensor*>(meta->MutableGrad()->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        meta->MutableGrad()->set_impl(
            paddle::experimental::zeros_like(*(meta->MutableGrad())).impl());
      }
668
    }
669 670
  }

671 672
  RETURN_PY_NONE

673 674 675
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

676 677
static PyObject* tensor__share_buffer_to(TensorObject* self,
                                         PyObject* args,
678 679
                                         PyObject* kwargs) {
  EAGER_TRY
680
  paddle::Tensor* dst_ptr =
681
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
682 683
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
684 685 686
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
687
                        self->tensor.name()));
688
  auto* src_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
689 690 691
  if (!dst_ptr->defined()) {
    dst_ptr->set_impl(std::make_shared<phi::DenseTensor>());
  }
692
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
B
Baibaifan 已提交
693
  dst_tensor->ShareBufferWith(*src_tensor);
694
  dst_tensor->ShareDataTypeWith(*src_tensor);
695 696
  RETURN_PY_NONE

697 698 699
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

700 701 702 703
static PyObject* tensor__is_shared_buffer_with(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
704
  paddle::Tensor* dst_ptr =
705
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
706 707
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
708 709 710
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
711
                        self->tensor.name()));
712
  bool res = false;
713
  if (!self->tensor.defined() || !dst_ptr->defined()) {
714 715
    return ToPyObject(res);
  }
716 717
  auto* self_ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
718 719 720 721 722
  res = dst_tensor->IsSharedBufferWith(*self_ptr);
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

723 724 725 726
static PyObject* tensor__share_underline_tensor_to(TensorObject* self,
                                                   PyObject* args,
                                                   PyObject* kwargs) {
  EAGER_TRY
727
  paddle::Tensor* src_ptr =
728
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
729 730
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
731 732 733
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
734 735
                        self->tensor.name()));
  src_ptr->set_impl(self->tensor.impl());
736 737
  RETURN_PY_NONE

738 739 740
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

741 742 743 744
static PyObject* tensor__is_shared_underline_tensor_with(TensorObject* self,
                                                         PyObject* args,
                                                         PyObject* kwargs) {
  EAGER_TRY
745
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
746 747
  PADDLE_ENFORCE_EQ(src_tensor.initialized(),
                    true,
748 749 750 751 752
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
                        src_tensor.name()));
  bool res = false;
753
  if (!self->tensor.defined() || !src_tensor.defined()) {
754 755
    return ToPyObject(res);
  }
756
  res = (self->tensor.impl().get() == src_tensor.impl().get());
757 758 759 760
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

761 762
static PyObject* tensor_method_detach(TensorObject* self,
                                      PyObject* args,
763 764
                                      PyObject* kwargs) {
  EAGER_TRY
765
  PADDLE_ENFORCE_EQ(
766
      self->tensor.defined(),
767
      true,
768
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
769
                                        self->tensor.name()));
770

771
  PyObject* obj = p_tensor_type->tp_alloc(p_tensor_type, 0);
772
  if (obj) {
773
    auto v = reinterpret_cast<TensorObject*>(obj);
774
    new (&(v->tensor)) paddle::Tensor();
775 776 777 778
    v->tensor.set_impl(self->tensor.impl());
    v->tensor.set_name(egr::Controller::Instance().GenerateUniqueName());
    auto autograd_meta_src = egr::EagerUtils::autograd_meta(&(self->tensor));
    auto autograd_meta = egr::EagerUtils::autograd_meta(&(v->tensor));
779 780 781 782 783 784 785 786 787 788
    autograd_meta->SetPersistable(autograd_meta_src->Persistable());
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "tp_alloc return null, can not new a PyObject."));
  }

  return obj;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

789 790 791 792
static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
793
  if (!self->tensor.defined()) {
794 795 796
    // The original `get_tensor` method of Variable will create a empty tensor
    phi::DenseTensor empty_tensor;
    return ToPyObject(&empty_tensor);
797
  }
798
  if (self->tensor.is_dense_tensor()) {
799
    auto* tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
800 801
    VLOG(6) << "tensor: " << tensor->IsInitialized();
    return ToPyObject(tensor);
L
LiYuRio 已提交
802 803 804 805 806 807 808 809 810
  } else if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE
    auto* tensor = static_cast<phi::distributed::auto_parallel::DistTensor*>(
        self->tensor.impl().get());
    VLOG(6) << "dist tensor: " << tensor->IsInitialized();
    return ToPyObject(tensor);
#else
    RETURN_PY_NONE
#endif
811
  } else {
812
    RETURN_PY_NONE
813 814 815 816
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

817 818 819 820 821
static PyObject* tensor_method_get_underline_selected_rows(TensorObject* self,
                                                           PyObject* args,
                                                           PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
822
    RETURN_PY_NONE
823 824 825 826 827 828
  }
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    return ToPyObject(selected_rows);
  } else {
829
    RETURN_PY_NONE
830 831 832 833
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

834 835 836 837 838 839 840 841 842 843 844 845 846 847
static PyObject* tensor_method__get_tensor_from_selected_rows(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows."));

  auto* selected_rows =
      static_cast<phi::SelectedRows*>(self->tensor.impl().get());

  PADDLE_ENFORCE(
      selected_rows->initialized(),
      paddle::platform::errors::Fatal("SelectedRows must be initialized."));

848 849
  auto* dense_tensor =
      static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
L
Leo Chen 已提交
850
  VLOG(4) << "dense_tensor: " << dense_tensor->IsInitialized();
851

852
  auto t = paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
853 854 855 856 857 858 859
  t.set_impl(std::make_shared<phi::DenseTensor>(*dense_tensor));

  return ToPyObject(t);

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jiabin Yang 已提交
860 861 862
static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
863
  EAGER_TRY
J
Jiabin Yang 已提交
864 865 866 867 868 869
  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  VLOG(4) << "Call _getitem_index_not_tensor";
  std::vector<int> slice_axes, slice_starts, slice_ends, slice_strides,
      decrease_axis, none_axes, infer_flags, list_select_idxs;
  // if index is a list, list_select_flag will be true
  bool list_select_flag = false;
870 871
  // Note(0x45f): Using defined() instead of initialized()
  // to support slice tensor which shape like [0, 0, 0].
J
Jiabin Yang 已提交
872
  PADDLE_ENFORCE_EQ(
873
      self->tensor.defined(),
874
      true,
J
Jiabin Yang 已提交
875 876 877 878 879
      platform::errors::InvalidArgument(
          "tensor %s has not been initialized, we can only slice initialized "
          "tensor please init it first with numpy or other tensor.",
          self->tensor.name()));
  auto tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
880 881 882 883 884 885 886 887 888 889 890
  ParseIndexingSlice(tensor,
                     _index,
                     &slice_axes,
                     &slice_starts,
                     &slice_ends,
                     &slice_strides,
                     &decrease_axis,
                     &none_axes,
                     &infer_flags,
                     &list_select_idxs,
                     &list_select_flag);
J
Jiabin Yang 已提交
891

892 893 894 895
  auto out =
      slice_axes.empty() && !list_select_flag
          ? self->tensor
          : paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
J
Jiabin Yang 已提交
896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911

  if (!slice_axes.empty()) {
    framework::AttributeMap attrs = {{"axes", slice_axes},
                                     {"starts", slice_starts},
                                     {"ends", slice_ends},
                                     {"infer_flags", infer_flags},
                                     {"decrease_axis", decrease_axis}};
    std::string op_type = "slice";
    for (auto stride : slice_strides) {
      if (stride != 1) {
        op_type = "strided_slice";
        attrs.insert({"strides", slice_strides});
        attrs.erase("decrease_axis");
        break;
      }
    }
912 913 914 915 916 917
    std::vector<int64_t> slice_axes_tmp(slice_axes.begin(), slice_axes.end());
    std::vector<int64_t> infer_flags_tmp(infer_flags.begin(),
                                         infer_flags.end());
    std::vector<int64_t> decrease_axis_tmp(decrease_axis.begin(),
                                           decrease_axis.end());

J
Jiabin Yang 已提交
918
    if (op_type == "slice") {
W
wanghuancoder 已提交
919
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
920 921 922 923 924 925
      out = slice_ad_func(self->tensor,
                          slice_axes_tmp,
                          slice_starts,
                          slice_ends,
                          infer_flags_tmp,
                          decrease_axis_tmp);
J
Jiabin Yang 已提交
926
    } else if (op_type == "strided_slice") {
W
wanghuancoder 已提交
927
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
928
      out = strided_slice_ad_func(
929
          self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
J
Jiabin Yang 已提交
930 931 932 933 934 935 936 937 938
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Slice is only support slice and strided_slice, but we got %s which "
          "is impossible, please check your code first or contact us by "
          "issue. ",
          op_type));
    }
  }

939
  bool set_to_1d = FLAGS_set_to_1d;
940 941 942 943 944 945

  if (set_to_1d) {
    // NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
    // with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
    // otherwise the output shape will be not correct.
    if (static_cast<int>(decrease_axis.size()) == tensor->dims().size()) {
J
JYChen 已提交
946
      VLOG(1)
947 948 949 950 951 952 953 954 955 956 957 958
          << "Warning: In Tensor '__getitem__', if the number of scalar "
             "elements "
             "in the index is equal to the rank of the Tensor, the output "
             "should "
             "be 0-D. In order to be consistent with the behavior of previous "
             "versions, it will be processed to 1-D. But it is not correct and "
             "will be "
             "removed in release 2.6. "
             "If 1-D is still wanted, please modify the index element from "
             "scalar to slice "
             "(e.g. 'x[i]' => 'x[i:i+1]'). ";
      if (!none_axes.empty()) {
959 960 961
        none_axes.pop_back();
      }
    }
962 963 964 965 966 967 968 969 970 971 972 973 974 975
  }
  if (!none_axes.empty()) {
    paddle::Tensor new_out;
    {
      eager_gil_scoped_release guard;
      // Deal with cases that decrease_axes is not empty
      // For example:
      // # x.shape: (2,3,4)
      // out = x[0, 0:2, None] # out.shape : (2, 1, 4)
      for (auto& axis : none_axes) {
        int len = 0;
        for (int da : decrease_axis) {
          if (da < axis) {
            len++;
J
Jiabin Yang 已提交
976 977
          }
        }
978
        axis -= len;
J
Jiabin Yang 已提交
979
      }
980
      new_out = unsqueeze_ad_func(out, none_axes);
J
Jiabin Yang 已提交
981
    }
982
    return ToPyObject(new_out);
J
Jiabin Yang 已提交
983 984 985 986
  }

  // the index is a list
  if (list_select_flag) {
W
wanghuancoder 已提交
987
    eager_gil_scoped_release guard;
988 989
    auto select_index =
        paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
J
Jiabin Yang 已提交
990
    auto idx_tensor = std::make_shared<phi::DenseTensor>();
W
wanghuancoder 已提交
991
    select_index.set_impl(idx_tensor);
J
Jiabin Yang 已提交
992 993
    auto* dev_ctx = platform::DeviceContextPool::Instance().Get(
        egr::Controller::Instance().GetExpectedPlace());
994 995
    paddle::framework::TensorFromVector(
        list_select_idxs, *dev_ctx, idx_tensor.get());
J
Jiabin Yang 已提交
996
    framework::AttributeMap attrs = {{"dim", 0}};
J
Jiabin Yang 已提交
997
    out = index_select_ad_func(self->tensor, select_index, 0);
J
Jiabin Yang 已提交
998 999 1000
  }

  return ToPyObject(out);
1001 1002 1003
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1004 1005
static PyObject* tensor__getitem_from_offset(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
1006 1007
                                             PyObject* kwargs) {
  EAGER_TRY
1008 1009 1010 1011 1012 1013 1014 1015
  phi::DenseTensor* ptr = nullptr;
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    ptr = static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
  } else {
    ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  }
1016 1017 1018
  PADDLE_ENFORCE_NOT_NULL(ptr,
                          platform::errors::InvalidArgument(
                              "%s is not a DenseTensor.", self->tensor.name()));
W
wanghuancoder 已提交
1019 1020
  const auto& tensor = *ptr;
  PADDLE_ENFORCE_EQ(
1021 1022
      tensor.IsInitialized(),
      true,
W
wanghuancoder 已提交
1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039
      platform::errors::InvalidArgument(
          "Tensor of %s is Empty, please check if it has no data.",
          self->tensor.name()));

  const auto& tensor_dims = tensor.dims();

  std::vector<size_t> dims(tensor_dims.size());
  std::vector<size_t> strides(tensor_dims.size());

  size_t numel = 1;
  for (int i = tensor_dims.size() - 1; i >= 0; --i) {
    strides[i] = numel;
    dims[i] = static_cast<size_t>(tensor_dims[i]);
    numel *= dims[i];
  }
  size_t offset = 0;
  if (PyTuple_Size(args) == 0) {
1040 1041
    PADDLE_ENFORCE_EQ(numel,
                      1,
W
wanghuancoder 已提交
1042 1043 1044 1045 1046 1047
                      platform::errors::InvalidArgument(
                          "only one element tensors can be converted to Python "
                          "scalars when no input coordinates"));
  } else if (PyTuple_Size(args) == 1) {
    offset = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
    PADDLE_ENFORCE_LT(
1048 1049
        offset,
        numel,
W
wanghuancoder 已提交
1050 1051 1052
        platform::errors::InvalidArgument(
            "index %d is out of bounds for size %d", offset, numel));
  } else {
1053 1054
    PADDLE_ENFORCE_EQ(PyTuple_Size(args),
                      dims.size(),
W
wanghuancoder 已提交
1055 1056 1057 1058 1059 1060
                      platform::errors::InvalidArgument(
                          "incorrect number of indices for Tensor"));

    for (Py_ssize_t i = 0; i < PyTuple_Size(args); ++i) {
      size_t index = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, i), i);
      PADDLE_ENFORCE_LT(
1061 1062
          index,
          dims[i],
W
wanghuancoder 已提交
1063
          platform::errors::InvalidArgument(
1064 1065 1066
              "index %d is out fo bounds for axis %d with size %d",
              index,
              i,
W
wanghuancoder 已提交
1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095
              dims[i]));
      offset += index * strides[i];
    }
  }
#define PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(_) \
  _(bool, DataType::BOOL)                     \
  _(int8_t, DataType::INT8)                   \
  _(uint8_t, DataType::UINT8)                 \
  _(int16_t, DataType::INT16)                 \
  _(uint16_t, DataType::UINT16)               \
  _(int32_t, DataType::INT32)                 \
  _(uint32_t, DataType::UINT32)               \
  _(int64_t, DataType::INT64)                 \
  _(uint64_t, DataType::UINT64)               \
  _(bfloat16, DataType::BFLOAT16)             \
  _(float16, DataType::FLOAT16)               \
  _(float, DataType::FLOAT32)                 \
  _(double, DataType::FLOAT64)                \
  _(complex64, DataType::COMPLEX64)           \
  _(complex128, DataType::COMPLEX128)

#define TENSOR_TO_PY_SCALAR(T, proto_type)                                   \
  if (tensor.dtype() == proto_type) {                                        \
    auto numpy_dtype = TensorDtype2NumpyDtype(proto_type);                   \
    T b = paddle::pybind::TensorGetElement<T>(tensor, offset);               \
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];                  \
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];               \
    auto& api = pybind11::detail::npy_api::get();                            \
    PyObject* array = api.PyArray_NewFromDescr_(                             \
1096 1097
        api.PyArray_Type_,                                                   \
        api.PyArray_DescrFromType_(numpy_dtype),                             \
1098
        0,                                                                   \
1099 1100 1101
        py_dims,                                                             \
        py_strides,                                                          \
        nullptr,                                                             \
W
wanghuancoder 已提交
1102 1103 1104 1105 1106
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |                      \
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,                 \
        nullptr);                                                            \
    std::memcpy(                                                             \
        reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data), \
1107 1108
        static_cast<void*>(&b),                                              \
        sizeof(b));                                                          \
W
wanghuancoder 已提交
1109 1110 1111 1112 1113 1114 1115 1116 1117 1118
    return array;                                                            \
  }

  PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(TENSOR_TO_PY_SCALAR);
#undef TENSOR_TO_PY_SCALAR
  PADDLE_THROW(platform::errors::Unimplemented(
      "Unsupported tensor data type: %s", tensor.dtype()));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162
static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Call __setitem_eager_tensor";

  auto self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());

  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  PyObject* value_obj = PyTuple_GET_ITEM(args, 1);
  // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
  // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
  PyObject* index_ptr =
      !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
  DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
    if (!PyTuple_Check(_index)) {
      Py_DECREF(index_ptr);
      VLOG(4) << "Call Py_DECREF";
    }
  });

  // 1. Check argumnets
  bool parse_index = true;

  // Check whether _index can be parsed.
  const int size = PyTuple_GET_SIZE(index_ptr);
  for (int dim = 0; dim < size; ++dim) {
    PyObject* slice_item = PyTuple_GetItem(index_ptr, dim);
    if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item) ||
          slice_item == Py_Ellipsis || slice_item == Py_None)) {
      parse_index = false;
      break;
    }
  }

  // 2. Call op set_value to speed up if the condition is met,
  // otherwise call TensorToPyArray.
  // TODO(liym27): Try not to call TensorToPyArray because it always
  // copys data to cpu place, which reduces performance.
  if (parse_index) {
    std::vector<int> axes, starts, ends, steps, decrease_axes, none_axes,
        infer_flags, list_select_idxs;
    // if index is a list, list_select_flag will be true
    bool list_select_flag = false;
1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173
    ParseIndexingSlice(self_tensor,
                       index_ptr,
                       &axes,
                       &starts,
                       &ends,
                       &steps,
                       &decrease_axes,
                       &none_axes,
                       &infer_flags,
                       &list_select_idxs,
                       &list_select_flag);
W
wanghuancoder 已提交
1174 1175 1176 1177 1178 1179 1180 1181 1182 1183

    framework::AttributeMap attrs = {{"axes", axes},
                                     {"starts", starts},
                                     {"ends", ends},
                                     {"steps", steps},
                                     {"decrease_axes", decrease_axes},
                                     {"none_axes", none_axes}};

    if (egr::Controller::Instance().HasGrad()) {
      PADDLE_ENFORCE_EQ(
1184
          egr::EagerUtils::IsLeafTensor(self->tensor) &&
W
wanghuancoder 已提交
1185
              !egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient(),
1186 1187 1188 1189 1190
          false,
          platform::errors::InvalidArgument(
              "Leaf Tensor (%s) that doesn't stop gradient can't use "
              "inplace strategy.",
              self->tensor.name()));
W
wanghuancoder 已提交
1191 1192
    }

1193
    paddle::Tensor value_tensor;
W
wanghuancoder 已提交
1194 1195 1196 1197

    if (PyCheckTensor(value_obj)) {
      value_tensor = reinterpret_cast<TensorObject*>(value_obj)->tensor;
    } else if (py::isinstance<py::array>(value_obj)) {
1198
      paddle::Tensor value_tensor_tmp(
W
wanghuancoder 已提交
1199 1200 1201 1202
          std::make_shared<phi::DenseTensor>(),
          egr::Controller::Instance().GenerateUniqueName());
      py::object value_obj_tmp(py::handle(value_obj), true);
      py::object value = value_obj_tmp;
1203
      if (self->tensor.dtype() == phi::DataType::FLOAT32) {
W
wanghuancoder 已提交
1204 1205 1206
        if (!py::isinstance<py::array_t<float>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<float>(value_obj_tmp);
        }
1207
      } else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
W
wanghuancoder 已提交
1208 1209 1210
        if (!py::isinstance<py::array_t<double>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<double>(value_obj_tmp);
        }
1211
      } else if (self->tensor.dtype() == phi::DataType::INT32) {
W
wanghuancoder 已提交
1212 1213 1214
        if (!py::isinstance<py::array_t<int32_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int32_t>(value_obj_tmp);
        }
1215
      } else if (self->tensor.dtype() == phi::DataType::INT64) {
W
wanghuancoder 已提交
1216 1217 1218
        if (!py::isinstance<py::array_t<int64_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int64_t>(value_obj_tmp);
        }
1219
      } else if (self->tensor.dtype() == phi::DataType::BOOL) {
W
wanghuancoder 已提交
1220 1221 1222
        if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp);
        }
1223 1224 1225 1226 1227 1228 1229 1230 1231 1232
      } else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
        if (!py::isinstance<py::array_t<std::complex<float>>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<std::complex<float>>(
              value_obj_tmp);
        }
      } else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
        if (!py::isinstance<py::array_t<std::complex<double>>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<std::complex<double>>(
              value_obj_tmp);
        }
W
wanghuancoder 已提交
1233 1234 1235 1236
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "When assign a numpy.np value to a paddle.Tensor, "
            "the data type of the paddle.Tensor must be bool, "
1237
            "float32, float64, complex64, complex128, int32 or int64, "
W
wanghuancoder 已提交
1238 1239 1240
            "please check the type of tensor."));
      }

W
wanghuancoder 已提交
1241 1242 1243 1244 1245
      SetTensorFromPyArray(
          static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
          value,
          self->tensor.place(),
          false);
W
wanghuancoder 已提交
1246 1247 1248 1249 1250 1251 1252

      value_tensor = value_tensor_tmp;
    } else {
      py::object value_obj_tmp(py::handle(value_obj), true);
      // convert the value to self data type
      if (py::isinstance<py::float_>(value_obj_tmp) ||
          py::isinstance<py::int_>(value_obj_tmp) ||
1253 1254
          py::isinstance<py::bool_>(value_obj_tmp) ||
          PyComplex_Check(value_obj)) {
1255
        if (self->tensor.dtype() == phi::DataType::FLOAT32) {
1256 1257
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<float>()};
1258
        } else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
1259 1260
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<double>()};
1261
        } else if (self->tensor.dtype() == phi::DataType::INT32) {
1262 1263
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<int32_t>()};
1264
        } else if (self->tensor.dtype() == phi::DataType::INT64) {
1265 1266
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<int64_t>()};
1267
        } else if (self->tensor.dtype() == phi::DataType::BOOL) {
1268 1269
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<bool>()};
1270
        } else if (self->tensor.dtype() == phi::DataType::FLOAT16) {
1271 1272 1273 1274 1275 1276 1277 1278
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<float>()};
        } else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<std::complex<float>>()};
        } else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<std::complex<double>>()};
W
wanghuancoder 已提交
1279 1280 1281 1282
        } else {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "When assign a value to a paddle.Tensor, "
              "the data type of the paddle.Tensor must be bool, "
1283 1284
              "float32, float64, complex64, complex128, int32, int64 or "
              "float16, "
W
wanghuancoder 已提交
1285 1286 1287 1288 1289 1290 1291
              "please check the type of tensor."));
        }
        attrs["shape"] = std::vector<int64_t>{1};

      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Value type error. The assign value allows "
1292
            "numpy.ndarray, integer, float, complex  or bool, "
W
wanghuancoder 已提交
1293 1294 1295 1296 1297 1298 1299
            "but received %s.",
            Py_TYPE(value_obj)));
      }
    }
    {
      // Release gil and do tracing
      py::gil_scoped_release release;
1300
      // use inplace set_value_ operator
J
Jiabin Yang 已提交
1301 1302
      if (value_tensor.initialized() &&
          (self->tensor.dtype() != value_tensor.dtype())) {
1303
        paddle::small_vector<std::vector<paddle::Tensor>,
J
Jiabin Yang 已提交
1304 1305 1306 1307 1308 1309 1310
                             egr::kSlotSmallVectorSize>
            tmps = {{self->tensor}, {value_tensor}};
        auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
        self->tensor = egr::EagerAmpAutoCast(
            self->tensor.name(), self->tensor, amp_dtype, "set_value");
        value_tensor = egr::EagerAmpAutoCast(
            value_tensor.name(), value_tensor, amp_dtype, "set_value");
1311 1312 1313
        if (self->tensor.dtype() != value_tensor.dtype()) {
          value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
        }
J
Jiabin Yang 已提交
1314
      }
1315 1316
      self->tensor = set_value__dygraph_function(
          self->tensor, value_tensor, {}, {}, {}, attrs);
1317 1318 1319 1320 1321 1322 1323 1324 1325
    }
    if (PyCheckTensor(value_obj)) {
      // pass the stop_gradient from value to tensor.
      // pass stop gradient should be done after CheckInplace in
      // set_value__dygraph_function.
      if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
        egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
      }
W
wanghuancoder 已提交
1326 1327
    }
  } else {
1328
    auto self_numpy = TensorToPyArray(*self_tensor, true);
W
wanghuancoder 已提交
1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339
    VLOG(4) << "parse_index is false";
    if (PyCheckTensor(_index)) {
      VLOG(4) << "index is tensor";
      auto index_tensor = static_cast<phi::DenseTensor*>(
          reinterpret_cast<TensorObject*>(_index)->tensor.impl().get());
      auto index_numpy = TensorToPyArray(*index_tensor);
      self_numpy[index_numpy] = py::object(py::handle(value_obj), true);
    } else {
      VLOG(4) << "index is not tensor";
      self_numpy[_index] = py::object(py::handle(value_obj), true);
    }
1340
    if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
1341
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1342 1343 1344 1345
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CUDAPlace(0)),
                           false);
W
wanghuancoder 已提交
1346
#else
1347 1348 1349 1350
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CPUPlace()),
                           false);
W
wanghuancoder 已提交
1351 1352
#endif
    } else {
1353 1354
      SetTensorFromPyArray(
          self_tensor, self_numpy, self->tensor.place(), false);
W
wanghuancoder 已提交
1355 1356
    }
  }
1357 1358
  RETURN_PY_NONE

W
wanghuancoder 已提交
1359 1360 1361
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1362 1363
static PyObject* tensor_register_grad_hook(TensorObject* self,
                                           PyObject* args,
1364 1365 1366
                                           PyObject* kwargs) {
  EAGER_TRY
  int64_t hook_id;
1367
  if (egr::EagerUtils::IsLeafTensor(self->tensor)) {
1368
    VLOG(6) << "Register hook for leaf tensor: " << self->tensor.name();
1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380

    auto autograd_meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);

    if (autograd_meta && !autograd_meta->StopGradient()) {
      if (!autograd_meta->GetMutableGradNode()) {
        VLOG(6) << "Detected NULL grad_node, Leaf tensor should have had "
                   "grad_node with type: GradNodeAccumulation.";
        autograd_meta->SetGradNode(
            std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
      }
    }

1381 1382 1383 1384 1385 1386 1387 1388 1389
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();
    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    auto accumulation_grad_node =
        std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
    hook_id = accumulation_grad_node->RegisterGradientHook(
1390 1391
        rank_info.first,
        rank_info.second,
1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403
        std::make_shared<PyTensorHook>(hook_func));

  } else {
    VLOG(6) << "Register hook for non leaf tensor: " << self->tensor.name();
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();

    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    hook_id = grad_node->RegisterGradientHook(
1404 1405
        rank_info.first,
        rank_info.second,
1406 1407 1408 1409 1410 1411
        std::make_shared<PyTensorHook>(hook_func));
  }
  return ToPyObject(hook_id);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1412 1413
static PyObject* tensor_remove_grad_hook(TensorObject* self,
                                         PyObject* args,
1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425
                                         PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Remove the registered hook for tensor: " << self->tensor.name();
  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);

  int64_t hook_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);

  return ToPyObject(grad_node->RemoveGradientHook(hook_id));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1426 1427
static PyObject* tensor_register_reduce_hook(TensorObject* self,
                                             PyObject* args,
1428 1429 1430 1431 1432 1433
                                             PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Register reduce hook for tensor: " << self->tensor.name();

  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);
1434
  PADDLE_ENFORCE_EQ(egr::EagerUtils::IsLeafTensor(self->tensor),
1435
                    true,
1436 1437 1438 1439
                    platform::errors::InvalidArgument(
                        "Only can register backward hook for leaf Tensor."));
  PADDLE_ENFORCE_EQ(
      !egr::EagerUtils::unsafe_autograd_meta(self->tensor)->StopGradient(),
1440 1441 1442 1443
      true,
      platform::errors::InvalidArgument(
          "Cannot register backward hook on a Tensor that stop "
          "gradient."));
1444 1445 1446 1447 1448 1449 1450 1451 1452 1453
  PADDLE_ENFORCE(
      grad_node.get() != nullptr,
      paddle::platform::errors::Fatal("Detected NULL grad_node,"
                                      "Leaf tensor should have had grad_node "
                                      "with type: GradNodeAccumulation."));
  PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

  auto accumulation_grad_node =
      std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
  accumulation_grad_node->RegisterReduceHook(
1454
      std::make_shared<PyVoidHook>(hook_func));
1455

1456 1457
  RETURN_PY_NONE

1458 1459 1460
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1461 1462
static PyObject* tensor__set_grad_type(TensorObject* self,
                                       PyObject* args,
J
Jiabin Yang 已提交
1463
                                       PyObject* kwargs) {
1464 1465 1466
  EAGER_TRY
  auto var_type = pybind::CastPyArg2ProtoType(PyTuple_GET_ITEM(args, 0), 0);
  auto grad_tensor =
1467
      egr::EagerUtils::autograd_meta(&self->tensor)->MutableGrad();
1468
  if (var_type == framework::proto::VarType::LOD_TENSOR) {
1469
    grad_tensor->set_impl(std::make_shared<phi::DenseTensor>());
1470
  } else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
1471
    grad_tensor->set_impl(std::make_shared<phi::SelectedRows>());
1472
  }
1473 1474
  RETURN_PY_NONE

1475 1476 1477
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1478 1479
static PyObject* tensor__clear(TensorObject* self,
                               PyObject* args,
J
Jiabin Yang 已提交
1480 1481 1482
                               PyObject* kwargs) {
  EAGER_TRY
  self->tensor.reset();
1483 1484
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1485 1486 1487
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1488 1489 1490 1491 1492 1493 1494 1495 1496
static PyObject* tensor__clear_dataptr(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  self->tensor.set_impl(nullptr);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1497 1498
static PyObject* tensor__copy_gradient_from(TensorObject* self,
                                            PyObject* args,
J
Jiabin Yang 已提交
1499 1500 1501
                                            PyObject* kwargs) {
  EAGER_TRY
  auto src = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
1502
  if (self->tensor.initialized()) {
1503 1504
    PADDLE_ENFORCE_EQ(self->tensor.dtype(),
                      src.dtype(),
J
Jiabin Yang 已提交
1505 1506
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different data type with Tensor %s",
1507 1508
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1509 1510 1511 1512 1513
    PADDLE_ENFORCE_EQ(self->tensor.impl()->type_info().id(),
                      src.impl()->type_info().id(),
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different type with Tensor %s, Tensor "
                          "ShareGradientDataWith cannot be performed!",
1514 1515
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1516 1517 1518 1519
  }
  VLOG(6) << "Tensor copy gradient from: " << src.name();
  auto* p_grad = egr::EagerUtils::mutable_grad(self->tensor);
  if (p_grad) {
1520 1521
    PADDLE_ENFORCE_EQ(src.initialized(),
                      true,
J
Jiabin Yang 已提交
1522 1523 1524 1525
                      platform::errors::InvalidArgument(
                          "Tensor %s has not been initialized", src.name()));
    p_grad->set_impl(src.impl());
  }
1526 1527
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1528 1529
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
1530

1531 1532 1533
static PyObject* tensor__use_gpudnn(TensorObject* self,
                                    PyObject* args,
                                    PyObject* kwargs) {
1534 1535 1536
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.defined() && self->tensor.is_dense_tensor(),
                 paddle::platform::errors::Fatal(
1537
                     "function _use_gpudnn is only effective for DenseTensor"));
1538

1539
  bool use_gpudnn = pybind::CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
1540

1541
  // Set the same use_gpudnn attribute, return directly
1542 1543 1544 1545
  phi::DenseTensor* dense_tensor =
      static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  phi::DenseTensorMeta* dense_tensor_meta =
      phi::DenseTensorUtils::GetMutableMeta(dense_tensor);
1546
  if (use_gpudnn == dense_tensor_meta->use_gpudnn) {
1547 1548 1549
    return ToPyObject(self->tensor);
  }

1550
  // Share all other members of Tensor except use_gpudnn
1551
  phi::DenseTensorMeta target_dense_meta = *dense_tensor_meta;
1552
  target_dense_meta.use_gpudnn = use_gpudnn;
1553 1554 1555 1556
  phi::DenseTensor target_dense_tensor;
  target_dense_tensor.ShareDataWith(*dense_tensor);
  target_dense_tensor.set_meta(target_dense_meta);
  // Construct returned tensor
1557
  paddle::Tensor target_tensor(
1558 1559 1560 1561
      std::make_shared<phi::DenseTensor>(target_dense_tensor),
      self->tensor.name());
  target_tensor.set_autograd_meta(self->tensor.mutable_autograd_meta());
  VLOG(4) << "Tensor: " << target_tensor.name()
1562
          << " set use_gpudnn = " << use_gpudnn;
1563 1564 1565 1566 1567

  return ToPyObject(target_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1568 1569
static PyObject* tensor_method_set_vocab(TensorObject* self,
                                         PyObject* args,
1570 1571
                                         PyObject* kwargs) {
  EAGER_TRY
1572
  using Vocab = paddle::framework::Vocab;
1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584
  auto vocab = CastPyArg2Vocab(PyTuple_GET_ITEM(args, 0), 0);
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Vocab>() = vocab;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_set_string_list(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
1585
  using Strings = paddle::framework::Strings;
1586
  auto strings = CastPyArg2VectorOfString(PyTuple_GET_ITEM(args, 0), 0);
1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Strings>() = strings;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_map_tensor(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE_EQ(
1599 1600
      egr::IsVariableCompatTensor(self->tensor),
      true,
1601 1602
      paddle::platform::errors::Fatal(
          "this method is only effective for VariableCompatTensor"));
1603
  using Vocab = paddle::framework::Vocab;
1604 1605 1606 1607 1608 1609
  auto* var_tensor =
      static_cast<const egr::VariableCompatTensor*>(self->tensor.impl().get());
  return ToPyObject(var_tensor->Get<Vocab>());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630
static PyObject* tensor_method_get_non_zero_nums(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
    return ToPyObject(sparse_coo_tensor->nnz());
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
    return ToPyObject(sparse_csr_tensor->nnz());
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1631 1632 1633 1634 1635 1636 1637 1638 1639
static PyObject* tensor_method_get_non_zero_indices(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_coo_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCooTensor"));
  auto sparse_coo_tensor =
      std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
1640
  paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657
      sparse_coo_tensor->non_zero_indices()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_elements(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
1658
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1659 1660 1661 1662 1663
        sparse_coo_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1664
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679
        sparse_csr_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_crows(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1680
  paddle::Tensor tensor(
1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_crows()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_cols(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1695
  paddle::Tensor tensor(
1696 1697 1698 1699 1700
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_cols()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1701 1702
static PyObject* tensor_method_is_dense(TensorObject* self,
                                        PyObject* args,
1703 1704 1705 1706 1707 1708 1709 1710 1711
                                        PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
  return ToPyObject(self->tensor.is_dense_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

L
LiYuRio 已提交
1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722
static PyObject* tensor_method_is_dist(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
  return ToPyObject(self->tensor.is_dist_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1723 1724
static PyObject* tensor_method_is_sparse(TensorObject* self,
                                         PyObject* args,
1725 1726
                                         PyObject* kwargs) {
  EAGER_TRY
1727 1728 1729
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1730 1731 1732 1733 1734
  return ToPyObject(self->tensor.is_sparse_coo_tensor() ||
                    self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1735 1736
static PyObject* tensor_method_is_sparse_coo(TensorObject* self,
                                             PyObject* args,
1737 1738
                                             PyObject* kwargs) {
  EAGER_TRY
1739 1740 1741
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1742 1743 1744 1745
  return ToPyObject(self->tensor.is_sparse_coo_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1746 1747
static PyObject* tensor_method_is_sparse_csr(TensorObject* self,
                                             PyObject* args,
1748 1749
                                             PyObject* kwargs) {
  EAGER_TRY
1750 1751 1752
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1753 1754 1755 1756
  return ToPyObject(self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1757 1758
static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
                                             PyObject* args,
1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771
                                             PyObject* kwargs) {
  EAGER_TRY
  auto csr_tensor = self->tensor.to_sparse_csr();
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetStopGradient(
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient());
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetPersistable(
          egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  return ToPyObject(csr_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1772 1773 1774 1775 1776 1777 1778 1779 1780
static PyObject* tensor_method_is_same_shape(TensorObject* self,
                                             PyObject* args,
                                             PyObject* kwargs) {
  EAGER_TRY
  auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
  return ToPyObject(self->tensor.shape() == other.shape());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1781 1782
static PyObject* tensor__inplace_version(TensorObject* self,
                                         PyObject* args,
1783 1784 1785 1786 1787 1788 1789 1790
                                         PyObject* kwargs) {
  EAGER_TRY
  uint32_t inplace_version = self->tensor.current_inplace_version();

  return ToPyObject(inplace_version);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1791 1792
static PyObject* tensor_method_element_size(TensorObject* self,
                                            PyObject* args,
1793 1794
                                            PyObject* kwargs) {
  EAGER_TRY
1795
  uint32_t element_size = phi::SizeOf(self->tensor.dtype());
1796 1797 1798 1799 1800

  return ToPyObject(element_size);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1801 1802 1803 1804 1805
static PyObject* tensor__bump_inplace_version(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  self->tensor.bump_inplace_version();
1806
  RETURN_PY_NONE
1807 1808 1809
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1810 1811 1812 1813
static PyObject* tensor_method_is_selected_rows(TensorObject* self,
                                                PyObject* args,
                                                PyObject* kwargs) {
  EAGER_TRY
1814 1815 1816
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1817 1818 1819 1820
  return ToPyObject(self->tensor.is_selected_rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1821 1822
static PyObject* tensor_method_get_rows(TensorObject* self,
                                        PyObject* args,
1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833
                                        PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows"));
  auto selected_rows =
      std::dynamic_pointer_cast<phi::SelectedRows>(self->tensor.impl());
  return ToPyObject(selected_rows->rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1834 1835 1836 1837 1838 1839 1840 1841 1842 1843
static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
  }

1844
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1845 1846 1847 1848
  if (grad && grad->defined() && grad->is_dense_tensor() &&
      grad->initialized()) {
    grad->reset_inplace_version(set_to_zero);
  }
1849 1850
  RETURN_PY_NONE

1851 1852 1853
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1854 1855
static PyObject* tensor_method__share_memory(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
1856 1857 1858
                                             PyObject* kwargs) {
  EAGER_TRY
#ifndef _WIN32
1859 1860
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
W
wanghuancoder 已提交
1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876
                    platform::errors::InvalidArgument(
                        "Sharing memory only support CPU Tensor currently"));
  // 1. get LoDTensor
  auto* t =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl()).get();
  // 2. allocate shared memory
  void* data_ptr = t->data();
  size_t data_size =
      t->numel() *
      framework::SizeOfType(framework::TransToProtoVarType(t->dtype()));
  auto shared_writer_holder =
      memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
  // 3. maintain mmap fd set & backup ipc_name
  const std::string& ipc_name = shared_writer_holder->ipc_name();
  memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
  // 4. copy data & reset holder
1877 1878 1879 1880 1881
  memory::Copy(platform::CPUPlace(),
               shared_writer_holder->ptr(),
               platform::CPUPlace(),
               data_ptr,
               data_size);
W
wanghuancoder 已提交
1882 1883 1884 1885 1886
  t->ResetHolder(shared_writer_holder);
  return ToPyObject(t);
#else
  PADDLE_THROW(platform::errors::PermissionDenied(
      "Sharing memory in Windows OS is not supported currently"));
1887 1888
  RETURN_PY_NONE

W
wanghuancoder 已提交
1889 1890 1891 1892
#endif
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1893 1894
static PyObject* tensor__offset(TensorObject* self,
                                PyObject* args,
1895 1896 1897 1898
                                PyObject* kwargs) {
  EAGER_TRY
  auto t = std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  PADDLE_ENFORCE_EQ(
1899 1900
      t->IsInitialized(),
      true,
1901 1902 1903 1904 1905 1906 1907
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
                                        self->tensor.name()));

  return ToPyObject(t->offset());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1908 1909
static PyObject* tensor__grad_name(TensorObject* self,
                                   PyObject* args,
1910 1911
                                   PyObject* kwargs) {
  EAGER_TRY
1912
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1913 1914
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1915 1916 1917 1918 1919 1920 1921
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));
  return ToPyObject(grad->name());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1922 1923
static PyObject* tensor__grad_value(TensorObject* self,
                                    PyObject* args,
1924 1925
                                    PyObject* kwargs) {
  EAGER_TRY
1926
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1927 1928
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1929 1930 1931 1932 1933
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));

  if (!grad->defined()) {
1934
    RETURN_PY_NONE
1935 1936
  }
  if (grad->is_dense_tensor()) {
1937
    auto* grad_tensor = static_cast<phi::DenseTensor*>(grad->impl().get());
1938 1939 1940 1941
    return ToPyObject(grad_tensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::Fatal(
        "this method is only supported for DenseTensor"));
1942
    RETURN_PY_NONE
1943 1944 1945 1946
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1947 1948
static PyObject* tensor__unset_fake_empty(TensorObject* self,
                                          PyObject* args,
1949 1950
                                          PyObject* kwargs) {
  EAGER_TRY
1951
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1952 1953
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1954 1955 1956 1957
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));

1958
  bool is_leaf = egr::EagerUtils::IsLeafTensor(self->tensor);
1959 1960 1961 1962 1963 1964 1965 1966 1967
  if (is_leaf) {
    std::static_pointer_cast<egr::GradNodeAccumulation>(
        egr::EagerUtils::grad_node(self->tensor))
        ->SetFakeEmpty(false);
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1968 1969 1970 1971 1972
static PyObject* tensor_data_ptr(TensorObject* self,
                                 PyObject* args,
                                 PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.initialized() && self->tensor.is_dense_tensor()) {
S
sneaxiy 已提交
1973 1974 1975 1976
    return ToPyObject(
        (int64_t)std::dynamic_pointer_cast<phi::DenseTensor>(  // NOLINT
            self->tensor.impl())
            ->data());
1977 1978 1979 1980 1981
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996
static PyObject* tensor__grad_ivar(TensorObject* self,
                                   PyObject* args,
                                   PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Get grad for tensor: " << self->tensor.name();
  auto meta = egr::EagerUtils::nullable_autograd_meta(self->tensor);
  VLOG(6) << meta << " initialized: " << meta->Grad().initialized();
  if (meta && meta->Grad().initialized()) {
    return ToPyObject(meta->Grad());
  } else {
    RETURN_PY_NONE
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1997
#if defined(PADDLE_WITH_CUDA)
1998 1999
static PyObject* tensor_method__uva(TensorObject* self,
                                    PyObject* args,
2000 2001 2002
                                    PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Running in tensor_method__uva.";
2003 2004
  PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(),
                    true,
W
Weilong Wu 已提交
2005 2006 2007
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "DenseTensor currently."));
2008 2009
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
2010 2011 2012 2013
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "CPU Tensor currently."));
  int device_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
2014
  auto* self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
2015 2016
  tensor_uva(self_tensor, device_id);

2017 2018
  RETURN_PY_NONE

2019 2020 2021
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
#endif
J
Jack Zhou 已提交
2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033
static PyObject* tensor_method__is_string_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto string_tensor =
      std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
  if (string_tensor) {
    return ToPyObject(string_tensor->initialized());
  } else {
    return ToPyObject(false);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
2034

2035
PyMethodDef variable_methods[] = {
2036 2037 2038 2039
    {"numpy",
     (PyCFunction)(void (*)(void))tensor_method_numpy,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2040
    {"_is_initialized",
2041
     (PyCFunction)(void (*)(void))tensor_method__is_initialized,
2042 2043
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2044
    {"_is_dense_tensor_hold_allocation",
2045 2046
     (PyCFunction)(void (*)(
         void))tensor_method__is_dense_tensor_hold_allocation,
2047 2048 2049 2050 2051 2052 2053 2054 2055 2056
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_copy_to",
     (PyCFunction)(void (*)(void))tensor_method__copy_to,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"copy_",
     (PyCFunction)(void (*)(void))tensor_method_copy_,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2057 2058 2059 2060
    {"clone",
     (PyCFunction)(void (*)(void))tensor_method_clone,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2061
    {"reconstruct_from_",
2062
     (PyCFunction)(void (*)(void))tensor_method_reconstruct_from_,
2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"retain_grads",
     (PyCFunction)(void (*)(void))tensor_retain_grads,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"clear_gradient",
     (PyCFunction)(void (*)(void))tensor_clear_gradient,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_dense",
     (PyCFunction)(void (*)(void))tensor_method_is_dense,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
L
LiYuRio 已提交
2077 2078 2079 2080
    {"is_dist",
     (PyCFunction)(void (*)(void))tensor_method_is_dist,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2081 2082 2083 2084 2085 2086 2087 2088
    {"_zero_grads",
     (PyCFunction)(void (*)(void))tensor__zero_grads,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_share_buffer_to",
     (PyCFunction)(void (*)(void))tensor__share_buffer_to,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2089
    {"_is_shared_buffer_with",
2090
     (PyCFunction)(void (*)(void))tensor__is_shared_buffer_with,
2091 2092
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2093
    {"_share_underline_tensor_to",
2094
     (PyCFunction)(void (*)(void))tensor__share_underline_tensor_to,
2095 2096
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2097
    {"_is_shared_underline_tensor_with",
2098
     (PyCFunction)(void (*)(void))tensor__is_shared_underline_tensor_with,
2099 2100 2101 2102 2103 2104
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"detach",
     (PyCFunction)(void (*)(void))tensor_method_detach,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2105
    {"get_tensor",
2106
     (PyCFunction)(void (*)(void))tensor_method_get_underline_tensor,
2107 2108
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2109 2110
    {"get_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method_get_underline_selected_rows,
2111 2112
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2113 2114 2115 2116
    {"_get_tensor_from_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method__get_tensor_from_selected_rows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jiabin Yang 已提交
2117 2118
    {"_getitem_index_not_tensor",
     (PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor,
2119 2120
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2121 2122
    {"_getitem_from_offset",
     (PyCFunction)(void (*)(void))tensor__getitem_from_offset,
2123 2124
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2125 2126
    {"__setitem_eager_tensor__",
     (PyCFunction)(void (*)(void))tensor_method__setitem_eager_tensor,
2127 2128
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2129 2130
    {"_register_grad_hook",
     (PyCFunction)(void (*)(void))tensor_register_grad_hook,
2131 2132 2133 2134 2135 2136
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_remove_grad_hook",
     (PyCFunction)(void (*)(void))tensor_remove_grad_hook,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2137 2138
    {"_register_backward_hook",
     (PyCFunction)(void (*)(void))tensor_register_reduce_hook,
2139 2140 2141 2142 2143 2144 2145 2146 2147 2148
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_set_grad_type",
     (PyCFunction)(void (*)(void))tensor__set_grad_type,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_clear",
     (PyCFunction)(void (*)(void))tensor__clear,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2149 2150 2151 2152
    {"_clear_dataptr",
     (PyCFunction)(void (*)(void))tensor__clear_dataptr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jiabin Yang 已提交
2153 2154
    {"_copy_gradient_from",
     (PyCFunction)(void (*)(void))tensor__copy_gradient_from,
2155 2156
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2157 2158
    {"_tensor_use_gpudnn",
     (PyCFunction)(void (*)(void))tensor__use_gpudnn,
2159 2160
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2161 2162 2163
    /** the methods to adapt old dygraph, will be removed in the future **/
    {"set_string_list",
     (PyCFunction)(void (*)(void))tensor_method_set_string_list,
2164 2165 2166 2167 2168 2169
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"set_vocab",
     (PyCFunction)(void (*)(void))tensor_method_set_vocab,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2170 2171
    {"get_map_tensor",
     (PyCFunction)(void (*)(void))tensor_method_get_map_tensor,
2172 2173
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2174
    /***the method of sparse tensor****/
2175 2176 2177 2178
    {"nnz",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_nums,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206
    {"indices",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_indices,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"values",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_elements,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"crows",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_crows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"cols",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_cols,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse_coo",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse_coo,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse_csr",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2207 2208 2209 2210
    {"is_same_shape",
     (PyCFunction)(void (*)(void))tensor_method_is_same_shape,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2211 2212 2213 2214 2215 2216 2217 2218
    {"to_sparse_csr",
     (PyCFunction)(void (*)(void))tensor_method_to_sparse_csr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"element_size",
     (PyCFunction)(void (*)(void))tensor_method_element_size,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2219
    /***the method of sparse tensor****/
2220 2221 2222 2223
    {"_inplace_version",
     (PyCFunction)(void (*)(void))tensor__inplace_version,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2224 2225
    {"_bump_inplace_version",
     (PyCFunction)(void (*)(void))tensor__bump_inplace_version,
2226 2227
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2228 2229
    {"is_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method_is_selected_rows,
2230 2231 2232 2233 2234 2235
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"rows",
     (PyCFunction)(void (*)(void))tensor_method_get_rows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2236 2237
    {"_reset_grad_inplace_version",
     (PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version,
2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254 2255 2256 2257 2258 2259
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_share_memory",
     (PyCFunction)(void (*)(void))tensor_method__share_memory,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_offset",
     (PyCFunction)(void (*)(void))tensor__offset,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_grad_name",
     (PyCFunction)(void (*)(void))tensor__grad_name,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_grad_value",
     (PyCFunction)(void (*)(void))tensor__grad_value,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_unset_fake_empty",
     (PyCFunction)(void (*)(void))tensor__unset_fake_empty,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2260 2261 2262 2263
    {"data_ptr",
     (PyCFunction)(void (*)(void))tensor_data_ptr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2264 2265 2266 2267
    {"_grad_ivar",
     (PyCFunction)(void (*)(void))tensor__grad_ivar,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2268
#if defined(PADDLE_WITH_CUDA)
2269 2270 2271 2272
    {"_tensor_uva",
     (PyCFunction)(void (*)(void))tensor_method__uva,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2273
#endif
2274 2275
    {NULL, NULL, 0, NULL}};

J
Jack Zhou 已提交
2276 2277 2278 2279
// variable_methods for core.eager.StringTensor
PyMethodDef string_tensor_variable_methods[] = {
    {"numpy",
     (PyCFunction)(void (*)(void))tensor_method_numpy_for_string_tensor,
2280 2281
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2282 2283
    {"_is_initialized",
     (PyCFunction)(void (*)(void))tensor_method__is_initialized,
2284 2285
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2286
    {"_is_string_tensor_hold_allocation",
2287 2288
     (PyCFunction)(void (*)(
         void))tensor_method__is_string_tensor_hold_allocation,
2289 2290
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2291 2292 2293
    // TODO(zhoushunjie): Need to add _copy_to, copy_ for StringTensor.
    {NULL, NULL, 0, NULL}};

2294 2295
}  // namespace pybind
}  // namespace paddle