未验证 提交 f8863e06 编写于 作者: Z zhupengyang 提交者: GitHub

leaky_relu and LeakyReLU: alpha->negative_slope (#26216)

上级 c6090660
......@@ -781,8 +781,8 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
}
};
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
// leaky_relu Grad: dx=dy if x>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if x>=0 else alpha * ddx
template <typename T>
class LeakyReluDoubleGradMaker
: public ::paddle::framework::SingleGradOpMaker<T> {
......@@ -792,8 +792,8 @@ class LeakyReluDoubleGradMaker
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("leaky_relu_grad_grad");
// input1: Out
op->SetInput("Out", this->Input("Out"));
// input1: X
op->SetInput("X", this->Input("X"));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(this->Attrs());
......
......@@ -1084,7 +1084,11 @@ struct LeakyReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
if (alpha < 1.f) {
out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
} else {
out.device(d) = x.cwiseMin(static_cast<T>(alpha) * x);
}
}
};
......@@ -1098,12 +1102,12 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 =
static_cast<T>(alpha) * (out <= static_cast<T>(0)).template cast<T>();
auto temp2 = (out > static_cast<T>(0)).template cast<T>();
static_cast<T>(alpha) * (x < static_cast<T>(0)).template cast<T>();
auto temp2 = (x >= static_cast<T>(0)).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
......@@ -1451,18 +1455,18 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Output", "Out", "LeakyReluGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "LeakyReluGradGrad"));
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad"));
ddout.device(*d) = ddx *
((out > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) *
(out <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
ddout.device(*d) =
ddx *
((x > static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * (x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};
template <typename T>
......
......@@ -41,12 +41,12 @@ static void InitRandom(framework::Tensor *tensor,
template <typename T>
struct LeakyReluGradGradEachElementFunctor {
LeakyReluGradGradEachElementFunctor(const T *ddx, const T *out, T alpha,
LeakyReluGradGradEachElementFunctor(const T *ddx, const T *x, T alpha,
T *ddout)
: ddx_(ddx), out_(out), alpha_(alpha), ddout_(ddout) {}
: ddx_(ddx), x_(x), alpha_(alpha), ddout_(ddout) {}
HOSTDEVICE void operator()(int idx) {
if (out_[idx] > 0) {
if (x_[idx] >= 0) {
ddout_[idx] = ddx_[idx];
} else {
ddout_[idx] = ddx_[idx] * alpha_;
......@@ -54,7 +54,7 @@ struct LeakyReluGradGradEachElementFunctor {
}
const T *ddx_;
const T *out_;
const T *x_;
T alpha_;
T *ddout_;
};
......@@ -66,13 +66,13 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim,
LeakyReluGradGradFunctor<T> functor;
functor.alpha = alpha;
auto &dev_ctx = *platform::DeviceContextPool::Instance().Get(place);
framework::Tensor *x = nullptr;
framework::Tensor *out = nullptr;
framework::Tensor *dout = nullptr;
framework::Tensor *dx = nullptr;
framework::Tensor out;
out.Resize(dim);
InitRandom<T>(&out, place);
framework::Tensor x;
x.Resize(dim);
InitRandom<T>(&x, place);
framework::Tensor ddx;
ddx.Resize(dim);
......@@ -85,22 +85,22 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim,
framework::Tensor ddout_actual;
ddout_actual.mutable_data<T>(dim, place);
LeakyReluGradGradEachElementFunctor<T> actual_functor(
ddx.data<T>(), out.data<T>(), static_cast<T>(alpha),
ddx.data<T>(), x.data<T>(), static_cast<T>(alpha),
ddout_actual.data<T>());
int64_t limit = out.numel();
int64_t limit = x.numel();
#ifdef __NVCC__
if (platform::is_gpu_place(place)) {
auto &cuda_dev_ctx = dynamic_cast<platform::CUDADeviceContext &>(dev_ctx);
functor(cuda_dev_ctx, x, &out, &ddx, &ddout, dout, dx);
functor(cuda_dev_ctx, &x, out, &ddx, &ddout, dout, dx);
platform::ForRange<platform::CUDADeviceContext> for_range(cuda_dev_ctx,
limit);
for_range(actual_functor);
} else {
#endif
auto &cpu_dev_ctx = dynamic_cast<platform::CPUDeviceContext &>(dev_ctx);
functor(cpu_dev_ctx, x, &out, &ddx, &ddout, dout, dx);
functor(cpu_dev_ctx, &x, out, &ddx, &ddout, dout, dx);
platform::ForRange<platform::CPUDeviceContext> for_range(cpu_dev_ctx,
limit);
for_range(actual_functor);
......
......@@ -9772,13 +9772,10 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
return out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.leaky_relu")
@templatedoc()
def leaky_relu(x, alpha=0.02, name=None):
"""
:alias_main: paddle.nn.functional.leaky_relu
:alias: paddle.nn.functional.leaky_relu,paddle.nn.functional.activation.leaky_relu
:old_api: paddle.fluid.layers.leaky_relu
${comment}
Args:
x(${x_type}): ${x_comment}
......@@ -9807,19 +9804,7 @@ def leaky_relu(x, alpha=0.02, name=None):
res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res])
print(res_val) # [[-0.1, 2], [3, -0.4]]
"""
if in_dygraph_mode():
return core.ops.leaky_relu(x, 'alpha', alpha)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'leaky_relu')
inputs = {'X': [x]}
attrs = {'alpha': alpha}
helper = LayerHelper('leaky_relu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='leaky_relu', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out
return paddle.nn.functional.leaky_relu(x, alpha, name)
def soft_relu(x, threshold=40.0, name=None):
......
......@@ -903,18 +903,30 @@ class TestReluAPI(unittest.TestCase):
F.relu(x_fp16)
def ref_leaky_relu(x, alpha=0.01):
out = np.copy(x)
out[out < 0] *= alpha
return out
class TestLeakyRelu(TestActivation):
def get_alpha(self):
return 0.02
def setUp(self):
self.op_type = "leaky_relu"
self.init_dtype()
alpha = self.get_alpha()
np.random.seed(10)
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
out = np.maximum(x, 0.02 * x)
x[np.abs(x) < 0.005] = 0.05
out = ref_leaky_relu(x, alpha)
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {'alpha': alpha}
def test_check_grad(self):
if self.dtype == np.float16:
......@@ -922,18 +934,78 @@ class TestLeakyRelu(TestActivation):
self.check_grad(['X'], 'Out')
class TestLeakyReluOpError(unittest.TestCase):
class TestLeakyReluAlpha1(TestLeakyRelu):
def get_alpha(self):
return 2
class TestLeakyReluAlpha2(TestLeakyRelu):
def get_alpha(self):
return -0.01
class TestLeakyReluAlpha3(TestLeakyRelu):
def get_alpha(self):
return -2.0
class TestLeakyReluAPI(unittest.TestCase):
# test paddle.nn.LeakyReLU, paddle.nn.functional.leaky_relu,
# fluid.layers.leaky_relu
def setUp(self):
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_static_api(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', [10, 12])
out1 = F.leaky_relu(x)
m = paddle.nn.LeakyReLU()
out2 = m(x)
exe = paddle.static.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out1, out2])
out_ref = ref_leaky_relu(self.x_np)
for r in res:
self.assertEqual(np.allclose(out_ref, r), True)
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_variable(self.x_np)
out1 = F.leaky_relu(x)
m = paddle.nn.LeakyReLU()
out2 = m(x)
out_ref = ref_leaky_relu(self.x_np)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
out1 = F.leaky_relu(x, 0.6)
m = paddle.nn.LeakyReLU(0.6)
out2 = m(x)
out_ref = ref_leaky_relu(self.x_np, 0.6)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
paddle.enable_static()
def test_fluid_api(self):
with fluid.program_guard(fluid.Program()):
x = fluid.data('X', [10, 12])
out = fluid.layers.leaky_relu(x, 0.01)
exe = fluid.Executor(self.place)
res = exe.run(feed={'X': self.x_np}, fetch_list=[out])
out_ref = ref_leaky_relu(self.x_np)
self.assertEqual(np.allclose(out_ref, res[0]), True)
def test_errors(self):
with program_guard(Program()):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, fluid.layers.leaky_relu, 1)
self.assertRaises(TypeError, F.leaky_relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = fluid.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, fluid.layers.leaky_relu, x_int32)
# support the input dtype is float32
x_fp16 = fluid.layers.data(
name='x_fp16', shape=[12, 10], dtype='float32')
fluid.layers.leaky_relu(x_fp16)
x_int32 = paddle.data(name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.leaky_relu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.data(name='x_fp16', shape=[12, 10], dtype='float16')
F.leaky_relu(x_fp16)
def gelu(x, approximate):
......
......@@ -316,21 +316,6 @@ class TestLayer(LayerTest):
self.assertTrue(np.allclose(static_ret, dy_ret_value))
def test_leakyrelu(self):
inputs = np.random.uniform(-1, 1, (10, 10)).astype('float32')
with self.static_graph():
t = layers.data(name='t', shape=[10, 10], dtype='float32')
ret = layers.leaky_relu(t, alpha=0.01)
static_ret = self.get_static_graph_result(
feed={'t': inputs}, fetch_list=[ret])[0]
with self.dynamic_graph():
lrelu = paddle.nn.LeakyReLU(alpha=0.01)
dy_ret = lrelu(base.to_variable(inputs))
dy_ret_value = dy_ret.numpy()
self.assertTrue(np.allclose(static_ret, dy_ret_value))
def test_pad2d(self):
with self.static_graph():
t = layers.data(name='t', shape=[-1, 3, 5, 5], dtype='float32')
......@@ -2678,13 +2663,6 @@ class TestBook(LayerTest):
out = layers.brelu(input, t_min=1.0, t_max=20.0, name='brelu')
return (out)
def make_leaky_relu(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = self._get_data(name="input", shape=[16], dtype="float32")
out = layers.leaky_relu(input, alpha=0.1, name='leaky_relu')
return (out)
def make_soft_relu(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
......
......@@ -17,7 +17,6 @@ from ...fluid.layers import brelu #DEFINE_ALIAS
from ...fluid.layers import erf #DEFINE_ALIAS
from ...fluid.layers import hard_sigmoid #DEFINE_ALIAS
from ...fluid.layers import hard_swish #DEFINE_ALIAS
from ...fluid.layers import leaky_relu #DEFINE_ALIAS
from ...fluid.layers import maxout #DEFINE_ALIAS
from ...fluid.layers import soft_relu #DEFINE_ALIAS
from ...fluid.layers import swish #DEFINE_ALIAS
......@@ -386,6 +385,57 @@ def hsigmoid(input,
return out
def leaky_relu(x, negative_slope=0.01, name=None):
"""
leaky_relu activation
.. math:
leaky_relu(x)=
\left\{
\begin{aligned}
&x, & & if \ x >= 0 \\
&negative\_slope * x, & & otherwise \\
\end{aligned}
\right. \\
Args:
x (Tensor): The input Tensor with data type float32, float64.
negative_slope (float, optional): Slope of the activation function at
:math:`x < 0` . Default is 0.01.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-2, 0, 1]))
out = F.leaky_relu(x) # [-0.02, 0., 1.]
"""
if in_dygraph_mode():
return core.ops.leaky_relu(x, 'alpha', negative_slope)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'leaky_relu')
helper = LayerHelper('leaky_relu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='leaky_relu',
inputs={'X': x},
outputs={'Out': out},
attrs={'alpha': negative_slope})
return out
def prelu(x, weight, name=None):
"""
prelu activation.
......
......@@ -558,11 +558,17 @@ class LeakyReLU(layers.Layer):
.. math:
out = max(x, alpha * x)
LeakyReLU(x)=
\left\{
\begin{aligned}
&x, & & if \ x >= 0 \\
&negative\_slope * x, & & otherwise \\
\end{aligned}
\right. \\
Parameters:
alpha (float, optional): Slope of the activation function at :math:`x < 0` .
Default: 0.01.
negative_slope (float, optional): Slope of the activation function at
:math:`x < 0` . Default is 0.01.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
......@@ -573,23 +579,23 @@ class LeakyReLU(layers.Layer):
Examples:
.. code-block:: python
import paddle
import numpy as np
import paddle
import numpy as np
paddle.disable_static()
paddle.disable_static()
lrelu = paddle.nn.LeakyReLU()
x = paddle.to_tensor(np.array([-2, 0, 1], 'float32'))
out = lrelu(x) # [-0.02, 0., 1.]
m = paddle.nn.LeakyReLU()
x = paddle.to_tensor(np.array([-2, 0, 1]))
out = m(x) # [-0.02, 0., 1.]
"""
def __init__(self, alpha=1e-2, name=None):
def __init__(self, negative_slope=0.01, name=None):
super(LeakyReLU, self).__init__()
self._alpha = alpha
self._negative_slope = negative_slope
self._name = name
def forward(self, x):
return F.leaky_relu(x, self._alpha, self._name)
return F.leaky_relu(x, self._negative_slope, self._name)
class Sigmoid(layers.Layer):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册