From 29f65225a212c843ffa81dcf87f417120f8c7ee4 Mon Sep 17 00:00:00 2001 From: WeiXin Date: Wed, 14 Apr 2021 21:27:26 -0500 Subject: [PATCH] Customizable Python Layer in Dygraph (#32130) * custom python backward * polish up the code * polish up the code * polish up the code. * Fix code format and comments. * Delete redundant files. * add unnittest. * edit unnittest. * edit unnittest. * Remove redundant header files. * Improve coverage and remove redundant code. * support saving for backward. * polish code according to comments. * Add support type for PyLayer. * Modify the DOC. * polish Doc. * polish Doc. * polish Doc. * polish Doc. * polish Doc. * polish Doc. * polish code and make the code robust. * Modify the code format. --- paddle/fluid/framework/operator.h | 1 + paddle/fluid/imperative/dygraph_grad_maker.h | 2 + paddle/fluid/imperative/layer.cc | 2 +- paddle/fluid/imperative/layer.h | 2 + paddle/fluid/imperative/py_layer_fwd.h | 172 ++++++++++ paddle/fluid/imperative/tracer.cc | 2 +- paddle/fluid/imperative/tracer.h | 2 + paddle/fluid/imperative/variable_wrapper.h | 3 + paddle/fluid/operators/CMakeLists.txt | 3 +- paddle/fluid/operators/py_layer_op.cc | 197 +++++++++++ paddle/fluid/operators/py_layer_op.h | 105 ++++++ paddle/fluid/pybind/CMakeLists.txt | 1 + paddle/fluid/pybind/imperative.cc | 62 +++- python/paddle/autograd/__init__.py | 5 +- python/paddle/autograd/py_layer.py | 318 ++++++++++++++++++ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../fluid/tests/unittests/test_pylayer_op.py | 303 +++++++++++++++++ 17 files changed, 1159 insertions(+), 22 deletions(-) create mode 100644 paddle/fluid/imperative/py_layer_fwd.h create mode 100644 paddle/fluid/operators/py_layer_op.cc create mode 100644 paddle/fluid/operators/py_layer_op.h create mode 100644 python/paddle/autograd/py_layer.py create mode 100644 python/paddle/fluid/tests/unittests/test_pylayer_op.py diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index bf27a8e37e0..3fc61581eca 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -419,6 +419,7 @@ class ExecutionContext { const RuntimeContext Context() const { return ctx_; } std::string DebugString() const { return op_.DebugString(); } + const OperatorBase& GetOp() const { return op_; } private: const OperatorBase& op_; diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index a3678404728..7fefc9ccc67 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -279,6 +279,8 @@ class TracedGradOp { void SetType(const std::string& type) { op_->SetType(type); } + const framework::OperatorBase& InnerOp() const { return op_->InnerOp(); } + void SetAttrMap(const framework::AttributeMap& attrs) { return op_->SetAttrMap(attrs); } diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 062f04c6b70..70359dc3fd2 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -406,7 +406,7 @@ void OpBase::Run(const framework::OperatorBase& op, OpBaseRunImpl(op, ins, outs, attrs, place); } -static void ClearNoNeedBufferInputs(OpBase* op) { +void ClearNoNeedBufferInputs(OpBase* op) { auto& inferer = op->Info().NoNeedBufferVarsInferer(); if (!inferer) return; auto* ins = op->GetMutableInsMap(); diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index 362ba1eb70b..bbede47e364 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -286,5 +286,7 @@ std::shared_ptr CreateGradOpNode( const platform::Place& place, const std::map& inplace_map); +void ClearNoNeedBufferInputs(OpBase* op); + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/py_layer_fwd.h b/paddle/fluid/imperative/py_layer_fwd.h new file mode 100644 index 00000000000..bd132f2576f --- /dev/null +++ b/paddle/fluid/imperative/py_layer_fwd.h @@ -0,0 +1,172 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/imperative/layer.h" +#include "paddle/fluid/imperative/tracer.h" + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/type_defs.h" +#include "paddle/fluid/operators/py_layer_op.h" + +namespace paddle { +namespace imperative { + +namespace py = ::pybind11; + +bool RequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs) { + for (const auto& name_pair : ins) { + for (const auto& var_base : name_pair.second) { + if (!var_base->OverridedStopGradient()) { + PassStopGradient(outs, var_base->OverridedStopGradient()); + return true; + } + } + } + return false; +} + +std::shared_ptr CreateGradOpNode( + const std::string& type, const NameVarBaseMap& ins, + const NameVarBaseMap& outs, const framework::AttributeMap& attrs, + const platform::Place& place, + const std::map& inplace_map, + const std::shared_ptr& py_context) { + operators::PyLayerGradOpMaker maker( + type, ins, outs, attrs, inplace_map); + + maker.SetPyLayerContext(py_context); + auto grad_node = maker(); + if (grad_node && !grad_node->empty()) { + for (auto& grad_op : *grad_node) { + grad_op.SetId(OpBase::GenerateUniqueId()); + grad_op.SetPlace(place); + ClearNoNeedBufferInputs(&grad_op); + } + return grad_node; + } else { + return nullptr; + } +} + +py::object PyLayerApply(const platform::Place& place, const py::object& cls, + const py::args args, const py::kwargs kwargs) { + auto bk_function = cls.attr("_backward_function"); + auto context = bk_function(); + auto forward = cls.attr("forward"); + + auto result_forward = forward(context, *args, **kwargs); + std::shared_ptr py_layer_ctx = + std::make_shared(context.release().ptr()); + // make inputs to varbase + std::vector> input_vars; + // process args,`input_vars` only collect `imperative::VarBase` + if (!args.empty()) { + for (auto ptr = args.begin(); ptr != args.end(); ptr++) { + try { + if (Py_None != ptr->ptr()) { + auto a = ptr->cast>(); + input_vars.push_back(a); + } + } catch (py::cast_error& err) { + // Only collect Tensor type in 'args' and pass them to backward. Ignore + // other types of input temporarily. + } + } + } + // process kwargs, only collect `imperative::VarBase` + if (!kwargs.empty()) { + for (auto ptr = kwargs.begin(); ptr != kwargs.end(); ptr++) { + try { + if (Py_None != ptr->second.ptr()) { + auto a = ptr->second.cast>(); + input_vars.push_back(a); + } + } catch (py::cast_error&) { + // Only collect Tensor type in 'kwargs' and pass them to backward. + // Ignore other types of input temporarily. + } + } + } + NameVarBaseMap ins = {{"X", input_vars}}; + + std::vector> output_vars; + if (PyTuple_Check(result_forward.ptr()) || + PyList_Check(result_forward.ptr())) { + auto tuple_result = result_forward.cast(); + for (size_t i = 0; i < tuple_result.size(); i++) { + if (Py_None != tuple_result[i].ptr()) { + try { + auto temp_out = + tuple_result[i].cast>(); + output_vars.push_back(temp_out); + } catch (py::cast_error&) { + PADDLE_THROW(platform::errors::Unimplemented( + "The output of `PyLayer.forward` should be `Tensor`.")); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "The output of `PyLayer.forward` can not be `None`.")); + } + } + } else { + if (Py_None != result_forward.ptr()) { + try { + auto temp_out = + result_forward.cast>(); + output_vars.push_back(temp_out); + } catch (py::cast_error&) { + PADDLE_THROW(platform::errors::Unimplemented( + "The output of `PyLayer.forward` should be `Tensor`.")); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "The output of `PyLayer.forward` can not be `None`.")); + } + } + + NameVarBaseMap outs = {{"Out", output_vars}}; + + if (RequiredGrad(ins, outs)) { + std::map inplace_map{}; + bool if_inplace = false; + for (auto temp_ins : input_vars) { + if (if_inplace) { + break; + } + for (auto temp_outs : output_vars) { + if (temp_ins->Name() == temp_outs->Name()) { + if_inplace = true; + break; + } + } + } + if (if_inplace) { + inplace_map["X"] = "Out"; + } + + CreateGradOpNode("py_layer", ins, outs, {{}}, place, inplace_map, + py_layer_ctx); + } else { + VLOG(3) << "No Grad to track for Op: py_layer_op"; + } + + return result_forward; +} + +} // namespace imperative +} // namespace paddle diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 608cc407d5b..777cb10e075 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -38,7 +38,7 @@ void SetCurrentTracer(const std::shared_ptr& tracer) { VLOG(6) << "Set current tracer: " << g_current_tracer; } -static void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) { +void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) { for (const auto& pair : outs) { for (const auto& var : pair.second) { // NOTE(zhiqiu): this happends when None output are passed from python diff --git a/paddle/fluid/imperative/tracer.h b/paddle/fluid/imperative/tracer.h index b10d1b2d0b4..8f505508782 100644 --- a/paddle/fluid/imperative/tracer.h +++ b/paddle/fluid/imperative/tracer.h @@ -130,5 +130,7 @@ void IncreaseVarbaseReferenceCountUntilCopyComplete( const std::shared_ptr& var, const platform::Place& place); +void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad); + } // namespace imperative } // namespace paddle diff --git a/paddle/fluid/imperative/variable_wrapper.h b/paddle/fluid/imperative/variable_wrapper.h index 3b23f4a6222..5fa8b89a396 100644 --- a/paddle/fluid/imperative/variable_wrapper.h +++ b/paddle/fluid/imperative/variable_wrapper.h @@ -38,6 +38,9 @@ class VariableWrapper { explicit VariableWrapper(const std::string& name) : name_(name) {} + VariableWrapper(const std::string& name, const framework::Variable& variable) + : var_(variable), name_(name) {} + ~VariableWrapper() { VLOG(10) << "Destruct VariableWrapper: " << Name(); } const framework::Variable& Var() const { return var_; } diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index b475f75990f..cecc70cc6dd 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -69,7 +69,7 @@ if(WITH_UNITY_BUILD) include(unity_build_rule.cmake) endif() -register_operators(EXCLUDES py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op +register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS}) op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS}) @@ -162,6 +162,7 @@ endif() cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS}) if (WITH_PYTHON) cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind) + cc_library(py_layer_op SRCS py_layer_op.cc DEPS op_registry python pybind) endif() if (WITH_ASCEND_CL) diff --git a/paddle/fluid/operators/py_layer_op.cc b/paddle/fluid/operators/py_layer_op.cc new file mode 100644 index 00000000000..0d5c23bed60 --- /dev/null +++ b/paddle/fluid/operators/py_layer_op.cc @@ -0,0 +1,197 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/operators/py_layer_op.h" + +namespace paddle { +namespace operators { + +namespace py = ::pybind11; + +void RunPyObject(py::object *py_object, + const std::vector &ins, + std::vector *outs) { + py::gil_scoped_acquire guard; + + auto py_function = py_object->attr("backward"); + + py::tuple inputs(ins.size()); + for (size_t i = 0; i < ins.size(); i++) { + auto in_var = ins[i]; + if (in_var != nullptr) { + auto name = paddle::string::Sprintf("generator_custom_py_layer_%d@GRAD", + static_cast(i)); + + std::shared_ptr temp_wrap = + std::make_shared(name, *in_var); + temp_wrap->InnerSetOverridedStopGradient(true); + std::shared_ptr temp_varbase = + std::make_shared(temp_wrap); + try { + inputs[i] = py::cast(temp_varbase).ptr(); + } catch (py::cast_error &) { + PADDLE_THROW(platform::errors::Unimplemented( + "The output of `PyLayer.backward` should be `Tensor`.")); + } + } + } + + auto py_result = py_function(*py_object, *inputs); + + if (PyTuple_Check(py_result.ptr()) || PyList_Check(py_result.ptr())) { + auto result_tuple = py_result.cast(); + if (result_tuple.size() != outs->size()) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The number of outputs of `PyLayer.backward` should be %d, but " + "received %d.", + outs->size(), result_tuple.size())); + } + for (size_t i = 0; i < result_tuple.size(); i++) { + if (Py_None != result_tuple[i].ptr()) { + try { + auto result_var = + result_tuple[i].cast>(); + *(*outs)[i] = result_var->Var(); + } catch (py::cast_error &) { + PADDLE_THROW(platform::errors::Unimplemented( + "The output of `PyLayer.backward` should be `Tensor`.")); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "The output of `PyLayer.backward` can not be `None`.")); + } + } + } else { + if (Py_None != py_result.ptr()) { + try { + auto result_var = + py_result.cast>(); + *((*outs)[0]) = result_var->Var(); + } catch (py::cast_error &) { + PADDLE_THROW(platform::errors::Unimplemented( + "The output of `PyLayer.backward` should be `Tensor`.")); + } + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "The output of `PyLayer.backward` can not be `None`.")); + } + } +} + +void PyLayerGradOpMaker::Apply( + GradOpPtr grad_op) const { + grad_op->SetType("py_layer"); + auto &inner_op = grad_op->InnerOp(); + auto py_layer_op_const = dynamic_cast(&inner_op); + + if (py_layer_op_const) { + auto py_layer_op = const_cast(py_layer_op_const); + py_layer_op->SetPyLayerContext(py_context_); + + } else { + PADDLE_THROW(platform::errors::Fatal( + "PyLayerGradOpMaker can't cast %s to PyLayerOp*.", + typeid(&inner_op).name())); + } + + auto fwd_out_grads = this->OutputGrad("Out"); + using return_type = decltype(fwd_out_grads); + return_type bwd_ins; + + bwd_ins.insert(bwd_ins.begin(), fwd_out_grads.begin(), fwd_out_grads.end()); + + auto bwd_outs = this->InputGrad("X", false); + + grad_op->SetInput("X", bwd_ins); + grad_op->SetOutput("Out", bwd_outs); +} + +class PyLayerOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Inputs of PyLayer op.").AsDuplicable(); + AddOutput("Out", "Outputs of PyLayer op").AsDuplicable(); + AddComment(R"DOC("PyLayer Op")DOC"); + } +}; + +template +class PyLayerOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto &op_ = ctx.GetOp(); + auto pylayer_op = dynamic_cast(&op_); + if (pylayer_op) { + auto py_layer_context = pylayer_op->GetPyLayerContext(); + py::object bk_ctx(py::handle(py_layer_context->GetMutableCtx()), true); + auto &input_vars = ctx.MultiInputVar("X"); + auto output_vars = ctx.MultiOutputVar("Out"); + RunPyObject(&bk_ctx, input_vars, &output_vars); + + } else { + PADDLE_THROW(platform::errors::Fatal( + "PyLayerOpKernel can't cast %s to PyLayer*.", typeid(&op_).name())); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(py_layer, ops::PyLayerOp, ops::PyLayerOpMaker, + ops::PyLayerGradOpMaker, + ops::PyLayerGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + py_layer, ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel); +#ifdef PADDLE_WITH_CUDA +REGISTER_OP_CUDA_KERNEL( + py_layer, ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel, + ops::PyLayerOpKernel); +#endif // PADDLE_WITH_CUDA diff --git a/paddle/fluid/operators/py_layer_op.h b/paddle/fluid/operators/py_layer_op.h new file mode 100644 index 00000000000..133435aa84d --- /dev/null +++ b/paddle/fluid/operators/py_layer_op.h @@ -0,0 +1,105 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/python_headers.h" + +namespace paddle { +namespace operators { +namespace py = ::pybind11; + +class PyLayerContext { + public: + explicit PyLayerContext(PyObject* context) : context_(context) { + Py_INCREF(context_); + } + + PyLayerContext() = delete; + + PyObject* GetMutableCtx() { return context_; } + + private: + PyObject* context_; +}; + +class PyLayerOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + VLOG(3) << "`InferShape` of `PyLayer` is an empty function, and it cannot " + "infer the shape of the output tensors."; + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(data_type, ctx.device_context()); + } + + public: + void SetPyLayerContext(const std::shared_ptr& py_context) { + py_context_ = py_context; + } + const std::shared_ptr& GetPyLayerContext() const { + return py_context_; + } + + private: + std::shared_ptr py_context_; +}; + +template +class PyLayerGradOpMaker {}; +template <> +class PyLayerGradOpMaker + : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker< + paddle::framework::OpDesc>::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "`PyLayer` don't support static graph mode.")); + } +}; + +template <> +class PyLayerGradOpMaker + : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker< + paddle::imperative::OpBase>::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr grad_op) const override; + + public: + void SetPyLayerContext(const std::shared_ptr& py_context) { + py_context_ = py_context; + } + + private: + std::shared_ptr py_context_; +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 10c79933546..b1d60193d46 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -36,6 +36,7 @@ endif(NOT WIN32) if(WITH_PYTHON) list(APPEND PYBIND_DEPS py_func_op) + list(APPEND PYBIND_DEPS py_layer_op) endif() set(PYBIND_SRCS diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 7a2ff9ff7ec..0817dc33671 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -39,6 +39,7 @@ limitations under the License. */ #include "paddle/fluid/imperative/nccl_context.h" #include "paddle/fluid/imperative/partial_grad_engine.h" #include "paddle/fluid/imperative/profiler.h" +#include "paddle/fluid/imperative/py_layer_fwd.h" #include "paddle/fluid/imperative/reducer.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/type_defs.h" @@ -1339,22 +1340,28 @@ void BindImperative(py::module *m_ptr) { &imperative::VarBase::SetOverridedStopGradient) .def_property("persistable", &imperative::VarBase::Persistable, &imperative::VarBase::SetPersistable) - .def_property_readonly( - "shape", - [](imperative::VarBase &self) { - if (self.Var().IsType()) { - return framework::vectorize( - self.Var().Get().dims()); - } else if (self.Var().IsType()) { - return framework::vectorize( - self.Var().Get().value().dims()); - } else { - VLOG(2) << "It is meaningless to get shape of " - "variable type " - << GetTypeName(self); - return std::vector(); - } - }) + .def_property_readonly("shape", + [](imperative::VarBase &self) { + if (self.Var().IsType()) { + return framework::vectorize( + self.Var() + .Get() + .dims()); + } else if (self.Var() + .IsType< + framework::SelectedRows>()) { + return framework::vectorize( + self.Var() + .Get() + .value() + .dims()); + } else { + VLOG(2) << "It is meaningless to get shape of " + "variable type " + << GetTypeName(self); + return std::vector(); + } + }) .def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf, R"DOC( Whether a Tensor is leaf Tensor. @@ -1643,6 +1650,29 @@ void BindImperative(py::module *m_ptr) { &imperative::BKCLParallelContext::InitWithRingID, py::arg("ring_id")); #endif + m.def("pylayer_apply", + [](const platform::CPUPlace &place, const py::object &cls, + const py::args args, const py::kwargs kwargs) { + return imperative::PyLayerApply(place, cls, args, kwargs); + }); + + m.def("pylayer_apply", + [](const platform::CUDAPlace &place, const py::object &cls, + const py::args args, const py::kwargs kwargs) { + return imperative::PyLayerApply(place, cls, args, kwargs); + }); + + m.def("pylayer_apply", + [](const platform::XPUPlace &place, const py::object &cls, + const py::args args, const py::kwargs kwargs) { + return imperative::PyLayerApply(place, cls, args, kwargs); + }); + + m.def("pylayer_apply", + [](const platform::CUDAPinnedPlace &place, const py::object &cls, + const py::args args, const py::kwargs kwargs) { + return imperative::PyLayerApply(place, cls, args, kwargs); + }); } } // namespace pybind diff --git a/python/paddle/autograd/__init__.py b/python/paddle/autograd/__init__.py index 8b3f3086a4a..71110e95817 100644 --- a/python/paddle/autograd/__init__.py +++ b/python/paddle/autograd/__init__.py @@ -16,7 +16,6 @@ from ..fluid.dygraph.base import grad #DEFINE_ALIAS from . import backward_mode from .backward_mode import backward +from .py_layer import PyLayer, PyLayerContext -__all__ = ['grad'] - -__all__ += backward_mode.__all__ +__all__ = ['grad', 'backward', 'PyLayer', 'PyLayerContext'] diff --git a/python/paddle/autograd/py_layer.py b/python/paddle/autograd/py_layer.py new file mode 100644 index 00000000000..c093565dc92 --- /dev/null +++ b/python/paddle/autograd/py_layer.py @@ -0,0 +1,318 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.fluid.framework import dygraph_only +from paddle.fluid import core +__all__ = ['PyLayer', 'PyLayerContext'] + + +class PyLayerContext(object): + """ + The object of this class is a context that is used in PyLayer to enhance the function. + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + # ctx is a object of PyLayerContext. + y = paddle.tanh(x) + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # ctx is a object of PyLayerContext. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + """ + + def __init__(self): + self.container = None + + def save_for_backward(self, *tensors): + """ + Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors. + + .. note:: + This API should be called at most once, and only inside `forward`. + + Args: + tensors(list of Tensors): Tensors to be stored. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + # ctx is a context object that store some objects for backward. + y = paddle.tanh(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + + """ + self.container = tensors + + def saved_tensor(self): + """ + Get the tensors stored by ``save_for_backward``. + + Returns: + list of Tensors or None: If context contains tensors stored by `save_for_backward`, + then return these tensors, otherwise return None. + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + # ctx is a context object that store some objects for backward. + y = paddle.tanh(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + """ + + return self.container + + +def with_mateclass(meta, *bases): + class impl(meta): + def __new__(cls, name, temp_bases, attrs): + return meta(name, bases, attrs) + + return type.__new__(impl, "impl", (), {}) + + +class CPyLayer(object): + @classmethod + @dygraph_only + def apply(cls, *args, **kwargs): + """ + After building the custom PyLayer, run it through the ``apply``. + + Args: + *args(tuple): input of PyLayer. + **kwargs(dict): input of PyLayer. + + Returns: + tensors or other types : output of PyLayer. + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x, func1, func2=paddle.square): + ctx.func = func2 + y = func1(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - ctx.func(y)) + return grad + + + data = paddle.randn([2, 3], dtype="float64") + data.stop_gradient = False + # run custom Layer. + z = cus_tanh.apply(data, func1=paddle.tanh) + """ + place = paddle.fluid.framework._current_expected_place() + with paddle.fluid.dygraph.no_grad(): + return core.pylayer_apply(place, cls, *args, **kwargs) + + +class PyLayerBackward(PyLayerContext): + def backward(self, *args, **kwargs): + with paddle.fluid.dygraph.no_grad(): + return self._forward_cls.backward(*args, **kwargs) + + +class LayerMeta(type): + def __init__(cls, name, bases, attrs): + cls._backward_function = type(name + '_backward', (PyLayerBackward, ), + {"_forward_cls": cls}) + + return super(LayerMeta, cls).__init__(name, bases, attrs) + + +class PyLayer(with_mateclass(LayerMeta, CPyLayer)): + """ + Build a custom `Layer` by creating subclasses. Subclasses need to follow the following rules: + 1. Subclasses contain `forward` and `backward` function. Both forward and backward are @staticmethod. + Their first argument should be a context and `None` can not be included in the returned result. + 2. Input of backward contains a context as the first argument, and the rest arguments are the + gradient of forward's output tensors. so the number of backward's input tensors equal to + the number of forward output tensors. If you need the forward's inputs or outputs in `backward`, + you can use `save_for_backward` to store the required tensors, and then use them in the backward. + 3. Output of backward function can only be `Tensor` or tuple/list of `Tensor`. + Output tensors of backward are the gradient of forward's input tensors, + so the number of backward's output tensors equal to the number of forward input tensors. + After building the custom Layer, run it through the `apply` method. + + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + # Inherit from PyLayer + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x, func1, func2=paddle.square): + # ctx is a context object that store some objects for backward. + ctx.func = func2 + y = func1(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + # forward has only one output, so there is only one gradient in the input of backward. + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - ctx.func(y)) + # forward has only one input, so only one gradient tensor is returned. + return grad + + + data = paddle.randn([2, 3], dtype="float64") + data.stop_gradient = False + z = cus_tanh.apply(data, func1=paddle.tanh) + z.mean().backward() + + print(data.grad) + + """ + + @staticmethod + def forward(ctx, *args, **kwargs): + """ + It is to be overloaded by subclasses. It must accept a object of `PyLayerContext` as + the first argument, followed by any number of arguments (tensors or other types). + `None` can not be included in the returned result. + + Args: + *args(tuple): input of PyLayer. + **kwargs(dict): input of PyLayer. + + Returns: + tensors or other types : output of PyLayer. + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + y = paddle.tanh(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + """ + raise NotImplementedError( + "You must implement the forward function for PyLayer.") + + @staticmethod + def backward(ctx, *args, **kwargs): + """ + This is a function to calculate the gradient. It is to be overloaded by subclasses. + It must accept a object of `PyLayerContext` as the first argument, and the rest + arguments are the gradient of forward's output tensors. Output tensors of backward + are the gradient of forward's input tensors. + + Args: + *args(tuple): The gradient of forward's output tensor(s). + **kwargs(dict): The gradient of forward's output tensor(s). + + Returns: + Tensor or list of Tensors: The gradient of forward's input tensor(s). + + Examples: + .. code-block:: python + + import paddle + from paddle.autograd import PyLayer + + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + y = paddle.tanh(x) + # Pass tensors to backward. + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx, dy): + # Get the tensors passed by forward. + y, = ctx.saved_tensor() + grad = dy * (1 - paddle.square(y)) + return grad + """ + + raise NotImplementedError( + "You must implement the backward function for PyLayer.") diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 6b276574516..604d50b8ed1 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -730,6 +730,7 @@ set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120) set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120) set_tests_properties(test_gather_op PROPERTIES TIMEOUT 120) set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 250) +set_tests_properties(test_pylayer_op PROPERTIES TIMEOUT 120) if (WIN32) set_tests_properties(test_static_save_load_large PROPERTIES TIMEOUT 900) set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250) diff --git a/python/paddle/fluid/tests/unittests/test_pylayer_op.py b/python/paddle/fluid/tests/unittests/test_pylayer_op.py new file mode 100644 index 00000000000..89f8330fe5b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_pylayer_op.py @@ -0,0 +1,303 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +import paddle +from paddle.autograd import PyLayer + + +class TestPyLayer(unittest.TestCase): + def test_simple_pylayer_multiple_output(self): + class tanh(PyLayer): + @staticmethod + def forward(ctx, x1, x2, func1, func2=paddle.square): + ctx.func = func2 + y1 = func1(x1) + y2 = func1(x2) + ctx.save_for_backward(y1, y2) + return y1, y2 + + @staticmethod + def backward(ctx, dy1, dy2): + y1, y2 = ctx.saved_tensor() + re1 = dy1 * (1 - ctx.func(y1)) + re2 = dy2 * (1 - paddle.square(y2)) + return re1, re2 + + input1 = paddle.randn([2, 3]).astype("float64") + input2 = input1.detach().clone() + input1.stop_gradient = False + input2.stop_gradient = False + z = tanh.apply(input1, input1, paddle.tanh, paddle.square) + z = z[0] + z[1] + z.mean().backward() + + z2 = paddle.tanh(input2) + paddle.tanh(input2) + z2.mean().backward() + + self.assertTrue(np.max(np.abs((input1.grad - input2.grad))) < 1e-10) + + def test_simple_pylayer_single_output(self): + class tanh(PyLayer): + @staticmethod + def forward(ctx, x1, func1, func2=paddle.square): + ctx.func = func2 + y1 = func1(x1) + ctx.save_for_backward(y1) + return y1 + + @staticmethod + def backward(ctx, dy1): + y1, = ctx.saved_tensor() + re1 = dy1 * (1 - ctx.func(y1)) + return re1 + + input1 = paddle.randn([2, 3]).astype("float64") + input2 = input1.detach().clone() + input1.stop_gradient = False + input2.stop_gradient = False + z = tanh.apply(x1=input1, func1=paddle.tanh) + z.mean().backward() + z2 = paddle.tanh(input2) + z2.mean().backward() + + self.assertTrue(np.max(np.abs((input1.grad - input2.grad))) < 1e-10) + + def test_pylayer_dtype(self): + class tanh(PyLayer): + @staticmethod + def forward(ctx, x, dtype): + y = paddle.cast(x, dtype) + return y + + @staticmethod + def backward(ctx, dy1): + return dy1 + + dtypes = [ + 'bool', 'float16', 'float32', 'float64', 'uint8', 'int32', 'int64' + ] + for dtype in dtypes: + input1 = (paddle.randn([2, 3])) + input1.stop_gradient = False + self.assertTrue(input1.grad is None) + + z = tanh.apply(input1, dtype) + z = paddle.cast(z, "float32") + z.sum().backward() + self.assertTrue(input1.grad is not None) + + def test_pylayer_Exception_forward(self): + class Layer_None1(PyLayer): + @staticmethod + def forward(ctx, *args): + return None + + @staticmethod + def backward(ctx, *args): + return args + + input1 = paddle.randn([2, 3]).astype("float64") + with self.assertRaises(NotImplementedError): + z = Layer_None1.apply(input1) + + class Layer_None2(PyLayer): + @staticmethod + def forward(ctx, *args): + return [None, None] + + @staticmethod + def backward(ctx, *args): + return args + + input1 = paddle.randn([2, 3]).astype("float64") + with self.assertRaises(NotImplementedError): + z = Layer_None2.apply(input1) + + class Layer_one1(PyLayer): + @staticmethod + def forward(ctx, *args): + return 1 + + @staticmethod + def backward(ctx, *args): + return args + + input1 = paddle.randn([2, 3]).astype("float64") + with self.assertRaises(NotImplementedError): + z = Layer_one1.apply(input1) + + class Layer_one2(PyLayer): + @staticmethod + def forward(ctx, *args): + return [1, 2] + + @staticmethod + def backward(ctx, *args): + return args + + input1 = paddle.randn([2, 3]).astype("float64") + with self.assertRaises(NotImplementedError): + z = Layer_one2.apply(input1) + + class Layer_no_fw(PyLayer): + @staticmethod + def backward(ctx, *args): + return args + + input1 = paddle.randn([2, 3]).astype("float64") + with self.assertRaises(NotImplementedError): + z = Layer_no_fw.apply(input1) + + def test_pylayer_nograd(self): + class tanh(PyLayer): + @staticmethod + def forward(ctx, x1, func1, func2=paddle.square, xx=None): + ctx.func = func2 + y1 = func1(x1) + return y1 + + @staticmethod + def backward(ctx, x1, y1, dy1): + re1 = dy1 * (1 - ctx.func(y1)) + return re1 + + input1 = paddle.randn([2, 3]).astype("float64") + z = tanh.apply(input1, paddle.tanh, paddle.square) + z.mean().backward() + self.assertTrue(z.grad is None) + + def test_pylayer_Exception_bk(self): + class Layer_bk_none1(PyLayer): + @staticmethod + def forward(ctx, x): + return x * 2 + + @staticmethod + def backward(ctx, dy1): + return None + + input2 = paddle.randn([2, 3]).astype("float64") + input2.stop_gradient = False + z = Layer_bk_none1.apply(input2) + + with self.assertRaises(NotImplementedError): + with paddle.fluid.dygraph.guard(): + z.sum().backward() + + class Layer_bk_none2(PyLayer): + @staticmethod + def forward(ctx, x1, x2): + return x1 + x2 + + @staticmethod + def backward(ctx, dy1): + return None, dy1 + + input1 = paddle.randn([2, 3]).astype("float64") + input1.stop_gradient = False + z = Layer_bk_none2.apply(input1, input1) + with self.assertRaises(NotImplementedError): + with paddle.fluid.dygraph.guard(): + z.mean().backward() + + class Layer_bk_one1(PyLayer): + @staticmethod + def forward(ctx, x): + return x + x + + @staticmethod + def backward(ctx, dy): + return 1 + + input1 = paddle.randn([2, 3]).astype("float64") + input1.stop_gradient = False + z = Layer_bk_one1.apply(input1) + with self.assertRaises(NotImplementedError): + with paddle.fluid.dygraph.guard(): + z.mean().backward() + + class Layer_bk_one2(PyLayer): + @staticmethod + def forward(ctx, x): + return x * 2, x * 5 + + @staticmethod + def backward(ctx, *args): + return 1, 1 + + input1 = paddle.randn([2, 3]).astype("float64") + input1.stop_gradient = False + z = Layer_bk_one1.apply(input1) + with self.assertRaises(NotImplementedError): + with paddle.fluid.dygraph.guard(): + z.mean().backward() + + class Layer_no_bk(PyLayer): + @staticmethod + def forward(ctx, x): + return x * 2, x * 5 + + input1 = paddle.randn([2, 3]).astype("float64") + input1.stop_gradient = False + z = Layer_no_bk.apply(input1) + + with self.assertRaises(NotImplementedError): + with paddle.fluid.dygraph.guard(): + z = z[0] + z[1] + z.mean().backward() + + class Layer_bk_match(PyLayer): + @staticmethod + def forward(ctx, x): + return x * 2, x * 5 + + @staticmethod + def backward(ctx, dy1, dy2): + return dy2 * 2, dy1 * 2 + + input1 = paddle.randn([2, 3]).astype("float64") + input1.stop_gradient = False + z = Layer_bk_match.apply(input1) + with self.assertRaises(ValueError): + with paddle.fluid.dygraph.guard(): + z = z[0] + z[1] + z.mean().backward() + + def test_pylayer_inplace(self): + class cus_tanh(PyLayer): + @staticmethod + def forward(ctx, x): + return x.mean() + + @staticmethod + def backward(ctx, dy): + return dy + + for i in range(2): + data = paddle.ones([2, 3], dtype="float64") / (i + 1) + data.stop_gradient = False + data = paddle.nn.functional.relu(data) + z = paddle.tanh(data) + z = cus_tanh.apply(data) + z.backward() + self.assertTrue(data.grad is not None) + + +if __name__ == '__main__': + unittest.main() -- GitLab