Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1f26dce6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1f26dce6
编写于
2月 09, 2018
作者:
C
Cao Ying
提交者:
GitHub
2月 09, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #8302 from guoshengCS/add-python-layernorm
Add python wrapper for layer normalization.
上级
1185a1b5
9b743b85
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
131 addition
and
14 deletion
+131
-14
doc/api/v2/fluid/layers.rst
doc/api/v2/fluid/layers.rst
+6
-0
paddle/operators/layer_norm_op.cc
paddle/operators/layer_norm_op.cc
+0
-2
python/paddle/v2/fluid/__init__.py
python/paddle/v2/fluid/__init__.py
+21
-6
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+101
-4
python/paddle/v2/fluid/nets.py
python/paddle/v2/fluid/nets.py
+3
-2
未找到文件。
doc/api/v2/fluid/layers.rst
浏览文件 @
1f26dce6
...
...
@@ -323,6 +323,12 @@ batch_norm
.. autofunction:: paddle.v2.fluid.layers.batch_norm
:noindex:
layer_norm
----------
.. autofunction:: paddle.v2.fluid.layers.layer_norm
:noindex:
beam_search_decode
------------------
...
...
paddle/operators/layer_norm_op.cc
浏览文件 @
1f26dce6
...
...
@@ -116,8 +116,6 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
// check input
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of LayerNormOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Scale"
),
"Input(Scale) of LayerNormOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Mean"
),
"Input(Mean) of LayerNormOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Variance"
),
...
...
python/paddle/v2/fluid/__init__.py
浏览文件 @
1f26dce6
...
...
@@ -29,7 +29,7 @@ import optimizer
import
learning_rate_decay
import
backward
import
regularizer
from
param_attr
import
ParamAttr
from
param_attr
import
ParamAttr
,
WeightNormParamAttr
from
data_feeder
import
DataFeeder
from
core
import
LoDTensor
,
CPUPlace
,
CUDAPlace
from
distribute_transpiler
import
DistributeTranspiler
...
...
@@ -41,11 +41,26 @@ import profiler
Tensor
=
LoDTensor
__all__
=
framework
.
__all__
+
executor
.
__all__
+
[
'io'
,
'initializer'
,
'layers'
,
'nets'
,
'optimizer'
,
'learning_rate_decay'
,
'backward'
,
'regularizer'
,
'LoDTensor'
,
'CPUPlace'
,
'CUDAPlace'
,
'Tensor'
,
'ParamAttr'
'DataFeeder'
,
'clip'
,
'SimpleDistributeTranspiler'
,
'DistributeTranspiler'
,
'memory_optimize'
,
'profiler'
'io'
,
'initializer'
,
'layers'
,
'nets'
,
'optimizer'
,
'learning_rate_decay'
,
'backward'
,
'regularizer'
,
'LoDTensor'
,
'CPUPlace'
,
'CUDAPlace'
,
'Tensor'
,
'ParamAttr'
,
'WeightNormParamAttr'
,
'DataFeeder'
,
'clip'
,
'SimpleDistributeTranspiler'
,
'DistributeTranspiler'
,
'memory_optimize'
,
'profiler'
,
]
...
...
python/paddle/v2/fluid/layers/nn.py
浏览文件 @
1f26dce6
...
...
@@ -65,6 +65,7 @@ __all__ = [
'beam_search'
,
'row_conv'
,
'multiplex'
,
'layer_norm'
,
]
...
...
@@ -641,8 +642,8 @@ def dynamic_gru(input,
Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh".
Returns:
Variable: The hidden state of GRU. The shape is
(T
\\
times D), and lod
\
is the same with the input.
Variable: The hidden state of GRU. The shape is
:math:`(T
\\
times D)`,
\
and lod
is the same with the input.
Examples:
.. code-block:: python
...
...
@@ -990,7 +991,7 @@ def square_error_cost(input, label, **kwargs):
label(Variable): Label tensor, has target labels.
Returns:
Variable: The tensor variable storing the element-wise squared error
Variable: The tensor variable storing the element-wise squared error
\
difference of input and label.
Examples:
...
...
@@ -1214,7 +1215,7 @@ def conv2d(input,
act(str): Activation type. Default: None
Returns:
Variable: The tensor variable storing the convolution and
Variable: The tensor variable storing the convolution and
\
non-linearity activation result.
Raises:
...
...
@@ -1565,6 +1566,102 @@ def batch_norm(input,
return
helper
.
append_activation
(
batch_norm_out
)
def
layer_norm
(
input
,
scale
=
True
,
shift
=
True
,
begin_norm_axis
=
1
,
epsilon
=
1e-05
,
param_attr
=
None
,
bias_attr
=
None
,
act
=
None
,
name
=
None
):
"""
**Layer Normalization**
Assume feature vectors exist on dimensions
:attr:`begin_norm_axis ... rank(input)` and calculate the moment statistics
along these dimensions for each feature vector :math:`a` with size
:math:`H`, then normalize each feature vector using the corresponding
statistics. After that, apply learnable gain and bias on the normalized
tensor to scale and shift if :attr:`scale` and :attr:`shift` are set.
Refer to `Layer Normalization <https://arxiv.org/pdf/1607.06450v1.pdf>`_
The formula is as follows:
.. math::
\\
mu & =
\\
frac{1}{H}
\\
sum_{i=1}^{H} a_i
\\
sigma & =
\\
sqrt{
\\
frac{1}{H}\sum_{i=1}^{H}(a_i -
\\
mu)^2}
h & = f(
\\
frac{g}{
\\
sigma}(a -
\\
mu) + b)
Args:
input(Variable): The input tensor variable.
scale(bool): Whether to learn the adaptive gain :math:`g` after
normalization.
shift(bool): Whether to learn the adaptive bias :math:`b` after
normalization.
begin_norm_axis(bool): The normalization will be performed along
dimensions from :attr:`begin_norm_axis` to :attr:`rank(input)`.
epsilon(float): The small value added to the variance to prevent
division by zero.
param_attr(ParamAttr|None): The parameter attribute for the learnable
gain :math:`g`.
bias_attr(ParamAttr|None): The parameter attribute for the learnable
bias :math:`b`.
act(str): Activation to be applied to the output of layer normalizaiton.
Returns:
Variable: A tensor variable with the same shape as the input.
Examples:
.. code-block:: python
data = fluid.layers.data(
name='data', shape=[3, 32, 32], dtype='float32')
x = fluid.layers.layer_norm(input=data, begin_norm_axis=1)
"""
helper
=
LayerHelper
(
'layer_norm'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
# create intput and parameters
inputs
=
{
'X'
:
input
}
input_shape
=
input
.
shape
param_shape
=
[
reduce
(
lambda
x
,
y
:
x
*
y
,
input_shape
[
begin_norm_axis
:])]
if
scale
:
scale
=
helper
.
create_parameter
(
attr
=
helper
.
param_attr
,
shape
=
param_shape
,
dtype
=
dtype
,
default_initializer
=
Constant
(
1.0
))
inputs
[
'Scale'
]
=
scale
if
shift
:
assert
bias_attr
is
not
False
bias
=
helper
.
create_parameter
(
attr
=
helper
.
bias_attr
,
shape
=
param_shape
,
dtype
=
dtype
,
is_bias
=
True
)
inputs
[
'Bias'
]
=
bias
# create output
mean_out
=
helper
.
create_tmp_variable
(
dtype
=
dtype
,
stop_gradient
=
True
)
variance_out
=
helper
.
create_tmp_variable
(
dtype
=
dtype
,
stop_gradient
=
True
)
layer_norm_out
=
helper
.
create_tmp_variable
(
dtype
)
helper
.
append_op
(
type
=
"layer_norm"
,
inputs
=
inputs
,
outputs
=
{
"Y"
:
layer_norm_out
,
"Mean"
:
mean_out
,
"Variance"
:
variance_out
,
},
attrs
=
{
"epsilon"
:
epsilon
,
"begin_norm_axis"
:
begin_norm_axis
})
return
helper
.
append_activation
(
layer_norm_out
)
def
beam_search_decode
(
ids
,
scores
,
name
=
None
):
helper
=
LayerHelper
(
'beam_search_decode'
,
**
locals
())
sentence_ids
=
helper
.
create_tmp_variable
(
dtype
=
ids
.
dtype
)
...
...
python/paddle/v2/fluid/nets.py
浏览文件 @
1f26dce6
...
...
@@ -194,7 +194,7 @@ def scaled_dot_product_attention(queries,
Returns:
Variable: A 3-D Tensor computed by multi-head scaled dot product
Variable: A 3-D Tensor computed by multi-head scaled dot product
\
attention.
Raises:
...
...
@@ -333,6 +333,7 @@ def scaled_dot_product_attention(queries,
x
=
product
,
shape
=
[
-
1
,
product
.
shape
[
-
1
]],
act
=
"softmax"
),
shape
=
product
.
shape
)
if
dropout_rate
:
weights
=
layers
.
dropout
(
x
,
dropout_prob
=
dropout_rate
,
is_test
=
False
)
weights
=
layers
.
dropout
(
weights
,
dropout_prob
=
dropout_rate
,
is_test
=
False
)
ctx_multiheads
=
layers
.
matmul
(
weights
,
v
)
return
__combine_heads
(
ctx_multiheads
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录