Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0025e0d8
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0025e0d8
编写于
10月 10, 2020
作者:
Z
zhupengyang
提交者:
GitHub
10月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine APIs: brelu, hardsigmoid, hardswish, maxout (#27658)
上级
5098891f
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
685 addition
and
260 deletion
+685
-260
paddle/fluid/operators/maxout_op.cc
paddle/fluid/operators/maxout_op.cc
+12
-0
paddle/fluid/operators/maxout_op.h
paddle/fluid/operators/maxout_op.h
+7
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+22
-35
python/paddle/fluid/tests/unittests/test_activation_op.py
python/paddle/fluid/tests/unittests/test_activation_op.py
+169
-88
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+0
-29
python/paddle/fluid/tests/unittests/test_maxout_op.py
python/paddle/fluid/tests/unittests/test_maxout_op.py
+94
-59
python/paddle/nn/__init__.py
python/paddle/nn/__init__.py
+3
-0
python/paddle/nn/functional/__init__.py
python/paddle/nn/functional/__init__.py
+2
-3
python/paddle/nn/functional/activation.py
python/paddle/nn/functional/activation.py
+193
-20
python/paddle/nn/layer/activation.py
python/paddle/nn/layer/activation.py
+183
-26
未找到文件。
paddle/fluid/operators/maxout_op.cc
浏览文件 @
0025e0d8
...
...
@@ -83,6 +83,18 @@ class MaxOutOp : public framework::OperatorWithKernel {
"Attr(groups) of Op(maxout) should be "
"larger than 1. But received %d."
,
groups
));
PADDLE_ENFORCE_EQ
(
axis
==
1
||
axis
==
-
1
||
axis
==
3
,
true
,
platform
::
errors
::
InvalidArgument
(
"axis only supported 1, -1 or 3, but recevied axis is: %d"
,
axis
));
PADDLE_ENFORCE_EQ
(
in_x_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"x's dims should be 4, but received x's dims is: %d"
,
in_x_dims
.
size
()));
if
(
axis
<
0
)
{
axis
+=
in_x_dims
.
size
();
}
PADDLE_ENFORCE_EQ
(
in_x_dims
[
axis
]
%
groups
,
0
,
platform
::
errors
::
InvalidArgument
(
...
...
paddle/fluid/operators/maxout_op.h
浏览文件 @
0025e0d8
...
...
@@ -31,6 +31,9 @@ class MaxOutKernel : public framework::OpKernel<T> {
Tensor
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
int
groups
=
context
.
template
Attr
<
int
>(
"groups"
);
int
axis
=
context
.
template
Attr
<
int
>(
"axis"
);
if
(
axis
<
0
)
{
axis
+=
in_x
->
dims
().
size
();
}
math
::
MaxOutFunctor
<
DeviceContext
,
T
>
maxout_forward
;
maxout_forward
(
context
.
template
device_context
<
DeviceContext
>(),
*
in_x
,
out
,
...
...
@@ -49,6 +52,10 @@ class MaxOutGradKernel : public framework::OpKernel<T> {
Tensor
*
in_x_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
int
groups
=
context
.
template
Attr
<
int
>(
"groups"
);
int
axis
=
context
.
template
Attr
<
int
>(
"axis"
);
if
(
axis
<
0
)
{
axis
+=
in_x
->
dims
().
size
();
}
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
math
::
SetConstant
<
DeviceContext
,
T
>
zero
;
if
(
in_x_grad
)
{
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
0025e0d8
...
...
@@ -9592,10 +9592,6 @@ def stanh(x, scale_a=0.67, scale_b=1.7159, name=None):
@templatedoc()
def hard_sigmoid(x, slope=0.2, offset=0.5, name=None):
"""
:alias_main: paddle.nn.functional.hard_sigmoid
:alias: paddle.nn.functional.hard_sigmoid,paddle.nn.functional.activation.hard_sigmoid
:old_api: paddle.fluid.layers.hard_sigmoid
${comment}
Parameters:
x (${x_type}): ${x_comment}
...
...
@@ -9613,9 +9609,15 @@ def hard_sigmoid(x, slope=0.2, offset=0.5, name=None):
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
data = fluid.layers.fill_constant(shape=[3, 2], value=0.5, dtype='float32') # [[0.5, 0.5], [0.5, 0.5], [0.5, 0.5]]
result = fluid.layers.hard_sigmoid(data) # [[0.6, 0.6], [0.6, 0.6], [0.6, 0.6]]
"""
if in_dygraph_mode():
return core.ops.hard_sigmoid(x, 'slope', slope, 'offset', offset)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_sigmoid')
...
...
@@ -9802,10 +9804,6 @@ def prelu(x, mode, param_attr=None, name=None):
@templatedoc()
def brelu(x, t_min=0.0, t_max=24.0, name=None):
"""
:alias_main: paddle.nn.functional.brelu
:alias: paddle.nn.functional.brelu,paddle.nn.functional.activation.brelu
:old_api: paddle.fluid.layers.brelu
${comment}
Args:
x(${x_type}): ${x_comment}
...
...
@@ -9821,7 +9819,9 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
.. code-block:: python
import paddle.fluid as fluid
import paddle
import numpy as np
paddle.enable_static()
input_brelu = np.array([[-1,6],[1,15.6]])
with fluid.dygraph.guard():
...
...
@@ -9831,6 +9831,9 @@ def brelu(x, t_min=0.0, t_max=24.0, name=None):
#[[ 1. 6.]
#[ 1. 10.]]
"""
if in_dygraph_mode():
return core.ops.brelu(x, 't_min', t_min, 't_max', t_max)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], 'brelu')
helper = LayerHelper('brelu', **locals())
...
...
@@ -12564,13 +12567,10 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
return out
@deprecated(since="2.0.0", update_to="paddle.nn.functional.maxout")
@templatedoc()
def maxout(x, groups, name=None, axis=1):
"""
:alias_main: paddle.nn.functional.maxout
:alias: paddle.nn.functional.maxout,paddle.nn.functional.activation.maxout
:old_api: paddle.fluid.layers.maxout
${comment}
Args:
...
...
@@ -12592,31 +12592,16 @@ def maxout(x, groups, name=None, axis=1):
.. code-block:: python
import paddle.fluid as fluid
import paddle
paddle.enable_static()
input = fluid.data(
name='data',
shape=[None, 256, 32, 32],
dtype='float32')
out = fluid.layers.maxout(input, groups=2)
"""
check_variable_and_dtype(x, 'x', ['float32', 'float64'], 'maxout')
helper = LayerHelper("maxout", **locals())
if axis not in [1, -1, 3]:
raise ValueError(
"Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "
"Attr(axis): %s." % str(axis))
if axis == -1:
axis = 3
out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type="maxout",
inputs={"X": x},
attrs={"groups": groups,
"axis": axis},
outputs={"Out": out})
return out
return paddle.nn.functional.maxout(**locals())
def space_to_depth(x, blocksize, name=None):
...
...
@@ -14877,10 +14862,6 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
@templatedoc()
def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
"""
:alias_main: paddle.nn.functional.hard_swish
:alias: paddle.nn.functional.hard_swish,paddle.nn.functional.activation.hard_swish
:old_api: paddle.fluid.layers.hard_swish
This operator implements the hard_swish activation function.
Hard_swish is proposed in MobileNetV3, and performs better in computational stability and efficiency compared to swish function.
For more details please refer to: https://arxiv.org/pdf/1905.02244.pdf
...
...
@@ -14911,7 +14892,9 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
.. code-block:: python
import paddle.fluid as fluid
import paddle
import numpy as np
paddle.enable_static()
DATATYPE='float32'
...
...
@@ -14926,6 +14909,10 @@ def hard_swish(x, threshold=6.0, scale=6.0, offset=3.0, name=None):
out, = exe.run(feed={'x':x_data}, fetch_list=[y.name])
print(out) # [[0.66666667, 1.66666667,3., 4.]]
"""
if in_dygraph_mode():
return core.ops.hard_swish(x, 'threshold', threshold, 'scale', scale,
'offset', offset)
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'hard_swish')
...
...
python/paddle/fluid/tests/unittests/test_activation_op.py
浏览文件 @
0025e0d8
...
...
@@ -25,10 +25,11 @@ import paddle.nn as nn
import
paddle.nn.functional
as
F
from
paddle.fluid
import
compiler
,
Program
,
program_guard
paddle
.
enable_static
()
class
TestSqrtOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
(),
Program
()):
# The input type of sqrt op must be Variable or numpy.ndarray.
in1
=
1
...
...
@@ -45,7 +46,6 @@ class TestSqrtOpError(unittest.TestCase):
class
TestActivation
(
OpTest
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"exp"
self
.
init_dtype
()
self
.
init_kernel_type
()
...
...
@@ -74,7 +74,6 @@ class TestActivation(OpTest):
class
TestParameter
(
object
):
def
test_out_name
(
self
):
paddle
.
enable_static
()
with
fluid
.
program_guard
(
fluid
.
Program
()):
np_x
=
np
.
array
([
0.1
])
data
=
fluid
.
layers
.
data
(
name
=
"X"
,
shape
=
[
1
])
...
...
@@ -96,7 +95,6 @@ class TestParameter(object):
class
TestSigmoid
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"sigmoid"
self
.
init_dtype
()
...
...
@@ -118,7 +116,6 @@ class TestSigmoid(TestActivation):
class
TestLogSigmoid
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"logsigmoid"
self
.
init_dtype
()
...
...
@@ -192,7 +189,6 @@ class TestLogSigmoidAPI(unittest.TestCase):
class
TestTanh
(
TestActivation
,
TestParameter
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"tanh"
self
.
init_dtype
()
np
.
random
.
seed
(
1024
)
...
...
@@ -273,7 +269,6 @@ class TestTanhAPI(unittest.TestCase):
class
TestAtan
(
TestActivation
,
TestParameter
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"atan"
self
.
init_dtype
()
...
...
@@ -311,7 +306,6 @@ class TestAtan(TestActivation, TestParameter):
class
TestSinh
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"sinh"
self
.
init_dtype
()
...
...
@@ -371,7 +365,6 @@ class TestSinh(TestActivation):
class
TestSinhOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
sinh
,
1
)
...
...
@@ -385,7 +378,6 @@ class TestSinhOpError(unittest.TestCase):
class
TestCosh
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"cosh"
self
.
init_dtype
()
...
...
@@ -445,7 +437,6 @@ class TestCosh(TestActivation):
class
TestCoshOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
cosh
,
1
)
...
...
@@ -464,7 +455,6 @@ def ref_tanhshrink(x):
class
TestTanhshrink
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"tanh_shrink"
self
.
init_dtype
()
...
...
@@ -544,7 +534,6 @@ def ref_hardshrink(x, threshold):
class
TestHardShrink
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"hard_shrink"
self
.
init_dtype
()
...
...
@@ -575,7 +564,6 @@ class TestHardShrink_threshold_negative(TestHardShrink):
class
TestHardShrinkAPI
(
unittest
.
TestCase
):
# test paddle.nn.Hardshrink, paddle.nn.functional.hardshrink
def
setUp
(
self
):
paddle
.
enable_static
()
np
.
random
.
seed
(
1024
)
self
.
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
[
10
,
12
]).
astype
(
'float32'
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
...
...
@@ -704,7 +692,6 @@ def ref_softshrink(x, threshold=0.5):
class
TestSoftshrink
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"softshrink"
self
.
init_dtype
()
...
...
@@ -784,7 +771,6 @@ class TestSoftshrinkAPI(unittest.TestCase):
class
TestSqrt
(
TestActivation
,
TestParameter
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"sqrt"
self
.
init_dtype
()
...
...
@@ -803,7 +789,6 @@ class TestSqrt(TestActivation, TestParameter):
class
TestRsqrt
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"rsqrt"
self
.
init_dtype
()
...
...
@@ -822,7 +807,6 @@ class TestRsqrt(TestActivation):
class
TestAbs
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"abs"
self
.
init_dtype
()
...
...
@@ -846,7 +830,6 @@ class TestAbs(TestActivation):
class
TestCeil
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"ceil"
self
.
init_dtype
()
...
...
@@ -864,7 +847,6 @@ class TestCeil(TestActivation):
class
TestFloor
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"floor"
self
.
init_dtype
()
...
...
@@ -884,7 +866,6 @@ class TestFloor(TestActivation):
class
TestCos
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"cos"
self
.
init_dtype
()
...
...
@@ -903,7 +884,6 @@ class TestCos(TestActivation):
class
TestAcos
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"acos"
self
.
init_dtype
()
...
...
@@ -922,7 +902,6 @@ class TestAcos(TestActivation):
class
TestSin
(
TestActivation
,
TestParameter
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"sin"
self
.
init_dtype
()
...
...
@@ -941,7 +920,6 @@ class TestSin(TestActivation, TestParameter):
class
TestAsin
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"asin"
self
.
init_dtype
()
...
...
@@ -960,7 +938,6 @@ class TestAsin(TestActivation):
class
TestRound
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"round"
self
.
init_dtype
()
...
...
@@ -977,7 +954,6 @@ class TestRound(TestActivation):
class
TestRelu
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"relu"
self
.
init_dtype
()
...
...
@@ -1052,7 +1028,6 @@ class TestLeakyRelu(TestActivation):
return
0.02
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"leaky_relu"
self
.
init_dtype
()
alpha
=
self
.
get_alpha
()
...
...
@@ -1162,7 +1137,6 @@ def gelu(x, approximate):
class
TestGeluApproximate
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"gelu"
self
.
init_dtype
()
approximate
=
True
...
...
@@ -1182,7 +1156,6 @@ class TestGeluApproximate(TestActivation):
class
TestGelu
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"gelu"
self
.
init_dtype
()
approximate
=
False
...
...
@@ -1254,7 +1227,6 @@ class TestGELUAPI(unittest.TestCase):
class
TestBRelu
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"brelu"
self
.
init_dtype
()
...
...
@@ -1279,9 +1251,35 @@ class TestBRelu(TestActivation):
self
.
check_grad
([
'X'
],
'Out'
)
class
TestBReluOpError
(
unittest
.
TestCase
):
class
TestBreluAPI
(
unittest
.
TestCase
):
# test paddle.fluid.layers.brelu
def
setUp
(
self
):
np
.
random
.
seed
(
1024
)
self
.
t_min
=
0.
self
.
t_max
=
24.
self
.
x_np
=
np
.
random
.
uniform
(
-
1
,
30
,
[
10
,
12
]).
astype
(
'float32'
)
self
.
out_ref
=
np
.
copy
(
self
.
x_np
)
self
.
out_ref
[
self
.
out_ref
<
self
.
t_min
]
=
self
.
t_min
self
.
out_ref
[
self
.
out_ref
>
self
.
t_max
]
=
self
.
t_max
self
.
out_ref
=
self
.
out_ref
.
astype
(
'float32'
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_fluid_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
static
.
data
(
'X'
,
[
10
,
12
])
out
=
paddle
.
fluid
.
layers
.
brelu
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
self
.
assertTrue
(
np
.
allclose
(
self
.
out_ref
,
res
[
0
]))
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out
=
paddle
.
fluid
.
layers
.
brelu
(
x
)
self
.
assertTrue
(
np
.
allclose
(
self
.
out_ref
,
out
.
numpy
()))
paddle
.
enable_static
()
def
test_errors
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
brelu
,
1
)
...
...
@@ -1303,7 +1301,6 @@ def ref_relu6(x, threshold=6.0):
class
TestRelu6
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"relu6"
self
.
init_dtype
()
...
...
@@ -1378,9 +1375,13 @@ class TestRelu6API(unittest.TestCase):
F
.
relu6
(
x_fp16
)
def
ref_hardswish
(
x
,
threshold
=
6.0
,
scale
=
6.0
,
offset
=
3.0
):
return
(
x
*
np
.
minimum
(
np
.
maximum
(
x
+
offset
,
0.
),
threshold
)
/
scale
).
astype
(
x
.
dtype
)
class
TestHardSwish
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
'hard_swish'
self
.
init_dtype
()
...
...
@@ -1392,9 +1393,9 @@ class TestHardSwish(TestActivation):
#the same with TestAbs
x
[
np
.
abs
(
x
+
offset
)
<
0.005
]
=
0.02
x
[
np
.
abs
(
x
-
threshold
+
offset
)
<
0.005
]
=
threshold
-
offset
+
0.02
out
=
x
*
np
.
minimum
(
np
.
maximum
(
x
+
offset
,
0
),
threshold
)
/
scale
out
=
ref_hardswish
(
x
,
threshold
,
scale
,
offset
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
x
)
}
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
'threshold'
:
threshold
,
'scale'
:
scale
,
'offset'
:
offset
}
self
.
outputs
=
{
'Out'
:
out
}
...
...
@@ -1404,23 +1405,65 @@ class TestHardSwish(TestActivation):
self
.
check_grad
([
'X'
],
'Out'
)
class
TestHardSwishOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
class
TestHardswishAPI
(
unittest
.
TestCase
):
# test paddle.nn.Hardswish, paddle.nn.functional.hardswish
def
setUp
(
self
):
self
.
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
[
10
,
12
]).
astype
(
np
.
float64
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_static_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out1
=
F
.
hardswish
(
x
)
m
=
paddle
.
nn
.
Hardswish
()
out2
=
m
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out1
,
out2
])
out_ref
=
ref_hardswish
(
self
.
x_np
)
for
r
in
res
:
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
r
))
def
test_dygraph_api
(
self
):
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out1
=
F
.
hardswish
(
x
)
m
=
paddle
.
nn
.
Hardswish
()
out2
=
m
(
x
)
out_ref
=
ref_hardswish
(
self
.
x_np
)
for
r
in
[
out1
,
out2
]:
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
r
.
numpy
()))
paddle
.
enable_static
()
with
program_guard
(
Program
()):
def
test_fluid_api
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out
=
fluid
.
layers
.
hard_swish
(
x
)
exe
=
fluid
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
out_ref
=
ref_hardswish
(
self
.
x_np
)
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
res
[
0
]))
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out
=
paddle
.
fluid
.
layers
.
hard_swish
(
x
)
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
out
.
numpy
()))
paddle
.
enable_static
()
def
test_errors
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
hard_
swish
,
1
)
self
.
assertRaises
(
TypeError
,
F
.
hard
swish
,
1
)
# The input dtype must be float16, float32, float64.
x_int32
=
fluid
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
hard_
swish
,
x_int32
)
x_int32
=
paddle
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
F
.
hard
swish
,
x_int32
)
# support the input dtype is float16
x_fp16
=
fluid
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
fluid
.
layers
.
hard_
swish
(
x_fp16
)
x_fp16
=
paddle
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
F
.
hard
swish
(
x_fp16
)
class
TestSoftRelu
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"soft_relu"
self
.
init_dtype
()
...
...
@@ -1447,7 +1490,6 @@ class TestSoftRelu(TestActivation):
class
TestSoftReluOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
soft_relu
,
1
)
...
...
@@ -1466,7 +1508,6 @@ def elu(x, alpha):
class
TestELU
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"elu"
self
.
init_dtype
()
...
...
@@ -1540,7 +1581,6 @@ class TestELUAPI(unittest.TestCase):
class
TestReciprocal
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"reciprocal"
self
.
init_dtype
()
...
...
@@ -1559,7 +1599,6 @@ class TestReciprocal(TestActivation):
class
TestLog
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"log"
self
.
init_dtype
()
...
...
@@ -1587,7 +1626,6 @@ class TestLog(TestActivation):
class
TestLog1p
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"log1p"
self
.
init_dtype
()
...
...
@@ -1633,7 +1671,6 @@ class TestLog1p(TestActivation):
class
TestSquare
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"square"
self
.
init_dtype
()
...
...
@@ -1652,7 +1689,6 @@ class TestSquare(TestActivation):
class
TestPow
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"pow"
self
.
init_dtype
()
...
...
@@ -1672,7 +1708,6 @@ class TestPow(TestActivation):
class
TestPow_factor_tensor
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"pow"
self
.
init_dtype
()
...
...
@@ -1750,7 +1785,6 @@ class TestPow_factor_tensor(TestActivation):
class
TestSTanh
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"stanh"
self
.
init_dtype
()
...
...
@@ -1772,7 +1806,6 @@ class TestSTanh(TestActivation):
class
TestSTanhOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
stanh
,
1
)
...
...
@@ -1793,7 +1826,6 @@ def ref_softplus(x, beta=1, threshold=20):
class
TestSoftplus
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"softplus"
self
.
init_dtype
()
...
...
@@ -1877,7 +1909,6 @@ def ref_softsign(x):
class
TestSoftsign
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"softsign"
self
.
init_dtype
()
...
...
@@ -1950,7 +1981,6 @@ class TestSoftsignAPI(unittest.TestCase):
class
TestThresholdedRelu
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"thresholded_relu"
self
.
init_dtype
()
...
...
@@ -1975,7 +2005,6 @@ class TestThresholdedRelu(TestActivation):
class
TestThresholdedReluOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
thresholded_relu
,
1
)
...
...
@@ -1987,54 +2016,107 @@ class TestThresholdedReluOpError(unittest.TestCase):
fluid
.
layers
.
thresholded_relu
(
x_fp16
)
def
ref_hardsigmoid
(
x
,
slope
=
0.166666666666667
,
offset
=
0.5
):
return
np
.
maximum
(
np
.
minimum
(
x
*
slope
+
offset
,
1.
),
0.
).
astype
(
x
.
dtype
)
class
TestHardSigmoid
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"hard_sigmoid"
self
.
init_dtype
()
np
.
random
.
seed
(
1024
)
X
=
np
.
random
.
uniform
(
-
5
,
5
,
[
10
,
12
]).
astype
(
"float32"
)
slope
=
0.2
offset
=
0.5
lower_threshold
=
-
offset
/
slope
upper_threshold
=
(
1
-
offset
)
/
slope
self
.
dtype
=
'float64'
self
.
slope
=
0.166666666666667
self
.
offset
=
0.5
self
.
set_attrs
()
self
.
delta
=
0.005
x
=
np
.
random
.
uniform
(
-
5
,
5
,
[
10
,
12
]).
astype
(
self
.
dtype
)
lower_threshold
=
-
self
.
offset
/
self
.
slope
upper_threshold
=
(
1.
-
self
.
offset
)
/
self
.
slope
# Same reason as TestAbs
X
[(
X
-
lower_threshold
)
<
self
.
delta
]
=
lower_threshold
-
0.02
X
[(
X
-
upper_threshold
)
<
self
.
delta
]
=
upper_threshold
+
0.02
delta
=
0.005
x
[
np
.
abs
(
x
-
lower_threshold
)
<
delta
]
=
lower_threshold
-
0.02
x
[
np
.
abs
(
x
-
upper_threshold
)
<
delta
]
=
upper_threshold
-
0.02
temp
=
X
*
slope
+
offset
out
=
np
.
maximum
(
0.0
,
np
.
minimum
(
1.0
,
temp
))
out
=
ref_hardsigmoid
(
x
,
self
.
slope
,
self
.
offset
)
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
X
)}
self
.
attrs
=
{
'slope'
:
self
.
slope
,
'offset'
:
self
.
offset
}
self
.
inputs
=
{
'X'
:
x
}
self
.
outputs
=
{
'Out'
:
out
}
def
test_check_grad
(
self
):
if
self
.
dtype
==
np
.
float16
:
return
self
.
check_grad
([
'X'
],
'Out'
)
def
set_attrs
(
self
):
pass
class
TestHardSigmoidOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
class
TestHardSigmoidFP32
(
TestHardSigmoid
):
def
set_attrs
(
self
):
self
.
dtype
=
'float32'
class
TestHardSigmoidSlopeOffset
(
TestHardSigmoid
):
def
set_attrs
(
self
):
self
.
slope
=
0.2
self
.
offset
=
0.4
class
TestHardsigmoidAPI
(
unittest
.
TestCase
):
# test paddle.nn.Hardsigmoid, paddle.nn.functional.hardsigmoid
def
setUp
(
self
):
self
.
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
[
10
,
12
]).
astype
(
np
.
float64
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_static_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out1
=
F
.
hardsigmoid
(
x
)
m
=
paddle
.
nn
.
Hardsigmoid
()
out2
=
m
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out1
,
out2
])
out_ref
=
ref_hardsigmoid
(
self
.
x_np
)
for
r
in
res
:
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
r
))
def
test_dygraph_api
(
self
):
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out1
=
F
.
hardsigmoid
(
x
)
m
=
paddle
.
nn
.
Hardsigmoid
()
out2
=
m
(
x
)
out_ref
=
ref_hardsigmoid
(
self
.
x_np
)
for
r
in
[
out1
,
out2
]:
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
r
.
numpy
()))
paddle
.
enable_static
()
with
program_guard
(
Program
()):
def
test_fluid_api
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out
=
fluid
.
layers
.
hard_sigmoid
(
x
)
exe
=
fluid
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
out_ref
=
ref_hardsigmoid
(
self
.
x_np
,
0.2
,
0.5
)
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
res
[
0
]))
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out
=
paddle
.
fluid
.
layers
.
hard_sigmoid
(
x
)
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
out
.
numpy
()))
paddle
.
enable_static
()
def
test_errors
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
hard_
sigmoid
,
1
)
self
.
assertRaises
(
TypeError
,
F
.
hard
sigmoid
,
1
)
# The input dtype must be float16, float32, float64.
x_int32
=
fluid
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
hard_
sigmoid
,
x_int32
)
x_int32
=
paddle
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
F
.
hard
sigmoid
,
x_int32
)
# support the input dtype is float16
x_fp16
=
fluid
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
fluid
.
layers
.
hard_
sigmoid
(
x_fp16
)
x_fp16
=
paddle
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
F
.
hard
sigmoid
(
x_fp16
)
class
TestSwish
(
TestActivation
):
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
op_type
=
"swish"
self
.
init_dtype
()
...
...
@@ -2055,7 +2137,6 @@ class TestSwish(TestActivation):
class
TestSwishOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
paddle
.
enable_static
()
with
program_guard
(
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
swish
,
1
)
...
...
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
0025e0d8
...
...
@@ -1657,21 +1657,6 @@ class TestLayer(LayerTest):
with
self
.
assertRaises
(
TypeError
):
layers
.
eye
(
num_rows
=
3
,
batch_shape
=
[
-
1
])
def
test_hard_swish
(
self
):
with
self
.
static_graph
():
t
=
layers
.
data
(
name
=
't'
,
shape
=
[
3
,
3
],
dtype
=
'float32'
)
ret
=
layers
.
hard_swish
(
t
)
static_ret
=
self
.
get_static_graph_result
(
feed
=
{
't'
:
np
.
ones
(
[
3
,
3
],
dtype
=
'float32'
)},
fetch_list
=
[
ret
])[
0
]
with
self
.
dynamic_graph
():
t
=
np
.
ones
([
3
,
3
],
dtype
=
'float32'
)
dy_ret
=
layers
.
hard_swish
(
base
.
to_variable
(
t
))
dy_ret_rlt
=
dy_ret
.
numpy
()
self
.
assertTrue
(
np
.
allclose
(
static_ret
,
dy_ret_rlt
))
def
test_while_loop
(
self
):
with
self
.
static_graph
():
i
=
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int64'
,
value
=
0
)
...
...
@@ -2563,13 +2548,6 @@ class TestBook(LayerTest):
output
=
layers
.
l2_normalize
(
x
,
axis
=
1
)
return
output
def
make_maxout
(
self
):
with
program_guard
(
fluid
.
default_main_program
(),
fluid
.
default_startup_program
()):
data
=
self
.
_get_data
(
name
=
'x'
,
shape
=
[
8
,
6
,
6
],
dtype
=
"float32"
)
output
=
layers
.
maxout
(
x
=
data
,
groups
=
2
)
return
(
output
)
def
make_crop
(
self
):
with
program_guard
(
fluid
.
default_main_program
(),
fluid
.
default_startup_program
()):
...
...
@@ -2656,13 +2634,6 @@ class TestBook(LayerTest):
name
=
'prelu'
)
return
(
out
)
def
make_brelu
(
self
):
with
program_guard
(
fluid
.
default_main_program
(),
fluid
.
default_startup_program
()):
input
=
self
.
_get_data
(
name
=
"input"
,
shape
=
[
16
],
dtype
=
"float32"
)
out
=
layers
.
brelu
(
input
,
t_min
=
1.0
,
t_max
=
20.0
,
name
=
'brelu'
)
return
(
out
)
def
make_soft_relu
(
self
):
with
program_guard
(
fluid
.
default_main_program
(),
fluid
.
default_startup_program
()):
...
...
python/paddle/fluid/tests/unittests/test_maxout_op.py
浏览文件 @
0025e0d8
...
...
@@ -16,32 +16,43 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
import
paddle.fluid.core
as
core
import
paddle.nn.functional
as
F
from
op_test
import
OpTest
paddle
.
enable_static
()
np
.
random
.
seed
(
1
)
def
maxout_forward_naive
(
input
,
groups
,
channel_axis
):
s0
,
s1
,
s2
,
s3
=
input
.
shape
if
channel_axis
==
3
:
return
np
.
ndarray
([
s0
,
s1
,
s2
,
s3
//
groups
,
groups
],
\
buffer
=
input
,
dtype
=
input
.
dtype
).
max
(
axis
=
(
4
))
return
np
.
ndarray
([
s0
,
s1
//
groups
,
groups
,
s2
,
s3
],
\
buffer
=
input
,
dtype
=
input
.
dtype
).
max
(
axis
=
(
2
))
def
maxout_forward_naive
(
x
,
groups
,
channel_axis
):
s0
,
s1
,
s2
,
s3
=
x
.
shape
if
channel_axis
==
1
:
return
np
.
ndarray
([
s0
,
s1
//
groups
,
groups
,
s2
,
s3
],
\
buffer
=
x
,
dtype
=
x
.
dtype
).
max
(
axis
=
2
)
return
np
.
ndarray
([
s0
,
s1
,
s2
,
s3
//
groups
,
groups
],
\
buffer
=
x
,
dtype
=
x
.
dtype
).
max
(
axis
=
4
)
class
TestMaxOutOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"maxout"
self
.
init_test_case
()
input
=
np
.
random
.
random
(
self
.
shape
)
output
=
self
.
MaxOut_forward_naive
(
input
,
self
.
groups
,
self
.
axis
)
self
.
dtype
=
'float64'
self
.
shape
=
[
3
,
6
,
2
,
4
]
self
.
groups
=
2
self
.
axis
=
1
self
.
set_attrs
()
x
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
shape
).
astype
(
self
.
dtype
)
out
=
maxout_forward_naive
(
x
,
self
.
groups
,
self
.
axis
)
self
.
inputs
=
{
'X'
:
input
}
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
'groups'
:
self
.
groups
,
'axis'
:
self
.
axis
}
self
.
outputs
=
{
'Out'
:
out
}
self
.
outputs
=
{
'Out'
:
output
}
def
set_attrs
(
self
):
pass
def
test_check_output
(
self
):
self
.
check_output
()
...
...
@@ -49,65 +60,89 @@ class TestMaxOutOp(OpTest):
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Out'
)
def
init_test_case
(
self
):
self
.
MaxOut_forward_naive
=
maxout_forward_naive
self
.
shape
=
[
100
,
6
,
2
,
2
]
self
.
groups
=
2
self
.
axis
=
1
class
TestMaxOutOpAxis
(
TestMaxOutOp
):
def
init_test_case
(
self
):
self
.
MaxOut_forward_naive
=
maxout_forward_naive
self
.
shape
=
[
100
,
2
,
2
,
6
]
# NHWC format
self
.
groups
=
2
self
.
axis
=
3
class
TestMaxOutOpAxis0
(
TestMaxOutOp
):
def
set_attrs
(
self
):
self
.
axis
=
-
1
class
TestMaxOutOpAxisAPI
(
unittest
.
TestCase
):
def
test_axis
(
self
):
data1
=
fluid
.
data
(
name
=
'data1'
,
shape
=
[
3
,
6
,
2
,
2
],
dtype
=
'float32'
)
data2
=
fluid
.
data
(
name
=
'data2'
,
shape
=
[
3
,
2
,
2
,
6
],
dtype
=
'float32'
)
out1
=
fluid
.
layers
.
maxout
(
data1
,
groups
=
2
,
axis
=
1
)
out2
=
fluid
.
layers
.
maxout
(
data2
,
groups
=
2
,
axis
=-
1
)
data1_np
=
np
.
random
.
random
((
3
,
6
,
2
,
2
)).
astype
(
"float32"
)
data2_np
=
np
.
transpose
(
data1_np
,
[
0
,
2
,
3
,
1
])
class
TestMaxOutOpAxis1
(
TestMaxOutOp
):
def
set_attrs
(
self
):
self
.
axis
=
3
if
core
.
is_compiled_with_cuda
():
place
=
core
.
CUDAPlace
(
0
)
else
:
place
=
core
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
results
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"data1"
:
data1_np
,
"data2"
:
data2_np
},
fetch_list
=
[
out1
,
out2
],
return_numpy
=
True
)
self
.
assertTrue
(
np
.
allclose
(
results
[
0
],
np
.
transpose
(
results
[
1
],
(
0
,
3
,
1
,
2
))))
class
TestMaxOutOpFP32
(
TestMaxOutOp
):
def
set_attrs
(
self
):
self
.
dtype
=
'float32'
def
test_exception
(
self
):
input
=
fluid
.
data
(
name
=
"input"
,
shape
=
[
2
,
4
,
6
,
6
],
dtype
=
"float32"
)
def
_attr_axis
():
out
=
fluid
.
layers
.
maxout
(
input
,
groups
=
2
,
axis
=
2
)
class
TestMaxOutOpGroups
(
TestMaxOutOp
):
def
set_attrs
(
self
):
self
.
groups
=
3
self
.
assertRaises
(
ValueError
,
_attr_axis
)
class
TestMaxoutAPI
(
unittest
.
TestCase
):
# test paddle.nn.Maxout, paddle.nn.functional.maxout
def
setUp
(
self
):
self
.
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
[
2
,
6
,
5
,
4
]).
astype
(
np
.
float64
)
self
.
groups
=
2
self
.
axis
=
1
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_static_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out1
=
F
.
maxout
(
x
,
self
.
groups
,
self
.
axis
)
m
=
paddle
.
nn
.
Maxout
(
self
.
groups
,
self
.
axis
)
out2
=
m
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out1
,
out2
])
out_ref
=
maxout_forward_naive
(
self
.
x_np
,
self
.
groups
,
self
.
axis
)
for
r
in
res
:
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
r
))
def
test_dygraph_api
(
self
):
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out1
=
F
.
maxout
(
x
,
self
.
groups
,
self
.
axis
)
m
=
paddle
.
nn
.
Maxout
(
self
.
groups
,
self
.
axis
)
out2
=
m
(
x
)
out_ref
=
maxout_forward_naive
(
self
.
x_np
,
self
.
groups
,
self
.
axis
)
for
r
in
[
out1
,
out2
]:
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
r
.
numpy
()))
out3
=
F
.
maxout
(
x
,
self
.
groups
,
-
1
)
out3_ref
=
maxout_forward_naive
(
self
.
x_np
,
self
.
groups
,
-
1
)
self
.
assertTrue
(
np
.
allclose
(
out3_ref
,
out3
.
numpy
()))
paddle
.
enable_static
()
def
test_fluid_api
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
'X'
,
self
.
x_np
.
shape
,
self
.
x_np
.
dtype
)
out
=
fluid
.
layers
.
maxout
(
x
,
groups
=
self
.
groups
,
axis
=
self
.
axis
)
exe
=
fluid
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
out_ref
=
maxout_forward_naive
(
self
.
x_np
,
self
.
groups
,
self
.
axis
)
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
res
[
0
]))
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_tensor
(
self
.
x_np
)
out
=
paddle
.
fluid
.
layers
.
maxout
(
x
,
groups
=
self
.
groups
,
axis
=
self
.
axis
)
self
.
assertTrue
(
np
.
allclose
(
out_ref
,
out
.
numpy
()))
paddle
.
enable_static
()
class
TestMaxOutOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
p
rogram_guard
(
Program
()):
with
p
addle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
maxout
,
1
,
2
)
self
.
assertRaises
(
TypeError
,
F
.
maxout
,
1
)
# The input dtype must be float16, float32, float64.
x_int32
=
fluid
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
fluid
.
layers
.
maxout
,
x_int32
,
2
)
# support the input dtype is float32
x_fp32
=
fluid
.
data
(
name
=
'x_fp32'
,
shape
=
[
12
,
10
],
dtype
=
'float32'
)
fluid
.
layers
.
maxout
(
x_fp32
,
2
)
x_int32
=
paddle
.
data
(
name
=
'x_int32'
,
shape
=
[
2
,
4
,
6
,
8
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
F
.
maxout
,
x_int32
)
x_float32
=
paddle
.
data
(
name
=
'x_float32'
,
shape
=
[
2
,
4
,
6
,
8
])
self
.
assertRaises
(
ValueError
,
F
.
maxout
,
x_float32
,
2
,
2
)
if
__name__
==
'__main__'
:
...
...
python/paddle/nn/__init__.py
浏览文件 @
0025e0d8
...
...
@@ -55,6 +55,7 @@ from .layer.activation import ELU #DEFINE_ALIAS
from
.layer.activation
import
GELU
#DEFINE_ALIAS
from
.layer.activation
import
Tanh
#DEFINE_ALIAS
from
.layer.activation
import
Hardshrink
#DEFINE_ALIAS
from
.layer.activation
import
Hardswish
#DEFINE_ALIAS
from
.layer.activation
import
Hardtanh
#DEFINE_ALIAS
from
.layer.activation
import
PReLU
#DEFINE_ALIAS
from
.layer.activation
import
ReLU
#DEFINE_ALIAS
...
...
@@ -62,6 +63,7 @@ from .layer.activation import ReLU6 #DEFINE_ALIAS
from
.layer.activation
import
SELU
#DEFINE_ALIAS
from
.layer.activation
import
LeakyReLU
#DEFINE_ALIAS
from
.layer.activation
import
Sigmoid
#DEFINE_ALIAS
from
.layer.activation
import
Hardsigmoid
#DEFINE_ALIAS
from
.layer.activation
import
LogSigmoid
from
.layer.activation
import
Softmax
#DEFINE_ALIAS
from
.layer.activation
import
Softplus
#DEFINE_ALIAS
...
...
@@ -70,6 +72,7 @@ from .layer.activation import Softsign #DEFINE_ALIAS
from
.layer.activation
import
Tanhshrink
#DEFINE_ALIAS
from
.layer.activation
import
LogSoftmax
#DEFINE_ALIAS
from
.layer.activation
import
HSigmoid
#DEFINE_ALIAS
from
.layer.activation
import
Maxout
#DEFINE_ALIAS
from
.layer.common
import
BilinearTensorProduct
#DEFINE_ALIAS
from
.layer.common
import
Pool2D
#DEFINE_ALIAS
from
.layer.common
import
Pad2D
#DEFINE_ALIAS
...
...
python/paddle/nn/functional/__init__.py
浏览文件 @
0025e0d8
...
...
@@ -29,14 +29,13 @@ from . import pooling
__all__
+=
pooling
.
__all__
from
.
import
loss
__all__
+=
loss
.
__all__
from
.activation
import
brelu
#DEFINE_ALIAS
from
.activation
import
elu
#DEFINE_ALIAS
from
.activation
import
erf
#DEFINE_ALIAS
from
.activation
import
gelu
#DEFINE_ALIAS
from
.activation
import
hardshrink
#DEFINE_ALIAS
from
.activation
import
hardtanh
#DEFINE_ALIAS
from
.activation
import
hard
_
sigmoid
#DEFINE_ALIAS
from
.activation
import
hard
_
swish
#DEFINE_ALIAS
from
.activation
import
hardsigmoid
#DEFINE_ALIAS
from
.activation
import
hardswish
#DEFINE_ALIAS
from
.activation
import
hsigmoid
#DEFINE_ALIAS
from
.activation
import
leaky_relu
#DEFINE_ALIAS
from
.activation
import
log_sigmoid
#DEFINE_ALIAS
...
...
python/paddle/nn/functional/activation.py
浏览文件 @
0025e0d8
...
...
@@ -13,11 +13,7 @@
# limitations under the License.
# TODO: define activation functions of neural network
from
...fluid.layers
import
brelu
#DEFINE_ALIAS
from
...fluid.layers
import
erf
#DEFINE_ALIAS
from
...fluid.layers
import
hard_sigmoid
#DEFINE_ALIAS
from
...fluid.layers
import
hard_swish
#DEFINE_ALIAS
from
...fluid.layers
import
maxout
#DEFINE_ALIAS
from
...fluid.layers
import
soft_relu
#DEFINE_ALIAS
from
...fluid.layers
import
swish
#DEFINE_ALIAS
from
...fluid.layers
import
sigmoid
#DEFINE_ALIAS
...
...
@@ -25,14 +21,13 @@ from ...fluid.layers import thresholded_relu #DEFINE_ALIAS
from
...tensor.math
import
tanh
#DEFINE_ALIAS
__all__
=
[
'brelu'
,
'elu'
,
'erf'
,
'gelu'
,
'hardshrink'
,
'hardtanh'
,
'hard
_
sigmoid'
,
'hard
_
swish'
,
'hardsigmoid'
,
'hardswish'
,
'hsigmoid'
,
'leaky_relu'
,
'log_sigmoid'
,
...
...
@@ -75,10 +70,10 @@ def elu(x, alpha=1.0, name=None):
alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
...
...
@@ -89,7 +84,7 @@ def elu(x, alpha=1.0, name=None):
paddle.disable_static()
x = paddle.to_tensor(np.array([[-1,6],[1,15.6]]))
out = F.elu(x, alpha=0.2)
out = F.elu(x, alpha=0.2)
# [[-0.12642411 6. ]
# [ 1. 15.6 ]]
"""
...
...
@@ -123,16 +118,16 @@ def gelu(x, approximate=False, name=None):
.. math::
gelu(x) = 0.5 * x * (1 + erf(
\\
frac{x}{
\\
sqrt{2}}))
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
approximate (bool, optional): Wether to enable approximation. Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
...
...
@@ -265,6 +260,109 @@ def hardtanh(x, min=-1.0, max=1.0, name=None):
return
out
def
hardsigmoid
(
x
,
name
=
None
):
"""
hardsigmoid activation.
A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
which is much faster than sigmoid.
.. math::
hardsigmoid(x)=
\\
left
\\
{
\\
begin{aligned}
&0, & &
\\
text{if } x
\\
leq -3
\\\\
&1, & &
\\
text{if } x
\\
geq 3
\\\\
&x/6 + 1/2, & &
\\
text{otherwise}
\\
end{aligned}
\\
right.
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-4., 5., 1.])
out = F.hardsigmoid(x) # [0., 1., 0.666667]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
hard_sigmoid
(
x
,
'slope'
,
0.1666666666666667
,
'offset'
,
0.5
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'hardsigmoid'
)
helper
=
LayerHelper
(
'hardsigmoid'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'hard_sigmoid'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'slope'
:
0.1666666666666667
,
'offset'
:
0.5
})
return
out
def
hardswish
(
x
,
name
=
None
):
"""
hardswish activation
hardswish is proposed in MobileNetV3, and performs better in computational stability
and efficiency compared to swish function. For more details please refer
to: https://arxiv.org/pdf/1905.02244.pdf
.. math::
hardswish(x)=
\\
left
\\
{
\\
begin{aligned}
&0, & &
\\
text{if } x
\\
leq -3
\\\\
&x, & &
\\
text{if } x
\\
geq 3
\\\\
&
\\
frac{x(x+3)}{6}, & &
\\
text{otherwise}
\\
end{aligned}
\\
right.
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.to_tensor([-4., 5., 1.])
out = F.hardswish(x) # [0., 5., 0.666667]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
hard_swish
(
x
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'hardswish'
)
helper
=
LayerHelper
(
'hardswish'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'hard_swish'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
})
return
out
def
hsigmoid
(
input
,
label
,
weight
,
...
...
@@ -489,7 +587,7 @@ def prelu(x, weight, name=None):
assert
len
(
weight
.
shape
)
==
1
,
"The dim count of weight shape should be 1 in prelu()."
# NOTE(): The input of this API should be ``N,C,...`` format,
# NOTE(): The input of this API should be ``N,C,...`` format,
# which means x.shape[0] is batch_size and x.shape[0] is channel.
mode
=
'all'
if
weight
.
shape
[
0
]
>
1
:
...
...
@@ -559,15 +657,15 @@ def log_sigmoid(x, name=None):
.. math::
log
\\
_sigmoid(x) = log
\\
frac{1}{1 + e^{-x}}
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type and shape as ``x`` .
Examples:
.. code-block:: python
...
...
@@ -591,6 +689,81 @@ def log_sigmoid(x, name=None):
return
out
def
maxout
(
x
,
groups
,
axis
=
1
,
name
=
None
):
"""
maxout activation.
Assumed the input shape is (N, Ci, H, W).
The output shape is (N, Co, H, W).
Then Co = Ci/groups and the operator formula is as follows:
.. math::
&out_{si+j} =
\\
max_{k} x_{gsi + sk + j}
\\\\
&g = groups
\\\\
&s =
\\
frac{input.size}{num
\\
_channels}
\\\\
&0
\\
le i <
\\
frac{num
\\
_channels}{groups}
\\\\
&0
\\
le j < s
\\\\
&0
\\
le k < groups
Parameters:
x (Tensor): The input is 4-D Tensor with shape [N, C, H, W] or [N, H, W, C], the data type
of input is float32 or float64.
groups (int, optional): The groups number of maxout. `groups` specifies the
index of channel dimension where maxout will be performed. This must be
a factor of number of features. Default is 1.
axis (int, optional): The axis along which to perform maxout calculations.
It should be 1 when data format is NCHW, be -1 or 3 when data format
is NHWC. If ``axis`` < 0, it works the same way as :math:`axis + D` ,
where D is the dimensions of ``x`` . ``axis`` only supports 1, 3 or -1.
Default is 1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same data type as ``x`` .
Examples:
.. code-block:: python
import paddle
import paddle.nn.functional as F
x = paddle.rand([1, 2, 3, 4])
# [[[[0.5002636 0.22272532 0.17402348 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.02879342 0.88725346 0.61093384 0.38833922]]
# [[0.5231306 0.03807496 0.91661984 0.15602879]
# [0.666127 0.616567 0.30741522 0.24044901]
# [0.7142536 0.7351477 0.31588817 0.23782359]]]]
out = F.maxout(x, groups=2)
# [[[[0.5231306 0.22272532 0.91661984 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.7142536 0.88725346 0.61093384 0.38833922]]]]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
maxout
(
x
,
'groups'
,
groups
,
'axis'
,
axis
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float32'
,
'float64'
],
'maxout'
)
if
axis
not
in
[
1
,
-
1
,
3
]:
raise
ValueError
(
"Attr(axis) should be 1 when data format is NCHW, -1 or 3 when data format is NHWC. Received "
"Attr(axis): %s."
%
str
(
axis
))
if
axis
==
-
1
:
axis
=
3
helper
=
LayerHelper
(
'maxout'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'maxout'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
},
attrs
=
{
'groups'
:
groups
,
'axis'
:
axis
})
return
out
def
relu6
(
x
,
name
=
None
):
"""
relu6 activation
...
...
@@ -778,7 +951,7 @@ def softmax(x, axis=-1, dtype=None, name=None):
:math:`axis + D` . Default is -1.
dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data
type of the output tensor. If dtype is specified, ``x`` is casted
to ``dtype`` before the operation is performed. This is useful for
to ``dtype`` before the operation is performed. This is useful for
preventing data type overflows. Supported dtype: float32, float64.
If ``dtype`` is None, the output Tensor has the same dtype as x.
Default is None.
...
...
@@ -1051,13 +1224,13 @@ def log_softmax(x, axis=-1, dtype=None, name=None):
:math:`axis + D` . Default is -1.
dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data
type of the output tensor. If dtype is specified, ``x`` is casted
to ``dtype`` before the operation is performed. This is useful for
to ``dtype`` before the operation is performed. This is useful for
preventing data type overflows. Supported dtype: float32, float64.
If ``dtype`` is None, the output Tensor has the same dtype as x.
Default is None.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Returns:
A Tensor with the same shape and data type (use ``dtype`` if it is
specified) as x.
...
...
python/paddle/nn/layer/activation.py
浏览文件 @
0025e0d8
...
...
@@ -18,6 +18,7 @@ __all__ = [
'ELU'
,
'GELU'
,
'Hardshrink'
,
'Hardswish'
,
'Tanh'
,
'Hardtanh'
,
'PReLU'
,
...
...
@@ -26,6 +27,7 @@ __all__ = [
'SELU'
,
'LeakyReLU'
,
'Sigmoid'
,
'Hardsigmoid'
,
'Softmax'
,
'Softplus'
,
'Softshrink'
,
...
...
@@ -33,6 +35,7 @@ __all__ = [
'Tanhshrink'
,
'LogSigmoid'
,
'LogSoftmax'
,
'Maxout'
,
'HSigmoid'
,
]
...
...
@@ -50,18 +53,18 @@ class ELU(layers.Layer):
ELU Activation.
.. math::
ELU(x) = max(0, x) + min(0,
\\
alpha * (e^{x}-1))
Parameters:
alpha (float, optional): The 'alpha' value of the ELU formulation. Default is 1.0.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
...
...
@@ -106,11 +109,11 @@ class GELU(layers.Layer):
approximate (bool, optional): Wether to enable approximation. Default is False.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
...
...
@@ -120,7 +123,7 @@ class GELU(layers.Layer):
paddle.disable_static()
x = paddle.to_tensor(np.array([[-1, 0.5],[1, 1.5]]))
m = paddle.nn.GELU()
out = m(x) # [-0.158655 0.345731 0.841345 1.39979]
...
...
@@ -184,6 +187,52 @@ class Hardshrink(layers.Layer):
return
F
.
hardshrink
(
x
,
self
.
_threshold
,
self
.
_name
)
class
Hardswish
(
layers
.
Layer
):
"""
Hardswish activation
Hardswish is proposed in MobileNetV3, and performs better in computational stability
and efficiency compared to swish function. For more details please refer
to: https://arxiv.org/pdf/1905.02244.pdf
.. math::
Hardswish(x)=
\\
left
\\
{
\\
begin{aligned}
&0, & &
\\
text{if } x
\\
leq -3
\\\\
&x, & &
\\
text{if } x
\\
geq 3
\\\\
&
\\
frac{x(x+3)}{6}, & &
\\
text{otherwise}
\\
end{aligned}
\\
right.
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([-4., 5., 1.])
m = paddle.nn.Hardswish()
out = m(x) # [0., 5., 0.666667]
"""
def
__init__
(
self
,
name
=
None
):
super
(
Hardswish
,
self
).
__init__
()
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
hardswish
(
x
,
self
.
_name
)
class
Tanh
(
layers
.
Layer
):
"""
Tanh Activation.
...
...
@@ -240,11 +289,11 @@ class Hardtanh(layers.Layer):
max (float, optional): The value of max for Hardtanh. Default is 1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
...
...
@@ -274,7 +323,7 @@ class HSigmoid(layers.Layer):
:alias: paddle.nn.HSigmoid,paddle.nn.layer.HSigmoid,paddle.nn.layer.activation.HSigmoid
Hierarchical Sigmoid Layer.
The hierarchical sigmoid organizes the classes into a complete binary tree to reduce the computational complexity
and speed up the model training, especially the training of language model.
Each leaf node of the complete binary tree represents a class(word) and each non-leaf node acts as a binary classifier.
...
...
@@ -309,7 +358,7 @@ class HSigmoid(layers.Layer):
is set to False, no bias will be added. If it is set to None or one attribute of ParamAttr,
hsigmoid will create a ParamAttr as bias_attr. If the Initializer of the bias_attr is not
set, the bias is initialized zero. Default: None.
is_custom (bool, optional): Whether use custom binary tree. If it's True, `path_table` and
is_custom (bool, optional): Whether use custom binary tree. If it's True, `path_table` and
`path_code` should be passed to its forward method, otherwise `path_table` and `path_code`
should not be passed to its forward method. Default: False.
is_sparse (bool, optional): Whether use sparse updating instead of dense updating, if it's True, the
...
...
@@ -414,19 +463,19 @@ class PReLU(layers.Layer):
Parameters:
num_parameters (int, optional): Number of `weight` to learn. The supported values are:
1 - a single parameter `alpha` is used for all input channels;
1 - a single parameter `alpha` is used for all input channels;
Number of channels - a seperate `alpha` is used for each input channel.
Default is 1.
init (float, optional): Init value of learnable `weight`. Default is 0.25.
weight_attr(ParamAttr, optional): The parameter attribute for the learnable `weight`.
weight_attr(ParamAttr, optional): The parameter attribute for the learnable `weight`.
Default is None. For more information, please refer to :ref:`api_fluid_ParamAttr`.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape. Default dtype is float32.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
...
...
@@ -487,7 +536,7 @@ class ReLU(layers.Layer):
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
...
...
@@ -613,11 +662,11 @@ class LeakyReLU(layers.Layer):
:math:`x < 0` . Default is 0.01.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
...
...
@@ -643,11 +692,11 @@ class LeakyReLU(layers.Layer):
class
Sigmoid
(
layers
.
Layer
):
"""
this interface is used to construct a callable object of the ``Sigmoid`` class. This layer calcluate the `sigmoid` of input x.
.. math::
Sigmoid(x) =
\f
rac{1}{1 + e^{-x}}
Parameters:
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
...
...
@@ -656,7 +705,7 @@ class Sigmoid(layers.Layer):
Returns:
A callable object of Sigmoid.
Examples:
.. code-block:: python
...
...
@@ -680,6 +729,53 @@ class Sigmoid(layers.Layer):
return
F
.
sigmoid
(
x
,
self
.
name
)
class
Hardsigmoid
(
layers
.
Layer
):
"""
This interface is used to construct a callable object of the ``Hardsigmoid`` class.
This layer calcluate the `hardsigmoid` of input x.
A 3-part piecewise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
which is much faster than sigmoid.
.. math::
Hardsigmoid(x)=
\\
left
\\
{
\\
begin{aligned}
&0, & &
\\
text{if } x
\\
leq -3
\\\\
&1, & &
\\
text{if } x
\\
geq 3
\\\\
&x/6 + 1/2, & &
\\
text{otherwise}
\\
end{aligned}
\\
right.
Parameters:
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Shape:
x: N-D tensor, available dtype is float32, float64.
Returns:
A callable object of Hardsigmoid.
Examples:
.. code-block:: python
import paddle
m = paddle.nn.Sigmoid()
x = paddle.to_tensor([-4., 5., 1.])
out = m(x) # [0., 1, 0.666667]
"""
def
__init__
(
self
,
name
=
None
):
super
(
Hardsigmoid
,
self
).
__init__
()
self
.
name
=
name
def
forward
(
self
,
x
):
return
F
.
hardsigmoid
(
x
,
self
.
name
)
class
Softplus
(
layers
.
Layer
):
"""
Softplus Activation
...
...
@@ -842,7 +938,7 @@ class Tanhshrink(layers.Layer):
class
LogSigmoid
(
layers
.
Layer
):
"""
LogSigmoid Activation.
.. math::
LogSigmoid(x) = log
\\
frac{1}{1 + e^{-x}}
...
...
@@ -851,11 +947,11 @@ class LogSigmoid(layers.Layer):
x (Tensor): The input Tensor with data type float32, or float64.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
...
...
@@ -961,7 +1057,7 @@ class Softmax(layers.Layer):
:math:`axis + D` . Default is -1.
dtype (str|np.dtype|core.VarDesc.VarType, optional): The desired data
type of the output tensor. If dtype is specified, ``x`` is casted
to ``dtype`` before the operation is performed. This is useful for
to ``dtype`` before the operation is performed. This is useful for
preventing data type overflows. Supported dtype: float32, float64.
If ``dtype`` is None, the output Tensor has the same dtype as x.
Default is None.
...
...
@@ -1013,7 +1109,7 @@ class LogSoftmax(layers.Layer):
.. math::
Out[i, j] = log(softmax(x))
Out[i, j] = log(softmax(x))
= log(
\\
frac{\exp(X[i, j])}{
\\
sum_j(exp(X[i, j])})
Parameters:
...
...
@@ -1023,7 +1119,7 @@ class LogSoftmax(layers.Layer):
same way as :math:`axis + D` . Default is -1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
...
...
@@ -1060,3 +1156,64 @@ class LogSoftmax(layers.Layer):
def
forward
(
self
,
x
):
return
F
.
log_softmax
(
x
,
self
.
_axis
)
class
Maxout
(
layers
.
Layer
):
"""
Maxout Activation.
Assumed the input shape is (N, Ci, H, W).
The output shape is (N, Co, H, W).
Then Co = Ci/groups and the operator formula is as follows:
.. math::
&out_{si+j} = \max_{k} x_{gsi + sk + j}
\\\\
&g = groups
\\\\
&s =
\\
frac{input.size}{num
\\
_channels}
\\\\
&0
\\
le i <
\\
frac{num
\\
_channels}{groups}
\\\\
&0
\\
le j < s
\\\\
&0
\\
le k < groups
Parameters:
groups (int, optional): The groups number of maxout. `groups` specifies the
index of channel dimension where maxout will be performed. This must be
a factor of number of features. Default is 1.
axis (int, optional): The axis along which to perform maxout calculations.
It should be 1 when data format is NCHW, be -1 or 3 when data format
is NHWC. If ``axis`` < 0, it works the same way as :math:`axis + D` ,
where D is the dimensions of ``x`` . Default is 1.
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: :math:`(N, C_{in}, H_{in}, W_{in})`
- output: :math:`(N, C_{out}, H_{out}, W_{out})`
Examples:
.. code-block:: python
import paddle
x = paddle.rand([1, 2, 3, 4])
# [[[[0.5002636 0.22272532 0.17402348 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.02879342 0.88725346 0.61093384 0.38833922]]
# [[0.5231306 0.03807496 0.91661984 0.15602879]
# [0.666127 0.616567 0.30741522 0.24044901]
# [0.7142536 0.7351477 0.31588817 0.23782359]]]]
m = paddle.nn.Maxout(groups=2)
out = m(x)
# [[[[0.5231306 0.22272532 0.91661984 0.2874594 ]
# [0.95313174 0.6228939 0.7129065 0.7087491 ]
# [0.7142536 0.88725346 0.61093384 0.38833922]]]]
"""
def
__init__
(
self
,
groups
,
axis
=
1
,
name
=
None
):
super
(
Maxout
,
self
).
__init__
()
self
.
_groups
=
groups
self
.
_axis
=
axis
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
maxout
(
x
,
self
.
_groups
,
self
.
_axis
,
self
.
_name
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录