Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
64c139e8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
64c139e8
编写于
4月 17, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Using constructor for VarHandle
上级
64bf3df0
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
28 addition
and
49 deletion
+28
-49
paddle/fluid/framework/details/broadcast_op_handle_test.cc
paddle/fluid/framework/details/broadcast_op_handle_test.cc
+5
-14
paddle/fluid/framework/details/gather_op_handle_test.cc
paddle/fluid/framework/details/gather_op_handle_test.cc
+5
-13
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+3
-7
paddle/fluid/framework/details/ssa_graph_builder.cc
paddle/fluid/framework/details/ssa_graph_builder.cc
+5
-13
paddle/fluid/framework/details/var_handle.h
paddle/fluid/framework/details/var_handle.h
+10
-2
未找到文件。
paddle/fluid/framework/details/broadcast_op_handle_test.cc
浏览文件 @
64c139e8
...
...
@@ -77,14 +77,9 @@ struct TestBroadcastOpHandle {
local_scopes_
[
input_scope_idx
]
->
Var
(
"input"
);
op_handle_
.
reset
(
new
BroadcastOpHandle
(
local_scopes_
,
gpu_list_
));
vars_
.
emplace_back
(
new
VarHandle
());
VarHandle
*
in_var_handle
=
static_cast
<
VarHandle
*>
(
vars_
.
back
().
get
());
in_var_handle
->
place_
=
gpu_list_
[
input_scope_idx
];
in_var_handle
->
name_
=
"input"
;
in_var_handle
->
version_
=
1
;
in_var_handle
->
scope_idx_
=
input_scope_idx
;
in_var_handle
->
generated_op_
=
nullptr
;
auto
*
in_var_handle
=
new
VarHandle
(
1
,
input_scope_idx
,
"input"
,
gpu_list_
[
input_scope_idx
]);
vars_
.
emplace_back
(
in_var_handle
);
op_handle_
->
AddInput
(
in_var_handle
);
// add dummy var
...
...
@@ -96,12 +91,8 @@ struct TestBroadcastOpHandle {
for
(
size_t
j
=
0
;
j
<
gpu_list_
.
size
();
++
j
)
{
op_handle_
->
dev_ctxes_
[
gpu_list_
[
j
]]
=
ctxs_
[
j
].
get
();
vars_
.
emplace_back
(
new
VarHandle
());
VarHandle
*
out_var_handle
=
static_cast
<
VarHandle
*>
(
vars_
.
back
().
get
());
out_var_handle
->
place_
=
gpu_list_
[
j
];
out_var_handle
->
name_
=
"out"
;
out_var_handle
->
version_
=
2
;
out_var_handle
->
scope_idx_
=
j
;
VarHandle
*
out_var_handle
=
new
VarHandle
(
2
,
j
,
"out"
,
gpu_list_
[
j
]);
vars_
.
emplace_back
(
out_var_handle
);
op_handle_
->
AddOutput
(
out_var_handle
);
}
...
...
paddle/fluid/framework/details/gather_op_handle_test.cc
浏览文件 @
64c139e8
...
...
@@ -79,13 +79,8 @@ struct TestGatherOpHandle {
// add input
for
(
size_t
j
=
0
;
j
<
gpu_list_
.
size
();
++
j
)
{
op_handle_
->
dev_ctxes_
[
gpu_list_
[
j
]]
=
ctxs_
[
j
].
get
();
vars_
.
emplace_back
(
new
VarHandle
());
VarHandle
*
in_var_handle
=
static_cast
<
VarHandle
*>
(
vars_
.
back
().
get
());
in_var_handle
->
place_
=
gpu_list_
[
j
];
in_var_handle
->
name_
=
"input"
;
in_var_handle
->
version_
=
1
;
in_var_handle
->
scope_idx_
=
j
;
in_var_handle
->
generated_op_
=
nullptr
;
auto
*
in_var_handle
=
new
VarHandle
(
1
,
j
,
"input"
,
gpu_list_
[
j
]);
vars_
.
emplace_back
(
in_var_handle
);
op_handle_
->
AddInput
(
in_var_handle
);
}
...
...
@@ -97,12 +92,9 @@ struct TestGatherOpHandle {
op_handle_
->
AddInput
(
in_dummy_var_handle
);
// add output
vars_
.
emplace_back
(
new
VarHandle
());
VarHandle
*
out_var_handle
=
static_cast
<
VarHandle
*>
(
vars_
.
back
().
get
());
out_var_handle
->
place_
=
gpu_list_
[
input_scope_idx
];
out_var_handle
->
name_
=
"out"
;
out_var_handle
->
version_
=
2
;
out_var_handle
->
scope_idx_
=
input_scope_idx
;
auto
*
out_var_handle
=
new
VarHandle
(
2
,
input_scope_idx
,
"out"
,
gpu_list_
[
input_scope_idx
]);
vars_
.
emplace_back
(
out_var_handle
);
op_handle_
->
AddOutput
(
out_var_handle
);
// add dummy var
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
64c139e8
...
...
@@ -177,13 +177,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
auto
&
prev_grad
=
vars
[
vars
.
size
()
-
1
];
op_handle
->
AddInput
(
prev_grad
.
get
());
vars
.
emplace_back
(
new
VarHandle
);
auto
&
var
=
vars
.
back
();
var
->
place_
=
p
;
var
->
name_
=
og
;
var
->
version_
=
vars
.
size
()
-
1
;
op_handle
->
AddOutput
(
var
.
get
());
auto
var
=
new
VarHandle
(
vars
.
size
()
-
1
,
i
,
og
,
p
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
#else
PADDLE_ENFORCE
(
"Not implemented"
);
...
...
paddle/fluid/framework/details/ssa_graph_builder.cc
浏览文件 @
64c139e8
...
...
@@ -54,13 +54,8 @@ VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
auto
&
var_holder
=
var_holders
[
each_var_name
];
VarHandle
*
var
=
nullptr
;
if
(
var_holder
.
empty
())
{
var_holder
.
emplace_back
(
new
VarHandle
);
auto
&
init_var
=
var_holder
[
0
];
init_var
->
place_
=
place
;
init_var
->
name_
=
each_var_name
;
init_var
->
generated_op_
=
nullptr
;
init_var
->
version_
=
0
;
var
=
init_var
.
get
();
var
=
new
VarHandle
(
0
,
place_offset
,
each_var_name
,
place
);
var_holder
.
emplace_back
(
var
);
}
else
{
var
=
var_holder
.
rbegin
()
->
get
();
}
...
...
@@ -73,12 +68,9 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
size_t
place_offset
)
{
auto
&
vars
=
graph
->
vars_
[
place_offset
][
each_var_name
];
size_t
version
=
vars
.
size
();
vars
.
emplace_back
(
new
VarHandle
());
auto
&
var
=
vars
.
back
();
var
->
version_
=
version
;
var
->
name_
=
each_var_name
;
var
->
place_
=
place
;
op_handle
->
AddOutput
(
var
.
get
());
auto
var
=
new
VarHandle
(
version
,
place_offset
,
each_var_name
,
place
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
template
<
typename
Callback
>
...
...
paddle/fluid/framework/details/var_handle.h
浏览文件 @
64c139e8
...
...
@@ -16,6 +16,7 @@
#include <sstream>
#include <string>
#include <unordered_set>
#include <utility>
#include "paddle/fluid/platform/place.h"
...
...
@@ -33,10 +34,10 @@ struct VarHandleBase {
// The operator who generate this variable. nullptr if the variable
// is a root node.
OpHandleBase
*
generated_op_
;
OpHandleBase
*
generated_op_
{
nullptr
}
;
// Operators which depend on this variable ready.
std
::
unordered_set
<
OpHandleBase
*>
pending_ops_
;
std
::
unordered_set
<
OpHandleBase
*>
pending_ops_
;
};
// VarHandle is actually a single version of Runtime Variable.
...
...
@@ -47,6 +48,13 @@ struct VarHandleBase {
struct
VarHandle
:
public
VarHandleBase
{
std
::
string
DebugString
()
const
override
;
VarHandle
(
size_t
version
,
size_t
scope_index
,
std
::
string
name
,
platform
::
Place
place
)
:
version_
(
version
),
scope_idx_
(
scope_index
),
name_
(
std
::
move
(
name
)),
place_
(
std
::
move
(
place
))
{}
// version field currently is not used, however, just store the version to
// debug easily.
size_t
version_
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录