Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8d205c85
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8d205c85
编写于
11月 13, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add is_test for lookup_sparse_table
上级
9a6e2392
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
60 addition
and
14 deletion
+60
-14
paddle/fluid/framework/selected_rows.cc
paddle/fluid/framework/selected_rows.cc
+44
-7
paddle/fluid/framework/selected_rows.h
paddle/fluid/framework/selected_rows.h
+2
-2
paddle/fluid/framework/selected_rows_test.cc
paddle/fluid/framework/selected_rows_test.cc
+8
-4
paddle/fluid/operators/lookup_sparse_table_op.cc
paddle/fluid/operators/lookup_sparse_table_op.cc
+6
-1
未找到文件。
paddle/fluid/framework/selected_rows.cc
浏览文件 @
8d205c85
...
@@ -63,6 +63,26 @@ struct TensorCopyVisitor {
...
@@ -63,6 +63,26 @@ struct TensorCopyVisitor {
int64_t
size_
;
int64_t
size_
;
};
};
struct
TensorFillVisitor
{
TensorFillVisitor
(
framework
::
Tensor
*
dst
,
int64_t
dst_offset
,
int64_t
size
,
float
value
)
:
dst_
(
dst
),
dst_offset_
(
dst_offset
),
size_
(
size
)
{}
template
<
typename
T
>
void
apply
()
const
{
// TODO(Yancey1989): support other place
platform
::
CPUPlace
cpu
;
auto
*
tensor_data
=
dst_
->
mutable_data
<
T
>
(
cpu
);
auto
*
start
=
tensor_data
+
dst_offset_
;
auto
*
end
=
start
+
size_
;
std
::
fill
(
start
,
end
,
static_cast
<
T
>
(
0.0
));
}
framework
::
Tensor
*
dst_
;
int64_t
dst_offset_
;
int64_t
size_
;
};
void
SerializeToStream
(
std
::
ostream
&
os
,
const
SelectedRows
&
selected_rows
,
void
SerializeToStream
(
std
::
ostream
&
os
,
const
SelectedRows
&
selected_rows
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
const
platform
::
DeviceContext
&
dev_ctx
)
{
{
// the 1st field, uint32_t version
{
// the 1st field, uint32_t version
...
@@ -120,7 +140,17 @@ bool SelectedRows::HasKey(int64_t key) const {
...
@@ -120,7 +140,17 @@ bool SelectedRows::HasKey(int64_t key) const {
:
true
;
:
true
;
}
}
int64_t
SelectedRows
::
AutoGrownIndex
(
int64_t
key
,
bool
auto_grown
)
{
int64_t
SelectedRows
::
AutoGrownIndex
(
int64_t
key
,
bool
auto_grown
,
bool
is_test
)
{
if
(
is_test
)
{
auto
iter
=
id_to_index_
.
find
(
key
);
if
(
iter
==
id_to_index_
.
end
())
{
return
-
1
;
}
else
{
return
iter
->
second
;
}
}
rwlock_
->
RDLock
();
rwlock_
->
RDLock
();
auto
iter
=
id_to_index_
.
find
(
key
);
auto
iter
=
id_to_index_
.
find
(
key
);
if
(
iter
==
id_to_index_
.
end
())
{
if
(
iter
==
id_to_index_
.
end
())
{
...
@@ -172,7 +202,7 @@ void SelectedRows::SyncIndex() {
...
@@ -172,7 +202,7 @@ void SelectedRows::SyncIndex() {
}
}
void
SelectedRows
::
Get
(
const
framework
::
Tensor
&
ids
,
framework
::
Tensor
*
value
,
void
SelectedRows
::
Get
(
const
framework
::
Tensor
&
ids
,
framework
::
Tensor
*
value
,
bool
auto_grown
)
{
bool
auto_grown
,
bool
is_test
)
{
PADDLE_ENFORCE
(
value
->
IsInitialized
(),
PADDLE_ENFORCE
(
value
->
IsInitialized
(),
"The value tensor should be initialized."
);
"The value tensor should be initialized."
);
if
(
ids
.
numel
()
==
0
)
{
if
(
ids
.
numel
()
==
0
)
{
...
@@ -183,11 +213,18 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
...
@@ -183,11 +213,18 @@ void SelectedRows::Get(const framework::Tensor& ids, framework::Tensor* value,
"output tensor should have the same shape with table "
"output tensor should have the same shape with table "
"except the dims[0]."
);
"except the dims[0]."
);
for
(
int
i
=
0
;
i
<
ids
.
numel
();
++
i
)
{
for
(
int
i
=
0
;
i
<
ids
.
numel
();
++
i
)
{
int64_t
index
=
AutoGrownIndex
(
ids
.
data
<
int64_t
>
()[
i
],
auto_grown
);
int64_t
index
=
framework
::
VisitDataType
(
AutoGrownIndex
(
ids
.
data
<
int64_t
>
()[
i
],
auto_grown
,
is_test
);
framework
::
ToDataType
(
value_
->
type
()),
if
(
index
<
0
)
{
TensorCopyVisitor
(
value
,
i
*
value_width
,
*
value_
.
get
(),
framework
::
VisitDataType
(
index
*
value_width
,
value_width
));
framework
::
ToDataType
(
value_
->
type
()),
TensorFillVisitor
(
value
,
i
*
value_width
,
value_width
,
0.0
));
}
else
{
framework
::
VisitDataType
(
framework
::
ToDataType
(
value_
->
type
()),
TensorCopyVisitor
(
value
,
i
*
value_width
,
*
value_
.
get
(),
index
*
value_width
,
value_width
));
}
}
}
}
}
}
}
...
...
paddle/fluid/framework/selected_rows.h
浏览文件 @
8d205c85
...
@@ -105,7 +105,7 @@ class SelectedRows {
...
@@ -105,7 +105,7 @@ class SelectedRows {
* the value
* the value
*/
*/
void
Get
(
const
framework
::
Tensor
&
ids
,
framework
::
Tensor
*
value
,
void
Get
(
const
framework
::
Tensor
&
ids
,
framework
::
Tensor
*
value
,
bool
auto_grown
=
false
);
bool
auto_grown
=
false
,
bool
is_test
=
false
);
/*
/*
* @brief Get the index of the key from id_to_index_ map. If the key not
* @brief Get the index of the key from id_to_index_ map. If the key not
...
@@ -118,7 +118,7 @@ class SelectedRows {
...
@@ -118,7 +118,7 @@ class SelectedRows {
*
*
* @return index of the key.
* @return index of the key.
*/
*/
int64_t
AutoGrownIndex
(
int64_t
key
,
bool
auto_grown
);
int64_t
AutoGrownIndex
(
int64_t
key
,
bool
auto_grown
,
bool
is_test
=
false
);
void
SyncIndex
();
void
SyncIndex
();
...
...
paddle/fluid/framework/selected_rows_test.cc
浏览文件 @
8d205c85
...
@@ -84,10 +84,14 @@ TEST(SelectedRows, SparseTable) {
...
@@ -84,10 +84,14 @@ TEST(SelectedRows, SparseTable) {
data
[
i
*
embedding_width
+
j
]
=
static_cast
<
float
>
(
i
);
data
[
i
*
embedding_width
+
j
]
=
static_cast
<
float
>
(
i
);
}
}
}
}
ASSERT_EQ
(
table
.
AutoGrownIndex
(
10
,
true
),
0
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
10
,
true
,
false
),
0
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
8
,
true
),
1
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
8
,
true
,
false
),
1
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
8
,
true
),
1
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
8
,
true
,
false
),
1
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
6
,
true
),
2
);
ASSERT_EQ
(
table
.
AutoGrownIndex
(
6
,
true
,
false
),
2
);
for
(
int64_t
i
=
11
;
i
<
20
;
i
++
)
{
ASSERT_EQ
(
table
.
AutoGrownIndex
(
i
,
true
,
true
),
-
1
);
ASSERT_TRUE
(
!
table
.
HasKey
(
i
));
}
ASSERT_TRUE
(
table
.
HasKey
(
10
));
ASSERT_TRUE
(
table
.
HasKey
(
10
));
ASSERT_TRUE
(
table
.
HasKey
(
8
));
ASSERT_TRUE
(
table
.
HasKey
(
8
));
ASSERT_TRUE
(
table
.
HasKey
(
6
));
ASSERT_TRUE
(
table
.
HasKey
(
6
));
...
...
paddle/fluid/operators/lookup_sparse_table_op.cc
浏览文件 @
8d205c85
...
@@ -45,6 +45,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
...
@@ -45,6 +45,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
auto
out_var
=
scope
.
FindVar
(
Output
(
"Out"
));
auto
out_var
=
scope
.
FindVar
(
Output
(
"Out"
));
auto
w_var
=
scope
.
FindVar
(
Input
(
"W"
));
auto
w_var
=
scope
.
FindVar
(
Input
(
"W"
));
auto
ids_var
=
scope
.
FindVar
(
Input
(
"Ids"
));
auto
ids_var
=
scope
.
FindVar
(
Input
(
"Ids"
));
auto
is_test
=
Attr
<
bool
>
(
"is_test"
);
PADDLE_ENFORCE
(
out_var
->
IsType
<
framework
::
LoDTensor
>
(),
PADDLE_ENFORCE
(
out_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The type of Out var should be LodTensor."
);
"The type of Out var should be LodTensor."
);
...
@@ -65,7 +66,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
...
@@ -65,7 +66,7 @@ class LookupSparseTableOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ
(
framework
::
ToDataType
(
w_t
->
value
().
type
()),
PADDLE_ENFORCE_EQ
(
framework
::
ToDataType
(
w_t
->
value
().
type
()),
framework
::
proto
::
VarType
::
FP32
,
framework
::
proto
::
VarType
::
FP32
,
"The sparse table only support FP32"
);
"The sparse table only support FP32"
);
w_t
->
Get
(
ids_t
,
out_t
,
true
);
w_t
->
Get
(
ids_t
,
out_t
,
true
,
is_test
);
}
}
};
};
...
@@ -91,6 +92,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -91,6 +92,10 @@ class LookupSparseTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool default false)"
"(bool default false)"
"Whether create new value if for nonexistent key."
)
"Whether create new value if for nonexistent key."
)
.
SetDefault
(
true
);
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"is_test"
,
"In test mode, lookup_sparse_table will "
"return a default value for unknown id"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Lookup Sprase Tablel Operator.
Lookup Sprase Tablel Operator.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录