Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
fd9b7bdb
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fd9b7bdb
编写于
4月 04, 2020
作者:
Z
zhangchunle
提交者:
GitHub
4月 04, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Op (FusedEmbeddingSeqPool) error message enhancement. (#23454)
上级
16315d3d
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
50 addition
and
18 deletion
+50
-18
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
+31
-14
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
+19
-4
未找到文件。
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.cc
浏览文件 @
fd9b7bdb
...
...
@@ -24,30 +24,47 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"W"
),
"Input W of FusedEmbeddingSeqPoolOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Ids"
),
"Input Ids of FusedEmbeddingSeqPoolOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output of FusedEmbeddingSeqPoolOp should not be null."
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"W"
),
"Input"
,
"W"
,
"FusedEmbeddingSeqPool"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Ids"
),
"Input"
,
"Ids"
,
"FusedEmbeddingSeqPool"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"FusedEmbeddingSeqPool"
);
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
const
std
::
string
&
combiner
=
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"combiner"
);
PADDLE_ENFORCE_EQ
(
table_dims
.
size
(),
2
);
PADDLE_ENFORCE_GE
(
ids_dims
.
size
(),
1
,
"The dim size of the 'Ids' tensor must greater than 1."
);
PADDLE_ENFORCE_EQ
(
ids_dims
[
ids_dims
.
size
()
-
1
],
1
,
"The last dimension of the 'Ids' tensor must be 1."
);
PADDLE_ENFORCE_EQ
(
table_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"The dim size of the input tensor 'W' should be 2. "
"But received W's size = %d."
,
table_dims
.
size
()));
PADDLE_ENFORCE_GE
(
ids_dims
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The dim size of the input tensor 'Ids' should be greater "
"than or equal to 1. But received Ids's size = %d."
,
ids_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
ids_dims
[
ids_dims
.
size
()
-
1
],
1
,
platform
::
errors
::
InvalidArgument
(
"The last dimension of the input tensor 'Ids' should be 1. "
"But received Ids's size in the last dimension = %d."
,
ids_dims
[
ids_dims
.
size
()
-
1
]));
// we only support sum now
PADDLE_ENFORCE_EQ
(
combiner
,
"sum"
);
PADDLE_ENFORCE_EQ
(
combiner
,
"sum"
,
platform
::
errors
::
Unimplemented
(
"The pooling type of sequence_pool only support sum "
"now. So the 'combiner' must be 'sum'."
));
int64_t
last_dim
=
FusedEmbeddingSeqPoolLastDim
(
table_dims
,
ids_dims
);
// in compile time, the lod level of ids must be 1
framework
::
VarDesc
*
ids_desc
=
boost
::
get
<
framework
::
VarDesc
*>
(
ctx
->
GetInputVarPtrs
(
"Ids"
)[
0
]);
PADDLE_ENFORCE_EQ
(
ids_desc
->
GetLoDLevel
(),
1
);
PADDLE_ENFORCE_EQ
(
ids_desc
->
GetLoDLevel
(),
1
,
platform
::
errors
::
InvalidArgument
(
"In compile time, the LoD Level of Ids should be 1. "
"But received the LoD Level of Ids = %d."
,
ids_desc
->
GetLoDLevel
()));
// in compile time, the shape from Ids -> output
// should be [-1, 1] -> [-1, embedding_size]
...
...
paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h
浏览文件 @
fd9b7bdb
...
...
@@ -90,8 +90,17 @@ struct EmbeddingVSumFunctor {
int64_t
idx_width
=
ids_t
->
numel
()
/
ids_lod
.
back
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
PADDLE_ENFORCE_LE
(
table_width
*
idx_width
,
out_width
);
PADDLE_ENFORCE_GT
(
ids_lod
.
size
(),
1UL
,
"The LoD[0] could NOT be empty"
);
PADDLE_ENFORCE_LE
(
table_width
*
idx_width
,
out_width
,
platform
::
errors
::
InvalidArgument
(
"table_width * idx_width should be less than or "
"equal to out_width. But received "
"table_width * idx_width = %s, out_width = %d."
,
table_width
*
idx_width
,
out_width
));
PADDLE_ENFORCE_GT
(
ids_lod
.
size
(),
1UL
,
platform
::
errors
::
InvalidArgument
(
"The tensor ids's LoD[0] should be greater than 1. "
"But received the ids's LoD[0] = %d."
,
ids_lod
.
size
()));
jit
::
emb_seq_pool_attr_t
attr
(
table_height
,
table_width
,
0
,
idx_width
,
out_width
,
jit
::
SeqPoolType
::
kSum
);
...
...
@@ -130,7 +139,10 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
const
auto
&
ids_lod
=
ids_t
->
lod
();
// in run time, the LoD of ids must be 1
PADDLE_ENFORCE_EQ
(
ids_lod
.
size
(),
1UL
,
"The LoD level of Input(Ids) must be 1"
);
platform
::
errors
::
InvalidArgument
(
"The LoD level of Input(Ids) should be 1. But "
"received Ids's LoD level = %d."
,
ids_lod
.
size
()));
int64_t
batch_size
=
ids_lod
[
0
].
size
()
-
1
;
// in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, last_dim]
...
...
@@ -244,7 +256,10 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
const
auto
&
ids_lod
=
ids
->
lod
();
PADDLE_ENFORCE_EQ
(
ids_lod
.
size
(),
1UL
,
"The LoD level of Input(Ids) must be 1"
);
platform
::
errors
::
InvalidArgument
(
"The LoD level of Input(Ids) should be 1. But "
"received Ids's LoD level = %d."
,
ids_lod
.
size
()));
const
std
::
vector
<
uint64_t
>
offset
=
ids_lod
[
0
];
auto
len
=
ids
->
numel
();
int
idx_width
=
len
/
offset
.
back
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录