eager_method.cc 78.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// disable numpy compile error
12 13 14 15 16 17

#if defined(_MSC_VER)
#include <BaseTsd.h>
typedef SSIZE_T ssize_t;
#endif

18 19 20
#include <Python.h>

#include <string>
21
#include <unordered_map>
22 23
#include <vector>

24
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
25
#include "paddle/fluid/eager/api/all.h"
J
Jiabin Yang 已提交
26
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
27
#include "paddle/fluid/eager/autograd_meta.h"
28 29
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
30
#include "paddle/fluid/eager/utils.h"
31
#include "paddle/fluid/framework/convert_utils.h"
32 33 34 35 36 37
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
J
Jiabin Yang 已提交
38
#include "paddle/fluid/pybind/slice_utils.h"
39
#include "paddle/fluid/pybind/uva_utils.h"
40 41 42 43
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
44 45
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
W
wanghuancoder 已提交
46
#include "pybind11/detail/internals.h"
47 48
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
W
wanghuancoder 已提交
49
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
J
Jiabin Yang 已提交
50
#include "paddle/fluid/eager/amp_utils.h"
51
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
J
Jiabin Yang 已提交
52
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
W
wanghuancoder 已提交
53
#include "paddle/fluid/framework/python_headers.h"
W
wanghuancoder 已提交
54
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
W
wanghuancoder 已提交
55
#include "paddle/fluid/pybind/tensor_py.h"
W
wanghuancoder 已提交
56
#include "paddle/phi/core/ddim.h"
57
#include "paddle/phi/kernels/funcs/math_function.h"
J
Jiabin Yang 已提交
58

59 60 61
namespace paddle {
namespace pybind {

62 63
extern void InitTensorWithNumpyValue(TensorObject* self,
                                     const pybind11::object& array,
64
                                     const paddle::platform::Place& place,
65
                                     bool zero_copy);
66

67
extern PyTypeObject* p_tensor_type;
68

J
Jiabin Yang 已提交
69 70 71 72 73
Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) {
  if (PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type))) {
    VLOG(6) << "Call GetSliceIndexFromTensor in Eager";
    paddle::experimental::Tensor tensor = CastPyArg2Tensor(obj, 0);
    PADDLE_ENFORCE_EQ(
74 75
        tensor.initialized(),
        true,
J
Jiabin Yang 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in slice, however we got "
            "uninitialized tensor %s, please check your code.",
            tensor.name()));
    return GetSliceIndexFromTensor((*static_cast<phi::DenseTensor*>(
        CastPyArg2Tensor(obj, 0).impl().get())));
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "We should only get paddle::experimental::Tensor or VarBase in this "
        "method, when you reach this means we got another type index."));
  }
}

bool PyCheckTensor(PyObject* obj) {
  return PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type));
}

93 94
static PyObject* tensor_method_numpy(TensorObject* self,
                                     PyObject* args,
95 96
                                     PyObject* kwargs) {
  EAGER_TRY
W
wanghuancoder 已提交
97 98 99 100 101 102 103 104 105
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl()) {
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
106 107 108 109 110
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_FLOAT_),
        1,
        py_dims,
        py_strides,
        nullptr,
W
wanghuancoder 已提交
111 112 113 114 115
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }
116 117
  auto tensor_dims = self->tensor.shape();
  auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type());
118
  auto sizeof_dtype = paddle::framework::DataTypeSize(self->tensor.type());
119 120 121 122 123 124 125 126
  Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
  Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
  size_t numel = 1;
  for (int i = tensor_dims.size() - 1; i >= 0; --i) {
    py_dims[i] = static_cast<size_t>(tensor_dims[i]);
    py_strides[i] = sizeof_dtype * numel;
    numel *= py_dims[i];
  }
W
wanghuancoder 已提交
127

128
  PyObject* array = api.PyArray_NewFromDescr_(
129 130 131 132 133 134
      api.PyArray_Type_,
      api.PyArray_DescrFromType_(numpy_dtype),
      tensor_dims.size(),
      py_dims,
      py_strides,
      nullptr,
135 136 137 138
      pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
          pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
      nullptr);

W
wanghuancoder 已提交
139
  if (!self->tensor.impl()->initialized()) {
140 141 142 143
    if (tensor_dims.size() == 0) {
      py_dims[0] = 0;
      py_strides[0] = 0;
      PyObject* array = api.PyArray_NewFromDescr_(
144 145 146 147 148 149
          api.PyArray_Type_,
          api.PyArray_DescrFromType_(numpy_dtype),
          1,
          py_dims,
          py_strides,
          nullptr,
150 151 152 153 154
          pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
              pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
          nullptr);
      return array;
    }
W
wanghuancoder 已提交
155 156 157
    return array;
  }

158
  if (self->tensor.is_cpu() || self->tensor.is_gpu_pinned()) {
159
    platform::CPUPlace place;
160 161 162 163 164 165 166 167 168 169 170
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
      auto* dense_tensor = static_cast<paddle::framework::LoDTensor*>(
          selected_rows->mutable_value());

      // deep copy
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
171 172 173
          place,
          dense_tensor->data(),
          sizeof_dtype * numel);
174 175 176 177 178 179 180 181
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      // deep copy
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
182 183 184
          place,
          dense_tensor->data(),
          sizeof_dtype * numel);
185 186
    }

187
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
188
  } else if (self->tensor.is_gpu()) {
189 190 191 192 193
#if defined(PADDLE_WITH_CUDA)
    gpuMemcpyKind kind = cudaMemcpyDeviceToHost;
#elif defined(PADDLE_WITH_HIP)
    gpuMemcpyKind kind = hipMemcpyDeviceToHost;
#endif
194 195 196 197 198 199 200
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
      auto* dense_tensor = static_cast<paddle::framework::LoDTensor*>(
          selected_rows->mutable_value());
      paddle::platform::GpuMemcpySync(
201 202
          pybind11::detail::array_proxy(array)->data,
          dense_tensor->data(),
203 204
          paddle::framework::DataTypeSize(dense_tensor->dtype()) *
              dense_tensor->numel(),
205
          kind);
206 207 208 209 210
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      paddle::platform::GpuMemcpySync(
211 212
          pybind11::detail::array_proxy(array)->data,
          dense_tensor->data(),
213 214
          paddle::framework::DataTypeSize(dense_tensor->dtype()) *
              dense_tensor->numel(),
215
          kind);
216
    }
217
#endif
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  } else if (self->tensor.is_custom_device()) {
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
      auto* dense_tensor = static_cast<paddle::framework::LoDTensor*>(
          selected_rows->mutable_value());
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
          ->MemoryCopyD2H(
              pybind11::detail::array_proxy(array)->data,
              dense_tensor->data(),
              paddle::framework::DataTypeSize(dense_tensor->dtype()) *
                  dense_tensor->numel());
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
          ->MemoryCopyD2H(
              pybind11::detail::array_proxy(array)->data,
              dense_tensor->data(),
              paddle::framework::DataTypeSize(dense_tensor->dtype()) *
                  dense_tensor->numel());
    }
#endif
244 245 246
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Tensor.numpy() only support cpu tensor."));
247
    RETURN_PY_NONE
248 249 250 251 252 253
  }

  return array;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jack Zhou 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
static PyObject* tensor_method_numpy_for_string_tensor(TensorObject* self,
                                                       PyObject* args,
                                                       PyObject* kwargs) {
  EAGER_TRY
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl() || !self->tensor.impl()->initialized()) {
    VLOG(6) << "The StringTensor is uninitialized. Return the empty string "
               "numpy array.";
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
269 270 271 272 273
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_UNICODE_),
        1,
        py_dims,
        py_strides,
        nullptr,
J
Jack Zhou 已提交
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }

  if (self->tensor.is_cpu()) {
    VLOG(6) << "Getting StringTensor's numpy value";
    auto string_tensor =
        std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
    const auto* st_ptr = string_tensor->data();
    auto numel = self->tensor.numel();
    auto tensor_dims = self->tensor.shape();
    // Get the max unicode length of StringTensor to create numpy unicode string
    // array.
    auto* longest_pstring = std::max_element(
        st_ptr, st_ptr + numel, [](const auto& a, const auto& b) {
          auto a_unicode_len =
              phi::strings::GetUnicodeStrLen(a.data(), a.size());
          auto b_unicode_len =
              phi::strings::GetUnicodeStrLen(b.data(), b.size());
          return a_unicode_len < b_unicode_len;
        });
    size_t max_unicode_length = phi::strings::GetUnicodeStrLen(
        longest_pstring->data(), longest_pstring->size());
    max_unicode_length = (max_unicode_length == 0) ? 1 : max_unicode_length;
    VLOG(6) << "The max unicode length is " << max_unicode_length;
    auto sp = std::make_unique<uint32_t[]>(max_unicode_length * numel);
    auto py_array_data = sp.get();
    memset(py_array_data, 0, max_unicode_length * numel * sizeof(uint32_t));
    for (int64_t i = 0; i < numel; ++i) {
      auto curr_unicode_len =
          phi::strings::GetUnicodeStrLen(st_ptr[i].data(), st_ptr[i].size());
      phi::strings::GetUnicodeStr(st_ptr[i].data(),
                                  py_array_data + i * max_unicode_length,
                                  curr_unicode_len);
    }
    py::array array(py::dtype("U" + std::to_string(max_unicode_length)),
312 313 314
                    tensor_dims,
                    {},
                    py_array_data);
J
Jack Zhou 已提交
315 316 317 318
    return array.release().ptr();
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "StringTensor.numpy() only support cpu tensor."));
319
    RETURN_PY_NONE
J
Jack Zhou 已提交
320 321 322 323
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

324 325 326 327
static PyObject* tensor_method__is_initialized(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
328
  return ToPyObject(self->tensor.initialized());
329 330 331
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
332 333 334 335 336 337 338 339 340 341 342 343 344 345
static PyObject* tensor_method__is_dense_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto dense_tensor =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  if (dense_tensor) {
    return ToPyObject(dense_tensor->IsInitialized());
  } else {
    return ToPyObject(false);
  }

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
static void IncreaseTensorReferenceCountUntilCopyComplete(
    const paddle::experimental::Tensor& tensor, const platform::Place& place) {
  auto place_ = platform::is_gpu_place(place) ? place : tensor.place();

  auto tracer = egr::Controller::Instance().GetCurrentTracer();
  auto gc = tracer->MutableGarbageCollectorIfNotExists(place_);

  // Note(dev): This is an empty callback, the only way is to "reference"
  // inner memory Holder, so it will not be destructed until the kernels
  // launched at current stream of given place is finished, such as
  // CUDAPinned Mem -> CUDA by cudamemcpyAsync.
  auto callback = [tensor, place_]() {
    VLOG(3) << "Run callback of Tensor:" << tensor.name() << " at place "
            << place_;
  };
  gc->DirectClearCallback(callback);
}

364 365
static PyObject* tensor_method__copy_to(TensorObject* self,
                                        PyObject* args,
366 367
                                        PyObject* kwargs) {
  EAGER_TRY
368 369
  auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
370
  auto cp_tensor = self->tensor.copy_to(place, blocking);
371 372 373
  if (!blocking) {
    IncreaseTensorReferenceCountUntilCopyComplete(self->tensor, place);
  }
374 375 376
  egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
  egr::EagerUtils::autograd_meta(&cp_tensor)
      ->SetPersistable(
377
          egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
378 379 380 381
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

382 383
static PyObject* tensor_method_cpu(TensorObject* self,
                                   PyObject* args,
384 385
                                   PyObject* kwargs) {
  EAGER_TRY
386
  auto cp_tensor = self->tensor.copy_to(phi::CPUPlace(), true);
387 388 389 390 391 392 393 394
  egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
  egr::EagerUtils::autograd_meta(&cp_tensor)
      ->SetPersistable(
          egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

395 396 397 398
static PyObject* tensor_method_reconstruct_from_(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
399 400 401
  paddle::experimental::Tensor src_tensor =
      CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
  std::string orig_name = self->tensor.name();
402 403
  VLOG(6) << "Start Reconstructing Tensor from" << src_tensor.name() << " to "
          << orig_name;
404
  self->tensor = src_tensor;
405 406

  // Recover source name
407
  self->tensor.set_name(orig_name);
408 409

  VLOG(6) << "Finished Reconstructing Tensor from" << src_tensor.name()
410
          << " to " << self->tensor.name();
411 412
  RETURN_PY_NONE

413 414 415
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

416 417
static PyObject* tensor_method_copy_(TensorObject* self,
                                     PyObject* args,
418 419
                                     PyObject* kwargs) {
  EAGER_TRY
420 421
  paddle::experimental::Tensor src_tensor =
      CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
422
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
423
  VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to "
424
          << self->tensor.name();
425
  if (!self->tensor.initialized()) {
426
    egr::EagerUtils::autograd_meta(&(self->tensor))
427 428
        ->SetStopGradient(
            egr::EagerUtils::autograd_meta(&(src_tensor))->StopGradient());
429
    egr::EagerUtils::autograd_meta(&(self->tensor))
430 431
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(src_tensor))->Persistable());
432
    if (src_tensor.initialized()) {
C
Chen Weihang 已提交
433
      self->tensor.copy_(src_tensor, src_tensor.place(), blocking);
434 435 436
    }
  } else {
    if (src_tensor.initialized()) {
C
Chen Weihang 已提交
437
      self->tensor.copy_(src_tensor, self->tensor.place(), blocking);
438
    }
439 440
  }

441
  VLOG(6) << "Finish Copy Tensor " << src_tensor.name() << " to "
442
          << self->tensor.name();
443 444
  RETURN_PY_NONE

445 446 447
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

448 449
static PyObject* tensor_retain_grads(TensorObject* self,
                                     PyObject* args,
450
                                     PyObject* kwargs) {
451
  EAGER_TRY
452
  if (egr::Controller::Instance().HasGrad()) {
453
    auto meta = egr::EagerUtils::autograd_meta(&(self->tensor));
454
    if (!meta->GetMutableGradNode()) {
455
      VLOG(6) << "Make grad node of tensor: " << self->tensor.name()
456
              << "become accumulation node";
457
      meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
458
    }
459
    egr::egr_utils_api::RetainGradForTensor(self->tensor);
460
  }
461 462
  RETURN_PY_NONE

463 464 465
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

466 467
static PyObject* tensor_clear_gradient(TensorObject* self,
                                       PyObject* args,
468
                                       PyObject* kwargs) {
469
  EAGER_TRY
470
  VLOG(4) << "ClearGradient " << self->tensor.name();
471

472 473 474
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
J
Jiabin Yang 已提交
475
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
476 477
  }

478
  paddle::experimental::Tensor* grad;
J
Jiabin Yang 已提交
479 480
  bool is_leaf = egr::egr_utils_api::IsLeafTensor(self->tensor);
  if (is_leaf) {
481 482 483 484 485 486
    grad = egr::EagerUtils::mutable_grad(self->tensor);
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
487
  } else {
488
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
489
    grad = meta->MutableGrad();
490 491
  }

492 493 494 495 496 497 498 499 500 501 502
  if (grad->impl()) {
    if (grad->is_selected_rows()) {
      auto selected_rows =
          std::dynamic_pointer_cast<phi::SelectedRows>(grad->impl());
      if (selected_rows->mutable_value()->IsInitialized()) {
        selected_rows->mutable_rows()->clear();
        selected_rows->mutable_value()->clear();
      }
    } else if (grad->is_dense_tensor()) {
      if (grad->initialized()) {
        if (set_to_zero) {
503 504 505 506
          auto* grad_t = static_cast<phi::DenseTensor*>(grad->impl().get());
          auto* dev_ctx =
              platform::DeviceContextPool::Instance().Get(grad_t->place());
          phi::funcs::set_constant(*dev_ctx, grad_t, 0.0);
J
Jiabin Yang 已提交
507 508 509 510 511
          if (is_leaf) {
            std::static_pointer_cast<egr::GradNodeAccumulation>(
                egr::EagerUtils::grad_node(self->tensor))
                ->SetFakeEmpty(true);
          }
512 513 514 515 516 517 518
        } else {
          VLOG(4) << "Gradient of " << self->tensor.name()
                  << " is initialized, will be released.";
          auto dense_tensor =
              std::dynamic_pointer_cast<phi::DenseTensor>(grad->impl());
          dense_tensor->MoveMemoryHolder();
        }
519 520
      }
    }
521
  }
522

523 524
  RETURN_PY_NONE

525 526 527
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

528 529
static PyObject* tensor__zero_grads(TensorObject* self,
                                    PyObject* args,
530
                                    PyObject* kwargs) {
531
  EAGER_TRY
532
  VLOG(4) << "ZeroGrads " << self->tensor.name();
533

534
  if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
535
    // Add RetainGrad as PostHook to AccumulationNode
536 537 538 539 540 541 542 543
    paddle::experimental::Tensor* grad =
        egr::EagerUtils::mutable_grad(self->tensor);
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
    if (grad->initialized()) {
544 545 546 547 548 549 550
      if (grad->is_dense_tensor()) {
        auto* t = static_cast<phi::DenseTensor*>(grad->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl());
      }
551
    }
552
  } else {
553
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
554
    if (meta->MutableGrad()->initialized()) {
555 556 557 558 559 560 561 562 563
      if (meta->MutableGrad()->is_dense_tensor()) {
        auto* t =
            static_cast<phi::DenseTensor*>(meta->MutableGrad()->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        meta->MutableGrad()->set_impl(
            paddle::experimental::zeros_like(*(meta->MutableGrad())).impl());
      }
564
    }
565 566
  }

567 568
  RETURN_PY_NONE

569 570 571
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

572 573
static PyObject* tensor__share_buffer_to(TensorObject* self,
                                         PyObject* args,
574 575
                                         PyObject* kwargs) {
  EAGER_TRY
576 577
  paddle::experimental::Tensor* dst_ptr =
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
578 579
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
580 581 582
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
583
                        self->tensor.name()));
584
  auto* src_tensor =
585
      static_cast<paddle::framework::Tensor*>(self->tensor.impl().get());
586 587 588
  if (!dst_ptr->defined()) {
    dst_ptr->set_impl(std::make_shared<phi::DenseTensor>());
  }
589 590
  auto dst_tensor =
      static_cast<paddle::framework::Tensor*>(dst_ptr->impl().get());
B
Baibaifan 已提交
591
  dst_tensor->ShareBufferWith(*src_tensor);
592
  dst_tensor->ShareDataTypeWith(*src_tensor);
593 594
  RETURN_PY_NONE

595 596 597
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

598 599 600 601
static PyObject* tensor__is_shared_buffer_with(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
602 603
  paddle::experimental::Tensor* dst_ptr =
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
604 605
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
606 607 608
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
609
                        self->tensor.name()));
610
  bool res = false;
611
  if (!self->tensor.defined() || !dst_ptr->defined()) {
612 613 614
    return ToPyObject(res);
  }
  auto* self_ptr =
615
      static_cast<paddle::framework::Tensor*>(self->tensor.impl().get());
616 617 618 619 620 621 622
  auto dst_tensor =
      static_cast<paddle::framework::Tensor*>(dst_ptr->impl().get());
  res = dst_tensor->IsSharedBufferWith(*self_ptr);
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

623 624 625 626
static PyObject* tensor__share_underline_tensor_to(TensorObject* self,
                                                   PyObject* args,
                                                   PyObject* kwargs) {
  EAGER_TRY
627 628
  paddle::experimental::Tensor* src_ptr =
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
629 630
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
631 632 633
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
634 635
                        self->tensor.name()));
  src_ptr->set_impl(self->tensor.impl());
636 637
  RETURN_PY_NONE

638 639 640
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

641 642 643 644
static PyObject* tensor__is_shared_underline_tensor_with(TensorObject* self,
                                                         PyObject* args,
                                                         PyObject* kwargs) {
  EAGER_TRY
645 646
  paddle::experimental::Tensor src_tensor =
      CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
647 648
  PADDLE_ENFORCE_EQ(src_tensor.initialized(),
                    true,
649 650 651 652 653
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
                        src_tensor.name()));
  bool res = false;
654
  if (!self->tensor.defined() || !src_tensor.defined()) {
655 656
    return ToPyObject(res);
  }
657
  res = (self->tensor.impl().get() == src_tensor.impl().get());
658 659 660 661
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

662 663
static PyObject* tensor_method_detach(TensorObject* self,
                                      PyObject* args,
664 665
                                      PyObject* kwargs) {
  EAGER_TRY
666
  PADDLE_ENFORCE_EQ(
667 668
      self->tensor.initialized(),
      true,
669
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
670
                                        self->tensor.name()));
671

672
  PyObject* obj = p_tensor_type->tp_alloc(p_tensor_type, 0);
673
  if (obj) {
674 675 676 677 678 679
    auto v = reinterpret_cast<TensorObject*>(obj);
    new (&(v->tensor)) paddle::experimental::Tensor();
    v->tensor.set_impl(self->tensor.impl());
    v->tensor.set_name(egr::Controller::Instance().GenerateUniqueName());
    auto autograd_meta_src = egr::EagerUtils::autograd_meta(&(self->tensor));
    auto autograd_meta = egr::EagerUtils::autograd_meta(&(v->tensor));
680 681 682 683 684 685 686 687 688 689
    autograd_meta->SetPersistable(autograd_meta_src->Persistable());
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "tp_alloc return null, can not new a PyObject."));
  }

  return obj;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

690 691 692 693
static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
694
  if (!self->tensor.defined()) {
695 696 697
    // The original `get_tensor` method of Variable will create a empty tensor
    phi::DenseTensor empty_tensor;
    return ToPyObject(&empty_tensor);
698
  }
699 700 701
  if (self->tensor.is_dense_tensor()) {
    auto* tensor =
        static_cast<paddle::framework::LoDTensor*>(self->tensor.impl().get());
702 703 704
    VLOG(6) << "tensor: " << tensor->IsInitialized();
    return ToPyObject(tensor);
  } else {
705
    RETURN_PY_NONE
706 707 708 709
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

710 711 712 713 714
static PyObject* tensor_method_get_underline_selected_rows(TensorObject* self,
                                                           PyObject* args,
                                                           PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
715
    RETURN_PY_NONE
716 717 718 719 720 721
  }
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    return ToPyObject(selected_rows);
  } else {
722
    RETURN_PY_NONE
723 724 725 726
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753
static PyObject* tensor_method__get_tensor_from_selected_rows(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows."));

  auto* selected_rows =
      static_cast<phi::SelectedRows*>(self->tensor.impl().get());

  PADDLE_ENFORCE(
      selected_rows->initialized(),
      paddle::platform::errors::Fatal("SelectedRows must be initialized."));

  auto* dense_tensor = static_cast<paddle::framework::LoDTensor*>(
      selected_rows->mutable_value());
  VLOG(1) << "dense_tensor: " << dense_tensor->IsInitialized();

  auto t = paddle::experimental::Tensor(
      egr::Controller::Instance().GenerateUniqueName());
  t.set_impl(std::make_shared<phi::DenseTensor>(*dense_tensor));

  return ToPyObject(t);

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jiabin Yang 已提交
754 755 756
static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
757
  EAGER_TRY
J
Jiabin Yang 已提交
758 759 760 761 762 763
  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  VLOG(4) << "Call _getitem_index_not_tensor";
  std::vector<int> slice_axes, slice_starts, slice_ends, slice_strides,
      decrease_axis, none_axes, infer_flags, list_select_idxs;
  // if index is a list, list_select_flag will be true
  bool list_select_flag = false;
764 765
  // Note(0x45f): Using defined() instead of initialized()
  // to support slice tensor which shape like [0, 0, 0].
J
Jiabin Yang 已提交
766
  PADDLE_ENFORCE_EQ(
767
      self->tensor.defined(),
768
      true,
J
Jiabin Yang 已提交
769 770 771 772 773
      platform::errors::InvalidArgument(
          "tensor %s has not been initialized, we can only slice initialized "
          "tensor please init it first with numpy or other tensor.",
          self->tensor.name()));
  auto tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
774 775 776 777 778 779 780 781 782 783 784
  ParseIndexingSlice(tensor,
                     _index,
                     &slice_axes,
                     &slice_starts,
                     &slice_ends,
                     &slice_strides,
                     &decrease_axis,
                     &none_axes,
                     &infer_flags,
                     &list_select_idxs,
                     &list_select_flag);
J
Jiabin Yang 已提交
785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805

  auto out = slice_axes.empty() && !list_select_flag
                 ? self->tensor
                 : paddle::experimental::Tensor(
                       egr::Controller::Instance().GenerateUniqueName());

  if (!slice_axes.empty()) {
    framework::AttributeMap attrs = {{"axes", slice_axes},
                                     {"starts", slice_starts},
                                     {"ends", slice_ends},
                                     {"infer_flags", infer_flags},
                                     {"decrease_axis", decrease_axis}};
    std::string op_type = "slice";
    for (auto stride : slice_strides) {
      if (stride != 1) {
        op_type = "strided_slice";
        attrs.insert({"strides", slice_strides});
        attrs.erase("decrease_axis");
        break;
      }
    }
806 807 808 809 810 811
    std::vector<int64_t> slice_axes_tmp(slice_axes.begin(), slice_axes.end());
    std::vector<int64_t> infer_flags_tmp(infer_flags.begin(),
                                         infer_flags.end());
    std::vector<int64_t> decrease_axis_tmp(decrease_axis.begin(),
                                           decrease_axis.end());

J
Jiabin Yang 已提交
812
    if (op_type == "slice") {
J
Jiabin Yang 已提交
813 814 815 816 817 818
      out = slice_ad_func(self->tensor,
                          slice_axes_tmp,
                          slice_starts,
                          slice_ends,
                          infer_flags_tmp,
                          decrease_axis_tmp);
J
Jiabin Yang 已提交
819
    } else if (op_type == "strided_slice") {
J
Jiabin Yang 已提交
820
      out = strided_slice_ad_func(
821
          self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
J
Jiabin Yang 已提交
822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Slice is only support slice and strided_slice, but we got %s which "
          "is impossible, please check your code first or contact us by "
          "issue. ",
          op_type));
    }
  }

  if (!none_axes.empty()) {
    // Deal with cases when all axes are decreased.
    // After slice, the shape of out is [1], which should have been
    // [], but Paddle doesn't support scalar.
    // In order to ensure the correctness of the final shape of out,
    // one dimension of out needs to be decreased.
    // For example:
    // # x.shape: (2,3,4)
    // out = x[0, 1, 1, None] # out.shape : (1)
    if (static_cast<int>(decrease_axis.size()) == tensor->dims().size()) {
      none_axes.pop_back();
    }
    if (!none_axes.empty()) {
      // Deal with cases that decrease_axes is not empty
      // For example:
      // # x.shape: (2,3,4)
      // out = x[0, 0:2, None] # out.shape : (2, 1, 4)
      for (auto& axis : none_axes) {
        int len = 0;
        for (int da : decrease_axis) {
          if (da < axis) {
            len++;
          }
        }
        axis -= len;
      }

      paddle::experimental::Tensor new_out;
J
Jiabin Yang 已提交
859
      new_out = unsqueeze_ad_func(out, none_axes);
J
Jiabin Yang 已提交
860 861 862 863 864 865 866 867 868
      return ToPyObject(new_out);
    }
  }

  // the index is a list
  if (list_select_flag) {
    auto select_index = paddle::experimental::Tensor(
        egr::Controller::Instance().GenerateUniqueName());
    auto idx_tensor = std::make_shared<phi::DenseTensor>();
W
wanghuancoder 已提交
869
    select_index.set_impl(idx_tensor);
J
Jiabin Yang 已提交
870 871
    auto* dev_ctx = platform::DeviceContextPool::Instance().Get(
        egr::Controller::Instance().GetExpectedPlace());
872 873
    paddle::framework::TensorFromVector(
        list_select_idxs, *dev_ctx, idx_tensor.get());
J
Jiabin Yang 已提交
874
    framework::AttributeMap attrs = {{"dim", 0}};
J
Jiabin Yang 已提交
875
    out = index_select_ad_func(self->tensor, select_index, 0);
J
Jiabin Yang 已提交
876 877 878
  }

  return ToPyObject(out);
879 880 881
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

882 883
static PyObject* tensor__getitem_from_offset(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
884 885 886
                                             PyObject* kwargs) {
  EAGER_TRY
  auto ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
887 888 889
  PADDLE_ENFORCE_NOT_NULL(ptr,
                          platform::errors::InvalidArgument(
                              "%s is not a DenseTensor.", self->tensor.name()));
W
wanghuancoder 已提交
890 891
  const auto& tensor = *ptr;
  PADDLE_ENFORCE_EQ(
892 893
      tensor.IsInitialized(),
      true,
W
wanghuancoder 已提交
894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910
      platform::errors::InvalidArgument(
          "Tensor of %s is Empty, please check if it has no data.",
          self->tensor.name()));

  const auto& tensor_dims = tensor.dims();

  std::vector<size_t> dims(tensor_dims.size());
  std::vector<size_t> strides(tensor_dims.size());

  size_t numel = 1;
  for (int i = tensor_dims.size() - 1; i >= 0; --i) {
    strides[i] = numel;
    dims[i] = static_cast<size_t>(tensor_dims[i]);
    numel *= dims[i];
  }
  size_t offset = 0;
  if (PyTuple_Size(args) == 0) {
911 912
    PADDLE_ENFORCE_EQ(numel,
                      1,
W
wanghuancoder 已提交
913 914 915 916 917 918
                      platform::errors::InvalidArgument(
                          "only one element tensors can be converted to Python "
                          "scalars when no input coordinates"));
  } else if (PyTuple_Size(args) == 1) {
    offset = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
    PADDLE_ENFORCE_LT(
919 920
        offset,
        numel,
W
wanghuancoder 已提交
921 922 923
        platform::errors::InvalidArgument(
            "index %d is out of bounds for size %d", offset, numel));
  } else {
924 925
    PADDLE_ENFORCE_EQ(PyTuple_Size(args),
                      dims.size(),
W
wanghuancoder 已提交
926 927 928 929 930 931
                      platform::errors::InvalidArgument(
                          "incorrect number of indices for Tensor"));

    for (Py_ssize_t i = 0; i < PyTuple_Size(args); ++i) {
      size_t index = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, i), i);
      PADDLE_ENFORCE_LT(
932 933
          index,
          dims[i],
W
wanghuancoder 已提交
934
          platform::errors::InvalidArgument(
935 936 937
              "index %d is out fo bounds for axis %d with size %d",
              index,
              i,
W
wanghuancoder 已提交
938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968
              dims[i]));
      offset += index * strides[i];
    }
  }
#define PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(_) \
  _(bool, DataType::BOOL)                     \
  _(int8_t, DataType::INT8)                   \
  _(uint8_t, DataType::UINT8)                 \
  _(int16_t, DataType::INT16)                 \
  _(uint16_t, DataType::UINT16)               \
  _(int32_t, DataType::INT32)                 \
  _(uint32_t, DataType::UINT32)               \
  _(int64_t, DataType::INT64)                 \
  _(uint64_t, DataType::UINT64)               \
  _(bfloat16, DataType::BFLOAT16)             \
  _(float16, DataType::FLOAT16)               \
  _(float, DataType::FLOAT32)                 \
  _(double, DataType::FLOAT64)                \
  _(complex64, DataType::COMPLEX64)           \
  _(complex128, DataType::COMPLEX128)

#define TENSOR_TO_PY_SCALAR(T, proto_type)                                   \
  if (tensor.dtype() == proto_type) {                                        \
    auto numpy_dtype = TensorDtype2NumpyDtype(proto_type);                   \
    T b = paddle::pybind::TensorGetElement<T>(tensor, offset);               \
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];                  \
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];               \
    py_dims[0] = 1;                                                          \
    py_strides[0] = 1;                                                       \
    auto& api = pybind11::detail::npy_api::get();                            \
    PyObject* array = api.PyArray_NewFromDescr_(                             \
969 970 971 972 973 974
        api.PyArray_Type_,                                                   \
        api.PyArray_DescrFromType_(numpy_dtype),                             \
        1,                                                                   \
        py_dims,                                                             \
        py_strides,                                                          \
        nullptr,                                                             \
W
wanghuancoder 已提交
975 976 977 978 979
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |                      \
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,                 \
        nullptr);                                                            \
    std::memcpy(                                                             \
        reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data), \
980 981
        static_cast<void*>(&b),                                              \
        sizeof(b));                                                          \
W
wanghuancoder 已提交
982 983 984 985 986 987 988 989 990 991
    return array;                                                            \
  }

  PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(TENSOR_TO_PY_SCALAR);
#undef TENSOR_TO_PY_SCALAR
  PADDLE_THROW(platform::errors::Unimplemented(
      "Unsupported tensor data type: %s", tensor.dtype()));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035
static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Call __setitem_eager_tensor";

  auto self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());

  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  PyObject* value_obj = PyTuple_GET_ITEM(args, 1);
  // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
  // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
  PyObject* index_ptr =
      !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
  DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
    if (!PyTuple_Check(_index)) {
      Py_DECREF(index_ptr);
      VLOG(4) << "Call Py_DECREF";
    }
  });

  // 1. Check argumnets
  bool parse_index = true;

  // Check whether _index can be parsed.
  const int size = PyTuple_GET_SIZE(index_ptr);
  for (int dim = 0; dim < size; ++dim) {
    PyObject* slice_item = PyTuple_GetItem(index_ptr, dim);
    if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item) ||
          slice_item == Py_Ellipsis || slice_item == Py_None)) {
      parse_index = false;
      break;
    }
  }

  // 2. Call op set_value to speed up if the condition is met,
  // otherwise call TensorToPyArray.
  // TODO(liym27): Try not to call TensorToPyArray because it always
  // copys data to cpu place, which reduces performance.
  if (parse_index) {
    std::vector<int> axes, starts, ends, steps, decrease_axes, none_axes,
        infer_flags, list_select_idxs;
    // if index is a list, list_select_flag will be true
    bool list_select_flag = false;
1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046
    ParseIndexingSlice(self_tensor,
                       index_ptr,
                       &axes,
                       &starts,
                       &ends,
                       &steps,
                       &decrease_axes,
                       &none_axes,
                       &infer_flags,
                       &list_select_idxs,
                       &list_select_flag);
W
wanghuancoder 已提交
1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058

    framework::AttributeMap attrs = {{"axes", axes},
                                     {"starts", starts},
                                     {"ends", ends},
                                     {"steps", steps},
                                     {"decrease_axes", decrease_axes},
                                     {"none_axes", none_axes}};

    if (egr::Controller::Instance().HasGrad()) {
      PADDLE_ENFORCE_EQ(
          egr::egr_utils_api::IsLeafTensor(self->tensor) &&
              !egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient(),
1059 1060 1061 1062 1063
          false,
          platform::errors::InvalidArgument(
              "Leaf Tensor (%s) that doesn't stop gradient can't use "
              "inplace strategy.",
              self->tensor.name()));
W
wanghuancoder 已提交
1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106
    }

    paddle::experimental::Tensor value_tensor;

    if (PyCheckTensor(value_obj)) {
      value_tensor = reinterpret_cast<TensorObject*>(value_obj)->tensor;
    } else if (py::isinstance<py::array>(value_obj)) {
      paddle::experimental::Tensor value_tensor_tmp(
          std::make_shared<phi::DenseTensor>(),
          egr::Controller::Instance().GenerateUniqueName());
      py::object value_obj_tmp(py::handle(value_obj), true);
      py::object value = value_obj_tmp;
      if (self->tensor.dtype() == paddle::experimental::DataType::FLOAT32) {
        if (!py::isinstance<py::array_t<float>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<float>(value_obj_tmp);
        }
      } else if (self->tensor.dtype() ==
                 paddle::experimental::DataType::FLOAT64) {
        if (!py::isinstance<py::array_t<double>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<double>(value_obj_tmp);
        }
      } else if (self->tensor.dtype() ==
                 paddle::experimental::DataType::INT32) {
        if (!py::isinstance<py::array_t<int32_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int32_t>(value_obj_tmp);
        }
      } else if (self->tensor.dtype() ==
                 paddle::experimental::DataType::INT64) {
        if (!py::isinstance<py::array_t<int64_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int64_t>(value_obj_tmp);
        }
      } else if (self->tensor.dtype() == paddle::experimental::DataType::BOOL) {
        if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp);
        }
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "When assign a numpy.np value to a paddle.Tensor, "
            "the data type of the paddle.Tensor must be bool, "
            "float32, int32 or int64, "
            "please check the type of tensor."));
      }

1107
      if (!value_tensor_tmp.initialized()) {
W
wanghuancoder 已提交
1108 1109 1110
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
        SetTensorFromPyArray(
            static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
1111 1112 1113
            value,
            platform::Place(platform::CUDAPlace(0)),
            false);
W
wanghuancoder 已提交
1114 1115 1116
#else
        SetTensorFromPyArray(
            static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
1117 1118 1119
            value,
            platform::Place(platform::CPUPlace()),
            false);
W
wanghuancoder 已提交
1120 1121 1122 1123
#endif
      } else {
        SetTensorFromPyArray(
            static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
1124 1125 1126
            value,
            value_tensor_tmp.place(),
            false);
W
wanghuancoder 已提交
1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174
      }

      value_tensor = value_tensor_tmp;
    } else {
      py::object value_obj_tmp(py::handle(value_obj), true);
      // convert the value to self data type
      if (py::isinstance<py::float_>(value_obj_tmp) ||
          py::isinstance<py::int_>(value_obj_tmp) ||
          py::isinstance<py::bool_>(value_obj_tmp)) {
        if (self->tensor.dtype() == paddle::experimental::DataType::FLOAT32) {
          attrs["fp32_values"] =
              std::vector<float>{value_obj_tmp.cast<float>()};
        } else if (self->tensor.dtype() ==
                   paddle::experimental::DataType::FLOAT64) {
          attrs["fp64_values"] =
              std::vector<double>{value_obj_tmp.cast<double>()};
        } else if (self->tensor.dtype() ==
                   paddle::experimental::DataType::INT32) {
          attrs["int32_values"] =
              std::vector<int32_t>{value_obj_tmp.cast<int32_t>()};
        } else if (self->tensor.dtype() ==
                   paddle::experimental::DataType::INT64) {
          attrs["int64_values"] =
              std::vector<int64_t>{value_obj_tmp.cast<int64_t>()};
        } else if (self->tensor.dtype() ==
                   paddle::experimental::DataType::BOOL) {
          attrs["bool_values"] = std::vector<int>{value_obj_tmp.cast<bool>()};
        } else {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "When assign a value to a paddle.Tensor, "
              "the data type of the paddle.Tensor must be bool, "
              "float32, int32 or int64, "
              "please check the type of tensor."));
        }
        attrs["shape"] = std::vector<int64_t>{1};

      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Value type error. The assign value allows "
            "numpy.ndarray, integer, float or bool, "
            "but received %s.",
            Py_TYPE(value_obj)));
      }
    }

    {
      // Release gil and do tracing
      py::gil_scoped_release release;
1175
      // use inplace set_value_ operator
J
Jiabin Yang 已提交
1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186
      if (value_tensor.initialized() &&
          (self->tensor.dtype() != value_tensor.dtype())) {
        paddle::small_vector<std::vector<paddle::experimental::Tensor>,
                             egr::kSlotSmallVectorSize>
            tmps = {{self->tensor}, {value_tensor}};
        auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
        self->tensor = egr::EagerAmpAutoCast(
            self->tensor.name(), self->tensor, amp_dtype, "set_value");
        value_tensor = egr::EagerAmpAutoCast(
            value_tensor.name(), value_tensor, amp_dtype, "set_value");
      }
1187 1188
      self->tensor = set_value__dygraph_function(
          self->tensor, value_tensor, {}, {}, {}, attrs);
1189 1190 1191 1192 1193 1194 1195 1196 1197
    }
    if (PyCheckTensor(value_obj)) {
      // pass the stop_gradient from value to tensor.
      // pass stop gradient should be done after CheckInplace in
      // set_value__dygraph_function.
      if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
        egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
      }
W
wanghuancoder 已提交
1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211
    }
  } else {
    auto self_numpy = TensorToPyArray(*self_tensor);
    VLOG(4) << "parse_index is false";
    if (PyCheckTensor(_index)) {
      VLOG(4) << "index is tensor";
      auto index_tensor = static_cast<phi::DenseTensor*>(
          reinterpret_cast<TensorObject*>(_index)->tensor.impl().get());
      auto index_numpy = TensorToPyArray(*index_tensor);
      self_numpy[index_numpy] = py::object(py::handle(value_obj), true);
    } else {
      VLOG(4) << "index is not tensor";
      self_numpy[_index] = py::object(py::handle(value_obj), true);
    }
1212
    if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
1213
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1214 1215 1216 1217
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CUDAPlace(0)),
                           false);
W
wanghuancoder 已提交
1218
#else
1219 1220 1221 1222
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CPUPlace()),
                           false);
W
wanghuancoder 已提交
1223 1224
#endif
    } else {
1225 1226
      SetTensorFromPyArray(
          self_tensor, self_numpy, self->tensor.place(), false);
W
wanghuancoder 已提交
1227 1228
    }
  }
1229 1230
  RETURN_PY_NONE

W
wanghuancoder 已提交
1231 1232 1233
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1234 1235
static PyObject* tensor_register_grad_hook(TensorObject* self,
                                           PyObject* args,
1236 1237 1238 1239 1240
                                           PyObject* kwargs) {
  EAGER_TRY
  int64_t hook_id;
  if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
    VLOG(6) << "Register hook for leaf tensor: " << self->tensor.name();
1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252

    auto autograd_meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);

    if (autograd_meta && !autograd_meta->StopGradient()) {
      if (!autograd_meta->GetMutableGradNode()) {
        VLOG(6) << "Detected NULL grad_node, Leaf tensor should have had "
                   "grad_node with type: GradNodeAccumulation.";
        autograd_meta->SetGradNode(
            std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
      }
    }

1253 1254 1255 1256 1257 1258 1259 1260 1261
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();
    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    auto accumulation_grad_node =
        std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
    hook_id = accumulation_grad_node->RegisterGradientHook(
1262 1263
        rank_info.first,
        rank_info.second,
1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275
        std::make_shared<PyTensorHook>(hook_func));

  } else {
    VLOG(6) << "Register hook for non leaf tensor: " << self->tensor.name();
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();

    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    hook_id = grad_node->RegisterGradientHook(
1276 1277
        rank_info.first,
        rank_info.second,
1278 1279 1280 1281 1282 1283
        std::make_shared<PyTensorHook>(hook_func));
  }
  return ToPyObject(hook_id);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1284 1285
static PyObject* tensor_remove_grad_hook(TensorObject* self,
                                         PyObject* args,
1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297
                                         PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Remove the registered hook for tensor: " << self->tensor.name();
  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);

  int64_t hook_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);

  return ToPyObject(grad_node->RemoveGradientHook(hook_id));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1298 1299
static PyObject* tensor_register_reduce_hook(TensorObject* self,
                                             PyObject* args,
1300 1301 1302 1303 1304 1305
                                             PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Register reduce hook for tensor: " << self->tensor.name();

  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);
1306 1307
  PADDLE_ENFORCE_EQ(egr::egr_utils_api::IsLeafTensor(self->tensor),
                    true,
1308 1309 1310 1311
                    platform::errors::InvalidArgument(
                        "Only can register backward hook for leaf Tensor."));
  PADDLE_ENFORCE_EQ(
      !egr::EagerUtils::unsafe_autograd_meta(self->tensor)->StopGradient(),
1312 1313 1314 1315
      true,
      platform::errors::InvalidArgument(
          "Cannot register backward hook on a Tensor that stop "
          "gradient."));
1316 1317 1318 1319 1320 1321 1322 1323 1324 1325
  PADDLE_ENFORCE(
      grad_node.get() != nullptr,
      paddle::platform::errors::Fatal("Detected NULL grad_node,"
                                      "Leaf tensor should have had grad_node "
                                      "with type: GradNodeAccumulation."));
  PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

  auto accumulation_grad_node =
      std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
  accumulation_grad_node->RegisterReduceHook(
1326
      std::make_shared<PyVoidHook>(hook_func));
1327

1328 1329
  RETURN_PY_NONE

1330 1331 1332
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1333 1334
static PyObject* tensor__set_grad_type(TensorObject* self,
                                       PyObject* args,
J
Jiabin Yang 已提交
1335
                                       PyObject* kwargs) {
1336 1337 1338
  EAGER_TRY
  auto var_type = pybind::CastPyArg2ProtoType(PyTuple_GET_ITEM(args, 0), 0);
  auto grad_tensor =
1339
      egr::EagerUtils::autograd_meta(&self->tensor)->MutableGrad();
1340
  if (var_type == framework::proto::VarType::LOD_TENSOR) {
1341
    grad_tensor->set_impl(std::make_shared<phi::DenseTensor>());
1342
  } else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
1343
    grad_tensor->set_impl(std::make_shared<phi::SelectedRows>());
1344
  }
1345 1346
  RETURN_PY_NONE

1347 1348 1349
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1350 1351
static PyObject* tensor__clear(TensorObject* self,
                               PyObject* args,
J
Jiabin Yang 已提交
1352 1353 1354
                               PyObject* kwargs) {
  EAGER_TRY
  self->tensor.reset();
1355 1356
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1357 1358 1359
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1360 1361
static PyObject* tensor__copy_gradient_from(TensorObject* self,
                                            PyObject* args,
J
Jiabin Yang 已提交
1362 1363 1364
                                            PyObject* kwargs) {
  EAGER_TRY
  auto src = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
1365
  if (self->tensor.initialized()) {
1366 1367
    PADDLE_ENFORCE_EQ(self->tensor.dtype(),
                      src.dtype(),
J
Jiabin Yang 已提交
1368 1369
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different data type with Tensor %s",
1370 1371
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1372 1373 1374 1375 1376
    PADDLE_ENFORCE_EQ(self->tensor.impl()->type_info().id(),
                      src.impl()->type_info().id(),
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different type with Tensor %s, Tensor "
                          "ShareGradientDataWith cannot be performed!",
1377 1378
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1379 1380 1381 1382
  }
  VLOG(6) << "Tensor copy gradient from: " << src.name();
  auto* p_grad = egr::EagerUtils::mutable_grad(self->tensor);
  if (p_grad) {
1383 1384
    PADDLE_ENFORCE_EQ(src.initialized(),
                      true,
J
Jiabin Yang 已提交
1385 1386 1387 1388
                      platform::errors::InvalidArgument(
                          "Tensor %s has not been initialized", src.name()));
    p_grad->set_impl(src.impl());
  }
1389 1390
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1391 1392
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
1393

1394 1395
static PyObject* tensor_method_set_vocab(TensorObject* self,
                                         PyObject* args,
1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424
                                         PyObject* kwargs) {
  EAGER_TRY
  using Vocab = std::unordered_map<std::wstring, int>;
  auto vocab = CastPyArg2Vocab(PyTuple_GET_ITEM(args, 0), 0);
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Vocab>() = vocab;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_set_string_list(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
  using Strings = std::vector<std::string>;
  auto strings = CastPyArg2Strings(PyTuple_GET_ITEM(args, 0), 0);
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Strings>() = strings;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_map_tensor(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE_EQ(
1425 1426
      egr::IsVariableCompatTensor(self->tensor),
      true,
1427 1428 1429 1430 1431 1432 1433 1434 1435
      paddle::platform::errors::Fatal(
          "this method is only effective for VariableCompatTensor"));
  using Vocab = std::unordered_map<std::wstring, int>;
  auto* var_tensor =
      static_cast<const egr::VariableCompatTensor*>(self->tensor.impl().get());
  return ToPyObject(var_tensor->Get<Vocab>());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456
static PyObject* tensor_method_get_non_zero_nums(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
    return ToPyObject(sparse_coo_tensor->nnz());
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
    return ToPyObject(sparse_csr_tensor->nnz());
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526
static PyObject* tensor_method_get_non_zero_indices(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_coo_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCooTensor"));
  auto sparse_coo_tensor =
      std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
  paddle::experimental::Tensor tensor(std::make_shared<phi::DenseTensor>(
      sparse_coo_tensor->non_zero_indices()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_elements(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
    paddle::experimental::Tensor tensor(std::make_shared<phi::DenseTensor>(
        sparse_coo_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
    paddle::experimental::Tensor tensor(std::make_shared<phi::DenseTensor>(
        sparse_csr_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_crows(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
  paddle::experimental::Tensor tensor(
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_crows()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_cols(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
  paddle::experimental::Tensor tensor(
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_cols()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1527 1528
static PyObject* tensor_method_is_dense(TensorObject* self,
                                        PyObject* args,
1529 1530 1531 1532 1533 1534 1535 1536 1537
                                        PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
  return ToPyObject(self->tensor.is_dense_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1538 1539
static PyObject* tensor_method_is_sparse(TensorObject* self,
                                         PyObject* args,
1540 1541
                                         PyObject* kwargs) {
  EAGER_TRY
1542 1543 1544
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1545 1546 1547 1548 1549
  return ToPyObject(self->tensor.is_sparse_coo_tensor() ||
                    self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1550 1551
static PyObject* tensor_method_is_sparse_coo(TensorObject* self,
                                             PyObject* args,
1552 1553
                                             PyObject* kwargs) {
  EAGER_TRY
1554 1555 1556
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1557 1558 1559 1560
  return ToPyObject(self->tensor.is_sparse_coo_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1561 1562
static PyObject* tensor_method_is_sparse_csr(TensorObject* self,
                                             PyObject* args,
1563 1564
                                             PyObject* kwargs) {
  EAGER_TRY
1565 1566 1567
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1568 1569 1570 1571
  return ToPyObject(self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1572 1573
static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
                                             PyObject* args,
1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586
                                             PyObject* kwargs) {
  EAGER_TRY
  auto csr_tensor = self->tensor.to_sparse_csr();
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetStopGradient(
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient());
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetPersistable(
          egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  return ToPyObject(csr_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1587 1588 1589 1590 1591 1592 1593 1594 1595
static PyObject* tensor_method_is_same_shape(TensorObject* self,
                                             PyObject* args,
                                             PyObject* kwargs) {
  EAGER_TRY
  auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
  return ToPyObject(self->tensor.shape() == other.shape());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1596 1597
static PyObject* tensor__inplace_version(TensorObject* self,
                                         PyObject* args,
1598 1599 1600 1601 1602 1603 1604 1605
                                         PyObject* kwargs) {
  EAGER_TRY
  uint32_t inplace_version = self->tensor.current_inplace_version();

  return ToPyObject(inplace_version);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1606 1607
static PyObject* tensor_method_element_size(TensorObject* self,
                                            PyObject* args,
1608 1609 1610 1611 1612 1613 1614 1615
                                            PyObject* kwargs) {
  EAGER_TRY
  uint32_t element_size = framework::DataTypeSize(self->tensor.dtype());

  return ToPyObject(element_size);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1616 1617 1618 1619 1620
static PyObject* tensor__bump_inplace_version(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  self->tensor.bump_inplace_version();
1621
  RETURN_PY_NONE
1622 1623 1624
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1625 1626 1627 1628
static PyObject* tensor_method_is_selected_rows(TensorObject* self,
                                                PyObject* args,
                                                PyObject* kwargs) {
  EAGER_TRY
1629 1630 1631
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1632 1633 1634 1635
  return ToPyObject(self->tensor.is_selected_rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1636 1637
static PyObject* tensor_method_get_rows(TensorObject* self,
                                        PyObject* args,
1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648
                                        PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows"));
  auto selected_rows =
      std::dynamic_pointer_cast<phi::SelectedRows>(self->tensor.impl());
  return ToPyObject(selected_rows->rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1649 1650
static PyObject* tensor_methon_element_size(TensorObject* self,
                                            PyObject* args,
1651 1652 1653 1654 1655 1656
                                            PyObject* kwargs) {
  EAGER_TRY
  return ToPyObject(paddle::experimental::SizeOf(self->tensor.dtype()));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672
static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
  }

  paddle::experimental::Tensor* grad =
      egr::EagerUtils::mutable_grad(self->tensor);
  if (grad && grad->defined() && grad->is_dense_tensor() &&
      grad->initialized()) {
    grad->reset_inplace_version(set_to_zero);
  }
1673 1674
  RETURN_PY_NONE

1675 1676 1677
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1678 1679
static PyObject* tensor_method__share_memory(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
1680 1681 1682
                                             PyObject* kwargs) {
  EAGER_TRY
#ifndef _WIN32
1683 1684
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
W
wanghuancoder 已提交
1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700
                    platform::errors::InvalidArgument(
                        "Sharing memory only support CPU Tensor currently"));
  // 1. get LoDTensor
  auto* t =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl()).get();
  // 2. allocate shared memory
  void* data_ptr = t->data();
  size_t data_size =
      t->numel() *
      framework::SizeOfType(framework::TransToProtoVarType(t->dtype()));
  auto shared_writer_holder =
      memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
  // 3. maintain mmap fd set & backup ipc_name
  const std::string& ipc_name = shared_writer_holder->ipc_name();
  memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
  // 4. copy data & reset holder
1701 1702 1703 1704 1705
  memory::Copy(platform::CPUPlace(),
               shared_writer_holder->ptr(),
               platform::CPUPlace(),
               data_ptr,
               data_size);
W
wanghuancoder 已提交
1706 1707 1708 1709 1710
  t->ResetHolder(shared_writer_holder);
  return ToPyObject(t);
#else
  PADDLE_THROW(platform::errors::PermissionDenied(
      "Sharing memory in Windows OS is not supported currently"));
1711 1712
  RETURN_PY_NONE

W
wanghuancoder 已提交
1713 1714 1715 1716
#endif
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1717 1718
static PyObject* tensor__offset(TensorObject* self,
                                PyObject* args,
1719 1720 1721 1722
                                PyObject* kwargs) {
  EAGER_TRY
  auto t = std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  PADDLE_ENFORCE_EQ(
1723 1724
      t->IsInitialized(),
      true,
1725 1726 1727 1728 1729 1730 1731
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
                                        self->tensor.name()));

  return ToPyObject(t->offset());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1732 1733
static PyObject* tensor__grad_name(TensorObject* self,
                                   PyObject* args,
1734 1735 1736 1737
                                   PyObject* kwargs) {
  EAGER_TRY
  paddle::experimental::Tensor* grad =
      egr::EagerUtils::mutable_grad(self->tensor);
1738 1739
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1740 1741 1742 1743 1744 1745 1746
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));
  return ToPyObject(grad->name());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1747 1748
static PyObject* tensor__grad_value(TensorObject* self,
                                    PyObject* args,
1749 1750 1751 1752
                                    PyObject* kwargs) {
  EAGER_TRY
  paddle::experimental::Tensor* grad =
      egr::EagerUtils::mutable_grad(self->tensor);
1753 1754
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1755 1756 1757 1758 1759
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));

  if (!grad->defined()) {
1760
    RETURN_PY_NONE
1761 1762 1763 1764 1765 1766 1767 1768
  }
  if (grad->is_dense_tensor()) {
    auto* grad_tensor =
        static_cast<paddle::framework::LoDTensor*>(grad->impl().get());
    return ToPyObject(grad_tensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::Fatal(
        "this method is only supported for DenseTensor"));
1769
    RETURN_PY_NONE
1770 1771 1772 1773
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1774 1775
static PyObject* tensor__unset_fake_empty(TensorObject* self,
                                          PyObject* args,
1776 1777 1778 1779
                                          PyObject* kwargs) {
  EAGER_TRY
  paddle::experimental::Tensor* grad =
      egr::EagerUtils::mutable_grad(self->tensor);
1780 1781
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));

  bool is_leaf = egr::egr_utils_api::IsLeafTensor(self->tensor);
  if (is_leaf) {
    std::static_pointer_cast<egr::GradNodeAccumulation>(
        egr::EagerUtils::grad_node(self->tensor))
        ->SetFakeEmpty(false);
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1796
#if defined(PADDLE_WITH_CUDA)
1797 1798
static PyObject* tensor_method__uva(TensorObject* self,
                                    PyObject* args,
1799 1800 1801
                                    PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Running in tensor_method__uva.";
1802 1803
  PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(),
                    true,
W
Weilong Wu 已提交
1804 1805 1806
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "DenseTensor currently."));
1807 1808
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
1809 1810 1811 1812 1813 1814 1815 1816
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "CPU Tensor currently."));
  int device_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
  auto* self_tensor =
      static_cast<paddle::framework::LoDTensor*>(self->tensor.impl().get());
  tensor_uva(self_tensor, device_id);

1817 1818
  RETURN_PY_NONE

1819 1820 1821
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
#endif
J
Jack Zhou 已提交
1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833
static PyObject* tensor_method__is_string_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto string_tensor =
      std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
  if (string_tensor) {
    return ToPyObject(string_tensor->initialized());
  } else {
    return ToPyObject(false);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
1834

1835
PyMethodDef variable_methods[] = {
1836 1837 1838 1839
    {"numpy",
     (PyCFunction)(void (*)(void))tensor_method_numpy,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1840
    {"_is_initialized",
1841
     (PyCFunction)(void (*)(void))tensor_method__is_initialized,
1842 1843
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
1844
    {"_is_dense_tensor_hold_allocation",
1845 1846
     (PyCFunction)(void (*)(
         void))tensor_method__is_dense_tensor_hold_allocation,
1847 1848 1849 1850 1851 1852 1853 1854 1855 1856
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_copy_to",
     (PyCFunction)(void (*)(void))tensor_method__copy_to,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"copy_",
     (PyCFunction)(void (*)(void))tensor_method_copy_,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1857
    {"reconstruct_from_",
1858
     (PyCFunction)(void (*)(void))tensor_method_reconstruct_from_,
1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"retain_grads",
     (PyCFunction)(void (*)(void))tensor_retain_grads,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"clear_gradient",
     (PyCFunction)(void (*)(void))tensor_clear_gradient,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_dense",
     (PyCFunction)(void (*)(void))tensor_method_is_dense,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_zero_grads",
     (PyCFunction)(void (*)(void))tensor__zero_grads,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_share_buffer_to",
     (PyCFunction)(void (*)(void))tensor__share_buffer_to,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1881
    {"_is_shared_buffer_with",
1882
     (PyCFunction)(void (*)(void))tensor__is_shared_buffer_with,
1883 1884
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1885
    {"_share_underline_tensor_to",
1886
     (PyCFunction)(void (*)(void))tensor__share_underline_tensor_to,
1887 1888
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1889
    {"_is_shared_underline_tensor_with",
1890
     (PyCFunction)(void (*)(void))tensor__is_shared_underline_tensor_with,
1891 1892 1893 1894 1895 1896
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"detach",
     (PyCFunction)(void (*)(void))tensor_method_detach,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1897
    {"get_tensor",
1898
     (PyCFunction)(void (*)(void))tensor_method_get_underline_tensor,
1899 1900
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1901 1902
    {"get_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method_get_underline_selected_rows,
1903 1904
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1905 1906 1907 1908
    {"_get_tensor_from_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method__get_tensor_from_selected_rows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jiabin Yang 已提交
1909 1910
    {"_getitem_index_not_tensor",
     (PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor,
1911 1912
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
1913 1914
    {"_getitem_from_offset",
     (PyCFunction)(void (*)(void))tensor__getitem_from_offset,
1915 1916
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
1917 1918
    {"__setitem_eager_tensor__",
     (PyCFunction)(void (*)(void))tensor_method__setitem_eager_tensor,
1919 1920
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1921 1922
    {"_register_grad_hook",
     (PyCFunction)(void (*)(void))tensor_register_grad_hook,
1923 1924 1925 1926 1927 1928
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_remove_grad_hook",
     (PyCFunction)(void (*)(void))tensor_remove_grad_hook,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1929 1930
    {"_register_backward_hook",
     (PyCFunction)(void (*)(void))tensor_register_reduce_hook,
1931 1932 1933 1934 1935 1936 1937 1938 1939 1940
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_set_grad_type",
     (PyCFunction)(void (*)(void))tensor__set_grad_type,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_clear",
     (PyCFunction)(void (*)(void))tensor__clear,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jiabin Yang 已提交
1941 1942
    {"_copy_gradient_from",
     (PyCFunction)(void (*)(void))tensor__copy_gradient_from,
1943 1944
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1945 1946 1947
    /** the methods to adapt old dygraph, will be removed in the future **/
    {"set_string_list",
     (PyCFunction)(void (*)(void))tensor_method_set_string_list,
1948 1949 1950 1951 1952 1953
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"set_vocab",
     (PyCFunction)(void (*)(void))tensor_method_set_vocab,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1954 1955
    {"get_map_tensor",
     (PyCFunction)(void (*)(void))tensor_method_get_map_tensor,
1956 1957
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1958
    /***the method of sparse tensor****/
1959 1960 1961 1962
    {"nnz",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_nums,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990
    {"indices",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_indices,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"values",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_elements,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"crows",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_crows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"cols",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_cols,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse_coo",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse_coo,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse_csr",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1991 1992 1993 1994
    {"is_same_shape",
     (PyCFunction)(void (*)(void))tensor_method_is_same_shape,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1995 1996 1997 1998 1999 2000 2001 2002
    {"to_sparse_csr",
     (PyCFunction)(void (*)(void))tensor_method_to_sparse_csr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"element_size",
     (PyCFunction)(void (*)(void))tensor_method_element_size,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2003
    /***the method of sparse tensor****/
2004 2005 2006 2007
    {"_inplace_version",
     (PyCFunction)(void (*)(void))tensor__inplace_version,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2008 2009
    {"_bump_inplace_version",
     (PyCFunction)(void (*)(void))tensor__bump_inplace_version,
2010 2011
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2012 2013
    {"is_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method_is_selected_rows,
2014 2015 2016 2017 2018 2019 2020 2021 2022 2023
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"rows",
     (PyCFunction)(void (*)(void))tensor_method_get_rows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"element_size",
     (PyCFunction)(void (*)(void))tensor_methon_element_size,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2024 2025
    {"_reset_grad_inplace_version",
     (PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version,
2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_share_memory",
     (PyCFunction)(void (*)(void))tensor_method__share_memory,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_offset",
     (PyCFunction)(void (*)(void))tensor__offset,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_grad_name",
     (PyCFunction)(void (*)(void))tensor__grad_name,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_grad_value",
     (PyCFunction)(void (*)(void))tensor__grad_value,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_unset_fake_empty",
     (PyCFunction)(void (*)(void))tensor__unset_fake_empty,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2048
#if defined(PADDLE_WITH_CUDA)
2049 2050 2051 2052
    {"_tensor_uva",
     (PyCFunction)(void (*)(void))tensor_method__uva,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2053
#endif
2054 2055
    {NULL, NULL, 0, NULL}};

J
Jack Zhou 已提交
2056 2057 2058 2059
// variable_methods for core.eager.StringTensor
PyMethodDef string_tensor_variable_methods[] = {
    {"numpy",
     (PyCFunction)(void (*)(void))tensor_method_numpy_for_string_tensor,
2060 2061
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2062 2063
    {"_is_initialized",
     (PyCFunction)(void (*)(void))tensor_method__is_initialized,
2064 2065
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2066
    {"_is_string_tensor_hold_allocation",
2067 2068
     (PyCFunction)(void (*)(
         void))tensor_method__is_string_tensor_hold_allocation,
2069 2070
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2071 2072 2073
    // TODO(zhoushunjie): Need to add _copy_to, copy_ for StringTensor.
    {NULL, NULL, 0, NULL}};

2074 2075
}  // namespace pybind
}  // namespace paddle