eager_method.cc 29.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// disable numpy compile error
#include <Python.h>

#include <string>
#include <vector>

#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"

20
#include "paddle/fluid/eager/accumulation/accumulation_node.h"
21
#include "paddle/fluid/eager/api/all.h"
J
Jiabin Yang 已提交
22
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
23
#include "paddle/fluid/eager/autograd_meta.h"
24 25
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/hooks.h"
26
#include "paddle/fluid/eager/utils.h"
27
#include "paddle/fluid/framework/convert_utils.h"
28 29 30 31 32 33
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/fluid/pybind/exception.h"
J
Jiabin Yang 已提交
34
#include "paddle/fluid/pybind/slice_utils.h"
35 36 37 38
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
J
Jiabin Yang 已提交
39

40 41 42
namespace paddle {
namespace pybind {

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
namespace py = ::pybind11;

class PyTensorHook : public egr::TensorHook {
 public:
  explicit PyTensorHook(PyObject* func) : py_func_(func) {
    Py_INCREF(py_func_);
  }

  ~PyTensorHook() {
    py::gil_scoped_acquire gil;
    Py_DECREF(py_func_);
  }

  paddle::experimental::Tensor operator()(
      const paddle::experimental::Tensor& var) override {
    py::gil_scoped_acquire gil;
    VLOG(3) << "Call PyTensorHook for var " << var.name();

    PyObject* res = nullptr;
    try {
      res = PyObject_CallFunctionObjArgs(py_func_, ToPyObject(var), nullptr);
    } catch (platform::EnforceNotMet& e) {
      throw std::move(e);
    } catch (std::exception& e) {
      PADDLE_THROW(platform::errors::Unavailable(
          "Hook function of Tensor raises an exception: %s.", e.what()));
    } catch (...) {
      PADDLE_THROW(platform::errors::Fatal(
          "Hook function of Tensor raises an unknown exception."));
    }

    PADDLE_ENFORCE_NOT_NULL(res,
                            platform::errors::Unavailable(
                                "Hook function of Tensor return a nullptr."));
    if (res == Py_None) {
      return var;
    }
    return reinterpret_cast<TensorObject*>(res)->tensor;
  }

 private:
  PyObject* py_func_;
};

class PyTensorVoidHook : public egr::TensorVoidHook {
 public:
  explicit PyTensorVoidHook(PyObject* func) : py_func_(func) {
    Py_INCREF(py_func_);
  }

  ~PyTensorVoidHook() {
    py::gil_scoped_acquire gil;
    Py_DECREF(py_func_);
  }

  void operator()() override {
    py::gil_scoped_acquire gil;
    VLOG(3) << "Call PyTensorVoidHook";

    try {
      PyObject_CallFunctionObjArgs(py_func_, nullptr);
    } catch (platform::EnforceNotMet& e) {
      throw std::move(e);
    } catch (std::exception& e) {
      PADDLE_THROW(platform::errors::Unavailable(
          "Hook function of Tensor raises an exception: %s.", e.what()));
    } catch (...) {
      PADDLE_THROW(platform::errors::Fatal(
          "Hook function of Tensor raises an unknown exception."));
    }
  }

 private:
  PyObject* py_func_;
};

119 120 121
extern void InitTensorWithNumpyValue(TensorObject* self,
                                     const pybind11::object& array,
                                     bool zero_copy);
122

123
extern PyTypeObject* p_tensor_type;
124

J
Jiabin Yang 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) {
  if (PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type))) {
    VLOG(6) << "Call GetSliceIndexFromTensor in Eager";
    paddle::experimental::Tensor tensor = CastPyArg2Tensor(obj, 0);
    PADDLE_ENFORCE_EQ(
        tensor.initialized(), true,
        paddle::platform::errors::InvalidArgument(
            "We can only support initialized tensor in slice, however we got "
            "uninitialized tensor %s, please check your code.",
            tensor.name()));
    return GetSliceIndexFromTensor((*static_cast<phi::DenseTensor*>(
        CastPyArg2Tensor(obj, 0).impl().get())));
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "We should only get paddle::experimental::Tensor or VarBase in this "
        "method, when you reach this means we got another type index."));
  }
}

bool PyCheckTensor(PyObject* obj) {
  return PyObject_IsInstance(obj, reinterpret_cast<PyObject*>(p_tensor_type));
}

148 149 150
static PyObject* tensor_method_numpy(TensorObject* self, PyObject* args,
                                     PyObject* kwargs) {
  EAGER_TRY
151
  PADDLE_ENFORCE_EQ(
152
      self->tensor.initialized(), true,
153 154 155
      platform::errors::InvalidArgument(
          "Tensor data of %s is Empty that indicates we have null tensor for "
          "now, please check if it has no data and initialize it first.",
156 157 158
          self->tensor.name()));
  auto tensor_dims = self->tensor.shape();
  auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type());
159
  auto sizeof_dtype = paddle::framework::DataTypeSize(self->tensor.type());
160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175
  Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank];
  Py_intptr_t py_strides[paddle::framework::DDim::kMaxRank];
  size_t numel = 1;
  for (int i = tensor_dims.size() - 1; i >= 0; --i) {
    py_dims[i] = static_cast<size_t>(tensor_dims[i]);
    py_strides[i] = sizeof_dtype * numel;
    numel *= py_dims[i];
  }
  auto& api = pybind11::detail::npy_api::get();
  PyObject* array = api.PyArray_NewFromDescr_(
      api.PyArray_Type_, api.PyArray_DescrFromType_(numpy_dtype),
      tensor_dims.size(), py_dims, py_strides, nullptr,
      pybind11::detail::npy_api::NPY_ARRAY_ALIGNED_ |
          pybind11::detail::npy_api::NPY_ARRAY_WRITEABLE_,
      nullptr);

176
  if (self->tensor.is_cpu()) {
177
    auto dense_tensor =
178
        std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
179 180 181 182 183 184
    platform::CPUPlace place;
    // deep copy
    paddle::memory::Copy(place, reinterpret_cast<void*>(
                                    pybind11::detail::array_proxy(array)->data),
                         place, dense_tensor->data(), sizeof_dtype * numel);
#if defined(PADDLE_WITH_CUDA)
185
  } else if (self->tensor.is_cuda()) {
186
    auto dense_tensor =
187
        std::dynamic_pointer_cast<phi::DenseTensor>(self->tensor.impl());
188 189 190

    paddle::platform::GpuMemcpySync(
        pybind11::detail::array_proxy(array)->data, dense_tensor->data(),
191 192
        paddle::framework::DataTypeSize(dense_tensor->dtype()) *
            dense_tensor->numel(),
193 194 195 196 197 198 199 200 201 202 203 204 205
        cudaMemcpyDeviceToHost);
#endif
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Tensor.numpy() only support cpu tensor."));
    Py_INCREF(Py_None);
    return Py_None;
  }

  return array;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

206 207 208 209
static PyObject* tensor_method__is_initialized(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
210
  return ToPyObject(self->tensor.initialized());
211 212 213
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

214 215 216
static PyObject* tensor_method__copy_to(TensorObject* self, PyObject* args,
                                        PyObject* kwargs) {
  EAGER_TRY
217 218
  auto place = CastPyArg2Place(PyTuple_GET_ITEM(args, 0), 0);
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
219
  auto cp_tensor =
220
      self->tensor.copy_to(phi::TransToPhiBackend(place), blocking);
221 222 223
  egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
  egr::EagerUtils::autograd_meta(&cp_tensor)
      ->SetPersistable(
224
          egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
225 226 227 228
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

229 230 231 232 233 234 235 236 237 238 239 240 241
static PyObject* tensor_method_cpu(TensorObject* self, PyObject* args,
                                   PyObject* kwargs) {
  EAGER_TRY
  auto cp_tensor =
      self->tensor.copy_to(phi::TransToPhiBackend(phi::CPUPlace()), true);
  egr::EagerUtils::autograd_meta(&cp_tensor)->SetStopGradient(true);
  egr::EagerUtils::autograd_meta(&cp_tensor)
      ->SetPersistable(
          egr::EagerUtils::autograd_meta(&(self->tensor))->Persistable());
  return ToPyObject(cp_tensor);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

242 243 244 245
static PyObject* tensor_method_reconstruct_from_(TensorObject* self,
                                                 PyObject* args,
                                                 PyObject* kwargs) {
  EAGER_TRY
246 247 248
  paddle::experimental::Tensor src_tensor =
      CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
  std::string orig_name = self->tensor.name();
249 250
  VLOG(6) << "Start Reconstructing Tensor from" << src_tensor.name() << " to "
          << orig_name;
251
  self->tensor = src_tensor;
252 253

  // Recover source name
254
  self->tensor.set_name(orig_name);
255 256

  VLOG(6) << "Finished Reconstructing Tensor from" << src_tensor.name()
257
          << " to " << self->tensor.name();
258 259 260 261 262
  Py_INCREF(Py_None);
  return Py_None;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

263 264 265
static PyObject* tensor_method_copy_(TensorObject* self, PyObject* args,
                                     PyObject* kwargs) {
  EAGER_TRY
266 267
  paddle::experimental::Tensor src_tensor =
      CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
268
  bool blocking = CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 1), 1);
269
  VLOG(6) << "Start Copy Tensor " << src_tensor.name() << " to "
270 271 272
          << self->tensor.name();
  if (!self->tensor.defined()) {
    egr::EagerUtils::autograd_meta(&(self->tensor))
273 274
        ->SetStopGradient(
            egr::EagerUtils::autograd_meta(&(src_tensor))->StopGradient());
275
    egr::EagerUtils::autograd_meta(&(self->tensor))
276 277 278 279
        ->SetPersistable(
            egr::EagerUtils::autograd_meta(&(src_tensor))->Persistable());
  }

280
  self->tensor.copy_(src_tensor, self->tensor.inner_place(), blocking);
281

282
  VLOG(6) << "Finish Copy Tensor " << src_tensor.name() << " to "
283
          << self->tensor.name();
284 285 286 287 288
  Py_INCREF(Py_None);
  return Py_None;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

289 290
static PyObject* tensor_retain_grads(TensorObject* self, PyObject* args,
                                     PyObject* kwargs) {
291
  EAGER_TRY
292
  if (egr::Controller::Instance().HasGrad()) {
293
    auto meta = egr::EagerUtils::autograd_meta(&(self->tensor));
294
    if (!meta->GetMutableGradNode()) {
295
      VLOG(6) << "Make grad node of tensor: " << self->tensor.name()
296
              << "become accumulation node";
297
      meta->SetGradNode(std::make_shared<egr::GradNodeAccumulation>(meta));
298
    }
299
    egr::egr_utils_api::RetainGradForTensor(self->tensor);
300
  }
301 302 303 304 305
  Py_INCREF(Py_None);
  return Py_None;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

306 307
static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args,
                                       PyObject* kwargs) {
308
  EAGER_TRY
309
  VLOG(4) << "ClearGradient " << self->tensor.name();
310

311 312 313 314 315 316
  Py_ssize_t args_num = PyTuple_Size(args);
  bool set_to_zero = true;
  if (args_num == (Py_ssize_t)1) {
    CastPyArg2AttrBoolean(PyTuple_GET_ITEM(args, 0), 0);
  }

317 318
  paddle::experimental::Tensor* grad;
  if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
319 320 321 322 323 324
    grad = egr::EagerUtils::mutable_grad(self->tensor);
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
325
  } else {
326
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
327
    grad = meta->MutableGrad();
328 329
  }

330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348
  if (grad->is_selected_rows()) {
    auto selected_rows =
        std::dynamic_pointer_cast<phi::SelectedRows>(grad->impl());
    if (selected_rows->mutable_value()->IsInitialized()) {
      selected_rows->mutable_rows()->clear();
      selected_rows->mutable_value()->clear();
    }
  } else if (grad->is_dense_tensor()) {
    if (grad->initialized()) {
      if (set_to_zero) {
        grad->set_impl(paddle::experimental::zeros_like(*grad).impl());
      } else {
        VLOG(4) << "Gradient of " << self->tensor.name()
                << " is initialized, will be released.";
        auto dense_tensor =
            std::dynamic_pointer_cast<phi::DenseTensor>(grad->impl());
        dense_tensor->MoveMemoryHolder();
      }
    }
349
  }
350

351 352 353 354 355
  Py_INCREF(Py_None);
  return Py_None;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

356 357
static PyObject* tensor__zero_grads(TensorObject* self, PyObject* args,
                                    PyObject* kwargs) {
358
  EAGER_TRY
359
  VLOG(4) << "ZeroGrads " << self->tensor.name();
360

361
  if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
362
    // Add RetainGrad as PostHook to AccumulationNode
363 364 365 366 367 368 369 370 371
    paddle::experimental::Tensor* grad =
        egr::EagerUtils::mutable_grad(self->tensor);
    PADDLE_ENFORCE(grad != nullptr,
                   paddle::platform::errors::Fatal(
                       "Detected NULL grad"
                       "Please check if you have manually cleared"
                       "the grad inside autograd_meta"));
    if (grad->initialized()) {
      grad->set_impl(paddle::experimental::zeros_like(*(grad)).impl());
372
    }
373
  } else {
374
    auto meta = egr::EagerUtils::unsafe_autograd_meta(self->tensor);
375
    if (meta->MutableGrad()->initialized()) {
376 377
      meta->MutableGrad()->set_impl(
          paddle::experimental::zeros_like(*(meta->MutableGrad())).impl());
378
    }
379 380 381 382 383 384 385
  }

  Py_INCREF(Py_None);
  return Py_None;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

386 387 388
static PyObject* tensor__share_buffer_to(TensorObject* self, PyObject* args,
                                         PyObject* kwargs) {
  EAGER_TRY
389 390 391
  paddle::experimental::Tensor* dst_ptr =
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
  PADDLE_ENFORCE_EQ(self->tensor.initialized(), true,
392 393 394
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
395
                        self->tensor.name()));
396
  auto* src_tensor =
397
      static_cast<paddle::framework::Tensor*>(self->tensor.impl().get());
398 399 400 401
  auto dst_tensor =
      static_cast<paddle::framework::Tensor*>(dst_ptr->impl().get());
  dst_tensor->ShareDataWith(*src_tensor);
  dst_tensor->ShareDataTypeWith(*src_tensor);
402 403 404 405 406
  Py_INCREF(Py_None);
  return Py_None;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

407 408 409 410
static PyObject* tensor__is_shared_buffer_with(TensorObject* self,
                                               PyObject* args,
                                               PyObject* kwargs) {
  EAGER_TRY
411 412 413
  paddle::experimental::Tensor* dst_ptr =
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
  PADDLE_ENFORCE_EQ(self->tensor.initialized(), true,
414 415 416
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
417
                        self->tensor.name()));
418
  bool res = false;
419
  if (!self->tensor.defined() || !dst_ptr->defined()) {
420 421 422
    return ToPyObject(res);
  }
  auto* self_ptr =
423
      static_cast<paddle::framework::Tensor*>(self->tensor.impl().get());
424 425 426 427 428 429 430
  auto dst_tensor =
      static_cast<paddle::framework::Tensor*>(dst_ptr->impl().get());
  res = dst_tensor->IsSharedBufferWith(*self_ptr);
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

431 432 433 434
static PyObject* tensor__share_underline_tensor_to(TensorObject* self,
                                                   PyObject* args,
                                                   PyObject* kwargs) {
  EAGER_TRY
435 436 437
  paddle::experimental::Tensor* src_ptr =
      &(reinterpret_cast<TensorObject*>(PyTuple_GET_ITEM(args, 0))->tensor);
  PADDLE_ENFORCE_EQ(self->tensor.initialized(), true,
438 439 440
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
441 442
                        self->tensor.name()));
  src_ptr->set_impl(self->tensor.impl());
443 444 445 446 447
  Py_INCREF(Py_None);
  return Py_None;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

448 449 450 451
static PyObject* tensor__is_shared_underline_tensor_with(TensorObject* self,
                                                         PyObject* args,
                                                         PyObject* kwargs) {
  EAGER_TRY
452 453
  paddle::experimental::Tensor src_tensor =
      CastPyArg2Tensor(PyTuple_GET_ITEM(args, 0), 0);
454 455 456 457 458 459
  PADDLE_ENFORCE_EQ(src_tensor.initialized(), true,
                    platform::errors::InvalidArgument(
                        "Tensor %s has not been initialized! please initialize "
                        "src tensor before share_buffer_with to other.",
                        src_tensor.name()));
  bool res = false;
460
  if (!self->tensor.defined() || !src_tensor.defined()) {
461 462
    return ToPyObject(res);
  }
463
  res = (self->tensor.impl().get() == src_tensor.impl().get());
464 465 466 467
  return ToPyObject(res);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

468 469 470
static PyObject* tensor_method_detach(TensorObject* self, PyObject* args,
                                      PyObject* kwargs) {
  EAGER_TRY
471
  PADDLE_ENFORCE_EQ(
472
      self->tensor.initialized(), true,
473
      platform::errors::InvalidArgument("Tensor %s has not been initialized!",
474
                                        self->tensor.name()));
475

476
  PyObject* obj = p_tensor_type->tp_alloc(p_tensor_type, 0);
477
  if (obj) {
478 479 480 481 482 483
    auto v = reinterpret_cast<TensorObject*>(obj);
    new (&(v->tensor)) paddle::experimental::Tensor();
    v->tensor.set_impl(self->tensor.impl());
    v->tensor.set_name(egr::Controller::Instance().GenerateUniqueName());
    auto autograd_meta_src = egr::EagerUtils::autograd_meta(&(self->tensor));
    auto autograd_meta = egr::EagerUtils::autograd_meta(&(v->tensor));
484 485 486 487 488 489 490 491 492 493
    autograd_meta->SetPersistable(autograd_meta_src->Persistable());
  } else {
    PADDLE_THROW(platform::errors::Fatal(
        "tp_alloc return null, can not new a PyObject."));
  }

  return obj;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

494 495 496 497
static PyObject* tensor_method_get_underline_tensor(TensorObject* self,
                                                    PyObject* args,
                                                    PyObject* kwargs) {
  EAGER_TRY
498 499 500
  if (self->tensor.is_dense_tensor()) {
    auto* tensor =
        static_cast<paddle::framework::LoDTensor*>(self->tensor.impl().get());
501 502 503 504 505 506 507 508 509
    VLOG(6) << "tensor: " << tensor->IsInitialized();
    return ToPyObject(tensor);
  } else {
    Py_IncRef(Py_None);
    return Py_None;
  }
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

J
Jiabin Yang 已提交
510 511 512
static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
                                                  PyObject* args,
                                                  PyObject* kwargs) {
513
  EAGER_TRY
J
Jiabin Yang 已提交
514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614
  PyObject* _index = PyTuple_GET_ITEM(args, 0);
  VLOG(4) << "Call _getitem_index_not_tensor";
  std::vector<int> slice_axes, slice_starts, slice_ends, slice_strides,
      decrease_axis, none_axes, infer_flags, list_select_idxs;
  // if index is a list, list_select_flag will be true
  bool list_select_flag = false;
  PADDLE_ENFORCE_EQ(
      self->tensor.is_initialized(), true,
      platform::errors::InvalidArgument(
          "tensor %s has not been initialized, we can only slice initialized "
          "tensor please init it first with numpy or other tensor.",
          self->tensor.name()));
  auto tensor = static_cast<phi::DenseTensor*>(self->tensor.impl().get());
  ParseIndexingSlice(tensor, _index, &slice_axes, &slice_starts, &slice_ends,
                     &slice_strides, &decrease_axis, &none_axes, &infer_flags,
                     &list_select_idxs, &list_select_flag);

  auto out = slice_axes.empty() && !list_select_flag
                 ? self->tensor
                 : paddle::experimental::Tensor(
                       egr::Controller::Instance().GenerateUniqueName());

  if (!slice_axes.empty()) {
    framework::AttributeMap attrs = {{"axes", slice_axes},
                                     {"starts", slice_starts},
                                     {"ends", slice_ends},
                                     {"infer_flags", infer_flags},
                                     {"decrease_axis", decrease_axis}};
    std::string op_type = "slice";
    for (auto stride : slice_strides) {
      if (stride != 1) {
        op_type = "strided_slice";
        attrs.insert({"strides", slice_strides});
        attrs.erase("decrease_axis");
        break;
      }
    }
    if (op_type == "slice") {
      out = slice_dygraph_function(self->tensor, paddle::experimental::Tensor(),
                                   paddle::experimental::Tensor(),
                                   std::move(attrs));
    } else if (op_type == "strided_slice") {
      out = strided_slice_dygraph_function(self->tensor, attrs);
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Slice is only support slice and strided_slice, but we got %s which "
          "is impossible, please check your code first or contact us by "
          "issue. ",
          op_type));
    }
  }

  if (!none_axes.empty()) {
    // Deal with cases when all axes are decreased.
    // After slice, the shape of out is [1], which should have been
    // [], but Paddle doesn't support scalar.
    // In order to ensure the correctness of the final shape of out,
    // one dimension of out needs to be decreased.
    // For example:
    // # x.shape: (2,3,4)
    // out = x[0, 1, 1, None] # out.shape : (1)
    if (static_cast<int>(decrease_axis.size()) == tensor->dims().size()) {
      none_axes.pop_back();
    }
    if (!none_axes.empty()) {
      // Deal with cases that decrease_axes is not empty
      // For example:
      // # x.shape: (2,3,4)
      // out = x[0, 0:2, None] # out.shape : (2, 1, 4)
      for (auto& axis : none_axes) {
        int len = 0;
        for (int da : decrease_axis) {
          if (da < axis) {
            len++;
          }
        }
        axis -= len;
      }

      paddle::experimental::Tensor new_out;
      framework::AttributeMap attrs = {{"axes", none_axes}};
      new_out = std::get<0>(unsqueeze2_dygraph_function(out, std::move(attrs)));
      return ToPyObject(new_out);
    }
  }

  // the index is a list
  if (list_select_flag) {
    auto select_index = paddle::experimental::Tensor(
        egr::Controller::Instance().GenerateUniqueName());
    auto idx_tensor = std::make_shared<phi::DenseTensor>();
    auto* dev_ctx = platform::DeviceContextPool::Instance().Get(
        egr::Controller::Instance().GetExpectedPlace());
    paddle::framework::TensorFromVector(list_select_idxs, *dev_ctx,
                                        idx_tensor.get());
    framework::AttributeMap attrs = {{"dim", 0}};
    out = index_select_dygraph_function(self->tensor, select_index,
                                        std::move(attrs));
  }

  return ToPyObject(out);
615 616 617
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703
static PyObject* tensor_register_grad_hook(TensorObject* self, PyObject* args,
                                           PyObject* kwargs) {
  EAGER_TRY
  int64_t hook_id;
  if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
    VLOG(6) << "Register hook for leaf tensor: " << self->tensor.name();
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    PADDLE_ENFORCE(
        grad_node.get() != nullptr,
        paddle::platform::errors::Fatal("Detected NULL grad_node,"
                                        "Leaf tensor should have had grad_node "
                                        "with type: GradNodeAccumulation."));
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();

    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    auto accumulation_grad_node =
        std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
    hook_id = accumulation_grad_node->RegisterGradientHook(
        rank_info.first, rank_info.second,
        std::make_shared<PyTensorHook>(hook_func));

  } else {
    VLOG(6) << "Register hook for non leaf tensor: " << self->tensor.name();
    std::shared_ptr<egr::GradNodeBase> grad_node =
        egr::EagerUtils::grad_node(self->tensor);
    auto rank_info =
        egr::EagerUtils::unsafe_autograd_meta(self->tensor)->OutRankInfo();

    PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

    hook_id = grad_node->RegisterGradientHook(
        rank_info.first, rank_info.second,
        std::make_shared<PyTensorHook>(hook_func));
  }
  return ToPyObject(hook_id);
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_remove_grad_hook(TensorObject* self, PyObject* args,
                                         PyObject* kwargs) {
  EAGER_TRY
  VLOG(6) << "Remove the registered hook for tensor: " << self->tensor.name();
  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);

  int64_t hook_id = pybind::CastPyArg2AttrLong(PyTuple_GET_ITEM(args, 0), 0);

  return ToPyObject(grad_node->RemoveGradientHook(hook_id));
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_register_reduce_hook(TensorObject* self, PyObject* args,
                                             PyObject* kwargs) {
  EAGER_TRY
  VLOG(4) << "Register reduce hook for tensor: " << self->tensor.name();

  std::shared_ptr<egr::GradNodeBase> grad_node =
      egr::EagerUtils::grad_node(self->tensor);
  PADDLE_ENFORCE_EQ(egr::egr_utils_api::IsLeafTensor(self->tensor), true,
                    platform::errors::InvalidArgument(
                        "Only can register backward hook for leaf Tensor."));
  PADDLE_ENFORCE_EQ(
      !egr::EagerUtils::unsafe_autograd_meta(self->tensor)->StopGradient(),
      true, platform::errors::InvalidArgument(
                "Cannot register backward hook on a Tensor that stop "
                "gradient."));
  PADDLE_ENFORCE(
      grad_node.get() != nullptr,
      paddle::platform::errors::Fatal("Detected NULL grad_node,"
                                      "Leaf tensor should have had grad_node "
                                      "with type: GradNodeAccumulation."));
  PyObject* hook_func = PyTuple_GET_ITEM(args, 0);

  auto accumulation_grad_node =
      std::dynamic_pointer_cast<egr::GradNodeAccumulation>(grad_node);
  accumulation_grad_node->RegisterReduceHook(
      std::make_shared<PyTensorVoidHook>(hook_func));

  Py_INCREF(Py_None);
  return Py_None;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

704 705 706 707 708 709 710 711 712 713 714 715 716 717 718
static PyObject* set_grad_type(TensorObject* self, PyObject* args,
                               PyObject* kwargs) {
  EAGER_TRY
  auto var_type = pybind::CastPyArg2ProtoType(PyTuple_GET_ITEM(args, 0), 0);
  auto grad_tensor =
      egr::EagerUtils::unsafe_autograd_meta(self->tensor)->Grad();
  if (var_type == framework::proto::VarType::LOD_TENSOR) {
    grad_tensor.set_impl(std::make_shared<phi::DenseTensor>());
  } else if (var_type == framework::proto::VarType::SELECTED_ROWS) {
    grad_tensor.set_impl(std::make_shared<phi::SelectedRows>());
  }
  return Py_None;
  EAGER_CATCH_AND_THROW_RETURN_NULL
}

719
PyMethodDef variable_methods[] = {
720
    {"numpy", (PyCFunction)(void (*)(void))tensor_method_numpy,
721 722
     METH_VARARGS | METH_KEYWORDS, NULL},
    {"_is_initialized",
723
     (PyCFunction)(void (*)(void))tensor_method__is_initialized,
724
     METH_VARARGS | METH_KEYWORDS, NULL},
725
    {"_copy_to", (PyCFunction)(void (*)(void))tensor_method__copy_to,
726
     METH_VARARGS | METH_KEYWORDS, NULL},
727
    {"copy_", (PyCFunction)(void (*)(void))tensor_method_copy_,
728
     METH_VARARGS | METH_KEYWORDS, NULL},
729
    {"reconstruct_from_",
730
     (PyCFunction)(void (*)(void))tensor_method_reconstruct_from_,
731
     METH_VARARGS | METH_KEYWORDS, NULL},
732
    {"retain_grads", (PyCFunction)(void (*)(void))tensor_retain_grads,
733
     METH_VARARGS | METH_KEYWORDS, NULL},
734
    {"clear_gradient", (PyCFunction)(void (*)(void))tensor_clear_gradient,
735
     METH_VARARGS | METH_KEYWORDS, NULL},
736
    {"_zero_grads", (PyCFunction)(void (*)(void))tensor__zero_grads,
737
     METH_VARARGS | METH_KEYWORDS, NULL},
738
    {"_share_buffer_to", (PyCFunction)(void (*)(void))tensor__share_buffer_to,
739
     METH_VARARGS | METH_KEYWORDS, NULL},
740
    {"_is_shared_buffer_with",
741
     (PyCFunction)(void (*)(void))tensor__is_shared_buffer_with,
742
     METH_VARARGS | METH_KEYWORDS, NULL},
743
    {"_share_underline_tensor_to",
744
     (PyCFunction)(void (*)(void))tensor__share_underline_tensor_to,
745 746
     METH_VARARGS | METH_KEYWORDS, NULL},
    {"_is_shared_underline_tensor_with",
747
     (PyCFunction)(void (*)(void))tensor__is_shared_underline_tensor_with,
748
     METH_VARARGS | METH_KEYWORDS, NULL},
749
    {"detach", (PyCFunction)(void (*)(void))tensor_method_detach,
750
     METH_VARARGS | METH_KEYWORDS, NULL},
751
    {"get_tensor",
752
     (PyCFunction)(void (*)(void))tensor_method_get_underline_tensor,
753
     METH_VARARGS | METH_KEYWORDS, NULL},
J
Jiabin Yang 已提交
754 755
    {"_getitem_index_not_tensor",
     (PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor,
756
     METH_VARARGS | METH_KEYWORDS, NULL},
757 758 759 760 761 762 763 764
    {"_register_grad_hook",
     (PyCFunction)(void (*)(void))tensor_register_grad_hook,
     METH_VARARGS | METH_KEYWORDS, NULL},
    {"_remove_grad_hook", (PyCFunction)(void (*)(void))tensor_remove_grad_hook,
     METH_VARARGS | METH_KEYWORDS, NULL},
    {"_register_backward_hook",
     (PyCFunction)(void (*)(void))tensor_register_reduce_hook,
     METH_VARARGS | METH_KEYWORDS, NULL},
765 766
    {"_set_grad_type", (PyCFunction)(void (*)(void))set_grad_type,
     METH_VARARGS | METH_KEYWORDS, NULL},
767 768 769 770
    {NULL, NULL, 0, NULL}};

}  // namespace pybind
}  // namespace paddle