Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
6dd9901b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6dd9901b
编写于
4月 11, 2020
作者:
Z
zhupengyang
提交者:
GitHub
4月 11, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add activation ops under paddle.nn and paddle.nn.functional: ReLU, LogSoftmax (#23258)
上级
03073937
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
454 addition
and
37 deletion
+454
-37
python/paddle/fluid/tests/unittests/test_activation_op.py
python/paddle/fluid/tests/unittests/test_activation_op.py
+71
-3
python/paddle/fluid/tests/unittests/test_log_softmax.py
python/paddle/fluid/tests/unittests/test_log_softmax.py
+107
-0
python/paddle/nn/__init__.py
python/paddle/nn/__init__.py
+4
-4
python/paddle/nn/functional/__init__.py
python/paddle/nn/functional/__init__.py
+3
-2
python/paddle/nn/functional/activation.py
python/paddle/nn/functional/activation.py
+163
-25
python/paddle/nn/layer/__init__.py
python/paddle/nn/layer/__init__.py
+2
-0
python/paddle/nn/layer/activation.py
python/paddle/nn/layer/activation.py
+104
-1
python/setup.py.in
python/setup.py.in
+0
-2
未找到文件。
python/paddle/fluid/tests/unittests/test_activation_op.py
浏览文件 @
6dd9901b
...
...
@@ -21,6 +21,8 @@ from op_test import OpTest
from
scipy.special
import
expit
,
erf
import
paddle
import
paddle.fluid
as
fluid
import
paddle.nn
as
nn
import
paddle.nn.functional
as
functional
from
paddle.fluid
import
compiler
,
Program
,
program_guard
...
...
@@ -759,9 +761,6 @@ class TestPow_factor_tensor(TestActivation):
self
.
check_grad
([
'X'
],
'Out'
)
def
test_api
(
self
):
import
paddle
import
paddle.fluid
as
fluid
input
=
np
.
random
.
uniform
(
1
,
2
,
[
11
,
17
]).
astype
(
"float32"
)
x
=
fluid
.
layers
.
data
(
name
=
"x"
,
shape
=
[
11
,
17
],
append_batch_size
=
False
,
dtype
=
"float32"
)
...
...
@@ -1003,5 +1002,74 @@ create_test_act_fp16_class(TestHardSigmoid)
create_test_act_fp16_class
(
TestSwish
)
create_test_act_fp16_class
(
TestHardSwish
)
class
TestNNReluAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
init_data
()
def
init_data
(
self
):
self
.
x_shape
=
[
10
,
12
]
self
.
x
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
x_shape
).
astype
(
np
.
float32
)
self
.
y
=
self
.
ref_forward
(
self
.
x
)
def
ref_forward
(
self
,
x
):
return
np
.
maximum
(
x
,
0
)
def
ref_backward
(
self
,
y
,
dy
):
y_t
=
y
.
copy
()
y_t
[
y_t
>
0
]
=
1
return
y_t
*
dy
def
check_api
(
self
,
place
=
fluid
.
CPUPlace
(),
inplace
=
False
):
main_program
=
Program
()
myrelu
=
nn
.
ReLU
(
inplace
)
with
fluid
.
program_guard
(
main_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
self
.
x_shape
)
x
.
stop_gradient
=
False
y
=
myrelu
(
x
)
fluid
.
backward
.
append_backward
(
fluid
.
layers
.
mean
(
y
))
exe
=
fluid
.
Executor
(
place
)
out
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
self
.
x
},
fetch_list
=
[
y
,
y
.
grad_name
,
x
.
grad_name
])
self
.
assertTrue
(
np
.
allclose
(
out
[
0
],
self
.
y
))
self
.
assertTrue
(
np
.
allclose
(
out
[
2
],
self
.
ref_backward
(
self
.
y
,
out
[
1
])))
with
fluid
.
dygraph
.
guard
(
place
):
x
=
fluid
.
dygraph
.
to_variable
(
self
.
x
)
y
=
myrelu
(
x
)
self
.
assertTrue
(
np
.
allclose
(
y
.
numpy
(),
self
.
y
))
def
test_check_api
(
self
):
places
=
[
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
CUDAPlace
(
0
))
for
place
in
places
:
for
inplace
in
[
True
,
False
]:
self
.
check_api
(
place
,
inplace
)
class
TestNNFunctionalReluAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
init_data
()
def
init_data
(
self
):
self
.
x_shape
=
[
10
,
12
]
self
.
x
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
x_shape
).
astype
(
np
.
float32
)
self
.
y
=
self
.
ref_forward
(
self
.
x
)
def
ref_forward
(
self
,
x
):
return
np
.
maximum
(
x
,
0
)
def
test_check_api
(
self
):
main_program
=
Program
()
with
fluid
.
program_guard
(
main_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
self
.
x_shape
)
y
=
functional
.
relu
(
x
)
exe
=
fluid
.
Executor
(
fluid
.
CPUPlace
())
out
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
self
.
x
},
fetch_list
=
[
y
])
self
.
assertTrue
(
np
.
allclose
(
out
[
0
],
self
.
y
))
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_log_softmax.py
0 → 100644
浏览文件 @
6dd9901b
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
import
paddle.nn
as
nn
import
paddle.nn.functional
as
functional
def
stable_softmax
(
x
):
shiftx
=
(
x
-
np
.
max
(
x
))
exps
=
np
.
exp
(
shiftx
)
return
exps
/
np
.
sum
(
exps
)
def
ref_log_softmax
(
x
,
axis
=
None
,
dtype
=
None
):
x_t
=
x
.
copy
()
if
dtype
is
not
None
:
x_t
=
x_t
.
astype
(
dtype
)
if
axis
is
None
:
axis
=
-
1
out
=
np
.
apply_along_axis
(
stable_softmax
,
axis
,
x_t
)
return
np
.
log
(
out
)
class
TestNNLogSoftmaxAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
init_data
()
def
init_data
(
self
):
self
.
x_shape
=
[
2
,
3
,
4
,
5
]
self
.
x
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
x_shape
).
astype
(
np
.
float32
)
def
check_api
(
self
,
place
=
fluid
.
CPUPlace
(),
axis
=
None
):
ref_out
=
ref_log_softmax
(
self
.
x
,
axis
)
main_program
=
fluid
.
Program
()
mylogsoftmax
=
nn
.
LogSoftmax
(
axis
)
with
fluid
.
program_guard
(
main_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
self
.
x_shape
)
y
=
mylogsoftmax
(
x
)
exe
=
fluid
.
Executor
(
place
)
out
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
self
.
x
},
fetch_list
=
[
y
])
self
.
assertTrue
(
np
.
allclose
(
out
[
0
],
ref_out
))
with
fluid
.
dygraph
.
guard
(
place
):
x
=
fluid
.
dygraph
.
to_variable
(
self
.
x
)
y
=
mylogsoftmax
(
x
)
self
.
assertTrue
(
np
.
allclose
(
y
.
numpy
(),
ref_out
))
def
test_check_api
(
self
):
places
=
[
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
CUDAPlace
(
0
))
for
place
in
places
:
for
axis
in
[
None
,
2
]:
self
.
check_api
(
place
,
axis
)
class
TestNNFunctionalLogSoftmaxAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
init_data
()
def
init_data
(
self
):
self
.
x_shape
=
[
2
,
3
,
4
,
5
]
self
.
x
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
x_shape
).
astype
(
np
.
float32
)
def
check_api
(
self
,
place
=
fluid
.
CPUPlace
(),
axis
=
None
,
dtype
=
None
):
ref_out
=
ref_log_softmax
(
self
.
x
,
axis
,
dtype
)
main_program
=
fluid
.
Program
()
mylogsoftmax
=
nn
.
LogSoftmax
(
axis
)
with
fluid
.
program_guard
(
main_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
self
.
x_shape
)
y
=
functional
.
log_softmax
(
x
,
axis
,
dtype
)
exe
=
fluid
.
Executor
(
place
)
out
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
self
.
x
},
fetch_list
=
[
y
])
self
.
assertTrue
(
np
.
allclose
(
out
[
0
],
ref_out
))
with
fluid
.
dygraph
.
guard
(
place
):
x
=
fluid
.
dygraph
.
to_variable
(
self
.
x
)
y
=
functional
.
log_softmax
(
x
,
axis
,
dtype
)
self
.
assertTrue
(
np
.
allclose
(
y
.
numpy
(),
ref_out
))
def
test_check_api
(
self
):
places
=
[
fluid
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
fluid
.
CUDAPlace
(
0
))
for
place
in
places
:
self
.
check_api
(
place
,
None
,
None
)
self
.
check_api
(
place
,
None
,
np
.
float64
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/nn/__init__.py
浏览文件 @
6dd9901b
...
...
@@ -81,10 +81,10 @@ from .layer.conv import Conv2D, Conv2DTranspose, Conv3D, Conv3DTranspose #DEFIN
from
.layer.norm
import
InstanceNorm
#DEFINE_ALIAS
# from .layer.norm import SpectralNorm #DEFINE_ALIAS
# from .layer.activation import PReLU #DEFINE_ALIAS
# from .layer.activation import ReLU
#DEFINE_ALIAS
from
.layer.activation
import
ReLU
#DEFINE_ALIAS
# from .layer.activation import Sigmoid #DEFINE_ALIAS
# from .layer.activation import Softmax #DEFINE_ALIAS
# from .layer.activation import LogSoftmax
#DEFINE_ALIAS
from
.layer.activation
import
LogSoftmax
#DEFINE_ALIAS
# from .layer.rnn import RNNCell #DEFINE_ALIAS
# from .layer.rnn import GRUCell #DEFINE_ALIAS
# from .layer.rnn import LSTMCell #DEFINE_ALIAS
...
...
@@ -189,7 +189,7 @@ from .functional.conv import conv3d_transpose #DEFINE_ALIAS
# from .functional.activation import logsigmoid #DEFINE_ALIAS
# from .functional.activation import maxout #DEFINE_ALIAS
# from .functional.activation import prelu #DEFINE_ALIAS
# from .functional.activation import relu
#DEFINE_ALIAS
from
.functional.activation
import
relu
#DEFINE_ALIAS
# from .functional.activation import relu6 #DEFINE_ALIAS
# from .functional.activation import selu #DEFINE_ALIAS
# from .functional.activation import sigmoid #DEFINE_ALIAS
...
...
@@ -201,7 +201,7 @@ from .functional.conv import conv3d_transpose #DEFINE_ALIAS
# from .functional.activation import swish #DEFINE_ALIAS
# from .functional.activation import tanh_shrink #DEFINE_ALIAS
# from .functional.activation import thresholded_relu #DEFINE_ALIAS
# from .functional.activation import log_softmax
#DEFINE_ALIAS
from
.functional.activation
import
log_softmax
#DEFINE_ALIAS
# from .functional.extension import add_position_encoding #DEFINE_ALIAS
# from .functional.extension import autoincreased_step_counter #DEFINE_ALIAS
# from .functional.extension import continuous_value_model #DEFINE_ALIAS
...
...
python/paddle/nn/functional/__init__.py
浏览文件 @
6dd9901b
...
...
@@ -102,6 +102,7 @@ from .conv import conv3d_transpose #DEFINE_ALIAS
# from .vision import space_to_depth #DEFINE_ALIAS
# from .vision import yolo_box #DEFINE_ALIAS
# from .vision import yolov3_loss #DEFINE_ALIAS
from
.
import
activation
# from .activation import brelu #DEFINE_ALIAS
# from .activation import elu #DEFINE_ALIAS
# from .activation import erf #DEFINE_ALIAS
...
...
@@ -114,7 +115,7 @@ from .conv import conv3d_transpose #DEFINE_ALIAS
# from .activation import logsigmoid #DEFINE_ALIAS
# from .activation import maxout #DEFINE_ALIAS
# from .activation import prelu #DEFINE_ALIAS
# from .activation import relu
#DEFINE_ALIAS
from
.activation
import
relu
#DEFINE_ALIAS
# from .activation import relu6 #DEFINE_ALIAS
# from .activation import selu #DEFINE_ALIAS
# from .activation import sigmoid #DEFINE_ALIAS
...
...
@@ -126,7 +127,7 @@ from .conv import conv3d_transpose #DEFINE_ALIAS
# from .activation import swish #DEFINE_ALIAS
# from .activation import tanh_shrink #DEFINE_ALIAS
# from .activation import thresholded_relu #DEFINE_ALIAS
# from .activation import log_softmax
#DEFINE_ALIAS
from
.activation
import
log_softmax
#DEFINE_ALIAS
# from .extension import add_position_encoding #DEFINE_ALIAS
# from .extension import autoincreased_step_counter #DEFINE_ALIAS
# from .extension import continuous_value_model #DEFINE_ALIAS
...
...
python/paddle/nn/functional/activation.py
浏览文件 @
6dd9901b
...
...
@@ -12,29 +12,167 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
warnings
from
...fluid.layer_helper
import
LayerHelper
from
...fluid.framework
import
in_dygraph_mode
,
convert_np_dtype_to_dtype_
from
...fluid
import
core
# TODO: define activation functions of neural network
# __all__ = ['brelu',
# 'elu',
# 'erf',
# 'gelu',
# 'hard_shrink',
# 'hard_sigmoid',
# 'hard_swish',
# 'hsigmoid',
# 'leaky_relu',
# 'logsigmoid',
# 'maxout',
# 'prelu',
# 'relu',
# 'relu6',
# 'selu',
# 'sigmoid',
# 'soft_relu',
# 'softmax',
# 'softplus',
# 'softshrink',
# 'softsign',
# 'swish',
# 'tanh_shrink',
# 'thresholded_relu',
# 'log_softmax']
__all__
=
[
# 'brelu',
# 'elu',
# 'erf',
# 'gelu',
# 'hard_shrink',
# 'hard_sigmoid',
# 'hard_swish',
# 'hsigmoid',
# 'leaky_relu',
# 'logsigmoid',
# 'maxout',
# 'prelu',
'relu'
,
# 'relu6',
# 'selu',
# 'sigmoid',
# 'soft_relu',
# 'softmax',
# 'softplus',
# 'softshrink',
# 'softsign',
# 'swish',
# 'tanh_shrink',
# 'thresholded_relu',
'log_softmax'
,
]
def
relu
(
input
,
inplace
=
False
,
name
=
None
):
"""
ReLU Activation.
.. math:
out = max(x, 0)
Parameters:
input (Variable): The input variable. A multi-dimension Tensor with type float16, float32, or float64.
inplace (bool, optional): If inplace is True, the input and output of ``ReLU`` are the same variable.
Otherwise, the input and output of ``ReLU`` are different variables. Default: False. Note that if x is
more than one OPs' input, inplace must be False.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Output of relu operator, a Tensor with shape same as input
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.nn.functional as functional
import numpy as np
data = np.array([-2, 0, 1]).astype('float32')
with fluid.dygraph.guard():
data = fluid.dygraph.to_variable(data)
res = functional.relu(data) # [0, 0, 1]
"""
if
in_dygraph_mode
():
if
inplace
:
warnings
.
warn
(
"Inplace on ReLU is not allowed and will be discarded in dygraph mode currently."
)
return
core
.
ops
.
relu
(
input
)
helper
=
LayerHelper
(
'relu'
,
**
locals
())
outs
=
input
if
inplace
else
helper
.
create_variable_for_type_inference
(
input
.
dtype
)
helper
.
append_op
(
type
=
'relu'
,
inputs
=
{
'X'
:
[
input
]},
outputs
=
{
'Out'
:
outs
})
return
outs
def
log_softmax
(
input
,
axis
=
None
,
dtype
=
None
,
name
=
None
):
"""
This operator implements the log_softmax layer. The calculation process is as follows:
.. math::
Out[i, j] = log(softmax(x))
= log(
\\
frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])})
Parameters:
input (Variable): The input variable. A multi-dimension Tensor with type float32, or float64.
axis (int, optional): The index of dimension to perform softmax calculations, it should be in
range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None.
None and -1 means the last dimension.
dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified,
the input tensor is casted to dtype before the operation is performed. This is useful for
preventing data type overflows. Default: None. Supported dtype: float32 or float64
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Variable: ``Tensor`` indicates the output of softmax. The data type and shape are the same as ``input``.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.nn.functional as F
import numpy as np
data = np.array([[[-2.0, 3.0, -4.0, 5.0],
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]],
[[1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[6.0, 7.0, 8.0, 9.0]]]).astype('float32')
with fluid.dygraph.guard():
data = fluid.dygraph.to_variable(data)
res = F.log_softmax(data, -1)
# [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948]
# [ -2.1270514 -9.127051 -0.12705144 -11.127051 ]
# [-16.313261 -17.313261 -1.3132617 -0.31326184]]
# [[ -3.0518122 -6.051812 -7.051812 -0.051812 ]
# [-12.313267 -1.3132664 -0.3132665 -15.313267 ]
# [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]]
"""
axis
=
-
1
if
axis
is
None
else
axis
dtype
=
convert_np_dtype_to_dtype_
(
dtype
)
if
dtype
is
not
None
else
dtype
if
in_dygraph_mode
():
outs_cast
=
input
if
dtype
is
None
\
else
core
.
ops
.
cast
(
input
,
'in_dtype'
,
input
.
dtype
,
'out_dtype'
,
dtype
)
outs_softmax
=
core
.
ops
.
softmax
(
outs_cast
,
'axis'
,
axis
,
'use_cudnn'
,
False
)
return
core
.
ops
.
log
(
outs_softmax
)
helper
=
LayerHelper
(
"log_softmax"
,
**
locals
())
outs_cast
=
input
if
dtype
is
not
None
:
outs_cast
=
helper
.
create_variable_for_type_inference
(
dtype
)
helper
.
append_op
(
type
=
'cast'
,
inputs
=
{
'X'
:
input
},
outputs
=
{
'Out'
:
outs_cast
},
attrs
=
{
'in_dtype'
:
input
.
dtype
,
'out_dtype'
:
dtype
})
outs_softmax
=
helper
.
create_variable_for_type_inference
(
outs_cast
.
dtype
)
helper
.
append_op
(
type
=
'softmax'
,
inputs
=
{
'X'
:
outs_cast
},
outputs
=
{
'Out'
:
outs_softmax
},
attrs
=
{
'axis'
:
axis
,
'use_cudnn'
:
False
})
outs_log
=
helper
.
create_variable_for_type_inference
(
outs_softmax
.
dtype
)
helper
.
append_op
(
type
=
'log'
,
inputs
=
{
'X'
:
outs_softmax
},
outputs
=
{
'Out'
:
outs_log
})
return
outs_log
python/paddle/nn/layer/__init__.py
浏览文件 @
6dd9901b
...
...
@@ -14,10 +14,12 @@
# TODO: define activation functions of neural network
from
.
import
activation
from
.
import
loss
from
.
import
conv
from
.
import
norm
from
.activation
import
*
from
.loss
import
*
from
.conv
import
*
from
.norm
import
*
python/paddle/nn/layer/activation.py
浏览文件 @
6dd9901b
...
...
@@ -12,5 +12,108 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from
...fluid.dygraph
import
layers
from
...fluid
import
core
from
...fluid.framework
import
in_dygraph_mode
from
..
import
functional
# TODO: define activation functions of neural network
# __all__ = ['PReLU', 'ReLU', 'Sigmoid', 'Softmax', 'LogSoftmax']
__all__
=
[
# 'PReLU',
'ReLU'
,
# 'Sigmoid',
# 'Softmax',
'LogSoftmax'
,
]
class
ReLU
(
layers
.
Layer
):
"""
ReLU Activation.
.. math:
out = max(x, 0)
Parameters:
inplace (bool, optional): If inplace is True, the input and output of
``ReLU`` are the same variable. Otherwise, the input and output of
``ReLU`` are different variables. Default False. Note that if x is
more than one OPs' input, inplace must be False.
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.nn as nn
import numpy as np
data = np.array([-2, 0, 1]).astype('float32')
my_relu = nn.ReLU()
with fluid.dygraph.guard():
data = fluid.dygraph.to_variable(data)
res = my_relu(data) # [0, 0, 1]
"""
def
__init__
(
self
,
inplace
=
False
):
super
(
ReLU
,
self
).
__init__
()
self
.
_inplace
=
inplace
def
forward
(
self
,
input
):
return
functional
.
relu
(
input
,
self
.
_inplace
)
class
LogSoftmax
(
layers
.
Layer
):
"""
This operator implements the log_softmax layer. The calculation process is as follows:
.. math::
Out[i, j] = log(softmax(x))
= log(
\\
frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])})
Parameters:
axis (int, optional): The index of dimension to perform softmax calculations, it should be in
range :math:`[-1, rank-1]`, while :math:`rank` is the rank of input variable. Default: None.
None and -1 means the last dimension.
dtype (np.dtype|core.VarDesc.VarType|str): The desired data type of returned tensor. If specified,
the input tensor is casted to dtype before the operation is performed. This is useful for
preventing data type overflows. Default: None. Supported dtype: float32 or float64
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.nn as nn
import numpy as np
data = np.array([[[-2.0, 3.0, -4.0, 5.0],
[3.0, -4.0, 5.0, -6.0],
[-7.0, -8.0, 8.0, 9.0]],
[[1.0, -2.0, -3.0, 4.0],
[-5.0, 6.0, 7.0, -8.0],
[6.0, 7.0, 8.0, 9.0]]]).astype('float32')
my_log_softnmax = nn.LogSoftmax()
with fluid.dygraph.guard():
data = fluid.dygraph.to_variable(data)
res = my_log_softnmax(data)
# [[[ -7.1278396 -2.1278396 -9.127839 -0.12783948]
# [ -2.1270514 -9.127051 -0.12705144 -11.127051 ]
# [-16.313261 -17.313261 -1.3132617 -0.31326184]]
# [[ -3.0518122 -6.051812 -7.051812 -0.051812 ]
# [-12.313267 -1.3132664 -0.3132665 -15.313267 ]
# [ -3.4401896 -2.4401896 -1.4401896 -0.44018966]]]
"""
def
__init__
(
self
,
axis
=
None
):
super
(
LogSoftmax
,
self
).
__init__
()
self
.
_axis
=
axis
def
forward
(
self
,
input
):
return
functional
.
log_softmax
(
input
,
self
.
_axis
)
python/setup.py.in
浏览文件 @
6dd9901b
...
...
@@ -105,8 +105,6 @@ write_version_py(filename='@PADDLE_BINARY_DIR@/python/paddle/version.py')
packages=['paddle',
'paddle.nn',
'paddle.nn.layer',
'paddle.libs',
'paddle.utils',
'paddle.dataset',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录