Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
788c600e
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
788c600e
编写于
3月 14, 2018
作者:
C
chengduo
提交者:
GitHub
3月 14, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #8932 from chengduoZH/feature/add_concat_rows
Enhance look_up_table op
上级
28078969
a43eee40
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
129 addition
and
23 deletion
+129
-23
paddle/fluid/operators/lookup_table_op.cc
paddle/fluid/operators/lookup_table_op.cc
+31
-13
paddle/fluid/operators/lookup_table_op.cu
paddle/fluid/operators/lookup_table_op.cu
+22
-4
paddle/fluid/operators/lookup_table_op.h
paddle/fluid/operators/lookup_table_op.h
+27
-6
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
+49
-0
未找到文件。
paddle/fluid/operators/lookup_table_op.cc
浏览文件 @
788c600e
...
@@ -33,8 +33,16 @@ class LookupTableOp : public framework::OperatorWithKernel {
...
@@ -33,8 +33,16 @@ class LookupTableOp : public framework::OperatorWithKernel {
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
PADDLE_ENFORCE_EQ
(
ids_dims
.
size
(),
2
);
auto
ids_var_type
=
ctx
->
GetInputsVarType
(
"Ids"
).
front
();
PADDLE_ENFORCE_EQ
(
ids_dims
[
1
],
1
);
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W
// and it must be a column vector with rank = 2 while the 2nd dimension
// size must be 1, when Ids's type is SelectedRows, the rows of Ids
// contains the ids to be looked up in W;
if
(
ids_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
PADDLE_ENFORCE_EQ
(
ids_dims
.
size
(),
2
);
PADDLE_ENFORCE_EQ
(
ids_dims
[
1
],
1
);
}
ctx
->
SetOutputDim
(
"Out"
,
{
ids_dims
[
0
],
table_dims
[
1
]});
ctx
->
SetOutputDim
(
"Out"
,
{
ids_dims
[
0
],
table_dims
[
1
]});
ctx
->
ShareLoD
(
"Ids"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"Ids"
,
/*->*/
"Out"
);
...
@@ -54,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -54,17 +62,22 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
LookupTableOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
LookupTableOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"W"
,
AddInput
(
"W"
,
"
An
input represents embedding tensors, "
"
(Tensor) The
input represents embedding tensors, "
"which is a learnable parameter."
);
"which is a learnable parameter."
);
AddInput
(
"Ids"
,
AddInput
(
"An input with type int32 or int64 "
"Ids"
,
"contains the ids to be looked up in W. "
"(Tensor or SelectedRows) Ids's type can be Tensor or "
"Ids must be a column vector with rank = 2. "
"SelectedRows, when Ids's type is Tensor, this tensor contains "
"The 2nd dimension size must be 1."
);
"the ids to be looked up in W and it must be a column vector with "
AddOutput
(
"Out"
,
"The lookup results, which have the same type as W."
);
"rank = 2 while the 2nd dimension size must be 1; when Ids's type is "
"SelectedRows, the rows of Ids contains the ids to be looked up "
"in W."
);
AddOutput
(
"Out"
,
"(Tensor or SelectedRows) The lookup results, which have the "
"same type as W."
);
AddAttr
<
bool
>
(
"is_sparse"
,
AddAttr
<
bool
>
(
"is_sparse"
,
"(boolean, default false) "
"(boolean, default false) "
"Sparse update"
)
"Sparse update
.
"
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
int64_t
>
(
"padding_idx"
,
AddAttr
<
int64_t
>
(
"padding_idx"
,
"(int64, default -1) "
"(int64, default -1) "
...
@@ -76,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -76,10 +89,15 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
Lookup Table Operator.
Lookup Table Operator.
This operator is used to perform lookups on the parameter W,
This operator is used to perform lookups on the parameter W,
then concatenated into a dense tensor.
then concatenated into a dense or sparse tensor.
The type of Ids(Input) is SelectedRows, Tensor or LoDTensor, when Ids's
type is SelectedRows, the rows of Ids contains the ids to be looked up in W;
when Ids's type is Tensor, this tensor contains the ids to be looked up in W
and it must be a column vector with rank = 2 while the 2nd dimension size must be 1,
at this time, Ids can carry the LoD (Level of Details) information, or not, and
the output only shares the LoD information with input Ids.
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC"
);
)DOC"
);
}
}
...
...
paddle/fluid/operators/lookup_table_op.cu
浏览文件 @
788c600e
...
@@ -74,14 +74,32 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
...
@@ -74,14 +74,32 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
auto
*
ids_var
=
context
.
InputVar
(
"Ids"
);
Tensor
*
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
int64_t
*
ids
;
int64_t
K
;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if
(
ids_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
K
=
ids_t
->
numel
();
}
else
if
(
ids_var
->
IsType
<
framework
::
SelectedRows
>
())
{
auto
*
ids_t
=
context
.
Input
<
framework
::
SelectedRows
>
(
"Ids"
);
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
rows
().
CUDAData
(
context
.
GetPlace
()));
K
=
ids_t
->
rows
().
size
();
output_t
->
Resize
({
K
,
table_t
->
dims
()[
1
]});
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Ids"
);
}
size_t
N
=
table_t
->
dims
()[
0
];
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
size_t
D
=
table_t
->
dims
()[
1
];
size_t
K
=
ids_t
->
numel
();
auto
*
ids
=
ids_t
->
data
<
int64_t
>
();
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
...
paddle/fluid/operators/lookup_table_op.h
浏览文件 @
788c600e
...
@@ -22,6 +22,7 @@ limitations under the License. */
...
@@ -22,6 +22,7 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
SelectedRows
=
framework
::
SelectedRows
;
using
SelectedRows
=
framework
::
SelectedRows
;
...
@@ -29,25 +30,45 @@ template <typename T>
...
@@ -29,25 +30,45 @@ template <typename T>
class
LookupTableKernel
:
public
framework
::
OpKernel
<
T
>
{
class
LookupTableKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
// float tensor
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
// int tensor
auto
*
ids_var
=
context
.
InputVar
(
"Ids"
);
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
// float tensor
Tensor
*
output_t
=
context
.
Output
<
Tensor
>
(
"Out"
);
int64_t
*
ids
;
int64_t
ids_numel
;
// The type of Ids(Input) is SelectedRows or LoDTensor, when Ids's type
// is LoDTensor, this tensor contains the ids to be looked up in W;
// when Ids's type is SelectedRows, the rows of Ids contains the
// ids to be looked up in W.
if
(
ids_var
->
IsType
<
LoDTensor
>
())
{
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
data
<
int64_t
>
());
ids_numel
=
ids_t
->
numel
();
}
else
if
(
ids_var
->
IsType
<
SelectedRows
>
())
{
auto
*
ids_t
=
context
.
Input
<
SelectedRows
>
(
"Ids"
);
ids
=
const_cast
<
int64_t
*>
(
ids_t
->
rows
().
data
());
ids_numel
=
ids_t
->
rows
().
size
();
output_t
->
Resize
({
ids_numel
,
table_t
->
dims
()[
1
]});
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Ids"
);
}
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
int
N
=
table_t
->
dims
()[
0
];
int
N
=
table_t
->
dims
()[
0
];
int
D
=
table_t
->
dims
()[
1
];
int
D
=
table_t
->
dims
()[
1
];
auto
*
ids
=
ids_t
->
data
<
int64_t
>
();
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
padding_idx
==
-
1
)
{
if
(
padding_idx
==
-
1
)
{
for
(
int64_t
i
=
0
;
i
<
ids_
t
->
numel
()
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
ids_
numel
;
++
i
)
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
}
}
}
else
{
}
else
{
for
(
int64_t
i
=
0
;
i
<
ids_
t
->
numel
()
;
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
ids_
numel
;
++
i
)
{
if
(
ids
[
i
]
==
padding_idx
)
{
if
(
ids
[
i
]
==
padding_idx
)
{
memset
(
output
+
i
*
D
,
0
,
D
*
sizeof
(
T
));
memset
(
output
+
i
*
D
,
0
,
D
*
sizeof
(
T
));
}
else
{
}
else
{
...
...
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
浏览文件 @
788c600e
...
@@ -15,6 +15,8 @@
...
@@ -15,6 +15,8 @@
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
from
op_test
import
OpTest
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
class
TestLookupTableOp
(
OpTest
):
class
TestLookupTableOp
(
OpTest
):
...
@@ -47,5 +49,52 @@ class TestLookupTableOpWithPadding(TestLookupTableOp):
...
@@ -47,5 +49,52 @@ class TestLookupTableOpWithPadding(TestLookupTableOp):
pass
pass
class
TestLookupTableIdsIsSelectedRows
(
OpTest
):
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
# create and initialize Variable
height
=
10
rows
=
[
0
,
4
,
4
,
7
]
row_numel
=
12
# create and initialize W Variable
W
=
scope
.
var
(
'W'
).
get_tensor
()
W_array
=
np
.
full
((
height
,
row_numel
),
1.0
).
astype
(
"float32"
)
for
i
in
range
(
height
):
W_array
[
i
]
*=
i
W
.
set
(
W_array
,
place
)
# create and initialize Ids Variable
ids_selected_rows
=
scope
.
var
(
'Ids'
).
get_selected_rows
()
ids_selected_rows
.
set_height
(
len
(
rows
))
ids_selected_rows
.
set_rows
(
rows
)
np_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
ids_tensor
=
ids_selected_rows
.
get_tensor
()
ids_tensor
.
set
(
np_array
,
place
)
# create Out Variable
Out
=
scope
.
var
(
'Out'
).
get_selected_rows
()
# create and run lookup_table operator
concat_rows_op
=
Operator
(
"lookup_table"
,
W
=
'W'
,
Ids
=
'Ids'
,
Out
=
'Out'
)
concat_rows_op
.
run
(
scope
,
place
)
# get result from Out
Out_tensor
=
Out
.
get_tensor
()
result_array
=
np
.
array
(
Out_tensor
)
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for
idx
,
row
in
enumerate
(
rows
):
assert
(
row
==
result_array
[
idx
]).
all
()
def
test_concat_rows
(
self
):
places
=
[
core
.
CPUPlace
()]
if
core
.
is_compiled_with_cuda
():
places
.
append
(
core
.
CUDAPlace
(
0
))
for
place
in
places
:
self
.
check_with_place
(
place
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录