From 0b60b7845246773d839db12b7cd621a9e6bdb9be Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Wed, 28 Dec 2022 08:41:07 +0800 Subject: [PATCH] delete old dygraph pylayer (#49339) * delete old dygraph pylayer --- paddle/fluid/imperative/py_layer_fwd.h | 257 -------------- paddle/fluid/operators/CMakeLists.txt | 3 +- paddle/fluid/operators/py_layer_op.cc | 239 ------------- paddle/fluid/operators/py_layer_op.h | 113 ------ paddle/fluid/pybind/CMakeLists.txt | 1 - paddle/fluid/pybind/imperative.cc | 55 --- python/paddle/autograd/__init__.py | 10 +- python/paddle/autograd/py_layer.py | 322 +----------------- .../fluid/tests/unittests/test_pylayer_op.py | 180 ++-------- 9 files changed, 39 insertions(+), 1141 deletions(-) delete mode 100644 paddle/fluid/imperative/py_layer_fwd.h delete mode 100644 paddle/fluid/operators/py_layer_op.cc delete mode 100644 paddle/fluid/operators/py_layer_op.h diff --git a/paddle/fluid/imperative/py_layer_fwd.h b/paddle/fluid/imperative/py_layer_fwd.h deleted file mode 100644 index 4bdb43a440..0000000000 --- a/paddle/fluid/imperative/py_layer_fwd.h +++ /dev/null @@ -1,257 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/type_defs.h" -#include "paddle/fluid/imperative/layer.h" -#include "paddle/fluid/imperative/prepared_operator.h" -#include "paddle/fluid/imperative/tracer.h" -#include "paddle/fluid/operators/py_layer_op.h" - -namespace paddle { -namespace imperative { - -namespace py = ::pybind11; - -bool RequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs) { - for (const auto& name_pair : ins) { - for (const auto& var_base : name_pair.second) { - if (!var_base->OverridedStopGradient()) { - for (const auto& pair : outs) { - for (const auto& var : pair.second) { - if (var) { - var->SetOverridedStopGradient(false); - SetForwardDataTypeOfGradVar(var); - VLOG(3) << "Set output: " << var->Name() - << "'s OverridedStopGradient as " - << var->OverridedStopGradient(); - } - } - } - return true; - } - } - } - return false; -} - -std::shared_ptr CreateGradOpNode( - const std::string& type, - const NameVarBaseMap& ins, - const NameVarBaseMap& outs, - const framework::AttributeMap& attrs, - const platform::Place& place, - const std::map& inplace_map, - const std::shared_ptr& py_context) { - operators::PyLayerGradOpMaker maker( - type, ins, outs, attrs, inplace_map); - - maker.SetPyLayerContext(py_context); - auto grad_node = maker(); - if (grad_node && !grad_node->empty()) { - for (auto& grad_op : *grad_node) { - grad_op.SetId(OpBase::GenerateUniqueId()); - grad_op.SetPlace(place); - ClearNoNeedBufferInputs(&grad_op); - } - return grad_node; - } else { - return nullptr; - } -} - -py::object PyLayerApply(const platform::Place& place, - const py::handle& cls, - const py::args args, - const py::kwargs kwargs) { - py::gil_scoped_acquire guard; - auto bk_function = cls.attr("_backward_function"); - auto context = bk_function(); - auto forward = cls.attr("forward"); - - // make inputs to varbase - std::vector> input_vars; - // process args,`input_vars` only collect `imperative::VarBase` - if (!args.empty()) { - for (auto ptr = args.begin(); ptr != args.end(); ptr++) { - // Only collect Tensor type in 'args' and pass them to backward. Ignore - // other types of input temporarily. - if (py::isinstance(*ptr)) { - try { - auto a = ptr->cast>(); - input_vars.push_back(a); - } catch (py::cast_error& err) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The `PyLayer.forward` function contains invalid argument, the " - "`%s` type argument can not be cast into `Tensor`.", - ptr->ptr()->ob_type->tp_name)); - } - } else if (py::isinstance(*ptr) || - py::isinstance(*ptr)) { - try { - auto tuple_arg = ptr->cast(); - for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) { - try { - auto t = iter->cast>(); - input_vars.push_back(t); - } catch (py::cast_error& err) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The `PyLayer.forward` function contains invalid argument, " - "the " - "`%s` type argument can not be cast into `Tensor`.", - ptr->ptr()->ob_type->tp_name)); - } - } - } catch (py::cast_error& err) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The `PyLayer.forward` function contains invalid argument, the " - "`%s` type argument can not be cast into `Tensor`.", - ptr->ptr()->ob_type->tp_name)); - } - } - } - } - // process kwargs, only collect `imperative::VarBase` - if (!kwargs.empty()) { - for (auto ptr = kwargs.begin(); ptr != kwargs.end(); ptr++) { - // Only collect Tensor type in 'kwargs' and pass them to backward. - // Ignore other types of input temporarily. - if (py::isinstance(*ptr->second)) { - try { - auto a = ptr->second.cast>(); - input_vars.push_back(a); - } catch (py::cast_error&) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The `PyLayer.forward` function contains invalid argument, the " - "`%s` type argument can not be cast into `Tensor`.", - ptr->second.ptr()->ob_type->tp_name)); - } - } else if (py::isinstance(*ptr->second) || - py::isinstance(*ptr->second)) { - try { - auto tuple_arg = ptr->second.cast(); - for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) { - try { - auto t = iter->cast>(); - input_vars.push_back(t); - } catch (py::cast_error& err) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The `PyLayer.forward` function contains invalid argument, " - "the " - "`%s` type argument can not be cast into `Tensor`.", - ptr->second.ptr()->ob_type->tp_name)); - } - } - } catch (py::cast_error& err) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The `PyLayer.forward` function contains invalid argument, the " - "`%s` type argument can not be cast into `Tensor`.", - ptr->second.ptr()->ob_type->tp_name)); - } - } - } - } - - std::shared_ptr py_layer_ctx = - std::make_shared(context.ptr()); - auto result_forward = forward(context, *args, **kwargs); - NameVarBaseMap ins = {{"X", input_vars}}; - - std::vector> output_vars; - if (PyTuple_Check(result_forward.ptr()) || - PyList_Check(result_forward.ptr())) { - auto tuple_result = result_forward.cast(); - for (size_t i = 0; i < tuple_result.size(); i++) { - // Only collect Tensor type of output and pass them to backward. - // Ignore other types of input temporarily. - if (py::isinstance(tuple_result[i])) { - try { - auto temp_out = - tuple_result[i].cast>(); - output_vars.push_back(temp_out); - } catch (py::cast_error&) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The `PyLayer.forward` function returns invalid argument, the " - "`%s` type argument can not be cast into `Tensor`.", - tuple_result[i].ptr()->ob_type->tp_name)); - } - } - } - } else { - // Only collect Tensor type of output and pass them to backward. - // Ignore other types of input temporarily. - if (py::isinstance(result_forward)) { - try { - auto temp_out = - result_forward.cast>(); - output_vars.push_back(temp_out); - } catch (py::cast_error&) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The `PyLayer.forward` function returns invalid argument, the `%s` " - "type argument can not be cast into `Tensor`.", - result_forward.ptr()->ob_type->tp_name)); - } - } - } - if (output_vars.size() == 0) { - PADDLE_THROW(platform::errors::InvalidArgument( - "At least one output of `PyLayer.forward` is a `Tensor`.")); - } - - NameVarBaseMap outs = {{"Out", output_vars}}; - - if (RequiredGrad(ins, outs)) { - std::map inplace_map{}; - bool if_inplace = false; - for (auto temp_ins : input_vars) { - if (if_inplace) { - break; - } - for (auto temp_outs : output_vars) { - if (temp_ins->Name() == temp_outs->Name()) { - if_inplace = true; - break; - } - } - } - if (if_inplace) { - // when pylayer forward is inplace strategy, check whether tensor is leaf - for (auto& t : input_vars) { - PADDLE_ENFORCE_EQ(t->IsLeaf() && !t->OverridedStopGradient(), - false, - platform::errors::InvalidArgument( - "Leaf Var (%s) that doesn't stop gradient can't " - "use inplace strategy.", - t->Name())); - } - - inplace_map["X"] = "Out"; - } - - CreateGradOpNode( - "py_layer", ins, outs, {{}}, place, inplace_map, py_layer_ctx); - } else { - VLOG(3) << "No Grad to track for Op: py_layer_op"; - } - - return result_forward; -} - -} // namespace imperative -} // namespace paddle diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 4aeb3d6b74..901d8741c0 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -97,7 +97,7 @@ endif() set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils backward_infermeta sparse_backward_infermeta) -register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op +register_operators(EXCLUDES py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS}) @@ -206,7 +206,6 @@ cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor de cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS}) if (WITH_PYTHON) cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind) - cc_library(py_layer_op SRCS py_layer_op.cc DEPS op_registry python pybind) endif() if (WITH_ASCEND_CL) diff --git a/paddle/fluid/operators/py_layer_op.cc b/paddle/fluid/operators/py_layer_op.cc deleted file mode 100644 index 9c13934ccd..0000000000 --- a/paddle/fluid/operators/py_layer_op.cc +++ /dev/null @@ -1,239 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/operators/py_layer_op.h" - -#include - -namespace paddle { -namespace operators { - -namespace py = ::pybind11; - -void RunPyObject(py::object *py_object, - const std::vector &ins, - std::vector *outs) { - py::gil_scoped_acquire guard; - - auto py_function = py_object->attr("backward"); - - py::tuple inputs(ins.size()); - for (size_t i = 0; i < ins.size(); i++) { - auto in_var = ins[i]; - if (in_var != nullptr) { - auto name = paddle::string::Sprintf("generator_custom_py_layer_%d@GRAD", - static_cast(i)); - - std::shared_ptr temp_wrap = - std::make_shared(name, *in_var); - temp_wrap->InnerSetOverridedStopGradient(true); - std::shared_ptr temp_varbase = - std::make_shared(temp_wrap); - try { - inputs[i] = py::cast(temp_varbase).ptr(); - } catch (py::cast_error &) { - PADDLE_THROW(platform::errors::Unimplemented( - "The output of `PyLayer.backward` should be `Tensor`.")); - } - } - } - - auto py_result = py_function(*py_object, *inputs); - - if (PyTuple_Check(py_result.ptr()) || PyList_Check(py_result.ptr())) { - auto result_tuple = py_result.cast(); - if (result_tuple.size() != outs->size()) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The number of outputs of `PyLayer.backward` should be %d, but " - "received %d.", - outs->size(), - result_tuple.size())); - } - for (size_t i = 0; i < result_tuple.size(); i++) { - if ((*outs)[i] != nullptr) { - if (Py_None != result_tuple[i].ptr()) { - if (py::isinstance(result_tuple[i])) { - try { - auto result_var = - result_tuple[i].cast>(); - *(*outs)[i] = result_var->Var(); - } catch (py::cast_error &) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The `PyLayer.backward` function returns invalid argument, " - "the `%s` type argument can not be cast into `Tensor`.", - result_tuple[i].ptr()->ob_type->tp_name)); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The output of `PyLayer.backward` should be `Tensor`, but " - "received `%s`.", - result_tuple[i].ptr()->ob_type->tp_name)); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The %dth input tensor of forward needs gradient and the " - "corresponding gradient cannot be None.", - i)); - } - } else { - if (Py_None != result_tuple[i].ptr()) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The %dth input tensor of forward do not need gradient and the " - "corresponding gradient should be `None`.", - i)); - } - } - } - } else { - if (1 != outs->size()) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The number of outputs of `PyLayer.backward` should be %d, but " - "received 1.", - outs->size())); - } - if ((*outs)[0] != nullptr) { - if (Py_None != py_result.ptr()) { - if (py::isinstance(py_result)) { - try { - auto result_var = - py_result.cast>(); - *((*outs)[0]) = result_var->Var(); - } catch (py::cast_error &) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The `PyLayer.backward` function returns invalid argument, the " - "`%s` type argument can not be cast into `Tensor`.", - py_result.ptr()->ob_type->tp_name)); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The output of `PyLayer.backward` should be `Tensor`, but " - "received `%s`", - py_result.ptr()->ob_type->tp_name)); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The input tensor of forward needs gradient, so the output of " - "`PyLayer.backward` can not be `None`.")); - } - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "The input tensor of forward do not need gradient, so the output of " - "`PyLayer.backward` should be `None`.")); - } - } -} - -void PyLayerGradOpMaker::Apply( - GradOpPtr grad_op) const { - grad_op->SetType("py_layer"); - auto &inner_op = grad_op->InnerOp(); - auto py_layer_op_const = dynamic_cast(&inner_op); - - if (py_layer_op_const) { - auto py_layer_op = const_cast(py_layer_op_const); - py_layer_op->SetPyLayerContext(py_context_); - - } else { - PADDLE_THROW(platform::errors::Fatal( - "PyLayerGradOpMaker can't cast %s to PyLayerOp*.", - typeid(&inner_op).name())); - } - - auto fwd_out_grads = this->OutputGrad("Out"); - using return_type = decltype(fwd_out_grads); - return_type bwd_ins; - - bwd_ins.insert(bwd_ins.begin(), fwd_out_grads.begin(), fwd_out_grads.end()); - - auto bwd_outs = this->InputGrad("X", false); - - grad_op->SetInput("X", bwd_ins); - grad_op->SetOutput("Out", bwd_outs); -} - -class PyLayerOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "Inputs of PyLayer op.").AsDuplicable(); - AddOutput("Out", "Outputs of PyLayer op").AsDuplicable(); - AddComment(R"DOC("PyLayer Op")DOC"); - } -}; - -template -class PyLayerOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto &op_ = ctx.GetOp(); - auto const_pylayer_op = dynamic_cast(&op_); - if (const_pylayer_op) { - auto pylayer_op = const_cast(const_pylayer_op); - - // Release contex after executing the compute - auto py_layer_context = pylayer_op->ReleasePyLayerContext(); - py::object bk_ctx(py::handle(py_layer_context->GetMutableCtx()), true); - auto &input_vars = ctx.MultiInputVar("X"); - auto output_vars = ctx.MultiOutputVar("Out"); - RunPyObject(&bk_ctx, input_vars, &output_vars); - - } else { - PADDLE_THROW(platform::errors::Fatal( - "PyLayerOpKernel can't cast %s to PyLayer*.", typeid(&op_).name())); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OPERATOR(py_layer, - ops::PyLayerOp, - ops::PyLayerOpMaker, - ops::PyLayerGradOpMaker, - ops::PyLayerGradOpMaker); - -REGISTER_OP_CPU_KERNEL( - py_layer, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel>, - ops::PyLayerOpKernel>); -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) -REGISTER_OP_CUDA_KERNEL( - py_layer, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel, - ops::PyLayerOpKernel>, - ops::PyLayerOpKernel>); -#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP diff --git a/paddle/fluid/operators/py_layer_op.h b/paddle/fluid/operators/py_layer_op.h deleted file mode 100644 index ea048ee9e5..0000000000 --- a/paddle/fluid/operators/py_layer_op.h +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/python_headers.h" - -namespace paddle { -namespace operators { -namespace py = ::pybind11; - -class PyLayerContext { - public: - explicit PyLayerContext(PyObject* context) : context_(context) { - Py_INCREF(context_); - } - - PyLayerContext() = delete; - - PyObject* GetMutableCtx() { return context_; } - ~PyLayerContext() { - py::gil_scoped_acquire guard; - Py_XDECREF(context_); - } - - private: - PyObject* context_; -}; - -class PyLayerOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { - VLOG(3) << "`InferShape` of `PyLayer` is an empty function, and it cannot " - "infer the shape of the output tensors."; - } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - auto data_type = paddle::framework::proto::VarType::Type::VarType_Type_FP32; - return framework::OpKernelType(data_type, ctx.device_context()); - } - - public: - void SetPyLayerContext(const std::shared_ptr& py_context) { - py_context_ = py_context; - } - std::shared_ptr ReleasePyLayerContext() { - auto temp = py_context_; - py_context_.reset(); - VLOG(3) << "`py_context_` in the PyLayerOp is released."; - return temp; - } - - private: - std::shared_ptr py_context_; -}; - -template -class PyLayerGradOpMaker {}; -template <> -class PyLayerGradOpMaker - : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker< - paddle::framework::OpDesc>::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override { - PADDLE_THROW(platform::errors::Unimplemented( - "`PyLayer` don't support static graph mode.")); - } -}; - -template <> -class PyLayerGradOpMaker - : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker< - paddle::imperative::OpBase>::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr grad_op) const override; - - public: - void SetPyLayerContext(const std::shared_ptr& py_context) { - py_context_ = py_context; - } - - private: - std::shared_ptr py_context_; -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 55accf7e04..eebb8dd843 100755 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -106,7 +106,6 @@ endif() if(WITH_PYTHON) list(APPEND PYBIND_DEPS py_func_op) - list(APPEND PYBIND_DEPS py_layer_op) endif() set(PYBIND_SRCS diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 7adec4dca2..ba18e32f1d 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -46,7 +46,6 @@ limitations under the License. */ #include "paddle/fluid/imperative/nccl_context.h" #include "paddle/fluid/imperative/partial_grad_engine.h" #include "paddle/fluid/imperative/profiler.h" -#include "paddle/fluid/imperative/py_layer_fwd.h" #include "paddle/fluid/imperative/reducer.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" @@ -2656,60 +2655,6 @@ void BindImperative(py::module *m_ptr) { .def("init", [](imperative::HeterParallelContext &self) { self.Init(); }); #endif - m.def("pylayer_apply", - [](const platform::CPUPlace &place, - const py::object &cls, - const py::args args, - const py::kwargs kwargs) { - return imperative::PyLayerApply(place, cls, args, kwargs); - }); - - m.def("pylayer_apply", - [](const platform::CUDAPlace &place, - const py::object &cls, - const py::args args, - const py::kwargs kwargs) { - return imperative::PyLayerApply(place, cls, args, kwargs); - }); - - m.def("pylayer_apply", - [](const platform::XPUPlace &place, - const py::object &cls, - const py::args args, - const py::kwargs kwargs) { - return imperative::PyLayerApply(place, cls, args, kwargs); - }); - - m.def("pylayer_apply", - [](const platform::CUDAPinnedPlace &place, - const py::object &cls, - const py::args args, - const py::kwargs kwargs) { - return imperative::PyLayerApply(place, cls, args, kwargs); - }); - - m.def("pylayer_apply", - [](const platform::NPUPlace &place, - const py::object &cls, - const py::args args, - const py::kwargs kwargs) { - return imperative::PyLayerApply(place, cls, args, kwargs); - }); - m.def("pylayer_apply", - [](const platform::MLUPlace &place, - const py::object &cls, - const py::args args, - const py::kwargs kwargs) { - return imperative::PyLayerApply(place, cls, args, kwargs); - }); - m.def("pylayer_apply", - [](const platform::CustomPlace &place, - const py::object &cls, - const py::args args, - const py::kwargs kwargs) { - return imperative::PyLayerApply(place, cls, args, kwargs); - }); - #if defined(PADDLE_WITH_CUDA) m.def( "to_uva_tensor", diff --git a/python/paddle/autograd/__init__.py b/python/paddle/autograd/__init__.py index c05967ecd6..bf4d3e117c 100644 --- a/python/paddle/autograd/__init__.py +++ b/python/paddle/autograd/__init__.py @@ -17,14 +17,8 @@ from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401 from ..framework import is_grad_enabled, set_grad_enabled # noqa: F401 from . import backward_mode # noqa: F401 from .backward_mode import backward # noqa: F401 -from ..fluid.framework import _in_eager_mode_ - -if _in_eager_mode_: - from .py_layer import EagerPyLayer as PyLayer # noqa: F401 - from .py_layer import EagerPyLayerContext as PyLayerContext # noqa: F401 -else: - from .py_layer import LegacyPyLayer as PyLayer # noqa: F401 - from .py_layer import LegacyPyLayerContext as PyLayerContext # noqa: F401 +from .py_layer import PyLayer as PyLayer # noqa: F401 +from .py_layer import PyLayerContext as PyLayerContext # noqa: F401 from .saved_tensors_hooks import saved_tensors_hooks __all__ = [ # noqa diff --git a/python/paddle/autograd/py_layer.py b/python/paddle/autograd/py_layer.py index 39892ca5c6..e187ea171d 100644 --- a/python/paddle/autograd/py_layer.py +++ b/python/paddle/autograd/py_layer.py @@ -13,116 +13,11 @@ # limitations under the License. import paddle -from paddle.amp.auto_cast import auto_cast from paddle.fluid import core -from paddle.fluid.dygraph.amp.auto_cast import amp_state -from paddle.fluid.framework import dygraph_only __all__ = [] -class LegacyPyLayerContext: - """ - The object of this class is a context that is used in PyLayer to enhance the function. - - Examples: - .. code-block:: python - - import paddle - from paddle.autograd import PyLayer - - class cus_tanh(PyLayer): - @staticmethod - def forward(ctx, x): - # ctx is a object of PyLayerContext. - y = paddle.tanh(x) - ctx.save_for_backward(y) - return y - - @staticmethod - def backward(ctx, dy): - # ctx is a object of PyLayerContext. - y, = ctx.saved_tensor() - grad = dy * (1 - paddle.square(y)) - return grad - """ - - def __init__(self): - self.container = None - self._amp_state = amp_state() - - def save_for_backward(self, *tensors): - """ - Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors. - - Note: - This API should be called at most once, and only inside `forward`. - - Args: - tensors(list of Tensors): Tensors to be stored. - - Returns: - None - - Examples: - .. code-block:: python - - import paddle - from paddle.autograd import PyLayer - - class cus_tanh(PyLayer): - @staticmethod - def forward(ctx, x): - # ctx is a context object that store some objects for backward. - y = paddle.tanh(x) - # Pass tensors to backward. - ctx.save_for_backward(y) - return y - - @staticmethod - def backward(ctx, dy): - # Get the tensors passed by forward. - y, = ctx.saved_tensor() - grad = dy * (1 - paddle.square(y)) - return grad - - """ - self.container = tensors - - def saved_tensor(self): - """ - Get the tensors stored by ``save_for_backward``. - - Returns: - list of Tensors or None: If context contains tensors stored by `save_for_backward`, - then return these tensors, otherwise return None. - - Examples: - .. code-block:: python - - import paddle - from paddle.autograd import PyLayer - - class cus_tanh(PyLayer): - @staticmethod - def forward(ctx, x): - # ctx is a context object that store some objects for backward. - y = paddle.tanh(x) - # Pass tensors to backward. - ctx.save_for_backward(y) - return y - - @staticmethod - def backward(ctx, dy): - # Get the tensors passed by forward. - y, = ctx.saved_tensor() - grad = dy * (1 - paddle.square(y)) - return grad - """ - - return self.container - - def with_mateclass(meta, *bases): class impl(meta): def __new__(cls, name, temp_bases, attrs): @@ -131,212 +26,7 @@ def with_mateclass(meta, *bases): return type.__new__(impl, "impl", (), {}) -class CPyLayer: - @classmethod - @dygraph_only - def apply(cls, *args, **kwargs): - """ - After building the custom PyLayer, run it through the ``apply``. - - Args: - *args(tuple): input of PyLayer. - **kwargs(dict): input of PyLayer. - - Returns: - tensors or other types : output of PyLayer. - - Examples: - .. code-block:: python - - import paddle - from paddle.autograd import PyLayer - - class cus_tanh(PyLayer): - @staticmethod - def forward(ctx, x, func1, func2=paddle.square): - ctx.func = func2 - y = func1(x) - # Pass tensors to backward. - ctx.save_for_backward(y) - return y - - @staticmethod - def backward(ctx, dy): - # Get the tensors passed by forward. - y, = ctx.saved_tensor() - grad = dy * (1 - ctx.func(y)) - return grad - - - data = paddle.randn([2, 3], dtype="float64") - data.stop_gradient = False - # run custom Layer. - z = cus_tanh.apply(data, func1=paddle.tanh) - """ - place = paddle.fluid.framework._current_expected_place() - with paddle.fluid.dygraph.no_grad(): - return core.pylayer_apply(place, cls, *args, **kwargs) - - -class PyLayerBackward(LegacyPyLayerContext): - def backward(self, *args, **kwargs): - with paddle.fluid.dygraph.guard(): - with paddle.fluid.dygraph.no_grad(): - if ( - self._amp_state - and 'enable' in self._amp_state - and self._amp_state['enable'] - ): - with auto_cast(**args[0]._amp_state): - return self._forward_cls.backward(*args, **kwargs) - else: - - return self._forward_cls.backward(*args, **kwargs) - return self._forward_cls.backward(*args, **kwargs) - - -class LayerMeta(type): - def __init__(cls, name, bases, attrs): - cls._backward_function = type( - name + '_backward', (PyLayerBackward,), {"_forward_cls": cls} - ) - - return super().__init__(name, bases, attrs) - - -class LegacyPyLayer(with_mateclass(LayerMeta, CPyLayer)): - """ - Build a custom `Layer` by creating subclasses. Subclasses need to follow the following rules: - 1. Subclasses contain `forward` and `backward` function. Both forward and backward are @staticmethod. - Their first argument should be a context and `None` can not be included in the returned result. - 2. Input of backward contains a context as the first argument, and the rest arguments are the - gradient of forward's output tensors. so the number of backward's input tensors equal to - the number of forward output tensors. If you need the forward's inputs or outputs in `backward`, - you can use `save_for_backward` to store the required tensors, and then use them in the backward. - 3. Output of backward function can only be `Tensor` or tuple/list of `Tensor`. - Output tensors of backward are the gradient of forward's input tensors, - so the number of backward's output tensors equal to the number of forward input tensors. - After building the custom Layer, run it through the `apply` method. - - - Examples: - .. code-block:: python - - import paddle - from paddle.autograd import PyLayer - - # Inherit from PyLayer - class cus_tanh(PyLayer): - @staticmethod - def forward(ctx, x, func1, func2=paddle.square): - # ctx is a context object that store some objects for backward. - ctx.func = func2 - y = func1(x) - # Pass tensors to backward. - ctx.save_for_backward(y) - return y - - @staticmethod - # forward has only one output, so there is only one gradient in the input of backward. - def backward(ctx, dy): - # Get the tensors passed by forward. - y, = ctx.saved_tensor() - grad = dy * (1 - ctx.func(y)) - # forward has only one input, so only one gradient tensor is returned. - return grad - - - data = paddle.randn([2, 3], dtype="float64") - data.stop_gradient = False - z = cus_tanh.apply(data, func1=paddle.tanh) - z.mean().backward() - - print(data.grad) - - """ - - @staticmethod - def forward(ctx, *args, **kwargs): - """ - It is to be overloaded by subclasses. It must accept a object of `PyLayerContext` as - the first argument, followed by any number of arguments (tensors or other types). - `None` can not be included in the returned result. - - Args: - *args(tuple): input of PyLayer. - **kwargs(dict): input of PyLayer. - - Returns: - tensors or other types : output of PyLayer. - - Examples: - .. code-block:: python - - import paddle - from paddle.autograd import PyLayer - - class cus_tanh(PyLayer): - @staticmethod - def forward(ctx, x): - y = paddle.tanh(x) - # Pass tensors to backward. - ctx.save_for_backward(y) - return y - - @staticmethod - def backward(ctx, dy): - # Get the tensors passed by forward. - y, = ctx.saved_tensor() - grad = dy * (1 - paddle.square(y)) - return grad - """ - raise NotImplementedError( - "You must implement the forward function for PyLayer." - ) - - @staticmethod - def backward(ctx, *args, **kwargs): - """ - This is a function to calculate the gradient. It is to be overloaded by subclasses. - It must accept a object of `PyLayerContext` as the first argument, and the rest - arguments are the gradient of forward's output tensors. Output tensors of backward - are the gradient of forward's input tensors. - - Args: - *args(tuple): The gradient of forward's output tensor(s). - **kwargs(dict): The gradient of forward's output tensor(s). - - Returns: - Tensor or list of Tensors: The gradient of forward's input tensor(s). - - Examples: - .. code-block:: python - - import paddle - from paddle.autograd import PyLayer - - class cus_tanh(PyLayer): - @staticmethod - def forward(ctx, x): - y = paddle.tanh(x) - # Pass tensors to backward. - ctx.save_for_backward(y) - return y - - @staticmethod - def backward(ctx, dy): - # Get the tensors passed by forward. - y, = ctx.saved_tensor() - grad = dy * (1 - paddle.square(y)) - return grad - """ - - raise NotImplementedError( - "You must implement the backward function for PyLayer." - ) - - -class EagerPyLayerContext: +class PyLayerContext: def save_for_backward(self, *tensors): """ Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors. @@ -537,23 +227,21 @@ class EagerPyLayerContext: self.materialize_grads = value -class EagerPyLayerBackward(core.eager.PyLayer, EagerPyLayerContext): +class PyLayerBackward(core.eager.PyLayer, PyLayerContext): def backward(self, *args): return self._forward_cls.backward(self, *args) -class EagerPyLayerMeta(type): +class PyLayerMeta(type): def __init__(cls, name, bases, attrs): cls._backward_function = type( - name + '_backward', (EagerPyLayerBackward,), {"_forward_cls": cls} + name + '_backward', (PyLayerBackward,), {"_forward_cls": cls} ) return super().__init__(name, bases, attrs) -class EagerPyLayer( - with_mateclass(EagerPyLayerMeta, core.eager.PyLayer, EagerPyLayerContext) -): +class PyLayer(with_mateclass(PyLayerMeta, core.eager.PyLayer, PyLayerContext)): @staticmethod def forward(ctx, *args, **kwargs): """ diff --git a/python/paddle/fluid/tests/unittests/test_pylayer_op.py b/python/paddle/fluid/tests/unittests/test_pylayer_op.py index c44d83df70..5af29898ce 100644 --- a/python/paddle/fluid/tests/unittests/test_pylayer_op.py +++ b/python/paddle/fluid/tests/unittests/test_pylayer_op.py @@ -17,8 +17,7 @@ import unittest import numpy as np import paddle -from paddle.autograd.py_layer import EagerPyLayer, LegacyPyLayer -from paddle.fluid.framework import in_dygraph_mode +from paddle.autograd.py_layer import PyLayer class FakeTensor(paddle.fluid.core.VarBase): @@ -28,7 +27,7 @@ class FakeTensor(paddle.fluid.core.VarBase): class TestPyLayer(unittest.TestCase): def test_simple_pylayer_multiple_output(self): - class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class tanh(PyLayer): @staticmethod def forward(ctx, x1, x2, func1, func2=paddle.square): ctx.func = func2 @@ -60,7 +59,7 @@ class TestPyLayer(unittest.TestCase): ) def test_simple_pylayer_return_none_with_no_grad(self): - class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class tanh(PyLayer): @staticmethod def forward(ctx, x1, x2, func1, func2=paddle.square): ctx.func = func2 @@ -96,7 +95,7 @@ class TestPyLayer(unittest.TestCase): ) def test_simple_pylayer_single_output(self): - class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class tanh(PyLayer): @staticmethod def forward(ctx, x1, func1, func2=paddle.square): ctx.func = func2 @@ -124,7 +123,7 @@ class TestPyLayer(unittest.TestCase): ) def test_pylayer_num_output_match(self): - class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class tanh(PyLayer): @staticmethod def forward( ctx, @@ -146,7 +145,7 @@ class TestPyLayer(unittest.TestCase): z.mean().backward() def test_pylayer_dtype(self): - class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class tanh(PyLayer): @staticmethod def forward(ctx, x, dtype): y = paddle.cast(x, dtype) @@ -176,7 +175,7 @@ class TestPyLayer(unittest.TestCase): self.assertIsNotNone(input1.grad) def test_pylayer_Exception_forward(self): - class Layer_None1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class Layer_None1(PyLayer): @staticmethod def forward(ctx, *args): return None @@ -189,7 +188,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z = Layer_None1.apply(input1) - class Layer_None2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class Layer_None2(PyLayer): @staticmethod def forward(ctx, *args): return [None, args[0]] @@ -202,7 +201,7 @@ class TestPyLayer(unittest.TestCase): # return None z = Layer_None2.apply(input1) - class Layer_one1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class Layer_one1(PyLayer): @staticmethod def forward(ctx, *args): return 1 @@ -216,7 +215,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z = Layer_one1.apply(input1) - class Layer_one2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class Layer_one2(PyLayer): @staticmethod def forward(ctx, *args): return [1, 2, args[0]] @@ -229,7 +228,7 @@ class TestPyLayer(unittest.TestCase): # return int z = Layer_one2.apply(input1) - class Layer_no_fw(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class Layer_no_fw(PyLayer): @staticmethod def backward(ctx, *args): return args @@ -239,7 +238,7 @@ class TestPyLayer(unittest.TestCase): z = Layer_no_fw.apply(input1) def test_pylayer_nograd(self): - class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class tanh(PyLayer): @staticmethod def forward(ctx, x1, func1, func2=paddle.square, xx=None): ctx.func = func2 @@ -257,9 +256,7 @@ class TestPyLayer(unittest.TestCase): self.assertIsNone(z.grad) def test_pylayer_Exception_bk(self): - class Layer_bk_none1( - EagerPyLayer if in_dygraph_mode() else LegacyPyLayer - ): + class Layer_bk_none1(PyLayer): @staticmethod def forward(ctx, x): return x * 2 @@ -275,9 +272,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.sum().backward() - class Layer_bk_none2( - EagerPyLayer if in_dygraph_mode() else LegacyPyLayer - ): + class Layer_bk_none2(PyLayer): @staticmethod def forward(ctx, x1, x2): return x1 + x2 @@ -293,9 +288,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.mean().backward() - class Layer_bk_one1( - EagerPyLayer if in_dygraph_mode() else LegacyPyLayer - ): + class Layer_bk_one1(PyLayer): @staticmethod def forward(ctx, x): return x + x @@ -311,9 +304,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.mean().backward() - class Layer_bk_one2( - EagerPyLayer if in_dygraph_mode() else LegacyPyLayer - ): + class Layer_bk_one2(PyLayer): @staticmethod def forward(ctx, x1, x2): return x1 * 2, x2 * 5 @@ -330,7 +321,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.mean().backward() - class Layer_no_bk(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class Layer_no_bk(PyLayer): @staticmethod def forward(ctx, x): return x * 2, x * 5 @@ -343,9 +334,7 @@ class TestPyLayer(unittest.TestCase): z = z[0] + z[1] z.mean().backward() - class Layer_bk_match( - EagerPyLayer if in_dygraph_mode() else LegacyPyLayer - ): + class Layer_bk_match(PyLayer): @staticmethod def forward(ctx, x): return x * 2, x * 5 @@ -362,9 +351,7 @@ class TestPyLayer(unittest.TestCase): z.mean().backward() def test_pylayer_bk_return_none(self): - class Layer_bk_none1( - EagerPyLayer if in_dygraph_mode() else LegacyPyLayer - ): + class Layer_bk_none1(PyLayer): @staticmethod def forward(ctx, x1, x2): return x1 + x2 @@ -382,9 +369,7 @@ class TestPyLayer(unittest.TestCase): with self.assertRaises(ValueError): z.mean().backward() - class Layer_bk_none2( - EagerPyLayer if in_dygraph_mode() else LegacyPyLayer - ): + class Layer_bk_none2(PyLayer): @staticmethod def forward(ctx, x1, x2): return x1 * 2, x2 * 5 @@ -403,7 +388,7 @@ class TestPyLayer(unittest.TestCase): z.mean().backward() def test_pylayer_inplace(self): - class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class cus_tanh(PyLayer): @staticmethod def forward(ctx, x): return x @@ -431,7 +416,7 @@ class TestPyLayer(unittest.TestCase): self.assertIsNotNone(data.grad) def test_pylayer_inplace_backward_error(self): - class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class cus_tanh(PyLayer): @staticmethod def forward(ctx, x): return x @@ -464,7 +449,7 @@ class TestPyLayer(unittest.TestCase): z.backward() def test_pylayer_inplace_backward_success_1(self): - class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class cus_tanh(PyLayer): @staticmethod def forward(ctx, x): return x @@ -493,7 +478,7 @@ class TestPyLayer(unittest.TestCase): self.assertIsNotNone(data.grad) def test_pylayer_inplace_backward_success_2(self): - class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class cus_tanh(PyLayer): @staticmethod def forward(ctx, x): return x @@ -522,9 +507,7 @@ class TestPyLayer(unittest.TestCase): self.assertIsNotNone(data.grad) def test_pylayer_inplace_and_leaf_exception(self): - class cus_pylayer_op( - EagerPyLayer if in_dygraph_mode() else LegacyPyLayer - ): + class cus_pylayer_op(PyLayer): @staticmethod def forward(ctx, x): return x @@ -550,7 +533,7 @@ class TestPyLayer(unittest.TestCase): z = layer(data) def test_backward_in_backward(self): - class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class cus_tanh(PyLayer): @staticmethod def forward(ctx, x): temp = x.detach() @@ -575,7 +558,7 @@ class TestPyLayer(unittest.TestCase): z = cus_tanh.apply(data) def test_return_to_tensor(self): - class Tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer): + class Tanh(PyLayer): @staticmethod def forward(ctx, x1): y1 = paddle.tanh(x1) @@ -597,7 +580,7 @@ class TestPyLayer(unittest.TestCase): z.mean().backward() def test_materialize_grads(self): - class Tanh(EagerPyLayer): + class Tanh(PyLayer): @staticmethod def forward(ctx, x): ctx.mark_not_inplace(x) @@ -613,7 +596,7 @@ class TestPyLayer(unittest.TestCase): Tanh.apply(x)[0].backward() def test_dont_materialize_grads(self): - class Tanh(EagerPyLayer): + class Tanh(PyLayer): @staticmethod def forward(ctx, x): ctx.mark_not_inplace(x) @@ -630,7 +613,7 @@ class TestPyLayer(unittest.TestCase): Tanh.apply(x)[0].backward() def test_mark_non_differentiable(self): - class Tanh(EagerPyLayer): + class Tanh(PyLayer): @staticmethod def forward(ctx, x): a = x + x @@ -648,7 +631,7 @@ class TestPyLayer(unittest.TestCase): y.sum().backward() def test_mark_non_differentiable2(self): - class Tanh(EagerPyLayer): + class Tanh(PyLayer): @staticmethod def forward(ctx, x): a = x + x @@ -669,106 +652,5 @@ class TestPyLayer(unittest.TestCase): self.assertEqual(x.grad, paddle.ones([1], dtype="float64")) -class TestPyLayerReturnType(unittest.TestCase): - def test_forward_args_fake_tensor(self): - class Tanh(LegacyPyLayer): - @staticmethod - def forward(ctx, x1): - y1 = FakeTensor() - return y1, x1 - - @staticmethod - def backward(ctx, dy1, dy2): - return dy1 - - input1 = FakeTensor() - - with self.assertRaises(ValueError): - y1, y2 = Tanh.apply(input1) - - def test_forward_kwargs_fake_tensor(self): - class Tanh(LegacyPyLayer): - @staticmethod - def forward(ctx, x1): - - return x1 - - @staticmethod - def backward(ctx, dy1, dy2): - return dy1 - - input1 = FakeTensor() - - with self.assertRaises(ValueError): - y = Tanh.apply(x1=input1) - - def test_forward_return_fake_tensor(self): - class Tanh(LegacyPyLayer): - @staticmethod - def forward(ctx, x1): - - return FakeTensor() - - @staticmethod - def backward(ctx, dy1, dy2): - return dy1 - - input1 = paddle.randn([3, 2]) - - with self.assertRaises(ValueError): - y = Tanh.apply(x1=input1) - - def test_forward_return_fake_tensor_tuple(self): - class Tanh(LegacyPyLayer): - @staticmethod - def forward(ctx, x1): - - return FakeTensor(), FakeTensor() - - @staticmethod - def backward(ctx, dy1, dy2): - return dy1 - - input1 = paddle.randn([3, 2]) - - with self.assertRaises(ValueError): - y = Tanh.apply(x1=input1) - - def test_backward_return_fake_tensor_tuple(self): - class Tanh(LegacyPyLayer): - @staticmethod - def forward(ctx, x1, x2): - return x1 + 1, x1 + 2 - - @staticmethod - def backward(ctx, dy1, dy2): - - return FakeTensor(), 2 - - input1 = paddle.randn([3, 2]) - input1.stop_gradient = False - - with self.assertRaises(ValueError): - y, _ = Tanh.apply(input1, 1 + input1) - y.mean().backward() - - def test_backward_return_fake_tensor(self): - class Tanh(LegacyPyLayer): - @staticmethod - def forward(ctx, x1): - return x1 + 1, x1 + 2 - - @staticmethod - def backward(ctx, dy1, dy2): - return FakeTensor() - - input1 = paddle.randn([3, 2]) - input1.stop_gradient = False - - with self.assertRaises(ValueError): - y, _ = Tanh.apply(input1) - y.mean().backward() - - if __name__ == '__main__': unittest.main() -- GitLab