Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ba22624d
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ba22624d
编写于
10月 29, 2018
作者:
G
gmcather
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
position encoding && log loss
test=develop
上级
e3701ad7
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
436 addition
and
0 deletion
+436
-0
paddle/fluid/API.spec
paddle/fluid/API.spec
+2
-0
paddle/fluid/operators/add_position_encoding_op.cc
paddle/fluid/operators/add_position_encoding_op.cc
+97
-0
paddle/fluid/operators/add_position_encoding_op.h
paddle/fluid/operators/add_position_encoding_op.h
+105
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+98
-0
python/paddle/fluid/tests/unittests/test_add_position_encoding_op.py
...le/fluid/tests/unittests/test_add_position_encoding_op.py
+134
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
ba22624d
...
@@ -177,6 +177,8 @@ paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, k
...
@@ -177,6 +177,8 @@ paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, k
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.hash ArgSpec(args=['input', 'hash_size', 'num_hash', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.log_loss ArgSpec(args=['input', 'label', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(0.0001, None))
paddle.fluid.layers.add_position_encoding ArgSpec(args=['input', 'alpha', 'beta', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None)
...
...
paddle/fluid/operators/add_position_encoding_op.cc
0 → 100644
浏览文件 @
ba22624d
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/add_position_encoding_op.h"
namespace
paddle
{
namespace
operators
{
class
AddPositionEncodingOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) of add_position_encoding_op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Out(Output) of add_position_encoding_op should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
ctx
->
SetOutputDim
(
"Out"
,
x_dims
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
};
class
AddPositionEncodingOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
),
"Out must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Out@GRAD must not be null."
);
auto
out_dims
=
ctx
->
GetInputDim
(
"Out"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
out_dims
);
}
}
};
class
AddPositionEncodingOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"Input of AddPositionEncoding operator"
);
AddOutput
(
"Out"
,
"Output of AddPositionEncoding operator"
);
AddAttr
<
float
>
(
"alpha"
,
"The scale of Original Embedding."
)
.
SetDefault
(
1.0
f
)
.
AddCustomChecker
([](
const
float
&
alpha
)
{
PADDLE_ENFORCE
(
alpha
>=
0.0
f
,
"'alpha' must be above 0.0."
);
});
AddAttr
<
float
>
(
"beta"
,
"The scale of Position Embedding."
)
.
SetDefault
(
1.0
f
)
.
AddCustomChecker
([](
const
float
&
beta
)
{
PADDLE_ENFORCE
(
beta
>=
0.0
f
,
"'beta' must be between 0.0."
);
});
AddComment
(
R"DOC(
Add Position Encoding Operator.
The add position encoding calculates the output based on the input, alpha, beta.
The size of each dimension of the parameters checked in the infer-shape.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plt
=
paddle
::
platform
;
REGISTER_OPERATOR
(
add_position_encoding
,
ops
::
AddPositionEncodingOp
,
ops
::
AddPositionEncodingOpMaker
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
add_position_encoding_grad
,
ops
::
AddPositionEncodingOpGrad
);
REGISTER_OP_CPU_KERNEL
(
add_position_encoding
,
ops
::
AddPositionEncodingKernel
<
plt
::
CPUDeviceContext
,
float
>
,
ops
::
AddPositionEncodingKernel
<
plt
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
add_position_encoding_grad
,
ops
::
AddPositionEncodingGradKernel
<
plt
::
CPUDeviceContext
,
float
>
,
ops
::
AddPositionEncodingGradKernel
<
plt
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/add_position_encoding_op.h
0 → 100644
浏览文件 @
ba22624d
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
AddPositionEncodingKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
X
=
context
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
&
x_lod
=
X
->
lod
();
auto
*
src_ptr
=
X
->
data
<
T
>
();
auto
*
Out
=
context
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
dst_ptr
=
Out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
float
alpha
=
context
.
Attr
<
float
>
(
"alpha"
);
float
beta
=
context
.
Attr
<
float
>
(
"beta"
);
auto
x_dim
=
X
->
dims
();
int
batch_size
=
0
;
int
max_seq_len
=
0
;
int
enc_size
=
0
;
if
(
x_lod
.
empty
())
{
PADDLE_ENFORCE
(
x_dim
.
size
()
==
3UL
,
"The input X of Add Position Encoding should be 3-D Tensor!"
);
batch_size
=
x_dim
[
0
];
max_seq_len
=
x_dim
[
1
];
enc_size
=
x_dim
[
2
];
}
else
{
PADDLE_ENFORCE
(
x_dim
.
size
()
==
2UL
,
"The input X of Add Position Encoding should be 2-D LoDTensor!"
);
PADDLE_ENFORCE
(
x_lod
.
size
()
==
1UL
,
"The Add Position Encoding Op only supports lod_level == 1!"
);
batch_size
=
x_lod
[
0
].
size
()
-
1
;
max_seq_len
=
-
1
;
enc_size
=
x_dim
[
1
];
}
PADDLE_ENFORCE
(
enc_size
%
2
==
0
,
"Only support even encode size!"
);
const
int
half_size
=
enc_size
/
2
;
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
const
int
max_length
=
x_lod
.
empty
()
?
max_seq_len
:
x_lod
[
0
][
i
+
1
]
-
x_lod
[
0
][
i
];
for
(
int
j
=
0
;
j
<
max_length
;
++
j
)
{
for
(
int
k
=
0
;
k
<
half_size
;
++
k
)
{
const
double
val
=
(
half_size
>
1
)
?
j
/
pow
(
10000.0
,
double
(
k
)
/
(
half_size
-
1
))
:
j
/
10000.0
;
dst_ptr
[
k
]
=
src_ptr
[
k
]
*
alpha
+
sin
(
val
)
*
beta
;
dst_ptr
[
half_size
+
k
]
=
src_ptr
[
half_size
+
k
]
*
alpha
+
cos
(
val
)
*
beta
;
}
src_ptr
+=
enc_size
;
dst_ptr
+=
enc_size
;
}
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
AddPositionEncodingGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
dOut
=
context
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
dout
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dOut
);
auto
*
dX
=
context
.
Output
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
dX
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
dx
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
dX
);
float
alpha
=
context
.
Attr
<
float
>
(
"alpha"
);
auto
*
place
=
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
dx
.
device
(
*
place
)
=
dout
*
static_cast
<
T
>
(
alpha
);
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/nn.py
浏览文件 @
ba22624d
...
@@ -157,6 +157,8 @@ __all__ = [
...
@@ -157,6 +157,8 @@ __all__ = [
'sequence_reverse'
,
'sequence_reverse'
,
'affine_channel'
,
'affine_channel'
,
'hash'
,
'hash'
,
'log_loss'
,
'add_position_encoding'
,
]
]
...
@@ -7580,3 +7582,99 @@ def hash(input, hash_size, num_hash=1, name=None):
...
@@ -7580,3 +7582,99 @@ def hash(input, hash_size, num_hash=1, name=None):
attrs
=
{
'num_hash'
:
num_hash
,
attrs
=
{
'num_hash'
:
num_hash
,
'mod_by'
:
hash_size
})
'mod_by'
:
hash_size
})
return
out
return
out
def
log_loss
(
input
,
label
,
epsilon
=
1e-4
,
name
=
None
):
"""
**Negative Log Loss Layer**
This layer accepts input predictions and target label and returns the
negative log loss.
.. math::
Out = -label *
\\
log{(input +
\\
epsilon)}
- (1 - label) *
\\
log{(1 - input +
\\
epsilon)}
Args:
input (Variable|list): a 2-D tensor with shape [N x 1], where N is the
batch size. This input is a probability computed
by the previous operator.
label (Variable|list): the ground truth which is a 2-D tensor with
shape [N x 1], where N is the batch size.
epsilon (float): epsilon
name (string): the name of log_loss
Returns:
Variable: A 2-D tensor with shape [N x 1], the negative log loss.
Examples:
.. code-block:: python
prob = fluid.layers.sigmoid(net)
cost = fluid.layers.log_loss(input=prob, label=label)
"""
helper
=
LayerHelper
(
'log_loss'
,
**
locals
())
if
name
is
None
:
loss
=
helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
else
:
loss
=
helper
.
create_variable
(
name
=
name
,
dtype
=
input
.
dtype
,
persistable
=
False
)
helper
.
append_op
(
type
=
'log_loss'
,
inputs
=
{
'Predicted'
:
[
input
],
'Labels'
:
[
label
]},
outputs
=
{
'Loss'
:
[
loss
]},
attrs
=
{
'epsilon'
:
epsilon
})
return
loss
def
add_position_encoding
(
input
,
alpha
,
beta
,
name
=
None
):
"""
**Add Position Encoding Layer**
This layer accepts an input 3D-Tensor of shape [N x M x P], and return an
output Tensor of shape [N x M x P] with positional encoding value.
Refer to `Attention Is All You Need<http://arxiv.org/pdf/1706.03762.pdf>`_ .
.. math::
PE(pos, 2i) =
\\
sin{(pos / 10000^{2i / P})}
\\\\
PE(pos, 2i + 1) =
\\
cos{(pos / 10000^{2i / P})}
\\\\
Out(:, pos, i) =
\\
alpha * input(:, pos, i) +
\\
beta * PE(pos, i)
Where:
* PE(pos, 2i): the increment for the number at even position
* PE(pos, 2i + 1): the increment for the number at odd position
Args:
input (Variable): 3-D input tensor with shape [N x M x P]
alpha (float): multiple of Input Tensor
beta (float): multiple of Positional Encoding Tensor
name (string): the name of position encoding layer
Returns:
Variable: A 3-D Tensor of shape [N x M x P] with positional encoding.
Examples:
.. code-block:: python
position_tensor = fluid.layers.add_position_encoding(input=tensor)
"""
helper
=
LayerHelper
(
'add_position_encoding'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
if
name
is
None
:
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
dtype
)
else
:
out
=
helper
.
create_variable
(
name
=
name
,
dtype
=
dtype
,
persistable
=
False
)
helper
.
append_op
(
type
=
"add_position_encoding"
,
inputs
=
{
"X"
:
input
},
outputs
=
{
"Out"
:
out
},
attrs
=
{
"alpha"
:
alpha
,
"beta"
:
beta
})
return
out
python/paddle/fluid/tests/unittests/test_add_position_encoding_op.py
0 → 100644
浏览文件 @
ba22624d
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
math
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
class
TestAddPositionEncodingTensorOp
(
OpTest
):
"""
This class is to test the AddPositionEncodingOp
"""
def
setUp
(
self
):
"""
the prepared section for add position encoding op
"""
self
.
op_type
=
"add_position_encoding"
self
.
dtype
=
np
.
float32
self
.
init_input_output
()
self
.
inputs
=
{
'X'
:
OpTest
.
np_dtype_to_fluid_dtype
(
self
.
x
),
}
self
.
outputs
=
{
'Out'
:
self
.
out
}
self
.
attrs
=
{
'alpha'
:
self
.
alpha
,
'beta'
:
self
.
beta
}
def
test_check_output
(
self
):
"""
check the correctness of output
"""
self
.
check_output
()
def
test_check_grad
(
self
):
"""
check the correctness of grad
"""
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.005
)
def
init_input_output
(
self
):
"""
init the input and output for test cases
"""
self
.
alpha
=
0.6
self
.
beta
=
0.5
self
.
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
4
,
4
]).
astype
(
self
.
dtype
)
self
.
out
=
np
.
copy
(
self
.
x
)
batch_size
=
self
.
x
.
shape
[
0
]
max_length
=
self
.
x
.
shape
[
1
]
enc_size
=
self
.
x
.
shape
[
2
]
half_shape
=
int
(
enc_size
/
2
)
for
i
in
range
(
batch_size
):
for
j
in
range
(
max_length
):
for
k
in
range
(
half_shape
):
val
=
j
/
pow
(
10000.0
,
k
/
(
half_shape
-
1
))
if
half_shape
>
1
else
j
/
10000.0
self
.
out
[
i
,
j
,
k
]
=
\
self
.
x
[
i
,
j
,
k
]
*
self
.
alpha
+
math
.
sin
(
val
)
*
self
.
beta
self
.
out
[
i
,
j
,
half_shape
+
k
]
=
\
self
.
x
[
i
,
j
,
half_shape
+
k
]
*
self
.
alpha
+
math
.
cos
(
val
)
*
self
.
beta
class
TestAddPositionEncodingLoDTensorOp
(
OpTest
):
"""
This class is to test the AddPositionEncodingLoDTensorOp
"""
def
setUp
(
self
):
"""
the prepared section for add position encoding LoDTensor op
"""
self
.
op_type
=
"add_position_encoding"
self
.
dtype
=
np
.
float32
self
.
init_input_output
()
self
.
inputs
=
{
'X'
:
(
self
.
x
,
self
.
lod
),
}
self
.
outputs
=
{
'Out'
:
(
self
.
out
,
self
.
lod
)}
self
.
attrs
=
{
'alpha'
:
self
.
alpha
,
'beta'
:
self
.
beta
}
def
test_check_output
(
self
):
"""
check the correctness of output
"""
self
.
check_output
()
def
test_check_grad
(
self
):
"""
check the correctness of grad
"""
self
.
check_grad
([
'X'
],
'Out'
,
max_relative_error
=
0.005
)
def
init_input_output
(
self
):
"""
init the input and output for test cases
"""
self
.
alpha
=
0.6
self
.
beta
=
0.5
self
.
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
10
,
4
]).
astype
(
self
.
dtype
)
self
.
lod
=
[[
3
,
7
]]
self
.
out
=
np
.
copy
(
self
.
x
)
batch_size
=
len
(
self
.
lod
[
0
])
enc_size
=
self
.
x
.
shape
[
1
]
start
=
0
half_shape
=
int
(
enc_size
/
2
)
for
i
in
range
(
batch_size
):
max_length
=
self
.
lod
[
0
][
i
]
for
j
in
range
(
max_length
):
for
k
in
range
(
half_shape
):
val
=
j
/
pow
(
10000.0
,
k
/
(
half_shape
-
1
))
if
half_shape
>
1
else
j
/
10000.0
pos
=
start
+
j
self
.
out
[
pos
,
k
]
=
\
self
.
x
[
pos
,
k
]
*
self
.
alpha
+
math
.
sin
(
val
)
*
self
.
beta
self
.
out
[
pos
,
half_shape
+
k
]
=
\
self
.
x
[
pos
,
half_shape
+
k
]
*
self
.
alpha
+
math
.
cos
(
val
)
*
self
.
beta
start
+=
max_length
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录