提交 fab4b076 编写于 作者: D Double_V 提交者: ceci3

support elu_op double grad (#21822)

* support elu activation double grad,test=develop

* delete the code commit in .cc,test=develop

* fix relu test unpass, test=develop

* add elu double grad kernel and unit test

* add caculate dX in elu double grad functor, test=develop

* update the commit code,test=develop
上级 0a51098a
......@@ -764,6 +764,31 @@ class LeakyReluDoubleGradMaker
}
};
// elu grad: dx=dy if y>0 else alpha*dy*x.exp()
// elu gradgrad: ddx=ddy if y>0 else alpha*ddy*x.exp()
template <typename T>
class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
op->SetType("elu_grad_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(this->Attrs());
// Out@GRAD@GRAD: ddy
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<T>(op);
}
};
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
template <typename T>
......@@ -984,6 +1009,34 @@ REGISTER_OP_CPU_KERNEL(
plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ======================== elu register ============================ */
REGISTER_OPERATOR(
elu, ops::ActivationOp, ops::ELUOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::ELUGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::ELUGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(elu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::ELUDoubleGradMaker<paddle::framework::OpDesc>,
ops::ELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
elu_grad_grad,
ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInference);
REGISTER_ACTIVATION_CPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);
REGISTER_OP_CPU_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
ops::ELUGradGradFunctor<float>>,
ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
ops::ELUGradGradFunctor<double>>,
ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== sqrt register ============================= */
REGISTER_OPERATOR(
sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
......
......@@ -47,6 +47,18 @@ REGISTER_OP_CUDA_KERNEL(
plat::CUDADeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ======================== elu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);
REGISTER_OP_CUDA_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
ops::ELUGradGradFunctor<float>>,
ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
ops::ELUGradGradFunctor<double>>,
ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== relu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
......
......@@ -1084,7 +1084,7 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
dout * static_cast<T>(alpha) * x.exp() *
(x < static_cast<T>(0)).template cast<T>();
(x <= static_cast<T>(0)).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
......@@ -1405,6 +1405,39 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
if (dX) {
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() *
(x < static_cast<T>(0)).template cast<T>();
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * x.exp() *
(x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
......@@ -1515,6 +1548,33 @@ class SquareDoubleGradKernel
}
};
template <typename DeviceContext, typename Functor>
class ELUDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *X, *ddX, *dOut;
X = ddX = dOut = nullptr;
framework::Tensor *dX, *ddOut;
dX = ddOut = nullptr;
ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);
if (dX) dX->mutable_data<T>(X->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
functor(place, X, ddX, ddOut, dOut, dX);
}
};
template <typename DeviceContext, typename Functor>
class SqrtDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......@@ -1688,7 +1748,6 @@ class PowGradKernel
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
__macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELU, ELUFunctor, ELUGradFunctor); \
__macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \
__macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \
HardSigmoidGradFunctor); \
......
......@@ -75,6 +75,30 @@ class TestLeakyReluDoubleGradCheck(unittest.TestCase):
self.func(p)
class TestELUDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 0.0001
alpha = 1.1
dtype = np.float64
x = layers.data('x', shape, False, dtype)
x.persistable = True
y = layers.elu(x, alpha=alpha)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestSqrtDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册