未验证 提交 29c28e2f 编写于 作者: J Jiabin Yang 提交者: GitHub

support more custom vjp (#52533)

上级 5257a79e
...@@ -23,10 +23,12 @@ limitations under the License. */ ...@@ -23,10 +23,12 @@ limitations under the License. */
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/common_infer_shape_functions.h" #include "paddle/fluid/operators/common_infer_shape_functions.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/backends/dynload/port.h" #include "paddle/phi/backends/dynload/port.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/backward.h" #include "paddle/phi/infermeta/backward.h"
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
...@@ -80,6 +82,22 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -80,6 +82,22 @@ class ActivationGradOpMaker : public framework::SingleGradOpMaker<T> {
} }
} }
}; };
class HardSwishCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
public:
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
protected:
void Apply() override {
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor out_grad = this->GetSingleOutputGrad("Out");
paddle::Tensor dx = this->GetSingleInputGrad("X");
auto* dx_ptr = this->GetOutputPtr(&dx);
std::string dx_name = this->GetOutputName(dx);
VLOG(6) << "Runing hardswish_grad composite func";
prim::hardswish_grad<prim::DescTensor>(x, out_grad, dx_ptr);
this->RecoverOutputName(dx, dx_name);
}
};
phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx, phi::KernelKey GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper, const framework::OperatorWithKernel& oper,
...@@ -400,6 +418,25 @@ namespace plat = paddle::platform; ...@@ -400,6 +418,25 @@ namespace plat = paddle::platform;
ops::ActivationOpGrad, \ ops::ActivationOpGrad, \
ops::ActivationGradOpInplaceInferer); ops::ActivationGradOpInplaceInferer);
#define REGISTER_ACTIVATION_OP_WITH_COMP( \
KERNEL_TYPE, OP_NAME, functor, grad_functor) \
REGISTER_OPERATOR( \
KERNEL_TYPE, \
ops::ActivationOp, \
ops::OP_NAME##OpMaker, \
ops::ActivationOpInferVarType, \
ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(), \
paddle::framework::OpDesc>, \
ops::ActivationGradOpMaker<ops::grad_functor<float>::FwdDeps(), \
paddle::imperative::OpBase>, \
ops::OP_NAME##CompositeGradOpMaker, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
ops::ActFwdInplaceInferer, \
void>::type); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, \
ops::ActivationOpGrad, \
ops::ActivationGradOpInplaceInferer);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name) \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name) \
...@@ -416,10 +453,10 @@ REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu) ...@@ -416,10 +453,10 @@ REGISTER_ACTIVATION_CPU_KERNEL(soft_relu, SoftRelu)
REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor); REGISTER_ACTIVATION_OP(relu6, Relu6, Relu6Functor, Relu6GradFunctor);
REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor); REGISTER_ACTIVATION_OP(mish, Mish, MishFunctor, MishGradFunctor);
REGISTER_ACTIVATION_OP(hard_swish, REGISTER_ACTIVATION_OP_WITH_COMP(hard_swish,
HardSwish, HardSwish,
HardSwishFunctor, HardSwishFunctor,
HardSwishGradFunctor); HardSwishGradFunctor);
REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor); REGISTER_ACTIVATION_OP(swish, Swish, SwishFunctor, SwishGradFunctor);
/* ========================== register checkpoint ===========================*/ /* ========================== register checkpoint ===========================*/
......
...@@ -32,6 +32,42 @@ using Tensor = paddle::Tensor; ...@@ -32,6 +32,42 @@ using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArrayBase<paddle::Tensor>; using IntArray = paddle::experimental::IntArrayBase<paddle::Tensor>;
// This function should have as same signature as phi, which defined in // This function should have as same signature as phi, which defined in
// paddle/phi/api/backward/backward_api.h // paddle/phi/api/backward/backward_api.h
template <typename T>
void hardswish_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto offset = full<T>(phi::vectorize(x.dims()), 3.0, x.dtype());
auto condition = less_equal<T>(x, offset);
auto tmp1 = where<T>(condition, out_grad * ((x / 3.0) + 0.5), out_grad);
auto res = where<T>(
less_than<T>(x, full<T>(phi::vectorize(x.dims()), -3.0, x.dtype())),
full<T>(phi::vectorize(x.dims()), 0.0, x.dtype()),
tmp1);
set_output<T>(res, x_grad);
}
}
template <typename T>
void leaky_relu_grad(const Tensor& out,
const Tensor& out_grad,
float negative_slope,
Tensor* x_grad) {
if (x_grad) {
auto condition = greater_than<T>(
out, full<T>(phi::vectorize(out.dims()), 0.0, out.dtype()));
auto res = where<T>(condition, out_grad, out_grad * negative_slope);
set_output<T>(res, x_grad);
}
}
template <typename T>
void silu_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) {
auto sigmoid = 1.0 / (1.0 + exp<T>(-x));
auto res = out_grad * sigmoid * (1.0 + x * (1.0 - sigmoid));
set_output<T>(res, x_grad);
}
}
template <typename T> template <typename T>
void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) { void relu_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
if (x_grad) { if (x_grad) {
......
...@@ -376,7 +376,7 @@ class CompositeGradOpMakerBase { ...@@ -376,7 +376,7 @@ class CompositeGradOpMakerBase {
} }
std::vector<framework::VarDesc*> MultiInputGrad( std::vector<framework::VarDesc*> MultiInputGrad(
const std::string& name, bool drop_empty_grad = true) const { const std::string& name) const {
std::vector<std::string> ret_val; std::vector<std::string> ret_val;
std::vector<framework::VarDesc*> input_grads; std::vector<framework::VarDesc*> input_grads;
auto var_names = this->MultiForwardInputVarName(name); auto var_names = this->MultiForwardInputVarName(name);
...@@ -393,39 +393,7 @@ class CompositeGradOpMakerBase { ...@@ -393,39 +393,7 @@ class CompositeGradOpMakerBase {
return framework::kEmptyVarName; return framework::kEmptyVarName;
} }
}); });
if (!drop_empty_grad) { for (const auto& name : ret_val) {
for (const auto& name : ret_val) {
if (original_block_->HasVar(name)) {
// Copy Var from original block to active block, or create a new one.
CopyVarFromOrig(name);
input_grads.emplace_back(
StaticCompositeContext::Instance().GetBlock()->FindVar(name));
} else {
input_grads.emplace_back(
StaticCompositeContext::Instance().GetBlock()->Var(name));
}
}
return input_grads;
}
PADDLE_ENFORCE_LE(
var_names.size(),
1UL,
platform::errors::Unavailable(
"BUG from operator developer:"
" for input argument with a list of variables, "
" drop_empty_grad is not allowed because it makes"
" the correspondence bewteen a variable and its gradient"
" ambiguous."));
std::vector<std::string> dropped_ret_val;
dropped_ret_val.reserve(ret_val.size());
std::copy_if(
ret_val.begin(),
ret_val.end(),
std::back_inserter(dropped_ret_val),
[](const std::string& str) { return str != framework::kEmptyVarName; });
for (const auto& name : dropped_ret_val) {
// TODO(jiabin): Will this cause fill zeros error?
if (original_block_->HasVar(name)) { if (original_block_->HasVar(name)) {
// Copy Var from original block to active block, or create a new one. // Copy Var from original block to active block, or create a new one.
CopyVarFromOrig(name); CopyVarFromOrig(name);
...@@ -437,6 +405,15 @@ class CompositeGradOpMakerBase { ...@@ -437,6 +405,15 @@ class CompositeGradOpMakerBase {
} }
} }
return input_grads; return input_grads;
PADDLE_ENFORCE_LE(
var_names.size(),
1UL,
platform::errors::Unavailable(
"BUG from operator developer:"
" for input argument with a list of variables, "
" drop_empty_grad is not allowed because it makes"
" the correspondence bewteen a variable and its gradient"
" ambiguous."));
} }
std::vector<framework::VarDesc*> MultiOutputGrad( std::vector<framework::VarDesc*> MultiOutputGrad(
......
...@@ -829,6 +829,7 @@ ...@@ -829,6 +829,7 @@
kernel : kernel :
func : leaky_relu_grad func : leaky_relu_grad
backward : leaky_relu_double_grad backward : leaky_relu_double_grad
composite: leaky_relu_grad(x, out_grad, negative_slope, x_grad)
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : lerp_grad - backward_op : lerp_grad
...@@ -1461,6 +1462,7 @@ ...@@ -1461,6 +1462,7 @@
param : [x] param : [x]
kernel : kernel :
func : silu_grad func : silu_grad
composite : silu_grad(x, out_grad, x_grad)
inplace : (out_grad -> x_grad) inplace : (out_grad -> x_grad)
- backward_op : sin_double_grad - backward_op : sin_double_grad
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册