Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ed102ea1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ed102ea1
编写于
8月 22, 2020
作者:
W
WangXi
提交者:
GitHub
8月 22, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
【API】Add sign and tanh api (#26357)
上级
45711dad
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
202 addition
and
7 deletion
+202
-7
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+1
-4
python/paddle/fluid/layers/ops.py
python/paddle/fluid/layers/ops.py
+1
-1
python/paddle/fluid/tests/unittests/test_activation_op.py
python/paddle/fluid/tests/unittests/test_activation_op.py
+53
-0
python/paddle/fluid/tests/unittests/test_sign_op.py
python/paddle/fluid/tests/unittests/test_sign_op.py
+28
-0
python/paddle/nn/__init__.py
python/paddle/nn/__init__.py
+1
-0
python/paddle/nn/functional/__init__.py
python/paddle/nn/functional/__init__.py
+1
-0
python/paddle/nn/functional/activation.py
python/paddle/nn/functional/activation.py
+2
-0
python/paddle/nn/layer/activation.py
python/paddle/nn/layer/activation.py
+40
-0
python/paddle/tensor/math.py
python/paddle/tensor/math.py
+75
-2
未找到文件。
python/paddle/fluid/layers/nn.py
浏览文件 @
ed102ea1
...
...
@@ -14028,12 +14028,9 @@ def where(condition):
return out
@deprecated(since="2.0.0", update_to="paddle.sign")
def sign(x):
"""
:alias_main: paddle.sign
:alias: paddle.sign,paddle.tensor.sign,paddle.tensor.math.sign
:old_api: paddle.fluid.layers.sign
This OP returns sign of every element in `x`: 1 for positive, -1 for negative and 0 for zero.
Args:
...
...
python/paddle/fluid/layers/ops.py
浏览文件 @
ed102ea1
...
...
@@ -28,11 +28,11 @@ __activations_noattr__ = [
'tanh_shrink'
,
'softplus'
,
'softsign'
,
'tanh'
,
]
__unary_func__
=
[
'exp'
,
'tanh'
,
'atan'
,
'sqrt'
,
'rsqrt'
,
...
...
python/paddle/fluid/tests/unittests/test_activation_op.py
浏览文件 @
ed102ea1
...
...
@@ -191,6 +191,59 @@ class TestTanh(TestActivation, TestParameter):
self
.
dtype
=
np
.
float32
class
TestTanhAPI
(
unittest
.
TestCase
):
# test paddle.tanh, paddle.nn.tanh, paddle.nn.functional.tanh
def
setUp
(
self
):
self
.
dtype
=
'float32'
self
.
x_np
=
np
.
random
.
uniform
(
-
1
,
1
,
[
10
,
12
]).
astype
(
self
.
dtype
)
self
.
place
=
paddle
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
()
\
else
paddle
.
CPUPlace
()
def
test_static_api
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
x
=
paddle
.
data
(
'X'
,
[
10
,
12
],
self
.
dtype
)
out1
=
F
.
tanh
(
x
)
th
=
paddle
.
nn
.
Tanh
()
out2
=
th
(
x
)
exe
=
paddle
.
static
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out1
,
out2
])
out_ref
=
np
.
tanh
(
self
.
x_np
)
for
r
in
res
:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
),
True
)
def
test_dygraph_api
(
self
):
paddle
.
disable_static
(
self
.
place
)
x
=
paddle
.
to_variable
(
self
.
x_np
)
out1
=
F
.
tanh
(
x
)
out2
=
paddle
.
tanh
(
x
)
th
=
paddle
.
nn
.
Tanh
()
out3
=
th
(
x
)
out_ref
=
np
.
tanh
(
self
.
x_np
)
for
r
in
[
out1
,
out2
,
out3
]:
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
r
.
numpy
()),
True
)
paddle
.
enable_static
()
def
test_fluid_api
(
self
):
with
fluid
.
program_guard
(
fluid
.
Program
()):
x
=
fluid
.
data
(
'X'
,
[
10
,
12
],
self
.
dtype
)
out
=
fluid
.
layers
.
tanh
(
x
)
exe
=
fluid
.
Executor
(
self
.
place
)
res
=
exe
.
run
(
feed
=
{
'X'
:
self
.
x_np
},
fetch_list
=
[
out
])
out_ref
=
np
.
tanh
(
self
.
x_np
)
self
.
assertEqual
(
np
.
allclose
(
out_ref
,
res
[
0
]),
True
)
def
test_errors
(
self
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
()):
# The input type must be Variable.
self
.
assertRaises
(
TypeError
,
F
.
tanh
,
1
)
# The input dtype must be float16, float32.
x_int32
=
paddle
.
data
(
name
=
'x_int32'
,
shape
=
[
12
,
10
],
dtype
=
'int32'
)
self
.
assertRaises
(
TypeError
,
F
.
tanh
,
x_int32
)
# support the input dtype is float16
x_fp16
=
paddle
.
data
(
name
=
'x_fp16'
,
shape
=
[
12
,
10
],
dtype
=
'float16'
)
F
.
tanh
(
x_fp16
)
class
TestAtan
(
TestActivation
,
TestParameter
):
def
setUp
(
self
):
self
.
op_type
=
"atan"
...
...
python/paddle/fluid/tests/unittests/test_sign_op.py
浏览文件 @
ed102ea1
...
...
@@ -17,6 +17,7 @@ from __future__ import print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
...
...
@@ -54,5 +55,32 @@ class TestSignOpError(unittest.TestCase):
fluid
.
layers
.
sign
(
input4
)
class
TestSignAPI
(
unittest
.
TestCase
):
def
test_dygraph
(
self
):
with
fluid
.
dygraph
.
guard
():
np_x
=
np
.
array
([
-
1.
,
0.
,
-
0.
,
1.2
,
1.5
],
dtype
=
'float64'
)
x
=
paddle
.
to_tensor
(
np_x
)
z
=
paddle
.
sign
(
x
)
np_z
=
z
.
numpy
()
z_expected
=
np
.
sign
(
np_x
)
self
.
assertEqual
((
np_z
==
z_expected
).
all
(),
True
)
def
test_static
(
self
):
with
program_guard
(
Program
(),
Program
()):
# The input type of sign_op must be Variable or numpy.ndarray.
input1
=
12
self
.
assertRaises
(
TypeError
,
paddle
.
tensor
.
math
.
sign
,
input1
)
# The input dtype of sign_op must be float16, float32, float64.
input2
=
fluid
.
layers
.
data
(
name
=
'input2'
,
shape
=
[
12
,
10
],
dtype
=
"int32"
)
input3
=
fluid
.
layers
.
data
(
name
=
'input3'
,
shape
=
[
12
,
10
],
dtype
=
"int64"
)
self
.
assertRaises
(
TypeError
,
paddle
.
tensor
.
math
.
sign
,
input2
)
self
.
assertRaises
(
TypeError
,
paddle
.
tensor
.
math
.
sign
,
input3
)
input4
=
fluid
.
layers
.
data
(
name
=
'input4'
,
shape
=
[
4
],
dtype
=
"float16"
)
paddle
.
sign
(
input4
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/nn/__init__.py
浏览文件 @
ed102ea1
...
...
@@ -54,6 +54,7 @@ from .decode import gather_tree #DEFINE_ALIAS
# from .input import Input #DEFINE_ALIAS
from
.layer.activation
import
ELU
from
.layer.activation
import
GELU
from
.layer.activation
import
Tanh
from
.layer.activation
import
Hardshrink
from
.layer.activation
import
Hardtanh
from
.layer.activation
import
PReLU
...
...
python/paddle/nn/functional/__init__.py
浏览文件 @
ed102ea1
...
...
@@ -50,6 +50,7 @@ from .activation import softplus #DEFINE_ALIAS
from
.activation
import
softshrink
#DEFINE_ALIAS
from
.activation
import
softsign
#DEFINE_ALIAS
from
.activation
import
swish
#DEFINE_ALIAS
from
.activation
import
tanh
#DEFINE_ALIAS
from
.activation
import
tanhshrink
#DEFINE_ALIAS
from
.activation
import
thresholded_relu
#DEFINE_ALIAS
from
.activation
import
log_softmax
#DEFINE_ALIAS
...
...
python/paddle/nn/functional/activation.py
浏览文件 @
ed102ea1
...
...
@@ -22,6 +22,7 @@ from ...fluid.layers import soft_relu #DEFINE_ALIAS
from
...fluid.layers
import
swish
#DEFINE_ALIAS
from
...fluid.layers
import
sigmoid
#DEFINE_ALIAS
from
...fluid.layers
import
thresholded_relu
#DEFINE_ALIAS
from
...tensor.math
import
tanh
#DEFINE_ALIAS
__all__
=
[
'brelu'
,
...
...
@@ -47,6 +48,7 @@ __all__ = [
'softsign'
,
'sigmoid'
,
'swish'
,
'tanh'
,
'tanhshrink'
,
'thresholded_relu'
,
'log_softmax'
,
...
...
python/paddle/nn/layer/activation.py
浏览文件 @
ed102ea1
...
...
@@ -18,6 +18,7 @@ __all__ = [
'ELU'
,
'GELU'
,
'Hardshrink'
,
'Tanh'
,
'Hardtanh'
,
'PReLU'
,
'ReLU'
,
...
...
@@ -182,6 +183,45 @@ class Hardshrink(layers.Layer):
return
F
.
hardshrink
(
x
,
self
.
_threshold
,
self
.
_name
)
class
Tanh
(
layers
.
Layer
):
"""
Tanh Activation.
.. math::
Tanh(x) =
\\
frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}
Parameters:
name (str, optional): Name for the operation (optional, default is None).
For more information, please refer to :ref:`api_guide_Name`.
Shape:
- input: Tensor with any shape.
- output: Tensor with the same shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x = paddle.to_tensor(np.array([-0.4, -0.2, 0.1, 0.3]))
m = paddle.nn.Tanh()
out = m(x)
print(out.numpy())
# [-0.37994896 -0.19737532 0.09966799 0.29131261]
"""
def
__init__
(
self
,
name
=
None
):
super
(
Tanh
,
self
).
__init__
()
self
.
_name
=
name
def
forward
(
self
,
x
):
return
F
.
tanh
(
x
,
self
.
_name
)
class
Hardtanh
(
layers
.
Layer
):
"""
Hardtanh Activation
...
...
python/paddle/tensor/math.py
浏览文件 @
ed102ea1
...
...
@@ -51,14 +51,12 @@ from ..fluid.layers import reduce_sum #DEFINE_ALIAS
from
..fluid.layers
import
round
#DEFINE_ALIAS
from
..fluid.layers
import
rsqrt
#DEFINE_ALIAS
from
..fluid.layers
import
scale
#DEFINE_ALIAS
from
..fluid.layers
import
sign
#DEFINE_ALIAS
from
..fluid.layers
import
square
#DEFINE_ALIAS
from
..fluid.layers
import
stanh
#DEFINE_ALIAS
from
..fluid.layers
import
atan
#DEFINE_ALIAS
from
..fluid.layers
import
erf
#DEFINE_ALIAS
from
..fluid.layers
import
sqrt
#DEFINE_ALIAS
from
..fluid.layers
import
sin
#DEFINE_ALIAS
from
..fluid.layers
import
tanh
#DEFINE_ALIAS
from
..fluid.layers
import
increment
#DEFINE_ALIAS
from
..fluid.layers
import
multiplex
#DEFINE_ALIAS
...
...
@@ -1747,3 +1745,78 @@ def prod(x, axis=None, keepdim=False, dtype=None, name=None):
x
=
layers
.
cast
(
x
,
dtype
)
return
layers
.
reduce_prod
(
input
=
x
,
dim
=
axis
,
keep_dim
=
keepdim
,
name
=
name
)
def
sign
(
x
,
name
=
None
):
"""
This OP returns sign of every element in `x`: 1 for positive, -1 for negative and 0 for zero.
Args:
x(Tensor): The input tensor. The data type can be float16, float32 or float64.
name (str, optional): The default value is None. Normally there is no need for user to
set this property. For more information, please refer to :ref:`api_guide_Name`
Returns:
Tensor: The output sign tensor with identical shape and data type to the input :attr:`x`.
Examples:
.. code-block:: python
import numpy as np
import paddle
data = np.array([3.0, 0.0, -2.0, 1.7], dtype='float32')
paddle.disable_static()
x = paddle.to_tensor(data)
out = paddle.sign(x=x)
print(out) # [1.0, 0.0, -1.0, 1.0]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
sign
(
x
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'sign'
)
helper
=
LayerHelper
(
"sign"
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'sign'
,
inputs
=
{
'X'
:
[
x
]},
outputs
=
{
'Out'
:
[
out
]})
return
out
def
tanh
(
x
,
name
=
None
):
"""
Tanh Activation Operator.
.. math::
out =
\\
frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}
Args:
x (Tensor): Input of Tanh operator, an N-D Tensor, with data type float32, float64 or float16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Output of Tanh operator, a Tensor with same data type and shape as input.
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
x_data = np.array([-0.4, -0.2, 0.1, 0.3])
x = paddle.to_tensor(x_data)
out = paddle.tanh(x)
print(out.numpy())
# [-0.37994896 -0.19737532 0.09966799 0.29131261]
"""
if
in_dygraph_mode
():
return
core
.
ops
.
tanh
(
x
)
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'tanh'
)
helper
=
LayerHelper
(
'tanh'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
x
.
dtype
)
helper
.
append_op
(
type
=
'tanh'
,
inputs
=
{
'X'
:
x
},
outputs
=
{
'Out'
:
out
})
return
out
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录