Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
0c4697f8
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0c4697f8
编写于
8月 27, 2018
作者:
C
chenweihang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: change to enumerate by sentence
上级
4ec12496
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
48 addition
and
28 deletion
+48
-28
paddle/fluid/operators/sequence_enumerate_op.cc
paddle/fluid/operators/sequence_enumerate_op.cc
+2
-2
paddle/fluid/operators/sequence_enumerate_op.cu
paddle/fluid/operators/sequence_enumerate_op.cu
+17
-6
paddle/fluid/operators/sequence_enumerate_op.h
paddle/fluid/operators/sequence_enumerate_op.h
+8
-6
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+3
-3
python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py
...addle/fluid/tests/unittests/test_sequence_enumerate_op.py
+18
-11
未找到文件。
paddle/fluid/operators/sequence_enumerate_op.cc
浏览文件 @
0c4697f8
...
...
@@ -72,14 +72,14 @@ Examples:
Case 1:
Input:
X.lod = [[0, 3, 5]]
X.data = [
1, 2, 3, 4, 5
]
X.data = [
[1], [2], [3], [4], [5]
]
X.dims = [5, 1]
Attrs:
win_size = 2
pad_value = 0
Output:
Out.lod = [[0, 3, 5]]
Out.data = [[1, 2], [2, 3], [3,
4], [4, 5], [0
, 0]]
Out.data = [[1, 2], [2, 3], [3,
0], [4, 5], [5
, 0]]
Out.dims = [5, 2]
)DOC"
);
...
...
paddle/fluid/operators/sequence_enumerate_op.cu
浏览文件 @
0c4697f8
...
...
@@ -23,15 +23,23 @@ using platform::PADDLE_CUDA_NUM_THREADS;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
>
__global__
void
CalcOutPut
(
const
T
*
in_data
,
const
int64_t
in_len
,
const
int64_t
win_size
,
const
int64_t
pad_valu
e
,
T
*
out_data
)
{
__global__
void
CalcOutPut
(
const
T
*
in_data
,
const
size_t
*
in_lod
,
const
size_t
lod_len
,
const
int64_t
win_siz
e
,
const
int64_t
pad_value
,
T
*
out_data
)
{
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
index
<
in_len
)
{
if
(
index
<
in_lod
[
lod_len
-
1
])
{
int
end_idx
=
0
;
// Get LoD interval of index
for
(
int
i
=
1
;
i
<
lod_len
;
++
i
)
{
if
(
index
<
in_lod
[
i
])
{
end_idx
=
in_lod
[
i
];
break
;
}
}
for
(
size_t
i
=
0
;
i
<
win_size
;
++
i
)
{
int
word_pos
=
index
+
i
;
out_data
[
index
*
win_size
+
i
]
=
word_pos
<
in_len
?
in_data
[
word_pos
]
:
pad_value
;
word_pos
<
end_idx
?
in_data
[
word_pos
]
:
pad_value
;
}
}
}
...
...
@@ -54,13 +62,16 @@ class SequenceEnumerateOpCUDAKernel : public framework::OpKernel<T> {
/* Generate enumerate sequence set */
auto
stream
=
context
.
cuda_device_context
().
stream
();
auto
lod0
=
in_lod
[
0
];
auto
in_len
=
in
->
numel
();
auto
in_data
=
in
->
data
<
T
>
();
auto
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
// Copy LoD to GPU
const
size_t
*
dev_in_lod_ptr
=
lod0
.
CUDAData
(
context
.
GetPlace
());
// Calc output tensor
CalcOutPut
<<<
(
in_len
-
1
)
/
PADDLE_CUDA_NUM_THREADS
+
1
,
PADDLE_CUDA_NUM_THREADS
,
0
,
stream
>>>
(
in_data
,
in_len
,
win_size
,
pad_value
,
out_data
);
in_data
,
dev_in_lod_ptr
,
lod0
.
size
()
,
win_size
,
pad_value
,
out_data
);
}
};
...
...
paddle/fluid/operators/sequence_enumerate_op.h
浏览文件 @
0c4697f8
...
...
@@ -37,14 +37,16 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
"The actual input data's size mismatched with LoD information."
);
// Generate enumerate sequence set
auto
seq_length
=
in_dims
[
0
];
auto
lod0
=
in_lod
[
0
];
auto
in_data
=
in
->
data
<
T
>
();
auto
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
idx
=
0
;
idx
<
seq_length
;
++
idx
)
{
for
(
int
word_idx
=
0
;
word_idx
<
win_size
;
++
word_idx
)
{
int
word_pos
=
idx
+
word_idx
;
out_data
[
win_size
*
idx
+
word_idx
]
=
word_pos
<
seq_length
?
in_data
[
word_pos
]
:
pad_value
;
for
(
size_t
i
=
0
;
i
<
lod0
.
size
()
-
1
;
++
i
)
{
for
(
size_t
idx
=
lod0
[
i
];
idx
<
lod0
[
i
+
1
];
++
idx
)
{
for
(
int
word_idx
=
0
;
word_idx
<
win_size
;
++
word_idx
)
{
size_t
word_pos
=
idx
+
word_idx
;
out_data
[
win_size
*
idx
+
word_idx
]
=
word_pos
<
lod0
[
i
+
1
]
?
in_data
[
word_pos
]
:
pad_value
;
}
}
}
}
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
0c4697f8
...
...
@@ -5534,14 +5534,14 @@ def sequence_enumerate(input, win_size, pad_value, name=None):
Case 1:
Input:
X.lod = [[0, 3, 5]]
X.data = [
1, 2, 3, 4, 5
]
X.data = [
[1], [2], [3], [4], [5]
]
X.dims = [5, 1]
Attrs:
win_size = 2
pad_value = 0
Output:
Out.lod = [[0, 3, 5]]
Out.data = [[1, 2], [2, 3], [3,
4], [4, 5], [0
, 0]]
Out.data = [[1, 2], [2, 3], [3,
0], [4, 5], [5
, 0]]
Out.dims = [5, 2]
Args:
...
...
@@ -5567,7 +5567,7 @@ def sequence_enumerate(input, win_size, pad_value, name=None):
attrs
=
{
'win_size'
:
win_size
,
'pad_value'
:
pad_value
})
def
sequence_mask
(
x
,
maxlen
=
None
,
dtype
=
'int64'
,
name
=
None
):
"""
**SequenceMask Layer**
...
...
python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py
浏览文件 @
0c4697f8
...
...
@@ -19,16 +19,20 @@ import numpy as np
from
op_test
import
OpTest
def
sequence_enumerate
(
input_seq
,
win_size
,
pad_value
):
def
sequence_enumerate
(
input_seq
,
in_lod
,
win_size
,
pad_value
):
lod0
=
[
0
]
for
i
in
range
(
0
,
len
(
in_lod
[
0
])):
lod0
.
append
(
lod0
[
i
]
+
in_lod
[
0
][
i
])
out_seq
=
[]
for
idx
in
range
(
0
,
len
(
input_seq
)):
single_seq
=
[]
for
word_idx
in
range
(
win_size
):
word_pos
=
idx
+
word_idx
dat
=
input_seq
[
word_pos
]
if
word_pos
<
len
(
input_seq
)
\
for
i
in
range
(
0
,
len
(
lod0
)
-
1
):
for
idx
in
range
(
lod0
[
i
],
lod0
[
i
+
1
]):
single_seq
=
[]
for
word_idx
in
range
(
win_size
):
word_pos
=
idx
+
word_idx
dat
=
input_seq
[
word_pos
]
if
word_pos
<
lod0
[
i
+
1
]
\
else
pad_value
single_seq
.
append
(
dat
)
out_seq
.
append
(
single_seq
)
single_seq
.
append
(
dat
)
out_seq
.
append
(
single_seq
)
return
out_seq
...
...
@@ -48,7 +52,8 @@ class TestSequenceEnumerateOp(OpTest):
self
.
lod
=
[[
9
,
4
,
11
,
6
]]
self
.
win_size
=
2
self
.
pad_value
=
0
out_seq
=
sequence_enumerate
(
self
.
in_seq
,
self
.
win_size
,
self
.
pad_value
)
out_seq
=
sequence_enumerate
(
self
.
in_seq
,
self
.
lod
,
self
.
win_size
,
self
.
pad_value
)
self
.
out_seq
=
np
.
array
(
out_seq
).
astype
(
"int32"
)
...
...
@@ -58,7 +63,8 @@ class TesSequenceEnumerateOpInt64(TestSequenceEnumerateOp):
self
.
lod
=
[[
9
,
4
,
11
,
6
]]
self
.
win_size
=
2
self
.
pad_value
=
0
out_seq
=
sequence_enumerate
(
self
.
in_seq
,
self
.
win_size
,
self
.
pad_value
)
out_seq
=
sequence_enumerate
(
self
.
in_seq
,
self
.
lod
,
self
.
win_size
,
self
.
pad_value
)
self
.
out_seq
=
np
.
array
(
out_seq
).
astype
(
"int64"
)
...
...
@@ -68,7 +74,8 @@ class TestSequenceEnumerateOpMaxWinSize(TestSequenceEnumerateOp):
self
.
lod
=
[[
9
,
4
,
11
,
6
]]
self
.
win_size
=
30
self
.
pad_value
=
0
out_seq
=
sequence_enumerate
(
self
.
in_seq
,
self
.
win_size
,
self
.
pad_value
)
out_seq
=
sequence_enumerate
(
self
.
in_seq
,
self
.
lod
,
self
.
win_size
,
self
.
pad_value
)
self
.
out_seq
=
np
.
array
(
out_seq
).
astype
(
"int32"
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录