Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
f8863e06
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f8863e06
编写于
8月 22, 2020
作者:
Z
zhupengyang
提交者:
GitHub
8月 22, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
leaky_relu and LeakyReLU: alpha->negative_slope (#26216)
上级
c6090660
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
187 addition
and
92 deletion
+187
-92
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+4
-4
paddle/fluid/operators/activation_op.h
paddle/fluid/operators/activation_op.h
+16
-12
paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h
paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h
+12
-12
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+2
-17
python/paddle/fluid/tests/unittests/test_activation_op.py
python/paddle/fluid/tests/unittests/test_activation_op.py
+84
-12
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+0
-22
python/paddle/nn/functional/activation.py
python/paddle/nn/functional/activation.py
+51
-1
python/paddle/nn/layer/activation.py
python/paddle/nn/layer/activation.py
+18
-12
未找到文件。
paddle/fluid/operators/activation_op.cc
浏览文件 @
f8863e06
...
...
@@ -781,8 +781,8 @@ class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
}
};
// leaky_relu Grad: dx=dy if
y
>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if
y
>=0 else alpha * ddx
// leaky_relu Grad: dx=dy if
x
>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if
x
>=0 else alpha * ddx
template
<
typename
T
>
class
LeakyReluDoubleGradMaker
:
public
::
paddle
::
framework
::
SingleGradOpMaker
<
T
>
{
...
...
@@ -792,8 +792,8 @@ class LeakyReluDoubleGradMaker
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"leaky_relu_grad_grad"
);
// input1:
Out
op
->
SetInput
(
"
Out"
,
this
->
Input
(
"Out
"
));
// input1:
X
op
->
SetInput
(
"
X"
,
this
->
Input
(
"X
"
));
// X@GRAD@GRAD: ddx
op
->
SetInput
(
"DDX"
,
this
->
OutputGrad
(
framework
::
GradVarName
(
"X"
)));
op
->
SetAttrMap
(
this
->
Attrs
());
...
...
paddle/fluid/operators/activation_op.h
浏览文件 @
f8863e06
...
...
@@ -1084,7 +1084,11 @@ struct LeakyReluFunctor : public BaseActivationFunctor<T> {
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
const
{
out
.
device
(
d
)
=
x
.
cwiseMax
(
static_cast
<
T
>
(
alpha
)
*
x
);
if
(
alpha
<
1.
f
)
{
out
.
device
(
d
)
=
x
.
cwiseMax
(
static_cast
<
T
>
(
alpha
)
*
x
);
}
else
{
out
.
device
(
d
)
=
x
.
cwiseMin
(
static_cast
<
T
>
(
alpha
)
*
x
);
}
}
};
...
...
@@ -1098,12 +1102,12 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
const
{
auto
temp1
=
static_cast
<
T
>
(
alpha
)
*
(
out
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
auto
temp2
=
(
out
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
static_cast
<
T
>
(
alpha
)
*
(
x
<
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
auto
temp2
=
(
x
>=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
dx
.
device
(
d
)
=
dout
*
(
temp1
+
temp2
).
template
cast
<
T
>();
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDep
Out
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDep
X
;
}
};
template
<
typename
T
>
...
...
@@ -1451,18 +1455,18 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
auto
*
d
=
dev
.
eigen_device
();
auto
ddx
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddX
,
"Input"
,
"DDX"
,
"LeakyReluGradGrad"
));
auto
out
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
Out
,
"Output"
,
"Out
"
,
"LeakyReluGradGrad"
));
auto
x
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
X
,
"Input"
,
"X
"
,
"LeakyReluGradGrad"
));
auto
ddout
=
framework
::
EigenVector
<
T
>::
Flatten
(
GET_DATA_SAFELY
(
ddOut
,
"Output"
,
"DOut"
,
"LeakyReluGradGrad"
));
ddout
.
device
(
*
d
)
=
ddx
*
((
out
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>()
+
static_cast
<
T
>
(
alpha
)
*
(
out
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>())
.
template
cast
<
T
>();
ddout
.
device
(
*
d
)
=
ddx
*
((
x
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>()
+
static_cast
<
T
>
(
alpha
)
*
(
x
<=
static_cast
<
T
>
(
0
)).
template
cast
<
T
>())
.
template
cast
<
T
>();
}
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDep
Out
;
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDep
X
;
}
};
template
<
typename
T
>
...
...
paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h
浏览文件 @
f8863e06
...
...
@@ -41,12 +41,12 @@ static void InitRandom(framework::Tensor *tensor,
template
<
typename
T
>
struct
LeakyReluGradGradEachElementFunctor
{
LeakyReluGradGradEachElementFunctor
(
const
T
*
ddx
,
const
T
*
out
,
T
alpha
,
LeakyReluGradGradEachElementFunctor
(
const
T
*
ddx
,
const
T
*
x
,
T
alpha
,
T
*
ddout
)
:
ddx_
(
ddx
),
out_
(
out
),
alpha_
(
alpha
),
ddout_
(
ddout
)
{}
:
ddx_
(
ddx
),
x_
(
x
),
alpha_
(
alpha
),
ddout_
(
ddout
)
{}
HOSTDEVICE
void
operator
()(
int
idx
)
{
if
(
out_
[
idx
]
>
0
)
{
if
(
x_
[
idx
]
>=
0
)
{
ddout_
[
idx
]
=
ddx_
[
idx
];
}
else
{
ddout_
[
idx
]
=
ddx_
[
idx
]
*
alpha_
;
...
...
@@ -54,7 +54,7 @@ struct LeakyReluGradGradEachElementFunctor {
}
const
T
*
ddx_
;
const
T
*
out
_
;
const
T
*
x
_
;
T
alpha_
;
T
*
ddout_
;
};
...
...
@@ -66,13 +66,13 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim,
LeakyReluGradGradFunctor
<
T
>
functor
;
functor
.
alpha
=
alpha
;
auto
&
dev_ctx
=
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
);
framework
::
Tensor
*
x
=
nullptr
;
framework
::
Tensor
*
out
=
nullptr
;
framework
::
Tensor
*
dout
=
nullptr
;
framework
::
Tensor
*
dx
=
nullptr
;
framework
::
Tensor
out
;
out
.
Resize
(
dim
);
InitRandom
<
T
>
(
&
out
,
place
);
framework
::
Tensor
x
;
x
.
Resize
(
dim
);
InitRandom
<
T
>
(
&
x
,
place
);
framework
::
Tensor
ddx
;
ddx
.
Resize
(
dim
);
...
...
@@ -85,22 +85,22 @@ static bool TestLeakyReluGradGradMain(const framework::DDim &dim,
framework
::
Tensor
ddout_actual
;
ddout_actual
.
mutable_data
<
T
>
(
dim
,
place
);
LeakyReluGradGradEachElementFunctor
<
T
>
actual_functor
(
ddx
.
data
<
T
>
(),
out
.
data
<
T
>
(),
static_cast
<
T
>
(
alpha
),
ddx
.
data
<
T
>
(),
x
.
data
<
T
>
(),
static_cast
<
T
>
(
alpha
),
ddout_actual
.
data
<
T
>
());
int64_t
limit
=
out
.
numel
();
int64_t
limit
=
x
.
numel
();
#ifdef __NVCC__
if
(
platform
::
is_gpu_place
(
place
))
{
auto
&
cuda_dev_ctx
=
dynamic_cast
<
platform
::
CUDADeviceContext
&>
(
dev_ctx
);
functor
(
cuda_dev_ctx
,
x
,
&
out
,
&
ddx
,
&
ddout
,
dout
,
dx
);
functor
(
cuda_dev_ctx
,
&
x
,
out
,
&
ddx
,
&
ddout
,
dout
,
dx
);
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
cuda_dev_ctx
,
limit
);
for_range
(
actual_functor
);
}
else
{
#endif
auto
&
cpu_dev_ctx
=
dynamic_cast
<
platform
::
CPUDeviceContext
&>
(
dev_ctx
);
functor
(
cpu_dev_ctx
,
x
,
&
out
,
&
ddx
,
&
ddout
,
dout
,
dx
);
functor
(
cpu_dev_ctx
,
&
x
,
out
,
&
ddx
,
&
ddout
,
dout
,
dx
);
platform
::
ForRange
<
platform
::
CPUDeviceContext
>
for_range
(
cpu_dev_ctx
,
limit
);
for_range
(
actual_functor
);
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
f8863e06
...
...
@@ -9772,13 +9772,10 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
return out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.leaky_relu")
@templatedoc()
def leaky_relu(x, alpha=0.02, name=None):
"""
:alias_main: paddle.nn.functional.leaky_relu
:alias: paddle.nn.functional.leaky_relu,paddle.nn.functional.activation.leaky_relu
:old_api: paddle.fluid.layers.leaky_relu
${comment}
Args:
x(${x_type}): ${x_comment}
...
...
@@ -9807,19 +9804,7 @@ def leaky_relu(x, alpha=0.02, name=None):
res_val, = exe.run(fluid.default_main_program(), feed={'x':x_i}, fetch_list=[res])
print(res_val) # [[-0.1, 2], [3, -0.4]]
"""
if in_dygraph_mode():
return core.ops.leaky_relu(x, 'alpha', alpha)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'leaky_relu')
inputs = {'X': [x]}
attrs = {'alpha': alpha}
helper = LayerHelper('leaky_relu', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='leaky_relu', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out
return paddle.nn.functional.leaky_relu(x, alpha, name)
def soft_relu(x, threshold=40.0, name=None):
...
...
python/paddle/fluid/tests/unittests/test_activation_op.py
浏览文件 @
f8863e06
...
...
@@ -903,18 +903,30 @@ class TestReluAPI(unittest.TestCase):
F
.
relu
(
x_fp16
)
def
ref_leaky_relu
(
x
,
alpha
=
0.01
):
out
=
np
.
copy
(
x
)
out
[
out
<
0
]
*=
alpha
return
out
class
TestLeakyRelu
(
TestActivation
):
def
get_alpha
(
self
):
return
0.02
def
setUp
(
self
):
self
.
op_type
=
"leaky_relu"
self
.
init_dtype
()
alpha
=
self
.
get_alpha
()
np
.
random
.
seed
(
10
)
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
# The same reason with TestAbs
x
[
np
.
abs
(
x
)
<
0.005
]
=
0.0
2
out
=
np
.
maximum
(
x
,
0.02
*
x
)
x
[
np
.
abs
(
x
)
<
0.005
]
=
0.0
5
out
=
ref_leaky_relu
(
x
,
alpha
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)
}
self
.
inputs
=
{
'X'
:
x
}
self
.
outputs
=
{
'Out'
:
out
}
self
.
attrs
=
{
'alpha'
:
alpha
}
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
...
...
@@ -922,18 +934,78 @@ class TestLeakyRelu(TestActivation):
self
.
check_grad
([
'X'
],
'Out'
)
class
TestLeakyReluOpError
(
unittest
.
TestCase
):
class
TestLeakyReluAlpha1
(
TestLeakyRelu
):
def
get_alpha
(
self
):
return
2
class
TestLeakyReluAlpha2
(
TestLeakyRelu
):
def
get_alpha
(
self
):
return
-
0.01
class
TestLeakyReluAlpha3
(
TestLeakyRelu
):
def
get_alpha
(
self
):
return
-
2.0
class
TestLeakyReluAPI
(
unittest
.
TestCase
):
# test paddle.nn.LeakyReLU, paddle.nn.functional.leaky_relu,
# fluid.layers.leaky_relu
def
setUp
(
self
):
self
.
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
[
10
,
12
]).
astype
(
'float32'
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_static_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
data
(
'X'
,
[
10
,
12
])
out1
=
F
.
leaky_relu
(
x
)
m
=
paddle
.
nn
.
LeakyReLU
()
out2
=
m
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out1
,
out2
])
out_ref
=
ref_leaky_relu
(
self
.
x_np
)
for
r
in
res
:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
),
True
)
def
test_dygraph_api
(
self
):
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_variable
(
self
.
x_np
)
out1
=
F
.
leaky_relu
(
x
)
m
=
paddle
.
nn
.
LeakyReLU
()
out2
=
m
(
x
)
out_ref
=
ref_leaky_relu
(
self
.
x_np
)
for
r
in
[
out1
,
out2
]:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
.
numpy
()),
True
)
out1
=
F
.
leaky_relu
(
x
,
0.6
)
m
=
paddle
.
nn
.
LeakyReLU
(
0.6
)
out2
=
m
(
x
)
out_ref
=
ref_leaky_relu
(
self
.
x_np
,
0.6
)
for
r
in
[
out1
,
out2
]:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
.
numpy
()),
True
)
paddle
.
enable_static
()
def
test_fluid_api
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
'X'
,
[
10
,
12
])
out
=
fluid
.
layers
.
leaky_relu
(
x
,
0.01
)
exe
=
fluid
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
out_ref
=
ref_leaky_relu
(
self
.
x_np
)
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
res
[
0
]),
True
)
def
test_errors
(
self
):
with
p
rogram_guard
(
Program
()):
with
p
addle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
leaky_relu
,
1
)
self
.
assertRaises
(
TypeError
,
F
.
leaky_relu
,
1
)
# The input dtype must be float16, float32, float64.
x_int32
=
fluid
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
leaky_relu
,
x_int32
)
# support the input dtype is float32
x_fp16
=
fluid
.
layers
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float32'
)
fluid
.
layers
.
leaky_relu
(
x_fp16
)
x_int32
=
paddle
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
F
.
leaky_relu
,
x_int32
)
# support the input dtype is float16
x_fp16
=
paddle
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
F
.
leaky_relu
(
x_fp16
)
def
gelu
(
x
,
approximate
):
...
...
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
f8863e06
...
...
@@ -316,21 +316,6 @@ class TestLayer(LayerTest):
self
.
assertTrue
(
np
.
allclose
(
static_ret
,
dy_ret_value
))
def
test_leakyrelu
(
self
):
inputs
=
np
.
random
.
uniform
(
-
1
,
1
,
(
10
,
10
)).
astype
(
'float32'
)
with
self
.
static_graph
():
t
=
layers
.
data
(
name
=
't'
,
shape
=
[
10
,
10
],
dtype
=
'float32'
)
ret
=
layers
.
leaky_relu
(
t
,
alpha
=
0.01
)
static_ret
=
self
.
get_static_graph_result
(
feed
=
{
't'
:
inputs
},
fetch_list
=
[
ret
])[
0
]
with
self
.
dynamic_graph
():
lrelu
=
paddle
.
nn
.
LeakyReLU
(
alpha
=
0.01
)
dy_ret
=
lrelu
(
base
.
to_variable
(
inputs
))
dy_ret_value
=
dy_ret
.
numpy
()
self
.
assertTrue
(
np
.
allclose
(
static_ret
,
dy_ret_value
))
def
test_pad2d
(
self
):
with
self
.
static_graph
():
t
=
layers
.
data
(
name
=
't'
,
shape
=
[
-
1
,
3
,
5
,
5
],
dtype
=
'float32'
)
...
...
@@ -2678,13 +2663,6 @@ class TestBook(LayerTest):
out
=
layers
.
brelu
(
input
,
t_min
=
1.0
,
t_max
=
20.0
,
name
=
'brelu'
)
return
(
out
)
def
make_leaky_relu
(
self
):
with
program_guard
(
fluid
.
default_main_program
(),
fluid
.
default_startup_program
()):
input
=
self
.
_get_data
(
name
=
"input"
,
shape
=
[
16
],
dtype
=
"float32"
)
out
=
layers
.
leaky_relu
(
input
,
alpha
=
0.1
,
name
=
'leaky_relu'
)
return
(
out
)
def
make_soft_relu
(
self
):
with
program_guard
(
fluid
.
default_main_program
(),
fluid
.
default_startup_program
()):
...
...
python/paddle/nn/functional/activation.py
浏览文件 @
f8863e06
...
...
@@ -17,7 +17,6 @@ from ...fluid.layers import brelu #DEFINE_ALIAS
from
...fluid.layers
import
erf
#DEFINE_ALIAS
from
...fluid.layers
import
hard_sigmoid
#DEFINE_ALIAS
from
...fluid.layers
import
hard_swish
#DEFINE_ALIAS
from
...fluid.layers
import
leaky_relu
#DEFINE_ALIAS
from
...fluid.layers
import
maxout
#DEFINE_ALIAS
from
...fluid.layers
import
soft_relu
#DEFINE_ALIAS
from
...fluid.layers
import
swish
#DEFINE_ALIAS
...
...
@@ -386,6 +385,57 @@ def hsigmoid(input,
return
out
def
leaky_relu
(
x
,
negative_slope
=
0.01
,
name
=
None
):
"""
leaky_relu activation
.. math:
leaky_relu(x)=
\left\{
\b
egin{aligned}
&x, & & if \ x >= 0
\\
&negative\_slope * x, & & otherwise
\\
\end{aligned}
\r
ight.
\\
Args:
x (Tensor): The input Tensor with data type float32, float64.
negative_slope (float, optional): Slope of the activation function at
:math:`x < 0` . Default is 0.01.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-2, 0, 1]))
out = F.leaky_relu(x) # [-0.02, 0., 1.]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
leaky_relu
(
x
,
'alpha'
,
negative_slope
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'leaky_relu'
)
helper
=
LayerHelper
(
'leaky_relu'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'leaky_relu'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'alpha'
:
negative_slope
})
return
out
def
prelu
(
x
,
weight
,
name
=
None
):
"""
prelu activation.
...
...
python/paddle/nn/layer/activation.py
浏览文件 @
f8863e06
...
...
@@ -558,11 +558,17 @@ class LeakyReLU(layers.Layer):
.. math:
out = max(x, alpha * x)
LeakyReLU(x)=
\left\{
\b
egin{aligned}
&x, & & if \ x >= 0
\\
&negative\_slope * x, & & otherwise
\\
\end{aligned}
\r
ight.
\\
Parameters:
alpha (float, optional): Slope of the activation function at :math:`x < 0` .
Default:
0.01.
negative_slope (float, optional): Slope of the activation function at
:math:`x < 0` . Default is
0.01.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
...
...
@@ -573,23 +579,23 @@ class LeakyReLU(layers.Layer):
Examples:
.. code-block:: python
import paddle
import numpy as np
import paddle
import numpy as np
paddle.disable_static()
paddle.disable_static()
lrelu
= paddle.nn.LeakyReLU()
x = paddle.to_tensor(np.array([-2, 0, 1], 'float32'
))
out = lrelu
(x) # [-0.02, 0., 1.]
m
= paddle.nn.LeakyReLU()
x = paddle.to_tensor(np.array([-2, 0, 1]
))
out = m
(x) # [-0.02, 0., 1.]
"""
def
__init__
(
self
,
alpha
=
1e-2
,
name
=
None
):
def
__init__
(
self
,
negative_slope
=
0.01
,
name
=
None
):
super
(
LeakyReLU
,
self
).
__init__
()
self
.
_
alpha
=
alpha
self
.
_
negative_slope
=
negative_slope
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
leaky_relu
(
x
,
self
.
_
alpha
,
self
.
_name
)
return
F
.
leaky_relu
(
x
,
self
.
_
negative_slope
,
self
.
_name
)
class
Sigmoid
(
layers
.
Layer
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录