diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index e5cfe838c54f3434c15e695340e88d7a720e8c12..7d99a80eaeea879e912f88a0d55c6d072aab0d27 100644 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -14,7 +14,8 @@ set(eager_deps grad_node_info grad_tensor_holder accumulation_node - custom_operator_node) + custom_operator_node + python) set(fluid_deps tracer @@ -77,6 +78,10 @@ cc_library( autograd_meta hook_utils) +cc_library( + saved_tensors_hooks + SRCS saved_tensors_hooks.cc + DEPS hook_utils) if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) add_subdirectory(tests) endif() diff --git a/paddle/fluid/eager/hooks.h b/paddle/fluid/eager/hooks.h index 064c96bff380b865464f78f1c646fc9c698dcc49..f501c4acc621032fb3f9941ae29f19d571cd72cb 100644 --- a/paddle/fluid/eager/hooks.h +++ b/paddle/fluid/eager/hooks.h @@ -62,4 +62,18 @@ class CppVoidHook : public VoidHook { std::function fn_; }; +class PackHookBase { + public: + virtual ~PackHookBase() = default; + virtual void* operator()(const paddle::experimental::Tensor& tensor) = 0; + virtual void* operator()(void* py_tensor) = 0; +}; + +class UnPackHookBase { + public: + virtual ~UnPackHookBase() = default; + virtual paddle::experimental::Tensor operator()(void* packed_value) = 0; + virtual void* operator()(void* packed_value, void* other) = 0; +}; + } // namespace egr diff --git a/paddle/fluid/eager/saved_tensors_hooks.cc b/paddle/fluid/eager/saved_tensors_hooks.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bd62c21611c0c4fcea842b4fc8315b0a281511f --- /dev/null +++ b/paddle/fluid/eager/saved_tensors_hooks.cc @@ -0,0 +1,107 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/saved_tensors_hooks.h" +#include "paddle/fluid/eager/api/utils/global_utils.h" + +#if !(defined(PADDLE_NO_PYTHON) && defined(PADDLE_ON_INFERENCE)) +#include "paddle/fluid/pybind/eager.h" +#include "paddle/fluid/pybind/eager_utils.h" +#endif + +namespace egr { +#if !(defined(PADDLE_NO_PYTHON) && defined(PADDLE_ON_INFERENCE)) +PackHook::PackHook(PyObject* hook) : hook_(hook) { Py_INCREF(hook_); } + +PackHook::~PackHook() { + ::pybind11::gil_scoped_acquire gil; + Py_DECREF(hook_); +} + +void* PackHook::operator()(const paddle::experimental::Tensor& tensor) { + bool grad_tmp = egr::Controller::Instance().HasGrad(); + egr::Controller::Instance().SetHasGrad(false); + ::pybind11::gil_scoped_acquire gil; + auto args = PyTuple_New(1); + PyTuple_SET_ITEM(args, 0, paddle::pybind::ToPyObject(tensor)); + PyObject* ret = PyObject_Call(hook_, args, nullptr); + Py_XDECREF(args); + egr::Controller::Instance().SetHasGrad(grad_tmp); + return reinterpret_cast(ret); +} + +void* PackHook::operator()(void* py_tensor) { + bool grad_tmp = egr::Controller::Instance().HasGrad(); + egr::Controller::Instance().SetHasGrad(false); + ::pybind11::gil_scoped_acquire gil; + auto args = PyTuple_New(1); + Py_INCREF(reinterpret_cast(py_tensor)); + PyTuple_SET_ITEM(args, 0, reinterpret_cast(py_tensor)); + PyObject* ret = PyObject_Call(hook_, args, nullptr); + Py_XDECREF(args); + egr::Controller::Instance().SetHasGrad(grad_tmp); + return reinterpret_cast(ret); +} + +UnPackHook::UnPackHook(PyObject* hook) : hook_(hook) { Py_INCREF(hook_); } + +UnPackHook::~UnPackHook() { + ::pybind11::gil_scoped_acquire gil; + Py_DECREF(hook_); +} + +paddle::experimental::Tensor UnPackHook::operator()(void* packed_value) { + bool grad_tmp = egr::Controller::Instance().HasGrad(); + egr::Controller::Instance().SetHasGrad(false); + ::pybind11::gil_scoped_acquire gil; + auto args = PyTuple_New(1); + Py_INCREF(reinterpret_cast(packed_value)); + PyTuple_SET_ITEM(args, 0, reinterpret_cast(packed_value)); + PyObject* ret = PyObject_Call(hook_, args, nullptr); + Py_XDECREF(args); + egr::Controller::Instance().SetHasGrad(grad_tmp); + + PADDLE_ENFORCE_EQ(paddle::pybind::IsEagerTensor(ret), + true, + paddle::platform::errors::InvalidArgument( + "paddle.autograd.saved_tensors_hooks only one pair " + "of hooks is allowed at a time.")); + + auto tensor = reinterpret_cast(ret)->tensor; + Py_XDECREF(ret); + return tensor; +} + +void* UnPackHook::operator()(void* packed_value, void* other) { + bool grad_tmp = egr::Controller::Instance().HasGrad(); + egr::Controller::Instance().SetHasGrad(false); + ::pybind11::gil_scoped_acquire gil; + auto args = PyTuple_New(1); + Py_INCREF(reinterpret_cast(packed_value)); + PyTuple_SET_ITEM(args, 0, reinterpret_cast(packed_value)); + PyObject* ret = PyObject_Call(hook_, args, nullptr); + Py_XDECREF(args); + egr::Controller::Instance().SetHasGrad(grad_tmp); + + PADDLE_ENFORCE_EQ(paddle::pybind::IsEagerTensor(ret), + true, + paddle::platform::errors::InvalidArgument( + "paddle.autograd.saved_tensors_hooks only one pair " + "of hooks is allowed at a time.")); + + return reinterpret_cast(ret); +} +#endif + +} // namespace egr diff --git a/paddle/fluid/eager/saved_tensors_hooks.h b/paddle/fluid/eager/saved_tensors_hooks.h new file mode 100644 index 0000000000000000000000000000000000000000..1deb30daaa8e1ffaf4a5ac7c337a825cd344be30 --- /dev/null +++ b/paddle/fluid/eager/saved_tensors_hooks.h @@ -0,0 +1,97 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/eager/hooks.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/errors.h" + +namespace egr { +#if !(defined(PADDLE_NO_PYTHON) && defined(PADDLE_ON_INFERENCE)) +class PackHook : public PackHookBase { + public: + explicit PackHook(PyObject* hook); + + ~PackHook(); + + void* operator()(const paddle::experimental::Tensor& tensor) override; + + void* operator()(void* py_tensor) override; + + private: + PyObject* hook_; +}; + +class UnPackHook : public UnPackHookBase { + public: + explicit UnPackHook(PyObject* hook); + + ~UnPackHook(); + + paddle::experimental::Tensor operator()(void* packed_value) override; + + void* operator()(void* packed_value, void* other) override; + + private: + PyObject* hook_; +}; +#endif + +class SavedTensorsHooks { + public: + SavedTensorsHooks() = default; + + ~SavedTensorsHooks() {} + + void SetHooks(PyObject* pack_hook, PyObject* unpack_hook) { +#if !(defined(PADDLE_NO_PYTHON) && defined(PADDLE_ON_INFERENCE)) + PADDLE_ENFORCE_EQ(pack_hook_ == nullptr && unpack_hook_ == nullptr, + true, + paddle::platform::errors::InvalidArgument( + "paddle.autograd.saved_tensors_hooks only one pair " + "of hooks is allowed at a time.")); + pack_hook_ = std::make_shared(pack_hook); + unpack_hook_ = std::make_shared(unpack_hook); + is_enable_ = true; +#endif + } + + void ResetHooks() { +#if !(defined(PADDLE_NO_PYTHON) && defined(PADDLE_ON_INFERENCE)) + pack_hook_ = nullptr; + unpack_hook_ = nullptr; + is_enable_ = false; +#endif + } + + bool IsEnable() { return is_enable_; } + + std::shared_ptr GetPackHook() { return pack_hook_; } + std::shared_ptr GetUnPackHook() { return unpack_hook_; } + + static SavedTensorsHooks& GetInstance() { + static SavedTensorsHooks instance; + return instance; + } + + private: + std::shared_ptr pack_hook_; + std::shared_ptr unpack_hook_; + bool is_enable_{false}; +}; + +} // namespace egr diff --git a/paddle/fluid/eager/tensor_wrapper.h b/paddle/fluid/eager/tensor_wrapper.h index e7994e388298d5c7b117b3a02f65bc7e98ed98ec..35a4c83257f6a7617a1aa48c58b20e9f62fa43dd 100644 --- a/paddle/fluid/eager/tensor_wrapper.h +++ b/paddle/fluid/eager/tensor_wrapper.h @@ -27,6 +27,7 @@ #pragma once #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/saved_tensors_hooks.h" #include "paddle/fluid/eager/utils.h" #include "paddle/phi/api/lib/utils/allocator.h" @@ -69,7 +70,20 @@ class TensorWrapper { "Unrecognized tensor type for no_need_buffer feature")); } } else { - intermidiate_tensor_.set_impl(tensor.impl()); + if (SavedTensorsHooks::GetInstance().IsEnable() && + tensor.is_dense_tensor()) { + phi::DenseTensor* dense_tensor = + static_cast(tensor.impl().get()); + intermidiate_tensor_.set_impl( + std::move(std::make_shared( + std::make_shared(nullptr, 0, tensor.place()), + dense_tensor->meta()))); + auto pack_hook = SavedTensorsHooks::GetInstance().GetPackHook(); + unpack_hook_ = SavedTensorsHooks::GetInstance().GetUnPackHook(); + packed_value_ = reinterpret_cast((*pack_hook)(tensor)); + } else { + intermidiate_tensor_.set_impl(tensor.impl()); + } } if (VLOG_IS_ON(7)) { @@ -86,6 +100,29 @@ class TensorWrapper { } } + TensorWrapper(const TensorWrapper& other) { + no_need_buffer_ = other.no_need_buffer_; + intermidiate_tensor_ = other.intermidiate_tensor_; + weak_grad_node_ = other.weak_grad_node_; + inplace_version_snapshot_ = other.inplace_version_snapshot_; + packed_value_ = other.packed_value_; + unpack_hook_ = other.unpack_hook_; + Py_XINCREF(packed_value_); + } + + TensorWrapper& operator=(const TensorWrapper& other) { + no_need_buffer_ = other.no_need_buffer_; + intermidiate_tensor_ = other.intermidiate_tensor_; + weak_grad_node_ = other.weak_grad_node_; + inplace_version_snapshot_ = other.inplace_version_snapshot_; + packed_value_ = other.packed_value_; + unpack_hook_ = other.unpack_hook_; + Py_XINCREF(packed_value_); + return *this; + } + + ~TensorWrapper() { Py_XDECREF(packed_value_); } + paddle::experimental::Tensor recover() { VLOG(6) << "Recover tensor: " << intermidiate_tensor_.name() << " for wrapper"; @@ -94,7 +131,16 @@ class TensorWrapper { return paddle::experimental::Tensor(); } - check_inplace_version(); + if (packed_value_ && unpack_hook_) { + auto tensor_unpacked = + (*unpack_hook_)(reinterpret_cast(packed_value_)); + auto src_dense_tensor = + static_cast(tensor_unpacked.impl().get()); + static_cast(intermidiate_tensor_.impl().get()) + ->ResetHolder(src_dense_tensor->MoveMemoryHolder()); + } else { + check_inplace_version(); + } paddle::experimental::Tensor recovered_tensor = intermidiate_tensor_; @@ -168,5 +214,7 @@ class TensorWrapper { paddle::experimental::Tensor intermidiate_tensor_; std::weak_ptr weak_grad_node_; uint32_t inplace_version_snapshot_ = 0; + PyObject* packed_value_{nullptr}; + std::shared_ptr unpack_hook_; }; } // namespace egr diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 72885c0bbe5b7ece2dd62a721913d12e1739376f..280d83985def2ea32b1048c397d53f896adbd17d 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -41,7 +41,8 @@ set(PYBIND_DEPS new_profiler auto_parallel jit_layer - jit_property) + jit_property + saved_tensors_hooks) if(WITH_PSCORE) set(PYBIND_DEPS ${PYBIND_DEPS} ps_service) diff --git a/paddle/fluid/pybind/eager.h b/paddle/fluid/pybind/eager.h index f617ead08e24329aa69bf38f5e55715caed0ba9a..8a4a42b82a253a47210ef8eacaeff79c62edad52 100644 --- a/paddle/fluid/pybind/eager.h +++ b/paddle/fluid/pybind/eager.h @@ -12,6 +12,7 @@ limitations under the License. */ #include +#include "paddle/fluid/eager/hooks.h" #include "paddle/fluid/eager/pylayer/py_layer_node.h" #include "paddle/phi/core/dense_tensor.h" #include "pybind11/pybind11.h" @@ -28,6 +29,8 @@ typedef struct { typedef struct { PyObject_HEAD PyObject* container; + bool container_be_packed; + std::shared_ptr unpack_hook; PyObject* non_differentiable; PyObject* not_inplace_tensors; bool materialize_grads; diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 16a5cff031d65af64f0f69708ba1811ad008887d..956d8e5814cc0b7ef5c31b127479a38104df8d87 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -25,6 +25,7 @@ typedef SSIZE_T ssize_t; #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/backward.h" #include "paddle/fluid/eager/custom_operator/custom_operator_node.h" +#include "paddle/fluid/eager/saved_tensors_hooks.h" #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/custom_operator.h" @@ -591,6 +592,29 @@ static PyObject* eager_api_sparse_csr_tensor(PyObject* self, return ToPyObject(tensor); EAGER_CATCH_AND_THROW_RETURN_NULL } + +static PyObject* eager_api_register_saved_tensors_hooks(PyObject* self, + PyObject* args, + PyObject* kwargs) { + EAGER_TRY + if (egr::Controller::Instance().HasGrad()) { + auto pack_hook = PyTuple_GET_ITEM(args, 0); + auto unpack_hook = PyTuple_GET_ITEM(args, 1); + egr::SavedTensorsHooks::GetInstance().SetHooks(pack_hook, unpack_hook); + } + RETURN_PY_NONE + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +static PyObject* eager_api_reset_saved_tensors_hooks(PyObject* self, + PyObject* args, + PyObject* kwargs) { + EAGER_TRY + egr::SavedTensorsHooks::GetInstance().ResetHooks(); + RETURN_PY_NONE + EAGER_CATCH_AND_THROW_RETURN_NULL +} + #if defined(PADDLE_WITH_CUDA) static PyObject* eager_api_async_read(PyObject* self, PyObject* args, @@ -965,6 +989,14 @@ PyMethodDef variable_functions[] = { (PyCFunction)(void (*)(void))eager_api_sparse_csr_tensor, METH_VARARGS | METH_KEYWORDS, NULL}, + {"register_saved_tensors_hooks", + (PyCFunction)(void (*)(void))eager_api_register_saved_tensors_hooks, + METH_VARARGS | METH_KEYWORDS, + NULL}, + {"reset_saved_tensors_hooks", + (PyCFunction)(void (*)(void))eager_api_reset_saved_tensors_hooks, + METH_VARARGS | METH_KEYWORDS, + NULL}, /**sparse functions**/ #if defined(PADDLE_WITH_CUDA) {"async_read", diff --git a/paddle/fluid/pybind/eager_py_layer.cc b/paddle/fluid/pybind/eager_py_layer.cc index 7e25b06e80a4dd70b4fcbaa55021b199e6998580..f39dc6d74f4ebe254378c6cc14c545ffc2d4f85b 100644 --- a/paddle/fluid/pybind/eager_py_layer.cc +++ b/paddle/fluid/pybind/eager_py_layer.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/eager/api/all.h" #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/eager/pylayer/py_layer_node.h" +#include "paddle/fluid/eager/saved_tensors_hooks.h" #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/memory/allocation/allocator.h" @@ -78,6 +79,7 @@ PyObject* PyLayerNew(PyTypeObject* type, PyObject* args, PyObject* kwargs) { if (obj) { auto v = reinterpret_cast(obj); v->materialize_grads = true; + v->container_be_packed = false; new (&v->grad_node) std::weak_ptr(); new (&v->forward_input_tensor_is_duplicable) std::vector(); new (&v->forward_output_tensor_is_duplicable) std::vector(); @@ -96,6 +98,7 @@ static void PyLayerDealloc(PyLayerObject* self) { Py_DECREF(self->not_inplace_tensors); } self->grad_node.~weak_ptr(); + self->unpack_hook = nullptr; self->forward_input_tensor_is_duplicable.~vector(); self->forward_output_tensor_is_duplicable.~vector(); Py_TYPE(self)->tp_free(reinterpret_cast(self)); @@ -455,23 +458,148 @@ PyObject* pylayer_method_apply(PyObject* cls, EAGER_CATCH_AND_THROW_RETURN_NULL } +PyObject* call_unpack_hook(PyLayerObject* self) { + auto unpack_hook = self->unpack_hook; + auto packed_value = self->container; + + auto packed_value_size = PyTuple_GET_SIZE(packed_value); + auto unpacked_value = PyTuple_New(packed_value_size); + + for (Py_ssize_t i = 0; i < packed_value_size; i++) { + PyObject* obj = PyTuple_GET_ITEM(packed_value, i); + if (PyList_Check(obj)) { + Py_ssize_t len = PyList_Size(obj); + auto tmp_list = PyList_New(len); + for (Py_ssize_t j = 0; j < len; j++) { + PyObject* o = PyList_GetItem(obj, j); + PyTuple_SET_ITEM(tmp_list, + j, + reinterpret_cast(((*unpack_hook)( + reinterpret_cast(o), nullptr)))); + } + PyTuple_SET_ITEM(unpacked_value, i, tmp_list); + } else if (PyTuple_Check(obj)) { + Py_ssize_t len = PyTuple_Size(obj); + auto tmp_tuple = PyTuple_New(len); + for (Py_ssize_t j = 0; j < len; j++) { + PyObject* o = PyTuple_GetItem(obj, j); + PyTuple_SET_ITEM(tmp_tuple, + j, + reinterpret_cast((*unpack_hook)( + reinterpret_cast(o), nullptr))); + } + PyTuple_SET_ITEM(unpacked_value, i, tmp_tuple); + } else { + PyTuple_SET_ITEM(unpacked_value, + i, + reinterpret_cast((*unpack_hook)( + reinterpret_cast(obj), nullptr))); + } + } + + return unpacked_value; +} + PyObject* tensor_properties_get_container(PyLayerObject* self, void* closure) { EAGER_TRY if (self->container == nullptr) { RETURN_PY_NONE; } - Py_INCREF(self->container); - return self->container; + if (self->container_be_packed) { + return call_unpack_hook(self); + } else { + Py_INCREF(self->container); + return self->container; + } EAGER_CATCH_AND_THROW_RETURN_NULL } +void call_pack_hook(PyLayerObject* self, PyObject* value) { + PyObject* saved_value = nullptr; + if (PyTuple_Check(value)) { + saved_value = value; + } else if (PyList_Check(value)) { + saved_value = PyList_AsTuple(value); + } else { + saved_value = PyTuple_New(1); + Py_INCREF(value); + PyTuple_SET_ITEM(saved_value, 0, value); + } + + auto pack_hook = egr::SavedTensorsHooks::GetInstance().GetPackHook(); + self->unpack_hook = egr::SavedTensorsHooks::GetInstance().GetUnPackHook(); + + auto saved_value_size = PyTuple_GET_SIZE(saved_value); + PyObject* packed_value = PyTuple_New(saved_value_size); + + for (Py_ssize_t i = 0; i < saved_value_size; i++) { + PyObject* obj = PyTuple_GET_ITEM(saved_value, i); + if (IsEagerTensor(obj)) { + PyTuple_SET_ITEM(packed_value, + i, + reinterpret_cast( + (*pack_hook)(reinterpret_cast(obj)))); + } else if (PyList_Check(obj)) { + Py_ssize_t len = PyList_Size(obj); + auto tmp_list = PyList_New(len); + for (Py_ssize_t j = 0; j < len; j++) { + PyObject* o = PyList_GetItem(obj, j); + if (IsEagerTensor(o)) { + PyTuple_SET_ITEM(tmp_list, + j, + reinterpret_cast( + (*pack_hook)(reinterpret_cast(o)))); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "save_for_backward only support Tensor, list of Tensor, tuple of " + "Tensor.")); + } + } + PyTuple_SET_ITEM(packed_value, i, tmp_list); + } else if (PyTuple_Check(obj)) { + Py_ssize_t len = PyTuple_Size(obj); + auto tmp_tuple = PyTuple_New(len); + for (Py_ssize_t j = 0; j < len; j++) { + PyObject* o = PyTuple_GetItem(obj, j); + if (IsEagerTensor(o)) { + PyTuple_SET_ITEM(tmp_tuple, + j, + reinterpret_cast( + (*pack_hook)(reinterpret_cast(o)))); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "save_for_backward only support Tensor, list of Tensor, tuple of " + "Tensor.")); + } + } + PyTuple_SET_ITEM(packed_value, i, tmp_tuple); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "save_for_backward only support Tensor, list of Tensor, tuple of " + "Tensor.")); + } + } + + if (PyTuple_Check(value)) { + Py_XDECREF(saved_value); + } + + Py_XDECREF(self->container); + self->container = packed_value; + self->container_be_packed = true; +} + int tensor_properties_set_container(PyLayerObject* self, PyObject* value, void* closure) { EAGER_TRY - Py_XINCREF(value); - Py_XDECREF(self->container); - self->container = value; + if (egr::SavedTensorsHooks::GetInstance().IsEnable()) { + call_pack_hook(self, value); + } else { + Py_XINCREF(value); + Py_XDECREF(self->container); + self->container = value; + } return 0; EAGER_CATCH_AND_THROW_RETURN_NEG } diff --git a/python/paddle/autograd/__init__.py b/python/paddle/autograd/__init__.py index 8bc7b11368680e7f714e2b15a0033b0a21068fe2..70fc9647cd4898b94d9d415a618c3e700ccbc0ea 100644 --- a/python/paddle/autograd/__init__.py +++ b/python/paddle/autograd/__init__.py @@ -26,9 +26,11 @@ else: from .py_layer import LegacyPyLayerContext as PyLayerContext # noqa: F401 from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 +from .saved_tensors_hooks import saved_tensors_hooks __all__ = [ # noqa 'backward', 'PyLayer', 'PyLayerContext', + 'saved_tensors_hooks', ] diff --git a/python/paddle/autograd/saved_tensors_hooks.py b/python/paddle/autograd/saved_tensors_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..1e0f292d68b49c9d482fd2f2fa8aca87a9bfd307 --- /dev/null +++ b/python/paddle/autograd/saved_tensors_hooks.py @@ -0,0 +1,111 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.fluid import core + +__all__ = [] + + +class saved_tensors_hooks(): + """ + Dynamic graph, registers a pair of pack / unpack hooks for saved tensors. + + Parameters: + pack_hook (function): The pack hook will be called every time the forward + operation inputs/outputs tensors need be saved for backward. Then you + can save it to CPU or Disk. The input of `pack_hook` is a tensor need + be saved. The output of `pack_hook` is then stored information instead + of the original tensor. `pack_hook` will also be called while any + tensor need be saved by `PyLayerContext.save_for_backward`. If a tensor + saved for backward is no need buffer, `pack_hook` will not be called. + Only the thensor saved for backward is LoDTensor, `pack_hook` will be + called. + unpack_hook (function): The unpack hook will be called every time the + backward need use the saved inputs/outputs tensors. Then you can reload + the tensor and return it to paddle framework. The input of `unpack_hook` + is the information returned by `pack_hook`. The output of `unpack_hook` + is a tensor reloaded by the information, and the tensor mast has the same + content as the original tensor passed as input to the corresponding + `pack_hook`. + + Returns: + None + + Examples: + .. code-block:: python + + # Example1 + import paddle + + def pack_hook(x): + print("Packing", x) + return x.numpy() + + def unpack_hook(x): + print("UnPacking", x) + return paddle.to_tensor(x) + + a = paddle.ones([3,3]) + b = paddle.ones([3,3]) * 2 + a.stop_gradient = False + b.stop_gradient = False + with paddle.autograd.saved_tensors_hooks(pack_hook, unpack_hook): + y = paddle.multiply(a, b) + y.sum().backward() + + # Example2 + import paddle + from paddle.autograd import PyLayer + + class cus_multiply(PyLayer): + @staticmethod + def forward(ctx, a, b): + y = paddle.multiply(a, b) + ctx.save_for_backward(a, b) + return y + + @staticmethod + def backward(ctx, dy): + a,b = ctx.saved_tensor() + grad_a = dy * a + grad_b = dy * b + return grad_a, grad_b + + def pack_hook(x): + print("Packing", x) + return x.numpy() + + def unpack_hook(x): + print("UnPacking", x) + return paddle.to_tensor(x) + + a = paddle.ones([3,3]) + b = paddle.ones([3,3]) * 2 + a.stop_gradient = False + b.stop_gradient = False + with paddle.autograd.saved_tensors_hooks(pack_hook, unpack_hook): + y = cus_multiply.apply(a, b) + y.sum().backward() + """ + + def __init__(self, pack_hook, unpack_hook): + self.pack_hook = pack_hook + self.unpack_hook = unpack_hook + + def __enter__(self): + core.eager.register_saved_tensors_hooks(self.pack_hook, + self.unpack_hook) + + def __exit__(self, *args): + core.eager.reset_saved_tensors_hooks() diff --git a/python/paddle/fluid/tests/unittests/test_saved_tensors_hooks.py b/python/paddle/fluid/tests/unittests/test_saved_tensors_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..7d2791d3058997859745f3759ebf0d16a0faf5c6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_saved_tensors_hooks.py @@ -0,0 +1,92 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +from paddle.autograd import PyLayer + + +class TestSavedTensorsHooks(unittest.TestCase): + + def test_save_for_multiply(self): + + def pack_hook(x): + return x.numpy() + + def unpack_hook(x): + return paddle.to_tensor(x) + + a = paddle.ones([3, 3]) + b = paddle.ones([3, 3]) * 2 + a.stop_gradient = False + b.stop_gradient = False + with paddle.autograd.saved_tensors_hooks(pack_hook, unpack_hook): + y = paddle.multiply(a, b) + y.sum().backward() + + aa = paddle.ones([3, 3]) + bb = paddle.ones([3, 3]) * 2 + aa.stop_gradient = False + bb.stop_gradient = False + yy = paddle.multiply(aa, bb) + yy.sum().backward() + + self.assertTrue(paddle.equal_all(aa.grad, a.grad)) + self.assertTrue(paddle.equal_all(bb.grad, b.grad)) + + def test_save_for_pylayer(self): + + class cus_multiply(PyLayer): + + @staticmethod + def forward(ctx, a, b): + y = paddle.multiply(a, b) + ctx.save_for_backward(a, b) + return y + + @staticmethod + def backward(ctx, dy): + a, b = ctx.saved_tensor() + grad_a = dy * a + grad_b = dy * b + return grad_a, grad_b + + def pack_hook(x): + return x.numpy() + + def unpack_hook(x): + return paddle.to_tensor(x) + + a = paddle.ones([3, 3]) + b = paddle.ones([3, 3]) * 2 + a.stop_gradient = False + b.stop_gradient = False + with paddle.autograd.saved_tensors_hooks(pack_hook, unpack_hook): + y = cus_multiply.apply(a, b) + y.sum().backward() + + aa = paddle.ones([3, 3]) + bb = paddle.ones([3, 3]) * 2 + aa.stop_gradient = False + bb.stop_gradient = False + yy = cus_multiply.apply(aa, bb) + yy.sum().backward() + + self.assertTrue(paddle.equal_all(aa.grad, a.grad)) + self.assertTrue(paddle.equal_all(bb.grad, b.grad)) + + +if __name__ == '__main__': + unittest.main()