Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
40d193ed
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
40d193ed
编写于
8月 20, 2020
作者:
H
hong19860320
提交者:
GitHub
8月 20, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the ReLU6, Tanhshrink, SELU, Softplus, Softshrink and Softsign for the api 2.0 (#26376)
上级
6e13e86a
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
993 addition
and
116 deletion
+993
-116
paddle/fluid/operators/activation_op.cc
paddle/fluid/operators/activation_op.cc
+30
-8
paddle/fluid/operators/activation_op.h
paddle/fluid/operators/activation_op.h
+27
-13
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+2
-7
python/paddle/fluid/layers/ops.py
python/paddle/fluid/layers/ops.py
+22
-20
python/paddle/fluid/tests/unittests/test_activation_op.py
python/paddle/fluid/tests/unittests/test_activation_op.py
+302
-42
python/paddle/fluid/tests/unittests/test_selu_op.py
python/paddle/fluid/tests/unittests/test_selu_op.py
+71
-18
python/paddle/nn/__init__.py
python/paddle/nn/__init__.py
+6
-0
python/paddle/nn/functional/__init__.py
python/paddle/nn/functional/__init__.py
+1
-1
python/paddle/nn/functional/activation.py
python/paddle/nn/functional/activation.py
+280
-7
python/paddle/nn/layer/activation.py
python/paddle/nn/layer/activation.py
+252
-0
未找到文件。
paddle/fluid/operators/activation_op.cc
浏览文件 @
40d193ed
...
...
@@ -317,13 +317,6 @@ $$out = x^2$$
)DOC"
;
UNUSED
constexpr
char
SoftplusDoc
[]
=
R"DOC(
Softplus Activation Operator.
$$out = \ln(1 + e^{x})$$
)DOC"
;
UNUSED
constexpr
char
SoftsignDoc
[]
=
R"DOC(
Softsign Activation Operator.
...
...
@@ -396,6 +389,36 @@ $$out = \max(x, \alpha * x)$$
}
};
class
SoftplusOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"Input of Softplus operator, an N-D Tensor, with data type "
"float32, float64 or float16."
);
AddOutput
(
"Out"
,
"Output of Softplus operator, a Tensor with shape same as input."
);
AddAttr
<
float
>
(
"beta"
,
"The value of beta for Softplus."
).
SetDefault
(
1.0
f
);
AddAttr
<
float
>
(
"threshold"
,
"The value of threshold for Softplus."
)
.
SetDefault
(
20.0
f
);
AddAttr
<
bool
>
(
"use_mkldnn"
,
"(bool, default false) Only used in mkldnn kernel."
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"use_cudnn"
,
"(bool, default false) Only used in cudnn kernel, need install cudnn."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
:strong:`Softplus Activation Operator`
.. math::
out = \frac{1}{\beta} * \log(1 + \exp(\beta * x)) \\
\text{For numerical stability, the implementation reverts to the linear function when :}\,x \times \beta > threshold.
)DOC"
);
}
};
class
SoftShrinkOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
...
...
@@ -672,7 +695,6 @@ REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER
(
Log
,
LogDoc
);
REGISTER_ACTIVATION_OP_MAKER
(
Log1p
,
Log1pDoc
);
REGISTER_ACTIVATION_OP_MAKER
(
Square
,
SquareDoc
);
REGISTER_ACTIVATION_OP_MAKER
(
Softplus
,
SoftplusDoc
);
REGISTER_ACTIVATION_OP_MAKER
(
Softsign
,
SoftsignDoc
);
template
<
ActBwdOpFwdDeps
kDepValue
>
...
...
paddle/fluid/operators/activation_op.h
浏览文件 @
40d193ed
...
...
@@ -975,32 +975,46 @@ struct HardSwishGradFunctor : public BaseActivationFunctor<T> {
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
};
// softplus(x) = log(1 + exp(x))
// When x is a very large positive number, exp(x) may explode to inf,
// Using trick below for numerical stability
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
// For numerical stability, using the following formula instead of softplus(x) =
// log(1 + exp(x))
// softplus(x) = log(1 + exp(beta * x)) / beta when beta * x <= threshold(beta =
// 1, threshold = 20 by default), otherwise x
template
<
typename
T
>
struct
SoftplusFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
beta
;
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
},
{
"threshold"
,
&
threshold
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
>
void
operator
()(
Device
d
,
X
x
,
Out
out
)
{
auto
temp
=
x
.
cwiseMax
(
static_cast
<
T
>
(
0
));
// temp = max(x, 0)
out
.
device
(
d
)
=
temp
+
(((
-
temp
).
exp
()
+
(
x
-
temp
).
exp
()).
log
());
auto
x_beta
=
static_cast
<
T
>
(
beta
)
*
x
;
out
.
device
(
d
)
=
(
x_beta
>
static_cast
<
T
>
(
threshold
))
.
select
(
x
,
(
static_cast
<
T
>
(
1
)
+
x_beta
.
exp
()).
log
()
/
static_cast
<
T
>
(
beta
));
}
};
//
d(softplus(x))/dx = exp(x) / (1 + exp(x))
//
For numerical stability:
// d(softplus(x))/dx =
exp(x - max(x, 0)) / (exp(-max(x, 0)) +
//
exp(x - max(x, 0)))
//
For numerical stability, using the following formula instead of
//
d(softplus(x))/dx = 1 / (1 + exp(-x))
// d(softplus(x))/dx =
1 / (1 + exp(-beta * x)) when beta * x <= threshold(beta
//
= 1, threshold = 20 by default), otherwise x
template
<
typename
T
>
struct
SoftplusGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
beta
;
float
threshold
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"beta"
,
&
beta
},
{
"threshold"
,
&
threshold
}};
}
template
<
typename
Device
,
typename
X
,
typename
Out
,
typename
dOut
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Out
out
,
dOut
dout
,
dX
dx
)
{
auto
temp
=
x
.
cwiseMax
(
static_cast
<
T
>
(
0
));
// temp = max(x, 0)
auto
x_beta
=
static_cast
<
T
>
(
beta
)
*
x
;
dx
.
device
(
d
)
=
dout
*
((
x
-
temp
).
exp
()
/
((
-
temp
).
exp
()
+
(
x
-
temp
).
exp
()));
(
x_beta
>
static_cast
<
T
>
(
threshold
))
.
select
(
dout
,
dout
/
(
static_cast
<
T
>
(
1
)
+
(
-
x_beta
).
exp
()));
}
static
constexpr
ActBwdOpFwdDeps
FwdDeps
()
{
return
kDepX
;
}
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
40d193ed
...
...
@@ -8643,11 +8643,9 @@ def relu(x, name=None):
return out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.selu")
def selu(x, scale=None, alpha=None, name=None):
"""
:alias_main: paddle.nn.functional.selu
:alias: paddle.nn.functional.selu,paddle.nn.functional.activation.selu
:old_api: paddle.fluid.layers.selu
Selu Operator.
...
...
@@ -9304,12 +9302,9 @@ def elu(x, alpha=1.0, name=None):
return out
@
templatedoc(
)
@
deprecated(since="2.0.0", update_to="paddle.nn.functional.relu6"
)
def relu6(x, threshold=6.0, name=None):
"""
:alias_main: paddle.nn.functional.relu6
:alias: paddle.nn.functional.relu6,paddle.nn.functional.activation.relu6
:old_api: paddle.fluid.layers.relu6
${comment}
...
...
python/paddle/fluid/layers/ops.py
浏览文件 @
40d193ed
...
...
@@ -20,6 +20,8 @@ from ..framework import convert_np_dtype_to_dtype_, Variable
from
..data_feeder
import
convert_dtype
,
check_variable_and_dtype
,
check_type
,
check_dtype
from
paddle.utils
import
deprecated
__deprecated_func_name__
=
{
'tanh_shrink'
:
'tanhshrink'
,
}
__activations_noattr__
=
[
'sigmoid'
,
'logsigmoid'
,
...
...
@@ -64,14 +66,20 @@ __all__ += __activations_noattr__
__all__
+=
__unary_func__
for
_OP
in
set
(
__activations_noattr__
):
_new_OP
=
_OP
if
_OP
in
__deprecated_func_name__
:
_new_OP
=
__deprecated_func_name__
[
_OP
]
func
=
generate_activation_fn
(
_OP
)
func
=
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.nn.functional.%s"
%
(
_OP
))(
func
)
since
=
"2.0.0"
,
update_to
=
"paddle.nn.functional.%s"
%
(
_
new_
OP
))(
func
)
globals
()[
_OP
]
=
func
for
_OP
in
set
(
__unary_func__
):
_new_OP
=
_OP
if
_OP
in
__deprecated_func_name__
:
_new_OP
=
__deprecated_func_name__
[
_OP
]
func
=
generate_activation_fn
(
_OP
)
func
=
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.%s"
%
(
_OP
))(
func
)
func
=
deprecated
(
since
=
"2.0.0"
,
update_to
=
"paddle.%s"
%
(
_
new_
OP
))(
func
)
globals
()[
_OP
]
=
func
add_sample_code
(
globals
()[
"sigmoid"
],
r
"""
...
...
@@ -160,16 +168,14 @@ add_sample_code(globals()["tanh_shrink"], r"""
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x_data = np.array([-0.4, -0.2, 0.1, 0.3])
x = paddle.to_variable(x_data)
out = F.tanh_shrink(x)
print(out.numpy())
# [-0.02005104 -0.00262468 0.00033201 0.00868739]
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
out = F.tanhshrink(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739]
"""
)
...
...
@@ -401,16 +407,14 @@ add_sample_code(globals()["softplus"], r"""
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x_data = np.array([-0.4, -0.2, 0.1, 0.3])
x = paddle.to_variable(x_data)
out = F.softplus(x)
print(out.numpy())
# [0.51301525 0.59813887 0.74439666 0.85435524]
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
out = F.softplus(x) # [0.513015, 0.598139, 0.744397, 0.854355]
"""
)
...
...
@@ -418,16 +422,14 @@ add_sample_code(globals()["softsign"], r"""
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x_data = np.array([-0.4, -0.2, 0.1, 0.3])
x = paddle.to_variable(x_data)
out = F.softsign(x)
print(out.numpy())
# [-0.28571429 -0.16666667 0.09090909 0.23076923]
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
out = F.softsign(x) # [-0.285714, -0.166667, 0.0909091, 0.230769]
"""
)
...
...
python/paddle/fluid/tests/unittests/test_activation_op.py
浏览文件 @
40d193ed
...
...
@@ -369,15 +369,20 @@ class TestCoshOpError(unittest.TestCase):
fluid
.
layers
.
cosh
(
x_fp16
)
class
TestTanhShrink
(
TestActivation
):
def
ref_tanhshrink
(
x
):
out
=
x
-
np
.
tanh
(
x
)
return
out
class
TestTanhshrink
(
TestActivation
):
def
setUp
(
self
):
self
.
op_type
=
"tanh_shrink"
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
10
,
17
]).
astype
(
self
.
dtype
)
out
=
x
-
np
.
tanh
(
x
)
x
=
np
.
random
.
uniform
(
10
,
20
,
[
10
,
17
]).
astype
(
self
.
dtype
)
out
=
ref_tanhshrink
(
x
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)
}
self
.
inputs
=
{
'X'
:
x
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_grad
(
self
):
...
...
@@ -386,6 +391,57 @@ class TestTanhShrink(TestActivation):
self
.
check_grad
([
'X'
],
'Out'
)
class
TestTanhshrinkAPI
(
unittest
.
TestCase
):
# test paddle.nn.Tanhshrink, paddle.nn.functional.tanhshrink
def
setUp
(
self
):
self
.
x_np
=
np
.
random
.
uniform
(
10
,
20
,
[
10
,
17
]).
astype
(
np
.
float64
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_static_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out1
=
F
.
tanhshrink
(
x
)
tanhshrink
=
paddle
.
nn
.
Tanhshrink
()
out2
=
tanhshrink
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out1
,
out2
])
out_ref
=
ref_tanhshrink
(
self
.
x_np
)
for
r
in
res
:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
),
True
)
def
test_dygraph_api
(
self
):
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out1
=
F
.
tanhshrink
(
x
)
tanhshrink
=
paddle
.
nn
.
Tanhshrink
()
out2
=
tanhshrink
(
x
)
out_ref
=
ref_tanhshrink
(
self
.
x_np
)
for
r
in
[
out1
,
out2
]:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
.
numpy
()),
True
)
paddle
.
enable_static
()
def
test_fluid_api
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out
=
fluid
.
layers
.
tanh_shrink
(
x
)
exe
=
fluid
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
out_ref
=
ref_tanhshrink
(
self
.
x_np
)
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
res
[
0
]),
True
)
def
test_errors
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
F
.
tanhshrink
,
1
)
# The input dtype must be float16, float32, float64.
x_int32
=
paddle
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
F
.
tanhshrink
,
x_int32
)
# support the input dtype is float16
x_fp16
=
paddle
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
F
.
tanhshrink
(
x_fp16
)
def
ref_hardshrink
(
x
,
threshold
):
out
=
np
.
copy
(
x
)
out
[(
out
>=
-
threshold
)
&
(
out
<=
threshold
)]
=
0
...
...
@@ -469,19 +525,24 @@ class TestHardShrinkAPI(unittest.TestCase):
F
.
hardshrink
(
x_fp16
)
class
TestSoftShrink
(
TestActivation
):
def
ref_softshrink
(
x
,
threshold
=
0.5
):
out
=
np
.
copy
(
x
)
out
=
(
out
<
-
threshold
)
*
(
out
+
threshold
)
+
(
out
>
threshold
)
*
(
out
-
threshold
)
return
out
class
TestSoftshrink
(
TestActivation
):
def
setUp
(
self
):
self
.
op_type
=
"softshrink"
self
.
init_dtype
()
lambda_val
=
0.1
x
=
np
.
random
.
uniform
(
0.25
,
10
,
[
10
,
12
]).
astype
(
self
.
dtype
)
out
=
np
.
copy
(
x
)
out
=
(
out
<
-
lambda_val
)
*
(
out
+
lambda_val
)
+
(
out
>
lambda_val
)
*
(
out
-
lambda_val
)
threshold
=
0.8
self
.
attrs
=
{
'lambda'
:
lambda_val
}
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
x
=
np
.
random
.
uniform
(
0.25
,
10
,
[
10
,
12
]).
astype
(
self
.
dtype
)
out
=
ref_softshrink
(
x
,
threshold
)
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
"lambda"
:
threshold
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_grad
(
self
):
...
...
@@ -490,17 +551,56 @@ class TestSoftShrink(TestActivation):
self
.
check_grad
([
'X'
],
'Out'
)
class
TestSoftShrinkOpError
(
unittest
.
TestCase
):
class
TestSoftshrinkAPI
(
unittest
.
TestCase
):
# test paddle.nn.Softshrink, paddle.nn.functional.softshrink
def
setUp
(
self
):
self
.
threshold
=
0.8
self
.
x_np
=
np
.
random
.
uniform
(
0.25
,
10
,
[
10
,
12
]).
astype
(
np
.
float64
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_static_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out1
=
F
.
softshrink
(
x
,
self
.
threshold
)
softshrink
=
paddle
.
nn
.
Softshrink
(
self
.
threshold
)
out2
=
softshrink
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out1
,
out2
])
out_ref
=
ref_softshrink
(
self
.
x_np
,
self
.
threshold
)
for
r
in
res
:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
),
True
)
def
test_dygraph_api
(
self
):
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out1
=
F
.
softshrink
(
x
,
self
.
threshold
)
softshrink
=
paddle
.
nn
.
Softshrink
(
self
.
threshold
)
out2
=
softshrink
(
x
)
out_ref
=
ref_softshrink
(
self
.
x_np
,
self
.
threshold
)
for
r
in
[
out1
,
out2
]:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
.
numpy
()),
True
)
paddle
.
enable_static
()
def
test_fluid_api
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out
=
fluid
.
layers
.
softshrink
(
x
,
self
.
threshold
)
exe
=
fluid
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
out_ref
=
ref_softshrink
(
self
.
x_np
,
self
.
threshold
)
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
res
[
0
]),
True
)
def
test_errors
(
self
):
with
p
rogram_guard
(
Program
()):
with
p
addle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
softshrink
,
1
)
self
.
assertRaises
(
TypeError
,
F
.
softshrink
,
1
)
# The input dtype must be float16, float32, float64.
x_int32
=
fluid
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
softshrink
,
x_int32
)
x_int32
=
paddle
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
F
.
softshrink
,
x_int32
)
# support the input dtype is float16
x_fp16
=
fluid
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
fluid
.
layers
.
softshrink
(
x_fp16
)
x_fp16
=
paddle
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
F
.
softshrink
(
x_fp16
)
class
TestSqrt
(
TestActivation
,
TestParameter
):
...
...
@@ -903,20 +1003,24 @@ class TestBReluOpError(unittest.TestCase):
fluid
.
layers
.
brelu
(
x_fp16
)
def
ref_relu6
(
x
,
threshold
=
6.0
):
out
=
np
.
copy
(
x
)
out
[
np
.
abs
(
x
-
threshold
)
<
0.005
]
=
threshold
+
0.02
out
=
np
.
minimum
(
np
.
maximum
(
x
,
0
),
threshold
)
return
out
class
TestRelu6
(
TestActivation
):
def
setUp
(
self
):
self
.
op_type
=
"relu6"
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
10
,
[
10
,
12
]).
astype
(
self
.
dtype
)
threshold
=
6.0
# The same with TestAbs
x
[
np
.
abs
(
x
)
<
0.005
]
=
0.02
x
[
np
.
abs
(
x
-
threshold
)
<
0.005
]
=
threshold
+
0.02
out
=
np
.
minimum
(
np
.
maximum
(
x
,
0
),
threshold
)
out
=
ref_relu6
(
x
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)
}
self
.
attrs
=
{
'threshold'
:
threshold
}
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
'threshold'
:
6.0
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_grad
(
self
):
...
...
@@ -925,17 +1029,56 @@ class TestRelu6(TestActivation):
self
.
check_grad
([
'X'
],
'Out'
)
class
TestRelu6OpError
(
unittest
.
TestCase
):
class
TestRelu6API
(
unittest
.
TestCase
):
# test paddle.nn.ReLU6, paddle.nn.functional.relu6
def
setUp
(
self
):
self
.
x_np
=
np
.
random
.
uniform
(
-
1
,
10
,
[
10
,
12
]).
astype
(
np
.
float64
)
self
.
x_np
[
np
.
abs
(
self
.
x_np
)
<
0.005
]
=
0.02
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_static_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out1
=
F
.
relu6
(
x
)
relu6
=
paddle
.
nn
.
ReLU6
()
out2
=
relu6
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out1
,
out2
])
out_ref
=
ref_relu6
(
self
.
x_np
)
for
r
in
res
:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
),
True
)
def
test_dygraph_api
(
self
):
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out1
=
F
.
relu6
(
x
)
relu6
=
paddle
.
nn
.
ReLU6
()
out2
=
relu6
(
x
)
out_ref
=
ref_relu6
(
self
.
x_np
)
for
r
in
[
out1
,
out2
]:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
.
numpy
()),
True
)
paddle
.
enable_static
()
def
test_fluid_api
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out
=
fluid
.
layers
.
relu6
(
x
)
exe
=
fluid
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
out_ref
=
ref_relu6
(
self
.
x_np
)
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
res
[
0
]),
True
)
def
test_errors
(
self
):
with
p
rogram_guard
(
Program
()):
with
p
addle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
relu6
,
1
)
self
.
assertRaises
(
TypeError
,
F
.
relu6
,
1
)
# The input dtype must be float16, float32, float64.
x_int32
=
fluid
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
relu6
,
x_int32
)
x_int32
=
paddle
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
F
.
relu6
,
x_int32
)
# support the input dtype is float16
x_fp16
=
fluid
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
fluid
.
layers
.
relu6
(
x_fp16
)
x_fp16
=
paddle
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
F
.
relu6
(
x_fp16
)
class
TestHardSwish
(
TestActivation
):
...
...
@@ -1318,16 +1461,25 @@ class TestSTanhOpError(unittest.TestCase):
fluid
.
layers
.
stanh
(
x_fp16
)
def
ref_softplus
(
x
,
beta
=
1
,
threshold
=
20
):
x_beta
=
beta
*
x
out
=
np
.
select
([
x_beta
<=
threshold
,
x_beta
>
threshold
],
[
np
.
log
(
1
+
np
.
exp
(
x_beta
))
/
beta
,
x
])
return
out
class
TestSoftplus
(
TestActivation
):
def
setUp
(
self
):
self
.
op_type
=
"softplus"
self
.
init_dtype
()
self
.
dtype
=
np
.
float64
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
out
=
np
.
log
(
1
+
np
.
exp
(
x
))
beta
=
2
threshold
=
15
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
10
,
12
]).
astype
(
self
.
dtype
)
out
=
ref_softplus
(
x
,
beta
,
threshold
)
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
'beta'
:
beta
,
"threshold"
:
threshold
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_grad
(
self
):
...
...
@@ -1336,15 +1488,72 @@ class TestSoftplus(TestActivation):
self
.
check_grad
([
'X'
],
'Out'
)
class
TestSoftplusAPI
(
unittest
.
TestCase
):
# test paddle.nn.Softplus, paddle.nn.functional.softplus
def
setUp
(
self
):
self
.
beta
=
2
self
.
threshold
=
15
self
.
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
[
10
,
12
]).
astype
(
np
.
float64
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_static_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out1
=
F
.
softplus
(
x
,
self
.
beta
,
self
.
threshold
)
softplus
=
paddle
.
nn
.
Softplus
(
self
.
beta
,
self
.
threshold
)
out2
=
softplus
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out1
,
out2
])
out_ref
=
ref_softplus
(
self
.
x_np
,
self
.
beta
,
self
.
threshold
)
for
r
in
res
:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
),
True
)
def
test_dygraph_api
(
self
):
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out1
=
F
.
softplus
(
x
,
self
.
beta
,
self
.
threshold
)
softplus
=
paddle
.
nn
.
Softplus
(
self
.
beta
,
self
.
threshold
)
out2
=
softplus
(
x
)
out_ref
=
ref_softplus
(
self
.
x_np
,
self
.
beta
,
self
.
threshold
)
for
r
in
[
out1
,
out2
]:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
.
numpy
()),
True
)
paddle
.
enable_static
()
def
test_fluid_api
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out
=
fluid
.
layers
.
softplus
(
x
)
exe
=
fluid
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
out_ref
=
ref_softplus
(
self
.
x_np
)
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
res
[
0
]),
True
)
def
test_errors
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
F
.
softplus
,
1
)
# The input dtype must be float16, float32, float64.
x_int32
=
paddle
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
F
.
softplus
,
x_int32
)
# support the input dtype is float16
x_fp16
=
paddle
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
F
.
softplus
(
x_fp16
)
def
ref_softsign
(
x
):
out
=
np
.
divide
(
x
,
1
+
np
.
abs
(
x
))
return
out
class
TestSoftsign
(
TestActivation
):
def
setUp
(
self
):
self
.
op_type
=
"softsign"
self
.
init_dtype
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
11
,
17
]).
astype
(
self
.
dtype
)
out
=
np
.
divide
(
x
,
1
+
np
.
abs
(
x
))
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)}
x
=
np
.
random
.
uniform
(
-
1
,
1
,
[
10
,
12
]).
astype
(
self
.
dtype
)
out
=
ref_softsign
(
x
)
self
.
inputs
=
{
'X'
:
x
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_grad
(
self
):
...
...
@@ -1353,6 +1562,57 @@ class TestSoftsign(TestActivation):
self
.
check_grad
([
'X'
],
'Out'
)
class
TestSoftsignAPI
(
unittest
.
TestCase
):
# test paddle.nn.Softsign, paddle.nn.functional.softsign
def
setUp
(
self
):
self
.
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
[
10
,
12
]).
astype
(
np
.
float64
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_static_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out1
=
F
.
softsign
(
x
)
softsign
=
paddle
.
nn
.
Softsign
()
out2
=
softsign
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out1
,
out2
])
out_ref
=
ref_softsign
(
self
.
x_np
)
for
r
in
res
:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
),
True
)
def
test_dygraph_api
(
self
):
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out1
=
F
.
softsign
(
x
)
softsign
=
paddle
.
nn
.
Softsign
()
out2
=
softsign
(
x
)
out_ref
=
ref_softsign
(
self
.
x_np
)
for
r
in
[
out1
,
out2
]:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
.
numpy
()),
True
)
paddle
.
enable_static
()
def
test_fluid_api
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out
=
fluid
.
layers
.
softsign
(
x
)
exe
=
fluid
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
out_ref
=
ref_softsign
(
self
.
x_np
)
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
res
[
0
]),
True
)
def
test_errors
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
F
.
softsign
,
1
)
# The input dtype must be float16, float32, float64.
x_int32
=
paddle
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
F
.
softsign
,
x_int32
)
# support the input dtype is float16
x_fp16
=
paddle
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
F
.
softsign
(
x_fp16
)
class
TestThresholdedRelu
(
TestActivation
):
def
setUp
(
self
):
self
.
op_type
=
"thresholded_relu"
...
...
@@ -1548,9 +1808,9 @@ create_test_act_fp16_class(TestActivation)
create_test_act_fp16_class
(
TestSigmoid
)
create_test_act_fp16_class
(
TestLogSigmoid
)
create_test_act_fp16_class
(
TestTanh
)
create_test_act_fp16_class
(
TestTanh
S
hrink
)
create_test_act_fp16_class
(
TestTanh
s
hrink
)
create_test_act_fp16_class
(
TestHardShrink
)
create_test_act_fp16_class
(
TestSoft
S
hrink
)
create_test_act_fp16_class
(
TestSoft
s
hrink
)
create_test_act_fp16_class
(
TestSqrt
)
create_test_act_fp16_class
(
TestAbs
)
create_test_act_fp16_class
(
TestCeil
,
grad_check
=
False
)
...
...
python/paddle/fluid/tests/unittests/test_selu_op.py
浏览文件 @
40d193ed
...
...
@@ -17,9 +17,26 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
import
six
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
import
paddle.nn
as
nn
import
paddle.nn.functional
as
F
from
paddle.fluid
import
compiler
,
Program
,
program_guard
def
ref_selu
(
x
,
scale
=
1.0507009873554804934193349852946
,
alpha
=
1.6732632423543772848170429916717
):
out
=
np
.
copy
(
x
)
out_flat
=
out
.
flatten
()
for
i
in
range
(
out_flat
.
size
):
if
out_flat
[
i
]
<
0
:
out_flat
[
i
]
=
alpha
*
np
.
exp
(
out_flat
[
i
])
-
alpha
out_flat
[
i
]
=
scale
*
out_flat
[
i
]
out
=
out_flat
.
reshape
(
x
.
shape
)
return
out
class
SeluTest
(
OpTest
):
...
...
@@ -39,17 +56,10 @@ class SeluTest(OpTest):
# zero.
x
[
np
.
abs
(
x
)
<
0.005
]
=
0.02
x_flat
=
x
.
flatten
()
for
i
in
range
(
x_flat
.
size
):
if
x_flat
[
i
]
<
0
:
x_flat
[
i
]
=
alpha
*
np
.
exp
(
x_flat
[
i
])
-
alpha
x_flat
[
i
]
=
scale
*
x_flat
[
i
]
out_np
=
x_flat
.
reshape
(
self
.
x_shape
)
out
=
ref_selu
(
x
,
scale
,
alpha
)
self
.
inputs
=
{
'X'
:
x
}
self
.
outputs
=
{
'Out'
:
out
_np
}
self
.
outputs
=
{
'Out'
:
out
}
self
.
attrs
=
{
'alpha'
:
alpha
,
...
...
@@ -69,17 +79,60 @@ class SeluTest(OpTest):
self
.
check_grad
([
'X'
],
'Out'
)
class
TestSeluOpError
(
unittest
.
TestCase
):
class
TestSeluAPI
(
unittest
.
TestCase
):
# test paddle.nn.SELU, paddle.nn.functional.selu
def
setUp
(
self
):
self
.
scale
=
1.5
self
.
alpha
=
2.0
self
.
x_np
=
np
.
random
.
normal
(
size
=
[
3
,
5
,
5
,
10
]).
astype
(
np
.
float64
)
# Since zero point in selu is not differentiable, avoid randomize
# zero.
self
.
x_np
[
np
.
abs
(
self
.
x_np
)
<
0.005
]
=
0.02
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_static_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out1
=
F
.
selu
(
x
,
self
.
scale
,
self
.
alpha
)
selu
=
paddle
.
nn
.
SELU
(
self
.
scale
,
self
.
alpha
)
out2
=
selu
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out1
,
out2
])
out_ref
=
ref_selu
(
self
.
x_np
,
self
.
scale
,
self
.
alpha
)
for
r
in
res
:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
),
True
)
def
test_dygraph_api
(
self
):
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out1
=
F
.
selu
(
x
,
self
.
scale
,
self
.
alpha
)
selu
=
paddle
.
nn
.
SELU
(
self
.
scale
,
self
.
alpha
)
out2
=
selu
(
x
)
out_ref
=
ref_selu
(
self
.
x_np
,
self
.
scale
,
self
.
alpha
)
for
r
in
[
out1
,
out2
]:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
.
numpy
()),
True
)
paddle
.
enable_static
()
def
test_fluid_api
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out
=
fluid
.
layers
.
selu
(
x
,
self
.
scale
,
self
.
alpha
)
exe
=
fluid
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
out_ref
=
ref_selu
(
self
.
x_np
,
self
.
scale
,
self
.
alpha
)
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
res
[
0
]),
True
)
def
test_errors
(
self
):
with
p
rogram_guard
(
Program
()):
with
p
addle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
selu
,
1
)
self
.
assertRaises
(
TypeError
,
F
.
selu
,
1
)
# The input dtype must be float16, float32, float64.
x_int32
=
fluid
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
selu
,
x_int32
)
# support the input dtype is float
32
x_fp
32
=
fluid
.
data
(
name
=
'x_fp32'
,
shape
=
[
12
,
10
],
dtype
=
'float32
'
)
fluid
.
layers
.
selu
(
x_fp32
)
x_int32
=
paddle
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
F
.
selu
,
x_int32
)
# support the input dtype is float
16
x_fp
16
=
paddle
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16
'
)
F
.
selu
(
x_fp16
)
if
__name__
==
"__main__"
:
...
...
python/paddle/nn/__init__.py
浏览文件 @
40d193ed
...
...
@@ -57,10 +57,16 @@ from .layer.activation import GELU
from
.layer.activation
import
Hardshrink
# from .layer.activation import PReLU #DEFINE_ALIAS
from
.layer.activation
import
ReLU
from
.layer.activation
import
ReLU6
#DEFINE_ALIAS
from
.layer.activation
import
SELU
#DEFINE_ALIAS
from
.layer.activation
import
LeakyReLU
#DEFINE_ALIAS
from
.layer.activation
import
Sigmoid
#DEFINE_ALIAS
from
.layer.activation
import
LogSigmoid
# from .layer.activation import Softmax #DEFINE_ALIAS
from
.layer.activation
import
Softplus
#DEFINE_ALIAS
from
.layer.activation
import
Softshrink
#DEFINE_ALIAS
from
.layer.activation
import
Softsign
#DEFINE_ALIAS
from
.layer.activation
import
Tanhshrink
#DEFINE_ALIAS
from
.layer.activation
import
LogSoftmax
#DEFINE_ALIAS
from
.layer.activation
import
HSigmoid
#DEFINE_ALIAS
from
.layer.common
import
BilinearTensorProduct
#DEFINE_ALIAS
...
...
python/paddle/nn/functional/__init__.py
浏览文件 @
40d193ed
...
...
@@ -47,7 +47,7 @@ from .activation import softplus #DEFINE_ALIAS
from
.activation
import
softshrink
#DEFINE_ALIAS
from
.activation
import
softsign
#DEFINE_ALIAS
from
.activation
import
swish
#DEFINE_ALIAS
from
.activation
import
tanh
_
shrink
#DEFINE_ALIAS
from
.activation
import
tanhshrink
#DEFINE_ALIAS
from
.activation
import
thresholded_relu
#DEFINE_ALIAS
from
.activation
import
log_softmax
#DEFINE_ALIAS
from
.common
import
dropout
#DEFINE_ALIAS
...
...
python/paddle/nn/functional/activation.py
浏览文件 @
40d193ed
...
...
@@ -19,15 +19,9 @@ from ...fluid.layers import hard_sigmoid #DEFINE_ALIAS
from
...fluid.layers
import
hard_swish
#DEFINE_ALIAS
from
...fluid.layers
import
leaky_relu
#DEFINE_ALIAS
from
...fluid.layers
import
maxout
#DEFINE_ALIAS
from
...fluid.layers
import
relu6
#DEFINE_ALIAS
from
...fluid.layers
import
selu
#DEFINE_ALIAS
from
...fluid.layers
import
soft_relu
#DEFINE_ALIAS
from
...fluid.layers
import
softplus
#DEFINE_ALIAS
from
...fluid.layers
import
softshrink
#DEFINE_ALIAS
from
...fluid.layers
import
softsign
#DEFINE_ALIAS
from
...fluid.layers
import
swish
#DEFINE_ALIAS
from
...fluid.layers
import
sigmoid
#DEFINE_ALIAS
from
...fluid.layers
import
tanh_shrink
#DEFINE_ALIAS
from
...fluid.layers
import
thresholded_relu
#DEFINE_ALIAS
__all__
=
[
...
...
@@ -53,7 +47,7 @@ __all__ = [
'softsign'
,
'sigmoid'
,
'swish'
,
'tanh
_
shrink'
,
'tanhshrink'
,
'thresholded_relu'
,
'log_softmax'
]
...
...
@@ -423,6 +417,103 @@ def logsigmoid(x, name=None):
return
out
def
relu6
(
x
,
name
=
None
):
"""
relu6 activation
.. math::
\t
ext{relu6}(x) = \min(\max(0,x), 6)
Args:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-1, 0.3, 6.5]))
out = F.relu6(x) # [0, 0.3, 6]
"""
threshold
=
6.0
if
in_dygraph_mode
():
return
core
.
ops
.
relu6
(
x
,
'threshold'
,
threshold
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'relu6'
)
helper
=
LayerHelper
(
'relu6'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'relu6'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'threshold'
:
threshold
})
return
out
def
selu
(
x
,
scale
=
1.0507009873554804934193349852946
,
alpha
=
1.6732632423543772848170429916717
,
name
=
None
):
"""
selu activation
.. math::
\t
ext{selu}(x) = scale * (\max(0,x) + \min(0,
\a
lpha * (\exp(x) - 1))),
\\
with\,alpha=1.6732632423543772848170429916717 and
\\
scale=1.0507009873554804934193349852946
Args:
x (Tensor): The input Tensor with data type float32, float64.
scale (float, optional): The value of scale for selu. Default is 1.0507009873554804934193349852946
alpha (float, optional): The value of alpha for selu. Default is 1.6732632423543772848170429916717
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([[0, 1],[2, 3]]))
out = F.selu(x) # [[0, 1.050701],[2.101402, 3.152103]]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
selu
(
x
,
'scale'
,
scale
,
'alpha'
,
alpha
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'selu'
)
helper
=
LayerHelper
(
'selu'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'selu'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'scale'
:
scale
,
'alpha'
:
alpha
})
return
out
def
softmax
(
x
,
axis
=-
1
,
name
=
None
):
"""
This operator implements the softmax layer. The calculation process is as follows:
...
...
@@ -539,6 +630,188 @@ def softmax(x, axis=-1, name=None):
return
paddle
.
fluid
.
layers
.
softmax
(
input
=
x
,
axis
=
axis
,
name
=
name
)
def
softplus
(
x
,
beta
=
1
,
threshold
=
20
,
name
=
None
):
"""
softplus activation
.. math::
\t
ext{softplus}(x) =
\f
rac{1}{
\b
eta} * \log(1 + \exp(
\b
eta * x))
\\
\t
ext{For numerical stability, the implementation reverts to the linear function when :}\,x
\t
imes
\b
eta > threshold.
Args:
x (Tensor): The input Tensor with data type float32, float64.
beta (float, optional): The value of beta for softplus. Default is 1
threshold (float, optional): The value of threshold for softplus. Default is 20
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
out = F.softplus(x) # [0.513015, 0.598139, 0.744397, 0.854355]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
softplus
(
x
,
'beta'
,
beta
,
'threshold'
,
threshold
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'softplus'
)
helper
=
LayerHelper
(
'softplus'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'softplus'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'beta'
:
beta
,
'threshold'
:
threshold
})
return
out
def
softshrink
(
x
,
threshold
=
0.5
,
name
=
None
):
"""
softshrink activation
.. math::
\t
ext{softshrink}(x) =
\b
egin{cases}
x - threshold, &
\t
ext{ if } x > threshold
\\
x + threshold, &
\t
ext{ if } x < -threshold
\\
0, &
\t
ext{ otherwise }
\end{cases}
Args:
x (Tensor): The input Tensor with data type float32, float64.
threshold (float, optional): The value of threshold(must be no less than zero) for softplus. Default is 0.5
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.9, -0.2, 0.1, 0.8]))
out = F.softshrink(x) # [-0.4, 0, 0, 0.3]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
softshrink
(
x
,
'lambda'
,
threshold
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'softshrink'
)
helper
=
LayerHelper
(
'softshrink'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'softshrink'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'lambda'
:
threshold
})
return
out
def
softsign
(
x
,
name
=
None
):
"""
softsign activation
.. math::
\t
ext{softsign}(x) =
\f
rac{x}{1 + |x|}
Args:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
out = F.softsign(x) # [-0.285714, -0.166667, 0.0909091, 0.230769]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
softsign
(
x
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'softsign'
)
helper
=
LayerHelper
(
'softsign'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'softsign'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
})
return
out
def
tanhshrink
(
x
,
name
=
None
):
"""
tanhshrink activation
.. math::
\t
ext{tanhshrink}(x) = x -
\t
ext{tanh}(x)
Args:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
out = F.tanhshrink(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
tanh_shrink
(
x
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'tanhshrink'
)
helper
=
LayerHelper
(
'tanh_shrink'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'tanh_shrink'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
})
return
out
def
log_softmax
(
x
,
axis
=-
1
,
dtype
=
None
,
name
=
None
):
"""
This operator implements the log_softmax layer. The calculation process is
...
...
python/paddle/nn/layer/activation.py
浏览文件 @
40d193ed
...
...
@@ -20,9 +20,15 @@ __all__ = [
'Hardshrink'
,
# 'PReLU',
'ReLU'
,
'ReLU6'
,
'SELU'
,
'LeakyReLU'
,
'Sigmoid'
,
# 'Softmax',
'Softplus'
,
'Softshrink'
,
'Softsign'
,
'Tanhshrink'
,
'LogSigmoid'
,
'LogSoftmax'
,
'HSigmoid'
...
...
@@ -351,6 +357,91 @@ class ReLU(layers.Layer):
return
F
.
relu
(
x
,
self
.
_name
)
class
ReLU6
(
layers
.
Layer
):
"""
ReLU6 Activation
.. math::
\t
ext{ReLU6}(x) = \min(\max(0,x), 6)
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-1, 0.3, 6.5]))
m = paddle.nn.ReLU6()
out = m(x) # [0, 0.3, 6]
"""
def
__init__
(
self
,
name
=
None
):
super
(
ReLU6
,
self
).
__init__
()
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
relu6
(
x
,
self
.
_name
)
class
SELU
(
layers
.
Layer
):
"""
SELU Activation
.. math::
\t
ext{SELU}(x) = scale * (\max(0,x) + \min(0,
\a
lpha * (\exp(x) - 1))),
\\
with\,alpha=1.6732632423543772848170429916717 and
\\
scale=1.0507009873554804934193349852946
Parameters:
scale (float, optional): The value of scale for SELU. Default is 1.0507009873554804934193349852946
alpha (float, optional): The value of alpha for SELU. Default is 1.6732632423543772848170429916717
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([[0, 1],[2, 3]]))
m = paddle.nn.SELU()
out = m(x) # [[0, 1.050701],[2.101402, 3.152103]]
"""
def
__init__
(
self
,
scale
=
1.0507009873554804934193349852946
,
alpha
=
1.6732632423543772848170429916717
,
name
=
None
):
super
(
SELU
,
self
).
__init__
()
self
.
_scale
=
scale
self
.
_alpha
=
alpha
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
selu
(
x
,
self
.
_scale
,
self
.
_alpha
,
self
.
_name
)
class
LeakyReLU
(
layers
.
Layer
):
"""
Leaky ReLU Activation.
...
...
@@ -431,6 +522,167 @@ class Sigmoid(layers.Layer):
return
F
.
sigmoid
(
x
,
self
.
name
)
class
Softplus
(
layers
.
Layer
):
"""
Softplus Activation
.. math::
\t
ext{Softplus}(x) =
\f
rac{1}{
\b
eta} * \log(1 + \exp(
\b
eta * x))
\\
\t
ext{For numerical stability, the implementation reverts to the linear function when :}\,x
\t
imes
\b
eta > threshold.
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
m = paddle.nn.Softplus()
out = m(x) # [0.513015, 0.598139, 0.744397, 0.854355]
"""
def
__init__
(
self
,
beta
=
1
,
threshold
=
20
,
name
=
None
):
super
(
Softplus
,
self
).
__init__
()
self
.
_beta
=
beta
self
.
_threshold
=
threshold
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
softplus
(
x
,
self
.
_beta
,
self
.
_threshold
,
self
.
_name
)
class
Softshrink
(
layers
.
Layer
):
"""
Softshrink Activation
.. math::
\t
ext{Softshrink}(x) =
\b
egin{cases}
x - threshold, &
\t
ext{ if } x > threshold
\\
x + threshold, &
\t
ext{ if } x < -threshold
\\
0, &
\t
ext{ otherwise }
\end{cases}
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.9, -0.2, 0.1, 0.8]))
m = paddle.nn.Softshrink()
out = m(x) # [-0.4, 0, 0, 0.3]
"""
def
__init__
(
self
,
threshold
=
0.5
,
name
=
None
):
super
(
Softshrink
,
self
).
__init__
()
self
.
_threshold
=
threshold
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
softshrink
(
x
,
self
.
_threshold
,
self
.
_name
)
class
Softsign
(
layers
.
Layer
):
"""
Softsign Activation
.. math::
\t
ext{Softsign}(x) =
\f
rac{x}{1 + |x|}
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
m = paddle.nn.Softsign()
out = m(x) # [-0.285714, -0.166667, 0.0909091, 0.230769]
"""
def
__init__
(
self
,
name
=
None
):
super
(
Softsign
,
self
).
__init__
()
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
softsign
(
x
,
self
.
_name
)
class
Tanhshrink
(
layers
.
Layer
):
"""
Tanhshrink Activation
.. math::
\t
ext{Tanhshrink}(x) = x -
\t
ext{Tanh}(x)
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
m = paddle.nn.Tanhshrink()
out = m(x) # [-0.020051, -0.00262468, 0.000332005, 0.00868739]
"""
def
__init__
(
self
,
name
=
None
):
super
(
Tanhshrink
,
self
).
__init__
()
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
tanhshrink
(
x
,
self
.
_name
)
class
LogSigmoid
(
layers
.
Layer
):
"""
LogSigmoid Activation.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录