From 157c1a28d8fdb30c699604311730a0c409ffebf8 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Wed, 30 Mar 2022 09:36:15 +0800 Subject: [PATCH] [Eager] Pylayer (#39989) * Supported Complex2Real Conversion for Eager Dygraph * Supported Complex2Real Conversion for Eager Dygraph * Enabled complex type promotion test for matmul_v2 * pylayer, test=develop * Fix CI issues * Support initializing specific grad tensors to zero for selected operators * finish forward, test=develop * create grad node finish, test=develop * Merged adj_edges_ with GradSlotMeta * Fixed monir issue * backward finish, start dbg, test=develop * Adjusted num runs * Recovered Eager performance tests configurations * Recovered Eager performance tests configurations * finish, test=develop * polish, test=develop * polish, test=develop * refine, test=develop * eager, test=develop * Adjusted performance tests configurations * Fixed Minor Issues with performance tests * [Phi] Fix macro name typo * support set_materialize_grads, test=develop * suppotr mark_non_differentiable, test=develop * support once_differentiable, test=develop * refine, test=develop * refine, test=develop * Moved out Edge from GradSlotMeta * Fixed issues from merge * Fixed typo * Addressed review comments * Fixed merge issues * Fixed minor issues * Fixed minor issue * refine, test=develop * refine, test=develop * refine, test=develop * Fixed major issues and enabled auto_prune test cases * Fixed issues from merge * refine, test=develop * refine, test=develop * refine, test=develop * refine, test=develop * refine, test=develop Co-authored-by: jim19930609 Co-authored-by: Aurelius84 --- paddle/fluid/eager/CMakeLists.txt | 5 +- paddle/fluid/eager/grad_node_info.h | 1 - paddle/fluid/eager/pylayer/CMakeLists.txt | 1 + paddle/fluid/eager/pylayer/py_layer_node.cc | 159 ++++++ paddle/fluid/eager/pylayer/py_layer_node.h | 82 +++ paddle/fluid/eager/utils.cc | 33 ++ paddle/fluid/eager/utils.h | 7 + paddle/fluid/pybind/CMakeLists.txt | 4 +- paddle/fluid/pybind/eager.cc | 1 + paddle/fluid/pybind/eager.h | 20 + paddle/fluid/pybind/eager_py_layer.cc | 497 ++++++++++++++++++ paddle/fluid/pybind/eager_utils.cc | 97 +++- paddle/fluid/pybind/eager_utils.h | 17 +- python/paddle/autograd/__init__.py | 2 +- python/paddle/autograd/py_layer.py | 269 ++++++++++ .../fluid/tests/unittests/test_pylayer_op.py | 204 +++++-- 16 files changed, 1352 insertions(+), 47 deletions(-) create mode 100644 paddle/fluid/eager/pylayer/CMakeLists.txt create mode 100644 paddle/fluid/eager/pylayer/py_layer_node.cc create mode 100644 paddle/fluid/eager/pylayer/py_layer_node.h create mode 100644 paddle/fluid/pybind/eager_py_layer.cc diff --git a/paddle/fluid/eager/CMakeLists.txt b/paddle/fluid/eager/CMakeLists.txt index 31e6e9f4b6..d8089bedf9 100644 --- a/paddle/fluid/eager/CMakeLists.txt +++ b/paddle/fluid/eager/CMakeLists.txt @@ -11,8 +11,9 @@ endif() add_subdirectory(api) add_subdirectory(accumulation) add_subdirectory(custom_operator) - - +if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) + add_subdirectory(pylayer) +endif() cc_library(grad_node_info SRCS grad_node_info.cc DEPS phi_api phi_tensor) cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulator) diff --git a/paddle/fluid/eager/grad_node_info.h b/paddle/fluid/eager/grad_node_info.h index 81470f38cc..ff4445f426 100644 --- a/paddle/fluid/eager/grad_node_info.h +++ b/paddle/fluid/eager/grad_node_info.h @@ -147,7 +147,6 @@ class GradNodeBase { size_t slot_rank); void SetGradOutMeta(const paddle::experimental::Tensor& fwd_in, size_t slot_rank); - /** * Default setters for Grad in/out meta this should be used for same special * Node which will not create by user diff --git a/paddle/fluid/eager/pylayer/CMakeLists.txt b/paddle/fluid/eager/pylayer/CMakeLists.txt new file mode 100644 index 0000000000..1e5f2dc6cc --- /dev/null +++ b/paddle/fluid/eager/pylayer/CMakeLists.txt @@ -0,0 +1 @@ +cc_library(py_layer_node SRCS py_layer_node.cc DEPS phi phi_api grad_node_info) diff --git a/paddle/fluid/eager/pylayer/py_layer_node.cc b/paddle/fluid/eager/pylayer/py_layer_node.cc new file mode 100644 index 0000000000..5008e958c5 --- /dev/null +++ b/paddle/fluid/eager/pylayer/py_layer_node.cc @@ -0,0 +1,159 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/eager/pylayer/py_layer_node.h" +#include "paddle/fluid/eager/eager_tensor.h" + +#include "paddle/phi/api/all.h" +#include "paddle/phi/core/dense_tensor.h" + +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/pybind/eager.h" +#include "paddle/fluid/pybind/eager_utils.h" + +#include "glog/logging.h" +#pragma GCC diagnostic ignored "-Wattributes" +#include "pybind11/pytypes.h" + +namespace egr { +std::vector> GradNodePyLayer:: +operator()( + std::vector>& grads, // NOLINT + bool create_graph) { + VLOG(3) << "Running Eager Backward Node: " << name(); + + std::vector> hooked_grads = + GradNodePyLayer::ApplyGradientHooks(grads); + + paddle::pybind::PyLayerObject* ctx = + reinterpret_cast(ctx_); + + PADDLE_ENFORCE_EQ(ctx->forward_output_tensor_is_duplicable.size(), + grads.size(), + paddle::platform::errors::InvalidArgument( + "%s's grad input size(%s) mast be equal with it's " + "forward's output size(%s).", + name(), grads.size(), + ctx->forward_output_tensor_is_duplicable.size())); + + auto backward_args = PyTuple_New(grads.size()); + for (size_t i = 0; i < grads.size(); i++) { + if (ctx->forward_output_tensor_is_duplicable[i]) { + PyObject* pylist = PyList_New((Py_ssize_t)grads[i].size()); + for (size_t j = 0; j < grads[i].size(); j++) { + if (ctx->materialize_grads && !grads[i][j].initialized()) { + paddle::experimental::Tensor tensor_tmp; + auto dense_tensor = std::make_shared(); + dense_tensor->set_meta(forward_outputs_meta_[i][j]); + tensor_tmp.set_impl(dense_tensor); + PyList_SET_ITEM( + pylist, static_cast(i), + paddle::pybind::ToPyObject(paddle::experimental::zeros_like( + tensor_tmp, tensor_tmp.dtype(), + forward_outputs_place_[i][j]))); + } else { + PyList_SET_ITEM(pylist, static_cast(i), + paddle::pybind::ToPyObject(grads[i][0], true)); + } + } + PyTuple_SET_ITEM(backward_args, i, pylist); + } else { + if (ctx->materialize_grads && !grads[i][0].initialized()) { + paddle::experimental::Tensor tensor_tmp; + auto dense_tensor = std::make_shared(); + dense_tensor->set_meta(forward_outputs_meta_[i][0]); + tensor_tmp.set_impl(dense_tensor); + PyTuple_SET_ITEM( + backward_args, i, + paddle::pybind::ToPyObject(paddle::experimental::zeros_like( + tensor_tmp, tensor_tmp.dtype(), forward_outputs_place_[i][0]))); + } else { + PyTuple_SET_ITEM(backward_args, i, + paddle::pybind::ToPyObject(grads[i][0], true)); + } + } + } + + VLOG(6) << "PyLayer backward args is ready, begin call user's backward " + "function..."; + + auto backward_fn = + PyObject_GetAttrString(reinterpret_cast(ctx), "backward"); + if (!backward_fn) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Get backward function faild.")); + } + auto outputs = PyObject_CallObject(backward_fn, backward_args); + if (!outputs) { + PADDLE_THROW(paddle::platform::errors::External( + pybind11::detail::error_string().c_str())); + } + + outputs_ = outputs; + + VLOG(6) << "PyLayer backward function finish..."; + + PyObject* outputs_tuple = nullptr; + if (PyTuple_Check(outputs)) { + outputs_tuple = outputs; + } else { + outputs_tuple = PyTuple_New(1); + Py_INCREF(outputs); + PyTuple_SET_ITEM(outputs_tuple, 0, outputs); + } + + size_t outputs_size = PyTuple_GET_SIZE(outputs_tuple); + + if (outputs_size > ctx->forward_input_tensor_is_duplicable.size()) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The number of outputs of `PyLayer.backward` should be %d, but " + "received %d.", + ctx->forward_input_tensor_is_duplicable.size(), outputs_size)); + } + + std::vector> grad_out; + grad_out.reserve(ctx->forward_input_tensor_is_duplicable.size()); + for (size_t i = 0; i < ctx->forward_input_tensor_is_duplicable.size(); i++) { + if (i < outputs_size) { + PyObject* obj = PyTuple_GET_ITEM(outputs_tuple, i); + if (this->OutputMeta()[i][0].IsStopGradient()) { + PADDLE_ENFORCE_EQ( + obj, Py_None, + paddle::platform::errors::InvalidArgument( + "%s's backward function should return None at %d position, " + "because it's forward Tensor's stopgradient is true.", + name(), i)); + grad_out.push_back({}); + } else { + if (ctx->forward_input_tensor_is_duplicable[i]) { + grad_out.push_back(paddle::pybind::GetTensorListFromPyObject(obj)); + } else { + grad_out.push_back({paddle::pybind::GetTensorFromPyObject(obj)}); + } + } + } else { + PADDLE_ENFORCE_EQ( + this->OutputMeta()[i][0].IsStopGradient(), true, + paddle::platform::errors::InvalidArgument( + "%s's backward function should not return empyt at %d position.", + name(), i)); + grad_out.push_back({}); + } + } + + return grad_out; +} +} // namespace egr diff --git a/paddle/fluid/eager/pylayer/py_layer_node.h b/paddle/fluid/eager/pylayer/py_layer_node.h new file mode 100644 index 0000000000..cd0a517afb --- /dev/null +++ b/paddle/fluid/eager/pylayer/py_layer_node.h @@ -0,0 +1,82 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/hooks.h" +#include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/tensor_meta.h" + +namespace egr { + +class GradNodePyLayer : public GradNodeBase { + public: + GradNodePyLayer(PyObject* ctx, size_t bwd_in_slot_num, + size_t bwd_out_slot_num) + : GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) { + ctx_ = ctx; + } + + ~GradNodePyLayer() override { Py_DECREF(ctx_); }; + + virtual std::vector> operator()( + std::vector>& grads, // NOLINT + bool create_graph = false) override; + + void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } + + bool IsTensorWrappersCleared() override { + VLOG(6) << "Do nothing here now"; + return false; + } + + std::string name() { + return "GradNodePyLayer_" + std::string(Py_TYPE(ctx_)->tp_name); + } + + // for paddle.grad get result + PyObject* GetMutableOutputs() { return outputs_; } + + void SaveForwardOutputsMeta( + const std::vector>& + outputs_tensor) { + forward_outputs_meta_.resize(outputs_tensor.size()); + forward_outputs_place_.resize(outputs_tensor.size()); + for (size_t i = 0; i < outputs_tensor.size(); i++) { + forward_outputs_meta_[i].reserve(outputs_tensor[i].size()); + forward_outputs_place_[i].reserve(outputs_tensor[i].size()); + for (auto tensor : outputs_tensor[i]) { + if (tensor->is_dense_tensor()) { + forward_outputs_meta_[i].push_back( + static_cast(tensor->impl().get())->meta()); + } else { + forward_outputs_meta_[i].emplace_back(); + } + forward_outputs_place_[i].emplace_back(tensor->inner_place()); + } + } + } + + private: + PyObject* ctx_{nullptr}; + PyObject* outputs_{nullptr}; + std::vector> forward_outputs_meta_; + std::vector> forward_outputs_place_; +}; + +} // namespace egr diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index 3d0972783d..34dd9d8d34 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -90,6 +90,16 @@ std::vector EagerUtils::nullable_autograd_meta( return metas; } +std::vector EagerUtils::nullable_autograd_meta( + const std::vector& targets) { + std::vector metas; + metas.reserve(targets.size()); + for (const paddle::experimental::Tensor* t : targets) { + metas.emplace_back(nullable_autograd_meta(*t)); + } + return metas; +} + std::vector EagerUtils::autograd_meta( std::vector* targets) { std::vector ret; @@ -103,6 +113,19 @@ std::vector EagerUtils::autograd_meta( return ret; } +std::vector EagerUtils::autograd_meta( + std::vector* targets) { + std::vector ret; + ret.reserve(targets->size()); + + // for autograd_meta we can tolerent it has nullptr. + for (size_t i = 0; i < targets->size(); i++) { + auto* p_autograd_meta = autograd_meta((*targets)[i]); + ret.emplace_back(p_autograd_meta); + } + return ret; +} + std::pair EagerUtils::OutRankInfo( const paddle::experimental::Tensor& target) { return unsafe_autograd_meta(target)->OutRankInfo(); @@ -380,6 +403,16 @@ void EagerUtils::CheckAndRetainGrad( } } +void EagerUtils::CheckAndRetainGrad( + const std::vector& tensors) { + if (FLAGS_retain_grad_for_all_tensor) { + for (auto& tensor : tensors) { + VLOG(6) << "RetainGradForTensor: " << tensor->name(); + egr::egr_utils_api::RetainGradForTensor(*tensor); + } + } +} + std::shared_ptr EagerUtils::GetGradAccumulationNode( const paddle::experimental::Tensor& tensor) { auto* autograd_ptr = nullable_autograd_meta(tensor); diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index 537d6c59c0..c7f14cd021 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -98,6 +98,9 @@ class EagerUtils { static std::vector autograd_meta( std::vector* targets); + static std::vector autograd_meta( + std::vector* targets); + static std::pair OutRankInfo( const paddle::experimental::Tensor& target); @@ -125,6 +128,8 @@ class EagerUtils { paddle::optional target); static std::vector nullable_autograd_meta( const std::vector& targets); + static std::vector nullable_autograd_meta( + const std::vector& targets); static AutogradMeta* unsafe_autograd_meta( const paddle::experimental::Tensor& target); static std::vector unsafe_autograd_meta( @@ -220,6 +225,8 @@ class EagerUtils { static void CheckAndRetainGrad(const paddle::experimental::Tensor& tensor); static void CheckAndRetainGrad( const std::vector& tensors); + static void CheckAndRetainGrad( + const std::vector& tensors); static std::shared_ptr GetGradAccumulationNode( const paddle::experimental::Tensor& tensor); diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 7b223f7ed2..52af9bb236 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -350,8 +350,8 @@ if(WITH_PYTHON) if(NOT ((NOT WITH_PYTHON) AND ON_INFER)) cc_library(paddle_eager - SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc - DEPS eager_api autograd_meta backward grad_node_info phi op_function_common final_dygraph_function final_dygraph_node dygraph_function dygraph_node accumulation_node global_utils utils python custom_operator custom_operator_node) + SRCS eager.cc eager_functions.cc eager_method.cc eager_properties.cc eager_utils.cc eager_py_layer.cc + DEPS eager_api autograd_meta backward grad_node_info phi op_function_common final_dygraph_function final_dygraph_node dygraph_function dygraph_node accumulation_node py_layer_node global_utils utils python custom_operator custom_operator_node) add_dependencies(paddle_eager eager_codegen) add_dependencies(paddle_eager eager_op_function_generator_cmd) list(APPEND PYBIND_DEPS paddle_eager) diff --git a/paddle/fluid/pybind/eager.cc b/paddle/fluid/pybind/eager.cc index 1da657ac1f..5278f371dd 100644 --- a/paddle/fluid/pybind/eager.cc +++ b/paddle/fluid/pybind/eager.cc @@ -753,6 +753,7 @@ void BindEager(pybind11::module* module) { } BindFunctions(m.ptr()); + BindEagerPyLayer(m.ptr()); BindEagerOpFunctions(&m); } diff --git a/paddle/fluid/pybind/eager.h b/paddle/fluid/pybind/eager.h index c1a869d9b8..bb55ef62ee 100644 --- a/paddle/fluid/pybind/eager.h +++ b/paddle/fluid/pybind/eager.h @@ -14,11 +14,31 @@ limitations under the License. */ #include "pybind11/pybind11.h" #include "pybind11/stl.h" +#include "paddle/fluid/eager/pylayer/py_layer_node.h" +#include "paddle/phi/core/dense_tensor.h" + namespace paddle { namespace pybind { +typedef struct { + PyObject_HEAD paddle::experimental::Tensor tensor; +} TensorObject; + +typedef struct { + PyObject_HEAD + + PyObject* container; + PyObject* non_differentiable; + PyObject* dirty_tensors; + bool materialize_grads; + std::vector forward_input_tensor_is_duplicable; + std::vector forward_output_tensor_is_duplicable; + std::weak_ptr grad_node; +} PyLayerObject; + void BindEager(pybind11::module* m); void BindFunctions(PyObject* module); +void BindEagerPyLayer(PyObject* module); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/eager_py_layer.cc b/paddle/fluid/pybind/eager_py_layer.cc new file mode 100644 index 0000000000..9e9231415f --- /dev/null +++ b/paddle/fluid/pybind/eager_py_layer.cc @@ -0,0 +1,497 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +// disable numpy compile error +#include + +#include +#include +#include + +#pragma GCC diagnostic ignored "-Wattributes" +#include "pybind11/pytypes.h" + +#include "paddle/fluid/eager/accumulation/accumulation_node.h" +#include "paddle/fluid/eager/api/all.h" +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/pylayer/py_layer_node.h" +#include "paddle/fluid/eager/utils.h" +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/pybind/eager.h" +#include "paddle/fluid/pybind/eager_utils.h" +#include "paddle/fluid/pybind/exception.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/compat/convert_utils.h" +#include "paddle/phi/core/dense_tensor.h" +#include "pybind11/detail/internals.h" + +namespace paddle { +namespace pybind { + +namespace py = ::pybind11; + +PyTypeObject* p_pylayer_type; +extern PyTypeObject* p_tensor_type; + +std::set GetNonDifferentiableNames( + PyObject* obj) { + std::set result; + if (obj == nullptr) { + return result; + } + if (IsEagerTensor(obj)) { + result.insert(&reinterpret_cast(obj)->tensor); // NOLINT + } else if (PyList_Check(obj)) { + Py_ssize_t len = PyList_Size(obj); + for (Py_ssize_t i = 0; i < len; i++) { + if (IsEagerTensor(PyList_GetItem(obj, i))) { + result.insert( + &reinterpret_cast(PyList_GetItem(obj, i)) // NOLINT + ->tensor); + } + } + } else if (PyTuple_Check(obj)) { + Py_ssize_t len = PyTuple_Size(obj); + for (Py_ssize_t i = 0; i < len; i++) { + if (IsEagerTensor(PyTuple_GetItem(obj, i))) { + result.insert( + &reinterpret_cast(PyTuple_GetItem(obj, i)) // NOLINT + ->tensor); + } + } + } + return result; +} + +PyObject* PyLayerNew(PyTypeObject* type, PyObject* args, PyObject* kwargs) { + PyObject* obj = type->tp_alloc(type, 0); + if (obj) { + auto v = reinterpret_cast(obj); + v->materialize_grads = true; + new (&v->grad_node) std::weak_ptr(); + new (&v->forward_input_tensor_is_duplicable) std::vector(); + new (&v->forward_output_tensor_is_duplicable) std::vector(); + } + return obj; +} + +static void PyLayerDealloc(PyLayerObject* self) { + if (self->container) { + Py_DECREF(self->container); + } + if (self->non_differentiable) { + Py_DECREF(self->non_differentiable); + } + if (self->dirty_tensors) { + Py_DECREF(self->dirty_tensors); + } + self->grad_node.~weak_ptr(); + self->forward_input_tensor_is_duplicable.~vector(); + self->forward_output_tensor_is_duplicable.~vector(); + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + +PyObject* pylayer_method_name(PyObject* self, PyObject* noargs) { + EAGER_TRY + return ToPyObject( + reinterpret_cast(self)->grad_node.lock()->name()); + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +PyObject* pylayer_method_apply(PyObject* cls, PyObject* args, + PyObject* kwargs) { + EAGER_TRY + VLOG(6) << "Begin run PyLayer apply..."; + PyObject* backward_function = + PyObject_GetAttrString(cls, "_backward_function"); + if (!backward_function) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Get _backward_function faild.")); + } + PyLayerObject* ctx = reinterpret_cast( + PyObject_CallFunctionObjArgs(backward_function, nullptr)); + if (!ctx) { + return nullptr; + } + VLOG(6) << "PyLayer construct PyLayerContext finish..."; + + bool require_any_grad = false; + + size_t inputs_size = 0; + PyObject* forward_args = nullptr; + PyObject* kwargs_value_list = nullptr; + if (kwargs) { + inputs_size = PyDict_Size(kwargs); + kwargs_value_list = PyDict_Values(kwargs); + forward_args = PyTuple_New(1); + } else { + inputs_size = PyTuple_GET_SIZE(args); + forward_args = PyTuple_New(inputs_size + 1); + } + Py_INCREF(ctx); + PyTuple_SET_ITEM(forward_args, 0, reinterpret_cast(ctx)); + + std::vector> inputs_autograd_meta; + inputs_autograd_meta.reserve(inputs_size); + std::vector> inputs_tensor; + inputs_tensor.reserve(inputs_size); + ctx->forward_input_tensor_is_duplicable.clear(); + ctx->forward_input_tensor_is_duplicable.reserve(inputs_size); + for (size_t i = 0; i < inputs_size; i++) { + PyObject* obj = nullptr; + if (kwargs) { + obj = PyList_GetItem(kwargs_value_list, i); + } else { + obj = PyTuple_GET_ITEM(args, i); + } + if (IsEagerTensor(obj)) { + auto autograd_meta = egr::EagerUtils::nullable_autograd_meta( + reinterpret_cast(obj)->tensor); + inputs_autograd_meta.push_back({autograd_meta}); + inputs_tensor.push_back( + {&(reinterpret_cast(obj)->tensor)}); // NOLINT + bool stop_gradient = + autograd_meta == nullptr ? true : autograd_meta->StopGradient(); + if (!stop_gradient) { + require_any_grad = true; + } + ctx->forward_input_tensor_is_duplicable.push_back(false); + } else if (PyList_Check(obj)) { + std::vector tensors; + Py_ssize_t len = PyList_Size(obj); + for (Py_ssize_t i = 0; i < len; i++) { + if (IsEagerTensor(PyList_GetItem(obj, i))) { + tensors.push_back(&( + reinterpret_cast(PyList_GetItem(obj, i))->tensor)); + } + } + if (!tensors.empty()) { + auto autograd_meta = egr::EagerUtils::nullable_autograd_meta(tensors); + for (auto iter : autograd_meta) { + bool stop_gradient = iter == nullptr ? true : iter->StopGradient(); + if (!stop_gradient) { + require_any_grad = true; + } + } + inputs_autograd_meta.push_back(autograd_meta); + inputs_tensor.push_back(tensors); + ctx->forward_input_tensor_is_duplicable.push_back(true); + } + } else if (PyTuple_Check(obj)) { + std::vector tensors; + Py_ssize_t len = PyTuple_Size(obj); + for (Py_ssize_t i = 0; i < len; i++) { + if (IsEagerTensor(PyTuple_GetItem(obj, i))) { + tensors.push_back( + &(reinterpret_cast(PyTuple_GetItem(obj, i)) + ->tensor)); + } + } + if (!tensors.empty()) { + auto autograd_meta = egr::EagerUtils::nullable_autograd_meta(tensors); + for (auto iter : autograd_meta) { + bool stop_gradient = iter == nullptr ? true : iter->StopGradient(); + if (!stop_gradient) { + require_any_grad = true; + } + } + inputs_autograd_meta.push_back(autograd_meta); + inputs_tensor.push_back(tensors); + ctx->forward_input_tensor_is_duplicable.push_back(true); + } + } + + if (!kwargs) { + Py_INCREF(obj); + PyTuple_SET_ITEM(forward_args, i + 1, obj); + } + } + + VLOG(6) + << "PyLayer forward args is ready, begin call user's forward function..."; + // call forward + auto forward_fn = PyObject_GetAttrString(cls, "forward"); + if (!forward_fn) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Get forward function faild.")); + } + bool trace_backward = egr::Controller::Instance().HasGrad(); + egr::Controller::Instance().SetHasGrad(false); + auto outputs = PyObject_Call(forward_fn, forward_args, kwargs); + egr::Controller::Instance().SetHasGrad(trace_backward); + if (!outputs) { + return nullptr; + } + + PyObject* outputs_tuple = nullptr; + if (PyTuple_Check(outputs)) { + outputs_tuple = outputs; + } else { + outputs_tuple = PyTuple_New(1); + Py_INCREF(outputs); + PyTuple_SET_ITEM(outputs_tuple, 0, outputs); + } + + auto outputs_size = PyTuple_GET_SIZE(outputs_tuple); + std::vector> outputs_tensor; + outputs_tensor.reserve(outputs_size); + std::vector> outputs_autograd_meta; + outputs_autograd_meta.reserve(outputs_size); + ctx->forward_output_tensor_is_duplicable.clear(); + ctx->forward_output_tensor_is_duplicable.reserve(outputs_size); + for (Py_ssize_t i = 0; i < outputs_size; i++) { + PyObject* obj = PyTuple_GET_ITEM(outputs_tuple, i); + if (IsEagerTensor(obj)) { + outputs_tensor.push_back( + {&(reinterpret_cast(obj)->tensor)}); + outputs_autograd_meta.push_back({egr::EagerUtils::autograd_meta( + &(reinterpret_cast(obj)->tensor))}); + ctx->forward_output_tensor_is_duplicable.push_back(false); + } else if (PyList_Check(obj)) { + std::vector tensors; + Py_ssize_t len = PyList_Size(obj); + for (Py_ssize_t i = 0; i < len; i++) { + if (IsEagerTensor(PyList_GetItem(obj, i))) { + tensors.push_back(&( + reinterpret_cast(PyList_GetItem(obj, i))->tensor)); + } + } + if (!tensors.empty()) { + outputs_tensor.push_back(tensors); + outputs_autograd_meta.push_back( + egr::EagerUtils::autograd_meta(&tensors)); + ctx->forward_output_tensor_is_duplicable.push_back(true); + } + } else if (PyTuple_Check(obj)) { + std::vector tensors; + Py_ssize_t len = PyTuple_Size(obj); + for (Py_ssize_t i = 0; i < len; i++) { + if (IsEagerTensor(PyTuple_GetItem(obj, i))) { + tensors.push_back( + &(reinterpret_cast(PyTuple_GetItem(obj, i)) + ->tensor)); + } + } + if (!tensors.empty()) { + outputs_tensor.push_back(tensors); + outputs_autograd_meta.push_back( + egr::EagerUtils::autograd_meta(&tensors)); + ctx->forward_output_tensor_is_duplicable.push_back(true); + } + } + } + + if (outputs_tensor.size() == 0) { + PADDLE_THROW(platform::errors::InvalidArgument( + "At least one output of `PyLayer.forward` is a `Tensor`.")); + } + VLOG(6) << "PyLayer forward function finish..."; + + if (require_any_grad && trace_backward) { + auto non_differentiable = + GetNonDifferentiableNames(ctx->non_differentiable); + for (size_t i = 0; i < outputs_autograd_meta.size(); i++) { + for (size_t j = 0; j < outputs_autograd_meta[i].size(); j++) { + if (non_differentiable.find(outputs_tensor[i][j]) != + non_differentiable.end()) { + outputs_autograd_meta[i][j]->SetStopGradient(true); + } else { + outputs_autograd_meta[i][j]->WeakSetStopGradient(false); + } + } + } + + // TODO(pangyoki) add inplace, inplaced tensor is ctx->dirty_tensors + + auto grad_node = std::make_shared( + reinterpret_cast(ctx), outputs_autograd_meta.size(), + inputs_autograd_meta.size()); + ctx->grad_node = grad_node; + + if (ctx->materialize_grads) { + grad_node->SaveForwardOutputsMeta(outputs_tensor); + } + + for (size_t i = 0; i < inputs_autograd_meta.size(); i++) { + if (ctx->forward_input_tensor_is_duplicable[i]) { + for (auto t : inputs_tensor[i]) { + grad_node->SetGradOutMeta(*t, i); + } + grad_node->AddEdges(&inputs_autograd_meta[i], i); + } else { + grad_node->SetGradOutMeta(*inputs_tensor[i][0], i); + grad_node->AddEdges(inputs_autograd_meta[i][0], i); + } + } + + for (size_t i = 0; i < outputs_autograd_meta.size(); i++) { + if (ctx->forward_output_tensor_is_duplicable[i]) { + egr::EagerUtils::SetOutRankWithSlot(&outputs_autograd_meta[i], i); + egr::EagerUtils::SetHistory(&outputs_autograd_meta[i], grad_node); + for (auto t : outputs_tensor[i]) { + grad_node->SetGradInMeta(*t, i); + } + egr::EagerUtils::CheckAndRetainGrad(outputs_tensor[i]); + } else { + egr::EagerUtils::SetOutRankWithSlot(outputs_autograd_meta[i][0], i); + egr::EagerUtils::SetHistory(outputs_autograd_meta[i][0], grad_node); + grad_node->SetGradInMeta(*outputs_tensor[i][0], i); + egr::EagerUtils::CheckAndRetainGrad(*outputs_tensor[i][0]); + } + } + VLOG(6) << "PyLayer construct backward node finish..."; + } + + return outputs; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +PyObject* pylayer_method_register_hook(PyObject* _self, PyObject* hook) { + EAGER_TRY + return nullptr; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +PyObject* tensor_properties_get_container(PyLayerObject* self, void* closure) { + EAGER_TRY + if (self->container == nullptr) { + Py_INCREF(Py_None); + return Py_None; + } + Py_INCREF(self->container); + return self->container; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +int tensor_properties_set_container(PyLayerObject* self, PyObject* value, + void* closure) { + EAGER_TRY + Py_XINCREF(value); + Py_XDECREF(self->container); + self->container = value; + return 0; + EAGER_CATCH_AND_THROW_RETURN_ZERO +} + +PyObject* tensor_properties_get_non_differentiable(PyLayerObject* self, + void* closure) { + EAGER_TRY + if (self->non_differentiable == nullptr) { + Py_INCREF(Py_None); + return Py_None; + } + Py_INCREF(self->non_differentiable); + return self->non_differentiable; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +int tensor_properties_set_non_differentiable(PyLayerObject* self, + PyObject* value, void* closure) { + EAGER_TRY + Py_XINCREF(value); + Py_XDECREF(self->non_differentiable); + self->non_differentiable = value; + return 0; + EAGER_CATCH_AND_THROW_RETURN_ZERO +} + +PyObject* tensor_properties_get_dirty_tensors(PyLayerObject* self, + void* closure) { + EAGER_TRY + if (self->dirty_tensors == nullptr) { + Py_INCREF(Py_None); + return Py_None; + } + Py_INCREF(self->dirty_tensors); + return self->dirty_tensors; + EAGER_CATCH_AND_THROW_RETURN_NULL +} + +int tensor_properties_set_dirty_tensors(PyLayerObject* self, PyObject* value, + void* closure) { + EAGER_TRY + Py_XINCREF(value); + Py_XDECREF(self->dirty_tensors); + self->dirty_tensors = value; + return 0; + EAGER_CATCH_AND_THROW_RETURN_ZERO +} + +int tensor_properties_set_materialize_grads(PyLayerObject* self, + PyObject* value, void* closure) { + EAGER_TRY + self->materialize_grads = CastPyArg2AttrBoolean(value, 0); + return 0; + EAGER_CATCH_AND_THROW_RETURN_ZERO +} + +PyMethodDef pylayer_methods[] = { + {"name", (PyCFunction)(void (*)(void))pylayer_method_name, METH_NOARGS, + NULL}, + {"apply", (PyCFunction)(void (*)(void))pylayer_method_apply, + METH_CLASS | METH_VARARGS | METH_KEYWORDS, NULL}, + {"register_hook", (PyCFunction)(void (*)(void))pylayer_method_register_hook, + METH_O, NULL}, + {NULL, NULL, 0, NULL}}; + +struct PyGetSetDef pylayer_properties[]{ + {"container", (getter)tensor_properties_get_container, + (setter)tensor_properties_set_container, nullptr, nullptr}, + {"non_differentiable", (getter)tensor_properties_get_non_differentiable, + (setter)tensor_properties_set_non_differentiable, nullptr, nullptr}, + {"dirty_tensors", (getter)tensor_properties_get_dirty_tensors, + (setter)tensor_properties_set_dirty_tensors, nullptr, nullptr}, + {"materialize_grads", nullptr, + (setter)tensor_properties_set_materialize_grads, nullptr, nullptr}, + {nullptr, nullptr, nullptr, nullptr, nullptr}}; + +void BindEagerPyLayer(PyObject* module) { + auto heap_type = reinterpret_cast( + PyType_Type.tp_alloc(&PyType_Type, 0)); + heap_type->ht_name = ToPyObject("PyLayer"); + heap_type->ht_qualname = ToPyObject("PyLayer"); + auto type = &heap_type->ht_type; + type->tp_name = "PyLayer"; + type->tp_basicsize = sizeof(PyLayerObject); + type->tp_dealloc = (destructor)PyLayerDealloc; + type->tp_methods = pylayer_methods; + type->tp_getset = pylayer_properties; + type->tp_new = PyLayerNew; + Py_INCREF(&PyBaseObject_Type); + type->tp_base = reinterpret_cast(&PyBaseObject_Type); + type->tp_flags |= + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HEAPTYPE; +#if PY_VERSION_HEX >= 0x03050000 + type->tp_as_async = &heap_type->as_async; +#endif + p_pylayer_type = type; + + if (PyType_Ready(type) < 0) { + PADDLE_THROW(platform::errors::Fatal( + "Init Paddle error in BindEager(PyType_Ready).")); + return; + } + + Py_INCREF(type); + if (PyModule_AddObject(module, "PyLayer", reinterpret_cast(type)) < + 0) { + Py_DECREF(type); + Py_DECREF(module); + PADDLE_THROW(platform::errors::Fatal( + "Init Paddle error in BindEager(PyModule_AddObject).")); + return; + } +} + +} // namespace pybind +} // namespace paddle diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index b4d316a957..17300e5ce9 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -182,6 +182,10 @@ std::string CastPyArg2AttrString(PyObject* obj, ssize_t arg_pos) { } } +bool IsEagerTensor(PyObject* obj) { + return PyObject_IsInstance(obj, reinterpret_cast(p_tensor_type)); +} + paddle::experimental::Tensor CastPyArg2Tensor(PyObject* obj, ssize_t arg_pos) { if (PyObject_IsInstance(obj, reinterpret_cast(p_tensor_type))) { return reinterpret_cast(obj)->tensor; @@ -434,7 +438,12 @@ PyObject* ToPyObject(const std::string& value) { return PyUnicode_FromString(value.c_str()); } -PyObject* ToPyObject(const paddle::experimental::Tensor& value) { +PyObject* ToPyObject(const paddle::experimental::Tensor& value, + bool return_py_none_if_not_initialize) { + if (return_py_none_if_not_initialize && !value.initialized()) { + Py_INCREF(Py_None); + return Py_None; + } PyObject* obj = p_tensor_type->tp_alloc(p_tensor_type, 0); if (obj) { auto v = reinterpret_cast(obj); @@ -792,6 +801,92 @@ std::vector GetTensorPtrListFromArgs( return result; } +std::vector GetTensorPtrListFromPyObject( + PyObject* obj) { + std::vector result; + + if (PyList_Check(obj)) { + Py_ssize_t len = PyList_Size(obj); + if (len == 0) { + PADDLE_THROW( + platform::errors::InvalidArgument("The list of Tensor is empty.")); + } + for (Py_ssize_t i = 0; i < len; i++) { + result.emplace_back( + &(reinterpret_cast(PyList_GetItem(obj, i))->tensor)); + } + } else if (PyTuple_Check(obj)) { + Py_ssize_t len = PyTuple_Size(obj); + if (len == 0) { + PADDLE_THROW( + platform::errors::InvalidArgument("The tuple of Tensor is empty.")); + } + for (Py_ssize_t i = 0; i < len; i++) { + result.emplace_back( + &(reinterpret_cast(PyTuple_GetItem(obj, i))->tensor)); + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "The PyObject must be list of Tensors, but got " + "%s", + (reinterpret_cast(obj->ob_type))->tp_name)); + } + + return result; +} + +std::vector GetTensorListFromPyObject( + PyObject* obj) { + std::vector result; + if (PyList_Check(obj)) { + Py_ssize_t len = PyList_Size(obj); + PyObject* item = nullptr; + for (Py_ssize_t i = 0; i < len; i++) { + item = PyList_GetItem(obj, i); + if (PyObject_IsInstance(item, + reinterpret_cast(p_tensor_type))) { + result.emplace_back(reinterpret_cast(item)->tensor); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument must be " + "list of Tensor, but got %s at pos %d", + reinterpret_cast(item->ob_type)->tp_name, i)); + } + } + } else if (PyTuple_Check(obj)) { + Py_ssize_t len = PyTuple_Size(obj); + PyObject* item = nullptr; + for (Py_ssize_t i = 0; i < len; i++) { + item = PyTuple_GetItem(obj, i); + if (PyObject_IsInstance(item, + reinterpret_cast(p_tensor_type))) { + result.emplace_back(reinterpret_cast(item)->tensor); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument must be " + "list of Tensor, but got %s at pos %d", + reinterpret_cast(item->ob_type)->tp_name, i)); + } + } + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument must be " + "list or tuple, but got %s", + reinterpret_cast(obj->ob_type)->tp_name)); + } + return result; +} + +paddle::experimental::Tensor& GetTensorFromPyObject(PyObject* obj) { + if (!IsEagerTensor(obj)) { + PADDLE_THROW(platform::errors::InvalidArgument( + "argument must be " + "Tensor, but got %s", + reinterpret_cast(obj->ob_type)->tp_name)); + } + return reinterpret_cast(obj)->tensor; +} + paddle::experimental::Scalar CastPyArg2Scalar(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index 13565cfe70..15d289d7bc 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -26,12 +26,10 @@ class Scope; } namespace pybind { -typedef struct { - PyObject_HEAD paddle::experimental::Tensor tensor; -} TensorObject; - int TensorDtype2NumpyDtype(phi::DataType dtype); +bool IsEagerTensor(PyObject* obj); + bool PyObject_CheckLongOrConvertToLong(PyObject** obj); bool PyObject_CheckFloatOrConvertToFloat(PyObject** obj); bool PyObject_CheckStr(PyObject* obj); @@ -63,7 +61,8 @@ PyObject* ToPyObject(float value); PyObject* ToPyObject(double value); PyObject* ToPyObject(const char* value); PyObject* ToPyObject(const std::string& value); -PyObject* ToPyObject(const paddle::experimental::Tensor& value); +PyObject* ToPyObject(const paddle::experimental::Tensor& value, + bool return_py_none_if_not_initialize = false); PyObject* ToPyObject(const paddle::experimental::Tensor& value, ssize_t value_idx, PyObject* args, ssize_t arg_idx); PyObject* ToPyObject(const std::vector& value); @@ -185,6 +184,14 @@ std::vector GetTensorPtrListFromArgs( const std::string& op_type, const std::string& arg_name, PyObject* args, ssize_t arg_idx, bool dispensable = false); +std::vector GetTensorPtrListFromPyObject( + PyObject* obj); + +std::vector GetTensorListFromPyObject( + PyObject* obj); + +paddle::experimental::Tensor& GetTensorFromPyObject(PyObject* obj); + // end of Slice related methods std::vector GetScopePtrListFromArgs( diff --git a/python/paddle/autograd/__init__.py b/python/paddle/autograd/__init__.py index 86500db659..7aab7117de 100644 --- a/python/paddle/autograd/__init__.py +++ b/python/paddle/autograd/__init__.py @@ -15,7 +15,7 @@ from ..fluid.dygraph.base import grad # noqa: F401 from . import backward_mode # noqa: F401 from .backward_mode import backward # noqa: F401 -from .py_layer import PyLayer, PyLayerContext # noqa: F401 +from .py_layer import PyLayer, PyLayerContext, EagerPyLayer, EagerPyLayerContext # noqa: F401 from ..framework import set_grad_enabled, is_grad_enabled # noqa: F401 from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 from .functional import jacobian, hessian, batch_jacobian, batch_hessian # noqa: F401 diff --git a/python/paddle/autograd/py_layer.py b/python/paddle/autograd/py_layer.py index 26740dfd0f..0fb90b334f 100644 --- a/python/paddle/autograd/py_layer.py +++ b/python/paddle/autograd/py_layer.py @@ -327,3 +327,272 @@ class PyLayer(with_mateclass(LayerMeta, CPyLayer)): raise NotImplementedError( "You must implement the backward function for PyLayer.") + + +class EagerPyLayerContext(object): + def save_for_backward(self, *tensors): + """ + Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors. + + .. note:: + This API should be called at most once, and only inside `forward`. + + Args: + tensors(list of Tensors): Tensors to be stored. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + # ctx is a context object that store some objects for backward. + y = paddle.tanh(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + + """ + self.container = tensors + + def saved_tensor(self): + """ + Get the tensors stored by ``save_for_backward``. + + Returns: + list of Tensors or None: If context contains tensors stored by `save_for_backward`, + then return these tensors, otherwise return None. + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + # ctx is a context object that store some objects for backward. + y = paddle.tanh(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + """ + return self.container + + def mark_dirty(self, *args): + self.dirty_tensors = args + + def mark_non_differentiable(self, *args): + """ + Marks outputs as non-differentiable. + This should be called at most once, only from inside thethe `forward` method, + and all arguments should be tensor outputs. + + This will mark outputs as not requiring gradients, increasing the + efficiency of backward computation. You still need to accept a gradient + for each output in `backward`, but it's always going to + be a zero tensor with the same shape as the shape of a corresponding + output. + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + import numpy as np + + class Tanh(PyLayer): + @staticmethod + def forward(ctx, x): + a = x + x + b = x + x + x + ctx.mark_non_differentiable(a) + return a, b + + @staticmethod + def backward(ctx, grad_a, grad_b): + assert np.equal(grad_a.numpy(), paddle.zeros([1]).numpy()) + assert np.equal(grad_b.numpy(), paddle.ones([1], dtype="float64").numpy()) + return grad_b + + x = paddle.ones([1], dtype="float64") + x.stop_gradient = False + a, b = Tanh.apply(x) + b.sum().backward() + """ + self.non_differentiable = args + + def set_materialize_grads(self, value: bool): + """ + Sets whether to materialize output grad tensors. Default is True. + + This should be called only from inside the `forward` method. + + If True, undefined output grad tensors will be expanded to tensors full + of zeros prior to calling the `backward` method. + + If False, undefined output grad tensors will be None. + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + import numpy as np + + class Tanh(PyLayer): + @staticmethod + def forward(ctx, x): + return x, x+x + + @staticmethod + def backward(ctx, grad, grad2): + assert np.equal(grad2.numpy(), paddle.zeros([1]).numpy()) + return grad + + class Tanh2(PyLayer): + @staticmethod + def forward(ctx, x): + ctx.set_materialize_grads(False) + return x, x+x + + @staticmethod + def backward(ctx, grad, grad2): + assert grad2==None + return grad + + x = paddle.ones([1], dtype="float64") + x.stop_gradient = False + Tanh.apply(x)[0].backward() + + x2 = paddle.ones([1], dtype="float64") + x2.stop_gradient = False + Tanh2.apply(x2)[0].backward() + """ + self.materialize_grads = value + + +class EagerPyLayerBackward(core.eager.PyLayer, EagerPyLayerContext): + def backward(self, *args): + return self._forward_cls.backward(self, *args) + + +class EagerPyLayerMeta(type): + def __init__(cls, name, bases, attrs): + cls._backward_function = type(name + '_backward', + (EagerPyLayerBackward, ), + {"_forward_cls": cls}) + + return super(EagerPyLayerMeta, cls).__init__(name, bases, attrs) + + +class EagerPyLayer( + with_mateclass(EagerPyLayerMeta, core.eager.PyLayer, + EagerPyLayerContext)): + @staticmethod + def forward(ctx, *args, **kwargs): + """ + It is to be overloaded by subclasses. It must accept a object of `PyLayerContext` as + the first argument, followed by any number of arguments (tensors or other types). + `None` can not be included in the returned result. + + Args: + *args(tuple): input of PyLayer. + **kwargs(dict): input of PyLayer. + + Returns: + tensors or other types : output of PyLayer. + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + y = paddle.tanh(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + """ + raise NotImplementedError( + "You must implement the forward function for PyLayer.") + + @staticmethod + def backward(ctx, *args): + """ + This is a function to calculate the gradient. It is to be overloaded by subclasses. + It must accept a object of `PyLayerContext` as the first argument, and the rest + arguments are the gradient of forward's output tensors. Output tensors of backward + are the gradient of forward's input tensors. + + Args: + *args(tuple): The gradient of forward's output tensor(s). + **kwargs(dict): The gradient of forward's output tensor(s). + + Returns: + Tensor or list of Tensors: The gradient of forward's input tensor(s). + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + y = paddle.tanh(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + """ + + raise NotImplementedError( + "You must implement the backward function for PyLayer.") + + +def once_differentiable(backward): + def wrapper(ctx, *args): + with paddle.fluid.dygraph.no_grad(): + outputs = backward(ctx, *args) + return outputs + + return wrapper diff --git a/python/paddle/fluid/tests/unittests/test_pylayer_op.py b/python/paddle/fluid/tests/unittests/test_pylayer_op.py index 200273c606..786f4cb7a7 100644 --- a/python/paddle/fluid/tests/unittests/test_pylayer_op.py +++ b/python/paddle/fluid/tests/unittests/test_pylayer_op.py @@ -18,7 +18,8 @@ import unittest import numpy as np import paddle -from paddle.autograd import PyLayer +from paddle.autograd import PyLayer, EagerPyLayer +from paddle.fluid.framework import _test_eager_guard, in_dygraph_mode class FakeTensor(paddle.fluid.core.VarBase): @@ -27,8 +28,8 @@ class FakeTensor(paddle.fluid.core.VarBase): class TestPyLayer(unittest.TestCase): - def test_simple_pylayer_multiple_output(self): - class tanh(PyLayer): + def func_test_simple_pylayer_multiple_output(self): + class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x1, x2, func1, func2=paddle.square): ctx.func = func2 @@ -58,8 +59,13 @@ class TestPyLayer(unittest.TestCase): self.assertTrue( np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10) - def test_simple_pylayer_return_none_with_no_grad(self): - class tanh(PyLayer): + def test_simple_pylayer_multiple_output(self): + with _test_eager_guard(): + self.func_test_simple_pylayer_multiple_output() + self.func_test_simple_pylayer_multiple_output() + + def func_test_simple_pylayer_return_none_with_no_grad(self): + class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x1, x2, func1, func2=paddle.square): ctx.func = func2 @@ -93,8 +99,13 @@ class TestPyLayer(unittest.TestCase): self.assertTrue( np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10) - def test_simple_pylayer_single_output(self): - class tanh(PyLayer): + def test_simple_pylayer_return_none_with_no_grad(self): + with _test_eager_guard(): + self.func_test_simple_pylayer_return_none_with_no_grad() + self.func_test_simple_pylayer_return_none_with_no_grad() + + def func_test_simple_pylayer_single_output(self): + class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x1, func1, func2=paddle.square): ctx.func = func2 @@ -120,8 +131,13 @@ class TestPyLayer(unittest.TestCase): self.assertTrue( np.max(np.abs((input1.grad.numpy() - input2.grad.numpy()))) < 1e-10) - def test_pylayer_num_output_match(self): - class tanh(PyLayer): + def test_simple_pylayer_single_output(self): + with _test_eager_guard(): + self.func_test_simple_pylayer_single_output() + self.func_test_simple_pylayer_single_output() + + def func_test_pylayer_num_output_match(self): + class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward( ctx, @@ -141,8 +157,13 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.mean().backward() - def test_pylayer_dtype(self): - class tanh(PyLayer): + def test_pylayer_num_output_match(self): + with _test_eager_guard(): + self.func_test_pylayer_num_output_match() + self.func_test_pylayer_num_output_match() + + def func_test_pylayer_dtype(self): + class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x, dtype): y = paddle.cast(x, dtype) @@ -165,8 +186,13 @@ class TestPyLayer(unittest.TestCase): z.sum().backward() self.assertTrue(input1.grad is not None) - def test_pylayer_Exception_forward(self): - class Layer_None1(PyLayer): + def test_pylayer_dtype(self): + with _test_eager_guard(): + self.func_test_pylayer_dtype() + self.func_test_pylayer_dtype() + + def func_test_pylayer_Exception_forward(self): + class Layer_None1(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, *args): return None @@ -179,7 +205,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z = Layer_None1.apply(input1) - class Layer_None2(PyLayer): + class Layer_None2(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, *args): return [None, args[0]] @@ -192,7 +218,7 @@ class TestPyLayer(unittest.TestCase): # return None z = Layer_None2.apply(input1) - class Layer_one1(PyLayer): + class Layer_one1(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, *args): return 1 @@ -206,7 +232,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z = Layer_one1.apply(input1) - class Layer_one2(PyLayer): + class Layer_one2(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, *args): return [1, 2, args[0]] @@ -219,7 +245,7 @@ class TestPyLayer(unittest.TestCase): # return int z = Layer_one2.apply(input1) - class Layer_no_fw(PyLayer): + class Layer_no_fw(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def backward(ctx, *args): return args @@ -228,8 +254,13 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(NotImplementedError): z = Layer_no_fw.apply(input1) - def test_pylayer_nograd(self): - class tanh(PyLayer): + def test_pylayer_Exception_forward(self): + with _test_eager_guard(): + self.func_test_pylayer_Exception_forward() + self.func_test_pylayer_Exception_forward() + + def func_test_pylayer_nograd(self): + class tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x1, func1, func2=paddle.square, xx=None): ctx.func = func2 @@ -246,8 +277,13 @@ class TestPyLayer(unittest.TestCase): z.mean().backward() self.assertTrue(z.grad is None) - def test_pylayer_Exception_bk(self): - class Layer_bk_none1(PyLayer): + def test_pylayer_nograd(self): + with _test_eager_guard(): + self.func_test_pylayer_nograd() + self.func_test_pylayer_nograd() + + def func_test_pylayer_Exception_bk(self): + class Layer_bk_none1(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x): return x * 2 @@ -263,7 +299,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.sum().backward() - class Layer_bk_none2(PyLayer): + class Layer_bk_none2(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x1, x2): return x1 + x2 @@ -279,7 +315,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.mean().backward() - class Layer_bk_one1(PyLayer): + class Layer_bk_one1(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x): return x + x @@ -295,7 +331,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.mean().backward() - class Layer_bk_one2(PyLayer): + class Layer_bk_one2(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x1, x2): return x1 * 2, x2 * 5 @@ -312,7 +348,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.mean().backward() - class Layer_no_bk(PyLayer): + class Layer_no_bk(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x): return x * 2, x * 5 @@ -325,7 +361,7 @@ class TestPyLayer(unittest.TestCase): z = z[0] + z[1] z.mean().backward() - class Layer_bk_match(PyLayer): + class Layer_bk_match(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x): return x * 2, x * 5 @@ -341,8 +377,13 @@ class TestPyLayer(unittest.TestCase): z = z[0] + z[1] z.mean().backward() - def test_pylayer_bk_return_none(self): - class Layer_bk_none1(PyLayer): + def test_pylayer_Exception_bk(self): + with _test_eager_guard(): + self.func_test_pylayer_Exception_bk() + self.func_test_pylayer_Exception_bk() + + def func_test_pylayer_bk_return_none(self): + class Layer_bk_none1(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x1, x2): return x1 + x2 @@ -360,7 +401,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.mean().backward() - class Layer_bk_none2(PyLayer): + class Layer_bk_none2(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x1, x2): return x1 * 2, x2 * 5 @@ -378,8 +419,13 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.mean().backward() + def test_pylayer_bk_return_none(self): + with _test_eager_guard(): + self.func_test_pylayer_bk_return_none() + self.func_test_pylayer_bk_return_none() + def test_pylayer_inplace(self): - class cus_tanh(PyLayer): + class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x): return x @@ -407,7 +453,7 @@ class TestPyLayer(unittest.TestCase): self.assertTrue(data.grad is not None) def test_pylayer_inplace_and_leaf_exception(self): - class cus_pylayer_op(PyLayer): + class cus_pylayer_op(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x): return x @@ -432,8 +478,8 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z = layer(data) - def test_backward_in_backward(self): - class cus_tanh(PyLayer): + def func_test_backward_in_backward(self): + class cus_tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x): temp = x.detach() @@ -457,8 +503,13 @@ class TestPyLayer(unittest.TestCase): z = paddle.tanh(data) z = cus_tanh.apply(data) - def test_return_to_tensor(self): - class Tanh(PyLayer): + def test_backward_in_backward(self): + with _test_eager_guard(): + self.func_test_backward_in_backward() + self.func_test_backward_in_backward() + + def func_test_return_to_tensor(self): + class Tanh(EagerPyLayer if in_dygraph_mode() else PyLayer): @staticmethod def forward(ctx, x1): y1 = paddle.tanh(x1) @@ -479,6 +530,89 @@ class TestPyLayer(unittest.TestCase): z, number, none_item, string_item, tensor1 = Tanh.apply(x1=input1) z.mean().backward() + def test_return_to_tensor(self): + with _test_eager_guard(): + self.func_test_return_to_tensor() + self.func_test_return_to_tensor() + + def test_materialize_grads(self): + with _test_eager_guard(): + + class Tanh(EagerPyLayer): + @staticmethod + def forward(ctx, x): + return x, x + x + + @staticmethod + def backward(ctx, grad, grad2): + self.assertEqual(grad2, paddle.zeros([1])) + return grad + + x = paddle.ones([1], dtype="float64") + x.stop_gradient = False + Tanh.apply(x)[0].backward() + + def test_dont_materialize_grads(self): + with _test_eager_guard(): + + class Tanh(EagerPyLayer): + @staticmethod + def forward(ctx, x): + ctx.set_materialize_grads(False) + return x, x + x + + @staticmethod + def backward(ctx, grad, grad2): + self.assertIsNone(grad2) + return grad + + x = paddle.ones([1], dtype="float64") + x.stop_gradient = False + Tanh.apply(x)[0].backward() + + def test_mark_non_differentiable(self): + with _test_eager_guard(): + + class Tanh(EagerPyLayer): + @staticmethod + def forward(ctx, x): + a = x + x + ctx.mark_non_differentiable(a) + return a + + @staticmethod + def backward(ctx, grad): + self.assertTrue(False) # should not be call + return paddle.ones([1], dtype="float64") + + x = paddle.ones([1], dtype="float64") + x.stop_gradient = False + y = Tanh.apply(x) + y.sum().backward() + + def test_mark_non_differentiable2(self): + with _test_eager_guard(): + + class Tanh(EagerPyLayer): + @staticmethod + def forward(ctx, x): + a = x + x + b = x + x + x + ctx.mark_non_differentiable(a) + return a, b + + @staticmethod + def backward(ctx, grad_a, grad_b): + self.assertEqual(grad_a, paddle.zeros([1])) + self.assertEqual(grad_b, paddle.ones([1], dtype="float64")) + return grad_b + + x = paddle.ones([1], dtype="float64") + x.stop_gradient = False + a, b = Tanh.apply(x) + b.sum().backward() + self.assertEqual(x.grad, paddle.ones([1], dtype="float64")) + class TestPyLayerReturnType(unittest.TestCase): def test_forward_args_fake_tensor(self): -- GitLab