eager_method.cc 110.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// disable numpy compile error
12 13 14 15 16 17

#if defined(_MSC_VER)
#include <BaseTsd.h>
typedef SSIZE_T ssize_t;
#endif

18
#include <Python.h>
19 20 21 22
// Avoid a problem with copysign defined in pyconfig.h on Windows.
#ifdef copysign
#undef copysign
#endif
23 24

#include <string>
25
#include <unordered_map>
26 27
#include <vector>

28
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
29
#include "paddle/fluid/eager/api/all.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
31
#include "paddle/fluid/eager/autograd_meta.h"
32 33
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
34
#include "paddle/fluid/eager/utils.h"
35
#include "paddle/fluid/framework/convert_utils.h"
36
#include "paddle/fluid/framework/string_array.h"
37 38 39 40 41 42
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
J
Jiabin Yang 已提交
43
#include "paddle/fluid/pybind/slice_utils.h"
44
#include "paddle/fluid/pybind/uva_utils.h"
45 46 47 48
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
49 50
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
W
wanghuancoder 已提交
51
#include "pybind11/detail/internals.h"
52 53
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
W
wanghuancoder 已提交
54
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
J
Jiabin Yang 已提交
55
#include "paddle/fluid/eager/amp_utils.h"
56
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
J
Jiabin Yang 已提交
57
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
W
wanghuancoder 已提交
58
#include "paddle/fluid/framework/python_headers.h"
W
wanghuancoder 已提交
59
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
W
wanghuancoder 已提交
60
#include "paddle/fluid/pybind/tensor_py.h"
W
wanghuancoder 已提交
61
#include "paddle/phi/api/lib/data_transform.h"
W
wanghuancoder 已提交
62
#include "paddle/phi/core/ddim.h"
63
#include "paddle/phi/core/flags.h"
64
#include "paddle/phi/core/tensor_utils.h"
65
#include "paddle/phi/kernels/funcs/math_function.h"
66
#include "paddle/utils/pybind.h"
L
LiYuRio 已提交
67 68 69
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
#endif
J
Jiabin Yang 已提交
70

71
PHI_DECLARE_bool(set_to_1d);
W
wanghuancoder 已提交
72
DECLARE_bool(use_stride_kernel);
73

74 75 76
namespace paddle {
namespace pybind {

77 78
extern void InitTensorWithNumpyValue(TensorObject* self,
                                     const pybind11::object& array,
79
                                     const paddle::platform::Place& place,
80
                                     bool zero_copy);
81

82
extern PyTypeObject* p_tensor_type;
83

J
Jiabin Yang 已提交
84
Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) {
85
  if (PyObject_TypeCheck(obj, p_tensor_type)) {
J
Jiabin Yang 已提交
86
    VLOG(6) << "Call GetSliceIndexFromTensor in Eager";
87
    paddle::Tensor tensor = CastPyArg2Tensor(obj, 0);
J
Jiabin Yang 已提交
88
    PADDLE_ENFORCE_EQ(
89 90
        tensor.initialized(),
        true,
J
Jiabin Yang 已提交
91 92 93 94 95 96 97 98
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in slice, however we got "
            "uninitialized tensor %s, please check your code.",
            tensor.name()));
    return GetSliceIndexFromTensor((*static_cast<phi::DenseTensor*>(
        CastPyArg2Tensor(obj, 0).impl().get())));
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
99
        "We should only get paddle::Tensor or VarBase in this "
J
Jiabin Yang 已提交
100 101 102 103
        "method, when you reach this means we got another type index."));
  }
}

104 105
PyDoc_STRVAR(tensor_method_numpy__doc__,  // NOLINT
             R"DOC(numpy($self, /)
W
wanghuancoder 已提交
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
--

Returns a numpy array shows the value of current Tensor.

Returns:
    ndarray, The numpy value of current Tensor, dtype is
    same as current Tensor.

Examples:
    .. code-block:: python

        import paddle

        data = paddle.uniform([30, 10, 32], dtype="float32", min=-1, max=1)
        linear = paddle.nn.Linear(32, 64)
        data = paddle.to_tensor(data)
        x = linear(data)
        print(x.numpy())
)DOC");

126 127
static PyObject* tensor_method_numpy(TensorObject* self,
                                     PyObject* args,
128 129
                                     PyObject* kwargs) {
  EAGER_TRY
W
wanghuancoder 已提交
130 131
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl()) {
132 133
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];     // NOLINT
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];  // NOLINT
W
wanghuancoder 已提交
134 135 136 137 138
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
139 140 141 142 143
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_FLOAT_),
        1,
        py_dims,
        py_strides,
        nullptr,
W
wanghuancoder 已提交
144 145 146 147 148
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }
149 150
  auto tensor_dims = self->tensor.shape();
  auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type());
151
  auto sizeof_dtype = phi::SizeOf(self->tensor.type());
152 153
  Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];     // NOLINT
  Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];  // NOLINT
154
  size_t py_rank = tensor_dims.size();
155
  size_t numel = 1;
156
  if (py_rank == 0) {
157
    Py_ssize_t args_num = PyTuple_Size(args);
158 159
    // true by default
    bool set_to_1d = FLAGS_set_to_1d;
160 161 162 163 164 165 166
    if (args_num == (Py_ssize_t)1) {
      PyObject* obj = PyTuple_GET_ITEM(args, 0);
      if (obj == Py_False) {
        set_to_1d = false;
      }
    }
    if (set_to_1d) {
167
      // 0D Tensor hack process to 1D numpy, will remove in release 2.6
168 169 170 171 172
      VLOG(0)
          << "Warning:: 0D Tensor cannot be used as 'Tensor.numpy()[0]' . In "
             "order to avoid this problem, "
             "0D Tensor will be changed to 1D numpy currently, but it's not "
             "correct and will be "
173 174
             "removed in release 2.6. For Tensor contain only one element, "
             "Please "
175
             "modify "
176
             " 'Tensor.numpy()[0]' to 'float(Tensor)' as soon as "
177
             "possible, "
178
             "otherwise 'Tensor.numpy()[0]' will raise error in release 2.6.";
179 180 181 182
      py_rank = 1;
      py_dims[0] = 1;
      py_strides[0] = sizeof_dtype * numel;
    }
W
wanghuancoder 已提交
183 184 185 186 187 188 189 190
  } else if (self->tensor.is_dense_tensor()) {
    auto tensor_stride = self->tensor.strides();

    for (int i = tensor_dims.size() - 1; i >= 0; --i) {
      py_dims[i] = static_cast<size_t>(tensor_dims[i]);
      py_strides[i] = sizeof_dtype * tensor_stride[i];
      numel *= py_dims[i];
    }
191 192 193 194 195 196
  } else {
    for (int i = tensor_dims.size() - 1; i >= 0; --i) {
      py_dims[i] = static_cast<size_t>(tensor_dims[i]);
      py_strides[i] = sizeof_dtype * numel;
      numel *= py_dims[i];
    }
197
  }
W
wanghuancoder 已提交
198 199

  if (!self->tensor.impl()->initialized()) {
W
wanghuancoder 已提交
200 201 202 203 204 205 206 207 208 209 210
    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
        api.PyArray_DescrFromType_(numpy_dtype),
        py_rank,
        py_dims,
        py_strides,
        nullptr,
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);

211
    if (tensor_dims.empty()) {
212 213 214
      py_dims[0] = 0;
      py_strides[0] = 0;
      PyObject* array = api.PyArray_NewFromDescr_(
215 216 217 218 219 220
          api.PyArray_Type_,
          api.PyArray_DescrFromType_(numpy_dtype),
          1,
          py_dims,
          py_strides,
          nullptr,
221 222 223 224 225
          pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
              pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
          nullptr);
      return array;
    }
W
wanghuancoder 已提交
226 227 228
    return array;
  }

W
wanghuancoder 已提交
229 230 231
  phi::DenseTensor cpu_tensor;
  platform::CPUPlace cpu_place;

232
  if (self->tensor.is_cpu() || self->tensor.is_gpu_pinned()) {
W
wanghuancoder 已提交
233
    eager_gil_scoped_release guard;
234
    platform::CPUPlace place;
235 236 237 238
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
239 240
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
241 242 243 244 245
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
246
      // deep copy
W
wanghuancoder 已提交
247 248 249 250 251
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           place,
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
#ifdef PADDLE_WITH_DISTRIBUTE
    } else if (self->tensor.is_dist_tensor()) {
      // TODO(chenweihang): deal with DistTensor as local DenseTensor now,
      // if the local DenseTensor is shard or partial, do gather or reduce?
      VLOG(6) << "Getting DistTensor's numpy value";
      auto* dist_tensor =
          static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
      auto& dense_tensor = dist_tensor->value();
      cpu_tensor.set_meta(dense_tensor.meta());
      // deep copy
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor.Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      // deep copy
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           place,
                           dense_tensor.Holder()->ptr(),
                           dense_tensor.Holder()->size());
#endif
273 274 275 276
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
W
wanghuancoder 已提交
277 278 279 280 281
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
282
      // deep copy
W
wanghuancoder 已提交
283 284 285 286 287
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           place,
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
288 289
    }

290
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
291
  } else if (self->tensor.is_gpu()) {
W
wanghuancoder 已提交
292
    eager_gil_scoped_release guard;
293 294 295 296
#if defined(PADDLE_WITH_CUDA)
    gpuMemcpyKind kind = cudaMemcpyDeviceToHost;
#elif defined(PADDLE_WITH_HIP)
    gpuMemcpyKind kind = hipMemcpyDeviceToHost;
297
    phi::DeviceContextPool::Instance().Get(self->tensor.place())->Wait();
298
#endif
299 300 301 302
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
303 304
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
305 306 307 308 309 310 311 312 313
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
                                      dense_tensor->Holder()->ptr(),
                                      dense_tensor->Holder()->size(),
                                      kind);
314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
#ifdef PADDLE_WITH_DISTRIBUTE
    } else if (self->tensor.is_dist_tensor()) {
      VLOG(6) << "Getting DistTensor's numpy value";
      auto* dist_tensor =
          static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
      auto& dense_tensor = dist_tensor->value();
      cpu_tensor.set_meta(dense_tensor.meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor.Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
                                      dense_tensor.Holder()->ptr(),
                                      dense_tensor.Holder()->size(),
                                      kind);
#endif
330 331 332 333
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
W
wanghuancoder 已提交
334 335 336 337 338 339 340 341 342
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::platform::GpuMemcpySync(cpu_tensor.Holder()->ptr(),
                                      dense_tensor->Holder()->ptr(),
                                      dense_tensor->Holder()->size(),
                                      kind);
343
    }
344
#endif
C
Chen Weihang 已提交
345 346 347 348 349 350 351
#if defined(PADDLE_WITH_XPU)
  } else if (self->tensor.is_xpu()) {
    platform::CPUPlace place;
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
352 353
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
354 355 356 357 358 359 360 361 362 363
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           dense_tensor->place(),
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
C
Chen Weihang 已提交
364 365 366 367
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
W
wanghuancoder 已提交
368 369 370 371 372 373 374 375 376 377
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
      paddle::memory::Copy(place,
                           cpu_tensor.Holder()->ptr(),
                           dense_tensor->place(),
                           dense_tensor->Holder()->ptr(),
                           dense_tensor->Holder()->size());
C
Chen Weihang 已提交
378 379
    }
#endif
380 381
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  } else if (self->tensor.is_custom_device()) {
W
wanghuancoder 已提交
382
    eager_gil_scoped_release guard;
383
    phi::DeviceContextPool::Instance().Get(self->tensor.place())->Wait();
384 385 386 387
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
388 389
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
W
wanghuancoder 已提交
390 391 392 393 394
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
395
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
W
wanghuancoder 已提交
396 397 398
          ->MemoryCopyD2H(cpu_tensor.Holder()->ptr(),
                          dense_tensor->Holder()->ptr(),
                          dense_tensor->Holder()->size());
399 400 401 402
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
C
co63oc 已提交
403
      // TODO(qili93): temporary for ascend npu performance to be removed along
404
      // with npu_identity op
405
      paddle::Tensor temp_tensor(std::make_shared<phi::DenseTensor>());
406 407 408 409 410
      if (dense_tensor->storage_properties_initialized()) {
        temp_tensor = npu_identity_ad_func(self->tensor, -1);
        dense_tensor =
            std::dynamic_pointer_cast<phi::DenseTensor>(temp_tensor.impl());
      }
W
wanghuancoder 已提交
411 412 413 414 415
      cpu_tensor.set_meta(dense_tensor->meta());
      auto tmp_allocation_ptr =
          memory::Alloc(cpu_place, dense_tensor->Holder()->size());
      cpu_tensor.ResetHolder(std::shared_ptr<phi::Allocation>(
          tmp_allocation_ptr.release(), tmp_allocation_ptr.get_deleter()));
416
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
W
wanghuancoder 已提交
417 418 419
          ->MemoryCopyD2H(cpu_tensor.Holder()->ptr(),
                          dense_tensor->Holder()->ptr(),
                          dense_tensor->Holder()->size());
420 421
    }
#endif
422 423 424
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Tensor.numpy() only support cpu tensor."));
425
    RETURN_PY_NONE
426 427
  }

W
wanghuancoder 已提交
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447
  void* array_buffer = cpu_tensor.Holder()->ptr();
  size_t array_offset = cpu_tensor.offset();

  PyObject* base = ToPyObject(paddle::Tensor(
      std::make_shared<phi::DenseTensor>(std::move(cpu_tensor))));

  PyObject* array = api.PyArray_NewFromDescr_(
      api.PyArray_Type_,
      api.PyArray_DescrFromType_(numpy_dtype),
      py_rank,
      py_dims,
      py_strides,
      reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(array_buffer) +
                              array_offset),
      pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
          pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
      nullptr);

  api.PyArray_SetBaseObject_(array, base);

448 449 450 451
  return array;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jack Zhou 已提交
452 453 454 455 456 457 458 459
static PyObject* tensor_method_numpy_for_string_tensor(TensorObject* self,
                                                       PyObject* args,
                                                       PyObject* kwargs) {
  EAGER_TRY
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl() || !self->tensor.impl()->initialized()) {
    VLOG(6) << "The StringTensor is uninitialized. Return the empty string "
               "numpy array.";
460 461
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];     // NOLINT
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];  // NOLINT
J
Jack Zhou 已提交
462 463 464 465 466
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
467 468 469 470 471
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_UNICODE_),
        1,
        py_dims,
        py_strides,
        nullptr,
J
Jack Zhou 已提交
472 473 474 475 476 477 478 479 480 481 482 483 484
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }

  if (self->tensor.is_cpu()) {
    VLOG(6) << "Getting StringTensor's numpy value";
    auto string_tensor =
        std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
    const auto* st_ptr = string_tensor->data();
    auto numel = self->tensor.numel();
    auto tensor_dims = self->tensor.shape();
W
wanghuancoder 已提交
485 486
    // Get the max unicode length of StringTensor to create numpy unicode
    // string array.
J
Jack Zhou 已提交
487 488 489 490 491 492 493 494 495 496 497 498
    auto* longest_pstring = std::max_element(
        st_ptr, st_ptr + numel, [](const auto& a, const auto& b) {
          auto a_unicode_len =
              phi::strings::GetUnicodeStrLen(a.data(), a.size());
          auto b_unicode_len =
              phi::strings::GetUnicodeStrLen(b.data(), b.size());
          return a_unicode_len < b_unicode_len;
        });
    size_t max_unicode_length = phi::strings::GetUnicodeStrLen(
        longest_pstring->data(), longest_pstring->size());
    max_unicode_length = (max_unicode_length == 0) ? 1 : max_unicode_length;
    VLOG(6) << "The max unicode length is " << max_unicode_length;
499 500
    auto sp =
        std::make_unique<uint32_t[]>(max_unicode_length * numel);  // NOLINT
J
Jack Zhou 已提交
501 502 503 504 505 506 507 508 509 510
    auto py_array_data = sp.get();
    memset(py_array_data, 0, max_unicode_length * numel * sizeof(uint32_t));
    for (int64_t i = 0; i < numel; ++i) {
      auto curr_unicode_len =
          phi::strings::GetUnicodeStrLen(st_ptr[i].data(), st_ptr[i].size());
      phi::strings::GetUnicodeStr(st_ptr[i].data(),
                                  py_array_data + i * max_unicode_length,
                                  curr_unicode_len);
    }
    py::array array(py::dtype("U" + std::to_string(max_unicode_length)),
511 512 513
                    tensor_dims,
                    {},
                    py_array_data);
J
Jack Zhou 已提交
514 515 516 517
    return array.release().ptr();
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "StringTensor.numpy() only support cpu tensor."));
518
    RETURN_PY_NONE
J
Jack Zhou 已提交
519 520 521 522
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

523 524 525 526
static PyObject* tensor_method__is_initialized(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
527
  return ToPyObject(self->tensor.initialized());
528 529 530
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
531 532 533 534 535 536 537 538 539 540 541 542 543 544
static PyObject* tensor_method__is_dense_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto dense_tensor =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  if (dense_tensor) {
    return ToPyObject(dense_tensor->IsInitialized());
  } else {
    return ToPyObject(false);
  }

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

545
static void IncreaseTensorReferenceCountUntilCopyComplete(
546
    const paddle::Tensor& tensor, const platform::Place& place) {
547 548 549 550 551 552 553 554
  auto place_ = platform::is_gpu_place(place) ? place : tensor.place();

  auto tracer = egr::Controller::Instance().GetCurrentTracer();
  auto gc = tracer->MutableGarbageCollectorIfNotExists(place_);

  // Note(dev): This is an empty callback, the only way is to "reference"
  // inner memory Holder, so it will not be destructed until the kernels
  // launched at current stream of given place is finished, such as
C
co63oc 已提交
555
  // CUDAPinned Mem -> CUDA by cudaMemcpyAsync.
556 557 558 559 560 561 562
  auto callback = [tensor, place_]() {
    VLOG(3) << "Run callback of Tensor:" << tensor.name() << " at place "
            << place_;
  };
  gc->DirectClearCallback(callback);
}

563 564
static PyObject* tensor_method__copy_to(TensorObject* self,
                                        PyObject* args,
565 566
                                        PyObject* kwargs) {
  EAGER_TRY
567 568
  auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
569
  paddle::Tensor cp_tensor;
W
wanghuancoder 已提交
570 571 572 573 574 575 576 577 578 579
  {
    eager_gil_scoped_release guard;
    cp_tensor = self->tensor.copy_to(place, blocking);
    if (!blocking) {
      IncreaseTensorReferenceCountUntilCopyComplete(self->tensor, place);
    }
    egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
    egr::EagerUtils::autograd_meta(&cp_tensor)
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
580
  }
581 582 583 584
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606
PyDoc_STRVAR(tensor_reconstruct_from___doc__,
             R"DOC(reconstruct_from_($self, other/)
--

Reconstruct the self with other Tensor. It is a deep copy of 'self = other'.

Returns:
    None.

Examples:
    .. code-block:: python

      import paddle

      t1 = paddle.to_tensor([1.0], stop_gradient=False)
      t2 = paddle.to_tensor([2.0], stop_gradient=True)

      t1.reconstruct_from_(t2)

      print(t1)
)DOC");

607 608 609 610
static PyObject* tensor_method_reconstruct_from_(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
611
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
612
  std::string orig_name = self->tensor.name();
613 614
  VLOG(6) << "Start Reconstructing Tensor from" << src_tensor.name() << " to "
          << orig_name;
615
  self->tensor = src_tensor;
616 617

  // Recover source name
618
  self->tensor.set_name(orig_name);
619 620

  VLOG(6) << "Finished Reconstructing Tensor from" << src_tensor.name()
621
          << " to " << self->tensor.name();
622 623
  RETURN_PY_NONE

624 625 626
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

627 628
static PyObject* tensor_method_copy_(TensorObject* self,
                                     PyObject* args,
629 630
                                     PyObject* kwargs) {
  EAGER_TRY
631
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
632
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
633
  VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to "
634
          << self->tensor.name();
635
  if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
636
    eager_gil_scoped_release guard;
637
    egr::EagerUtils::autograd_meta(&(self->tensor))
638 639
        ->SetStopGradient(
            egr::EagerUtils::autograd_meta(&(src_tensor))->StopGradient());
640
    egr::EagerUtils::autograd_meta(&(self->tensor))
641 642
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(src_tensor))->Persistable());
643
    if (src_tensor.initialized()) {
C
Chen Weihang 已提交
644
      self->tensor.copy_(src_tensor, src_tensor.place(), blocking);
645 646 647
    }
  } else {
    if (src_tensor.initialized()) {
W
wanghuancoder 已提交
648
      eager_gil_scoped_release guard;
C
Chen Weihang 已提交
649
      self->tensor.copy_(src_tensor, self->tensor.place(), blocking);
650
    }
651 652
  }

653
  VLOG(6) << "Finish Copy Tensor " << src_tensor.name() << " to "
654
          << self->tensor.name();
655 656
  RETURN_PY_NONE

657 658 659
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

660 661
PyDoc_STRVAR(tensor_method_clone__doc__,  // NOLINT
             R"DOC(clone($self, /)
W
wanghuancoder 已提交
662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695
--

Returns a new Tensor, which is clone of origin Tensor, and it remains in the current graph.
It will always have a Tensor copy.
Tn addition, the cloned Tensor provides gradient propagation.

Returns:
    Tensor, The cloned Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor(1.0, stop_gradient=False)
        clone_x = x.clone()
        y = clone_x**2
        y.backward()
        print(clone_x.stop_gradient) # False
        print(clone_x.grad)          # [2.0], support gradient propagation
        print(x.stop_gradient)       # False
        print(x.grad)                # [2.0], clone_x support gradient propagation for x

        x = paddle.to_tensor(1.0)
        clone_x = x.clone()
        clone_x.stop_gradient = False
        z = clone_x**3
        z.backward()
        print(clone_x.stop_gradient) # False
        print(clone_x.grad)          # [3.0], support gradient propagation
        print(x.stop_gradient) # True
        print(x.grad)          # None
)DOC");

696 697 698 699
static PyObject* tensor_method_clone(TensorObject* self,
                                     PyObject* args,
                                     PyObject* kwargs) {
  EAGER_TRY
700
  paddle::Tensor out;
W
wanghuancoder 已提交
701 702 703 704 705 706 707 708 709
  {
    eager_gil_scoped_release guard;
    PADDLE_ENFORCE_EQ(
        self->tensor.initialized(),
        true,
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in clone, however we got "
            "uninitialized tensor %s, please check your code.",
            self->tensor.name()));
710

W
wanghuancoder 已提交
711 712
    out = assign_ad_func(self->tensor);
  }
713 714 715 716
  return ToPyObject(out);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748
PyDoc_STRVAR(tensor_method_retain_grads__doc__, R"DOC(retain_grads($self, /)
--

Enables this Tensor to have their grad populated during backward(). It is a no-op for leaf tensors.

Returns:
    None.

Examples:
    .. code-block:: python

      import paddle

      x = paddle.to_tensor([1.0, 2.0, 3.0])
      x.stop_gradient = False
      y = x + x
      y.retain_grads()
      loss = y.sum()
      loss.backward()

      print(y.grad) # [1., 1., 1.]

      x = paddle.to_tensor([1.0, 2.0, 3.0])
      x.stop_gradient = False
      y = x + x
      # y.retain_grads()
      loss = y.sum()
      loss.backward()

      print(y.grad) # None
)DOC");

749 750
static PyObject* tensor_retain_grads(TensorObject* self,
                                     PyObject* args,
751
                                     PyObject* kwargs) {
752
  EAGER_TRY
753
  if (egr::Controller::Instance().HasGrad()) {
W
wanghuancoder 已提交
754
    eager_gil_scoped_release guard;
755
    auto meta = egr::EagerUtils::autograd_meta(&(self->tensor));
756
    if (!meta->GetMutableGradNode()) {
757
      VLOG(6) << "Make grad node of tensor: " << self->tensor.name()
758
              << "become accumulation node";
759
      meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
760
    }
761
    egr::egr_utils_api::RetainGradForTensor(self->tensor);
762
  }
763 764
  RETURN_PY_NONE

765 766 767
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

768
PyDoc_STRVAR(tensor_clear_gradient__doc__,  // NOLINT
W
wanghuancoder 已提交
769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797
             R"DOC(clear_gradient($self, set_to_zero=True, /)
--

Only for Tensor that has gradient, normally we use this for Parameters since
other temporary Tensor doesen't has gradient.

The Gradient of current Tensor will be set to ``0`` elementwise or ``None``.

Args:
    set_to_zero (bool, optional): If set to ``True``, the gradient will be set
        to ``0`` elementwise, otherwise the gradient will be set to ``None``.
        Default: ``True``.

Returns:
    None.

Examples:
    .. code-block:: python

        import paddle
        input = paddle.uniform([10, 2])
        linear = paddle.nn.Linear(2, 3)
        out = linear(input)
        out.backward()
        print("Before clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
        linear.weight.clear_gradient()
        print("After clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
)DOC");

798 799
static PyObject* tensor_clear_gradient(TensorObject* self,
                                       PyObject* args,
800
                                       PyObject* kwargs) {
801
  EAGER_TRY
802
  VLOG(4) << "ClearGradient " << self->tensor.name();
803

804 805 806
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
J
Jiabin Yang 已提交
807
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
808 809
  }

810
  paddle::Tensor* grad;
811
  bool is_leaf = egr::EagerUtils::IsLeafTensor(self->tensor);
J
Jiabin Yang 已提交
812
  if (is_leaf) {
813 814 815
    grad = egr::EagerUtils::mutable_grad(self->tensor);
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
816
                       "Detected nullptr grad"
817 818
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
819
  } else {
820
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
821
    grad = meta->MutableGrad();
822 823
  }

824
  if (grad->impl()) {
W
wanghuancoder 已提交
825
    eager_gil_scoped_release guard;
826 827 828 829 830 831 832 833 834 835
    if (grad->is_selected_rows()) {
      auto selected_rows =
          std::dynamic_pointer_cast<phi::SelectedRows>(grad->impl());
      if (selected_rows->mutable_value()->IsInitialized()) {
        selected_rows->mutable_rows()->clear();
        selected_rows->mutable_value()->clear();
      }
    } else if (grad->is_dense_tensor()) {
      if (grad->initialized()) {
        if (set_to_zero) {
836 837 838 839
          auto* grad_t = static_cast<phi::DenseTensor*>(grad->impl().get());
          auto* dev_ctx =
              platform::DeviceContextPool::Instance().Get(grad_t->place());
          phi::funcs::set_constant(*dev_ctx, grad_t, 0.0);
J
Jiabin Yang 已提交
840 841 842 843 844
          if (is_leaf) {
            std::static_pointer_cast<egr::GradNodeAccumulation>(
                egr::EagerUtils::grad_node(self->tensor))
                ->SetFakeEmpty(true);
          }
845 846 847 848 849 850 851
        } else {
          VLOG(4) << "Gradient of " << self->tensor.name()
                  << " is initialized, will be released.";
          auto dense_tensor =
              std::dynamic_pointer_cast<phi::DenseTensor>(grad->impl());
          dense_tensor->MoveMemoryHolder();
        }
852 853
      }
    }
854
  }
855

856 857
  RETURN_PY_NONE

858 859 860
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

861 862
static PyObject* tensor__zero_grads(TensorObject* self,
                                    PyObject* args,
863
                                    PyObject* kwargs) {
864
  EAGER_TRY
865
  VLOG(4) << "ZeroGrads " << self->tensor.name();
866

867
  if (egr::EagerUtils::IsLeafTensor(self->tensor)) {
W
wanghuancoder 已提交
868
    eager_gil_scoped_release guard;
869
    // Add RetainGrad as PostHook to AccumulationNode
870
    paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
871 872
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
873
                       "Detected nullptr grad"
874 875 876
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
    if (grad->initialized()) {
877 878 879 880 881 882 883
      if (grad->is_dense_tensor()) {
        auto* t = static_cast<phi::DenseTensor*>(grad->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl());
      }
884
    }
885
  } else {
W
wanghuancoder 已提交
886
    eager_gil_scoped_release guard;
887
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
888
    if (meta->MutableGrad()->initialized()) {
889 890 891 892 893 894 895 896 897
      if (meta->MutableGrad()->is_dense_tensor()) {
        auto* t =
            static_cast<phi::DenseTensor*>(meta->MutableGrad()->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        meta->MutableGrad()->set_impl(
            paddle::experimental::zeros_like(*(meta->MutableGrad())).impl());
      }
898
    }
899 900
  }

901 902
  RETURN_PY_NONE

903 904 905
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

906 907
static PyObject* tensor__share_buffer_to(TensorObject* self,
                                         PyObject* args,
908 909
                                         PyObject* kwargs) {
  EAGER_TRY
910
  paddle::Tensor* dst_ptr =
911
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
912 913
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
914 915 916
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
917
                        self->tensor.name()));
918
  auto* src_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
919 920 921
  if (!dst_ptr->defined()) {
    dst_ptr->set_impl(std::make_shared<phi::DenseTensor>());
  }
922
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
B
Baibaifan 已提交
923
  dst_tensor->ShareBufferWith(*src_tensor);
924
  dst_tensor->ShareDataTypeWith(*src_tensor);
925 926
  RETURN_PY_NONE

927 928 929
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

930 931 932 933
static PyObject* tensor__is_shared_buffer_with(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
934
  paddle::Tensor* dst_ptr =
935
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
936 937
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
938 939 940
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
941
                        self->tensor.name()));
942
  bool res = false;
943
  if (!self->tensor.defined() || !dst_ptr->defined()) {
944 945
    return ToPyObject(res);
  }
946 947
  auto* self_ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
948 949 950 951 952
  res = dst_tensor->IsSharedBufferWith(*self_ptr);
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

953 954 955 956
static PyObject* tensor__share_underline_tensor_to(TensorObject* self,
                                                   PyObject* args,
                                                   PyObject* kwargs) {
  EAGER_TRY
957
  paddle::Tensor* src_ptr =
958
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
959 960
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
961 962 963
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
964 965
                        self->tensor.name()));
  src_ptr->set_impl(self->tensor.impl());
966 967
  RETURN_PY_NONE

968 969 970
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

971 972 973 974
static PyObject* tensor__is_shared_underline_tensor_with(TensorObject* self,
                                                         PyObject* args,
                                                         PyObject* kwargs) {
  EAGER_TRY
975
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
976 977
  PADDLE_ENFORCE_EQ(src_tensor.initialized(),
                    true,
978 979 980 981 982
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
                        src_tensor.name()));
  bool res = false;
983
  if (!self->tensor.defined() || !src_tensor.defined()) {
984 985
    return ToPyObject(res);
  }
986
  res = (self->tensor.impl().get() == src_tensor.impl().get());
987 988 989 990
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

991 992
PyDoc_STRVAR(tensor_method_detach__doc__,  // NOLINT
             R"DOC(detach($self, /)
W
wanghuancoder 已提交
993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031
--

Returns a new Tensor, detached from the current graph.
It will share data with origin Tensor and always doesn't have a Tensor copy.
In addition, the detached Tensor doesn't provide gradient propagation.

Returns:
    Tensor, The detached Tensor.

Examples:
    .. code-block:: python

      import paddle

      x = paddle.to_tensor([1.0], stop_gradient=False)
      detach_x = x.detach()
      detach_x[0] = 10.0
      print(x)  # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False,
                  #        [10.])
      y = x**2
      y.backward()
      print(x.grad)         # [20.0]
      print(detach_x.grad)  # None, 'stop_gradient=True' by default

      detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
      z = detach_x**3
      z.backward()

      print(x.grad)         # [20.0], detach_x is detached from x's graph, not affect each other
      print(detach_x.grad)  # [300.0], detach_x has its own graph

      # Due to sharing of data with origin Tensor, There are some unsafe operations:
      # y = 2 * x
      # detach_x[:] = 5.0
      # y.backward()
      # It will raise Error:
      #   one of the variables needed for gradient computation has been modified by an inplace operation.
)DOC");

1032 1033
static PyObject* tensor_method_detach(TensorObject* self,
                                      PyObject* args,
1034 1035
                                      PyObject* kwargs) {
  EAGER_TRY
1036
  PADDLE_ENFORCE_EQ(
1037
      self->tensor.defined(),
1038
      true,
1039
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
1040
                                        self->tensor.name()));
1041

1042
  PyObject* obj = p_tensor_type->tp_alloc(p_tensor_type, 0);
1043
  if (obj) {
1044
    auto v = reinterpret_cast<TensorObject*>(obj);
1045
    new (&(v->tensor)) paddle::Tensor();
1046 1047 1048 1049
    v->tensor.set_impl(self->tensor.impl());
    v->tensor.set_name(egr::Controller::Instance().GenerateUniqueName());
    auto autograd_meta_src = egr::EagerUtils::autograd_meta(&(self->tensor));
    auto autograd_meta = egr::EagerUtils::autograd_meta(&(v->tensor));
1050 1051 1052 1053 1054 1055 1056 1057 1058 1059
    autograd_meta->SetPersistable(autograd_meta_src->Persistable());
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "tp_alloc return null, can not new a PyObject."));
  }

  return obj;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1060 1061 1062 1063 1064 1065 1066 1067 1068 1069
PyDoc_STRVAR(tensor_method_detach___doc__, R"DOC(detach_($self, /)
--

Detach self from the current graph, and returns self Tensor.
In addition, the detached Tensor doesn't provide gradient propagation.

Returns:
    Tensor, The detached Tensor.
)DOC");

W
wanghuancoder 已提交
1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
static PyObject* tensor_method_detach_(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE_EQ(
      self->tensor.defined(),
      true,
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
                                        self->tensor.name()));

  auto autograd_meta = std::make_shared<egr::AutogradMeta>();
  autograd_meta->SetPersistable(
      egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  self->tensor.set_autograd_meta(autograd_meta);

  return reinterpret_cast<PyObject*>(self);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106
PyDoc_STRVAR(tensor_method_get_tensor__doc__, R"DOC(get_tensor($self, /)
--

Returns the underline tensor in the origin Tensor.

Returns:
    Underline tensor.

Examples:
    .. code-block:: python

      import paddle

      x = paddle.to_tensor([1.0], stop_gradient=False)
      underline_x = x.get_tensor()
      print(underline_x) # a Dense Tensor info
)DOC");

1107 1108 1109 1110
static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
1111
  if (!self->tensor.defined()) {
1112 1113 1114
    // The original `get_tensor` method of Variable will create a empty tensor
    phi::DenseTensor empty_tensor;
    return ToPyObject(&empty_tensor);
1115
  }
1116
  if (self->tensor.is_dense_tensor()) {
1117
    auto* tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
1118 1119
    VLOG(6) << "tensor: " << tensor->IsInitialized();
    return ToPyObject(tensor);
L
LiYuRio 已提交
1120 1121
  } else if (self->tensor.is_dist_tensor()) {
#ifdef PADDLE_WITH_DISTRIBUTE
1122 1123
    auto* tensor =
        static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
1124
    VLOG(6) << "dist tensor: " << tensor->defined();
L
LiYuRio 已提交
1125 1126 1127 1128
    return ToPyObject(tensor);
#else
    RETURN_PY_NONE
#endif
1129
  } else {
1130
    RETURN_PY_NONE
1131 1132 1133 1134
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1135 1136 1137 1138 1139
static PyObject* tensor_method_get_underline_selected_rows(TensorObject* self,
                                                           PyObject* args,
                                                           PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
1140
    RETURN_PY_NONE
1141 1142 1143 1144 1145 1146
  }
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    return ToPyObject(selected_rows);
  } else {
1147
    RETURN_PY_NONE
1148 1149 1150 1151
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165
static PyObject* tensor_method__get_tensor_from_selected_rows(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows."));

  auto* selected_rows =
      static_cast<phi::SelectedRows*>(self->tensor.impl().get());

  PADDLE_ENFORCE(
      selected_rows->initialized(),
      paddle::platform::errors::Fatal("SelectedRows must be initialized."));

1166 1167
  auto* dense_tensor =
      static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
L
Leo Chen 已提交
1168
  VLOG(4) << "dense_tensor: " << dense_tensor->IsInitialized();
1169

1170
  auto t = paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
1171 1172 1173 1174 1175 1176 1177
  t.set_impl(std::make_shared<phi::DenseTensor>(*dense_tensor));

  return ToPyObject(t);

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jiabin Yang 已提交
1178 1179 1180
static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
1181
  EAGER_TRY
J
Jiabin Yang 已提交
1182 1183 1184
  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  VLOG(4) << "Call _getitem_index_not_tensor";
  std::vector<int> slice_axes, slice_starts, slice_ends, slice_strides,
W
wanghuancoder 已提交
1185 1186
      decrease_axis, none_axes, infer_flags;
  std::vector<int64_t> list_select_idxs;
J
Jiabin Yang 已提交
1187 1188
  // if index is a list, list_select_flag will be true
  bool list_select_flag = false;
1189 1190
  // Note(0x45f): Using defined() instead of initialized()
  // to support slice tensor which shape like [0, 0, 0].
J
Jiabin Yang 已提交
1191
  PADDLE_ENFORCE_EQ(
1192
      self->tensor.defined(),
1193
      true,
J
Jiabin Yang 已提交
1194 1195 1196 1197 1198
      platform::errors::InvalidArgument(
          "tensor %s has not been initialized, we can only slice initialized "
          "tensor please init it first with numpy or other tensor.",
          self->tensor.name()));
  auto tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209
  ParseIndexingSlice(tensor,
                     _index,
                     &slice_axes,
                     &slice_starts,
                     &slice_ends,
                     &slice_strides,
                     &decrease_axis,
                     &none_axes,
                     &infer_flags,
                     &list_select_idxs,
                     &list_select_flag);
J
Jiabin Yang 已提交
1210

1211 1212 1213 1214
  auto out =
      slice_axes.empty() && !list_select_flag
          ? self->tensor
          : paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
J
Jiabin Yang 已提交
1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230

  if (!slice_axes.empty()) {
    framework::AttributeMap attrs = {{"axes", slice_axes},
                                     {"starts", slice_starts},
                                     {"ends", slice_ends},
                                     {"infer_flags", infer_flags},
                                     {"decrease_axis", decrease_axis}};
    std::string op_type = "slice";
    for (auto stride : slice_strides) {
      if (stride != 1) {
        op_type = "strided_slice";
        attrs.insert({"strides", slice_strides});
        attrs.erase("decrease_axis");
        break;
      }
    }
1231 1232 1233 1234 1235 1236
    std::vector<int64_t> slice_axes_tmp(slice_axes.begin(), slice_axes.end());
    std::vector<int64_t> infer_flags_tmp(infer_flags.begin(),
                                         infer_flags.end());
    std::vector<int64_t> decrease_axis_tmp(decrease_axis.begin(),
                                           decrease_axis.end());

J
Jiabin Yang 已提交
1237
    if (op_type == "slice") {
W
wanghuancoder 已提交
1238
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
1239 1240 1241 1242 1243 1244
      out = slice_ad_func(self->tensor,
                          slice_axes_tmp,
                          slice_starts,
                          slice_ends,
                          infer_flags_tmp,
                          decrease_axis_tmp);
J
Jiabin Yang 已提交
1245
    } else if (op_type == "strided_slice") {
W
wanghuancoder 已提交
1246
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
1247
      out = strided_slice_ad_func(
1248
          self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
1249 1250 1251
      if (!decrease_axis_tmp.empty()) {
        out = squeeze_ad_func(out, decrease_axis_tmp);
      }
J
Jiabin Yang 已提交
1252 1253 1254 1255 1256 1257 1258 1259 1260
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Slice is only support slice and strided_slice, but we got %s which "
          "is impossible, please check your code first or contact us by "
          "issue. ",
          op_type));
    }
  }

1261
  bool set_to_1d = FLAGS_set_to_1d;
1262 1263 1264 1265 1266 1267

  if (set_to_1d) {
    // NOTE(zoooo0820): When all axes are decreased, the output will be 1-D
    // with FLAGS_set_to_1d=True. In this case, one `None` should be pop out,
    // otherwise the output shape will be not correct.
    if (static_cast<int>(decrease_axis.size()) == tensor->dims().size()) {
J
JYChen 已提交
1268
      VLOG(1)
1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280
          << "Warning: In Tensor '__getitem__', if the number of scalar "
             "elements "
             "in the index is equal to the rank of the Tensor, the output "
             "should "
             "be 0-D. In order to be consistent with the behavior of previous "
             "versions, it will be processed to 1-D. But it is not correct and "
             "will be "
             "removed in release 2.6. "
             "If 1-D is still wanted, please modify the index element from "
             "scalar to slice "
             "(e.g. 'x[i]' => 'x[i:i+1]'). ";
      if (!none_axes.empty()) {
1281 1282 1283
        none_axes.pop_back();
      }
    }
1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297
  }
  if (!none_axes.empty()) {
    paddle::Tensor new_out;
    {
      eager_gil_scoped_release guard;
      // Deal with cases that decrease_axes is not empty
      // For example:
      // # x.shape: (2,3,4)
      // out = x[0, 0:2, None] # out.shape : (2, 1, 4)
      for (auto& axis : none_axes) {
        int len = 0;
        for (int da : decrease_axis) {
          if (da < axis) {
            len++;
J
Jiabin Yang 已提交
1298 1299
          }
        }
1300
        axis -= len;
J
Jiabin Yang 已提交
1301
      }
1302
      new_out = unsqueeze_ad_func(out, none_axes);
J
Jiabin Yang 已提交
1303
    }
1304
    return ToPyObject(new_out);
J
Jiabin Yang 已提交
1305 1306 1307 1308
  }

  // the index is a list
  if (list_select_flag) {
W
wanghuancoder 已提交
1309
    eager_gil_scoped_release guard;
W
wanghuancoder 已提交
1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322
    if (FLAGS_use_stride_kernel && list_select_idxs.size() == 1) {
      out = index_select_strided_ad_func(self->tensor, list_select_idxs[0], 0);
    } else {
      auto select_index =
          paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
      auto idx_tensor = std::make_shared<phi::DenseTensor>();
      select_index.set_impl(idx_tensor);
      auto* dev_ctx = platform::DeviceContextPool::Instance().Get(
          egr::Controller::Instance().GetExpectedPlace());
      paddle::framework::TensorFromVector(
          list_select_idxs, *dev_ctx, idx_tensor.get());
      out = index_select_ad_func(self->tensor, select_index, 0);
    }
J
Jiabin Yang 已提交
1323 1324 1325
  }

  return ToPyObject(out);
1326 1327 1328
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1329 1330
static PyObject* tensor__getitem_from_offset(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
1331 1332
                                             PyObject* kwargs) {
  EAGER_TRY
1333 1334 1335 1336 1337 1338 1339 1340
  phi::DenseTensor* ptr = nullptr;
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    ptr = static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
  } else {
    ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  }
1341 1342 1343
  PADDLE_ENFORCE_NOT_NULL(ptr,
                          platform::errors::InvalidArgument(
                              "%s is not a DenseTensor.", self->tensor.name()));
W
wanghuancoder 已提交
1344 1345
  const auto& tensor = *ptr;
  PADDLE_ENFORCE_EQ(
1346 1347
      tensor.IsInitialized(),
      true,
W
wanghuancoder 已提交
1348 1349 1350 1351 1352 1353 1354
      platform::errors::InvalidArgument(
          "Tensor of %s is Empty, please check if it has no data.",
          self->tensor.name()));

  const auto& tensor_dims = tensor.dims();

  std::vector<size_t> dims(tensor_dims.size());
W
wanghuancoder 已提交
1355
  std::vector<size_t> stride = phi::vectorize<size_t>(tensor.strides());
W
wanghuancoder 已提交
1356 1357 1358 1359 1360 1361 1362 1363

  size_t numel = 1;
  for (int i = tensor_dims.size() - 1; i >= 0; --i) {
    dims[i] = static_cast<size_t>(tensor_dims[i]);
    numel *= dims[i];
  }
  size_t offset = 0;
  if (PyTuple_Size(args) == 0) {
1364 1365
    PADDLE_ENFORCE_EQ(numel,
                      1,
W
wanghuancoder 已提交
1366 1367 1368 1369 1370 1371
                      platform::errors::InvalidArgument(
                          "only one element tensors can be converted to Python "
                          "scalars when no input coordinates"));
  } else if (PyTuple_Size(args) == 1) {
    offset = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
    PADDLE_ENFORCE_LT(
1372 1373
        offset,
        numel,
W
wanghuancoder 已提交
1374 1375 1376
        platform::errors::InvalidArgument(
            "index %d is out of bounds for size %d", offset, numel));
  } else {
1377 1378
    PADDLE_ENFORCE_EQ(PyTuple_Size(args),
                      dims.size(),
W
wanghuancoder 已提交
1379 1380 1381 1382 1383 1384
                      platform::errors::InvalidArgument(
                          "incorrect number of indices for Tensor"));

    for (Py_ssize_t i = 0; i < PyTuple_Size(args); ++i) {
      size_t index = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, i), i);
      PADDLE_ENFORCE_LT(
1385 1386
          index,
          dims[i],
W
wanghuancoder 已提交
1387
          platform::errors::InvalidArgument(
1388 1389 1390
              "index %d is out fo bounds for axis %d with size %d",
              index,
              i,
W
wanghuancoder 已提交
1391
              dims[i]));
W
wanghuancoder 已提交
1392
      offset += index * stride[i];
W
wanghuancoder 已提交
1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415
    }
  }
#define PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(_) \
  _(bool, DataType::BOOL)                     \
  _(int8_t, DataType::INT8)                   \
  _(uint8_t, DataType::UINT8)                 \
  _(int16_t, DataType::INT16)                 \
  _(uint16_t, DataType::UINT16)               \
  _(int32_t, DataType::INT32)                 \
  _(uint32_t, DataType::UINT32)               \
  _(int64_t, DataType::INT64)                 \
  _(uint64_t, DataType::UINT64)               \
  _(bfloat16, DataType::BFLOAT16)             \
  _(float16, DataType::FLOAT16)               \
  _(float, DataType::FLOAT32)                 \
  _(double, DataType::FLOAT64)                \
  _(complex64, DataType::COMPLEX64)           \
  _(complex128, DataType::COMPLEX128)

#define TENSOR_TO_PY_SCALAR(T, proto_type)                                   \
  if (tensor.dtype() == proto_type) {                                        \
    auto numpy_dtype = TensorDtype2NumpyDtype(proto_type);                   \
    T b = paddle::pybind::TensorGetElement<T>(tensor, offset);               \
1416 1417
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];    /* NOLINT */  \
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank]; /* NOLINT */  \
W
wanghuancoder 已提交
1418 1419
    auto& api = pybind11::detail::npy_api::get();                            \
    PyObject* array = api.PyArray_NewFromDescr_(                             \
1420 1421
        api.PyArray_Type_,                                                   \
        api.PyArray_DescrFromType_(numpy_dtype),                             \
1422
        0,                                                                   \
1423 1424 1425
        py_dims,                                                             \
        py_strides,                                                          \
        nullptr,                                                             \
W
wanghuancoder 已提交
1426 1427 1428 1429 1430
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |                      \
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,                 \
        nullptr);                                                            \
    std::memcpy(                                                             \
        reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data), \
1431 1432
        static_cast<void*>(&b),                                              \
        sizeof(b));                                                          \
W
wanghuancoder 已提交
1433 1434 1435 1436 1437 1438 1439 1440 1441 1442
    return array;                                                            \
  }

  PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(TENSOR_TO_PY_SCALAR);
#undef TENSOR_TO_PY_SCALAR
  PADDLE_THROW(platform::errors::Unimplemented(
      "Unsupported tensor data type: %s", tensor.dtype()));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483
static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Call __setitem_eager_tensor";

  auto self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());

  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  PyObject* value_obj = PyTuple_GET_ITEM(args, 1);
  // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
  // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
  PyObject* index_ptr =
      !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
  DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
    if (!PyTuple_Check(_index)) {
      Py_DECREF(index_ptr);
      VLOG(4) << "Call Py_DECREF";
    }
  });

  // 1. Check argumnets
  bool parse_index = true;

  // Check whether _index can be parsed.
  const int size = PyTuple_GET_SIZE(index_ptr);
  for (int dim = 0; dim < size; ++dim) {
    PyObject* slice_item = PyTuple_GetItem(index_ptr, dim);
    if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item) ||
          slice_item == Py_Ellipsis || slice_item == Py_None)) {
      parse_index = false;
      break;
    }
  }

  // 2. Call op set_value to speed up if the condition is met,
  // otherwise call TensorToPyArray.
  // TODO(liym27): Try not to call TensorToPyArray because it always
  // copys data to cpu place, which reduces performance.
  if (parse_index) {
    std::vector<int> axes, starts, ends, steps, decrease_axes, none_axes,
W
wanghuancoder 已提交
1484 1485
        infer_flags;
    std::vector<int64_t> list_select_idxs;
W
wanghuancoder 已提交
1486 1487
    // if index is a list, list_select_flag will be true
    bool list_select_flag = false;
1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498
    ParseIndexingSlice(self_tensor,
                       index_ptr,
                       &axes,
                       &starts,
                       &ends,
                       &steps,
                       &decrease_axes,
                       &none_axes,
                       &infer_flags,
                       &list_select_idxs,
                       &list_select_flag);
W
wanghuancoder 已提交
1499 1500 1501 1502 1503 1504 1505 1506 1507 1508

    framework::AttributeMap attrs = {{"axes", axes},
                                     {"starts", starts},
                                     {"ends", ends},
                                     {"steps", steps},
                                     {"decrease_axes", decrease_axes},
                                     {"none_axes", none_axes}};

    if (egr::Controller::Instance().HasGrad()) {
      PADDLE_ENFORCE_EQ(
1509
          egr::EagerUtils::IsLeafTensor(self->tensor) &&
W
wanghuancoder 已提交
1510
              !egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient(),
1511 1512 1513 1514 1515
          false,
          platform::errors::InvalidArgument(
              "Leaf Tensor (%s) that doesn't stop gradient can't use "
              "inplace strategy.",
              self->tensor.name()));
W
wanghuancoder 已提交
1516 1517
    }

1518
    paddle::Tensor value_tensor;
W
wanghuancoder 已提交
1519 1520 1521 1522

    if (PyCheckTensor(value_obj)) {
      value_tensor = reinterpret_cast<TensorObject*>(value_obj)->tensor;
    } else if (py::isinstance<py::array>(value_obj)) {
1523
      paddle::Tensor value_tensor_tmp(
W
wanghuancoder 已提交
1524 1525 1526 1527
          std::make_shared<phi::DenseTensor>(),
          egr::Controller::Instance().GenerateUniqueName());
      py::object value_obj_tmp(py::handle(value_obj), true);
      py::object value = value_obj_tmp;
1528
      if (self->tensor.dtype() == phi::DataType::FLOAT32) {
W
wanghuancoder 已提交
1529 1530 1531
        if (!py::isinstance<py::array_t<float>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<float>(value_obj_tmp);
        }
1532
      } else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
W
wanghuancoder 已提交
1533 1534 1535
        if (!py::isinstance<py::array_t<double>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<double>(value_obj_tmp);
        }
1536
      } else if (self->tensor.dtype() == phi::DataType::INT32) {
W
wanghuancoder 已提交
1537 1538 1539
        if (!py::isinstance<py::array_t<int32_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int32_t>(value_obj_tmp);
        }
1540
      } else if (self->tensor.dtype() == phi::DataType::INT64) {
W
wanghuancoder 已提交
1541 1542 1543
        if (!py::isinstance<py::array_t<int64_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int64_t>(value_obj_tmp);
        }
1544
      } else if (self->tensor.dtype() == phi::DataType::BOOL) {
W
wanghuancoder 已提交
1545 1546 1547
        if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp);
        }
1548 1549 1550 1551 1552 1553 1554 1555 1556 1557
      } else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
        if (!py::isinstance<py::array_t<std::complex<float>>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<std::complex<float>>(
              value_obj_tmp);
        }
      } else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
        if (!py::isinstance<py::array_t<std::complex<double>>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<std::complex<double>>(
              value_obj_tmp);
        }
W
wanghuancoder 已提交
1558 1559 1560 1561
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "When assign a numpy.np value to a paddle.Tensor, "
            "the data type of the paddle.Tensor must be bool, "
1562
            "float32, float64, complex64, complex128, int32 or int64, "
W
wanghuancoder 已提交
1563 1564 1565
            "please check the type of tensor."));
      }

W
wanghuancoder 已提交
1566 1567 1568 1569 1570
      SetTensorFromPyArray(
          static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
          value,
          self->tensor.place(),
          false);
W
wanghuancoder 已提交
1571 1572 1573 1574 1575 1576 1577

      value_tensor = value_tensor_tmp;
    } else {
      py::object value_obj_tmp(py::handle(value_obj), true);
      // convert the value to self data type
      if (py::isinstance<py::float_>(value_obj_tmp) ||
          py::isinstance<py::int_>(value_obj_tmp) ||
1578 1579
          py::isinstance<py::bool_>(value_obj_tmp) ||
          PyComplex_Check(value_obj)) {
1580
        if (self->tensor.dtype() == phi::DataType::FLOAT32) {
1581 1582
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<float>()};
1583
        } else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
1584 1585
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<double>()};
1586
        } else if (self->tensor.dtype() == phi::DataType::INT32) {
1587 1588
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<int32_t>()};
1589
        } else if (self->tensor.dtype() == phi::DataType::INT64) {
1590 1591
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<int64_t>()};
1592
        } else if (self->tensor.dtype() == phi::DataType::BOOL) {
1593 1594
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<bool>()};
1595
        } else if (self->tensor.dtype() == phi::DataType::FLOAT16) {
1596 1597 1598 1599 1600 1601 1602 1603
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<float>()};
        } else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<std::complex<float>>()};
        } else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
          attrs["values"] = std::vector<paddle::experimental::Scalar>{
              value_obj_tmp.cast<std::complex<double>>()};
W
wanghuancoder 已提交
1604 1605 1606 1607
        } else {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "When assign a value to a paddle.Tensor, "
              "the data type of the paddle.Tensor must be bool, "
1608 1609
              "float32, float64, complex64, complex128, int32, int64 or "
              "float16, "
W
wanghuancoder 已提交
1610 1611 1612 1613 1614 1615 1616
              "please check the type of tensor."));
        }
        attrs["shape"] = std::vector<int64_t>{1};

      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Value type error. The assign value allows "
1617
            "numpy.ndarray, integer, float, complex  or bool, "
W
wanghuancoder 已提交
1618 1619 1620 1621 1622 1623 1624
            "but received %s.",
            Py_TYPE(value_obj)));
      }
    }
    {
      // Release gil and do tracing
      py::gil_scoped_release release;
1625
      // use inplace set_value_ operator
J
Jiabin Yang 已提交
1626 1627
      if (value_tensor.initialized() &&
          (self->tensor.dtype() != value_tensor.dtype())) {
1628
        paddle::small_vector<std::vector<paddle::Tensor>,
J
Jiabin Yang 已提交
1629 1630 1631 1632 1633 1634 1635
                             egr::kSlotSmallVectorSize>
            tmps = {{self->tensor}, {value_tensor}};
        auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
        self->tensor = egr::EagerAmpAutoCast(
            self->tensor.name(), self->tensor, amp_dtype, "set_value");
        value_tensor = egr::EagerAmpAutoCast(
            value_tensor.name(), value_tensor, amp_dtype, "set_value");
1636 1637 1638
        if (self->tensor.dtype() != value_tensor.dtype()) {
          value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
        }
J
Jiabin Yang 已提交
1639
      }
1640 1641
      self->tensor = set_value__dygraph_function(
          self->tensor, value_tensor, {}, {}, {}, attrs);
1642 1643 1644 1645 1646 1647 1648 1649 1650
    }
    if (PyCheckTensor(value_obj)) {
      // pass the stop_gradient from value to tensor.
      // pass stop gradient should be done after CheckInplace in
      // set_value__dygraph_function.
      if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
        egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
      }
W
wanghuancoder 已提交
1651 1652
    }
  } else {
1653
    auto self_numpy = TensorToPyArray(*self_tensor, true);
W
wanghuancoder 已提交
1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664
    VLOG(4) << "parse_index is false";
    if (PyCheckTensor(_index)) {
      VLOG(4) << "index is tensor";
      auto index_tensor = static_cast<phi::DenseTensor*>(
          reinterpret_cast<TensorObject*>(_index)->tensor.impl().get());
      auto index_numpy = TensorToPyArray(*index_tensor);
      self_numpy[index_numpy] = py::object(py::handle(value_obj), true);
    } else {
      VLOG(4) << "index is not tensor";
      self_numpy[_index] = py::object(py::handle(value_obj), true);
    }
1665
    if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
1666
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1667 1668 1669 1670
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CUDAPlace(0)),
                           false);
W
wanghuancoder 已提交
1671
#else
1672 1673 1674 1675
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CPUPlace()),
                           false);
W
wanghuancoder 已提交
1676 1677
#endif
    } else {
1678 1679
      SetTensorFromPyArray(
          self_tensor, self_numpy, self->tensor.place(), false);
W
wanghuancoder 已提交
1680 1681
    }
  }
1682 1683
  RETURN_PY_NONE

W
wanghuancoder 已提交
1684 1685 1686
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1687 1688
static PyObject* tensor_register_grad_hook(TensorObject* self,
                                           PyObject* args,
1689 1690 1691
                                           PyObject* kwargs) {
  EAGER_TRY
  int64_t hook_id;
1692
  if (egr::EagerUtils::IsLeafTensor(self->tensor)) {
1693
    VLOG(6) << "Register hook for leaf tensor: " << self->tensor.name();
1694 1695 1696 1697 1698

    auto autograd_meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);

    if (autograd_meta && !autograd_meta->StopGradient()) {
      if (!autograd_meta->GetMutableGradNode()) {
1699
        VLOG(6) << "Detected nullptr grad_node, Leaf tensor should have had "
1700 1701 1702 1703 1704 1705
                   "grad_node with type: GradNodeAccumulation.";
        autograd_meta->SetGradNode(
            std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
      }
    }

1706 1707 1708 1709 1710 1711 1712 1713 1714
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();
    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    auto accumulation_grad_node =
        std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
    hook_id = accumulation_grad_node->RegisterGradientHook(
1715 1716
        rank_info.first,
        rank_info.second,
1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728
        std::make_shared<PyTensorHook>(hook_func));

  } else {
    VLOG(6) << "Register hook for non leaf tensor: " << self->tensor.name();
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();

    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    hook_id = grad_node->RegisterGradientHook(
1729 1730
        rank_info.first,
        rank_info.second,
1731 1732 1733 1734 1735 1736
        std::make_shared<PyTensorHook>(hook_func));
  }
  return ToPyObject(hook_id);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1737 1738
static PyObject* tensor_remove_grad_hook(TensorObject* self,
                                         PyObject* args,
1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750
                                         PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Remove the registered hook for tensor: " << self->tensor.name();
  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);

  int64_t hook_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);

  return ToPyObject(grad_node->RemoveGradientHook(hook_id));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762
static PyObject* tensor_inplace_assign(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "inplace assign for tensor:" << self->tensor.name();
  PyObject* other = PyTuple_GET_ITEM(args, 0);
  PyObject* self_obj = reinterpret_cast<PyObject*>(self);
  ShareTensor(self_obj, other);
  RETURN_PY_NONE;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1763
PyDoc_STRVAR(tensor_method__register_reduce_hook__doc__,  // NOLINT
W
wanghuancoder 已提交
1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786
             R"DOC(_register_backward_hook($self, hook, /)
--

Registers a backward hook for current Tensor.

This hook will be called every time the gradient of current Tensor has been fully calculated.

There are two differences with `_register_grad_hook`:
1. This backward hook will be executed after the gradient accumulation completed across batches,
  but the hook registered by `_register_grad_hook` will be executed the gradient accumulation
  completed in current batch.
2. This backward hook function should have the following signature:

    hook() -> None

  It requires no input and no return value.

Args:
    hook(function): A backward hook to be registered for Tensor.gradient

Returns:
    None
)DOC");
1787 1788
static PyObject* tensor_register_reduce_hook(TensorObject* self,
                                             PyObject* args,
1789 1790 1791 1792 1793 1794
                                             PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Register reduce hook for tensor: " << self->tensor.name();

  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);
1795
  PADDLE_ENFORCE_EQ(egr::EagerUtils::IsLeafTensor(self->tensor),
1796
                    true,
1797 1798 1799 1800
                    platform::errors::InvalidArgument(
                        "Only can register backward hook for leaf Tensor."));
  PADDLE_ENFORCE_EQ(
      !egr::EagerUtils::unsafe_autograd_meta(self->tensor)->StopGradient(),
1801 1802 1803 1804
      true,
      platform::errors::InvalidArgument(
          "Cannot register backward hook on a Tensor that stop "
          "gradient."));
1805 1806
  PADDLE_ENFORCE(
      grad_node.get() != nullptr,
1807
      paddle::platform::errors::Fatal("Detected nullptr grad_node,"
1808 1809 1810 1811 1812 1813 1814
                                      "Leaf tensor should have had grad_node "
                                      "with type: GradNodeAccumulation."));
  PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

  auto accumulation_grad_node =
      std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
  accumulation_grad_node->RegisterReduceHook(
1815
      std::make_shared<PyVoidHook>(hook_func));
1816

1817 1818
  RETURN_PY_NONE

1819 1820 1821
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1822 1823
static PyObject* tensor__set_grad_type(TensorObject* self,
                                       PyObject* args,
J
Jiabin Yang 已提交
1824
                                       PyObject* kwargs) {
1825 1826 1827
  EAGER_TRY
  auto var_type = pybind::CastPyArg2ProtoType(PyTuple_GET_ITEM(args, 0), 0);
  auto grad_tensor =
1828
      egr::EagerUtils::autograd_meta(&self->tensor)->MutableGrad();
1829
  if (var_type == framework::proto::VarType::LOD_TENSOR) {
1830
    grad_tensor->set_impl(std::make_shared<phi::DenseTensor>());
1831
  } else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
1832
    grad_tensor->set_impl(std::make_shared<phi::SelectedRows>());
1833
  }
1834 1835
  RETURN_PY_NONE

1836 1837 1838
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1839 1840
static PyObject* tensor__clear(TensorObject* self,
                               PyObject* args,
J
Jiabin Yang 已提交
1841 1842 1843
                               PyObject* kwargs) {
  EAGER_TRY
  self->tensor.reset();
1844 1845
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1846 1847 1848
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1849 1850 1851 1852 1853 1854 1855 1856 1857
static PyObject* tensor__clear_dataptr(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  self->tensor.set_impl(nullptr);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1858 1859
static PyObject* tensor__copy_gradient_from(TensorObject* self,
                                            PyObject* args,
J
Jiabin Yang 已提交
1860 1861 1862
                                            PyObject* kwargs) {
  EAGER_TRY
  auto src = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
1863
  if (self->tensor.initialized()) {
1864 1865
    PADDLE_ENFORCE_EQ(self->tensor.dtype(),
                      src.dtype(),
J
Jiabin Yang 已提交
1866 1867
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different data type with Tensor %s",
1868 1869
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1870 1871 1872 1873 1874
    PADDLE_ENFORCE_EQ(self->tensor.impl()->type_info().id(),
                      src.impl()->type_info().id(),
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different type with Tensor %s, Tensor "
                          "ShareGradientDataWith cannot be performed!",
1875 1876
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1877 1878 1879 1880
  }
  VLOG(6) << "Tensor copy gradient from: " << src.name();
  auto* p_grad = egr::EagerUtils::mutable_grad(self->tensor);
  if (p_grad) {
1881 1882
    PADDLE_ENFORCE_EQ(src.initialized(),
                      true,
J
Jiabin Yang 已提交
1883 1884 1885 1886
                      platform::errors::InvalidArgument(
                          "Tensor %s has not been initialized", src.name()));
    p_grad->set_impl(src.impl());
  }
1887 1888
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1889 1890
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
1891

1892 1893 1894
static PyObject* tensor__use_gpudnn(TensorObject* self,
                                    PyObject* args,
                                    PyObject* kwargs) {
1895 1896 1897
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.defined() && self->tensor.is_dense_tensor(),
                 paddle::platform::errors::Fatal(
1898
                     "function _use_gpudnn is only effective for DenseTensor"));
1899

1900
  bool use_gpudnn = pybind::CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
1901

1902
  // Set the same use_gpudnn attribute, return directly
1903 1904 1905 1906
  phi::DenseTensor* dense_tensor =
      static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  phi::DenseTensorMeta* dense_tensor_meta =
      phi::DenseTensorUtils::GetMutableMeta(dense_tensor);
1907
  if (use_gpudnn == dense_tensor_meta->use_gpudnn) {
1908 1909 1910
    return ToPyObject(self->tensor);
  }

1911
  // Share all other members of Tensor except use_gpudnn
1912
  phi::DenseTensorMeta target_dense_meta = *dense_tensor_meta;
1913
  target_dense_meta.use_gpudnn = use_gpudnn;
1914 1915 1916 1917
  phi::DenseTensor target_dense_tensor;
  target_dense_tensor.ShareDataWith(*dense_tensor);
  target_dense_tensor.set_meta(target_dense_meta);
  // Construct returned tensor
1918
  paddle::Tensor target_tensor(
1919 1920 1921 1922
      std::make_shared<phi::DenseTensor>(target_dense_tensor),
      self->tensor.name());
  target_tensor.set_autograd_meta(self->tensor.mutable_autograd_meta());
  VLOG(4) << "Tensor: " << target_tensor.name()
1923
          << " set use_gpudnn = " << use_gpudnn;
1924 1925 1926 1927 1928

  return ToPyObject(target_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1929 1930
static PyObject* tensor_method_set_vocab(TensorObject* self,
                                         PyObject* args,
1931 1932
                                         PyObject* kwargs) {
  EAGER_TRY
1933
  using Vocab = paddle::framework::Vocab;
1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945
  auto vocab = CastPyArg2Vocab(PyTuple_GET_ITEM(args, 0), 0);
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Vocab>() = vocab;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_set_string_list(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
1946
  using Strings = paddle::framework::Strings;
1947
  auto strings = CastPyArg2VectorOfString(PyTuple_GET_ITEM(args, 0), 0);
1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Strings>() = strings;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_map_tensor(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE_EQ(
1960 1961
      egr::IsVariableCompatTensor(self->tensor),
      true,
1962 1963
      paddle::platform::errors::Fatal(
          "this method is only effective for VariableCompatTensor"));
1964
  using Vocab = paddle::framework::Vocab;
1965 1966 1967 1968 1969 1970
  auto* var_tensor =
      static_cast<const egr::VariableCompatTensor*>(self->tensor.impl().get());
  return ToPyObject(var_tensor->Get<Vocab>());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996
PyDoc_STRVAR(tensor_method_nnz__doc__,
             R"DOC(nnz($self, /)
--

Note:
    **This API is only available for SparseCooTensor or SparseCsrTensor.**

Returns the total number of non zero elements in input SparseCooTensor/SparseCsrTensor.

Returns:
    int

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.nnz()
        # 3

)DOC");

1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017
static PyObject* tensor_method_get_non_zero_nums(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
    return ToPyObject(sparse_coo_tensor->nnz());
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
    return ToPyObject(sparse_csr_tensor->nnz());
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045
PyDoc_STRVAR(tensor_method_indices__doc__,
             R"DOC(indices($self, /)
--

Note:
    **This API is only available for SparseCooTensor.**

Returns the indices of non zero elements in input SparseCooTensor.

Returns:
    DenseTesnor

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.indices()
        # Tensor(shape=[2, 3], dtype=int64, place=Place(gpu:0), stop_gradient=True,
        #        [[0, 1, 2],
        #         [1, 2, 0]])

)DOC");

2046 2047 2048 2049 2050 2051 2052 2053 2054
static PyObject* tensor_method_get_non_zero_indices(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_coo_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCooTensor"));
  auto sparse_coo_tensor =
      std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
2055
  paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
2056 2057 2058 2059 2060
      sparse_coo_tensor->non_zero_indices()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087
PyDoc_STRVAR(tensor_method_values__doc__,
             R"DOC(values($self, /)
--

Note:
    **This API is only available for SparseCooTensor or SparseCsrTensor.**

Returns the values of non zero elements in input SparseCooTensor.

Returns:
    DenseTesnor

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.values()
        # Tensor(shape=[3], dtype=float32, place=Place(gpu:0), stop_gradient=True,
        #        [1., 2., 3.])

)DOC");

2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099
static PyObject* tensor_method_get_non_zero_elements(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
2100
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
2101 2102 2103 2104 2105
        sparse_coo_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
2106
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
2107 2108 2109 2110 2111 2112
        sparse_csr_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140
PyDoc_STRVAR(tensor_method_crows__doc__,
             R"DOC(crows($self, /)
--

Note:
    **This API is only available for SparseCsrTensor.**

Returns the compressed row index of non zero elements in input SparseCsrTensor.

Returns:
    DenseTesnor

Examples:
    .. code-block:: python

        import paddle

        crows = [0, 2, 3, 5]
        cols = [1, 3, 2, 0, 1]
        values = [1, 2, 3, 4, 5]
        dense_shape = [3, 4]
        csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape)
        csr.crows()
        # Tensor(shape=[4], dtype=int64, place=Place(gpu:0), stop_gradient=True,
        #        [0, 2, 3, 5])

)DOC");

2141 2142 2143 2144 2145 2146 2147 2148 2149
static PyObject* tensor_method_get_non_zero_crows(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
2150
  paddle::Tensor tensor(
2151 2152 2153 2154 2155
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_crows()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183
PyDoc_STRVAR(tensor_method_cols__doc__,
             R"DOC(cols($self, /)
--

Note:
    **This API is only available for SparseCsrTensor.**

Returns the column index of non zero elements in input SparseCsrTensor.

Returns:
    DenseTesnor

Examples:
    .. code-block:: python

        import paddle

        crows = [0, 2, 3, 5]
        cols = [1, 3, 2, 0, 1]
        values = [1, 2, 3, 4, 5]
        dense_shape = [3, 4]
        csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape)
        csr.cols()
        # Tensor(shape=[5], dtype=int64, place=Place(gpu:0), stop_gradient=True,
        #        [1, 3, 2, 0, 1])

)DOC");

2184 2185 2186 2187 2188 2189 2190 2191 2192
static PyObject* tensor_method_get_non_zero_cols(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
2193
  paddle::Tensor tensor(
2194 2195 2196 2197 2198
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_cols()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215
PyDoc_STRVAR(tensor_method_is_dense__doc__, R"DOC(is_dense($self, /)
--

Whether the Tensor is a Dense Tensor.

Returns:
    Whether the Tensor is a Dense Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1.0], stop_gradient=False)
        print(x.is_dense())
)DOC");

2216 2217
static PyObject* tensor_method_is_dense(TensorObject* self,
                                        PyObject* args,
2218 2219 2220 2221 2222 2223 2224 2225 2226
                                        PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
  return ToPyObject(self->tensor.is_dense_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2227 2228 2229 2230 2231 2232 2233 2234 2235 2236 2237 2238 2239 2240 2241 2242 2243
PyDoc_STRVAR(tensor_method_is_dist__doc__, R"DOC(is_dist($self, /)
--

Whether the Tensor is a Distributed Tensor.

Returns:
    Whether the Tensor is a Distributed Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1.0], stop_gradient=False)
        print(x.is_dist()) # False
)DOC");

L
LiYuRio 已提交
2244 2245 2246 2247 2248 2249 2250 2251 2252 2253 2254
static PyObject* tensor_method_is_dist(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
  return ToPyObject(self->tensor.is_dist_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2255 2256 2257 2258 2259 2260 2261 2262 2263 2264 2265 2266 2267 2268 2269 2270 2271 2272 2273 2274 2275 2276 2277 2278
PyDoc_STRVAR(tensor_is_sparse__doc__,
             R"DOC(is_sparse($self, /)
--

Returns whether the input Tensor is SparseCooTensor or SparseCsrTensor.

When input is SparseCooTensor/SparseCsrTensor, will return True. When input is DenseTensor, will return False.

Returns:
    bool

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.is_sparse()
        # True

)DOC");
2279 2280
static PyObject* tensor_method_is_sparse(TensorObject* self,
                                         PyObject* args,
2281 2282
                                         PyObject* kwargs) {
  EAGER_TRY
2283 2284 2285
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
2286 2287 2288 2289 2290
  return ToPyObject(self->tensor.is_sparse_coo_tensor() ||
                    self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2291 2292 2293 2294 2295 2296 2297 2298 2299 2300 2301 2302 2303 2304 2305 2306 2307 2308 2309 2310 2311 2312 2313 2314 2315
PyDoc_STRVAR(tensor_is_sparse_coo__doc__,
             R"DOC(is_sparse_coo($self, /)
--

Returns whether the input Tensor is SparseCooTensor.

When input is SparseCooTensor, will return True. When input is DenseTensor/SparseCsrTensor, will return False.

Returns:
    bool

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.is_sparse_coo()
        # True

)DOC");

2316 2317
static PyObject* tensor_method_is_sparse_coo(TensorObject* self,
                                             PyObject* args,
2318 2319
                                             PyObject* kwargs) {
  EAGER_TRY
2320 2321 2322
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
2323 2324 2325 2326
  return ToPyObject(self->tensor.is_sparse_coo_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2327 2328 2329 2330 2331 2332 2333 2334 2335 2336 2337 2338 2339 2340 2341 2342 2343 2344 2345 2346 2347 2348 2349 2350 2351 2352
PyDoc_STRVAR(tensor_is_sparse_csr__doc__,
             R"DOC(is_sparse_csr($self, /)
--

Returns whether the input Tensor is SparseCsrTensor.

When input is SparseCsrTensor, will return True. When input is DenseTensor/SparseCooTensor, will return False.

Returns:
    bool

Examples:
    .. code-block:: python

        import paddle

        crows = [0, 2, 3, 5]
        cols = [1, 3, 2, 0, 1]
        values = [1, 2, 3, 4, 5]
        dense_shape = [3, 4]
        csr = paddle.sparse.sparse_csr_tensor(crows, cols, values, dense_shape)
        csr.is_sparse_csr()
        # True

)DOC");

2353 2354
static PyObject* tensor_method_is_sparse_csr(TensorObject* self,
                                             PyObject* args,
2355 2356
                                             PyObject* kwargs) {
  EAGER_TRY
2357 2358 2359
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
2360 2361 2362 2363
  return ToPyObject(self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2364 2365 2366 2367 2368 2369 2370 2371 2372 2373 2374 2375 2376 2377 2378 2379 2380 2381 2382 2383 2384 2385 2386 2387 2388 2389 2390 2391 2392 2393 2394
PyDoc_STRVAR(tensor_to_sparse_csr__doc__,
             R"DOC(to_sparse_csr($self, /)
--

Note:
    **This API is only available for DenseTensor or SparseCooTensor.**

Convert input Tensor to SparseCsrTensor.

When input is SparseCooTensor, will convert `COO` to `CSR` . When input is DenseTensor, will convert `Dense` to `CSR` .

Returns:
    SparseCsrTensor

Examples:
    .. code-block:: python

        import paddle

        indices = [[0, 1, 2], [1, 2, 0]]
        values = [1.0, 2.0, 3.0]
        dense_shape = [3, 3]
        coo = paddle.sparse.sparse_coo_tensor(indices, values, dense_shape)
        coo.to_sparse_csr()
        # Tensor(shape=[3, 3], dtype=paddle.float32, place=Place(gpu:0), stop_gradient=True,
        #        crows=[0, 1, 2, 3],
        #        cols=[1, 2, 0],
        #        values=[1., 2., 3.])

)DOC");

2395 2396
static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
                                             PyObject* args,
2397 2398 2399 2400 2401 2402 2403 2404 2405 2406 2407 2408 2409
                                             PyObject* kwargs) {
  EAGER_TRY
  auto csr_tensor = self->tensor.to_sparse_csr();
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetStopGradient(
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient());
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetPersistable(
          egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  return ToPyObject(csr_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2410 2411 2412 2413 2414 2415 2416 2417 2418 2419 2420 2421 2422 2423 2424 2425 2426 2427 2428 2429 2430 2431 2432 2433 2434 2435 2436 2437 2438 2439 2440 2441
PyDoc_STRVAR(tensor_is_same_shape__doc__,
             R"DOC(is_same_shape($self, y, /)
--

Return the results of shape comparison between two Tensors, check whether x.shape equal to y.shape.
Any two type Tensor among DenseTensor/SparseCooTensor/SparseCsrTensor are supported.

Args:
    x (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.
    y (Tensor): The input tensor. It can be DenseTensor/SparseCooTensor/SparseCsrTensor.

Returns:
    bool: True for same shape and False for different shape.

Examples:

    .. code-block:: python

        import paddle

        x = paddle.rand([2, 3, 8])
        y = paddle.rand([2, 3, 8])
        y = y.to_sparse_csr()
        z = paddle.rand([2, 5])

        x.is_same_shape(y)
        # True
        x.is_same_shape(z)
        # False

)DOC");

2442 2443 2444 2445 2446 2447 2448 2449 2450
static PyObject* tensor_method_is_same_shape(TensorObject* self,
                                             PyObject* args,
                                             PyObject* kwargs) {
  EAGER_TRY
  auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
  return ToPyObject(self->tensor.shape() == other.shape());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2451 2452
static PyObject* tensor__inplace_version(TensorObject* self,
                                         PyObject* args,
2453 2454 2455 2456 2457 2458 2459 2460
                                         PyObject* kwargs) {
  EAGER_TRY
  uint32_t inplace_version = self->tensor.current_inplace_version();

  return ToPyObject(inplace_version);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2461 2462
PyDoc_STRVAR(tensor_method_element_size__doc__,  // NOLINT
             R"DOC(element_size($self, /)
W
wanghuancoder 已提交
2463 2464 2465 2466 2467 2468 2469 2470 2471 2472 2473 2474 2475 2476 2477 2478 2479 2480 2481 2482 2483 2484 2485 2486 2487 2488 2489 2490
--

Returns the size in bytes of an element in the Tensor.

Returns:
    int, The size in bytes of an element in the Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor(1, dtype='bool')
        x.element_size() # 1

        x = paddle.to_tensor(1, dtype='float16')
        x.element_size() # 2

        x = paddle.to_tensor(1, dtype='float32')
        x.element_size() # 4

        x = paddle.to_tensor(1, dtype='float64')
        x.element_size() # 8

        x = paddle.to_tensor(1, dtype='complex128')
        x.element_size() # 16
)DOC");

2491 2492
static PyObject* tensor_method_element_size(TensorObject* self,
                                            PyObject* args,
2493 2494
                                            PyObject* kwargs) {
  EAGER_TRY
2495
  uint32_t element_size = phi::SizeOf(self->tensor.dtype());
2496 2497 2498 2499 2500

  return ToPyObject(element_size);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2501
PyDoc_STRVAR(tensor_method__bump_inplace_version__doc__,  // NOLINT
W
wanghuancoder 已提交
2502 2503 2504
             R"DOC(_bump_inplace_version($self, /)
--

2505
Note:
W
wanghuancoder 已提交
2506 2507
    **This API is ONLY available in Dygraph mode.**
    **This is a very low level API. Users should not use it directly. **
2508

W
wanghuancoder 已提交
2509 2510
  Bump the version whenever the Tensor is modified through an inplace operation.
)DOC");
2511 2512 2513 2514 2515
static PyObject* tensor__bump_inplace_version(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  self->tensor.bump_inplace_version();
2516
  RETURN_PY_NONE
2517 2518 2519
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2520 2521 2522 2523
static PyObject* tensor_method_is_selected_rows(TensorObject* self,
                                                PyObject* args,
                                                PyObject* kwargs) {
  EAGER_TRY
2524 2525 2526
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
2527 2528 2529 2530
  return ToPyObject(self->tensor.is_selected_rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2531 2532
static PyObject* tensor_method_get_rows(TensorObject* self,
                                        PyObject* args,
2533 2534 2535 2536 2537 2538 2539 2540 2541 2542 2543
                                        PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows"));
  auto selected_rows =
      std::dynamic_pointer_cast<phi::SelectedRows>(self->tensor.impl());
  return ToPyObject(selected_rows->rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2544 2545 2546 2547 2548 2549 2550 2551 2552 2553
static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
  }

2554
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
2555 2556 2557 2558
  if (grad && grad->defined() && grad->is_dense_tensor() &&
      grad->initialized()) {
    grad->reset_inplace_version(set_to_zero);
  }
2559 2560
  RETURN_PY_NONE

2561 2562 2563
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2564 2565
static PyObject* tensor_method__share_memory(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
2566 2567 2568
                                             PyObject* kwargs) {
  EAGER_TRY
#ifndef _WIN32
2569 2570
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
W
wanghuancoder 已提交
2571 2572 2573 2574 2575 2576 2577 2578 2579 2580 2581 2582 2583 2584 2585 2586
                    platform::errors::InvalidArgument(
                        "Sharing memory only support CPU Tensor currently"));
  // 1. get LoDTensor
  auto* t =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl()).get();
  // 2. allocate shared memory
  void* data_ptr = t->data();
  size_t data_size =
      t->numel() *
      framework::SizeOfType(framework::TransToProtoVarType(t->dtype()));
  auto shared_writer_holder =
      memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
  // 3. maintain mmap fd set & backup ipc_name
  const std::string& ipc_name = shared_writer_holder->ipc_name();
  memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
  // 4. copy data & reset holder
2587 2588 2589 2590 2591
  memory::Copy(platform::CPUPlace(),
               shared_writer_holder->ptr(),
               platform::CPUPlace(),
               data_ptr,
               data_size);
W
wanghuancoder 已提交
2592 2593 2594 2595 2596
  t->ResetHolder(shared_writer_holder);
  return ToPyObject(t);
#else
  PADDLE_THROW(platform::errors::PermissionDenied(
      "Sharing memory in Windows OS is not supported currently"));
2597 2598
  RETURN_PY_NONE

W
wanghuancoder 已提交
2599 2600 2601 2602
#endif
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2603 2604
static PyObject* tensor__offset(TensorObject* self,
                                PyObject* args,
2605 2606 2607 2608
                                PyObject* kwargs) {
  EAGER_TRY
  auto t = std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  PADDLE_ENFORCE_EQ(
2609 2610
      t->IsInitialized(),
      true,
2611 2612 2613 2614 2615 2616 2617
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
                                        self->tensor.name()));

  return ToPyObject(t->offset());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2618 2619
static PyObject* tensor__grad_name(TensorObject* self,
                                   PyObject* args,
2620 2621
                                   PyObject* kwargs) {
  EAGER_TRY
2622
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
2623 2624 2625 2626 2627 2628
  PADDLE_ENFORCE_EQ(
      grad != nullptr,
      true,
      platform::errors::InvalidArgument(
          "Detected nullptr grad. Please check if you have manually "
          "cleared the grad inside autograd_meta"));
2629 2630 2631 2632
  return ToPyObject(grad->name());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2633 2634
static PyObject* tensor__grad_value(TensorObject* self,
                                    PyObject* args,
2635 2636
                                    PyObject* kwargs) {
  EAGER_TRY
2637
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
2638 2639 2640 2641 2642 2643
  PADDLE_ENFORCE_EQ(
      grad != nullptr,
      true,
      platform::errors::InvalidArgument(
          "Detected nullptr grad. Please check if you have manually "
          "cleared the grad inside autograd_meta"));
2644 2645

  if (!grad->defined()) {
2646
    RETURN_PY_NONE
2647 2648
  }
  if (grad->is_dense_tensor()) {
2649
    auto* grad_tensor = static_cast<phi::DenseTensor*>(grad->impl().get());
2650 2651 2652 2653
    return ToPyObject(grad_tensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::Fatal(
        "this method is only supported for DenseTensor"));
2654
    RETURN_PY_NONE
2655 2656 2657 2658
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2659 2660
static PyObject* tensor__unset_fake_empty(TensorObject* self,
                                          PyObject* args,
2661 2662
                                          PyObject* kwargs) {
  EAGER_TRY
2663
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
2664 2665 2666 2667 2668 2669
  PADDLE_ENFORCE_EQ(
      grad != nullptr,
      true,
      platform::errors::InvalidArgument(
          "Detected nullptr grad. Please check if you have manually "
          "cleared the grad inside autograd_meta"));
2670

2671
  bool is_leaf = egr::EagerUtils::IsLeafTensor(self->tensor);
2672 2673 2674 2675 2676 2677 2678 2679 2680
  if (is_leaf) {
    std::static_pointer_cast<egr::GradNodeAccumulation>(
        egr::EagerUtils::grad_node(self->tensor))
        ->SetFakeEmpty(false);
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2681 2682 2683 2684 2685 2686 2687 2688 2689 2690 2691 2692 2693 2694 2695 2696 2697 2698
PyDoc_STRVAR(tensor_data_ptr__doc__,
             R"DOC(data_ptr($self, /)
--

Returns the address of the first element of current Tensor.

Returns:
    int, The address of the first element of current Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1, 2, 3])
        print(x.data_ptr())
)DOC");

2699 2700 2701 2702 2703
static PyObject* tensor_data_ptr(TensorObject* self,
                                 PyObject* args,
                                 PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.initialized() && self->tensor.is_dense_tensor()) {
S
sneaxiy 已提交
2704 2705 2706 2707
    return ToPyObject(
        (int64_t)std::dynamic_pointer_cast<phi::DenseTensor>(  // NOLINT
            self->tensor.impl())
            ->data());
2708 2709 2710 2711 2712
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2713 2714 2715 2716 2717 2718 2719 2720 2721 2722 2723 2724 2725 2726 2727
static PyObject* tensor__grad_ivar(TensorObject* self,
                                   PyObject* args,
                                   PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Get grad for tensor: " << self->tensor.name();
  auto meta = egr::EagerUtils::nullable_autograd_meta(self->tensor);
  VLOG(6) << meta << " initialized: " << meta->Grad().initialized();
  if (meta && meta->Grad().initialized()) {
    return ToPyObject(meta->Grad());
  } else {
    RETURN_PY_NONE
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2728 2729 2730 2731 2732 2733 2734 2735 2736 2737 2738 2739 2740 2741 2742 2743 2744 2745 2746
PyDoc_STRVAR(tensor_get_strides__doc__,
             R"DOC(get_strides($self, /)
--

Returns the strides of current Tensor.

Returns:
    List, the strides of current Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1, 2, 3])
        y = x[1]
        print(y.get_strides())
)DOC");

W
wanghuancoder 已提交
2747 2748 2749 2750 2751 2752 2753 2754 2755 2756 2757 2758 2759 2760 2761 2762 2763 2764
static PyObject* tensor_method_strides(TensorObject* self,
                                       PyObject* args,
                                       PyObject* kwargs) {
  EAGER_TRY
  std::vector<int64_t> value;
  if (!self->tensor.defined() || !self->tensor.is_dense_tensor()) {
    return ToPyObject(value);
  }
  auto stride = self->tensor.strides();
  size_t rank = static_cast<size_t>(stride.size());
  value.resize(rank);
  for (size_t i = 0; i < rank; i++) {
    value[i] = stride[i];
  }
  return ToPyObject(value);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2765 2766 2767 2768 2769 2770 2771 2772 2773 2774 2775 2776 2777 2778 2779 2780 2781 2782 2783 2784 2785
PyDoc_STRVAR(tensor_contiguous__doc__,
             R"DOC(contiguous($self, /)
--

Returns a contiguous in memory tensor containing the same data as current Tensor.
If self tensor is already contiguous, this function returns the current Tensor.

Returns:
    Tensor, The contiguous Tensor.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1, 2, 3])
        y = x[1]
        y = y.contiguous()
        print(y)
)DOC");

W
wanghuancoder 已提交
2786 2787 2788 2789 2790 2791 2792 2793 2794 2795 2796 2797 2798 2799 2800 2801 2802 2803 2804 2805 2806 2807 2808 2809
static PyObject* tensor_contiguous(TensorObject* self,
                                   PyObject* args,
                                   PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.is_dense_tensor()) {
    auto dense_tensor =
        std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
    if (dense_tensor->meta().is_contiguous()) {
      Py_INCREF(self);
      return reinterpret_cast<PyObject*>(self);
    } else {
      eager_gil_scoped_release guard;
      return ToPyObject(
          paddle::Tensor(std::make_shared<phi::DenseTensor>(std::move(
              paddle::experimental::Trans2Contiguous(*(dense_tensor.get()))))));
    }

  } else {
    Py_INCREF(self);
    return reinterpret_cast<PyObject*>(self);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
2810 2811 2812 2813 2814 2815 2816 2817 2818 2819 2820 2821 2822 2823 2824 2825 2826 2827
PyDoc_STRVAR(tensor_is_contiguous__doc__,
             R"DOC(is_contiguous($self, /)
--

Whether the Tensor is contiguous.

Returns:
    Bool, Whether the Tensor is contiguous.

Examples:
    .. code-block:: python

        import paddle

        x = paddle.to_tensor([1, 2, 3])
        y = x[1]
        print(y.is_contiguous())
)DOC");
W
wanghuancoder 已提交
2828 2829 2830 2831 2832 2833 2834 2835 2836 2837 2838 2839 2840 2841
static PyObject* tensor_is_contiguous(TensorObject* self,
                                      PyObject* args,
                                      PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.is_dense_tensor()) {
    auto dense_tensor =
        std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
    return ToPyObject(dense_tensor->meta().is_contiguous());
  } else {
    return ToPyObject(true);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

2842
#if defined(PADDLE_WITH_CUDA)
2843 2844
static PyObject* tensor_method__uva(TensorObject* self,
                                    PyObject* args,
2845 2846 2847
                                    PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Running in tensor_method__uva.";
2848 2849
  PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(),
                    true,
W
Weilong Wu 已提交
2850 2851 2852
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "DenseTensor currently."));
2853 2854
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
2855 2856 2857 2858
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "CPU Tensor currently."));
  int device_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
2859
  auto* self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
2860 2861
  tensor_uva(self_tensor, device_id);

2862 2863
  RETURN_PY_NONE

2864 2865 2866
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
#endif
J
Jack Zhou 已提交
2867 2868 2869 2870 2871 2872 2873 2874 2875 2876 2877 2878
static PyObject* tensor_method__is_string_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto string_tensor =
      std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
  if (string_tensor) {
    return ToPyObject(string_tensor->initialized());
  } else {
    return ToPyObject(false);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
2879

2880
PyMethodDef variable_methods[] = {  // NOLINT
2881
    {"numpy",
2882
     (PyCFunction)(void (*)())tensor_method_numpy,
2883
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2884
     tensor_method_numpy__doc__},
2885
    {"_is_initialized",
2886
     (PyCFunction)(void (*)())tensor_method__is_initialized,
2887
     METH_VARARGS | METH_KEYWORDS,
2888
     nullptr},
W
wanghuancoder 已提交
2889
    {"_is_dense_tensor_hold_allocation",
2890 2891
     (PyCFunction)(void (*)(
         void))tensor_method__is_dense_tensor_hold_allocation,
2892
     METH_VARARGS | METH_KEYWORDS,
2893
     nullptr},
2894
    {"_copy_to",
2895
     (PyCFunction)(void (*)())tensor_method__copy_to,
2896
     METH_VARARGS | METH_KEYWORDS,
2897
     nullptr},
2898
    {"copy_",
2899
     (PyCFunction)(void (*)())tensor_method_copy_,
2900
     METH_VARARGS | METH_KEYWORDS,
2901
     nullptr},
2902
    {"clone",
2903
     (PyCFunction)(void (*)())tensor_method_clone,
2904
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2905
     tensor_method_clone__doc__},
2906
    {"reconstruct_from_",
2907
     (PyCFunction)(void (*)())tensor_method_reconstruct_from_,
2908
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2909
     tensor_reconstruct_from___doc__},
2910
    {"retain_grads",
2911
     (PyCFunction)(void (*)())tensor_retain_grads,
2912
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2913
     tensor_method_retain_grads__doc__},
2914
    {"clear_gradient",
2915
     (PyCFunction)(void (*)())tensor_clear_gradient,
2916
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2917
     tensor_clear_gradient__doc__},
2918
    {"is_dense",
2919
     (PyCFunction)(void (*)())tensor_method_is_dense,
2920
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2921
     tensor_method_is_dense__doc__},
L
LiYuRio 已提交
2922
    {"is_dist",
2923
     (PyCFunction)(void (*)())tensor_method_is_dist,
L
LiYuRio 已提交
2924
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2925
     tensor_method_is_dist__doc__},
2926
    {"_zero_grads",
2927
     (PyCFunction)(void (*)())tensor__zero_grads,
2928
     METH_VARARGS | METH_KEYWORDS,
2929
     nullptr},
2930
    {"_share_buffer_to",
2931
     (PyCFunction)(void (*)())tensor__share_buffer_to,
2932
     METH_VARARGS | METH_KEYWORDS,
2933
     nullptr},
2934
    {"_is_shared_buffer_with",
2935
     (PyCFunction)(void (*)())tensor__is_shared_buffer_with,
2936
     METH_VARARGS | METH_KEYWORDS,
2937
     nullptr},
2938
    {"_share_underline_tensor_to",
2939
     (PyCFunction)(void (*)())tensor__share_underline_tensor_to,
2940
     METH_VARARGS | METH_KEYWORDS,
2941
     nullptr},
2942
    {"_is_shared_underline_tensor_with",
2943
     (PyCFunction)(void (*)())tensor__is_shared_underline_tensor_with,
2944
     METH_VARARGS | METH_KEYWORDS,
2945
     nullptr},
2946
    {"detach",
2947
     (PyCFunction)(void (*)())tensor_method_detach,
2948
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2949
     tensor_method_detach__doc__},
W
wanghuancoder 已提交
2950 2951 2952
    {"detach_",
     (PyCFunction)(void (*)(void))tensor_method_detach_,
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2953
     tensor_method_detach___doc__},
2954
    {"get_tensor",
2955
     (PyCFunction)(void (*)())tensor_method_get_underline_tensor,
2956
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2957
     tensor_method_get_tensor__doc__},
2958
    {"get_selected_rows",
2959
     (PyCFunction)(void (*)())tensor_method_get_underline_selected_rows,
2960
     METH_VARARGS | METH_KEYWORDS,
2961
     nullptr},
2962
    {"_get_tensor_from_selected_rows",
2963
     (PyCFunction)(void (*)())tensor_method__get_tensor_from_selected_rows,
2964
     METH_VARARGS | METH_KEYWORDS,
2965
     nullptr},
J
Jiabin Yang 已提交
2966
    {"_getitem_index_not_tensor",
2967
     (PyCFunction)(void (*)())tensor__getitem_index_not_tensor,
2968
     METH_VARARGS | METH_KEYWORDS,
2969
     nullptr},
W
wanghuancoder 已提交
2970
    {"_getitem_from_offset",
2971
     (PyCFunction)(void (*)())tensor__getitem_from_offset,
2972
     METH_VARARGS | METH_KEYWORDS,
2973
     nullptr},
W
wanghuancoder 已提交
2974
    {"__setitem_eager_tensor__",
2975
     (PyCFunction)(void (*)())tensor_method__setitem_eager_tensor,
2976
     METH_VARARGS | METH_KEYWORDS,
2977
     nullptr},
2978
    {"_register_grad_hook",
2979
     (PyCFunction)(void (*)())tensor_register_grad_hook,
2980
     METH_VARARGS | METH_KEYWORDS,
2981
     nullptr},
2982 2983 2984 2985
    {"_inplace_assign",  // NOTE(xiongkun03): only used in sot.
     (PyCFunction)(void (*)())tensor_inplace_assign,
     METH_VARARGS | METH_KEYWORDS,
     nullptr},
2986
    {"_remove_grad_hook",
2987
     (PyCFunction)(void (*)())tensor_remove_grad_hook,
2988
     METH_VARARGS | METH_KEYWORDS,
2989
     nullptr},
2990
    {"_register_backward_hook",
2991
     (PyCFunction)(void (*)())tensor_register_reduce_hook,
2992
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
2993
     tensor_method__register_reduce_hook__doc__},
2994
    {"_set_grad_type",
2995
     (PyCFunction)(void (*)())tensor__set_grad_type,
2996
     METH_VARARGS | METH_KEYWORDS,
2997
     nullptr},
2998
    {"_clear",
2999
     (PyCFunction)(void (*)())tensor__clear,
3000
     METH_VARARGS | METH_KEYWORDS,
3001
     nullptr},
3002
    {"_clear_dataptr",
3003
     (PyCFunction)(void (*)())tensor__clear_dataptr,
3004
     METH_VARARGS | METH_KEYWORDS,
3005
     nullptr},
J
Jiabin Yang 已提交
3006
    {"_copy_gradient_from",
3007
     (PyCFunction)(void (*)())tensor__copy_gradient_from,
3008
     METH_VARARGS | METH_KEYWORDS,
3009
     nullptr},
3010
    {"_tensor_use_gpudnn",
3011
     (PyCFunction)(void (*)())tensor__use_gpudnn,
3012
     METH_VARARGS | METH_KEYWORDS,
3013
     nullptr},
3014 3015
    /** the methods to adapt old dygraph, will be removed in the future **/
    {"set_string_list",
3016
     (PyCFunction)(void (*)())tensor_method_set_string_list,
3017
     METH_VARARGS | METH_KEYWORDS,
3018
     nullptr},
3019
    {"set_vocab",
3020
     (PyCFunction)(void (*)())tensor_method_set_vocab,
3021
     METH_VARARGS | METH_KEYWORDS,
3022
     nullptr},
3023
    {"get_map_tensor",
3024
     (PyCFunction)(void (*)())tensor_method_get_map_tensor,
3025
     METH_VARARGS | METH_KEYWORDS,
3026
     nullptr},
3027
    /***the method of sparse tensor****/
3028
    {"nnz",
3029
     (PyCFunction)(void (*)())tensor_method_get_non_zero_nums,
3030
     METH_VARARGS | METH_KEYWORDS,
3031
     tensor_method_nnz__doc__},
3032
    {"indices",
3033
     (PyCFunction)(void (*)())tensor_method_get_non_zero_indices,
3034
     METH_VARARGS | METH_KEYWORDS,
3035
     tensor_method_indices__doc__},
3036
    {"values",
3037
     (PyCFunction)(void (*)())tensor_method_get_non_zero_elements,
3038
     METH_VARARGS | METH_KEYWORDS,
3039
     tensor_method_values__doc__},
3040
    {"crows",
3041
     (PyCFunction)(void (*)())tensor_method_get_non_zero_crows,
3042
     METH_VARARGS | METH_KEYWORDS,
3043
     tensor_method_crows__doc__},
3044
    {"cols",
3045
     (PyCFunction)(void (*)())tensor_method_get_non_zero_cols,
3046
     METH_VARARGS | METH_KEYWORDS,
3047
     tensor_method_cols__doc__},
3048
    {"is_sparse",
3049
     (PyCFunction)(void (*)())tensor_method_is_sparse,
3050
     METH_VARARGS | METH_KEYWORDS,
3051
     tensor_is_sparse__doc__},
3052
    {"is_sparse_coo",
3053
     (PyCFunction)(void (*)())tensor_method_is_sparse_coo,
3054
     METH_VARARGS | METH_KEYWORDS,
3055
     tensor_is_sparse_coo__doc__},
3056
    {"is_sparse_csr",
3057
     (PyCFunction)(void (*)())tensor_method_is_sparse_csr,
3058
     METH_VARARGS | METH_KEYWORDS,
3059
     tensor_is_sparse_csr__doc__},
3060
    {"is_same_shape",
3061
     (PyCFunction)(void (*)())tensor_method_is_same_shape,
3062
     METH_VARARGS | METH_KEYWORDS,
3063
     tensor_is_same_shape__doc__},
3064
    {"to_sparse_csr",
3065
     (PyCFunction)(void (*)())tensor_method_to_sparse_csr,
3066
     METH_VARARGS | METH_KEYWORDS,
3067 3068
     tensor_to_sparse_csr__doc__},
    /***the method of sparse tensor****/
3069
    {"element_size",
3070
     (PyCFunction)(void (*)())tensor_method_element_size,
3071
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3072
     tensor_method_element_size__doc__},
3073
    {"_inplace_version",
3074
     (PyCFunction)(void (*)())tensor__inplace_version,
3075
     METH_VARARGS | METH_KEYWORDS,
3076
     nullptr},
3077
    {"_bump_inplace_version",
3078
     (PyCFunction)(void (*)())tensor__bump_inplace_version,
3079
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3080
     tensor_method__bump_inplace_version__doc__},
3081
    {"is_selected_rows",
3082
     (PyCFunction)(void (*)())tensor_method_is_selected_rows,
3083
     METH_VARARGS | METH_KEYWORDS,
3084
     nullptr},
3085
    {"rows",
3086
     (PyCFunction)(void (*)())tensor_method_get_rows,
3087
     METH_VARARGS | METH_KEYWORDS,
3088
     nullptr},
3089
    {"_reset_grad_inplace_version",
3090
     (PyCFunction)(void (*)())tensor__reset_grad_inplace_version,
3091
     METH_VARARGS | METH_KEYWORDS,
3092
     nullptr},
3093
    {"_share_memory",
3094
     (PyCFunction)(void (*)())tensor_method__share_memory,
3095
     METH_VARARGS | METH_KEYWORDS,
3096
     nullptr},
3097
    {"_offset",
3098
     (PyCFunction)(void (*)())tensor__offset,
3099
     METH_VARARGS | METH_KEYWORDS,
3100
     nullptr},
3101
    {"_grad_name",
3102
     (PyCFunction)(void (*)())tensor__grad_name,
3103
     METH_VARARGS | METH_KEYWORDS,
3104
     nullptr},
3105
    {"_grad_value",
3106
     (PyCFunction)(void (*)())tensor__grad_value,
3107
     METH_VARARGS | METH_KEYWORDS,
3108
     nullptr},
3109
    {"_unset_fake_empty",
3110
     (PyCFunction)(void (*)())tensor__unset_fake_empty,
3111
     METH_VARARGS | METH_KEYWORDS,
3112
     nullptr},
3113
    {"data_ptr",
3114
     (PyCFunction)(void (*)())tensor_data_ptr,
3115
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3116
     tensor_data_ptr__doc__},
W
wanghuancoder 已提交
3117
    {"_grad_ivar",
3118
     (PyCFunction)(void (*)())tensor__grad_ivar,
W
wanghuancoder 已提交
3119
     METH_VARARGS | METH_KEYWORDS,
3120
     nullptr},
W
wanghuancoder 已提交
3121 3122 3123
    {"contiguous",
     (PyCFunction)(void (*)(void))tensor_contiguous,
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3124
     tensor_contiguous__doc__},
W
wanghuancoder 已提交
3125 3126 3127
    {"is_contiguous",
     (PyCFunction)(void (*)(void))tensor_is_contiguous,
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3128
     tensor_is_contiguous__doc__},
W
wanghuancoder 已提交
3129 3130 3131
    {"get_strides",
     (PyCFunction)(void (*)(void))tensor_method_strides,
     METH_VARARGS | METH_KEYWORDS,
W
wanghuancoder 已提交
3132
     tensor_get_strides__doc__},
3133
#if defined(PADDLE_WITH_CUDA)
3134
    {"_tensor_uva",
3135
     (PyCFunction)(void (*)())tensor_method__uva,
3136
     METH_VARARGS | METH_KEYWORDS,
3137
     nullptr},
3138
#endif
3139
    {nullptr, nullptr, 0, nullptr}};
3140

J
Jack Zhou 已提交
3141
// variable_methods for core.eager.StringTensor
3142
PyMethodDef string_tensor_variable_methods[] = {  // NOLINT
J
Jack Zhou 已提交
3143
    {"numpy",
3144
     (PyCFunction)(void (*)())tensor_method_numpy_for_string_tensor,
3145
     METH_VARARGS | METH_KEYWORDS,
3146
     nullptr},
J
Jack Zhou 已提交
3147
    {"_is_initialized",
3148
     (PyCFunction)(void (*)())tensor_method__is_initialized,
3149
     METH_VARARGS | METH_KEYWORDS,
3150
     nullptr},
J
Jack Zhou 已提交
3151
    {"_is_string_tensor_hold_allocation",
3152 3153
     (PyCFunction)(void (*)(
         void))tensor_method__is_string_tensor_hold_allocation,
3154
     METH_VARARGS | METH_KEYWORDS,
3155
     nullptr},
J
Jack Zhou 已提交
3156
    // TODO(zhoushunjie): Need to add _copy_to, copy_ for StringTensor.
3157
    {nullptr, nullptr, 0, nullptr}};
J
Jack Zhou 已提交
3158

3159 3160
}  // namespace pybind
}  // namespace paddle