Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
25f47fc0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
25f47fc0
编写于
5月 28, 2018
作者:
Q
qiaolongfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix prefetch bugs, optimize code
上级
bf869e45
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
25 addition
and
23 deletion
+25
-23
paddle/fluid/framework/selected_rows.cc
paddle/fluid/framework/selected_rows.cc
+20
-15
paddle/fluid/framework/selected_rows.h
paddle/fluid/framework/selected_rows.h
+1
-1
paddle/fluid/operators/detail/grpc_server.cc
paddle/fluid/operators/detail/grpc_server.cc
+4
-7
未找到文件。
paddle/fluid/framework/selected_rows.cc
浏览文件 @
25f47fc0
...
...
@@ -121,24 +121,29 @@ bool SelectedRows::HasKey(int64_t key) const {
}
std
::
vector
<
std
::
pair
<
int64_t
,
int64_t
>>
SelectedRows
::
Get
(
std
::
vector
<
int64_t
>
keys
,
framework
::
Tensor
*
value
)
const
{
const
std
::
vector
<
int64_t
>&
keys
,
framework
::
Tensor
*
value
)
const
{
PADDLE_ENFORCE
(
value
->
IsInitialized
(),
"The value tensor should be initialized."
);
std
::
vector
<
std
::
pair
<
int64_t
,
int64_t
>>
non_keys_pair
;
int64_t
value_width
=
value_
->
numel
()
/
value_
->
dims
()[
0
];
PADDLE_ENFORCE_EQ
(
value_width
,
value
->
numel
()
/
value
->
dims
()[
0
],
"output tensor should have the same shape with table "
"execpt the dims[0]."
);
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
int64_t
index
=
Index
(
keys
[
i
]);
if
(
index
==
-
1
)
{
non_keys_pair
.
push_back
(
std
::
make_pair
(
keys
[
i
],
static_cast
<
int64_t
>
(
i
)));
}
else
{
framework
::
VisitDataType
(
framework
::
ToDataType
(
value_
->
type
()),
TensorCopyVisitor
(
value
,
i
*
value_width
,
*
value_
.
get
(),
index
*
value_width
,
value_width
));
if
(
keys
.
empty
())
{
VLOG
(
3
)
<<
"keys is empty, please check data!"
;
}
else
{
int64_t
value_width
=
value_
->
numel
()
/
value_
->
dims
()[
0
];
PADDLE_ENFORCE_EQ
(
value_width
,
value
->
numel
()
/
value
->
dims
()[
0
],
"output tensor should have the same shape with table "
"except the dims[0]."
);
for
(
size_t
i
=
0
;
i
<
keys
.
size
();
++
i
)
{
int64_t
index
=
Index
(
keys
[
i
]);
if
(
index
==
-
1
)
{
non_keys_pair
.
push_back
(
std
::
make_pair
(
keys
[
i
],
static_cast
<
int64_t
>
(
i
)));
}
else
{
framework
::
VisitDataType
(
framework
::
ToDataType
(
value_
->
type
()),
TensorCopyVisitor
(
value
,
i
*
value_width
,
*
value_
.
get
(),
index
*
value_width
,
value_width
));
}
}
}
return
non_keys_pair
;
...
...
paddle/fluid/framework/selected_rows.h
浏览文件 @
25f47fc0
...
...
@@ -82,7 +82,7 @@ class SelectedRows {
* @return a list of pair which contains the non-exists key and the index in
* the value
*/
std
::
vector
<
std
::
pair
<
int64_t
,
int64_t
>>
Get
(
std
::
vector
<
int64_t
>
keys
,
std
::
vector
<
std
::
pair
<
int64_t
,
int64_t
>>
Get
(
const
std
::
vector
<
int64_t
>&
keys
,
framework
::
Tensor
*
value
)
const
;
/*
...
...
paddle/fluid/operators/detail/grpc_server.cc
浏览文件 @
25f47fc0
...
...
@@ -177,11 +177,8 @@ class RequestPrefetch final : public RequestBase {
program_
(
program
),
prefetch_ctx_
(
prefetch_ctx
),
req_id_
(
req_id
)
{
if
(
sync_mode_
)
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
false
));
}
else
{
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
}
// prefetch always create a new sub scope
request_
.
reset
(
new
VariableResponse
(
scope
,
dev_ctx_
,
true
));
int
method_id
=
static_cast
<
int
>
(
detail
::
GrpcMethod
::
kPrefetchVariable
);
service_
->
RequestAsyncUnary
(
method_id
,
&
ctx_
,
request_
.
get
(),
&
responder_
,
cq_
,
cq_
,
...
...
@@ -198,10 +195,10 @@ class RequestPrefetch final : public RequestBase {
std
::
string
var_name
=
request_
->
OutVarname
();
VLOG
(
3
)
<<
"RequestPrefetch "
<<
var_name
;
auto
var_desc
=
program_
->
Block
(
0
).
FindVar
(
var_name
);
framework
::
Scope
*
local_scope
=
&
scope_
->
New
Scope
();
framework
::
Scope
*
local_scope
=
request_
->
GetMutableLocal
Scope
();
auto
*
var
=
local_scope
->
FindVar
(
var_name
);
InitializeVariable
(
var
,
var_desc
->
GetType
());
executor_
->
RunPreparedContext
(
prefetch_ctx_
,
scope_
);
executor_
->
RunPreparedContext
(
prefetch_ctx_
,
local_scope
);
SerializeToByteBuffer
(
var_name
,
var
,
*
dev_ctx_
,
&
reply_
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录