未验证 提交 0b60b784 编写于 作者: W wanghuancoder 提交者: GitHub

delete old dygraph pylayer (#49339)

* delete old dygraph pylayer
上级 941811b2
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/prepared_operator.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/operators/py_layer_op.h"
namespace paddle {
namespace imperative {
namespace py = ::pybind11;
bool RequiredGrad(const NameVarBaseMap& ins, const NameVarBaseMap& outs) {
for (const auto& name_pair : ins) {
for (const auto& var_base : name_pair.second) {
if (!var_base->OverridedStopGradient()) {
for (const auto& pair : outs) {
for (const auto& var : pair.second) {
if (var) {
var->SetOverridedStopGradient(false);
SetForwardDataTypeOfGradVar(var);
VLOG(3) << "Set output: " << var->Name()
<< "'s OverridedStopGradient as "
<< var->OverridedStopGradient();
}
}
}
return true;
}
}
}
return false;
}
std::shared_ptr<GradOpNode> CreateGradOpNode(
const std::string& type,
const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
const framework::AttributeMap& attrs,
const platform::Place& place,
const std::map<std::string, std::string>& inplace_map,
const std::shared_ptr<operators::PyLayerContext>& py_context) {
operators::PyLayerGradOpMaker<paddle::imperative::OpBase> maker(
type, ins, outs, attrs, inplace_map);
maker.SetPyLayerContext(py_context);
auto grad_node = maker();
if (grad_node && !grad_node->empty()) {
for (auto& grad_op : *grad_node) {
grad_op.SetId(OpBase::GenerateUniqueId());
grad_op.SetPlace(place);
ClearNoNeedBufferInputs(&grad_op);
}
return grad_node;
} else {
return nullptr;
}
}
py::object PyLayerApply(const platform::Place& place,
const py::handle& cls,
const py::args args,
const py::kwargs kwargs) {
py::gil_scoped_acquire guard;
auto bk_function = cls.attr("_backward_function");
auto context = bk_function();
auto forward = cls.attr("forward");
// make inputs to varbase
std::vector<std::shared_ptr<imperative::VarBase>> input_vars;
// process args,`input_vars` only collect `imperative::VarBase`
if (!args.empty()) {
for (auto ptr = args.begin(); ptr != args.end(); ptr++) {
// Only collect Tensor type in 'args' and pass them to backward. Ignore
// other types of input temporarily.
if (py::isinstance<imperative::VarBase>(*ptr)) {
try {
auto a = ptr->cast<std::shared_ptr<VarBase>>();
input_vars.push_back(a);
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->ptr()->ob_type->tp_name));
}
} else if (py::isinstance<py::tuple>(*ptr) ||
py::isinstance<py::list>(*ptr)) {
try {
auto tuple_arg = ptr->cast<py::tuple>();
for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) {
try {
auto t = iter->cast<std::shared_ptr<VarBase>>();
input_vars.push_back(t);
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, "
"the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->ptr()->ob_type->tp_name));
}
}
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->ptr()->ob_type->tp_name));
}
}
}
}
// process kwargs, only collect `imperative::VarBase`
if (!kwargs.empty()) {
for (auto ptr = kwargs.begin(); ptr != kwargs.end(); ptr++) {
// Only collect Tensor type in 'kwargs' and pass them to backward.
// Ignore other types of input temporarily.
if (py::isinstance<imperative::VarBase>(*ptr->second)) {
try {
auto a = ptr->second.cast<std::shared_ptr<VarBase>>();
input_vars.push_back(a);
} catch (py::cast_error&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->second.ptr()->ob_type->tp_name));
}
} else if (py::isinstance<py::tuple>(*ptr->second) ||
py::isinstance<py::list>(*ptr->second)) {
try {
auto tuple_arg = ptr->second.cast<py::tuple>();
for (auto iter = tuple_arg.begin(); iter != tuple_arg.end(); ++iter) {
try {
auto t = iter->cast<std::shared_ptr<VarBase>>();
input_vars.push_back(t);
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, "
"the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->second.ptr()->ob_type->tp_name));
}
}
} catch (py::cast_error& err) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function contains invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
ptr->second.ptr()->ob_type->tp_name));
}
}
}
}
std::shared_ptr<operators::PyLayerContext> py_layer_ctx =
std::make_shared<operators::PyLayerContext>(context.ptr());
auto result_forward = forward(context, *args, **kwargs);
NameVarBaseMap ins = {{"X", input_vars}};
std::vector<std::shared_ptr<imperative::VarBase>> output_vars;
if (PyTuple_Check(result_forward.ptr()) ||
PyList_Check(result_forward.ptr())) {
auto tuple_result = result_forward.cast<py::tuple>();
for (size_t i = 0; i < tuple_result.size(); i++) {
// Only collect Tensor type of output and pass them to backward.
// Ignore other types of input temporarily.
if (py::isinstance<imperative::VarBase>(tuple_result[i])) {
try {
auto temp_out =
tuple_result[i].cast<std::shared_ptr<imperative::VarBase>>();
output_vars.push_back(temp_out);
} catch (py::cast_error&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function returns invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
tuple_result[i].ptr()->ob_type->tp_name));
}
}
}
} else {
// Only collect Tensor type of output and pass them to backward.
// Ignore other types of input temporarily.
if (py::isinstance<imperative::VarBase>(result_forward)) {
try {
auto temp_out =
result_forward.cast<std::shared_ptr<imperative::VarBase>>();
output_vars.push_back(temp_out);
} catch (py::cast_error&) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.forward` function returns invalid argument, the `%s` "
"type argument can not be cast into `Tensor`.",
result_forward.ptr()->ob_type->tp_name));
}
}
}
if (output_vars.size() == 0) {
PADDLE_THROW(platform::errors::InvalidArgument(
"At least one output of `PyLayer.forward` is a `Tensor`."));
}
NameVarBaseMap outs = {{"Out", output_vars}};
if (RequiredGrad(ins, outs)) {
std::map<std::string, std::string> inplace_map{};
bool if_inplace = false;
for (auto temp_ins : input_vars) {
if (if_inplace) {
break;
}
for (auto temp_outs : output_vars) {
if (temp_ins->Name() == temp_outs->Name()) {
if_inplace = true;
break;
}
}
}
if (if_inplace) {
// when pylayer forward is inplace strategy, check whether tensor is leaf
for (auto& t : input_vars) {
PADDLE_ENFORCE_EQ(t->IsLeaf() && !t->OverridedStopGradient(),
false,
platform::errors::InvalidArgument(
"Leaf Var (%s) that doesn't stop gradient can't "
"use inplace strategy.",
t->Name()));
}
inplace_map["X"] = "Out";
}
CreateGradOpNode(
"py_layer", ins, outs, {{}}, place, inplace_map, py_layer_ctx);
} else {
VLOG(3) << "No Grad to track for Op: py_layer_op";
}
return result_forward;
}
} // namespace imperative
} // namespace paddle
......@@ -97,7 +97,7 @@ endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_api_utils backward_infermeta sparse_backward_infermeta)
register_operators(EXCLUDES py_layer_op py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op
register_operators(EXCLUDES py_func_op warpctc_op dgc_op load_combine_op lstm_op run_program_op eye_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS})
......@@ -206,7 +206,6 @@ cc_test(share_buffer_op_cpp_test SRCS share_buffer_op_test.cc DEPS lod_tensor de
cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS})
if (WITH_PYTHON)
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
cc_library(py_layer_op SRCS py_layer_op.cc DEPS op_registry python pybind)
endif()
if (WITH_ASCEND_CL)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/py_layer_op.h"
#include <vector>
namespace paddle {
namespace operators {
namespace py = ::pybind11;
void RunPyObject(py::object *py_object,
const std::vector<framework::Variable *> &ins,
std::vector<framework::Variable *> *outs) {
py::gil_scoped_acquire guard;
auto py_function = py_object->attr("backward");
py::tuple inputs(ins.size());
for (size_t i = 0; i < ins.size(); i++) {
auto in_var = ins[i];
if (in_var != nullptr) {
auto name = paddle::string::Sprintf("generator_custom_py_layer_%d@GRAD",
static_cast<int>(i));
std::shared_ptr<imperative::VariableWrapper> temp_wrap =
std::make_shared<imperative::VariableWrapper>(name, *in_var);
temp_wrap->InnerSetOverridedStopGradient(true);
std::shared_ptr<imperative::VarBase> temp_varbase =
std::make_shared<imperative::VarBase>(temp_wrap);
try {
inputs[i] = py::cast(temp_varbase).ptr();
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::Unimplemented(
"The output of `PyLayer.backward` should be `Tensor`."));
}
}
}
auto py_result = py_function(*py_object, *inputs);
if (PyTuple_Check(py_result.ptr()) || PyList_Check(py_result.ptr())) {
auto result_tuple = py_result.cast<py::tuple>();
if (result_tuple.size() != outs->size()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The number of outputs of `PyLayer.backward` should be %d, but "
"received %d.",
outs->size(),
result_tuple.size()));
}
for (size_t i = 0; i < result_tuple.size(); i++) {
if ((*outs)[i] != nullptr) {
if (Py_None != result_tuple[i].ptr()) {
if (py::isinstance<imperative::VarBase>(result_tuple[i])) {
try {
auto result_var =
result_tuple[i].cast<std::shared_ptr<imperative::VarBase>>();
*(*outs)[i] = result_var->Var();
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.backward` function returns invalid argument, "
"the `%s` type argument can not be cast into `Tensor`.",
result_tuple[i].ptr()->ob_type->tp_name));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The output of `PyLayer.backward` should be `Tensor`, but "
"received `%s`.",
result_tuple[i].ptr()->ob_type->tp_name));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The %dth input tensor of forward needs gradient and the "
"corresponding gradient cannot be None.",
i));
}
} else {
if (Py_None != result_tuple[i].ptr()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The %dth input tensor of forward do not need gradient and the "
"corresponding gradient should be `None`.",
i));
}
}
}
} else {
if (1 != outs->size()) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The number of outputs of `PyLayer.backward` should be %d, but "
"received 1.",
outs->size()));
}
if ((*outs)[0] != nullptr) {
if (Py_None != py_result.ptr()) {
if (py::isinstance<imperative::VarBase>(py_result)) {
try {
auto result_var =
py_result.cast<std::shared_ptr<imperative::VarBase>>();
*((*outs)[0]) = result_var->Var();
} catch (py::cast_error &) {
PADDLE_THROW(platform::errors::InvalidArgument(
"The `PyLayer.backward` function returns invalid argument, the "
"`%s` type argument can not be cast into `Tensor`.",
py_result.ptr()->ob_type->tp_name));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The output of `PyLayer.backward` should be `Tensor`, but "
"received `%s`",
py_result.ptr()->ob_type->tp_name));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor of forward needs gradient, so the output of "
"`PyLayer.backward` can not be `None`."));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor of forward do not need gradient, so the output of "
"`PyLayer.backward` should be `None`."));
}
}
}
void PyLayerGradOpMaker<paddle::imperative::OpBase>::Apply(
GradOpPtr<paddle::imperative::OpBase> grad_op) const {
grad_op->SetType("py_layer");
auto &inner_op = grad_op->InnerOp();
auto py_layer_op_const = dynamic_cast<const PyLayerOp *>(&inner_op);
if (py_layer_op_const) {
auto py_layer_op = const_cast<PyLayerOp *>(py_layer_op_const);
py_layer_op->SetPyLayerContext(py_context_);
} else {
PADDLE_THROW(platform::errors::Fatal(
"PyLayerGradOpMaker can't cast %s to PyLayerOp*.",
typeid(&inner_op).name()));
}
auto fwd_out_grads = this->OutputGrad("Out");
using return_type = decltype(fwd_out_grads);
return_type bwd_ins;
bwd_ins.insert(bwd_ins.begin(), fwd_out_grads.begin(), fwd_out_grads.end());
auto bwd_outs = this->InputGrad("X", false);
grad_op->SetInput("X", bwd_ins);
grad_op->SetOutput("Out", bwd_outs);
}
class PyLayerOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Inputs of PyLayer op.").AsDuplicable();
AddOutput("Out", "Outputs of PyLayer op").AsDuplicable();
AddComment(R"DOC("PyLayer Op")DOC");
}
};
template <typename DeviceContext, typename T>
class PyLayerOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto &op_ = ctx.GetOp();
auto const_pylayer_op = dynamic_cast<const PyLayerOp *>(&op_);
if (const_pylayer_op) {
auto pylayer_op = const_cast<PyLayerOp *>(const_pylayer_op);
// Release contex after executing the compute
auto py_layer_context = pylayer_op->ReleasePyLayerContext();
py::object bk_ctx(py::handle(py_layer_context->GetMutableCtx()), true);
auto &input_vars = ctx.MultiInputVar("X");
auto output_vars = ctx.MultiOutputVar("Out");
RunPyObject(&bk_ctx, input_vars, &output_vars);
} else {
PADDLE_THROW(platform::errors::Fatal(
"PyLayerOpKernel can't cast %s to PyLayer*.", typeid(&op_).name()));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(py_layer,
ops::PyLayerOp,
ops::PyLayerOpMaker,
ops::PyLayerGradOpMaker<paddle::imperative::OpBase>,
ops::PyLayerGradOpMaker<paddle::framework::OpDesc>);
REGISTER_OP_CPU_KERNEL(
py_layer,
ops::PyLayerOpKernel<phi::CPUContext, float>,
ops::PyLayerOpKernel<phi::CPUContext, ::paddle::platform::float16>,
ops::PyLayerOpKernel<phi::CPUContext, ::paddle::platform::bfloat16>,
ops::PyLayerOpKernel<phi::CPUContext, double>,
ops::PyLayerOpKernel<phi::CPUContext, int>,
ops::PyLayerOpKernel<phi::CPUContext, int64_t>,
ops::PyLayerOpKernel<phi::CPUContext, bool>,
ops::PyLayerOpKernel<phi::CPUContext, uint8_t>,
ops::PyLayerOpKernel<phi::CPUContext, int16_t>,
ops::PyLayerOpKernel<phi::CPUContext, int8_t>,
ops::PyLayerOpKernel<phi::CPUContext, ::paddle::platform::complex<float>>,
ops::PyLayerOpKernel<phi::CPUContext, ::paddle::platform::complex<double>>);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
REGISTER_OP_CUDA_KERNEL(
py_layer,
ops::PyLayerOpKernel<phi::GPUContext, float>,
ops::PyLayerOpKernel<phi::GPUContext, ::paddle::platform::float16>,
ops::PyLayerOpKernel<phi::GPUContext, ::paddle::platform::bfloat16>,
ops::PyLayerOpKernel<phi::GPUContext, double>,
ops::PyLayerOpKernel<phi::GPUContext, int>,
ops::PyLayerOpKernel<phi::GPUContext, int64_t>,
ops::PyLayerOpKernel<phi::GPUContext, bool>,
ops::PyLayerOpKernel<phi::GPUContext, uint8_t>,
ops::PyLayerOpKernel<phi::GPUContext, int16_t>,
ops::PyLayerOpKernel<phi::GPUContext, int8_t>,
ops::PyLayerOpKernel<phi::GPUContext, ::paddle::platform::complex<float>>,
ops::PyLayerOpKernel<phi::GPUContext, ::paddle::platform::complex<double>>);
#endif // PADDLE_WITH_CUDA || PADDLE_WITH_HIP
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <functional>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/python_headers.h"
namespace paddle {
namespace operators {
namespace py = ::pybind11;
class PyLayerContext {
public:
explicit PyLayerContext(PyObject* context) : context_(context) {
Py_INCREF(context_);
}
PyLayerContext() = delete;
PyObject* GetMutableCtx() { return context_; }
~PyLayerContext() {
py::gil_scoped_acquire guard;
Py_XDECREF(context_);
}
private:
PyObject* context_;
};
class PyLayerOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
VLOG(3) << "`InferShape` of `PyLayer` is an empty function, and it cannot "
"infer the shape of the output tensors.";
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = paddle::framework::proto::VarType::Type::VarType_Type_FP32;
return framework::OpKernelType(data_type, ctx.device_context());
}
public:
void SetPyLayerContext(const std::shared_ptr<PyLayerContext>& py_context) {
py_context_ = py_context;
}
std::shared_ptr<PyLayerContext> ReleasePyLayerContext() {
auto temp = py_context_;
py_context_.reset();
VLOG(3) << "`py_context_` in the PyLayerOp is released.";
return temp;
}
private:
std::shared_ptr<PyLayerContext> py_context_;
};
template <typename T>
class PyLayerGradOpMaker {};
template <>
class PyLayerGradOpMaker<paddle::framework::OpDesc>
: public framework::SingleGradOpMaker<paddle::framework::OpDesc> {
public:
using framework::SingleGradOpMaker<
paddle::framework::OpDesc>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<paddle::framework::OpDesc> grad_op) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"`PyLayer` don't support static graph mode."));
}
};
template <>
class PyLayerGradOpMaker<paddle::imperative::OpBase>
: public framework::SingleGradOpMaker<paddle::imperative::OpBase> {
public:
using framework::SingleGradOpMaker<
paddle::imperative::OpBase>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<paddle::imperative::OpBase> grad_op) const override;
public:
void SetPyLayerContext(const std::shared_ptr<PyLayerContext>& py_context) {
py_context_ = py_context;
}
private:
std::shared_ptr<PyLayerContext> py_context_;
};
} // namespace operators
} // namespace paddle
......@@ -106,7 +106,6 @@ endif()
if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op)
list(APPEND PYBIND_DEPS py_layer_op)
endif()
set(PYBIND_SRCS
......
......@@ -46,7 +46,6 @@ limitations under the License. */
#include "paddle/fluid/imperative/nccl_context.h"
#include "paddle/fluid/imperative/partial_grad_engine.h"
#include "paddle/fluid/imperative/profiler.h"
#include "paddle/fluid/imperative/py_layer_fwd.h"
#include "paddle/fluid/imperative/reducer.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/imperative/type_defs.h"
......@@ -2656,60 +2655,6 @@ void BindImperative(py::module *m_ptr) {
.def("init", [](imperative::HeterParallelContext &self) { self.Init(); });
#endif
m.def("pylayer_apply",
[](const platform::CPUPlace &place,
const py::object &cls,
const py::args args,
const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
m.def("pylayer_apply",
[](const platform::CUDAPlace &place,
const py::object &cls,
const py::args args,
const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
m.def("pylayer_apply",
[](const platform::XPUPlace &place,
const py::object &cls,
const py::args args,
const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
m.def("pylayer_apply",
[](const platform::CUDAPinnedPlace &place,
const py::object &cls,
const py::args args,
const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
m.def("pylayer_apply",
[](const platform::NPUPlace &place,
const py::object &cls,
const py::args args,
const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
m.def("pylayer_apply",
[](const platform::MLUPlace &place,
const py::object &cls,
const py::args args,
const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
m.def("pylayer_apply",
[](const platform::CustomPlace &place,
const py::object &cls,
const py::args args,
const py::kwargs kwargs) {
return imperative::PyLayerApply(place, cls, args, kwargs);
});
#if defined(PADDLE_WITH_CUDA)
m.def(
"to_uva_tensor",
......
......@@ -17,14 +17,8 @@ from ..fluid.dygraph.base import no_grad_ as no_grad # noqa: F401
from ..framework import is_grad_enabled, set_grad_enabled # noqa: F401
from . import backward_mode # noqa: F401
from .backward_mode import backward # noqa: F401
from ..fluid.framework import _in_eager_mode_
if _in_eager_mode_:
from .py_layer import EagerPyLayer as PyLayer # noqa: F401
from .py_layer import EagerPyLayerContext as PyLayerContext # noqa: F401
else:
from .py_layer import LegacyPyLayer as PyLayer # noqa: F401
from .py_layer import LegacyPyLayerContext as PyLayerContext # noqa: F401
from .py_layer import PyLayer as PyLayer # noqa: F401
from .py_layer import PyLayerContext as PyLayerContext # noqa: F401
from .saved_tensors_hooks import saved_tensors_hooks
__all__ = [ # noqa
......
......@@ -13,116 +13,11 @@
# limitations under the License.
import paddle
from paddle.amp.auto_cast import auto_cast
from paddle.fluid import core
from paddle.fluid.dygraph.amp.auto_cast import amp_state
from paddle.fluid.framework import dygraph_only
__all__ = []
class LegacyPyLayerContext:
"""
The object of this class is a context that is used in PyLayer to enhance the function.
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
# ctx is a object of PyLayerContext.
y = paddle.tanh(x)
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# ctx is a object of PyLayerContext.
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
"""
def __init__(self):
self.container = None
self._amp_state = amp_state()
def save_for_backward(self, *tensors):
"""
Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors.
Note:
This API should be called at most once, and only inside `forward`.
Args:
tensors(list of Tensors): Tensors to be stored.
Returns:
None
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
# ctx is a context object that store some objects for backward.
y = paddle.tanh(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
"""
self.container = tensors
def saved_tensor(self):
"""
Get the tensors stored by ``save_for_backward``.
Returns:
list of Tensors or None: If context contains tensors stored by `save_for_backward`,
then return these tensors, otherwise return None.
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
# ctx is a context object that store some objects for backward.
y = paddle.tanh(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
"""
return self.container
def with_mateclass(meta, *bases):
class impl(meta):
def __new__(cls, name, temp_bases, attrs):
......@@ -131,212 +26,7 @@ def with_mateclass(meta, *bases):
return type.__new__(impl, "impl", (), {})
class CPyLayer:
@classmethod
@dygraph_only
def apply(cls, *args, **kwargs):
"""
After building the custom PyLayer, run it through the ``apply``.
Args:
*args(tuple): input of PyLayer.
**kwargs(dict): input of PyLayer.
Returns:
tensors or other types : output of PyLayer.
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x, func1, func2=paddle.square):
ctx.func = func2
y = func1(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - ctx.func(y))
return grad
data = paddle.randn([2, 3], dtype="float64")
data.stop_gradient = False
# run custom Layer.
z = cus_tanh.apply(data, func1=paddle.tanh)
"""
place = paddle.fluid.framework._current_expected_place()
with paddle.fluid.dygraph.no_grad():
return core.pylayer_apply(place, cls, *args, **kwargs)
class PyLayerBackward(LegacyPyLayerContext):
def backward(self, *args, **kwargs):
with paddle.fluid.dygraph.guard():
with paddle.fluid.dygraph.no_grad():
if (
self._amp_state
and 'enable' in self._amp_state
and self._amp_state['enable']
):
with auto_cast(**args[0]._amp_state):
return self._forward_cls.backward(*args, **kwargs)
else:
return self._forward_cls.backward(*args, **kwargs)
return self._forward_cls.backward(*args, **kwargs)
class LayerMeta(type):
def __init__(cls, name, bases, attrs):
cls._backward_function = type(
name + '_backward', (PyLayerBackward,), {"_forward_cls": cls}
)
return super().__init__(name, bases, attrs)
class LegacyPyLayer(with_mateclass(LayerMeta, CPyLayer)):
"""
Build a custom `Layer` by creating subclasses. Subclasses need to follow the following rules:
1. Subclasses contain `forward` and `backward` function. Both forward and backward are @staticmethod.
Their first argument should be a context and `None` can not be included in the returned result.
2. Input of backward contains a context as the first argument, and the rest arguments are the
gradient of forward's output tensors. so the number of backward's input tensors equal to
the number of forward output tensors. If you need the forward's inputs or outputs in `backward`,
you can use `save_for_backward` to store the required tensors, and then use them in the backward.
3. Output of backward function can only be `Tensor` or tuple/list of `Tensor`.
Output tensors of backward are the gradient of forward's input tensors,
so the number of backward's output tensors equal to the number of forward input tensors.
After building the custom Layer, run it through the `apply` method.
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
# Inherit from PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x, func1, func2=paddle.square):
# ctx is a context object that store some objects for backward.
ctx.func = func2
y = func1(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
# forward has only one output, so there is only one gradient in the input of backward.
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - ctx.func(y))
# forward has only one input, so only one gradient tensor is returned.
return grad
data = paddle.randn([2, 3], dtype="float64")
data.stop_gradient = False
z = cus_tanh.apply(data, func1=paddle.tanh)
z.mean().backward()
print(data.grad)
"""
@staticmethod
def forward(ctx, *args, **kwargs):
"""
It is to be overloaded by subclasses. It must accept a object of `PyLayerContext` as
the first argument, followed by any number of arguments (tensors or other types).
`None` can not be included in the returned result.
Args:
*args(tuple): input of PyLayer.
**kwargs(dict): input of PyLayer.
Returns:
tensors or other types : output of PyLayer.
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
y = paddle.tanh(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
"""
raise NotImplementedError(
"You must implement the forward function for PyLayer."
)
@staticmethod
def backward(ctx, *args, **kwargs):
"""
This is a function to calculate the gradient. It is to be overloaded by subclasses.
It must accept a object of `PyLayerContext` as the first argument, and the rest
arguments are the gradient of forward's output tensors. Output tensors of backward
are the gradient of forward's input tensors.
Args:
*args(tuple): The gradient of forward's output tensor(s).
**kwargs(dict): The gradient of forward's output tensor(s).
Returns:
Tensor or list of Tensors: The gradient of forward's input tensor(s).
Examples:
.. code-block:: python
import paddle
from paddle.autograd import PyLayer
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
y = paddle.tanh(x)
# Pass tensors to backward.
ctx.save_for_backward(y)
return y
@staticmethod
def backward(ctx, dy):
# Get the tensors passed by forward.
y, = ctx.saved_tensor()
grad = dy * (1 - paddle.square(y))
return grad
"""
raise NotImplementedError(
"You must implement the backward function for PyLayer."
)
class EagerPyLayerContext:
class PyLayerContext:
def save_for_backward(self, *tensors):
"""
Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors.
......@@ -537,23 +227,21 @@ class EagerPyLayerContext:
self.materialize_grads = value
class EagerPyLayerBackward(core.eager.PyLayer, EagerPyLayerContext):
class PyLayerBackward(core.eager.PyLayer, PyLayerContext):
def backward(self, *args):
return self._forward_cls.backward(self, *args)
class EagerPyLayerMeta(type):
class PyLayerMeta(type):
def __init__(cls, name, bases, attrs):
cls._backward_function = type(
name + '_backward', (EagerPyLayerBackward,), {"_forward_cls": cls}
name + '_backward', (PyLayerBackward,), {"_forward_cls": cls}
)
return super().__init__(name, bases, attrs)
class EagerPyLayer(
with_mateclass(EagerPyLayerMeta, core.eager.PyLayer, EagerPyLayerContext)
):
class PyLayer(with_mateclass(PyLayerMeta, core.eager.PyLayer, PyLayerContext)):
@staticmethod
def forward(ctx, *args, **kwargs):
"""
......
......@@ -17,8 +17,7 @@ import unittest
import numpy as np
import paddle
from paddle.autograd.py_layer import EagerPyLayer, LegacyPyLayer
from paddle.fluid.framework import in_dygraph_mode
from paddle.autograd.py_layer import PyLayer
class FakeTensor(paddle.fluid.core.VarBase):
......@@ -28,7 +27,7 @@ class FakeTensor(paddle.fluid.core.VarBase):
class TestPyLayer(unittest.TestCase):
def test_simple_pylayer_multiple_output(self):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x1, x2, func1, func2=paddle.square):
ctx.func = func2
......@@ -60,7 +59,7 @@ class TestPyLayer(unittest.TestCase):
)
def test_simple_pylayer_return_none_with_no_grad(self):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x1, x2, func1, func2=paddle.square):
ctx.func = func2
......@@ -96,7 +95,7 @@ class TestPyLayer(unittest.TestCase):
)
def test_simple_pylayer_single_output(self):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x1, func1, func2=paddle.square):
ctx.func = func2
......@@ -124,7 +123,7 @@ class TestPyLayer(unittest.TestCase):
)
def test_pylayer_num_output_match(self):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class tanh(PyLayer):
@staticmethod
def forward(
ctx,
......@@ -146,7 +145,7 @@ class TestPyLayer(unittest.TestCase):
z.mean().backward()
def test_pylayer_dtype(self):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x, dtype):
y = paddle.cast(x, dtype)
......@@ -176,7 +175,7 @@ class TestPyLayer(unittest.TestCase):
self.assertIsNotNone(input1.grad)
def test_pylayer_Exception_forward(self):
class Layer_None1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class Layer_None1(PyLayer):
@staticmethod
def forward(ctx, *args):
return None
......@@ -189,7 +188,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z = Layer_None1.apply(input1)
class Layer_None2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class Layer_None2(PyLayer):
@staticmethod
def forward(ctx, *args):
return [None, args[0]]
......@@ -202,7 +201,7 @@ class TestPyLayer(unittest.TestCase):
# return None
z = Layer_None2.apply(input1)
class Layer_one1(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class Layer_one1(PyLayer):
@staticmethod
def forward(ctx, *args):
return 1
......@@ -216,7 +215,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z = Layer_one1.apply(input1)
class Layer_one2(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class Layer_one2(PyLayer):
@staticmethod
def forward(ctx, *args):
return [1, 2, args[0]]
......@@ -229,7 +228,7 @@ class TestPyLayer(unittest.TestCase):
# return int
z = Layer_one2.apply(input1)
class Layer_no_fw(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class Layer_no_fw(PyLayer):
@staticmethod
def backward(ctx, *args):
return args
......@@ -239,7 +238,7 @@ class TestPyLayer(unittest.TestCase):
z = Layer_no_fw.apply(input1)
def test_pylayer_nograd(self):
class tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class tanh(PyLayer):
@staticmethod
def forward(ctx, x1, func1, func2=paddle.square, xx=None):
ctx.func = func2
......@@ -257,9 +256,7 @@ class TestPyLayer(unittest.TestCase):
self.assertIsNone(z.grad)
def test_pylayer_Exception_bk(self):
class Layer_bk_none1(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
class Layer_bk_none1(PyLayer):
@staticmethod
def forward(ctx, x):
return x * 2
......@@ -275,9 +272,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z.sum().backward()
class Layer_bk_none2(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
class Layer_bk_none2(PyLayer):
@staticmethod
def forward(ctx, x1, x2):
return x1 + x2
......@@ -293,9 +288,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z.mean().backward()
class Layer_bk_one1(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
class Layer_bk_one1(PyLayer):
@staticmethod
def forward(ctx, x):
return x + x
......@@ -311,9 +304,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z.mean().backward()
class Layer_bk_one2(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
class Layer_bk_one2(PyLayer):
@staticmethod
def forward(ctx, x1, x2):
return x1 * 2, x2 * 5
......@@ -330,7 +321,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z.mean().backward()
class Layer_no_bk(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class Layer_no_bk(PyLayer):
@staticmethod
def forward(ctx, x):
return x * 2, x * 5
......@@ -343,9 +334,7 @@ class TestPyLayer(unittest.TestCase):
z = z[0] + z[1]
z.mean().backward()
class Layer_bk_match(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
class Layer_bk_match(PyLayer):
@staticmethod
def forward(ctx, x):
return x * 2, x * 5
......@@ -362,9 +351,7 @@ class TestPyLayer(unittest.TestCase):
z.mean().backward()
def test_pylayer_bk_return_none(self):
class Layer_bk_none1(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
class Layer_bk_none1(PyLayer):
@staticmethod
def forward(ctx, x1, x2):
return x1 + x2
......@@ -382,9 +369,7 @@ class TestPyLayer(unittest.TestCase):
with self.assertRaises(ValueError):
z.mean().backward()
class Layer_bk_none2(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
class Layer_bk_none2(PyLayer):
@staticmethod
def forward(ctx, x1, x2):
return x1 * 2, x2 * 5
......@@ -403,7 +388,7 @@ class TestPyLayer(unittest.TestCase):
z.mean().backward()
def test_pylayer_inplace(self):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
return x
......@@ -431,7 +416,7 @@ class TestPyLayer(unittest.TestCase):
self.assertIsNotNone(data.grad)
def test_pylayer_inplace_backward_error(self):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
return x
......@@ -464,7 +449,7 @@ class TestPyLayer(unittest.TestCase):
z.backward()
def test_pylayer_inplace_backward_success_1(self):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
return x
......@@ -493,7 +478,7 @@ class TestPyLayer(unittest.TestCase):
self.assertIsNotNone(data.grad)
def test_pylayer_inplace_backward_success_2(self):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
return x
......@@ -522,9 +507,7 @@ class TestPyLayer(unittest.TestCase):
self.assertIsNotNone(data.grad)
def test_pylayer_inplace_and_leaf_exception(self):
class cus_pylayer_op(
EagerPyLayer if in_dygraph_mode() else LegacyPyLayer
):
class cus_pylayer_op(PyLayer):
@staticmethod
def forward(ctx, x):
return x
......@@ -550,7 +533,7 @@ class TestPyLayer(unittest.TestCase):
z = layer(data)
def test_backward_in_backward(self):
class cus_tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class cus_tanh(PyLayer):
@staticmethod
def forward(ctx, x):
temp = x.detach()
......@@ -575,7 +558,7 @@ class TestPyLayer(unittest.TestCase):
z = cus_tanh.apply(data)
def test_return_to_tensor(self):
class Tanh(EagerPyLayer if in_dygraph_mode() else LegacyPyLayer):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x1):
y1 = paddle.tanh(x1)
......@@ -597,7 +580,7 @@ class TestPyLayer(unittest.TestCase):
z.mean().backward()
def test_materialize_grads(self):
class Tanh(EagerPyLayer):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x):
ctx.mark_not_inplace(x)
......@@ -613,7 +596,7 @@ class TestPyLayer(unittest.TestCase):
Tanh.apply(x)[0].backward()
def test_dont_materialize_grads(self):
class Tanh(EagerPyLayer):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x):
ctx.mark_not_inplace(x)
......@@ -630,7 +613,7 @@ class TestPyLayer(unittest.TestCase):
Tanh.apply(x)[0].backward()
def test_mark_non_differentiable(self):
class Tanh(EagerPyLayer):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x):
a = x + x
......@@ -648,7 +631,7 @@ class TestPyLayer(unittest.TestCase):
y.sum().backward()
def test_mark_non_differentiable2(self):
class Tanh(EagerPyLayer):
class Tanh(PyLayer):
@staticmethod
def forward(ctx, x):
a = x + x
......@@ -669,106 +652,5 @@ class TestPyLayer(unittest.TestCase):
self.assertEqual(x.grad, paddle.ones([1], dtype="float64"))
class TestPyLayerReturnType(unittest.TestCase):
def test_forward_args_fake_tensor(self):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1):
y1 = FakeTensor()
return y1, x1
@staticmethod
def backward(ctx, dy1, dy2):
return dy1
input1 = FakeTensor()
with self.assertRaises(ValueError):
y1, y2 = Tanh.apply(input1)
def test_forward_kwargs_fake_tensor(self):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1):
return x1
@staticmethod
def backward(ctx, dy1, dy2):
return dy1
input1 = FakeTensor()
with self.assertRaises(ValueError):
y = Tanh.apply(x1=input1)
def test_forward_return_fake_tensor(self):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1):
return FakeTensor()
@staticmethod
def backward(ctx, dy1, dy2):
return dy1
input1 = paddle.randn([3, 2])
with self.assertRaises(ValueError):
y = Tanh.apply(x1=input1)
def test_forward_return_fake_tensor_tuple(self):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1):
return FakeTensor(), FakeTensor()
@staticmethod
def backward(ctx, dy1, dy2):
return dy1
input1 = paddle.randn([3, 2])
with self.assertRaises(ValueError):
y = Tanh.apply(x1=input1)
def test_backward_return_fake_tensor_tuple(self):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1, x2):
return x1 + 1, x1 + 2
@staticmethod
def backward(ctx, dy1, dy2):
return FakeTensor(), 2
input1 = paddle.randn([3, 2])
input1.stop_gradient = False
with self.assertRaises(ValueError):
y, _ = Tanh.apply(input1, 1 + input1)
y.mean().backward()
def test_backward_return_fake_tensor(self):
class Tanh(LegacyPyLayer):
@staticmethod
def forward(ctx, x1):
return x1 + 1, x1 + 2
@staticmethod
def backward(ctx, dy1, dy2):
return FakeTensor()
input1 = paddle.randn([3, 2])
input1.stop_gradient = False
with self.assertRaises(ValueError):
y, _ = Tanh.apply(input1)
y.mean().backward()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册