eager_method.cc 111.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// disable numpy compile error
12 13 14 15 16 17

#if defined(_MSC_VER)
#include <BaseTsd.h>
typedef SSIZE_T ssize_t;
#endif

18
#include <Python.h>
19 20 21 22
// Avoid a problem with copysign defined in pyconfig.h on Windows.
#ifdef copysign
#undef copysign
#endif
23 24

#include <string>
25
#include <unordered_map>
26 27
#include <vector>

28
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
29
#include "paddle/fluid/eager/api/all.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
31
#include "paddle/fluid/eager/autograd_meta.h"
32 33
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
34
#include "paddle/fluid/eager/utils.h"
35
#include "paddle/fluid/framework/convert_utils.h"
36
#include "paddle/fluid/framework/string_array.h"
37 38 39 40 41 42
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
J
Jiabin Yang 已提交
43
#include "paddle/fluid/pybind/slice_utils.h"
44
#include "paddle/fluid/pybind/uva_utils.h"
45 46 47 48
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
49 50
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
W
wanghuancoder 已提交
51
#include "pybind11/detail/internals.h"
52 53
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
W
wanghuancoder 已提交
54
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
J
Jiabin Yang 已提交
55
#include "paddle/fluid/eager/amp_utils.h"
56
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
J
Jiabin Yang 已提交
57
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
W
wanghuancoder 已提交
58
#include "paddle/fluid/framework/python_headers.h"
W
wanghuancoder 已提交
59
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
W
wanghuancoder 已提交
60
#include "paddle/fluid/pybind/tensor_py.h"
W
wanghuancoder 已提交
61
#include "paddle/phi/api/lib/data_transform.h"
W
wanghuancoder 已提交
62
#include "paddle/phi/core/ddim.h"
63
#include "paddle/phi/core/flags.h"
64
#include "paddle/phi/core/tensor_utils.h"
65
#include "paddle/phi/kernels/funcs/math_function.h"
66
#include "paddle/utils/pybind.h"
L
LiYuRio 已提交
67 68 69
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
J
Jiabin Yang 已提交
70

71
PHI_DECLARE_bool(set_to_1d);
W
wanghuancoder 已提交
72
DECLARE_bool(use_stride_kernel);
73

74 75 76
namespace paddle {
namespace pybind {

77 78
extern void InitTensorWithNumpyValue(TensorObject* self,
                                     const pybind11::object& array,
79
                                     const paddle::platform::Place& place,
80
                                     bool zero_copy);
81

82
extern PyTypeObject* p_tensor_type;
83

J
Jiabin Yang 已提交
84
Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) {
85
  if (PyObject_TypeCheck(obj, p_tensor_type)) {
J
Jiabin Yang 已提交
86
    VLOG(6) << "Call GetSliceIndexFromTensor in Eager";
87
    paddle::Tensor tensor = CastPyArg2Tensor(obj, 0);
J
Jiabin Yang 已提交
88
    PADDLE_ENFORCE_EQ(
89 90
        tensor.initialized(),
        true,
J
Jiabin Yang 已提交
91 92 93 94 95 96 97 98
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in slice, however we got "
            "uninitialized tensor %s, please check your code.",
            tensor.name()));
    return GetSliceIndexFromTensor((*static_cast<phi::DenseTensor*>(
        CastPyArg2Tensor(obj, 0).impl().get())));
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
99
        "We should only get paddle::Tensor or VarBase in this "
J
Jiabin Yang 已提交
100 101 102 103
        "method, when you reach this means we got another type index."));
  }
}

104 105
PyDoc_STRVAR(tensor_method_numpy__doc__,  // NOLINT
             R"DOC(numpy($self, /)
W
wanghuancoder 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
--

Returns a numpy array shows the value of current Tensor.

Returns:
    ndarray, The numpy value of current Tensor, dtype is
    same as current Tensor.

Examples:
    .. code-block:: python

        import paddle

        data = paddle.uniform([30, 10, 32], dtype="float32", min=-1, max=1)
        linear = paddle.nn.Linear(32, 64)
        data = paddle.to_tensor(data)
        x = linear(data)
        print(x.numpy())
)DOC");

126 127
static PyObject* tensor_method_numpy(TensorObject* self,
                                     PyObject* args,
128 129
                                     PyObject* kwargs) {
  EAGER_TRY
W
wanghuancoder 已提交
130 131
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl()) {
132 133
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];     // NOLINT
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];  // NOLINT
W
wanghuancoder 已提交
134 135 136 137 138
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
139 140 141 142 143
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_FLOAT_),
        1,
        py_dims,
        py_strides,
        nullptr,
W
wanghuancoder 已提交
144 145 146 147 148
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }
149
  auto tensor_dims = self->tensor.shape();
150 151 152 153 154 155 156 157 158
#ifdef PADDLE_WITH_DISTRIBUTE
  // Now the DistTensor's numpy() return the local tensor value
  if (self->tensor.is_dist_tensor()) {
    tensor_dims = phi::vectorize(
        static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get())
            ->value()
            .dims());
  }
#endif
159
  auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type());
160
  auto sizeof_dtype = phi::SizeOf(self->tensor.type());
161 162
  Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];     // NOLINT
  Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];  // NOLINT
163
  size_t py_rank = tensor_dims.size();
164
  size_t numel = 1;
165
  if (py_rank == 0) {
166
    Py_ssize_t args_num = PyTuple_Size(args);
167 168
    // true by default
    bool set_to_1d = FLAGS_set_to_1d;
169 170 171 172 173 174 175
    if (args_num == (Py_ssize_t)1) {
      PyObject* obj = PyTuple_GET_ITEM(args, 0);
      if (obj == Py_False) {
        set_to_1d = false;
      }
    }
    if (set_to_1d) {
176
      // 0D Tensor hack process to 1D numpy, will remove in release 2.6
177 178 179 180 181
      VLOG(0)
          << "Warning:: 0D Tensor cannot be used as 'Tensor.numpy()[0]' . In "
             "order to avoid this problem, "
             "0D Tensor will be changed to 1D numpy currently, but it's not "
             "correct and will be "
182 183
             "removed in release 2.6. For Tensor contain only one element, "
             "Please "
184
             "modify "
185
             " 'Tensor.numpy()[0]' to 'float(Tensor)' as soon as "
186
             "possible, "
187
             "otherwise 'Tensor.numpy()[0]' will raise error in release 2.6.";
188 189 190 191
      py_rank = 1;
      py_dims[0] = 1;
      py_strides[0] = sizeof_dtype * numel;
    }
W
wanghuancoder 已提交
192 193 194 195 196 197 198 199
  } else if (self->tensor.is_dense_tensor()) {
    auto tensor_stride = self->tensor.strides();

    for (int i = tensor_dims.size() - 1; i >= 0; --i) {
      py_dims[i] = static_cast<size_t>(tensor_dims[i]);
      py_strides[i] = sizeof_dtype * tensor_stride[i];
      numel *= py_dims[i];
    }
200 201 202 203 204 205
  } else {
    for (int i = tensor_dims.size() - 1; i >= 0; --i) {
      py_dims[i] = static_cast<size_t>(tensor_dims[i]);
      py_strides[i] = sizeof_dtype * numel;
      numel *= py_dims[i];
    }
206
  }
W
wanghuancoder 已提交
207 208

  if (!self->tensor.impl()->initialized()) {
W
wanghuancoder 已提交
209 210 211 212 213 214 215 216 217 218 219
    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
        api.PyArray_DescrFromType_(numpy_dtype),
        py_rank,
        py_dims,
        py_strides,
        nullptr,
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);

220
    if (tensor_dims.empty()) {
221 222 223
      py_dims[0] = 0;
      py_strides[0] = 0;
      PyObject* array = api.PyArray_NewFromDescr_(
224 225 226 227 228 229
          api.PyArray_Type_,
          api.PyArray_DescrFromType_(numpy_dtype),
          1,
          py_dims,
          py_strides,
          nullptr,
230 231 232 233 234
          pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
              pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
          nullptr);
      return array;
    }
W
wanghuancoder 已提交
235 236 237
    return array;
  }

W
wanghuancoder 已提交
238 239 240
  phi::DenseTensor cpu_tensor;
  platform::CPUPlace cpu_place;

241
  if (self->tensor.is_cpu() || self->tensor.is_gpu_pinned()) {
W
wanghuancoder 已提交
242
    eager_gil_scoped_release guard;
243
    platform::CPUPlace place;
244 245 246 247
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
248 249
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
250 251 252 253 254
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
255
      // deep copy
W
wanghuancoder 已提交
256 257 258 259 260
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           place,
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
#ifdef PADDLE_WITH_DISTRIBUTE
    } else if (self->tensor.is_dist_tensor()) {
      // TODO(chenweihang): deal with DistTensor as local DenseTensor now,
      // if the local DenseTensor is shard or partial, do gather or reduce?
      VLOG(6) << "Getting DistTensor's numpy value";
      auto* dist_tensor =
          static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
      auto& dense_tensor = dist_tensor->value();
      cpu_tensor.set_meta(dense_tensor.meta());
      // deep copy
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor.Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      // deep copy
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           place,
                           dense_tensor.Holder()->ptr(),
                           dense_tensor.Holder()->size());
#endif
282 283 284 285
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
W
wanghuancoder 已提交
286 287 288 289 290
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
291
      // deep copy
W
wanghuancoder 已提交
292 293 294 295 296
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           place,
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
297 298
    }

299
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
300
  } else if (self->tensor.is_gpu()) {
W
wanghuancoder 已提交
301
    eager_gil_scoped_release guard;
302 303 304 305
#if defined(PADDLE_WITH_CUDA)
    gpuMemcpyKind kind = cudaMemcpyDeviceToHost;
#elif defined(PADDLE_WITH_HIP)
    gpuMemcpyKind kind = hipMemcpyDeviceToHost;
306
    phi::DeviceContextPool::Instance().Get(self->tensor.place())->Wait();
307
#endif
308 309 310 311
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
312 313
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
314 315 316 317 318 319 320 321 322
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
                                      dense_tensor->Holder()->ptr(),
                                      dense_tensor->Holder()->size(),
                                      kind);
323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
#ifdef PADDLE_WITH_DISTRIBUTE
    } else if (self->tensor.is_dist_tensor()) {
      VLOG(6) << "Getting DistTensor's numpy value";
      auto* dist_tensor =
          static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
      auto& dense_tensor = dist_tensor->value();
      cpu_tensor.set_meta(dense_tensor.meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor.Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
                                      dense_tensor.Holder()->ptr(),
                                      dense_tensor.Holder()->size(),
                                      kind);
#endif
339 340 341 342
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
W
wanghuancoder 已提交
343 344 345 346 347 348 349 350 351
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
                                      dense_tensor->Holder()->ptr(),
                                      dense_tensor->Holder()->size(),
                                      kind);
352
    }
353
#endif
C
Chen Weihang 已提交
354 355 356 357 358 359 360
#if defined(PADDLE_WITH_XPU)
  } else if (self->tensor.is_xpu()) {
    platform::CPUPlace place;
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
361 362
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
363 364 365 366 367 368 369 370 371 372
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           dense_tensor->place(),
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
C
Chen Weihang 已提交
373 374 375 376
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
W
wanghuancoder 已提交
377 378 379 380 381 382 383 384 385 386
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           dense_tensor->place(),
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
C
Chen Weihang 已提交
387 388
    }
#endif
389 390
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  } else if (self->tensor.is_custom_device()) {
W
wanghuancoder 已提交
391
    eager_gil_scoped_release guard;
392
    phi::DeviceContextPool::Instance().Get(self->tensor.place())->Wait();
393 394 395 396
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
397 398
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
399 400 401 402 403
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
404
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
W
wanghuancoder 已提交
405 406 407
          ->MemoryCopyD2H(cpu_tensor.Holder()->ptr(),
                          dense_tensor->Holder()->ptr(),
                          dense_tensor->Holder()->size());
408 409 410 411
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
C
co63oc 已提交
412
      // TODO(qili93): temporary for ascend npu performance to be removed along
413
      // with npu_identity op
414
      paddle::Tensor temp_tensor(std::make_shared<phi::DenseTensor>());
415 416 417 418 419
      if (dense_tensor->storage_properties_initialized()) {
        temp_tensor = npu_identity_ad_func(self->tensor, -1);
        dense_tensor =
            std::dynamic_pointer_cast<phi::DenseTensor>(temp_tensor.impl());
      }
W
wanghuancoder 已提交
420 421 422 423 424
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
425
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
W
wanghuancoder 已提交
426 427 428
          ->MemoryCopyD2H(cpu_tensor.Holder()->ptr(),
                          dense_tensor->Holder()->ptr(),
                          dense_tensor->Holder()->size());
429 430
    }
#endif
431 432 433
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Tensor.numpy() only support cpu tensor."));
434
    RETURN_PY_NONE
435 436
  }

W
wanghuancoder 已提交
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
  void* array_buffer = cpu_tensor.Holder()->ptr();
  size_t array_offset = cpu_tensor.offset();

  PyObject* base = ToPyObject(paddle::Tensor(
      std::make_shared<phi::DenseTensor>(std::move(cpu_tensor))));

  PyObject* array = api.PyArray_NewFromDescr_(
      api.PyArray_Type_,
      api.PyArray_DescrFromType_(numpy_dtype),
      py_rank,
      py_dims,
      py_strides,
      reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(array_buffer) +
                              array_offset),
      pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
          pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
      nullptr);

  api.PyArray_SetBaseObject_(array, base);

457 458 459 460
  return array;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jack Zhou 已提交
461 462 463 464 465 466 467 468
static PyObject* tensor_method_numpy_for_string_tensor(TensorObject* self,
                                                       PyObject* args,
                                                       PyObject* kwargs) {
  EAGER_TRY
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl() || !self->tensor.impl()->initialized()) {
    VLOG(6) << "The StringTensor is uninitialized. Return the empty string "
               "numpy array.";
469 470
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];     // NOLINT
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];  // NOLINT
J
Jack Zhou 已提交
471 472 473 474 475
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
476 477 478 479 480
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_UNICODE_),
        1,
        py_dims,
        py_strides,
        nullptr,
J
Jack Zhou 已提交
481 482 483 484 485 486 487 488 489 490 491 492 493
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }

  if (self->tensor.is_cpu()) {
    VLOG(6) << "Getting StringTensor's numpy value";
    auto string_tensor =
        std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
    const auto* st_ptr = string_tensor->data();
    auto numel = self->tensor.numel();
    auto tensor_dims = self->tensor.shape();
W
wanghuancoder 已提交
494 495
    // Get the max unicode length of StringTensor to create numpy unicode
    // string array.
J
Jack Zhou 已提交
496 497 498 499 500 501 502 503 504 505 506 507
    auto* longest_pstring = std::max_element(
        st_ptr, st_ptr + numel, [](const auto& a, const auto& b) {
          auto a_unicode_len =
              phi::strings::GetUnicodeStrLen(a.data(), a.size());
          auto b_unicode_len =
              phi::strings::GetUnicodeStrLen(b.data(), b.size());
          return a_unicode_len < b_unicode_len;
        });
    size_t max_unicode_length = phi::strings::GetUnicodeStrLen(
        longest_pstring->data(), longest_pstring->size());
    max_unicode_length = (max_unicode_length == 0) ? 1 : max_unicode_length;
    VLOG(6) << "The max unicode length is " << max_unicode_length;
508 509
    auto sp =
        std::make_unique<uint32_t[]>(max_unicode_length * numel);  // NOLINT
J
Jack Zhou 已提交
510 511 512 513 514 515 516 517 518 519
    auto py_array_data = sp.get();
    memset(py_array_data, 0, max_unicode_length * numel * sizeof(uint32_t));
    for (int64_t i = 0; i < numel; ++i) {
      auto curr_unicode_len =
          phi::strings::GetUnicodeStrLen(st_ptr[i].data(), st_ptr[i].size());
      phi::strings::GetUnicodeStr(st_ptr[i].data(),
                                  py_array_data + i * max_unicode_length,
                                  curr_unicode_len);
    }
    py::array array(py::dtype("U" + std::to_string(max_unicode_length)),
520 521 522
                    tensor_dims,
                    {},
                    py_array_data);
J
Jack Zhou 已提交
523 524 525 526
    return array.release().ptr();
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "StringTensor.numpy() only support cpu tensor."));
527
    RETURN_PY_NONE
J
Jack Zhou 已提交
528 529 530 531
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

532 533 534 535
static PyObject* tensor_method__is_initialized(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
536
  return ToPyObject(self->tensor.initialized());
537 538 539
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
540 541 542 543 544 545 546 547 548 549 550 551 552 553
static PyObject* tensor_method__is_dense_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto dense_tensor =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  if (dense_tensor) {
    return ToPyObject(dense_tensor->IsInitialized());
  } else {
    return ToPyObject(false);
  }

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

554
static void IncreaseTensorReferenceCountUntilCopyComplete(
555
    const paddle::Tensor& tensor, const platform::Place& place) {
556 557 558 559 560 561 562 563
  auto place_ = platform::is_gpu_place(place) ? place : tensor.place();

  auto tracer = egr::Controller::Instance().GetCurrentTracer();
  auto gc = tracer->MutableGarbageCollectorIfNotExists(place_);

  // Note(dev): This is an empty callback, the only way is to "reference"
  // inner memory Holder, so it will not be destructed until the kernels
  // launched at current stream of given place is finished, such as
C
co63oc 已提交
564
  // CUDAPinned Mem -> CUDA by cudaMemcpyAsync.
565 566 567 568 569 570 571
  auto callback = [tensor, place_]() {
    VLOG(3) << "Run callback of Tensor:" << tensor.name() << " at place "
            << place_;
  };
  gc->DirectClearCallback(callback);
}

572 573
static PyObject* tensor_method__copy_to(TensorObject* self,
                                        PyObject* args,
574 575
                                        PyObject* kwargs) {
  EAGER_TRY
576 577
  auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
578
  paddle::Tensor cp_tensor;
W
wanghuancoder 已提交
579 580 581 582 583 584 585 586 587 588
  {
    eager_gil_scoped_release guard;
    cp_tensor = self->tensor.copy_to(place, blocking);
    if (!blocking) {
      IncreaseTensorReferenceCountUntilCopyComplete(self->tensor, place);
    }
    egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
    egr::EagerUtils::autograd_meta(&cp_tensor)
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
589
  }
590 591 592 593
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615
PyDoc_STRVAR(tensor_reconstruct_from___doc__,
             R"DOC(reconstruct_from_($self, other/)
--

Reconstruct the self with other Tensor. It is a deep copy of 'self = other'.

Returns:
    None.

Examples:
    .. code-block:: python

      import paddle

      t1 = paddle.to_tensor([1.0], stop_gradient=False)
      t2 = paddle.to_tensor([2.0], stop_gradient=True)

      t1.reconstruct_from_(t2)

      print(t1)
)DOC");

616 617 618 619
static PyObject* tensor_method_reconstruct_from_(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
620
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
621
  std::string orig_name = self->tensor.name();
622 623
  VLOG(6) << "Start Reconstructing Tensor from" << src_tensor.name() << " to "
          << orig_name;
624
  self->tensor = src_tensor;
625 626

  // Recover source name
627
  self->tensor.set_name(orig_name);
628 629

  VLOG(6) << "Finished Reconstructing Tensor from" << src_tensor.name()
630
          << " to " << self->tensor.name();
631 632
  RETURN_PY_NONE

633 634 635
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

636 637
static PyObject* tensor_method_copy_(TensorObject* self,
                                     PyObject* args,
638 639
                                     PyObject* kwargs) {
  EAGER_TRY
640
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
641
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
642
  VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to "
643
          << self->tensor.name();
644
  if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
645
    eager_gil_scoped_release guard;
646
    egr::EagerUtils::autograd_meta(&(self->tensor))
647 648
        ->SetStopGradient(
            egr::EagerUtils::autograd_meta(&(src_tensor))->StopGradient());
649
    egr::EagerUtils::autograd_meta(&(self->tensor))
650 651
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(src_tensor))->Persistable());
652
    if (src_tensor.initialized()) {
C
Chen Weihang 已提交
653
      self->tensor.copy_(src_tensor, src_tensor.place(), blocking);
654 655 656
    }
  } else {
    if (src_tensor.initialized()) {
W
wanghuancoder 已提交
657
      eager_gil_scoped_release guard;
C
Chen Weihang 已提交
658
      self->tensor.copy_(src_tensor, self->tensor.place(), blocking);
659
    }
660 661
  }

662
  VLOG(6) << "Finish Copy Tensor " << src_tensor.name() << " to "
663
          << self->tensor.name();
664 665
  RETURN_PY_NONE

666 667 668
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

669 670
PyDoc_STRVAR(tensor_method_clone__doc__,  // NOLINT
             R"DOC(clone($self, /)
W
wanghuancoder 已提交
671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704
--

Returns a new Tensor, which is clone of origin Tensor, and it remains in the current graph.
It will always have a Tensor copy.
Tn addition, the cloned Tensor provides gradient propagation.

Returns:
    Tensor, The cloned Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor(1.0, stop_gradient=False)
        clone_x = x.clone()
        y = clone_x**2
        y.backward()
        print(clone_x.stop_gradient) # False
        print(clone_x.grad)          # [2.0], support gradient propagation
        print(x.stop_gradient)       # False
        print(x.grad)                # [2.0], clone_x support gradient propagation for x

        x = paddle.to_tensor(1.0)
        clone_x = x.clone()
        clone_x.stop_gradient = False
        z = clone_x**3
        z.backward()
        print(clone_x.stop_gradient) # False
        print(clone_x.grad)          # [3.0], support gradient propagation
        print(x.stop_gradient) # True
        print(x.grad)          # None
)DOC");

705 706 707 708
static PyObject* tensor_method_clone(TensorObject* self,
                                     PyObject* args,
                                     PyObject* kwargs) {
  EAGER_TRY
709
  paddle::Tensor out;
W
wanghuancoder 已提交
710 711 712 713 714 715 716 717 718
  {
    eager_gil_scoped_release guard;
    PADDLE_ENFORCE_EQ(
        self->tensor.initialized(),
        true,
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in clone, however we got "
            "uninitialized tensor %s, please check your code.",
            self->tensor.name()));
719

W
wanghuancoder 已提交
720 721
    out = assign_ad_func(self->tensor);
  }
722 723 724 725
  return ToPyObject(out);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757
PyDoc_STRVAR(tensor_method_retain_grads__doc__, R"DOC(retain_grads($self, /)
--

Enables this Tensor to have their grad populated during backward(). It is a no-op for leaf tensors.

Returns:
    None.

Examples:
    .. code-block:: python

      import paddle

      x = paddle.to_tensor([1.0, 2.0, 3.0])
      x.stop_gradient = False
      y = x + x
      y.retain_grads()
      loss = y.sum()
      loss.backward()

      print(y.grad) # [1., 1., 1.]

      x = paddle.to_tensor([1.0, 2.0, 3.0])
      x.stop_gradient = False
      y = x + x
      # y.retain_grads()
      loss = y.sum()
      loss.backward()

      print(y.grad) # None
)DOC");

758 759
static PyObject* tensor_retain_grads(TensorObject* self,
                                     PyObject* args,
760
                                     PyObject* kwargs) {
761
  EAGER_TRY
762
  if (egr::Controller::Instance().HasGrad()) {
W
wanghuancoder 已提交
763
    eager_gil_scoped_release guard;
764
    auto meta = egr::EagerUtils::autograd_meta(&(self->tensor));
765
    if (!meta->GetMutableGradNode()) {
766
      VLOG(6) << "Make grad node of tensor: " << self->tensor.name()
767
              << "become accumulation node";
768
      meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
769
    }
770
    egr::egr_utils_api::RetainGradForTensor(self->tensor);
771
  }
772 773
  RETURN_PY_NONE

774 775 776
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

777
PyDoc_STRVAR(tensor_clear_gradient__doc__,  // NOLINT
W
wanghuancoder 已提交
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806
             R"DOC(clear_gradient($self, set_to_zero=True, /)
--

Only for Tensor that has gradient, normally we use this for Parameters since
other temporary Tensor doesen't has gradient.

The Gradient of current Tensor will be set to ``0`` elementwise or ``None``.

Args:
    set_to_zero (bool, optional): If set to ``True``, the gradient will be set
        to ``0`` elementwise, otherwise the gradient will be set to ``None``.
        Default: ``True``.

Returns:
    None.

Examples:
    .. code-block:: python

        import paddle
        input = paddle.uniform([10, 2])
        linear = paddle.nn.Linear(2, 3)
        out = linear(input)
        out.backward()
        print("Before clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
        linear.weight.clear_gradient()
        print("After clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
)DOC");

807 808
static PyObject* tensor_clear_gradient(TensorObject* self,
                                       PyObject* args,
809
                                       PyObject* kwargs) {
810
  EAGER_TRY
811
  VLOG(4) << "ClearGradient " << self->tensor.name();
812

813 814 815
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
J
Jiabin Yang 已提交
816
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
817 818
  }

819
  paddle::Tensor* grad;
820
  bool is_leaf = egr::EagerUtils::IsLeafTensor(self->tensor);
J
Jiabin Yang 已提交
821
  if (is_leaf) {
822 823 824
    grad = egr::EagerUtils::mutable_grad(self->tensor);
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
825
                       "Detected nullptr grad"
826 827
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
828
  } else {
829
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
830
    grad = meta->MutableGrad();
831 832
  }

833
  if (grad->impl()) {
W
wanghuancoder 已提交
834
    eager_gil_scoped_release guard;
835 836 837 838 839 840 841 842 843 844
    if (grad->is_selected_rows()) {
      auto selected_rows =
          std::dynamic_pointer_cast<phi::SelectedRows>(grad->impl());
      if (selected_rows->mutable_value()->IsInitialized()) {
        selected_rows->mutable_rows()->clear();
        selected_rows->mutable_value()->clear();
      }
    } else if (grad->is_dense_tensor()) {
      if (grad->initialized()) {
        if (set_to_zero) {
845 846 847 848
          auto* grad_t = static_cast<phi::DenseTensor*>(grad->impl().get());
          auto* dev_ctx =
              platform::DeviceContextPool::Instance().Get(grad_t->place());
          phi::funcs::set_constant(*dev_ctx, grad_t, 0.0);
J
Jiabin Yang 已提交
849 850 851 852 853
          if (is_leaf) {
            std::static_pointer_cast<egr::GradNodeAccumulation>(
                egr::EagerUtils::grad_node(self->tensor))
                ->SetFakeEmpty(true);
          }
854 855 856 857 858 859 860
        } else {
          VLOG(4) << "Gradient of " << self->tensor.name()
                  << " is initialized, will be released.";
          auto dense_tensor =
              std::dynamic_pointer_cast<phi::DenseTensor>(grad->impl());
          dense_tensor->MoveMemoryHolder();
        }
861 862
      }
    }
863
  }
864

865 866
  RETURN_PY_NONE

867 868 869
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

870 871
static PyObject* tensor__zero_grads(TensorObject* self,
                                    PyObject* args,
872
                                    PyObject* kwargs) {
873
  EAGER_TRY
874
  VLOG(4) << "ZeroGrads " << self->tensor.name();
875

876
  if (egr::EagerUtils::IsLeafTensor(self->tensor)) {
W
wanghuancoder 已提交
877
    eager_gil_scoped_release guard;
878
    // Add RetainGrad as PostHook to AccumulationNode
879
    paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
880 881
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
882
                       "Detected nullptr grad"
883 884 885
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
    if (grad->initialized()) {
886 887 888 889 890 891 892
      if (grad->is_dense_tensor()) {
        auto* t = static_cast<phi::DenseTensor*>(grad->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl());
      }
893
    }
894
  } else {
W
wanghuancoder 已提交
895
    eager_gil_scoped_release guard;
896
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
897
    if (meta->MutableGrad()->initialized()) {
898 899 900 901 902 903 904 905 906
      if (meta->MutableGrad()->is_dense_tensor()) {
        auto* t =
            static_cast<phi::DenseTensor*>(meta->MutableGrad()->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        meta->MutableGrad()->set_impl(
            paddle::experimental::zeros_like(*(meta->MutableGrad())).impl());
      }
907
    }
908 909
  }

910 911
  RETURN_PY_NONE

912 913 914
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

915 916
static PyObject* tensor__share_buffer_to(TensorObject* self,
                                         PyObject* args,
917 918
                                         PyObject* kwargs) {
  EAGER_TRY
919
  paddle::Tensor* dst_ptr =
920
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
921 922
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
923 924 925
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
926
                        self->tensor.name()));
927
  auto* src_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
928 929 930
  if (!dst_ptr->defined()) {
    dst_ptr->set_impl(std::make_shared<phi::DenseTensor>());
  }
931
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
B
Baibaifan 已提交
932
  dst_tensor->ShareBufferWith(*src_tensor);
933
  dst_tensor->ShareDataTypeWith(*src_tensor);
934 935
  RETURN_PY_NONE

936 937 938
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

939 940 941 942
static PyObject* tensor__is_shared_buffer_with(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
943
  paddle::Tensor* dst_ptr =
944
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
945 946
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
947 948 949
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
950
                        self->tensor.name()));
951
  bool res = false;
952
  if (!self->tensor.defined() || !dst_ptr->defined()) {
953 954
    return ToPyObject(res);
  }
955 956
  auto* self_ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
957 958 959 960 961
  res = dst_tensor->IsSharedBufferWith(*self_ptr);
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

962 963 964 965
static PyObject* tensor__share_underline_tensor_to(TensorObject* self,
                                                   PyObject* args,
                                                   PyObject* kwargs) {
  EAGER_TRY
966
  paddle::Tensor* src_ptr =
967
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
968 969
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
970 971 972
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
973 974
                        self->tensor.name()));
  src_ptr->set_impl(self->tensor.impl());
975 976
  RETURN_PY_NONE

977 978 979
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

980 981 982 983
static PyObject* tensor__is_shared_underline_tensor_with(TensorObject* self,
                                                         PyObject* args,
                                                         PyObject* kwargs) {
  EAGER_TRY
984
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
985 986
  PADDLE_ENFORCE_EQ(src_tensor.initialized(),
                    true,
987 988 989 990 991
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
                        src_tensor.name()));
  bool res = false;
992
  if (!self->tensor.defined() || !src_tensor.defined()) {
993 994
    return ToPyObject(res);
  }
995
  res = (self->tensor.impl().get() == src_tensor.impl().get());
996 997 998 999
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1000 1001
PyDoc_STRVAR(tensor_method_detach__doc__,  // NOLINT
             R"DOC(detach($self, /)
W
wanghuancoder 已提交
1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040
--

Returns a new Tensor, detached from the current graph.
It will share data with origin Tensor and always doesn't have a Tensor copy.
In addition, the detached Tensor doesn't provide gradient propagation.

Returns:
    Tensor, The detached Tensor.

Examples:
    .. code-block:: python

      import paddle

      x = paddle.to_tensor([1.0], stop_gradient=False)
      detach_x = x.detach()
      detach_x[0] = 10.0
      print(x)  # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False,
                  #        [10.])
      y = x**2
      y.backward()
      print(x.grad)         # [20.0]
      print(detach_x.grad)  # None, 'stop_gradient=True' by default

      detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
      z = detach_x**3
      z.backward()

      print(x.grad)         # [20.0], detach_x is detached from x's graph, not affect each other
      print(detach_x.grad)  # [300.0], detach_x has its own graph

      # Due to sharing of data with origin Tensor, There are some unsafe operations:
      # y = 2 * x
      # detach_x[:] = 5.0
      # y.backward()
      # It will raise Error:
      #   one of the variables needed for gradient computation has been modified by an inplace operation.
)DOC");

1041 1042
static PyObject* tensor_method_detach(TensorObject* self,
                                      PyObject* args,
1043 1044
                                      PyObject* kwargs) {
  EAGER_TRY
1045
  PADDLE_ENFORCE_EQ(
1046
      self->tensor.defined(),
1047
      true,
1048
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
1049
                                        self->tensor.name()));
1050

1051
  PyObject* obj = p_tensor_type->tp_alloc(p_tensor_type, 0);
1052
  if (obj) {
1053
    auto v = reinterpret_cast<TensorObject*>(obj);
1054
    new (&(v->tensor)) paddle::Tensor();
1055 1056 1057 1058
    v->tensor.set_impl(self->tensor.impl());
    v->tensor.set_name(egr::Controller::Instance().GenerateUniqueName());
    auto autograd_meta_src = egr::EagerUtils::autograd_meta(&(self->tensor));
    auto autograd_meta = egr::EagerUtils::autograd_meta(&(v->tensor));
1059 1060 1061 1062 1063 1064 1065 1066 1067 1068
    autograd_meta->SetPersistable(autograd_meta_src->Persistable());
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "tp_alloc return null, can not new a PyObject."));
  }

  return obj;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1069 1070 1071 1072 1073 1074 1075 1076 1077 1078
PyDoc_STRVAR(tensor_method_detach___doc__, R"DOC(detach_($self, /)
--

Detach self from the current graph, and returns self Tensor.
In addition, the detached Tensor doesn't provide gradient propagation.

Returns:
    Tensor, The detached Tensor.
)DOC");

W
wanghuancoder 已提交
1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
static PyObject* tensor_method_detach_(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE_EQ(
      self->tensor.defined(),
      true,
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
                                        self->tensor.name()));

  auto autograd_meta = std::make_shared<egr::AutogradMeta>();
  autograd_meta->SetPersistable(
      egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  self->tensor.set_autograd_meta(autograd_meta);

  return reinterpret_cast<PyObject*>(self);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115
PyDoc_STRVAR(tensor_method_get_tensor__doc__, R"DOC(get_tensor($self, /)
--

Returns the underline tensor in the origin Tensor.

Returns:
    Underline tensor.

Examples:
    .. code-block:: python

      import paddle

      x = paddle.to_tensor([1.0], stop_gradient=False)
      underline_x = x.get_tensor()
      print(underline_x) # a Dense Tensor info
)DOC");

1116 1117 1118 1119
static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
1120
  if (!self->tensor.defined()) {
1121 1122 1123
    // The original `get_tensor` method of Variable will create a empty tensor
    phi::DenseTensor empty_tensor;
    return ToPyObject(&empty_tensor);
1124
  }
1125
  if (self->tensor.is_dense_tensor()) {
1126
    auto* tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
1127 1128
    VLOG(6) << "tensor: " << tensor->IsInitialized();
    return ToPyObject(tensor);
L
LiYuRio 已提交
1129 1130
  } else if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE
1131 1132
    auto* tensor =
        static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
1133
    VLOG(6) << "dist tensor: " << tensor->defined();
L
LiYuRio 已提交
1134 1135 1136 1137
    return ToPyObject(tensor);
#else
    RETURN_PY_NONE
#endif
1138
  } else {
1139
    RETURN_PY_NONE
1140 1141 1142 1143
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1144 1145 1146 1147 1148
static PyObject* tensor_method_get_underline_selected_rows(TensorObject* self,
                                                           PyObject* args,
                                                           PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
1149
    RETURN_PY_NONE
1150 1151 1152 1153 1154 1155
  }
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    return ToPyObject(selected_rows);
  } else {
1156
    RETURN_PY_NONE
1157 1158 1159 1160
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174
static PyObject* tensor_method__get_tensor_from_selected_rows(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows."));

  auto* selected_rows =
      static_cast<phi::SelectedRows*>(self->tensor.impl().get());

  PADDLE_ENFORCE(
      selected_rows->initialized(),
      paddle::platform::errors::Fatal("SelectedRows must be initialized."));

1175 1176
  auto* dense_tensor =
      static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
L
Leo Chen 已提交
1177
  VLOG(4) << "dense_tensor: " << dense_tensor->IsInitialized();
1178

1179
  auto t = paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
1180 1181 1182 1183 1184 1185 1186
  t.set_impl(std::make_shared<phi::DenseTensor>(*dense_tensor));

  return ToPyObject(t);

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jiabin Yang 已提交
1187 1188 1189
static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
1190
  EAGER_TRY
J
Jiabin Yang 已提交
1191 1192 1193
  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  VLOG(4) << "Call _getitem_index_not_tensor";
  std::vector<int> slice_axes, slice_starts, slice_ends, slice_strides,
W
wanghuancoder 已提交
1194 1195
      decrease_axis, none_axes, infer_flags;
  std::vector<int64_t> list_select_idxs;
J
Jiabin Yang 已提交
1196 1197
  // if index is a list, list_select_flag will be true
  bool list_select_flag = false;
1198 1199
  // Note(0x45f): Using defined() instead of initialized()
  // to support slice tensor which shape like [0, 0, 0].
J
Jiabin Yang 已提交
1200
  PADDLE_ENFORCE_EQ(
1201
      self->tensor.defined(),
1202
      true,
J
Jiabin Yang 已提交
1203 1204 1205 1206 1207
      platform::errors::InvalidArgument(
          "tensor %s has not been initialized, we can only slice initialized "
          "tensor please init it first with numpy or other tensor.",
          self->tensor.name()));
  auto tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218
  ParseIndexingSlice(tensor,
                     _index,
                     &slice_axes,
                     &slice_starts,
                     &slice_ends,
                     &slice_strides,
                     &decrease_axis,
                     &none_axes,
                     &infer_flags,
                     &list_select_idxs,
                     &list_select_flag);
J
Jiabin Yang 已提交
1219

1220 1221 1222 1223
  auto out =
      slice_axes.empty() && !list_select_flag
          ? self->tensor
          : paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
J
Jiabin Yang 已提交
1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239

  if (!slice_axes.empty()) {
    framework::AttributeMap attrs = {{"axes", slice_axes},
                                     {"starts", slice_starts},
                                     {"ends", slice_ends},
                                     {"infer_flags", infer_flags},
                                     {"decrease_axis", decrease_axis}};
    std::string op_type = "slice";
    for (auto stride : slice_strides) {
      if (stride != 1) {
        op_type = "strided_slice";
        attrs.insert({"strides", slice_strides});
        attrs.erase("decrease_axis");
        break;
      }
    }
1240 1241 1242 1243 1244 1245
    std::vector<int64_t> slice_axes_tmp(slice_axes.begin(), slice_axes.end());
    std::vector<int64_t> infer_flags_tmp(infer_flags.begin(),
                                         infer_flags.end());
    std::vector<int64_t> decrease_axis_tmp(decrease_axis.begin(),
                                           decrease_axis.end());

J
Jiabin Yang 已提交
1246
    if (op_type == "slice") {
W
wanghuancoder 已提交
1247
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
1248 1249 1250 1251 1252 1253
      out = slice_ad_func(self->tensor,
                          slice_axes_tmp,
                          slice_starts,
                          slice_ends,
                          infer_flags_tmp,
                          decrease_axis_tmp);
J
Jiabin Yang 已提交
1254
    } else if (op_type == "strided_slice") {
W
wanghuancoder 已提交
1255
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
1256
      out = strided_slice_ad_func(
1257
          self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
1258 1259 1260
      if (!decrease_axis_tmp.empty()) {
        out = squeeze_ad_func(out, decrease_axis_tmp);
      }
J
Jiabin Yang 已提交
1261 1262 1263 1264 1265 1266 1267 1268 1269
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Slice is only support slice and strided_slice, but we got %s which "
          "is impossible, please check your code first or contact us by "
          "issue. ",
          op_type));
    }
  }

1270
  bool set_to_1d = FLAGS_set_to_1d;
1271 1272 1273 1274 1275 1276

  if (set_to_1d) {
    // NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
    // with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
    // otherwise the output shape will be not correct.
    if (static_cast<int>(decrease_axis.size()) == tensor->dims().size()) {
J
JYChen 已提交
1277
      VLOG(1)
1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289
          << "Warning: In Tensor '__getitem__', if the number of scalar "
             "elements "
             "in the index is equal to the rank of the Tensor, the output "
             "should "
             "be 0-D. In order to be consistent with the behavior of previous "
             "versions, it will be processed to 1-D. But it is not correct and "
             "will be "
             "removed in release 2.6. "
             "If 1-D is still wanted, please modify the index element from "
             "scalar to slice "
             "(e.g. 'x[i]' => 'x[i:i+1]'). ";
      if (!none_axes.empty()) {
1290 1291 1292
        none_axes.pop_back();
      }
    }
1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306
  }
  if (!none_axes.empty()) {
    paddle::Tensor new_out;
    {
      eager_gil_scoped_release guard;
      // Deal with cases that decrease_axes is not empty
      // For example:
      // # x.shape: (2,3,4)
      // out = x[0, 0:2, None] # out.shape : (2, 1, 4)
      for (auto& axis : none_axes) {
        int len = 0;
        for (int da : decrease_axis) {
          if (da < axis) {
            len++;
J
Jiabin Yang 已提交
1307 1308
          }
        }
1309
        axis -= len;
J
Jiabin Yang 已提交
1310
      }
1311
      new_out = unsqueeze_ad_func(out, none_axes);
J
Jiabin Yang 已提交
1312
    }
1313
    return ToPyObject(new_out);
J
Jiabin Yang 已提交
1314 1315 1316 1317
  }

  // the index is a list
  if (list_select_flag) {
W
wanghuancoder 已提交
1318
    eager_gil_scoped_release guard;
W
wanghuancoder 已提交
1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331
    if (FLAGS_use_stride_kernel && list_select_idxs.size() == 1) {
      out = index_select_strided_ad_func(self->tensor, list_select_idxs[0], 0);
    } else {
      auto select_index =
          paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
      auto idx_tensor = std::make_shared<phi::DenseTensor>();
      select_index.set_impl(idx_tensor);
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(
          egr::Controller::Instance().GetExpectedPlace());
      paddle::framework::TensorFromVector(
          list_select_idxs, *dev_ctx, idx_tensor.get());
      out = index_select_ad_func(self->tensor, select_index, 0);
    }
J
Jiabin Yang 已提交
1332 1333 1334
  }

  return ToPyObject(out);
1335 1336 1337
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1338 1339
static PyObject* tensor__getitem_from_offset(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
1340 1341
                                             PyObject* kwargs) {
  EAGER_TRY
1342 1343 1344 1345 1346 1347 1348 1349
  phi::DenseTensor* ptr = nullptr;
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    ptr = static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
  } else {
    ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  }
1350 1351 1352
  PADDLE_ENFORCE_NOT_NULL(ptr,
                          platform::errors::InvalidArgument(
                              "%s is not a DenseTensor.", self->tensor.name()));
W
wanghuancoder 已提交
1353 1354
  const auto& tensor = *ptr;
  PADDLE_ENFORCE_EQ(
1355 1356
      tensor.IsInitialized(),
      true,
W
wanghuancoder 已提交
1357 1358 1359 1360 1361 1362 1363
      platform::errors::InvalidArgument(
          "Tensor of %s is Empty, please check if it has no data.",
          self->tensor.name()));

  const auto& tensor_dims = tensor.dims();

  std::vector<size_t> dims(tensor_dims.size());
W
wanghuancoder 已提交
1364
  std::vector<size_t> stride = phi::vectorize<size_t>(tensor.strides());
W
wanghuancoder 已提交
1365 1366 1367 1368 1369 1370 1371 1372

  size_t numel = 1;
  for (int i = tensor_dims.size() - 1; i >= 0; --i) {
    dims[i] = static_cast<size_t>(tensor_dims[i]);
    numel *= dims[i];
  }
  size_t offset = 0;
  if (PyTuple_Size(args) == 0) {
1373 1374
    PADDLE_ENFORCE_EQ(numel,
                      1,
W
wanghuancoder 已提交
1375 1376 1377 1378 1379 1380
                      platform::errors::InvalidArgument(
                          "only one element tensors can be converted to Python "
                          "scalars when no input coordinates"));
  } else if (PyTuple_Size(args) == 1) {
    offset = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
    PADDLE_ENFORCE_LT(
1381 1382
        offset,
        numel,
W
wanghuancoder 已提交
1383 1384 1385
        platform::errors::InvalidArgument(
            "index %d is out of bounds for size %d", offset, numel));
  } else {
1386 1387
    PADDLE_ENFORCE_EQ(PyTuple_Size(args),
                      dims.size(),
W
wanghuancoder 已提交
1388 1389 1390 1391 1392 1393
                      platform::errors::InvalidArgument(
                          "incorrect number of indices for Tensor"));

    for (Py_ssize_t i = 0; i < PyTuple_Size(args); ++i) {
      size_t index = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, i), i);
      PADDLE_ENFORCE_LT(
1394 1395
          index,
          dims[i],
W
wanghuancoder 已提交
1396
          platform::errors::InvalidArgument(
1397 1398 1399
              "index %d is out fo bounds for axis %d with size %d",
              index,
              i,
W
wanghuancoder 已提交
1400
              dims[i]));
W
wanghuancoder 已提交
1401
      offset += index * stride[i];
W
wanghuancoder 已提交
1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424
    }
  }
#define PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(_) \
  _(bool, DataType::BOOL)                     \
  _(int8_t, DataType::INT8)                   \
  _(uint8_t, DataType::UINT8)                 \
  _(int16_t, DataType::INT16)                 \
  _(uint16_t, DataType::UINT16)               \
  _(int32_t, DataType::INT32)                 \
  _(uint32_t, DataType::UINT32)               \
  _(int64_t, DataType::INT64)                 \
  _(uint64_t, DataType::UINT64)               \
  _(bfloat16, DataType::BFLOAT16)             \
  _(float16, DataType::FLOAT16)               \
  _(float, DataType::FLOAT32)                 \
  _(double, DataType::FLOAT64)                \
  _(complex64, DataType::COMPLEX64)           \
  _(complex128, DataType::COMPLEX128)

#define TENSOR_TO_PY_SCALAR(T, proto_type)                                   \
  if (tensor.dtype() == proto_type) {                                        \
    auto numpy_dtype = TensorDtype2NumpyDtype(proto_type);                   \
    T b = paddle::pybind::TensorGetElement<T>(tensor, offset);               \
1425 1426
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];    /* NOLINT */  \
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank]; /* NOLINT */  \
W
wanghuancoder 已提交
1427 1428
    auto& api = pybind11::detail::npy_api::get();                            \
    PyObject* array = api.PyArray_NewFromDescr_(                             \
1429 1430
        api.PyArray_Type_,                                                   \
        api.PyArray_DescrFromType_(numpy_dtype),                             \
1431
        0,                                                                   \
1432 1433 1434
        py_dims,                                                             \
        py_strides,                                                          \
        nullptr,                                                             \
W
wanghuancoder 已提交
1435 1436 1437 1438 1439
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |                      \
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,                 \
        nullptr);                                                            \
    std::memcpy(                                                             \
        reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data), \
1440 1441
        static_cast<void*>(&b),                                              \
        sizeof(b));                                                          \
W
wanghuancoder 已提交
1442 1443 1444 1445 1446 1447 1448 1449 1450 1451
    return array;                                                            \
  }

  PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(TENSOR_TO_PY_SCALAR);
#undef TENSOR_TO_PY_SCALAR
  PADDLE_THROW(platform::errors::Unimplemented(
      "Unsupported tensor data type: %s", tensor.dtype()));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492
static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Call __setitem_eager_tensor";

  auto self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());

  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  PyObject* value_obj = PyTuple_GET_ITEM(args, 1);
  // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
  // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
  PyObject* index_ptr =
      !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
  DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
    if (!PyTuple_Check(_index)) {
      Py_DECREF(index_ptr);
      VLOG(4) << "Call Py_DECREF";
    }
  });

  // 1. Check argumnets
  bool parse_index = true;

  // Check whether _index can be parsed.
  const int size = PyTuple_GET_SIZE(index_ptr);
  for (int dim = 0; dim < size; ++dim) {
    PyObject* slice_item = PyTuple_GetItem(index_ptr, dim);
    if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item) ||
          slice_item == Py_Ellipsis || slice_item == Py_None)) {
      parse_index = false;
      break;
    }
  }

  // 2. Call op set_value to speed up if the condition is met,
  // otherwise call TensorToPyArray.
  // TODO(liym27): Try not to call TensorToPyArray because it always
  // copys data to cpu place, which reduces performance.
  if (parse_index) {
    std::vector<int> axes, starts, ends, steps, decrease_axes, none_axes,
W
wanghuancoder 已提交
1493 1494
        infer_flags;
    std::vector<int64_t> list_select_idxs;
W
wanghuancoder 已提交
1495 1496
    // if index is a list, list_select_flag will be true
    bool list_select_flag = false;
1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507
    ParseIndexingSlice(self_tensor,
                       index_ptr,
                       &axes,
                       &starts,
                       &ends,
                       &steps,
                       &decrease_axes,
                       &none_axes,
                       &infer_flags,
                       &list_select_idxs,
                       &list_select_flag);
W
wanghuancoder 已提交
1508 1509 1510 1511 1512 1513 1514 1515 1516 1517

    framework::AttributeMap attrs = {{"axes", axes},
                                     {"starts", starts},
                                     {"ends", ends},
                                     {"steps", steps},
                                     {"decrease_axes", decrease_axes},
                                     {"none_axes", none_axes}};

    if (egr::Controller::Instance().HasGrad()) {
      PADDLE_ENFORCE_EQ(
1518
          egr::EagerUtils::IsLeafTensor(self->tensor) &&
W
wanghuancoder 已提交
1519
              !egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient(),
1520 1521 1522 1523 1524
          false,
          platform::errors::InvalidArgument(
              "Leaf Tensor (%s) that doesn't stop gradient can't use "
              "inplace strategy.",
              self->tensor.name()));
W
wanghuancoder 已提交
1525 1526
    }

1527
    paddle::Tensor value_tensor;
W
wanghuancoder 已提交
1528 1529 1530 1531

    if (PyCheckTensor(value_obj)) {
      value_tensor = reinterpret_cast<TensorObject*>(value_obj)->tensor;
    } else if (py::isinstance<py::array>(value_obj)) {
1532
      paddle::Tensor value_tensor_tmp(
W
wanghuancoder 已提交
1533 1534 1535 1536
          std::make_shared<phi::DenseTensor>(),
          egr::Controller::Instance().GenerateUniqueName());
      py::object value_obj_tmp(py::handle(value_obj), true);
      py::object value = value_obj_tmp;
1537
      if (self->tensor.dtype() == phi::DataType::FLOAT32) {
W
wanghuancoder 已提交
1538 1539 1540
        if (!py::isinstance<py::array_t<float>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<float>(value_obj_tmp);
        }
1541
      } else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
W
wanghuancoder 已提交
1542 1543 1544
        if (!py::isinstance<py::array_t<double>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<double>(value_obj_tmp);
        }
1545
      } else if (self->tensor.dtype() == phi::DataType::INT32) {
W
wanghuancoder 已提交
1546 1547 1548
        if (!py::isinstance<py::array_t<int32_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int32_t>(value_obj_tmp);
        }
1549
      } else if (self->tensor.dtype() == phi::DataType::INT64) {
W
wanghuancoder 已提交
1550 1551 1552
        if (!py::isinstance<py::array_t<int64_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int64_t>(value_obj_tmp);
        }
1553
      } else if (self->tensor.dtype() == phi::DataType::BOOL) {
W
wanghuancoder 已提交
1554 1555 1556
        if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp);
        }
1557 1558 1559 1560 1561 1562 1563 1564 1565 1566
      } else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
        if (!py::isinstance<py::array_t<std::complex<float>>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<std::complex<float>>(
              value_obj_tmp);
        }
      } else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
        if (!py::isinstance<py::array_t<std::complex<double>>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<std::complex<double>>(
              value_obj_tmp);
        }
W
wanghuancoder 已提交
1567 1568 1569 1570
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "When assign a numpy.np value to a paddle.Tensor, "
            "the data type of the paddle.Tensor must be bool, "
1571
            "float32, float64, complex64, complex128, int32 or int64, "
W
wanghuancoder 已提交
1572 1573 1574
            "please check the type of tensor."));
      }

W
wanghuancoder 已提交
1575 1576 1577 1578 1579
      SetTensorFromPyArray(
          static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
          value,
          self->tensor.place(),
          false);
W
wanghuancoder 已提交
1580 1581 1582 1583 1584 1585 1586

      value_tensor = value_tensor_tmp;
    } else {
      py::object value_obj_tmp(py::handle(value_obj), true);
      // convert the value to self data type
      if (py::isinstance<py::float_>(value_obj_tmp) ||
          py::isinstance<py::int_>(value_obj_tmp) ||
1587 1588
          py::isinstance<py::bool_>(value_obj_tmp) ||
          PyComplex_Check(value_obj)) {
1589
        if (self->tensor.dtype() == phi::DataType::FLOAT32) {
1590 1591
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<float>()};
1592
        } else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
1593 1594
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<double>()};
1595
        } else if (self->tensor.dtype() == phi::DataType::INT32) {
1596 1597
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<int32_t>()};
1598
        } else if (self->tensor.dtype() == phi::DataType::INT64) {
1599 1600
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<int64_t>()};
1601
        } else if (self->tensor.dtype() == phi::DataType::BOOL) {
1602 1603
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<bool>()};
1604
        } else if (self->tensor.dtype() == phi::DataType::FLOAT16) {
1605 1606 1607 1608 1609 1610 1611 1612
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<float>()};
        } else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<std::complex<float>>()};
        } else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<std::complex<double>>()};
W
wanghuancoder 已提交
1613 1614 1615 1616
        } else {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "When assign a value to a paddle.Tensor, "
              "the data type of the paddle.Tensor must be bool, "
1617 1618
              "float32, float64, complex64, complex128, int32, int64 or "
              "float16, "
W
wanghuancoder 已提交
1619 1620 1621 1622 1623 1624 1625
              "please check the type of tensor."));
        }
        attrs["shape"] = std::vector<int64_t>{1};

      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Value type error. The assign value allows "
1626
            "numpy.ndarray, integer, float, complex  or bool, "
W
wanghuancoder 已提交
1627 1628 1629 1630 1631 1632 1633
            "but received %s.",
            Py_TYPE(value_obj)));
      }
    }
    {
      // Release gil and do tracing
      py::gil_scoped_release release;
1634
      // use inplace set_value_ operator
J
Jiabin Yang 已提交
1635 1636
      if (value_tensor.initialized() &&
          (self->tensor.dtype() != value_tensor.dtype())) {
1637
        paddle::small_vector<std::vector<paddle::Tensor>,
J
Jiabin Yang 已提交
1638 1639 1640 1641 1642 1643 1644
                             egr::kSlotSmallVectorSize>
            tmps = {{self->tensor}, {value_tensor}};
        auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
        self->tensor = egr::EagerAmpAutoCast(
            self->tensor.name(), self->tensor, amp_dtype, "set_value");
        value_tensor = egr::EagerAmpAutoCast(
            value_tensor.name(), value_tensor, amp_dtype, "set_value");
1645 1646 1647
        if (self->tensor.dtype() != value_tensor.dtype()) {
          value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
        }
J
Jiabin Yang 已提交
1648
      }
1649 1650
      self->tensor = set_value__dygraph_function(
          self->tensor, value_tensor, {}, {}, {}, attrs);
1651 1652 1653 1654 1655 1656 1657 1658 1659
    }
    if (PyCheckTensor(value_obj)) {
      // pass the stop_gradient from value to tensor.
      // pass stop gradient should be done after CheckInplace in
      // set_value__dygraph_function.
      if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
        egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
      }
W
wanghuancoder 已提交
1660 1661
    }
  } else {
1662
    auto self_numpy = TensorToPyArray(*self_tensor, true);
W
wanghuancoder 已提交
1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673
    VLOG(4) << "parse_index is false";
    if (PyCheckTensor(_index)) {
      VLOG(4) << "index is tensor";
      auto index_tensor = static_cast<phi::DenseTensor*>(
          reinterpret_cast<TensorObject*>(_index)->tensor.impl().get());
      auto index_numpy = TensorToPyArray(*index_tensor);
      self_numpy[index_numpy] = py::object(py::handle(value_obj), true);
    } else {
      VLOG(4) << "index is not tensor";
      self_numpy[_index] = py::object(py::handle(value_obj), true);
    }
1674
    if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
1675
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1676 1677 1678 1679
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CUDAPlace(0)),
                           false);
W
wanghuancoder 已提交
1680
#else
1681 1682 1683 1684
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CPUPlace()),
                           false);
W
wanghuancoder 已提交
1685 1686
#endif
    } else {
1687 1688
      SetTensorFromPyArray(
          self_tensor, self_numpy, self->tensor.place(), false);
W
wanghuancoder 已提交
1689 1690
    }
  }
1691 1692
  RETURN_PY_NONE

W
wanghuancoder 已提交
1693 1694 1695
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1696 1697
static PyObject* tensor_register_grad_hook(TensorObject* self,
                                           PyObject* args,
1698 1699 1700
                                           PyObject* kwargs) {
  EAGER_TRY
  int64_t hook_id;
1701
  if (egr::EagerUtils::IsLeafTensor(self->tensor)) {
1702
    VLOG(6) << "Register hook for leaf tensor: " << self->tensor.name();
1703 1704 1705 1706 1707

    auto autograd_meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);

    if (autograd_meta && !autograd_meta->StopGradient()) {
      if (!autograd_meta->GetMutableGradNode()) {
1708
        VLOG(6) << "Detected nullptr grad_node, Leaf tensor should have had "
1709 1710 1711 1712 1713 1714
                   "grad_node with type: GradNodeAccumulation.";
        autograd_meta->SetGradNode(
            std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
      }
    }

1715 1716 1717 1718 1719 1720 1721 1722 1723
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();
    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    auto accumulation_grad_node =
        std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
    hook_id = accumulation_grad_node->RegisterGradientHook(
1724 1725
        rank_info.first,
        rank_info.second,
1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737
        std::make_shared<PyTensorHook>(hook_func));

  } else {
    VLOG(6) << "Register hook for non leaf tensor: " << self->tensor.name();
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();

    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    hook_id = grad_node->RegisterGradientHook(
1738 1739
        rank_info.first,
        rank_info.second,
1740 1741 1742 1743 1744 1745
        std::make_shared<PyTensorHook>(hook_func));
  }
  return ToPyObject(hook_id);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1746 1747
static PyObject* tensor_remove_grad_hook(TensorObject* self,
                                         PyObject* args,
1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759
                                         PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Remove the registered hook for tensor: " << self->tensor.name();
  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);

  int64_t hook_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);

  return ToPyObject(grad_node->RemoveGradientHook(hook_id));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771
static PyObject* tensor_inplace_assign(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "inplace assign for tensor:" << self->tensor.name();
  PyObject* other = PyTuple_GET_ITEM(args, 0);
  PyObject* self_obj = reinterpret_cast<PyObject*>(self);
  ShareTensor(self_obj, other);
  RETURN_PY_NONE;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1772
PyDoc_STRVAR(tensor_method__register_reduce_hook__doc__,  // NOLINT
W
wanghuancoder 已提交
1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795
             R"DOC(_register_backward_hook($self, hook, /)
--

Registers a backward hook for current Tensor.

This hook will be called every time the gradient of current Tensor has been fully calculated.

There are two differences with `_register_grad_hook`:
1. This backward hook will be executed after the gradient accumulation completed across batches,
  but the hook registered by `_register_grad_hook` will be executed the gradient accumulation
  completed in current batch.
2. This backward hook function should have the following signature:

    hook() -> None

  It requires no input and no return value.

Args:
    hook(function): A backward hook to be registered for Tensor.gradient

Returns:
    None
)DOC");
1796 1797
static PyObject* tensor_register_reduce_hook(TensorObject* self,
                                             PyObject* args,
1798 1799 1800 1801 1802 1803
                                             PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Register reduce hook for tensor: " << self->tensor.name();

  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);
1804
  PADDLE_ENFORCE_EQ(egr::EagerUtils::IsLeafTensor(self->tensor),
1805
                    true,
1806 1807 1808 1809
                    platform::errors::InvalidArgument(
                        "Only can register backward hook for leaf Tensor."));
  PADDLE_ENFORCE_EQ(
      !egr::EagerUtils::unsafe_autograd_meta(self->tensor)->StopGradient(),
1810 1811 1812 1813
      true,
      platform::errors::InvalidArgument(
          "Cannot register backward hook on a Tensor that stop "
          "gradient."));
1814 1815
  PADDLE_ENFORCE(
      grad_node.get() != nullptr,
1816
      paddle::platform::errors::Fatal("Detected nullptr grad_node,"
1817 1818 1819 1820 1821 1822 1823
                                      "Leaf tensor should have had grad_node "
                                      "with type: GradNodeAccumulation."));
  PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

  auto accumulation_grad_node =
      std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
  accumulation_grad_node->RegisterReduceHook(
1824
      std::make_shared<PyVoidHook>(hook_func));
1825

1826 1827
  RETURN_PY_NONE

1828 1829 1830
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1831 1832
static PyObject* tensor__set_grad_type(TensorObject* self,
                                       PyObject* args,
J
Jiabin Yang 已提交
1833
                                       PyObject* kwargs) {
1834 1835 1836
  EAGER_TRY
  auto var_type = pybind::CastPyArg2ProtoType(PyTuple_GET_ITEM(args, 0), 0);
  auto grad_tensor =
1837
      egr::EagerUtils::autograd_meta(&self->tensor)->MutableGrad();
1838
  if (var_type == framework::proto::VarType::LOD_TENSOR) {
1839
    grad_tensor->set_impl(std::make_shared<phi::DenseTensor>());
1840
  } else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
1841
    grad_tensor->set_impl(std::make_shared<phi::SelectedRows>());
1842
  }
1843 1844
  RETURN_PY_NONE

1845 1846 1847
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1848 1849
static PyObject* tensor__clear(TensorObject* self,
                               PyObject* args,
J
Jiabin Yang 已提交
1850 1851 1852
                               PyObject* kwargs) {
  EAGER_TRY
  self->tensor.reset();
1853 1854
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1855 1856 1857
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1858 1859 1860 1861 1862 1863 1864 1865 1866
static PyObject* tensor__clear_dataptr(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  self->tensor.set_impl(nullptr);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1867 1868
static PyObject* tensor__copy_gradient_from(TensorObject* self,
                                            PyObject* args,
J
Jiabin Yang 已提交
1869 1870 1871
                                            PyObject* kwargs) {
  EAGER_TRY
  auto src = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
1872
  if (self->tensor.initialized()) {
1873 1874
    PADDLE_ENFORCE_EQ(self->tensor.dtype(),
                      src.dtype(),
J
Jiabin Yang 已提交
1875 1876
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different data type with Tensor %s",
1877 1878
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1879 1880 1881 1882 1883
    PADDLE_ENFORCE_EQ(self->tensor.impl()->type_info().id(),
                      src.impl()->type_info().id(),
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different type with Tensor %s, Tensor "
                          "ShareGradientDataWith cannot be performed!",
1884 1885
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1886 1887 1888 1889
  }
  VLOG(6) << "Tensor copy gradient from: " << src.name();
  auto* p_grad = egr::EagerUtils::mutable_grad(self->tensor);
  if (p_grad) {
1890 1891
    PADDLE_ENFORCE_EQ(src.initialized(),
                      true,
J
Jiabin Yang 已提交
1892 1893 1894 1895
                      platform::errors::InvalidArgument(
                          "Tensor %s has not been initialized", src.name()));
    p_grad->set_impl(src.impl());
  }
1896 1897
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1898 1899
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
1900

1901 1902 1903
static PyObject* tensor__use_gpudnn(TensorObject* self,
                                    PyObject* args,
                                    PyObject* kwargs) {
1904 1905 1906
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.defined() && self->tensor.is_dense_tensor(),
                 paddle::platform::errors::Fatal(
1907
                     "function _use_gpudnn is only effective for DenseTensor"));
1908

1909
  bool use_gpudnn = pybind::CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
1910

1911
  // Set the same use_gpudnn attribute, return directly
1912 1913 1914 1915
  phi::DenseTensor* dense_tensor =
      static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  phi::DenseTensorMeta* dense_tensor_meta =
      phi::DenseTensorUtils::GetMutableMeta(dense_tensor);
1916
  if (use_gpudnn == dense_tensor_meta->use_gpudnn) {
1917 1918 1919
    return ToPyObject(self->tensor);
  }

1920
  // Share all other members of Tensor except use_gpudnn
1921
  phi::DenseTensorMeta target_dense_meta = *dense_tensor_meta;
1922
  target_dense_meta.use_gpudnn = use_gpudnn;
1923 1924 1925 1926
  phi::DenseTensor target_dense_tensor;
  target_dense_tensor.ShareDataWith(*dense_tensor);
  target_dense_tensor.set_meta(target_dense_meta);
  // Construct returned tensor
1927
  paddle::Tensor target_tensor(
1928 1929 1930 1931
      std::make_shared<phi::DenseTensor>(target_dense_tensor),
      self->tensor.name());
  target_tensor.set_autograd_meta(self->tensor.mutable_autograd_meta());
  VLOG(4) << "Tensor: " << target_tensor.name()
1932
          << " set use_gpudnn = " << use_gpudnn;
1933 1934 1935 1936 1937

  return ToPyObject(target_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1938 1939
static PyObject* tensor_method_set_vocab(TensorObject* self,
                                         PyObject* args,
1940 1941
                                         PyObject* kwargs) {
  EAGER_TRY
1942
  using Vocab = paddle::framework::Vocab;
1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954
  auto vocab = CastPyArg2Vocab(PyTuple_GET_ITEM(args, 0), 0);
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Vocab>() = vocab;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_set_string_list(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
1955
  using Strings = paddle::framework::Strings;
1956
  auto strings = CastPyArg2VectorOfString(PyTuple_GET_ITEM(args, 0), 0);
1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Strings>() = strings;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_map_tensor(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE_EQ(
1969 1970
      egr::IsVariableCompatTensor(self->tensor),
      true,
1971 1972
      paddle::platform::errors::Fatal(
          "this method is only effective for VariableCompatTensor"));
1973
  using Vocab = paddle::framework::Vocab;
1974 1975 1976 1977 1978 1979
  auto* var_tensor =
      static_cast<const egr::VariableCompatTensor*>(self->tensor.impl().get());
  return ToPyObject(var_tensor->Get<Vocab>());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005
PyDoc_STRVAR(tensor_method_nnz__doc__,
             R"DOC(nnz($self, /)
--

Note:
    **This API is only available for SparseCooTensor or SparseCsrTensor.**

Returns the total number of non zero elements in input SparseCooTensor/SparseCsrTensor.

Returns:
    int

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.nnz()
        # 3

)DOC");

2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026
static PyObject* tensor_method_get_non_zero_nums(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
    return ToPyObject(sparse_coo_tensor->nnz());
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
    return ToPyObject(sparse_csr_tensor->nnz());
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054
PyDoc_STRVAR(tensor_method_indices__doc__,
             R"DOC(indices($self, /)
--

Note:
    **This API is only available for SparseCooTensor.**

Returns the indices of non zero elements in input SparseCooTensor.

Returns:
    DenseTesnor

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.indices()
        # Tensor(shape=[2, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
        #        [[0, 1, 2],
        #         [1, 2, 0]])

)DOC");

2055 2056 2057 2058 2059 2060 2061 2062 2063
static PyObject* tensor_method_get_non_zero_indices(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_coo_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCooTensor"));
  auto sparse_coo_tensor =
      std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
2064
  paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
2065 2066 2067 2068 2069
      sparse_coo_tensor->non_zero_indices()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096
PyDoc_STRVAR(tensor_method_values__doc__,
             R"DOC(values($self, /)
--

Note:
    **This API is only available for SparseCooTensor or SparseCsrTensor.**

Returns the values of non zero elements in input SparseCooTensor.

Returns:
    DenseTesnor

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.values()
        # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
        #        [1., 2., 3.])

)DOC");

2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108
static PyObject* tensor_method_get_non_zero_elements(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
2109
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
2110 2111 2112 2113 2114
        sparse_coo_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
2115
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
2116 2117 2118 2119 2120 2121
        sparse_csr_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149
PyDoc_STRVAR(tensor_method_crows__doc__,
             R"DOC(crows($self, /)
--

Note:
    **This API is only available for SparseCsrTensor.**

Returns the compressed row index of non zero elements in input SparseCsrTensor.

Returns:
    DenseTesnor

Examples:
    .. code-block:: python

        import paddle

        crows = [0, 2, 3, 5]
        cols = [1, 3, 2, 0, 1]
        values = [1, 2, 3, 4, 5]
        dense_shape = [3, 4]
        csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape)
        csr.crows()
        # Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True,
        #        [0, 2, 3, 5])

)DOC");

2150 2151 2152 2153 2154 2155 2156 2157 2158
static PyObject* tensor_method_get_non_zero_crows(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
2159
  paddle::Tensor tensor(
2160 2161 2162 2163 2164
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_crows()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192
PyDoc_STRVAR(tensor_method_cols__doc__,
             R"DOC(cols($self, /)
--

Note:
    **This API is only available for SparseCsrTensor.**

Returns the column index of non zero elements in input SparseCsrTensor.

Returns:
    DenseTesnor

Examples:
    .. code-block:: python

        import paddle

        crows = [0, 2, 3, 5]
        cols = [1, 3, 2, 0, 1]
        values = [1, 2, 3, 4, 5]
        dense_shape = [3, 4]
        csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape)
        csr.cols()
        # Tensor(shape=[5], dtype=int64, place=Place(gpu:0), stop_gradient=True,
        #        [1, 3, 2, 0, 1])

)DOC");

2193 2194 2195 2196 2197 2198 2199 2200 2201
static PyObject* tensor_method_get_non_zero_cols(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
2202
  paddle::Tensor tensor(
2203 2204 2205 2206 2207
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_cols()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224
PyDoc_STRVAR(tensor_method_is_dense__doc__, R"DOC(is_dense($self, /)
--

Whether the Tensor is a Dense Tensor.

Returns:
    Whether the Tensor is a Dense Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1.0], stop_gradient=False)
        print(x.is_dense())
)DOC");

2225 2226
static PyObject* tensor_method_is_dense(TensorObject* self,
                                        PyObject* args,
2227 2228 2229 2230 2231 2232 2233 2234 2235
                                        PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
  return ToPyObject(self->tensor.is_dense_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2236 2237 2238 2239 2240 2241 2242 2243 2244 2245 2246 2247 2248 2249 2250 2251 2252
PyDoc_STRVAR(tensor_method_is_dist__doc__, R"DOC(is_dist($self, /)
--

Whether the Tensor is a Distributed Tensor.

Returns:
    Whether the Tensor is a Distributed Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1.0], stop_gradient=False)
        print(x.is_dist()) # False
)DOC");

L
LiYuRio 已提交
2253 2254 2255 2256 2257 2258 2259 2260 2261 2262 2263
static PyObject* tensor_method_is_dist(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
  return ToPyObject(self->tensor.is_dist_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278 2279 2280 2281 2282 2283 2284 2285 2286 2287
PyDoc_STRVAR(tensor_is_sparse__doc__,
             R"DOC(is_sparse($self, /)
--

Returns whether the input Tensor is SparseCooTensor or SparseCsrTensor.

When input is SparseCooTensor/SparseCsrTensor, will return True. When input is DenseTensor, will return False.

Returns:
    bool

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.is_sparse()
        # True

)DOC");
2288 2289
static PyObject* tensor_method_is_sparse(TensorObject* self,
                                         PyObject* args,
2290 2291
                                         PyObject* kwargs) {
  EAGER_TRY
2292 2293 2294
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
2295 2296 2297 2298 2299
  return ToPyObject(self->tensor.is_sparse_coo_tensor() ||
                    self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315 2316 2317 2318 2319 2320 2321 2322 2323 2324
PyDoc_STRVAR(tensor_is_sparse_coo__doc__,
             R"DOC(is_sparse_coo($self, /)
--

Returns whether the input Tensor is SparseCooTensor.

When input is SparseCooTensor, will return True. When input is DenseTensor/SparseCsrTensor, will return False.

Returns:
    bool

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.is_sparse_coo()
        # True

)DOC");

2325 2326
static PyObject* tensor_method_is_sparse_coo(TensorObject* self,
                                             PyObject* args,
2327 2328
                                             PyObject* kwargs) {
  EAGER_TRY
2329 2330 2331
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
2332 2333 2334 2335
  return ToPyObject(self->tensor.is_sparse_coo_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352 2353 2354 2355 2356 2357 2358 2359 2360 2361
PyDoc_STRVAR(tensor_is_sparse_csr__doc__,
             R"DOC(is_sparse_csr($self, /)
--

Returns whether the input Tensor is SparseCsrTensor.

When input is SparseCsrTensor, will return True. When input is DenseTensor/SparseCooTensor, will return False.

Returns:
    bool

Examples:
    .. code-block:: python

        import paddle

        crows = [0, 2, 3, 5]
        cols = [1, 3, 2, 0, 1]
        values = [1, 2, 3, 4, 5]
        dense_shape = [3, 4]
        csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape)
        csr.is_sparse_csr()
        # True

)DOC");

2362 2363
static PyObject* tensor_method_is_sparse_csr(TensorObject* self,
                                             PyObject* args,
2364 2365
                                             PyObject* kwargs) {
  EAGER_TRY
2366 2367 2368
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
2369 2370 2371 2372
  return ToPyObject(self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394 2395 2396 2397 2398 2399 2400 2401 2402 2403
PyDoc_STRVAR(tensor_to_sparse_csr__doc__,
             R"DOC(to_sparse_csr($self, /)
--

Note:
    **This API is only available for DenseTensor or SparseCooTensor.**

Convert input Tensor to SparseCsrTensor.

When input is SparseCooTensor, will convert `COO` to `CSR` . When input is DenseTensor, will convert `Dense` to `CSR` .

Returns:
    SparseCsrTensor

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.to_sparse_csr()
        # Tensor(shape=[3, 3], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
        #        crows=[0, 1, 2, 3],
        #        cols=[1, 2, 0],
        #        values=[1., 2., 3.])

)DOC");

2404 2405
static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
                                             PyObject* args,
2406 2407 2408 2409 2410 2411 2412 2413 2414 2415 2416 2417 2418
                                             PyObject* kwargs) {
  EAGER_TRY
  auto csr_tensor = self->tensor.to_sparse_csr();
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetStopGradient(
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient());
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetPersistable(
          egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  return ToPyObject(csr_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441 2442 2443 2444 2445 2446 2447 2448 2449 2450
PyDoc_STRVAR(tensor_is_same_shape__doc__,
             R"DOC(is_same_shape($self, y, /)
--

Return the results of shape comparison between two Tensors, check whether x.shape equal to y.shape.
Any two type Tensor among DenseTensor/SparseCooTensor/SparseCsrTensor are supported.

Args:
    x (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
    y (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.

Returns:
    bool: True for same shape and False for different shape.

Examples:

    .. code-block:: python

        import paddle

        x = paddle.rand([2, 3, 8])
        y = paddle.rand([2, 3, 8])
        y = y.to_sparse_csr()
        z = paddle.rand([2, 5])

        x.is_same_shape(y)
        # True
        x.is_same_shape(z)
        # False

)DOC");

2451 2452 2453 2454 2455 2456 2457 2458 2459
static PyObject* tensor_method_is_same_shape(TensorObject* self,
                                             PyObject* args,
                                             PyObject* kwargs) {
  EAGER_TRY
  auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
  return ToPyObject(self->tensor.shape() == other.shape());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2460 2461
static PyObject* tensor__inplace_version(TensorObject* self,
                                         PyObject* args,
2462 2463 2464 2465 2466 2467 2468 2469
                                         PyObject* kwargs) {
  EAGER_TRY
  uint32_t inplace_version = self->tensor.current_inplace_version();

  return ToPyObject(inplace_version);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2470 2471
PyDoc_STRVAR(tensor_method_element_size__doc__,  // NOLINT
             R"DOC(element_size($self, /)
W
wanghuancoder 已提交
2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490 2491 2492 2493 2494 2495 2496 2497 2498 2499
--

Returns the size in bytes of an element in the Tensor.

Returns:
    int, The size in bytes of an element in the Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor(1, dtype='bool')
        x.element_size() # 1

        x = paddle.to_tensor(1, dtype='float16')
        x.element_size() # 2

        x = paddle.to_tensor(1, dtype='float32')
        x.element_size() # 4

        x = paddle.to_tensor(1, dtype='float64')
        x.element_size() # 8

        x = paddle.to_tensor(1, dtype='complex128')
        x.element_size() # 16
)DOC");

2500 2501
static PyObject* tensor_method_element_size(TensorObject* self,
                                            PyObject* args,
2502 2503
                                            PyObject* kwargs) {
  EAGER_TRY
2504
  uint32_t element_size = phi::SizeOf(self->tensor.dtype());
2505 2506 2507 2508 2509

  return ToPyObject(element_size);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2510
PyDoc_STRVAR(tensor_method__bump_inplace_version__doc__,  // NOLINT
W
wanghuancoder 已提交
2511 2512 2513
             R"DOC(_bump_inplace_version($self, /)
--

2514
Note:
W
wanghuancoder 已提交
2515 2516
    **This API is ONLY available in Dygraph mode.**
    **This is a very low level API. Users should not use it directly. **
2517

W
wanghuancoder 已提交
2518 2519
  Bump the version whenever the Tensor is modified through an inplace operation.
)DOC");
2520 2521 2522 2523 2524
static PyObject* tensor__bump_inplace_version(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  self->tensor.bump_inplace_version();
2525
  RETURN_PY_NONE
2526 2527 2528
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2529 2530 2531 2532
static PyObject* tensor_method_is_selected_rows(TensorObject* self,
                                                PyObject* args,
                                                PyObject* kwargs) {
  EAGER_TRY
2533 2534 2535
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
2536 2537 2538 2539
  return ToPyObject(self->tensor.is_selected_rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2540 2541
static PyObject* tensor_method_get_rows(TensorObject* self,
                                        PyObject* args,
2542 2543 2544 2545 2546 2547 2548 2549 2550 2551 2552
                                        PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows"));
  auto selected_rows =
      std::dynamic_pointer_cast<phi::SelectedRows>(self->tensor.impl());
  return ToPyObject(selected_rows->rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2553 2554 2555 2556 2557 2558 2559 2560 2561 2562
static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
  }

2563
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
2564 2565 2566 2567
  if (grad && grad->defined() && grad->is_dense_tensor() &&
      grad->initialized()) {
    grad->reset_inplace_version(set_to_zero);
  }
2568 2569
  RETURN_PY_NONE

2570 2571 2572
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2573 2574
static PyObject* tensor_method__share_memory(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
2575 2576 2577
                                             PyObject* kwargs) {
  EAGER_TRY
#ifndef _WIN32
2578 2579
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
W
wanghuancoder 已提交
2580 2581 2582 2583 2584 2585 2586 2587 2588 2589 2590 2591 2592 2593 2594 2595
                    platform::errors::InvalidArgument(
                        "Sharing memory only support CPU Tensor currently"));
  // 1. get LoDTensor
  auto* t =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl()).get();
  // 2. allocate shared memory
  void* data_ptr = t->data();
  size_t data_size =
      t->numel() *
      framework::SizeOfType(framework::TransToProtoVarType(t->dtype()));
  auto shared_writer_holder =
      memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
  // 3. maintain mmap fd set & backup ipc_name
  const std::string& ipc_name = shared_writer_holder->ipc_name();
  memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
  // 4. copy data & reset holder
2596 2597 2598 2599 2600
  memory::Copy(platform::CPUPlace(),
               shared_writer_holder->ptr(),
               platform::CPUPlace(),
               data_ptr,
               data_size);
W
wanghuancoder 已提交
2601 2602 2603 2604 2605
  t->ResetHolder(shared_writer_holder);
  return ToPyObject(t);
#else
  PADDLE_THROW(platform::errors::PermissionDenied(
      "Sharing memory in Windows OS is not supported currently"));
2606 2607
  RETURN_PY_NONE

W
wanghuancoder 已提交
2608 2609 2610 2611
#endif
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2612 2613
static PyObject* tensor__offset(TensorObject* self,
                                PyObject* args,
2614 2615 2616 2617
                                PyObject* kwargs) {
  EAGER_TRY
  auto t = std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  PADDLE_ENFORCE_EQ(
2618 2619
      t->IsInitialized(),
      true,
2620 2621 2622 2623 2624 2625 2626
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
                                        self->tensor.name()));

  return ToPyObject(t->offset());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2627 2628
static PyObject* tensor__grad_name(TensorObject* self,
                                   PyObject* args,
2629 2630
                                   PyObject* kwargs) {
  EAGER_TRY
2631
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
2632 2633 2634 2635 2636 2637
  PADDLE_ENFORCE_EQ(
      grad != nullptr,
      true,
      platform::errors::InvalidArgument(
          "Detected nullptr grad. Please check if you have manually "
          "cleared the grad inside autograd_meta"));
2638 2639 2640 2641
  return ToPyObject(grad->name());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2642 2643
static PyObject* tensor__grad_value(TensorObject* self,
                                    PyObject* args,
2644 2645
                                    PyObject* kwargs) {
  EAGER_TRY
2646
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
2647 2648 2649 2650 2651 2652
  PADDLE_ENFORCE_EQ(
      grad != nullptr,
      true,
      platform::errors::InvalidArgument(
          "Detected nullptr grad. Please check if you have manually "
          "cleared the grad inside autograd_meta"));
2653 2654

  if (!grad->defined()) {
2655
    RETURN_PY_NONE
2656 2657
  }
  if (grad->is_dense_tensor()) {
2658
    auto* grad_tensor = static_cast<phi::DenseTensor*>(grad->impl().get());
2659 2660 2661 2662
    return ToPyObject(grad_tensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::Fatal(
        "this method is only supported for DenseTensor"));
2663
    RETURN_PY_NONE
2664 2665 2666 2667
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2668 2669
static PyObject* tensor__unset_fake_empty(TensorObject* self,
                                          PyObject* args,
2670 2671
                                          PyObject* kwargs) {
  EAGER_TRY
2672
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
2673 2674 2675 2676 2677 2678
  PADDLE_ENFORCE_EQ(
      grad != nullptr,
      true,
      platform::errors::InvalidArgument(
          "Detected nullptr grad. Please check if you have manually "
          "cleared the grad inside autograd_meta"));
2679

2680
  bool is_leaf = egr::EagerUtils::IsLeafTensor(self->tensor);
2681 2682 2683 2684 2685 2686 2687 2688 2689
  if (is_leaf) {
    std::static_pointer_cast<egr::GradNodeAccumulation>(
        egr::EagerUtils::grad_node(self->tensor))
        ->SetFakeEmpty(false);
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2690 2691 2692 2693 2694 2695 2696 2697 2698 2699 2700 2701 2702 2703 2704 2705 2706 2707
PyDoc_STRVAR(tensor_data_ptr__doc__,
             R"DOC(data_ptr($self, /)
--

Returns the address of the first element of current Tensor.

Returns:
    int, The address of the first element of current Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1, 2, 3])
        print(x.data_ptr())
)DOC");

2708 2709 2710 2711 2712
static PyObject* tensor_data_ptr(TensorObject* self,
                                 PyObject* args,
                                 PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.initialized() && self->tensor.is_dense_tensor()) {
S
sneaxiy 已提交
2713 2714 2715 2716
    return ToPyObject(
        (int64_t)std::dynamic_pointer_cast<phi::DenseTensor>(  // NOLINT
            self->tensor.impl())
            ->data());
2717 2718 2719 2720 2721
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2722 2723 2724 2725 2726 2727 2728 2729 2730 2731 2732 2733 2734 2735 2736
static PyObject* tensor__grad_ivar(TensorObject* self,
                                   PyObject* args,
                                   PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Get grad for tensor: " << self->tensor.name();
  auto meta = egr::EagerUtils::nullable_autograd_meta(self->tensor);
  VLOG(6) << meta << " initialized: " << meta->Grad().initialized();
  if (meta && meta->Grad().initialized()) {
    return ToPyObject(meta->Grad());
  } else {
    RETURN_PY_NONE
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2737 2738 2739 2740 2741 2742 2743 2744 2745 2746 2747 2748 2749 2750 2751 2752 2753 2754 2755
PyDoc_STRVAR(tensor_get_strides__doc__,
             R"DOC(get_strides($self, /)
--

Returns the strides of current Tensor.

Returns:
    List, the strides of current Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1, 2, 3])
        y = x[1]
        print(y.get_strides())
)DOC");

W
wanghuancoder 已提交
2756 2757 2758 2759 2760 2761 2762 2763 2764 2765 2766 2767 2768 2769 2770 2771 2772 2773
static PyObject* tensor_method_strides(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  std::vector<int64_t> value;
  if (!self->tensor.defined() || !self->tensor.is_dense_tensor()) {
    return ToPyObject(value);
  }
  auto stride = self->tensor.strides();
  size_t rank = static_cast<size_t>(stride.size());
  value.resize(rank);
  for (size_t i = 0; i < rank; i++) {
    value[i] = stride[i];
  }
  return ToPyObject(value);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2774 2775 2776 2777 2778 2779 2780 2781 2782 2783 2784 2785 2786 2787 2788 2789 2790 2791 2792 2793 2794
PyDoc_STRVAR(tensor_contiguous__doc__,
             R"DOC(contiguous($self, /)
--

Returns a contiguous in memory tensor containing the same data as current Tensor.
If self tensor is already contiguous, this function returns the current Tensor.

Returns:
    Tensor, The contiguous Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1, 2, 3])
        y = x[1]
        y = y.contiguous()
        print(y)
)DOC");

W
wanghuancoder 已提交
2795 2796 2797 2798 2799 2800 2801 2802 2803 2804 2805 2806 2807 2808 2809 2810 2811 2812 2813 2814 2815 2816 2817 2818
static PyObject* tensor_contiguous(TensorObject* self,
                                   PyObject* args,
                                   PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.is_dense_tensor()) {
    auto dense_tensor =
        std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
    if (dense_tensor->meta().is_contiguous()) {
      Py_INCREF(self);
      return reinterpret_cast<PyObject*>(self);
    } else {
      eager_gil_scoped_release guard;
      return ToPyObject(
          paddle::Tensor(std::make_shared<phi::DenseTensor>(std::move(
              paddle::experimental::Trans2Contiguous(*(dense_tensor.get()))))));
    }

  } else {
    Py_INCREF(self);
    return reinterpret_cast<PyObject*>(self);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2819 2820 2821 2822 2823 2824 2825 2826 2827 2828 2829 2830 2831 2832 2833 2834 2835 2836
PyDoc_STRVAR(tensor_is_contiguous__doc__,
             R"DOC(is_contiguous($self, /)
--

Whether the Tensor is contiguous.

Returns:
    Bool, Whether the Tensor is contiguous.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1, 2, 3])
        y = x[1]
        print(y.is_contiguous())
)DOC");
W
wanghuancoder 已提交
2837 2838 2839 2840 2841 2842 2843 2844 2845 2846 2847 2848 2849 2850
static PyObject* tensor_is_contiguous(TensorObject* self,
                                      PyObject* args,
                                      PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.is_dense_tensor()) {
    auto dense_tensor =
        std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
    return ToPyObject(dense_tensor->meta().is_contiguous());
  } else {
    return ToPyObject(true);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2851
#if defined(PADDLE_WITH_CUDA)
2852 2853
static PyObject* tensor_method__uva(TensorObject* self,
                                    PyObject* args,
2854 2855 2856
                                    PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Running in tensor_method__uva.";
2857 2858
  PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(),
                    true,
W
Weilong Wu 已提交
2859 2860 2861
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "DenseTensor currently."));
2862 2863
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
2864 2865 2866 2867
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "CPU Tensor currently."));
  int device_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
2868
  auto* self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
2869 2870
  tensor_uva(self_tensor, device_id);

2871 2872
  RETURN_PY_NONE

2873 2874 2875
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
#endif
J
Jack Zhou 已提交
2876 2877 2878 2879 2880 2881 2882 2883 2884 2885 2886 2887
static PyObject* tensor_method__is_string_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto string_tensor =
      std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
  if (string_tensor) {
    return ToPyObject(string_tensor->initialized());
  } else {
    return ToPyObject(false);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
2888

2889
PyMethodDef variable_methods[] = {  // NOLINT
2890
    {"numpy",
2891
     (PyCFunction)(void (*)())tensor_method_numpy,
2892
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2893
     tensor_method_numpy__doc__},
2894
    {"_is_initialized",
2895
     (PyCFunction)(void (*)())tensor_method__is_initialized,
2896
     METH_VARARGS | METH_KEYWORDS,
2897
     nullptr},
W
wanghuancoder 已提交
2898
    {"_is_dense_tensor_hold_allocation",
2899 2900
     (PyCFunction)(void (*)(
         void))tensor_method__is_dense_tensor_hold_allocation,
2901
     METH_VARARGS | METH_KEYWORDS,
2902
     nullptr},
2903
    {"_copy_to",
2904
     (PyCFunction)(void (*)())tensor_method__copy_to,
2905
     METH_VARARGS | METH_KEYWORDS,
2906
     nullptr},
2907
    {"copy_",
2908
     (PyCFunction)(void (*)())tensor_method_copy_,
2909
     METH_VARARGS | METH_KEYWORDS,
2910
     nullptr},
2911
    {"clone",
2912
     (PyCFunction)(void (*)())tensor_method_clone,
2913
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2914
     tensor_method_clone__doc__},
2915
    {"reconstruct_from_",
2916
     (PyCFunction)(void (*)())tensor_method_reconstruct_from_,
2917
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2918
     tensor_reconstruct_from___doc__},
2919
    {"retain_grads",
2920
     (PyCFunction)(void (*)())tensor_retain_grads,
2921
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2922
     tensor_method_retain_grads__doc__},
2923
    {"clear_gradient",
2924
     (PyCFunction)(void (*)())tensor_clear_gradient,
2925
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2926
     tensor_clear_gradient__doc__},
2927
    {"is_dense",
2928
     (PyCFunction)(void (*)())tensor_method_is_dense,
2929
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2930
     tensor_method_is_dense__doc__},
L
LiYuRio 已提交
2931
    {"is_dist",
2932
     (PyCFunction)(void (*)())tensor_method_is_dist,
L
LiYuRio 已提交
2933
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2934
     tensor_method_is_dist__doc__},
2935
    {"_zero_grads",
2936
     (PyCFunction)(void (*)())tensor__zero_grads,
2937
     METH_VARARGS | METH_KEYWORDS,
2938
     nullptr},
2939
    {"_share_buffer_to",
2940
     (PyCFunction)(void (*)())tensor__share_buffer_to,
2941
     METH_VARARGS | METH_KEYWORDS,
2942
     nullptr},
2943
    {"_is_shared_buffer_with",
2944
     (PyCFunction)(void (*)())tensor__is_shared_buffer_with,
2945
     METH_VARARGS | METH_KEYWORDS,
2946
     nullptr},
2947
    {"_share_underline_tensor_to",
2948
     (PyCFunction)(void (*)())tensor__share_underline_tensor_to,
2949
     METH_VARARGS | METH_KEYWORDS,
2950
     nullptr},
2951
    {"_is_shared_underline_tensor_with",
2952
     (PyCFunction)(void (*)())tensor__is_shared_underline_tensor_with,
2953
     METH_VARARGS | METH_KEYWORDS,
2954
     nullptr},
2955
    {"detach",
2956
     (PyCFunction)(void (*)())tensor_method_detach,
2957
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2958
     tensor_method_detach__doc__},
W
wanghuancoder 已提交
2959 2960 2961
    {"detach_",
     (PyCFunction)(void (*)(void))tensor_method_detach_,
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2962
     tensor_method_detach___doc__},
2963
    {"get_tensor",
2964
     (PyCFunction)(void (*)())tensor_method_get_underline_tensor,
2965
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2966
     tensor_method_get_tensor__doc__},
2967
    {"get_selected_rows",
2968
     (PyCFunction)(void (*)())tensor_method_get_underline_selected_rows,
2969
     METH_VARARGS | METH_KEYWORDS,
2970
     nullptr},
2971
    {"_get_tensor_from_selected_rows",
2972
     (PyCFunction)(void (*)())tensor_method__get_tensor_from_selected_rows,
2973
     METH_VARARGS | METH_KEYWORDS,
2974
     nullptr},
J
Jiabin Yang 已提交
2975
    {"_getitem_index_not_tensor",
2976
     (PyCFunction)(void (*)())tensor__getitem_index_not_tensor,
2977
     METH_VARARGS | METH_KEYWORDS,
2978
     nullptr},
W
wanghuancoder 已提交
2979
    {"_getitem_from_offset",
2980
     (PyCFunction)(void (*)())tensor__getitem_from_offset,
2981
     METH_VARARGS | METH_KEYWORDS,
2982
     nullptr},
W
wanghuancoder 已提交
2983
    {"__setitem_eager_tensor__",
2984
     (PyCFunction)(void (*)())tensor_method__setitem_eager_tensor,
2985
     METH_VARARGS | METH_KEYWORDS,
2986
     nullptr},
2987
    {"_register_grad_hook",
2988
     (PyCFunction)(void (*)())tensor_register_grad_hook,
2989
     METH_VARARGS | METH_KEYWORDS,
2990
     nullptr},
2991 2992 2993 2994
    {"_inplace_assign",  // NOTE(xiongkun03): only used in sot.
     (PyCFunction)(void (*)())tensor_inplace_assign,
     METH_VARARGS | METH_KEYWORDS,
     nullptr},
2995
    {"_remove_grad_hook",
2996
     (PyCFunction)(void (*)())tensor_remove_grad_hook,
2997
     METH_VARARGS | METH_KEYWORDS,
2998
     nullptr},
2999
    {"_register_backward_hook",
3000
     (PyCFunction)(void (*)())tensor_register_reduce_hook,
3001
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3002
     tensor_method__register_reduce_hook__doc__},
3003
    {"_set_grad_type",
3004
     (PyCFunction)(void (*)())tensor__set_grad_type,
3005
     METH_VARARGS | METH_KEYWORDS,
3006
     nullptr},
3007
    {"_clear",
3008
     (PyCFunction)(void (*)())tensor__clear,
3009
     METH_VARARGS | METH_KEYWORDS,
3010
     nullptr},
3011
    {"_clear_dataptr",
3012
     (PyCFunction)(void (*)())tensor__clear_dataptr,
3013
     METH_VARARGS | METH_KEYWORDS,
3014
     nullptr},
J
Jiabin Yang 已提交
3015
    {"_copy_gradient_from",
3016
     (PyCFunction)(void (*)())tensor__copy_gradient_from,
3017
     METH_VARARGS | METH_KEYWORDS,
3018
     nullptr},
3019
    {"_tensor_use_gpudnn",
3020
     (PyCFunction)(void (*)())tensor__use_gpudnn,
3021
     METH_VARARGS | METH_KEYWORDS,
3022
     nullptr},
3023 3024
    /** the methods to adapt old dygraph, will be removed in the future **/
    {"set_string_list",
3025
     (PyCFunction)(void (*)())tensor_method_set_string_list,
3026
     METH_VARARGS | METH_KEYWORDS,
3027
     nullptr},
3028
    {"set_vocab",
3029
     (PyCFunction)(void (*)())tensor_method_set_vocab,
3030
     METH_VARARGS | METH_KEYWORDS,
3031
     nullptr},
3032
    {"get_map_tensor",
3033
     (PyCFunction)(void (*)())tensor_method_get_map_tensor,
3034
     METH_VARARGS | METH_KEYWORDS,
3035
     nullptr},
3036
    /***the method of sparse tensor****/
3037
    {"nnz",
3038
     (PyCFunction)(void (*)())tensor_method_get_non_zero_nums,
3039
     METH_VARARGS | METH_KEYWORDS,
3040
     tensor_method_nnz__doc__},
3041
    {"indices",
3042
     (PyCFunction)(void (*)())tensor_method_get_non_zero_indices,
3043
     METH_VARARGS | METH_KEYWORDS,
3044
     tensor_method_indices__doc__},
3045
    {"values",
3046
     (PyCFunction)(void (*)())tensor_method_get_non_zero_elements,
3047
     METH_VARARGS | METH_KEYWORDS,
3048
     tensor_method_values__doc__},
3049
    {"crows",
3050
     (PyCFunction)(void (*)())tensor_method_get_non_zero_crows,
3051
     METH_VARARGS | METH_KEYWORDS,
3052
     tensor_method_crows__doc__},
3053
    {"cols",
3054
     (PyCFunction)(void (*)())tensor_method_get_non_zero_cols,
3055
     METH_VARARGS | METH_KEYWORDS,
3056
     tensor_method_cols__doc__},
3057
    {"is_sparse",
3058
     (PyCFunction)(void (*)())tensor_method_is_sparse,
3059
     METH_VARARGS | METH_KEYWORDS,
3060
     tensor_is_sparse__doc__},
3061
    {"is_sparse_coo",
3062
     (PyCFunction)(void (*)())tensor_method_is_sparse_coo,
3063
     METH_VARARGS | METH_KEYWORDS,
3064
     tensor_is_sparse_coo__doc__},
3065
    {"is_sparse_csr",
3066
     (PyCFunction)(void (*)())tensor_method_is_sparse_csr,
3067
     METH_VARARGS | METH_KEYWORDS,
3068
     tensor_is_sparse_csr__doc__},
3069
    {"is_same_shape",
3070
     (PyCFunction)(void (*)())tensor_method_is_same_shape,
3071
     METH_VARARGS | METH_KEYWORDS,
3072
     tensor_is_same_shape__doc__},
3073
    {"to_sparse_csr",
3074
     (PyCFunction)(void (*)())tensor_method_to_sparse_csr,
3075
     METH_VARARGS | METH_KEYWORDS,
3076 3077
     tensor_to_sparse_csr__doc__},
    /***the method of sparse tensor****/
3078
    {"element_size",
3079
     (PyCFunction)(void (*)())tensor_method_element_size,
3080
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3081
     tensor_method_element_size__doc__},
3082
    {"_inplace_version",
3083
     (PyCFunction)(void (*)())tensor__inplace_version,
3084
     METH_VARARGS | METH_KEYWORDS,
3085
     nullptr},
3086
    {"_bump_inplace_version",
3087
     (PyCFunction)(void (*)())tensor__bump_inplace_version,
3088
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3089
     tensor_method__bump_inplace_version__doc__},
3090
    {"is_selected_rows",
3091
     (PyCFunction)(void (*)())tensor_method_is_selected_rows,
3092
     METH_VARARGS | METH_KEYWORDS,
3093
     nullptr},
3094
    {"rows",
3095
     (PyCFunction)(void (*)())tensor_method_get_rows,
3096
     METH_VARARGS | METH_KEYWORDS,
3097
     nullptr},
3098
    {"_reset_grad_inplace_version",
3099
     (PyCFunction)(void (*)())tensor__reset_grad_inplace_version,
3100
     METH_VARARGS | METH_KEYWORDS,
3101
     nullptr},
3102
    {"_share_memory",
3103
     (PyCFunction)(void (*)())tensor_method__share_memory,
3104
     METH_VARARGS | METH_KEYWORDS,
3105
     nullptr},
3106
    {"_offset",
3107
     (PyCFunction)(void (*)())tensor__offset,
3108
     METH_VARARGS | METH_KEYWORDS,
3109
     nullptr},
3110
    {"_grad_name",
3111
     (PyCFunction)(void (*)())tensor__grad_name,
3112
     METH_VARARGS | METH_KEYWORDS,
3113
     nullptr},
3114
    {"_grad_value",
3115
     (PyCFunction)(void (*)())tensor__grad_value,
3116
     METH_VARARGS | METH_KEYWORDS,
3117
     nullptr},
3118
    {"_unset_fake_empty",
3119
     (PyCFunction)(void (*)())tensor__unset_fake_empty,
3120
     METH_VARARGS | METH_KEYWORDS,
3121
     nullptr},
3122
    {"data_ptr",
3123
     (PyCFunction)(void (*)())tensor_data_ptr,
3124
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3125
     tensor_data_ptr__doc__},
W
wanghuancoder 已提交
3126
    {"_grad_ivar",
3127
     (PyCFunction)(void (*)())tensor__grad_ivar,
W
wanghuancoder 已提交
3128
     METH_VARARGS | METH_KEYWORDS,
3129
     nullptr},
W
wanghuancoder 已提交
3130 3131 3132
    {"contiguous",
     (PyCFunction)(void (*)(void))tensor_contiguous,
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3133
     tensor_contiguous__doc__},
W
wanghuancoder 已提交
3134 3135 3136
    {"is_contiguous",
     (PyCFunction)(void (*)(void))tensor_is_contiguous,
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3137
     tensor_is_contiguous__doc__},
W
wanghuancoder 已提交
3138 3139 3140
    {"get_strides",
     (PyCFunction)(void (*)(void))tensor_method_strides,
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3141
     tensor_get_strides__doc__},
3142
#if defined(PADDLE_WITH_CUDA)
3143
    {"_tensor_uva",
3144
     (PyCFunction)(void (*)())tensor_method__uva,
3145
     METH_VARARGS | METH_KEYWORDS,
3146
     nullptr},
3147
#endif
3148
    {nullptr, nullptr, 0, nullptr}};
3149

J
Jack Zhou 已提交
3150
// variable_methods for core.eager.StringTensor
3151
PyMethodDef string_tensor_variable_methods[] = {  // NOLINT
J
Jack Zhou 已提交
3152
    {"numpy",
3153
     (PyCFunction)(void (*)())tensor_method_numpy_for_string_tensor,
3154
     METH_VARARGS | METH_KEYWORDS,
3155
     nullptr},
J
Jack Zhou 已提交
3156
    {"_is_initialized",
3157
     (PyCFunction)(void (*)())tensor_method__is_initialized,
3158
     METH_VARARGS | METH_KEYWORDS,
3159
     nullptr},
J
Jack Zhou 已提交
3160
    {"_is_string_tensor_hold_allocation",
3161 3162
     (PyCFunction)(void (*)(
         void))tensor_method__is_string_tensor_hold_allocation,
3163
     METH_VARARGS | METH_KEYWORDS,
3164
     nullptr},
J
Jack Zhou 已提交
3165
    // TODO(zhoushunjie): Need to add _copy_to, copy_ for StringTensor.
3166
    {nullptr, nullptr, 0, nullptr}};
J
Jack Zhou 已提交
3167

3168 3169
}  // namespace pybind
}  // namespace paddle