Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ead558b7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
ead558b7
编写于
3月 19, 2019
作者:
T
tensor-tang
提交者:
GitHub
3月 19, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16256 from tensor-tang/refine/seqenum
refine sequence enumerate op
上级
c7f1f3ed
50931dee
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
28 addition
and
18 deletion
+28
-18
paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc
paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc
+0
-7
paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h
paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h
+28
-11
未找到文件。
paddle/fluid/operators/sequence_ops/sequence_enumerate_op.cc
浏览文件 @
ead558b7
...
@@ -30,13 +30,6 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
...
@@ -30,13 +30,6 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
"Output(X) of SequenceEnumerate operator should not be null."
);
"Output(X) of SequenceEnumerate operator should not be null."
);
const
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
const
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"Input(X) of SequenceEnumerate operator's rank should be 2."
);
PADDLE_ENFORCE_EQ
(
x_dims
[
1
],
1
,
"Input(X) of SequenceEnumerate operator's 2nd "
"dimension should be 1."
);
const
auto
win_size
=
ctx
->
Attrs
().
Get
<
int
>
(
"win_size"
);
const
auto
win_size
=
ctx
->
Attrs
().
Get
<
int
>
(
"win_size"
);
ctx
->
SetOutputDim
(
"Out"
,
{
x_dims
[
0
],
win_size
});
ctx
->
SetOutputDim
(
"Out"
,
{
x_dims
[
0
],
win_size
});
ctx
->
ShareLoD
(
"X"
,
"Out"
);
ctx
->
ShareLoD
(
"X"
,
"Out"
);
...
...
paddle/fluid/operators/sequence_ops/sequence_enumerate_op.h
浏览文件 @
ead558b7
...
@@ -27,30 +27,47 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
...
@@ -27,30 +27,47 @@ class SequenceEnumerateKernel : public framework::OpKernel<T> {
auto
*
in
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
in
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int
win_size
=
context
.
Attr
<
int
>
(
"win_size"
);
int
win_size
=
context
.
Attr
<
int
>
(
"win_size"
);
int
pad_value
=
context
.
Attr
<
int
>
(
"pad_value"
);
auto
pad_value
=
static_cast
<
T
>
(
context
.
Attr
<
int
>
(
"pad_value"
)
);
auto
in_dims
=
in
->
dims
();
auto
in_dims
=
in
->
dims
();
auto
in_lod
=
in
->
lod
();
auto
lod0
=
in
->
lod
()[
0
];
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
static_cast
<
uint64_t
>
(
in_dims
[
0
]),
in_lod
[
0
]
.
back
(),
static_cast
<
uint64_t
>
(
in_dims
[
0
]),
lod0
.
back
(),
"The actual input data's size mismatched with LoD information."
);
"The actual input data's size mismatched with LoD information."
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
2UL
,
"Input(X) of SequenceEnumerate operator's rank should be 2."
);
PADDLE_ENFORCE_EQ
(
in_dims
[
1
],
1
,
"Input(X) of SequenceEnumerate operator's 2nd "
"dimension should be 1."
);
// Generate enumerate sequence set
// Generate enumerate sequence set
auto
lod0
=
in_lod
[
0
];
auto
in_data
=
in
->
data
<
T
>
();
auto
in_data
=
in
->
data
<
T
>
();
out
->
Resize
({
in_dims
[
0
],
win_size
});
out
->
Resize
({
in_dims
[
0
],
win_size
});
out
->
set_lod
(
in
->
lod
());
auto
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
size_t
i
=
0
;
i
<
lod0
.
size
()
-
1
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
lod0
.
size
()
-
1
;
++
i
)
{
for
(
size_t
idx
=
lod0
[
i
];
idx
<
lod0
[
i
+
1
];
++
idx
)
{
int
start
=
lod0
[
i
];
for
(
int
word_idx
=
0
;
word_idx
<
win_size
;
++
word_idx
)
{
int
end
=
lod0
[
i
+
1
];
size_t
word_pos
=
idx
+
word_idx
;
int
copy_size
=
win_size
<
end
-
start
+
1
?
win_size
:
end
-
start
+
1
;
out_data
[
win_size
*
idx
+
word_idx
]
=
int
mid
=
end
+
1
-
copy_size
;
word_pos
<
lod0
[
i
+
1
]
?
in_data
[
word_pos
]
:
pad_value
;
int
pad_num
=
win_size
-
copy_size
;
copy_size
*=
sizeof
(
T
);
for
(
int
idx
=
start
;
idx
<
mid
;
++
idx
)
{
std
::
memcpy
(
out_data
,
in_data
+
idx
,
copy_size
);
out_data
+=
win_size
;
}
}
for
(
int
idx
=
mid
;
idx
<
end
;
++
idx
)
{
copy_size
-=
sizeof
(
T
);
pad_num
++
;
std
::
memcpy
(
out_data
,
in_data
+
idx
,
copy_size
);
T
*
pdata
=
out_data
+
copy_size
/
sizeof
(
T
);
for
(
int
i
=
0
;
i
<
pad_num
;
++
i
)
{
pdata
[
i
]
=
pad_value
;
}
out_data
+=
win_size
;
}
}
}
}
out
->
set_lod
(
in
->
lod
());
}
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录