eager_method.cc 82.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// disable numpy compile error
12 13 14 15 16 17

#if defined(_MSC_VER)
#include <BaseTsd.h>
typedef SSIZE_T ssize_t;
#endif

18
#include <Python.h>
19 20 21 22
// Avoid a problem with copysign defined in pyconfig.h on Windows.
#ifdef copysign
#undef copysign
#endif
23 24

#include <string>
25
#include <unordered_map>
26 27
#include <vector>

28
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
29
#include "paddle/fluid/eager/api/all.h"
J
Jiabin Yang 已提交
30
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
31
#include "paddle/fluid/eager/autograd_meta.h"
32 33
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
34
#include "paddle/fluid/eager/utils.h"
35
#include "paddle/fluid/framework/convert_utils.h"
36
#include "paddle/fluid/framework/string_array.h"
37 38 39 40 41 42
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
J
Jiabin Yang 已提交
43
#include "paddle/fluid/pybind/slice_utils.h"
44
#include "paddle/fluid/pybind/uva_utils.h"
45 46 47 48
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
49 50
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
W
wanghuancoder 已提交
51
#include "pybind11/detail/internals.h"
52 53
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
W
wanghuancoder 已提交
54
#pragma GCC diagnostic ignored "-Wmissing-field-initializers"
J
Jiabin Yang 已提交
55
#include "paddle/fluid/eager/amp_utils.h"
56
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
J
Jiabin Yang 已提交
57
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
W
wanghuancoder 已提交
58
#include "paddle/fluid/framework/python_headers.h"
W
wanghuancoder 已提交
59
#include "paddle/fluid/memory/allocation/mmap_allocator.h"
W
wanghuancoder 已提交
60
#include "paddle/fluid/pybind/tensor_py.h"
W
wanghuancoder 已提交
61
#include "paddle/phi/core/ddim.h"
62
#include "paddle/phi/core/tensor_utils.h"
63
#include "paddle/phi/kernels/funcs/math_function.h"
J
Jiabin Yang 已提交
64

65 66 67
namespace paddle {
namespace pybind {

68 69
extern void InitTensorWithNumpyValue(TensorObject* self,
                                     const pybind11::object& array,
70
                                     const paddle::platform::Place& place,
71
                                     bool zero_copy);
72

73
extern PyTypeObject* p_tensor_type;
74

J
Jiabin Yang 已提交
75 76 77
Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) {
  if (PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type))) {
    VLOG(6) << "Call GetSliceIndexFromTensor in Eager";
78
    paddle::Tensor tensor = CastPyArg2Tensor(obj, 0);
J
Jiabin Yang 已提交
79
    PADDLE_ENFORCE_EQ(
80 81
        tensor.initialized(),
        true,
J
Jiabin Yang 已提交
82 83 84 85 86 87 88 89
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in slice, however we got "
            "uninitialized tensor %s, please check your code.",
            tensor.name()));
    return GetSliceIndexFromTensor((*static_cast<phi::DenseTensor*>(
        CastPyArg2Tensor(obj, 0).impl().get())));
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
90
        "We should only get paddle::Tensor or VarBase in this "
J
Jiabin Yang 已提交
91 92 93 94
        "method, when you reach this means we got another type index."));
  }
}

95 96
static PyObject* tensor_method_numpy(TensorObject* self,
                                     PyObject* args,
97 98
                                     PyObject* kwargs) {
  EAGER_TRY
W
wanghuancoder 已提交
99 100 101 102 103 104 105 106 107
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl()) {
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
108 109 110 111 112
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_FLOAT_),
        1,
        py_dims,
        py_strides,
        nullptr,
W
wanghuancoder 已提交
113 114 115 116 117
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }
118 119
  auto tensor_dims = self->tensor.shape();
  auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type());
120
  auto sizeof_dtype = phi::SizeOf(self->tensor.type());
121 122 123 124 125 126 127 128
  Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
  Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
  size_t numel = 1;
  for (int i = tensor_dims.size() - 1; i >= 0; --i) {
    py_dims[i] = static_cast<size_t>(tensor_dims[i]);
    py_strides[i] = sizeof_dtype * numel;
    numel *= py_dims[i];
  }
W
wanghuancoder 已提交
129

130
  PyObject* array = api.PyArray_NewFromDescr_(
131 132 133 134 135 136
      api.PyArray_Type_,
      api.PyArray_DescrFromType_(numpy_dtype),
      tensor_dims.size(),
      py_dims,
      py_strides,
      nullptr,
137 138 139 140
      pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
          pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
      nullptr);

W
wanghuancoder 已提交
141
  if (!self->tensor.impl()->initialized()) {
142 143 144 145
    if (tensor_dims.size() == 0) {
      py_dims[0] = 0;
      py_strides[0] = 0;
      PyObject* array = api.PyArray_NewFromDescr_(
146 147 148 149 150 151
          api.PyArray_Type_,
          api.PyArray_DescrFromType_(numpy_dtype),
          1,
          py_dims,
          py_strides,
          nullptr,
152 153 154 155 156
          pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
              pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
          nullptr);
      return array;
    }
W
wanghuancoder 已提交
157 158 159
    return array;
  }

160
  if (self->tensor.is_cpu() || self->tensor.is_gpu_pinned()) {
W
wanghuancoder 已提交
161
    eager_gil_scoped_release guard;
162
    platform::CPUPlace place;
163 164 165 166
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
167 168
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
169 170 171 172 173

      // deep copy
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
174 175 176
          place,
          dense_tensor->data(),
          sizeof_dtype * numel);
177 178 179 180 181 182 183 184
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      // deep copy
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
185 186 187
          place,
          dense_tensor->data(),
          sizeof_dtype * numel);
188 189
    }

190
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
191
  } else if (self->tensor.is_gpu()) {
W
wanghuancoder 已提交
192
    eager_gil_scoped_release guard;
193 194 195 196 197
#if defined(PADDLE_WITH_CUDA)
    gpuMemcpyKind kind = cudaMemcpyDeviceToHost;
#elif defined(PADDLE_WITH_HIP)
    gpuMemcpyKind kind = hipMemcpyDeviceToHost;
#endif
198 199 200 201
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
202 203
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
204
      paddle::platform::GpuMemcpySync(
205 206
          pybind11::detail::array_proxy(array)->data,
          dense_tensor->data(),
207
          phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel(),
208
          kind);
209 210 211 212 213
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      paddle::platform::GpuMemcpySync(
214 215
          pybind11::detail::array_proxy(array)->data,
          dense_tensor->data(),
216
          phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel(),
217
          kind);
218
    }
219
#endif
C
Chen Weihang 已提交
220 221 222 223 224 225 226
#if defined(PADDLE_WITH_XPU)
  } else if (self->tensor.is_xpu()) {
    platform::CPUPlace place;
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
227 228
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
C
Chen Weihang 已提交
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
          dense_tensor->place(),
          dense_tensor->data(),
          sizeof_dtype * numel);
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
      paddle::memory::Copy(
          place,
          reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data),
          dense_tensor->place(),
          dense_tensor->data(),
          sizeof_dtype * numel);
    }
#endif
247 248
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  } else if (self->tensor.is_custom_device()) {
W
wanghuancoder 已提交
249
    eager_gil_scoped_release guard;
250 251 252 253
    if (self->tensor.is_selected_rows()) {
      VLOG(6) << "Getting SelectedRows's numpy value";
      auto* selected_rows =
          static_cast<phi::SelectedRows*>(self->tensor.impl().get());
254 255
      auto* dense_tensor =
          static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
256 257 258 259
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
          ->MemoryCopyD2H(
              pybind11::detail::array_proxy(array)->data,
              dense_tensor->data(),
260
              phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel());
261 262 263 264
    } else {
      VLOG(6) << "Getting DenseTensor's numpy value";
      auto dense_tensor =
          std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
265 266
      // TODO(qili93): temporary for ascned npu performance to be removed along
      // with npu_identity op
267
      paddle::Tensor temp_tensor(std::make_shared<phi::DenseTensor>());
268 269 270 271 272
      if (dense_tensor->storage_properties_initialized()) {
        temp_tensor = npu_identity_ad_func(self->tensor, -1);
        dense_tensor =
            std::dynamic_pointer_cast<phi::DenseTensor>(temp_tensor.impl());
      }
273 274 275 276
      phi::DeviceManager::GetDeviceWithPlace(self->tensor.place())
          ->MemoryCopyD2H(
              pybind11::detail::array_proxy(array)->data,
              dense_tensor->data(),
277
              phi::SizeOf(dense_tensor->dtype()) * dense_tensor->numel());
278 279
    }
#endif
280 281 282
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Tensor.numpy() only support cpu tensor."));
283
    RETURN_PY_NONE
284 285 286 287 288 289
  }

  return array;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jack Zhou 已提交
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
static PyObject* tensor_method_numpy_for_string_tensor(TensorObject* self,
                                                       PyObject* args,
                                                       PyObject* kwargs) {
  EAGER_TRY
  auto& api = pybind11::detail::npy_api::get();
  if (!self->tensor.impl() || !self->tensor.impl()->initialized()) {
    VLOG(6) << "The StringTensor is uninitialized. Return the empty string "
               "numpy array.";
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
    py_dims[0] = 0;
    py_strides[0] = 0;

    PyObject* array = api.PyArray_NewFromDescr_(
        api.PyArray_Type_,
305 306 307 308 309
        api.PyArray_DescrFromType_(pybind11::detail::npy_api::NPY_UNICODE_),
        1,
        py_dims,
        py_strides,
        nullptr,
J
Jack Zhou 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
        nullptr);
    return array;
  }

  if (self->tensor.is_cpu()) {
    VLOG(6) << "Getting StringTensor's numpy value";
    auto string_tensor =
        std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
    const auto* st_ptr = string_tensor->data();
    auto numel = self->tensor.numel();
    auto tensor_dims = self->tensor.shape();
W
wanghuancoder 已提交
323 324
    // Get the max unicode length of StringTensor to create numpy unicode
    // string array.
J
Jack Zhou 已提交
325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347
    auto* longest_pstring = std::max_element(
        st_ptr, st_ptr + numel, [](const auto& a, const auto& b) {
          auto a_unicode_len =
              phi::strings::GetUnicodeStrLen(a.data(), a.size());
          auto b_unicode_len =
              phi::strings::GetUnicodeStrLen(b.data(), b.size());
          return a_unicode_len < b_unicode_len;
        });
    size_t max_unicode_length = phi::strings::GetUnicodeStrLen(
        longest_pstring->data(), longest_pstring->size());
    max_unicode_length = (max_unicode_length == 0) ? 1 : max_unicode_length;
    VLOG(6) << "The max unicode length is " << max_unicode_length;
    auto sp = std::make_unique<uint32_t[]>(max_unicode_length * numel);
    auto py_array_data = sp.get();
    memset(py_array_data, 0, max_unicode_length * numel * sizeof(uint32_t));
    for (int64_t i = 0; i < numel; ++i) {
      auto curr_unicode_len =
          phi::strings::GetUnicodeStrLen(st_ptr[i].data(), st_ptr[i].size());
      phi::strings::GetUnicodeStr(st_ptr[i].data(),
                                  py_array_data + i * max_unicode_length,
                                  curr_unicode_len);
    }
    py::array array(py::dtype("U" + std::to_string(max_unicode_length)),
348 349 350
                    tensor_dims,
                    {},
                    py_array_data);
J
Jack Zhou 已提交
351 352 353 354
    return array.release().ptr();
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "StringTensor.numpy() only support cpu tensor."));
355
    RETURN_PY_NONE
J
Jack Zhou 已提交
356 357 358 359
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

360 361 362 363
static PyObject* tensor_method__is_initialized(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
364
  return ToPyObject(self->tensor.initialized());
365 366 367
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
368 369 370 371 372 373 374 375 376 377 378 379 380 381
static PyObject* tensor_method__is_dense_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto dense_tensor =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  if (dense_tensor) {
    return ToPyObject(dense_tensor->IsInitialized());
  } else {
    return ToPyObject(false);
  }

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

382
static void IncreaseTensorReferenceCountUntilCopyComplete(
383
    const paddle::Tensor& tensor, const platform::Place& place) {
384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399
  auto place_ = platform::is_gpu_place(place) ? place : tensor.place();

  auto tracer = egr::Controller::Instance().GetCurrentTracer();
  auto gc = tracer->MutableGarbageCollectorIfNotExists(place_);

  // Note(dev): This is an empty callback, the only way is to "reference"
  // inner memory Holder, so it will not be destructed until the kernels
  // launched at current stream of given place is finished, such as
  // CUDAPinned Mem -> CUDA by cudamemcpyAsync.
  auto callback = [tensor, place_]() {
    VLOG(3) << "Run callback of Tensor:" << tensor.name() << " at place "
            << place_;
  };
  gc->DirectClearCallback(callback);
}

400 401
static PyObject* tensor_method__copy_to(TensorObject* self,
                                        PyObject* args,
402 403
                                        PyObject* kwargs) {
  EAGER_TRY
404 405
  auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
406
  paddle::Tensor cp_tensor;
W
wanghuancoder 已提交
407 408 409 410 411 412 413 414 415 416
  {
    eager_gil_scoped_release guard;
    cp_tensor = self->tensor.copy_to(place, blocking);
    if (!blocking) {
      IncreaseTensorReferenceCountUntilCopyComplete(self->tensor, place);
    }
    egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
    egr::EagerUtils::autograd_meta(&cp_tensor)
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
417
  }
418 419 420 421
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

422 423
static PyObject* tensor_method_cpu(TensorObject* self,
                                   PyObject* args,
424 425
                                   PyObject* kwargs) {
  EAGER_TRY
426
  paddle::Tensor cp_tensor;
W
wanghuancoder 已提交
427 428 429 430 431 432 433 434
  {
    eager_gil_scoped_release guard;
    cp_tensor = self->tensor.copy_to(phi::CPUPlace(), true);
    egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
    egr::EagerUtils::autograd_meta(&cp_tensor)
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  }
435 436 437 438
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

439 440 441 442
static PyObject* tensor_method_reconstruct_from_(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
443
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
444
  std::string orig_name = self->tensor.name();
445 446
  VLOG(6) << "Start Reconstructing Tensor from" << src_tensor.name() << " to "
          << orig_name;
447
  self->tensor = src_tensor;
448 449

  // Recover source name
450
  self->tensor.set_name(orig_name);
451 452

  VLOG(6) << "Finished Reconstructing Tensor from" << src_tensor.name()
453
          << " to " << self->tensor.name();
454 455
  RETURN_PY_NONE

456 457 458
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

459 460
static PyObject* tensor_method_copy_(TensorObject* self,
                                     PyObject* args,
461 462
                                     PyObject* kwargs) {
  EAGER_TRY
463
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
464
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
465
  VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to "
466
          << self->tensor.name();
467
  if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
468
    eager_gil_scoped_release guard;
469
    egr::EagerUtils::autograd_meta(&(self->tensor))
470 471
        ->SetStopGradient(
            egr::EagerUtils::autograd_meta(&(src_tensor))->StopGradient());
472
    egr::EagerUtils::autograd_meta(&(self->tensor))
473 474
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(src_tensor))->Persistable());
475
    if (src_tensor.initialized()) {
C
Chen Weihang 已提交
476
      self->tensor.copy_(src_tensor, src_tensor.place(), blocking);
477 478 479
    }
  } else {
    if (src_tensor.initialized()) {
W
wanghuancoder 已提交
480
      eager_gil_scoped_release guard;
C
Chen Weihang 已提交
481
      self->tensor.copy_(src_tensor, self->tensor.place(), blocking);
482
    }
483 484
  }

485
  VLOG(6) << "Finish Copy Tensor " << src_tensor.name() << " to "
486
          << self->tensor.name();
487 488
  RETURN_PY_NONE

489 490 491
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

492 493 494 495
static PyObject* tensor_method_clone(TensorObject* self,
                                     PyObject* args,
                                     PyObject* kwargs) {
  EAGER_TRY
496
  paddle::Tensor out;
W
wanghuancoder 已提交
497 498 499 500 501 502 503 504 505
  {
    eager_gil_scoped_release guard;
    PADDLE_ENFORCE_EQ(
        self->tensor.initialized(),
        true,
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in clone, however we got "
            "uninitialized tensor %s, please check your code.",
            self->tensor.name()));
506

W
wanghuancoder 已提交
507 508
    out = assign_ad_func(self->tensor);
  }
509 510 511 512
  return ToPyObject(out);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

513 514
static PyObject* tensor_retain_grads(TensorObject* self,
                                     PyObject* args,
515
                                     PyObject* kwargs) {
516
  EAGER_TRY
517
  if (egr::Controller::Instance().HasGrad()) {
W
wanghuancoder 已提交
518
    eager_gil_scoped_release guard;
519
    auto meta = egr::EagerUtils::autograd_meta(&(self->tensor));
520
    if (!meta->GetMutableGradNode()) {
521
      VLOG(6) << "Make grad node of tensor: " << self->tensor.name()
522
              << "become accumulation node";
523
      meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
524
    }
525
    egr::egr_utils_api::RetainGradForTensor(self->tensor);
526
  }
527 528
  RETURN_PY_NONE

529 530 531
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

532 533
static PyObject* tensor_clear_gradient(TensorObject* self,
                                       PyObject* args,
534
                                       PyObject* kwargs) {
535
  EAGER_TRY
536
  VLOG(4) << "ClearGradient " << self->tensor.name();
537

538 539 540
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
J
Jiabin Yang 已提交
541
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
542 543
  }

544
  paddle::Tensor* grad;
J
Jiabin Yang 已提交
545 546
  bool is_leaf = egr::egr_utils_api::IsLeafTensor(self->tensor);
  if (is_leaf) {
547 548 549 550 551 552
    grad = egr::EagerUtils::mutable_grad(self->tensor);
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
553
  } else {
554
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
555
    grad = meta->MutableGrad();
556 557
  }

558
  if (grad->impl()) {
W
wanghuancoder 已提交
559
    eager_gil_scoped_release guard;
560 561 562 563 564 565 566 567 568 569
    if (grad->is_selected_rows()) {
      auto selected_rows =
          std::dynamic_pointer_cast<phi::SelectedRows>(grad->impl());
      if (selected_rows->mutable_value()->IsInitialized()) {
        selected_rows->mutable_rows()->clear();
        selected_rows->mutable_value()->clear();
      }
    } else if (grad->is_dense_tensor()) {
      if (grad->initialized()) {
        if (set_to_zero) {
570 571 572 573
          auto* grad_t = static_cast<phi::DenseTensor*>(grad->impl().get());
          auto* dev_ctx =
              platform::DeviceContextPool::Instance().Get(grad_t->place());
          phi::funcs::set_constant(*dev_ctx, grad_t, 0.0);
J
Jiabin Yang 已提交
574 575 576 577 578
          if (is_leaf) {
            std::static_pointer_cast<egr::GradNodeAccumulation>(
                egr::EagerUtils::grad_node(self->tensor))
                ->SetFakeEmpty(true);
          }
579 580 581 582 583 584 585
        } else {
          VLOG(4) << "Gradient of " << self->tensor.name()
                  << " is initialized, will be released.";
          auto dense_tensor =
              std::dynamic_pointer_cast<phi::DenseTensor>(grad->impl());
          dense_tensor->MoveMemoryHolder();
        }
586 587
      }
    }
588
  }
589

590 591
  RETURN_PY_NONE

592 593 594
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

595 596
static PyObject* tensor__zero_grads(TensorObject* self,
                                    PyObject* args,
597
                                    PyObject* kwargs) {
598
  EAGER_TRY
599
  VLOG(4) << "ZeroGrads " << self->tensor.name();
600

601
  if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
W
wanghuancoder 已提交
602
    eager_gil_scoped_release guard;
603
    // Add RetainGrad as PostHook to AccumulationNode
604
    paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
605 606 607 608 609 610
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
    if (grad->initialized()) {
611 612 613 614 615 616 617
      if (grad->is_dense_tensor()) {
        auto* t = static_cast<phi::DenseTensor*>(grad->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl());
      }
618
    }
619
  } else {
W
wanghuancoder 已提交
620
    eager_gil_scoped_release guard;
621
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
622
    if (meta->MutableGrad()->initialized()) {
623 624 625 626 627 628 629 630 631
      if (meta->MutableGrad()->is_dense_tensor()) {
        auto* t =
            static_cast<phi::DenseTensor*>(meta->MutableGrad()->impl().get());
        auto* dev_ctx = platform::DeviceContextPool::Instance().Get(t->place());
        phi::funcs::set_constant(*dev_ctx, t, 0.0);
      } else {
        meta->MutableGrad()->set_impl(
            paddle::experimental::zeros_like(*(meta->MutableGrad())).impl());
      }
632
    }
633 634
  }

635 636
  RETURN_PY_NONE

637 638 639
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

640 641
static PyObject* tensor__share_buffer_to(TensorObject* self,
                                         PyObject* args,
642 643
                                         PyObject* kwargs) {
  EAGER_TRY
644
  paddle::Tensor* dst_ptr =
645
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
646 647
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
648 649 650
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
651
                        self->tensor.name()));
652
  auto* src_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
653 654 655
  if (!dst_ptr->defined()) {
    dst_ptr->set_impl(std::make_shared<phi::DenseTensor>());
  }
656
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
B
Baibaifan 已提交
657
  dst_tensor->ShareBufferWith(*src_tensor);
658
  dst_tensor->ShareDataTypeWith(*src_tensor);
659 660
  RETURN_PY_NONE

661 662 663
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

664 665 666 667
static PyObject* tensor__is_shared_buffer_with(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
668
  paddle::Tensor* dst_ptr =
669
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
670 671
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
672 673 674
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
675
                        self->tensor.name()));
676
  bool res = false;
677
  if (!self->tensor.defined() || !dst_ptr->defined()) {
678 679
    return ToPyObject(res);
  }
680 681
  auto* self_ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  auto dst_tensor = static_cast<phi::DenseTensor*>(dst_ptr->impl().get());
682 683 684 685 686
  res = dst_tensor->IsSharedBufferWith(*self_ptr);
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

687 688 689 690
static PyObject* tensor__share_underline_tensor_to(TensorObject* self,
                                                   PyObject* args,
                                                   PyObject* kwargs) {
  EAGER_TRY
691
  paddle::Tensor* src_ptr =
692
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
693 694
  PADDLE_ENFORCE_EQ(self->tensor.initialized(),
                    true,
695 696 697
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
698 699
                        self->tensor.name()));
  src_ptr->set_impl(self->tensor.impl());
700 701
  RETURN_PY_NONE

702 703 704
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

705 706 707 708
static PyObject* tensor__is_shared_underline_tensor_with(TensorObject* self,
                                                         PyObject* args,
                                                         PyObject* kwargs) {
  EAGER_TRY
709
  paddle::Tensor src_tensor = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
710 711
  PADDLE_ENFORCE_EQ(src_tensor.initialized(),
                    true,
712 713 714 715 716
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
                        src_tensor.name()));
  bool res = false;
717
  if (!self->tensor.defined() || !src_tensor.defined()) {
718 719
    return ToPyObject(res);
  }
720
  res = (self->tensor.impl().get() == src_tensor.impl().get());
721 722 723 724
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

725 726
static PyObject* tensor_method_detach(TensorObject* self,
                                      PyObject* args,
727 728
                                      PyObject* kwargs) {
  EAGER_TRY
729
  PADDLE_ENFORCE_EQ(
730
      self->tensor.defined(),
731
      true,
732
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
733
                                        self->tensor.name()));
734

735
  PyObject* obj = p_tensor_type->tp_alloc(p_tensor_type, 0);
736
  if (obj) {
737
    auto v = reinterpret_cast<TensorObject*>(obj);
738
    new (&(v->tensor)) paddle::Tensor();
739 740 741 742
    v->tensor.set_impl(self->tensor.impl());
    v->tensor.set_name(egr::Controller::Instance().GenerateUniqueName());
    auto autograd_meta_src = egr::EagerUtils::autograd_meta(&(self->tensor));
    auto autograd_meta = egr::EagerUtils::autograd_meta(&(v->tensor));
743 744 745 746 747 748 749 750 751 752
    autograd_meta->SetPersistable(autograd_meta_src->Persistable());
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "tp_alloc return null, can not new a PyObject."));
  }

  return obj;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

753 754 755 756
static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
757
  if (!self->tensor.defined()) {
758 759 760
    // The original `get_tensor` method of Variable will create a empty tensor
    phi::DenseTensor empty_tensor;
    return ToPyObject(&empty_tensor);
761
  }
762
  if (self->tensor.is_dense_tensor()) {
763
    auto* tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
764 765 766
    VLOG(6) << "tensor: " << tensor->IsInitialized();
    return ToPyObject(tensor);
  } else {
767
    RETURN_PY_NONE
768 769 770 771
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

772 773 774 775 776
static PyObject* tensor_method_get_underline_selected_rows(TensorObject* self,
                                                           PyObject* args,
                                                           PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
777
    RETURN_PY_NONE
778 779 780 781 782 783
  }
  if (self->tensor.is_selected_rows()) {
    auto* selected_rows =
        static_cast<phi::SelectedRows*>(self->tensor.impl().get());
    return ToPyObject(selected_rows);
  } else {
784
    RETURN_PY_NONE
785 786 787 788
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

789 790 791 792 793 794 795 796 797 798 799 800 801 802
static PyObject* tensor_method__get_tensor_from_selected_rows(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows."));

  auto* selected_rows =
      static_cast<phi::SelectedRows*>(self->tensor.impl().get());

  PADDLE_ENFORCE(
      selected_rows->initialized(),
      paddle::platform::errors::Fatal("SelectedRows must be initialized."));

803 804
  auto* dense_tensor =
      static_cast<phi::DenseTensor*>(selected_rows->mutable_value());
L
Leo Chen 已提交
805
  VLOG(4) << "dense_tensor: " << dense_tensor->IsInitialized();
806

807
  auto t = paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
808 809 810 811 812 813 814
  t.set_impl(std::make_shared<phi::DenseTensor>(*dense_tensor));

  return ToPyObject(t);

  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jiabin Yang 已提交
815 816 817
static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
818
  EAGER_TRY
J
Jiabin Yang 已提交
819 820 821 822 823 824
  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  VLOG(4) << "Call _getitem_index_not_tensor";
  std::vector<int> slice_axes, slice_starts, slice_ends, slice_strides,
      decrease_axis, none_axes, infer_flags, list_select_idxs;
  // if index is a list, list_select_flag will be true
  bool list_select_flag = false;
825 826
  // Note(0x45f): Using defined() instead of initialized()
  // to support slice tensor which shape like [0, 0, 0].
J
Jiabin Yang 已提交
827
  PADDLE_ENFORCE_EQ(
828
      self->tensor.defined(),
829
      true,
J
Jiabin Yang 已提交
830 831 832 833 834
      platform::errors::InvalidArgument(
          "tensor %s has not been initialized, we can only slice initialized "
          "tensor please init it first with numpy or other tensor.",
          self->tensor.name()));
  auto tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
835 836 837 838 839 840 841 842 843 844 845
  ParseIndexingSlice(tensor,
                     _index,
                     &slice_axes,
                     &slice_starts,
                     &slice_ends,
                     &slice_strides,
                     &decrease_axis,
                     &none_axes,
                     &infer_flags,
                     &list_select_idxs,
                     &list_select_flag);
J
Jiabin Yang 已提交
846

847 848 849 850
  auto out =
      slice_axes.empty() && !list_select_flag
          ? self->tensor
          : paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
J
Jiabin Yang 已提交
851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866

  if (!slice_axes.empty()) {
    framework::AttributeMap attrs = {{"axes", slice_axes},
                                     {"starts", slice_starts},
                                     {"ends", slice_ends},
                                     {"infer_flags", infer_flags},
                                     {"decrease_axis", decrease_axis}};
    std::string op_type = "slice";
    for (auto stride : slice_strides) {
      if (stride != 1) {
        op_type = "strided_slice";
        attrs.insert({"strides", slice_strides});
        attrs.erase("decrease_axis");
        break;
      }
    }
867 868 869 870 871 872
    std::vector<int64_t> slice_axes_tmp(slice_axes.begin(), slice_axes.end());
    std::vector<int64_t> infer_flags_tmp(infer_flags.begin(),
                                         infer_flags.end());
    std::vector<int64_t> decrease_axis_tmp(decrease_axis.begin(),
                                           decrease_axis.end());

J
Jiabin Yang 已提交
873
    if (op_type == "slice") {
W
wanghuancoder 已提交
874
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
875 876 877 878 879 880
      out = slice_ad_func(self->tensor,
                          slice_axes_tmp,
                          slice_starts,
                          slice_ends,
                          infer_flags_tmp,
                          decrease_axis_tmp);
J
Jiabin Yang 已提交
881
    } else if (op_type == "strided_slice") {
W
wanghuancoder 已提交
882
      eager_gil_scoped_release guard;
J
Jiabin Yang 已提交
883
      out = strided_slice_ad_func(
884
          self->tensor, slice_axes, slice_starts, slice_ends, slice_strides);
J
Jiabin Yang 已提交
885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Slice is only support slice and strided_slice, but we got %s which "
          "is impossible, please check your code first or contact us by "
          "issue. ",
          op_type));
    }
  }

  if (!none_axes.empty()) {
    // Deal with cases when all axes are decreased.
    // After slice, the shape of out is [1], which should have been
    // [], but Paddle doesn't support scalar.
    // In order to ensure the correctness of the final shape of out,
    // one dimension of out needs to be decreased.
    // For example:
    // # x.shape: (2,3,4)
    // out = x[0, 1, 1, None] # out.shape : (1)
    if (static_cast<int>(decrease_axis.size()) == tensor->dims().size()) {
      none_axes.pop_back();
    }
    if (!none_axes.empty()) {
907
      paddle::Tensor new_out;
W
wanghuancoder 已提交
908 909 910 911 912 913 914 915 916 917 918 919
      {
        eager_gil_scoped_release guard;
        // Deal with cases that decrease_axes is not empty
        // For example:
        // # x.shape: (2,3,4)
        // out = x[0, 0:2, None] # out.shape : (2, 1, 4)
        for (auto& axis : none_axes) {
          int len = 0;
          for (int da : decrease_axis) {
            if (da < axis) {
              len++;
            }
J
Jiabin Yang 已提交
920
          }
W
wanghuancoder 已提交
921
          axis -= len;
J
Jiabin Yang 已提交
922
        }
W
wanghuancoder 已提交
923
        new_out = unsqueeze_ad_func(out, none_axes);
J
Jiabin Yang 已提交
924 925 926 927 928 929 930
      }
      return ToPyObject(new_out);
    }
  }

  // the index is a list
  if (list_select_flag) {
W
wanghuancoder 已提交
931
    eager_gil_scoped_release guard;
932 933
    auto select_index =
        paddle::Tensor(egr::Controller::Instance().GenerateUniqueName());
J
Jiabin Yang 已提交
934
    auto idx_tensor = std::make_shared<phi::DenseTensor>();
W
wanghuancoder 已提交
935
    select_index.set_impl(idx_tensor);
J
Jiabin Yang 已提交
936 937
    auto* dev_ctx = platform::DeviceContextPool::Instance().Get(
        egr::Controller::Instance().GetExpectedPlace());
938 939
    paddle::framework::TensorFromVector(
        list_select_idxs, *dev_ctx, idx_tensor.get());
J
Jiabin Yang 已提交
940
    framework::AttributeMap attrs = {{"dim", 0}};
J
Jiabin Yang 已提交
941
    out = index_select_ad_func(self->tensor, select_index, 0);
J
Jiabin Yang 已提交
942 943 944
  }

  return ToPyObject(out);
945 946 947
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

948 949
static PyObject* tensor__getitem_from_offset(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
950 951 952
                                             PyObject* kwargs) {
  EAGER_TRY
  auto ptr = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
953 954 955
  PADDLE_ENFORCE_NOT_NULL(ptr,
                          platform::errors::InvalidArgument(
                              "%s is not a DenseTensor.", self->tensor.name()));
W
wanghuancoder 已提交
956 957
  const auto& tensor = *ptr;
  PADDLE_ENFORCE_EQ(
958 959
      tensor.IsInitialized(),
      true,
W
wanghuancoder 已提交
960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976
      platform::errors::InvalidArgument(
          "Tensor of %s is Empty, please check if it has no data.",
          self->tensor.name()));

  const auto& tensor_dims = tensor.dims();

  std::vector<size_t> dims(tensor_dims.size());
  std::vector<size_t> strides(tensor_dims.size());

  size_t numel = 1;
  for (int i = tensor_dims.size() - 1; i >= 0; --i) {
    strides[i] = numel;
    dims[i] = static_cast<size_t>(tensor_dims[i]);
    numel *= dims[i];
  }
  size_t offset = 0;
  if (PyTuple_Size(args) == 0) {
977 978
    PADDLE_ENFORCE_EQ(numel,
                      1,
W
wanghuancoder 已提交
979 980 981 982 983 984
                      platform::errors::InvalidArgument(
                          "only one element tensors can be converted to Python "
                          "scalars when no input coordinates"));
  } else if (PyTuple_Size(args) == 1) {
    offset = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
    PADDLE_ENFORCE_LT(
985 986
        offset,
        numel,
W
wanghuancoder 已提交
987 988 989
        platform::errors::InvalidArgument(
            "index %d is out of bounds for size %d", offset, numel));
  } else {
990 991
    PADDLE_ENFORCE_EQ(PyTuple_Size(args),
                      dims.size(),
W
wanghuancoder 已提交
992 993 994 995 996 997
                      platform::errors::InvalidArgument(
                          "incorrect number of indices for Tensor"));

    for (Py_ssize_t i = 0; i < PyTuple_Size(args); ++i) {
      size_t index = CastPyArg2AttrLong(PyTuple_GET_ITEM(args, i), i);
      PADDLE_ENFORCE_LT(
998 999
          index,
          dims[i],
W
wanghuancoder 已提交
1000
          platform::errors::InvalidArgument(
1001 1002 1003
              "index %d is out fo bounds for axis %d with size %d",
              index,
              i,
W
wanghuancoder 已提交
1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034
              dims[i]));
      offset += index * strides[i];
    }
  }
#define PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(_) \
  _(bool, DataType::BOOL)                     \
  _(int8_t, DataType::INT8)                   \
  _(uint8_t, DataType::UINT8)                 \
  _(int16_t, DataType::INT16)                 \
  _(uint16_t, DataType::UINT16)               \
  _(int32_t, DataType::INT32)                 \
  _(uint32_t, DataType::UINT32)               \
  _(int64_t, DataType::INT64)                 \
  _(uint64_t, DataType::UINT64)               \
  _(bfloat16, DataType::BFLOAT16)             \
  _(float16, DataType::FLOAT16)               \
  _(float, DataType::FLOAT32)                 \
  _(double, DataType::FLOAT64)                \
  _(complex64, DataType::COMPLEX64)           \
  _(complex128, DataType::COMPLEX128)

#define TENSOR_TO_PY_SCALAR(T, proto_type)                                   \
  if (tensor.dtype() == proto_type) {                                        \
    auto numpy_dtype = TensorDtype2NumpyDtype(proto_type);                   \
    T b = paddle::pybind::TensorGetElement<T>(tensor, offset);               \
    Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];                  \
    Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];               \
    py_dims[0] = 1;                                                          \
    py_strides[0] = 1;                                                       \
    auto& api = pybind11::detail::npy_api::get();                            \
    PyObject* array = api.PyArray_NewFromDescr_(                             \
1035 1036 1037 1038 1039 1040
        api.PyArray_Type_,                                                   \
        api.PyArray_DescrFromType_(numpy_dtype),                             \
        1,                                                                   \
        py_dims,                                                             \
        py_strides,                                                          \
        nullptr,                                                             \
W
wanghuancoder 已提交
1041 1042 1043 1044 1045
        pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |                      \
            pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,                 \
        nullptr);                                                            \
    std::memcpy(                                                             \
        reinterpret_cast<void*>(pybind11::detail::array_proxy(array)->data), \
1046 1047
        static_cast<void*>(&b),                                              \
        sizeof(b));                                                          \
W
wanghuancoder 已提交
1048 1049 1050 1051 1052 1053 1054 1055 1056 1057
    return array;                                                            \
  }

  PD_FOR_EACH_DENSE_TENSOR_DATA_TYPE(TENSOR_TO_PY_SCALAR);
#undef TENSOR_TO_PY_SCALAR
  PADDLE_THROW(platform::errors::Unimplemented(
      "Unsupported tensor data type: %s", tensor.dtype()));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101
static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Call __setitem_eager_tensor";

  auto self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());

  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  PyObject* value_obj = PyTuple_GET_ITEM(args, 1);
  // NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
  // https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
  PyObject* index_ptr =
      !PyTuple_Check(_index) ? PyTuple_Pack(1, _index) : _index;
  DEFINE_PADDLE_SCOPE_GUARD([index_ptr, &_index]() {
    if (!PyTuple_Check(_index)) {
      Py_DECREF(index_ptr);
      VLOG(4) << "Call Py_DECREF";
    }
  });

  // 1. Check argumnets
  bool parse_index = true;

  // Check whether _index can be parsed.
  const int size = PyTuple_GET_SIZE(index_ptr);
  for (int dim = 0; dim < size; ++dim) {
    PyObject* slice_item = PyTuple_GetItem(index_ptr, dim);
    if (!(PyCheckInteger(slice_item) || PySlice_Check(slice_item) ||
          slice_item == Py_Ellipsis || slice_item == Py_None)) {
      parse_index = false;
      break;
    }
  }

  // 2. Call op set_value to speed up if the condition is met,
  // otherwise call TensorToPyArray.
  // TODO(liym27): Try not to call TensorToPyArray because it always
  // copys data to cpu place, which reduces performance.
  if (parse_index) {
    std::vector<int> axes, starts, ends, steps, decrease_axes, none_axes,
        infer_flags, list_select_idxs;
    // if index is a list, list_select_flag will be true
    bool list_select_flag = false;
1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112
    ParseIndexingSlice(self_tensor,
                       index_ptr,
                       &axes,
                       &starts,
                       &ends,
                       &steps,
                       &decrease_axes,
                       &none_axes,
                       &infer_flags,
                       &list_select_idxs,
                       &list_select_flag);
W
wanghuancoder 已提交
1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124

    framework::AttributeMap attrs = {{"axes", axes},
                                     {"starts", starts},
                                     {"ends", ends},
                                     {"steps", steps},
                                     {"decrease_axes", decrease_axes},
                                     {"none_axes", none_axes}};

    if (egr::Controller::Instance().HasGrad()) {
      PADDLE_ENFORCE_EQ(
          egr::egr_utils_api::IsLeafTensor(self->tensor) &&
              !egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient(),
1125 1126 1127 1128 1129
          false,
          platform::errors::InvalidArgument(
              "Leaf Tensor (%s) that doesn't stop gradient can't use "
              "inplace strategy.",
              self->tensor.name()));
W
wanghuancoder 已提交
1130 1131
    }

1132
    paddle::Tensor value_tensor;
W
wanghuancoder 已提交
1133 1134 1135 1136

    if (PyCheckTensor(value_obj)) {
      value_tensor = reinterpret_cast<TensorObject*>(value_obj)->tensor;
    } else if (py::isinstance<py::array>(value_obj)) {
1137
      paddle::Tensor value_tensor_tmp(
W
wanghuancoder 已提交
1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172
          std::make_shared<phi::DenseTensor>(),
          egr::Controller::Instance().GenerateUniqueName());
      py::object value_obj_tmp(py::handle(value_obj), true);
      py::object value = value_obj_tmp;
      if (self->tensor.dtype() == paddle::experimental::DataType::FLOAT32) {
        if (!py::isinstance<py::array_t<float>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<float>(value_obj_tmp);
        }
      } else if (self->tensor.dtype() ==
                 paddle::experimental::DataType::FLOAT64) {
        if (!py::isinstance<py::array_t<double>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<double>(value_obj_tmp);
        }
      } else if (self->tensor.dtype() ==
                 paddle::experimental::DataType::INT32) {
        if (!py::isinstance<py::array_t<int32_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int32_t>(value_obj_tmp);
        }
      } else if (self->tensor.dtype() ==
                 paddle::experimental::DataType::INT64) {
        if (!py::isinstance<py::array_t<int64_t>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<int64_t>(value_obj_tmp);
        }
      } else if (self->tensor.dtype() == paddle::experimental::DataType::BOOL) {
        if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) {
          value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp);
        }
      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "When assign a numpy.np value to a paddle.Tensor, "
            "the data type of the paddle.Tensor must be bool, "
            "float32, int32 or int64, "
            "please check the type of tensor."));
      }

W
wanghuancoder 已提交
1173 1174 1175 1176 1177
      SetTensorFromPyArray(
          static_cast<phi::DenseTensor*>(value_tensor_tmp.impl().get()),
          value,
          self->tensor.place(),
          false);
W
wanghuancoder 已提交
1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203

      value_tensor = value_tensor_tmp;
    } else {
      py::object value_obj_tmp(py::handle(value_obj), true);
      // convert the value to self data type
      if (py::isinstance<py::float_>(value_obj_tmp) ||
          py::isinstance<py::int_>(value_obj_tmp) ||
          py::isinstance<py::bool_>(value_obj_tmp)) {
        if (self->tensor.dtype() == paddle::experimental::DataType::FLOAT32) {
          attrs["fp32_values"] =
              std::vector<float>{value_obj_tmp.cast<float>()};
        } else if (self->tensor.dtype() ==
                   paddle::experimental::DataType::FLOAT64) {
          attrs["fp64_values"] =
              std::vector<double>{value_obj_tmp.cast<double>()};
        } else if (self->tensor.dtype() ==
                   paddle::experimental::DataType::INT32) {
          attrs["int32_values"] =
              std::vector<int32_t>{value_obj_tmp.cast<int32_t>()};
        } else if (self->tensor.dtype() ==
                   paddle::experimental::DataType::INT64) {
          attrs["int64_values"] =
              std::vector<int64_t>{value_obj_tmp.cast<int64_t>()};
        } else if (self->tensor.dtype() ==
                   paddle::experimental::DataType::BOOL) {
          attrs["bool_values"] = std::vector<int>{value_obj_tmp.cast<bool>()};
1204 1205 1206 1207
        } else if (self->tensor.dtype() ==
                   paddle::experimental::DataType::FLOAT16) {
          attrs["fp16_values"] =
              std::vector<float>{value_obj_tmp.cast<float>()};
W
wanghuancoder 已提交
1208 1209 1210 1211
        } else {
          PADDLE_THROW(platform::errors::InvalidArgument(
              "When assign a value to a paddle.Tensor, "
              "the data type of the paddle.Tensor must be bool, "
1212
              "float32, int32, int64 or float16, "
W
wanghuancoder 已提交
1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227
              "please check the type of tensor."));
        }
        attrs["shape"] = std::vector<int64_t>{1};

      } else {
        PADDLE_THROW(platform::errors::InvalidArgument(
            "Value type error. The assign value allows "
            "numpy.ndarray, integer, float or bool, "
            "but received %s.",
            Py_TYPE(value_obj)));
      }
    }
    {
      // Release gil and do tracing
      py::gil_scoped_release release;
1228
      // use inplace set_value_ operator
J
Jiabin Yang 已提交
1229 1230
      if (value_tensor.initialized() &&
          (self->tensor.dtype() != value_tensor.dtype())) {
1231
        paddle::small_vector<std::vector<paddle::Tensor>,
J
Jiabin Yang 已提交
1232 1233 1234 1235 1236 1237 1238
                             egr::kSlotSmallVectorSize>
            tmps = {{self->tensor}, {value_tensor}};
        auto amp_dtype = egr::GetAmpDestDtype("set_value", tmps);
        self->tensor = egr::EagerAmpAutoCast(
            self->tensor.name(), self->tensor, amp_dtype, "set_value");
        value_tensor = egr::EagerAmpAutoCast(
            value_tensor.name(), value_tensor, amp_dtype, "set_value");
1239 1240 1241
        if (self->tensor.dtype() != value_tensor.dtype()) {
          value_tensor = cast_ad_func(value_tensor, self->tensor.dtype());
        }
J
Jiabin Yang 已提交
1242
      }
1243 1244
      self->tensor = set_value__dygraph_function(
          self->tensor, value_tensor, {}, {}, {}, attrs);
1245 1246 1247 1248 1249 1250 1251 1252 1253
    }
    if (PyCheckTensor(value_obj)) {
      // pass the stop_gradient from value to tensor.
      // pass stop gradient should be done after CheckInplace in
      // set_value__dygraph_function.
      if (!egr::EagerUtils::autograd_meta(&value_tensor)->StopGradient() &&
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient()) {
        egr::EagerUtils::autograd_meta(&self->tensor)->SetStopGradient(false);
      }
W
wanghuancoder 已提交
1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267
    }
  } else {
    auto self_numpy = TensorToPyArray(*self_tensor);
    VLOG(4) << "parse_index is false";
    if (PyCheckTensor(_index)) {
      VLOG(4) << "index is tensor";
      auto index_tensor = static_cast<phi::DenseTensor*>(
          reinterpret_cast<TensorObject*>(_index)->tensor.impl().get());
      auto index_numpy = TensorToPyArray(*index_tensor);
      self_numpy[index_numpy] = py::object(py::handle(value_obj), true);
    } else {
      VLOG(4) << "index is not tensor";
      self_numpy[_index] = py::object(py::handle(value_obj), true);
    }
1268
    if (!self->tensor.initialized()) {
W
wanghuancoder 已提交
1269
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
1270 1271 1272 1273
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CUDAPlace(0)),
                           false);
W
wanghuancoder 已提交
1274
#else
1275 1276 1277 1278
      SetTensorFromPyArray(self_tensor,
                           self_numpy,
                           platform::Place(platform::CPUPlace()),
                           false);
W
wanghuancoder 已提交
1279 1280
#endif
    } else {
1281 1282
      SetTensorFromPyArray(
          self_tensor, self_numpy, self->tensor.place(), false);
W
wanghuancoder 已提交
1283 1284
    }
  }
1285 1286
  RETURN_PY_NONE

W
wanghuancoder 已提交
1287 1288 1289
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1290 1291
static PyObject* tensor_register_grad_hook(TensorObject* self,
                                           PyObject* args,
1292 1293 1294 1295 1296
                                           PyObject* kwargs) {
  EAGER_TRY
  int64_t hook_id;
  if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
    VLOG(6) << "Register hook for leaf tensor: " << self->tensor.name();
1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308

    auto autograd_meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);

    if (autograd_meta && !autograd_meta->StopGradient()) {
      if (!autograd_meta->GetMutableGradNode()) {
        VLOG(6) << "Detected NULL grad_node, Leaf tensor should have had "
                   "grad_node with type: GradNodeAccumulation.";
        autograd_meta->SetGradNode(
            std::make_shared<egr::GradNodeAccumulation>(autograd_meta));
      }
    }

1309 1310 1311 1312 1313 1314 1315 1316 1317
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();
    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    auto accumulation_grad_node =
        std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
    hook_id = accumulation_grad_node->RegisterGradientHook(
1318 1319
        rank_info.first,
        rank_info.second,
1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331
        std::make_shared<PyTensorHook>(hook_func));

  } else {
    VLOG(6) << "Register hook for non leaf tensor: " << self->tensor.name();
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();

    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    hook_id = grad_node->RegisterGradientHook(
1332 1333
        rank_info.first,
        rank_info.second,
1334 1335 1336 1337 1338 1339
        std::make_shared<PyTensorHook>(hook_func));
  }
  return ToPyObject(hook_id);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1340 1341
static PyObject* tensor_remove_grad_hook(TensorObject* self,
                                         PyObject* args,
1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353
                                         PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Remove the registered hook for tensor: " << self->tensor.name();
  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);

  int64_t hook_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);

  return ToPyObject(grad_node->RemoveGradientHook(hook_id));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1354 1355
static PyObject* tensor_register_reduce_hook(TensorObject* self,
                                             PyObject* args,
1356 1357 1358 1359 1360 1361
                                             PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Register reduce hook for tensor: " << self->tensor.name();

  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);
1362 1363
  PADDLE_ENFORCE_EQ(egr::egr_utils_api::IsLeafTensor(self->tensor),
                    true,
1364 1365 1366 1367
                    platform::errors::InvalidArgument(
                        "Only can register backward hook for leaf Tensor."));
  PADDLE_ENFORCE_EQ(
      !egr::EagerUtils::unsafe_autograd_meta(self->tensor)->StopGradient(),
1368 1369 1370 1371
      true,
      platform::errors::InvalidArgument(
          "Cannot register backward hook on a Tensor that stop "
          "gradient."));
1372 1373 1374 1375 1376 1377 1378 1379 1380 1381
  PADDLE_ENFORCE(
      grad_node.get() != nullptr,
      paddle::platform::errors::Fatal("Detected NULL grad_node,"
                                      "Leaf tensor should have had grad_node "
                                      "with type: GradNodeAccumulation."));
  PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

  auto accumulation_grad_node =
      std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
  accumulation_grad_node->RegisterReduceHook(
1382
      std::make_shared<PyVoidHook>(hook_func));
1383

1384 1385
  RETURN_PY_NONE

1386 1387 1388
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1389 1390
static PyObject* tensor__set_grad_type(TensorObject* self,
                                       PyObject* args,
J
Jiabin Yang 已提交
1391
                                       PyObject* kwargs) {
1392 1393 1394
  EAGER_TRY
  auto var_type = pybind::CastPyArg2ProtoType(PyTuple_GET_ITEM(args, 0), 0);
  auto grad_tensor =
1395
      egr::EagerUtils::autograd_meta(&self->tensor)->MutableGrad();
1396
  if (var_type == framework::proto::VarType::LOD_TENSOR) {
1397
    grad_tensor->set_impl(std::make_shared<phi::DenseTensor>());
1398
  } else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
1399
    grad_tensor->set_impl(std::make_shared<phi::SelectedRows>());
1400
  }
1401 1402
  RETURN_PY_NONE

1403 1404 1405
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1406 1407
static PyObject* tensor__clear(TensorObject* self,
                               PyObject* args,
J
Jiabin Yang 已提交
1408 1409 1410
                               PyObject* kwargs) {
  EAGER_TRY
  self->tensor.reset();
1411 1412
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1413 1414 1415
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1416 1417
static PyObject* tensor__copy_gradient_from(TensorObject* self,
                                            PyObject* args,
J
Jiabin Yang 已提交
1418 1419 1420
                                            PyObject* kwargs) {
  EAGER_TRY
  auto src = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
1421
  if (self->tensor.initialized()) {
1422 1423
    PADDLE_ENFORCE_EQ(self->tensor.dtype(),
                      src.dtype(),
J
Jiabin Yang 已提交
1424 1425
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different data type with Tensor %s",
1426 1427
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1428 1429 1430 1431 1432
    PADDLE_ENFORCE_EQ(self->tensor.impl()->type_info().id(),
                      src.impl()->type_info().id(),
                      platform::errors::PreconditionNotMet(
                          "Tensor %s has different type with Tensor %s, Tensor "
                          "ShareGradientDataWith cannot be performed!",
1433 1434
                          self->tensor.name(),
                          src.name()));
J
Jiabin Yang 已提交
1435 1436 1437 1438
  }
  VLOG(6) << "Tensor copy gradient from: " << src.name();
  auto* p_grad = egr::EagerUtils::mutable_grad(self->tensor);
  if (p_grad) {
1439 1440
    PADDLE_ENFORCE_EQ(src.initialized(),
                      true,
J
Jiabin Yang 已提交
1441 1442 1443 1444
                      platform::errors::InvalidArgument(
                          "Tensor %s has not been initialized", src.name()));
    p_grad->set_impl(src.impl());
  }
1445 1446
  RETURN_PY_NONE

J
Jiabin Yang 已提交
1447 1448
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
1449

1450 1451 1452
static PyObject* tensor__use_gpudnn(TensorObject* self,
                                    PyObject* args,
                                    PyObject* kwargs) {
1453 1454 1455
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.defined() && self->tensor.is_dense_tensor(),
                 paddle::platform::errors::Fatal(
1456
                     "function _use_gpudnn is only effective for DenseTensor"));
1457

1458
  bool use_gpudnn = pybind::CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
1459

1460
  // Set the same use_gpudnn attribute, return directly
1461 1462 1463 1464
  phi::DenseTensor* dense_tensor =
      static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  phi::DenseTensorMeta* dense_tensor_meta =
      phi::DenseTensorUtils::GetMutableMeta(dense_tensor);
1465
  if (use_gpudnn == dense_tensor_meta->use_gpudnn) {
1466 1467 1468
    return ToPyObject(self->tensor);
  }

1469
  // Share all other members of Tensor except use_gpudnn
1470
  phi::DenseTensorMeta target_dense_meta = *dense_tensor_meta;
1471
  target_dense_meta.use_gpudnn = use_gpudnn;
1472 1473 1474 1475
  phi::DenseTensor target_dense_tensor;
  target_dense_tensor.ShareDataWith(*dense_tensor);
  target_dense_tensor.set_meta(target_dense_meta);
  // Construct returned tensor
1476
  paddle::Tensor target_tensor(
1477 1478 1479 1480
      std::make_shared<phi::DenseTensor>(target_dense_tensor),
      self->tensor.name());
  target_tensor.set_autograd_meta(self->tensor.mutable_autograd_meta());
  VLOG(4) << "Tensor: " << target_tensor.name()
1481
          << " set use_gpudnn = " << use_gpudnn;
1482 1483 1484 1485 1486

  return ToPyObject(target_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1487 1488
static PyObject* tensor_method_set_vocab(TensorObject* self,
                                         PyObject* args,
1489 1490
                                         PyObject* kwargs) {
  EAGER_TRY
1491
  using Vocab = paddle::framework::Vocab;
1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503
  auto vocab = CastPyArg2Vocab(PyTuple_GET_ITEM(args, 0), 0);
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Vocab>() = vocab;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_set_string_list(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
1504
  using Strings = paddle::framework::Strings;
1505
  auto strings = CastPyArg2VectorOfString(PyTuple_GET_ITEM(args, 0), 0);
1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517
  auto var_tensor = std::make_shared<egr::VariableCompatTensor>();
  *var_tensor->GetMutable<Strings>() = strings;
  self->tensor.set_impl(var_tensor);
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_map_tensor(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE_EQ(
1518 1519
      egr::IsVariableCompatTensor(self->tensor),
      true,
1520 1521
      paddle::platform::errors::Fatal(
          "this method is only effective for VariableCompatTensor"));
1522
  using Vocab = paddle::framework::Vocab;
1523 1524 1525 1526 1527 1528
  auto* var_tensor =
      static_cast<const egr::VariableCompatTensor*>(self->tensor.impl().get());
  return ToPyObject(var_tensor->Get<Vocab>());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549
static PyObject* tensor_method_get_non_zero_nums(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
    return ToPyObject(sparse_coo_tensor->nnz());
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
    return ToPyObject(sparse_csr_tensor->nnz());
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1550 1551 1552 1553 1554 1555 1556 1557 1558
static PyObject* tensor_method_get_non_zero_indices(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_coo_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCooTensor"));
  auto sparse_coo_tensor =
      std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
1559
  paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576
      sparse_coo_tensor->non_zero_indices()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_elements(TensorObject* self,
                                                     PyObject* args,
                                                     PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(
      self->tensor.is_sparse_coo_tensor() ||
          self->tensor.is_sparse_csr_tensor(),
      paddle::platform::errors::Fatal("this method is only effective for "
                                      "SparseCooTensor or SparseCsrTensor"));
  if (self->tensor.is_sparse_coo_tensor()) {
    auto sparse_coo_tensor =
        std::dynamic_pointer_cast<phi::SparseCooTensor>(self->tensor.impl());
1577
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1578 1579 1580 1581 1582
        sparse_coo_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  } else {
    auto sparse_csr_tensor =
        std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1583
    paddle::Tensor tensor(std::make_shared<phi::DenseTensor>(
1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598
        sparse_csr_tensor->non_zero_elements()));
    return ToPyObject(tensor);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_crows(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1599
  paddle::Tensor tensor(
1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_crows()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_method_get_non_zero_cols(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_sparse_csr_tensor(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SparseCsrTensor"));
  auto sparse_csr_tensor =
      std::dynamic_pointer_cast<phi::SparseCsrTensor>(self->tensor.impl());
1614
  paddle::Tensor tensor(
1615 1616 1617 1618 1619
      std::make_shared<phi::DenseTensor>(sparse_csr_tensor->non_zero_cols()));
  return ToPyObject(tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1620 1621
static PyObject* tensor_method_is_dense(TensorObject* self,
                                        PyObject* args,
1622 1623 1624 1625 1626 1627 1628 1629 1630
                                        PyObject* kwargs) {
  EAGER_TRY
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
  return ToPyObject(self->tensor.is_dense_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1631 1632
static PyObject* tensor_method_is_sparse(TensorObject* self,
                                         PyObject* args,
1633 1634
                                         PyObject* kwargs) {
  EAGER_TRY
1635 1636 1637
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1638 1639 1640 1641 1642
  return ToPyObject(self->tensor.is_sparse_coo_tensor() ||
                    self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1643 1644
static PyObject* tensor_method_is_sparse_coo(TensorObject* self,
                                             PyObject* args,
1645 1646
                                             PyObject* kwargs) {
  EAGER_TRY
1647 1648 1649
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1650 1651 1652 1653
  return ToPyObject(self->tensor.is_sparse_coo_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1654 1655
static PyObject* tensor_method_is_sparse_csr(TensorObject* self,
                                             PyObject* args,
1656 1657
                                             PyObject* kwargs) {
  EAGER_TRY
1658 1659 1660
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1661 1662 1663 1664
  return ToPyObject(self->tensor.is_sparse_csr_tensor());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1665 1666
static PyObject* tensor_method_to_sparse_csr(TensorObject* self,
                                             PyObject* args,
1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679
                                             PyObject* kwargs) {
  EAGER_TRY
  auto csr_tensor = self->tensor.to_sparse_csr();
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetStopGradient(
          egr::EagerUtils::autograd_meta(&self->tensor)->StopGradient());
  egr::EagerUtils::autograd_meta(&csr_tensor)
      ->SetPersistable(
          egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  return ToPyObject(csr_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1680 1681 1682 1683 1684 1685 1686 1687 1688
static PyObject* tensor_method_is_same_shape(TensorObject* self,
                                             PyObject* args,
                                             PyObject* kwargs) {
  EAGER_TRY
  auto other = CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
  return ToPyObject(self->tensor.shape() == other.shape());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1689 1690
static PyObject* tensor__inplace_version(TensorObject* self,
                                         PyObject* args,
1691 1692 1693 1694 1695 1696 1697 1698
                                         PyObject* kwargs) {
  EAGER_TRY
  uint32_t inplace_version = self->tensor.current_inplace_version();

  return ToPyObject(inplace_version);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1699 1700
static PyObject* tensor_method_element_size(TensorObject* self,
                                            PyObject* args,
1701 1702
                                            PyObject* kwargs) {
  EAGER_TRY
1703
  uint32_t element_size = phi::SizeOf(self->tensor.dtype());
1704 1705 1706 1707 1708

  return ToPyObject(element_size);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1709 1710 1711 1712 1713
static PyObject* tensor__bump_inplace_version(TensorObject* self,
                                              PyObject* args,
                                              PyObject* kwargs) {
  EAGER_TRY
  self->tensor.bump_inplace_version();
1714
  RETURN_PY_NONE
1715 1716 1717
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1718 1719 1720 1721
static PyObject* tensor_method_is_selected_rows(TensorObject* self,
                                                PyObject* args,
                                                PyObject* kwargs) {
  EAGER_TRY
1722 1723 1724
  if (!self->tensor.defined()) {
    return ToPyObject(false);
  }
1725 1726 1727 1728
  return ToPyObject(self->tensor.is_selected_rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1729 1730
static PyObject* tensor_method_get_rows(TensorObject* self,
                                        PyObject* args,
1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741
                                        PyObject* kwargs) {
  EAGER_TRY
  PADDLE_ENFORCE(self->tensor.is_selected_rows(),
                 paddle::platform::errors::Fatal(
                     "this method is only effective for SelectedRows"));
  auto selected_rows =
      std::dynamic_pointer_cast<phi::SelectedRows>(self->tensor.impl());
  return ToPyObject(selected_rows->rows());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1742 1743 1744 1745 1746 1747 1748 1749 1750 1751
static PyObject* tensor__reset_grad_inplace_version(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
    set_to_zero = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
  }

1752
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1753 1754 1755 1756
  if (grad && grad->defined() && grad->is_dense_tensor() &&
      grad->initialized()) {
    grad->reset_inplace_version(set_to_zero);
  }
1757 1758
  RETURN_PY_NONE

1759 1760 1761
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1762 1763
static PyObject* tensor_method__share_memory(TensorObject* self,
                                             PyObject* args,
W
wanghuancoder 已提交
1764 1765 1766
                                             PyObject* kwargs) {
  EAGER_TRY
#ifndef _WIN32
1767 1768
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
W
wanghuancoder 已提交
1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784
                    platform::errors::InvalidArgument(
                        "Sharing memory only support CPU Tensor currently"));
  // 1. get LoDTensor
  auto* t =
      std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl()).get();
  // 2. allocate shared memory
  void* data_ptr = t->data();
  size_t data_size =
      t->numel() *
      framework::SizeOfType(framework::TransToProtoVarType(t->dtype()));
  auto shared_writer_holder =
      memory::allocation::AllocateMemoryMapWriterAllocation(data_size);
  // 3. maintain mmap fd set & backup ipc_name
  const std::string& ipc_name = shared_writer_holder->ipc_name();
  memory::allocation::MemoryMapFdSet::Instance().Insert(ipc_name);
  // 4. copy data & reset holder
1785 1786 1787 1788 1789
  memory::Copy(platform::CPUPlace(),
               shared_writer_holder->ptr(),
               platform::CPUPlace(),
               data_ptr,
               data_size);
W
wanghuancoder 已提交
1790 1791 1792 1793 1794
  t->ResetHolder(shared_writer_holder);
  return ToPyObject(t);
#else
  PADDLE_THROW(platform::errors::PermissionDenied(
      "Sharing memory in Windows OS is not supported currently"));
1795 1796
  RETURN_PY_NONE

W
wanghuancoder 已提交
1797 1798 1799 1800
#endif
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1801 1802
static PyObject* tensor__offset(TensorObject* self,
                                PyObject* args,
1803 1804 1805 1806
                                PyObject* kwargs) {
  EAGER_TRY
  auto t = std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
  PADDLE_ENFORCE_EQ(
1807 1808
      t->IsInitialized(),
      true,
1809 1810 1811 1812 1813 1814 1815
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
                                        self->tensor.name()));

  return ToPyObject(t->offset());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1816 1817
static PyObject* tensor__grad_name(TensorObject* self,
                                   PyObject* args,
1818 1819
                                   PyObject* kwargs) {
  EAGER_TRY
1820
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1821 1822
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1823 1824 1825 1826 1827 1828 1829
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));
  return ToPyObject(grad->name());
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1830 1831
static PyObject* tensor__grad_value(TensorObject* self,
                                    PyObject* args,
1832 1833
                                    PyObject* kwargs) {
  EAGER_TRY
1834
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1835 1836
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1837 1838 1839 1840 1841
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));

  if (!grad->defined()) {
1842
    RETURN_PY_NONE
1843 1844
  }
  if (grad->is_dense_tensor()) {
1845
    auto* grad_tensor = static_cast<phi::DenseTensor*>(grad->impl().get());
1846 1847 1848 1849
    return ToPyObject(grad_tensor);
  } else {
    PADDLE_THROW(paddle::platform::errors::Fatal(
        "this method is only supported for DenseTensor"));
1850
    RETURN_PY_NONE
1851 1852 1853 1854
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1855 1856
static PyObject* tensor__unset_fake_empty(TensorObject* self,
                                          PyObject* args,
1857 1858
                                          PyObject* kwargs) {
  EAGER_TRY
1859
  paddle::Tensor* grad = egr::EagerUtils::mutable_grad(self->tensor);
1860 1861
  PADDLE_ENFORCE_EQ(grad != nullptr,
                    true,
1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875
                    platform::errors::InvalidArgument(
                        "Detected NULL grad. Please check if you have manually "
                        "cleared the grad inside autograd_meta"));

  bool is_leaf = egr::egr_utils_api::IsLeafTensor(self->tensor);
  if (is_leaf) {
    std::static_pointer_cast<egr::GradNodeAccumulation>(
        egr::EagerUtils::grad_node(self->tensor))
        ->SetFakeEmpty(false);
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1876 1877 1878 1879 1880
static PyObject* tensor_data_ptr(TensorObject* self,
                                 PyObject* args,
                                 PyObject* kwargs) {
  EAGER_TRY
  if (self->tensor.initialized() && self->tensor.is_dense_tensor()) {
S
sneaxiy 已提交
1881 1882 1883 1884
    return ToPyObject(
        (int64_t)std::dynamic_pointer_cast<phi::DenseTensor>(  // NOLINT
            self->tensor.impl())
            ->data());
1885 1886 1887 1888 1889
  }
  RETURN_PY_NONE
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

W
wanghuancoder 已提交
1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904
static PyObject* tensor__grad_ivar(TensorObject* self,
                                   PyObject* args,
                                   PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Get grad for tensor: " << self->tensor.name();
  auto meta = egr::EagerUtils::nullable_autograd_meta(self->tensor);
  VLOG(6) << meta << " initialized: " << meta->Grad().initialized();
  if (meta && meta->Grad().initialized()) {
    return ToPyObject(meta->Grad());
  } else {
    RETURN_PY_NONE
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

1905
#if defined(PADDLE_WITH_CUDA)
1906 1907
static PyObject* tensor_method__uva(TensorObject* self,
                                    PyObject* args,
1908 1909 1910
                                    PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Running in tensor_method__uva.";
1911 1912
  PADDLE_ENFORCE_EQ(self->tensor.is_dense_tensor(),
                    true,
W
Weilong Wu 已提交
1913 1914 1915
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "DenseTensor currently."));
1916 1917
  PADDLE_ENFORCE_EQ(platform::is_cpu_place(self->tensor.place()),
                    true,
1918 1919 1920 1921
                    platform::errors::InvalidArgument(
                        "Unified virtual addressing only support "
                        "CPU Tensor currently."));
  int device_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);
1922
  auto* self_tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
1923 1924
  tensor_uva(self_tensor, device_id);

1925 1926
  RETURN_PY_NONE

1927 1928 1929
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
#endif
J
Jack Zhou 已提交
1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941
static PyObject* tensor_method__is_string_tensor_hold_allocation(
    TensorObject* self, PyObject* args, PyObject* kwargs) {
  EAGER_TRY
  auto string_tensor =
      std::dynamic_pointer_cast<phi::StringTensor>(self->tensor.impl());
  if (string_tensor) {
    return ToPyObject(string_tensor->initialized());
  } else {
    return ToPyObject(false);
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}
1942

1943
PyMethodDef variable_methods[] = {
1944 1945 1946 1947
    {"numpy",
     (PyCFunction)(void (*)(void))tensor_method_numpy,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1948
    {"_is_initialized",
1949
     (PyCFunction)(void (*)(void))tensor_method__is_initialized,
1950 1951
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
1952
    {"_is_dense_tensor_hold_allocation",
1953 1954
     (PyCFunction)(void (*)(
         void))tensor_method__is_dense_tensor_hold_allocation,
1955 1956 1957 1958 1959 1960 1961 1962 1963 1964
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_copy_to",
     (PyCFunction)(void (*)(void))tensor_method__copy_to,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"copy_",
     (PyCFunction)(void (*)(void))tensor_method_copy_,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1965 1966 1967 1968
    {"clone",
     (PyCFunction)(void (*)(void))tensor_method_clone,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1969
    {"reconstruct_from_",
1970
     (PyCFunction)(void (*)(void))tensor_method_reconstruct_from_,
1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"retain_grads",
     (PyCFunction)(void (*)(void))tensor_retain_grads,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"clear_gradient",
     (PyCFunction)(void (*)(void))tensor_clear_gradient,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_dense",
     (PyCFunction)(void (*)(void))tensor_method_is_dense,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_zero_grads",
     (PyCFunction)(void (*)(void))tensor__zero_grads,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_share_buffer_to",
     (PyCFunction)(void (*)(void))tensor__share_buffer_to,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1993
    {"_is_shared_buffer_with",
1994
     (PyCFunction)(void (*)(void))tensor__is_shared_buffer_with,
1995 1996
     METH_VARARGS | METH_KEYWORDS,
     NULL},
1997
    {"_share_underline_tensor_to",
1998
     (PyCFunction)(void (*)(void))tensor__share_underline_tensor_to,
1999 2000
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2001
    {"_is_shared_underline_tensor_with",
2002
     (PyCFunction)(void (*)(void))tensor__is_shared_underline_tensor_with,
2003 2004 2005 2006 2007 2008
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"detach",
     (PyCFunction)(void (*)(void))tensor_method_detach,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2009
    {"get_tensor",
2010
     (PyCFunction)(void (*)(void))tensor_method_get_underline_tensor,
2011 2012
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2013 2014
    {"get_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method_get_underline_selected_rows,
2015 2016
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2017 2018 2019 2020
    {"_get_tensor_from_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method__get_tensor_from_selected_rows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jiabin Yang 已提交
2021 2022
    {"_getitem_index_not_tensor",
     (PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor,
2023 2024
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2025 2026
    {"_getitem_from_offset",
     (PyCFunction)(void (*)(void))tensor__getitem_from_offset,
2027 2028
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2029 2030
    {"__setitem_eager_tensor__",
     (PyCFunction)(void (*)(void))tensor_method__setitem_eager_tensor,
2031 2032
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2033 2034
    {"_register_grad_hook",
     (PyCFunction)(void (*)(void))tensor_register_grad_hook,
2035 2036 2037 2038 2039 2040
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_remove_grad_hook",
     (PyCFunction)(void (*)(void))tensor_remove_grad_hook,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2041 2042
    {"_register_backward_hook",
     (PyCFunction)(void (*)(void))tensor_register_reduce_hook,
2043 2044 2045 2046 2047 2048 2049 2050 2051 2052
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_set_grad_type",
     (PyCFunction)(void (*)(void))tensor__set_grad_type,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_clear",
     (PyCFunction)(void (*)(void))tensor__clear,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jiabin Yang 已提交
2053 2054
    {"_copy_gradient_from",
     (PyCFunction)(void (*)(void))tensor__copy_gradient_from,
2055 2056
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2057 2058
    {"_tensor_use_gpudnn",
     (PyCFunction)(void (*)(void))tensor__use_gpudnn,
2059 2060
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2061 2062 2063
    /** the methods to adapt old dygraph, will be removed in the future **/
    {"set_string_list",
     (PyCFunction)(void (*)(void))tensor_method_set_string_list,
2064 2065 2066 2067 2068 2069
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"set_vocab",
     (PyCFunction)(void (*)(void))tensor_method_set_vocab,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2070 2071
    {"get_map_tensor",
     (PyCFunction)(void (*)(void))tensor_method_get_map_tensor,
2072 2073
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2074
    /***the method of sparse tensor****/
2075 2076 2077 2078
    {"nnz",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_nums,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106
    {"indices",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_indices,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"values",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_elements,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"crows",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_crows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"cols",
     (PyCFunction)(void (*)(void))tensor_method_get_non_zero_cols,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse_coo",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse_coo,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"is_sparse_csr",
     (PyCFunction)(void (*)(void))tensor_method_is_sparse_csr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2107 2108 2109 2110
    {"is_same_shape",
     (PyCFunction)(void (*)(void))tensor_method_is_same_shape,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2111 2112 2113 2114 2115 2116 2117 2118
    {"to_sparse_csr",
     (PyCFunction)(void (*)(void))tensor_method_to_sparse_csr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"element_size",
     (PyCFunction)(void (*)(void))tensor_method_element_size,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2119
    /***the method of sparse tensor****/
2120 2121 2122 2123
    {"_inplace_version",
     (PyCFunction)(void (*)(void))tensor__inplace_version,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2124 2125
    {"_bump_inplace_version",
     (PyCFunction)(void (*)(void))tensor__bump_inplace_version,
2126 2127
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2128 2129
    {"is_selected_rows",
     (PyCFunction)(void (*)(void))tensor_method_is_selected_rows,
2130 2131 2132 2133 2134 2135
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"rows",
     (PyCFunction)(void (*)(void))tensor_method_get_rows,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2136 2137
    {"_reset_grad_inplace_version",
     (PyCFunction)(void (*)(void))tensor__reset_grad_inplace_version,
2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_share_memory",
     (PyCFunction)(void (*)(void))tensor_method__share_memory,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_offset",
     (PyCFunction)(void (*)(void))tensor__offset,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_grad_name",
     (PyCFunction)(void (*)(void))tensor__grad_name,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_grad_value",
     (PyCFunction)(void (*)(void))tensor__grad_value,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
    {"_unset_fake_empty",
     (PyCFunction)(void (*)(void))tensor__unset_fake_empty,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2160 2161 2162 2163
    {"data_ptr",
     (PyCFunction)(void (*)(void))tensor_data_ptr,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
W
wanghuancoder 已提交
2164 2165 2166 2167
    {"_grad_ivar",
     (PyCFunction)(void (*)(void))tensor__grad_ivar,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2168
#if defined(PADDLE_WITH_CUDA)
2169 2170 2171 2172
    {"_tensor_uva",
     (PyCFunction)(void (*)(void))tensor_method__uva,
     METH_VARARGS | METH_KEYWORDS,
     NULL},
2173
#endif
2174 2175
    {NULL, NULL, 0, NULL}};

J
Jack Zhou 已提交
2176 2177 2178 2179
// variable_methods for core.eager.StringTensor
PyMethodDef string_tensor_variable_methods[] = {
    {"numpy",
     (PyCFunction)(void (*)(void))tensor_method_numpy_for_string_tensor,
2180 2181
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2182 2183
    {"_is_initialized",
     (PyCFunction)(void (*)(void))tensor_method__is_initialized,
2184 2185
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2186
    {"_is_string_tensor_hold_allocation",
2187 2188
     (PyCFunction)(void (*)(
         void))tensor_method__is_string_tensor_hold_allocation,
2189 2190
     METH_VARARGS | METH_KEYWORDS,
     NULL},
J
Jack Zhou 已提交
2191 2192 2193
    // TODO(zhoushunjie): Need to add _copy_to, copy_ for StringTensor.
    {NULL, NULL, 0, NULL}};

2194 2195
}  // namespace pybind
}  // namespace paddle