未验证 提交 436808c6 编写于 作者: Z zhupengyang 提交者: GitHub

elu support alpha < 0 (#37316) (#37437)

上级 58a51130
......@@ -560,6 +560,22 @@ $$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$$
}
};
template <typename T>
class ELUGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("elu_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("Out", this->Output("Out"));
op->SetInput("X", this->Input("X"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -1233,13 +1249,11 @@ REGISTER_OP_CPU_KERNEL(
/* ========================================================================== */
/* ======================== elu register ============================ */
REGISTER_OPERATOR(
elu, ops::ActivationOp, ops::ELUOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::ELUGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::ELUGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(elu, ops::ActivationOp, ops::ELUOpMaker,
ops::ActivationOpInferVarType,
ops::ELUGradOpMaker<paddle::framework::OpDesc>,
ops::ELUGradOpMaker<paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(elu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer,
ops::ELUDoubleGradMaker<paddle::framework::OpDesc>,
......@@ -1249,7 +1263,14 @@ REGISTER_OPERATOR(
ops::ActivationOpDoubleGrad<ops::ELUGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(elu, ELU, ELUFunctor, ELUGradFunctor);
REGISTER_OP_CPU_KERNEL(elu,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ELUFunctor<float>>,
ops::ActivationKernel<paddle::platform::CPUDeviceContext,
ops::ELUFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
elu_grad, ops::ELUGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ELUGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CPUDeviceContext,
ops::ELUGradGradFunctor<float>>,
......
......@@ -1161,11 +1161,12 @@ struct CudaELUFunctor : public BaseActivationFunctor<T> {
return {{"alpha", &alpha}};
}
// elu(x) = max(0, x) + min(0, alpha * (exp(x) - 1))
// elu(x) = x, if x > 0
// elu(x) = alpha * (e^x - 1), if x <= 0
__device__ __forceinline__ T operator()(const T& arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x) - one);
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
CT res = x > zero ? x : temp;
return static_cast<T>(res);
}
};
......@@ -1174,34 +1175,84 @@ template <typename T>
struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
MPType one = static_cast<MPType>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// dx = dout, if alpha > 0 and x > 0
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
// dx = 0, if alpha <= 0 and x <=0
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
__device__ __forceinline__ T operator()(const T& arg_dout,
const T& arg_out) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType a = static_cast<MPType>(alpha);
MPType out_pos = static_cast<MPType>(out > zero);
MPType out_neg = static_cast<MPType>(out <= zero);
return static_cast<T>(dout * (out_pos + out_neg * (out + a)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
struct CudaELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
__device__ __forceinline__ T operator()(const T& arg_dout, const T& arg_out,
const T& arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
MPType temp_x_pos = static_cast<MPType>(x > zero);
MPType temp_x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(
dout * (temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * a * exp(x) +
temp_a_neg * temp_x_pos * (one + a * exp(x))));
MPType x_pos = static_cast<MPType>(x > zero);
MPType x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(dout * (x_pos + x_neg * (out + a)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename DeviceContext, typename T>
class ELUGradCudaKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out = ctx.Input<framework::Tensor>("Out");
auto* x = ctx.Input<framework::Tensor>("X");
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
const float alpha = ctx.Attr<float>("alpha");
auto& dev_ctx = ctx.device_context<DeviceContext>();
std::vector<const framework::Tensor*> ins = {d_out, out};
std::vector<framework::Tensor*> outs = {d_x};
if (alpha > 0) {
CudaELUGradFunctor<T> functor;
functor.alpha = alpha;
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, functor);
} else {
CudaELUGradNegativeAlphaFunctor<T> functor;
functor.alpha = alpha;
ins.push_back(x);
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, functor);
}
}
};
template <typename DeviceContext, typename Functor>
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......@@ -1330,7 +1381,17 @@ REGISTER_OP_CUDA_KERNEL(
/* ========================================================================== */
/* ======================== elu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(elu, ELU, CudaELUFunctor, CudaELUGradFunctor);
REGISTER_OP_CUDA_KERNEL(
elu, ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaELUFunctor<float>>,
ops::ActivationCudaKernel<paddle::platform::CUDADeviceContext,
ops::CudaELUFunctor<double>>,
ops::ActivationCudaKernel<plat::CUDADeviceContext,
ops::CudaELUFunctor<plat::float16>>);
REGISTER_OP_CUDA_KERNEL(
elu_grad, ops::ELUGradCudaKernel<plat::CUDADeviceContext, float>,
ops::ELUGradCudaKernel<plat::CUDADeviceContext, double>,
ops::ELUGradCudaKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
elu_grad_grad, ops::ELUDoubleGradKernel<plat::CUDADeviceContext,
......
......@@ -1311,25 +1311,70 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp_a_pos = static_cast<T>(alpha > 0);
auto temp_a_neg = static_cast<T>(alpha <= 0);
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();
// dx = dout, if alpha > 0 and x > 0
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
// dx = 0, if alpha <= 0 and x <=0
dx.device(d) =
dout * temp_a_pos * temp_x_pos +
dout * static_cast<T>(alpha) * x.exp() * temp_a_pos * temp_x_neg +
dout * (static_cast<T>(1) + static_cast<T>(alpha) * x.exp()) *
temp_a_neg * temp_x_pos;
// case 1: alpha >= 0
// dx = dout, if out > 0
// dx = dout * (out + alpha), if out <= 0
dx.device(d) = (out > static_cast<T>(0))
.select(dout, dout * (out + static_cast<T>(alpha)));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct ELUGradNegativeAlphaFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
// case 2: alpha < 0
// dx = dout, if x > 0
// dx = dout * (out + alpha), if x <=0
dx.device(d) = (x > static_cast<T>(0))
.select(dout, dout * static_cast<T>(alpha) * x.exp());
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename DeviceContext, typename T>
class ELUGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Out = context.Input<framework::Tensor>("Out");
auto* dOut =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
const float alpha = context.Attr<float>("alpha");
dX->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "elu_grad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Input", "Out", "elu_grad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "dOut", "elu_grad"));
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "dX", "elu_grad"));
auto* place =
context.template device_context<DeviceContext>().eigen_device();
if (alpha > 0) {
ELUGradFunctor<T> functor;
functor.alpha = alpha;
functor(*place, x, out, dout, dx);
} else {
ELUGradNegativeAlphaFunctor<T> functor;
functor.alpha = alpha;
functor(*place, x, out, dout, dx);
}
}
};
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
......
......@@ -104,7 +104,7 @@ class InplaceABNActivation {
auto temp2 = (y * temp / static_cast<T>(alpha) + static_cast<T>(1)).log();
x.device(d) = (y * temp1 + temp2).template cast<T>();
ELUGradFunctor<T> functor;
ELUGradNegativeAlphaFunctor<T> functor;
compute(ctx, &functor, d, x, y, dy, dx);
} else {
PADDLE_THROW(
......
......@@ -1742,7 +1742,7 @@ class TestSoftReluOpError(unittest.TestCase):
def elu(x, alpha):
out_ref = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
out_ref = np.where(x > 0, x, alpha * (np.exp(x) - 1))
return out_ref.astype(x.dtype)
......@@ -1753,7 +1753,7 @@ class TestELU(TestActivation):
np.random.seed(1024)
x = np.random.uniform(-3, 3, [10, 12]).astype(self.dtype)
alpha = 1.
alpha = self.get_alpha()
out = elu(x, alpha)
# Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1)
# is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here
......@@ -1766,6 +1766,14 @@ class TestELU(TestActivation):
return
self.check_grad(['X'], 'Out')
def get_alpha(self):
return 1.
class TestELUAlpha(TestELU):
def get_alpha(self):
return -0.2
class TestELUAPI(unittest.TestCase):
# test paddle.nn.ELU, paddle.nn.functional.elu
......@@ -1832,6 +1840,12 @@ class TestELUInplaceAPI(TestELUAPI):
def executed_api(self):
self.elu = F.elu_
def test_alpha_error(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
self.assertRaises(Exception, F.elu_, x, -0.2)
paddle.enable_static()
class TestReciprocal(TestActivation):
def setUp(self):
......
......@@ -37,7 +37,13 @@ def elu(x, alpha=1.0, name=None):
.. math::
elu(x) = max(0, x) + min(0, \alpha * (e^{x}-1))
elu(x)=
\left\{
\begin{array}{lcl}
x,& &\text{if } \ x > 0 \\
alpha * (e^{x} - 1),& &\text{if } \ x <= 0
\end{array}
\right.
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
......@@ -80,6 +86,7 @@ def elu_(x, alpha=1.0, name=None):
Inplace version of ``elu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_nn_cn_elu`.
"""
assert alpha >= 0., "elu_ only support alpha >= 0, please use elu instead."
return _C_ops.elu_(x, 'alpha', alpha)
......
......@@ -31,7 +31,13 @@ class ELU(Layer):
.. math::
ELU(x) = max(0, x) + min(0, \alpha * (e^{x}-1))
ELU(x)=
\left\{
\begin{array}{lcl}
x,& &\text{if } \ x > 0 \\
alpha * (e^{x} - 1),& &\text{if } \ x <= 0
\end{array}
\right.
Parameters:
alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册