Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
e65cbd3b
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 2 年 前同步成功
通知
708
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e65cbd3b
编写于
11月 14, 2018
作者:
Q
Qiao Longfei
提交者:
GitHub
11月 14, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #14387 from jacquesqiao/lookup_sparse_table_add_test_mode
Lookup sparse table add test mode
上级
6cf8f24b
51f3838f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
88 addition
and
14 deletion
+88
-14
paddle/fluid/framework/selected_rows.cc
paddle/fluid/framework/selected_rows.cc
+45
-7
paddle/fluid/framework/selected_rows.h
paddle/fluid/framework/selected_rows.h
+2
-2
paddle/fluid/framework/selected_rows_test.cc
paddle/fluid/framework/selected_rows_test.cc
+8
-4
paddle/fluid/operators/lookup_sparse_table_op.cc
paddle/fluid/operators/lookup_sparse_table_op.cc
+6
-1
python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py
...ddle/fluid/tests/unittests/test_lookup_sparse_table_op.py
+27
-0
未找到文件。
paddle/fluid/framework/selected_rows.cc
浏览文件 @
e65cbd3b
...
@@ -63,6 +63,26 @@ struct TensorCopyVisitor {
...
@@ -63,6 +63,26 @@ struct TensorCopyVisitor {
int64_t
size_
;
int64_t
size_
;
};
};
struct
TensorFillVisitor
{
TensorFillVisitor
(
framework
::
Tensor
*
dst
,
int64_t
dst_offset
,
int64_t
size
,
float
value
)
:
dst_
(
dst
),
dst_offset_
(
dst_offset
),
size_
(
size
)
{}
template
<
typename
T
>
void
apply
()
const
{
// TODO(qiao): support other place
platform
::
CPUPlace
cpu
;
auto
*
tensor_data
=
dst_
->
mutable_data
<
T
>
(
cpu
);
auto
*
start
=
tensor_data
+
dst_offset_
;
auto
*
end
=
start
+
size_
;
std
::
fill
(
start
,
end
,
static_cast
<
T
>
(
0.0
));
}
framework
::
Tensor
*
dst_
;
int64_t
dst_offset_
;
int64_t
size_
;
};
void
SerializeToStream
(
std
::
ostream
&
os
,
const
SelectedRows
&
selected_rows
,
void
SerializeToStream
(
std
::
ostream
&
os
,
const
SelectedRows
&
selected_rows
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
const
platform
::
DeviceContext
&
dev_ctx
)
{
{
// the 1st field, uint32_t version
{
// the 1st field, uint32_t version
...
@@ -120,7 +140,17 @@ bool SelectedRows::HasKey(int64_t key) const {
...
@@ -120,7 +140,17 @@ bool SelectedRows::HasKey(int64_t key) const {
:
true
;
:
true
;
}
}
int64_t
SelectedRows
::
AutoGrownIndex
(
int64_t
key
,
bool
auto_grown
)
{
int64_t
SelectedRows
::
AutoGrownIndex
(
int64_t
key
,
bool
auto_grown
,
bool
is_test
)
{
if
(
is_test
)
{
auto
iter
=
id_to_index_
.
find
(
key
);
if
(
iter
==
id_to_index_
.
end
())
{
return
-
1
;
}
else
{
return
iter
->
second
;
}
}
rwlock_
->
RDLock
();
rwlock_
->
RDLock
();
auto
iter
=
id_to_index_
.
find
(
key
);
auto
iter
=
id_to_index_
.
find
(
key
);
if
(
iter
==
id_to_index_
.
end
())
{
if
(
iter
==
id_to_index_
.
end
())
{
...
@@ -172,7 +202,7 @@ void SelectedRows::SyncIndex() {
...
@@ -172,7 +202,7 @@ void SelectedRows::SyncIndex() {
}
}
void
SelectedRows
::
Get
(
const
framework
::
Tensor
&
ids
,
framework
::
Tensor
*
value
,
void
SelectedRows
::
Get
(
const
framework
::
Tensor
&
ids
,
framework
::
Tensor
*
value
,
bool
auto_grown
)
{
bool
auto_grown
,
bool
is_test
)
{
PADDLE_ENFORCE
(
value
->
IsInitialized
(),
PADDLE_ENFORCE
(
value
->
IsInitialized
(),
"The value tensor should be initialized."
);
"The value tensor should be initialized."
);
if
(
ids
.
numel
()
==
0
)
{
if
(
ids
.
numel
()
==
0
)
{
...
@@ -183,11 +213,19 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
...
@@ -183,11 +213,19 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
"output tensor should have the same shape with table "
"output tensor should have the same shape with table "
"except the dims[0]."
);
"except the dims[0]."
);
for
(
int
i
=
0
;
i
<
ids
.
numel
();
++
i
)
{
for
(
int
i
=
0
;
i
<
ids
.
numel
();
++
i
)
{
int64_t
index
=
AutoGrownIndex
(
ids
.
data
<
int64_t
>
()[
i
],
auto_grown
);
auto
id
=
ids
.
data
<
int64_t
>
()[
i
];
framework
::
VisitDataType
(
int64_t
index
=
AutoGrownIndex
(
id
,
auto_grown
,
is_test
);
framework
::
ToDataType
(
value_
->
type
()),
if
(
index
<
0
)
{
TensorCopyVisitor
(
value
,
i
*
value_width
,
*
value_
.
get
(),
VLOG
(
5
)
<<
"id "
<<
id
<<
" not in the table, return 0"
;
index
*
value_width
,
value_width
));
framework
::
VisitDataType
(
framework
::
ToDataType
(
value_
->
type
()),
TensorFillVisitor
(
value
,
i
*
value_width
,
value_width
,
0.0
));
}
else
{
framework
::
VisitDataType
(
framework
::
ToDataType
(
value_
->
type
()),
TensorCopyVisitor
(
value
,
i
*
value_width
,
*
value_
.
get
(),
index
*
value_width
,
value_width
));
}
}
}
}
}
}
}
...
...
paddle/fluid/framework/selected_rows.h
浏览文件 @
e65cbd3b
...
@@ -105,7 +105,7 @@ class SelectedRows {
...
@@ -105,7 +105,7 @@ class SelectedRows {
* the value
* the value
*/
*/
void
Get
(
const
framework
::
Tensor
&
ids
,
framework
::
Tensor
*
value
,
void
Get
(
const
framework
::
Tensor
&
ids
,
framework
::
Tensor
*
value
,
bool
auto_grown
=
false
);
bool
auto_grown
=
false
,
bool
is_test
=
false
);
/*
/*
* @brief Get the index of the key from id_to_index_ map. If the key not
* @brief Get the index of the key from id_to_index_ map. If the key not
...
@@ -118,7 +118,7 @@ class SelectedRows {
...
@@ -118,7 +118,7 @@ class SelectedRows {
*
*
* @return index of the key.
* @return index of the key.
*/
*/
int64_t
AutoGrownIndex
(
int64_t
key
,
bool
auto_grown
);
int64_t
AutoGrownIndex
(
int64_t
key
,
bool
auto_grown
,
bool
is_test
=
false
);
void
SyncIndex
();
void
SyncIndex
();
...
...
paddle/fluid/framework/selected_rows_test.cc
浏览文件 @
e65cbd3b
...
@@ -84,10 +84,14 @@ TEST(SelectedRows, SparseTable) {
...
@@ -84,10 +84,14 @@ TEST(SelectedRows, SparseTable) {
data
[
i
*
embedding_width
+
j
]
=
static_cast
<
float
>
(
i
);
data
[
i
*
embedding_width
+
j
]
=
static_cast
<
float
>
(
i
);
}
}
}
}
ASSERT_EQ
(
table
.
AutoGrownIndex
(
10
,
true
),
0
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
10
,
true
,
false
),
0
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
8
,
true
),
1
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
8
,
true
,
false
),
1
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
8
,
true
),
1
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
8
,
true
,
false
),
1
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
6
,
true
),
2
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
6
,
true
,
false
),
2
);
for
(
int64_t
i
=
11
;
i
<
20
;
i
++
)
{
ASSERT_EQ
(
table
.
AutoGrownIndex
(
i
,
true
,
true
),
-
1
);
ASSERT_TRUE
(
!
table
.
HasKey
(
i
));
}
ASSERT_TRUE
(
table
.
HasKey
(
10
));
ASSERT_TRUE
(
table
.
HasKey
(
10
));
ASSERT_TRUE
(
table
.
HasKey
(
8
));
ASSERT_TRUE
(
table
.
HasKey
(
8
));
ASSERT_TRUE
(
table
.
HasKey
(
6
));
ASSERT_TRUE
(
table
.
HasKey
(
6
));
...
...
paddle/fluid/operators/lookup_sparse_table_op.cc
浏览文件 @
e65cbd3b
...
@@ -45,6 +45,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
...
@@ -45,6 +45,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
auto
out_var
=
scope
.
FindVar
(
Output
(
"Out"
));
auto
out_var
=
scope
.
FindVar
(
Output
(
"Out"
));
auto
w_var
=
scope
.
FindVar
(
Input
(
"W"
));
auto
w_var
=
scope
.
FindVar
(
Input
(
"W"
));
auto
ids_var
=
scope
.
FindVar
(
Input
(
"Ids"
));
auto
ids_var
=
scope
.
FindVar
(
Input
(
"Ids"
));
auto
is_test
=
Attr
<
bool
>
(
"is_test"
);
PADDLE_ENFORCE
(
out_var
->
IsType
<
framework
::
LoDTensor
>
(),
PADDLE_ENFORCE
(
out_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The type of Out var should be LodTensor."
);
"The type of Out var should be LodTensor."
);
...
@@ -65,7 +66,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
...
@@ -65,7 +66,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ
(
framework
::
ToDataType
(
w_t
->
value
().
type
()),
PADDLE_ENFORCE_EQ
(
framework
::
ToDataType
(
w_t
->
value
().
type
()),
framework
::
proto
::
VarType
::
FP32
,
framework
::
proto
::
VarType
::
FP32
,
"The sparse table only support FP32"
);
"The sparse table only support FP32"
);
w_t
->
Get
(
ids_t
,
out_t
,
true
);
w_t
->
Get
(
ids_t
,
out_t
,
true
,
is_test
);
}
}
};
};
...
@@ -91,6 +92,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -91,6 +92,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool default false)"
"(bool default false)"
"Whether create new value if for nonexistent key."
)
"Whether create new value if for nonexistent key."
)
.
SetDefault
(
true
);
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"is_test"
,
"In test mode, lookup_sparse_table will "
"return a 0 for unknown id"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Lookup Sprase Tablel Operator.
Lookup Sprase Tablel Operator.
...
...
python/paddle/fluid/tests/unittests/test_lookup_sparse_table_op.py
浏览文件 @
e65cbd3b
...
@@ -80,6 +80,33 @@ class TestLookupSpraseTable(OpTest):
...
@@ -80,6 +80,33 @@ class TestLookupSpraseTable(OpTest):
assert
(
result_array2
[
3
]
==
w_array
[
6
]).
all
()
assert
(
result_array2
[
3
]
==
w_array
[
6
]).
all
()
assert
(
result_array2
[
4
]
==
w_array
[
7
]).
all
()
assert
(
result_array2
[
4
]
==
w_array
[
7
]).
all
()
# create and run lookup_table operator
test_lookup_table
=
Operator
(
"lookup_sparse_table"
,
W
=
'W'
,
Ids
=
'Ids'
,
Out
=
'Out'
,
min
=-
5.0
,
max
=
10.0
,
seed
=
10
,
is_test
=
True
)
ids
=
scope
.
var
(
"Ids"
).
get_tensor
()
unknown_id
=
[
44
,
22
,
33
]
ids_array2
=
np
.
array
([
4
,
2
,
3
,
7
,
100000
]
+
unknown_id
).
astype
(
"int64"
)
ids
.
set
(
ids_array2
,
place
)
test_lookup_table
.
run
(
scope
,
place
)
result_array2
=
np
.
array
(
out_tensor
)
assert
(
result_array2
[
0
]
==
w_array
[
5
]).
all
()
assert
(
result_array2
[
1
]
==
w_array
[
1
]).
all
()
assert
(
result_array2
[
2
]
==
w_array
[
2
]).
all
()
assert
(
result_array2
[
3
]
==
w_array
[
6
]).
all
()
assert
(
result_array2
[
4
]
==
w_array
[
7
]).
all
()
for
i
in
[
5
,
6
,
7
]:
assert
np
.
all
(
result_array2
[
i
]
==
0
)
def
test_w_is_selected_rows
(
self
):
def
test_w_is_selected_rows
(
self
):
places
=
[
core
.
CPUPlace
()]
places
=
[
core
.
CPUPlace
()]
# currently only support CPU
# currently only support CPU
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录