未验证 提交 29f65225 编写于 作者: W WeiXin 提交者: GitHub

Customizable Python Layer in Dygraph (#32130)

* custom python backward

* polish up the code

* polish up the code

* polish up the code.

* Fix code format and comments.

* Delete redundant files.

* add unnittest.

* edit unnittest.

* edit unnittest.

* Remove redundant header files.

* Improve coverage and remove redundant code.

* support saving for backward.

* polish code according to comments.

* Add support type for PyLayer.

* Modify the DOC.

* polish Doc.

* polish Doc.

* polish Doc.

* polish Doc.

* polish Doc.

* polish Doc.

* polish code and make the code robust.

* Modify the code format.
上级 0c037d2d
......@@ -419,6 +419,7 @@ class ExecutionContext {
const RuntimeContext Context() const { return ctx_; }
std::string DebugString() const { return op_.DebugString(); }
const OperatorBase& GetOp() const { return op_; }
private:
const OperatorBase& op_;
......
......@@ -279,6 +279,8 @@ class TracedGradOp {
void SetType(const std::string& type) { op_->SetType(type); }
const framework::OperatorBase& InnerOp() const { return op_->InnerOp(); }
void SetAttrMap(const framework::AttributeMap& attrs) {
return op_->SetAttrMap(attrs);
}
......
......@@ -406,7 +406,7 @@ void OpBase::Run(const framework::OperatorBase& op,
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, place);
}
static void ClearNoNeedBufferInputs(OpBase* op) {
void ClearNoNeedBufferInputs(OpBase* op) {
auto& inferer = op->Info().NoNeedBufferVarsInferer();
if (!inferer) return;
auto* ins = op->GetMutableInsMap();
......
......@@ -286,5 +286,7 @@ std::shared_ptr<GradOpNode> CreateGradOpNode(
const platform::Place& place,
const std::map<std::string, std::string>& inplace_map);
void ClearNoNeedBufferInputs(OpBase* op);
} // namespace imperative
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/operators/py_layer_op.h"
namespace paddle {
namespace imperative {
namespace py = ::pybind11;
bool RequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
for (const auto& name_pair : ins) {
for (const auto& var_base : name_pair.second) {
if (!var_base->OverridedStopGradient()) {
PassStopGradient(outs, var_base->OverridedStopGradient());
return true;
}
}
}
return false;
}
std::shared_ptr<GradOpNode> CreateGradOpNode(
const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place,
const std::map<std::string, std::string>& inplace_map,
const std::shared_ptr<operators::PyLayerContext>& py_context) {
operators::PyLayerGradOpMaker<paddle::imperative::OpBase> maker(
type, ins, outs, attrs, inplace_map);
maker.SetPyLayerContext(py_context);
auto grad_node = maker();
if (grad_node && !grad_node->empty()) {
for (auto& grad_op : *grad_node) {
grad_op.SetId(OpBase::GenerateUniqueId());
grad_op.SetPlace(place);
ClearNoNeedBufferInputs(&grad_op);
}
return grad_node;
} else {
return nullptr;
}
}
py::object PyLayerApply(const platform::Place& place, const py::object& cls,
const py::args args, const py::kwargs kwargs) {
auto bk_function = cls.attr("_backward_function");
auto context = bk_function();
auto forward = cls.attr("forward");
auto result_forward = forward(context, *args, **kwargs);
std::shared_ptr<operators::PyLayerContext> py_layer_ctx =
std::make_shared<operators::PyLayerContext>(context.release().ptr());
// make inputs to varbase
std::vector<std::shared_ptr<imperative::VarBase>> input_vars;
// process args,`input_vars` only collect `imperative::VarBase`
if (!args.empty()) {
for (auto ptr = args.begin(); ptr != args.end(); ptr++) {
try {
if (Py_None != ptr->ptr()) {
auto a = ptr->cast<std::shared_ptr<VarBase>>();
input_vars.push_back(a);
}
} catch (py::cast_error& err) {
// Only collect Tensor type in 'args' and pass them to backward. Ignore
// other types of input temporarily.
}
}
}
// process kwargs, only collect `imperative::VarBase`
if (!kwargs.empty()) {
for (auto ptr = kwargs.begin(); ptr != kwargs.end(); ptr++) {
try {
if (Py_None != ptr->second.ptr()) {
auto a = ptr->second.cast<std::shared_ptr<VarBase>>();
input_vars.push_back(a);
}
} catch (py::cast_error&) {
// Only collect Tensor type in 'kwargs' and pass them to backward.
// Ignore other types of input temporarily.
}
}
}
NameVarBaseMap ins = {{"X", input_vars}};
std::vector<std::shared_ptr<imperative::VarBase>> output_vars;
if (PyTuple_Check(result_forward.ptr()) ||
PyList_Check(result_forward.ptr())) {
auto tuple_result = result_forward.cast<py::tuple>();
for (size_t i = 0; i < tuple_result.size(); i++) {
if (Py_None != tuple_result[i].ptr()) {
try {
auto temp_out =
tuple_result[i].cast<std::shared_ptr<imperative::VarBase>>();
output_vars.push_back(temp_out);
} catch (py::cast_error&) {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.forward` should be `Tensor`."));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.forward` can not be `None`."));
}
}
} else {
if (Py_None != result_forward.ptr()) {
try {
auto temp_out =
result_forward.cast<std::shared_ptr<imperative::VarBase>>();
output_vars.push_back(temp_out);
} catch (py::cast_error&) {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.forward` should be `Tensor`."));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.forward` can not be `None`."));
}
}
NameVarBaseMap outs = {{"Out", output_vars}};
if (RequiredGrad(ins, outs)) {
std::map<std::string, std::string> inplace_map{};
bool if_inplace = false;
for (auto temp_ins : input_vars) {
if (if_inplace) {
break;
}
for (auto temp_outs : output_vars) {
if (temp_ins->Name() == temp_outs->Name()) {
if_inplace = true;
break;
}
}
}
if (if_inplace) {
inplace_map["X"] = "Out";
}
CreateGradOpNode("py_layer", ins, outs, {{}}, place, inplace_map,
py_layer_ctx);
} else {
VLOG(3) << "No Grad to track for Op: py_layer_op";
}
return result_forward;
}
} // namespace imperative
} // namespace paddle
......@@ -38,7 +38,7 @@ void SetCurrentTracer(const std::shared_ptr<Tracer>& tracer) {
VLOG(6) << "Set current tracer: " << g_current_tracer;
}
static void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) {
void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) {
for (const auto& pair : outs) {
for (const auto& var : pair.second) {
// NOTE(zhiqiu): this happends when None output are passed from python
......
......@@ -130,5 +130,7 @@ void IncreaseVarbaseReferenceCountUntilCopyComplete(
const std::shared_ptr<imperative::VarBase>& var,
const platform::Place& place);
void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad);
} // namespace imperative
} // namespace paddle
......@@ -38,6 +38,9 @@ class VariableWrapper {
explicit VariableWrapper(const std::string& name) : name_(name) {}
VariableWrapper(const std::string& name, const framework::Variable& variable)
: var_(variable), name_(name) {}
~VariableWrapper() { VLOG(10) << "Destruct VariableWrapper: " << Name(); }
const framework::Variable& Var() const { return var_; }
......
......@@ -69,7 +69,7 @@ if(WITH_UNITY_BUILD)
include(unity_build_rule.cmake)
endif()
register_operators(EXCLUDES py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op
sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache ${OP_HEADER_DEPS})
......@@ -162,6 +162,7 @@ endif()
cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS})
if (WITH_PYTHON)
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
cc_library(py_layer_op SRCS py_layer_op.cc DEPS op_registry python pybind)
endif()
if (WITH_ASCEND_CL)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "paddle/fluid/operators/py_layer_op.h"
namespace paddle {
namespace operators {
namespace py = ::pybind11;
void RunPyObject(py::object *py_object,
const std::vector<framework::Variable *> &ins,
std::vector<framework::Variable *> *outs) {
py::gil_scoped_acquire guard;
auto py_function = py_object->attr("backward");
py::tuple inputs(ins.size());
for (size_t i = 0; i < ins.size(); i++) {
auto in_var = ins[i];
if (in_var != nullptr) {
auto name = paddle::string::Sprintf("generator_custom_py_layer_%d@GRAD",
static_cast<int>(i));
std::shared_ptr<imperative::VariableWrapper> temp_wrap =
std::make_shared<imperative::VariableWrapper>(name, *in_var);
temp_wrap->InnerSetOverridedStopGradient(true);
std::shared_ptr<imperative::VarBase> temp_varbase =
std::make_shared<imperative::VarBase>(temp_wrap);
try {
inputs[i] = py::cast(temp_varbase).ptr();
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.backward` should be `Tensor`."));
}
}
}
auto py_result = py_function(*py_object, *inputs);
if (PyTuple_Check(py_result.ptr()) || PyList_Check(py_result.ptr())) {
auto result_tuple = py_result.cast<py::tuple>();
if (result_tuple.size() != outs->size()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The number of outputs of `PyLayer.backward` should be %d, but "
"received %d.",
outs->size(), result_tuple.size()));
}
for (size_t i = 0; i < result_tuple.size(); i++) {
if (Py_None != result_tuple[i].ptr()) {
try {
auto result_var =
result_tuple[i].cast<std::shared_ptr<imperative::VarBase>>();
*(*outs)[i] = result_var->Var();
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.backward` should be `Tensor`."));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.backward` can not be `None`."));
}
}
} else {
if (Py_None != py_result.ptr()) {
try {
auto result_var =
py_result.cast<std::shared_ptr<imperative::VarBase>>();
*((*outs)[0]) = result_var->Var();
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.backward` should be `Tensor`."));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.backward` can not be `None`."));
}
}
}
void PyLayerGradOpMaker<paddle::imperative::OpBase>::Apply(
GradOpPtr<paddle::imperative::OpBase> grad_op) const {
grad_op->SetType("py_layer");
auto &inner_op = grad_op->InnerOp();
auto py_layer_op_const = dynamic_cast<const PyLayerOp *>(&inner_op);
if (py_layer_op_const) {
auto py_layer_op = const_cast<PyLayerOp *>(py_layer_op_const);
py_layer_op->SetPyLayerContext(py_context_);
} else {
PADDLE_THROW(platform::errors::Fatal(
"PyLayerGradOpMaker can't cast %s to PyLayerOp*.",
typeid(&inner_op).name()));
}
auto fwd_out_grads = this->OutputGrad("Out");
using return_type = decltype(fwd_out_grads);
return_type bwd_ins;
bwd_ins.insert(bwd_ins.begin(), fwd_out_grads.begin(), fwd_out_grads.end());
auto bwd_outs = this->InputGrad("X", false);
grad_op->SetInput("X", bwd_ins);
grad_op->SetOutput("Out", bwd_outs);
}
class PyLayerOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Inputs of PyLayer op.").AsDuplicable();
AddOutput("Out", "Outputs of PyLayer op").AsDuplicable();
AddComment(R"DOC("PyLayer Op")DOC");
}
};
template <typename DeviceContext, typename T>
class PyLayerOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto &op_ = ctx.GetOp();
auto pylayer_op = dynamic_cast<const PyLayerOp *>(&op_);
if (pylayer_op) {
auto py_layer_context = pylayer_op->GetPyLayerContext();
py::object bk_ctx(py::handle(py_layer_context->GetMutableCtx()), true);
auto &input_vars = ctx.MultiInputVar("X");
auto output_vars = ctx.MultiOutputVar("Out");
RunPyObject(&bk_ctx, input_vars, &output_vars);
} else {
PADDLE_THROW(platform::errors::Fatal(
"PyLayerOpKernel can't cast %s to PyLayer*.", typeid(&op_).name()));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(py_layer, ops::PyLayerOp, ops::PyLayerOpMaker,
ops::PyLayerGradOpMaker<paddle::imperative::OpBase>,
ops::PyLayerGradOpMaker<paddle::framework::OpDesc>);
REGISTER_OP_CPU_KERNEL(
py_layer, ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext,
::paddle::platform::float16>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext,
::paddle::platform::bfloat16>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, bool>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, int16_t>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext,
::paddle::platform::complex64>,
ops::PyLayerOpKernel<paddle::platform::CPUDeviceContext,
::paddle::platform::complex128>);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL(
py_layer, ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext,
::paddle::platform::float16>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext,
::paddle::platform::bfloat16>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, bool>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, uint8_t>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, int16_t>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext, int8_t>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext,
::paddle::platform::complex64>,
ops::PyLayerOpKernel<paddle::platform::CUDADeviceContext,
::paddle::platform::complex128>);
#endif // PADDLE_WITH_CUDA
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <functional>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/python_headers.h"
namespace paddle {
namespace operators {
namespace py = ::pybind11;
class PyLayerContext {
public:
explicit PyLayerContext(PyObject* context) : context_(context) {
Py_INCREF(context_);
}
PyLayerContext() = delete;
PyObject* GetMutableCtx() { return context_; }
private:
PyObject* context_;
};
class PyLayerOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
VLOG(3) << "`InferShape` of `PyLayer` is an empty function, and it cannot "
"infer the shape of the output tensors.";
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
public:
void SetPyLayerContext(const std::shared_ptr<PyLayerContext>& py_context) {
py_context_ = py_context;
}
const std::shared_ptr<PyLayerContext>& GetPyLayerContext() const {
return py_context_;
}
private:
std::shared_ptr<PyLayerContext> py_context_;
};
template <typename T>
class PyLayerGradOpMaker {};
template <>
class PyLayerGradOpMaker<paddle::framework::OpDesc>
: public framework::SingleGradOpMaker<paddle::framework::OpDesc> {
public:
using framework::SingleGradOpMaker<
paddle::framework::OpDesc>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<paddle::framework::OpDesc> grad_op) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"`PyLayer` don't support static graph mode."));
}
};
template <>
class PyLayerGradOpMaker<paddle::imperative::OpBase>
: public framework::SingleGradOpMaker<paddle::imperative::OpBase> {
public:
using framework::SingleGradOpMaker<
paddle::imperative::OpBase>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<paddle::imperative::OpBase> grad_op) const override;
public:
void SetPyLayerContext(const std::shared_ptr<PyLayerContext>& py_context) {
py_context_ = py_context;
}
private:
std::shared_ptr<PyLayerContext> py_context_;
};
} // namespace operators
} // namespace paddle
......@@ -36,6 +36,7 @@ endif(NOT WIN32)
if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op)
list(APPEND PYBIND_DEPS py_layer_op)
endif()
set(PYBIND_SRCS
......
......@@ -39,6 +39,7 @@ limitations under the License. */
#include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/imperative/partial_grad_engine.h"
#include "paddle/fluid/imperative/profiler.h"
#include "paddle/fluid/imperative/py_layer_fwd.h"
#include "paddle/fluid/imperative/reducer.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
......@@ -1339,22 +1340,28 @@ void BindImperative(py::module *m_ptr) {
&imperative::VarBase::SetOverridedStopGradient)
.def_property("persistable", &imperative::VarBase::Persistable,
&imperative::VarBase::SetPersistable)
.def_property_readonly(
"shape",
[](imperative::VarBase &self) {
if (self.Var().IsType<framework::LoDTensor>()) {
return framework::vectorize<int>(
self.Var().Get<framework::LoDTensor>().dims());
} else if (self.Var().IsType<framework::SelectedRows>()) {
return framework::vectorize<int>(
self.Var().Get<framework::SelectedRows>().value().dims());
} else {
VLOG(2) << "It is meaningless to get shape of "
"variable type "
<< GetTypeName(self);
return std::vector<int>();
}
})
.def_property_readonly("shape",
[](imperative::VarBase &self) {
if (self.Var().IsType<framework::LoDTensor>()) {
return framework::vectorize<int>(
self.Var()
.Get<framework::LoDTensor>()
.dims());
} else if (self.Var()
.IsType<
framework::SelectedRows>()) {
return framework::vectorize<int>(
self.Var()
.Get<framework::SelectedRows>()
.value()
.dims());
} else {
VLOG(2) << "It is meaningless to get shape of "
"variable type "
<< GetTypeName(self);
return std::vector<int>();
}
})
.def_property_readonly("is_leaf", &imperative::VarBase::IsLeaf,
R"DOC(
Whether a Tensor is leaf Tensor.
......@@ -1643,6 +1650,29 @@ void BindImperative(py::module *m_ptr) {
&imperative::BKCLParallelContext::InitWithRingID,
py::arg("ring_id"));
#endif
m.def("pylayer_apply",
[](const platform::CPUPlace &place, const py::object &cls,
const py::args args, const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
m.def("pylayer_apply",
[](const platform::CUDAPlace &place, const py::object &cls,
const py::args args, const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
m.def("pylayer_apply",
[](const platform::XPUPlace &place, const py::object &cls,
const py::args args, const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
m.def("pylayer_apply",
[](const platform::CUDAPinnedPlace &place, const py::object &cls,
const py::args args, const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
}
} // namespace pybind
......
......@@ -16,7 +16,6 @@ from ..fluid.dygraph.base import grad #DEFINE_ALIAS
from . import backward_mode
from .backward_mode import backward
from .py_layer import PyLayer, PyLayerContext
__all__ = ['grad']
__all__ += backward_mode.__all__
__all__ = ['grad', 'backward', 'PyLayer', 'PyLayerContext']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid.framework import dygraph_only
from paddle.fluid import core
__all__ = ['PyLayer', 'PyLayerContext']
class PyLayerContext(object):
"""
The object of this class is a context that is used in PyLayer to enhance the function.
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
# ctx is a object of PyLayerContext.
y = paddle.tanh(x)
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# ctx is a object of PyLayerContext.
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
"""
def __init__(self):
self.container = None
def save_for_backward(self, *tensors):
"""
Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors.
.. note::
This API should be called at most once, and only inside `forward`.
Args:
tensors(list of Tensors): Tensors to be stored.
Returns:
None
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
# ctx is a context object that store some objects for backward.
y = paddle.tanh(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
"""
self.container = tensors
def saved_tensor(self):
"""
Get the tensors stored by ``save_for_backward``.
Returns:
list of Tensors or None: If context contains tensors stored by `save_for_backward`,
then return these tensors, otherwise return None.
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
# ctx is a context object that store some objects for backward.
y = paddle.tanh(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
"""
return self.container
def with_mateclass(meta, *bases):
class impl(meta):
def __new__(cls, name, temp_bases, attrs):
return meta(name, bases, attrs)
return type.__new__(impl, "impl", (), {})
class CPyLayer(object):
@classmethod
@dygraph_only
def apply(cls, *args, **kwargs):
"""
After building the custom PyLayer, run it through the ``apply``.
Args:
*args(tuple): input of PyLayer.
**kwargs(dict): input of PyLayer.
Returns:
tensors or other types : output of PyLayer.
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x, func1, func2=paddle.square):
ctx.func = func2
y = func1(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - ctx.func(y))
return grad
data = paddle.randn([2, 3], dtype="float64")
data.stop_gradient = False
# run custom Layer.
z = cus_tanh.apply(data, func1=paddle.tanh)
"""
place = paddle.fluid.framework._current_expected_place()
with paddle.fluid.dygraph.no_grad():
return core.pylayer_apply(place, cls, *args, **kwargs)
class PyLayerBackward(PyLayerContext):
def backward(self, *args, **kwargs):
with paddle.fluid.dygraph.no_grad():
return self._forward_cls.backward(*args, **kwargs)
class LayerMeta(type):
def __init__(cls, name, bases, attrs):
cls._backward_function = type(name + '_backward', (PyLayerBackward, ),
{"_forward_cls": cls})
return super(LayerMeta, cls).__init__(name, bases, attrs)
class PyLayer(with_mateclass(LayerMeta, CPyLayer)):
"""
Build a custom `Layer` by creating subclasses. Subclasses need to follow the following rules:
1. Subclasses contain `forward` and `backward` function. Both forward and backward are @staticmethod.
Their first argument should be a context and `None` can not be included in the returned result.
2. Input of backward contains a context as the first argument, and the rest arguments are the
gradient of forward's output tensors. so the number of backward's input tensors equal to
the number of forward output tensors. If you need the forward's inputs or outputs in `backward`,
you can use `save_for_backward` to store the required tensors, and then use them in the backward.
3. Output of backward function can only be `Tensor` or tuple/list of `Tensor`.
Output tensors of backward are the gradient of forward's input tensors,
so the number of backward's output tensors equal to the number of forward input tensors.
After building the custom Layer, run it through the `apply` method.
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
# Inherit from PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x, func1, func2=paddle.square):
# ctx is a context object that store some objects for backward.
ctx.func = func2
y = func1(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
# forward has only one output, so there is only one gradient in the input of backward.
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - ctx.func(y))
# forward has only one input, so only one gradient tensor is returned.
return grad
data = paddle.randn([2, 3], dtype="float64")
data.stop_gradient = False
z = cus_tanh.apply(data, func1=paddle.tanh)
z.mean().backward()
print(data.grad)
"""
@staticmethod
def forward(ctx, *args, **kwargs):
"""
It is to be overloaded by subclasses. It must accept a object of `PyLayerContext` as
the first argument, followed by any number of arguments (tensors or other types).
`None` can not be included in the returned result.
Args:
*args(tuple): input of PyLayer.
**kwargs(dict): input of PyLayer.
Returns:
tensors or other types : output of PyLayer.
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
y = paddle.tanh(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
"""
raise NotImplementedError(
"You must implement the forward function for PyLayer.")
@staticmethod
def backward(ctx, *args, **kwargs):
"""
This is a function to calculate the gradient. It is to be overloaded by subclasses.
It must accept a object of `PyLayerContext` as the first argument, and the rest
arguments are the gradient of forward's output tensors. Output tensors of backward
are the gradient of forward's input tensors.
Args:
*args(tuple): The gradient of forward's output tensor(s).
**kwargs(dict): The gradient of forward's output tensor(s).
Returns:
Tensor or list of Tensors: The gradient of forward's input tensor(s).
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
y = paddle.tanh(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
"""
raise NotImplementedError(
"You must implement the backward function for PyLayer.")
......@@ -730,6 +730,7 @@ set_tests_properties(test_trilinear_interp_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_bicubic_interp_v2_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_gather_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_static_save_load PROPERTIES TIMEOUT 250)
set_tests_properties(test_pylayer_op PROPERTIES TIMEOUT 120)
if (WIN32)
set_tests_properties(test_static_save_load_large PROPERTIES TIMEOUT 900)
set_tests_properties(test_paddle_save_load PROPERTIES TIMEOUT 250)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.autograd import PyLayer
class TestPyLayer(unittest.TestCase):
def test_simple_pylayer_multiple_output(self):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x1, x2, func1, func2=paddle.square):
ctx.func = func2
y1 = func1(x1)
y2 = func1(x2)
ctx.save_for_backward(y1, y2)
return y1, y2
@staticmethod
def backward(ctx, dy1, dy2):
y1, y2 = ctx.saved_tensor()
re1 = dy1 * (1 - ctx.func(y1))
re2 = dy2 * (1 - paddle.square(y2))
return re1, re2
input1 = paddle.randn([2, 3]).astype("float64")
input2 = input1.detach().clone()
input1.stop_gradient = False
input2.stop_gradient = False
z = tanh.apply(input1, input1, paddle.tanh, paddle.square)
z = z[0] + z[1]
z.mean().backward()
z2 = paddle.tanh(input2) + paddle.tanh(input2)
z2.mean().backward()
self.assertTrue(np.max(np.abs((input1.grad - input2.grad))) < 1e-10)
def test_simple_pylayer_single_output(self):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x1, func1, func2=paddle.square):
ctx.func = func2
y1 = func1(x1)
ctx.save_for_backward(y1)
return y1
@staticmethod
def backward(ctx, dy1):
y1, = ctx.saved_tensor()
re1 = dy1 * (1 - ctx.func(y1))
return re1
input1 = paddle.randn([2, 3]).astype("float64")
input2 = input1.detach().clone()
input1.stop_gradient = False
input2.stop_gradient = False
z = tanh.apply(x1=input1, func1=paddle.tanh)
z.mean().backward()
z2 = paddle.tanh(input2)
z2.mean().backward()
self.assertTrue(np.max(np.abs((input1.grad - input2.grad))) < 1e-10)
def test_pylayer_dtype(self):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x, dtype):
y = paddle.cast(x, dtype)
return y
@staticmethod
def backward(ctx, dy1):
return dy1
dtypes = [
'bool', 'float16', 'float32', 'float64', 'uint8', 'int32', 'int64'
]
for dtype in dtypes:
input1 = (paddle.randn([2, 3]))
input1.stop_gradient = False
self.assertTrue(input1.grad is None)
z = tanh.apply(input1, dtype)
z = paddle.cast(z, "float32")
z.sum().backward()
self.assertTrue(input1.grad is not None)
def test_pylayer_Exception_forward(self):
class Layer_None1(PyLayer):
@staticmethod
def forward(ctx, *args):
return None
@staticmethod
def backward(ctx, *args):
return args
input1 = paddle.randn([2, 3]).astype("float64")
with self.assertRaises(NotImplementedError):
z = Layer_None1.apply(input1)
class Layer_None2(PyLayer):
@staticmethod
def forward(ctx, *args):
return [None, None]
@staticmethod
def backward(ctx, *args):
return args
input1 = paddle.randn([2, 3]).astype("float64")
with self.assertRaises(NotImplementedError):
z = Layer_None2.apply(input1)
class Layer_one1(PyLayer):
@staticmethod
def forward(ctx, *args):
return 1
@staticmethod
def backward(ctx, *args):
return args
input1 = paddle.randn([2, 3]).astype("float64")
with self.assertRaises(NotImplementedError):
z = Layer_one1.apply(input1)
class Layer_one2(PyLayer):
@staticmethod
def forward(ctx, *args):
return [1, 2]
@staticmethod
def backward(ctx, *args):
return args
input1 = paddle.randn([2, 3]).astype("float64")
with self.assertRaises(NotImplementedError):
z = Layer_one2.apply(input1)
class Layer_no_fw(PyLayer):
@staticmethod
def backward(ctx, *args):
return args
input1 = paddle.randn([2, 3]).astype("float64")
with self.assertRaises(NotImplementedError):
z = Layer_no_fw.apply(input1)
def test_pylayer_nograd(self):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x1, func1, func2=paddle.square, xx=None):
ctx.func = func2
y1 = func1(x1)
return y1
@staticmethod
def backward(ctx, x1, y1, dy1):
re1 = dy1 * (1 - ctx.func(y1))
return re1
input1 = paddle.randn([2, 3]).astype("float64")
z = tanh.apply(input1, paddle.tanh, paddle.square)
z.mean().backward()
self.assertTrue(z.grad is None)
def test_pylayer_Exception_bk(self):
class Layer_bk_none1(PyLayer):
@staticmethod
def forward(ctx, x):
return x * 2
@staticmethod
def backward(ctx, dy1):
return None
input2 = paddle.randn([2, 3]).astype("float64")
input2.stop_gradient = False
z = Layer_bk_none1.apply(input2)
with self.assertRaises(NotImplementedError):
with paddle.fluid.dygraph.guard():
z.sum().backward()
class Layer_bk_none2(PyLayer):
@staticmethod
def forward(ctx, x1, x2):
return x1 + x2
@staticmethod
def backward(ctx, dy1):
return None, dy1
input1 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = False
z = Layer_bk_none2.apply(input1, input1)
with self.assertRaises(NotImplementedError):
with paddle.fluid.dygraph.guard():
z.mean().backward()
class Layer_bk_one1(PyLayer):
@staticmethod
def forward(ctx, x):
return x + x
@staticmethod
def backward(ctx, dy):
return 1
input1 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = False
z = Layer_bk_one1.apply(input1)
with self.assertRaises(NotImplementedError):
with paddle.fluid.dygraph.guard():
z.mean().backward()
class Layer_bk_one2(PyLayer):
@staticmethod
def forward(ctx, x):
return x * 2, x * 5
@staticmethod
def backward(ctx, *args):
return 1, 1
input1 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = False
z = Layer_bk_one1.apply(input1)
with self.assertRaises(NotImplementedError):
with paddle.fluid.dygraph.guard():
z.mean().backward()
class Layer_no_bk(PyLayer):
@staticmethod
def forward(ctx, x):
return x * 2, x * 5
input1 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = False
z = Layer_no_bk.apply(input1)
with self.assertRaises(NotImplementedError):
with paddle.fluid.dygraph.guard():
z = z[0] + z[1]
z.mean().backward()
class Layer_bk_match(PyLayer):
@staticmethod
def forward(ctx, x):
return x * 2, x * 5
@staticmethod
def backward(ctx, dy1, dy2):
return dy2 * 2, dy1 * 2
input1 = paddle.randn([2, 3]).astype("float64")
input1.stop_gradient = False
z = Layer_bk_match.apply(input1)
with self.assertRaises(ValueError):
with paddle.fluid.dygraph.guard():
z = z[0] + z[1]
z.mean().backward()
def test_pylayer_inplace(self):
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
return x.mean()
@staticmethod
def backward(ctx, dy):
return dy
for i in range(2):
data = paddle.ones([2, 3], dtype="float64") / (i + 1)
data.stop_gradient = False
data = paddle.nn.functional.relu(data)
z = paddle.tanh(data)
z = cus_tanh.apply(data)
z.backward()
self.assertTrue(data.grad is not None)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册