未验证 提交 d7064f04 编写于 作者: Y yujun 提交者: GitHub

[PaddlePaddle hackathon] + ADD CELU (#36088)

* update

* update

* update

* try make CI pass

* doc typo

* update doc string
上级 0c31579c
......@@ -560,6 +560,28 @@ $$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$$
}
};
class CELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input is a multi-dimensional Tensor. The data type is "
"float32 or float64.");
AddOutput("Out",
"The output is a multi-dimensional Tensor which has same "
"dimension and data type as the ``x``.");
AddAttr<float>("alpha", "The alpha value of CELU").SetDefault(1.0f);
AddComment(R"DOC(
CELU Activation Operator.
Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1704.07483.
$$out = \max(0, x) + \min(0, \alpha * (e^(x/\alpha) - 1))$$
)DOC");
}
};
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -982,6 +1004,29 @@ class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
}
};
// celu grad: dx=dy if y>0 else dy*(x/alpha).exp()
// celu gradgrad: ddx=ddy if y>0 else ddy*(x/alpha).exp()/alpha
template <typename T>
class CELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("celu_grad_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(this->Attrs());
// Out@GRAD@GRAD: ddy
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
};
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
template <typename T>
......@@ -1353,6 +1398,35 @@ REGISTER_OP_CPU_KERNEL(
/* ========================================================================== */
/* ======================== celu register ============================
*/
REGISTER_OPERATOR(
celu, ops::ActivationOp, ops::CELUOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::CELUGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::CELUGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(celu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer,
ops::CELUDoubleGradMaker<paddle::framework::OpDesc>,
ops::CELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
celu_grad_grad,
ops::ActivationOpDoubleGrad<ops::CELUGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(celu, CELU, CELUFunctor, CELUGradFunctor);
REGISTER_OP_CPU_KERNEL(
celu_grad_grad, ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<float>>,
ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<double>>,
ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== sqrt register ============================= */
REGISTER_OPERATOR(
sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
......
......@@ -1202,6 +1202,59 @@ struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type;
CT zero = static_cast<CT>(0.0f);
CT one = static_cast<CT>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
__device__ __forceinline__ T operator()(const T& arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
return static_cast<T>(res);
}
};
template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
MPType one = static_cast<MPType>(1.0f);
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
__device__ __forceinline__ T operator()(const T& arg_dout,
const T& arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
MPType temp_x_pos = static_cast<MPType>(x > zero);
MPType temp_x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(
dout *
(temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename DeviceContext, typename Functor>
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......@@ -1341,6 +1394,19 @@ REGISTER_OP_CUDA_KERNEL(
ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ======================== celu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor,
CudaCELUGradFunctor);
REGISTER_OP_CUDA_KERNEL(
celu_grad_grad, ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
ops::CELUGradGradFunctor<float>>,
ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
ops::CELUGradGradFunctor<double>>,
ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== relu register ============================ */
#ifdef PADDLE_WITH_HIP
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor,
......
......@@ -1389,6 +1389,51 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
(x < static_cast<T>(0))
.select(static_cast<T>(alpha) *
((x / static_cast<T>(alpha)).exp() - static_cast<T>(1)),
x);
}
};
template <typename T>
struct CELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp_a_pos = static_cast<T>(alpha > 0);
auto temp_a_neg = static_cast<T>(alpha <= 0);
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();
// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
dx.device(d) =
dout * temp_a_pos * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_pos * temp_x_neg +
dout * temp_a_neg * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_neg * temp_x_neg;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
......@@ -1775,6 +1820,45 @@ struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct CELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad"));
if (dX) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad"));
dx.device(*d) = ddx * dout / static_cast<T>(alpha) *
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>();
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad"));
ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() +
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
......@@ -2107,6 +2191,33 @@ class ELUDoubleGradKernel
}
};
template <typename DeviceContext, typename Functor>
class CELUDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *X, *ddX, *dOut;
X = ddX = dOut = nullptr;
framework::Tensor *dX, *ddOut;
dX = ddOut = nullptr;
ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);
if (dX) dX->mutable_data<T>(X->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
functor(place, X, ddX, ddOut, dOut, dX);
}
};
template <typename DeviceContext, typename Functor>
class SqrtDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......
......@@ -22,6 +22,7 @@ import paddle
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import gradient_checker
import paddle.nn.functional as F
from decorator_helper import prog_scope
......@@ -168,6 +169,32 @@ class TestELUDoubleGradCheck(unittest.TestCase):
self.func(p)
class TestCELUDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 4, 4, 4]
eps = 1e-6
alpha = 0.2
dtype = np.float64
SEED = 0
x = layers.data('x', shape, False, dtype)
x.persistable = True
y = F.celu(x, alpha=alpha)
np.random.RandomState(SEED)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestSqrtDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
......
......@@ -1827,6 +1827,94 @@ class TestELUAPI(unittest.TestCase):
self.elu(x_fp16)
def celu(x, alpha):
out_ref = np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x / alpha) - 1))
return out_ref.astype(x.dtype)
class TestCELU(TestActivation):
def setUp(self):
self.op_type = "celu"
self.init_dtype()
np.random.seed(1024)
x = np.random.uniform(-3, 3, [10, 12]).astype(self.dtype)
alpha = 1.5
out = celu(x, alpha)
self.inputs = {'X': x}
self.attrs = {'alpha': alpha}
self.outputs = {'Out': out}
def test_check_grad(self):
if self.dtype == np.float16:
return
self.check_grad(['X'], 'Out')
class TestCELUAPI(unittest.TestCase):
# test paddle.nn.CELU, paddle.nn.functional.celu
def setUp(self):
np.random.seed(1024)
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
self.executed_api()
def executed_api(self):
self.celu = F.celu
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [10, 12])
out1 = self.celu(x, 1.5)
m = paddle.nn.CELU(1.5)
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = celu(self.x_np, 1.5)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = self.celu(x, 1.5)
x = paddle.to_tensor(self.x_np)
m = paddle.nn.CELU(1.5)
out2 = m(x)
out_ref = celu(self.x_np, 1.5)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
out1 = self.celu(x, 0.2)
x = paddle.to_tensor(self.x_np)
m = paddle.nn.CELU(0.2)
out2 = m(x)
out_ref = celu(self.x_np, 0.2)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()
def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, self.celu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(
name='x_int32', shape=[10, 12], dtype='int32')
self.assertRaises(TypeError, self.celu, x_int32)
# The alpha must be not equal 0
x_fp32 = paddle.fluid.data(
name='x_fp32', shape=[10, 12], dtype='float32')
self.assertRaises(ZeroDivisionError, F.celu, x_fp32, 0)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[10, 12], dtype='float16')
self.celu(x_fp16)
class TestELUInplaceAPI(TestELUAPI):
# test paddle.nn.functional.elu_
def executed_api(self):
......@@ -2791,6 +2879,7 @@ create_test_act_fp16_class(TestBRelu)
create_test_act_fp16_class(TestRelu6)
create_test_act_fp16_class(TestSoftRelu, grad_atol=0.85)
create_test_act_fp16_class(TestELU)
create_test_act_fp16_class(TestCELU)
create_test_act_fp16_class(TestReciprocal)
create_test_act_fp16_class(TestLog)
if core.is_compiled_with_rocm():
......
......@@ -22,6 +22,9 @@ class TestLayerPrint(unittest.TestCase):
module = nn.ELU(0.2)
self.assertEqual(str(module), 'ELU(alpha=0.2)')
module = nn.CELU(0.2)
self.assertEqual(str(module), 'CELU(alpha=0.2)')
module = nn.GELU(True)
self.assertEqual(str(module), 'GELU(approximate=True)')
......
......@@ -25,6 +25,7 @@ from .clip import ClipGradByNorm # noqa: F401
from .clip import ClipGradByValue # noqa: F401
from .decode import BeamSearchDecoder # noqa: F401
from .decode import dynamic_decode # noqa: F401
from .layer.activation import CELU # noqa: F401
from .layer.activation import ELU # noqa: F401
from .layer.activation import GELU # noqa: F401
from .layer.activation import Tanh # noqa: F401
......@@ -185,6 +186,7 @@ def weight_norm(*args):
__all__ = [ #noqa
'BatchNorm',
'CELU',
'GroupNorm',
'LayerNorm',
'SpectralNorm',
......
......@@ -15,6 +15,7 @@
# TODO: import all neural network related api under this directory,
# including layers, linear, conv, rnn etc.
from .activation import celu # noqa: F401
from .activation import elu # noqa: F401
from .activation import elu_ # noqa: F401
from .activation import gelu # noqa: F401
......@@ -115,6 +116,7 @@ from ...fluid.layers import temporal_shift # noqa: F401
from .sparse_attention import sparse_attention
__all__ = [ #noqa
'celu',
'conv1d',
'conv1d_transpose',
'conv2d',
......
......@@ -31,6 +31,50 @@ from paddle import _C_ops
__all__ = []
def celu(x, alpha=1.0, name=None):
r"""
celu activation.
.. math::
celu(x) = max(0, x) + min(0, \alpha * (e^{x/\alpha}-1))
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
alpha (float, optional): The 'alpha' value of the CELU formulation. Default is 1.0.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([[-1., 6.], [1., 15.6]])
out = F.celu(x, alpha=0.2)
# [[-0.19865242, 6. ],
# [ 1. , 15.60000038]]
"""
if alpha == 0:
raise ZeroDivisionError("alpha cannot be 0 for celu")
if in_dygraph_mode():
return _C_ops.celu(x, 'alpha', alpha)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'celu')
helper = LayerHelper("celu", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
helper.append_op(
type='celu',
inputs={'X': x},
outputs={'Out': out},
attrs={'alpha': alpha})
return out
def elu(x, alpha=1.0, name=None):
r"""
elu activation.
......
......@@ -18,6 +18,7 @@ from . import rnn # noqa: F401
from . import transformer # noqa: F401
from . import container # noqa: F401
from .activation import CELU # noqa: F401
from .activation import PReLU # noqa: F401
from .activation import ReLU # noqa: F401
from .activation import ReLU6 # noqa: F401
......
......@@ -25,6 +25,48 @@ from paddle.nn import Layer
__all__ = []
class CELU(Layer):
r"""
CELU Activation.
.. math::
CELU(x) = max(0, x) + min(0, \alpha * (e^{x/\alpha}-1))
Parameters:
alpha (float, optional): The 'alpha' value of the CELU formulation. Default is 1.0.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([[-1. ,6.], [1., 15.6]])
m = paddle.nn.CELU(0.2)
out = m(x)
# [[-0.19865242, 6. ],
# [ 1. , 15.60000038]]
"""
def __init__(self, alpha=1.0, name=None):
super(CELU, self).__init__()
self._alpha = alpha
self._name = name
def forward(self, x):
return F.celu(x, self._alpha, self._name)
def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
return 'alpha={}{}'.format(self._alpha, name_str)
class ELU(Layer):
r"""
ELU Activation.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册