Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
92e2207e
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
92e2207e
编写于
3月 13, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine doc
上级
ff09b21c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
39 addition
and
26 deletion
+39
-26
paddle/fluid/operators/lookup_table_op.cc
paddle/fluid/operators/lookup_table_op.cc
+26
-13
paddle/fluid/operators/lookup_table_op.cu
paddle/fluid/operators/lookup_table_op.cu
+6
-6
paddle/fluid/operators/lookup_table_op.h
paddle/fluid/operators/lookup_table_op.h
+7
-7
未找到文件。
paddle/fluid/operators/lookup_table_op.cc
浏览文件 @
92e2207e
...
@@ -34,8 +34,11 @@ class LookupTableOp : public framework::OperatorWithKernel {
...
@@ -34,8 +34,11 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
auto
ids_var_type
=
ctx
->
GetInputsVarType
(
"Ids"
).
front
();
auto
ids_var_type
=
ctx
->
GetInputsVarType
(
"Ids"
).
front
();
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// Maybe near future we will add concat_rows op.
// is LoDTensor, this tensor contains the ids to be looked up in W
// and it must be a column vector with rank = 2 while the 2nd dimension
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
// contains the ids to be looked up in W;
if
(
ids_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
if
(
ids_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
PADDLE_ENFORCE_EQ
(
ids_dims
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
ids_dims
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
ids_dims
[
1
],
1
);
PADDLE_ENFORCE_EQ
(
ids_dims
[
1
],
1
);
...
@@ -59,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -59,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
LookupTableOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
LookupTableOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"W"
,
AddInput
(
"W"
,
"
An
input represents embedding tensors, "
"
(Tensor) The
input represents embedding tensors, "
"which is a learnable parameter."
);
"which is a learnable parameter."
);
AddInput
(
"Ids"
,
AddInput
(
"An input with type int32 or int64 "
"Ids"
,
"contains the ids to be looked up in W. "
"(Tensor or SelectedRows) Ids's type can be Tensor or "
"Ids must be a column vector with rank = 2. "
"SelectedRows, when Ids's type is Tensor, this tensor contains "
"The 2nd dimension size must be 1."
);
"the ids to be looked up in W and it must be a column vector with "
AddOutput
(
"Out"
,
"The lookup results, which have the same type as W."
);
"rank = 2 while the 2nd dimension size must be 1; when Ids's type is "
"SelectedRows, the rows of Ids contains the ids to be looked up "
"in W."
);
AddOutput
(
"Out"
,
"(Tensor or SelectedRows) The lookup results, which have the "
"same type as W."
);
AddAttr
<
bool
>
(
"is_sparse"
,
AddAttr
<
bool
>
(
"is_sparse"
,
"(boolean, default false) "
"(boolean, default false) "
"Sparse update"
)
"Sparse update
.
"
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
int64_t
>
(
"padding_idx"
,
AddAttr
<
int64_t
>
(
"padding_idx"
,
"(int64, default -1) "
"(int64, default -1) "
...
@@ -81,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -81,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
Lookup Table Operator.
Lookup Table Operator.
This operator is used to perform lookups on the parameter W,
This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor.
then concatenated into a dense or sparse tensor.
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
at this time, Ids can carry the LoD (Level of Details) information, or not, and
the output only shares the LoD information with input Ids.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC"
);
)DOC"
);
}
}
...
...
paddle/fluid/operators/lookup_table_op.cu
浏览文件 @
92e2207e
...
@@ -75,22 +75,22 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
...
@@ -75,22 +75,22 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
auto
*
ids_var
=
context
.
InputVar
(
"Ids"
);
// int tensor
auto
*
ids_var
=
context
.
InputVar
(
"Ids"
);
Tensor
*
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
int64_t
*
ids
;
int64_t
*
ids
;
int64_t
K
;
int64_t
K
;
framework
::
Tensor
*
output_t
;
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// Maybe near future we will add concat_rows op.
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if
(
ids_var
->
IsType
<
framework
::
LoDTensor
>
())
{
if
(
ids_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
// float tensor
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
K
=
ids_t
->
numel
();
K
=
ids_t
->
numel
();
}
else
if
(
ids_var
->
IsType
<
framework
::
SelectedRows
>
())
{
}
else
if
(
ids_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
ids_t
=
context
.
Input
<
framework
::
SelectedRows
>
(
"Ids"
);
auto
*
ids_t
=
context
.
Input
<
framework
::
SelectedRows
>
(
"Ids"
);
output_t
=
context
.
Output
<
SelectedRows
>
(
"Out"
)
->
mutable_value
();
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
rows
().
CUDAData
(
context
.
GetPlace
()));
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
rows
().
CUDAData
(
context
.
GetPlace
()));
K
=
ids_t
->
rows
().
size
();
K
=
ids_t
->
rows
().
size
();
output_t
->
Resize
({
K
,
table_t
->
dims
()[
1
]});
output_t
->
Resize
({
K
,
table_t
->
dims
()[
1
]});
...
...
paddle/fluid/operators/lookup_table_op.h
浏览文件 @
92e2207e
...
@@ -30,23 +30,23 @@ template <typename T>
...
@@ -30,23 +30,23 @@ template <typename T>
class
LookupTableKernel
:
public
framework
::
OpKernel
<
T
>
{
class
LookupTableKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
// float tensor
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
ids_var
=
context
.
InputVar
(
"Ids"
);
// int tensor
auto
*
ids_var
=
context
.
InputVar
(
"Ids"
);
Tensor
*
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
int64_t
*
ids
;
int64_t
*
ids
;
int64_t
ids_numel
;
int64_t
ids_numel
;
Tensor
*
output_t
;
// ids_var_types also can be LOD_TENSOR_ARRAY, it's used as concat_rows.
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// Maybe near future we will add concat_rows op.
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if
(
ids_var
->
IsType
<
LoDTensor
>
())
{
if
(
ids_var
->
IsType
<
LoDTensor
>
())
{
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
ids_numel
=
ids_t
->
numel
();
ids_numel
=
ids_t
->
numel
();
}
else
if
(
ids_var
->
IsType
<
SelectedRows
>
())
{
}
else
if
(
ids_var
->
IsType
<
SelectedRows
>
())
{
auto
*
ids_t
=
context
.
Input
<
SelectedRows
>
(
"Ids"
);
auto
*
ids_t
=
context
.
Input
<
SelectedRows
>
(
"Ids"
);
output_t
=
context
.
Output
<
SelectedRows
>
(
"Out"
)
->
mutable_value
();
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
rows
().
data
());
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
rows
().
data
());
ids_numel
=
ids_t
->
rows
().
size
();
ids_numel
=
ids_t
->
rows
().
size
();
output_t
->
Resize
({
ids_numel
,
table_t
->
dims
()[
1
]});
output_t
->
Resize
({
ids_numel
,
table_t
->
dims
()[
1
]});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录